update_model and delete_model working; convert is WIP

This commit is contained in:
Lincoln Stein
2023-09-16 12:22:23 -04:00
parent db7fdc3555
commit c090c5f907
5 changed files with 181 additions and 100 deletions

View File

@ -2,6 +2,7 @@
import pathlib
import traceback
from typing import List, Literal, Optional, Union
from fastapi import Body, Path, Query, Response
@ -23,6 +24,10 @@ from ..dependencies import ApiDependencies
models_router = APIRouter(prefix="/v1/models", tags=["models"])
# NOTE: The generic configuration classes defined in invokeai.backend.model_manager.config
# such as "MainCheckpointConfig" are repackaged by code original written by Stalker
# into base-specific classes such as `abc.StableDiffusion1ModelCheckpointConfig`
# This is the reason for the calls to dict() followed by pydantic.parse_obj_as()
UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
@ -56,7 +61,7 @@ async def list_models(
@models_router.patch(
"/{base_model}/{model_type}/{model_name}",
"/{key}",
operation_id="update_model",
responses={
200: {"description": "The model was updated successfully"},
@ -68,58 +73,17 @@ async def list_models(
response_model=UpdateModelResponse,
)
async def update_model(
base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="model name"),
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
key: str = Path(description="Unique key of model"),
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
) -> UpdateModelResponse:
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
logger = ApiDependencies.invoker.services.logger
try:
previous_info = ApiDependencies.invoker.services.model_manager.list_model(
model_name=model_name,
base_model=base_model,
model_type=model_type,
)
# rename operation requested
if info.model_name != model_name or info.base_model != base_model:
ApiDependencies.invoker.services.model_manager.rename_model(
base_model=base_model,
model_type=model_type,
model_name=model_name,
new_name=info.model_name,
new_base=info.base_model,
)
logger.info(f"Successfully renamed {base_model.value}/{model_name}=>{info.base_model}/{info.model_name}")
# update information to support an update of attributes
model_name = info.model_name
base_model = info.base_model
new_info = ApiDependencies.invoker.services.model_manager.list_model(
model_name=model_name,
base_model=base_model,
model_type=model_type,
)
if new_info.get("path") != previous_info.get(
"path"
): # model manager moved model path during rename - don't overwrite it
info.path = new_info.get("path")
# replace empty string values with None/null to avoid phenomenon of vae: ''
info_dict = info.dict()
info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()}
ApiDependencies.invoker.services.model_manager.update_model(
model_name=model_name, base_model=base_model, model_type=model_type, model_attributes=info_dict
)
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=model_name,
base_model=base_model,
model_type=model_type,
)
model_response = parse_obj_as(UpdateModelResponse, model_raw)
new_config = ApiDependencies.invoker.services.model_manager.update_model(key, new_config=info_dict)
model_response = parse_obj_as(UpdateModelResponse, new_config.dict())
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
@ -221,25 +185,25 @@ async def add_model(
@models_router.delete(
"/{base_model}/{model_type}/{model_name}",
"/{key}",
operation_id="del_model",
responses={204: {"description": "Model deleted successfully"}, 404: {"description": "Model not found"}},
status_code=204,
response_model=None,
)
async def delete_model(
base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="model name"),
key: str = Path(description="Unique key of model to remove from model registry."),
delete_files: Optional[bool] = Query(
description="Delete underlying files and directories as well.",
default=False
)
) -> Response:
"""Delete Model"""
logger = ApiDependencies.invoker.services.logger
try:
ApiDependencies.invoker.services.model_manager.del_model(
model_name, base_model=base_model, model_type=model_type
)
logger.info(f"Deleted model: {model_name}")
ApiDependencies.invoker.services.model_manager.del_model(key, delete_files=delete_files)
logger.info(f"Deleted model: {key}")
return Response(status_code=204)
except UnknownModelException as e:
logger.error(str(e))
@ -247,7 +211,7 @@ async def delete_model(
@models_router.put(
"/convert/{base_model}/{model_type}/{model_name}",
"/convert/{key}",
operation_id="convert_model",
responses={
200: {"description": "Model converted successfully"},
@ -258,27 +222,19 @@ async def delete_model(
response_model=ConvertModelResponse,
)
async def convert_model(
base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="model name"),
key: str = Path(description="Unique key of model to remove from model registry."),
convert_dest_directory: Optional[str] = Query(
default=None, description="Save the converted model to the designated directory"
),
) -> ConvertModelResponse:
"""Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none."""
logger = ApiDependencies.invoker.services.logger
info = ApiDependencies.invoker.services.model_manager.model_info(key)
try:
logger.info(f"Converting model: {model_name}")
logger.info(f"Converting model: {info.name}")
dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None
ApiDependencies.invoker.services.model_manager.convert_model(
model_name,
base_model=base_model,
model_type=model_type,
convert_dest_directory=dest,
)
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name, base_model=base_model, model_type=model_type
)
ApiDependencies.invoker.services.model_manager.convert_model(key, convert_dest_directory=dest)
model_raw = ApiDependencies.invoker.services.model_manager.model_info(key).dict()
response = parse_obj_as(ConvertModelResponse, model_raw)
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")

View File

@ -370,7 +370,9 @@ class ModelManagerService(ModelManagerServiceBase):
self, model_path: Path, model_attributes: Optional[dict] = None, wait: bool = False
) -> ModelInstallJob:
"""
Add a model using its path, with a dictionary of attributes. Will fail with an
Add a model using its path, with a dictionary of attributes.
Will fail with an
assertion error if the name already exists.
"""
self.logger.debug(f"add/update model {model_path}")
@ -385,8 +387,11 @@ class ModelManagerService(ModelManagerServiceBase):
model_attributes: Optional[Dict[str, Any]] = None,
) -> ModelInstallJob:
"""
Add a model using its path, with a dictionary of attributes. Will fail with an
assertion error if the name already exists.
Add a model using a path, repo_id or URL.
:param model_attributes: Dictionary of ModelConfigBase fields to
attach to the model. When installing a URL or repo_id, some metadata
values, such as `tags` will be automagically added.
"""
self.logger.debug(f"add/update model {source}")
variant = "fp16" if self._loader.precision == "float16" else None
@ -402,17 +407,18 @@ class ModelManagerService(ModelManagerServiceBase):
new_config: Union[dict, ModelConfigBase],
) -> ModelConfigBase:
"""
Update the named model with a dictionary of attributes. Will fail with a
Update the named model with a dictionary of attributes.
Will fail with a
UnknownModelException if the name does not already exist.
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
the model key is unknown.
"""
model_info = self.model_info(key)
self.logger.debug(f"update model {model_info.name}")
self.logger.warning("TO DO: write code to move models around if base or type change")
return self._loader.store.update_model(key, new_config)
self.logger.debug(f"update model {key}")
new_info = self._loader.store.update_model(key, new_config)
return self._loader.installer.sync_model_path(new_info.key)
def del_model(
self,
@ -420,7 +426,9 @@ class ModelManagerService(ModelManagerServiceBase):
delete_files: bool = False,
):
"""
Delete the named model from configuration. If delete_files is true,
Delete the named model from configuration.
If delete_files is true,
then the underlying weight file or diffusers directory will be deleted
as well.
"""
@ -428,7 +436,7 @@ class ModelManagerService(ModelManagerServiceBase):
self.logger.debug(f"delete model {model_info.name}")
self._loader.store.del_model(key)
if delete_files and Path(model_info.path).exists():
path = Path(model_info)
path = Path(model_info.path)
if path.is_dir():
shutil.rmtree(path)
else:

View File

@ -52,7 +52,7 @@ import re
import tempfile
from abc import ABC, abstractmethod
from pathlib import Path
from shutil import rmtree
from shutil import rmtree, move
from typing import Any, Callable, Dict, List, Optional, Set, Union
from pydantic import Field
@ -61,7 +61,15 @@ from pydantic.networks import AnyHttpUrl
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util.logging import InvokeAILogger
from .config import BaseModelType, ModelFormat, ModelType, ModelVariantType, SchedulerPredictionType
from .config import (
ModelConfigBase,
BaseModelType,
ModelFormat,
ModelType,
ModelVariantType,
SchedulerPredictionType,
SubModelType,
)
from .download import (
HTTP_RE,
REPO_ID_RE,
@ -111,8 +119,8 @@ class ModelInstallBase(ABC):
@abstractmethod
def __init__(
self,
store: Optional[ModelConfigStore] = None,
config: Optional[InvokeAIAppConfig] = None,
store: Optional[ModelConfigStore] = None,
logger: Optional[InvokeAILogger] = None,
download: Optional[DownloadQueueBase] = None,
event_handlers: Optional[List[DownloadEventHandler]] = None,
@ -271,6 +279,20 @@ class ModelInstallBase(ABC):
"""
pass
@abstractmethod
def sync_model_path(self, key) -> Path:
"""
Move model into the location indicated by its basetype, type and name.
Call this after updating a model's attributes in order to move
the model's path into the location indicated by its basetype, type and
name. Applies only to models whose paths are within the root `models_dir`
directory.
May raise an UnknownModelException.
"""
pass
class ModelInstall(ModelInstallBase):
"""Model installer class handles installation from a local path."""
@ -372,18 +394,24 @@ class ModelInstall(ModelInstallBase):
info: ModelProbeInfo = self._probe_model(model_path, overrides)
dest_path = self._config.models_path / info.base_type.value / info.model_type.value / model_path.name
dest_path.parent.mkdir(parents=True, exist_ok=True)
return self._register(
self._move_model(model_path, dest_path),
info,
)
def _move_model(self, old_path: Path, new_path: Path) -> Path:
if old_path == new_path:
return old_path
new_path.parent.mkdir(parents=True, exist_ok=True)
# if path already exists then we jigger the name to make it unique
counter: int = 1
while dest_path.exists():
dest_path = dest_path.with_stem(dest_path.stem + f"_{counter:02d}")
while new_path.exists():
new_path = new_path.with_stem(new_path.stem + f"_{counter:02d}")
counter += 1
return old_path.replace(new_path)
return self._register(
model_path.replace(dest_path),
info,
)
def _probe_model(self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]] = None) -> ModelProbeInfo:
info: ModelProbeInfo = ModelProbe.probe(model_path)
@ -436,7 +464,7 @@ class ModelInstall(ModelInstallBase):
info.license = metadata.license
info.thumbnail_url = metadata.thumbnail_url
self._store.update_model(model_id, info)
self._async_installs[job.source] = model_id
self._async_installs[info.source] = model_id
job.model_key = model_id
elif job.status == "error":
self._logger.warning(f"{job.source}: Model installation error: {job.error}")
@ -455,13 +483,36 @@ class ModelInstall(ModelInstallBase):
info.source = str(job.source)
info.description = f"Imported model {info.name}"
self._store.update_model(model_id, info)
self._async_installs[job.source] = model_id
self._async_installs[info.source] = model_id
job.model_key = model_id
elif job.status == "error":
self._logger.warning(f"{job.source}: Model installation error: {job.error}")
elif job.status == "cancelled":
self._logger.warning(f"{job.source}: Model installation cancelled at caller's request.")
def sync_model_path(self, key) -> ModelConfigBase:
"""
Move model into the location indicated by its basetype, type and name.
Call this after updating a model's attributes in order to move
the model's path into the location indicated by its basetype, type and
name. Applies only to models whose paths are within the root `models_dir`
directory.
May raise an UnknownModelException.
"""
model = self._store.get_model(key)
old_path = Path(model.path)
models_dir = self._config.models_path
if not old_path.is_relative_to(models_dir):
return old_path
new_path = models_dir / model.base_model.value / model.model_type.value / model.name
model.path = self._move_model(old_path, new_path).as_posix()
self._store.update_model(key, model)
return model
def _make_download_job(
self,
source: Union[str, Path, AnyHttpUrl],
@ -487,7 +538,7 @@ class ModelInstall(ModelInstallBase):
cls = ModelInstallURLJob
kwargs = {}
else:
raise NotImplementedError(f"Don't know what to do with this type of source: {source}")
raise ValueError(f"'{source}' is not recognized as a local file, directory, repo_id or URL")
return cls(source=source, destination=Path(self._tmpdir.name), access_token=access_token, **kwargs)
def wait_for_installs(self) -> Dict[str, str]: # noqa D102
@ -515,6 +566,64 @@ class ModelInstall(ModelInstallBase):
def hash(self, model_path: Union[Path, str]) -> str: # noqa D102
return FastModelHash.hash(model_path)
def convert_model(
self,
key: str,
dest_directory: Optional[Path] = None,
) -> ModelConfigBase:
"""
Convert a checkpoint file into a diffusers folder, deleting the cached
version and deleting the original checkpoint file if it is in the models
directory.
:param key: Unique key of model.
:dest_directory: Optional place to put converted file. If not specified,
will be stored in the `models_dir`.
This will raise a ValueError unless the model is a checkpoint.
This will raise an UnknownModelException if key is unknown.
"""
from .loader import ModelInfo, ModelLoader # to avoid circular imports
try:
info: ModelConfigBase = self._store.get_model(key)
print(f'DEBUG: requested_model={info}')
if info.model_format != "checkpoint":
raise ValueError(f"not a checkpoint format model: {info.name}")
# We are taking advantage of a side effect of get_model() that converts check points
# into cached diffusers directories stored at `path`. It doesn't matter
# what submodel type we request here, so we get the smallest.
loader = ModelLoader(self._config)
submodel = {"submodel_type": SubModelType.Scheduler} if info.model_type == ModelType.Main else {}
converted_model: ModelInfo = loader.get_model(key, **submodel)
checkpoint_path = loader.resolve_model_path(info.path)
old_diffusers_path = loader.resolve_model_path(converted_model.location)
new_diffusers_path = None
if dest_directory:
new_diffusers_path = Path(dest_directory) / info.name
if new_diffusers_path.exists():
raise ValueError(f"A diffusers model already exists at {new_diffusers_path}")
move(old_diffusers_path, new_diffusers_path)
info.path = new_diffusers_path.as_posix()
info.pop("config")
info.model_format = "diffusers"
self._store.update_model(key, info)
result = self.sync_model_path(key)
except Exception:
# something went wrong, so don't leave dangling diffusers model in directory or it will cause a duplicate model error!
if new_diffusers_path:
rmtree(new_diffusers_path)
raise
if checkpoint_path.exists() and checkpoint_path.is_relative_to(self._config.models_path):
checkpoint_path.unlink()
return result
# the following two methods are callbacks to the ModelSearch object
def _scan_register(self, model: Path) -> bool:
try:

View File

@ -91,8 +91,12 @@ class ModelLoaderBase(ABC):
@abstractmethod
def collect_cache_stats(self, cache_stats: CacheStats):
"""Replace cache statistics."""
pass
pass
@abstractmethod
def resolve_model_path(self, path: Union[Path, str]) -> Path:
"""Turn a potentially relative path into an absolute one in the models_dir."""
pass
@property
@abstractmethod
@ -214,6 +218,9 @@ class ModelLoader(ModelLoaderBase):
the model to retrieve (e.g. ModelType.Vae)
"""
model_config = self.store.get_model(key) # May raise a UnknownModelException
if model_config.model_type == 'main' and not submodel_type:
raise InvalidModelException('submodel_type is required when loading a main model')
model_path, is_submodel_override = self._get_model_path(model_config, submodel_type)
if is_submodel_override:
@ -258,17 +265,17 @@ class ModelLoader(ModelLoaderBase):
def collect_cache_stats(self, cache_stats: CacheStats):
self._cache.stats = cache_stats
def resolve_model_path(self, path: Union[Path, str]) -> Path:
"""Turn a potentially relative path into an absolute one in the models_dir."""
return self._app_config.models_path / path
def _get_implementation(self, base_model: BaseModelType, model_type: ModelType) -> type[ModelBase]:
"""Get the concrete implementation class for a specific model type."""
model_class = MODEL_CLASSES[base_model][model_type]
return model_class
def _get_model_cache_path(self, model_path):
return self._resolve_model_path(Path(".cache") / hashlib.md5(str(model_path).encode()).hexdigest())
def _resolve_model_path(self, path: Union[Path, str]) -> Path:
"""Return relative paths based on configured models_path."""
return self._app_config.models_path / path
return self.resolve_model_path(Path(".cache") / hashlib.md5(str(model_path).encode()).hexdigest())
def _get_model_path(
self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None
@ -287,7 +294,7 @@ class ModelLoader(ModelLoaderBase):
model_path = getattr(model_config, submodel_type)
is_submodel_override = True
model_path = self._resolve_model_path(model_path)
model_path = self.resolve_model_path(model_path)
return model_path, is_submodel_override
def sync_to_config(self):
@ -301,7 +308,7 @@ class ModelLoader(ModelLoaderBase):
with Chdir(self._app_config.models_path):
self._logger.info("Checking for models that have been moved or deleted from disk.")
for model_config in self._store.all_models():
path = self._resolve_model_path(model_config.path)
path = self.resolve_model_path(model_config.path)
if not path.exists():
self._logger.info(f"{model_config.name}: path {path.as_posix()} no longer exists. Unregistering.")
defunct_models.add(model_config.key)
@ -311,6 +318,6 @@ class ModelLoader(ModelLoaderBase):
self._logger.info(f"Scanning {self._app_config.models_path} for new models")
for cur_base_model in BaseModelType:
for cur_model_type in ModelType:
models_dir = self._resolve_model_path(Path(cur_base_model.value, cur_model_type.value))
models_dir = self.resolve_model_path(Path(cur_base_model.value, cur_model_type.value))
installed.update(self._installer.scan_directory(models_dir))
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")

View File

@ -163,6 +163,7 @@ class ModelConfigStoreYAML(ModelConfigStore):
self._commit()
finally:
self._lock.release()
return self.get_model(key)
def get_model(self, key: str) -> ModelConfigBase:
"""