mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
consolidate model manager parts into a single class
This commit is contained in:
parent
a6508d1391
commit
1d724bca4a
6
invokeai/app/services/model_load/__init__.py
Normal file
6
invokeai/app/services/model_load/__init__.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
"""Initialization file for model load service module."""
|
||||||
|
|
||||||
|
from .model_load_base import ModelLoadServiceBase
|
||||||
|
from .model_load_default import ModelLoadService
|
||||||
|
|
||||||
|
__all__ = ["ModelLoadServiceBase", "ModelLoadService"]
|
22
invokeai/app/services/model_load/model_load_base.py
Normal file
22
invokeai/app/services/model_load/model_load_base.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team
|
||||||
|
"""Base class for model loader."""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from invokeai.backend.model_manager import AnyModelConfig, SubModelType
|
||||||
|
from invokeai.backend.model_manager.load import LoadedModel
|
||||||
|
|
||||||
|
|
||||||
|
class ModelLoadServiceBase(ABC):
|
||||||
|
"""Wrapper around AnyModelLoader."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load_model_by_key(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||||
|
"""Given a model's key, load it and return the LoadedModel object."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load_model_by_config(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||||
|
"""Given a model's configuration, load it and return the LoadedModel object."""
|
||||||
|
pass
|
54
invokeai/app/services/model_load/model_load_default.py
Normal file
54
invokeai/app/services/model_load/model_load_default.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team
|
||||||
|
"""Implementation of model loader service."""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||||
|
from invokeai.backend.model_manager import AnyModelConfig, SubModelType
|
||||||
|
from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel, ModelCache, ModelConvertCache
|
||||||
|
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||||
|
from invokeai.backend.model_manager.load.ram_cache import ModelCacheBase
|
||||||
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
|
from .model_load_base import ModelLoadServiceBase
|
||||||
|
|
||||||
|
|
||||||
|
class ModelLoadService(ModelLoadServiceBase):
|
||||||
|
"""Wrapper around AnyModelLoader."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
app_config: InvokeAIAppConfig,
|
||||||
|
record_store: ModelRecordServiceBase,
|
||||||
|
ram_cache: Optional[ModelCacheBase] = None,
|
||||||
|
convert_cache: Optional[ModelConvertCacheBase] = None,
|
||||||
|
):
|
||||||
|
"""Initialize the model load service."""
|
||||||
|
logger = InvokeAILogger.get_logger(self.__class__.__name__)
|
||||||
|
logger.setLevel(app_config.log_level.upper())
|
||||||
|
self._store = record_store
|
||||||
|
self._any_loader = AnyModelLoader(
|
||||||
|
app_config=app_config,
|
||||||
|
logger=logger,
|
||||||
|
ram_cache=ram_cache
|
||||||
|
or ModelCache(
|
||||||
|
max_cache_size=app_config.ram_cache_size,
|
||||||
|
max_vram_cache_size=app_config.vram_cache_size,
|
||||||
|
logger=logger,
|
||||||
|
),
|
||||||
|
convert_cache=convert_cache
|
||||||
|
or ModelConvertCache(
|
||||||
|
cache_path=app_config.models_convert_cache_path,
|
||||||
|
max_size=app_config.convert_cache_size,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_model_by_key(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||||
|
"""Given a model's key, load it and return the LoadedModel object."""
|
||||||
|
config = self._store.get_model(key)
|
||||||
|
return self.load_model_by_config(config, submodel_type)
|
||||||
|
|
||||||
|
def load_model_by_config(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||||
|
"""Given a model's configuration, load it and return the LoadedModel object."""
|
||||||
|
return self._any_loader.load_model(config, submodel_type)
|
@ -1 +1,16 @@
|
|||||||
from .model_manager_default import ModelManagerService # noqa F401
|
"""Initialization file for model manager service."""
|
||||||
|
|
||||||
|
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType
|
||||||
|
from invokeai.backend.model_manager.load import LoadedModel
|
||||||
|
|
||||||
|
from .model_manager_default import ModelManagerService
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ModelManagerService",
|
||||||
|
"AnyModel",
|
||||||
|
"AnyModelConfig",
|
||||||
|
"BaseModelType",
|
||||||
|
"ModelType",
|
||||||
|
"SubModelType",
|
||||||
|
"LoadedModel",
|
||||||
|
]
|
||||||
|
@ -1,286 +1,39 @@
|
|||||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from logging import Logger
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Tuple, Union
|
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import BaseModel, Field
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
from ..config import InvokeAIAppConfig
|
||||||
from invokeai.backend.model_management import (
|
from ..events.events_base import EventServiceBase
|
||||||
AddModelResult,
|
from ..download import DownloadQueueServiceBase
|
||||||
BaseModelType,
|
from ..model_install import ModelInstallServiceBase
|
||||||
MergeInterpolationMethod,
|
from ..model_load import ModelLoadServiceBase
|
||||||
ModelInfo,
|
from ..model_records import ModelRecordServiceBase
|
||||||
ModelType,
|
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
||||||
SchedulerPredictionType,
|
|
||||||
SubModelType,
|
|
||||||
)
|
|
||||||
from invokeai.backend.model_management.model_cache import CacheStats
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, InvocationContext
|
|
||||||
|
|
||||||
|
|
||||||
class ModelManagerServiceBase(ABC):
|
class ModelManagerServiceBase(BaseModel, ABC):
|
||||||
"""Responsible for managing models on disk and in memory"""
|
"""Abstract base class for the model manager service."""
|
||||||
|
|
||||||
|
store: ModelRecordServiceBase = Field(description="An instance of the model record configuration service.")
|
||||||
|
install: ModelInstallServiceBase = Field(description="An instance of the model install service.")
|
||||||
|
load: ModelLoadServiceBase = Field(description="An instance of the model load service.")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __init__(
|
def build_model_manager(
|
||||||
self,
|
cls,
|
||||||
config: InvokeAIAppConfig,
|
app_config: InvokeAIAppConfig,
|
||||||
logger: Logger,
|
db: SqliteDatabase,
|
||||||
):
|
download_queue: DownloadQueueServiceBase,
|
||||||
|
events: EventServiceBase,
|
||||||
|
) -> Self:
|
||||||
"""
|
"""
|
||||||
Initialize with the path to the models.yaml config file.
|
Construct the model manager service instance.
|
||||||
Optional parameters are the torch device type, precision, max_models,
|
|
||||||
and sequential_offload boolean. Note that the default device
|
Use it rather than the __init__ constructor. This class
|
||||||
type and precision are set up for a CUDA system running at half precision.
|
method simplifies the construction considerably.
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_model(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: ModelType,
|
|
||||||
submodel: Optional[SubModelType] = None,
|
|
||||||
node: Optional[BaseInvocation] = None,
|
|
||||||
context: Optional[InvocationContext] = None,
|
|
||||||
) -> ModelInfo:
|
|
||||||
"""Retrieve the indicated model with name and type.
|
|
||||||
submodel can be used to get a part (such as the vae)
|
|
||||||
of a diffusers pipeline."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def logger(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def model_exists(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: ModelType,
|
|
||||||
) -> bool:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
|
||||||
"""
|
|
||||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
|
||||||
Uses the exact format as the omegaconf stanza.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def list_models(self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None) -> dict:
|
|
||||||
"""
|
|
||||||
Return a dict of models in the format:
|
|
||||||
{ model_type1:
|
|
||||||
{ model_name1: {'status': 'active'|'cached'|'not loaded',
|
|
||||||
'model_name' : name,
|
|
||||||
'model_type' : SDModelType,
|
|
||||||
'description': description,
|
|
||||||
'format': 'folder'|'safetensors'|'ckpt'
|
|
||||||
},
|
|
||||||
model_name2: { etc }
|
|
||||||
},
|
|
||||||
model_type2:
|
|
||||||
{ model_name_n: etc
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
|
||||||
"""
|
|
||||||
Return information about the model using the same format as list_models()
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
|
||||||
"""
|
|
||||||
Returns a list of all the model names known.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def add_model(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: ModelType,
|
|
||||||
model_attributes: dict,
|
|
||||||
clobber: bool = False,
|
|
||||||
) -> AddModelResult:
|
|
||||||
"""
|
|
||||||
Update the named model with a dictionary of attributes. Will fail with an
|
|
||||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
|
||||||
On a successful update, the config will be changed in memory. Will fail
|
|
||||||
with an assertion error if provided attributes are incorrect or
|
|
||||||
the model name is missing. Call commit() to write changes to disk.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def update_model(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: ModelType,
|
|
||||||
model_attributes: dict,
|
|
||||||
) -> AddModelResult:
|
|
||||||
"""
|
|
||||||
Update the named model with a dictionary of attributes. Will fail with a
|
|
||||||
ModelNotFoundException if the name does not already exist.
|
|
||||||
|
|
||||||
On a successful update, the config will be changed in memory. Will fail
|
|
||||||
with an assertion error if provided attributes are incorrect or
|
|
||||||
the model name is missing. Call commit() to write changes to disk.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def del_model(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: ModelType,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Delete the named model from configuration. If delete_files is true,
|
|
||||||
then the underlying weight file or diffusers directory will be deleted
|
|
||||||
as well. Call commit() to write to disk.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def rename_model(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: ModelType,
|
|
||||||
new_name: str,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Rename the indicated model.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def list_checkpoint_configs(self) -> List[Path]:
|
|
||||||
"""
|
|
||||||
List the checkpoint config paths from ROOT/configs/stable-diffusion.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def convert_model(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: Literal[ModelType.Main, ModelType.Vae],
|
|
||||||
) -> AddModelResult:
|
|
||||||
"""
|
|
||||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
|
||||||
version and deleting the original checkpoint file if it is in the models
|
|
||||||
directory.
|
|
||||||
:param model_name: Name of the model to convert
|
|
||||||
:param base_model: Base model type
|
|
||||||
:param model_type: Type of model ['vae' or 'main']
|
|
||||||
|
|
||||||
This will raise a ValueError unless the model is not a checkpoint. It will
|
|
||||||
also raise a ValueError in the event that there is a similarly-named diffusers
|
|
||||||
directory already in place.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def heuristic_import(
|
|
||||||
self,
|
|
||||||
items_to_import: set[str],
|
|
||||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
|
||||||
) -> dict[str, AddModelResult]:
|
|
||||||
"""Import a list of paths, repo_ids or URLs. Returns the set of
|
|
||||||
successfully imported items.
|
|
||||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
|
||||||
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
|
||||||
|
|
||||||
The prediction type helper is necessary to distinguish between
|
|
||||||
models based on Stable Diffusion 2 Base (requiring
|
|
||||||
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
|
|
||||||
(requiring SchedulerPredictionType.VPrediction). It is
|
|
||||||
generally impossible to do this programmatically, so the
|
|
||||||
prediction_type_helper usually asks the user to choose.
|
|
||||||
|
|
||||||
The result is a set of successfully installed models. Each element
|
|
||||||
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
|
||||||
that model.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def merge_models(
|
|
||||||
self,
|
|
||||||
model_names: List[str] = Field(
|
|
||||||
default=None, min_length=2, max_length=3, description="List of model names to merge"
|
|
||||||
),
|
|
||||||
base_model: Union[BaseModelType, str] = Field(
|
|
||||||
default=None, description="Base model shared by all models to be merged"
|
|
||||||
),
|
|
||||||
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
|
||||||
alpha: Optional[float] = 0.5,
|
|
||||||
interp: Optional[MergeInterpolationMethod] = None,
|
|
||||||
force: Optional[bool] = False,
|
|
||||||
merge_dest_directory: Optional[Path] = None,
|
|
||||||
) -> AddModelResult:
|
|
||||||
"""
|
|
||||||
Merge two to three diffusrs pipeline models and save as a new model.
|
|
||||||
:param model_names: List of 2-3 models to merge
|
|
||||||
:param base_model: Base model to use for all models
|
|
||||||
:param merged_model_name: Name of destination merged model
|
|
||||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
|
||||||
:param interp: Interpolation method. None (default)
|
|
||||||
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def search_for_models(self, directory: Path) -> List[Path]:
|
|
||||||
"""
|
|
||||||
Return list of all models found in the designated directory.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def sync_to_config(self):
|
|
||||||
"""
|
|
||||||
Re-read models.yaml, rescan the models directory, and reimport models
|
|
||||||
in the autoimport directories. Call after making changes outside the
|
|
||||||
model manager API.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def collect_cache_stats(self, cache_stats: CacheStats):
|
|
||||||
"""
|
|
||||||
Reset model cache statistics for graph with graph_id.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def commit(self, conf_file: Optional[Path] = None) -> None:
|
|
||||||
"""
|
|
||||||
Write current configuration out to the indicated file.
|
|
||||||
If no conf_file is provided, then replaces the
|
|
||||||
original file/database used to initialize the object.
|
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
@ -1,413 +1,67 @@
|
|||||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||||
|
"""Implementation of ModelManagerServiceBase."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from typing_extensions import Self
|
||||||
|
|
||||||
from logging import Logger
|
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache
|
||||||
from pathlib import Path
|
from invokeai.backend.model_manager.metadata import ModelMetadataStore
|
||||||
from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Tuple, Union
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
import torch
|
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
|
||||||
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException
|
|
||||||
from invokeai.backend.model_management import (
|
|
||||||
AddModelResult,
|
|
||||||
BaseModelType,
|
|
||||||
MergeInterpolationMethod,
|
|
||||||
ModelInfo,
|
|
||||||
ModelManager,
|
|
||||||
ModelMerger,
|
|
||||||
ModelNotFoundException,
|
|
||||||
ModelType,
|
|
||||||
SchedulerPredictionType,
|
|
||||||
SubModelType,
|
|
||||||
)
|
|
||||||
from invokeai.backend.model_management.model_cache import CacheStats
|
|
||||||
from invokeai.backend.model_management.model_search import FindModels
|
|
||||||
from invokeai.backend.util import choose_precision, choose_torch_device
|
|
||||||
|
|
||||||
|
from ..config import InvokeAIAppConfig
|
||||||
|
from ..download import DownloadQueueServiceBase
|
||||||
|
from ..events.events_base import EventServiceBase
|
||||||
|
from ..model_install import ModelInstallService
|
||||||
|
from ..model_load import ModelLoadService
|
||||||
|
from ..model_records import ModelRecordServiceSQL
|
||||||
|
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
||||||
from .model_manager_base import ModelManagerServiceBase
|
from .model_manager_base import ModelManagerServiceBase
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from invokeai.app.invocations.baseinvocation import InvocationContext
|
|
||||||
|
|
||||||
|
|
||||||
# simple implementation
|
|
||||||
class ModelManagerService(ModelManagerServiceBase):
|
class ModelManagerService(ModelManagerServiceBase):
|
||||||
"""Responsible for managing models on disk and in memory"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: InvokeAIAppConfig,
|
|
||||||
logger: Logger,
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Initialize with the path to the models.yaml config file.
|
The ModelManagerService handles various aspects of model installation, maintenance and loading.
|
||||||
Optional parameters are the torch device type, precision, max_models,
|
|
||||||
and sequential_offload boolean. Note that the default device
|
It bundles three distinct services:
|
||||||
type and precision are set up for a CUDA system running at half precision.
|
model_manager.store -- Routines to manage the database of model configuration records.
|
||||||
|
model_manager.install -- Routines to install, move and delete models.
|
||||||
|
model_manager.load -- Routines to load models into memory.
|
||||||
"""
|
"""
|
||||||
if config.model_conf_path and config.model_conf_path.exists():
|
|
||||||
config_file = config.model_conf_path
|
|
||||||
else:
|
|
||||||
config_file = config.root_dir / "configs/models.yaml"
|
|
||||||
|
|
||||||
logger.debug(f"Config file={config_file}")
|
@classmethod
|
||||||
|
def build_model_manager(
|
||||||
|
cls,
|
||||||
|
app_config: InvokeAIAppConfig,
|
||||||
|
db: SqliteDatabase,
|
||||||
|
download_queue: DownloadQueueServiceBase,
|
||||||
|
events: EventServiceBase,
|
||||||
|
) -> Self:
|
||||||
|
"""
|
||||||
|
Construct the model manager service instance.
|
||||||
|
|
||||||
device = torch.device(choose_torch_device())
|
For simplicity, use this class method rather than the __init__ constructor.
|
||||||
device_name = torch.cuda.get_device_name() if device == torch.device("cuda") else ""
|
"""
|
||||||
logger.info(f"GPU device = {device} {device_name}")
|
logger = InvokeAILogger.get_logger(cls.__name__)
|
||||||
|
logger.setLevel(app_config.log_level.upper())
|
||||||
|
|
||||||
precision = config.precision
|
ram_cache = ModelCache(
|
||||||
if precision == "auto":
|
max_cache_size=app_config.ram_cache_size, max_vram_cache_size=app_config.vram_cache_size, logger=logger
|
||||||
precision = choose_precision(device)
|
|
||||||
dtype = torch.float32 if precision == "float32" else torch.float16
|
|
||||||
|
|
||||||
# this is transitional backward compatibility
|
|
||||||
# support for the deprecated `max_loaded_models`
|
|
||||||
# configuration value. If present, then the
|
|
||||||
# cache size is set to 2.5 GB times
|
|
||||||
# the number of max_loaded_models. Otherwise
|
|
||||||
# use new `ram_cache_size` config setting
|
|
||||||
max_cache_size = config.ram_cache_size
|
|
||||||
|
|
||||||
logger.debug(f"Maximum RAM cache size: {max_cache_size} GiB")
|
|
||||||
|
|
||||||
sequential_offload = config.sequential_guidance
|
|
||||||
|
|
||||||
self.mgr = ModelManager(
|
|
||||||
config=config_file,
|
|
||||||
device_type=device,
|
|
||||||
precision=dtype,
|
|
||||||
max_cache_size=max_cache_size,
|
|
||||||
sequential_offload=sequential_offload,
|
|
||||||
logger=logger,
|
|
||||||
)
|
)
|
||||||
logger.info("Model manager service initialized")
|
convert_cache = ModelConvertCache(
|
||||||
|
cache_path=app_config.models_convert_cache_path, max_size=app_config.convert_cache_size
|
||||||
def get_model(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: ModelType,
|
|
||||||
submodel: Optional[SubModelType] = None,
|
|
||||||
context: Optional[InvocationContext] = None,
|
|
||||||
) -> ModelInfo:
|
|
||||||
"""
|
|
||||||
Retrieve the indicated model. submodel can be used to get a
|
|
||||||
part (such as the vae) of a diffusers mode.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# we can emit model loading events if we are executing with access to the invocation context
|
|
||||||
if context:
|
|
||||||
self._emit_load_event(
|
|
||||||
context=context,
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=model_type,
|
|
||||||
submodel=submodel,
|
|
||||||
)
|
)
|
||||||
|
record_store = ModelRecordServiceSQL(db=db)
|
||||||
model_info = self.mgr.get_model(
|
loader = ModelLoadService(
|
||||||
model_name,
|
app_config=app_config,
|
||||||
base_model,
|
record_store=record_store,
|
||||||
model_type,
|
ram_cache=ram_cache,
|
||||||
submodel,
|
convert_cache=convert_cache,
|
||||||
)
|
)
|
||||||
|
record_store._loader = loader # yeah, there is a circular reference here
|
||||||
if context:
|
installer = ModelInstallService(
|
||||||
self._emit_load_event(
|
app_config=app_config,
|
||||||
context=context,
|
record_store=record_store,
|
||||||
model_name=model_name,
|
download_queue=download_queue,
|
||||||
base_model=base_model,
|
metadata_store=ModelMetadataStore(db=db),
|
||||||
model_type=model_type,
|
event_bus=events,
|
||||||
submodel=submodel,
|
|
||||||
model_info=model_info,
|
|
||||||
)
|
|
||||||
|
|
||||||
return model_info
|
|
||||||
|
|
||||||
def model_exists(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: ModelType,
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Given a model name, returns True if it is a valid
|
|
||||||
identifier.
|
|
||||||
"""
|
|
||||||
return self.mgr.model_exists(
|
|
||||||
model_name,
|
|
||||||
base_model,
|
|
||||||
model_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]:
|
|
||||||
"""
|
|
||||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
|
||||||
"""
|
|
||||||
return self.mgr.model_info(model_name, base_model, model_type)
|
|
||||||
|
|
||||||
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
|
||||||
"""
|
|
||||||
Returns a list of all the model names known.
|
|
||||||
"""
|
|
||||||
return self.mgr.model_names()
|
|
||||||
|
|
||||||
def list_models(
|
|
||||||
self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None
|
|
||||||
) -> list[dict]:
|
|
||||||
"""
|
|
||||||
Return a list of models.
|
|
||||||
"""
|
|
||||||
return self.mgr.list_models(base_model, model_type)
|
|
||||||
|
|
||||||
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]:
|
|
||||||
"""
|
|
||||||
Return information about the model using the same format as list_models()
|
|
||||||
"""
|
|
||||||
return self.mgr.list_model(model_name=model_name, base_model=base_model, model_type=model_type)
|
|
||||||
|
|
||||||
def add_model(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: ModelType,
|
|
||||||
model_attributes: dict,
|
|
||||||
clobber: bool = False,
|
|
||||||
) -> AddModelResult:
|
|
||||||
"""
|
|
||||||
Update the named model with a dictionary of attributes. Will fail with an
|
|
||||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
|
||||||
On a successful update, the config will be changed in memory. Will fail
|
|
||||||
with an assertion error if provided attributes are incorrect or
|
|
||||||
the model name is missing. Call commit() to write changes to disk.
|
|
||||||
"""
|
|
||||||
self.logger.debug(f"add/update model {model_name}")
|
|
||||||
return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
|
|
||||||
|
|
||||||
def update_model(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: ModelType,
|
|
||||||
model_attributes: dict,
|
|
||||||
) -> AddModelResult:
|
|
||||||
"""
|
|
||||||
Update the named model with a dictionary of attributes. Will fail with a
|
|
||||||
ModelNotFoundException exception if the name does not already exist.
|
|
||||||
On a successful update, the config will be changed in memory. Will fail
|
|
||||||
with an assertion error if provided attributes are incorrect or
|
|
||||||
the model name is missing. Call commit() to write changes to disk.
|
|
||||||
"""
|
|
||||||
self.logger.debug(f"update model {model_name}")
|
|
||||||
if not self.model_exists(model_name, base_model, model_type):
|
|
||||||
raise ModelNotFoundException(f"Unknown model {model_name}")
|
|
||||||
return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True)
|
|
||||||
|
|
||||||
def del_model(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: ModelType,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Delete the named model from configuration. If delete_files is true,
|
|
||||||
then the underlying weight file or diffusers directory will be deleted
|
|
||||||
as well.
|
|
||||||
"""
|
|
||||||
self.logger.debug(f"delete model {model_name}")
|
|
||||||
self.mgr.del_model(model_name, base_model, model_type)
|
|
||||||
self.mgr.commit()
|
|
||||||
|
|
||||||
def convert_model(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: Literal[ModelType.Main, ModelType.Vae],
|
|
||||||
convert_dest_directory: Optional[Path] = Field(
|
|
||||||
default=None, description="Optional directory location for merged model"
|
|
||||||
),
|
|
||||||
) -> AddModelResult:
|
|
||||||
"""
|
|
||||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
|
||||||
version and deleting the original checkpoint file if it is in the models
|
|
||||||
directory.
|
|
||||||
:param model_name: Name of the model to convert
|
|
||||||
:param base_model: Base model type
|
|
||||||
:param model_type: Type of model ['vae' or 'main']
|
|
||||||
:param convert_dest_directory: Save the converted model to the designated directory (`models/etc/etc` by default)
|
|
||||||
|
|
||||||
This will raise a ValueError unless the model is not a checkpoint. It will
|
|
||||||
also raise a ValueError in the event that there is a similarly-named diffusers
|
|
||||||
directory already in place.
|
|
||||||
"""
|
|
||||||
self.logger.debug(f"convert model {model_name}")
|
|
||||||
return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory)
|
|
||||||
|
|
||||||
def collect_cache_stats(self, cache_stats: CacheStats):
|
|
||||||
"""
|
|
||||||
Reset model cache statistics for graph with graph_id.
|
|
||||||
"""
|
|
||||||
self.mgr.cache.stats = cache_stats
|
|
||||||
|
|
||||||
def commit(self, conf_file: Optional[Path] = None):
|
|
||||||
"""
|
|
||||||
Write current configuration out to the indicated file.
|
|
||||||
If no conf_file is provided, then replaces the
|
|
||||||
original file/database used to initialize the object.
|
|
||||||
"""
|
|
||||||
return self.mgr.commit(conf_file)
|
|
||||||
|
|
||||||
def _emit_load_event(
|
|
||||||
self,
|
|
||||||
context: InvocationContext,
|
|
||||||
model_name: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: ModelType,
|
|
||||||
submodel: Optional[SubModelType] = None,
|
|
||||||
model_info: Optional[ModelInfo] = None,
|
|
||||||
):
|
|
||||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
|
||||||
raise CanceledException()
|
|
||||||
|
|
||||||
if model_info:
|
|
||||||
context.services.events.emit_model_load_completed(
|
|
||||||
queue_id=context.queue_id,
|
|
||||||
queue_item_id=context.queue_item_id,
|
|
||||||
queue_batch_id=context.queue_batch_id,
|
|
||||||
graph_execution_state_id=context.graph_execution_state_id,
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=model_type,
|
|
||||||
submodel=submodel,
|
|
||||||
model_info=model_info,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
context.services.events.emit_model_load_started(
|
|
||||||
queue_id=context.queue_id,
|
|
||||||
queue_item_id=context.queue_item_id,
|
|
||||||
queue_batch_id=context.queue_batch_id,
|
|
||||||
graph_execution_state_id=context.graph_execution_state_id,
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=model_type,
|
|
||||||
submodel=submodel,
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def logger(self):
|
|
||||||
return self.mgr.logger
|
|
||||||
|
|
||||||
def heuristic_import(
|
|
||||||
self,
|
|
||||||
items_to_import: set[str],
|
|
||||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
|
||||||
) -> dict[str, AddModelResult]:
|
|
||||||
"""Import a list of paths, repo_ids or URLs. Returns the set of
|
|
||||||
successfully imported items.
|
|
||||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
|
||||||
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
|
||||||
|
|
||||||
The prediction type helper is necessary to distinguish between
|
|
||||||
models based on Stable Diffusion 2 Base (requiring
|
|
||||||
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
|
|
||||||
(requiring SchedulerPredictionType.VPrediction). It is
|
|
||||||
generally impossible to do this programmatically, so the
|
|
||||||
prediction_type_helper usually asks the user to choose.
|
|
||||||
|
|
||||||
The result is a set of successfully installed models. Each element
|
|
||||||
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
|
||||||
that model.
|
|
||||||
"""
|
|
||||||
return self.mgr.heuristic_import(items_to_import, prediction_type_helper)
|
|
||||||
|
|
||||||
def merge_models(
|
|
||||||
self,
|
|
||||||
model_names: List[str] = Field(
|
|
||||||
default=None, min_length=2, max_length=3, description="List of model names to merge"
|
|
||||||
),
|
|
||||||
base_model: Union[BaseModelType, str] = Field(
|
|
||||||
default=None, description="Base model shared by all models to be merged"
|
|
||||||
),
|
|
||||||
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
|
||||||
alpha: float = 0.5,
|
|
||||||
interp: Optional[MergeInterpolationMethod] = None,
|
|
||||||
force: bool = False,
|
|
||||||
merge_dest_directory: Optional[Path] = Field(
|
|
||||||
default=None, description="Optional directory location for merged model"
|
|
||||||
),
|
|
||||||
) -> AddModelResult:
|
|
||||||
"""
|
|
||||||
Merge two to three diffusrs pipeline models and save as a new model.
|
|
||||||
:param model_names: List of 2-3 models to merge
|
|
||||||
:param base_model: Base model to use for all models
|
|
||||||
:param merged_model_name: Name of destination merged model
|
|
||||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
|
||||||
:param interp: Interpolation method. None (default)
|
|
||||||
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
|
||||||
"""
|
|
||||||
merger = ModelMerger(self.mgr)
|
|
||||||
try:
|
|
||||||
result = merger.merge_diffusion_models_and_save(
|
|
||||||
model_names=model_names,
|
|
||||||
base_model=base_model,
|
|
||||||
merged_model_name=merged_model_name,
|
|
||||||
alpha=alpha,
|
|
||||||
interp=interp,
|
|
||||||
force=force,
|
|
||||||
merge_dest_directory=merge_dest_directory,
|
|
||||||
)
|
|
||||||
except AssertionError as e:
|
|
||||||
raise ValueError(e)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def search_for_models(self, directory: Path) -> List[Path]:
|
|
||||||
"""
|
|
||||||
Return list of all models found in the designated directory.
|
|
||||||
"""
|
|
||||||
search = FindModels([directory], self.logger)
|
|
||||||
return search.list_models()
|
|
||||||
|
|
||||||
def sync_to_config(self):
|
|
||||||
"""
|
|
||||||
Re-read models.yaml, rescan the models directory, and reimport models
|
|
||||||
in the autoimport directories. Call after making changes outside the
|
|
||||||
model manager API.
|
|
||||||
"""
|
|
||||||
return self.mgr.sync_to_config()
|
|
||||||
|
|
||||||
def list_checkpoint_configs(self) -> List[Path]:
|
|
||||||
"""
|
|
||||||
List the checkpoint config paths from ROOT/configs/stable-diffusion.
|
|
||||||
"""
|
|
||||||
config = self.mgr.app_config
|
|
||||||
conf_path = config.legacy_conf_path
|
|
||||||
root_path = config.root_path
|
|
||||||
return [(conf_path / x).relative_to(root_path) for x in conf_path.glob("**/*.yaml")]
|
|
||||||
|
|
||||||
def rename_model(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: ModelType,
|
|
||||||
new_name: Optional[str] = None,
|
|
||||||
new_base: Optional[BaseModelType] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Rename the indicated model. Can provide a new name and/or a new base.
|
|
||||||
:param model_name: Current name of the model
|
|
||||||
:param base_model: Current base of the model
|
|
||||||
:param model_type: Model type (can't be changed)
|
|
||||||
:param new_name: New name for the model
|
|
||||||
:param new_base: New base for the model
|
|
||||||
"""
|
|
||||||
self.mgr.rename_model(
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=model_type,
|
|
||||||
model_name=model_name,
|
|
||||||
new_name=new_name,
|
|
||||||
new_base=new_base,
|
|
||||||
)
|
)
|
||||||
|
return cls(store=record_store, install=installer, load=loader)
|
||||||
|
@ -1,5 +1,3 @@
|
|||||||
"""
|
"""
|
||||||
Initialization file for invokeai.backend
|
Initialization file for invokeai.backend
|
||||||
"""
|
"""
|
||||||
from .model_management import BaseModelType, ModelCache, ModelInfo, ModelManager, ModelType, SubModelType # noqa: F401
|
|
||||||
from .model_management.models import SilenceWarnings # noqa: F401
|
|
||||||
|
@ -21,7 +21,7 @@ Validation errors will raise an InvalidModelConfigException error.
|
|||||||
"""
|
"""
|
||||||
import time
|
import time
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Literal, Optional, Type, Union
|
from typing import Literal, Optional, Type, Union, Class
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers import ModelMixin
|
from diffusers import ModelMixin
|
||||||
@ -333,9 +333,9 @@ class ModelConfigFactory(object):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def make_config(
|
def make_config(
|
||||||
cls,
|
cls,
|
||||||
model_data: Union[dict, AnyModelConfig],
|
model_data: Union[Dict[str, Any], AnyModelConfig],
|
||||||
key: Optional[str] = None,
|
key: Optional[str] = None,
|
||||||
dest_class: Optional[Type] = None,
|
dest_class: Optional[Type[Class]] = None,
|
||||||
timestamp: Optional[float] = None,
|
timestamp: Optional[float] = None,
|
||||||
) -> AnyModelConfig:
|
) -> AnyModelConfig:
|
||||||
"""
|
"""
|
||||||
|
@ -18,7 +18,7 @@ loaders = [x.stem for x in Path(Path(__file__).parent, "model_loaders").glob("*.
|
|||||||
for module in loaders:
|
for module in loaders:
|
||||||
import_module(f"{__package__}.model_loaders.{module}")
|
import_module(f"{__package__}.model_loaders.{module}")
|
||||||
|
|
||||||
__all__ = ["AnyModelLoader", "LoadedModel"]
|
__all__ = ["AnyModelLoader", "LoadedModel", "ModelCache", "ModelConvertCache"]
|
||||||
|
|
||||||
|
|
||||||
def get_standalone_loader(app_config: Optional[InvokeAIAppConfig]) -> AnyModelLoader:
|
def get_standalone_loader(app_config: Optional[InvokeAIAppConfig]) -> AnyModelLoader:
|
||||||
|
@ -26,10 +26,10 @@ from pathlib import Path
|
|||||||
from typing import Callable, Optional, Set, Union
|
from typing import Callable, Optional, Set, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
from logging import Logger
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
default_logger = InvokeAILogger.get_logger()
|
default_logger: Logger = InvokeAILogger.get_logger()
|
||||||
|
|
||||||
|
|
||||||
class SearchStats(BaseModel):
|
class SearchStats(BaseModel):
|
||||||
@ -56,7 +56,7 @@ class ModelSearchBase(ABC, BaseModel):
|
|||||||
on_model_found : Optional[Callable[[Path], bool]] = Field(default=None, description="Called when a model is found.") # noqa E221
|
on_model_found : Optional[Callable[[Path], bool]] = Field(default=None, description="Called when a model is found.") # noqa E221
|
||||||
on_search_completed : Optional[Callable[[Set[Path]], None]] = Field(default=None, description="Called when search is complete.") # noqa E221
|
on_search_completed : Optional[Callable[[Set[Path]], None]] = Field(default=None, description="Called when search is complete.") # noqa E221
|
||||||
stats : SearchStats = Field(default_factory=SearchStats, description="Summary statistics after search") # noqa E221
|
stats : SearchStats = Field(default_factory=SearchStats, description="Summary statistics after search") # noqa E221
|
||||||
logger : InvokeAILogger = Field(default=default_logger, description="Logger instance.") # noqa E221
|
logger : Logger = Field(default=default_logger, description="Logger instance.") # noqa E221
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
@ -128,13 +128,13 @@ class ModelSearch(ModelSearchBase):
|
|||||||
|
|
||||||
def model_found(self, model: Path) -> None:
|
def model_found(self, model: Path) -> None:
|
||||||
self.stats.models_found += 1
|
self.stats.models_found += 1
|
||||||
if not self.on_model_found or self.on_model_found(model):
|
if self.on_model_found is None or self.on_model_found(model):
|
||||||
self.stats.models_filtered += 1
|
self.stats.models_filtered += 1
|
||||||
self.models_found.add(model)
|
self.models_found.add(model)
|
||||||
|
|
||||||
def search_completed(self) -> None:
|
def search_completed(self) -> None:
|
||||||
if self.on_search_completed:
|
if self.on_search_completed is not None:
|
||||||
self.on_search_completed(self._models_found)
|
self.on_search_completed(self.models_found)
|
||||||
|
|
||||||
def search(self, directory: Union[Path, str]) -> Set[Path]:
|
def search(self, directory: Union[Path, str]) -> Set[Path]:
|
||||||
self._directory = Path(directory)
|
self._directory = Path(directory)
|
||||||
|
Loading…
Reference in New Issue
Block a user