fix(app): openapi schema generation

Some tech debt related to dynamic pydantic schemas for invocations became problematic. Including the invocations and results in the event schemas was breaking pydantic's handling of ref schemas. I don't really understand why - I think it's a pydantic bug in a remote edge case that we are hitting.

After many failed attempts I landed on this implementation, which is actually much tidier than what was in there before.

- Create pydantic-enabled types for `AnyInvocation` and `AnyInvocationOutput` and use these in place of the janky dynamic unions. Actually, they are kinda the same, but better encapsulated. Use these in `Graph`, `GraphExecutionState`, `InvocationEventBase` and `InvocationCompleteEvent`.
- Revise the custom openapi function to work with the new models.
- Split out the custom openapi function to a separate file. Add a `post_transform` callback so consumers can customize the output schema.
- Update makefile scripts.
This commit is contained in:
psychedelicious 2024-05-29 17:29:51 +10:00
parent e257a72f94
commit 2f9ebdec69
7 changed files with 177 additions and 226 deletions

View File

@ -18,6 +18,7 @@ help:
@echo "frontend-typegen Generate types for the frontend from the OpenAPI schema" @echo "frontend-typegen Generate types for the frontend from the OpenAPI schema"
@echo "installer-zip Build the installer .zip file for the current version" @echo "installer-zip Build the installer .zip file for the current version"
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)" @echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
@echo "openapi Generate the OpenAPI schema for the app, outputting to stdout"
# Runs ruff, fixing any safely-fixable errors and formatting # Runs ruff, fixing any safely-fixable errors and formatting
ruff: ruff:
@ -70,3 +71,6 @@ installer-zip:
tag-release: tag-release:
cd installer && ./tag_release.sh cd installer && ./tag_release.sh
# Generate the OpenAPI Schema for the app
openapi:
python scripts/generate_openapi_schema.py

View File

@ -3,9 +3,7 @@ import logging
import mimetypes import mimetypes
import socket import socket
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from inspect import signature
from pathlib import Path from pathlib import Path
from typing import Any
import torch import torch
import uvicorn import uvicorn
@ -13,11 +11,9 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware from fastapi.middleware.gzip import GZipMiddleware
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
from fastapi.openapi.utils import get_openapi
from fastapi.responses import HTMLResponse from fastapi.responses import HTMLResponse
from fastapi_events.handlers.local import local_handler from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware from fastapi_events.middleware import EventHandlerASGIMiddleware
from pydantic.json_schema import models_json_schema
from torch.backends.mps import is_available as is_mps_available from torch.backends.mps import is_available as is_mps_available
# for PyCharm: # for PyCharm:
@ -25,10 +21,8 @@ from torch.backends.mps import is_available as is_mps_available
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import) import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
import invokeai.frontend.web as web_dir import invokeai.frontend.web as web_dir
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.config.config_default import get_config from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.events.events_common import EventBase from invokeai.app.util.custom_openapi import get_openapi_func
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
from ..backend.util.logging import InvokeAILogger from ..backend.util.logging import InvokeAILogger
@ -45,11 +39,6 @@ from .api.routers import (
workflows, workflows,
) )
from .api.sockets import SocketIO from .api.sockets import SocketIO
from .invocations.baseinvocation import (
BaseInvocation,
UIConfigBase,
)
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
app_config = get_config() app_config = get_config()
@ -119,84 +108,7 @@ app.include_router(app_info.app_router, prefix="/api")
app.include_router(session_queue.session_queue_router, prefix="/api") app.include_router(session_queue.session_queue_router, prefix="/api")
app.include_router(workflows.workflows_router, prefix="/api") app.include_router(workflows.workflows_router, prefix="/api")
app.openapi = get_openapi_func(app)
# Build a custom OpenAPI to include all outputs
# TODO: can outputs be included on metadata of invocation schemas somehow?
def custom_openapi() -> dict[str, Any]:
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title=app.title,
description="An API for invoking AI image operations",
version="1.0.0",
routes=app.routes,
separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/
)
# Add all outputs
all_invocations = BaseInvocation.get_invocations()
output_types = set()
output_type_titles = {}
for invoker in all_invocations:
output_type = signature(invoker.invoke).return_annotation
output_types.add(output_type)
output_schemas = models_json_schema(
models=[(o, "serialization") for o in output_types], ref_template="#/components/schemas/{model}"
)
for schema_key, output_schema in output_schemas[1]["$defs"].items():
# TODO: note that we assume the schema_key here is the TYPE.__name__
# This could break in some cases, figure out a better way to do it
output_type_titles[schema_key] = output_schema["title"]
openapi_schema["components"]["schemas"][schema_key] = output_schema
openapi_schema["components"]["schemas"][schema_key]["class"] = "output"
# Some models don't end up in the schemas as standalone definitions
additional_schemas = models_json_schema(
[
(UIConfigBase, "serialization"),
(InputFieldJSONSchemaExtra, "serialization"),
(OutputFieldJSONSchemaExtra, "serialization"),
(ModelIdentifierField, "serialization"),
(ProgressImage, "serialization"),
],
ref_template="#/components/schemas/{model}",
)
for schema_key, schema_json in additional_schemas[1]["$defs"].items():
openapi_schema["components"]["schemas"][schema_key] = schema_json
openapi_schema["components"]["schemas"]["InvocationOutputMap"] = {
"type": "object",
"properties": {},
"required": [],
}
# Add a reference to the output type to additionalProperties of the invoker schema
for invoker in all_invocations:
invoker_name = invoker.__name__ # type: ignore [attr-defined] # this is a valid attribute
output_type = signature(obj=invoker.invoke).return_annotation
output_type_title = output_type_titles[output_type.__name__]
invoker_schema = openapi_schema["components"]["schemas"][f"{invoker_name}"]
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
invoker_schema["output"] = outputs_ref
openapi_schema["components"]["schemas"]["InvocationOutputMap"]["properties"][invoker.get_type()] = outputs_ref
openapi_schema["components"]["schemas"]["InvocationOutputMap"]["required"].append(invoker.get_type())
invoker_schema["class"] = "invocation"
# Add all event schemas
for event in sorted(EventBase.get_events(), key=lambda e: e.__name__):
json_schema = event.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
if "$defs" in json_schema:
for schema_key, schema in json_schema["$defs"].items():
openapi_schema["components"]["schemas"][schema_key] = schema
del json_schema["$defs"]
openapi_schema["components"]["schemas"][event.__name__] = json_schema
app.openapi_schema = openapi_schema
return app.openapi_schema
app.openapi = custom_openapi # type: ignore [method-assign] # this is a valid assignment
@app.get("/docs", include_in_schema=False) @app.get("/docs", include_in_schema=False)

