fix(model_manager): detect Flux 2 Klein LoRAs in Kohya format with transformer-only keys (#8938)

LoRAs trained with musubi-tuner (and potentially other trainers) that
only target transformer blocks (double_blocks/single_blocks) without
embedding layers (txt_in/vector_in/context_embedder) were incorrectly
classified as Flux 1. Add fallback detection using attention projection
hidden_size and MLP ratio from transformer block tensors

Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
This commit is contained in:
Alexander Eichhorn 2026-03-07 02:52:25 +01:00 committed by GitHub
parent 3d81edac61
commit 274d9b3a74
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -239,6 +239,52 @@ def _is_flux2_lora_state_dict(state_dict: dict[str | int, Any]) -> bool:
if in_dim is not None:
return in_dim in _FLUX2_VEC_IN_DIMS
# Kohya format: check transformer block dimensions (hidden_size and MLP ratio).
# This handles LoRAs that only target transformer blocks (no txt_in/vector_in/context_embedder).
# Klein 9B has hidden_size=4096 (vs 3072 for FLUX.1 and Klein 4B).
# Klein 4B has same hidden_size as FLUX.1 (3072) but different mlp_ratio (6 vs 4).
kohya_hidden_size: int | None = None
for key in state_dict:
if not isinstance(key, str):
continue
if not key.startswith("lora_unet_"):
continue
# Check img_attn_proj hidden_size
if "_img_attn_proj." in key and key.endswith("lora_down.weight"):
kohya_hidden_size = state_dict[key].shape[1]
if kohya_hidden_size != _FLUX1_HIDDEN_SIZE:
return True
break
# LoKR variant
elif "_img_attn_proj." in key and key.endswith((".lokr_w1", ".lokr_w1_b")):
layer_prefix = key.rsplit(".", 1)[0]
in_dim = _lokr_in_dim(state_dict, layer_prefix)
if in_dim is not None:
if in_dim != _FLUX1_HIDDEN_SIZE:
return True
kohya_hidden_size = in_dim
break
# Kohya format: hidden_size matches FLUX.1. Check MLP ratio to distinguish Klein 4B.
# Klein 4B uses mlp_ratio=6 (ffn_dim=18432), FLUX.1 uses mlp_ratio=4 (ffn_dim=12288).
if kohya_hidden_size == _FLUX1_HIDDEN_SIZE:
for key in state_dict:
if not isinstance(key, str):
continue
if key.startswith("lora_unet_") and "_img_mlp_0." in key and key.endswith("lora_up.weight"):
ffn_dim = state_dict[key].shape[0]
if ffn_dim != kohya_hidden_size * _FLUX1_MLP_RATIO:
return True
break
# LoKR variant
if key.startswith("lora_unet_") and "_img_mlp_0." in key and key.endswith((".lokr_w1", ".lokr_w1_a")):
layer_prefix = key.rsplit(".", 1)[0]
out_dim = _lokr_out_dim(state_dict, layer_prefix)
if out_dim is not None and out_dim != kohya_hidden_size * _FLUX1_MLP_RATIO:
return True
break
return False
@ -421,6 +467,33 @@ def _get_flux2_lora_variant(state_dict: dict[str | int, Any]) -> Flux2VariantTyp
return Flux2VariantType.Klein9B
return None
# Kohya format: check transformer block dimensions (hidden_size from img_attn_proj).
# This handles LoRAs that only target transformer blocks (no txt_in/vector_in/context_embedder).
for key in state_dict:
if not isinstance(key, str):
continue
if not key.startswith("lora_unet_"):
continue
# Check img_attn_proj hidden_size
if "_img_attn_proj." in key and key.endswith("lora_down.weight"):
dim = state_dict[key].shape[1]
if dim == KLEIN_4B_HIDDEN_SIZE:
return Flux2VariantType.Klein4B
if dim == KLEIN_9B_HIDDEN_SIZE:
return Flux2VariantType.Klein9B
return None
# LoKR variant
elif "_img_attn_proj." in key and key.endswith((".lokr_w1", ".lokr_w1_b")):
layer_prefix = key.rsplit(".", 1)[0]
in_dim = _lokr_in_dim(state_dict, layer_prefix)
if in_dim is not None:
if in_dim == KLEIN_4B_HIDDEN_SIZE:
return Flux2VariantType.Klein4B
if in_dim == KLEIN_9B_HIDDEN_SIZE:
return Flux2VariantType.Klein9B
return None
return None