mirror of
https://github.com/invoke-ai/InvokeAI
synced 2026-04-27 01:02:11 +02:00
Add state dict tensor shapes for existing LoRA unit tests.
This commit is contained in:
parent
50897ba066
commit
8ef8bd4261
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,8 +1,8 @@
|
||||
import torch
|
||||
|
||||
|
||||
def keys_to_mock_state_dict(keys: list[str]) -> dict[str, torch.Tensor]:
|
||||
def keys_to_mock_state_dict(keys: dict[str, list[int]]) -> dict[str, torch.Tensor]:
|
||||
state_dict: dict[str, torch.Tensor] = {}
|
||||
for k in keys:
|
||||
state_dict[k] = torch.empty(1)
|
||||
for k, shape in keys.items():
|
||||
state_dict[k] = torch.empty(shape)
|
||||
return state_dict
|
||||
|
||||
@ -23,7 +23,7 @@ from tests.backend.lora.conversions.lora_state_dicts.utils import keys_to_mock_s
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sd_keys", [flux_kohya_state_dict_keys, flux_kohya_te1_state_dict_keys])
|
||||
def test_is_state_dict_likely_in_flux_kohya_format_true(sd_keys: list[str]):
|
||||
def test_is_state_dict_likely_in_flux_kohya_format_true(sd_keys: dict[str, list[int]]):
|
||||
"""Test that is_state_dict_likely_in_flux_kohya_format() can identify a state dict in the Kohya FLUX LoRA format."""
|
||||
# Construct a state dict that is in the Kohya FLUX LoRA format.
|
||||
state_dict = keys_to_mock_state_dict(sd_keys)
|
||||
@ -83,7 +83,7 @@ def test_convert_flux_transformer_kohya_state_dict_to_invoke_format_error():
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sd_keys", [flux_kohya_state_dict_keys, flux_kohya_te1_state_dict_keys])
|
||||
def test_lora_model_from_flux_kohya_state_dict(sd_keys: list[str]):
|
||||
def test_lora_model_from_flux_kohya_state_dict(sd_keys: dict[str, list[int]]):
|
||||
"""Test that a LoRAModelRaw can be created from a state dict in the Kohya FLUX LoRA format."""
|
||||
# Construct a state dict that is in the Kohya FLUX LoRA format.
|
||||
state_dict = keys_to_mock_state_dict(sd_keys)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user