mirror of
https://github.com/invoke-ai/InvokeAI
synced 2026-03-01 20:49:10 +01:00
fix(gguf): ensure dequantized tensors are on correct device for MPS (#8713)
When using GGUF-quantized models on MPS (Apple Silicon), the dequantized tensors could end up on a different device than the other operands in math operations, causing "Expected all tensors to be on the same device" errors. This fix ensures that after dequantization, tensors are moved to the same device as the other tensors in the operation. Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
This commit is contained in:
parent
66974841f1
commit
3b2d2ef10a
@ -17,21 +17,32 @@ def dequantize_and_run(func, args, kwargs):
|
||||
Also casts other floating point tensors to match the compute_dtype of GGMLTensors
|
||||
to avoid dtype mismatches in matrix operations.
|
||||
"""
|
||||
# Find the compute_dtype from any GGMLTensor in the args
|
||||
# Find the compute_dtype and target_device from any GGMLTensor in the args
|
||||
compute_dtype = None
|
||||
target_device = None
|
||||
for a in args:
|
||||
if hasattr(a, "compute_dtype"):
|
||||
compute_dtype = a.compute_dtype
|
||||
if isinstance(a, torch.Tensor) and target_device is None:
|
||||
target_device = a.device
|
||||
if compute_dtype is not None and target_device is not None:
|
||||
break
|
||||
if compute_dtype is None:
|
||||
if compute_dtype is None or target_device is None:
|
||||
for v in kwargs.values():
|
||||
if hasattr(v, "compute_dtype"):
|
||||
if hasattr(v, "compute_dtype") and compute_dtype is None:
|
||||
compute_dtype = v.compute_dtype
|
||||
if isinstance(v, torch.Tensor) and target_device is None:
|
||||
target_device = v.device
|
||||
if compute_dtype is not None and target_device is not None:
|
||||
break
|
||||
|
||||
def process_tensor(t):
|
||||
if hasattr(t, "get_dequantized_tensor"):
|
||||
return t.get_dequantized_tensor()
|
||||
result = t.get_dequantized_tensor()
|
||||
# Ensure the dequantized tensor is on the target device
|
||||
if target_device is not None and result.device != target_device:
|
||||
result = result.to(target_device)
|
||||
return result
|
||||
elif isinstance(t, torch.Tensor) and compute_dtype is not None and t.is_floating_point():
|
||||
# Cast other floating point tensors to match the GGUF compute_dtype
|
||||
return t.to(compute_dtype)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user