mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(config): use yaml module instead of omegaconf when migrating models.yaml
Also use new paths.
This commit is contained in:
parent
ebd0cb6113
commit
4df28f1de6
@ -11,8 +11,8 @@ from shutil import copyfile, copytree, move, rmtree
|
|||||||
from tempfile import mkdtemp
|
from tempfile import mkdtemp
|
||||||
from typing import Any, Dict, List, Optional, Set, Union
|
from typing import Any, Dict, List, Optional, Set, Union
|
||||||
|
|
||||||
|
import yaml
|
||||||
from huggingface_hub import HfFolder
|
from huggingface_hub import HfFolder
|
||||||
from omegaconf import DictConfig, OmegaConf
|
|
||||||
from pydantic.networks import AnyHttpUrl
|
from pydantic.networks import AnyHttpUrl
|
||||||
from requests import Session
|
from requests import Session
|
||||||
|
|
||||||
@ -287,12 +287,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
|
|
||||||
def _migrate_yaml(self) -> None:
|
def _migrate_yaml(self) -> None:
|
||||||
db_models = self.record_store.all_models()
|
db_models = self.record_store.all_models()
|
||||||
|
legacy_models_yaml_path = self._app_config.root_path / "configs" / "models.yaml"
|
||||||
try:
|
try:
|
||||||
yaml = self._get_yaml()
|
legacy_models_yaml = yaml.safe_load(legacy_models_yaml_path.read_text())
|
||||||
except OSError:
|
except OSError:
|
||||||
return
|
return
|
||||||
|
|
||||||
yaml_metadata = yaml.pop("__metadata__")
|
yaml_metadata = legacy_models_yaml.pop("__metadata__")
|
||||||
yaml_version = yaml_metadata.get("version")
|
yaml_version = yaml_metadata.get("version")
|
||||||
|
|
||||||
if yaml_version != "3.0.0":
|
if yaml_version != "3.0.0":
|
||||||
@ -301,11 +302,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self._logger.info(
|
self._logger.info(
|
||||||
f"Starting one-time migration of {len(yaml.items())} models from `models.yaml` to database. This may take a few minutes."
|
f"Starting one-time migration of {len(legacy_models_yaml.items())} models from `models.yaml` to database. This may take a few minutes."
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(db_models) == 0 and len(yaml.items()) != 0:
|
if len(db_models) == 0 and len(legacy_models_yaml.items()) != 0:
|
||||||
for model_key, stanza in yaml.items():
|
for model_key, stanza in legacy_models_yaml.items():
|
||||||
_, _, model_name = str(model_key).split("/")
|
_, _, model_name = str(model_key).split("/")
|
||||||
model_path = Path(stanza["path"])
|
model_path = Path(stanza["path"])
|
||||||
if not model_path.is_absolute():
|
if not model_path.is_absolute():
|
||||||
@ -324,8 +325,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._logger.warning(f"Model at {model_path} could not be migrated: {e}")
|
self._logger.warning(f"Model at {model_path} could not be migrated: {e}")
|
||||||
|
|
||||||
# Rename `models.yaml` to `models.yaml.bak` to prevent re-migration
|
# Rename `models.yaml` to `models.yaml.bak` to prevent re-migration
|
||||||
yaml_path = self._app_config.model_conf_path
|
legacy_models_yaml_path.rename(legacy_models_yaml_path.with_suffix(".yaml.bak"))
|
||||||
yaml_path.rename(yaml_path.with_suffix(".yaml.bak"))
|
|
||||||
|
|
||||||
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
|
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
|
||||||
self._cached_model_paths = {Path(x.path).resolve() for x in self.record_store.all_models()}
|
self._cached_model_paths = {Path(x.path).resolve() for x in self.record_store.all_models()}
|
||||||
@ -602,16 +602,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._next_job_id += 1
|
self._next_job_id += 1
|
||||||
return id
|
return id
|
||||||
|
|
||||||
# --------------------------------------------------------------------------------------------
|
|
||||||
# Internal functions that manage the old yaml config
|
|
||||||
# --------------------------------------------------------------------------------------------
|
|
||||||
def _get_yaml(self) -> DictConfig:
|
|
||||||
"""Fetch the models.yaml DictConfig for this installation."""
|
|
||||||
yaml_path = self._app_config.model_conf_path
|
|
||||||
omegaconf = OmegaConf.load(yaml_path)
|
|
||||||
assert isinstance(omegaconf, DictConfig)
|
|
||||||
return omegaconf
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _guess_variant() -> Optional[ModelRepoVariant]:
|
def _guess_variant() -> Optional[ModelRepoVariant]:
|
||||||
"""Guess the best HuggingFace variant type to download."""
|
"""Guess the best HuggingFace variant type to download."""
|
||||||
|
Loading…
Reference in New Issue
Block a user