consolidate model manager parts into a single class

This commit is contained in:
Lincoln Stein 2024-02-09 23:08:38 -05:00 committed by Brandon Rising
parent 8f1b7355df
commit 721ff58e44
10 changed files with 186 additions and 696 deletions

View File

@ -0,0 +1,6 @@
"""Initialization file for model load service module."""
from .model_load_base import ModelLoadServiceBase
from .model_load_default import ModelLoadService
__all__ = ["ModelLoadServiceBase", "ModelLoadService"]

View File

@ -0,0 +1,22 @@
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team
"""Base class for model loader."""
from abc import ABC, abstractmethod
from typing import Optional
from invokeai.backend.model_manager import AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load import LoadedModel
class ModelLoadServiceBase(ABC):
"""Wrapper around AnyModelLoader."""
@abstractmethod
def load_model_by_key(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""Given a model's key, load it and return the LoadedModel object."""
pass
@abstractmethod
def load_model_by_config(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""Given a model's configuration, load it and return the LoadedModel object."""
pass

View File

@ -0,0 +1,54 @@
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team
"""Implementation of model loader service."""
from typing import Optional
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.backend.model_manager import AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel, ModelCache, ModelConvertCache
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.ram_cache import ModelCacheBase
from invokeai.backend.util.logging import InvokeAILogger
from .model_load_base import ModelLoadServiceBase
class ModelLoadService(ModelLoadServiceBase):
"""Wrapper around AnyModelLoader."""
def __init__(
self,
app_config: InvokeAIAppConfig,
record_store: ModelRecordServiceBase,
ram_cache: Optional[ModelCacheBase] = None,
convert_cache: Optional[ModelConvertCacheBase] = None,
):
"""Initialize the model load service."""
logger = InvokeAILogger.get_logger(self.__class__.__name__)
logger.setLevel(app_config.log_level.upper())
self._store = record_store
self._any_loader = AnyModelLoader(
app_config=app_config,
logger=logger,
ram_cache=ram_cache
or ModelCache(
max_cache_size=app_config.ram_cache_size,
max_vram_cache_size=app_config.vram_cache_size,
logger=logger,
),
convert_cache=convert_cache
or ModelConvertCache(
cache_path=app_config.models_convert_cache_path,
max_size=app_config.convert_cache_size,
),
)
def load_model_by_key(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""Given a model's key, load it and return the LoadedModel object."""
config = self._store.get_model(key)
return self.load_model_by_config(config, submodel_type)
def load_model_by_config(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""Given a model's configuration, load it and return the LoadedModel object."""
return self._any_loader.load_model(config, submodel_type)

View File

@ -1 +1,16 @@
from .model_manager_default import ModelManagerService # noqa F401
"""Initialization file for model manager service."""
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager.load import LoadedModel
from .model_manager_default import ModelManagerService
__all__ = [
"ModelManagerService",
"AnyModel",
"AnyModelConfig",
"BaseModelType",
"ModelType",
"SubModelType",
"LoadedModel",
]

View File

@ -1,283 +1,39 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
from __future__ import annotations
from abc import ABC, abstractmethod
from logging import Logger
from pathlib import Path
from typing import Callable, List, Literal, Optional, Tuple, Union
from pydantic import Field
from pydantic import BaseModel, Field
from typing_extensions import Self
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.shared.invocation_context import InvocationContextData
from invokeai.backend.model_management import (
AddModelResult,
BaseModelType,
LoadedModelInfo,
MergeInterpolationMethod,
ModelType,
SchedulerPredictionType,
SubModelType,
)
from invokeai.backend.model_management.model_cache import CacheStats
from ..config import InvokeAIAppConfig
from ..download import DownloadQueueServiceBase
from ..events.events_base import EventServiceBase
from ..model_install import ModelInstallServiceBase
from ..model_load import ModelLoadServiceBase
from ..model_records import ModelRecordServiceBase
from ..shared.sqlite.sqlite_database import SqliteDatabase
class ModelManagerServiceBase(ABC):
"""Responsible for managing models on disk and in memory"""
class ModelManagerServiceBase(BaseModel, ABC):
"""Abstract base class for the model manager service."""
store: ModelRecordServiceBase = Field(description="An instance of the model record configuration service.")
install: ModelInstallServiceBase = Field(description="An instance of the model install service.")
load: ModelLoadServiceBase = Field(description="An instance of the model load service.")
@classmethod
@abstractmethod
def __init__(
self,
config: InvokeAIAppConfig,
logger: Logger,
):
def build_model_manager(
cls,
app_config: InvokeAIAppConfig,
db: SqliteDatabase,
download_queue: DownloadQueueServiceBase,
events: EventServiceBase,
) -> Self:
"""
Initialize with the path to the models.yaml config file.
Optional parameters are the torch device type, precision, max_models,
and sequential_offload boolean. Note that the default device
type and precision are set up for a CUDA system running at half precision.
"""
pass
@abstractmethod
def get_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModelInfo:
"""Retrieve the indicated model with name and type.
submodel can be used to get a part (such as the vae)
of a diffusers pipeline."""
pass
@property
@abstractmethod
def logger(self):
pass
@abstractmethod
def model_exists(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
) -> bool:
pass
@abstractmethod
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
"""
Given a model name returns a dict-like (OmegaConf) object describing it.
Uses the exact format as the omegaconf stanza.
"""
pass
@abstractmethod
def list_models(self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None) -> dict:
"""
Return a dict of models in the format:
{ model_type1:
{ model_name1: {'status': 'active'|'cached'|'not loaded',
'model_name' : name,
'model_type' : SDModelType,
'description': description,
'format': 'folder'|'safetensors'|'ckpt'
},
model_name2: { etc }
},
model_type2:
{ model_name_n: etc
}
"""
pass
@abstractmethod
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
"""
Return information about the model using the same format as list_models()
"""
pass
@abstractmethod
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
"""
Returns a list of all the model names known.
"""
pass
@abstractmethod
def add_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
model_attributes: dict,
clobber: bool = False,
) -> AddModelResult:
"""
Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite.
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk.
"""
pass
@abstractmethod
def update_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
model_attributes: dict,
) -> AddModelResult:
"""
Update the named model with a dictionary of attributes. Will fail with a
ModelNotFoundException if the name does not already exist.
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk.
"""
pass
@abstractmethod
def del_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
):
"""
Delete the named model from configuration. If delete_files is true,
then the underlying weight file or diffusers directory will be deleted
as well. Call commit() to write to disk.
"""
pass
@abstractmethod
def rename_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
new_name: str,
):
"""
Rename the indicated model.
"""
pass
@abstractmethod
def list_checkpoint_configs(self) -> List[Path]:
"""
List the checkpoint config paths from ROOT/configs/stable-diffusion.
"""
pass
@abstractmethod
def convert_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: Literal[ModelType.Main, ModelType.Vae],
) -> AddModelResult:
"""
Convert a checkpoint file into a diffusers folder, deleting the cached
version and deleting the original checkpoint file if it is in the models
directory.
:param model_name: Name of the model to convert
:param base_model: Base model type
:param model_type: Type of model ['vae' or 'main']
This will raise a ValueError unless the model is not a checkpoint. It will
also raise a ValueError in the event that there is a similarly-named diffusers
directory already in place.
"""
pass
@abstractmethod
def heuristic_import(
self,
items_to_import: set[str],
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
) -> dict[str, AddModelResult]:
"""Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items.
:param items_to_import: Set of strings corresponding to models to be imported.
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
The prediction type helper is necessary to distinguish between
models based on Stable Diffusion 2 Base (requiring
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
(requiring SchedulerPredictionType.VPrediction). It is
generally impossible to do this programmatically, so the
prediction_type_helper usually asks the user to choose.
The result is a set of successfully installed models. Each element
of the set is a dict corresponding to the newly-created OmegaConf stanza for
that model.
"""
pass
@abstractmethod
def merge_models(
self,
model_names: List[str] = Field(
default=None, min_length=2, max_length=3, description="List of model names to merge"
),
base_model: Union[BaseModelType, str] = Field(
default=None, description="Base model shared by all models to be merged"
),
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
alpha: Optional[float] = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: Optional[bool] = False,
merge_dest_directory: Optional[Path] = None,
) -> AddModelResult:
"""
Merge two to three diffusrs pipeline models and save as a new model.
:param model_names: List of 2-3 models to merge
:param base_model: Base model to use for all models
:param merged_model_name: Name of destination merged model
:param alpha: Alpha strength to apply to 2d and 3d model
:param interp: Interpolation method. None (default)
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
"""
pass
@abstractmethod
def search_for_models(self, directory: Path) -> List[Path]:
"""
Return list of all models found in the designated directory.
"""
pass
@abstractmethod
def sync_to_config(self):
"""
Re-read models.yaml, rescan the models directory, and reimport models
in the autoimport directories. Call after making changes outside the
model manager API.
"""
pass
@abstractmethod
def collect_cache_stats(self, cache_stats: CacheStats):
"""
Reset model cache statistics for graph with graph_id.
"""
pass
@abstractmethod
def commit(self, conf_file: Optional[Path] = None) -> None:
"""
Write current configuration out to the indicated file.
If no conf_file is provided, then replaces the
original file/database used to initialize the object.
Construct the model manager service instance.
Use it rather than the __init__ constructor. This class
method simplifies the construction considerably.
"""
pass

View File

@ -1,421 +1,67 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
"""Implementation of ModelManagerServiceBase."""
from __future__ import annotations
from typing_extensions import Self
from logging import Logger
from pathlib import Path
from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Tuple, Union
import torch
from pydantic import Field
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.invocation_context import InvocationContextData
from invokeai.backend.model_management import (
AddModelResult,
BaseModelType,
LoadedModelInfo,
MergeInterpolationMethod,
ModelManager,
ModelMerger,
ModelNotFoundException,
ModelType,
SchedulerPredictionType,
SubModelType,
)
from invokeai.backend.model_management.model_cache import CacheStats
from invokeai.backend.model_management.model_search import FindModels
from invokeai.backend.util import choose_precision, choose_torch_device
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache
from invokeai.backend.model_manager.metadata import ModelMetadataStore
from invokeai.backend.util.logging import InvokeAILogger
from ..config import InvokeAIAppConfig
from ..download import DownloadQueueServiceBase
from ..events.events_base import EventServiceBase
from ..model_install import ModelInstallService
from ..model_load import ModelLoadService
from ..model_records import ModelRecordServiceSQL
from ..shared.sqlite.sqlite_database import SqliteDatabase
from .model_manager_base import ModelManagerServiceBase
if TYPE_CHECKING:
pass
# simple implementation
class ModelManagerService(ModelManagerServiceBase):
"""Responsible for managing models on disk and in memory"""
"""
The ModelManagerService handles various aspects of model installation, maintenance and loading.
def __init__(
self,
config: InvokeAIAppConfig,
logger: Logger,
):
It bundles three distinct services:
model_manager.store -- Routines to manage the database of model configuration records.
model_manager.install -- Routines to install, move and delete models.
model_manager.load -- Routines to load models into memory.
"""
@classmethod
def build_model_manager(
cls,
app_config: InvokeAIAppConfig,
db: SqliteDatabase,
download_queue: DownloadQueueServiceBase,
events: EventServiceBase,
) -> Self:
"""
Initialize with the path to the models.yaml config file.
Optional parameters are the torch device type, precision, max_models,
and sequential_offload boolean. Note that the default device
type and precision are set up for a CUDA system running at half precision.
Construct the model manager service instance.
For simplicity, use this class method rather than the __init__ constructor.
"""
if config.model_conf_path and config.model_conf_path.exists():
config_file = config.model_conf_path
else:
config_file = config.root_dir / "configs/models.yaml"
logger = InvokeAILogger.get_logger(cls.__name__)
logger.setLevel(app_config.log_level.upper())
logger.debug(f"Config file={config_file}")
device = torch.device(choose_torch_device())
device_name = torch.cuda.get_device_name() if device == torch.device("cuda") else ""
logger.info(f"GPU device = {device} {device_name}")
precision = config.precision
if precision == "auto":
precision = choose_precision(device)
dtype = torch.float32 if precision == "float32" else torch.float16
# this is transitional backward compatibility
# support for the deprecated `max_loaded_models`
# configuration value. If present, then the
# cache size is set to 2.5 GB times
# the number of max_loaded_models. Otherwise
# use new `ram_cache_size` config setting
max_cache_size = config.ram_cache_size
logger.debug(f"Maximum RAM cache size: {max_cache_size} GiB")
sequential_offload = config.sequential_guidance
self.mgr = ModelManager(
config=config_file,
device_type=device,
precision=dtype,
max_cache_size=max_cache_size,
sequential_offload=sequential_offload,
logger=logger,
ram_cache = ModelCache(
max_cache_size=app_config.ram_cache_size, max_vram_cache_size=app_config.vram_cache_size, logger=logger
)
logger.info("Model manager service initialized")
def start(self, invoker: Invoker) -> None:
self._invoker: Optional[Invoker] = invoker
def get_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModelInfo:
"""
Retrieve the indicated model. submodel can be used to get a
part (such as the vae) of a diffusers mode.
"""
# we can emit model loading events if we are executing with access to the invocation context
if context_data is not None:
self._emit_load_event(
context_data=context_data,
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
)
loaded_model_info = self.mgr.get_model(
model_name,
base_model,
model_type,
submodel,
convert_cache = ModelConvertCache(
cache_path=app_config.models_convert_cache_path, max_size=app_config.convert_cache_size
)
if context_data is not None:
self._emit_load_event(
context_data=context_data,
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
loaded_model_info=loaded_model_info,
)
return loaded_model_info
def model_exists(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
) -> bool:
"""
Given a model name, returns True if it is a valid
identifier.
"""
return self.mgr.model_exists(
model_name,
base_model,
model_type,
record_store = ModelRecordServiceSQL(db=db)
loader = ModelLoadService(
app_config=app_config,
record_store=record_store,
ram_cache=ram_cache,
convert_cache=convert_cache,
)
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]:
"""
Given a model name returns a dict-like (OmegaConf) object describing it.
"""
return self.mgr.model_info(model_name, base_model, model_type)
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
"""
Returns a list of all the model names known.
"""
return self.mgr.model_names()
def list_models(
self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None
) -> list[dict]:
"""
Return a list of models.
"""
return self.mgr.list_models(base_model, model_type)
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]:
"""
Return information about the model using the same format as list_models()
"""
return self.mgr.list_model(model_name=model_name, base_model=base_model, model_type=model_type)
def add_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
model_attributes: dict,
clobber: bool = False,
) -> AddModelResult:
"""
Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite.
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk.
"""
self.logger.debug(f"add/update model {model_name}")
return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
def update_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
model_attributes: dict,
) -> AddModelResult:
"""
Update the named model with a dictionary of attributes. Will fail with a
ModelNotFoundException exception if the name does not already exist.
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk.
"""
self.logger.debug(f"update model {model_name}")
if not self.model_exists(model_name, base_model, model_type):
raise ModelNotFoundException(f"Unknown model {model_name}")
return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True)
def del_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
):
"""
Delete the named model from configuration. If delete_files is true,
then the underlying weight file or diffusers directory will be deleted
as well.
"""
self.logger.debug(f"delete model {model_name}")
self.mgr.del_model(model_name, base_model, model_type)
self.mgr.commit()
def convert_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: Literal[ModelType.Main, ModelType.Vae],
convert_dest_directory: Optional[Path] = Field(
default=None, description="Optional directory location for merged model"
),
) -> AddModelResult:
"""
Convert a checkpoint file into a diffusers folder, deleting the cached
version and deleting the original checkpoint file if it is in the models
directory.
:param model_name: Name of the model to convert
:param base_model: Base model type
:param model_type: Type of model ['vae' or 'main']
:param convert_dest_directory: Save the converted model to the designated directory (`models/etc/etc` by default)
This will raise a ValueError unless the model is not a checkpoint. It will
also raise a ValueError in the event that there is a similarly-named diffusers
directory already in place.
"""
self.logger.debug(f"convert model {model_name}")
return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory)
def collect_cache_stats(self, cache_stats: CacheStats):
"""
Reset model cache statistics for graph with graph_id.
"""
self.mgr.cache.stats = cache_stats
def commit(self, conf_file: Optional[Path] = None):
"""
Write current configuration out to the indicated file.
If no conf_file is provided, then replaces the
original file/database used to initialize the object.
"""
return self.mgr.commit(conf_file)
def _emit_load_event(
self,
context_data: InvocationContextData,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
loaded_model_info: Optional[LoadedModelInfo] = None,
):
if self._invoker is None:
return
if self._invoker.services.queue.is_canceled(context_data.session_id):
raise CanceledException()
if loaded_model_info:
self._invoker.services.events.emit_model_load_completed(
queue_id=context_data.queue_id,
queue_item_id=context_data.queue_item_id,
queue_batch_id=context_data.batch_id,
graph_execution_state_id=context_data.session_id,
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
loaded_model_info=loaded_model_info,
)
else:
self._invoker.services.events.emit_model_load_started(
queue_id=context_data.queue_id,
queue_item_id=context_data.queue_item_id,
queue_batch_id=context_data.batch_id,
graph_execution_state_id=context_data.session_id,
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
)
@property
def logger(self):
return self.mgr.logger
def heuristic_import(
self,
items_to_import: set[str],
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
) -> dict[str, AddModelResult]:
"""Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items.
:param items_to_import: Set of strings corresponding to models to be imported.
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
The prediction type helper is necessary to distinguish between
models based on Stable Diffusion 2 Base (requiring
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
(requiring SchedulerPredictionType.VPrediction). It is
generally impossible to do this programmatically, so the
prediction_type_helper usually asks the user to choose.
The result is a set of successfully installed models. Each element
of the set is a dict corresponding to the newly-created OmegaConf stanza for
that model.
"""
return self.mgr.heuristic_import(items_to_import, prediction_type_helper)
def merge_models(
self,
model_names: List[str] = Field(
default=None, min_length=2, max_length=3, description="List of model names to merge"
),
base_model: Union[BaseModelType, str] = Field(
default=None, description="Base model shared by all models to be merged"
),
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
alpha: float = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: bool = False,
merge_dest_directory: Optional[Path] = Field(
default=None, description="Optional directory location for merged model"
),
) -> AddModelResult:
"""
Merge two to three diffusrs pipeline models and save as a new model.
:param model_names: List of 2-3 models to merge
:param base_model: Base model to use for all models
:param merged_model_name: Name of destination merged model
:param alpha: Alpha strength to apply to 2d and 3d model
:param interp: Interpolation method. None (default)
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
"""
merger = ModelMerger(self.mgr)
try:
result = merger.merge_diffusion_models_and_save(
model_names=model_names,
base_model=base_model,
merged_model_name=merged_model_name,
alpha=alpha,
interp=interp,
force=force,
merge_dest_directory=merge_dest_directory,
)
except AssertionError as e:
raise ValueError(e)
return result
def search_for_models(self, directory: Path) -> List[Path]:
"""
Return list of all models found in the designated directory.
"""
search = FindModels([directory], self.logger)
return search.list_models()
def sync_to_config(self):
"""
Re-read models.yaml, rescan the models directory, and reimport models
in the autoimport directories. Call after making changes outside the
model manager API.
"""
return self.mgr.sync_to_config()
def list_checkpoint_configs(self) -> List[Path]:
"""
List the checkpoint config paths from ROOT/configs/stable-diffusion.
"""
config = self.mgr.app_config
conf_path = config.legacy_conf_path
root_path = config.root_path
return [(conf_path / x).relative_to(root_path) for x in conf_path.glob("**/*.yaml")]
def rename_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
new_name: Optional[str] = None,
new_base: Optional[BaseModelType] = None,
):
"""
Rename the indicated model. Can provide a new name and/or a new base.
:param model_name: Current name of the model
:param base_model: Current base of the model
:param model_type: Model type (can't be changed)
:param new_name: New name for the model
:param new_base: New base for the model
"""
self.mgr.rename_model(
base_model=base_model,
model_type=model_type,
model_name=model_name,
new_name=new_name,
new_base=new_base,
record_store._loader = loader # yeah, there is a circular reference here
installer = ModelInstallService(
app_config=app_config,
record_store=record_store,
download_queue=download_queue,
metadata_store=ModelMetadataStore(db=db),
event_bus=events,
)
return cls(store=record_store, install=installer, load=loader)

View File

@ -1,12 +1,3 @@
"""
Initialization file for invokeai.backend
"""
from .model_management import ( # noqa: F401
BaseModelType,
LoadedModelInfo,
ModelCache,
ModelManager,
ModelType,
SubModelType,
)
from .model_management.models import SilenceWarnings # noqa: F401

View File

@ -21,7 +21,7 @@ Validation errors will raise an InvalidModelConfigException error.
"""
import time
from enum import Enum
from typing import Literal, Optional, Type, Union
from typing import Literal, Optional, Type, Union, Class
import torch
from diffusers import ModelMixin
@ -333,9 +333,9 @@ class ModelConfigFactory(object):
@classmethod
def make_config(
cls,
model_data: Union[dict, AnyModelConfig],
model_data: Union[Dict[str, Any], AnyModelConfig],
key: Optional[str] = None,
dest_class: Optional[Type] = None,
dest_class: Optional[Type[Class]] = None,
timestamp: Optional[float] = None,
) -> AnyModelConfig:
"""

View File

@ -18,7 +18,7 @@ loaders = [x.stem for x in Path(Path(__file__).parent, "model_loaders").glob("*.
for module in loaders:
import_module(f"{__package__}.model_loaders.{module}")
__all__ = ["AnyModelLoader", "LoadedModel"]
__all__ = ["AnyModelLoader", "LoadedModel", "ModelCache", "ModelConvertCache"]
def get_standalone_loader(app_config: Optional[InvokeAIAppConfig]) -> AnyModelLoader:

View File

@ -26,10 +26,10 @@ from pathlib import Path
from typing import Callable, Optional, Set, Union
from pydantic import BaseModel, Field
from logging import Logger
from invokeai.backend.util.logging import InvokeAILogger
default_logger = InvokeAILogger.get_logger()
default_logger: Logger = InvokeAILogger.get_logger()
class SearchStats(BaseModel):
@ -56,7 +56,7 @@ class ModelSearchBase(ABC, BaseModel):
on_model_found : Optional[Callable[[Path], bool]] = Field(default=None, description="Called when a model is found.") # noqa E221
on_search_completed : Optional[Callable[[Set[Path]], None]] = Field(default=None, description="Called when search is complete.") # noqa E221
stats : SearchStats = Field(default_factory=SearchStats, description="Summary statistics after search") # noqa E221
logger : InvokeAILogger = Field(default=default_logger, description="Logger instance.") # noqa E221
logger : Logger = Field(default=default_logger, description="Logger instance.") # noqa E221
# fmt: on
class Config:
@ -128,13 +128,13 @@ class ModelSearch(ModelSearchBase):
def model_found(self, model: Path) -> None:
self.stats.models_found += 1
if not self.on_model_found or self.on_model_found(model):
if self.on_model_found is None or self.on_model_found(model):
self.stats.models_filtered += 1
self.models_found.add(model)
def search_completed(self) -> None:
if self.on_search_completed:
self.on_search_completed(self._models_found)
if self.on_search_completed is not None:
self.on_search_completed(self.models_found)
def search(self, directory: Union[Path, str]) -> Set[Path]:
self._directory = Path(directory)