Add is_state_dict_likely_in_flux_onetrainer_format() util function.

This commit is contained in:
Ryan Dick 2025-01-21 16:39:59 +00:00
parent 8b4f411f7b
commit 5bd6428fdd
3 changed files with 92 additions and 1 deletions

View File

@ -26,6 +26,14 @@ FLUX_KOHYA_TRANSFORMER_KEY_REGEX = (
# lora_te1_text_model_encoder_layers_0_mlp_fc1.lora_up.weight
FLUX_KOHYA_CLIP_KEY_REGEX = r"lora_te1_text_model_encoder_layers_(\d+)_(mlp|self_attn)_(\w+)\.?.*"
# A regex pattern that matches all of the T5 keys in the Kohya FLUX LoRA format.
# Example keys:
# lora_te2_encoder_block_0_layer_0_SelfAttention_k.alpha
# lora_te2_encoder_block_0_layer_0_SelfAttention_k.dora_scale
# lora_te2_encoder_block_0_layer_0_SelfAttention_k.lora_down.weight
# lora_te2_encoder_block_0_layer_0_SelfAttention_k.lora_up.weight
FLUX_KOHYA_T5_KEY_REGEX = r"lora_te2_encoder_block_(\d+)_layer_(\d+)_(DenseReluDense|SelfAttention)_(\w+)\.?.*"
def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> bool:
"""Checks if the provided state dict is likely in the Kohya FLUX LoRA format.
@ -34,7 +42,9 @@ def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> boo
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
"""
return all(
re.match(FLUX_KOHYA_TRANSFORMER_KEY_REGEX, k) or re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k)
re.match(FLUX_KOHYA_TRANSFORMER_KEY_REGEX, k)
or re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k)
or re.match(FLUX_KOHYA_T5_KEY_REGEX, k)
for k in state_dict.keys()
)

View File

@ -0,0 +1,37 @@
import re
from typing import Any, Dict
from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import (
FLUX_KOHYA_CLIP_KEY_REGEX,
FLUX_KOHYA_T5_KEY_REGEX,
)
# A regex pattern that matches all of the transformer keys in the OneTrainer FLUX LoRA format.
# The OneTrainer format uses a mix of the Kohya and Diffusers formats:
# - The base model keys are in Diffusers format.
# - Periods are replaced with underscores, to match Kohya.
# - The LoRA key suffixes (e.g. .alpha, .lora_down.weight, .lora_up.weight) match Kohya.
# Example keys:
# - "lora_transformer_single_transformer_blocks_0_attn_to_k.alpha"
# - "lora_transformer_single_transformer_blocks_0_attn_to_k.dora_scale"
# - "lora_transformer_single_transformer_blocks_0_attn_to_k.lora_down.weight"
# - "lora_transformer_single_transformer_blocks_0_attn_to_k.lora_up.weight"
FLUX_ONETRAINER_TRANSFORMER_KEY_REGEX = (
r"lora_transformer_(single_transformer_blocks|transformer_blocks)_(\d+)_(\w+)\.(.*)"
)
def is_state_dict_likely_in_flux_onetrainer_format(state_dict: Dict[str, Any]) -> bool:
"""Checks if the provided state dict is likely in the OneTrainer FLUX LoRA format.
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
Note that OneTrainer matches the Kohya format for the CLIP and T5 models.
"""
return all(
re.match(FLUX_ONETRAINER_TRANSFORMER_KEY_REGEX, k)
or re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k)
or re.match(FLUX_KOHYA_T5_KEY_REGEX, k)
for k in state_dict.keys()
)

View File

@ -0,0 +1,44 @@
import pytest
from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_utils import (
is_state_dict_likely_in_flux_onetrainer_format,
)
from tests.backend.patches.lora_conversions.lora_state_dicts.flux_dora_onetrainer_format import (
state_dict_keys as flux_onetrainer_state_dict_keys,
)
from tests.backend.patches.lora_conversions.lora_state_dicts.flux_lora_diffusers_format import (
state_dict_keys as flux_diffusers_state_dict_keys,
)
from tests.backend.patches.lora_conversions.lora_state_dicts.flux_lora_kohya_format import (
state_dict_keys as flux_kohya_state_dict_keys,
)
from tests.backend.patches.lora_conversions.lora_state_dicts.flux_lora_kohya_with_te1_format import (
state_dict_keys as flux_kohya_te1_state_dict_keys,
)
from tests.backend.patches.lora_conversions.lora_state_dicts.utils import keys_to_mock_state_dict
def test_is_state_dict_likely_in_flux_onetrainer_format_true():
"""Test that is_state_dict_likely_in_flux_onetrainer_format() can identify a state dict in the OneTrainer
FLUX LoRA format.
"""
# Construct a state dict that is in the OneTrainer FLUX LoRA format.
state_dict = keys_to_mock_state_dict(flux_onetrainer_state_dict_keys)
assert is_state_dict_likely_in_flux_onetrainer_format(state_dict)
@pytest.mark.parametrize(
"sd_keys",
[
flux_kohya_state_dict_keys,
flux_kohya_te1_state_dict_keys,
flux_diffusers_state_dict_keys,
],
)
def test_is_state_dict_likely_in_flux_onetrainer_format_false(sd_keys: dict[str, list[int]]):
"""Test that is_state_dict_likely_in_flux_onetrainer_format() returns False for a state dict that is in the Diffusers
FLUX LoRA format.
"""
state_dict = keys_to_mock_state_dict(sd_keys)
assert not is_state_dict_likely_in_flux_onetrainer_format(state_dict)