added install module

This commit is contained in:
Lincoln Stein
2023-08-23 19:41:25 -04:00
parent 4b3d54dbc0
commit 9adc897302
11 changed files with 399 additions and 38 deletions

View File

@ -184,7 +184,7 @@ class InvokeAISettings(BaseSettings):
initconf: ClassVar[DictConfig] = None
argparse_groups: ClassVar[Dict] = {}
def parse_args(self, argv: list = sys.argv[1:]):
def parse_args(self, argv: List[str] = sys.argv[1:]):
parser = self.get_parser()
opt = parser.parse_args(argv)
for name in self.__fields__:
@ -217,7 +217,7 @@ class InvokeAISettings(BaseSettings):
return OmegaConf.to_yaml(conf)
@classmethod
def add_parser_arguments(cls, parser):
def add_parser_arguments(cls, parser: argparse.ArgumentParser):
if "type" in get_type_hints(cls):
settings_stanza = get_args(get_type_hints(cls)["type"])[0]
else:

View File

@ -71,7 +71,7 @@ class ModelSearch(ABC):
if any(
[
(path / x).exists()
for x in {"config.json", "model_index.json", "learned_embeds.bin", "pytorch_lora_weights.bin"}
for x in ["config.json", "model_index.json", "learned_embeds.bin", "pytorch_lora_weights.bin"]
]
):
try:

View File

@ -13,5 +13,6 @@ from .config import ( # noqa F401
SchedulerPredictionType,
SubModelType,
)
from .model_install import ModelInstall # noqa F401
from .install import ModelInstall # noqa F401
from .probe import ModelProbe, InvalidModelException # noqa F401
from .storage import DuplicateModelException # noqa F401

View File

