improve swagger documentation

This commit is contained in:
Lincoln Stein 2024-02-14 11:10:50 -05:00 committed by psychedelicious
parent 3e330d7d9d
commit b0835db47d
3 changed files with 160 additions and 74 deletions

View File

@ -9,7 +9,7 @@ from typing import Any, Dict, List, Optional, Set
from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict
from starlette.exceptions import HTTPException
from typing_extensions import Annotated
@ -37,35 +37,6 @@ from ..dependencies import ApiDependencies
model_manager_v2_router = APIRouter(prefix="/v2/models", tags=["model_manager_v2"])
example_model_output = {
"path": "sd-1/main/openjourney",
"name": "openjourney",
"base": "sd-1",
"type": "main",
"format": "diffusers",
"key": "3a0e45ff858926fd4a63da630688b1e1",
"original_hash": "1c12f18fb6e403baef26fb9d720fbd2f",
"current_hash": "1c12f18fb6e403baef26fb9d720fbd2f",
"description": "sd-1 main model openjourney",
"source": "/opt/invokeai/models/sd-1/main/openjourney",
"last_modified": 1707794711,
"vae": "/opt/invokeai/models/sd-1/vae/vae-ft-mse-840000-ema-pruned_fp16.safetensors",
"variant": "normal",
"prediction_type": "epsilon",
"repo_variant": "fp16",
}
example_model_input = {
"path": "base/type/name",
"name": "model_name",
"base": "sd-1",
"type": "main",
"format": "diffusers",
"description": "Model description",
"vae": None,
"variant": "normal",
}
class ModelsList(BaseModel):
"""Return list of configs."""
@ -84,6 +55,86 @@ class ModelTagSet(BaseModel):
tags: Set[str]
##############################################################################
# These are example inputs and outputs that are used in places where Swagger
# is unable to generate a correct example.
##############################################################################
example_model_config = {
"path": "string",
"name": "string",
"base": "sd-1",
"type": "main",
"format": "checkpoint",
"config": "string",
"key": "string",
"original_hash": "string",
"current_hash": "string",
"description": "string",
"source": "string",
"last_modified": 0,
"vae": "string",
"variant": "normal",
"prediction_type": "epsilon",
"repo_variant": "fp16",
"upcast_attention": False,
"ztsnr_training": False,
}
example_model_input = {
"path": "/path/to/model",
"name": "model_name",
"base": "sd-1",
"type": "main",
"format": "checkpoint",
"config": "configs/stable-diffusion/v1-inference.yaml",
"description": "Model description",
"vae": None,
"variant": "normal",
}
example_model_metadata = {
"name": "ip_adapter_sd_image_encoder",
"author": "InvokeAI",
"tags": [
"transformers",
"safetensors",
"clip_vision_model",
"endpoints_compatible",
"region:us",
"has_space",
"license:apache-2.0",
],
"files": [
{
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/README.md",
"path": "ip_adapter_sd_image_encoder/README.md",
"size": 628,
"sha256": None,
},
{
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/config.json",
"path": "ip_adapter_sd_image_encoder/config.json",
"size": 560,
"sha256": None,
},
{
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/model.safetensors",
"path": "ip_adapter_sd_image_encoder/model.safetensors",
"size": 2528373448,
"sha256": "6ca9667da1ca9e0b0f75e46bb030f7e011f44f86cbfb8d5a36590fcd7507b030",
},
],
"type": "huggingface",
"id": "InvokeAI/ip_adapter_sd_image_encoder",
"tag_dict": {"license": "apache-2.0"},
"last_modified": "2023-09-23T17:33:25Z",
}
##############################################################################
# ROUTES
##############################################################################
@model_manager_v2_router.get(
"/",
operation_id="list_model_records",
@ -119,7 +170,7 @@ async def list_model_records(
responses={
200: {
"description": "The model configuration was retrieved successfully",
"content": {"application/json": {"example": example_model_output}},
"content": {"application/json": {"example": example_model_config}},
},
400: {"description": "Bad request"},
404: {"description": "The model could not be found"},
@ -137,7 +188,7 @@ async def get_model_record(
raise HTTPException(status_code=404, detail=str(e))
@model_manager_v2_router.get("/meta", operation_id="list_model_summary")
@model_manager_v2_router.get("/summary", operation_id="list_model_summary")
async def list_model_summary(
page: int = Query(default=0, description="The page to get"),
per_page: int = Query(default=10, description="The number of models per page"),
@ -153,7 +204,10 @@ async def list_model_summary(
"/meta/i/{key}",
operation_id="get_model_metadata",
responses={
200: {"description": "Success"},
200: {
"description": "The model metadata was retrieved successfully",
"content": {"application/json": {"example": example_model_metadata}},
},
400: {"description": "Bad request"},
404: {"description": "No metadata available"},
},
@ -199,7 +253,7 @@ async def search_by_metadata_tags(
responses={
200: {
"description": "The model was updated successfully",
"content": {"application/json": {"example": example_model_output}},
"content": {"application/json": {"example": example_model_config}},
},
400: {"description": "Bad request"},
404: {"description": "The model could not be found"},
@ -212,7 +266,7 @@ async def update_model_record(
info: Annotated[
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
],
) -> Annotated[AnyModelConfig, Field(example="this is neat")]:
) -> AnyModelConfig:
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_manager.store
@ -263,7 +317,7 @@ async def del_model_record(
responses={
201: {
"description": "The model added successfully",
"content": {"application/json": {"example": example_model_output}},
"content": {"application/json": {"example": example_model_config}},
},
409: {"description": "There is already a model corresponding to this path or repo_id"},
415: {"description": "Unrecognized file/folder format"},
@ -271,7 +325,9 @@ async def del_model_record(
status_code=201,
)
async def add_model_record(
config: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")],
config: Annotated[
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
],
) -> AnyModelConfig:
"""Add a model using the configuration information appropriate for its type."""
logger = ApiDependencies.invoker.services.logger
@ -389,32 +445,38 @@ async def import_model(
appropriate value:
* To install a local path using LocalModelSource, pass a source of form:
`{
```
{
"type": "local",
"path": "/path/to/model",
"inplace": false
}`
The "inplace" flag, if true, will register the model in place in its
current filesystem location. Otherwise, the model will be copied
into the InvokeAI models directory.
}
```
The "inplace" flag, if true, will register the model in place in its
current filesystem location. Otherwise, the model will be copied
into the InvokeAI models directory.
* To install a HuggingFace repo_id using HFModelSource, pass a source of form:
`{
```
{
"type": "hf",
"repo_id": "stabilityai/stable-diffusion-2.0",
"variant": "fp16",
"subfolder": "vae",
"access_token": "f5820a918aaf01"
}`
The `variant`, `subfolder` and `access_token` fields are optional.
}
```
The `variant`, `subfolder` and `access_token` fields are optional.
* To install a remote model using an arbitrary URL, pass:
`{
```
{
"type": "url",
"url": "http://www.civitai.com/models/123456",
"access_token": "f5820a918aaf01"
}`
The `access_token` field is optonal
}
```
The `access_token` field is optonal
The model's configuration record will be probed and filled in
automatically. To override the default guesses, pass "metadata"
@ -423,9 +485,9 @@ async def import_model(
Installation occurs in the background. Either use list_model_install_jobs()
to poll for completion, or listen on the event bus for the following events:
"model_install_running"
"model_install_completed"
"model_install_error"
* "model_install_running"
* "model_install_completed"
* "model_install_error"
On successful completion, the event's payload will contain the field "key"
containing the installed ID of the model. On an error, the event's payload
@ -459,7 +521,25 @@ async def import_model(
operation_id="list_model_install_jobs",
)
async def list_model_install_jobs() -> List[ModelInstallJob]:
"""Return list of model install jobs."""
"""Return the list of model install jobs.
Install jobs have a numeric `id`, a `status`, and other fields that provide information on
the nature of the job and its progress. The `status` is one of:
* "waiting" -- Job is waiting in the queue to run
* "downloading" -- Model file(s) are downloading
* "running" -- Model has downloaded and the model probing and registration process is running
* "completed" -- Installation completed successfully
* "error" -- An error occurred. Details will be in the "error_type" and "error" fields.
* "cancelled" -- Job was cancelled before completion.
Once completed, information about the model such as its size, base
model, type, and metadata can be retrieved from the `config_out`
field. For multi-file models such as diffusers, information on individual files
can be retrieved from `download_parts`.
See the example and schema below for more information.
"""
jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_manager.install.list_jobs()
return jobs
@ -473,7 +553,10 @@ async def list_model_install_jobs() -> List[ModelInstallJob]:
},
)
async def get_model_install_job(id: int = Path(description="Model install id")) -> ModelInstallJob:
"""Return model install job corresponding to the given source."""
"""
Return model install job corresponding to the given source. See the documentation for 'List Model Install Jobs'
for information on the format of the return value.
"""
try:
result: ModelInstallJob = ApiDependencies.invoker.services.model_manager.install.get_job_by_id(id)
return result
@ -539,7 +622,7 @@ async def sync_models_to_config() -> Response:
responses={
200: {
"description": "Model converted successfully",
"content": {"application/json": {"example": example_model_output}},
"content": {"application/json": {"example": example_model_config}},
},
400: {"description": "Bad request"},
404: {"description": "Model not found"},
@ -551,8 +634,8 @@ async def convert_model(
) -> AnyModelConfig:
"""
Permanently convert a model into diffusers format, replacing the safetensors version.
Note that the key and model hash will change. Use the model configuration record returned
by this call to get the new values.
Note that during the conversion process the key and model hash will change.
The return value is the model configuration for the converted model.
"""
logger = ApiDependencies.invoker.services.logger
loader = ApiDependencies.invoker.services.model_manager.load
@ -617,7 +700,7 @@ async def convert_model(
responses={
200: {
"description": "Model converted successfully",
"content": {"application/json": {"example": example_model_output}},
"content": {"application/json": {"example": example_model_config}},
},
400: {"description": "Bad request"},
404: {"description": "Model not found"},
@ -639,14 +722,17 @@ async def merge(
),
) -> AnyModelConfig:
"""
Merge diffusers models.
keys: List of 2-3 model keys to merge together. All models must use the same base type.
merged_model_name: Name for the merged model [Concat model names]
alpha: Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5]
force: If true, force the merge even if the models were generated by different versions of the diffusers library [False]
interp: Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum]
merge_dest_directory: Specify a directory to store the merged model in [models directory]
Merge diffusers models. The process is controlled by a set parameters provided in the body of the request.
```
Argument Description [default]
-------- ----------------------
keys List of 2-3 model keys to merge together. All models must use the same base type.
merged_model_name Name for the merged model [Concat model names]
alpha Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5]
force If true, force the merge even if the models were generated by different versions of the diffusers library [False]
interp Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum]
merge_dest_directory Specify a directory to store the merged model in [models directory]
```
"""
logger = ApiDependencies.invoker.services.logger
try:

View File

@ -13,7 +13,7 @@ from typing import Any, Dict, List, Optional, Set
import requests
from pydantic.networks import AnyHttpUrl
from requests import HTTPError
from tqdm import tqdm, std
from tqdm import tqdm
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.util.misc import get_iso_timestamp

View File

@ -123,11 +123,11 @@ class ModelRepoVariant(str, Enum):
class ModelConfigBase(BaseModel):
"""Base class for model configuration information."""
path: str
name: str
base: BaseModelType
type: ModelType
format: ModelFormat
path: str = Field(description="filesystem path to the model file or directory")
name: str = Field(description="model name")
base: BaseModelType = Field(description="base model")
type: ModelType = Field(description="type of the model")
format: ModelFormat = Field(description="model format")
key: str = Field(description="unique key for model", default="<NOKEY>")
original_hash: Optional[str] = Field(
description="original fasthash of model contents", default=None
@ -135,9 +135,9 @@ class ModelConfigBase(BaseModel):
current_hash: Optional[str] = Field(
description="current fasthash of model contents", default=None
) # if model is converted or otherwise modified, this will hold updated hash
description: Optional[str] = Field(default=None)
source: Optional[str] = Field(description="Model download source (URL or repo_id)", default=None)
last_modified: Optional[float] = Field(description="Timestamp for modification time", default_factory=time.time)
description: Optional[str] = Field(description="human readable description of the model", default=None)
source: Optional[str] = Field(description="model original source (path, URL or repo_id)", default=None)
last_modified: Optional[float] = Field(description="timestamp for modification time", default_factory=time.time)
model_config = ConfigDict(
use_enum_values=False,