diff --git a/tests/backend/patches/lora_conversions/test_flux_diffusers_lora_conversion_utils.py b/tests/backend/patches/lora_conversions/test_flux_diffusers_lora_conversion_utils.py index b80577a349..2bdb883faf 100644 --- a/tests/backend/patches/lora_conversions/test_flux_diffusers_lora_conversion_utils.py +++ b/tests/backend/patches/lora_conversions/test_flux_diffusers_lora_conversion_utils.py @@ -6,6 +6,9 @@ from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_ut lora_model_from_flux_diffusers_state_dict, ) from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX +from tests.backend.patches.lora_conversions.lora_state_dicts.flux_dora_onetrainer_format import ( + state_dict_keys as flux_onetrainer_state_dict_keys, +) from tests.backend.patches.lora_conversions.lora_state_dicts.flux_lora_diffusers_format import ( state_dict_keys as flux_diffusers_state_dict_keys, ) @@ -27,12 +30,13 @@ def test_is_state_dict_likely_in_flux_diffusers_format_true(sd_keys: dict[str, l assert is_state_dict_likely_in_flux_diffusers_format(state_dict) -def test_is_state_dict_likely_in_flux_diffusers_format_false(): +@pytest.mark.parametrize("sd_keys", [flux_kohya_state_dict_keys, flux_onetrainer_state_dict_keys]) +def test_is_state_dict_likely_in_flux_diffusers_format_false(sd_keys: dict[str, list[int]]): """Test that is_state_dict_likely_in_flux_diffusers_format() returns False for a state dict that is in the Kohya FLUX LoRA format. """ # Construct a state dict that is not in the Kohya FLUX LoRA format. - state_dict = keys_to_mock_state_dict(flux_kohya_state_dict_keys) + state_dict = keys_to_mock_state_dict(sd_keys) assert not is_state_dict_likely_in_flux_diffusers_format(state_dict) diff --git a/tests/backend/patches/lora_conversions/test_flux_kohya_lora_conversion_utils.py b/tests/backend/patches/lora_conversions/test_flux_kohya_lora_conversion_utils.py index 4c58c11586..52b8ecc9c9 100644 --- a/tests/backend/patches/lora_conversions/test_flux_kohya_lora_conversion_utils.py +++ b/tests/backend/patches/lora_conversions/test_flux_kohya_lora_conversion_utils.py @@ -13,6 +13,9 @@ from invokeai.backend.patches.lora_conversions.flux_lora_constants import ( FLUX_LORA_CLIP_PREFIX, FLUX_LORA_TRANSFORMER_PREFIX, ) +from tests.backend.patches.lora_conversions.lora_state_dicts.flux_dora_onetrainer_format import ( + state_dict_keys as flux_onetrainer_state_dict_keys, +) from tests.backend.patches.lora_conversions.lora_state_dicts.flux_lora_diffusers_format import ( state_dict_keys as flux_diffusers_state_dict_keys, ) @@ -34,11 +37,12 @@ def test_is_state_dict_likely_in_flux_kohya_format_true(sd_keys: dict[str, list[ assert is_state_dict_likely_in_flux_kohya_format(state_dict) -def test_is_state_dict_likely_in_flux_kohya_format_false(): +@pytest.mark.parametrize("sd_keys", [flux_diffusers_state_dict_keys, flux_onetrainer_state_dict_keys]) +def test_is_state_dict_likely_in_flux_kohya_format_false(sd_keys: dict[str, list[int]]): """Test that is_state_dict_likely_in_flux_kohya_format() returns False for a state dict that is in the Diffusers FLUX LoRA format. """ - state_dict = keys_to_mock_state_dict(flux_diffusers_state_dict_keys) + state_dict = keys_to_mock_state_dict(sd_keys) assert not is_state_dict_likely_in_flux_kohya_format(state_dict)