Add state dict tensor shapes for existing LoRA unit tests.

This commit is contained in:
Ryan Dick 2024-11-13 21:59:48 +00:00
parent 50897ba066
commit 8ef8bd4261
5 changed files with 3039 additions and 3039 deletions

View File

@ -1,8 +1,8 @@
import torch
def keys_to_mock_state_dict(keys: list[str]) -> dict[str, torch.Tensor]:
def keys_to_mock_state_dict(keys: dict[str, list[int]]) -> dict[str, torch.Tensor]:
state_dict: dict[str, torch.Tensor] = {}
for k in keys:
state_dict[k] = torch.empty(1)
for k, shape in keys.items():
state_dict[k] = torch.empty(shape)
return state_dict

View File

@ -23,7 +23,7 @@ from tests.backend.lora.conversions.lora_state_dicts.utils import keys_to_mock_s
@pytest.mark.parametrize("sd_keys", [flux_kohya_state_dict_keys, flux_kohya_te1_state_dict_keys])
def test_is_state_dict_likely_in_flux_kohya_format_true(sd_keys: list[str]):
def test_is_state_dict_likely_in_flux_kohya_format_true(sd_keys: dict[str, list[int]]):
"""Test that is_state_dict_likely_in_flux_kohya_format() can identify a state dict in the Kohya FLUX LoRA format."""
# Construct a state dict that is in the Kohya FLUX LoRA format.
state_dict = keys_to_mock_state_dict(sd_keys)
@ -83,7 +83,7 @@ def test_convert_flux_transformer_kohya_state_dict_to_invoke_format_error():
@pytest.mark.parametrize("sd_keys", [flux_kohya_state_dict_keys, flux_kohya_te1_state_dict_keys])
def test_lora_model_from_flux_kohya_state_dict(sd_keys: list[str]):
def test_lora_model_from_flux_kohya_state_dict(sd_keys: dict[str, list[int]]):
"""Test that a LoRAModelRaw can be created from a state dict in the Kohya FLUX LoRA format."""
# Construct a state dict that is in the Kohya FLUX LoRA format.
state_dict = keys_to_mock_state_dict(sd_keys)