mirror of
https://github.com/invoke-ai/InvokeAI
synced 2026-03-12 09:50:16 +01:00
z-image-turbo-fp8-e5m2 works. the z-image-turbo_fp8_scaled_e4m3fn_KJ dont.
This commit is contained in:
parent
8551ff8569
commit
f9605e18a0
@ -126,9 +126,12 @@ def _has_z_image_keys(state_dict: dict[str | int, Any]) -> bool:
|
||||
if isinstance(key, int):
|
||||
continue
|
||||
# Check for Z-Image specific key prefixes
|
||||
prefix = key.split(".")[0]
|
||||
if prefix in z_image_specific_keys:
|
||||
return True
|
||||
# Handle both direct keys (cap_embedder.0.weight) and
|
||||
# ComfyUI-style keys (model.diffusion_model.cap_embedder.0.weight)
|
||||
key_parts = key.split(".")
|
||||
for part in key_parts:
|
||||
if part in z_image_specific_keys:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
|
||||
@ -41,6 +41,7 @@ def _convert_z_image_gguf_to_diffusers(sd: dict[str, Any]) -> dict[str, Any]:
|
||||
- k_norm.weight -> norm_k.weight
|
||||
- x_embedder.* -> all_x_embedder.2-1.*
|
||||
- final_layer.* -> all_final_layer.2-1.*
|
||||
- norm_final.* -> skipped (diffusers uses non-learnable LayerNorm)
|
||||
"""
|
||||
new_sd: dict[str, Any] = {}
|
||||
|
||||
@ -63,6 +64,11 @@ def _convert_z_image_gguf_to_diffusers(sd: dict[str, Any]) -> dict[str, Any]:
|
||||
new_sd[new_key] = value
|
||||
continue
|
||||
|
||||
# Skip norm_final keys - the diffusers model uses LayerNorm with elementwise_affine=False
|
||||
# (no learnable weight/bias), but some checkpoints (e.g., FP8) include these as all-zeros
|
||||
if key.startswith("norm_final."):
|
||||
continue
|
||||
|
||||
# Handle fused QKV weights - need to split
|
||||
if ".attention.qkv." in key:
|
||||
# Get the layer prefix and suffix
|
||||
@ -185,16 +191,19 @@ class ZImageCheckpointModel(ModelLoader):
|
||||
# Load the state dict from safetensors/checkpoint file
|
||||
sd = load_file(model_path)
|
||||
|
||||
# Some Z-Image checkpoint files have keys prefixed with "diffusion_model."
|
||||
# Check if we need to strip this prefix
|
||||
has_prefix = any(k.startswith("diffusion_model.") for k in sd.keys() if isinstance(k, str))
|
||||
# Some Z-Image checkpoint files have keys prefixed with "diffusion_model." or
|
||||
# "model.diffusion_model." (ComfyUI-style format). Check if we need to strip this prefix.
|
||||
prefix_to_strip = None
|
||||
for prefix in ["model.diffusion_model.", "diffusion_model."]:
|
||||
if any(k.startswith(prefix) for k in sd.keys() if isinstance(k, str)):
|
||||
prefix_to_strip = prefix
|
||||
break
|
||||
|
||||
if has_prefix:
|
||||
if prefix_to_strip:
|
||||
stripped_sd = {}
|
||||
prefix = "diffusion_model."
|
||||
for key, value in sd.items():
|
||||
if isinstance(key, str) and key.startswith(prefix):
|
||||
stripped_sd[key[len(prefix) :]] = value
|
||||
if isinstance(key, str) and key.startswith(prefix_to_strip):
|
||||
stripped_sd[key[len(prefix_to_strip) :]] = value
|
||||
else:
|
||||
stripped_sd[key] = value
|
||||
sd = stripped_sd
|
||||
@ -236,6 +245,12 @@ class ZImageCheckpointModel(ModelLoader):
|
||||
new_sd_size = sum([ten.nelement() * model_dtype.itemsize for ten in sd.values()])
|
||||
self._ram_cache.make_room(new_sd_size)
|
||||
|
||||
# Filter out FP8 scale_weight and scaled_fp8 metadata keys
|
||||
# These are quantization metadata that shouldn't be loaded into the model
|
||||
keys_to_remove = [k for k in sd.keys() if k.endswith(".scale_weight") or k == "scaled_fp8"]
|
||||
for k in keys_to_remove:
|
||||
del sd[k]
|
||||
|
||||
# Convert to target dtype
|
||||
for k in sd.keys():
|
||||
sd[k] = sd[k].to(model_dtype)
|
||||
@ -284,16 +299,19 @@ class ZImageGGUFCheckpointModel(ModelLoader):
|
||||
# Load the GGUF state dict
|
||||
sd = gguf_sd_loader(model_path, compute_dtype=compute_dtype)
|
||||
|
||||
# Some Z-Image GGUF models have keys prefixed with "diffusion_model."
|
||||
# Check if we need to strip this prefix
|
||||
has_prefix = any(k.startswith("diffusion_model.") for k in sd.keys() if isinstance(k, str))
|
||||
# Some Z-Image GGUF models have keys prefixed with "diffusion_model." or
|
||||
# "model.diffusion_model." (ComfyUI-style format). Check if we need to strip this prefix.
|
||||
prefix_to_strip = None
|
||||
for prefix in ["model.diffusion_model.", "diffusion_model."]:
|
||||
if any(k.startswith(prefix) for k in sd.keys() if isinstance(k, str)):
|
||||
prefix_to_strip = prefix
|
||||
break
|
||||
|
||||
if has_prefix:
|
||||
if prefix_to_strip:
|
||||
stripped_sd = {}
|
||||
prefix = "diffusion_model."
|
||||
for key, value in sd.items():
|
||||
if isinstance(key, str) and key.startswith(prefix):
|
||||
stripped_sd[key[len(prefix) :]] = value
|
||||
if isinstance(key, str) and key.startswith(prefix_to_strip):
|
||||
stripped_sd[key[len(prefix_to_strip) :]] = value
|
||||
else:
|
||||
stripped_sd[key] = value
|
||||
sd = stripped_sd
|
||||
|
||||
Loading…
Reference in New Issue
Block a user