diff --git a/invokeai/backend/model_manager/configs/main.py b/invokeai/backend/model_manager/configs/main.py index b2ce766f37..cd4a580802 100644 --- a/invokeai/backend/model_manager/configs/main.py +++ b/invokeai/backend/model_manager/configs/main.py @@ -126,9 +126,12 @@ def _has_z_image_keys(state_dict: dict[str | int, Any]) -> bool: if isinstance(key, int): continue # Check for Z-Image specific key prefixes - prefix = key.split(".")[0] - if prefix in z_image_specific_keys: - return True + # Handle both direct keys (cap_embedder.0.weight) and + # ComfyUI-style keys (model.diffusion_model.cap_embedder.0.weight) + key_parts = key.split(".") + for part in key_parts: + if part in z_image_specific_keys: + return True return False diff --git a/invokeai/backend/model_manager/load/model_loaders/z_image.py b/invokeai/backend/model_manager/load/model_loaders/z_image.py index e7e0813073..2668607a7d 100644 --- a/invokeai/backend/model_manager/load/model_loaders/z_image.py +++ b/invokeai/backend/model_manager/load/model_loaders/z_image.py @@ -41,6 +41,7 @@ def _convert_z_image_gguf_to_diffusers(sd: dict[str, Any]) -> dict[str, Any]: - k_norm.weight -> norm_k.weight - x_embedder.* -> all_x_embedder.2-1.* - final_layer.* -> all_final_layer.2-1.* + - norm_final.* -> skipped (diffusers uses non-learnable LayerNorm) """ new_sd: dict[str, Any] = {} @@ -63,6 +64,11 @@ def _convert_z_image_gguf_to_diffusers(sd: dict[str, Any]) -> dict[str, Any]: new_sd[new_key] = value continue + # Skip norm_final keys - the diffusers model uses LayerNorm with elementwise_affine=False + # (no learnable weight/bias), but some checkpoints (e.g., FP8) include these as all-zeros + if key.startswith("norm_final."): + continue + # Handle fused QKV weights - need to split if ".attention.qkv." in key: # Get the layer prefix and suffix @@ -185,16 +191,19 @@ class ZImageCheckpointModel(ModelLoader): # Load the state dict from safetensors/checkpoint file sd = load_file(model_path) - # Some Z-Image checkpoint files have keys prefixed with "diffusion_model." - # Check if we need to strip this prefix - has_prefix = any(k.startswith("diffusion_model.") for k in sd.keys() if isinstance(k, str)) + # Some Z-Image checkpoint files have keys prefixed with "diffusion_model." or + # "model.diffusion_model." (ComfyUI-style format). Check if we need to strip this prefix. + prefix_to_strip = None + for prefix in ["model.diffusion_model.", "diffusion_model."]: + if any(k.startswith(prefix) for k in sd.keys() if isinstance(k, str)): + prefix_to_strip = prefix + break - if has_prefix: + if prefix_to_strip: stripped_sd = {} - prefix = "diffusion_model." for key, value in sd.items(): - if isinstance(key, str) and key.startswith(prefix): - stripped_sd[key[len(prefix) :]] = value + if isinstance(key, str) and key.startswith(prefix_to_strip): + stripped_sd[key[len(prefix_to_strip) :]] = value else: stripped_sd[key] = value sd = stripped_sd @@ -236,6 +245,12 @@ class ZImageCheckpointModel(ModelLoader): new_sd_size = sum([ten.nelement() * model_dtype.itemsize for ten in sd.values()]) self._ram_cache.make_room(new_sd_size) + # Filter out FP8 scale_weight and scaled_fp8 metadata keys + # These are quantization metadata that shouldn't be loaded into the model + keys_to_remove = [k for k in sd.keys() if k.endswith(".scale_weight") or k == "scaled_fp8"] + for k in keys_to_remove: + del sd[k] + # Convert to target dtype for k in sd.keys(): sd[k] = sd[k].to(model_dtype) @@ -284,16 +299,19 @@ class ZImageGGUFCheckpointModel(ModelLoader): # Load the GGUF state dict sd = gguf_sd_loader(model_path, compute_dtype=compute_dtype) - # Some Z-Image GGUF models have keys prefixed with "diffusion_model." - # Check if we need to strip this prefix - has_prefix = any(k.startswith("diffusion_model.") for k in sd.keys() if isinstance(k, str)) + # Some Z-Image GGUF models have keys prefixed with "diffusion_model." or + # "model.diffusion_model." (ComfyUI-style format). Check if we need to strip this prefix. + prefix_to_strip = None + for prefix in ["model.diffusion_model.", "diffusion_model."]: + if any(k.startswith(prefix) for k in sd.keys() if isinstance(k, str)): + prefix_to_strip = prefix + break - if has_prefix: + if prefix_to_strip: stripped_sd = {} - prefix = "diffusion_model." for key, value in sd.items(): - if isinstance(key, str) and key.startswith(prefix): - stripped_sd[key[len(prefix) :]] = value + if isinstance(key, str) and key.startswith(prefix_to_strip): + stripped_sd[key[len(prefix_to_strip) :]] = value else: stripped_sd[key] = value sd = stripped_sd