diff --git a/Makefile b/Makefile index 7344b2e8d2..e858a89e2b 100644 --- a/Makefile +++ b/Makefile @@ -18,6 +18,7 @@ help: @echo "frontend-typegen Generate types for the frontend from the OpenAPI schema" @echo "installer-zip Build the installer .zip file for the current version" @echo "tag-release Tag the GitHub repository with the current version (use at release time only!)" + @echo "openapi Generate the OpenAPI schema for the app, outputting to stdout" # Runs ruff, fixing any safely-fixable errors and formatting ruff: @@ -70,3 +71,6 @@ installer-zip: tag-release: cd installer && ./tag_release.sh +# Generate the OpenAPI Schema for the app +openapi: + python scripts/generate_openapi_schema.py diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index b7da548377..e69d95af71 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -3,9 +3,7 @@ import logging import mimetypes import socket from contextlib import asynccontextmanager -from inspect import signature from pathlib import Path -from typing import Any import torch import uvicorn @@ -13,11 +11,9 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html -from fastapi.openapi.utils import get_openapi from fastapi.responses import HTMLResponse from fastapi_events.handlers.local import local_handler from fastapi_events.middleware import EventHandlerASGIMiddleware -from pydantic.json_schema import models_json_schema from torch.backends.mps import is_available as is_mps_available # for PyCharm: @@ -25,10 +21,8 @@ from torch.backends.mps import is_available as is_mps_available import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import) import invokeai.frontend.web as web_dir from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles -from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.services.config.config_default import get_config -from invokeai.app.services.events.events_common import EventBase -from invokeai.app.services.session_processor.session_processor_common import ProgressImage +from invokeai.app.util.custom_openapi import get_openapi_func from invokeai.backend.util.devices import TorchDevice from ..backend.util.logging import InvokeAILogger @@ -45,11 +39,6 @@ from .api.routers import ( workflows, ) from .api.sockets import SocketIO -from .invocations.baseinvocation import ( - BaseInvocation, - UIConfigBase, -) -from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra app_config = get_config() @@ -119,84 +108,7 @@ app.include_router(app_info.app_router, prefix="/api") app.include_router(session_queue.session_queue_router, prefix="/api") app.include_router(workflows.workflows_router, prefix="/api") - -# Build a custom OpenAPI to include all outputs -# TODO: can outputs be included on metadata of invocation schemas somehow? -def custom_openapi() -> dict[str, Any]: - if app.openapi_schema: - return app.openapi_schema - openapi_schema = get_openapi( - title=app.title, - description="An API for invoking AI image operations", - version="1.0.0", - routes=app.routes, - separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/ - ) - - # Add all outputs - all_invocations = BaseInvocation.get_invocations() - output_types = set() - output_type_titles = {} - for invoker in all_invocations: - output_type = signature(invoker.invoke).return_annotation - output_types.add(output_type) - - output_schemas = models_json_schema( - models=[(o, "serialization") for o in output_types], ref_template="#/components/schemas/{model}" - ) - for schema_key, output_schema in output_schemas[1]["$defs"].items(): - # TODO: note that we assume the schema_key here is the TYPE.__name__ - # This could break in some cases, figure out a better way to do it - output_type_titles[schema_key] = output_schema["title"] - openapi_schema["components"]["schemas"][schema_key] = output_schema - openapi_schema["components"]["schemas"][schema_key]["class"] = "output" - - # Some models don't end up in the schemas as standalone definitions - additional_schemas = models_json_schema( - [ - (UIConfigBase, "serialization"), - (InputFieldJSONSchemaExtra, "serialization"), - (OutputFieldJSONSchemaExtra, "serialization"), - (ModelIdentifierField, "serialization"), - (ProgressImage, "serialization"), - ], - ref_template="#/components/schemas/{model}", - ) - for schema_key, schema_json in additional_schemas[1]["$defs"].items(): - openapi_schema["components"]["schemas"][schema_key] = schema_json - - openapi_schema["components"]["schemas"]["InvocationOutputMap"] = { - "type": "object", - "properties": {}, - "required": [], - } - - # Add a reference to the output type to additionalProperties of the invoker schema - for invoker in all_invocations: - invoker_name = invoker.__name__ # type: ignore [attr-defined] # this is a valid attribute - output_type = signature(obj=invoker.invoke).return_annotation - output_type_title = output_type_titles[output_type.__name__] - invoker_schema = openapi_schema["components"]["schemas"][f"{invoker_name}"] - outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"} - invoker_schema["output"] = outputs_ref - openapi_schema["components"]["schemas"]["InvocationOutputMap"]["properties"][invoker.get_type()] = outputs_ref - openapi_schema["components"]["schemas"]["InvocationOutputMap"]["required"].append(invoker.get_type()) - invoker_schema["class"] = "invocation" - - # Add all event schemas - for event in sorted(EventBase.get_events(), key=lambda e: e.__name__): - json_schema = event.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}") - if "$defs" in json_schema: - for schema_key, schema in json_schema["$defs"].items(): - openapi_schema["components"]["schemas"][schema_key] = schema - del json_schema["$defs"] - openapi_schema["components"]["schemas"][event.__name__] = json_schema - - app.openapi_schema = openapi_schema - return app.openapi_schema - - -app.openapi = custom_openapi # type: ignore [method-assign] # this is a valid assignment +app.openapi = get_openapi_func(app) @app.get("/docs", include_in_schema=False) diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 40c7b41cae..9545179e21 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -113,10 +113,10 @@ class BaseInvocationOutput(BaseModel): def get_typeadapter(cls) -> TypeAdapter[Any]: """Gets a pydantc TypeAdapter for the union of all invocation output types.""" if not cls._typeadapter: - InvocationOutputsUnion = TypeAliasType( - "InvocationOutputsUnion", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")] + AnyInvocationOutput = TypeAliasType( + "AnyInvocationOutput", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")] ) - cls._typeadapter = TypeAdapter(InvocationOutputsUnion) + cls._typeadapter = TypeAdapter(AnyInvocationOutput) return cls._typeadapter @classmethod @@ -125,12 +125,13 @@ class BaseInvocationOutput(BaseModel): return (i.get_type() for i in BaseInvocationOutput.get_outputs()) @staticmethod - def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None: + def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocationOutput]) -> None: """Adds various UI-facing attributes to the invocation output's OpenAPI schema.""" # Because we use a pydantic Literal field with default value for the invocation type, # it will be typed as optional in the OpenAPI schema. Make it required manually. if "required" not in schema or not isinstance(schema["required"], list): schema["required"] = [] + schema["class"] = "output" schema["required"].extend(["type"]) @classmethod @@ -182,10 +183,10 @@ class BaseInvocation(ABC, BaseModel): def get_typeadapter(cls) -> TypeAdapter[Any]: """Gets a pydantc TypeAdapter for the union of all invocation types.""" if not cls._typeadapter: - InvocationsUnion = TypeAliasType( - "InvocationsUnion", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")] + AnyInvocation = TypeAliasType( + "AnyInvocation", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")] ) - cls._typeadapter = TypeAdapter(InvocationsUnion) + cls._typeadapter = TypeAdapter(AnyInvocation) return cls._typeadapter @classmethod @@ -221,7 +222,7 @@ class BaseInvocation(ABC, BaseModel): return signature(cls.invoke).return_annotation @staticmethod - def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel], *args, **kwargs) -> None: + def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None: """Adds various UI-facing attributes to the invocation's OpenAPI schema.""" uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None)) if uiconfig is not None: @@ -237,6 +238,7 @@ class BaseInvocation(ABC, BaseModel): schema["version"] = uiconfig.version if "required" not in schema or not isinstance(schema["required"], list): schema["required"] = [] + schema["class"] = "invocation" schema["required"].extend(["type", "id"]) @abstractmethod @@ -310,7 +312,7 @@ class BaseInvocation(ABC, BaseModel): protected_namespaces=(), validate_assignment=True, json_schema_extra=json_schema_extra, - json_schema_serialization_defaults_required=True, + json_schema_serialization_defaults_required=False, coerce_numbers_to_str=True, ) diff --git a/invokeai/app/services/events/events_common.py b/invokeai/app/services/events/events_common.py index 3ae4468b83..0adcaa2ab1 100644 --- a/invokeai/app/services/events/events_common.py +++ b/invokeai/app/services/events/events_common.py @@ -3,9 +3,8 @@ from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Generic, Optional, P from fastapi_events.handlers.local import local_handler from fastapi_events.registry.payload_schema import registry as payload_schema -from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, field_validator +from pydantic import BaseModel, ConfigDict, Field -from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput from invokeai.app.services.session_processor.session_processor_common import ProgressImage from invokeai.app.services.session_queue.session_queue_common import ( QUEUE_ITEM_STATUS, @@ -14,6 +13,7 @@ from invokeai.app.services.session_queue.session_queue_common import ( SessionQueueItem, SessionQueueStatus, ) +from invokeai.app.services.shared.graph import AnyInvocation, AnyInvocationOutput from invokeai.app.util.misc import get_timestamp from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState @@ -98,17 +98,9 @@ class InvocationEventBase(QueueItemEventBase): item_id: int = Field(description="The ID of the queue item") batch_id: str = Field(description="The ID of the queue batch") session_id: str = Field(description="The ID of the session (aka graph execution state)") - invocation: SerializeAsAny[BaseInvocation] = Field(description="The ID of the invocation") + invocation: AnyInvocation = Field(description="The ID of the invocation") invocation_source_id: str = Field(description="The ID of the prepared invocation's source node") - @field_validator("invocation", mode="plain") - @classmethod - def validate_invocation(cls, v: Any): - """Validates the invocation using the dynamic type adapter.""" - - invocation = BaseInvocation.get_typeadapter().validate_python(v) - return invocation - @payload_schema.register class InvocationStartedEvent(InvocationEventBase): @@ -117,7 +109,7 @@ class InvocationStartedEvent(InvocationEventBase): __event_name__ = "invocation_started" @classmethod - def build(cls, queue_item: SessionQueueItem, invocation: BaseInvocation) -> "InvocationStartedEvent": + def build(cls, queue_item: SessionQueueItem, invocation: AnyInvocation) -> "InvocationStartedEvent": return cls( queue_id=queue_item.queue_id, item_id=queue_item.item_id, @@ -144,7 +136,7 @@ class InvocationDenoiseProgressEvent(InvocationEventBase): def build( cls, queue_item: SessionQueueItem, - invocation: BaseInvocation, + invocation: AnyInvocation, intermediate_state: PipelineIntermediateState, progress_image: ProgressImage, ) -> "InvocationDenoiseProgressEvent": @@ -182,19 +174,11 @@ class InvocationCompleteEvent(InvocationEventBase): __event_name__ = "invocation_complete" - result: SerializeAsAny[BaseInvocationOutput] = Field(description="The result of the invocation") - - @field_validator("result", mode="plain") - @classmethod - def validate_results(cls, v: Any): - """Validates the invocation result using the dynamic type adapter.""" - - result = BaseInvocationOutput.get_typeadapter().validate_python(v) - return result + result: AnyInvocationOutput = Field(description="The result of the invocation") @classmethod def build( - cls, queue_item: SessionQueueItem, invocation: BaseInvocation, result: BaseInvocationOutput + cls, queue_item: SessionQueueItem, invocation: AnyInvocation, result: AnyInvocationOutput ) -> "InvocationCompleteEvent": return cls( queue_id=queue_item.queue_id, @@ -223,7 +207,7 @@ class InvocationErrorEvent(InvocationEventBase): def build( cls, queue_item: SessionQueueItem, - invocation: BaseInvocation, + invocation: AnyInvocation, error_type: str, error_message: str, error_traceback: str, diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index 8508d2484c..7f5b277ad8 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -2,11 +2,12 @@ import copy import itertools -from typing import Annotated, Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints +from typing import Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints import networkx as nx from pydantic import ( BaseModel, + GetCoreSchemaHandler, GetJsonSchemaHandler, ValidationError, field_validator, @@ -277,73 +278,46 @@ class CollectInvocation(BaseInvocation): return CollectInvocationOutput(collection=copy.copy(self.collection)) +class AnyInvocation(BaseInvocation): + @classmethod + def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler): + return BaseInvocation.get_typeadapter().core_schema + + @classmethod + def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue: + # Nodes are too powerful, we have to make our own OpenAPI schema manually + # No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually + oneOf: list[dict[str, str]] = [] + for i in BaseInvocation.get_invocations(): + oneOf.append({"$ref": f"#/components/schemas/{i.__name__}"}) + return {"oneOf": oneOf} + + +class AnyInvocationOutput(BaseInvocationOutput): + @classmethod + def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler): + return BaseInvocationOutput.get_typeadapter().core_schema + + @classmethod + def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue: + # Nodes are too powerful, we have to make our own OpenAPI schema manually + # No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually + + oneOf: list[dict[str, str]] = [] + for i in BaseInvocationOutput.get_outputs(): + oneOf.append({"$ref": f"#/components/schemas/{i.__name__}"}) + return {"oneOf": oneOf} + + class Graph(BaseModel): id: str = Field(description="The id of this graph", default_factory=uuid_string) # TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me - nodes: dict[str, BaseInvocation] = Field(description="The nodes in this graph", default_factory=dict) + nodes: dict[str, AnyInvocation] = Field(description="The nodes in this graph", default_factory=dict) edges: list[Edge] = Field( description="The connections between nodes and their fields in this graph", default_factory=list, ) - @field_validator("nodes", mode="plain") - @classmethod - def validate_nodes(cls, v: dict[str, Any]): - """Validates the nodes in the graph by retrieving a union of all node types and validating each node.""" - - # Invocations register themselves as their python modules are executed. The union of all invocations is - # constructed at runtime. We use pydantic to validate `Graph.nodes` using that union. - # - # It's possible that when `graph.py` is executed, not all invocation-containing modules will have executed. If - # we construct the invocation union as `graph.py` is executed, we may miss some invocations. Those missing - # invocations will cause a graph to fail if they are used. - # - # We can get around this by validating the nodes in the graph using a "plain" validator, which overrides the - # pydantic validation entirely. This allows us to validate the nodes using the union of invocations at runtime. - # - # This same pattern is used in `GraphExecutionState`. - - nodes: dict[str, BaseInvocation] = {} - typeadapter = BaseInvocation.get_typeadapter() - for node_id, node in v.items(): - nodes[node_id] = typeadapter.validate_python(node) - return nodes - - @classmethod - def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue: - # We use a "plain" validator to validate the nodes in the graph. Pydantic is unable to create a JSON Schema for - # fields that use "plain" validators, so we have to hack around this. Also, we need to add all invocations to - # the generated schema as options for the `nodes` field. - # - # The workaround is to create a new BaseModel that has the same fields as `Graph` but without the validator and - # with the invocation union as the type for the `nodes` field. Pydantic then generates the JSON Schema as - # expected. - # - # You might be tempted to do something like this: - # - # ```py - # cloned_model = create_model(cls.__name__, __base__=cls, nodes=...) - # delattr(cloned_model, "validate_nodes") - # cloned_model.model_rebuild(force=True) - # json_schema = handler(cloned_model.__pydantic_core_schema__) - # ``` - # - # Unfortunately, this does not work. Calling `handler` here results in infinite recursion as pydantic attempts - # to build the JSON Schema for the cloned model. Instead, we have to manually clone the model. - # - # This same pattern is used in `GraphExecutionState`. - - class Graph(BaseModel): - id: Optional[str] = Field(default=None, description="The id of this graph") - nodes: dict[ - str, Annotated[Union[tuple(BaseInvocation._invocation_classes)], Field(discriminator="type")] - ] = Field(description="The nodes in this graph") - edges: list[Edge] = Field(description="The connections between nodes and their fields in this graph") - - json_schema = handler(Graph.__pydantic_core_schema__) - json_schema = handler.resolve_ref_schema(json_schema) - return json_schema - def add_node(self, node: BaseInvocation) -> None: """Adds a node to a graph @@ -774,7 +748,7 @@ class GraphExecutionState(BaseModel): ) # The results of executed nodes - results: dict[str, BaseInvocationOutput] = Field(description="The results of node executions", default_factory=dict) + results: dict[str, AnyInvocationOutput] = Field(description="The results of node executions", default_factory=dict) # Errors raised when executing nodes errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict) @@ -791,52 +765,12 @@ class GraphExecutionState(BaseModel): default_factory=dict, ) - @field_validator("results", mode="plain") - @classmethod - def validate_results(cls, v: dict[str, BaseInvocationOutput]): - """Validates the results in the GES by retrieving a union of all output types and validating each result.""" - - # See the comment in `Graph.validate_nodes` for an explanation of this logic. - results: dict[str, BaseInvocationOutput] = {} - typeadapter = BaseInvocationOutput.get_typeadapter() - for result_id, result in v.items(): - results[result_id] = typeadapter.validate_python(result) - return results - @field_validator("graph") def graph_is_valid(cls, v: Graph): """Validates that the graph is valid""" v.validate_self() return v - @classmethod - def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue: - # See the comment in `Graph.__get_pydantic_json_schema__` for an explanation of this logic. - class GraphExecutionState(BaseModel): - """Tracks the state of a graph execution""" - - id: str = Field(description="The id of the execution state") - graph: Graph = Field(description="The graph being executed") - execution_graph: Graph = Field(description="The expanded graph of activated and executed nodes") - executed: set[str] = Field(description="The set of node ids that have been executed") - executed_history: list[str] = Field( - description="The list of node ids that have been executed, in order of execution" - ) - results: dict[ - str, Annotated[Union[tuple(BaseInvocationOutput._output_classes)], Field(discriminator="type")] - ] = Field(description="The results of node executions") - errors: dict[str, str] = Field(description="Errors raised when executing nodes") - prepared_source_mapping: dict[str, str] = Field( - description="The map of prepared nodes to original graph nodes" - ) - source_prepared_mapping: dict[str, set[str]] = Field( - description="The map of original graph nodes to prepared nodes" - ) - - json_schema = handler(GraphExecutionState.__pydantic_core_schema__) - json_schema = handler.resolve_ref_schema(json_schema) - return json_schema - def next(self) -> Optional[BaseInvocation]: """Gets the next node ready to execute.""" diff --git a/invokeai/app/util/custom_openapi.py b/invokeai/app/util/custom_openapi.py new file mode 100644 index 0000000000..9313f63b84 --- /dev/null +++ b/invokeai/app/util/custom_openapi.py @@ -0,0 +1,114 @@ +from typing import Any, Callable, Optional + +from fastapi import FastAPI +from fastapi.openapi.utils import get_openapi +from pydantic.json_schema import models_json_schema + +from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, UIConfigBase +from invokeai.app.invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra +from invokeai.app.invocations.model import ModelIdentifierField +from invokeai.app.services.events.events_common import EventBase +from invokeai.app.services.session_processor.session_processor_common import ProgressImage + + +def move_defs_to_top_level(openapi_schema: dict[str, Any], component_schema: dict[str, Any]) -> None: + """Moves a component schema's $defs to the top level of the openapi schema. Useful when generating a schema + for a single model that needs to be added back to the top level of the schema. Mutates openapi_schema and + component_schema.""" + + defs = component_schema.pop("$defs", {}) + for schema_key, json_schema in defs.items(): + if schema_key in openapi_schema["components"]["schemas"]: + continue + openapi_schema["components"]["schemas"][schema_key] = json_schema + + +def get_openapi_func( + app: FastAPI, post_transform: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None +) -> Callable[[], dict[str, Any]]: + """Gets the OpenAPI schema generator function. + + Args: + app (FastAPI): The FastAPI app to generate the schema for. + post_transform (Optional[Callable[[dict[str, Any]], dict[str, Any]]], optional): A function to apply to the + generated schema before returning it. Defaults to None. + + Returns: + Callable[[], dict[str, Any]]: The OpenAPI schema generator function. When first called, the generated schema is + cached in `app.openapi_schema`. On subsequent calls, the cached schema is returned. This caching behaviour + matches FastAPI's default schema generation caching. + """ + + def openapi() -> dict[str, Any]: + if app.openapi_schema: + return app.openapi_schema + + openapi_schema = get_openapi( + title=app.title, + description="An API for invoking AI image operations", + version="1.0.0", + routes=app.routes, + separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/ + ) + + # We'll create a map of invocation type to output schema to make some types simpler on the client. + invocation_output_map_properties: dict[str, Any] = {} + invocation_output_map_required: list[str] = [] + + # We need to manually add all outputs to the schema - pydantic doesn't add them because they aren't used directly. + for output in BaseInvocationOutput.get_outputs(): + json_schema = output.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}") + move_defs_to_top_level(openapi_schema, json_schema) + openapi_schema["components"]["schemas"][output.__name__] = json_schema + + # Technically, invocations are added to the schema by pydantic, but we still need to manually set their output + # property, so we'll just do it all manually. + for invocation in BaseInvocation.get_invocations(): + json_schema = invocation.model_json_schema( + mode="serialization", ref_template="#/components/schemas/{model}" + ) + move_defs_to_top_level(openapi_schema, json_schema) + output_title = invocation.get_output_annotation().__name__ + outputs_ref = {"$ref": f"#/components/schemas/{output_title}"} + json_schema["output"] = outputs_ref + openapi_schema["components"]["schemas"][invocation.__name__] = json_schema + + # Add this invocation and its output to the output map + invocation_type = invocation.get_type() + invocation_output_map_properties[invocation_type] = json_schema["output"] + invocation_output_map_required.append(invocation_type) + + # Add the output map to the schema + openapi_schema["components"]["schemas"]["InvocationOutputMap"] = { + "type": "object", + "properties": invocation_output_map_properties, + "required": invocation_output_map_required, + } + + # Some models don't end up in the schemas as standalone definitions because they aren't used directly in the API. + # We need to add them manually here. WARNING: Pydantic can choke if you call `model.model_json_schema()` to get + # a schema. This has something to do with schema refs - not totally clear. For whatever reason, using + # `models_json_schema` seems to work fine. + additional_models = [ + *EventBase.get_events(), + UIConfigBase, + InputFieldJSONSchemaExtra, + OutputFieldJSONSchemaExtra, + ModelIdentifierField, + ProgressImage, + ] + + additional_schemas = models_json_schema( + [(m, "serialization") for m in additional_models], + ref_template="#/components/schemas/{model}", + ) + # additional_schemas[1] is a dict of $defs that we need to add to the top level of the schema + move_defs_to_top_level(openapi_schema, additional_schemas[1]) + + if post_transform is not None: + openapi_schema = post_transform(openapi_schema) + + app.openapi_schema = openapi_schema + return app.openapi_schema + + return openapi diff --git a/scripts/generate_openapi_schema.py b/scripts/generate_openapi_schema.py index dd1f5b85dd..70baa194dc 100644 --- a/scripts/generate_openapi_schema.py +++ b/scripts/generate_openapi_schema.py @@ -7,9 +7,10 @@ def main(): # Change working directory to the repo root os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) - from invokeai.app.api_app import custom_openapi + from invokeai.app.api_app import app + from invokeai.app.util.custom_openapi import get_openapi_func - schema = custom_openapi() + schema = get_openapi_func(app)() json.dump(schema, sys.stdout, indent=2)