fix: openapi stuff (#6454)

## Summary

Fix some issues with openapi schema generation. See commits for details.

## Related Issues / Discussions


https://discord.com/channels/1020123559063990373/1049495067846524939/1245141831394529352

## QA Instructions

App should work, workflows should work.

## Merge Plan

n/a

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
This commit is contained in:
blessedcoolant 2024-05-30 08:22:34 +05:30 committed by GitHub
commit cfb12615e1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 2491 additions and 1668 deletions

View File

@ -18,6 +18,7 @@ help:
@echo "frontend-typegen Generate types for the frontend from the OpenAPI schema" @echo "frontend-typegen Generate types for the frontend from the OpenAPI schema"
@echo "installer-zip Build the installer .zip file for the current version" @echo "installer-zip Build the installer .zip file for the current version"
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)" @echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
@echo "openapi Generate the OpenAPI schema for the app, outputting to stdout"
# Runs ruff, fixing any safely-fixable errors and formatting # Runs ruff, fixing any safely-fixable errors and formatting
ruff: ruff:
@ -70,3 +71,6 @@ installer-zip:
tag-release: tag-release:
cd installer && ./tag_release.sh cd installer && ./tag_release.sh
# Generate the OpenAPI Schema for the app
openapi:
python scripts/generate_openapi_schema.py

View File

@ -3,9 +3,7 @@ import logging
import mimetypes import mimetypes
import socket import socket
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from inspect import signature
from pathlib import Path from pathlib import Path
from typing import Any
import torch import torch
import uvicorn import uvicorn
@ -13,11 +11,9 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware from fastapi.middleware.gzip import GZipMiddleware
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
from fastapi.openapi.utils import get_openapi
from fastapi.responses import HTMLResponse from fastapi.responses import HTMLResponse
from fastapi_events.handlers.local import local_handler from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware from fastapi_events.middleware import EventHandlerASGIMiddleware
from pydantic.json_schema import models_json_schema
from torch.backends.mps import is_available as is_mps_available from torch.backends.mps import is_available as is_mps_available
# for PyCharm: # for PyCharm:
@ -25,10 +21,8 @@ from torch.backends.mps import is_available as is_mps_available
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import) import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
import invokeai.frontend.web as web_dir import invokeai.frontend.web as web_dir
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.config.config_default import get_config from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.events.events_common import EventBase from invokeai.app.util.custom_openapi import get_openapi_func
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
from ..backend.util.logging import InvokeAILogger from ..backend.util.logging import InvokeAILogger
@ -45,11 +39,6 @@ from .api.routers import (
workflows, workflows,
) )
from .api.sockets import SocketIO from .api.sockets import SocketIO
from .invocations.baseinvocation import (
BaseInvocation,
UIConfigBase,
)
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
app_config = get_config() app_config = get_config()
@ -119,84 +108,7 @@ app.include_router(app_info.app_router, prefix="/api")
app.include_router(session_queue.session_queue_router, prefix="/api") app.include_router(session_queue.session_queue_router, prefix="/api")
app.include_router(workflows.workflows_router, prefix="/api") app.include_router(workflows.workflows_router, prefix="/api")
app.openapi = get_openapi_func(app)
# Build a custom OpenAPI to include all outputs
# TODO: can outputs be included on metadata of invocation schemas somehow?
def custom_openapi() -> dict[str, Any]:
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title=app.title,
description="An API for invoking AI image operations",
version="1.0.0",
routes=app.routes,
separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/
)
# Add all outputs
all_invocations = BaseInvocation.get_invocations()
output_types = set()
output_type_titles = {}
for invoker in all_invocations:
output_type = signature(invoker.invoke).return_annotation
output_types.add(output_type)
output_schemas = models_json_schema(
models=[(o, "serialization") for o in output_types], ref_template="#/components/schemas/{model}"
)
for schema_key, output_schema in output_schemas[1]["$defs"].items():
# TODO: note that we assume the schema_key here is the TYPE.__name__
# This could break in some cases, figure out a better way to do it
output_type_titles[schema_key] = output_schema["title"]
openapi_schema["components"]["schemas"][schema_key] = output_schema
openapi_schema["components"]["schemas"][schema_key]["class"] = "output"
# Some models don't end up in the schemas as standalone definitions
additional_schemas = models_json_schema(
[
(UIConfigBase, "serialization"),
(InputFieldJSONSchemaExtra, "serialization"),
(OutputFieldJSONSchemaExtra, "serialization"),
(ModelIdentifierField, "serialization"),
(ProgressImage, "serialization"),
],
ref_template="#/components/schemas/{model}",
)
for schema_key, schema_json in additional_schemas[1]["$defs"].items():
openapi_schema["components"]["schemas"][schema_key] = schema_json
openapi_schema["components"]["schemas"]["InvocationOutputMap"] = {
"type": "object",
"properties": {},
"required": [],
}
# Add a reference to the output type to additionalProperties of the invoker schema
for invoker in all_invocations:
invoker_name = invoker.__name__ # type: ignore [attr-defined] # this is a valid attribute
output_type = signature(obj=invoker.invoke).return_annotation
output_type_title = output_type_titles[output_type.__name__]
invoker_schema = openapi_schema["components"]["schemas"][f"{invoker_name}"]
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
invoker_schema["output"] = outputs_ref
openapi_schema["components"]["schemas"]["InvocationOutputMap"]["properties"][invoker.get_type()] = outputs_ref
openapi_schema["components"]["schemas"]["InvocationOutputMap"]["required"].append(invoker.get_type())
invoker_schema["class"] = "invocation"
# Add all event schemas
for event in sorted(EventBase.get_events(), key=lambda e: e.__name__):
json_schema = event.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
if "$defs" in json_schema:
for schema_key, schema in json_schema["$defs"].items():
openapi_schema["components"]["schemas"][schema_key] = schema
del json_schema["$defs"]
openapi_schema["components"]["schemas"][event.__name__] = json_schema
app.openapi_schema = openapi_schema
return app.openapi_schema
app.openapi = custom_openapi # type: ignore [method-assign] # this is a valid assignment
@app.get("/docs", include_in_schema=False) @app.get("/docs", include_in_schema=False)

View File

@ -98,11 +98,13 @@ class BaseInvocationOutput(BaseModel):
_output_classes: ClassVar[set[BaseInvocationOutput]] = set() _output_classes: ClassVar[set[BaseInvocationOutput]] = set()
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None _typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
_typeadapter_needs_update: ClassVar[bool] = False
@classmethod @classmethod
def register_output(cls, output: BaseInvocationOutput) -> None: def register_output(cls, output: BaseInvocationOutput) -> None:
"""Registers an invocation output.""" """Registers an invocation output."""
cls._output_classes.add(output) cls._output_classes.add(output)
cls._typeadapter_needs_update = True
@classmethod @classmethod
def get_outputs(cls) -> Iterable[BaseInvocationOutput]: def get_outputs(cls) -> Iterable[BaseInvocationOutput]:
@ -112,11 +114,12 @@ class BaseInvocationOutput(BaseModel):
@classmethod @classmethod
def get_typeadapter(cls) -> TypeAdapter[Any]: def get_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantc TypeAdapter for the union of all invocation output types.""" """Gets a pydantc TypeAdapter for the union of all invocation output types."""
if not cls._typeadapter: if not cls._typeadapter or cls._typeadapter_needs_update:
InvocationOutputsUnion = TypeAliasType( AnyInvocationOutput = TypeAliasType(
"InvocationOutputsUnion", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")] "AnyInvocationOutput", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
) )
cls._typeadapter = TypeAdapter(InvocationOutputsUnion) cls._typeadapter = TypeAdapter(AnyInvocationOutput)
cls._typeadapter_needs_update = False
return cls._typeadapter return cls._typeadapter
@classmethod @classmethod
@ -125,12 +128,13 @@ class BaseInvocationOutput(BaseModel):
return (i.get_type() for i in BaseInvocationOutput.get_outputs()) return (i.get_type() for i in BaseInvocationOutput.get_outputs())
@staticmethod @staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None: def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocationOutput]) -> None:
"""Adds various UI-facing attributes to the invocation output's OpenAPI schema.""" """Adds various UI-facing attributes to the invocation output's OpenAPI schema."""
# Because we use a pydantic Literal field with default value for the invocation type, # Because we use a pydantic Literal field with default value for the invocation type,
# it will be typed as optional in the OpenAPI schema. Make it required manually. # it will be typed as optional in the OpenAPI schema. Make it required manually.
if "required" not in schema or not isinstance(schema["required"], list): if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = [] schema["required"] = []
schema["class"] = "output"
schema["required"].extend(["type"]) schema["required"].extend(["type"])
@classmethod @classmethod
@ -167,6 +171,7 @@ class BaseInvocation(ABC, BaseModel):
_invocation_classes: ClassVar[set[BaseInvocation]] = set() _invocation_classes: ClassVar[set[BaseInvocation]] = set()
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None _typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
_typeadapter_needs_update: ClassVar[bool] = False
@classmethod @classmethod
def get_type(cls) -> str: def get_type(cls) -> str:
@ -177,15 +182,17 @@ class BaseInvocation(ABC, BaseModel):
def register_invocation(cls, invocation: BaseInvocation) -> None: def register_invocation(cls, invocation: BaseInvocation) -> None:
"""Registers an invocation.""" """Registers an invocation."""
cls._invocation_classes.add(invocation) cls._invocation_classes.add(invocation)
cls._typeadapter_needs_update = True
@classmethod @classmethod
def get_typeadapter(cls) -> TypeAdapter[Any]: def get_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantc TypeAdapter for the union of all invocation types.""" """Gets a pydantc TypeAdapter for the union of all invocation types."""
if not cls._typeadapter: if not cls._typeadapter or cls._typeadapter_needs_update:
InvocationsUnion = TypeAliasType( AnyInvocation = TypeAliasType(
"InvocationsUnion", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")] "AnyInvocation", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
) )
cls._typeadapter = TypeAdapter(InvocationsUnion) cls._typeadapter = TypeAdapter(AnyInvocation)
cls._typeadapter_needs_update = False
return cls._typeadapter return cls._typeadapter
@classmethod @classmethod
@ -221,7 +228,7 @@ class BaseInvocation(ABC, BaseModel):
return signature(cls.invoke).return_annotation return signature(cls.invoke).return_annotation
@staticmethod @staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel], *args, **kwargs) -> None: def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None:
"""Adds various UI-facing attributes to the invocation's OpenAPI schema.""" """Adds various UI-facing attributes to the invocation's OpenAPI schema."""
uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None)) uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None))
if uiconfig is not None: if uiconfig is not None:
@ -237,6 +244,7 @@ class BaseInvocation(ABC, BaseModel):
schema["version"] = uiconfig.version schema["version"] = uiconfig.version
if "required" not in schema or not isinstance(schema["required"], list): if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = [] schema["required"] = []
schema["class"] = "invocation"
schema["required"].extend(["type", "id"]) schema["required"].extend(["type", "id"])
@abstractmethod @abstractmethod
@ -310,7 +318,7 @@ class BaseInvocation(ABC, BaseModel):
protected_namespaces=(), protected_namespaces=(),
validate_assignment=True, validate_assignment=True,
json_schema_extra=json_schema_extra, json_schema_extra=json_schema_extra,
json_schema_serialization_defaults_required=True, json_schema_serialization_defaults_required=False,
coerce_numbers_to_str=True, coerce_numbers_to_str=True,
) )

View File

@ -3,9 +3,8 @@ from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Generic, Optional, P
from fastapi_events.handlers.local import local_handler from fastapi_events.handlers.local import local_handler
from fastapi_events.registry.payload_schema import registry as payload_schema from fastapi_events.registry.payload_schema import registry as payload_schema
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, field_validator from pydantic import BaseModel, ConfigDict, Field
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.session_processor.session_processor_common import ProgressImage from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import ( from invokeai.app.services.session_queue.session_queue_common import (
QUEUE_ITEM_STATUS, QUEUE_ITEM_STATUS,
@ -14,6 +13,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
SessionQueueItem, SessionQueueItem,
SessionQueueStatus, SessionQueueStatus,
) )
from invokeai.app.services.shared.graph import AnyInvocation, AnyInvocationOutput
from invokeai.app.util.misc import get_timestamp from invokeai.app.util.misc import get_timestamp
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
@ -98,17 +98,9 @@ class InvocationEventBase(QueueItemEventBase):
item_id: int = Field(description="The ID of the queue item") item_id: int = Field(description="The ID of the queue item")
batch_id: str = Field(description="The ID of the queue batch") batch_id: str = Field(description="The ID of the queue batch")
session_id: str = Field(description="The ID of the session (aka graph execution state)") session_id: str = Field(description="The ID of the session (aka graph execution state)")
invocation: SerializeAsAny[BaseInvocation] = Field(description="The ID of the invocation") invocation: AnyInvocation = Field(description="The ID of the invocation")
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node") invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
@field_validator("invocation", mode="plain")
@classmethod
def validate_invocation(cls, v: Any):
"""Validates the invocation using the dynamic type adapter."""
invocation = BaseInvocation.get_typeadapter().validate_python(v)
return invocation
@payload_schema.register @payload_schema.register
class InvocationStartedEvent(InvocationEventBase): class InvocationStartedEvent(InvocationEventBase):
@ -117,7 +109,7 @@ class InvocationStartedEvent(InvocationEventBase):
__event_name__ = "invocation_started" __event_name__ = "invocation_started"
@classmethod @classmethod
def build(cls, queue_item: SessionQueueItem, invocation: BaseInvocation) -> "InvocationStartedEvent": def build(cls, queue_item: SessionQueueItem, invocation: AnyInvocation) -> "InvocationStartedEvent":
return cls( return cls(
queue_id=queue_item.queue_id, queue_id=queue_item.queue_id,
item_id=queue_item.item_id, item_id=queue_item.item_id,
@ -144,7 +136,7 @@ class InvocationDenoiseProgressEvent(InvocationEventBase):
def build( def build(
cls, cls,
queue_item: SessionQueueItem, queue_item: SessionQueueItem,
invocation: BaseInvocation, invocation: AnyInvocation,
intermediate_state: PipelineIntermediateState, intermediate_state: PipelineIntermediateState,
progress_image: ProgressImage, progress_image: ProgressImage,
) -> "InvocationDenoiseProgressEvent": ) -> "InvocationDenoiseProgressEvent":
@ -182,19 +174,11 @@ class InvocationCompleteEvent(InvocationEventBase):
__event_name__ = "invocation_complete" __event_name__ = "invocation_complete"
result: SerializeAsAny[BaseInvocationOutput] = Field(description="The result of the invocation") result: AnyInvocationOutput = Field(description="The result of the invocation")
@field_validator("result", mode="plain")
@classmethod
def validate_results(cls, v: Any):
"""Validates the invocation result using the dynamic type adapter."""
result = BaseInvocationOutput.get_typeadapter().validate_python(v)
return result
@classmethod @classmethod
def build( def build(
cls, queue_item: SessionQueueItem, invocation: BaseInvocation, result: BaseInvocationOutput cls, queue_item: SessionQueueItem, invocation: AnyInvocation, result: AnyInvocationOutput
) -> "InvocationCompleteEvent": ) -> "InvocationCompleteEvent":
return cls( return cls(
queue_id=queue_item.queue_id, queue_id=queue_item.queue_id,
@ -223,7 +207,7 @@ class InvocationErrorEvent(InvocationEventBase):
def build( def build(
cls, cls,
queue_item: SessionQueueItem, queue_item: SessionQueueItem,
invocation: BaseInvocation, invocation: AnyInvocation,
error_type: str, error_type: str,
error_message: str, error_message: str,
error_traceback: str, error_traceback: str,

View File

@ -2,18 +2,19 @@
import copy import copy
import itertools import itertools
from typing import Annotated, Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints from typing import Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
import networkx as nx import networkx as nx
from pydantic import ( from pydantic import (
BaseModel, BaseModel,
GetCoreSchemaHandler,
GetJsonSchemaHandler, GetJsonSchemaHandler,
ValidationError, ValidationError,
field_validator, field_validator,
) )
from pydantic.fields import Field from pydantic.fields import Field
from pydantic.json_schema import JsonSchemaValue from pydantic.json_schema import JsonSchemaValue
from pydantic_core import CoreSchema from pydantic_core import core_schema
# Importing * is bad karma but needed here for node detection # Importing * is bad karma but needed here for node detection
from invokeai.app.invocations import * # noqa: F401 F403 from invokeai.app.invocations import * # noqa: F401 F403
@ -277,73 +278,58 @@ class CollectInvocation(BaseInvocation):
return CollectInvocationOutput(collection=copy.copy(self.collection)) return CollectInvocationOutput(collection=copy.copy(self.collection))
class AnyInvocation(BaseInvocation):
@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
def validate_invocation(v: Any) -> "AnyInvocation":
return BaseInvocation.get_typeadapter().validate_python(v)
return core_schema.no_info_plain_validator_function(validate_invocation)
@classmethod
def __get_pydantic_json_schema__(
cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
# Nodes are too powerful, we have to make our own OpenAPI schema manually
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
oneOf: list[dict[str, str]] = []
names = [i.__name__ for i in BaseInvocation.get_invocations()]
for name in sorted(names):
oneOf.append({"$ref": f"#/components/schemas/{name}"})
return {"oneOf": oneOf}
class AnyInvocationOutput(BaseInvocationOutput):
@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler):
def validate_invocation_output(v: Any) -> "AnyInvocationOutput":
return BaseInvocationOutput.get_typeadapter().validate_python(v)
return core_schema.no_info_plain_validator_function(validate_invocation_output)
@classmethod
def __get_pydantic_json_schema__(
cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
# Nodes are too powerful, we have to make our own OpenAPI schema manually
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
oneOf: list[dict[str, str]] = []
names = [i.__name__ for i in BaseInvocationOutput.get_outputs()]
for name in sorted(names):
oneOf.append({"$ref": f"#/components/schemas/{name}"})
return {"oneOf": oneOf}
class Graph(BaseModel): class Graph(BaseModel):
id: str = Field(description="The id of this graph", default_factory=uuid_string) id: str = Field(description="The id of this graph", default_factory=uuid_string)
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me # TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
nodes: dict[str, BaseInvocation] = Field(description="The nodes in this graph", default_factory=dict) nodes: dict[str, AnyInvocation] = Field(description="The nodes in this graph", default_factory=dict)
edges: list[Edge] = Field( edges: list[Edge] = Field(
description="The connections between nodes and their fields in this graph", description="The connections between nodes and their fields in this graph",
default_factory=list, default_factory=list,
) )
@field_validator("nodes", mode="plain")
@classmethod
def validate_nodes(cls, v: dict[str, Any]):
"""Validates the nodes in the graph by retrieving a union of all node types and validating each node."""
# Invocations register themselves as their python modules are executed. The union of all invocations is
# constructed at runtime. We use pydantic to validate `Graph.nodes` using that union.
#
# It's possible that when `graph.py` is executed, not all invocation-containing modules will have executed. If
# we construct the invocation union as `graph.py` is executed, we may miss some invocations. Those missing
# invocations will cause a graph to fail if they are used.
#
# We can get around this by validating the nodes in the graph using a "plain" validator, which overrides the
# pydantic validation entirely. This allows us to validate the nodes using the union of invocations at runtime.
#
# This same pattern is used in `GraphExecutionState`.
nodes: dict[str, BaseInvocation] = {}
typeadapter = BaseInvocation.get_typeadapter()
for node_id, node in v.items():
nodes[node_id] = typeadapter.validate_python(node)
return nodes
@classmethod
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
# We use a "plain" validator to validate the nodes in the graph. Pydantic is unable to create a JSON Schema for
# fields that use "plain" validators, so we have to hack around this. Also, we need to add all invocations to
# the generated schema as options for the `nodes` field.
#
# The workaround is to create a new BaseModel that has the same fields as `Graph` but without the validator and
# with the invocation union as the type for the `nodes` field. Pydantic then generates the JSON Schema as
# expected.
#
# You might be tempted to do something like this:
#
# ```py
# cloned_model = create_model(cls.__name__, __base__=cls, nodes=...)
# delattr(cloned_model, "validate_nodes")
# cloned_model.model_rebuild(force=True)
# json_schema = handler(cloned_model.__pydantic_core_schema__)
# ```
#
# Unfortunately, this does not work. Calling `handler` here results in infinite recursion as pydantic attempts
# to build the JSON Schema for the cloned model. Instead, we have to manually clone the model.
#
# This same pattern is used in `GraphExecutionState`.
class Graph(BaseModel):
id: Optional[str] = Field(default=None, description="The id of this graph")
nodes: dict[
str, Annotated[Union[tuple(BaseInvocation._invocation_classes)], Field(discriminator="type")]
] = Field(description="The nodes in this graph")
edges: list[Edge] = Field(description="The connections between nodes and their fields in this graph")
json_schema = handler(Graph.__pydantic_core_schema__)
json_schema = handler.resolve_ref_schema(json_schema)
return json_schema
def add_node(self, node: BaseInvocation) -> None: def add_node(self, node: BaseInvocation) -> None:
"""Adds a node to a graph """Adds a node to a graph
@ -774,7 +760,7 @@ class GraphExecutionState(BaseModel):
) )
# The results of executed nodes # The results of executed nodes
results: dict[str, BaseInvocationOutput] = Field(description="The results of node executions", default_factory=dict) results: dict[str, AnyInvocationOutput] = Field(description="The results of node executions", default_factory=dict)
# Errors raised when executing nodes # Errors raised when executing nodes
errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict) errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict)
@ -791,52 +777,12 @@ class GraphExecutionState(BaseModel):
default_factory=dict, default_factory=dict,
) )
@field_validator("results", mode="plain")
@classmethod
def validate_results(cls, v: dict[str, BaseInvocationOutput]):
"""Validates the results in the GES by retrieving a union of all output types and validating each result."""
# See the comment in `Graph.validate_nodes` for an explanation of this logic.
results: dict[str, BaseInvocationOutput] = {}
typeadapter = BaseInvocationOutput.get_typeadapter()
for result_id, result in v.items():
results[result_id] = typeadapter.validate_python(result)
return results
@field_validator("graph") @field_validator("graph")
def graph_is_valid(cls, v: Graph): def graph_is_valid(cls, v: Graph):
"""Validates that the graph is valid""" """Validates that the graph is valid"""
v.validate_self() v.validate_self()
return v return v
@classmethod
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
# See the comment in `Graph.__get_pydantic_json_schema__` for an explanation of this logic.
class GraphExecutionState(BaseModel):
"""Tracks the state of a graph execution"""
id: str = Field(description="The id of the execution state")
graph: Graph = Field(description="The graph being executed")
execution_graph: Graph = Field(description="The expanded graph of activated and executed nodes")
executed: set[str] = Field(description="The set of node ids that have been executed")
executed_history: list[str] = Field(
description="The list of node ids that have been executed, in order of execution"
)
results: dict[
str, Annotated[Union[tuple(BaseInvocationOutput._output_classes)], Field(discriminator="type")]
] = Field(description="The results of node executions")
errors: dict[str, str] = Field(description="Errors raised when executing nodes")
prepared_source_mapping: dict[str, str] = Field(
description="The map of prepared nodes to original graph nodes"
)
source_prepared_mapping: dict[str, set[str]] = Field(
description="The map of original graph nodes to prepared nodes"
)
json_schema = handler(GraphExecutionState.__pydantic_core_schema__)
json_schema = handler.resolve_ref_schema(json_schema)
return json_schema
def next(self) -> Optional[BaseInvocation]: def next(self) -> Optional[BaseInvocation]:
"""Gets the next node ready to execute.""" """Gets the next node ready to execute."""

View File

@ -0,0 +1,116 @@
from typing import Any, Callable, Optional
from fastapi import FastAPI
from fastapi.openapi.utils import get_openapi
from pydantic.json_schema import models_json_schema
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, UIConfigBase
from invokeai.app.invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.events.events_common import EventBase
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
def move_defs_to_top_level(openapi_schema: dict[str, Any], component_schema: dict[str, Any]) -> None:
"""Moves a component schema's $defs to the top level of the openapi schema. Useful when generating a schema
for a single model that needs to be added back to the top level of the schema. Mutates openapi_schema and
component_schema."""
defs = component_schema.pop("$defs", {})
for schema_key, json_schema in defs.items():
if schema_key in openapi_schema["components"]["schemas"]:
continue
openapi_schema["components"]["schemas"][schema_key] = json_schema
def get_openapi_func(
app: FastAPI, post_transform: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None
) -> Callable[[], dict[str, Any]]:
"""Gets the OpenAPI schema generator function.
Args:
app (FastAPI): The FastAPI app to generate the schema for.
post_transform (Optional[Callable[[dict[str, Any]], dict[str, Any]]], optional): A function to apply to the
generated schema before returning it. Defaults to None.
Returns:
Callable[[], dict[str, Any]]: The OpenAPI schema generator function. When first called, the generated schema is
cached in `app.openapi_schema`. On subsequent calls, the cached schema is returned. This caching behaviour
matches FastAPI's default schema generation caching.
"""
def openapi() -> dict[str, Any]:
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title=app.title,
description="An API for invoking AI image operations",
version="1.0.0",
routes=app.routes,
separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/
)
# We'll create a map of invocation type to output schema to make some types simpler on the client.
invocation_output_map_properties: dict[str, Any] = {}
invocation_output_map_required: list[str] = []
# We need to manually add all outputs to the schema - pydantic doesn't add them because they aren't used directly.
for output in BaseInvocationOutput.get_outputs():
json_schema = output.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
move_defs_to_top_level(openapi_schema, json_schema)
openapi_schema["components"]["schemas"][output.__name__] = json_schema
# Technically, invocations are added to the schema by pydantic, but we still need to manually set their output
# property, so we'll just do it all manually.
for invocation in BaseInvocation.get_invocations():
json_schema = invocation.model_json_schema(
mode="serialization", ref_template="#/components/schemas/{model}"
)
move_defs_to_top_level(openapi_schema, json_schema)
output_title = invocation.get_output_annotation().__name__
outputs_ref = {"$ref": f"#/components/schemas/{output_title}"}
json_schema["output"] = outputs_ref
openapi_schema["components"]["schemas"][invocation.__name__] = json_schema
# Add this invocation and its output to the output map
invocation_type = invocation.get_type()
invocation_output_map_properties[invocation_type] = json_schema["output"]
invocation_output_map_required.append(invocation_type)
# Add the output map to the schema
openapi_schema["components"]["schemas"]["InvocationOutputMap"] = {
"type": "object",
"properties": invocation_output_map_properties,
"required": invocation_output_map_required,
}
# Some models don't end up in the schemas as standalone definitions because they aren't used directly in the API.
# We need to add them manually here. WARNING: Pydantic can choke if you call `model.model_json_schema()` to get
# a schema. This has something to do with schema refs - not totally clear. For whatever reason, using
# `models_json_schema` seems to work fine.
additional_models = [
*EventBase.get_events(),
UIConfigBase,
InputFieldJSONSchemaExtra,
OutputFieldJSONSchemaExtra,
ModelIdentifierField,
ProgressImage,
]
additional_schemas = models_json_schema(
[(m, "serialization") for m in additional_models],
ref_template="#/components/schemas/{model}",
)
# additional_schemas[1] is a dict of $defs that we need to add to the top level of the schema
move_defs_to_top_level(openapi_schema, additional_schemas[1])
if post_transform is not None:
openapi_schema = post_transform(openapi_schema)
openapi_schema["components"]["schemas"] = dict(sorted(openapi_schema["components"]["schemas"].items()))
app.openapi_schema = openapi_schema
return app.openapi_schema
return openapi

