mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
on model manager start, look to see if yaml needs to be migrated and do it if so
This commit is contained in:
parent
6aae88bd88
commit
9063b1ae61
@ -10,6 +10,7 @@ from queue import Empty, Queue
|
|||||||
from shutil import copyfile, copytree, move, rmtree
|
from shutil import copyfile, copytree, move, rmtree
|
||||||
from tempfile import mkdtemp
|
from tempfile import mkdtemp
|
||||||
from typing import Any, Dict, List, Optional, Set, Union
|
from typing import Any, Dict, List, Optional, Set, Union
|
||||||
|
from omegaconf import DictConfig, OmegaConf
|
||||||
|
|
||||||
from huggingface_hub import HfFolder
|
from huggingface_hub import HfFolder
|
||||||
from pydantic.networks import AnyHttpUrl
|
from pydantic.networks import AnyHttpUrl
|
||||||
@ -115,6 +116,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
raise Exception("Attempt to start the installer service twice")
|
raise Exception("Attempt to start the installer service twice")
|
||||||
self._start_installer_thread()
|
self._start_installer_thread()
|
||||||
self._remove_dangling_install_dirs()
|
self._remove_dangling_install_dirs()
|
||||||
|
self._migrate_yaml()
|
||||||
self.sync_to_config()
|
self.sync_to_config()
|
||||||
|
|
||||||
def stop(self, invoker: Optional[Invoker] = None) -> None:
|
def stop(self, invoker: Optional[Invoker] = None) -> None:
|
||||||
@ -183,6 +185,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
access_token: Optional[str] = None,
|
access_token: Optional[str] = None,
|
||||||
inplace: Optional[bool] = False,
|
inplace: Optional[bool] = False,
|
||||||
) -> ModelInstallJob:
|
) -> ModelInstallJob:
|
||||||
|
print(f"starting import for model {source}")
|
||||||
variants = "|".join(ModelRepoVariant.__members__.values())
|
variants = "|".join(ModelRepoVariant.__members__.values())
|
||||||
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
|
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
|
||||||
source_obj: Optional[StringLikeSource] = None
|
source_obj: Optional[StringLikeSource] = None
|
||||||
@ -287,6 +290,31 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._logger.info(f"{len(installed)} new models registered")
|
self._logger.info(f"{len(installed)} new models registered")
|
||||||
self._logger.info("Model installer (re)initialized")
|
self._logger.info("Model installer (re)initialized")
|
||||||
|
|
||||||
|
def _migrate_yaml(self) -> None:
|
||||||
|
db_models = self.record_store.all_models()
|
||||||
|
try:
|
||||||
|
yaml = self._get_yaml()
|
||||||
|
except OSError:
|
||||||
|
return
|
||||||
|
if len(db_models) == 0 and len(yaml.items()) != 0:
|
||||||
|
|
||||||
|
self._logger.info("No models in DB, yaml items need to be migrated")
|
||||||
|
|
||||||
|
for model_key, stanza in yaml.items():
|
||||||
|
if model_key == "__metadata__":
|
||||||
|
assert (
|
||||||
|
stanza["version"] == "3.0.0"
|
||||||
|
), f"This script works on version 3.0.0 yaml files, but your configuration points to a {stanza['version']} version"
|
||||||
|
continue
|
||||||
|
|
||||||
|
base_type, model_type, model_name = str(model_key).split("/")
|
||||||
|
model_path = stanza["path"]
|
||||||
|
description = stanza["description"]
|
||||||
|
model_info = {"name": model_name, "description":description }
|
||||||
|
|
||||||
|
self.heuristic_import(source=model_path, config=model_info, inplace=True)
|
||||||
|
|
||||||
|
|
||||||
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
|
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
|
||||||
self._cached_model_paths = {Path(x.path).absolute() for x in self.record_store.all_models()}
|
self._cached_model_paths = {Path(x.path).absolute() for x in self.record_store.all_models()}
|
||||||
callback = self._scan_install if install else self._scan_register
|
callback = self._scan_install if install else self._scan_register
|
||||||
@ -559,6 +587,16 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._next_job_id += 1
|
self._next_job_id += 1
|
||||||
return id
|
return id
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------------------------
|
||||||
|
# Internal functions that manage the old yaml config
|
||||||
|
# --------------------------------------------------------------------------------------------
|
||||||
|
def _get_yaml(self) -> DictConfig:
|
||||||
|
"""Fetch the models.yaml DictConfig for this installation."""
|
||||||
|
yaml_path = self._app_config.model_conf_path
|
||||||
|
omegaconf = OmegaConf.load(yaml_path)
|
||||||
|
assert isinstance(omegaconf, DictConfig)
|
||||||
|
return omegaconf
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _guess_variant() -> Optional[ModelRepoVariant]:
|
def _guess_variant() -> Optional[ModelRepoVariant]:
|
||||||
"""Guess the best HuggingFace variant type to download."""
|
"""Guess the best HuggingFace variant type to download."""
|
||||||
|
Loading…
Reference in New Issue
Block a user