tidy(api): tidy mm routes

Rename MM routes to be consistent:
- "import" -> "install"
- "model_record" -> "model"

Comment several unused routes while I work (may end up removing them?):
- list model summary (we use the search route instead)
- add model record
- convert model
- merge models
This commit is contained in:
psychedelicious 2024-03-05 16:32:16 +11:00
parent 78895b3e80
commit 4f9bb00275

View File

@ -2,8 +2,7 @@
"""FastAPI route for model configuration records.""" """FastAPI route for model configuration records."""
import pathlib import pathlib
import shutil from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Set
from fastapi import Body, Path, Query, Response from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
@ -13,23 +12,16 @@ from typing_extensions import Annotated
from invokeai.app.services.model_install import ModelInstallJob from invokeai.app.services.model_install import ModelInstallJob
from invokeai.app.services.model_records import ( from invokeai.app.services.model_records import (
DuplicateModelException,
InvalidModelException, InvalidModelException,
ModelRecordOrderBy,
ModelSummary,
UnknownModelException, UnknownModelException,
) )
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.config import (
AnyModelConfig, AnyModelConfig,
BaseModelType, BaseModelType,
MainCheckpointConfig,
ModelFormat, ModelFormat,
ModelType, ModelType,
SubModelType,
) )
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
from invokeai.backend.model_manager.search import ModelSearch from invokeai.backend.model_manager.search import ModelSearch
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
@ -45,15 +37,6 @@ class ModelsList(BaseModel):
model_config = ConfigDict(use_enum_values=True) model_config = ConfigDict(use_enum_values=True)
class ModelTagSet(BaseModel):
"""Return tags for a set of models."""
key: str
name: str
author: str
tags: Set[str]
############################################################################## ##############################################################################
# These are example inputs and outputs that are used in places where Swagger # These are example inputs and outputs that are used in places where Swagger
# is unable to generate a correct example. # is unable to generate a correct example.
@ -167,16 +150,16 @@ async def get_model_record(
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
@model_manager_router.get("/summary", operation_id="list_model_summary") # @model_manager_router.get("/summary", operation_id="list_model_summary")
async def list_model_summary( # async def list_model_summary(
page: int = Query(default=0, description="The page to get"), # page: int = Query(default=0, description="The page to get"),
per_page: int = Query(default=10, description="The number of models per page"), # per_page: int = Query(default=10, description="The number of models per page"),
order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"), # order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
) -> PaginatedResults[ModelSummary]: # ) -> PaginatedResults[ModelSummary]:
"""Gets a page of model summary data.""" # """Gets a page of model summary data."""
record_store = ApiDependencies.invoker.services.model_manager.store # record_store = ApiDependencies.invoker.services.model_manager.store
results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by) # results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
return results # return results
class FoundModel(BaseModel): class FoundModel(BaseModel):
@ -282,14 +265,14 @@ async def update_model_record(
@model_manager_router.delete( @model_manager_router.delete(
"/i/{key}", "/i/{key}",
operation_id="del_model_record", operation_id="delete_model",
responses={ responses={
204: {"description": "Model deleted successfully"}, 204: {"description": "Model deleted successfully"},
404: {"description": "Model not found"}, 404: {"description": "Model not found"},
}, },
status_code=204, status_code=204,
) )
async def del_model_record( async def delete_model(
key: str = Path(description="Unique key of model to remove from model registry."), key: str = Path(description="Unique key of model to remove from model registry."),
) -> Response: ) -> Response:
""" """
@ -310,39 +293,39 @@ async def del_model_record(
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
@model_manager_router.post( # @model_manager_router.post(
"/i/", # "/i/",
operation_id="add_model_record", # operation_id="add_model_record",
responses={ # responses={
201: { # 201: {
"description": "The model added successfully", # "description": "The model added successfully",
"content": {"application/json": {"example": example_model_config}}, # "content": {"application/json": {"example": example_model_config}},
}, # },
409: {"description": "There is already a model corresponding to this path or repo_id"}, # 409: {"description": "There is already a model corresponding to this path or repo_id"},
415: {"description": "Unrecognized file/folder format"}, # 415: {"description": "Unrecognized file/folder format"},
}, # },
status_code=201, # status_code=201,
) # )
async def add_model_record( # async def add_model_record(
config: Annotated[ # config: Annotated[
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input) # AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
], # ],
) -> AnyModelConfig: # ) -> AnyModelConfig:
"""Add a model using the configuration information appropriate for its type.""" # """Add a model using the configuration information appropriate for its type."""
logger = ApiDependencies.invoker.services.logger # logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_manager.store # record_store = ApiDependencies.invoker.services.model_manager.store
try: # try:
record_store.add_model(config) # record_store.add_model(config)
except DuplicateModelException as e: # except DuplicateModelException as e:
logger.error(str(e)) # logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e)) # raise HTTPException(status_code=409, detail=str(e))
except InvalidModelException as e: # except InvalidModelException as e:
logger.error(str(e)) # logger.error(str(e))
raise HTTPException(status_code=415) # raise HTTPException(status_code=415)
# now fetch it out # # now fetch it out
result: AnyModelConfig = record_store.get_model(config.key) # result: AnyModelConfig = record_store.get_model(config.key)
return result # return result
@model_manager_router.post( @model_manager_router.post(
@ -417,10 +400,10 @@ async def install_model(
@model_manager_router.get( @model_manager_router.get(
"/import", "/install",
operation_id="list_model_install_jobs", operation_id="list_model_installs",
) )
async def list_model_install_jobs() -> List[ModelInstallJob]: async def list_model_installs() -> List[ModelInstallJob]:
"""Return the list of model install jobs. """Return the list of model install jobs.
Install jobs have a numeric `id`, a `status`, and other fields that provide information on Install jobs have a numeric `id`, a `status`, and other fields that provide information on
@ -444,7 +427,7 @@ async def list_model_install_jobs() -> List[ModelInstallJob]:
@model_manager_router.get( @model_manager_router.get(
"/import/{id}", "/install/{id}",
operation_id="get_model_install_job", operation_id="get_model_install_job",
responses={ responses={
200: {"description": "Success"}, 200: {"description": "Success"},
@ -464,7 +447,7 @@ async def get_model_install_job(id: int = Path(description="Model install id"))
@model_manager_router.delete( @model_manager_router.delete(
"/import/{id}", "/install/{id}",
operation_id="cancel_model_install_job", operation_id="cancel_model_install_job",
responses={ responses={
201: {"description": "The job was cancelled successfully"}, 201: {"description": "The job was cancelled successfully"},
@ -483,7 +466,7 @@ async def cancel_model_install_job(id: int = Path(description="Model install job
@model_manager_router.patch( @model_manager_router.patch(
"/import", "/install/prune",
operation_id="prune_model_install_jobs", operation_id="prune_model_install_jobs",
responses={ responses={
204: {"description": "All completed and errored jobs have been pruned"}, 204: {"description": "All completed and errored jobs have been pruned"},
@ -515,142 +498,142 @@ async def sync_models_to_config() -> Response:
return Response(status_code=204) return Response(status_code=204)
@model_manager_router.put( # @model_manager_router.put(
"/convert/{key}", # "/convert/{key}",
operation_id="convert_model", # operation_id="convert_model",
responses={ # responses={
200: { # 200: {
"description": "Model converted successfully", # "description": "Model converted successfully",
"content": {"application/json": {"example": example_model_config}}, # "content": {"application/json": {"example": example_model_config}},
}, # },
400: {"description": "Bad request"}, # 400: {"description": "Bad request"},
404: {"description": "Model not found"}, # 404: {"description": "Model not found"},
409: {"description": "There is already a model registered at this location"}, # 409: {"description": "There is already a model registered at this location"},
}, # },
) # )
async def convert_model( # async def convert_model(
key: str = Path(description="Unique key of the safetensors main model to convert to diffusers format."), # key: str = Path(description="Unique key of the safetensors main model to convert to diffusers format."),
) -> AnyModelConfig: # ) -> AnyModelConfig:
""" # """
Permanently convert a model into diffusers format, replacing the safetensors version. # Permanently convert a model into diffusers format, replacing the safetensors version.
Note that during the conversion process the key and model hash will change. # Note that during the conversion process the key and model hash will change.
The return value is the model configuration for the converted model. # The return value is the model configuration for the converted model.
""" # """
model_manager = ApiDependencies.invoker.services.model_manager # model_manager = ApiDependencies.invoker.services.model_manager
logger = ApiDependencies.invoker.services.logger # logger = ApiDependencies.invoker.services.logger
loader = ApiDependencies.invoker.services.model_manager.load # loader = ApiDependencies.invoker.services.model_manager.load
store = ApiDependencies.invoker.services.model_manager.store # store = ApiDependencies.invoker.services.model_manager.store
installer = ApiDependencies.invoker.services.model_manager.install # installer = ApiDependencies.invoker.services.model_manager.install
try: # try:
model_config = store.get_model(key) # model_config = store.get_model(key)
except UnknownModelException as e: # except UnknownModelException as e:
logger.error(str(e)) # logger.error(str(e))
raise HTTPException(status_code=424, detail=str(e)) # raise HTTPException(status_code=424, detail=str(e))
if not isinstance(model_config, MainCheckpointConfig): # if not isinstance(model_config, MainCheckpointConfig):
logger.error(f"The model with key {key} is not a main checkpoint model.") # logger.error(f"The model with key {key} is not a main checkpoint model.")
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.") # raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
# loading the model will convert it into a cached diffusers file # # loading the model will convert it into a cached diffusers file
model_manager.load_model_by_config(model_config, submodel_type=SubModelType.Scheduler) # model_manager.load_model_by_config(model_config, submodel_type=SubModelType.Scheduler)
# Get the path of the converted model from the loader # # Get the path of the converted model from the loader
cache_path = loader.convert_cache.cache_path(key) # cache_path = loader.convert_cache.cache_path(key)
assert cache_path.exists() # assert cache_path.exists()
# temporarily rename the original safetensors file so that there is no naming conflict # # temporarily rename the original safetensors file so that there is no naming conflict
original_name = model_config.name # original_name = model_config.name
model_config.name = f"{original_name}.DELETE" # model_config.name = f"{original_name}.DELETE"
changes = ModelRecordChanges(name=model_config.name) # changes = ModelRecordChanges(name=model_config.name)
store.update_model(key, changes=changes) # store.update_model(key, changes=changes)
# install the diffusers # # install the diffusers
try: # try:
new_key = installer.install_path( # new_key = installer.install_path(
cache_path, # cache_path,
config={ # config={
"name": original_name, # "name": original_name,
"description": model_config.description, # "description": model_config.description,
"hash": model_config.hash, # "hash": model_config.hash,
"source": model_config.source, # "source": model_config.source,
}, # },
) # )
except DuplicateModelException as e: # except DuplicateModelException as e:
logger.error(str(e)) # logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e)) # raise HTTPException(status_code=409, detail=str(e))
# delete the original safetensors file # # delete the original safetensors file
installer.delete(key) # installer.delete(key)
# delete the cached version # # delete the cached version
shutil.rmtree(cache_path) # shutil.rmtree(cache_path)
# return the config record for the new diffusers directory # # return the config record for the new diffusers directory
new_config: AnyModelConfig = store.get_model(new_key) # new_config: AnyModelConfig = store.get_model(new_key)
return new_config # return new_config
@model_manager_router.put( # @model_manager_router.put(
"/merge", # "/merge",
operation_id="merge", # operation_id="merge",
responses={ # responses={
200: { # 200: {
"description": "Model converted successfully", # "description": "Model converted successfully",
"content": {"application/json": {"example": example_model_config}}, # "content": {"application/json": {"example": example_model_config}},
}, # },
400: {"description": "Bad request"}, # 400: {"description": "Bad request"},
404: {"description": "Model not found"}, # 404: {"description": "Model not found"},
409: {"description": "There is already a model registered at this location"}, # 409: {"description": "There is already a model registered at this location"},
}, # },
) # )
async def merge( # async def merge(
keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3), # keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3),
merged_model_name: Optional[str] = Body(description="Name of destination model", default=None), # merged_model_name: Optional[str] = Body(description="Name of destination model", default=None),
alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5), # alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
force: bool = Body( # force: bool = Body(
description="Force merging of models created with different versions of diffusers", # description="Force merging of models created with different versions of diffusers",
default=False, # default=False,
), # ),
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method", default=None), # interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method", default=None),
merge_dest_directory: Optional[str] = Body( # merge_dest_directory: Optional[str] = Body(
description="Save the merged model to the designated directory (with 'merged_model_name' appended)", # description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
default=None, # default=None,
), # ),
) -> AnyModelConfig: # ) -> AnyModelConfig:
""" # """
Merge diffusers models. The process is controlled by a set parameters provided in the body of the request. # Merge diffusers models. The process is controlled by a set parameters provided in the body of the request.
``` # ```
Argument Description [default] # Argument Description [default]
-------- ---------------------- # -------- ----------------------
keys List of 2-3 model keys to merge together. All models must use the same base type. # keys List of 2-3 model keys to merge together. All models must use the same base type.
merged_model_name Name for the merged model [Concat model names] # merged_model_name Name for the merged model [Concat model names]
alpha Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5] # alpha Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5]
force If true, force the merge even if the models were generated by different versions of the diffusers library [False] # force If true, force the merge even if the models were generated by different versions of the diffusers library [False]
interp Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum] # interp Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum]
merge_dest_directory Specify a directory to store the merged model in [models directory] # merge_dest_directory Specify a directory to store the merged model in [models directory]
``` # ```
""" # """
logger = ApiDependencies.invoker.services.logger # logger = ApiDependencies.invoker.services.logger
try: # try:
logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}") # logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None # dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
installer = ApiDependencies.invoker.services.model_manager.install # installer = ApiDependencies.invoker.services.model_manager.install
merger = ModelMerger(installer) # merger = ModelMerger(installer)
model_names = [installer.record_store.get_model(x).name for x in keys] # model_names = [installer.record_store.get_model(x).name for x in keys]
response = merger.merge_diffusion_models_and_save( # response = merger.merge_diffusion_models_and_save(
model_keys=keys, # model_keys=keys,
merged_model_name=merged_model_name or "+".join(model_names), # merged_model_name=merged_model_name or "+".join(model_names),
alpha=alpha, # alpha=alpha,
interp=interp, # interp=interp,
force=force, # force=force,
merge_dest_directory=dest, # merge_dest_directory=dest,
) # )
except UnknownModelException: # except UnknownModelException:
raise HTTPException( # raise HTTPException(
status_code=404, # status_code=404,
detail=f"One or more of the models '{keys}' not found", # detail=f"One or more of the models '{keys}' not found",
) # )
except ValueError as e: # except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) # raise HTTPException(status_code=400, detail=str(e))
return response # return response