fix(gguf): ensure dequantized tensors are on correct device for MPS (#8713)

When using GGUF-quantized models on MPS (Apple Silicon), the
dequantized tensors could end up on a different device than the
other operands in math operations, causing "Expected all tensors
to be on the same device" errors.

This fix ensures that after dequantization, tensors are moved to
the same device as the other tensors in the operation.

Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
This commit is contained in:
Alexander Eichhorn 2026-01-02 01:45:50 +01:00 committed by GitHub
parent 66974841f1
commit 3b2d2ef10a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -17,21 +17,32 @@ def dequantize_and_run(func, args, kwargs):
Also casts other floating point tensors to match the compute_dtype of GGMLTensors
to avoid dtype mismatches in matrix operations.
"""
# Find the compute_dtype from any GGMLTensor in the args
# Find the compute_dtype and target_device from any GGMLTensor in the args
compute_dtype = None
target_device = None
for a in args:
if hasattr(a, "compute_dtype"):
compute_dtype = a.compute_dtype
if isinstance(a, torch.Tensor) and target_device is None:
target_device = a.device
if compute_dtype is not None and target_device is not None:
break
if compute_dtype is None:
if compute_dtype is None or target_device is None:
for v in kwargs.values():
if hasattr(v, "compute_dtype"):
if hasattr(v, "compute_dtype") and compute_dtype is None:
compute_dtype = v.compute_dtype
if isinstance(v, torch.Tensor) and target_device is None:
target_device = v.device
if compute_dtype is not None and target_device is not None:
break
def process_tensor(t):
if hasattr(t, "get_dequantized_tensor"):
return t.get_dequantized_tensor()
result = t.get_dequantized_tensor()
# Ensure the dequantized tensor is on the target device
if target_device is not None and result.device != target_device:
result = result.to(target_device)
return result
elif isinstance(t, torch.Tensor) and compute_dtype is not None and t.is_floating_point():
# Cast other floating point tensors to match the GGUF compute_dtype
return t.to(compute_dtype)