mirror of
https://github.com/invoke-ai/InvokeAI
synced 2026-04-21 06:06:02 +02:00
Handle DoRA layer device casting when model is partially-loaded.
This commit is contained in:
parent
5357d6e08e
commit
6c919e1bca
@ -2,6 +2,7 @@ from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
|
||||
from invokeai.backend.patches.layers.lora_layer_base import LoRALayerBase
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensors_size
|
||||
|
||||
@ -58,6 +59,8 @@ class DoRALayer(LoRALayerBase):
|
||||
return self.down.shape[0]
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
orig_weight = cast_to_device(orig_weight, self.up.device)
|
||||
|
||||
# Note: Variable names (e.g. delta_v) are based on the paper.
|
||||
delta_v = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
|
||||
delta_v = delta_v.reshape(orig_weight.shape)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user