diff --git a/invokeai/app/services/config.py b/invokeai/app/services/config.py index a9e5bbee98..e15cf94eb5 100644 --- a/invokeai/app/services/config.py +++ b/invokeai/app/services/config.py @@ -184,7 +184,7 @@ class InvokeAISettings(BaseSettings): initconf: ClassVar[DictConfig] = None argparse_groups: ClassVar[Dict] = {} - def parse_args(self, argv: list = sys.argv[1:]): + def parse_args(self, argv: List[str] = sys.argv[1:]): parser = self.get_parser() opt = parser.parse_args(argv) for name in self.__fields__: @@ -217,7 +217,7 @@ class InvokeAISettings(BaseSettings): return OmegaConf.to_yaml(conf) @classmethod - def add_parser_arguments(cls, parser): + def add_parser_arguments(cls, parser: argparse.ArgumentParser): if "type" in get_type_hints(cls): settings_stanza = get_args(get_type_hints(cls)["type"])[0] else: diff --git a/invokeai/backend/model_management/model_search.py b/invokeai/backend/model_management/model_search.py index 0a98091f4a..17e9b6b38e 100644 --- a/invokeai/backend/model_management/model_search.py +++ b/invokeai/backend/model_management/model_search.py @@ -71,7 +71,7 @@ class ModelSearch(ABC): if any( [ (path / x).exists() - for x in {"config.json", "model_index.json", "learned_embeds.bin", "pytorch_lora_weights.bin"} + for x in ["config.json", "model_index.json", "learned_embeds.bin", "pytorch_lora_weights.bin"] ] ): try: diff --git a/invokeai/backend/model_manager/__init__.py b/invokeai/backend/model_manager/__init__.py index 762675f602..46d4071503 100644 --- a/invokeai/backend/model_manager/__init__.py +++ b/invokeai/backend/model_manager/__init__.py @@ -13,5 +13,6 @@ from .config import ( # noqa F401 SchedulerPredictionType, SubModelType, ) -from .model_install import ModelInstall # noqa F401 +from .install import ModelInstall # noqa F401 from .probe import ModelProbe, InvalidModelException # noqa F401 +from .storage import DuplicateModelException # noqa F401 diff --git a/invokeai/backend/model_manager/hash.py b/invokeai/backend/model_manager/hash.py new file mode 100644 index 0000000000..3f11055e7f --- /dev/null +++ b/invokeai/backend/model_manager/hash.py @@ -0,0 +1,68 @@ +# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team +""" +Fast hashing of diffusers and checkpoint-style models. + +Usage: +from invokeai.backend.model_management.model_hash import FastModelHash +>>> FastModelHash.hash('/home/models/stable-diffusion-v1.5') +'a8e693a126ea5b831c96064dc569956f' +""" + +import os +import hashlib +from imohash import hashfile +from pathlib import Path +from typing import Dict, Union + + +class FastModelHash(object): + """FastModelHash obect provides one public class method, hash().""" + + @classmethod + def hash(cls, model_location: Union[str, Path]) -> str: + """ + Return hexdigest string for model located at model_location. + + :param model_location: Path to the model + """ + model_location = Path(model_location) + if model_location.is_file(): + return cls._hash_file(model_location) + elif model_location.is_dir(): + return cls._hash_dir(model_location) + else: + # avoid circular import + from .models import InvalidModelException + + raise InvalidModelException(f"Not a valid file or directory: {model_location}") + + @classmethod + def _hash_file(cls, model_location: Union[str, Path]) -> str: + """ + Fasthash a single file and return its hexdigest. + + :param model_location: Path to the model file + """ + # we return md5 hash of the filehash to make it shorter + # cryptographic security not needed here + return hashlib.md5(hashfile(model_location)).hexdigest() + + @classmethod + def _hash_dir(cls, model_location: Union[str, Path]) -> str: + components: Dict[str, str] = {} + + for root, dirs, files in os.walk(model_location): + for file in files: + # Ignore the config files, which change locally, + # and just look at the bin files. + if file in ['config.json', 'model_index.json']: + continue + path = Path(root) / file + fast_hash = cls._hash_file(path) + components.update({str(path): fast_hash}) + + # hash all the model hashes together, using alphabetic file order + md5 = hashlib.md5() + for path, fast_hash in sorted(components.items()): + md5.update(fast_hash.encode("utf-8")) + return md5.hexdigest() diff --git a/invokeai/backend/model_manager/model_install.py b/invokeai/backend/model_manager/install.py similarity index 51% rename from invokeai/backend/model_manager/model_install.py rename to invokeai/backend/model_manager/install.py index 2bc14c0412..fa4142c931 100644 --- a/invokeai/backend/model_manager/model_install.py +++ b/invokeai/backend/model_manager/install.py @@ -19,7 +19,7 @@ Typical usage: id: str = installer.install_model('/path/to/model') # unregister, don't delete - installer.forget(id) + installer.unregister(id) # unregister and delete model from disk installer.delete_model(id) @@ -38,10 +38,22 @@ The following exceptions may be raised: """ from abc import ABC, abstractmethod from pathlib import Path -from typing import Optional, List +from shutil import rmtree +from typing import Optional, List, Union from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.util.logging import InvokeAILogger -from .storage import ModelConfigStore, UnknownModelException +from .search import ModelSearch +from .storage import ModelConfigStore, ModelConfigStoreYAML, DuplicateModelException +from .hash import FastModelHash +from .probe import ModelProbe, ModelProbeInfo, InvalidModelException +from .config import ( + ModelType, + BaseModelType, + ModelVariantType, + ModelFormat, + SchedulerPredictionType, +) + class ModelInstallBase(ABC): """Abstract base class for InvokeAI model installation""" @@ -65,7 +77,7 @@ class ModelInstallBase(ABC): pass @abstractmethod - def register(self, model_path: Path) -> str: + def register(self, model_path: Union[Path, str]) -> str: """ Probe and register the model at model_path. @@ -75,7 +87,7 @@ class ModelInstallBase(ABC): pass @abstractmethod - def install(self, model_path: Path) -> str: + def install(self, model_path: Union[Path, str]) -> str: """ Probe, register and install the model in the models directory. @@ -88,7 +100,7 @@ class ModelInstallBase(ABC): pass @abstractmethod - def forget(self, id: str): + def unregister(self, id: str): """ Unregister the model identified by id. @@ -101,7 +113,7 @@ class ModelInstallBase(ABC): pass @abstractmethod - def delete(self, id: str) -> str: + def delete(self, id: str): """ Unregister and delete the model identified by id. @@ -138,7 +150,7 @@ class ModelInstallBase(ABC): pass @abstractmethod - def hash(self, model_path: Path) -> str: + def hash(self, model_path: Union[Path, str]) -> str: """ Compute and return the fast hash of the model. @@ -155,35 +167,124 @@ class ModelInstall(ModelInstallBase): _logger: InvokeAILogger _store: ModelConfigStore - def __init__(self, + _legacy_configs = { + BaseModelType.StableDiffusion1: { + ModelVariantType.Normal: "v1-inference.yaml", + ModelVariantType.Inpaint: "v1-inpainting-inference.yaml", + }, + BaseModelType.StableDiffusion2: { + ModelVariantType.Normal: { + SchedulerPredictionType.Epsilon: "v2-inference.yaml", + SchedulerPredictionType.VPrediction: "v2-inference-v.yaml", + }, + ModelVariantType.Inpaint: { + SchedulerPredictionType.Epsilon: "v2-inpainting-inference.yaml", + SchedulerPredictionType.VPrediction: "v2-inpainting-inference-v.yaml", + }, + }, + BaseModelType.StableDiffusionXL: { + ModelVariantType.Normal: "sd_xl_base.yaml", + }, + BaseModelType.StableDiffusionXLRefiner: { + ModelVariantType.Normal: "sd_xl_refiner.yaml", + }, + } + + def __init__(self, store: Optional[ModelConfigStore] = None, config: Optional[InvokeAIAppConfig] = None, logger: Optional[InvokeAILogger] = None - ): # noqa D107 - use base class docstrings + ): # noqa D107 - use base class docstrings self._config = config or InvokeAIAppConfig.get_config() self._logger = logger or InvokeAILogger.getLogger() if store is None: - from .storage import ModelConfigStoreYAML store = ModelConfigStoreYAML(config.model_conf_path) self._store = store - def register(self, model_path: Path) -> str: # noqa D102 - pass + def register(self, model_path: Union[Path, str]) -> str: # noqa D102 + model_path = Path(model_path) + info: ModelProbeInfo = ModelProbe.probe(model_path) + return self._register(model_path, info) - def install(self, model_path: Path) -> str: # noqa D102 - pass + def _register(self, model_path: Path, info: ModelProbeInfo) -> str: + id: str = FastModelHash.hash(model_path) + registration_data = dict( + path=model_path.as_posix(), + name=model_path.stem, + base_model=info.base_type, + model_type=info.model_type, + model_format=info.format + ) + # add 'main' specific fields + if info.model_type == ModelType.Main and info.format == ModelFormat.Checkpoint: + try: + config_file = self._legacy_configs[info.base_type][info.variant_type] + except KeyError as exc: + raise InvalidModelException("Configuration file for this checkpoint could not be determined") from exc + registration_data.update( + config=Path(self._config.legacy_conf_dir, config_file).as_posix(), + ) + self._store.add_model(id, registration_data) + return id - def forget(self, id: str) -> str: # noqa D102 - pass + def install(self, model_path: Union[Path, str]) -> str: # noqa D102 + model_path = Path(model_path) + info: ModelProbeInfo = ModelProbe.probe(model_path) + dest_path = self._config.models_path / info.base_model.value / info.model_type.value / model_path.name - def delete(self, id: str) -> str: # noqa D102 - pass + # if path already exists then we jigger the name to make it unique + counter: int = 1 + while dest_path.exists(): + dest_path = dest_path.with_stem(dest_path.stem + f"_{counter:02d}") + counter += 1 + + self._register( + model_path.replace(dest_path), + info, + ) + + def unregister(self, id: str): # noqa D102 + self._store.del_model(id) + + def delete(self, id: str): # noqa D102 + model = self._store.get_model(id) + rmtree(model.path) + self.unregister(id) def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102 - pass + search = ModelSearch() + search.model_found = self._scan_install if install else self._scan_register + self._installed = set() + search.search([scan_dir]) + return list(self._installed) def garbage_collect(self) -> List[str]: # noqa D102 - pass + unregistered = list() + for model in self._store.all_models(): + path = Path(model.path) + if not path.exists(): + self._store.del_model(model.id) + unregistered.append(model.id) + return unregistered - def hash(self, model_path: Path) -> str: # noqa D102 - pass + def hash(self, model_path: Union[Path, str]) -> str: # noqa D102 + return FastModelHash.hash(model_path) + + # the following two methods are callbacks to the ModelSearch object + def _scan_register(self, model: Path) -> bool: + try: + id = self.register(model) + self._logger.info(f"Registered {model} with id {id}") + self._installed.add(id) + except DuplicateModelException as exc: + pass + return True + + def _scan_install(self, model: Path) -> bool: + try: + id = self.install(model) + self._logger.info(f"Installed {model} with id {id}") + self._installed.add(id) + except DuplicateModelException as exc: + pass + return True diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index 69043386f4..a014d5c2d7 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -16,11 +16,7 @@ from picklescan.scanner import scan_file_path import torch import safetensors.torch -from invokeai.backend.model_management.models.base import ( - read_checkpoint_meta, - InvalidModelException, -) - +from .util import read_checkpoint_meta from .config import ( ModelType, BaseModelType, @@ -31,6 +27,9 @@ from .config import ( from .util import SilenceWarnings, lora_token_vector_length +class InvalidModelException(Exception): + """Raised when an invalid model is encountered.""" + @dataclass class ModelProbeInfo(object): """Fields describing a probed model.""" @@ -373,7 +372,7 @@ class TextualInversionCheckpointProbe(CheckpointProbeBase): def get_format(self) -> Optional[str]: """Return the format of a TextualInversion emedding.""" - return None + return ModelFormat.EmbeddingFile def get_base_type(self) -> BaseModelType: """Return BaseModelType of the checkpoint model.""" @@ -513,7 +512,7 @@ class TextualInversionFolderProbe(FolderProbeBase): def get_format(self) -> Optional[str]: """Return the format of the TextualInversion.""" - return None + return ModelFormat.EmbeddingFolder def get_base_type(self) -> BaseModelType: """Return the ModelBaseType of the HuggingFace-style Textual Inversion Folder.""" diff --git a/invokeai/backend/model_manager/search.py b/invokeai/backend/model_manager/search.py new file mode 100644 index 0000000000..5b170efd01 --- /dev/null +++ b/invokeai/backend/model_manager/search.py @@ -0,0 +1,138 @@ +# Copyright 2023, Lincoln D. Stein and the InvokeAI Team +""" +Abstract base class for recursive directory search for models. +""" + +import os +from abc import ABC, abstractmethod +from typing import List, Set, Optional, Callable, Union, types +from pathlib import Path + +import invokeai.backend.util.logging as logger + + +class ModelSearchBase(ABC): + """Hierarchical directory model search class""" + + def __init__(self, logger: types.ModuleType = logger): + """ + Initialize a recursive model directory search. + :param directories: List of directory Paths to recurse through + :param logger: Logger to use + """ + self.logger = logger + self._items_scanned = 0 + self._models_found = 0 + self._scanned_dirs = set() + self._scanned_paths = set() + self._pruned_paths = set() + + @abstractmethod + def on_search_started(self): + """ + Called before the scan starts. + """ + pass + + @abstractmethod + def on_model_found(self, model: Path): + """ + Process a found model. Raise an exception if something goes wrong. + :param model: Model to process - could be a directory or checkpoint. + """ + pass + + @abstractmethod + def on_search_completed(self): + """ + Perform some activity when the scan is completed. May use instance + variables, items_scanned and models_found + """ + pass + + def search(self, directories: List[Union[Path, str]]): + self.on_search_started() + for dir in directories: + self.walk_directory(dir) + self.on_search_completed() + + def walk_directory(self, path: Union[Path, str]): + for root, dirs, files in os.walk(path, followlinks=True): + if str(Path(root).name).startswith("."): + self._pruned_paths.add(root) + if any([Path(root).is_relative_to(x) for x in self._pruned_paths]): + continue + + self._items_scanned += len(dirs) + len(files) + for d in dirs: + path = Path(root) / d + if path in self._scanned_paths or path.parent in self._scanned_dirs: + self._scanned_dirs.add(path) + continue + if any( + [ + (path / x).exists() + for x in ["config.json", "model_index.json", "learned_embeds.bin", "pytorch_lora_weights.bin"] + ] + ): + try: + self.on_model_found(path) + self._models_found += 1 + self._scanned_dirs.add(path) + except Exception as e: + self.logger.warning(str(e)) + + for f in files: + path = Path(root) / f + if path.parent in self._scanned_dirs: + continue + if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}: + try: + self.on_model_found(path) + self._models_found += 1 + except Exception as e: + self.logger.warning(str(e)) + + +class ModelSearch(ModelSearchBase): + """ + Implementation of ModelSearch with callbacks. + Usage: + search = ModelSearch() + search.model_found = lambda path : 'anime' in path.as_posix() + found = search.list_models(['/tmp/models1','/tmp/models2']) + # returns all models that have 'anime' in the path + """ + + _model_set: Set[Path] + search_started: Callable[[Path], None] + search_completed: Callable[[Set[Path]], None] + model_found: Callable[[Path], bool] + + def __init__(self, logger: types.ModuleType = logger): + super().__init__(logger) + self._model_set = set() + self.search_started = None + self.search_completed = None + self.model_found = None + + def on_search_started(self): + self._model_set = set() + if self.search_started: + self.search_started() + + def on_model_found(self, model: Path): + if not self.model_found: + self._model_set.add(model) + return + if self.model_found(model): + self._model_set.add(model) + + def on_search_completed(self): + if self.search_completed: + self.search_completed(self._model_set) + + def list_models(self, directories: List[Union[Path,str]]) -> List[Path]: + """Return list of models found""" + self.search(directories) + return list(self._model_set) diff --git a/invokeai/backend/model_manager/storage/__init__.py b/invokeai/backend/model_manager/storage/__init__.py index 0bda32c9ea..457e3b47ae 100644 --- a/invokeai/backend/model_manager/storage/__init__.py +++ b/invokeai/backend/model_manager/storage/__init__.py @@ -1,6 +1,6 @@ """ Initialization file for invokeai.backend.model_manager.storage """ -from .base import ModelConfigStore, UnknownModelException # noqa F401 +from .base import ModelConfigStore, UnknownModelException, DuplicateModelException # noqa F401 from .yaml import ModelConfigStoreYAML # noqa F401 from .sql import ModelConfigStoreSQL # noqa F401 diff --git a/invokeai/backend/model_manager/storage/base.py b/invokeai/backend/model_manager/storage/base.py index d3a42d768f..69930dde4d 100644 --- a/invokeai/backend/model_manager/storage/base.py +++ b/invokeai/backend/model_manager/storage/base.py @@ -14,10 +14,13 @@ class DuplicateModelException(Exception): """Raised on an attempt to add a model with the same key twice.""" +class InvalidModelException(Exception): + """Raised when an invalid model is detected.""" + + class UnknownModelException(Exception): """Raised on an attempt to delete a model with a nonexistent key.""" - class ModelConfigStore(ABC): """Abstract base class for storage and retrieval of model configs.""" diff --git a/invokeai/backend/model_manager/util.py b/invokeai/backend/model_manager/util.py index 04a45b7510..ecc544df93 100644 --- a/invokeai/backend/model_manager/util.py +++ b/invokeai/backend/model_manager/util.py @@ -2,11 +2,15 @@ """ Various utilities used by the model manager. """ -from typing import Optional +import json import warnings +import torch +import safetensors +from pathlib import Path +from typing import Optional, Union from diffusers import logging as diffusers_logging from transformers import logging as transformers_logging - +from picklescan.scanner import scan_file_path class SilenceWarnings(object): """ @@ -106,3 +110,49 @@ def lora_token_vector_length(checkpoint: dict) -> Optional[int]: break return lora_token_vector_length + +def _fast_safetensors_reader(path: str): + checkpoint = dict() + device = torch.device("meta") + with open(path, "rb") as f: + definition_len = int.from_bytes(f.read(8), "little") + definition_json = f.read(definition_len) + definition = json.loads(definition_json) + + if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in { + "pt", + "torch", + "pytorch", + }: + raise Exception("Supported only pytorch safetensors files") + definition.pop("__metadata__", None) + + for key, info in definition.items(): + dtype = { + "I8": torch.int8, + "I16": torch.int16, + "I32": torch.int32, + "I64": torch.int64, + "F16": torch.float16, + "F32": torch.float32, + "F64": torch.float64, + }[info["dtype"]] + + checkpoint[key] = torch.empty(info["shape"], dtype=dtype, device=device) + + return checkpoint + +def read_checkpoint_meta(path: Union[str, Path], scan: bool = False): + if str(path).endswith(".safetensors"): + try: + checkpoint = _fast_safetensors_reader(path) + except Exception: + # TODO: create issue for support "meta"? + checkpoint = safetensors.torch.load_file(path, device="cpu") + else: + if scan: + scan_result = scan_file_path(path) + if scan_result.infected_files != 0: + raise Exception(f'The model file "{path}" is potentially infected by malware. Aborting import.') + checkpoint = torch.load(path, map_location=torch.device("meta")) + return checkpoint diff --git a/pyproject.toml b/pyproject.toml index 02e53f066a..431ca24e54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ dependencies = [ "fastapi-events==0.8.0", "fastapi-socketio==0.0.10", "huggingface-hub~=0.16.4", + "imohash~=1.0.0", "invisible-watermark~=0.2.0", # needed to install SDXL base and refiner using their repo_ids "matplotlib", # needed for plotting of Penner easing functions "mediapipe", # needed for "mediapipeface" controlnet model