diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 1b58ab6a15..5c7bd027c5 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -25,8 +25,6 @@ from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsSto from ..services.model_manager_service import ModelManagerService from ..services.processor import DefaultInvocationProcessor from ..services.sqlite import SqliteItemStorage -from ..services.model_manager_service import ModelManagerService -from ..services.download_manager import DownloadQueueService from ..services.invocation_stats import InvocationStatsService from .events import FastAPIEventService @@ -128,7 +126,6 @@ class ApiDependencies: processor=DefaultInvocationProcessor(), configuration=config, performance_statistics=InvocationStatsService(graph_execution_manager), - download_manager=DownloadQueueService(event_bus=events), logger=logger, ) diff --git a/invokeai/app/services/events.py b/invokeai/app/services/events.py index 1b36244966..0132310001 100644 --- a/invokeai/app/services/events.py +++ b/invokeai/app/services/events.py @@ -195,5 +195,5 @@ class EventServiceBase: def emit_model_download_event(self, job: DownloadJobBase): """Emit event when the status of a download job changes.""" self.dispatch( # use dispatch() directly here because we are not a session event. - event_name="download_job_event", payload=dict(job=job) + event_name="install_model_event", payload=dict(job=job) ) diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 2585aaf211..f3ce4c9b40 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -14,7 +14,6 @@ if TYPE_CHECKING: from invokeai.app.services.invocation_queue import InvocationQueueABC from invokeai.app.services.invocation_stats import InvocationStatsServiceBase from invokeai.app.services.invoker import InvocationProcessorABC - from invokeai.app.services.download_manager import DownloadQueueServiceBase from invokeai.app.services.item_storage import ItemStorageABC from invokeai.app.services.latent_storage import LatentsStorageBase from invokeai.app.services.model_manager_service import ModelManagerServiceBase @@ -36,7 +35,6 @@ class InvocationServices: model_manager: "ModelManagerServiceBase" processor: "InvocationProcessorABC" performance_statistics: "InvocationStatsServiceBase" - download_manager: Optional["DownloadQueueServiceBase"] queue: "InvocationQueueABC" def __init__( @@ -54,7 +52,6 @@ class InvocationServices: processor: "InvocationProcessorABC", performance_statistics: "InvocationStatsServiceBase", queue: "InvocationQueueABC", - download_manager: Optional["DownloadQueueServiceBase"] = None, # optional for now pending design decisions ): self.board_images = board_images self.boards = boards @@ -67,7 +64,6 @@ class InvocationServices: self.latents = latents self.logger = logger self.model_manager = model_manager - self.download_manager = download_manager self.processor = processor self.performance_statistics = performance_statistics self.queue = queue diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index fbab9010a7..1b4b1e6094 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -2,38 +2,35 @@ from __future__ import annotations +import shutil from abc import ABC, abstractmethod -from logging import Logger from pathlib import Path -from types import ModuleType from invokeai.backend.model_manager import ( BaseModelType, - DownloadJobBase, MergeInterpolationMethod, ModelConfigBase, ModelInfo, + ModelInstallJob, ModelLoader, ModelMerger, + ModelSearch, ModelType, - SchedulerPredictionType, SubModelType, UnknownModelException, + DuplicateModelException ) -from invokeai.backend.model_manager.search import ModelSearch from invokeai.backend.model_manager.cache import CacheStats -from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Tuple, Union +from typing import TYPE_CHECKING, List, Optional, Union, Dict, Any -import torch from pydantic import Field +from pydantic.networks import AnyHttpUrl from invokeai.app.models.exceptions import CanceledException - -from ...backend.util import choose_precision, choose_torch_device from .config import InvokeAIAppConfig if TYPE_CHECKING: - from ..invocations.baseinvocation import BaseInvocation, InvocationContext + from ..invocations.baseinvocation import InvocationContext class ModelManagerServiceBase(ABC): @@ -43,7 +40,6 @@ class ModelManagerServiceBase(ABC): def __init__( self, config: InvokeAIAppConfig, - logger: ModuleType, ): """ Initialize with the path to the models.yaml config file. @@ -55,17 +51,17 @@ class ModelManagerServiceBase(ABC): @abstractmethod def get_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: Optional[SubModelType] = None, - node: Optional[BaseInvocation] = None, - context: Optional[InvocationContext] = None, + self, + key: str, + submodel_type: Optional[SubModelType] = None, + context: Optional[InvocationContext] = None, ) -> ModelInfo: - """Retrieve the indicated model with name and type. - submodel can be used to get a part (such as the vae) - of a diffusers pipeline.""" + """Retrieve the indicated model identified by key. + + :param key: Unique key returned by the ModelConfigStore module. + :param submodel_type: Submodel to return (required for main models) + :param context" Optional InvocationContext, used in event reporting. + """ pass @property @@ -75,15 +71,13 @@ class ModelManagerServiceBase(ABC): @abstractmethod def model_exists( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, + self, + key: str, ) -> bool: pass @abstractmethod - def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: + def model_info(self, key: str) -> ModelConfigBase: """ Given a model name returns a dict-like (OmegaConf) object describing it. Uses the exact format as the omegaconf stanza. @@ -91,63 +85,60 @@ class ModelManagerServiceBase(ABC): pass @abstractmethod - def list_models(self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None) -> dict: + def list_models(self, + model_name: Optional[str] = None, + base_model: Optional[BaseModelType] = None, + model_type: Optional[ModelType] = None, + ) -> List[ModelConfigBase]: """ - Return a dict of models in the format: - { model_type1: - { model_name1: {'status': 'active'|'cached'|'not loaded', - 'model_name' : name, - 'model_type' : SDModelType, - 'description': description, - 'format': 'folder'|'safetensors'|'ckpt' - }, - model_name2: { etc } - }, - model_type2: - { model_name_n: etc - } + Return a list of ModelConfigBases that match the base, type and name criteria. + :param base_model: Filter by the base model type. + :param model_type: Filter by the model type. + :param model_name: Filter by the model name. """ pass - @abstractmethod - def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: + def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> ModelConfigBase: """ - Return information about the model using the same format as list_models() + Return information about the model using the same format as list_models(). + If there are more than one model that match, raises a DuplicateModelException. + If no model matches, raises an UnknownModelException """ - pass + model_configs = self.list_models( + model_name=model_name, + base_model=base_model, + model_type=model_type + ) + if len(model_configs) > 1: + raise DuplicateModelException("More than one model share the same name and type: {base_model}/{model_type}/{model_name}") + if len(model_configs) == 0: + raise UnknownModelException("No known model with name and type: {base_model}/{model_type}/{model_name}") + return model_configs[0] - @abstractmethod - def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]: + def all_models(self) -> List[ModelConfigBase]: """ - Returns a list of all the model names known. + Returns a list of all the models. """ - pass + return self.list_models() @abstractmethod def add_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - model_attributes: dict, - clobber: bool = False, - ) -> InstallJobBase: + self, + model_path: Path, + probe_overrides: Optional[Dict[str, Any]] = None, + wait: bool = False + ) -> ModelInstallJob: """ - Update the named model with a dictionary of attributes. Will fail with an - assertion error if the name already exists. Pass clobber=True to overwrite. - On a successful update, the config will be changed in memory. Will fail - with an assertion error if provided attributes are incorrect or - the model name is missing. Call commit() to write changes to disk. + Add a model using its path, with a dictionary of attributes. Will fail with an + assertion error if the name already exists. """ pass @abstractmethod def update_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - model_attributes: dict, + self, + key: str, + new_config: Union[dict, ModelConfigBase], ) -> ModelConfigBase: """ Update the named model with a dictionary of attributes. Will fail with a @@ -155,36 +146,32 @@ class ModelManagerServiceBase(ABC): On a successful update, the config will be changed in memory. Will fail with an assertion error if provided attributes are incorrect or - the model name is missing. Call commit() to write changes to disk. + the model key is unknown. """ pass @abstractmethod def del_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, + self, + key: str, + delete_files: bool = False ): """ - Delete the named model from configuration. If delete_files is true, - then the underlying weight file or diffusers directory will be deleted - as well. Call commit() to write to disk. + Delete the named model from configuration. If delete_files + is true, then the underlying file or directory will be + deleted as well. """ pass - @abstractmethod def rename_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - new_name: str, - ): + self, + key: str, + new_name: str, + ) -> ModelConfigBase: """ Rename the indicated model. """ - pass + return self.update_model(key, {"name": new_name}) @abstractmethod def list_checkpoint_configs(self) -> List[Path]: @@ -195,18 +182,17 @@ class ModelManagerServiceBase(ABC): @abstractmethod def convert_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: Literal[ModelType.Main, ModelType.Vae], - ) -> InstallJobBase: + self, + key: str, + convert_dest_directory: Path, + ) -> ModelConfigBase: """ - Convert a checkpoint file into a diffusers folder, deleting the cached - version and deleting the original checkpoint file if it is in the models - directory. - :param model_name: Name of the model to convert - :param base_model: Base model type - :param model_type: Type of model ['vae' or 'main'] + Convert a checkpoint file into a diffusers folder. + + This will delete the cached version if there is any and delete the original + checkpoint file if it is in the models directory. + :param key: Unique key for the model to convert. + :param convert_dest_directory: Save the converted model to the designated directory (`models/etc/etc` by default) This will raise a ValueError unless the model is not a checkpoint. It will also raise a ValueError in the event that there is a similarly-named diffusers @@ -215,37 +201,34 @@ class ModelManagerServiceBase(ABC): pass @abstractmethod - def heuristic_import( - self, - items_to_import: set[str], - prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None, - ) -> InstallJobBase: - """Import a list of paths, repo_ids or URLs. Returns the set of - successfully imported items. - :param items_to_import: Set of strings corresponding to models to be imported. - :param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType. + def install_model ( + self, + source: Union[str, Path, AnyHttpUrl], + model_attributes: Optional[Dict[str, Any]] = None, + ) -> ModelInstallJob: + """Import a path, repo_id or URL. Returns an ModelInstallJob. - The prediction type helper is necessary to distinguish between - models based on Stable Diffusion 2 Base (requiring - SchedulerPredictionType.Epsilson) and Stable Diffusion 768 - (requiring SchedulerPredictionType.VPrediction). It is - generally impossible to do this programmatically, so the - prediction_type_helper usually asks the user to choose. + :param model_attributes: Additional attributes to supplement/override + the model information gained from automated probing. - The result is a set of successfully installed models. Each element - of the set is a dict corresponding to the newly-created OmegaConf stanza for - that model. + Typical usage: + job = model_manager.install( + 'stabilityai/stable-diffusion-2-1', + model_attributes={'prediction_type": 'v-prediction'} + ) + + The result is an ModelInstallJob object, which provides + information on the asynchronous model download and install + process. A series of "install_model_event" events will be emitted + until the install is completed, cancelled or errors out. """ pass @abstractmethod def merge_models( self, - model_names: List[str] = Field( - default=None, min_items=2, max_items=3, description="List of model names to merge" - ), - base_model: Union[BaseModelType, str] = Field( - default=None, description="Base model shared by all models to be merged" + model_keys: List[str] = Field( + default=None, min_items=2, max_items=3, description="List of model keys to merge" ), merged_model_name: str = Field(default=None, description="Name of destination model after merging"), alpha: Optional[float] = 0.5, @@ -255,8 +238,7 @@ class ModelManagerServiceBase(ABC): ) -> ModelConfigBase: """ Merge two to three diffusrs pipeline models and save as a new model. - :param model_names: List of 2-3 models to merge - :param base_model: Base model to use for all models + :param model_keys: List of 2-3 model unique keys to merge :param merged_model_name: Name of destination merged model :param alpha: Alpha strength to apply to 2d and 3d model :param interp: Interpolation method. None (default) @@ -287,24 +269,16 @@ class ModelManagerServiceBase(ABC): """ pass - @abstractmethod - def commit(self, conf_file: Optional[Path] = None) -> None: - """ - Write current configuration out to the indicated file. - If no conf_file is provided, then replaces the - original file/database used to initialize the object. - """ - pass - -# simple implementation +# implementation class ModelManagerService(ModelManagerServiceBase): """Responsible for managing models on disk and in memory""" + _loader: ModelLoader = Field(description="InvokeAIAppConfig object for the current process") + def __init__( self, config: InvokeAIAppConfig, - logger: Logger, ): """ Initialize with the path to the models.yaml config file. @@ -312,218 +286,164 @@ class ModelManagerService(ModelManagerServiceBase): and sequential_offload boolean. Note that the default device type and precision are set up for a CUDA system running at half precision. """ - if config.model_conf_path and config.model_conf_path.exists(): - config_file = config.model_conf_path - else: - config_file = config.root_dir / "configs/models.yaml" - - logger.debug(f"Config file={config_file}") - - device = torch.device(choose_torch_device()) - device_name = torch.cuda.get_device_name() if device == torch.device("cuda") else "" - logger.info(f"GPU device = {device} {device_name}") - - precision = config.precision - if precision == "auto": - precision = choose_precision(device) - dtype = torch.float32 if precision == "float32" else torch.float16 - - # this is transitional backward compatibility - # support for the deprecated `max_loaded_models` - # configuration value. If present, then the - # cache size is set to 2.5 GB times - # the number of max_loaded_models. Otherwise - # use new `ram_cache_size` config setting - max_cache_size = config.ram_cache_size - - logger.debug(f"Maximum RAM cache size: {max_cache_size} GiB") - - sequential_offload = config.sequential_guidance - - self.mgr = ModelManager( - config=config_file, - device_type=device, - precision=dtype, - max_cache_size=max_cache_size, - sequential_offload=sequential_offload, - logger=logger, - ) - logger.info("Model manager service initialized") + self._loader = ModelLoader(config) def get_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: Optional[SubModelType] = None, - context: Optional[InvocationContext] = None, + self, + key: str, + submodel_type: Optional[SubModelType] = None, + context: Optional[InvocationContext] = None, ) -> ModelInfo: """ Retrieve the indicated model. submodel can be used to get a part (such as the vae) of a diffusers mode. """ - + + model_info: ModelInfo = self._loader.get_model(key, submodel_type) + # we can emit model loading events if we are executing with access to the invocation context if context: self._emit_load_event( context=context, - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=submodel, - ) - - model_info = self.mgr.get_model( - model_name, - base_model, - model_type, - submodel, - ) - - if context: - self._emit_load_event( - context=context, - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=submodel, + key=key, + submodel_type=submodel_type, model_info=model_info, ) return model_info def model_exists( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, + self, + key: str, ) -> bool: """ - Given a model name, returns True if it is a valid + Given a model key, returns True if it is a valid identifier. """ - return self.mgr.model_exists( - model_name, - base_model, - model_type, - ) + return self._loader.store.exists(key) - def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]: + def model_info(self, key: str) -> ModelConfigBase: """ Given a model name returns a dict-like (OmegaConf) object describing it. """ - return self.mgr.model_info(model_name, base_model, model_type) + return self._loader.store.get_model(key) - def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]: - """ - Returns a list of all the model names known. - """ - return self.mgr.model_names() + # def all_models(self) -> List[ModelConfigBase] -- defined in base class, same as list_models() + # def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -- defined in base class - def list_models( - self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None - ) -> list[dict]: + def list_models(self, + model_name: Optional[str] = None, + base_model: Optional[BaseModelType] = None, + model_type: Optional[ModelType] = None, + ) -> List[ModelConfigBase]: """ - Return a list of models. + Return a ModelConfigBase object for each model in the database. """ - return self.mgr.list_models(base_model, model_type) + return self._loader.store.search_by_name(model_name, base_model, model_type) - def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]: """ Return information about the model using the same format as list_models() """ return self.mgr.list_model(model_name=model_name, base_model=base_model, model_type=model_type) def add_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - model_attributes: dict, - clobber: bool = False, - ) -> InstallJobBase: + self, + model_path: Path, + model_attributes: Optional[dict] = None, + wait: bool = False + ) -> ModelInstallJob: """ - Update the named model with a dictionary of attributes. Will fail with an - assertion error if the name already exists. Pass clobber=True to overwrite. - On a successful update, the config will be changed in memory. Will fail - with an assertion error if provided attributes are incorrect or - the model name is missing. Call commit() to write changes to disk. + Add a model using its path, with a dictionary of attributes. Will fail with an + assertion error if the name already exists. """ - self.logger.debug(f"add/update model {model_name}") - return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber) + self.logger.debug(f"add/update model {model_path}") + return self._loader.installer.install( + model_path, + probe_override=model_attributes, + ) + + def install_model( + self, + source: Union[str, Path, AnyHttpUrl], + model_attributes: Optional[Dict[str, Any]] = None, + ) -> ModelInstallJob: + """ + Add a model using its path, with a dictionary of attributes. Will fail with an + assertion error if the name already exists. + """ + self.logger.debug(f"add/update model {source}") + variant = 'fp16' if self._loader.precision == 'float16' else None + return self._loader.installer.install( + source, + probe_override=model_attributes, + variant=variant, + ) def update_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - model_attributes: dict, - ) -> InstallJobBase: + self, + key: str, + new_config: Union[dict, ModelConfigBase], + ) -> ModelConfigBase: """ Update the named model with a dictionary of attributes. Will fail with a - UnknownModelException exception if the name does not already exist. + UnknownModelException if the name does not already exist. + On a successful update, the config will be changed in memory. Will fail with an assertion error if provided attributes are incorrect or - the model name is missing. Call commit() to write changes to disk. + the model key is unknown. """ - self.logger.debug(f"update model {model_name}") - if not self.model_exists(model_name, base_model, model_type): - raise UnknownModelException(f"Unknown model {model_name}") - return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True) + model_info = self.model_info(key) + self.logger.debug(f"update model {model_info.name}") + self.logger.warning("TO DO: write code to move models around if base or type change") + return self._loader.store.update_model(key, new_config) def del_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, + self, + key: str, + delete_files: bool = False, ): """ Delete the named model from configuration. If delete_files is true, then the underlying weight file or diffusers directory will be deleted as well. """ - self.logger.debug(f"delete model {model_name}") - self.mgr.del_model(model_name, base_model, model_type) - self.mgr.commit() + model_info = self.model_info(key) + self.logger.debug(f"delete model {model_info.name}") + self._loader.store.del_model(key) + if delete_files and Path(model_info.path).exists(): + path = Path(model_info) + if path.is_dir(): + shutil.rmtree(path) + else: + path.unlink() def convert_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: Literal[ModelType.Main, ModelType.Vae], - convert_dest_directory: Optional[Path] = Field( - default=None, description="Optional directory location for merged model" - ), - ) -> InstallJobBase: + self, + key: str, + convert_dest_directory: Path, + ) -> ModelConfigBase: """ Convert a checkpoint file into a diffusers folder, deleting the cached version and deleting the original checkpoint file if it is in the models directory. - :param model_name: Name of the model to convert - :param base_model: Base model type - :param model_type: Type of model ['vae' or 'main'] + + :param key: Key of the model to convert :param convert_dest_directory: Save the converted model to the designated directory (`models/etc/etc` by default) This will raise a ValueError unless the model is not a checkpoint. It will also raise a ValueError in the event that there is a similarly-named diffusers directory already in place. """ - self.logger.debug(f"convert model {model_name}") - return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory) + model_info = self.model_info(key) + self.logger.debug(f"convert model {model_info.name}") + self.logger.warning('This is not implemented yet') + return self._loader.convert_model(key, convert_dest_directory) def collect_cache_stats(self, cache_stats: CacheStats): """ - Reset model cache statistics for graph with graph_id. + Reset model cache statistics. Is this used? """ - self.mgr.cache.stats = cache_stats - - def commit(self, conf_file: Optional[Path] = None): - """ - Write current configuration out to the indicated file. - If no conf_file is provided, then replaces the - original file/database used to initialize the object. - """ - return self.mgr.commit(conf_file) + self._loader.collect_cache_stats(cache_stats) def _emit_load_event( self, @@ -557,51 +477,22 @@ class ModelManagerService(ModelManagerServiceBase): @property def logger(self): - return self.mgr.logger - - def heuristic_import( - self, - items_to_import: set[str], - prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None, - ) -> dict[str, InstallJobBase]: - """Import a list of paths, repo_ids or URLs. Returns the set of - successfully imported items. - :param items_to_import: Set of strings corresponding to models to be imported. - :param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType. - - The prediction type helper is necessary to distinguish between - models based on Stable Diffusion 2 Base (requiring - SchedulerPredictionType.Epsilson) and Stable Diffusion 768 - (requiring SchedulerPredictionType.VPrediction). It is - generally impossible to do this programmatically, so the - prediction_type_helper usually asks the user to choose. - - The result is a set of successfully installed models. Each element - of the set is a dict corresponding to the newly-created OmegaConf stanza for - that model. - """ - return self.mgr.heuristic_import(items_to_import, prediction_type_helper) - + return self._loader.logger + def merge_models( - self, - model_names: List[str] = Field( - default=None, min_items=2, max_items=3, description="List of model names to merge" - ), - base_model: Union[BaseModelType, str] = Field( - default=None, description="Base model shared by all models to be merged" - ), - merged_model_name: str = Field(default=None, description="Name of destination model after merging"), - alpha: float = 0.5, - interp: Optional[MergeInterpolationMethod] = None, - force: bool = False, - merge_dest_directory: Optional[Path] = Field( - default=None, description="Optional directory location for merged model" - ), - ) -> str: + self, + model_keys: List[str] = Field( + default=None, min_items=2, max_items=3, description="List of model keys to merge" + ), + merged_model_name: str = Field(default=None, description="Name of destination model after merging"), + alpha: Optional[float] = 0.5, + interp: Optional[MergeInterpolationMethod] = None, + force: Optional[bool] = False, + merge_dest_directory: Optional[Path] = None, + ) -> ModelConfigBase: """ Merge two to three diffusrs pipeline models and save as a new model. - :param model_names: List of 2-3 models to merge - :param base_model: Base model to use for all models + :param model_keys: List of 2-3 model unique keys to merge :param merged_model_name: Name of destination merged model :param alpha: Alpha strength to apply to 2d and 3d model :param interp: Interpolation method. None (default) @@ -609,9 +500,9 @@ class ModelManagerService(ModelManagerServiceBase): """ merger = ModelMerger(self.mgr) try: + self.logger.error('ModelMerger needs to be rewritten.') result = merger.merge_diffusion_models_and_save( - model_names=model_names, - base_model=base_model, + model_keys=model_keys, merged_model_name=merged_model_name, alpha=alpha, interp=interp, @@ -626,8 +517,7 @@ class ModelManagerService(ModelManagerServiceBase): """ Return list of all models found in the designated directory. """ - search = FindModels([directory], self.logger) - return search.list_models() + return ModelSearch().search(directory) def sync_to_config(self): """ @@ -635,7 +525,7 @@ class ModelManagerService(ModelManagerServiceBase): in the autoimport directories. Call after making changes outside the model manager API. """ - return self.mgr.sync_to_config() + return self._loader.sync_to_config() def list_checkpoint_configs(self) -> List[Path]: """ @@ -648,24 +538,13 @@ class ModelManagerService(ModelManagerServiceBase): def rename_model( self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - new_name: Optional[str] = None, - new_base: Optional[BaseModelType] = None, + key: str, + new_name: str, ): """ - Rename the indicated model. Can provide a new name and/or a new base. - :param model_name: Current name of the model - :param base_model: Current base of the model - :param model_type: Model type (can't be changed) + Rename the indicated model to the new name. + + :param key: Unique key for the model. :param new_name: New name for the model - :param new_base: New base for the model """ - self.mgr.rename_model( - base_model=base_model, - model_type=model_type, - model_name=model_name, - new_name=new_name, - new_base=new_base, - ) + return self.update_model(key, {"name": new_name}) diff --git a/invokeai/backend/model_manager/__init__.py b/invokeai/backend/model_manager/__init__.py index ac696d29fd..e6ec12c896 100644 --- a/invokeai/backend/model_manager/__init__.py +++ b/invokeai/backend/model_manager/__init__.py @@ -16,7 +16,7 @@ from .config import ( # noqa F401 ) from .lora import ONNXModelPatcher, ModelPatcher from .loader import ModelLoader, ModelInfo # noqa F401 -from .install import ModelInstall, DownloadJobBase # noqa F401 +from .install import ModelInstall, ModelInstallJob # noqa F401 from .probe import ModelProbe, InvalidModelException # noqa F401 from .storage import ( UnknownModelException, diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index f104d6efcf..4b7f8c5aae 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -136,6 +136,11 @@ class ModelConfigBase(BaseModel): v = list(v) return v + def update(self, attributes: dict): + """Update the object with fields in dict.""" + for key, value in attributes.items(): + setattr(self, key, value) # may raise a validation error + class CheckpointConfig(ModelConfigBase): """Model config for checkpoint-style models.""" diff --git a/invokeai/backend/model_manager/download/queue.py b/invokeai/backend/model_manager/download/queue.py index b2f9452ace..ce6825f8c9 100644 --- a/invokeai/backend/model_manager/download/queue.py +++ b/invokeai/backend/model_manager/download/queue.py @@ -329,14 +329,14 @@ class DownloadQueue(DownloadQueueBase): self._dones += 1 self._queue.task_done() - def _fetch_metadata(self, job: DownloadJobBase) -> Tuple[AnyHttpUrl, ModelSourceMetadata]: + def _get_metadata_and_url(self, job: DownloadJobBase) -> AnyHttpUrl: """ Fetch metadata from certain well-known URLs. The metadata will be stashed in job.metadata, if found Return the download URL. """ - metadata = ModelSourceMetadata() + metadata = job.metadata url = job.source metadata_url = url try: @@ -344,12 +344,14 @@ class DownloadQueue(DownloadQueueBase): if match := re.match(CIVITAI_MODEL_DOWNLOAD + r"(\d+)", metadata_url): version = match.group(1) resp = self._requests.get(CIVITAI_VERSIONS_ENDPOINT + version).json() - metadata.thumbnail_url = resp["images"][0]["url"] - metadata.description = ( - f"Trigger terms: {(', ').join(resp['trainedWords'])}" - if resp["trainedWords"] - else resp["description"] - ) + metadata.thumbnail_url = metadata.thumbnail_url \ + or resp["images"][0]["url"] + metadata.description = metadata.description \ + or ( + f"Trigger terms: {(', ').join(resp['trainedWords'])}" + if resp["trainedWords"] + else resp["description"] + ) metadata_url = CIVITAI_MODEL_PAGE + str(resp["modelId"]) # a Civitai model page @@ -360,21 +362,22 @@ class DownloadQueue(DownloadQueueBase): # note that we munge the URL here to get the download URL of the first model url = resp["modelVersions"][0]["downloadUrl"] - metadata.author = resp["creator"]["username"] - metadata.tags = resp["tags"] - metadata.thumbnail_url = resp["modelVersions"][0]["images"][0]["url"] - metadata.license = f"allowCommercialUse={resp['allowCommercialUse']}; allowDerivatives={resp['allowDerivatives']}; allowNoCredit={resp['allowNoCredit']}" + metadata.author = metadata.author or resp["creator"]["username"] + metadata.tags = metadata.tags or resp["tags"] + metadata.thumbnail_url = metadata.thumbnail_url \ + or resp["modelVersions"][0]["images"][0]["url"] + metadata.license = metadata.license \ + or f"allowCommercialUse={resp['allowCommercialUse']}; allowDerivatives={resp['allowDerivatives']}; allowNoCredit={resp['allowNoCredit']}" except (HTTPError, KeyError, TypeError, JSONDecodeError) as excp: self._logger.warn(excp) - # update metadata and return the download url - return url, metadata + # return the download url + return url def _download_with_resume(self, job: DownloadJobBase): """Do the actual download.""" try: - url, metadata = self._fetch_metadata(job) - job.metadata = metadata + url = self._get_metadata_and_url(job) header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {} open_mode = "wb" @@ -602,7 +605,6 @@ class DownloadQueue(DownloadQueueBase): """Call when the source is a Path or pathlike object.""" source = Path(job.source).resolve() destination = Path(job.destination).resolve() - job.metadata = ModelSourceMetadata() try: if source != destination: shutil.move(source, destination) diff --git a/invokeai/backend/model_manager/install.py b/invokeai/backend/model_manager/install.py index aff8ceb084..54a854e837 100644 --- a/invokeai/backend/model_manager/install.py +++ b/invokeai/backend/model_manager/install.py @@ -53,7 +53,7 @@ import tempfile from abc import ABC, abstractmethod from pathlib import Path from shutil import rmtree -from typing import Optional, List, Union, Dict, Set +from typing import Optional, List, Union, Dict, Set, Any from pydantic import Field from pydantic.networks import AnyHttpUrl from invokeai.app.services.config import InvokeAIAppConfig @@ -79,8 +79,8 @@ class ModelInstallJob(DownloadJobBase): model_key: Optional[str] = Field( description="After model installation, this field will hold its primary key", default=None ) - probe_info: Optional[ModelProbeInfo] = Field( - description="If provided, information here will be used instead of probing the model.", + probe_override: Optional[Dict[str, Any]] = Field( + description="Keys in this dict will override like-named attributes in the automatic probe info", default=None, ) @@ -316,9 +316,12 @@ class ModelInstall(ModelInstallBase): """Return the queue.""" return self._download_queue - def register_path(self, model_path: Union[Path, str], info: Optional[ModelProbeInfo] = None) -> str: # noqa D102 + def register_path(self, + model_path: Union[Path, str], + overrides: Optional[Dict[str, Any]] = None + ) -> str: # noqa D102 model_path = Path(model_path) - info: ModelProbeInfo = info or ModelProbe.probe(model_path) + info: ModelProbeInfo = self._probe_model(model_path, overrides) return self._register(model_path, info) def _register(self, model_path: Path, info: ModelProbeInfo) -> str: @@ -351,12 +354,13 @@ class ModelInstall(ModelInstallBase): return key def install_path( - self, - model_path: Union[Path, str], - info: Optional[ModelProbeInfo] = None, + self, + model_path: Union[Path, str], + overrides: Optional[Dict[str, Any]] = None, ) -> str: # noqa D102 model_path = Path(model_path) - info: ModelProbeInfo = info or ModelProbe.probe(model_path) + info: ModelProbeInfo = self._probe_model(model_path, overrides) + dest_path = self._config.models_path / info.base_type.value / info.model_type.value / model_path.name dest_path.parent.mkdir(parents=True, exist_ok=True) @@ -371,6 +375,16 @@ class ModelInstall(ModelInstallBase): info, ) + def _probe_model(self, + model_path: Union[Path, str], + overrides: Optional[Dict[str,Any]] = None + ) -> ModelProbeInfo: + info: ModelProbeInfo = ModelProbe.probe(model_path) + if overrides: # used to override probe fields + for key, value in overrides.items(): + setattr(info, key, value) # may generate a pydantic validation error + return info + def unregister(self, key: str): # noqa D102 self._store.del_model(key) @@ -380,12 +394,12 @@ class ModelInstall(ModelInstallBase): self.unregister(key) def install( - self, - source: Union[str, Path, AnyHttpUrl], - info: Optional[ModelProbeInfo] = None, - inplace: bool = True, - variant: Optional[str] = None, - access_token: Optional[str] = None, + self, + source: Union[str, Path, AnyHttpUrl], + inplace: bool = True, + variant: Optional[str] = None, + probe_override: Optional[Dict[str, Any]] = None, + access_token: Optional[str] = None, ) -> DownloadJobBase: # noqa D102 queue = self._download_queue @@ -395,8 +409,8 @@ class ModelInstall(ModelInstallBase): if inplace and Path(source).exists() else self._complete_installation_handler ) + job.probe_override = probe_override job.add_event_handler(handler) - job.probe_info = info self._async_installs[source] = None queue.submit_download_job(job, True) @@ -405,7 +419,7 @@ class ModelInstall(ModelInstallBase): def _complete_installation_handler(self, job: DownloadJobBase): if job.status == "completed": self._logger.info(f"{job.source}: Download finished with status {job.status}. Installing.") - model_id = self.install_path(job.destination, job.probe_info) + model_id = self.install_path(job.destination, job.probe_override) info = self._store.get_model(model_id) info.source = str(job.source) metadata: ModelSourceMetadata = job.metadata @@ -429,7 +443,7 @@ class ModelInstall(ModelInstallBase): def _complete_registration_handler(self, job: DownloadJobBase): if job.status == "completed": self._logger.info(f"{job.source}: Installing in place.") - model_id = self.register_path(job.destination, job.probe_info) + model_id = self.register_path(job.destination, job.probe_override) info = self._store.get_model(model_id) info.source = str(job.source) info.description = f"Imported model {info.name}" diff --git a/invokeai/backend/model_manager/loader.py b/invokeai/backend/model_manager/loader.py index ad2ad31ebf..6bf7842750 100644 --- a/invokeai/backend/model_manager/loader.py +++ b/invokeai/backend/model_manager/loader.py @@ -14,7 +14,7 @@ from invokeai.backend.util import choose_precision, choose_torch_device, InvokeA from .config import BaseModelType, ModelType, SubModelType, ModelConfigBase from .install import ModelInstallBase, ModelInstall from .storage import ModelConfigStore, get_config_store -from .cache import ModelCache, ModelLocker +from .cache import ModelCache, ModelLocker, CacheStats from .models import InvalidModelException, ModelBase, MODEL_CLASSES @@ -69,6 +69,34 @@ class ModelLoaderBase(ABC): """Return the ModelInstallBase object that supports this loader.""" pass + @property + @abstractmethod + def logger(self) -> InvokeAILogger: + """Return the current logger.""" + pass + + + @abstractmethod + def collect_cache_stats( + self, + cache_stats: CacheStats + ): + """Replace cache statistics.""" + pass + + @property + @abstractmethod + def precision(self) -> str: + """Return 'float32' or 'float16'.""" + pass + + @abstractmethod + def sync_to_config(self): + """ + Reinitialize the store to sync in-memory and in-disk + versions. + """ + pass class ModelLoader(ModelLoaderBase): """Implementation of ModelLoaderBase.""" @@ -79,6 +107,7 @@ class ModelLoader(ModelLoaderBase): _cache: ModelCache _logger: InvokeAILogger _cache_keys: dict + _models_file: Path def __init__( self, @@ -102,6 +131,7 @@ class ModelLoader(ModelLoaderBase): self._logger = InvokeAILogger.getLogger() self._installer = ModelInstall(store=self._store, logger=self._logger, config=self._app_config) self._cache_keys = dict() + self._models_file = models_file device = torch.device(choose_torch_device()) device_name = torch.cuda.get_device_name() if device == torch.device("cuda") else "" precision = choose_precision(device) if config.precision == "auto" else config.precision @@ -130,11 +160,21 @@ class ModelLoader(ModelLoaderBase): """Return the ModelConfigStore instance used by this class.""" return self._store + @property + def precision(self) -> str: + """Return 'float32' or 'float16'.""" + return self._cache.precision + @property def installer(self) -> ModelInstallBase: """Return the ModelInstallBase instance used by this class.""" return self._installer + @property + def logger(self) -> InvokeAILogger: + """Return the current logger.""" + return self._logger + def get_model(self, key: str, submodel_type: Optional[SubModelType] = None) -> ModelInfo: """ Get the ModelInfo corresponding to the model with key "key". @@ -188,6 +228,12 @@ class ModelLoader(ModelLoaderBase): _cache=self._cache, ) + def collect_cache_stats( + self, + cache_stats: CacheStats + ): + self._cache.stats = cache_stats + def _get_implementation(self, base_model: BaseModelType, model_type: ModelType) -> type[ModelBase]: """Get the concrete implementation class for a specific model type.""" model_class = MODEL_CLASSES[base_model][model_type] @@ -220,6 +266,10 @@ class ModelLoader(ModelLoaderBase): model_path = self._resolve_model_path(model_path) return model_path, is_submodel_override + def sync_to_config(self): + self._store = get_config_store(self._models_file) + self._scan_models_directory() + def _scan_models_directory(self): defunct_models = set() installed = set() diff --git a/invokeai/backend/model_manager/merge.py b/invokeai/backend/model_manager/merge.py index 90b773bec3..6dcb1cab2b 100644 --- a/invokeai/backend/model_manager/merge.py +++ b/invokeai/backend/model_manager/merge.py @@ -124,17 +124,26 @@ class ModelMerger(object): dump_path = (dump_path / merged_model_name).as_posix() merged_pipe.save_pretrained(dump_path, safe_serialization=True) - attributes = dict( - path=dump_path, - description=f"Merge of models {', '.join(model_names)}", - model_format="diffusers", - variant=ModelVariantType.Normal.value, - vae=vae, - ) - return self.manager.add_model( - merged_model_name, - base_model=base_model, + + # register model and get its unique key + info = ModelProbeInfo( model_type=ModelType.Main, - model_attributes=attributes, - clobber=True, + base_type=base_model, + format="diffusers", ) + key = self.manager.installer.register_path( + model_path=dump_path, + info=info, + ) + + # update model's config + model_config = self.manager.store.get_model(key) + model_config.update( + dict( + name=merged_model_name, + description=f"Merge of models {', '.join(model_names)}", + vae=vae, + ) + ) + self.manager.store.update_model(key, model_config) + return model_config diff --git a/tests/test_model_manager/configs/relative_sub.models.yaml b/tests/test_model_manager/configs/relative_sub.models.yaml index 4f4a774a60..db773e0e45 100644 --- a/tests/test_model_manager/configs/relative_sub.models.yaml +++ b/tests/test_model_manager/configs/relative_sub.models.yaml @@ -1,5 +1,5 @@ __metadata__: - version: 3.1.0 + version: 3.1.1 ed799245c762f6d0a9ddfd4e31fdb010: name: sdxl-base-1-0 path: sdxl/main/SDXL base 1_0