From dbbf28925b7fb8ae06d5e89b407b86aa8dc6a5c4 Mon Sep 17 00:00:00 2001 From: Alexander Eichhorn Date: Tue, 7 Apr 2026 04:31:33 +0200 Subject: [PATCH] fix: detect FLUX.2 Klein 9B Base variant via filename heuristic (#9011) Klein 9B Base (undistilled) and Klein 9B (distilled) have identical architectures and cannot be distinguished from the state dict alone. Use a filename heuristic ("base" in the name) to detect the Base variant for checkpoint, GGUF, and diffusers format models. Also fixes the incorrect guidance_embeds-based detection for diffusers format, since both variants have guidance_embeds=False. --- .../backend/model_manager/configs/main.py | 45 ++++++++++++------- 1 file changed, 28 insertions(+), 17 deletions(-) diff --git a/invokeai/backend/model_manager/configs/main.py b/invokeai/backend/model_manager/configs/main.py index 6f737ceb92..dff887f7d0 100644 --- a/invokeai/backend/model_manager/configs/main.py +++ b/invokeai/backend/model_manager/configs/main.py @@ -323,6 +323,16 @@ def _is_flux2_model(state_dict: dict[str | int, Any]) -> bool: return False +def _filename_suggests_base(name: str) -> bool: + """Check if a model name/filename suggests it is a Base (undistilled) variant. + + Klein 9B Base and Klein 9B have identical architectures and cannot be distinguished + from the state dict. We use the filename as a heuristic: filenames containing "base" + (e.g. "flux-2-klein-base-9b", "FLUX.2-klein-base-9B") indicate the undistilled model. + """ + return "base" in name.lower() + + def _get_flux2_variant(state_dict: dict[str | int, Any]) -> Flux2VariantType | None: """Determine FLUX.2 variant from state dict. @@ -330,9 +340,9 @@ def _get_flux2_variant(state_dict: dict[str | int, Any]) -> Flux2VariantType | N - Klein 4B: context_in_dim = 7680 (3 × Qwen3-4B hidden_size 2560) - Klein 9B: context_in_dim = 12288 (3 × Qwen3-8B hidden_size 4096) - Note: Klein 9B Base (undistilled) also has context_in_dim = 12288 but is rare. - We default to Klein9B (distilled) for all 9B models since GGUF models may not - include guidance embedding keys needed to distinguish them. + Note: Klein 9B (distilled) and Klein 9B Base (undistilled) have identical architectures + and cannot be distinguished from the state dict alone. This function defaults to Klein9B + for all 9B models. Callers should use filename heuristics to detect Klein9BBase. Supports both BFL format (checkpoint) and diffusers format keys: - BFL format: txt_in.weight (context embedder) @@ -366,7 +376,7 @@ def _get_flux2_variant(state_dict: dict[str | int, Any]) -> Flux2VariantType | N context_in_dim = shape[1] # Determine variant based on context dimension if context_in_dim == KLEIN_9B_CONTEXT_DIM: - # Default to Klein9B (distilled) - the official/common 9B model + # Default to Klein9B - callers use filename heuristics to detect Klein9BBase return Flux2VariantType.Klein9B elif context_in_dim == KLEIN_4B_CONTEXT_DIM: return Flux2VariantType.Klein4B @@ -553,6 +563,11 @@ class Main_Checkpoint_Flux2_Config(Checkpoint_Config_Base, Main_Config_Base, Con if variant is None: raise NotAMatchError("unable to determine FLUX.2 model variant from state dict") + # Klein 9B Base and Klein 9B have identical architectures. + # Use filename heuristic to detect the Base (undistilled) variant. + if variant == Flux2VariantType.Klein9B and _filename_suggests_base(mod.name): + return Flux2VariantType.Klein9BBase + return variant @classmethod @@ -720,6 +735,11 @@ class Main_GGUF_Flux2_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Ba if variant is None: raise NotAMatchError("unable to determine FLUX.2 model variant from state dict") + # Klein 9B Base and Klein 9B have identical architectures. + # Use filename heuristic to detect the Base (undistilled) variant. + if variant == Flux2VariantType.Klein9B and _filename_suggests_base(mod.name): + return Flux2VariantType.Klein9BBase + return variant @classmethod @@ -829,12 +849,8 @@ class Main_Diffusers_Flux2_Config(Diffusers_Config_Base, Main_Config_Base, Confi - Klein 4B: joint_attention_dim = 7680 (3×Qwen3-4B hidden size) - Klein 9B/9B Base: joint_attention_dim = 12288 (3×Qwen3-8B hidden size) - To distinguish Klein 9B (distilled) from Klein 9B Base (undistilled), - we check guidance_embeds: - - Klein 9B (distilled): guidance_embeds = False (guidance is "baked in" during distillation) - - Klein 9B Base (undistilled): guidance_embeds = True (needs guidance at inference) - - Note: The official BFL Klein 9B model is the distilled version with guidance_embeds=False. + Klein 9B (distilled) and Klein 9B Base (undistilled) have identical architectures + and both have guidance_embeds=False. We use a filename heuristic to detect Base models. """ KLEIN_4B_CONTEXT_DIM = 7680 # 3 × 2560 KLEIN_9B_CONTEXT_DIM = 12288 # 3 × 4096 @@ -842,17 +858,12 @@ class Main_Diffusers_Flux2_Config(Diffusers_Config_Base, Main_Config_Base, Confi transformer_config = get_config_dict_or_raise(mod.path / "transformer" / "config.json") joint_attention_dim = transformer_config.get("joint_attention_dim", 4096) - guidance_embeds = transformer_config.get("guidance_embeds", False) # Determine variant based on joint_attention_dim if joint_attention_dim == KLEIN_9B_CONTEXT_DIM: - # Check guidance_embeds to distinguish distilled from undistilled - # Klein 9B (distilled): guidance_embeds = False (guidance is baked in) - # Klein 9B Base (undistilled): guidance_embeds = True (needs guidance) - if guidance_embeds: + if _filename_suggests_base(mod.name): return Flux2VariantType.Klein9BBase - else: - return Flux2VariantType.Klein9B + return Flux2VariantType.Klein9B elif joint_attention_dim == KLEIN_4B_CONTEXT_DIM: return Flux2VariantType.Klein4B elif joint_attention_dim > 4096: