mirror of
https://github.com/invoke-ai/InvokeAI
synced 2026-04-18 12:45:56 +02:00
fix: detect FLUX.2 Klein 9B Base variant via filename heuristic (#9011)
Klein 9B Base (undistilled) and Klein 9B (distilled) have identical
architectures and cannot be distinguished from the state dict alone.
Use a filename heuristic ("base" in the name) to detect the Base
variant for checkpoint, GGUF, and diffusers format models.
Also fixes the incorrect guidance_embeds-based detection for diffusers
format, since both variants have guidance_embeds=False.
This commit is contained in:
parent
f08b802968
commit
dbbf28925b
@ -323,6 +323,16 @@ def _is_flux2_model(state_dict: dict[str | int, Any]) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _filename_suggests_base(name: str) -> bool:
|
||||
"""Check if a model name/filename suggests it is a Base (undistilled) variant.
|
||||
|
||||
Klein 9B Base and Klein 9B have identical architectures and cannot be distinguished
|
||||
from the state dict. We use the filename as a heuristic: filenames containing "base"
|
||||
(e.g. "flux-2-klein-base-9b", "FLUX.2-klein-base-9B") indicate the undistilled model.
|
||||
"""
|
||||
return "base" in name.lower()
|
||||
|
||||
|
||||
def _get_flux2_variant(state_dict: dict[str | int, Any]) -> Flux2VariantType | None:
|
||||
"""Determine FLUX.2 variant from state dict.
|
||||
|
||||
@ -330,9 +340,9 @@ def _get_flux2_variant(state_dict: dict[str | int, Any]) -> Flux2VariantType | N
|
||||
- Klein 4B: context_in_dim = 7680 (3 × Qwen3-4B hidden_size 2560)
|
||||
- Klein 9B: context_in_dim = 12288 (3 × Qwen3-8B hidden_size 4096)
|
||||
|
||||
Note: Klein 9B Base (undistilled) also has context_in_dim = 12288 but is rare.
|
||||
We default to Klein9B (distilled) for all 9B models since GGUF models may not
|
||||
include guidance embedding keys needed to distinguish them.
|
||||
Note: Klein 9B (distilled) and Klein 9B Base (undistilled) have identical architectures
|
||||
and cannot be distinguished from the state dict alone. This function defaults to Klein9B
|
||||
for all 9B models. Callers should use filename heuristics to detect Klein9BBase.
|
||||
|
||||
Supports both BFL format (checkpoint) and diffusers format keys:
|
||||
- BFL format: txt_in.weight (context embedder)
|
||||
@ -366,7 +376,7 @@ def _get_flux2_variant(state_dict: dict[str | int, Any]) -> Flux2VariantType | N
|
||||
context_in_dim = shape[1]
|
||||
# Determine variant based on context dimension
|
||||
if context_in_dim == KLEIN_9B_CONTEXT_DIM:
|
||||
# Default to Klein9B (distilled) - the official/common 9B model
|
||||
# Default to Klein9B - callers use filename heuristics to detect Klein9BBase
|
||||
return Flux2VariantType.Klein9B
|
||||
elif context_in_dim == KLEIN_4B_CONTEXT_DIM:
|
||||
return Flux2VariantType.Klein4B
|
||||
@ -553,6 +563,11 @@ class Main_Checkpoint_Flux2_Config(Checkpoint_Config_Base, Main_Config_Base, Con
|
||||
if variant is None:
|
||||
raise NotAMatchError("unable to determine FLUX.2 model variant from state dict")
|
||||
|
||||
# Klein 9B Base and Klein 9B have identical architectures.
|
||||
# Use filename heuristic to detect the Base (undistilled) variant.
|
||||
if variant == Flux2VariantType.Klein9B and _filename_suggests_base(mod.name):
|
||||
return Flux2VariantType.Klein9BBase
|
||||
|
||||
return variant
|
||||
|
||||
@classmethod
|
||||
@ -720,6 +735,11 @@ class Main_GGUF_Flux2_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Ba
|
||||
if variant is None:
|
||||
raise NotAMatchError("unable to determine FLUX.2 model variant from state dict")
|
||||
|
||||
# Klein 9B Base and Klein 9B have identical architectures.
|
||||
# Use filename heuristic to detect the Base (undistilled) variant.
|
||||
if variant == Flux2VariantType.Klein9B and _filename_suggests_base(mod.name):
|
||||
return Flux2VariantType.Klein9BBase
|
||||
|
||||
return variant
|
||||
|
||||
@classmethod
|
||||
@ -829,12 +849,8 @@ class Main_Diffusers_Flux2_Config(Diffusers_Config_Base, Main_Config_Base, Confi
|
||||
- Klein 4B: joint_attention_dim = 7680 (3×Qwen3-4B hidden size)
|
||||
- Klein 9B/9B Base: joint_attention_dim = 12288 (3×Qwen3-8B hidden size)
|
||||
|
||||
To distinguish Klein 9B (distilled) from Klein 9B Base (undistilled),
|
||||
we check guidance_embeds:
|
||||
- Klein 9B (distilled): guidance_embeds = False (guidance is "baked in" during distillation)
|
||||
- Klein 9B Base (undistilled): guidance_embeds = True (needs guidance at inference)
|
||||
|
||||
Note: The official BFL Klein 9B model is the distilled version with guidance_embeds=False.
|
||||
Klein 9B (distilled) and Klein 9B Base (undistilled) have identical architectures
|
||||
and both have guidance_embeds=False. We use a filename heuristic to detect Base models.
|
||||
"""
|
||||
KLEIN_4B_CONTEXT_DIM = 7680 # 3 × 2560
|
||||
KLEIN_9B_CONTEXT_DIM = 12288 # 3 × 4096
|
||||
@ -842,17 +858,12 @@ class Main_Diffusers_Flux2_Config(Diffusers_Config_Base, Main_Config_Base, Confi
|
||||
transformer_config = get_config_dict_or_raise(mod.path / "transformer" / "config.json")
|
||||
|
||||
joint_attention_dim = transformer_config.get("joint_attention_dim", 4096)
|
||||
guidance_embeds = transformer_config.get("guidance_embeds", False)
|
||||
|
||||
# Determine variant based on joint_attention_dim
|
||||
if joint_attention_dim == KLEIN_9B_CONTEXT_DIM:
|
||||
# Check guidance_embeds to distinguish distilled from undistilled
|
||||
# Klein 9B (distilled): guidance_embeds = False (guidance is baked in)
|
||||
# Klein 9B Base (undistilled): guidance_embeds = True (needs guidance)
|
||||
if guidance_embeds:
|
||||
if _filename_suggests_base(mod.name):
|
||||
return Flux2VariantType.Klein9BBase
|
||||
else:
|
||||
return Flux2VariantType.Klein9B
|
||||
return Flux2VariantType.Klein9B
|
||||
elif joint_attention_dim == KLEIN_4B_CONTEXT_DIM:
|
||||
return Flux2VariantType.Klein4B
|
||||
elif joint_attention_dim > 4096:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user