diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index c8e02a10a3..fb9751b549 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -14,7 +14,7 @@ from invokeai.backend.model_manager import ( ) from invokeai.backend.model_manager.config import DiffusersConfigBase from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase -from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache +from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache, get_model_cache_key from invokeai.backend.model_manager.load.model_cache.model_locker import ModelLocker from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init @@ -70,7 +70,7 @@ class ModelLoader(ModelLoaderBase): def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> ModelLocker: stats_name = ":".join([config.base, config.type, config.name, (submodel_type or "")]) try: - return self._ram_cache.get(config.key, submodel_type, stats_name=stats_name) + return self._ram_cache.get(key=get_model_cache_key(config.key, submodel_type), stats_name=stats_name) except IndexError: pass @@ -79,16 +79,11 @@ class ModelLoader(ModelLoaderBase): loaded_model = self._load_model(config, submodel_type) self._ram_cache.put( - config.key, - submodel_type=submodel_type, + get_model_cache_key(config.key, submodel_type), model=loaded_model, ) - return self._ram_cache.get( - key=config.key, - submodel_type=submodel_type, - stats_name=stats_name, - ) + return self._ram_cache.get(key=get_model_cache_key(config.key, submodel_type), stats_name=stats_name) def get_size_fs( self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache.py b/invokeai/backend/model_manager/load/model_cache/model_cache.py index 28c1d64865..de9e31917c 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache.py @@ -26,6 +26,13 @@ GB = 2**30 MB = 2**20 +def get_model_cache_key(model_key: str, submodel_type: Optional[SubModelType] = None) -> str: + if submodel_type: + return f"{model_key}:{submodel_type.value}" + else: + return model_key + + class ModelCache: """A cache for managing models in memory. @@ -159,10 +166,8 @@ class ModelCache: self, key: str, model: AnyModel, - submodel_type: Optional[SubModelType] = None, ) -> None: - """Store model under key and optional submodel_type.""" - key = self._make_cache_key(key, submodel_type) + """Insert model into the cache.""" if key in self._cached_models: return size = calc_model_size_by_data(self.logger, model) @@ -177,20 +182,15 @@ class ModelCache: def get( self, key: str, - submodel_type: Optional[SubModelType] = None, stats_name: Optional[str] = None, ) -> ModelLocker: - """ - Retrieve model using key and optional submodel_type. + """Retrieve a model from the cache. - :param key: Opaque model key - :param submodel_type: Type of the submodel to fetch - :param stats_name: A human-readable id for the model for the purposes of - stats reporting. + :param key: Model key + :param stats_name: A human-readable id for the model for the purposes of stats reporting. - This may raise an IndexError if the model is not in the cache. + Raises IndexError if the model is not in the cache. """ - key = self._make_cache_key(key, submodel_type) if key in self._cached_models: if self.stats: self.stats.hits += 1 diff --git a/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py index 1f57d5c199..92b80d2c7a 100644 --- a/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +++ b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py @@ -25,6 +25,7 @@ from invokeai.backend.model_manager.config import ( DiffusersConfigBase, MainCheckpointConfig, ) +from invokeai.backend.model_manager.load.model_cache.model_cache import get_model_cache_key from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader from invokeai.backend.util.silence_warnings import SilenceWarnings @@ -132,5 +133,5 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader): if subtype == submodel_type: continue if submodel := getattr(pipeline, subtype.value, None): - self._ram_cache.put(config.key, submodel_type=subtype, model=submodel) + self._ram_cache.put(get_model_cache_key(config.key, subtype), model=submodel) return getattr(pipeline, submodel_type.value)