mirror of
https://github.com/invoke-ai/InvokeAI
synced 2026-03-12 09:50:16 +01:00
fix(model_manager): detect Flux 2 Klein LoRAs in Kohya format with transformer-only keys (#8938)
LoRAs trained with musubi-tuner (and potentially other trainers) that only target transformer blocks (double_blocks/single_blocks) without embedding layers (txt_in/vector_in/context_embedder) were incorrectly classified as Flux 1. Add fallback detection using attention projection hidden_size and MLP ratio from transformer block tensors Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
This commit is contained in:
parent
3d81edac61
commit
274d9b3a74
@ -239,6 +239,52 @@ def _is_flux2_lora_state_dict(state_dict: dict[str | int, Any]) -> bool:
|
||||
if in_dim is not None:
|
||||
return in_dim in _FLUX2_VEC_IN_DIMS
|
||||
|
||||
# Kohya format: check transformer block dimensions (hidden_size and MLP ratio).
|
||||
# This handles LoRAs that only target transformer blocks (no txt_in/vector_in/context_embedder).
|
||||
# Klein 9B has hidden_size=4096 (vs 3072 for FLUX.1 and Klein 4B).
|
||||
# Klein 4B has same hidden_size as FLUX.1 (3072) but different mlp_ratio (6 vs 4).
|
||||
kohya_hidden_size: int | None = None
|
||||
for key in state_dict:
|
||||
if not isinstance(key, str):
|
||||
continue
|
||||
if not key.startswith("lora_unet_"):
|
||||
continue
|
||||
|
||||
# Check img_attn_proj hidden_size
|
||||
if "_img_attn_proj." in key and key.endswith("lora_down.weight"):
|
||||
kohya_hidden_size = state_dict[key].shape[1]
|
||||
if kohya_hidden_size != _FLUX1_HIDDEN_SIZE:
|
||||
return True
|
||||
break
|
||||
# LoKR variant
|
||||
elif "_img_attn_proj." in key and key.endswith((".lokr_w1", ".lokr_w1_b")):
|
||||
layer_prefix = key.rsplit(".", 1)[0]
|
||||
in_dim = _lokr_in_dim(state_dict, layer_prefix)
|
||||
if in_dim is not None:
|
||||
if in_dim != _FLUX1_HIDDEN_SIZE:
|
||||
return True
|
||||
kohya_hidden_size = in_dim
|
||||
break
|
||||
|
||||
# Kohya format: hidden_size matches FLUX.1. Check MLP ratio to distinguish Klein 4B.
|
||||
# Klein 4B uses mlp_ratio=6 (ffn_dim=18432), FLUX.1 uses mlp_ratio=4 (ffn_dim=12288).
|
||||
if kohya_hidden_size == _FLUX1_HIDDEN_SIZE:
|
||||
for key in state_dict:
|
||||
if not isinstance(key, str):
|
||||
continue
|
||||
if key.startswith("lora_unet_") and "_img_mlp_0." in key and key.endswith("lora_up.weight"):
|
||||
ffn_dim = state_dict[key].shape[0]
|
||||
if ffn_dim != kohya_hidden_size * _FLUX1_MLP_RATIO:
|
||||
return True
|
||||
break
|
||||
# LoKR variant
|
||||
if key.startswith("lora_unet_") and "_img_mlp_0." in key and key.endswith((".lokr_w1", ".lokr_w1_a")):
|
||||
layer_prefix = key.rsplit(".", 1)[0]
|
||||
out_dim = _lokr_out_dim(state_dict, layer_prefix)
|
||||
if out_dim is not None and out_dim != kohya_hidden_size * _FLUX1_MLP_RATIO:
|
||||
return True
|
||||
break
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@ -421,6 +467,33 @@ def _get_flux2_lora_variant(state_dict: dict[str | int, Any]) -> Flux2VariantTyp
|
||||
return Flux2VariantType.Klein9B
|
||||
return None
|
||||
|
||||
# Kohya format: check transformer block dimensions (hidden_size from img_attn_proj).
|
||||
# This handles LoRAs that only target transformer blocks (no txt_in/vector_in/context_embedder).
|
||||
for key in state_dict:
|
||||
if not isinstance(key, str):
|
||||
continue
|
||||
if not key.startswith("lora_unet_"):
|
||||
continue
|
||||
|
||||
# Check img_attn_proj hidden_size
|
||||
if "_img_attn_proj." in key and key.endswith("lora_down.weight"):
|
||||
dim = state_dict[key].shape[1]
|
||||
if dim == KLEIN_4B_HIDDEN_SIZE:
|
||||
return Flux2VariantType.Klein4B
|
||||
if dim == KLEIN_9B_HIDDEN_SIZE:
|
||||
return Flux2VariantType.Klein9B
|
||||
return None
|
||||
# LoKR variant
|
||||
elif "_img_attn_proj." in key and key.endswith((".lokr_w1", ".lokr_w1_b")):
|
||||
layer_prefix = key.rsplit(".", 1)[0]
|
||||
in_dim = _lokr_in_dim(state_dict, layer_prefix)
|
||||
if in_dim is not None:
|
||||
if in_dim == KLEIN_4B_HIDDEN_SIZE:
|
||||
return Flux2VariantType.Klein4B
|
||||
if in_dim == KLEIN_9B_HIDDEN_SIZE:
|
||||
return Flux2VariantType.Klein9B
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user