mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
loading works -- web app broken
This commit is contained in:
@ -1,5 +1,13 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend
|
||||
"""
|
||||
from .model_management import ModelManager, ModelCache, BaseModelType, ModelType, SubModelType, ModelInfo # noqa: F401
|
||||
from .model_management.models import SilenceWarnings # noqa: F401
|
||||
from .model_manager import ( # noqa F401
|
||||
ModelLoader,
|
||||
SilenceWarnings,
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SchedulerPredictionType,
|
||||
ModelVariantType,
|
||||
)
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.model_manager.config
|
||||
"""
|
||||
from ..model_management.models.base import read_checkpoint_meta # noqa F401
|
||||
from .models.base import read_checkpoint_meta # noqa F401
|
||||
from .config import ( # noqa F401
|
||||
BaseModelType,
|
||||
InvalidModelConfigException,
|
||||
@ -12,7 +12,9 @@ from .config import ( # noqa F401
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
SubModelType,
|
||||
SilenceWarnings,
|
||||
)
|
||||
from .loader import ModelLoader # noqa F401
|
||||
from .install import ModelInstall # noqa F401
|
||||
from .probe import ModelProbe, InvalidModelException # noqa F401
|
||||
from .storage import DuplicateModelException # noqa F401
|
||||
|
@ -19,6 +19,8 @@ Typical usage:
|
||||
Validation errors will raise an InvalidModelConfigException error.
|
||||
|
||||
"""
|
||||
import warnings
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional, Literal, List, Union, Type
|
||||
from omegaconf.listconfig import ListConfig # to support the yaml backend
|
||||
@ -26,11 +28,13 @@ import pydantic
|
||||
from pydantic import BaseModel, Field, Extra
|
||||
from pydantic.error_wrappers import ValidationError
|
||||
|
||||
# import these so that we can silence them
|
||||
from diffusers import logging as diffusers_logging
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
class InvalidModelConfigException(Exception):
|
||||
"""Exception for when config parser doesn't recognized this combination of model type and format."""
|
||||
|
||||
|
||||
class BaseModelType(str, Enum):
|
||||
"""Base model type."""
|
||||
|
||||
@ -94,6 +98,9 @@ class SchedulerPredictionType(str, Enum):
|
||||
VPrediction = "v_prediction"
|
||||
Sample = "sample"
|
||||
|
||||
# TODO: use this
|
||||
class ModelError(str, Enum):
|
||||
NotFound = "not_found"
|
||||
|
||||
class ModelConfigBase(BaseModel):
|
||||
"""Base class for model configuration information."""
|
||||
@ -114,7 +121,7 @@ class ModelConfigBase(BaseModel):
|
||||
class Config:
|
||||
"""Pydantic configuration hint."""
|
||||
|
||||
use_enum_values = True
|
||||
use_enum_values = False
|
||||
extra = Extra.forbid
|
||||
validate_assignment = True
|
||||
|
||||
@ -267,3 +274,21 @@ class ModelConfigFactory(object):
|
||||
) from exc
|
||||
except ValidationError as exc:
|
||||
raise InvalidModelConfigException(f"Invalid model configuration passed: {str(exc)}") from exc
|
||||
|
||||
# TO DO: Move this somewhere else
|
||||
class SilenceWarnings(object):
|
||||
def __init__(self):
|
||||
self.transformers_verbosity = transformers_logging.get_verbosity()
|
||||
self.diffusers_verbosity = diffusers_logging.get_verbosity()
|
||||
|
||||
def __enter__(self):
|
||||
transformers_logging.set_verbosity_error()
|
||||
diffusers_logging.set_verbosity_error()
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
transformers_logging.set_verbosity(self.transformers_verbosity)
|
||||
diffusers_logging.set_verbosity(self.diffusers_verbosity)
|
||||
warnings.simplefilter("default")
|
||||
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
Fast hashing of diffusers and checkpoint-style models.
|
||||
|
||||
Usage:
|
||||
from invokeai.backend.model_management.model_hash import FastModelHash
|
||||
from invokeai.backend.model_managre.model_hash import FastModelHash
|
||||
>>> FastModelHash.hash('/home/models/stable-diffusion-v1.5')
|
||||
'a8e693a126ea5b831c96064dc569956f'
|
||||
"""
|
||||
|
@ -52,7 +52,8 @@ import tempfile
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
from typing import Optional, List, Union, Dict
|
||||
from typing import Optional, List, Union, Dict, Set
|
||||
from pydantic import Field
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
@ -236,6 +237,7 @@ class ModelInstall(ModelInstallBase):
|
||||
_store: ModelConfigStore
|
||||
_download_queue: DownloadQueueBase
|
||||
_async_installs: Dict[str, str]
|
||||
_installed: Set[Path] = Field(default=set)
|
||||
_tmpdir: Optional[tempfile.TemporaryDirectory] # used for downloads
|
||||
|
||||
_legacy_configs = {
|
||||
@ -273,6 +275,7 @@ class ModelInstall(ModelInstallBase):
|
||||
self._store = store or ModelConfigStoreYAML(self._config.model_conf_path)
|
||||
self._download_queue = download or DownloadQueue(config=self._config)
|
||||
self._async_installs = dict()
|
||||
self._installed = set()
|
||||
self._tmpdir = None
|
||||
|
||||
@property
|
||||
@ -428,7 +431,7 @@ class ModelInstall(ModelInstallBase):
|
||||
# the following two methods are callbacks to the ModelSearch object
|
||||
def _scan_register(self, model: Path) -> bool:
|
||||
try:
|
||||
id = self.register(model)
|
||||
id = self.register_path(model)
|
||||
self._logger.info(f"Registered {model} with id {id}")
|
||||
self._installed.add(id)
|
||||
except DuplicateModelException:
|
||||
@ -437,7 +440,7 @@ class ModelInstall(ModelInstallBase):
|
||||
|
||||
def _scan_install(self, model: Path) -> bool:
|
||||
try:
|
||||
id = self.install(model)
|
||||
id = self.install_path(model)
|
||||
self._logger.info(f"Installed {model} with id {id}")
|
||||
self._installed.add(id)
|
||||
except DuplicateModelException:
|
||||
|
234
invokeai/backend/model_manager/loader.py
Normal file
234
invokeai/backend/model_manager/loader.py
Normal file
@ -0,0 +1,234 @@
|
||||
# Copyright (c) 2023, Lincoln D. Stein
|
||||
"""Model loader for InvokeAI."""
|
||||
|
||||
import hashlib
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Union, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util import choose_precision, choose_torch_device, InvokeAILogger
|
||||
from .config import BaseModelType, ModelType, SubModelType, ModelConfigBase
|
||||
from .install import ModelInstallBase, ModelInstall
|
||||
from .storage import ModelConfigStore, ModelConfigStoreYAML, ModelConfigStoreSQL
|
||||
from .cache import ModelCache, ModelLocker
|
||||
from .models import InvalidModelException, ModelBase, MODEL_CLASSES
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelInfo():
|
||||
"""This is a context manager object that is used to intermediate access to a model."""
|
||||
|
||||
context: ModelLocker
|
||||
name: str
|
||||
base_model: BaseModelType
|
||||
type: ModelType
|
||||
id: str
|
||||
location: Union[Path, str]
|
||||
precision: torch.dtype
|
||||
_cache: Optional[ModelCache] = None
|
||||
|
||||
def __enter__(self):
|
||||
"""Context entry."""
|
||||
return self.context.__enter__()
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
"""Context exit."""
|
||||
self.context.__exit__(*args, **kwargs)
|
||||
|
||||
|
||||
class ModelLoaderBase(ABC):
|
||||
"""Abstract base class for a model loader which works with the ModelConfigStore backend."""
|
||||
|
||||
@abstractmethod
|
||||
def get_model(self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None
|
||||
) -> ModelInfo:
|
||||
"""
|
||||
Return a model given its key.
|
||||
|
||||
Given a model key identified in the model configuration backend,
|
||||
return a ModelInfo object that can be used to retrieve the model.
|
||||
|
||||
:param key: model key, as known to the config backend
|
||||
:param submodel_type: an ModelType enum indicating the portion of
|
||||
the model to retrieve (e.g. ModelType.Vae)
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def store(self) -> ModelConfigStore:
|
||||
"""Return the ModelConfigStore object that supports this loader."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def installer(self) -> ModelInstallBase:
|
||||
"""Return the ModelInstallBase object that supports this loader."""
|
||||
pass
|
||||
|
||||
|
||||
class ModelLoader(ModelLoaderBase):
|
||||
"""Implementation of ModelLoaderBase."""
|
||||
|
||||
_app_config: InvokeAIAppConfig
|
||||
_store: ModelConfigStore
|
||||
_installer: ModelInstallBase
|
||||
_cache: ModelCache
|
||||
_logger: InvokeAILogger
|
||||
_cache_keys: dict
|
||||
|
||||
def __init__(self,
|
||||
config: InvokeAIAppConfig,
|
||||
):
|
||||
"""
|
||||
Initialize ModelLoader object.
|
||||
|
||||
:param config: The app's InvokeAIAppConfig object.
|
||||
"""
|
||||
if config.model_conf_path and config.model_conf_path.exists():
|
||||
models_file = config.model_conf_path
|
||||
else:
|
||||
models_file = config.root_path / "configs/models3.yaml"
|
||||
store = ModelConfigStoreYAML(models_file) \
|
||||
if models_file.suffix == '.yaml' \
|
||||
else ModelConfigStoreSQL(models_file) \
|
||||
if models_file.suffix == '.db' \
|
||||
else None
|
||||
if not store:
|
||||
raise ValueError(f"Invalid model configuration file: {models_file}")
|
||||
|
||||
self._app_config = config
|
||||
self._store = store
|
||||
self._logger = InvokeAILogger.getLogger()
|
||||
self._installer = ModelInstall(store=self._store,
|
||||
logger=self._logger,
|
||||
config=self._app_config
|
||||
)
|
||||
self._cache_keys = dict()
|
||||
device = torch.device(choose_torch_device())
|
||||
device_name = torch.cuda.get_device_name() if device == torch.device("cuda") else ""
|
||||
precision = choose_precision(device) if config.precision == "auto" else config.precision
|
||||
dtype = torch.float32 if precision == "float32" else torch.float16
|
||||
|
||||
self._logger.info(f"Using models database {models_file}")
|
||||
self._logger.info(f"Rendering device = {device} ({device_name})")
|
||||
self._logger.info(f"Maximum RAM cache size: {config.ram_cache_size}")
|
||||
self._logger.info(f"Maximum VRAM cache size: {config.vram_cache_size}")
|
||||
self._logger.info(f"Precision: {precision}")
|
||||
|
||||
self._cache = ModelCache(
|
||||
max_cache_size=config.ram_cache_size,
|
||||
max_vram_cache_size=config.vram_cache_size,
|
||||
lazy_offloading=config.lazy_offload,
|
||||
execution_device=device,
|
||||
precision=dtype,
|
||||
sequential_offload=config.sequential_guidance,
|
||||
logger=self._logger,
|
||||
)
|
||||
|
||||
@property
|
||||
def store(self) -> ModelConfigStore:
|
||||
"""Return the ModelConfigStore instance used by this class."""
|
||||
return self._store
|
||||
|
||||
@property
|
||||
def installer(self) -> ModelInstallBase:
|
||||
"""Return the ModelInstallBase instance used by this class."""
|
||||
return self._installer
|
||||
|
||||
def get_model(self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None
|
||||
) -> ModelInfo:
|
||||
"""
|
||||
Get the ModelInfo corresponding to the model with key "key".
|
||||
|
||||
Given a model key identified in the model configuration backend,
|
||||
return a ModelInfo object that can be used to retrieve the model.
|
||||
|
||||
:param key: model key, as known to the config backend
|
||||
:param submodel_type: an ModelType enum indicating the portion of
|
||||
the model to retrieve (e.g. ModelType.Vae)
|
||||
"""
|
||||
model_config = self.store.get_model(key) # May raise a UnknownModelException
|
||||
model_path, is_submodel_override = self._get_model_path(model_config, submodel_type)
|
||||
|
||||
if is_submodel_override:
|
||||
model_type = submodel_type
|
||||
submodel_type = None
|
||||
|
||||
model_class = self._get_implementation(model_config.base_model, model_config.model_type)
|
||||
if not model_path.exists():
|
||||
raise InvalidModelException(f"Files for model '{key}' not found at {model_path}")
|
||||
|
||||
dst_convert_path = self._get_model_cache_path(model_path)
|
||||
model_path = model_class.convert_if_required(
|
||||
base_model=model_config.base_model,
|
||||
model_path=model_path.as_posix(),
|
||||
output_path=dst_convert_path,
|
||||
config=model_config,
|
||||
)
|
||||
|
||||
model_context = self._cache.get_model(
|
||||
model_path=model_path,
|
||||
model_class=model_class,
|
||||
base_model=model_config.base_model,
|
||||
model_type=model_config.model_type,
|
||||
submodel=SubModelType(submodel_type),
|
||||
)
|
||||
|
||||
if key not in self._cache_keys:
|
||||
self._cache_keys[key] = set()
|
||||
self._cache_keys[key].add(model_context.key)
|
||||
|
||||
return ModelInfo(
|
||||
context=model_context,
|
||||
name=model_config.name,
|
||||
base_model=model_config.base_model,
|
||||
type=submodel_type or model_type,
|
||||
id=model_config.id,
|
||||
location=model_path,
|
||||
precision=self._cache.precision,
|
||||
_cache=self._cache,
|
||||
)
|
||||
|
||||
def _get_implementation(self,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType
|
||||
) -> type[ModelBase]:
|
||||
"""Get the concrete implementation class for a specific model type."""
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
return model_class
|
||||
|
||||
def _get_model_cache_path(self, model_path):
|
||||
return self._resolve_model_path(Path(".cache") / hashlib.md5(str(model_path).encode()).hexdigest())
|
||||
|
||||
def _resolve_model_path(self, path: Union[Path, str]) -> Path:
|
||||
"""Return relative paths based on configured models_path."""
|
||||
return self._app_config.models_path / path
|
||||
|
||||
def _get_model_path(
|
||||
self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None
|
||||
) -> (Path, bool):
|
||||
"""Extract a model's filesystem path from its config.
|
||||
|
||||
:return: The fully qualified Path of the module (or submodule).
|
||||
"""
|
||||
model_path = model_config.path
|
||||
is_submodel_override = False
|
||||
|
||||
# Does the config explicitly override the submodel?
|
||||
if submodel_type is not None and hasattr(model_config, submodel_type):
|
||||
submodel_path = getattr(model_config, submodel_type)
|
||||
if submodel_path is not None and len(submodel_path) > 0:
|
||||
model_path = getattr(model_config, submodel_type)
|
||||
is_submodel_override = True
|
||||
|
||||
model_path = self._resolve_model_path(model_path)
|
||||
return model_path, is_submodel_override
|
@ -10,11 +10,7 @@ from .base import ( # noqa: F401
|
||||
ModelConfigBase,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
ModelError,
|
||||
SilenceWarnings,
|
||||
ModelNotFoundException,
|
||||
InvalidModelException,
|
||||
DuplicateModelException,
|
||||
InvalidModelException
|
||||
)
|
||||
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
|
||||
from .sdxl import StableDiffusionXLModel
|
@ -3,7 +3,6 @@ import os
|
||||
import sys
|
||||
import typing
|
||||
import inspect
|
||||
import warnings
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from contextlib import suppress
|
||||
from enum import Enum
|
||||
@ -21,84 +20,29 @@ from onnxruntime import (
|
||||
SessionOptions,
|
||||
get_available_providers,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
|
||||
from diffusers import logging as diffusers_logging
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
|
||||
class DuplicateModelException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidModelException(Exception):
|
||||
pass
|
||||
|
||||
from ..config import ( # noqa F401
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
ModelVariantType,
|
||||
ModelFormat,
|
||||
SchedulerPredictionType,
|
||||
ModelConfigBase,
|
||||
)
|
||||
|
||||
class ModelNotFoundException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class BaseModelType(str, Enum):
|
||||
StableDiffusion1 = "sd-1"
|
||||
StableDiffusion2 = "sd-2"
|
||||
StableDiffusionXL = "sdxl"
|
||||
StableDiffusionXLRefiner = "sdxl-refiner"
|
||||
# Kandinsky2_1 = "kandinsky-2.1"
|
||||
|
||||
|
||||
class ModelType(str, Enum):
|
||||
ONNX = "onnx"
|
||||
Main = "main"
|
||||
Vae = "vae"
|
||||
Lora = "lora"
|
||||
ControlNet = "controlnet" # used by model_probe
|
||||
TextualInversion = "embedding"
|
||||
|
||||
|
||||
class SubModelType(str, Enum):
|
||||
UNet = "unet"
|
||||
TextEncoder = "text_encoder"
|
||||
TextEncoder2 = "text_encoder_2"
|
||||
Tokenizer = "tokenizer"
|
||||
Tokenizer2 = "tokenizer_2"
|
||||
Vae = "vae"
|
||||
VaeDecoder = "vae_decoder"
|
||||
VaeEncoder = "vae_encoder"
|
||||
Scheduler = "scheduler"
|
||||
SafetyChecker = "safety_checker"
|
||||
# MoVQ = "movq"
|
||||
|
||||
|
||||
class ModelVariantType(str, Enum):
|
||||
Normal = "normal"
|
||||
Inpaint = "inpaint"
|
||||
Depth = "depth"
|
||||
|
||||
|
||||
class SchedulerPredictionType(str, Enum):
|
||||
Epsilon = "epsilon"
|
||||
VPrediction = "v_prediction"
|
||||
Sample = "sample"
|
||||
|
||||
|
||||
class ModelError(str, Enum):
|
||||
NotFound = "not_found"
|
||||
|
||||
|
||||
class ModelConfigBase(BaseModel):
|
||||
path: str # or Path
|
||||
description: Optional[str] = Field(None)
|
||||
model_format: Optional[str] = Field(None)
|
||||
error: Optional[ModelError] = Field(None)
|
||||
|
||||
class Config:
|
||||
use_enum_values = True
|
||||
"""Exception for when a model is not found on the expected path."""
|
||||
|
||||
class InvalidModelException(Exception):
|
||||
"""Exception for when a model is corrupted in some way; for example missing files."""
|
||||
|
||||
class EmptyConfigLoader(ConfigMixin):
|
||||
|
||||
@classmethod
|
||||
def load_config(cls, *args, **kwargs):
|
||||
"""Load empty configuration."""
|
||||
cls.config_name = kwargs.pop("config_name")
|
||||
return super().load_config(*args, **kwargs)
|
||||
|
||||
@ -453,25 +397,8 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
|
||||
return checkpoint
|
||||
|
||||
|
||||
class SilenceWarnings(object):
|
||||
def __init__(self):
|
||||
self.transformers_verbosity = transformers_logging.get_verbosity()
|
||||
self.diffusers_verbosity = diffusers_logging.get_verbosity()
|
||||
|
||||
def __enter__(self):
|
||||
transformers_logging.set_verbosity_error()
|
||||
diffusers_logging.set_verbosity_error()
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
transformers_logging.set_verbosity(self.transformers_verbosity)
|
||||
diffusers_logging.set_verbosity(self.diffusers_verbosity)
|
||||
warnings.simplefilter("default")
|
||||
|
||||
|
||||
ONNX_WEIGHTS_NAME = "model.onnx"
|
||||
|
||||
|
||||
class IAIOnnxRuntimeModel:
|
||||
class _tensor_access:
|
||||
def __init__(self, model):
|
@ -11,12 +11,11 @@ from .base import (
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
DiffusersModel,
|
||||
SilenceWarnings,
|
||||
read_checkpoint_meta,
|
||||
classproperty,
|
||||
InvalidModelException,
|
||||
ModelNotFoundException,
|
||||
)
|
||||
from ..config import SilenceWarnings
|
||||
from .sdxl import StableDiffusionXLModel
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
@ -15,7 +15,6 @@ from .base import (
|
||||
# TODO: naming
|
||||
from ..lora import TextualInversionModel as TextualInversionModelRaw
|
||||
|
||||
|
||||
class TextualInversionModel(ModelBase):
|
||||
# model_size: int
|
||||
|
@ -4,7 +4,7 @@ Implementation of ModelConfigStore using a YAML file.
|
||||
|
||||
Typical usage:
|
||||
|
||||
from invokeai.backend.model_management2.storage.yaml import ModelConfigStoreYAML
|
||||
from invokeai.backend.model_manager.storage.yaml import ModelConfigStoreYAML
|
||||
store = ModelConfigStoreYAML("./configs/models.yaml")
|
||||
config = dict(
|
||||
path='/tmp/pokemon.bin',
|
||||
|
@ -18,3 +18,4 @@ from .util import ( # noqa: F401
|
||||
Chdir,
|
||||
)
|
||||
from .attention import auto_detect_slice_size # noqa: F401
|
||||
from .logging import InvokeAILogger
|
||||
|
Reference in New Issue
Block a user