From eb33303e79823caaac6a2d30e79cc5f87b9b76d3 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 27 Mar 2024 19:01:04 +1100 Subject: [PATCH] fix(mm): handle depth and inpainting models when converting to diffusers "Normal" models have 4 in-channels, while "Depth" models have 5 and "Inpaint" models have 9. We need to explicitly tell diffusers the channel count when converting models. Closes #6058 --- .../load/model_loaders/stable_diffusion.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py index e401a10eb9..c3260957c8 100644 --- a/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +++ b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py @@ -14,12 +14,18 @@ from invokeai.backend.model_manager import ( SchedulerPredictionType, SubModelType, ) -from invokeai.backend.model_manager.config import CheckpointConfigBase, MainCheckpointConfig +from invokeai.backend.model_manager.config import CheckpointConfigBase, MainCheckpointConfig, ModelVariantType from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers from .. import ModelLoaderRegistry from .generic_diffusers import GenericDiffusersLoader +VARIANT_TO_IN_CHANNEL_MAP = { + ModelVariantType.Normal: 4, + ModelVariantType.Depth: 5, + ModelVariantType.Inpaint: 9, +} + @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Diffusers) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Checkpoint) @@ -87,6 +93,7 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader): ) self._logger.info(f"Converting {model_path} to diffusers format") + convert_ckpt_to_diffusers( model_path, output_path, @@ -99,5 +106,6 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader): image_size=image_size, upcast_attention=upcast_attention, load_safety_checker=False, + num_in_channels=VARIANT_TO_IN_CHANNEL_MAP[config.variant], ) return output_path