From c367b21c71e53a86e1d085cd752019142e699f21 Mon Sep 17 00:00:00 2001 From: Fabio 'MrWHO' Torchetti Date: Sun, 12 Mar 2023 15:40:33 -0500 Subject: [PATCH 1/2] Fix issue #2932 --- .../convert_ckpt_to_diffusers.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py b/invokeai/backend/model_management/convert_ckpt_to_diffusers.py index ae5550880a..bb600af61f 100644 --- a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py +++ b/invokeai/backend/model_management/convert_ckpt_to_diffusers.py @@ -378,18 +378,24 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False for key in keys: if key.startswith("model.diffusion_model"): flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) - unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop( - flat_ema_key - ) + flat_ema_key_alt = "model_ema." + "".join(key.split(".")[2:]) + if flat_ema_key in checkpoint: + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop( + flat_ema_key + ) + elif flat_ema_key_alt in checkpoint: + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop( + flat_ema_key_alt + ) + else: + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop( + key + ) else: print( " | Extracting only the non-EMA weights (usually better for fine-tuning)" ) - for key in keys: - if key.startswith(unet_key): - unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) - new_checkpoint = {} new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict[ From 5c5106c14a382ef21e3c56755d7c3c3afa1c3048 Mon Sep 17 00:00:00 2001 From: Fabio 'MrWHO' Torchetti Date: Sun, 12 Mar 2023 16:22:22 -0500 Subject: [PATCH 2/2] Add keys when non EMA --- .../backend/model_management/convert_ckpt_to_diffusers.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py b/invokeai/backend/model_management/convert_ckpt_to_diffusers.py index bb600af61f..979cd82743 100644 --- a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py +++ b/invokeai/backend/model_management/convert_ckpt_to_diffusers.py @@ -396,6 +396,10 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False " | Extracting only the non-EMA weights (usually better for fine-tuning)" ) + for key in keys: + if key.startswith("model.diffusion_model") and key in checkpoint: + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + new_checkpoint = {} new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict[