fix(config): use yaml module instead of omegaconf when migrating models.yaml

Also use new paths.
This commit is contained in:
psychedelicious 2024-03-11 22:52:58 +11:00
parent ebd0cb6113
commit 4df28f1de6

View File

@ -11,8 +11,8 @@ from shutil import copyfile, copytree, move, rmtree
from tempfile import mkdtemp from tempfile import mkdtemp
from typing import Any, Dict, List, Optional, Set, Union from typing import Any, Dict, List, Optional, Set, Union
import yaml
from huggingface_hub import HfFolder from huggingface_hub import HfFolder
from omegaconf import DictConfig, OmegaConf
from pydantic.networks import AnyHttpUrl from pydantic.networks import AnyHttpUrl
from requests import Session from requests import Session
@ -287,12 +287,13 @@ class ModelInstallService(ModelInstallServiceBase):
def _migrate_yaml(self) -> None: def _migrate_yaml(self) -> None:
db_models = self.record_store.all_models() db_models = self.record_store.all_models()
legacy_models_yaml_path = self._app_config.root_path / "configs" / "models.yaml"
try: try:
yaml = self._get_yaml() legacy_models_yaml = yaml.safe_load(legacy_models_yaml_path.read_text())
except OSError: except OSError:
return return
yaml_metadata = yaml.pop("__metadata__") yaml_metadata = legacy_models_yaml.pop("__metadata__")
yaml_version = yaml_metadata.get("version") yaml_version = yaml_metadata.get("version")
if yaml_version != "3.0.0": if yaml_version != "3.0.0":
@ -301,11 +302,11 @@ class ModelInstallService(ModelInstallServiceBase):
) )
self._logger.info( self._logger.info(
f"Starting one-time migration of {len(yaml.items())} models from `models.yaml` to database. This may take a few minutes." f"Starting one-time migration of {len(legacy_models_yaml.items())} models from `models.yaml` to database. This may take a few minutes."
) )
if len(db_models) == 0 and len(yaml.items()) != 0: if len(db_models) == 0 and len(legacy_models_yaml.items()) != 0:
for model_key, stanza in yaml.items(): for model_key, stanza in legacy_models_yaml.items():
_, _, model_name = str(model_key).split("/") _, _, model_name = str(model_key).split("/")
model_path = Path(stanza["path"]) model_path = Path(stanza["path"])
if not model_path.is_absolute(): if not model_path.is_absolute():
@ -324,8 +325,7 @@ class ModelInstallService(ModelInstallServiceBase):
self._logger.warning(f"Model at {model_path} could not be migrated: {e}") self._logger.warning(f"Model at {model_path} could not be migrated: {e}")
# Rename `models.yaml` to `models.yaml.bak` to prevent re-migration # Rename `models.yaml` to `models.yaml.bak` to prevent re-migration
yaml_path = self._app_config.model_conf_path legacy_models_yaml_path.rename(legacy_models_yaml_path.with_suffix(".yaml.bak"))
yaml_path.rename(yaml_path.with_suffix(".yaml.bak"))
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102 def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
self._cached_model_paths = {Path(x.path).resolve() for x in self.record_store.all_models()} self._cached_model_paths = {Path(x.path).resolve() for x in self.record_store.all_models()}
@ -602,16 +602,6 @@ class ModelInstallService(ModelInstallServiceBase):
self._next_job_id += 1 self._next_job_id += 1
return id return id
# --------------------------------------------------------------------------------------------
# Internal functions that manage the old yaml config
# --------------------------------------------------------------------------------------------
def _get_yaml(self) -> DictConfig:
"""Fetch the models.yaml DictConfig for this installation."""
yaml_path = self._app_config.model_conf_path
omegaconf = OmegaConf.load(yaml_path)
assert isinstance(omegaconf, DictConfig)
return omegaconf
@staticmethod @staticmethod
def _guess_variant() -> Optional[ModelRepoVariant]: def _guess_variant() -> Optional[ModelRepoVariant]:
"""Guess the best HuggingFace variant type to download.""" """Guess the best HuggingFace variant type to download."""