diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py index 875b95da07..789a83f9f0 100644 --- a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py @@ -15,6 +15,7 @@ from invokeai.backend.patches.layer_patcher import LayerPatcher from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer +from invokeai.backend.patches.layers.lokr_layer import LoKRLayer from invokeai.backend.patches.layers.lora_layer import LoRALayer from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage from tests.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.test_custom_invoke_linear_8_bit_lt import ( @@ -282,6 +283,7 @@ PatchUnderTest = tuple[list[tuple[BaseLayerPatch, float]], torch.Tensor] "multiple_loras", "concatenated_lora", "flux_control_lora", + "single_lokr", ] ) def patch_under_test(request: pytest.FixtureRequest) -> PatchUnderTest: @@ -350,6 +352,20 @@ def patch_under_test(request: pytest.FixtureRequest) -> PatchUnderTest: input = torch.randn(1, patched_in_features) return ([(lora_layer, 0.7)], input) + elif layer_type == "single_lokr": + lokr_layer = LoKRLayer( + w1=torch.randn(rank, rank), + w1_a=None, + w1_b=None, + w2=torch.randn(out_features // rank, in_features // rank), + w2_a=None, + w2_b=None, + t2=None, + alpha=1.0, + bias=torch.randn(out_features), + ) + input = torch.randn(1, in_features) + return ([(lokr_layer, 0.7)], input) else: raise ValueError(f"Unsupported layer_type: {layer_type}")