diff --git a/invokeai/app/services/events.py b/invokeai/app/services/events.py index 775873e24c..e578a24006 100644 --- a/invokeai/app/services/events.py +++ b/invokeai/app/services/events.py @@ -3,7 +3,7 @@ from typing import Any from invokeai.app.models.image import ProgressImage from invokeai.app.util.misc import get_timestamp -from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, SDModelInfo +from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, ModelInfo from invokeai.app.models.exceptions import CanceledException class EventServiceBase: @@ -136,7 +136,7 @@ class EventServiceBase: base_model: BaseModelType, model_type: ModelType, submodel: SubModelType, - model_info: SDModelInfo, + model_info: ModelInfo, ) -> None: """Emitted when a model is correctly loaded (returns model info)""" self.__emit_session_event( diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index 155f1e3737..f56d3dbeac 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -13,7 +13,7 @@ from invokeai.backend.model_management.model_manager import ( BaseModelType, ModelType, SubModelType, - SDModelInfo, + ModelInfo, ) from invokeai.app.models.exceptions import CanceledException from .config import InvokeAIAppConfig @@ -49,7 +49,7 @@ class ModelManagerServiceBase(ABC): submodel: Optional[SubModelType] = None, node: Optional[BaseInvocation] = None, context: Optional[InvocationContext] = None, - ) -> SDModelInfo: + ) -> ModelInfo: """Retrieve the indicated model with name and type. submodel can be used to get a part (such as the vae) of a diffusers pipeline.""" @@ -302,7 +302,7 @@ class ModelManagerService(ModelManagerServiceBase): submodel: Optional[SubModelType] = None, node: Optional[BaseInvocation] = None, context: Optional[InvocationContext] = None, - ) -> SDModelInfo: + ) -> ModelInfo: """ Retrieve the indicated model. submodel can be used to get a part (such as the vae) of a diffusers mode. @@ -539,7 +539,7 @@ class ModelManagerService(ModelManagerServiceBase): base_model: BaseModelType, model_type: ModelType, submodel: SubModelType, - model_info: Optional[SDModelInfo] = None, + model_info: Optional[ModelInfo] = None, ): if context.services.queue.is_canceled(context.graph_execution_state_id): raise CanceledException() diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 4ab2381109..3393f6e467 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -166,22 +166,13 @@ import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.util import CUDA_DEVICE, download_with_resume from .model_cache import ModelCache, ModelLocker -from .models import BaseModelType, SubModelType, MODEL_CLASSES +from .models import BaseModelType, ModelType, SubModelType, MODEL_CLASSES # We are only starting to number the config file with release 3. # The config file version doesn't have to start at release version, but it will help # reduce confusion. CONFIG_FILE_VERSION='3.0.0' -# temporary forward definitions to avoid circular import errors. -class ModelLocker(object): - "Forward declaration" - pass - -class ModelCache(object): - "Forward declaration" - pass - @dataclass class ModelInfo(): context: ModelLocker @@ -744,3 +735,37 @@ class ModelManager(object): resolved_path = self.globals.root_dir / source return resolved_path + def scan_models_directory(self): + loaded_files = set() + + for model_key, model_config in list(self.models.items()): + model_name, base_model, model_type = self.parse_key(model_key) + if not os.path.exists(model_config.path): + if model_class.save_to_config: + model_config.error = ModelError.NotFound + else: + self.models.pop(model_key, None) + else: + loaded_files.add(model_config.path) + + for base_model in BaseModelType: + for model_type in ModelType: + model_class = MODEL_CLASSES[base_model][model_type] + models_dir = os.path.join(self.globals.models_path, base_model, model_type) + + if not os.path.exists(models_dir): + continue # TODO: or create all folders? + + for entry_name in os.listdir(models_dir): + model_path = os.path.join(models_dir, entry_name) + if model_path not in loaded_files: # TODO: check + model_name = Path(model_path).stem + model_key = self.create_key(model_name, base_model, model_type) + + if model_key in self.models: + raise Exception(f"Model with key {model_key} added twice") + + model_config: ModelConfigBase = model_class.build_config( + path=model_path, + ) + self.models[model_key] = model_config diff --git a/invokeai/backend/model_management/models/__init__.py b/invokeai/backend/model_management/models/__init__.py index 2fa328a1f7..3160a76408 100644 --- a/invokeai/backend/model_management/models/__init__.py +++ b/invokeai/backend/model_management/models/__init__.py @@ -5,26 +5,30 @@ from .lora import LoRAModel #from .controlnet import ControlNetModel # TODO: from .textual_inversion import TextualInversionModel +# TODO: +class ControlNetModel: + pass + MODEL_CLASSES = { BaseModelType.StableDiffusion1_5: { ModelType.Pipeline: StableDiffusion15Model, ModelType.Vae: VaeModel, ModelType.Lora: LoRAModel, - #ModelType.ControlNet: ControlNetModel, + ModelType.ControlNet: ControlNetModel, ModelType.TextualInversion: TextualInversionModel, }, BaseModelType.StableDiffusion2: { ModelType.Pipeline: StableDiffusion2Model, ModelType.Vae: VaeModel, ModelType.Lora: LoRAModel, - #ModelType.ControlNet: ControlNetModel, + ModelType.ControlNet: ControlNetModel, ModelType.TextualInversion: TextualInversionModel, }, BaseModelType.StableDiffusion2Base: { ModelType.Pipeline: StableDiffusion2BaseModel, ModelType.Vae: VaeModel, ModelType.Lora: LoRAModel, - #ModelType.ControlNet: ControlNetModel, + ModelType.ControlNet: ControlNetModel, ModelType.TextualInversion: TextualInversionModel, }, #BaseModelType.Kandinsky2_1: { @@ -35,3 +39,11 @@ MODEL_CLASSES = { # ModelType.TextualInversion: TextualInversionModel, #}, } + +# TODO: check with openapi annotation +def get_all_model_configs(): + configs = [] + for models in MODEL_CLASSES.values(): + for model in models.values(): + configs.extend(model._get_configs()) + return configs diff --git a/invokeai/backend/model_management/models/stable_diffusion.py b/invokeai/backend/model_management/models/stable_diffusion.py index 6299e28db9..476fb6bf86 100644 --- a/invokeai/backend/model_management/models/stable_diffusion.py +++ b/invokeai/backend/model_management/models/stable_diffusion.py @@ -1,5 +1,7 @@ import os +import json import torch +import safetensors.torch from pydantic import Field from typing import Literal, Optional from .base import ( @@ -8,10 +10,13 @@ from .base import ( BaseModelType, ModelType, SubModelType, + VariantType, DiffusersModel, ) from invokeai.app.services.config import InvokeAIAppConfig +ModelVariantType = VariantType # TODO: + # TODO: how to name properly class StableDiffusion15Model(DiffusersModel): @@ -20,11 +25,13 @@ class StableDiffusion15Model(DiffusersModel): class DiffusersConfig(ModelConfigBase): format: Literal["diffusers"] vae: Optional[str] = Field(None) + variant: ModelVariantType class CheckpointConfig(ModelConfigBase): format: Literal["checkpoint"] vae: Optional[str] = Field(None) config: Optional[str] = Field(None) + variant: ModelVariantType def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): @@ -36,6 +43,86 @@ class StableDiffusion15Model(DiffusersModel): model_type=ModelType.Pipeline, ) + @staticmethod + def _fast_safetensors_reader(path: str): + checkpoint = dict() + device = torch.device("meta") + with open(path, "rb") as f: + definition_len = int.from_bytes(f.read(8), 'little') + definition_json = f.read(definition_len) + definition = json.loads(definition_json) + + if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in {"pt", "torch", "pytorch"}: + raise Exception("Supported only pytorch safetensors files") + definition.pop("__metadata__", None) + + for key, info in definition.items(): + dtype = { + "I8": torch.int8, + "I16": torch.int16, + "I32": torch.int32, + "I64": torch.int64, + "F16": torch.float16, + "F32": torch.float32, + "F64": torch.float64, + }[info["dtype"]] + + checkpoint[key] = torch.empty(info["shape"], dtype=dtype, device=device) + + return checkpoint + + + @classmethod + def read_checkpoint_meta(cls, path: str): + if path.endswith(".safetensors"): + try: + checkpoint = cls._fast_safetensors_reader(path) + except: + checkpoint = safetensors.torch.load_file(path, device="cpu") # TODO: create issue for support "meta"? + else: + checkpoint = torch.load(path, map_location=torch.device("meta")) + return checkpoint + + @classmethod + def build_config(cls, **kwargs): + if "format" not in kwargs: + kwargs["format"] = cls.detect_format(kwargs["path"]) + + if "variant" not in kwargs: + if kwargs["format"] == "checkpoint": + if "config" in kwargs: + ckpt_config = OmegaConf.load(kwargs["config"]) + in_channels = ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"] + + else: + checkpoint = cls.read_checkpoint_meta(kwargs["path"]) + checkpoint = checkpoint.get('state_dict', checkpoint) + in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1] + + elif kwargs["format"] == "diffusers": + unet_config_path = os.path.join(kwargs["path"], "unet", "config.json") + if os.path.exists(unet_config_path): + unet_config = json.loads(unet_config_path) + in_channels = unet_config['in_channels'] + + else: + raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)") + + else: + raise NotImplementedError(f"Unknown stable diffusion format: {kwargs['format']}") + + if in_channels == 9: + kwargs["variant"] = ModelVariantType.Inpaint + elif in_channels == 5: + kwargs["variant"] = ModelVariantType.Depth + elif in_channels == 4: + kwargs["variant"] = ModelVariantType.Normal + else: + raise Exception("Unkown stable diffusion model format") + + + return super().build_config(**kwargs) + @classmethod def save_to_config(cls) -> bool: return True