add route for model conversion from safetensors to diffusers

- Begin to add SwaggerUI documentation for AnyModelConfig and other
  discriminated Unions.
This commit is contained in:
Lincoln Stein 2024-02-12 21:25:42 -05:00 committed by Brandon Rising
parent 93fb2d1a55
commit 6e91d5baaf
7 changed files with 113 additions and 21 deletions

View File

@ -2,6 +2,7 @@
"""FastAPI route for model configuration records."""
import pathlib
import shutil
from hashlib import sha1
from random import randbytes
from typing import Any, Dict, List, Optional, Set
@ -24,8 +25,10 @@ from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
MainCheckpointConfig,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
@ -318,7 +321,7 @@ async def heuristic_import(
@model_manager_v2_router.post(
"/import",
"/install",
operation_id="import_model",
responses={
201: {"description": "The model imported successfully"},
@ -490,6 +493,81 @@ async def sync_models_to_config() -> Response:
return Response(status_code=204)
@model_manager_v2_router.put(
"/convert/{key}",
operation_id="convert_model",
responses={
200: {"description": "Model converted successfully"},
400: {"description": "Bad request"},
404: {"description": "Model not found"},
409: {"description": "There is already a model registered at this location"},
},
)
async def convert_model(
key: str = Path(description="Unique key of the safetensors main model to convert to diffusers format."),
) -> AnyModelConfig:
"""
Permanently convert a model into diffusers format, replacing the safetensors version.
Note that the key and model hash will change. Use the model configuration record returned
by this call to get the new values.
"""
logger = ApiDependencies.invoker.services.logger
loader = ApiDependencies.invoker.services.model_manager.load
store = ApiDependencies.invoker.services.model_manager.store
installer = ApiDependencies.invoker.services.model_manager.install
try:
model_config = store.get_model(key)
except UnknownModelException as e:
logger.error(str(e))
raise HTTPException(status_code=424, detail=str(e))
if not isinstance(model_config, MainCheckpointConfig):
logger.error(f"The model with key {key} is not a main checkpoint model.")
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
# loading the model will convert it into a cached diffusers file
loader.load_model_by_config(model_config, submodel_type=SubModelType.Scheduler)
# Get the path of the converted model from the loader
cache_path = loader.convert_cache.cache_path(key)
assert cache_path.exists()
# temporarily rename the original safetensors file so that there is no naming conflict
original_name = model_config.name
model_config.name = f"{original_name}.DELETE"
store.update_model(key, config=model_config)
# install the diffusers
try:
new_key = installer.install_path(
cache_path,
config={
"name": original_name,
"description": model_config.description,
"original_hash": model_config.original_hash,
"source": model_config.source,
},
)
except DuplicateModelException as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
# get the original metadata
if orig_metadata := store.get_metadata(key):
store.metadata_store.add_metadata(new_key, orig_metadata)
# delete the original safetensors file
installer.delete(key)
# delete the cached version
shutil.rmtree(cache_path)
# return the config record for the new diffusers directory
new_config: AnyModelConfig = store.get_model(new_key)
return new_config
@model_manager_v2_router.put(
"/merge",
operation_id="merge",

View File

@ -162,8 +162,10 @@ class ModelInstallService(ModelInstallServiceBase):
config["source"] = model_path.resolve().as_posix()
info: AnyModelConfig = self._probe_model(Path(model_path), config)
old_hash = info.original_hash
dest_path = self.app_config.models_path / info.base.value / info.type.value / model_path.name
old_hash = info.current_hash
dest_path = (
self.app_config.models_path / info.base.value / info.type.value / (config.get("name") or model_path.name)
)
try:
new_path = self._copy_model(model_path, dest_path)
except FileExistsError as excp:

View File

@ -5,8 +5,10 @@ from abc import ABC, abstractmethod
from typing import Optional
from invokeai.app.invocations.baseinvocation import InvocationContext
from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager.load import LoadedModel
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
class ModelLoadServiceBase(ABC):
@ -70,3 +72,13 @@ class ModelLoadServiceBase(ABC):
NotImplementedException -- a model loader was not provided at initialization time
ValueError -- more than one model matches this combination
"""
@property
@abstractmethod
def ram_cache(self) -> ModelCacheBase[AnyModel]:
"""Return the RAM cache used by this loader."""
@property
@abstractmethod
def convert_cache(self) -> ModelConvertCacheBase:
"""Return the checkpoint convert cache used by this loader."""

View File

@ -10,7 +10,7 @@ from invokeai.app.services.model_records import ModelRecordServiceBase, UnknownM
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel, ModelCache, ModelConvertCache
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.model_cache import ModelCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
from invokeai.backend.util.logging import InvokeAILogger
from .model_load_base import ModelLoadServiceBase
@ -46,6 +46,16 @@ class ModelLoadService(ModelLoadServiceBase):
),
)
@property
def ram_cache(self) -> ModelCacheBase[AnyModel]:
"""Return the RAM cache used by this loader."""
return self._any_loader.ram_cache
@property
def convert_cache(self) -> ModelConvertCacheBase:
"""Return the checkpoint convert cache used by this loader."""
return self._any_loader.convert_cache
def load_model_by_key(
self,
key: str,

View File

@ -17,7 +17,6 @@ from invokeai.backend.model_manager import (
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.load import AnyModelLoader
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore
@ -195,12 +194,6 @@ class ModelRecordServiceBase(ABC):
"""
pass
@property
@abstractmethod
def loader(self) -> Optional[AnyModelLoader]:
"""Return the model loader used by this instance."""
pass
def all_models(self) -> List[AnyModelConfig]:
"""Return all the model configs in the database."""
return self.search_by_attr()

View File

@ -54,7 +54,6 @@ from invokeai.backend.model_manager.config import (
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.load import AnyModelLoader
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException
from ..shared.sqlite.sqlite_database import SqliteDatabase
@ -70,28 +69,21 @@ from .model_records_base import (
class ModelRecordServiceSQL(ModelRecordServiceBase):
"""Implementation of the ModelConfigStore ABC using a SQL database."""
def __init__(self, db: SqliteDatabase, loader: Optional[AnyModelLoader] = None):
def __init__(self, db: SqliteDatabase):
"""
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
:param db: Sqlite connection object
:param loader: Initialized model loader object (optional)
"""
super().__init__()
self._db = db
self._cursor = db.conn.cursor()
self._loader = loader
@property
def db(self) -> SqliteDatabase:
"""Return the underlying database."""
return self._db
@property
def loader(self) -> Optional[AnyModelLoader]:
"""Return the model loader used by this instance."""
return self._loader
def add_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
"""
Add a model to the database.

View File

@ -117,6 +117,11 @@ class AnyModelLoader:
"""Return the RAM cache associated used by the loaders."""
return self._ram_cache
@property
def convert_cache(self) -> ModelConvertCacheBase:
"""Return the convert cache associated used by the loaders."""
return self._convert_cache
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""
Return a model given its configuration.