diff --git a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_only_full_load.py b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_only_full_load.py new file mode 100644 index 0000000000..719a559dd0 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_only_full_load.py @@ -0,0 +1,93 @@ +from typing import Any + +import torch + + +class CachedModelOnlyFullLoad: + """A wrapper around a PyTorch model to handle full loads and unloads between the CPU and the compute device. + Note: "VRAM" is used throughout this class to refer to the memory on the compute device. It could be CUDA memory, + MPS memory, etc. + """ + + def __init__(self, model: torch.nn.Module | Any, compute_device: torch.device, total_bytes: int): + """Initialize a CachedModelOnlyFullLoad. + Args: + model (torch.nn.Module | Any): The model to wrap. Should be on the CPU. + compute_device (torch.device): The compute device to move the model to. + total_bytes (int): The total size (in bytes) of all the weights in the model. + """ + # model is often a torch.nn.Module, but could be any model type. Throughout this class, we handle both cases. + self._model = model + self._compute_device = compute_device + self._offload_device = torch.device("cpu") + + # A CPU read-only copy of the model's state dict. + self._cpu_state_dict: dict[str, torch.Tensor] | None = None + if isinstance(model, torch.nn.Module): + self._cpu_state_dict = model.state_dict() + + self._total_bytes = total_bytes + self._is_in_vram = False + + @property + def model(self) -> torch.nn.Module: + return self._model + + def get_cpu_state_dict(self) -> dict[str, torch.Tensor] | None: + """Get a read-only copy of the model's state dict in RAM.""" + # TODO(ryand): Document this better. + return self._cpu_state_dict + + def total_bytes(self) -> int: + """Get the total size (in bytes) of all the weights in the model.""" + return self._total_bytes + + def cur_vram_bytes(self) -> int: + """Get the size (in bytes) of the weights that are currently in VRAM.""" + if self._is_in_vram: + return self._total_bytes + else: + return 0 + + def is_in_vram(self) -> bool: + """Return true if the model is currently in VRAM.""" + return self._is_in_vram + + def full_load_to_vram(self) -> int: + """Load all weights into VRAM (if supported by the model). + Returns: + The number of bytes loaded into VRAM. + """ + if self._is_in_vram: + # Already in VRAM. + return 0 + + if not hasattr(self._model, "to"): + # Model doesn't support moving to a device. + return 0 + + if self._cpu_state_dict is not None: + new_state_dict: dict[str, torch.Tensor] = {} + for k, v in self._cpu_state_dict.items(): + new_state_dict[k] = v.to(self._compute_device, copy=True) + self._model.load_state_dict(new_state_dict, assign=True) + self._model.to(self._compute_device) + + self._is_in_vram = True + return self._total_bytes + + def full_unload_from_vram(self) -> int: + """Unload all weights from VRAM. + Returns: + The number of bytes unloaded from VRAM. + """ + if not self._is_in_vram: + # Already in RAM. + return 0 + + if self._cpu_state_dict is not None: + self._model.load_state_dict(self._cpu_state_dict, assign=True) + self._model.to(self._offload_device) + + self._is_in_vram = False + return self._total_bytes diff --git a/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_only_full_load.py b/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_only_full_load.py new file mode 100644 index 0000000000..76a3774288 --- /dev/null +++ b/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_only_full_load.py @@ -0,0 +1,122 @@ +import torch + +from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import ( + CachedModelOnlyFullLoad, +) +from tests.backend.model_manager.load.model_cache.cached_model.utils import DummyModule, parameterize_mps_and_cuda + + +class NonTorchModel: + """A model that does not sub-class torch.nn.Module.""" + + def __init__(self): + self.linear = torch.nn.Linear(10, 32) + + def run_inference(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + +@parameterize_mps_and_cuda +def test_cached_model_total_bytes(device: str): + model = DummyModule() + cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100) + assert cached_model.total_bytes() == 100 + + +@parameterize_mps_and_cuda +def test_cached_model_is_in_vram(device: str): + model = DummyModule() + cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100) + assert not cached_model.is_in_vram() + assert cached_model.cur_vram_bytes() == 0 + + cached_model.full_load_to_vram() + assert cached_model.is_in_vram() + assert cached_model.cur_vram_bytes() == 100 + + cached_model.full_unload_from_vram() + assert not cached_model.is_in_vram() + assert cached_model.cur_vram_bytes() == 0 + + +@parameterize_mps_and_cuda +def test_cached_model_full_load_and_unload(device: str): + model = DummyModule() + cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100) + assert cached_model.full_load_to_vram() == 100 + assert cached_model.is_in_vram() + assert all(p.device.type == device for p in cached_model.model.parameters()) + + assert cached_model.full_unload_from_vram() == 100 + assert not cached_model.is_in_vram() + assert all(p.device.type == "cpu" for p in cached_model.model.parameters()) + + +@parameterize_mps_and_cuda +def test_cached_model_get_cpu_state_dict(device: str): + model = DummyModule() + cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100) + assert not cached_model.is_in_vram() + + # The CPU state dict can be accessed and has the expected properties. + cpu_state_dict = cached_model.get_cpu_state_dict() + assert cpu_state_dict is not None + assert len(cpu_state_dict) == len(model.state_dict()) + assert all(p.device.type == "cpu" for p in cpu_state_dict.values()) + + # Full load the model into VRAM. + cached_model.full_load_to_vram() + assert cached_model.is_in_vram() + + # The CPU state dict is still available, and still on the CPU. + cpu_state_dict = cached_model.get_cpu_state_dict() + assert cpu_state_dict is not None + assert len(cpu_state_dict) == len(model.state_dict()) + assert all(p.device.type == "cpu" for p in cpu_state_dict.values()) + + +@parameterize_mps_and_cuda +def test_cached_model_full_load_and_inference(device: str): + model = DummyModule() + cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100) + assert not cached_model.is_in_vram() + + # Run inference on the CPU. + x = torch.randn(1, 10) + output1 = model(x) + assert output1.device.type == "cpu" + + # Full load the model into VRAM. + cached_model.full_load_to_vram() + assert cached_model.is_in_vram() + + # Run inference on the GPU. + output2 = model(x.to(device)) + assert output2.device.type == device + + # The outputs should be the same for both runs. + assert torch.allclose(output1, output2.to("cpu")) + + +@parameterize_mps_and_cuda +def test_non_torch_model(device: str): + model = NonTorchModel() + cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100) + assert not cached_model.is_in_vram() + + # The model does not have a CPU state dict. + assert cached_model.get_cpu_state_dict() is None + + # Attempting to load the model into VRAM should have no effect. + cached_model.full_load_to_vram() + assert not cached_model.is_in_vram() + assert cached_model.cur_vram_bytes() == 0 + + # Attempting to unload the model from VRAM should have no effect. + cached_model.full_unload_from_vram() + assert not cached_model.is_in_vram() + assert cached_model.cur_vram_bytes() == 0 + + # Running inference on the CPU should work. + output1 = model.run_inference(torch.randn(1, 10)) + assert output1.device.type == "cpu" diff --git a/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_with_partial_load.py b/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_with_partial_load.py index 6a8140d379..e3c99d0c34 100644 --- a/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_with_partial_load.py +++ b/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_with_partial_load.py @@ -1,6 +1,5 @@ import itertools -import pytest import torch from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import ( @@ -8,35 +7,7 @@ from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_w ) from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.autocast_modules import CustomLinear from invokeai.backend.util.calc_tensor_size import calc_tensor_size - - -class DummyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear1 = torch.nn.Linear(10, 32) - self.linear2 = torch.nn.Linear(32, 64) - self.register_buffer("buffer1", torch.ones(64)) - # Non-persistent buffers are not included in the state dict. We need to make sure that this case is handled - # correctly by the partial loading code. - self.register_buffer("buffer2", torch.ones(64), persistent=False) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.linear1(x) - x = self.linear2(x) - x = x + self.buffer1 - x = x + self.buffer2 - return x - - -parameterize_mps_and_cuda = pytest.mark.parametrize( - ("device"), - [ - pytest.param( - "mps", marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS is not available.") - ), - pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available.")), - ], -) +from tests.backend.model_manager.load.model_cache.cached_model.utils import DummyModule, parameterize_mps_and_cuda @parameterize_mps_and_cuda diff --git a/tests/backend/model_manager/load/model_cache/cached_model/utils.py b/tests/backend/model_manager/load/model_cache/cached_model/utils.py new file mode 100644 index 0000000000..9554299e06 --- /dev/null +++ b/tests/backend/model_manager/load/model_cache/cached_model/utils.py @@ -0,0 +1,31 @@ +import pytest +import torch + + +class DummyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(10, 32) + self.linear2 = torch.nn.Linear(32, 64) + self.register_buffer("buffer1", torch.ones(64)) + # Non-persistent buffers are not included in the state dict. We need to make sure that this case is handled + # correctly by the partial loading code. + self.register_buffer("buffer2", torch.ones(64), persistent=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear1(x) + x = self.linear2(x) + x = x + self.buffer1 + x = x + self.buffer2 + return x + + +parameterize_mps_and_cuda = pytest.mark.parametrize( + ("device"), + [ + pytest.param( + "mps", marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS is not available.") + ), + pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available.")), + ], +)