on model manager start, look to see if yaml needs to be migrated and do it if so

This commit is contained in:
maryhipp 2024-03-07 14:23:33 -05:00 committed by Mary Hipp Rogers
parent 6aae88bd88
commit 9063b1ae61

View File

@ -10,6 +10,7 @@ from queue import Empty, Queue
from shutil import copyfile, copytree, move, rmtree from shutil import copyfile, copytree, move, rmtree
from tempfile import mkdtemp from tempfile import mkdtemp
from typing import Any, Dict, List, Optional, Set, Union from typing import Any, Dict, List, Optional, Set, Union
from omegaconf import DictConfig, OmegaConf
from huggingface_hub import HfFolder from huggingface_hub import HfFolder
from pydantic.networks import AnyHttpUrl from pydantic.networks import AnyHttpUrl
@ -115,6 +116,7 @@ class ModelInstallService(ModelInstallServiceBase):
raise Exception("Attempt to start the installer service twice") raise Exception("Attempt to start the installer service twice")
self._start_installer_thread() self._start_installer_thread()
self._remove_dangling_install_dirs() self._remove_dangling_install_dirs()
self._migrate_yaml()
self.sync_to_config() self.sync_to_config()
def stop(self, invoker: Optional[Invoker] = None) -> None: def stop(self, invoker: Optional[Invoker] = None) -> None:
@ -183,6 +185,7 @@ class ModelInstallService(ModelInstallServiceBase):
access_token: Optional[str] = None, access_token: Optional[str] = None,
inplace: Optional[bool] = False, inplace: Optional[bool] = False,
) -> ModelInstallJob: ) -> ModelInstallJob:
print(f"starting import for model {source}")
variants = "|".join(ModelRepoVariant.__members__.values()) variants = "|".join(ModelRepoVariant.__members__.values())
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$" hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
source_obj: Optional[StringLikeSource] = None source_obj: Optional[StringLikeSource] = None
@ -287,6 +290,31 @@ class ModelInstallService(ModelInstallServiceBase):
self._logger.info(f"{len(installed)} new models registered") self._logger.info(f"{len(installed)} new models registered")
self._logger.info("Model installer (re)initialized") self._logger.info("Model installer (re)initialized")
def _migrate_yaml(self) -> None:
db_models = self.record_store.all_models()
try:
yaml = self._get_yaml()
except OSError:
return
if len(db_models) == 0 and len(yaml.items()) != 0:
self._logger.info("No models in DB, yaml items need to be migrated")
for model_key, stanza in yaml.items():
if model_key == "__metadata__":
assert (
stanza["version"] == "3.0.0"
), f"This script works on version 3.0.0 yaml files, but your configuration points to a {stanza['version']} version"
continue
base_type, model_type, model_name = str(model_key).split("/")
model_path = stanza["path"]
description = stanza["description"]
model_info = {"name": model_name, "description":description }
self.heuristic_import(source=model_path, config=model_info, inplace=True)
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102 def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
self._cached_model_paths = {Path(x.path).absolute() for x in self.record_store.all_models()} self._cached_model_paths = {Path(x.path).absolute() for x in self.record_store.all_models()}
callback = self._scan_install if install else self._scan_register callback = self._scan_install if install else self._scan_register
@ -559,6 +587,16 @@ class ModelInstallService(ModelInstallServiceBase):
self._next_job_id += 1 self._next_job_id += 1
return id return id
# --------------------------------------------------------------------------------------------
# Internal functions that manage the old yaml config
# --------------------------------------------------------------------------------------------
def _get_yaml(self) -> DictConfig:
"""Fetch the models.yaml DictConfig for this installation."""
yaml_path = self._app_config.model_conf_path
omegaconf = OmegaConf.load(yaml_path)
assert isinstance(omegaconf, DictConfig)
return omegaconf
@staticmethod @staticmethod
def _guess_variant() -> Optional[ModelRepoVariant]: def _guess_variant() -> Optional[ModelRepoVariant]:
"""Guess the best HuggingFace variant type to download.""" """Guess the best HuggingFace variant type to download."""