View File

@ -113,10 +113,10 @@ class BaseInvocationOutput(BaseModel):
def get_typeadapter(cls) -> TypeAdapter[Any]: def get_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantc TypeAdapter for the union of all invocation output types.""" """Gets a pydantc TypeAdapter for the union of all invocation output types."""
if not cls._typeadapter: if not cls._typeadapter:
InvocationOutputsUnion = TypeAliasType( AnyInvocationOutput = TypeAliasType(
"InvocationOutputsUnion", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")] "AnyInvocationOutput", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
) )
cls._typeadapter = TypeAdapter(InvocationOutputsUnion) cls._typeadapter = TypeAdapter(AnyInvocationOutput)
return cls._typeadapter return cls._typeadapter
@classmethod @classmethod
@ -125,12 +125,13 @@ class BaseInvocationOutput(BaseModel):
return (i.get_type() for i in BaseInvocationOutput.get_outputs()) return (i.get_type() for i in BaseInvocationOutput.get_outputs())
@staticmethod @staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None: def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocationOutput]) -> None:
"""Adds various UI-facing attributes to the invocation output's OpenAPI schema.""" """Adds various UI-facing attributes to the invocation output's OpenAPI schema."""
# Because we use a pydantic Literal field with default value for the invocation type, # Because we use a pydantic Literal field with default value for the invocation type,
# it will be typed as optional in the OpenAPI schema. Make it required manually. # it will be typed as optional in the OpenAPI schema. Make it required manually.
if "required" not in schema or not isinstance(schema["required"], list): if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = [] schema["required"] = []
schema["class"] = "output"
schema["required"].extend(["type"]) schema["required"].extend(["type"])
@classmethod @classmethod
@ -182,10 +183,10 @@ class BaseInvocation(ABC, BaseModel):
def get_typeadapter(cls) -> TypeAdapter[Any]: def get_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantc TypeAdapter for the union of all invocation types.""" """Gets a pydantc TypeAdapter for the union of all invocation types."""
if not cls._typeadapter: if not cls._typeadapter:
InvocationsUnion = TypeAliasType( AnyInvocation = TypeAliasType(
"InvocationsUnion", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")] "AnyInvocation", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
) )
cls._typeadapter = TypeAdapter(InvocationsUnion) cls._typeadapter = TypeAdapter(AnyInvocation)
return cls._typeadapter return cls._typeadapter
@classmethod @classmethod
@ -221,7 +222,7 @@ class BaseInvocation(ABC, BaseModel):
return signature(cls.invoke).return_annotation return signature(cls.invoke).return_annotation
@staticmethod @staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel], *args, **kwargs) -> None: def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None:
"""Adds various UI-facing attributes to the invocation's OpenAPI schema.""" """Adds various UI-facing attributes to the invocation's OpenAPI schema."""
uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None)) uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None))
if uiconfig is not None: if uiconfig is not None:
@ -237,6 +238,7 @@ class BaseInvocation(ABC, BaseModel):
schema["version"] = uiconfig.version schema["version"] = uiconfig.version
if "required" not in schema or not isinstance(schema["required"], list): if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = [] schema["required"] = []
schema["class"] = "invocation"
schema["required"].extend(["type", "id"]) schema["required"].extend(["type", "id"])
@abstractmethod @abstractmethod
@ -310,7 +312,7 @@ class BaseInvocation(ABC, BaseModel):
protected_namespaces=(), protected_namespaces=(),
validate_assignment=True, validate_assignment=True,
json_schema_extra=json_schema_extra, json_schema_extra=json_schema_extra,
json_schema_serialization_defaults_required=True, json_schema_serialization_defaults_required=False,
coerce_numbers_to_str=True, coerce_numbers_to_str=True,
) )

View File

@ -3,9 +3,8 @@ from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Generic, Optional, P
from fastapi_events.handlers.local import local_handler from fastapi_events.handlers.local import local_handler
from fastapi_events.registry.payload_schema import registry as payload_schema from fastapi_events.registry.payload_schema import registry as payload_schema
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, field_validator from pydantic import BaseModel, ConfigDict, Field
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.session_processor.session_processor_common import ProgressImage from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import ( from invokeai.app.services.session_queue.session_queue_common import (
QUEUE_ITEM_STATUS, QUEUE_ITEM_STATUS,
@ -14,6 +13,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
SessionQueueItem, SessionQueueItem,
SessionQueueStatus, SessionQueueStatus,
) )
from invokeai.app.services.shared.graph import AnyInvocation, AnyInvocationOutput
from invokeai.app.util.misc import get_timestamp from invokeai.app.util.misc import get_timestamp
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
@ -98,17 +98,9 @@ class InvocationEventBase(QueueItemEventBase):
item_id: int = Field(description="The ID of the queue item") item_id: int = Field(description="The ID of the queue item")
batch_id: str = Field(description="The ID of the queue batch") batch_id: str = Field(description="The ID of the queue batch")
session_id: str = Field(description="The ID of the session (aka graph execution state)") session_id: str = Field(description="The ID of the session (aka graph execution state)")
invocation: SerializeAsAny[BaseInvocation] = Field(description="The ID of the invocation") invocation: AnyInvocation = Field(description="The ID of the invocation")
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node") invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
@field_validator("invocation", mode="plain")
@classmethod
def validate_invocation(cls, v: Any):
"""Validates the invocation using the dynamic type adapter."""
invocation = BaseInvocation.get_typeadapter().validate_python(v)
return invocation
@payload_schema.register @payload_schema.register
class InvocationStartedEvent(InvocationEventBase): class InvocationStartedEvent(InvocationEventBase):
@ -117,7 +109,7 @@ class InvocationStartedEvent(InvocationEventBase):
__event_name__ = "invocation_started" __event_name__ = "invocation_started"
@classmethod @classmethod
def build(cls, queue_item: SessionQueueItem, invocation: BaseInvocation) -> "InvocationStartedEvent": def build(cls, queue_item: SessionQueueItem, invocation: AnyInvocation) -> "InvocationStartedEvent":
return cls( return cls(
queue_id=queue_item.queue_id, queue_id=queue_item.queue_id,
item_id=queue_item.item_id, item_id=queue_item.item_id,
@ -144,7 +136,7 @@ class InvocationDenoiseProgressEvent(InvocationEventBase):
def build( def build(
cls, cls,
queue_item: SessionQueueItem, queue_item: SessionQueueItem,
invocation: BaseInvocation, invocation: AnyInvocation,
intermediate_state: PipelineIntermediateState, intermediate_state: PipelineIntermediateState,
progress_image: ProgressImage, progress_image: ProgressImage,
) -> "InvocationDenoiseProgressEvent": ) -> "InvocationDenoiseProgressEvent":
@ -182,19 +174,11 @@ class InvocationCompleteEvent(InvocationEventBase):
__event_name__ = "invocation_complete" __event_name__ = "invocation_complete"
result: SerializeAsAny[BaseInvocationOutput] = Field(description="The result of the invocation") result: AnyInvocationOutput = Field(description="The result of the invocation")
@field_validator("result", mode="plain")
@classmethod
def validate_results(cls, v: Any):
"""Validates the invocation result using the dynamic type adapter."""
result = BaseInvocationOutput.get_typeadapter().validate_python(v)
return result
@classmethod @classmethod
def build( def build(
cls, queue_item: SessionQueueItem, invocation: BaseInvocation, result: BaseInvocationOutput cls, queue_item: SessionQueueItem, invocation: AnyInvocation, result: AnyInvocationOutput
) -> "InvocationCompleteEvent": ) -> "InvocationCompleteEvent":
return cls( return cls(
queue_id=queue_item.queue_id, queue_id=queue_item.queue_id,
@ -223,7 +207,7 @@ class InvocationErrorEvent(InvocationEventBase):
def build( def build(
cls, cls,
queue_item: SessionQueueItem, queue_item: SessionQueueItem,
invocation: BaseInvocation, invocation: AnyInvocation,
error_type: str, error_type: str,
error_message: str, error_message: str,
error_traceback: str, error_traceback: str,

View File

@ -2,11 +2,12 @@
import copy import copy
import itertools import itertools
from typing import Annotated, Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints from typing import Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
import networkx as nx import networkx as nx
from pydantic import ( from pydantic import (
BaseModel, BaseModel,
GetCoreSchemaHandler,
GetJsonSchemaHandler, GetJsonSchemaHandler,
ValidationError, ValidationError,
field_validator, field_validator,
@ -277,73 +278,46 @@ class CollectInvocation(BaseInvocation):
return CollectInvocationOutput(collection=copy.copy(self.collection)) return CollectInvocationOutput(collection=copy.copy(self.collection))
class AnyInvocation(BaseInvocation):
@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler):
return BaseInvocation.get_typeadapter().core_schema
@classmethod
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
# Nodes are too powerful, we have to make our own OpenAPI schema manually
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
oneOf: list[dict[str, str]] = []
for i in BaseInvocation.get_invocations():
oneOf.append({"$ref": f"#/components/schemas/{i.__name__}"})
return {"oneOf": oneOf}
class AnyInvocationOutput(BaseInvocationOutput):
@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler):
return BaseInvocationOutput.get_typeadapter().core_schema
@classmethod
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
# Nodes are too powerful, we have to make our own OpenAPI schema manually
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
oneOf: list[dict[str, str]] = []
for i in BaseInvocationOutput.get_outputs():
oneOf.append({"$ref": f"#/components/schemas/{i.__name__}"})
return {"oneOf": oneOf}
class Graph(BaseModel): class Graph(BaseModel):
id: str = Field(description="The id of this graph", default_factory=uuid_string) id: str = Field(description="The id of this graph", default_factory=uuid_string)
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me # TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
nodes: dict[str, BaseInvocation] = Field(description="The nodes in this graph", default_factory=dict) nodes: dict[str, AnyInvocation] = Field(description="The nodes in this graph", default_factory=dict)
edges: list[Edge] = Field( edges: list[Edge] = Field(
description="The connections between nodes and their fields in this graph", description="The connections between nodes and their fields in this graph",
default_factory=list, default_factory=list,
) )
@field_validator("nodes", mode="plain")
@classmethod
def validate_nodes(cls, v: dict[str, Any]):
"""Validates the nodes in the graph by retrieving a union of all node types and validating each node."""
# Invocations register themselves as their python modules are executed. The union of all invocations is
# constructed at runtime. We use pydantic to validate `Graph.nodes` using that union.
#
# It's possible that when `graph.py` is executed, not all invocation-containing modules will have executed. If
# we construct the invocation union as `graph.py` is executed, we may miss some invocations. Those missing
# invocations will cause a graph to fail if they are used.
#
# We can get around this by validating the nodes in the graph using a "plain" validator, which overrides the
# pydantic validation entirely. This allows us to validate the nodes using the union of invocations at runtime.
#
# This same pattern is used in `GraphExecutionState`.
nodes: dict[str, BaseInvocation] = {}
typeadapter = BaseInvocation.get_typeadapter()
for node_id, node in v.items():
nodes[node_id] = typeadapter.validate_python(node)
return nodes
@classmethod
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
# We use a "plain" validator to validate the nodes in the graph. Pydantic is unable to create a JSON Schema for
# fields that use "plain" validators, so we have to hack around this. Also, we need to add all invocations to
# the generated schema as options for the `nodes` field.
#
# The workaround is to create a new BaseModel that has the same fields as `Graph` but without the validator and
# with the invocation union as the type for the `nodes` field. Pydantic then generates the JSON Schema as
# expected.
#
# You might be tempted to do something like this:
#
# ```py
# cloned_model = create_model(cls.__name__, __base__=cls, nodes=...)
# delattr(cloned_model, "validate_nodes")
# cloned_model.model_rebuild(force=True)
# json_schema = handler(cloned_model.__pydantic_core_schema__)
# ```
#
# Unfortunately, this does not work. Calling `handler` here results in infinite recursion as pydantic attempts
# to build the JSON Schema for the cloned model. Instead, we have to manually clone the model.
#
# This same pattern is used in `GraphExecutionState`.
class Graph(BaseModel):
id: Optional[str] = Field(default=None, description="The id of this graph")
nodes: dict[
str, Annotated[Union[tuple(BaseInvocation._invocation_classes)], Field(discriminator="type")]
] = Field(description="The nodes in this graph")
edges: list[Edge] = Field(description="The connections between nodes and their fields in this graph")
json_schema = handler(Graph.__pydantic_core_schema__)
json_schema = handler.resolve_ref_schema(json_schema)
return json_schema
def add_node(self, node: BaseInvocation) -> None: def add_node(self, node: BaseInvocation) -> None:
"""Adds a node to a graph """Adds a node to a graph
@ -774,7 +748,7 @@ class GraphExecutionState(BaseModel):
) )
# The results of executed nodes # The results of executed nodes
results: dict[str, BaseInvocationOutput] = Field(description="The results of node executions", default_factory=dict) results: dict[str, AnyInvocationOutput] = Field(description="The results of node executions", default_factory=dict)
# Errors raised when executing nodes # Errors raised when executing nodes
errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict) errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict)
@ -791,52 +765,12 @@ class GraphExecutionState(BaseModel):
default_factory=dict, default_factory=dict,
) )
@field_validator("results", mode="plain")
@classmethod
def validate_results(cls, v: dict[str, BaseInvocationOutput]):
"""Validates the results in the GES by retrieving a union of all output types and validating each result."""
# See the comment in `Graph.validate_nodes` for an explanation of this logic.
results: dict[str, BaseInvocationOutput] = {}
typeadapter = BaseInvocationOutput.get_typeadapter()
for result_id, result in v.items():
results[result_id] = typeadapter.validate_python(result)
return results
@field_validator("graph") @field_validator("graph")
def graph_is_valid(cls, v: Graph): def graph_is_valid(cls, v: Graph):
"""Validates that the graph is valid""" """Validates that the graph is valid"""
v.validate_self() v.validate_self()
return v return v
@classmethod
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
# See the comment in `Graph.__get_pydantic_json_schema__` for an explanation of this logic.
class GraphExecutionState(BaseModel):
"""Tracks the state of a graph execution"""
id: str = Field(description="The id of the execution state")
graph: Graph = Field(description="The graph being executed")
execution_graph: Graph = Field(description="The expanded graph of activated and executed nodes")
executed: set[str] = Field(description="The set of node ids that have been executed")
executed_history: list[str] = Field(
description="The list of node ids that have been executed, in order of execution"
)
results: dict[
str, Annotated[Union[tuple(BaseInvocationOutput._output_classes)], Field(discriminator="type")]
] = Field(description="The results of node executions")
errors: dict[str, str] = Field(description="Errors raised when executing nodes")
prepared_source_mapping: dict[str, str] = Field(
description="The map of prepared nodes to original graph nodes"
)
source_prepared_mapping: dict[str, set[str]] = Field(
description="The map of original graph nodes to prepared nodes"
)
json_schema = handler(GraphExecutionState.__pydantic_core_schema__)
json_schema = handler.resolve_ref_schema(json_schema)
return json_schema
def next(self) -> Optional[BaseInvocation]: def next(self) -> Optional[BaseInvocation]:
"""Gets the next node ready to execute.""" """Gets the next node ready to execute."""

