Expand unit tests to test for confusion between FLUX LoRA formats.

This commit is contained in:
Ryan Dick 2025-01-21 16:48:26 +00:00
parent 5bd6428fdd
commit faa4fa02c0
2 changed files with 12 additions and 4 deletions

View File

@ -6,6 +6,9 @@ from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_ut
lora_model_from_flux_diffusers_state_dict,
)
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from tests.backend.patches.lora_conversions.lora_state_dicts.flux_dora_onetrainer_format import (
state_dict_keys as flux_onetrainer_state_dict_keys,
)
from tests.backend.patches.lora_conversions.lora_state_dicts.flux_lora_diffusers_format import (
state_dict_keys as flux_diffusers_state_dict_keys,
)
@ -27,12 +30,13 @@ def test_is_state_dict_likely_in_flux_diffusers_format_true(sd_keys: dict[str, l
assert is_state_dict_likely_in_flux_diffusers_format(state_dict)
def test_is_state_dict_likely_in_flux_diffusers_format_false():
@pytest.mark.parametrize("sd_keys", [flux_kohya_state_dict_keys, flux_onetrainer_state_dict_keys])
def test_is_state_dict_likely_in_flux_diffusers_format_false(sd_keys: dict[str, list[int]]):
"""Test that is_state_dict_likely_in_flux_diffusers_format() returns False for a state dict that is in the Kohya
FLUX LoRA format.
"""
# Construct a state dict that is not in the Kohya FLUX LoRA format.
state_dict = keys_to_mock_state_dict(flux_kohya_state_dict_keys)
state_dict = keys_to_mock_state_dict(sd_keys)
assert not is_state_dict_likely_in_flux_diffusers_format(state_dict)

View File

@ -13,6 +13,9 @@ from invokeai.backend.patches.lora_conversions.flux_lora_constants import (
FLUX_LORA_CLIP_PREFIX,
FLUX_LORA_TRANSFORMER_PREFIX,
)
from tests.backend.patches.lora_conversions.lora_state_dicts.flux_dora_onetrainer_format import (
state_dict_keys as flux_onetrainer_state_dict_keys,
)
from tests.backend.patches.lora_conversions.lora_state_dicts.flux_lora_diffusers_format import (
state_dict_keys as flux_diffusers_state_dict_keys,
)
@ -34,11 +37,12 @@ def test_is_state_dict_likely_in_flux_kohya_format_true(sd_keys: dict[str, list[
assert is_state_dict_likely_in_flux_kohya_format(state_dict)
def test_is_state_dict_likely_in_flux_kohya_format_false():
@pytest.mark.parametrize("sd_keys", [flux_diffusers_state_dict_keys, flux_onetrainer_state_dict_keys])
def test_is_state_dict_likely_in_flux_kohya_format_false(sd_keys: dict[str, list[int]]):
"""Test that is_state_dict_likely_in_flux_kohya_format() returns False for a state dict that is in the Diffusers
FLUX LoRA format.
"""
state_dict = keys_to_mock_state_dict(flux_diffusers_state_dict_keys)
state_dict = keys_to_mock_state_dict(sd_keys)
assert not is_state_dict_likely_in_flux_kohya_format(state_dict)