add route for model conversion from safetensors to diffusers

- Begin to add SwaggerUI documentation for AnyModelConfig and other
  discriminated Unions.
This commit is contained in:
Lincoln Stein 2024-02-12 21:25:42 -05:00
parent b71f53ba86
commit 433eb73d8e
7 changed files with 113 additions and 21 deletions

View File

@ -2,6 +2,7 @@
"""FastAPI route for model configuration records.""" """FastAPI route for model configuration records."""
import pathlib import pathlib
import shutil
from hashlib import sha1 from hashlib import sha1
from random import randbytes from random import randbytes
from typing import Any, Dict, List, Optional, Set from typing import Any, Dict, List, Optional, Set
@ -24,8 +25,10 @@ from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.config import (
AnyModelConfig, AnyModelConfig,
BaseModelType, BaseModelType,
MainCheckpointConfig,
ModelFormat, ModelFormat,
ModelType, ModelType,
SubModelType,
) )
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
@ -318,7 +321,7 @@ async def heuristic_import(
@model_manager_v2_router.post( @model_manager_v2_router.post(
"/import", "/install",
operation_id="import_model", operation_id="import_model",
responses={ responses={
201: {"description": "The model imported successfully"}, 201: {"description": "The model imported successfully"},
@ -490,6 +493,81 @@ async def sync_models_to_config() -> Response:
return Response(status_code=204) return Response(status_code=204)
@model_manager_v2_router.put(
"/convert/{key}",
operation_id="convert_model",
responses={
200: {"description": "Model converted successfully"},
400: {"description": "Bad request"},
404: {"description": "Model not found"},
409: {"description": "There is already a model registered at this location"},
},
)
async def convert_model(
key: str = Path(description="Unique key of the safetensors main model to convert to diffusers format."),
) -> AnyModelConfig:
"""
Permanently convert a model into diffusers format, replacing the safetensors version.
Note that the key and model hash will change. Use the model configuration record returned
by this call to get the new values.
"""
logger = ApiDependencies.invoker.services.logger
loader = ApiDependencies.invoker.services.model_manager.load
store = ApiDependencies.invoker.services.model_manager.store
installer = ApiDependencies.invoker.services.model_manager.install
try:
model_config = store.get_model(key)
except UnknownModelException as e:
logger.error(str(e))
raise HTTPException(status_code=424, detail=str(e))
if not isinstance(model_config, MainCheckpointConfig):
logger.error(f"The model with key {key} is not a main checkpoint model.")
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
# loading the model will convert it into a cached diffusers file
loader.load_model_by_config(model_config, submodel_type=SubModelType.Scheduler)
# Get the path of the converted model from the loader
cache_path = loader.convert_cache.cache_path(key)
assert cache_path.exists()
# temporarily rename the original safetensors file so that there is no naming conflict
original_name = model_config.name
model_config.name = f"{original_name}.DELETE"
store.update_model(key, config=model_config)
# install the diffusers
try:
new_key = installer.install_path(
cache_path,
config={
"name": original_name,
"description": model_config.description,
"original_hash": model_config.original_hash,
"source": model_config.source,
},
)
except DuplicateModelException as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
# get the original metadata
if orig_metadata := store.get_metadata(key):
store.metadata_store.add_metadata(new_key, orig_metadata)
# delete the original safetensors file
installer.delete(key)
# delete the cached version
shutil.rmtree(cache_path)
# return the config record for the new diffusers directory
new_config: AnyModelConfig = store.get_model(new_key)
return new_config
@model_manager_v2_router.put( @model_manager_v2_router.put(
"/merge", "/merge",
operation_id="merge", operation_id="merge",

View File

@ -162,8 +162,10 @@ class ModelInstallService(ModelInstallServiceBase):
config["source"] = model_path.resolve().as_posix() config["source"] = model_path.resolve().as_posix()
info: AnyModelConfig = self._probe_model(Path(model_path), config) info: AnyModelConfig = self._probe_model(Path(model_path), config)
old_hash = info.original_hash old_hash = info.current_hash
dest_path = self.app_config.models_path / info.base.value / info.type.value / model_path.name dest_path = (
self.app_config.models_path / info.base.value / info.type.value / (config.get("name") or model_path.name)
)
try: try:
new_path = self._copy_model(model_path, dest_path) new_path = self._copy_model(model_path, dest_path)
except FileExistsError as excp: except FileExistsError as excp:

View File

@ -5,8 +5,10 @@ from abc import ABC, abstractmethod
from typing import Optional from typing import Optional
from invokeai.app.invocations.baseinvocation import InvocationContext from invokeai.app.invocations.baseinvocation import InvocationContext
from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, ModelType, SubModelType from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager.load import LoadedModel from invokeai.backend.model_manager.load import LoadedModel
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
class ModelLoadServiceBase(ABC): class ModelLoadServiceBase(ABC):
@ -70,3 +72,13 @@ class ModelLoadServiceBase(ABC):
NotImplementedException -- a model loader was not provided at initialization time NotImplementedException -- a model loader was not provided at initialization time
ValueError -- more than one model matches this combination ValueError -- more than one model matches this combination
""" """
@property
@abstractmethod
def ram_cache(self) -> ModelCacheBase[AnyModel]:
"""Return the RAM cache used by this loader."""
@property
@abstractmethod
def convert_cache(self) -> ModelConvertCacheBase:
"""Return the checkpoint convert cache used by this loader."""

View File

@ -10,7 +10,7 @@ from invokeai.app.services.model_records import ModelRecordServiceBase, UnknownM
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel, ModelCache, ModelConvertCache from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel, ModelCache, ModelConvertCache
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.model_cache import ModelCacheBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from .model_load_base import ModelLoadServiceBase from .model_load_base import ModelLoadServiceBase
@ -46,6 +46,16 @@ class ModelLoadService(ModelLoadServiceBase):
), ),
) )
@property
def ram_cache(self) -> ModelCacheBase[AnyModel]:
"""Return the RAM cache used by this loader."""
return self._any_loader.ram_cache
@property
def convert_cache(self) -> ModelConvertCacheBase:
"""Return the checkpoint convert cache used by this loader."""
return self._any_loader.convert_cache
def load_model_by_key( def load_model_by_key(
self, self,
key: str, key: str,

View File

@ -17,7 +17,6 @@ from invokeai.backend.model_manager import (
ModelFormat, ModelFormat,
ModelType, ModelType,
) )
from invokeai.backend.model_manager.load import AnyModelLoader
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore
@ -195,12 +194,6 @@ class ModelRecordServiceBase(ABC):
""" """
pass pass
@property
@abstractmethod
def loader(self) -> Optional[AnyModelLoader]:
"""Return the model loader used by this instance."""
pass
def all_models(self) -> List[AnyModelConfig]: def all_models(self) -> List[AnyModelConfig]:
"""Return all the model configs in the database.""" """Return all the model configs in the database."""
return self.search_by_attr() return self.search_by_attr()

View File

@ -54,7 +54,6 @@ from invokeai.backend.model_manager.config import (
ModelFormat, ModelFormat,
ModelType, ModelType,
) )
from invokeai.backend.model_manager.load import AnyModelLoader
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException
from ..shared.sqlite.sqlite_database import SqliteDatabase from ..shared.sqlite.sqlite_database import SqliteDatabase
@ -70,28 +69,21 @@ from .model_records_base import (
class ModelRecordServiceSQL(ModelRecordServiceBase): class ModelRecordServiceSQL(ModelRecordServiceBase):
"""Implementation of the ModelConfigStore ABC using a SQL database.""" """Implementation of the ModelConfigStore ABC using a SQL database."""
def __init__(self, db: SqliteDatabase, loader: Optional[AnyModelLoader] = None): def __init__(self, db: SqliteDatabase):
""" """
Initialize a new object from preexisting sqlite3 connection and threading lock objects. Initialize a new object from preexisting sqlite3 connection and threading lock objects.
:param db: Sqlite connection object :param db: Sqlite connection object
:param loader: Initialized model loader object (optional)
""" """
super().__init__() super().__init__()
self._db = db self._db = db
self._cursor = db.conn.cursor() self._cursor = db.conn.cursor()
self._loader = loader
@property @property
def db(self) -> SqliteDatabase: def db(self) -> SqliteDatabase:
"""Return the underlying database.""" """Return the underlying database."""
return self._db return self._db
@property
def loader(self) -> Optional[AnyModelLoader]:
"""Return the model loader used by this instance."""
return self._loader
def add_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig: def add_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
""" """
Add a model to the database. Add a model to the database.

View File

@ -117,6 +117,11 @@ class AnyModelLoader:
"""Return the RAM cache associated used by the loaders.""" """Return the RAM cache associated used by the loaders."""
return self._ram_cache return self._ram_cache
@property
def convert_cache(self) -> ModelConvertCacheBase:
"""Return the convert cache associated used by the loaders."""
return self._convert_cache
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
""" """
Return a model given its configuration. Return a model given its configuration.