diff --git a/ldm/gfpgan/gfpgan_tools.py b/ldm/gfpgan/gfpgan_tools.py index 0de706ae42..72df72455a 100644 --- a/ldm/gfpgan/gfpgan_tools.py +++ b/ldm/gfpgan/gfpgan_tools.py @@ -75,52 +75,50 @@ def run_gfpgan(image, strength, seed, upsampler_scale=4): def _load_gfpgan_bg_upsampler(bg_upsampler, upsampler_scale, bg_tile=400): if bg_upsampler == 'realesrgan': - if not torch.cuda.is_available(): # CPU - warnings.warn( - 'The unoptimized RealESRGAN is slow on CPU. We do not use it. ' - 'If you really want to use it, please modify the corresponding codes.' - ) - bg_upsampler = None + if not torch.cuda.is_available(): # CPU or MPS on M1 + use_half_precision = False else: - model_path = { - 2: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth', - 4: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth', - } + use_half_precision = True - if upsampler_scale not in model_path: - return None + model_path = { + 2: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth', + 4: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth', + } - from basicsr.archs.rrdbnet_arch import RRDBNet - from realesrgan import RealESRGANer + if upsampler_scale not in model_path: + return None - if upsampler_scale == 4: - model = RRDBNet( - num_in_ch=3, - num_out_ch=3, - num_feat=64, - num_block=23, - num_grow_ch=32, - scale=4, - ) - if upsampler_scale == 2: - model = RRDBNet( - num_in_ch=3, - num_out_ch=3, - num_feat=64, - num_block=23, - num_grow_ch=32, - scale=2, - ) + from basicsr.archs.rrdbnet_arch import RRDBNet + from realesrgan import RealESRGANer - bg_upsampler = RealESRGANer( - scale=upsampler_scale, - model_path=model_path[upsampler_scale], - model=model, - tile=bg_tile, - tile_pad=10, - pre_pad=0, - half=True, - ) # need to set False in CPU mode + if upsampler_scale == 4: + model = RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=23, + num_grow_ch=32, + scale=4, + ) + if upsampler_scale == 2: + model = RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=23, + num_grow_ch=32, + scale=2, + ) + + bg_upsampler = RealESRGANer( + scale=upsampler_scale, + model_path=model_path[upsampler_scale], + model=model, + tile=bg_tile, + tile_pad=10, + pre_pad=0, + half=use_half_precision, + ) else: bg_upsampler = None