diff --git a/invokeai/backend/model_manager/configs/lora.py b/invokeai/backend/model_manager/configs/lora.py index dcacf8d492..1619c9d6f0 100644 --- a/invokeai/backend/model_manager/configs/lora.py +++ b/invokeai/backend/model_manager/configs/lora.py @@ -239,6 +239,52 @@ def _is_flux2_lora_state_dict(state_dict: dict[str | int, Any]) -> bool: if in_dim is not None: return in_dim in _FLUX2_VEC_IN_DIMS + # Kohya format: check transformer block dimensions (hidden_size and MLP ratio). + # This handles LoRAs that only target transformer blocks (no txt_in/vector_in/context_embedder). + # Klein 9B has hidden_size=4096 (vs 3072 for FLUX.1 and Klein 4B). + # Klein 4B has same hidden_size as FLUX.1 (3072) but different mlp_ratio (6 vs 4). + kohya_hidden_size: int | None = None + for key in state_dict: + if not isinstance(key, str): + continue + if not key.startswith("lora_unet_"): + continue + + # Check img_attn_proj hidden_size + if "_img_attn_proj." in key and key.endswith("lora_down.weight"): + kohya_hidden_size = state_dict[key].shape[1] + if kohya_hidden_size != _FLUX1_HIDDEN_SIZE: + return True + break + # LoKR variant + elif "_img_attn_proj." in key and key.endswith((".lokr_w1", ".lokr_w1_b")): + layer_prefix = key.rsplit(".", 1)[0] + in_dim = _lokr_in_dim(state_dict, layer_prefix) + if in_dim is not None: + if in_dim != _FLUX1_HIDDEN_SIZE: + return True + kohya_hidden_size = in_dim + break + + # Kohya format: hidden_size matches FLUX.1. Check MLP ratio to distinguish Klein 4B. + # Klein 4B uses mlp_ratio=6 (ffn_dim=18432), FLUX.1 uses mlp_ratio=4 (ffn_dim=12288). + if kohya_hidden_size == _FLUX1_HIDDEN_SIZE: + for key in state_dict: + if not isinstance(key, str): + continue + if key.startswith("lora_unet_") and "_img_mlp_0." in key and key.endswith("lora_up.weight"): + ffn_dim = state_dict[key].shape[0] + if ffn_dim != kohya_hidden_size * _FLUX1_MLP_RATIO: + return True + break + # LoKR variant + if key.startswith("lora_unet_") and "_img_mlp_0." in key and key.endswith((".lokr_w1", ".lokr_w1_a")): + layer_prefix = key.rsplit(".", 1)[0] + out_dim = _lokr_out_dim(state_dict, layer_prefix) + if out_dim is not None and out_dim != kohya_hidden_size * _FLUX1_MLP_RATIO: + return True + break + return False @@ -421,6 +467,33 @@ def _get_flux2_lora_variant(state_dict: dict[str | int, Any]) -> Flux2VariantTyp return Flux2VariantType.Klein9B return None + # Kohya format: check transformer block dimensions (hidden_size from img_attn_proj). + # This handles LoRAs that only target transformer blocks (no txt_in/vector_in/context_embedder). + for key in state_dict: + if not isinstance(key, str): + continue + if not key.startswith("lora_unet_"): + continue + + # Check img_attn_proj hidden_size + if "_img_attn_proj." in key and key.endswith("lora_down.weight"): + dim = state_dict[key].shape[1] + if dim == KLEIN_4B_HIDDEN_SIZE: + return Flux2VariantType.Klein4B + if dim == KLEIN_9B_HIDDEN_SIZE: + return Flux2VariantType.Klein9B + return None + # LoKR variant + elif "_img_attn_proj." in key and key.endswith((".lokr_w1", ".lokr_w1_b")): + layer_prefix = key.rsplit(".", 1)[0] + in_dim = _lokr_in_dim(state_dict, layer_prefix) + if in_dim is not None: + if in_dim == KLEIN_4B_HIDDEN_SIZE: + return Flux2VariantType.Klein4B + if in_dim == KLEIN_9B_HIDDEN_SIZE: + return Flux2VariantType.Klein9B + return None + return None