diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/autocast_modules.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/autocast_modules.py index 215da8ed3b..8a1bacf683 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/autocast_modules.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/autocast_modules.py @@ -1,13 +1,6 @@ -import copy -from typing import TypeVar - -import bitsandbytes as bnb import torch -from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt -from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4 - -T = TypeVar("T", torch.Tensor, None, torch.Tensor | None) +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device # This file contains custom torch.nn.Module classes that support streaming of weights to the target device. # Each class sub-classes the original module type that is is replacing, so the following properties are preserved: @@ -15,15 +8,6 @@ T = TypeVar("T", torch.Tensor, None, torch.Tensor | None) # - Patching the weights (e.g. for LoRA) should still work if non-quantized. -def cast_to_device(t: T, to_device: torch.device) -> T: - if t is None: - return t - - if t.device.type != to_device.type: - return t.to(to_device) - return t - - class CustomLinear(torch.nn.Linear): def forward(self, input: torch.Tensor) -> torch.Tensor: weight = cast_to_device(self.weight, input.device) @@ -64,59 +48,3 @@ class CustomEmbedding(torch.nn.Embedding): self.scale_grad_by_freq, self.sparse, ) - - -class CustomInvokeLinear8bitLt(InvokeLinear8bitLt): - def forward(self, x: torch.Tensor) -> torch.Tensor: - matmul_state = bnb.MatmulLtState() - matmul_state.threshold = self.state.threshold - matmul_state.has_fp16_weights = self.state.has_fp16_weights - matmul_state.use_pool = self.state.use_pool - matmul_state.is_training = self.training - # The underlying InvokeInt8Params weight must already be quantized. - assert self.weight.CB is not None - matmul_state.CB = cast_to_device(self.weight.CB, x.device) - matmul_state.SCB = cast_to_device(self.weight.SCB, x.device) - - # weights are cast automatically as Int8Params, but the bias has to be cast manually. - if self.bias is not None and self.bias.dtype != x.dtype: - self.bias.data = self.bias.data.to(x.dtype) - - # NOTE(ryand): The second parameter should not be needed at all given our expected inference configuration, but - # it's dtype field must be accessible, even though it's not used. We pass in self.weight even though it could be - # on the wrong device. - return bnb.matmul(x, self.weight, bias=cast_to_device(self.bias, x.device), state=matmul_state) - - -class CustomInvokeLinearNF4(InvokeLinearNF4): - def forward(self, x: torch.Tensor) -> torch.Tensor: - bnb.nn.modules.fix_4bit_weight_quant_state_from_module(self) - - # weights are cast automatically as Int8Params, but the bias has to be cast manually - if self.bias is not None and self.bias.dtype != x.dtype: - self.bias.data = self.bias.data.to(x.dtype) - - if not self.compute_type_is_set: - self.set_compute_type(x) - self.compute_type_is_set = True - - inp_dtype = x.dtype - if self.compute_dtype is not None: - x = x.to(self.compute_dtype) - - bias = None if self.bias is None else self.bias.to(self.compute_dtype) - - # HACK(ryand): Casting self.weight to the device also casts the self.weight.quant_state in-place (i.e. it - # does not follow the tensor semantics of returning a new copy when converting to a different device). This - # means that quant_state elements that started on the CPU would be left on the GPU, which we don't want. To - # avoid this side effect we make a shallow copy of the original quant_state so that we can restore it. Fixing - # this properly would require more invasive changes to the bitsandbytes library. - - # Make a shallow copy of the quant_state so that we can undo the in-place modification that occurs when casting - # to a new device. - old_quant_state = copy.copy(self.weight.quant_state) - weight = cast_to_device(self.weight, x.device) - self.weight.quant_state = old_quant_state - - bias = cast_to_device(self.bias, x.device) - return bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state).to(inp_dtype) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/cast_to_device.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/cast_to_device.py new file mode 100644 index 0000000000..7a50a19953 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/cast_to_device.py @@ -0,0 +1,15 @@ +from typing import TypeVar + +import torch + +T = TypeVar("T", torch.Tensor, None, torch.Tensor | None) + + +def cast_to_device(t: T, to_device: torch.device) -> T: + """Helper function to cast an optional tensor to a target device.""" + if t is None: + return t + + if t.device.type != to_device.type: + return t.to(to_device) + return t diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_invoke_linear_8_bit_lt.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_invoke_linear_8_bit_lt.py new file mode 100644 index 0000000000..3941a2af6b --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_invoke_linear_8_bit_lt.py @@ -0,0 +1,27 @@ +import bitsandbytes as bnb +import torch + +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device +from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt + + +class CustomInvokeLinear8bitLt(InvokeLinear8bitLt): + def forward(self, x: torch.Tensor) -> torch.Tensor: + matmul_state = bnb.MatmulLtState() + matmul_state.threshold = self.state.threshold + matmul_state.has_fp16_weights = self.state.has_fp16_weights + matmul_state.use_pool = self.state.use_pool + matmul_state.is_training = self.training + # The underlying InvokeInt8Params weight must already be quantized. + assert self.weight.CB is not None + matmul_state.CB = cast_to_device(self.weight.CB, x.device) + matmul_state.SCB = cast_to_device(self.weight.SCB, x.device) + + # weights are cast automatically as Int8Params, but the bias has to be cast manually. + if self.bias is not None and self.bias.dtype != x.dtype: + self.bias.data = self.bias.data.to(x.dtype) + + # NOTE(ryand): The second parameter should not be needed at all given our expected inference configuration, but + # it's dtype field must be accessible, even though it's not used. We pass in self.weight even though it could be + # on the wrong device. + return bnb.matmul(x, self.weight, bias=cast_to_device(self.bias, x.device), state=matmul_state) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_invoke_linear_nf4.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_invoke_linear_nf4.py new file mode 100644 index 0000000000..82e1050e99 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_invoke_linear_nf4.py @@ -0,0 +1,41 @@ +import copy + +import bitsandbytes as bnb +import torch + +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device +from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4 + + +class CustomInvokeLinearNF4(InvokeLinearNF4): + def forward(self, x: torch.Tensor) -> torch.Tensor: + bnb.nn.modules.fix_4bit_weight_quant_state_from_module(self) + + # weights are cast automatically as Int8Params, but the bias has to be cast manually + if self.bias is not None and self.bias.dtype != x.dtype: + self.bias.data = self.bias.data.to(x.dtype) + + if not self.compute_type_is_set: + self.set_compute_type(x) + self.compute_type_is_set = True + + inp_dtype = x.dtype + if self.compute_dtype is not None: + x = x.to(self.compute_dtype) + + bias = None if self.bias is None else self.bias.to(self.compute_dtype) + + # HACK(ryand): Casting self.weight to the device also casts the self.weight.quant_state in-place (i.e. it + # does not follow the tensor semantics of returning a new copy when converting to a different device). This + # means that quant_state elements that started on the CPU would be left on the GPU, which we don't want. To + # avoid this side effect we make a shallow copy of the original quant_state so that we can restore it. Fixing + # this properly would require more invasive changes to the bitsandbytes library. + + # Make a shallow copy of the quant_state so that we can undo the in-place modification that occurs when casting + # to a new device. + old_quant_state = copy.copy(self.weight.quant_state) + weight = cast_to_device(self.weight, x.device) + self.weight.quant_state = old_quant_state + + bias = cast_to_device(self.bias, x.device) + return bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state).to(inp_dtype) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py index 59c99ab411..825eebf64e 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py @@ -5,10 +5,8 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.autoc CustomConv2d, CustomEmbedding, CustomGroupNorm, - CustomInvokeLinear8bitLt, CustomLinear, ) -from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt AUTOCAST_MODULE_TYPE_MAPPING: dict[type[torch.nn.Module], type[torch.nn.Module]] = { torch.nn.Linear: CustomLinear, @@ -16,9 +14,24 @@ AUTOCAST_MODULE_TYPE_MAPPING: dict[type[torch.nn.Module], type[torch.nn.Module]] torch.nn.Conv2d: CustomConv2d, torch.nn.GroupNorm: CustomGroupNorm, torch.nn.Embedding: CustomEmbedding, - InvokeLinear8bitLt: CustomInvokeLinear8bitLt, } +try: + # These dependencies are not expected to be present on MacOS. + from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_invoke_linear_8_bit_lt import ( + CustomInvokeLinear8bitLt, + ) + from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_invoke_linear_nf4 import ( + CustomInvokeLinearNF4, + ) + from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt + from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4 + + AUTOCAST_MODULE_TYPE_MAPPING[InvokeLinear8bitLt] = CustomInvokeLinear8bitLt + AUTOCAST_MODULE_TYPE_MAPPING[InvokeLinearNF4] = CustomInvokeLinearNF4 +except ImportError: + pass + def apply_custom_layers_to_model(model: torch.nn.Module): def apply_custom_layers(module: torch.nn.Module): diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/test_autocast_modules.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/test_autocast_modules.py index 7f8c5cbbfe..e2200acb03 100644 --- a/tests/backend/model_manager/load/model_cache/torch_module_autocast/test_autocast_modules.py +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/test_autocast_modules.py @@ -1,12 +1,17 @@ import pytest import torch -from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.autocast_modules import ( - CustomInvokeLinear8bitLt, - CustomInvokeLinearNF4, -) -from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt -from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4 +if not torch.cuda.is_available(): + pytest.skip("CUDA is not available", allow_module_level=True) +else: + from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_invoke_linear_8_bit_lt import ( + CustomInvokeLinear8bitLt, + ) + from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_invoke_linear_nf4 import ( + CustomInvokeLinearNF4, + ) + from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt + from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4 @pytest.fixture diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/test_torch_module_autocast.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/test_torch_module_autocast.py index 1c47d297cb..91ec79d738 100644 --- a/tests/backend/model_manager/load/model_cache/torch_module_autocast/test_torch_module_autocast.py +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/test_torch_module_autocast.py @@ -6,9 +6,14 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch apply_custom_layers_to_model, remove_custom_layers_from_model, ) -from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt, quantize_model_llm_int8 from tests.backend.quantization.gguf.test_ggml_tensor import quantize_tensor +try: + from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt, quantize_model_llm_int8 +except ImportError: + # This is expected to fail on MacOS + pass + cuda_and_mps = pytest.mark.parametrize( "device", [ diff --git a/tests/backend/quantization/test_bnb_llm_int8.py b/tests/backend/quantization/test_bnb_llm_int8.py index ca42e3498e..9dbed6f3a6 100644 --- a/tests/backend/quantization/test_bnb_llm_int8.py +++ b/tests/backend/quantization/test_bnb_llm_int8.py @@ -1,7 +1,10 @@ import pytest import torch -from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt +try: + from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt +except ImportError: + pass def test_invoke_linear_8bit_lt_quantization():