fix(config): use yaml module instead of omegaconf when migrating models.yaml

Also use new paths.
This commit is contained in:
psychedelicious 2024-03-11 22:52:58 +11:00
parent ebd0cb6113
commit 4df28f1de6

View File

@ -11,8 +11,8 @@ from shutil import copyfile, copytree, move, rmtree
from tempfile import mkdtemp
from typing import Any, Dict, List, Optional, Set, Union
import yaml
from huggingface_hub import HfFolder
from omegaconf import DictConfig, OmegaConf
from pydantic.networks import AnyHttpUrl
from requests import Session
@ -287,12 +287,13 @@ class ModelInstallService(ModelInstallServiceBase):
def _migrate_yaml(self) -> None:
db_models = self.record_store.all_models()
legacy_models_yaml_path = self._app_config.root_path / "configs" / "models.yaml"
try:
yaml = self._get_yaml()
legacy_models_yaml = yaml.safe_load(legacy_models_yaml_path.read_text())
except OSError:
return
yaml_metadata = yaml.pop("__metadata__")
yaml_metadata = legacy_models_yaml.pop("__metadata__")
yaml_version = yaml_metadata.get("version")
if yaml_version != "3.0.0":
@ -301,11 +302,11 @@ class ModelInstallService(ModelInstallServiceBase):
)
self._logger.info(
f"Starting one-time migration of {len(yaml.items())} models from `models.yaml` to database. This may take a few minutes."
f"Starting one-time migration of {len(legacy_models_yaml.items())} models from `models.yaml` to database. This may take a few minutes."
)
if len(db_models) == 0 and len(yaml.items()) != 0:
for model_key, stanza in yaml.items():
if len(db_models) == 0 and len(legacy_models_yaml.items()) != 0:
for model_key, stanza in legacy_models_yaml.items():
_, _, model_name = str(model_key).split("/")
model_path = Path(stanza["path"])
if not model_path.is_absolute():
@ -324,8 +325,7 @@ class ModelInstallService(ModelInstallServiceBase):
self._logger.warning(f"Model at {model_path} could not be migrated: {e}")
# Rename `models.yaml` to `models.yaml.bak` to prevent re-migration
yaml_path = self._app_config.model_conf_path
yaml_path.rename(yaml_path.with_suffix(".yaml.bak"))
legacy_models_yaml_path.rename(legacy_models_yaml_path.with_suffix(".yaml.bak"))
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
self._cached_model_paths = {Path(x.path).resolve() for x in self.record_store.all_models()}
@ -602,16 +602,6 @@ class ModelInstallService(ModelInstallServiceBase):
self._next_job_id += 1
return id
# --------------------------------------------------------------------------------------------
# Internal functions that manage the old yaml config
# --------------------------------------------------------------------------------------------
def _get_yaml(self) -> DictConfig:
"""Fetch the models.yaml DictConfig for this installation."""
yaml_path = self._app_config.model_conf_path
omegaconf = OmegaConf.load(yaml_path)
assert isinstance(omegaconf, DictConfig)
return omegaconf
@staticmethod
def _guess_variant() -> Optional[ModelRepoVariant]:
"""Guess the best HuggingFace variant type to download."""