From 2855bb6b41ca267a87a87ce06d2c425f72ad3b7d Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Sat, 28 Dec 2024 21:12:53 +0000 Subject: [PATCH] Update BaseLayerPatch.get_parameters(...) to accept a dict of orig_parameters rather than orig_module. This will enable compatibility between patching and cpu->gpu streaming. --- .../custom_modules/custom_module_mixin.py | 6 +++--- invokeai/backend/patches/layer_patcher.py | 4 +++- invokeai/backend/patches/layers/base_layer_patch.py | 2 +- .../backend/patches/layers/concatenated_lora_layer.py | 2 +- .../backend/patches/layers/flux_control_lora_layer.py | 6 +++--- invokeai/backend/patches/layers/lora_layer_base.py | 10 +++++----- invokeai/backend/patches/layers/set_parameter_layer.py | 4 ++-- .../patches/sidecar_wrappers/base_sidecar_wrapper.py | 6 ++++-- .../patches/layers/test_flux_control_lora_layer.py | 2 +- tests/backend/patches/layers/test_lora_layer.py | 2 +- .../backend/patches/layers/test_set_parameter_layer.py | 2 +- 11 files changed, 25 insertions(+), 21 deletions(-) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_module_mixin.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_module_mixin.py index 03c6d81e2a..58b3a610a0 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_module_mixin.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_module_mixin.py @@ -32,9 +32,9 @@ class CustomModuleMixin: params: dict[str, torch.Tensor] = {} for patch, patch_weight in patches_and_weights: - # TODO(ryand): self._orig_module could be quantized. Depending on what the patch is doing with the original - # module, this might fail or return incorrect results. - layer_params = patch.get_parameters(self, weight=patch_weight) + # TODO(ryand): `self` could be a quantized module. Depending on what the patch is doing with the original + # parameters, this might fail or return incorrect results. + layer_params = patch.get_parameters(dict(self.named_parameters(recurse=False)), weight=patch_weight) # type: ignore for param_name, param_weight in layer_params.items(): if param_name not in params: diff --git a/invokeai/backend/patches/layer_patcher.py b/invokeai/backend/patches/layer_patcher.py index d7f6bea166..0eaad184e2 100644 --- a/invokeai/backend/patches/layer_patcher.py +++ b/invokeai/backend/patches/layer_patcher.py @@ -166,7 +166,9 @@ class LayerPatcher: # TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA # devices here. Experimentally, it was found to be very slow on CPU. More investigation needed. - for param_name, param_weight in patch.get_parameters(module_to_patch, weight=patch_weight).items(): + for param_name, param_weight in patch.get_parameters( + dict(module_to_patch.named_parameters(recurse=False)), weight=patch_weight + ).items(): param_key = module_to_patch_key + "." + param_name module_param = module_to_patch.get_parameter(param_name) diff --git a/invokeai/backend/patches/layers/base_layer_patch.py b/invokeai/backend/patches/layers/base_layer_patch.py index 5eb04864c8..f6f0289a90 100644 --- a/invokeai/backend/patches/layers/base_layer_patch.py +++ b/invokeai/backend/patches/layers/base_layer_patch.py @@ -5,7 +5,7 @@ import torch class BaseLayerPatch(ABC): @abstractmethod - def get_parameters(self, orig_module: torch.nn.Module, weight: float) -> dict[str, torch.Tensor]: + def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]: """Get the parameter residual updates that should be applied to the original parameters. Parameters omitted from the returned dict are not updated. """ diff --git a/invokeai/backend/patches/layers/concatenated_lora_layer.py b/invokeai/backend/patches/layers/concatenated_lora_layer.py index a098a9e61b..a699a47433 100644 --- a/invokeai/backend/patches/layers/concatenated_lora_layer.py +++ b/invokeai/backend/patches/layers/concatenated_lora_layer.py @@ -30,7 +30,7 @@ class ConcatenatedLoRALayer(LoRALayerBase): layer_weights = [lora_layer.get_weight(None) * lora_layer.scale() for lora_layer in self.lora_layers] # pyright: ignore[reportArgumentType] return torch.cat(layer_weights, dim=self.concat_axis) - def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]: + def get_bias(self, orig_bias: torch.Tensor | None) -> Optional[torch.Tensor]: # TODO(ryand): Currently, we pass orig_bias=None to the sub-layers. If we want to support sub-layers that # require this value, we will need to implement chunking of the original bias tensor here. # Note that we must apply the sub-layer scales here. diff --git a/invokeai/backend/patches/layers/flux_control_lora_layer.py b/invokeai/backend/patches/layers/flux_control_lora_layer.py index 142336a00a..ad592456a9 100644 --- a/invokeai/backend/patches/layers/flux_control_lora_layer.py +++ b/invokeai/backend/patches/layers/flux_control_lora_layer.py @@ -8,11 +8,11 @@ class FluxControlLoRALayer(LoRALayer): shapes don't match. """ - def get_parameters(self, orig_module: torch.nn.Module, weight: float) -> dict[str, torch.Tensor]: + def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]: """This overrides the base class behavior to skip the reshaping step.""" scale = self.scale() - params = {"weight": self.get_weight(orig_module.weight) * (weight * scale)} - bias = self.get_bias(orig_module.bias) + params = {"weight": self.get_weight(orig_parameters["weight"]) * (weight * scale)} + bias = self.get_bias(orig_parameters.get("bias", None)) if bias is not None: params["bias"] = bias * (weight * scale) diff --git a/invokeai/backend/patches/layers/lora_layer_base.py b/invokeai/backend/patches/layers/lora_layer_base.py index 13669ad5d3..123e5afa2c 100644 --- a/invokeai/backend/patches/layers/lora_layer_base.py +++ b/invokeai/backend/patches/layers/lora_layer_base.py @@ -54,19 +54,19 @@ class LoRALayerBase(BaseLayerPatch): def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: raise NotImplementedError() - def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]: + def get_bias(self, orig_bias: torch.Tensor | None) -> Optional[torch.Tensor]: return self.bias - def get_parameters(self, orig_module: torch.nn.Module, weight: float) -> dict[str, torch.Tensor]: + def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]: scale = self.scale() - params = {"weight": self.get_weight(orig_module.weight) * (weight * scale)} - bias = self.get_bias(orig_module.bias) + params = {"weight": self.get_weight(orig_parameters["weight"]) * (weight * scale)} + bias = self.get_bias(orig_parameters.get("bias", None)) if bias is not None: params["bias"] = bias * (weight * scale) # Reshape all params to match the original module's shape. for param_name, param_weight in params.items(): - orig_param = orig_module.get_parameter(param_name) + orig_param = orig_parameters[param_name] if param_weight.shape != orig_param.shape: params[param_name] = param_weight.reshape(orig_param.shape) diff --git a/invokeai/backend/patches/layers/set_parameter_layer.py b/invokeai/backend/patches/layers/set_parameter_layer.py index f0ae461f4d..1b7fe94d36 100644 --- a/invokeai/backend/patches/layers/set_parameter_layer.py +++ b/invokeai/backend/patches/layers/set_parameter_layer.py @@ -14,10 +14,10 @@ class SetParameterLayer(BaseLayerPatch): self.weight = weight self.param_name = param_name - def get_parameters(self, orig_module: torch.nn.Module, weight: float) -> dict[str, torch.Tensor]: + def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]: # Note: We intentionally ignore the weight parameter here. This matches the behavior in the official FLUX # Control LoRA implementation. - diff = self.weight - orig_module.get_parameter(self.param_name) + diff = self.weight - orig_parameters[self.param_name] return {self.param_name: diff} def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None): diff --git a/invokeai/backend/patches/sidecar_wrappers/base_sidecar_wrapper.py b/invokeai/backend/patches/sidecar_wrappers/base_sidecar_wrapper.py index c22525bc95..46d69bbe91 100644 --- a/invokeai/backend/patches/sidecar_wrappers/base_sidecar_wrapper.py +++ b/invokeai/backend/patches/sidecar_wrappers/base_sidecar_wrapper.py @@ -39,8 +39,10 @@ class BaseSidecarWrapper(torch.nn.Module): for patch, patch_weight in patches_and_weights: # TODO(ryand): self._orig_module could be quantized. Depending on what the patch is doing with the original - # module, this might fail or return incorrect results. - layer_params = patch.get_parameters(self._orig_module, weight=patch_weight) + # parameters, this might fail or return incorrect results. + layer_params = patch.get_parameters( + dict(self._orig_module.named_parameters(recurse=False)), weight=patch_weight + ) for param_name, param_weight in layer_params.items(): if param_name not in params: diff --git a/tests/backend/patches/layers/test_flux_control_lora_layer.py b/tests/backend/patches/layers/test_flux_control_lora_layer.py index 00590c3514..129fcfcb4e 100644 --- a/tests/backend/patches/layers/test_flux_control_lora_layer.py +++ b/tests/backend/patches/layers/test_flux_control_lora_layer.py @@ -18,7 +18,7 @@ def test_flux_control_lora_layer_get_parameters(): orig_module = torch.nn.Linear(small_in_features, out_features) # Test that get_parameters() behaves as expected in spite of the difference in in_features shapes. - params = layer.get_parameters(orig_module, weight=1.0) + params = layer.get_parameters(dict(orig_module.named_parameters(recurse=False)), weight=1.0) assert "weight" in params assert params["weight"].shape == (out_features, big_in_features) assert params["weight"].allclose(torch.ones(out_features, big_in_features) * alpha) diff --git a/tests/backend/patches/layers/test_lora_layer.py b/tests/backend/patches/layers/test_lora_layer.py index 34f62c3bcf..c0971fb9a1 100644 --- a/tests/backend/patches/layers/test_lora_layer.py +++ b/tests/backend/patches/layers/test_lora_layer.py @@ -107,7 +107,7 @@ def test_lora_layer_get_parameters(): # Create mock original module orig_module = torch.nn.Linear(in_features, out_features) - params = layer.get_parameters(orig_module, weight=1.0) + params = layer.get_parameters(dict(orig_module.named_parameters(recurse=False)), weight=1.0) assert "weight" in params assert params["weight"].shape == orig_module.weight.shape assert params["weight"].allclose(torch.ones(out_features, in_features) * alpha) diff --git a/tests/backend/patches/layers/test_set_parameter_layer.py b/tests/backend/patches/layers/test_set_parameter_layer.py index 0bca0293f5..bdf8e33749 100644 --- a/tests/backend/patches/layers/test_set_parameter_layer.py +++ b/tests/backend/patches/layers/test_set_parameter_layer.py @@ -10,7 +10,7 @@ def test_set_parameter_layer_get_parameters(): target_weight = torch.randn(8, 4) layer = SetParameterLayer(param_name="weight", weight=target_weight) - params = layer.get_parameters(orig_module, weight=1.0) + params = layer.get_parameters(dict(orig_module.named_parameters(recurse=False)), weight=1.0) assert len(params) == 1 new_weight = orig_module.weight + params["weight"] assert torch.allclose(new_weight, target_weight)