mirror of
https://github.com/invoke-ai/InvokeAI
synced 2026-03-02 04:59:06 +01:00
* Wrap GGUF loader for context managed close() Wrap gguf.GGUFReader and then use a context manager to load memory-mapped GGUF files, so that they will automatically close properly when no longer needed. Should prevent the 'file in use in another process' errors on Windows. * Additional check for cached state_dict Additional check for cached state_dict as path is now optional - should solve model manager 'missing' this and the resultant memory errors. * Appease ruff * Further ruff appeasement * ruff * loaders.py fix for linux No longer attempting to delete internal object. * loaders.py - one more _mmap ref removed --------- Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
138 lines
5.6 KiB
Python
138 lines
5.6 KiB
Python
from pathlib import Path
|
|
from typing import Any, Optional, TypeAlias
|
|
|
|
import safetensors.torch
|
|
import torch
|
|
from picklescan.scanner import scan_file_path
|
|
from safetensors import safe_open
|
|
|
|
from invokeai.app.services.config.config_default import get_config
|
|
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
|
|
from invokeai.backend.model_manager.taxonomy import ModelRepoVariant
|
|
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
|
|
from invokeai.backend.util.logging import InvokeAILogger
|
|
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
|
|
|
StateDict: TypeAlias = dict[str | int, Any] # When are the keys int?
|
|
|
|
logger = InvokeAILogger.get_logger()
|
|
|
|
|
|
class ModelOnDisk:
|
|
"""A utility class representing a model stored on disk."""
|
|
|
|
def __init__(self, path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single"):
|
|
self.path = path
|
|
if self.path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
|
|
self.name = path.stem
|
|
else:
|
|
self.name = path.name
|
|
self.hash_algo = hash_algo
|
|
# Having a cache helps users of ModelOnDisk (i.e. configs) to save state
|
|
# This prevents redundant computations during matching and parsing
|
|
self._state_dict_cache: dict[Path, Any] = {}
|
|
self._metadata_cache: dict[Path, Any] = {}
|
|
|
|
def hash(self) -> str:
|
|
return ModelHash(algorithm=self.hash_algo).hash(self.path)
|
|
|
|
def size(self) -> int:
|
|
if self.path.is_file():
|
|
return self.path.stat().st_size
|
|
return sum(file.stat().st_size for file in self.path.rglob("*"))
|
|
|
|
def weight_files(self) -> set[Path]:
|
|
if self.path.is_file():
|
|
return {self.path}
|
|
extensions = {".safetensors", ".pt", ".pth", ".ckpt", ".bin", ".gguf"}
|
|
return {f for f in self.path.rglob("*") if f.suffix in extensions and f.is_file()}
|
|
|
|
def metadata(self, path: Optional[Path] = None) -> dict[str, str]:
|
|
path = path or self.path
|
|
if path in self._metadata_cache:
|
|
return self._metadata_cache[path]
|
|
try:
|
|
with safe_open(self.path, framework="pt", device="cpu") as f:
|
|
metadata = f.metadata()
|
|
assert isinstance(metadata, dict)
|
|
except Exception:
|
|
metadata = {}
|
|
|
|
self._metadata_cache[path] = metadata
|
|
return metadata
|
|
|
|
def repo_variant(self) -> Optional[ModelRepoVariant]:
|
|
if self.path.is_file():
|
|
return None
|
|
|
|
weight_files = list(self.path.glob("**/*.safetensors"))
|
|
weight_files.extend(list(self.path.glob("**/*.bin")))
|
|
for x in weight_files:
|
|
if ".fp16" in x.suffixes:
|
|
return ModelRepoVariant.FP16
|
|
if "openvino_model" in x.name:
|
|
return ModelRepoVariant.OpenVINO
|
|
if "flax_model" in x.name:
|
|
return ModelRepoVariant.Flax
|
|
if x.suffix == ".onnx":
|
|
return ModelRepoVariant.ONNX
|
|
return ModelRepoVariant.Default
|
|
|
|
def load_state_dict(self, path: Optional[Path] = None) -> StateDict:
|
|
if path in self._state_dict_cache:
|
|
return self._state_dict_cache[path]
|
|
|
|
path = self.resolve_weight_file(path)
|
|
|
|
if path in self._state_dict_cache:
|
|
return self._state_dict_cache[path]
|
|
|
|
with SilenceWarnings():
|
|
if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
|
|
scan_result = scan_file_path(path)
|
|
if scan_result.infected_files != 0:
|
|
if get_config().unsafe_disable_picklescan:
|
|
logger.warning(
|
|
f"The model {path.stem} is potentially infected by malware, but picklescan is disabled. "
|
|
"Proceeding with caution."
|
|
)
|
|
else:
|
|
raise RuntimeError(
|
|
f"The model {path.stem} is potentially infected by malware. Aborting import."
|
|
)
|
|
if scan_result.scan_err:
|
|
if get_config().unsafe_disable_picklescan:
|
|
logger.warning(
|
|
f"Error scanning the model at {path.stem} for malware, but picklescan is disabled. "
|
|
"Proceeding with caution."
|
|
)
|
|
else:
|
|
raise RuntimeError(f"Error scanning the model at {path.stem} for malware. Aborting import.")
|
|
checkpoint = torch.load(path, map_location="cpu")
|
|
assert isinstance(checkpoint, dict)
|
|
elif path.suffix.endswith(".gguf"):
|
|
checkpoint = gguf_sd_loader(path, compute_dtype=torch.float32)
|
|
elif path.suffix.endswith(".safetensors"):
|
|
checkpoint = safetensors.torch.load_file(path)
|
|
else:
|
|
raise ValueError(f"Unrecognized model extension: {path.suffix}")
|
|
|
|
state_dict = checkpoint.get("state_dict", checkpoint)
|
|
self._state_dict_cache[path] = state_dict
|
|
return state_dict
|
|
|
|
def resolve_weight_file(self, path: Optional[Path] = None) -> Path:
|
|
if not path:
|
|
weight_files = list(self.weight_files())
|
|
match weight_files:
|
|
case []:
|
|
raise ValueError("No weight files found for this model")
|
|
case [p]:
|
|
return p
|
|
case ps if len(ps) >= 2:
|
|
raise ValueError(
|
|
f"Multiple weight files found for this model: {ps}. "
|
|
f"Please specify the intended file using the 'path' argument"
|
|
)
|
|
return path
|