View File

@ -13,7 +13,6 @@ import {
isControlAdapterLayer, isControlAdapterLayer,
} from 'features/controlLayers/store/controlLayersSlice'; } from 'features/controlLayers/store/controlLayersSlice';
import { CA_PROCESSOR_DATA } from 'features/controlLayers/util/controlAdapters'; import { CA_PROCESSOR_DATA } from 'features/controlLayers/util/controlAdapters';
import { isImageOutput } from 'features/nodes/types/common';
import { toast } from 'features/toast/toast'; import { toast } from 'features/toast/toast';
import { t } from 'i18next'; import { t } from 'i18next';
import { isEqual } from 'lodash-es'; import { isEqual } from 'lodash-es';
@ -139,7 +138,7 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
// We still have to check the output type // We still have to check the output type
assert( assert(
isImageOutput(invocationCompleteAction.payload.data.result), invocationCompleteAction.payload.data.result.type === 'image_output',
`Processor did not return an image output, got: ${invocationCompleteAction.payload.data.result}` `Processor did not return an image output, got: ${invocationCompleteAction.payload.data.result}`
); );
const { image_name } = invocationCompleteAction.payload.data.result.image; const { image_name } = invocationCompleteAction.payload.data.result.image;

View File

@ -9,7 +9,6 @@ import {
selectControlAdapterById, selectControlAdapterById,
} from 'features/controlAdapters/store/controlAdaptersSlice'; } from 'features/controlAdapters/store/controlAdaptersSlice';
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types'; import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
import { isImageOutput } from 'features/nodes/types/common';
import { toast } from 'features/toast/toast'; import { toast } from 'features/toast/toast';
import { t } from 'i18next'; import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images'; import { imagesApi } from 'services/api/endpoints/images';
@ -74,7 +73,7 @@ export const addControlNetImageProcessedListener = (startAppListening: AppStartL
); );
// We still have to check the output type // We still have to check the output type
if (isImageOutput(invocationCompleteAction.payload.data.result)) { if (invocationCompleteAction.payload.data.result.type === 'image_output') {
const { image_name } = invocationCompleteAction.payload.data.result.image; const { image_name } = invocationCompleteAction.payload.data.result.image;
// Wait for the ImageDTO to be received // Wait for the ImageDTO to be received

View File

@ -11,7 +11,6 @@ import {
} from 'features/gallery/store/gallerySlice'; } from 'features/gallery/store/gallerySlice';
import { IMAGE_CATEGORIES } from 'features/gallery/store/types'; import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState'; import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { isImageOutput } from 'features/nodes/types/common';
import { zNodeStatus } from 'features/nodes/types/invocation'; import { zNodeStatus } from 'features/nodes/types/invocation';
import { CANVAS_OUTPUT } from 'features/nodes/util/graph/constants'; import { CANVAS_OUTPUT } from 'features/nodes/util/graph/constants';
import { boardsApi } from 'services/api/endpoints/boards'; import { boardsApi } from 'services/api/endpoints/boards';
@ -33,7 +32,7 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
const { result, invocation_source_id } = data; const { result, invocation_source_id } = data;
// This complete event has an associated image output // This complete event has an associated image output
if (isImageOutput(data.result) && !nodeTypeDenylist.includes(data.invocation.type)) { if (data.result.type === 'image_output' && !nodeTypeDenylist.includes(data.invocation.type)) {
const { image_name } = data.result.image; const { image_name } = data.result.image;
const { canvas, gallery } = getState(); const { canvas, gallery } = getState();

View File

@ -11,8 +11,7 @@ import { selectLastSelectedNode } from 'features/nodes/store/selectors';
import { isInvocationNode } from 'features/nodes/types/invocation'; import { isInvocationNode } from 'features/nodes/types/invocation';
import { memo, useMemo } from 'react'; import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import type { ImageOutput } from 'services/api/types'; import type { AnyInvocationOutput, ImageOutput } from 'services/api/types';
import type { AnyResult } from 'services/events/types';
import ImageOutputPreview from './outputs/ImageOutputPreview'; import ImageOutputPreview from './outputs/ImageOutputPreview';
@ -66,4 +65,4 @@ const InspectorOutputsTab = () => {
export default memo(InspectorOutputsTab); export default memo(InspectorOutputsTab);
const getKey = (result: AnyResult, i: number) => `${result.type}-${i}`; const getKey = (result: AnyInvocationOutput, i: number) => `${result.type}-${i}`;

View File

@ -144,5 +144,4 @@ const zImageOutput = z.object({
type: z.literal('image_output'), type: z.literal('image_output'),
}); });
export type ImageOutput = z.infer<typeof zImageOutput>; export type ImageOutput = z.infer<typeof zImageOutput>;
export const isImageOutput = (output: unknown): output is ImageOutput => zImageOutput.safeParse(output).success;
// #endregion // #endregion

View File

@ -1,8 +1,7 @@
import type { NodesState } from 'features/nodes/store/types'; import type { NodesState } from 'features/nodes/store/types';
import { isInvocationNode } from 'features/nodes/types/invocation'; import { isInvocationNode } from 'features/nodes/types/invocation';
import { omit, reduce } from 'lodash-es'; import { omit, reduce } from 'lodash-es';
import type { Graph } from 'services/api/types'; import type { AnyInvocation, Graph } from 'services/api/types';
import type { AnyInvocation } from 'services/events/types';
import { v4 as uuidv4 } from 'uuid'; import { v4 as uuidv4 } from 'uuid';
/** /**

File diff suppressed because one or more lines are too long

View File

@ -122,7 +122,6 @@ export type ModelInstallStatus = S['InstallStatus'];
// Graphs // Graphs
export type Graph = S['Graph']; export type Graph = S['Graph'];
export type NonNullableGraph = O.Required<Graph, 'nodes' | 'edges'>; export type NonNullableGraph = O.Required<Graph, 'nodes' | 'edges'>;
export type GraphExecutionState = S['GraphExecutionState'];
export type Batch = S['Batch']; export type Batch = S['Batch'];
export type SessionQueueItemDTO = S['SessionQueueItemDTO']; export type SessionQueueItemDTO = S['SessionQueueItemDTO'];
export type WorkflowRecordOrderBy = S['WorkflowRecordOrderBy']; export type WorkflowRecordOrderBy = S['WorkflowRecordOrderBy'];
@ -132,14 +131,14 @@ export type WorkflowRecordListItemDTO = S['WorkflowRecordListItemDTO'];
type KeysOfUnion<T> = T extends T ? keyof T : never; type KeysOfUnion<T> = T extends T ? keyof T : never;
export type AnyInvocation = Exclude< export type AnyInvocation = Exclude<
Graph['nodes'][string], NonNullable<S['Graph']['nodes']>[string],
S['CoreMetadataInvocation'] | S['MetadataInvocation'] | S['MetadataItemInvocation'] | S['MergeMetadataInvocation'] S['CoreMetadataInvocation'] | S['MetadataInvocation'] | S['MetadataItemInvocation'] | S['MergeMetadataInvocation']
>; >;
export type AnyInvocationIncMetadata = S['Graph']['nodes'][string]; export type AnyInvocationIncMetadata = NonNullable<S['Graph']['nodes']>[string];
export type InvocationType = AnyInvocation['type']; export type InvocationType = AnyInvocation['type'];
type InvocationOutputMap = S['InvocationOutputMap']; type InvocationOutputMap = S['InvocationOutputMap'];
type AnyInvocationOutput = InvocationOutputMap[InvocationType]; export type AnyInvocationOutput = InvocationOutputMap[InvocationType];
export type Invocation<T extends InvocationType> = Extract<AnyInvocation, { type: T }>; export type Invocation<T extends InvocationType> = Extract<AnyInvocation, { type: T }>;
// export type InvocationOutput<T extends InvocationType> = InvocationOutputMap[T]; // export type InvocationOutput<T extends InvocationType> = InvocationOutputMap[T];

View File

@ -1,21 +1,12 @@
import type { Graph, GraphExecutionState, S } from 'services/api/types'; import type { S } from 'services/api/types';
export type AnyInvocation = NonNullable<NonNullable<Graph['nodes']>[string]>;
export type AnyResult = NonNullable<GraphExecutionState['results'][string]>;
export type ModelLoadStartedEvent = S['ModelLoadStartedEvent']; export type ModelLoadStartedEvent = S['ModelLoadStartedEvent'];
export type ModelLoadCompleteEvent = S['ModelLoadCompleteEvent']; export type ModelLoadCompleteEvent = S['ModelLoadCompleteEvent'];
export type InvocationStartedEvent = Omit<S['InvocationStartedEvent'], 'invocation'> & { invocation: AnyInvocation }; export type InvocationStartedEvent = S['InvocationStartedEvent'];
export type InvocationDenoiseProgressEvent = Omit<S['InvocationDenoiseProgressEvent'], 'invocation'> & { export type InvocationDenoiseProgressEvent = S['InvocationDenoiseProgressEvent'];
invocation: AnyInvocation; export type InvocationCompleteEvent = S['InvocationCompleteEvent'];
}; export type InvocationErrorEvent = S['InvocationErrorEvent'];
export type InvocationCompleteEvent = Omit<S['InvocationCompleteEvent'], 'result' | 'invocation'> & {
result: AnyResult;
invocation: AnyInvocation;
};
export type InvocationErrorEvent = Omit<S['InvocationErrorEvent'], 'invocation'> & { invocation: AnyInvocation };
export type ProgressImage = InvocationDenoiseProgressEvent['progress_image']; export type ProgressImage = InvocationDenoiseProgressEvent['progress_image'];
export type ModelInstallDownloadProgressEvent = S['ModelInstallDownloadProgressEvent']; export type ModelInstallDownloadProgressEvent = S['ModelInstallDownloadProgressEvent'];

View File

@ -55,10 +55,10 @@ dependencies = [
# Core application dependencies, pinned for reproducible builds. # Core application dependencies, pinned for reproducible builds.
"fastapi-events==0.11.0", "fastapi-events==0.11.0",
"fastapi==0.110.0", "fastapi==0.111.0",
"huggingface-hub==0.23.1", "huggingface-hub==0.23.1",
"pydantic-settings==2.2.1", "pydantic-settings==2.2.1",
"pydantic==2.6.3", "pydantic==2.7.2",
"python-socketio==5.11.1", "python-socketio==5.11.1",
"uvicorn[standard]==0.28.0", "uvicorn[standard]==0.28.0",

View File

@ -7,9 +7,10 @@ def main():
# Change working directory to the repo root # Change working directory to the repo root
os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from invokeai.app.api_app import custom_openapi from invokeai.app.api_app import app
from invokeai.app.util.custom_openapi import get_openapi_func
schema = custom_openapi() schema = get_openapi_func(app)()
json.dump(schema, sys.stdout, indent=2) json.dump(schema, sys.stdout, indent=2)

View File

@ -1,5 +1,6 @@
import pytest import pytest
from pydantic import TypeAdapter from pydantic import TypeAdapter
from pydantic.json_schema import models_json_schema
from invokeai.app.invocations.baseinvocation import ( from invokeai.app.invocations.baseinvocation import (
BaseInvocation, BaseInvocation,
@ -713,4 +714,4 @@ def test_iterate_accepts_collection():
def test_graph_can_generate_schema(): def test_graph_can_generate_schema():
# Not throwing on this line is sufficient # Not throwing on this line is sufficient
# NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation # NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation
_ = Graph.model_json_schema() models_json_schema([(Graph, "serialization")])