diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index 1b4b1e6094..cb34beed2d 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -18,7 +18,7 @@ from invokeai.backend.model_manager import ( ModelType, SubModelType, UnknownModelException, - DuplicateModelException + DuplicateModelException, ) from invokeai.backend.model_manager.cache import CacheStats from typing import TYPE_CHECKING, List, Optional, Union, Dict, Any @@ -51,10 +51,10 @@ class ModelManagerServiceBase(ABC): @abstractmethod def get_model( - self, - key: str, - submodel_type: Optional[SubModelType] = None, - context: Optional[InvocationContext] = None, + self, + key: str, + submodel_type: Optional[SubModelType] = None, + context: Optional[InvocationContext] = None, ) -> ModelInfo: """Retrieve the indicated model identified by key. @@ -71,8 +71,8 @@ class ModelManagerServiceBase(ABC): @abstractmethod def model_exists( - self, - key: str, + self, + key: str, ) -> bool: pass @@ -85,10 +85,11 @@ class ModelManagerServiceBase(ABC): pass @abstractmethod - def list_models(self, - model_name: Optional[str] = None, - base_model: Optional[BaseModelType] = None, - model_type: Optional[ModelType] = None, + def list_models( + self, + model_name: Optional[str] = None, + base_model: Optional[BaseModelType] = None, + model_type: Optional[ModelType] = None, ) -> List[ModelConfigBase]: """ Return a list of ModelConfigBases that match the base, type and name criteria. @@ -104,13 +105,11 @@ class ModelManagerServiceBase(ABC): If there are more than one model that match, raises a DuplicateModelException. If no model matches, raises an UnknownModelException """ - model_configs = self.list_models( - model_name=model_name, - base_model=base_model, - model_type=model_type - ) + model_configs = self.list_models(model_name=model_name, base_model=base_model, model_type=model_type) if len(model_configs) > 1: - raise DuplicateModelException("More than one model share the same name and type: {base_model}/{model_type}/{model_name}") + raise DuplicateModelException( + "More than one model share the same name and type: {base_model}/{model_type}/{model_name}" + ) if len(model_configs) == 0: raise UnknownModelException("No known model with name and type: {base_model}/{model_type}/{model_name}") return model_configs[0] @@ -123,22 +122,19 @@ class ModelManagerServiceBase(ABC): @abstractmethod def add_model( - self, - model_path: Path, - probe_overrides: Optional[Dict[str, Any]] = None, - wait: bool = False + self, model_path: Path, probe_overrides: Optional[Dict[str, Any]] = None, wait: bool = False ) -> ModelInstallJob: """ Add a model using its path, with a dictionary of attributes. Will fail with an - assertion error if the name already exists. + assertion error if the name already exists. """ pass @abstractmethod def update_model( - self, - key: str, - new_config: Union[dict, ModelConfigBase], + self, + key: str, + new_config: Union[dict, ModelConfigBase], ) -> ModelConfigBase: """ Update the named model with a dictionary of attributes. Will fail with a @@ -151,11 +147,7 @@ class ModelManagerServiceBase(ABC): pass @abstractmethod - def del_model( - self, - key: str, - delete_files: bool = False - ): + def del_model(self, key: str, delete_files: bool = False): """ Delete the named model from configuration. If delete_files is true, then the underlying file or directory will be @@ -164,9 +156,9 @@ class ModelManagerServiceBase(ABC): pass def rename_model( - self, - key: str, - new_name: str, + self, + key: str, + new_name: str, ) -> ModelConfigBase: """ Rename the indicated model. @@ -182,14 +174,14 @@ class ModelManagerServiceBase(ABC): @abstractmethod def convert_model( - self, - key: str, - convert_dest_directory: Path, + self, + key: str, + convert_dest_directory: Path, ) -> ModelConfigBase: """ Convert a checkpoint file into a diffusers folder. - This will delete the cached version if there is any and delete the original + This will delete the cached version if there is any and delete the original checkpoint file if it is in the models directory. :param key: Unique key for the model to convert. :param convert_dest_directory: Save the converted model to the designated directory (`models/etc/etc` by default) @@ -201,10 +193,10 @@ class ModelManagerServiceBase(ABC): pass @abstractmethod - def install_model ( - self, - source: Union[str, Path, AnyHttpUrl], - model_attributes: Optional[Dict[str, Any]] = None, + def install_model( + self, + source: Union[str, Path, AnyHttpUrl], + model_attributes: Optional[Dict[str, Any]] = None, ) -> ModelInstallJob: """Import a path, repo_id or URL. Returns an ModelInstallJob. @@ -270,7 +262,7 @@ class ModelManagerServiceBase(ABC): pass -# implementation +# implementation class ModelManagerService(ModelManagerServiceBase): """Responsible for managing models on disk and in memory""" @@ -289,18 +281,18 @@ class ModelManagerService(ModelManagerServiceBase): self._loader = ModelLoader(config) def get_model( - self, - key: str, - submodel_type: Optional[SubModelType] = None, - context: Optional[InvocationContext] = None, + self, + key: str, + submodel_type: Optional[SubModelType] = None, + context: Optional[InvocationContext] = None, ) -> ModelInfo: """ Retrieve the indicated model. submodel can be used to get a part (such as the vae) of a diffusers mode. """ - + model_info: ModelInfo = self._loader.get_model(key, submodel_type) - + # we can emit model loading events if we are executing with access to the invocation context if context: self._emit_load_event( @@ -313,8 +305,8 @@ class ModelManagerService(ModelManagerServiceBase): return model_info def model_exists( - self, - key: str, + self, + key: str, ) -> bool: """ Given a model key, returns True if it is a valid @@ -331,10 +323,11 @@ class ModelManagerService(ModelManagerServiceBase): # def all_models(self) -> List[ModelConfigBase] -- defined in base class, same as list_models() # def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -- defined in base class - def list_models(self, - model_name: Optional[str] = None, - base_model: Optional[BaseModelType] = None, - model_type: Optional[ModelType] = None, + def list_models( + self, + model_name: Optional[str] = None, + base_model: Optional[BaseModelType] = None, + model_type: Optional[ModelType] = None, ) -> List[ModelConfigBase]: """ Return a ModelConfigBase object for each model in the database. @@ -347,14 +340,11 @@ class ModelManagerService(ModelManagerServiceBase): return self.mgr.list_model(model_name=model_name, base_model=base_model, model_type=model_type) def add_model( - self, - model_path: Path, - model_attributes: Optional[dict] = None, - wait: bool = False + self, model_path: Path, model_attributes: Optional[dict] = None, wait: bool = False ) -> ModelInstallJob: """ Add a model using its path, with a dictionary of attributes. Will fail with an - assertion error if the name already exists. + assertion error if the name already exists. """ self.logger.debug(f"add/update model {model_path}") return self._loader.installer.install( @@ -363,16 +353,16 @@ class ModelManagerService(ModelManagerServiceBase): ) def install_model( - self, - source: Union[str, Path, AnyHttpUrl], - model_attributes: Optional[Dict[str, Any]] = None, + self, + source: Union[str, Path, AnyHttpUrl], + model_attributes: Optional[Dict[str, Any]] = None, ) -> ModelInstallJob: """ Add a model using its path, with a dictionary of attributes. Will fail with an - assertion error if the name already exists. + assertion error if the name already exists. """ self.logger.debug(f"add/update model {source}") - variant = 'fp16' if self._loader.precision == 'float16' else None + variant = "fp16" if self._loader.precision == "float16" else None return self._loader.installer.install( source, probe_override=model_attributes, @@ -380,9 +370,9 @@ class ModelManagerService(ModelManagerServiceBase): ) def update_model( - self, - key: str, - new_config: Union[dict, ModelConfigBase], + self, + key: str, + new_config: Union[dict, ModelConfigBase], ) -> ModelConfigBase: """ Update the named model with a dictionary of attributes. Will fail with a @@ -398,9 +388,9 @@ class ModelManagerService(ModelManagerServiceBase): return self._loader.store.update_model(key, new_config) def del_model( - self, - key: str, - delete_files: bool = False, + self, + key: str, + delete_files: bool = False, ): """ Delete the named model from configuration. If delete_files is true, @@ -418,9 +408,9 @@ class ModelManagerService(ModelManagerServiceBase): path.unlink() def convert_model( - self, - key: str, - convert_dest_directory: Path, + self, + key: str, + convert_dest_directory: Path, ) -> ModelConfigBase: """ Convert a checkpoint file into a diffusers folder, deleting the cached @@ -436,7 +426,7 @@ class ModelManagerService(ModelManagerServiceBase): """ model_info = self.model_info(key) self.logger.debug(f"convert model {model_info.name}") - self.logger.warning('This is not implemented yet') + self.logger.warning("This is not implemented yet") return self._loader.convert_model(key, convert_dest_directory) def collect_cache_stats(self, cache_stats: CacheStats): @@ -478,17 +468,17 @@ class ModelManagerService(ModelManagerServiceBase): @property def logger(self): return self._loader.logger - + def merge_models( - self, - model_keys: List[str] = Field( - default=None, min_items=2, max_items=3, description="List of model keys to merge" - ), - merged_model_name: str = Field(default=None, description="Name of destination model after merging"), - alpha: Optional[float] = 0.5, - interp: Optional[MergeInterpolationMethod] = None, - force: Optional[bool] = False, - merge_dest_directory: Optional[Path] = None, + self, + model_keys: List[str] = Field( + default=None, min_items=2, max_items=3, description="List of model keys to merge" + ), + merged_model_name: str = Field(default=None, description="Name of destination model after merging"), + alpha: Optional[float] = 0.5, + interp: Optional[MergeInterpolationMethod] = None, + force: Optional[bool] = False, + merge_dest_directory: Optional[Path] = None, ) -> ModelConfigBase: """ Merge two to three diffusrs pipeline models and save as a new model. @@ -500,7 +490,7 @@ class ModelManagerService(ModelManagerServiceBase): """ merger = ModelMerger(self.mgr) try: - self.logger.error('ModelMerger needs to be rewritten.') + self.logger.error("ModelMerger needs to be rewritten.") result = merger.merge_diffusion_models_and_save( model_keys=model_keys, merged_model_name=merged_model_name, diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 4b7f8c5aae..378e87306d 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -172,11 +172,13 @@ class VaeDiffusersConfig(ModelConfigBase): model_format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers + class ControlNetDiffusersConfig(ModelConfigBase): """Model config for ControlNet models (diffusers version).""" model_format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers + class TextualInversionConfig(ModelConfigBase): """Model config for textual inversion embeddings.""" diff --git a/invokeai/backend/model_manager/download/queue.py b/invokeai/backend/model_manager/download/queue.py index ce6825f8c9..48dab6ce59 100644 --- a/invokeai/backend/model_manager/download/queue.py +++ b/invokeai/backend/model_manager/download/queue.py @@ -344,14 +344,12 @@ class DownloadQueue(DownloadQueueBase): if match := re.match(CIVITAI_MODEL_DOWNLOAD + r"(\d+)", metadata_url): version = match.group(1) resp = self._requests.get(CIVITAI_VERSIONS_ENDPOINT + version).json() - metadata.thumbnail_url = metadata.thumbnail_url \ - or resp["images"][0]["url"] - metadata.description = metadata.description \ - or ( - f"Trigger terms: {(', ').join(resp['trainedWords'])}" - if resp["trainedWords"] - else resp["description"] - ) + metadata.thumbnail_url = metadata.thumbnail_url or resp["images"][0]["url"] + metadata.description = metadata.description or ( + f"Trigger terms: {(', ').join(resp['trainedWords'])}" + if resp["trainedWords"] + else resp["description"] + ) metadata_url = CIVITAI_MODEL_PAGE + str(resp["modelId"]) # a Civitai model page @@ -364,10 +362,11 @@ class DownloadQueue(DownloadQueueBase): metadata.author = metadata.author or resp["creator"]["username"] metadata.tags = metadata.tags or resp["tags"] - metadata.thumbnail_url = metadata.thumbnail_url \ - or resp["modelVersions"][0]["images"][0]["url"] - metadata.license = metadata.license \ + metadata.thumbnail_url = metadata.thumbnail_url or resp["modelVersions"][0]["images"][0]["url"] + metadata.license = ( + metadata.license or f"allowCommercialUse={resp['allowCommercialUse']}; allowDerivatives={resp['allowDerivatives']}; allowNoCredit={resp['allowNoCredit']}" + ) except (HTTPError, KeyError, TypeError, JSONDecodeError) as excp: self._logger.warn(excp) diff --git a/invokeai/backend/model_manager/install.py b/invokeai/backend/model_manager/install.py index 54a854e837..bda585b5ed 100644 --- a/invokeai/backend/model_manager/install.py +++ b/invokeai/backend/model_manager/install.py @@ -316,9 +316,8 @@ class ModelInstall(ModelInstallBase): """Return the queue.""" return self._download_queue - def register_path(self, - model_path: Union[Path, str], - overrides: Optional[Dict[str, Any]] = None + def register_path( + self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]] = None ) -> str: # noqa D102 model_path = Path(model_path) info: ModelProbeInfo = self._probe_model(model_path, overrides) @@ -354,13 +353,13 @@ class ModelInstall(ModelInstallBase): return key def install_path( - self, - model_path: Union[Path, str], - overrides: Optional[Dict[str, Any]] = None, + self, + model_path: Union[Path, str], + overrides: Optional[Dict[str, Any]] = None, ) -> str: # noqa D102 model_path = Path(model_path) info: ModelProbeInfo = self._probe_model(model_path, overrides) - + dest_path = self._config.models_path / info.base_type.value / info.model_type.value / model_path.name dest_path.parent.mkdir(parents=True, exist_ok=True) @@ -375,14 +374,11 @@ class ModelInstall(ModelInstallBase): info, ) - def _probe_model(self, - model_path: Union[Path, str], - overrides: Optional[Dict[str,Any]] = None - ) -> ModelProbeInfo: + def _probe_model(self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]] = None) -> ModelProbeInfo: info: ModelProbeInfo = ModelProbe.probe(model_path) if overrides: # used to override probe fields for key, value in overrides.items(): - setattr(info, key, value) # may generate a pydantic validation error + setattr(info, key, value) # may generate a pydantic validation error return info def unregister(self, key: str): # noqa D102 @@ -394,12 +390,12 @@ class ModelInstall(ModelInstallBase): self.unregister(key) def install( - self, - source: Union[str, Path, AnyHttpUrl], - inplace: bool = True, - variant: Optional[str] = None, - probe_override: Optional[Dict[str, Any]] = None, - access_token: Optional[str] = None, + self, + source: Union[str, Path, AnyHttpUrl], + inplace: bool = True, + variant: Optional[str] = None, + probe_override: Optional[Dict[str, Any]] = None, + access_token: Optional[str] = None, ) -> DownloadJobBase: # noqa D102 queue = self._download_queue diff --git a/invokeai/backend/model_manager/loader.py b/invokeai/backend/model_manager/loader.py index 6bf7842750..32a3e25fa8 100644 --- a/invokeai/backend/model_manager/loader.py +++ b/invokeai/backend/model_manager/loader.py @@ -75,13 +75,10 @@ class ModelLoaderBase(ABC): """Return the current logger.""" pass - @abstractmethod - def collect_cache_stats( - self, - cache_stats: CacheStats - ): + def collect_cache_stats(self, cache_stats: CacheStats): """Replace cache statistics.""" + pass @property @@ -98,6 +95,7 @@ class ModelLoaderBase(ABC): """ pass + class ModelLoader(ModelLoaderBase): """Implementation of ModelLoaderBase.""" @@ -228,10 +226,7 @@ class ModelLoader(ModelLoaderBase): _cache=self._cache, ) - def collect_cache_stats( - self, - cache_stats: CacheStats - ): + def collect_cache_stats(self, cache_stats: CacheStats): self._cache.stats = cache_stats def _get_implementation(self, base_model: BaseModelType, model_type: ModelType) -> type[ModelBase]: @@ -275,7 +270,6 @@ class ModelLoader(ModelLoaderBase): installed = set() with Chdir(self._app_config.models_path): - self._logger.info("Checking for models that have been moved or deleted from disk.") for model_config in self._store.all_models(): path = self._resolve_model_path(model_config.path) diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index ecd17d027e..5c5baf589c 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -128,7 +128,7 @@ class ModelProbe(ModelProbeBase): format_type = "onnx" if model_type == ModelType.ONNX else "diffusers" if model.is_dir() else "checkpoint" probe_class = cls.PROBES[format_type].get(model_type) - + if not probe_class: return None @@ -138,7 +138,7 @@ class ModelProbe(ModelProbeBase): variant_type = probe.get_variant_type() prediction_type = probe.get_scheduler_prediction_type() format = probe.get_format() - + model_info = ModelProbeInfo( model_type=model_type, base_type=base_type, diff --git a/invokeai/backend/model_manager/storage/base.py b/invokeai/backend/model_manager/storage/base.py index 035e07fb99..b16a126540 100644 --- a/invokeai/backend/model_manager/storage/base.py +++ b/invokeai/backend/model_manager/storage/base.py @@ -11,6 +11,7 @@ from ..config import ModelConfigBase, BaseModelType, ModelType # should match the InvokeAI version when this is first released. CONFIG_FILE_VERSION = "3.1.1" + class DuplicateModelException(Exception): """Raised on an attempt to add a model with the same key twice.""" @@ -122,4 +123,3 @@ class ModelConfigStore(ABC): Return all the model configs in the database. """ return self.search_by_name() - diff --git a/invokeai/backend/model_manager/storage/sql.py b/invokeai/backend/model_manager/storage/sql.py index db66cbf864..35575ef3c0 100644 --- a/invokeai/backend/model_manager/storage/sql.py +++ b/invokeai/backend/model_manager/storage/sql.py @@ -89,8 +89,9 @@ class ModelConfigStoreSQL(ModelConfigStore): self._conn.commit() finally: self._lock.release() - assert self.version == CONFIG_FILE_VERSION, \ - f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}" + assert ( + self.version == CONFIG_FILE_VERSION + ), f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}" def _create_tables(self) -> None: """Create sqlite3 tables.""" @@ -182,9 +183,8 @@ class ModelConfigStoreSQL(ModelConfigStore): ) VALUES (?,?); """, - ("version",CONFIG_FILE_VERSION), + ("version", CONFIG_FILE_VERSION), ) - def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> None: """ @@ -256,7 +256,6 @@ class ModelConfigStoreSQL(ModelConfigStore): finally: self._lock.release() - def _update_tags(self, key: str, tags: List[str]) -> None: """Update tags for model with key.""" # remove previous tags from this model diff --git a/invokeai/backend/model_manager/storage/yaml.py b/invokeai/backend/model_manager/storage/yaml.py index f36ba005e1..acc6a55adc 100644 --- a/invokeai/backend/model_manager/storage/yaml.py +++ b/invokeai/backend/model_manager/storage/yaml.py @@ -78,8 +78,9 @@ class ModelConfigStoreYAML(ModelConfigStore): if not self._filename.exists(): self._initialize_yaml() self._config = OmegaConf.load(self._filename) - assert self.version == CONFIG_FILE_VERSION, \ - f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}" + assert ( + self.version == CONFIG_FILE_VERSION + ), f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}" def _initialize_yaml(self): try: @@ -104,7 +105,7 @@ class ModelConfigStoreYAML(ModelConfigStore): @property def version(self) -> str: """Return version of this config file/database.""" - return self._config["__metadata__"].get('version') + return self._config["__metadata__"].get("version") def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> None: """