From 3b2d2ef10a870e0b00ed738cf11f37ab0fd8cf5c Mon Sep 17 00:00:00 2001 From: Alexander Eichhorn Date: Fri, 2 Jan 2026 01:45:50 +0100 Subject: [PATCH] fix(gguf): ensure dequantized tensors are on correct device for MPS (#8713) When using GGUF-quantized models on MPS (Apple Silicon), the dequantized tensors could end up on a different device than the other operands in math operations, causing "Expected all tensors to be on the same device" errors. This fix ensures that after dequantization, tensors are moved to the same device as the other tensors in the operation. Co-authored-by: Lincoln Stein --- .../backend/quantization/gguf/ggml_tensor.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/invokeai/backend/quantization/gguf/ggml_tensor.py b/invokeai/backend/quantization/gguf/ggml_tensor.py index f9cf67c0ee..af895fb3ee 100644 --- a/invokeai/backend/quantization/gguf/ggml_tensor.py +++ b/invokeai/backend/quantization/gguf/ggml_tensor.py @@ -17,21 +17,32 @@ def dequantize_and_run(func, args, kwargs): Also casts other floating point tensors to match the compute_dtype of GGMLTensors to avoid dtype mismatches in matrix operations. """ - # Find the compute_dtype from any GGMLTensor in the args + # Find the compute_dtype and target_device from any GGMLTensor in the args compute_dtype = None + target_device = None for a in args: if hasattr(a, "compute_dtype"): compute_dtype = a.compute_dtype + if isinstance(a, torch.Tensor) and target_device is None: + target_device = a.device + if compute_dtype is not None and target_device is not None: break - if compute_dtype is None: + if compute_dtype is None or target_device is None: for v in kwargs.values(): - if hasattr(v, "compute_dtype"): + if hasattr(v, "compute_dtype") and compute_dtype is None: compute_dtype = v.compute_dtype + if isinstance(v, torch.Tensor) and target_device is None: + target_device = v.device + if compute_dtype is not None and target_device is not None: break def process_tensor(t): if hasattr(t, "get_dequantized_tensor"): - return t.get_dequantized_tensor() + result = t.get_dequantized_tensor() + # Ensure the dequantized tensor is on the target device + if target_device is not None and result.device != target_device: + result = result.to(target_device) + return result elif isinstance(t, torch.Tensor) and compute_dtype is not None and t.is_floating_point(): # Cast other floating point tensors to match the GGUF compute_dtype return t.to(compute_dtype)