View File

@ -0,0 +1,114 @@
from typing import Any, Callable, Optional
from fastapi import FastAPI
from fastapi.openapi.utils import get_openapi
from pydantic.json_schema import models_json_schema
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, UIConfigBase
from invokeai.app.invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.events.events_common import EventBase
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
def move_defs_to_top_level(openapi_schema: dict[str, Any], component_schema: dict[str, Any]) -> None:
"""Moves a component schema's $defs to the top level of the openapi schema. Useful when generating a schema
for a single model that needs to be added back to the top level of the schema. Mutates openapi_schema and
component_schema."""
defs = component_schema.pop("$defs", {})
for schema_key, json_schema in defs.items():
if schema_key in openapi_schema["components"]["schemas"]:
continue
openapi_schema["components"]["schemas"][schema_key] = json_schema
def get_openapi_func(
app: FastAPI, post_transform: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None
) -> Callable[[], dict[str, Any]]:
"""Gets the OpenAPI schema generator function.
Args:
app (FastAPI): The FastAPI app to generate the schema for.
post_transform (Optional[Callable[[dict[str, Any]], dict[str, Any]]], optional): A function to apply to the
generated schema before returning it. Defaults to None.
Returns:
Callable[[], dict[str, Any]]: The OpenAPI schema generator function. When first called, the generated schema is
cached in `app.openapi_schema`. On subsequent calls, the cached schema is returned. This caching behaviour
matches FastAPI's default schema generation caching.
"""
def openapi() -> dict[str, Any]:
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title=app.title,
description="An API for invoking AI image operations",
version="1.0.0",
routes=app.routes,
separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/
)
# We'll create a map of invocation type to output schema to make some types simpler on the client.
invocation_output_map_properties: dict[str, Any] = {}
invocation_output_map_required: list[str] = []
# We need to manually add all outputs to the schema - pydantic doesn't add them because they aren't used directly.
for output in BaseInvocationOutput.get_outputs():
json_schema = output.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
move_defs_to_top_level(openapi_schema, json_schema)
openapi_schema["components"]["schemas"][output.__name__] = json_schema
# Technically, invocations are added to the schema by pydantic, but we still need to manually set their output
# property, so we'll just do it all manually.
for invocation in BaseInvocation.get_invocations():
json_schema = invocation.model_json_schema(
mode="serialization", ref_template="#/components/schemas/{model}"
)
move_defs_to_top_level(openapi_schema, json_schema)
output_title = invocation.get_output_annotation().__name__
outputs_ref = {"$ref": f"#/components/schemas/{output_title}"}
json_schema["output"] = outputs_ref
openapi_schema["components"]["schemas"][invocation.__name__] = json_schema
# Add this invocation and its output to the output map
invocation_type = invocation.get_type()
invocation_output_map_properties[invocation_type] = json_schema["output"]
invocation_output_map_required.append(invocation_type)
# Add the output map to the schema
openapi_schema["components"]["schemas"]["InvocationOutputMap"] = {
"type": "object",
"properties": invocation_output_map_properties,
"required": invocation_output_map_required,
}
# Some models don't end up in the schemas as standalone definitions because they aren't used directly in the API.
# We need to add them manually here. WARNING: Pydantic can choke if you call `model.model_json_schema()` to get
# a schema. This has something to do with schema refs - not totally clear. For whatever reason, using
# `models_json_schema` seems to work fine.
additional_models = [
*EventBase.get_events(),
UIConfigBase,
InputFieldJSONSchemaExtra,
OutputFieldJSONSchemaExtra,
ModelIdentifierField,
ProgressImage,
]
additional_schemas = models_json_schema(
[(m, "serialization") for m in additional_models],
ref_template="#/components/schemas/{model}",
)
# additional_schemas[1] is a dict of $defs that we need to add to the top level of the schema
move_defs_to_top_level(openapi_schema, additional_schemas[1])
if post_transform is not None:
openapi_schema = post_transform(openapi_schema)
app.openapi_schema = openapi_schema
return app.openapi_schema
return openapi

View File

@ -7,9 +7,10 @@ def main():
# Change working directory to the repo root # Change working directory to the repo root
os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from invokeai.app.api_app import custom_openapi from invokeai.app.api_app import app
from invokeai.app.util.custom_openapi import get_openapi_func
schema = custom_openapi() schema = get_openapi_func(app)()
json.dump(schema, sys.stdout, indent=2) json.dump(schema, sys.stdout, indent=2)