@ -0,0 +1,68 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Fast hashing of diffusers and checkpoint-style models.
Usage:
from invokeai.backend.model_management.model_hash import FastModelHash
>>> FastModelHash.hash('/home/models/stable-diffusion-v1.5')
'a8e693a126ea5b831c96064dc569956f'
"""
import os
import hashlib
from imohash import hashfile
from pathlib import Path
from typing import Dict, Union
class FastModelHash(object):
"""FastModelHash obect provides one public class method, hash()."""
@classmethod
def hash(cls, model_location: Union[str, Path]) -> str:
"""
Return hexdigest string for model located at model_location.
:param model_location: Path to the model
"""
model_location = Path(model_location)
if model_location.is_file():
return cls._hash_file(model_location)
elif model_location.is_dir():
return cls._hash_dir(model_location)
else:
# avoid circular import
from .models import InvalidModelException
raise InvalidModelException(f"Not a valid file or directory: {model_location}")
@classmethod
def _hash_file(cls, model_location: Union[str, Path]) -> str:
"""
Fasthash a single file and return its hexdigest.
:param model_location: Path to the model file
"""
# we return md5 hash of the filehash to make it shorter
# cryptographic security not needed here
return hashlib.md5(hashfile(model_location)).hexdigest()
@classmethod
def _hash_dir(cls, model_location: Union[str, Path]) -> str:
components: Dict[str, str] = {}
for root, dirs, files in os.walk(model_location):
for file in files:
# Ignore the config files, which change locally,
# and just look at the bin files.
if file in ['config.json', 'model_index.json']:
continue
path = Path(root) / file
fast_hash = cls._hash_file(path)
components.update({str(path): fast_hash})
# hash all the model hashes together, using alphabetic file order
md5 = hashlib.md5()
for path, fast_hash in sorted(components.items()):
md5.update(fast_hash.encode("utf-8"))
return md5.hexdigest()

View File

@ -19,7 +19,7 @@ Typical usage:
id: str = installer.install_model('/path/to/model')
# unregister, don't delete
installer.forget(id)
installer.unregister(id)
# unregister and delete model from disk
installer.delete_model(id)
@ -38,10 +38,22 @@ The following exceptions may be raised:
"""
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Optional, List
from shutil import rmtree
from typing import Optional, List, Union
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util.logging import InvokeAILogger
from .storage import ModelConfigStore, UnknownModelException
from .search import ModelSearch
from .storage import ModelConfigStore, ModelConfigStoreYAML, DuplicateModelException
from .hash import FastModelHash
from .probe import ModelProbe, ModelProbeInfo, InvalidModelException
from .config import (
ModelType,
BaseModelType,
ModelVariantType,
ModelFormat,
SchedulerPredictionType,
)
class ModelInstallBase(ABC):
"""Abstract base class for InvokeAI model installation"""
@ -65,7 +77,7 @@ class ModelInstallBase(ABC):
pass
@abstractmethod
def register(self, model_path: Path) -> str:
def register(self, model_path: Union[Path, str]) -> str:
"""
Probe and register the model at model_path.
@ -75,7 +87,7 @@ class ModelInstallBase(ABC):
pass
@abstractmethod
def install(self, model_path: Path) -> str:
def install(self, model_path: Union[Path, str]) -> str:
"""
Probe, register and install the model in the models directory.
@ -88,7 +100,7 @@ class ModelInstallBase(ABC):
pass
@abstractmethod
def forget(self, id: str):
def unregister(self, id: str):
"""
Unregister the model identified by id.
@ -101,7 +113,7 @@ class ModelInstallBase(ABC):
pass
@abstractmethod
def delete(self, id: str) -> str:
def delete(self, id: str):
"""
Unregister and delete the model identified by id.
@ -138,7 +150,7 @@ class ModelInstallBase(ABC):
pass
@abstractmethod
def hash(self, model_path: Path) -> str:
def hash(self, model_path: Union[Path, str]) -> str:
"""
Compute and return the fast hash of the model.
@ -155,35 +167,124 @@ class ModelInstall(ModelInstallBase):
_logger: InvokeAILogger
_store: ModelConfigStore
def __init__(self,
_legacy_configs = {
BaseModelType.StableDiffusion1: {
ModelVariantType.Normal: "v1-inference.yaml",
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
},
BaseModelType.StableDiffusion2: {
ModelVariantType.Normal: {
SchedulerPredictionType.Epsilon: "v2-inference.yaml",
SchedulerPredictionType.VPrediction: "v2-inference-v.yaml",
},
ModelVariantType.Inpaint: {
SchedulerPredictionType.Epsilon: "v2-inpainting-inference.yaml",
SchedulerPredictionType.VPrediction: "v2-inpainting-inference-v.yaml",
},
},
BaseModelType.StableDiffusionXL: {
ModelVariantType.Normal: "sd_xl_base.yaml",
},
BaseModelType.StableDiffusionXLRefiner: {
ModelVariantType.Normal: "sd_xl_refiner.yaml",
},
}
def __init__(self,
store: Optional[ModelConfigStore] = None,
config: Optional[InvokeAIAppConfig] = None,
logger: Optional[InvokeAILogger] = None
): # noqa D107 - use base class docstrings
): # noqa D107 - use base class docstrings
self._config = config or InvokeAIAppConfig.get_config()
self._logger = logger or InvokeAILogger.getLogger()
if store is None:
from .storage import ModelConfigStoreYAML
store = ModelConfigStoreYAML(config.model_conf_path)
self._store = store
def register(self, model_path: Path) -> str: # noqa D102
pass
def register(self, model_path: Union[Path, str]) -> str: # noqa D102
model_path = Path(model_path)
info: ModelProbeInfo = ModelProbe.probe(model_path)
return self._register(model_path, info)
def install(self, model_path: Path) -> str: # noqa D102
pass
def _register(self, model_path: Path, info: ModelProbeInfo) -> str:
id: str = FastModelHash.hash(model_path)
registration_data = dict(
path=model_path.as_posix(),
name=model_path.stem,
base_model=info.base_type,
model_type=info.model_type,
model_format=info.format
)
# add 'main' specific fields
if info.model_type == ModelType.Main and info.format == ModelFormat.Checkpoint:
try:
config_file = self._legacy_configs[info.base_type][info.variant_type]
except KeyError as exc:
raise InvalidModelException("Configuration file for this checkpoint could not be determined") from exc
registration_data.update(
config=Path(self._config.legacy_conf_dir, config_file).as_posix(),
)
self._store.add_model(id, registration_data)
return id
def forget(self, id: str) -> str: # noqa D102
pass
def install(self, model_path: Union[Path, str]) -> str: # noqa D102
model_path = Path(model_path)
info: ModelProbeInfo = ModelProbe.probe(model_path)
dest_path = self._config.models_path / info.base_model.value / info.model_type.value / model_path.name
def delete(self, id: str) -> str: # noqa D102
pass
# if path already exists then we jigger the name to make it unique
counter: int = 1
while dest_path.exists():
dest_path = dest_path.with_stem(dest_path.stem + f"_{counter:02d}")
counter += 1
self._register(
model_path.replace(dest_path),
info,
)
def unregister(self, id: str): # noqa D102
self._store.del_model(id)
def delete(self, id: str): # noqa D102
model = self._store.get_model(id)
rmtree(model.path)
self.unregister(id)
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
pass
search = ModelSearch()
search.model_found = self._scan_install if install else self._scan_register
self._installed = set()
search.search([scan_dir])
return list(self._installed)
def garbage_collect(self) -> List[str]: # noqa D102
pass
unregistered = list()
for model in self._store.all_models():
path = Path(model.path)
if not path.exists():
self._store.del_model(model.id)
unregistered.append(model.id)
return unregistered
def hash(self, model_path: Path) -> str: # noqa D102
pass
def hash(self, model_path: Union[Path, str]) -> str: # noqa D102
return FastModelHash.hash(model_path)
# the following two methods are callbacks to the ModelSearch object
def _scan_register(self, model: Path) -> bool:
try:
id = self.register(model)
self._logger.info(f"Registered {model} with id {id}")
self._installed.add(id)
except DuplicateModelException as exc:
pass
return True
def _scan_install(self, model: Path) -> bool:
try:
id = self.install(model)
self._logger.info(f"Installed {model} with id {id}")
self._installed.add(id)
except DuplicateModelException as exc:
pass
return True

View File

@ -16,11 +16,7 @@ from picklescan.scanner import scan_file_path
import torch
import safetensors.torch
from invokeai.backend.model_management.models.base import (
read_checkpoint_meta,
InvalidModelException,
)
from .util import read_checkpoint_meta
from .config import (
ModelType,
BaseModelType,
@ -31,6 +27,9 @@ from .config import (
from .util import SilenceWarnings, lora_token_vector_length
class InvalidModelException(Exception):
"""Raised when an invalid model is encountered."""
@dataclass
class ModelProbeInfo(object):
"""Fields describing a probed model."""
@ -373,7 +372,7 @@ class TextualInversionCheckpointProbe(CheckpointProbeBase):
def get_format(self) -> Optional[str]:
"""Return the format of a TextualInversion emedding."""
return None
return ModelFormat.EmbeddingFile
def get_base_type(self) -> BaseModelType:
"""Return BaseModelType of the checkpoint model."""
@ -513,7 +512,7 @@ class TextualInversionFolderProbe(FolderProbeBase):
def get_format(self) -> Optional[str]:
"""Return the format of the TextualInversion."""
return None
return ModelFormat.EmbeddingFolder
def get_base_type(self) -> BaseModelType:
"""Return the ModelBaseType of the HuggingFace-style Textual Inversion Folder."""

View File

@ -0,0 +1,138 @@
# Copyright 2023, Lincoln D. Stein and the InvokeAI Team
"""
Abstract base class for recursive directory search for models.
"""
import os
from abc import ABC, abstractmethod
from typing import List, Set, Optional, Callable, Union, types
from pathlib import Path
import invokeai.backend.util.logging as logger
class ModelSearchBase(ABC):
"""Hierarchical directory model search class"""
def __init__(self, logger: types.ModuleType = logger):
"""
Initialize a recursive model directory search.
:param directories: List of directory Paths to recurse through
:param logger: Logger to use
"""
self.logger = logger
self._items_scanned = 0
self._models_found = 0
self._scanned_dirs = set()
self._scanned_paths = set()
self._pruned_paths = set()
@abstractmethod
def on_search_started(self):
"""
Called before the scan starts.
"""
pass
@abstractmethod
def on_model_found(self, model: Path):
"""
Process a found model. Raise an exception if something goes wrong.
:param model: Model to process - could be a directory or checkpoint.
"""
pass
@abstractmethod
def on_search_completed(self):
"""
Perform some activity when the scan is completed. May use instance
variables, items_scanned and models_found
"""
pass
def search(self, directories: List[Union[Path, str]]):
self.on_search_started()
for dir in directories:
self.walk_directory(dir)
self.on_search_completed()
def walk_directory(self, path: Union[Path, str]):
for root, dirs, files in os.walk(path, followlinks=True):
if str(Path(root).name).startswith("."):
self._pruned_paths.add(root)
if any([Path(root).is_relative_to(x) for x in self._pruned_paths]):
continue
self._items_scanned += len(dirs) + len(files)
for d in dirs:
path = Path(root) / d
if path in self._scanned_paths or path.parent in self._scanned_dirs:
self._scanned_dirs.add(path)
continue
if any(
[
(path / x).exists()
for x in ["config.json", "model_index.json", "learned_embeds.bin", "pytorch_lora_weights.bin"]
]
):
try:
self.on_model_found(path)
self._models_found += 1
self._scanned_dirs.add(path)
except Exception as e:
self.logger.warning(str(e))
for f in files:
path = Path(root) / f
if path.parent in self._scanned_dirs:
continue
if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}:
try:
self.on_model_found(path)
self._models_found += 1
except Exception as e:
self.logger.warning(str(e))
class ModelSearch(ModelSearchBase):
"""
Implementation of ModelSearch with callbacks.
Usage:
search = ModelSearch()
search.model_found = lambda path : 'anime' in path.as_posix()
found = search.list_models(['/tmp/models1','/tmp/models2'])
# returns all models that have 'anime' in the path
"""
_model_set: Set[Path]
search_started: Callable[[Path], None]
search_completed: Callable[[Set[Path]], None]
model_found: Callable[[Path], bool]
def __init__(self, logger: types.ModuleType = logger):
super().__init__(logger)
self._model_set = set()
self.search_started = None
self.search_completed = None
self.model_found = None
def on_search_started(self):
self._model_set = set()
if self.search_started:
self.search_started()
def on_model_found(self, model: Path):
if not self.model_found:
self._model_set.add(model)
return
if self.model_found(model):
self._model_set.add(model)
def on_search_completed(self):
if self.search_completed:
self.search_completed(self._model_set)
def list_models(self, directories: List[Union[Path,str]]) -> List[Path]:
"""Return list of models found"""
self.search(directories)
return list(self._model_set)

View File

@ -1,6 +1,6 @@
"""
Initialization file for invokeai.backend.model_manager.storage
"""
from .base import ModelConfigStore, UnknownModelException # noqa F401
from .base import ModelConfigStore, UnknownModelException, DuplicateModelException # noqa F401
from .yaml import ModelConfigStoreYAML # noqa F401
from .sql import ModelConfigStoreSQL # noqa F401

View File

@ -14,10 +14,13 @@ class DuplicateModelException(Exception):
"""Raised on an attempt to add a model with the same key twice."""
class InvalidModelException(Exception):
"""Raised when an invalid model is detected."""
class UnknownModelException(Exception):
"""Raised on an attempt to delete a model with a nonexistent key."""
class ModelConfigStore(ABC):
"""Abstract base class for storage and retrieval of model configs."""

View File

@ -2,11 +2,15 @@
"""
Various utilities used by the model manager.
"""
from typing import Optional
import json
import warnings
import torch
import safetensors
from pathlib import Path
from typing import Optional, Union
from diffusers import logging as diffusers_logging
from transformers import logging as transformers_logging
from picklescan.scanner import scan_file_path
class SilenceWarnings(object):
"""
@ -106,3 +110,49 @@ def lora_token_vector_length(checkpoint: dict) -> Optional[int]:
break
return lora_token_vector_length
def _fast_safetensors_reader(path: str):
checkpoint = dict()
device = torch.device("meta")
with open(path, "rb") as f:
definition_len = int.from_bytes(f.read(8), "little")
definition_json = f.read(definition_len)
definition = json.loads(definition_json)
if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in {
"pt",
"torch",
"pytorch",
}:
raise Exception("Supported only pytorch safetensors files")
definition.pop("__metadata__", None)
for key, info in definition.items():
dtype = {
"I8": torch.int8,
"I16": torch.int16,
"I32": torch.int32,
"I64": torch.int64,
"F16": torch.float16,
"F32": torch.float32,
"F64": torch.float64,
}[info["dtype"]]
checkpoint[key] = torch.empty(info["shape"], dtype=dtype, device=device)
return checkpoint
def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
if str(path).endswith(".safetensors"):
try:
checkpoint = _fast_safetensors_reader(path)
except Exception:
# TODO: create issue for support "meta"?
checkpoint = safetensors.torch.load_file(path, device="cpu")
else:
if scan:
scan_result = scan_file_path(path)
if scan_result.infected_files != 0:
raise Exception(f'The model file "{path}" is potentially infected by malware. Aborting import.')
checkpoint = torch.load(path, map_location=torch.device("meta"))
return checkpoint

View File

@ -50,6 +50,7 @@ dependencies = [
"fastapi-events==0.8.0",
"fastapi-socketio==0.0.10",
"huggingface-hub~=0.16.4",
"imohash~=1.0.0",
"invisible-watermark~=0.2.0", # needed to install SDXL base and refiner using their repo_ids
"matplotlib", # needed for plotting of Penner easing functions
"mediapipe", # needed for "mediapipeface" controlnet model