Fix the _autocast_forward_with_patches() function for CustomConv1d and CustomConv2d.

This commit is contained in:
Ryan Dick 2024-12-29 00:22:37 +00:00
parent 2855bb6b41
commit 0525f967c2
2 changed files with 6 additions and 6 deletions

View File

@ -12,9 +12,9 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custo
class CustomConv1d(torch.nn.Conv1d, CustomModuleMixin):
def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor:
aggregated_param_residuals = self._aggregate_patch_parameters(self._patches_and_weights)
weight = add_nullable_tensors(self.weight, aggregated_param_residuals["weight"])
bias = add_nullable_tensors(self.bias, aggregated_param_residuals["bias"])
return torch.nn.functional.conv1d(input, weight, bias)
weight = add_nullable_tensors(self.weight, aggregated_param_residuals.get("weight", None))
bias = add_nullable_tensors(self.bias, aggregated_param_residuals.get("bias", None))
return self._conv_forward(input, weight, bias)
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)

View File

@ -12,9 +12,9 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custo
class CustomConv2d(torch.nn.Conv2d, CustomModuleMixin):
def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor:
aggregated_param_residuals = self._aggregate_patch_parameters(self._patches_and_weights)
weight = add_nullable_tensors(self.weight, aggregated_param_residuals["weight"])
bias = add_nullable_tensors(self.bias, aggregated_param_residuals["bias"])
return torch.nn.functional.conv2d(input, weight, bias)
weight = add_nullable_tensors(self.weight, aggregated_param_residuals.get("weight", None))
bias = add_nullable_tensors(self.bias, aggregated_param_residuals.get("bias", None))
return self._conv_forward(input, weight, bias)
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)