Handle DoRA layer device casting when model is partially-loaded.

This commit is contained in:
Ryan Dick 2025-01-24 20:24:22 +00:00
parent 5357d6e08e
commit 6c919e1bca

View File

@ -2,6 +2,7 @@ from typing import Dict, Optional
import torch
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
from invokeai.backend.patches.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.util.calc_tensor_size import calc_tensors_size
@ -58,6 +59,8 @@ class DoRALayer(LoRALayerBase):
return self.down.shape[0]
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
orig_weight = cast_to_device(orig_weight, self.up.device)
# Note: Variable names (e.g. delta_v) are based on the paper.
delta_v = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
delta_v = delta_v.reshape(orig_weight.shape)