From e046e60e1cfb3552e27d97bed0acfd6345adc0fe Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Sun, 21 Jul 2024 18:31:10 +0300 Subject: [PATCH] Add FreeU support to denoise --- invokeai/app/invocations/denoise_latents.py | 9 +++- .../stable_diffusion/extensions/freeu.py | 42 +++++++++++++++++++ .../stable_diffusion/extensions_manager.py | 10 +++-- 3 files changed, 56 insertions(+), 5 deletions(-) create mode 100644 invokeai/backend/stable_diffusion/extensions/freeu.py diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index ccacc3303c..e043e884f9 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -58,6 +58,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0 from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType +from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP @@ -790,18 +791,22 @@ class DenoiseLatentsInvocation(BaseInvocation): ext_manager.add_extension(PreviewExt(step_callback)) + ### freeu + if self.unet.freeu_config: + ext_manager.add_extension(FreeUExt(self.unet.freeu_config)) + # ext: t2i/ip adapter ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx) unet_info = context.models.load(self.unet.unet) assert isinstance(unet_info.model, UNet2DConditionModel) with ( - unet_info.model_on_device() as (model_state_dict, unet), + unet_info.model_on_device() as (cached_weights, unet), ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls), # ext: controlnet ext_manager.patch_extensions(unet), # ext: freeu, seamless, ip adapter, lora - ext_manager.patch_unet(model_state_dict, unet), + ext_manager.patch_unet(unet, cached_weights), ): sd_backend = StableDiffusionBackend(unet, scheduler) denoise_ctx.unet = unet diff --git a/invokeai/backend/stable_diffusion/extensions/freeu.py b/invokeai/backend/stable_diffusion/extensions/freeu.py new file mode 100644 index 0000000000..c723aaee0b --- /dev/null +++ b/invokeai/backend/stable_diffusion/extensions/freeu.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from contextlib import contextmanager +from typing import TYPE_CHECKING, Dict, Optional + +from diffusers import UNet2DConditionModel + +from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType +from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase + +if TYPE_CHECKING: + from invokeai.app.shared.models import FreeUConfig + + +class FreeUExt(ExtensionBase): + def __init__( + self, + freeu_config: Optional[FreeUConfig], + ): + super().__init__() + self.freeu_config = freeu_config + + @contextmanager + def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None): + did_apply_freeu = False + try: + assert hasattr(unet, "enable_freeu") # mypy doesn't pick up this attribute? + if self.freeu_config is not None: + unet.enable_freeu( + b1=self.freeu_config.b1, + b2=self.freeu_config.b2, + s1=self.freeu_config.s1, + s2=self.freeu_config.s2, + ) + did_apply_freeu = True + + yield + + finally: + assert hasattr(unet, "disable_freeu") # mypy doesn't pick up this attribute? + if did_apply_freeu: + unet.disable_freeu() diff --git a/invokeai/backend/stable_diffusion/extensions_manager.py b/invokeai/backend/stable_diffusion/extensions_manager.py index 1cae2e4219..9c4347a56c 100644 --- a/invokeai/backend/stable_diffusion/extensions_manager.py +++ b/invokeai/backend/stable_diffusion/extensions_manager.py @@ -63,9 +63,13 @@ class ExtensionsManager: yield None @contextmanager - def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel): + def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None): if self._is_canceled and self._is_canceled(): raise CanceledException - # TODO: create logic in PR with extension which uses it - yield None + # TODO: create weight patch logic in PR with extension which uses it + with ExitStack() as exit_stack: + for ext in self._extensions: + exit_stack.enter_context(ext.patch_unet(unet, cached_weights)) + + yield None