From 75165957c9d64137ed1735bd8cf7517202d6bfca Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 6 Oct 2022 20:52:38 -0400 Subject: [PATCH] Revert "realesrgan inherits precision setting from main program" This reverts commit 5f42d0894521fd2d9559decd523f21fc12ab1164. This fix was intended to solve issue #939, in which ESRGAN generates dark images when upscaling 4X on certain GTX cards. However, the fix apparently causes conflicts with some versions of the ESRGAN library, and this fix will have to wait until after release of 2.0. --- ldm/dream/restoration/realesrgan.py | 20 ++++++++++++-------- ldm/generate.py | 5 ++--- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/ldm/dream/restoration/realesrgan.py b/ldm/dream/restoration/realesrgan.py index 0836021031..dc3eebd912 100644 --- a/ldm/dream/restoration/realesrgan.py +++ b/ldm/dream/restoration/realesrgan.py @@ -1,7 +1,6 @@ import torch import warnings import numpy as np -from ldm.dream.devices import choose_precision, choose_torch_device from PIL import Image @@ -9,12 +8,17 @@ from PIL import Image class ESRGAN(): def __init__(self, bg_tile_size=400) -> None: self.bg_tile_size = bg_tile_size - device = torch.device(choose_torch_device()) - precision = choose_precision(device) - use_half_precision = precision == 'float16' - def load_esrgan_bg_upsampler(self, precision): - use_half_precision = precision == 'float16' + if not torch.cuda.is_available(): # CPU or MPS on M1 + use_half_precision = False + else: + use_half_precision = True + + def load_esrgan_bg_upsampler(self): + if not torch.cuda.is_available(): # CPU or MPS on M1 + use_half_precision = False + else: + use_half_precision = True from realesrgan.archs.srvgg_arch import SRVGGNetCompact from realesrgan import RealESRGANer @@ -35,13 +39,13 @@ class ESRGAN(): return bg_upsampler - def process(self, image, strength: float, seed: str = None, upsampler_scale: int = 2, precision: str = 'float16'): + def process(self, image, strength: float, seed: str = None, upsampler_scale: int = 2): with warnings.catch_warnings(): warnings.filterwarnings('ignore', category=DeprecationWarning) warnings.filterwarnings('ignore', category=UserWarning) try: - upsampler = self.load_esrgan_bg_upsampler(precision) + upsampler = self.load_esrgan_bg_upsampler() except Exception: import traceback import sys diff --git a/ldm/generate.py b/ldm/generate.py index 504314ef91..fc40fa6152 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -599,8 +599,7 @@ class Generate: opt, args, image_callback = callback, - prefix = prefix, - precision = self.precision, + prefix = prefix ) elif tool is None: @@ -771,7 +770,7 @@ class Generate: if len(upscale) < 2: upscale.append(0.75) image = self.esrgan.process( - image, upscale[1], seed, int(upscale[0]), precision=self.precision) + image, upscale[1], seed, int(upscale[0])) else: print(">> ESRGAN is disabled. Image not upscaled.") except Exception as e: