From 9063b1ae618a7ea1188a166eb7eb24dd9f64913d Mon Sep 17 00:00:00 2001 From: maryhipp Date: Thu, 7 Mar 2024 14:23:33 -0500 Subject: [PATCH] on model manager start, look to see if yaml needs to be migrated and do it if so --- .../model_install/model_install_default.py | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 7ed7f651ef..4289c6e946 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -10,6 +10,7 @@ from queue import Empty, Queue from shutil import copyfile, copytree, move, rmtree from tempfile import mkdtemp from typing import Any, Dict, List, Optional, Set, Union +from omegaconf import DictConfig, OmegaConf from huggingface_hub import HfFolder from pydantic.networks import AnyHttpUrl @@ -115,6 +116,7 @@ class ModelInstallService(ModelInstallServiceBase): raise Exception("Attempt to start the installer service twice") self._start_installer_thread() self._remove_dangling_install_dirs() + self._migrate_yaml() self.sync_to_config() def stop(self, invoker: Optional[Invoker] = None) -> None: @@ -183,6 +185,7 @@ class ModelInstallService(ModelInstallServiceBase): access_token: Optional[str] = None, inplace: Optional[bool] = False, ) -> ModelInstallJob: + print(f"starting import for model {source}") variants = "|".join(ModelRepoVariant.__members__.values()) hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$" source_obj: Optional[StringLikeSource] = None @@ -287,6 +290,31 @@ class ModelInstallService(ModelInstallServiceBase): self._logger.info(f"{len(installed)} new models registered") self._logger.info("Model installer (re)initialized") + def _migrate_yaml(self) -> None: + db_models = self.record_store.all_models() + try: + yaml = self._get_yaml() + except OSError: + return + if len(db_models) == 0 and len(yaml.items()) != 0: + + self._logger.info("No models in DB, yaml items need to be migrated") + + for model_key, stanza in yaml.items(): + if model_key == "__metadata__": + assert ( + stanza["version"] == "3.0.0" + ), f"This script works on version 3.0.0 yaml files, but your configuration points to a {stanza['version']} version" + continue + + base_type, model_type, model_name = str(model_key).split("/") + model_path = stanza["path"] + description = stanza["description"] + model_info = {"name": model_name, "description":description } + + self.heuristic_import(source=model_path, config=model_info, inplace=True) + + def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102 self._cached_model_paths = {Path(x.path).absolute() for x in self.record_store.all_models()} callback = self._scan_install if install else self._scan_register @@ -559,6 +587,16 @@ class ModelInstallService(ModelInstallServiceBase): self._next_job_id += 1 return id + # -------------------------------------------------------------------------------------------- + # Internal functions that manage the old yaml config + # -------------------------------------------------------------------------------------------- + def _get_yaml(self) -> DictConfig: + """Fetch the models.yaml DictConfig for this installation.""" + yaml_path = self._app_config.model_conf_path + omegaconf = OmegaConf.load(yaml_path) + assert isinstance(omegaconf, DictConfig) + return omegaconf + @staticmethod def _guess_variant() -> Optional[ModelRepoVariant]: """Guess the best HuggingFace variant type to download."""