mirror of
https://github.com/invoke-ai/InvokeAI
synced 2026-04-27 17:22:31 +02:00
Fix the _autocast_forward_with_patches() function for CustomConv1d and CustomConv2d.
This commit is contained in:
parent
2855bb6b41
commit
0525f967c2
@ -12,9 +12,9 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custo
|
||||
class CustomConv1d(torch.nn.Conv1d, CustomModuleMixin):
|
||||
def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor:
|
||||
aggregated_param_residuals = self._aggregate_patch_parameters(self._patches_and_weights)
|
||||
weight = add_nullable_tensors(self.weight, aggregated_param_residuals["weight"])
|
||||
bias = add_nullable_tensors(self.bias, aggregated_param_residuals["bias"])
|
||||
return torch.nn.functional.conv1d(input, weight, bias)
|
||||
weight = add_nullable_tensors(self.weight, aggregated_param_residuals.get("weight", None))
|
||||
bias = add_nullable_tensors(self.bias, aggregated_param_residuals.get("bias", None))
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
|
||||
@ -12,9 +12,9 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custo
|
||||
class CustomConv2d(torch.nn.Conv2d, CustomModuleMixin):
|
||||
def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor:
|
||||
aggregated_param_residuals = self._aggregate_patch_parameters(self._patches_and_weights)
|
||||
weight = add_nullable_tensors(self.weight, aggregated_param_residuals["weight"])
|
||||
bias = add_nullable_tensors(self.bias, aggregated_param_residuals["bias"])
|
||||
return torch.nn.functional.conv2d(input, weight, bias)
|
||||
weight = add_nullable_tensors(self.weight, aggregated_param_residuals.get("weight", None))
|
||||
bias = add_nullable_tensors(self.bias, aggregated_param_residuals.get("bias", None))
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user