z-image-turbo-fp8-e5m2 works. the z-image-turbo_fp8_scaled_e4m3fn_KJ dont.

This commit is contained in:
Alexander Eichhorn 2025-12-10 17:15:54 +01:00
parent 8551ff8569
commit f9605e18a0
2 changed files with 38 additions and 17 deletions

View File

@ -126,9 +126,12 @@ def _has_z_image_keys(state_dict: dict[str | int, Any]) -> bool:
if isinstance(key, int):
continue
# Check for Z-Image specific key prefixes
prefix = key.split(".")[0]
if prefix in z_image_specific_keys:
return True
# Handle both direct keys (cap_embedder.0.weight) and
# ComfyUI-style keys (model.diffusion_model.cap_embedder.0.weight)
key_parts = key.split(".")
for part in key_parts:
if part in z_image_specific_keys:
return True
return False

View File

@ -41,6 +41,7 @@ def _convert_z_image_gguf_to_diffusers(sd: dict[str, Any]) -> dict[str, Any]:
- k_norm.weight -> norm_k.weight
- x_embedder.* -> all_x_embedder.2-1.*
- final_layer.* -> all_final_layer.2-1.*
- norm_final.* -> skipped (diffusers uses non-learnable LayerNorm)
"""
new_sd: dict[str, Any] = {}
@ -63,6 +64,11 @@ def _convert_z_image_gguf_to_diffusers(sd: dict[str, Any]) -> dict[str, Any]:
new_sd[new_key] = value
continue
# Skip norm_final keys - the diffusers model uses LayerNorm with elementwise_affine=False
# (no learnable weight/bias), but some checkpoints (e.g., FP8) include these as all-zeros
if key.startswith("norm_final."):
continue
# Handle fused QKV weights - need to split
if ".attention.qkv." in key:
# Get the layer prefix and suffix
@ -185,16 +191,19 @@ class ZImageCheckpointModel(ModelLoader):
# Load the state dict from safetensors/checkpoint file
sd = load_file(model_path)
# Some Z-Image checkpoint files have keys prefixed with "diffusion_model."
# Check if we need to strip this prefix
has_prefix = any(k.startswith("diffusion_model.") for k in sd.keys() if isinstance(k, str))
# Some Z-Image checkpoint files have keys prefixed with "diffusion_model." or
# "model.diffusion_model." (ComfyUI-style format). Check if we need to strip this prefix.
prefix_to_strip = None
for prefix in ["model.diffusion_model.", "diffusion_model."]:
if any(k.startswith(prefix) for k in sd.keys() if isinstance(k, str)):
prefix_to_strip = prefix
break
if has_prefix:
if prefix_to_strip:
stripped_sd = {}
prefix = "diffusion_model."
for key, value in sd.items():
if isinstance(key, str) and key.startswith(prefix):
stripped_sd[key[len(prefix) :]] = value
if isinstance(key, str) and key.startswith(prefix_to_strip):
stripped_sd[key[len(prefix_to_strip) :]] = value
else:
stripped_sd[key] = value
sd = stripped_sd
@ -236,6 +245,12 @@ class ZImageCheckpointModel(ModelLoader):
new_sd_size = sum([ten.nelement() * model_dtype.itemsize for ten in sd.values()])
self._ram_cache.make_room(new_sd_size)
# Filter out FP8 scale_weight and scaled_fp8 metadata keys
# These are quantization metadata that shouldn't be loaded into the model
keys_to_remove = [k for k in sd.keys() if k.endswith(".scale_weight") or k == "scaled_fp8"]
for k in keys_to_remove:
del sd[k]
# Convert to target dtype
for k in sd.keys():
sd[k] = sd[k].to(model_dtype)
@ -284,16 +299,19 @@ class ZImageGGUFCheckpointModel(ModelLoader):
# Load the GGUF state dict
sd = gguf_sd_loader(model_path, compute_dtype=compute_dtype)
# Some Z-Image GGUF models have keys prefixed with "diffusion_model."
# Check if we need to strip this prefix
has_prefix = any(k.startswith("diffusion_model.") for k in sd.keys() if isinstance(k, str))
# Some Z-Image GGUF models have keys prefixed with "diffusion_model." or
# "model.diffusion_model." (ComfyUI-style format). Check if we need to strip this prefix.
prefix_to_strip = None
for prefix in ["model.diffusion_model.", "diffusion_model."]:
if any(k.startswith(prefix) for k in sd.keys() if isinstance(k, str)):
prefix_to_strip = prefix
break
if has_prefix:
if prefix_to_strip:
stripped_sd = {}
prefix = "diffusion_model."
for key, value in sd.items():
if isinstance(key, str) and key.startswith(prefix):
stripped_sd[key[len(prefix) :]] = value
if isinstance(key, str) and key.startswith(prefix_to_strip):
stripped_sd[key[len(prefix_to_strip) :]] = value
else:
stripped_sd[key] = value
sd = stripped_sd