This commit is contained in:
Brandon Rising 2024-03-07 17:20:57 -05:00 committed by Mary Hipp Rogers
parent 42d606f07c
commit a3dfa161a8

View File

@ -1,6 +1,7 @@
"""Model installation class.""" """Model installation class."""
import os import os
import pathlib
import re import re
import threading import threading
import time import time
@ -10,9 +11,9 @@ from queue import Empty, Queue
from shutil import copyfile, copytree, move, rmtree from shutil import copyfile, copytree, move, rmtree
from tempfile import mkdtemp from tempfile import mkdtemp
from typing import Any, Dict, List, Optional, Set, Union from typing import Any, Dict, List, Optional, Set, Union
from omegaconf import DictConfig, OmegaConf
from huggingface_hub import HfFolder from huggingface_hub import HfFolder
from omegaconf import DictConfig, OmegaConf
from pydantic.networks import AnyHttpUrl from pydantic.networks import AnyHttpUrl
from requests import Session from requests import Session
@ -297,7 +298,6 @@ class ModelInstallService(ModelInstallServiceBase):
except OSError: except OSError:
return return
if len(db_models) == 0 and len(yaml.items()) != 0: if len(db_models) == 0 and len(yaml.items()) != 0:
self._logger.info("No models in DB, yaml items need to be migrated") self._logger.info("No models in DB, yaml items need to be migrated")
for model_key, stanza in yaml.items(): for model_key, stanza in yaml.items():
@ -308,20 +308,18 @@ class ModelInstallService(ModelInstallServiceBase):
continue continue
_, _, model_name = str(model_key).split("/") _, _, model_name = str(model_key).split("/")
import pathlib
model_path = pathlib.Path(stanza["path"]) model_path = pathlib.Path(stanza["path"])
if not model_path.is_absolute(): if not model_path.is_absolute():
model_path = (self._app_config.models_path / model_path) model_path = self._app_config.models_path / model_path
model_path.resolve() model_path = model_path.resolve()
description = stanza["description"] description = stanza["description"]
model_info = {"name": model_name, "description":description } model_info = {"name": model_name, "description": description}
try: try:
self.register_path(model_path=model_path, config=model_info) self.register_path(model_path=model_path, config=model_info)
except Exception as e: except Exception as e:
self._logger.warning(f"Model at {model_path} could not be loaded into database: {e}") self._logger.warning(f"Model at {model_path} could not be loaded into database: {e}")
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102 def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
self._cached_model_paths = {Path(x.path).absolute() for x in self.record_store.all_models()} self._cached_model_paths = {Path(x.path).absolute() for x in self.record_store.all_models()}
callback = self._scan_install if install else self._scan_register callback = self._scan_install if install else self._scan_register