From 2f9ebdec694ab99c473fb9a730a3d082e537023b Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 29 May 2024 17:29:51 +1000 Subject: [PATCH] fix(app): openapi schema generation Some tech debt related to dynamic pydantic schemas for invocations became problematic. Including the invocations and results in the event schemas was breaking pydantic's handling of ref schemas. I don't really understand why - I think it's a pydantic bug in a remote edge case that we are hitting. After many failed attempts I landed on this implementation, which is actually much tidier than what was in there before. - Create pydantic-enabled types for `AnyInvocation` and `AnyInvocationOutput` and use these in place of the janky dynamic unions. Actually, they are kinda the same, but better encapsulated. Use these in `Graph`, `GraphExecutionState`, `InvocationEventBase` and `InvocationCompleteEvent`. - Revise the custom openapi function to work with the new models. - Split out the custom openapi function to a separate file. Add a `post_transform` callback so consumers can customize the output schema. - Update makefile scripts. --- Makefile | 4 + invokeai/app/api_app.py | 92 +----------- invokeai/app/invocations/baseinvocation.py | 20 +-- invokeai/app/services/events/events_common.py | 32 ++--- invokeai/app/services/shared/graph.py | 136 +++++------------- invokeai/app/util/custom_openapi.py | 114 +++++++++++++++ scripts/generate_openapi_schema.py | 5 +- 7 files changed, 177 insertions(+), 226 deletions(-) create mode 100644 invokeai/app/util/custom_openapi.py diff --git a/Makefile b/Makefile index 7344b2e8d2..e858a89e2b 100644 --- a/Makefile +++ b/Makefile @@ -18,6 +18,7 @@ help: @echo "frontend-typegen Generate types for the frontend from the OpenAPI schema" @echo "installer-zip Build the installer .zip file for the current version" @echo "tag-release Tag the GitHub repository with the current version (use at release time only!)" + @echo "openapi Generate the OpenAPI schema for the app, outputting to stdout" # Runs ruff, fixing any safely-fixable errors and formatting ruff: @@ -70,3 +71,6 @@ installer-zip: tag-release: cd installer && ./tag_release.sh +# Generate the OpenAPI Schema for the app +openapi: + python scripts/generate_openapi_schema.py diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index b7da548377..e69d95af71 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -3,9 +3,7 @@ import logging import mimetypes import socket from contextlib import asynccontextmanager -from inspect import signature from pathlib import Path -from typing import Any import torch import uvicorn @@ -13,11 +11,9 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html -from fastapi.openapi.utils import get_openapi from fastapi.responses import HTMLResponse from fastapi_events.handlers.local import local_handler from fastapi_events.middleware import EventHandlerASGIMiddleware -from pydantic.json_schema import models_json_schema from torch.backends.mps import is_available as is_mps_available # for PyCharm: @@ -25,10 +21,8 @@ from torch.backends.mps import is_available as is_mps_available import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import) import invokeai.frontend.web as web_dir from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles -from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.services.config.config_default import get_config -from invokeai.app.services.events.events_common import EventBase -from invokeai.app.services.session_processor.session_processor_common import ProgressImage +from invokeai.app.util.custom_openapi import get_openapi_func from invokeai.backend.util.devices import TorchDevice from ..backend.util.logging import InvokeAILogger @@ -45,11 +39,6 @@ from .api.routers import ( workflows, ) from .api.sockets import SocketIO -from .invocations.baseinvocation import ( - BaseInvocation, - UIConfigBase, -) -from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra app_config = get_config() @@ -119,84 +108,7 @@ app.include_router(app_info.app_router, prefix="/api") app.include_router(session_queue.session_queue_router, prefix="/api") app.include_router(workflows.workflows_router, prefix="/api") - -# Build a custom OpenAPI to include all outputs -# TODO: can outputs be included on metadata of invocation schemas somehow? -def custom_openapi() -> dict[str, Any]: - if app.openapi_schema: - return app.openapi_schema - openapi_schema = get_openapi( - title=app.title, - description="An API for invoking AI image operations", - version="1.0.0", - routes=app.routes, - separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/ - ) - - # Add all outputs - all_invocations = BaseInvocation.get_invocations() - output_types = set() - output_type_titles = {} - for invoker in all_invocations: - output_type = signature(invoker.invoke).return_annotation - output_types.add(output_type) - - output_schemas = models_json_schema( - models=[(o, "serialization") for o in output_types], ref_template="#/components/schemas/{model}" - ) - for schema_key, output_schema in output_schemas[1]["$defs"].items(): - # TODO: note that we assume the schema_key here is the TYPE.__name__ - # This could break in some cases, figure out a better way to do it - output_type_titles[schema_key] = output_schema["title"] - openapi_schema["components"]["schemas"][schema_key] = output_schema - openapi_schema["components"]["schemas"][schema_key]["class"] = "output" - - # Some models don't end up in the schemas as standalone definitions - additional_schemas = models_json_schema( - [ - (UIConfigBase, "serialization"), - (InputFieldJSONSchemaExtra, "serialization"), - (OutputFieldJSONSchemaExtra, "serialization"), - (ModelIdentifierField, "serialization"), - (ProgressImage, "serialization"), - ], - ref_template="#/components/schemas/{model}", - ) - for schema_key, schema_json in additional_schemas[1]["$defs"].items(): - openapi_schema["components"]["schemas"][schema_key] = schema_json - - openapi_schema["components"]["schemas"]["InvocationOutputMap"] = { - "type": "object", - "properties": {}, - "required": [], - } - - # Add a reference to the output type to additionalProperties of the invoker schema - for invoker in all_invocations: - invoker_name = invoker.__name__ # type: ignore [attr-defined] # this is a valid attribute - output_type = signature(obj=invoker.invoke).return_annotation - output_type_title = output_type_titles[output_type.__name__] - invoker_schema = openapi_schema["components"]["schemas"][f"{invoker_name}"] - outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"} - invoker_schema["output"] = outputs_ref - openapi_schema["components"]["schemas"]["InvocationOutputMap"]["properties"][invoker.get_type()] = outputs_ref - openapi_schema["components"]["schemas"]["InvocationOutputMap"]["required"].append(invoker.get_type()) - invoker_schema["class"] = "invocation" - - # Add all event schemas - for event in sorted(EventBase.get_events(), key=lambda e: e.__name__): - json_schema = event.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}") - if "$defs" in json_schema: - for schema_key, schema in json_schema["$defs"].items(): - openapi_schema["components"]["schemas"][schema_key] = schema - del json_schema["$defs"] - openapi_schema["components"]["schemas"][event.__name__] = json_schema - - app.openapi_schema = openapi_schema - return app.openapi_schema - - -app.openapi = custom_openapi # type: ignore [method-assign] # this is a valid assignment +app.openapi = get_openapi_func(app) @app.get("/docs", include_in_schema=False) diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 40c7b41cae..9545179e21 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -113,10 +113,10 @@ class BaseInvocationOutput(BaseModel): def get_typeadapter(cls) -> TypeAdapter[Any]: """Gets a pydantc TypeAdapter for the union of all invocation output types.""" if not cls._typeadapter: - InvocationOutputsUnion = TypeAliasType( - "InvocationOutputsUnion", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")] + AnyInvocationOutput = TypeAliasType( + "AnyInvocationOutput", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")] ) - cls._typeadapter = TypeAdapter(InvocationOutputsUnion) + cls._typeadapter = TypeAdapter(AnyInvocationOutput) return cls._typeadapter @classmethod @@ -125,12 +125,13 @@ class BaseInvocationOutput(BaseModel): return (i.get_type() for i in BaseInvocationOutput.get_outputs()) @staticmethod - def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None: + def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocationOutput]) -> None: """Adds various UI-facing attributes to the invocation output's OpenAPI schema.""" # Because we use a pydantic Literal field with default value for the invocation type, # it will be typed as optional in the OpenAPI schema. Make it required manually. if "required" not in schema or not isinstance(schema["required"], list): schema["required"] = [] + schema["class"] = "output" schema["required"].extend(["type"]) @classmethod @@ -182,10 +183,10 @@ class BaseInvocation(ABC, BaseModel): def get_typeadapter(cls) -> TypeAdapter[Any]: """Gets a pydantc TypeAdapter for the union of all invocation types.""" if not cls._typeadapter: - InvocationsUnion = TypeAliasType( - "InvocationsUnion", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")] + AnyInvocation = TypeAliasType( + "AnyInvocation", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")] ) - cls._typeadapter = TypeAdapter(InvocationsUnion) + cls._typeadapter = TypeAdapter(AnyInvocation) return cls._typeadapter @classmethod @@ -221,7 +222,7 @@ class BaseInvocation(ABC, BaseModel): return signature(cls.invoke).return_annotation @staticmethod - def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel], *args, **kwargs) -> None: + def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None: """Adds various UI-facing attributes to the invocation's OpenAPI schema.""" uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None)) if uiconfig is not None: @@ -237,6 +238,7 @@ class BaseInvocation(ABC, BaseModel): schema["version"] = uiconfig.version if "required" not in schema or not isinstance(schema["required"], list): schema["required"] = [] + schema["class"] = "invocation" schema["required"].extend(["type", "id"]) @abstractmethod @@ -310,7 +312,7 @@ class BaseInvocation(ABC, BaseModel): protected_namespaces=(), validate_assignment=True, json_schema_extra=json_schema_extra, - json_schema_serialization_defaults_required=True, + json_schema_serialization_defaults_required=False, coerce_numbers_to_str=True, ) diff --git a/invokeai/app/services/events/events_common.py b/invokeai/app/services/events/events_common.py index 3ae4468b83..0adcaa2ab1 100644 --- a/invokeai/app/services/events/events_common.py +++ b/invokeai/app/services/events/events_common.py @@ -3,9 +3,8 @@ from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Generic, Optional, P from fastapi_events.handlers.local import local_handler from fastapi_events.registry.payload_schema import registry as payload_schema -from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, field_validator +from pydantic import BaseModel, ConfigDict, Field -from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput from invokeai.app.services.session_processor.session_processor_common import ProgressImage from invokeai.app.services.session_queue.session_queue_common import ( QUEUE_ITEM_STATUS, @@ -14,6 +13,7 @@ from invokeai.app.services.session_queue.session_queue_common import ( SessionQueueItem, SessionQueueStatus, ) +from invokeai.app.services.shared.graph import AnyInvocation, AnyInvocationOutput from invokeai.app.util.misc import get_timestamp from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState @@ -98,17 +98,9 @@ class InvocationEventBase(QueueItemEventBase): item_id: int = Field(description="The ID of the queue item") batch_id: str = Field(description="The ID of the queue batch") session_id: str = Field(description="The ID of the session (aka graph execution state)") - invocation: SerializeAsAny[BaseInvocation] = Field(description="The ID of the invocation") + invocation: AnyInvocation = Field(description="The ID of the invocation") invocation_source_id: str = Field(description="The ID of the prepared invocation's source node") - @field_validator("invocation", mode="plain") - @classmethod - def validate_invocation(cls, v: Any): - """Validates the invocation using the dynamic type adapter.""" - - invocation = BaseInvocation.get_typeadapter().validate_python(v) - return invocation - @payload_schema.register class InvocationStartedEvent(InvocationEventBase): @@ -117,7 +109,7 @@ class InvocationStartedEvent(InvocationEventBase): __event_name__ = "invocation_started" @classmethod - def build(cls, queue_item: SessionQueueItem, invocation: BaseInvocation) -> "InvocationStartedEvent": + def build(cls, queue_item: SessionQueueItem, invocation: AnyInvocation) -> "InvocationStartedEvent": return cls( queue_id=queue_item.queue_id, item_id=queue_item.item_id, @@ -144,7 +136,7 @@ class InvocationDenoiseProgressEvent(InvocationEventBase): def build( cls, queue_item: SessionQueueItem, - invocation: BaseInvocation, + invocation: AnyInvocation, intermediate_state: PipelineIntermediateState, progress_image: ProgressImage, ) -> "InvocationDenoiseProgressEvent": @@ -182,19 +174,11 @@ class InvocationCompleteEvent(InvocationEventBase): __event_name__ = "invocation_complete" - result: SerializeAsAny[BaseInvocationOutput] = Field(description="The result of the invocation") - - @field_validator("result", mode="plain") - @classmethod - def validate_results(cls, v: Any): - """Validates the invocation result using the dynamic type adapter.""" - - result = BaseInvocationOutput.get_typeadapter().validate_python(v) - return result + result: AnyInvocationOutput = Field(description="The result of the invocation") @classmethod def build( - cls, queue_item: SessionQueueItem, invocation: BaseInvocation, result: BaseInvocationOutput + cls, queue_item: SessionQueueItem, invocation: AnyInvocation, result: AnyInvocationOutput ) -> "InvocationCompleteEvent": return cls( queue_id=queue_item.queue_id, @@ -223,7 +207,7 @@ class InvocationErrorEvent(InvocationEventBase): def build( cls, queue_item: SessionQueueItem, - invocation: BaseInvocation, + invocation: AnyInvocation, error_type: str, error_message: str, error_traceback: str, diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index 8508d2484c..7f5b277ad8 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -2,11 +2,12 @@ import copy import itertools -from typing import Annotated, Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints +from typing import Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints import networkx as nx from pydantic import ( BaseModel, + GetCoreSchemaHandler, GetJsonSchemaHandler, ValidationError, field_validator, @@ -277,73 +278,46 @@ class CollectInvocation(BaseInvocation): return CollectInvocationOutput(collection=copy.copy(self.collection)) +class AnyInvocation(BaseInvocation): + @classmethod + def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler): + return BaseInvocation.get_typeadapter().core_schema + + @classmethod + def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue: + # Nodes are too powerful, we have to make our own OpenAPI schema manually + # No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually + oneOf: list[dict[str, str]] = [] + for i in BaseInvocation.get_invocations(): + oneOf.append({"$ref": f"#/components/schemas/{i.__name__}"}) + return {"oneOf": oneOf} + + +class AnyInvocationOutput(BaseInvocationOutput): + @classmethod + def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler): + return BaseInvocationOutput.get_typeadapter().core_schema + + @classmethod + def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue: + # Nodes are too powerful, we have to make our own OpenAPI schema manually + # No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually + + oneOf: list[dict[str, str]] = [] + for i in BaseInvocationOutput.get_outputs(): + oneOf.append({"$ref": f"#/components/schemas/{i.__name__}"}) + return {"oneOf": oneOf} + + class Graph(BaseModel): id: str = Field(description="The id of this graph", default_factory=uuid_string) # TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me - nodes: dict[str, BaseInvocation] = Field(description="The nodes in this graph", default_factory=dict) + nodes: dict[str, AnyInvocation] = Field(description="The nodes in this graph", default_factory=dict) edges: list[Edge] = Field( description="The connections between nodes and their fields in this graph", default_factory=list, ) - @field_validator("nodes", mode="plain") - @classmethod - def validate_nodes(cls, v: dict[str, Any]): - """Validates the nodes in the graph by retrieving a union of all node types and validating each node.""" - - # Invocations register themselves as their python modules are executed. The union of all invocations is - # constructed at runtime. We use pydantic to validate `Graph.nodes` using that union. - # - # It's possible that when `graph.py` is executed, not all invocation-containing modules will have executed. If - # we construct the invocation union as `graph.py` is executed, we may miss some invocations. Those missing - # invocations will cause a graph to fail if they are used. - # - # We can get around this by validating the nodes in the graph using a "plain" validator, which overrides the - # pydantic validation entirely. This allows us to validate the nodes using the union of invocations at runtime. - # - # This same pattern is used in `GraphExecutionState`. - - nodes: dict[str, BaseInvocation] = {} - typeadapter = BaseInvocation.get_typeadapter() - for node_id, node in v.items(): - nodes[node_id] = typeadapter.validate_python(node) - return nodes - - @classmethod - def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue: - # We use a "plain" validator to validate the nodes in the graph. Pydantic is unable to create a JSON Schema for - # fields that use "plain" validators, so we have to hack around this. Also, we need to add all invocations to - # the generated schema as options for the `nodes` field. - # - # The workaround is to create a new BaseModel that has the same fields as `Graph` but without the validator and - # with the invocation union as the type for the `nodes` field. Pydantic then generates the JSON Schema as - # expected. - # - # You might be tempted to do something like this: - # - # ```py - # cloned_model = create_model(cls.__name__, __base__=cls, nodes=...) - # delattr(cloned_model, "validate_nodes") - # cloned_model.model_rebuild(force=True) - # json_schema = handler(cloned_model.__pydantic_core_schema__) - # ``` - # - # Unfortunately, this does not work. Calling `handler` here results in infinite recursion as pydantic attempts - # to build the JSON Schema for the cloned model. Instead, we have to manually clone the model. - # - # This same pattern is used in `GraphExecutionState`. - - class Graph(BaseModel): - id: Optional[str] = Field(default=None, description="The id of this graph") - nodes: dict[ - str, Annotated[Union[tuple(BaseInvocation._invocation_classes)], Field(discriminator="type")] - ] = Field(description="The nodes in this graph") - edges: list[Edge] = Field(description="The connections between nodes and their fields in this graph") - - json_schema = handler(Graph.__pydantic_core_schema__) - json_schema = handler.resolve_ref_schema(json_schema) - return json_schema - def add_node(self, node: BaseInvocation) -> None: """Adds a node to a graph @@ -774,7 +748,7 @@ class GraphExecutionState(BaseModel): ) # The results of executed nodes - results: dict[str, BaseInvocationOutput] = Field(description="The results of node executions", default_factory=dict) + results: dict[str, AnyInvocationOutput] = Field(description="The results of node executions", default_factory=dict) # Errors raised when executing nodes errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict) @@ -791,52 +765,12 @@ class GraphExecutionState(BaseModel): default_factory=dict, ) - @field_validator("results", mode="plain") - @classmethod - def validate_results(cls, v: dict[str, BaseInvocationOutput]): - """Validates the results in the GES by retrieving a union of all output types and validating each result.""" - - # See the comment in `Graph.validate_nodes` for an explanation of this logic. - results: dict[str, BaseInvocationOutput] = {} - typeadapter = BaseInvocationOutput.get_typeadapter() - for result_id, result in v.items(): - results[result_id] = typeadapter.validate_python(result) - return results - @field_validator("graph") def graph_is_valid(cls, v: Graph): """Validates that the graph is valid""" v.validate_self() return v - @classmethod - def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue: - # See the comment in `Graph.__get_pydantic_json_schema__` for an explanation of this logic. - class GraphExecutionState(BaseModel): - """Tracks the state of a graph execution""" - - id: str = Field(description="The id of the execution state") - graph: Graph = Field(description="The graph being executed") - execution_graph: Graph = Field(description="The expanded graph of activated and executed nodes") - executed: set[str] = Field(description="The set of node ids that have been executed") - executed_history: list[str] = Field( - description="The list of node ids that have been executed, in order of execution" - ) - results: dict[ - str, Annotated[Union[tuple(BaseInvocationOutput._output_classes)], Field(discriminator="type")] - ] = Field(description="The results of node executions") - errors: dict[str, str] = Field(description="Errors raised when executing nodes") - prepared_source_mapping: dict[str, str] = Field( - description="The map of prepared nodes to original graph nodes" - ) - source_prepared_mapping: dict[str, set[str]] = Field( - description="The map of original graph nodes to prepared nodes" - ) - - json_schema = handler(GraphExecutionState.__pydantic_core_schema__) - json_schema = handler.resolve_ref_schema(json_schema) - return json_schema - def next(self) -> Optional[BaseInvocation]: """Gets the next node ready to execute.""" diff --git a/invokeai/app/util/custom_openapi.py b/invokeai/app/util/custom_openapi.py new file mode 100644 index 0000000000..9313f63b84 --- /dev/null +++ b/invokeai/app/util/custom_openapi.py @@ -0,0 +1,114 @@ +from typing import Any, Callable, Optional + +from fastapi import FastAPI +from fastapi.openapi.utils import get_openapi +from pydantic.json_schema import models_json_schema + +from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, UIConfigBase +from invokeai.app.invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra +from invokeai.app.invocations.model import ModelIdentifierField +from invokeai.app.services.events.events_common import EventBase +from invokeai.app.services.session_processor.session_processor_common import ProgressImage + + +def move_defs_to_top_level(openapi_schema: dict[str, Any], component_schema: dict[str, Any]) -> None: + """Moves a component schema's $defs to the top level of the openapi schema. Useful when generating a schema + for a single model that needs to be added back to the top level of the schema. Mutates openapi_schema and + component_schema.""" + + defs = component_schema.pop("$defs", {}) + for schema_key, json_schema in defs.items(): + if schema_key in openapi_schema["components"]["schemas"]: + continue + openapi_schema["components"]["schemas"][schema_key] = json_schema + + +def get_openapi_func( + app: FastAPI, post_transform: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None +) -> Callable[[], dict[str, Any]]: + """Gets the OpenAPI schema generator function. + + Args: + app (FastAPI): The FastAPI app to generate the schema for. + post_transform (Optional[Callable[[dict[str, Any]], dict[str, Any]]], optional): A function to apply to the + generated schema before returning it. Defaults to None. + + Returns: + Callable[[], dict[str, Any]]: The OpenAPI schema generator function. When first called, the generated schema is + cached in `app.openapi_schema`. On subsequent calls, the cached schema is returned. This caching behaviour + matches FastAPI's default schema generation caching. + """ + + def openapi() -> dict[str, Any]: + if app.openapi_schema: + return app.openapi_schema + + openapi_schema = get_openapi( + title=app.title, + description="An API for invoking AI image operations", + version="1.0.0", + routes=app.routes, + separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/ + ) + + # We'll create a map of invocation type to output schema to make some types simpler on the client. + invocation_output_map_properties: dict[str, Any] = {} + invocation_output_map_required: list[str] = [] + + # We need to manually add all outputs to the schema - pydantic doesn't add them because they aren't used directly. + for output in BaseInvocationOutput.get_outputs(): + json_schema = output.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}") + move_defs_to_top_level(openapi_schema, json_schema) + openapi_schema["components"]["schemas"][output.__name__] = json_schema + + # Technically, invocations are added to the schema by pydantic, but we still need to manually set their output + # property, so we'll just do it all manually. + for invocation in BaseInvocation.get_invocations(): + json_schema = invocation.model_json_schema( + mode="serialization", ref_template="#/components/schemas/{model}" + ) + move_defs_to_top_level(openapi_schema, json_schema) + output_title = invocation.get_output_annotation().__name__ + outputs_ref = {"$ref": f"#/components/schemas/{output_title}"} + json_schema["output"] = outputs_ref + openapi_schema["components"]["schemas"][invocation.__name__] = json_schema + + # Add this invocation and its output to the output map + invocation_type = invocation.get_type() + invocation_output_map_properties[invocation_type] = json_schema["output"] + invocation_output_map_required.append(invocation_type) + + # Add the output map to the schema + openapi_schema["components"]["schemas"]["InvocationOutputMap"] = { + "type": "object", + "properties": invocation_output_map_properties, + "required": invocation_output_map_required, + } + + # Some models don't end up in the schemas as standalone definitions because they aren't used directly in the API. + # We need to add them manually here. WARNING: Pydantic can choke if you call `model.model_json_schema()` to get + # a schema. This has something to do with schema refs - not totally clear. For whatever reason, using + # `models_json_schema` seems to work fine. + additional_models = [ + *EventBase.get_events(), + UIConfigBase, + InputFieldJSONSchemaExtra, + OutputFieldJSONSchemaExtra, + ModelIdentifierField, + ProgressImage, + ] + + additional_schemas = models_json_schema( + [(m, "serialization") for m in additional_models], + ref_template="#/components/schemas/{model}", + ) + # additional_schemas[1] is a dict of $defs that we need to add to the top level of the schema + move_defs_to_top_level(openapi_schema, additional_schemas[1]) + + if post_transform is not None: + openapi_schema = post_transform(openapi_schema) + + app.openapi_schema = openapi_schema + return app.openapi_schema + + return openapi diff --git a/scripts/generate_openapi_schema.py b/scripts/generate_openapi_schema.py index dd1f5b85dd..70baa194dc 100644 --- a/scripts/generate_openapi_schema.py +++ b/scripts/generate_openapi_schema.py @@ -7,9 +7,10 @@ def main(): # Change working directory to the repo root os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) - from invokeai.app.api_app import custom_openapi + from invokeai.app.api_app import app + from invokeai.app.util.custom_openapi import get_openapi_func - schema = custom_openapi() + schema = get_openapi_func(app)() json.dump(schema, sys.stdout, indent=2)