improve swagger documentation

This commit is contained in:
Lincoln Stein 2024-02-14 11:10:50 -05:00 committed by psychedelicious
parent 631f6cae19
commit 3e82f63c7e
3 changed files with 160 additions and 74 deletions

View File

@ -9,7 +9,7 @@ from typing import Any, Dict, List, Optional, Set
from fastapi import Body, Path, Query, Response from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from typing_extensions import Annotated from typing_extensions import Annotated
@ -37,35 +37,6 @@ from ..dependencies import ApiDependencies
model_manager_v2_router = APIRouter(prefix="/v2/models", tags=["model_manager_v2"]) model_manager_v2_router = APIRouter(prefix="/v2/models", tags=["model_manager_v2"])
example_model_output = {
"path": "sd-1/main/openjourney",
"name": "openjourney",
"base": "sd-1",
"type": "main",
"format": "diffusers",
"key": "3a0e45ff858926fd4a63da630688b1e1",
"original_hash": "1c12f18fb6e403baef26fb9d720fbd2f",
"current_hash": "1c12f18fb6e403baef26fb9d720fbd2f",
"description": "sd-1 main model openjourney",
"source": "/opt/invokeai/models/sd-1/main/openjourney",
"last_modified": 1707794711,
"vae": "/opt/invokeai/models/sd-1/vae/vae-ft-mse-840000-ema-pruned_fp16.safetensors",
"variant": "normal",
"prediction_type": "epsilon",
"repo_variant": "fp16",
}
example_model_input = {
"path": "base/type/name",
"name": "model_name",
"base": "sd-1",
"type": "main",
"format": "diffusers",
"description": "Model description",
"vae": None,
"variant": "normal",
}
class ModelsList(BaseModel): class ModelsList(BaseModel):
"""Return list of configs.""" """Return list of configs."""
@ -84,6 +55,86 @@ class ModelTagSet(BaseModel):
tags: Set[str] tags: Set[str]
##############################################################################
# These are example inputs and outputs that are used in places where Swagger
# is unable to generate a correct example.
##############################################################################
example_model_config = {
"path": "string",
"name": "string",
"base": "sd-1",
"type": "main",
"format": "checkpoint",
"config": "string",
"key": "string",
"original_hash": "string",
"current_hash": "string",
"description": "string",
"source": "string",
"last_modified": 0,
"vae": "string",
"variant": "normal",
"prediction_type": "epsilon",
"repo_variant": "fp16",
"upcast_attention": False,
"ztsnr_training": False,
}
example_model_input = {
"path": "/path/to/model",
"name": "model_name",
"base": "sd-1",
"type": "main",
"format": "checkpoint",
"config": "configs/stable-diffusion/v1-inference.yaml",
"description": "Model description",
"vae": None,
"variant": "normal",
}
example_model_metadata = {
"name": "ip_adapter_sd_image_encoder",
"author": "InvokeAI",
"tags": [
"transformers",
"safetensors",
"clip_vision_model",
"endpoints_compatible",
"region:us",
"has_space",
"license:apache-2.0",
],
"files": [
{
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/README.md",
"path": "ip_adapter_sd_image_encoder/README.md",
"size": 628,
"sha256": None,
},
{
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/config.json",
"path": "ip_adapter_sd_image_encoder/config.json",
"size": 560,
"sha256": None,
},
{
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/model.safetensors",
"path": "ip_adapter_sd_image_encoder/model.safetensors",
"size": 2528373448,
"sha256": "6ca9667da1ca9e0b0f75e46bb030f7e011f44f86cbfb8d5a36590fcd7507b030",
},
],
"type": "huggingface",
"id": "InvokeAI/ip_adapter_sd_image_encoder",
"tag_dict": {"license": "apache-2.0"},
"last_modified": "2023-09-23T17:33:25Z",
}
##############################################################################
# ROUTES
##############################################################################
@model_manager_v2_router.get( @model_manager_v2_router.get(
"/", "/",
operation_id="list_model_records", operation_id="list_model_records",
@ -119,7 +170,7 @@ async def list_model_records(
responses={ responses={
200: { 200: {
"description": "The model configuration was retrieved successfully", "description": "The model configuration was retrieved successfully",
"content": {"application/json": {"example": example_model_output}}, "content": {"application/json": {"example": example_model_config}},
}, },
400: {"description": "Bad request"}, 400: {"description": "Bad request"},
404: {"description": "The model could not be found"}, 404: {"description": "The model could not be found"},
@ -137,7 +188,7 @@ async def get_model_record(
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
@model_manager_v2_router.get("/meta", operation_id="list_model_summary") @model_manager_v2_router.get("/summary", operation_id="list_model_summary")
async def list_model_summary( async def list_model_summary(
page: int = Query(default=0, description="The page to get"), page: int = Query(default=0, description="The page to get"),
per_page: int = Query(default=10, description="The number of models per page"), per_page: int = Query(default=10, description="The number of models per page"),
@ -153,7 +204,10 @@ async def list_model_summary(
"/meta/i/{key}", "/meta/i/{key}",
operation_id="get_model_metadata", operation_id="get_model_metadata",
responses={ responses={
200: {"description": "Success"}, 200: {
"description": "The model metadata was retrieved successfully",
"content": {"application/json": {"example": example_model_metadata}},
},
400: {"description": "Bad request"}, 400: {"description": "Bad request"},
404: {"description": "No metadata available"}, 404: {"description": "No metadata available"},
}, },
@ -199,7 +253,7 @@ async def search_by_metadata_tags(
responses={ responses={
200: { 200: {
"description": "The model was updated successfully", "description": "The model was updated successfully",
"content": {"application/json": {"example": example_model_output}}, "content": {"application/json": {"example": example_model_config}},
}, },
400: {"description": "Bad request"}, 400: {"description": "Bad request"},
404: {"description": "The model could not be found"}, 404: {"description": "The model could not be found"},
@ -212,7 +266,7 @@ async def update_model_record(
info: Annotated[ info: Annotated[
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input) AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
], ],
) -> Annotated[AnyModelConfig, Field(example="this is neat")]: ) -> AnyModelConfig:
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed.""" """Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_manager.store record_store = ApiDependencies.invoker.services.model_manager.store
@ -263,7 +317,7 @@ async def del_model_record(
responses={ responses={
201: { 201: {
"description": "The model added successfully", "description": "The model added successfully",
"content": {"application/json": {"example": example_model_output}}, "content": {"application/json": {"example": example_model_config}},
}, },
409: {"description": "There is already a model corresponding to this path or repo_id"}, 409: {"description": "There is already a model corresponding to this path or repo_id"},
415: {"description": "Unrecognized file/folder format"}, 415: {"description": "Unrecognized file/folder format"},
@ -271,7 +325,9 @@ async def del_model_record(
status_code=201, status_code=201,
) )
async def add_model_record( async def add_model_record(
config: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")], config: Annotated[
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
],
) -> AnyModelConfig: ) -> AnyModelConfig:
"""Add a model using the configuration information appropriate for its type.""" """Add a model using the configuration information appropriate for its type."""
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
@ -389,32 +445,38 @@ async def import_model(
appropriate value: appropriate value:
* To install a local path using LocalModelSource, pass a source of form: * To install a local path using LocalModelSource, pass a source of form:
`{ ```
{
"type": "local", "type": "local",
"path": "/path/to/model", "path": "/path/to/model",
"inplace": false "inplace": false
}` }
The "inplace" flag, if true, will register the model in place in its ```
current filesystem location. Otherwise, the model will be copied The "inplace" flag, if true, will register the model in place in its
into the InvokeAI models directory. current filesystem location. Otherwise, the model will be copied
into the InvokeAI models directory.
* To install a HuggingFace repo_id using HFModelSource, pass a source of form: * To install a HuggingFace repo_id using HFModelSource, pass a source of form:
`{ ```
{
"type": "hf", "type": "hf",
"repo_id": "stabilityai/stable-diffusion-2.0", "repo_id": "stabilityai/stable-diffusion-2.0",
"variant": "fp16", "variant": "fp16",
"subfolder": "vae", "subfolder": "vae",
"access_token": "f5820a918aaf01" "access_token": "f5820a918aaf01"
}` }
The `variant`, `subfolder` and `access_token` fields are optional. ```
The `variant`, `subfolder` and `access_token` fields are optional.
* To install a remote model using an arbitrary URL, pass: * To install a remote model using an arbitrary URL, pass:
`{ ```
{
"type": "url", "type": "url",
"url": "http://www.civitai.com/models/123456", "url": "http://www.civitai.com/models/123456",
"access_token": "f5820a918aaf01" "access_token": "f5820a918aaf01"
}` }
The `access_token` field is optonal ```
The `access_token` field is optonal
The model's configuration record will be probed and filled in The model's configuration record will be probed and filled in
automatically. To override the default guesses, pass "metadata" automatically. To override the default guesses, pass "metadata"
@ -423,9 +485,9 @@ async def import_model(
Installation occurs in the background. Either use list_model_install_jobs() Installation occurs in the background. Either use list_model_install_jobs()
to poll for completion, or listen on the event bus for the following events: to poll for completion, or listen on the event bus for the following events:
"model_install_running" * "model_install_running"
"model_install_completed" * "model_install_completed"
"model_install_error" * "model_install_error"
On successful completion, the event's payload will contain the field "key" On successful completion, the event's payload will contain the field "key"
containing the installed ID of the model. On an error, the event's payload containing the installed ID of the model. On an error, the event's payload
@ -459,7 +521,25 @@ async def import_model(
operation_id="list_model_install_jobs", operation_id="list_model_install_jobs",
) )
async def list_model_install_jobs() -> List[ModelInstallJob]: async def list_model_install_jobs() -> List[ModelInstallJob]:
"""Return list of model install jobs.""" """Return the list of model install jobs.
Install jobs have a numeric `id`, a `status`, and other fields that provide information on
the nature of the job and its progress. The `status` is one of:
* "waiting" -- Job is waiting in the queue to run
* "downloading" -- Model file(s) are downloading
* "running" -- Model has downloaded and the model probing and registration process is running
* "completed" -- Installation completed successfully
* "error" -- An error occurred. Details will be in the "error_type" and "error" fields.
* "cancelled" -- Job was cancelled before completion.
Once completed, information about the model such as its size, base
model, type, and metadata can be retrieved from the `config_out`
field. For multi-file models such as diffusers, information on individual files
can be retrieved from `download_parts`.
See the example and schema below for more information.
"""
jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_manager.install.list_jobs() jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_manager.install.list_jobs()
return jobs return jobs
@ -473,7 +553,10 @@ async def list_model_install_jobs() -> List[ModelInstallJob]:
}, },
) )
async def get_model_install_job(id: int = Path(description="Model install id")) -> ModelInstallJob: async def get_model_install_job(id: int = Path(description="Model install id")) -> ModelInstallJob:
"""Return model install job corresponding to the given source.""" """
Return model install job corresponding to the given source. See the documentation for 'List Model Install Jobs'
for information on the format of the return value.
"""
try: try:
result: ModelInstallJob = ApiDependencies.invoker.services.model_manager.install.get_job_by_id(id) result: ModelInstallJob = ApiDependencies.invoker.services.model_manager.install.get_job_by_id(id)
return result return result
@ -539,7 +622,7 @@ async def sync_models_to_config() -> Response:
responses={ responses={
200: { 200: {
"description": "Model converted successfully", "description": "Model converted successfully",
"content": {"application/json": {"example": example_model_output}}, "content": {"application/json": {"example": example_model_config}},
}, },
400: {"description": "Bad request"}, 400: {"description": "Bad request"},
404: {"description": "Model not found"}, 404: {"description": "Model not found"},
@ -551,8 +634,8 @@ async def convert_model(
) -> AnyModelConfig: ) -> AnyModelConfig:
""" """
Permanently convert a model into diffusers format, replacing the safetensors version. Permanently convert a model into diffusers format, replacing the safetensors version.
Note that the key and model hash will change. Use the model configuration record returned Note that during the conversion process the key and model hash will change.
by this call to get the new values. The return value is the model configuration for the converted model.
""" """
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
loader = ApiDependencies.invoker.services.model_manager.load loader = ApiDependencies.invoker.services.model_manager.load
@ -617,7 +700,7 @@ async def convert_model(
responses={ responses={
200: { 200: {
"description": "Model converted successfully", "description": "Model converted successfully",
"content": {"application/json": {"example": example_model_output}}, "content": {"application/json": {"example": example_model_config}},
}, },
400: {"description": "Bad request"}, 400: {"description": "Bad request"},
404: {"description": "Model not found"}, 404: {"description": "Model not found"},
@ -639,14 +722,17 @@ async def merge(
), ),
) -> AnyModelConfig: ) -> AnyModelConfig:
""" """
Merge diffusers models. Merge diffusers models. The process is controlled by a set parameters provided in the body of the request.
```
keys: List of 2-3 model keys to merge together. All models must use the same base type. Argument Description [default]
merged_model_name: Name for the merged model [Concat model names] -------- ----------------------
alpha: Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5] keys List of 2-3 model keys to merge together. All models must use the same base type.
force: If true, force the merge even if the models were generated by different versions of the diffusers library [False] merged_model_name Name for the merged model [Concat model names]
interp: Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum] alpha Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5]
merge_dest_directory: Specify a directory to store the merged model in [models directory] force If true, force the merge even if the models were generated by different versions of the diffusers library [False]
interp Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum]
merge_dest_directory Specify a directory to store the merged model in [models directory]
```
""" """
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
try: try:

View File

@ -13,7 +13,7 @@ from typing import Any, Dict, List, Optional, Set
import requests import requests
from pydantic.networks import AnyHttpUrl from pydantic.networks import AnyHttpUrl
from requests import HTTPError from requests import HTTPError
from tqdm import tqdm, std from tqdm import tqdm
from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.util.misc import get_iso_timestamp from invokeai.app.util.misc import get_iso_timestamp

View File

@ -123,11 +123,11 @@ class ModelRepoVariant(str, Enum):
class ModelConfigBase(BaseModel): class ModelConfigBase(BaseModel):
"""Base class for model configuration information.""" """Base class for model configuration information."""
path: str path: str = Field(description="filesystem path to the model file or directory")
name: str name: str = Field(description="model name")
base: BaseModelType base: BaseModelType = Field(description="base model")
type: ModelType type: ModelType = Field(description="type of the model")
format: ModelFormat format: ModelFormat = Field(description="model format")
key: str = Field(description="unique key for model", default="<NOKEY>") key: str = Field(description="unique key for model", default="<NOKEY>")
original_hash: Optional[str] = Field( original_hash: Optional[str] = Field(
description="original fasthash of model contents", default=None description="original fasthash of model contents", default=None
@ -135,9 +135,9 @@ class ModelConfigBase(BaseModel):
current_hash: Optional[str] = Field( current_hash: Optional[str] = Field(
description="current fasthash of model contents", default=None description="current fasthash of model contents", default=None
) # if model is converted or otherwise modified, this will hold updated hash ) # if model is converted or otherwise modified, this will hold updated hash
description: Optional[str] = Field(default=None) description: Optional[str] = Field(description="human readable description of the model", default=None)
source: Optional[str] = Field(description="Model download source (URL or repo_id)", default=None) source: Optional[str] = Field(description="model original source (path, URL or repo_id)", default=None)
last_modified: Optional[float] = Field(description="Timestamp for modification time", default_factory=time.time) last_modified: Optional[float] = Field(description="timestamp for modification time", default_factory=time.time)
model_config = ConfigDict( model_config = ConfigDict(
use_enum_values=False, use_enum_values=False,