From b583bddeb125e50e7c240de0aa68bed6163c881d Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sun, 10 Sep 2023 22:59:58 -0400 Subject: [PATCH] loading works -- web app broken --- invokeai/backend/__init__.py | 12 +- invokeai/backend/model_manager/__init__.py | 4 +- .../model_cache.py => model_manager/cache.py} | 0 invokeai/backend/model_manager/config.py | 29 ++- invokeai/backend/model_manager/hash.py | 2 +- invokeai/backend/model_manager/install.py | 9 +- invokeai/backend/model_manager/loader.py | 234 ++++++++++++++++++ .../lora.py | 0 .../models/__init__.py | 6 +- .../models/base.py | 103 ++------ .../models/controlnet.py | 0 .../models/lora.py | 0 .../models/sdxl.py | 0 .../models/stable_diffusion.py | 3 +- .../models/stable_diffusion_onnx.py | 0 .../models/textual_inversion.py | 1 - .../models/vae.py | 0 .../backend/model_manager/storage/yaml.py | 2 +- invokeai/backend/util/__init__.py | 1 + 19 files changed, 300 insertions(+), 106 deletions(-) rename invokeai/backend/{model_management/model_cache.py => model_manager/cache.py} (100%) create mode 100644 invokeai/backend/model_manager/loader.py rename invokeai/backend/{model_management => model_manager}/lora.py (100%) rename invokeai/backend/{model_management => model_manager}/models/__init__.py (97%) rename invokeai/backend/{model_management => model_manager}/models/base.py (90%) rename invokeai/backend/{model_management => model_manager}/models/controlnet.py (100%) rename invokeai/backend/{model_management => model_manager}/models/lora.py (100%) rename invokeai/backend/{model_management => model_manager}/models/sdxl.py (100%) rename invokeai/backend/{model_management => model_manager}/models/stable_diffusion.py (99%) rename invokeai/backend/{model_management => model_manager}/models/stable_diffusion_onnx.py (100%) rename invokeai/backend/{model_management => model_manager}/models/textual_inversion.py (99%) rename invokeai/backend/{model_management => model_manager}/models/vae.py (100%) diff --git a/invokeai/backend/__init__.py b/invokeai/backend/__init__.py index 2e77d12eca..3dc7eb0532 100644 --- a/invokeai/backend/__init__.py +++ b/invokeai/backend/__init__.py @@ -1,5 +1,13 @@ """ Initialization file for invokeai.backend """ -from .model_management import ModelManager, ModelCache, BaseModelType, ModelType, SubModelType, ModelInfo # noqa: F401 -from .model_management.models import SilenceWarnings # noqa: F401 +from .model_manager import ( # noqa F401 + ModelLoader, + SilenceWarnings, + DuplicateModelException, + InvalidModelException, + BaseModelType, + ModelType, + SchedulerPredictionType, + ModelVariantType, +) diff --git a/invokeai/backend/model_manager/__init__.py b/invokeai/backend/model_manager/__init__.py index 0d1cdd8de1..312be808c8 100644 --- a/invokeai/backend/model_manager/__init__.py +++ b/invokeai/backend/model_manager/__init__.py @@ -1,7 +1,7 @@ """ Initialization file for invokeai.backend.model_manager.config """ -from ..model_management.models.base import read_checkpoint_meta # noqa F401 +from .models.base import read_checkpoint_meta # noqa F401 from .config import ( # noqa F401 BaseModelType, InvalidModelConfigException, @@ -12,7 +12,9 @@ from .config import ( # noqa F401 ModelVariantType, SchedulerPredictionType, SubModelType, + SilenceWarnings, ) +from .loader import ModelLoader # noqa F401 from .install import ModelInstall # noqa F401 from .probe import ModelProbe, InvalidModelException # noqa F401 from .storage import DuplicateModelException # noqa F401 diff --git a/invokeai/backend/model_management/model_cache.py b/invokeai/backend/model_manager/cache.py similarity index 100% rename from invokeai/backend/model_management/model_cache.py rename to invokeai/backend/model_manager/cache.py diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 1f4186732f..119fb0e12d 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -19,6 +19,8 @@ Typical usage: Validation errors will raise an InvalidModelConfigException error. """ +import warnings + from enum import Enum from typing import Optional, Literal, List, Union, Type from omegaconf.listconfig import ListConfig # to support the yaml backend @@ -26,11 +28,13 @@ import pydantic from pydantic import BaseModel, Field, Extra from pydantic.error_wrappers import ValidationError +# import these so that we can silence them +from diffusers import logging as diffusers_logging +from transformers import logging as transformers_logging class InvalidModelConfigException(Exception): """Exception for when config parser doesn't recognized this combination of model type and format.""" - class BaseModelType(str, Enum): """Base model type.""" @@ -94,6 +98,9 @@ class SchedulerPredictionType(str, Enum): VPrediction = "v_prediction" Sample = "sample" +# TODO: use this +class ModelError(str, Enum): + NotFound = "not_found" class ModelConfigBase(BaseModel): """Base class for model configuration information.""" @@ -114,7 +121,7 @@ class ModelConfigBase(BaseModel): class Config: """Pydantic configuration hint.""" - use_enum_values = True + use_enum_values = False extra = Extra.forbid validate_assignment = True @@ -267,3 +274,21 @@ class ModelConfigFactory(object): ) from exc except ValidationError as exc: raise InvalidModelConfigException(f"Invalid model configuration passed: {str(exc)}") from exc + +# TO DO: Move this somewhere else +class SilenceWarnings(object): + def __init__(self): + self.transformers_verbosity = transformers_logging.get_verbosity() + self.diffusers_verbosity = diffusers_logging.get_verbosity() + + def __enter__(self): + transformers_logging.set_verbosity_error() + diffusers_logging.set_verbosity_error() + warnings.simplefilter("ignore") + + def __exit__(self, type, value, traceback): + transformers_logging.set_verbosity(self.transformers_verbosity) + diffusers_logging.set_verbosity(self.diffusers_verbosity) + warnings.simplefilter("default") + + diff --git a/invokeai/backend/model_manager/hash.py b/invokeai/backend/model_manager/hash.py index 342c9c798b..873d1b87b2 100644 --- a/invokeai/backend/model_manager/hash.py +++ b/invokeai/backend/model_manager/hash.py @@ -3,7 +3,7 @@ Fast hashing of diffusers and checkpoint-style models. Usage: -from invokeai.backend.model_management.model_hash import FastModelHash +from invokeai.backend.model_managre.model_hash import FastModelHash >>> FastModelHash.hash('/home/models/stable-diffusion-v1.5') 'a8e693a126ea5b831c96064dc569956f' """ diff --git a/invokeai/backend/model_manager/install.py b/invokeai/backend/model_manager/install.py index 740eae0da0..5b14cef3f5 100644 --- a/invokeai/backend/model_manager/install.py +++ b/invokeai/backend/model_manager/install.py @@ -52,7 +52,8 @@ import tempfile from abc import ABC, abstractmethod from pathlib import Path from shutil import rmtree -from typing import Optional, List, Union, Dict +from typing import Optional, List, Union, Dict, Set +from pydantic import Field from pydantic.networks import AnyHttpUrl from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.util.logging import InvokeAILogger @@ -236,6 +237,7 @@ class ModelInstall(ModelInstallBase): _store: ModelConfigStore _download_queue: DownloadQueueBase _async_installs: Dict[str, str] + _installed: Set[Path] = Field(default=set) _tmpdir: Optional[tempfile.TemporaryDirectory] # used for downloads _legacy_configs = { @@ -273,6 +275,7 @@ class ModelInstall(ModelInstallBase): self._store = store or ModelConfigStoreYAML(self._config.model_conf_path) self._download_queue = download or DownloadQueue(config=self._config) self._async_installs = dict() + self._installed = set() self._tmpdir = None @property @@ -428,7 +431,7 @@ class ModelInstall(ModelInstallBase): # the following two methods are callbacks to the ModelSearch object def _scan_register(self, model: Path) -> bool: try: - id = self.register(model) + id = self.register_path(model) self._logger.info(f"Registered {model} with id {id}") self._installed.add(id) except DuplicateModelException: @@ -437,7 +440,7 @@ class ModelInstall(ModelInstallBase): def _scan_install(self, model: Path) -> bool: try: - id = self.install(model) + id = self.install_path(model) self._logger.info(f"Installed {model} with id {id}") self._installed.add(id) except DuplicateModelException: diff --git a/invokeai/backend/model_manager/loader.py b/invokeai/backend/model_manager/loader.py new file mode 100644 index 0000000000..f93cd67f04 --- /dev/null +++ b/invokeai/backend/model_manager/loader.py @@ -0,0 +1,234 @@ +# Copyright (c) 2023, Lincoln D. Stein +"""Model loader for InvokeAI.""" + +import hashlib +from abc import ABC, abstractmethod +from dataclasses import dataclass +from pathlib import Path +from typing import Union, Optional + +import torch + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.backend.util import choose_precision, choose_torch_device, InvokeAILogger +from .config import BaseModelType, ModelType, SubModelType, ModelConfigBase +from .install import ModelInstallBase, ModelInstall +from .storage import ModelConfigStore, ModelConfigStoreYAML, ModelConfigStoreSQL +from .cache import ModelCache, ModelLocker +from .models import InvalidModelException, ModelBase, MODEL_CLASSES + + +@dataclass +class ModelInfo(): + """This is a context manager object that is used to intermediate access to a model.""" + + context: ModelLocker + name: str + base_model: BaseModelType + type: ModelType + id: str + location: Union[Path, str] + precision: torch.dtype + _cache: Optional[ModelCache] = None + + def __enter__(self): + """Context entry.""" + return self.context.__enter__() + + def __exit__(self, *args, **kwargs): + """Context exit.""" + self.context.__exit__(*args, **kwargs) + + +class ModelLoaderBase(ABC): + """Abstract base class for a model loader which works with the ModelConfigStore backend.""" + + @abstractmethod + def get_model(self, + key: str, + submodel_type: Optional[SubModelType] = None + ) -> ModelInfo: + """ + Return a model given its key. + + Given a model key identified in the model configuration backend, + return a ModelInfo object that can be used to retrieve the model. + + :param key: model key, as known to the config backend + :param submodel_type: an ModelType enum indicating the portion of + the model to retrieve (e.g. ModelType.Vae) + """ + pass + + @property + @abstractmethod + def store(self) -> ModelConfigStore: + """Return the ModelConfigStore object that supports this loader.""" + pass + + @property + @abstractmethod + def installer(self) -> ModelInstallBase: + """Return the ModelInstallBase object that supports this loader.""" + pass + + +class ModelLoader(ModelLoaderBase): + """Implementation of ModelLoaderBase.""" + + _app_config: InvokeAIAppConfig + _store: ModelConfigStore + _installer: ModelInstallBase + _cache: ModelCache + _logger: InvokeAILogger + _cache_keys: dict + + def __init__(self, + config: InvokeAIAppConfig, + ): + """ + Initialize ModelLoader object. + + :param config: The app's InvokeAIAppConfig object. + """ + if config.model_conf_path and config.model_conf_path.exists(): + models_file = config.model_conf_path + else: + models_file = config.root_path / "configs/models3.yaml" + store = ModelConfigStoreYAML(models_file) \ + if models_file.suffix == '.yaml' \ + else ModelConfigStoreSQL(models_file) \ + if models_file.suffix == '.db' \ + else None + if not store: + raise ValueError(f"Invalid model configuration file: {models_file}") + + self._app_config = config + self._store = store + self._logger = InvokeAILogger.getLogger() + self._installer = ModelInstall(store=self._store, + logger=self._logger, + config=self._app_config + ) + self._cache_keys = dict() + device = torch.device(choose_torch_device()) + device_name = torch.cuda.get_device_name() if device == torch.device("cuda") else "" + precision = choose_precision(device) if config.precision == "auto" else config.precision + dtype = torch.float32 if precision == "float32" else torch.float16 + + self._logger.info(f"Using models database {models_file}") + self._logger.info(f"Rendering device = {device} ({device_name})") + self._logger.info(f"Maximum RAM cache size: {config.ram_cache_size}") + self._logger.info(f"Maximum VRAM cache size: {config.vram_cache_size}") + self._logger.info(f"Precision: {precision}") + + self._cache = ModelCache( + max_cache_size=config.ram_cache_size, + max_vram_cache_size=config.vram_cache_size, + lazy_offloading=config.lazy_offload, + execution_device=device, + precision=dtype, + sequential_offload=config.sequential_guidance, + logger=self._logger, + ) + + @property + def store(self) -> ModelConfigStore: + """Return the ModelConfigStore instance used by this class.""" + return self._store + + @property + def installer(self) -> ModelInstallBase: + """Return the ModelInstallBase instance used by this class.""" + return self._installer + + def get_model(self, + key: str, + submodel_type: Optional[SubModelType] = None + ) -> ModelInfo: + """ + Get the ModelInfo corresponding to the model with key "key". + + Given a model key identified in the model configuration backend, + return a ModelInfo object that can be used to retrieve the model. + + :param key: model key, as known to the config backend + :param submodel_type: an ModelType enum indicating the portion of + the model to retrieve (e.g. ModelType.Vae) + """ + model_config = self.store.get_model(key) # May raise a UnknownModelException + model_path, is_submodel_override = self._get_model_path(model_config, submodel_type) + + if is_submodel_override: + model_type = submodel_type + submodel_type = None + + model_class = self._get_implementation(model_config.base_model, model_config.model_type) + if not model_path.exists(): + raise InvalidModelException(f"Files for model '{key}' not found at {model_path}") + + dst_convert_path = self._get_model_cache_path(model_path) + model_path = model_class.convert_if_required( + base_model=model_config.base_model, + model_path=model_path.as_posix(), + output_path=dst_convert_path, + config=model_config, + ) + + model_context = self._cache.get_model( + model_path=model_path, + model_class=model_class, + base_model=model_config.base_model, + model_type=model_config.model_type, + submodel=SubModelType(submodel_type), + ) + + if key not in self._cache_keys: + self._cache_keys[key] = set() + self._cache_keys[key].add(model_context.key) + + return ModelInfo( + context=model_context, + name=model_config.name, + base_model=model_config.base_model, + type=submodel_type or model_type, + id=model_config.id, + location=model_path, + precision=self._cache.precision, + _cache=self._cache, + ) + + def _get_implementation(self, + base_model: BaseModelType, + model_type: ModelType + ) -> type[ModelBase]: + """Get the concrete implementation class for a specific model type.""" + model_class = MODEL_CLASSES[base_model][model_type] + return model_class + + def _get_model_cache_path(self, model_path): + return self._resolve_model_path(Path(".cache") / hashlib.md5(str(model_path).encode()).hexdigest()) + + def _resolve_model_path(self, path: Union[Path, str]) -> Path: + """Return relative paths based on configured models_path.""" + return self._app_config.models_path / path + + def _get_model_path( + self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None + ) -> (Path, bool): + """Extract a model's filesystem path from its config. + + :return: The fully qualified Path of the module (or submodule). + """ + model_path = model_config.path + is_submodel_override = False + + # Does the config explicitly override the submodel? + if submodel_type is not None and hasattr(model_config, submodel_type): + submodel_path = getattr(model_config, submodel_type) + if submodel_path is not None and len(submodel_path) > 0: + model_path = getattr(model_config, submodel_type) + is_submodel_override = True + + model_path = self._resolve_model_path(model_path) + return model_path, is_submodel_override diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_manager/lora.py similarity index 100% rename from invokeai/backend/model_management/lora.py rename to invokeai/backend/model_manager/lora.py diff --git a/invokeai/backend/model_management/models/__init__.py b/invokeai/backend/model_manager/models/__init__.py similarity index 97% rename from invokeai/backend/model_management/models/__init__.py rename to invokeai/backend/model_manager/models/__init__.py index 2de206257b..565ed9dd43 100644 --- a/invokeai/backend/model_management/models/__init__.py +++ b/invokeai/backend/model_manager/models/__init__.py @@ -10,11 +10,7 @@ from .base import ( # noqa: F401 ModelConfigBase, ModelVariantType, SchedulerPredictionType, - ModelError, - SilenceWarnings, - ModelNotFoundException, - InvalidModelException, - DuplicateModelException, + InvalidModelException ) from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model from .sdxl import StableDiffusionXLModel diff --git a/invokeai/backend/model_management/models/base.py b/invokeai/backend/model_manager/models/base.py similarity index 90% rename from invokeai/backend/model_management/models/base.py rename to invokeai/backend/model_manager/models/base.py index ed1c2c6098..e112e9cd7e 100644 --- a/invokeai/backend/model_management/models/base.py +++ b/invokeai/backend/model_manager/models/base.py @@ -3,7 +3,6 @@ import os import sys import typing import inspect -import warnings from abc import ABCMeta, abstractmethod from contextlib import suppress from enum import Enum @@ -21,84 +20,29 @@ from onnxruntime import ( SessionOptions, get_available_providers, ) -from pydantic import BaseModel, Field +from pydantic import BaseModel from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union -from diffusers import logging as diffusers_logging -from transformers import logging as transformers_logging - - -class DuplicateModelException(Exception): - pass - - -class InvalidModelException(Exception): - pass - +from ..config import ( # noqa F401 + BaseModelType, + ModelType, + SubModelType, + ModelVariantType, + ModelFormat, + SchedulerPredictionType, + ModelConfigBase, +) class ModelNotFoundException(Exception): - pass - - -class BaseModelType(str, Enum): - StableDiffusion1 = "sd-1" - StableDiffusion2 = "sd-2" - StableDiffusionXL = "sdxl" - StableDiffusionXLRefiner = "sdxl-refiner" - # Kandinsky2_1 = "kandinsky-2.1" - - -class ModelType(str, Enum): - ONNX = "onnx" - Main = "main" - Vae = "vae" - Lora = "lora" - ControlNet = "controlnet" # used by model_probe - TextualInversion = "embedding" - - -class SubModelType(str, Enum): - UNet = "unet" - TextEncoder = "text_encoder" - TextEncoder2 = "text_encoder_2" - Tokenizer = "tokenizer" - Tokenizer2 = "tokenizer_2" - Vae = "vae" - VaeDecoder = "vae_decoder" - VaeEncoder = "vae_encoder" - Scheduler = "scheduler" - SafetyChecker = "safety_checker" - # MoVQ = "movq" - - -class ModelVariantType(str, Enum): - Normal = "normal" - Inpaint = "inpaint" - Depth = "depth" - - -class SchedulerPredictionType(str, Enum): - Epsilon = "epsilon" - VPrediction = "v_prediction" - Sample = "sample" - - -class ModelError(str, Enum): - NotFound = "not_found" - - -class ModelConfigBase(BaseModel): - path: str # or Path - description: Optional[str] = Field(None) - model_format: Optional[str] = Field(None) - error: Optional[ModelError] = Field(None) - - class Config: - use_enum_values = True + """Exception for when a model is not found on the expected path.""" +class InvalidModelException(Exception): + """Exception for when a model is corrupted in some way; for example missing files.""" class EmptyConfigLoader(ConfigMixin): + @classmethod def load_config(cls, *args, **kwargs): + """Load empty configuration.""" cls.config_name = kwargs.pop("config_name") return super().load_config(*args, **kwargs) @@ -453,25 +397,8 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = False): return checkpoint -class SilenceWarnings(object): - def __init__(self): - self.transformers_verbosity = transformers_logging.get_verbosity() - self.diffusers_verbosity = diffusers_logging.get_verbosity() - - def __enter__(self): - transformers_logging.set_verbosity_error() - diffusers_logging.set_verbosity_error() - warnings.simplefilter("ignore") - - def __exit__(self, type, value, traceback): - transformers_logging.set_verbosity(self.transformers_verbosity) - diffusers_logging.set_verbosity(self.diffusers_verbosity) - warnings.simplefilter("default") - - ONNX_WEIGHTS_NAME = "model.onnx" - class IAIOnnxRuntimeModel: class _tensor_access: def __init__(self, model): diff --git a/invokeai/backend/model_management/models/controlnet.py b/invokeai/backend/model_manager/models/controlnet.py similarity index 100% rename from invokeai/backend/model_management/models/controlnet.py rename to invokeai/backend/model_manager/models/controlnet.py diff --git a/invokeai/backend/model_management/models/lora.py b/invokeai/backend/model_manager/models/lora.py similarity index 100% rename from invokeai/backend/model_management/models/lora.py rename to invokeai/backend/model_manager/models/lora.py diff --git a/invokeai/backend/model_management/models/sdxl.py b/invokeai/backend/model_manager/models/sdxl.py similarity index 100% rename from invokeai/backend/model_management/models/sdxl.py rename to invokeai/backend/model_manager/models/sdxl.py diff --git a/invokeai/backend/model_management/models/stable_diffusion.py b/invokeai/backend/model_manager/models/stable_diffusion.py similarity index 99% rename from invokeai/backend/model_management/models/stable_diffusion.py rename to invokeai/backend/model_manager/models/stable_diffusion.py index cc34f14b9c..622a707d07 100644 --- a/invokeai/backend/model_management/models/stable_diffusion.py +++ b/invokeai/backend/model_manager/models/stable_diffusion.py @@ -11,12 +11,11 @@ from .base import ( ModelType, ModelVariantType, DiffusersModel, - SilenceWarnings, read_checkpoint_meta, classproperty, InvalidModelException, - ModelNotFoundException, ) +from ..config import SilenceWarnings from .sdxl import StableDiffusionXLModel import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig diff --git a/invokeai/backend/model_management/models/stable_diffusion_onnx.py b/invokeai/backend/model_manager/models/stable_diffusion_onnx.py similarity index 100% rename from invokeai/backend/model_management/models/stable_diffusion_onnx.py rename to invokeai/backend/model_manager/models/stable_diffusion_onnx.py diff --git a/invokeai/backend/model_management/models/textual_inversion.py b/invokeai/backend/model_manager/models/textual_inversion.py similarity index 99% rename from invokeai/backend/model_management/models/textual_inversion.py rename to invokeai/backend/model_manager/models/textual_inversion.py index a949a15be1..04c9b37980 100644 --- a/invokeai/backend/model_management/models/textual_inversion.py +++ b/invokeai/backend/model_manager/models/textual_inversion.py @@ -15,7 +15,6 @@ from .base import ( # TODO: naming from ..lora import TextualInversionModel as TextualInversionModelRaw - class TextualInversionModel(ModelBase): # model_size: int diff --git a/invokeai/backend/model_management/models/vae.py b/invokeai/backend/model_manager/models/vae.py similarity index 100% rename from invokeai/backend/model_management/models/vae.py rename to invokeai/backend/model_manager/models/vae.py diff --git a/invokeai/backend/model_manager/storage/yaml.py b/invokeai/backend/model_manager/storage/yaml.py index 66098c0f5d..0b1b686695 100644 --- a/invokeai/backend/model_manager/storage/yaml.py +++ b/invokeai/backend/model_manager/storage/yaml.py @@ -4,7 +4,7 @@ Implementation of ModelConfigStore using a YAML file. Typical usage: - from invokeai.backend.model_management2.storage.yaml import ModelConfigStoreYAML + from invokeai.backend.model_manager.storage.yaml import ModelConfigStoreYAML store = ModelConfigStoreYAML("./configs/models.yaml") config = dict( path='/tmp/pokemon.bin', diff --git a/invokeai/backend/util/__init__.py b/invokeai/backend/util/__init__.py index a8d53f54a4..f7b5d36d77 100644 --- a/invokeai/backend/util/__init__.py +++ b/invokeai/backend/util/__init__.py @@ -18,3 +18,4 @@ from .util import ( # noqa: F401 Chdir, ) from .attention import auto_detect_slice_size # noqa: F401 +from .logging import InvokeAILogger