diff --git a/ldm/dream/restoration/realesrgan.py b/ldm/dream/restoration/realesrgan.py index 0836021031..dc3eebd912 100644 --- a/ldm/dream/restoration/realesrgan.py +++ b/ldm/dream/restoration/realesrgan.py @@ -1,7 +1,6 @@ import torch import warnings import numpy as np -from ldm.dream.devices import choose_precision, choose_torch_device from PIL import Image @@ -9,12 +8,17 @@ from PIL import Image class ESRGAN(): def __init__(self, bg_tile_size=400) -> None: self.bg_tile_size = bg_tile_size - device = torch.device(choose_torch_device()) - precision = choose_precision(device) - use_half_precision = precision == 'float16' - def load_esrgan_bg_upsampler(self, precision): - use_half_precision = precision == 'float16' + if not torch.cuda.is_available(): # CPU or MPS on M1 + use_half_precision = False + else: + use_half_precision = True + + def load_esrgan_bg_upsampler(self): + if not torch.cuda.is_available(): # CPU or MPS on M1 + use_half_precision = False + else: + use_half_precision = True from realesrgan.archs.srvgg_arch import SRVGGNetCompact from realesrgan import RealESRGANer @@ -35,13 +39,13 @@ class ESRGAN(): return bg_upsampler - def process(self, image, strength: float, seed: str = None, upsampler_scale: int = 2, precision: str = 'float16'): + def process(self, image, strength: float, seed: str = None, upsampler_scale: int = 2): with warnings.catch_warnings(): warnings.filterwarnings('ignore', category=DeprecationWarning) warnings.filterwarnings('ignore', category=UserWarning) try: - upsampler = self.load_esrgan_bg_upsampler(precision) + upsampler = self.load_esrgan_bg_upsampler() except Exception: import traceback import sys diff --git a/ldm/generate.py b/ldm/generate.py index 504314ef91..fc40fa6152 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -599,8 +599,7 @@ class Generate: opt, args, image_callback = callback, - prefix = prefix, - precision = self.precision, + prefix = prefix ) elif tool is None: @@ -771,7 +770,7 @@ class Generate: if len(upscale) < 2: upscale.append(0.75) image = self.esrgan.process( - image, upscale[1], seed, int(upscale[0]), precision=self.precision) + image, upscale[1], seed, int(upscale[0])) else: print(">> ESRGAN is disabled. Image not upscaled.") except Exception as e: