mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
update_model and delete_model working; convert is WIP
This commit is contained in:
@ -2,6 +2,7 @@
|
||||
|
||||
|
||||
import pathlib
|
||||
import traceback
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from fastapi import Body, Path, Query, Response
|
||||
@ -23,6 +24,10 @@ from ..dependencies import ApiDependencies
|
||||
|
||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||
|
||||
# NOTE: The generic configuration classes defined in invokeai.backend.model_manager.config
|
||||
# such as "MainCheckpointConfig" are repackaged by code original written by Stalker
|
||||
# into base-specific classes such as `abc.StableDiffusion1ModelCheckpointConfig`
|
||||
# This is the reason for the calls to dict() followed by pydantic.parse_obj_as()
|
||||
UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
@ -56,7 +61,7 @@ async def list_models(
|
||||
|
||||
|
||||
@models_router.patch(
|
||||
"/{base_model}/{model_type}/{model_name}",
|
||||
"/{key}",
|
||||
operation_id="update_model",
|
||||
responses={
|
||||
200: {"description": "The model was updated successfully"},
|
||||
@ -68,58 +73,17 @@ async def list_models(
|
||||
response_model=UpdateModelResponse,
|
||||
)
|
||||
async def update_model(
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_type: ModelType = Path(description="The type of model"),
|
||||
model_name: str = Path(description="model name"),
|
||||
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
||||
key: str = Path(description="Unique key of model"),
|
||||
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
||||
) -> UpdateModelResponse:
|
||||
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
previous_info = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
)
|
||||
|
||||
# rename operation requested
|
||||
if info.model_name != model_name or info.base_model != base_model:
|
||||
ApiDependencies.invoker.services.model_manager.rename_model(
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
model_name=model_name,
|
||||
new_name=info.model_name,
|
||||
new_base=info.base_model,
|
||||
)
|
||||
logger.info(f"Successfully renamed {base_model.value}/{model_name}=>{info.base_model}/{info.model_name}")
|
||||
# update information to support an update of attributes
|
||||
model_name = info.model_name
|
||||
base_model = info.base_model
|
||||
new_info = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
)
|
||||
if new_info.get("path") != previous_info.get(
|
||||
"path"
|
||||
): # model manager moved model path during rename - don't overwrite it
|
||||
info.path = new_info.get("path")
|
||||
|
||||
# replace empty string values with None/null to avoid phenomenon of vae: ''
|
||||
info_dict = info.dict()
|
||||
info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()}
|
||||
|
||||
ApiDependencies.invoker.services.model_manager.update_model(
|
||||
model_name=model_name, base_model=base_model, model_type=model_type, model_attributes=info_dict
|
||||
)
|
||||
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
)
|
||||
model_response = parse_obj_as(UpdateModelResponse, model_raw)
|
||||
new_config = ApiDependencies.invoker.services.model_manager.update_model(key, new_config=info_dict)
|
||||
model_response = parse_obj_as(UpdateModelResponse, new_config.dict())
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
@ -221,25 +185,25 @@ async def add_model(
|
||||
|
||||
|
||||
@models_router.delete(
|
||||
"/{base_model}/{model_type}/{model_name}",
|
||||
"/{key}",
|
||||
operation_id="del_model",
|
||||
responses={204: {"description": "Model deleted successfully"}, 404: {"description": "Model not found"}},
|
||||
status_code=204,
|
||||
response_model=None,
|
||||
)
|
||||
async def delete_model(
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_type: ModelType = Path(description="The type of model"),
|
||||
model_name: str = Path(description="model name"),
|
||||
key: str = Path(description="Unique key of model to remove from model registry."),
|
||||
delete_files: Optional[bool] = Query(
|
||||
description="Delete underlying files and directories as well.",
|
||||
default=False
|
||||
)
|
||||
) -> Response:
|
||||
"""Delete Model"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.model_manager.del_model(
|
||||
model_name, base_model=base_model, model_type=model_type
|
||||
)
|
||||
logger.info(f"Deleted model: {model_name}")
|
||||
ApiDependencies.invoker.services.model_manager.del_model(key, delete_files=delete_files)
|
||||
logger.info(f"Deleted model: {key}")
|
||||
return Response(status_code=204)
|
||||
except UnknownModelException as e:
|
||||
logger.error(str(e))
|
||||
@ -247,7 +211,7 @@ async def delete_model(
|
||||
|
||||
|
||||
@models_router.put(
|
||||
"/convert/{base_model}/{model_type}/{model_name}",
|
||||
"/convert/{key}",
|
||||
operation_id="convert_model",
|
||||
responses={
|
||||
200: {"description": "Model converted successfully"},
|
||||
@ -258,27 +222,19 @@ async def delete_model(
|
||||
response_model=ConvertModelResponse,
|
||||
)
|
||||
async def convert_model(
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_type: ModelType = Path(description="The type of model"),
|
||||
model_name: str = Path(description="model name"),
|
||||
key: str = Path(description="Unique key of model to remove from model registry."),
|
||||
convert_dest_directory: Optional[str] = Query(
|
||||
default=None, description="Save the converted model to the designated directory"
|
||||
),
|
||||
) -> ConvertModelResponse:
|
||||
"""Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
info = ApiDependencies.invoker.services.model_manager.model_info(key)
|
||||
try:
|
||||
logger.info(f"Converting model: {model_name}")
|
||||
logger.info(f"Converting model: {info.name}")
|
||||
dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None
|
||||
ApiDependencies.invoker.services.model_manager.convert_model(
|
||||
model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
convert_dest_directory=dest,
|
||||
)
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name, base_model=base_model, model_type=model_type
|
||||
)
|
||||
ApiDependencies.invoker.services.model_manager.convert_model(key, convert_dest_directory=dest)
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.model_info(key).dict()
|
||||
response = parse_obj_as(ConvertModelResponse, model_raw)
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
|
||||
|
@ -370,7 +370,9 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
self, model_path: Path, model_attributes: Optional[dict] = None, wait: bool = False
|
||||
) -> ModelInstallJob:
|
||||
"""
|
||||
Add a model using its path, with a dictionary of attributes. Will fail with an
|
||||
Add a model using its path, with a dictionary of attributes.
|
||||
|
||||
Will fail with an
|
||||
assertion error if the name already exists.
|
||||
"""
|
||||
self.logger.debug(f"add/update model {model_path}")
|
||||
@ -385,8 +387,11 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
model_attributes: Optional[Dict[str, Any]] = None,
|
||||
) -> ModelInstallJob:
|
||||
"""
|
||||
Add a model using its path, with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists.
|
||||
Add a model using a path, repo_id or URL.
|
||||
|
||||
:param model_attributes: Dictionary of ModelConfigBase fields to
|
||||
attach to the model. When installing a URL or repo_id, some metadata
|
||||
values, such as `tags` will be automagically added.
|
||||
"""
|
||||
self.logger.debug(f"add/update model {source}")
|
||||
variant = "fp16" if self._loader.precision == "float16" else None
|
||||
@ -402,17 +407,18 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
new_config: Union[dict, ModelConfigBase],
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with a
|
||||
Update the named model with a dictionary of attributes.
|
||||
|
||||
Will fail with a
|
||||
UnknownModelException if the name does not already exist.
|
||||
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model key is unknown.
|
||||
"""
|
||||
model_info = self.model_info(key)
|
||||
self.logger.debug(f"update model {model_info.name}")
|
||||
self.logger.warning("TO DO: write code to move models around if base or type change")
|
||||
return self._loader.store.update_model(key, new_config)
|
||||
self.logger.debug(f"update model {key}")
|
||||
new_info = self._loader.store.update_model(key, new_config)
|
||||
return self._loader.installer.sync_model_path(new_info.key)
|
||||
|
||||
def del_model(
|
||||
self,
|
||||
@ -420,7 +426,9 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
delete_files: bool = False,
|
||||
):
|
||||
"""
|
||||
Delete the named model from configuration. If delete_files is true,
|
||||
Delete the named model from configuration.
|
||||
|
||||
If delete_files is true,
|
||||
then the underlying weight file or diffusers directory will be deleted
|
||||
as well.
|
||||
"""
|
||||
@ -428,7 +436,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
self.logger.debug(f"delete model {model_info.name}")
|
||||
self._loader.store.del_model(key)
|
||||
if delete_files and Path(model_info.path).exists():
|
||||
path = Path(model_info)
|
||||
path = Path(model_info.path)
|
||||
if path.is_dir():
|
||||
shutil.rmtree(path)
|
||||
else:
|
||||
|
@ -52,7 +52,7 @@ import re
|
||||
import tempfile
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
from shutil import rmtree, move
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Union
|
||||
|
||||
from pydantic import Field
|
||||
@ -61,7 +61,15 @@ from pydantic.networks import AnyHttpUrl
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from .config import BaseModelType, ModelFormat, ModelType, ModelVariantType, SchedulerPredictionType
|
||||
from .config import (
|
||||
ModelConfigBase,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
SubModelType,
|
||||
)
|
||||
from .download import (
|
||||
HTTP_RE,
|
||||
REPO_ID_RE,
|
||||
@ -111,8 +119,8 @@ class ModelInstallBase(ABC):
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
store: Optional[ModelConfigStore] = None,
|
||||
config: Optional[InvokeAIAppConfig] = None,
|
||||
store: Optional[ModelConfigStore] = None,
|
||||
logger: Optional[InvokeAILogger] = None,
|
||||
download: Optional[DownloadQueueBase] = None,
|
||||
event_handlers: Optional[List[DownloadEventHandler]] = None,
|
||||
@ -271,6 +279,20 @@ class ModelInstallBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def sync_model_path(self, key) -> Path:
|
||||
"""
|
||||
Move model into the location indicated by its basetype, type and name.
|
||||
|
||||
Call this after updating a model's attributes in order to move
|
||||
the model's path into the location indicated by its basetype, type and
|
||||
name. Applies only to models whose paths are within the root `models_dir`
|
||||
directory.
|
||||
|
||||
May raise an UnknownModelException.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ModelInstall(ModelInstallBase):
|
||||
"""Model installer class handles installation from a local path."""
|
||||
@ -372,18 +394,24 @@ class ModelInstall(ModelInstallBase):
|
||||
info: ModelProbeInfo = self._probe_model(model_path, overrides)
|
||||
|
||||
dest_path = self._config.models_path / info.base_type.value / info.model_type.value / model_path.name
|
||||
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
return self._register(
|
||||
self._move_model(model_path, dest_path),
|
||||
info,
|
||||
)
|
||||
|
||||
def _move_model(self, old_path: Path, new_path: Path) -> Path:
|
||||
if old_path == new_path:
|
||||
return old_path
|
||||
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# if path already exists then we jigger the name to make it unique
|
||||
counter: int = 1
|
||||
while dest_path.exists():
|
||||
dest_path = dest_path.with_stem(dest_path.stem + f"_{counter:02d}")
|
||||
while new_path.exists():
|
||||
new_path = new_path.with_stem(new_path.stem + f"_{counter:02d}")
|
||||
counter += 1
|
||||
return old_path.replace(new_path)
|
||||
|
||||
return self._register(
|
||||
model_path.replace(dest_path),
|
||||
info,
|
||||
)
|
||||
|
||||
def _probe_model(self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]] = None) -> ModelProbeInfo:
|
||||
info: ModelProbeInfo = ModelProbe.probe(model_path)
|
||||
@ -436,7 +464,7 @@ class ModelInstall(ModelInstallBase):
|
||||
info.license = metadata.license
|
||||
info.thumbnail_url = metadata.thumbnail_url
|
||||
self._store.update_model(model_id, info)
|
||||
self._async_installs[job.source] = model_id
|
||||
self._async_installs[info.source] = model_id
|
||||
job.model_key = model_id
|
||||
elif job.status == "error":
|
||||
self._logger.warning(f"{job.source}: Model installation error: {job.error}")
|
||||
@ -455,13 +483,36 @@ class ModelInstall(ModelInstallBase):
|
||||
info.source = str(job.source)
|
||||
info.description = f"Imported model {info.name}"
|
||||
self._store.update_model(model_id, info)
|
||||
self._async_installs[job.source] = model_id
|
||||
self._async_installs[info.source] = model_id
|
||||
job.model_key = model_id
|
||||
elif job.status == "error":
|
||||
self._logger.warning(f"{job.source}: Model installation error: {job.error}")
|
||||
elif job.status == "cancelled":
|
||||
self._logger.warning(f"{job.source}: Model installation cancelled at caller's request.")
|
||||
|
||||
def sync_model_path(self, key) -> ModelConfigBase:
|
||||
"""
|
||||
Move model into the location indicated by its basetype, type and name.
|
||||
|
||||
Call this after updating a model's attributes in order to move
|
||||
the model's path into the location indicated by its basetype, type and
|
||||
name. Applies only to models whose paths are within the root `models_dir`
|
||||
directory.
|
||||
|
||||
May raise an UnknownModelException.
|
||||
"""
|
||||
model = self._store.get_model(key)
|
||||
old_path = Path(model.path)
|
||||
models_dir = self._config.models_path
|
||||
|
||||
if not old_path.is_relative_to(models_dir):
|
||||
return old_path
|
||||
|
||||
new_path = models_dir / model.base_model.value / model.model_type.value / model.name
|
||||
model.path = self._move_model(old_path, new_path).as_posix()
|
||||
self._store.update_model(key, model)
|
||||
return model
|
||||
|
||||
def _make_download_job(
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
@ -487,7 +538,7 @@ class ModelInstall(ModelInstallBase):
|
||||
cls = ModelInstallURLJob
|
||||
kwargs = {}
|
||||
else:
|
||||
raise NotImplementedError(f"Don't know what to do with this type of source: {source}")
|
||||
raise ValueError(f"'{source}' is not recognized as a local file, directory, repo_id or URL")
|
||||
return cls(source=source, destination=Path(self._tmpdir.name), access_token=access_token, **kwargs)
|
||||
|
||||
def wait_for_installs(self) -> Dict[str, str]: # noqa D102
|
||||
@ -515,6 +566,64 @@ class ModelInstall(ModelInstallBase):
|
||||
def hash(self, model_path: Union[Path, str]) -> str: # noqa D102
|
||||
return FastModelHash.hash(model_path)
|
||||
|
||||
|
||||
def convert_model(
|
||||
self,
|
||||
key: str,
|
||||
dest_directory: Optional[Path] = None,
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||
version and deleting the original checkpoint file if it is in the models
|
||||
directory.
|
||||
:param key: Unique key of model.
|
||||
:dest_directory: Optional place to put converted file. If not specified,
|
||||
will be stored in the `models_dir`.
|
||||
|
||||
This will raise a ValueError unless the model is a checkpoint.
|
||||
This will raise an UnknownModelException if key is unknown.
|
||||
"""
|
||||
from .loader import ModelInfo, ModelLoader # to avoid circular imports
|
||||
try:
|
||||
info: ModelConfigBase = self._store.get_model(key)
|
||||
|
||||
print(f'DEBUG: requested_model={info}')
|
||||
|
||||
if info.model_format != "checkpoint":
|
||||
raise ValueError(f"not a checkpoint format model: {info.name}")
|
||||
|
||||
# We are taking advantage of a side effect of get_model() that converts check points
|
||||
# into cached diffusers directories stored at `path`. It doesn't matter
|
||||
# what submodel type we request here, so we get the smallest.
|
||||
loader = ModelLoader(self._config)
|
||||
submodel = {"submodel_type": SubModelType.Scheduler} if info.model_type == ModelType.Main else {}
|
||||
converted_model: ModelInfo = loader.get_model(key, **submodel)
|
||||
|
||||
checkpoint_path = loader.resolve_model_path(info.path)
|
||||
old_diffusers_path = loader.resolve_model_path(converted_model.location)
|
||||
new_diffusers_path = None
|
||||
if dest_directory:
|
||||
new_diffusers_path = Path(dest_directory) / info.name
|
||||
if new_diffusers_path.exists():
|
||||
raise ValueError(f"A diffusers model already exists at {new_diffusers_path}")
|
||||
move(old_diffusers_path, new_diffusers_path)
|
||||
info.path = new_diffusers_path.as_posix()
|
||||
|
||||
info.pop("config")
|
||||
info.model_format = "diffusers"
|
||||
self._store.update_model(key, info)
|
||||
result = self.sync_model_path(key)
|
||||
except Exception:
|
||||
# something went wrong, so don't leave dangling diffusers model in directory or it will cause a duplicate model error!
|
||||
if new_diffusers_path:
|
||||
rmtree(new_diffusers_path)
|
||||
raise
|
||||
|
||||
if checkpoint_path.exists() and checkpoint_path.is_relative_to(self._config.models_path):
|
||||
checkpoint_path.unlink()
|
||||
|
||||
return result
|
||||
|
||||
# the following two methods are callbacks to the ModelSearch object
|
||||
def _scan_register(self, model: Path) -> bool:
|
||||
try:
|
||||
|
@ -91,8 +91,12 @@ class ModelLoaderBase(ABC):
|
||||
@abstractmethod
|
||||
def collect_cache_stats(self, cache_stats: CacheStats):
|
||||
"""Replace cache statistics."""
|
||||
pass
|
||||
|
||||
pass
|
||||
@abstractmethod
|
||||
def resolve_model_path(self, path: Union[Path, str]) -> Path:
|
||||
"""Turn a potentially relative path into an absolute one in the models_dir."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
@ -214,6 +218,9 @@ class ModelLoader(ModelLoaderBase):
|
||||
the model to retrieve (e.g. ModelType.Vae)
|
||||
"""
|
||||
model_config = self.store.get_model(key) # May raise a UnknownModelException
|
||||
if model_config.model_type == 'main' and not submodel_type:
|
||||
raise InvalidModelException('submodel_type is required when loading a main model')
|
||||
|
||||
model_path, is_submodel_override = self._get_model_path(model_config, submodel_type)
|
||||
|
||||
if is_submodel_override:
|
||||
@ -258,17 +265,17 @@ class ModelLoader(ModelLoaderBase):
|
||||
def collect_cache_stats(self, cache_stats: CacheStats):
|
||||
self._cache.stats = cache_stats
|
||||
|
||||
def resolve_model_path(self, path: Union[Path, str]) -> Path:
|
||||
"""Turn a potentially relative path into an absolute one in the models_dir."""
|
||||
return self._app_config.models_path / path
|
||||
|
||||
def _get_implementation(self, base_model: BaseModelType, model_type: ModelType) -> type[ModelBase]:
|
||||
"""Get the concrete implementation class for a specific model type."""
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
return model_class
|
||||
|
||||
def _get_model_cache_path(self, model_path):
|
||||
return self._resolve_model_path(Path(".cache") / hashlib.md5(str(model_path).encode()).hexdigest())
|
||||
|
||||
def _resolve_model_path(self, path: Union[Path, str]) -> Path:
|
||||
"""Return relative paths based on configured models_path."""
|
||||
return self._app_config.models_path / path
|
||||
return self.resolve_model_path(Path(".cache") / hashlib.md5(str(model_path).encode()).hexdigest())
|
||||
|
||||
def _get_model_path(
|
||||
self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None
|
||||
@ -287,7 +294,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
model_path = getattr(model_config, submodel_type)
|
||||
is_submodel_override = True
|
||||
|
||||
model_path = self._resolve_model_path(model_path)
|
||||
model_path = self.resolve_model_path(model_path)
|
||||
return model_path, is_submodel_override
|
||||
|
||||
def sync_to_config(self):
|
||||
@ -301,7 +308,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
with Chdir(self._app_config.models_path):
|
||||
self._logger.info("Checking for models that have been moved or deleted from disk.")
|
||||
for model_config in self._store.all_models():
|
||||
path = self._resolve_model_path(model_config.path)
|
||||
path = self.resolve_model_path(model_config.path)
|
||||
if not path.exists():
|
||||
self._logger.info(f"{model_config.name}: path {path.as_posix()} no longer exists. Unregistering.")
|
||||
defunct_models.add(model_config.key)
|
||||
@ -311,6 +318,6 @@ class ModelLoader(ModelLoaderBase):
|
||||
self._logger.info(f"Scanning {self._app_config.models_path} for new models")
|
||||
for cur_base_model in BaseModelType:
|
||||
for cur_model_type in ModelType:
|
||||
models_dir = self._resolve_model_path(Path(cur_base_model.value, cur_model_type.value))
|
||||
models_dir = self.resolve_model_path(Path(cur_base_model.value, cur_model_type.value))
|
||||
installed.update(self._installer.scan_directory(models_dir))
|
||||
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")
|
||||
|
@ -163,6 +163,7 @@ class ModelConfigStoreYAML(ModelConfigStore):
|
||||
self._commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get_model(key)
|
||||
|
||||
def get_model(self, key: str) -> ModelConfigBase:
|
||||
"""
|
||||
|
Reference in New Issue
Block a user