diff --git a/invokeai/backend/util/hotfixes.py b/invokeai/backend/util/hotfixes.py index 9c643d13bc..89b3da5a37 100644 --- a/invokeai/backend/util/hotfixes.py +++ b/invokeai/backend/util/hotfixes.py @@ -6,7 +6,13 @@ from torch import nn from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.loaders import FromOriginalControlnetMixin from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor -from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps +from diffusers.models.embeddings import ( + TextImageProjection, + TextImageTimeEmbedding, + TextTimeEmbedding, + TimestepEmbedding, + Timesteps, +) from diffusers.models.modeling_utils import ModelMixin from diffusers.models.unet_2d_blocks import ( CrossAttnDownBlock2D, @@ -22,6 +28,7 @@ from diffusers.models.controlnet import ControlNetConditioningEmbedding, Control # TODO: create PR to diffusers # Modified ControlNetModel with encoder_attention_mask argument added + class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): """ A ControlNet model. @@ -736,9 +743,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): mid_block_res_sample = mid_block_res_sample * conditioning_scale if self.config.global_pool_conditions: - down_block_res_samples = [ - torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples - ] + down_block_res_samples = [torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples] mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True) if not return_dict: @@ -749,6 +754,5 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): ) - diffusers.ControlNetModel = ControlNetModel diffusers.models.controlnet.ControlNetModel = ControlNetModel