mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
tidy(api): tidy mm routes
Rename MM routes to be consistent: - "import" -> "install" - "model_record" -> "model" Comment several unused routes while I work (may end up removing them?): - list model summary (we use the search route instead) - add model record - convert model - merge models
This commit is contained in:
parent
78895b3e80
commit
4f9bb00275
@ -2,8 +2,7 @@
|
|||||||
"""FastAPI route for model configuration records."""
|
"""FastAPI route for model configuration records."""
|
||||||
|
|
||||||
import pathlib
|
import pathlib
|
||||||
import shutil
|
from typing import Any, Dict, List, Optional
|
||||||
from typing import Any, Dict, List, Optional, Set
|
|
||||||
|
|
||||||
from fastapi import Body, Path, Query, Response
|
from fastapi import Body, Path, Query, Response
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
@ -13,23 +12,16 @@ from typing_extensions import Annotated
|
|||||||
|
|
||||||
from invokeai.app.services.model_install import ModelInstallJob
|
from invokeai.app.services.model_install import ModelInstallJob
|
||||||
from invokeai.app.services.model_records import (
|
from invokeai.app.services.model_records import (
|
||||||
DuplicateModelException,
|
|
||||||
InvalidModelException,
|
InvalidModelException,
|
||||||
ModelRecordOrderBy,
|
|
||||||
ModelSummary,
|
|
||||||
UnknownModelException,
|
UnknownModelException,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
||||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
|
||||||
from invokeai.backend.model_manager.config import (
|
from invokeai.backend.model_manager.config import (
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
MainCheckpointConfig,
|
|
||||||
ModelFormat,
|
ModelFormat,
|
||||||
ModelType,
|
ModelType,
|
||||||
SubModelType,
|
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
|
|
||||||
from invokeai.backend.model_manager.search import ModelSearch
|
from invokeai.backend.model_manager.search import ModelSearch
|
||||||
|
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
@ -45,15 +37,6 @@ class ModelsList(BaseModel):
|
|||||||
model_config = ConfigDict(use_enum_values=True)
|
model_config = ConfigDict(use_enum_values=True)
|
||||||
|
|
||||||
|
|
||||||
class ModelTagSet(BaseModel):
|
|
||||||
"""Return tags for a set of models."""
|
|
||||||
|
|
||||||
key: str
|
|
||||||
name: str
|
|
||||||
author: str
|
|
||||||
tags: Set[str]
|
|
||||||
|
|
||||||
|
|
||||||
##############################################################################
|
##############################################################################
|
||||||
# These are example inputs and outputs that are used in places where Swagger
|
# These are example inputs and outputs that are used in places where Swagger
|
||||||
# is unable to generate a correct example.
|
# is unable to generate a correct example.
|
||||||
@ -167,16 +150,16 @@ async def get_model_record(
|
|||||||
raise HTTPException(status_code=404, detail=str(e))
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
@model_manager_router.get("/summary", operation_id="list_model_summary")
|
# @model_manager_router.get("/summary", operation_id="list_model_summary")
|
||||||
async def list_model_summary(
|
# async def list_model_summary(
|
||||||
page: int = Query(default=0, description="The page to get"),
|
# page: int = Query(default=0, description="The page to get"),
|
||||||
per_page: int = Query(default=10, description="The number of models per page"),
|
# per_page: int = Query(default=10, description="The number of models per page"),
|
||||||
order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
|
# order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
|
||||||
) -> PaginatedResults[ModelSummary]:
|
# ) -> PaginatedResults[ModelSummary]:
|
||||||
"""Gets a page of model summary data."""
|
# """Gets a page of model summary data."""
|
||||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
# record_store = ApiDependencies.invoker.services.model_manager.store
|
||||||
results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
|
# results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
|
||||||
return results
|
# return results
|
||||||
|
|
||||||
|
|
||||||
class FoundModel(BaseModel):
|
class FoundModel(BaseModel):
|
||||||
@ -282,14 +265,14 @@ async def update_model_record(
|
|||||||
|
|
||||||
@model_manager_router.delete(
|
@model_manager_router.delete(
|
||||||
"/i/{key}",
|
"/i/{key}",
|
||||||
operation_id="del_model_record",
|
operation_id="delete_model",
|
||||||
responses={
|
responses={
|
||||||
204: {"description": "Model deleted successfully"},
|
204: {"description": "Model deleted successfully"},
|
||||||
404: {"description": "Model not found"},
|
404: {"description": "Model not found"},
|
||||||
},
|
},
|
||||||
status_code=204,
|
status_code=204,
|
||||||
)
|
)
|
||||||
async def del_model_record(
|
async def delete_model(
|
||||||
key: str = Path(description="Unique key of model to remove from model registry."),
|
key: str = Path(description="Unique key of model to remove from model registry."),
|
||||||
) -> Response:
|
) -> Response:
|
||||||
"""
|
"""
|
||||||
@ -310,39 +293,39 @@ async def del_model_record(
|
|||||||
raise HTTPException(status_code=404, detail=str(e))
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
@model_manager_router.post(
|
# @model_manager_router.post(
|
||||||
"/i/",
|
# "/i/",
|
||||||
operation_id="add_model_record",
|
# operation_id="add_model_record",
|
||||||
responses={
|
# responses={
|
||||||
201: {
|
# 201: {
|
||||||
"description": "The model added successfully",
|
# "description": "The model added successfully",
|
||||||
"content": {"application/json": {"example": example_model_config}},
|
# "content": {"application/json": {"example": example_model_config}},
|
||||||
},
|
# },
|
||||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
# 409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||||
415: {"description": "Unrecognized file/folder format"},
|
# 415: {"description": "Unrecognized file/folder format"},
|
||||||
},
|
# },
|
||||||
status_code=201,
|
# status_code=201,
|
||||||
)
|
# )
|
||||||
async def add_model_record(
|
# async def add_model_record(
|
||||||
config: Annotated[
|
# config: Annotated[
|
||||||
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
|
# AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
|
||||||
],
|
# ],
|
||||||
) -> AnyModelConfig:
|
# ) -> AnyModelConfig:
|
||||||
"""Add a model using the configuration information appropriate for its type."""
|
# """Add a model using the configuration information appropriate for its type."""
|
||||||
logger = ApiDependencies.invoker.services.logger
|
# logger = ApiDependencies.invoker.services.logger
|
||||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
# record_store = ApiDependencies.invoker.services.model_manager.store
|
||||||
try:
|
# try:
|
||||||
record_store.add_model(config)
|
# record_store.add_model(config)
|
||||||
except DuplicateModelException as e:
|
# except DuplicateModelException as e:
|
||||||
logger.error(str(e))
|
# logger.error(str(e))
|
||||||
raise HTTPException(status_code=409, detail=str(e))
|
# raise HTTPException(status_code=409, detail=str(e))
|
||||||
except InvalidModelException as e:
|
# except InvalidModelException as e:
|
||||||
logger.error(str(e))
|
# logger.error(str(e))
|
||||||
raise HTTPException(status_code=415)
|
# raise HTTPException(status_code=415)
|
||||||
|
|
||||||
# now fetch it out
|
# # now fetch it out
|
||||||
result: AnyModelConfig = record_store.get_model(config.key)
|
# result: AnyModelConfig = record_store.get_model(config.key)
|
||||||
return result
|
# return result
|
||||||
|
|
||||||
|
|
||||||
@model_manager_router.post(
|
@model_manager_router.post(
|
||||||
@ -417,10 +400,10 @@ async def install_model(
|
|||||||
|
|
||||||
|
|
||||||
@model_manager_router.get(
|
@model_manager_router.get(
|
||||||
"/import",
|
"/install",
|
||||||
operation_id="list_model_install_jobs",
|
operation_id="list_model_installs",
|
||||||
)
|
)
|
||||||
async def list_model_install_jobs() -> List[ModelInstallJob]:
|
async def list_model_installs() -> List[ModelInstallJob]:
|
||||||
"""Return the list of model install jobs.
|
"""Return the list of model install jobs.
|
||||||
|
|
||||||
Install jobs have a numeric `id`, a `status`, and other fields that provide information on
|
Install jobs have a numeric `id`, a `status`, and other fields that provide information on
|
||||||
@ -444,7 +427,7 @@ async def list_model_install_jobs() -> List[ModelInstallJob]:
|
|||||||
|
|
||||||
|
|
||||||
@model_manager_router.get(
|
@model_manager_router.get(
|
||||||
"/import/{id}",
|
"/install/{id}",
|
||||||
operation_id="get_model_install_job",
|
operation_id="get_model_install_job",
|
||||||
responses={
|
responses={
|
||||||
200: {"description": "Success"},
|
200: {"description": "Success"},
|
||||||
@ -464,7 +447,7 @@ async def get_model_install_job(id: int = Path(description="Model install id"))
|
|||||||
|
|
||||||
|
|
||||||
@model_manager_router.delete(
|
@model_manager_router.delete(
|
||||||
"/import/{id}",
|
"/install/{id}",
|
||||||
operation_id="cancel_model_install_job",
|
operation_id="cancel_model_install_job",
|
||||||
responses={
|
responses={
|
||||||
201: {"description": "The job was cancelled successfully"},
|
201: {"description": "The job was cancelled successfully"},
|
||||||
@ -483,7 +466,7 @@ async def cancel_model_install_job(id: int = Path(description="Model install job
|
|||||||
|
|
||||||
|
|
||||||
@model_manager_router.patch(
|
@model_manager_router.patch(
|
||||||
"/import",
|
"/install/prune",
|
||||||
operation_id="prune_model_install_jobs",
|
operation_id="prune_model_install_jobs",
|
||||||
responses={
|
responses={
|
||||||
204: {"description": "All completed and errored jobs have been pruned"},
|
204: {"description": "All completed and errored jobs have been pruned"},
|
||||||
@ -515,142 +498,142 @@ async def sync_models_to_config() -> Response:
|
|||||||
return Response(status_code=204)
|
return Response(status_code=204)
|
||||||
|
|
||||||
|
|
||||||
@model_manager_router.put(
|
# @model_manager_router.put(
|
||||||
"/convert/{key}",
|
# "/convert/{key}",
|
||||||
operation_id="convert_model",
|
# operation_id="convert_model",
|
||||||
responses={
|
# responses={
|
||||||
200: {
|
# 200: {
|
||||||
"description": "Model converted successfully",
|
# "description": "Model converted successfully",
|
||||||
"content": {"application/json": {"example": example_model_config}},
|
# "content": {"application/json": {"example": example_model_config}},
|
||||||
},
|
# },
|
||||||
400: {"description": "Bad request"},
|
# 400: {"description": "Bad request"},
|
||||||
404: {"description": "Model not found"},
|
# 404: {"description": "Model not found"},
|
||||||
409: {"description": "There is already a model registered at this location"},
|
# 409: {"description": "There is already a model registered at this location"},
|
||||||
},
|
# },
|
||||||
)
|
# )
|
||||||
async def convert_model(
|
# async def convert_model(
|
||||||
key: str = Path(description="Unique key of the safetensors main model to convert to diffusers format."),
|
# key: str = Path(description="Unique key of the safetensors main model to convert to diffusers format."),
|
||||||
) -> AnyModelConfig:
|
# ) -> AnyModelConfig:
|
||||||
"""
|
# """
|
||||||
Permanently convert a model into diffusers format, replacing the safetensors version.
|
# Permanently convert a model into diffusers format, replacing the safetensors version.
|
||||||
Note that during the conversion process the key and model hash will change.
|
# Note that during the conversion process the key and model hash will change.
|
||||||
The return value is the model configuration for the converted model.
|
# The return value is the model configuration for the converted model.
|
||||||
"""
|
# """
|
||||||
model_manager = ApiDependencies.invoker.services.model_manager
|
# model_manager = ApiDependencies.invoker.services.model_manager
|
||||||
logger = ApiDependencies.invoker.services.logger
|
# logger = ApiDependencies.invoker.services.logger
|
||||||
loader = ApiDependencies.invoker.services.model_manager.load
|
# loader = ApiDependencies.invoker.services.model_manager.load
|
||||||
store = ApiDependencies.invoker.services.model_manager.store
|
# store = ApiDependencies.invoker.services.model_manager.store
|
||||||
installer = ApiDependencies.invoker.services.model_manager.install
|
# installer = ApiDependencies.invoker.services.model_manager.install
|
||||||
|
|
||||||
try:
|
# try:
|
||||||
model_config = store.get_model(key)
|
# model_config = store.get_model(key)
|
||||||
except UnknownModelException as e:
|
# except UnknownModelException as e:
|
||||||
logger.error(str(e))
|
# logger.error(str(e))
|
||||||
raise HTTPException(status_code=424, detail=str(e))
|
# raise HTTPException(status_code=424, detail=str(e))
|
||||||
|
|
||||||
if not isinstance(model_config, MainCheckpointConfig):
|
# if not isinstance(model_config, MainCheckpointConfig):
|
||||||
logger.error(f"The model with key {key} is not a main checkpoint model.")
|
# logger.error(f"The model with key {key} is not a main checkpoint model.")
|
||||||
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
|
# raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
|
||||||
|
|
||||||
# loading the model will convert it into a cached diffusers file
|
# # loading the model will convert it into a cached diffusers file
|
||||||
model_manager.load_model_by_config(model_config, submodel_type=SubModelType.Scheduler)
|
# model_manager.load_model_by_config(model_config, submodel_type=SubModelType.Scheduler)
|
||||||
|
|
||||||
# Get the path of the converted model from the loader
|
# # Get the path of the converted model from the loader
|
||||||
cache_path = loader.convert_cache.cache_path(key)
|
# cache_path = loader.convert_cache.cache_path(key)
|
||||||
assert cache_path.exists()
|
# assert cache_path.exists()
|
||||||
|
|
||||||
# temporarily rename the original safetensors file so that there is no naming conflict
|
# # temporarily rename the original safetensors file so that there is no naming conflict
|
||||||
original_name = model_config.name
|
# original_name = model_config.name
|
||||||
model_config.name = f"{original_name}.DELETE"
|
# model_config.name = f"{original_name}.DELETE"
|
||||||
changes = ModelRecordChanges(name=model_config.name)
|
# changes = ModelRecordChanges(name=model_config.name)
|
||||||
store.update_model(key, changes=changes)
|
# store.update_model(key, changes=changes)
|
||||||
|
|
||||||
# install the diffusers
|
# # install the diffusers
|
||||||
try:
|
# try:
|
||||||
new_key = installer.install_path(
|
# new_key = installer.install_path(
|
||||||
cache_path,
|
# cache_path,
|
||||||
config={
|
# config={
|
||||||
"name": original_name,
|
# "name": original_name,
|
||||||
"description": model_config.description,
|
# "description": model_config.description,
|
||||||
"hash": model_config.hash,
|
# "hash": model_config.hash,
|
||||||
"source": model_config.source,
|
# "source": model_config.source,
|
||||||
},
|
# },
|
||||||
)
|
# )
|
||||||
except DuplicateModelException as e:
|
# except DuplicateModelException as e:
|
||||||
logger.error(str(e))
|
# logger.error(str(e))
|
||||||
raise HTTPException(status_code=409, detail=str(e))
|
# raise HTTPException(status_code=409, detail=str(e))
|
||||||
|
|
||||||
# delete the original safetensors file
|
# # delete the original safetensors file
|
||||||
installer.delete(key)
|
# installer.delete(key)
|
||||||
|
|
||||||
# delete the cached version
|
# # delete the cached version
|
||||||
shutil.rmtree(cache_path)
|
# shutil.rmtree(cache_path)
|
||||||
|
|
||||||
# return the config record for the new diffusers directory
|
# # return the config record for the new diffusers directory
|
||||||
new_config: AnyModelConfig = store.get_model(new_key)
|
# new_config: AnyModelConfig = store.get_model(new_key)
|
||||||
return new_config
|
# return new_config
|
||||||
|
|
||||||
|
|
||||||
@model_manager_router.put(
|
# @model_manager_router.put(
|
||||||
"/merge",
|
# "/merge",
|
||||||
operation_id="merge",
|
# operation_id="merge",
|
||||||
responses={
|
# responses={
|
||||||
200: {
|
# 200: {
|
||||||
"description": "Model converted successfully",
|
# "description": "Model converted successfully",
|
||||||
"content": {"application/json": {"example": example_model_config}},
|
# "content": {"application/json": {"example": example_model_config}},
|
||||||
},
|
# },
|
||||||
400: {"description": "Bad request"},
|
# 400: {"description": "Bad request"},
|
||||||
404: {"description": "Model not found"},
|
# 404: {"description": "Model not found"},
|
||||||
409: {"description": "There is already a model registered at this location"},
|
# 409: {"description": "There is already a model registered at this location"},
|
||||||
},
|
# },
|
||||||
)
|
# )
|
||||||
async def merge(
|
# async def merge(
|
||||||
keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3),
|
# keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3),
|
||||||
merged_model_name: Optional[str] = Body(description="Name of destination model", default=None),
|
# merged_model_name: Optional[str] = Body(description="Name of destination model", default=None),
|
||||||
alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
# alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||||
force: bool = Body(
|
# force: bool = Body(
|
||||||
description="Force merging of models created with different versions of diffusers",
|
# description="Force merging of models created with different versions of diffusers",
|
||||||
default=False,
|
# default=False,
|
||||||
),
|
# ),
|
||||||
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method", default=None),
|
# interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method", default=None),
|
||||||
merge_dest_directory: Optional[str] = Body(
|
# merge_dest_directory: Optional[str] = Body(
|
||||||
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
|
# description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
|
||||||
default=None,
|
# default=None,
|
||||||
),
|
# ),
|
||||||
) -> AnyModelConfig:
|
# ) -> AnyModelConfig:
|
||||||
"""
|
# """
|
||||||
Merge diffusers models. The process is controlled by a set parameters provided in the body of the request.
|
# Merge diffusers models. The process is controlled by a set parameters provided in the body of the request.
|
||||||
```
|
# ```
|
||||||
Argument Description [default]
|
# Argument Description [default]
|
||||||
-------- ----------------------
|
# -------- ----------------------
|
||||||
keys List of 2-3 model keys to merge together. All models must use the same base type.
|
# keys List of 2-3 model keys to merge together. All models must use the same base type.
|
||||||
merged_model_name Name for the merged model [Concat model names]
|
# merged_model_name Name for the merged model [Concat model names]
|
||||||
alpha Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5]
|
# alpha Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5]
|
||||||
force If true, force the merge even if the models were generated by different versions of the diffusers library [False]
|
# force If true, force the merge even if the models were generated by different versions of the diffusers library [False]
|
||||||
interp Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum]
|
# interp Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum]
|
||||||
merge_dest_directory Specify a directory to store the merged model in [models directory]
|
# merge_dest_directory Specify a directory to store the merged model in [models directory]
|
||||||
```
|
# ```
|
||||||
"""
|
# """
|
||||||
logger = ApiDependencies.invoker.services.logger
|
# logger = ApiDependencies.invoker.services.logger
|
||||||
try:
|
# try:
|
||||||
logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
# logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||||
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
# dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
||||||
installer = ApiDependencies.invoker.services.model_manager.install
|
# installer = ApiDependencies.invoker.services.model_manager.install
|
||||||
merger = ModelMerger(installer)
|
# merger = ModelMerger(installer)
|
||||||
model_names = [installer.record_store.get_model(x).name for x in keys]
|
# model_names = [installer.record_store.get_model(x).name for x in keys]
|
||||||
response = merger.merge_diffusion_models_and_save(
|
# response = merger.merge_diffusion_models_and_save(
|
||||||
model_keys=keys,
|
# model_keys=keys,
|
||||||
merged_model_name=merged_model_name or "+".join(model_names),
|
# merged_model_name=merged_model_name or "+".join(model_names),
|
||||||
alpha=alpha,
|
# alpha=alpha,
|
||||||
interp=interp,
|
# interp=interp,
|
||||||
force=force,
|
# force=force,
|
||||||
merge_dest_directory=dest,
|
# merge_dest_directory=dest,
|
||||||
)
|
# )
|
||||||
except UnknownModelException:
|
# except UnknownModelException:
|
||||||
raise HTTPException(
|
# raise HTTPException(
|
||||||
status_code=404,
|
# status_code=404,
|
||||||
detail=f"One or more of the models '{keys}' not found",
|
# detail=f"One or more of the models '{keys}' not found",
|
||||||
)
|
# )
|
||||||
except ValueError as e:
|
# except ValueError as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
# raise HTTPException(status_code=400, detail=str(e))
|
||||||
return response
|
# return response
|
||||||
|
Loading…
Reference in New Issue
Block a user