fix: detect FLUX.2 Klein 9B Base variant via filename heuristic (#9011)

Klein 9B Base (undistilled) and Klein 9B (distilled) have identical
architectures and cannot be distinguished from the state dict alone.
Use a filename heuristic ("base" in the name) to detect the Base
variant for checkpoint, GGUF, and diffusers format models.

Also fixes the incorrect guidance_embeds-based detection for diffusers
format, since both variants have guidance_embeds=False.
This commit is contained in:
Alexander Eichhorn 2026-04-07 04:31:33 +02:00 committed by GitHub
parent f08b802968
commit dbbf28925b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -323,6 +323,16 @@ def _is_flux2_model(state_dict: dict[str | int, Any]) -> bool:
return False
def _filename_suggests_base(name: str) -> bool:
"""Check if a model name/filename suggests it is a Base (undistilled) variant.
Klein 9B Base and Klein 9B have identical architectures and cannot be distinguished
from the state dict. We use the filename as a heuristic: filenames containing "base"
(e.g. "flux-2-klein-base-9b", "FLUX.2-klein-base-9B") indicate the undistilled model.
"""
return "base" in name.lower()
def _get_flux2_variant(state_dict: dict[str | int, Any]) -> Flux2VariantType | None:
"""Determine FLUX.2 variant from state dict.
@ -330,9 +340,9 @@ def _get_flux2_variant(state_dict: dict[str | int, Any]) -> Flux2VariantType | N
- Klein 4B: context_in_dim = 7680 (3 × Qwen3-4B hidden_size 2560)
- Klein 9B: context_in_dim = 12288 (3 × Qwen3-8B hidden_size 4096)
Note: Klein 9B Base (undistilled) also has context_in_dim = 12288 but is rare.
We default to Klein9B (distilled) for all 9B models since GGUF models may not
include guidance embedding keys needed to distinguish them.
Note: Klein 9B (distilled) and Klein 9B Base (undistilled) have identical architectures
and cannot be distinguished from the state dict alone. This function defaults to Klein9B
for all 9B models. Callers should use filename heuristics to detect Klein9BBase.
Supports both BFL format (checkpoint) and diffusers format keys:
- BFL format: txt_in.weight (context embedder)
@ -366,7 +376,7 @@ def _get_flux2_variant(state_dict: dict[str | int, Any]) -> Flux2VariantType | N
context_in_dim = shape[1]
# Determine variant based on context dimension
if context_in_dim == KLEIN_9B_CONTEXT_DIM:
# Default to Klein9B (distilled) - the official/common 9B model
# Default to Klein9B - callers use filename heuristics to detect Klein9BBase
return Flux2VariantType.Klein9B
elif context_in_dim == KLEIN_4B_CONTEXT_DIM:
return Flux2VariantType.Klein4B
@ -553,6 +563,11 @@ class Main_Checkpoint_Flux2_Config(Checkpoint_Config_Base, Main_Config_Base, Con
if variant is None:
raise NotAMatchError("unable to determine FLUX.2 model variant from state dict")
# Klein 9B Base and Klein 9B have identical architectures.
# Use filename heuristic to detect the Base (undistilled) variant.
if variant == Flux2VariantType.Klein9B and _filename_suggests_base(mod.name):
return Flux2VariantType.Klein9BBase
return variant
@classmethod
@ -720,6 +735,11 @@ class Main_GGUF_Flux2_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Ba
if variant is None:
raise NotAMatchError("unable to determine FLUX.2 model variant from state dict")
# Klein 9B Base and Klein 9B have identical architectures.
# Use filename heuristic to detect the Base (undistilled) variant.
if variant == Flux2VariantType.Klein9B and _filename_suggests_base(mod.name):
return Flux2VariantType.Klein9BBase
return variant
@classmethod
@ -829,12 +849,8 @@ class Main_Diffusers_Flux2_Config(Diffusers_Config_Base, Main_Config_Base, Confi
- Klein 4B: joint_attention_dim = 7680 (3×Qwen3-4B hidden size)
- Klein 9B/9B Base: joint_attention_dim = 12288 (3×Qwen3-8B hidden size)
To distinguish Klein 9B (distilled) from Klein 9B Base (undistilled),
we check guidance_embeds:
- Klein 9B (distilled): guidance_embeds = False (guidance is "baked in" during distillation)
- Klein 9B Base (undistilled): guidance_embeds = True (needs guidance at inference)
Note: The official BFL Klein 9B model is the distilled version with guidance_embeds=False.
Klein 9B (distilled) and Klein 9B Base (undistilled) have identical architectures
and both have guidance_embeds=False. We use a filename heuristic to detect Base models.
"""
KLEIN_4B_CONTEXT_DIM = 7680 # 3 × 2560
KLEIN_9B_CONTEXT_DIM = 12288 # 3 × 4096
@ -842,17 +858,12 @@ class Main_Diffusers_Flux2_Config(Diffusers_Config_Base, Main_Config_Base, Confi
transformer_config = get_config_dict_or_raise(mod.path / "transformer" / "config.json")
joint_attention_dim = transformer_config.get("joint_attention_dim", 4096)
guidance_embeds = transformer_config.get("guidance_embeds", False)
# Determine variant based on joint_attention_dim
if joint_attention_dim == KLEIN_9B_CONTEXT_DIM:
# Check guidance_embeds to distinguish distilled from undistilled
# Klein 9B (distilled): guidance_embeds = False (guidance is baked in)
# Klein 9B Base (undistilled): guidance_embeds = True (needs guidance)
if guidance_embeds:
if _filename_suggests_base(mod.name):
return Flux2VariantType.Klein9BBase
else:
return Flux2VariantType.Klein9B
return Flux2VariantType.Klein9B
elif joint_attention_dim == KLEIN_4B_CONTEXT_DIM:
return Flux2VariantType.Klein4B
elif joint_attention_dim > 4096: