mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add route for model conversion from safetensors to diffusers
- Begin to add SwaggerUI documentation for AnyModelConfig and other discriminated Unions.
This commit is contained in:
parent
93fb2d1a55
commit
6e91d5baaf
@ -2,6 +2,7 @@
|
|||||||
"""FastAPI route for model configuration records."""
|
"""FastAPI route for model configuration records."""
|
||||||
|
|
||||||
import pathlib
|
import pathlib
|
||||||
|
import shutil
|
||||||
from hashlib import sha1
|
from hashlib import sha1
|
||||||
from random import randbytes
|
from random import randbytes
|
||||||
from typing import Any, Dict, List, Optional, Set
|
from typing import Any, Dict, List, Optional, Set
|
||||||
@ -24,8 +25,10 @@ from invokeai.app.services.shared.pagination import PaginatedResults
|
|||||||
from invokeai.backend.model_manager.config import (
|
from invokeai.backend.model_manager.config import (
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
|
MainCheckpointConfig,
|
||||||
ModelFormat,
|
ModelFormat,
|
||||||
ModelType,
|
ModelType,
|
||||||
|
SubModelType,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
|
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
|
||||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||||
@ -318,7 +321,7 @@ async def heuristic_import(
|
|||||||
|
|
||||||
|
|
||||||
@model_manager_v2_router.post(
|
@model_manager_v2_router.post(
|
||||||
"/import",
|
"/install",
|
||||||
operation_id="import_model",
|
operation_id="import_model",
|
||||||
responses={
|
responses={
|
||||||
201: {"description": "The model imported successfully"},
|
201: {"description": "The model imported successfully"},
|
||||||
@ -490,6 +493,81 @@ async def sync_models_to_config() -> Response:
|
|||||||
return Response(status_code=204)
|
return Response(status_code=204)
|
||||||
|
|
||||||
|
|
||||||
|
@model_manager_v2_router.put(
|
||||||
|
"/convert/{key}",
|
||||||
|
operation_id="convert_model",
|
||||||
|
responses={
|
||||||
|
200: {"description": "Model converted successfully"},
|
||||||
|
400: {"description": "Bad request"},
|
||||||
|
404: {"description": "Model not found"},
|
||||||
|
409: {"description": "There is already a model registered at this location"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def convert_model(
|
||||||
|
key: str = Path(description="Unique key of the safetensors main model to convert to diffusers format."),
|
||||||
|
) -> AnyModelConfig:
|
||||||
|
"""
|
||||||
|
Permanently convert a model into diffusers format, replacing the safetensors version.
|
||||||
|
Note that the key and model hash will change. Use the model configuration record returned
|
||||||
|
by this call to get the new values.
|
||||||
|
"""
|
||||||
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
loader = ApiDependencies.invoker.services.model_manager.load
|
||||||
|
store = ApiDependencies.invoker.services.model_manager.store
|
||||||
|
installer = ApiDependencies.invoker.services.model_manager.install
|
||||||
|
|
||||||
|
try:
|
||||||
|
model_config = store.get_model(key)
|
||||||
|
except UnknownModelException as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
raise HTTPException(status_code=424, detail=str(e))
|
||||||
|
|
||||||
|
if not isinstance(model_config, MainCheckpointConfig):
|
||||||
|
logger.error(f"The model with key {key} is not a main checkpoint model.")
|
||||||
|
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
|
||||||
|
|
||||||
|
# loading the model will convert it into a cached diffusers file
|
||||||
|
loader.load_model_by_config(model_config, submodel_type=SubModelType.Scheduler)
|
||||||
|
|
||||||
|
# Get the path of the converted model from the loader
|
||||||
|
cache_path = loader.convert_cache.cache_path(key)
|
||||||
|
assert cache_path.exists()
|
||||||
|
|
||||||
|
# temporarily rename the original safetensors file so that there is no naming conflict
|
||||||
|
original_name = model_config.name
|
||||||
|
model_config.name = f"{original_name}.DELETE"
|
||||||
|
store.update_model(key, config=model_config)
|
||||||
|
|
||||||
|
# install the diffusers
|
||||||
|
try:
|
||||||
|
new_key = installer.install_path(
|
||||||
|
cache_path,
|
||||||
|
config={
|
||||||
|
"name": original_name,
|
||||||
|
"description": model_config.description,
|
||||||
|
"original_hash": model_config.original_hash,
|
||||||
|
"source": model_config.source,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except DuplicateModelException as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
raise HTTPException(status_code=409, detail=str(e))
|
||||||
|
|
||||||
|
# get the original metadata
|
||||||
|
if orig_metadata := store.get_metadata(key):
|
||||||
|
store.metadata_store.add_metadata(new_key, orig_metadata)
|
||||||
|
|
||||||
|
# delete the original safetensors file
|
||||||
|
installer.delete(key)
|
||||||
|
|
||||||
|
# delete the cached version
|
||||||
|
shutil.rmtree(cache_path)
|
||||||
|
|
||||||
|
# return the config record for the new diffusers directory
|
||||||
|
new_config: AnyModelConfig = store.get_model(new_key)
|
||||||
|
return new_config
|
||||||
|
|
||||||
|
|
||||||
@model_manager_v2_router.put(
|
@model_manager_v2_router.put(
|
||||||
"/merge",
|
"/merge",
|
||||||
operation_id="merge",
|
operation_id="merge",
|
||||||
|
@ -162,8 +162,10 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
config["source"] = model_path.resolve().as_posix()
|
config["source"] = model_path.resolve().as_posix()
|
||||||
|
|
||||||
info: AnyModelConfig = self._probe_model(Path(model_path), config)
|
info: AnyModelConfig = self._probe_model(Path(model_path), config)
|
||||||
old_hash = info.original_hash
|
old_hash = info.current_hash
|
||||||
dest_path = self.app_config.models_path / info.base.value / info.type.value / model_path.name
|
dest_path = (
|
||||||
|
self.app_config.models_path / info.base.value / info.type.value / (config.get("name") or model_path.name)
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
new_path = self._copy_model(model_path, dest_path)
|
new_path = self._copy_model(model_path, dest_path)
|
||||||
except FileExistsError as excp:
|
except FileExistsError as excp:
|
||||||
|
@ -5,8 +5,10 @@ from abc import ABC, abstractmethod
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import InvocationContext
|
from invokeai.app.invocations.baseinvocation import InvocationContext
|
||||||
from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, ModelType, SubModelType
|
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType
|
||||||
from invokeai.backend.model_manager.load import LoadedModel
|
from invokeai.backend.model_manager.load import LoadedModel
|
||||||
|
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||||
|
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||||
|
|
||||||
|
|
||||||
class ModelLoadServiceBase(ABC):
|
class ModelLoadServiceBase(ABC):
|
||||||
@ -70,3 +72,13 @@ class ModelLoadServiceBase(ABC):
|
|||||||
NotImplementedException -- a model loader was not provided at initialization time
|
NotImplementedException -- a model loader was not provided at initialization time
|
||||||
ValueError -- more than one model matches this combination
|
ValueError -- more than one model matches this combination
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||||
|
"""Return the RAM cache used by this loader."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def convert_cache(self) -> ModelConvertCacheBase:
|
||||||
|
"""Return the checkpoint convert cache used by this loader."""
|
||||||
|
@ -10,7 +10,7 @@ from invokeai.app.services.model_records import ModelRecordServiceBase, UnknownM
|
|||||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType
|
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType
|
||||||
from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel, ModelCache, ModelConvertCache
|
from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel, ModelCache, ModelConvertCache
|
||||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||||
from invokeai.backend.model_manager.load.model_cache import ModelCacheBase
|
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
from .model_load_base import ModelLoadServiceBase
|
from .model_load_base import ModelLoadServiceBase
|
||||||
@ -46,6 +46,16 @@ class ModelLoadService(ModelLoadServiceBase):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||||
|
"""Return the RAM cache used by this loader."""
|
||||||
|
return self._any_loader.ram_cache
|
||||||
|
|
||||||
|
@property
|
||||||
|
def convert_cache(self) -> ModelConvertCacheBase:
|
||||||
|
"""Return the checkpoint convert cache used by this loader."""
|
||||||
|
return self._any_loader.convert_cache
|
||||||
|
|
||||||
def load_model_by_key(
|
def load_model_by_key(
|
||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
|
@ -17,7 +17,6 @@ from invokeai.backend.model_manager import (
|
|||||||
ModelFormat,
|
ModelFormat,
|
||||||
ModelType,
|
ModelType,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.load import AnyModelLoader
|
|
||||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore
|
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore
|
||||||
|
|
||||||
|
|
||||||
@ -195,12 +194,6 @@ class ModelRecordServiceBase(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def loader(self) -> Optional[AnyModelLoader]:
|
|
||||||
"""Return the model loader used by this instance."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def all_models(self) -> List[AnyModelConfig]:
|
def all_models(self) -> List[AnyModelConfig]:
|
||||||
"""Return all the model configs in the database."""
|
"""Return all the model configs in the database."""
|
||||||
return self.search_by_attr()
|
return self.search_by_attr()
|
||||||
|
@ -54,7 +54,6 @@ from invokeai.backend.model_manager.config import (
|
|||||||
ModelFormat,
|
ModelFormat,
|
||||||
ModelType,
|
ModelType,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.load import AnyModelLoader
|
|
||||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException
|
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException
|
||||||
|
|
||||||
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
||||||
@ -70,28 +69,21 @@ from .model_records_base import (
|
|||||||
class ModelRecordServiceSQL(ModelRecordServiceBase):
|
class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||||
"""Implementation of the ModelConfigStore ABC using a SQL database."""
|
"""Implementation of the ModelConfigStore ABC using a SQL database."""
|
||||||
|
|
||||||
def __init__(self, db: SqliteDatabase, loader: Optional[AnyModelLoader] = None):
|
def __init__(self, db: SqliteDatabase):
|
||||||
"""
|
"""
|
||||||
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
||||||
|
|
||||||
:param db: Sqlite connection object
|
:param db: Sqlite connection object
|
||||||
:param loader: Initialized model loader object (optional)
|
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._db = db
|
self._db = db
|
||||||
self._cursor = db.conn.cursor()
|
self._cursor = db.conn.cursor()
|
||||||
self._loader = loader
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def db(self) -> SqliteDatabase:
|
def db(self) -> SqliteDatabase:
|
||||||
"""Return the underlying database."""
|
"""Return the underlying database."""
|
||||||
return self._db
|
return self._db
|
||||||
|
|
||||||
@property
|
|
||||||
def loader(self) -> Optional[AnyModelLoader]:
|
|
||||||
"""Return the model loader used by this instance."""
|
|
||||||
return self._loader
|
|
||||||
|
|
||||||
def add_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
|
def add_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
|
||||||
"""
|
"""
|
||||||
Add a model to the database.
|
Add a model to the database.
|
||||||
|
@ -117,6 +117,11 @@ class AnyModelLoader:
|
|||||||
"""Return the RAM cache associated used by the loaders."""
|
"""Return the RAM cache associated used by the loaders."""
|
||||||
return self._ram_cache
|
return self._ram_cache
|
||||||
|
|
||||||
|
@property
|
||||||
|
def convert_cache(self) -> ModelConvertCacheBase:
|
||||||
|
"""Return the convert cache associated used by the loaders."""
|
||||||
|
return self._convert_cache
|
||||||
|
|
||||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||||
"""
|
"""
|
||||||
Return a model given its configuration.
|
Return a model given its configuration.
|
||||||
|
Loading…
Reference in New Issue
Block a user