mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(app): openapi schema generation
Some tech debt related to dynamic pydantic schemas for invocations became problematic. Including the invocations and results in the event schemas was breaking pydantic's handling of ref schemas. I don't really understand why - I think it's a pydantic bug in a remote edge case that we are hitting. After many failed attempts I landed on this implementation, which is actually much tidier than what was in there before. - Create pydantic-enabled types for `AnyInvocation` and `AnyInvocationOutput` and use these in place of the janky dynamic unions. Actually, they are kinda the same, but better encapsulated. Use these in `Graph`, `GraphExecutionState`, `InvocationEventBase` and `InvocationCompleteEvent`. - Revise the custom openapi function to work with the new models. - Split out the custom openapi function to a separate file. Add a `post_transform` callback so consumers can customize the output schema. - Update makefile scripts.
This commit is contained in:
@ -2,11 +2,12 @@
|
||||
|
||||
import copy
|
||||
import itertools
|
||||
from typing import Annotated, Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
|
||||
from typing import Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
|
||||
|
||||
import networkx as nx
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
GetCoreSchemaHandler,
|
||||
GetJsonSchemaHandler,
|
||||
ValidationError,
|
||||
field_validator,
|
||||
@ -277,73 +278,46 @@ class CollectInvocation(BaseInvocation):
|
||||
return CollectInvocationOutput(collection=copy.copy(self.collection))
|
||||
|
||||
|
||||
class AnyInvocation(BaseInvocation):
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler):
|
||||
return BaseInvocation.get_typeadapter().core_schema
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
|
||||
# Nodes are too powerful, we have to make our own OpenAPI schema manually
|
||||
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
|
||||
oneOf: list[dict[str, str]] = []
|
||||
for i in BaseInvocation.get_invocations():
|
||||
oneOf.append({"$ref": f"#/components/schemas/{i.__name__}"})
|
||||
return {"oneOf": oneOf}
|
||||
|
||||
|
||||
class AnyInvocationOutput(BaseInvocationOutput):
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler):
|
||||
return BaseInvocationOutput.get_typeadapter().core_schema
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
|
||||
# Nodes are too powerful, we have to make our own OpenAPI schema manually
|
||||
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
|
||||
|
||||
oneOf: list[dict[str, str]] = []
|
||||
for i in BaseInvocationOutput.get_outputs():
|
||||
oneOf.append({"$ref": f"#/components/schemas/{i.__name__}"})
|
||||
return {"oneOf": oneOf}
|
||||
|
||||
|
||||
class Graph(BaseModel):
|
||||
id: str = Field(description="The id of this graph", default_factory=uuid_string)
|
||||
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
|
||||
nodes: dict[str, BaseInvocation] = Field(description="The nodes in this graph", default_factory=dict)
|
||||
nodes: dict[str, AnyInvocation] = Field(description="The nodes in this graph", default_factory=dict)
|
||||
edges: list[Edge] = Field(
|
||||
description="The connections between nodes and their fields in this graph",
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
@field_validator("nodes", mode="plain")
|
||||
@classmethod
|
||||
def validate_nodes(cls, v: dict[str, Any]):
|
||||
"""Validates the nodes in the graph by retrieving a union of all node types and validating each node."""
|
||||
|
||||
# Invocations register themselves as their python modules are executed. The union of all invocations is
|
||||
# constructed at runtime. We use pydantic to validate `Graph.nodes` using that union.
|
||||
#
|
||||
# It's possible that when `graph.py` is executed, not all invocation-containing modules will have executed. If
|
||||
# we construct the invocation union as `graph.py` is executed, we may miss some invocations. Those missing
|
||||
# invocations will cause a graph to fail if they are used.
|
||||
#
|
||||
# We can get around this by validating the nodes in the graph using a "plain" validator, which overrides the
|
||||
# pydantic validation entirely. This allows us to validate the nodes using the union of invocations at runtime.
|
||||
#
|
||||
# This same pattern is used in `GraphExecutionState`.
|
||||
|
||||
nodes: dict[str, BaseInvocation] = {}
|
||||
typeadapter = BaseInvocation.get_typeadapter()
|
||||
for node_id, node in v.items():
|
||||
nodes[node_id] = typeadapter.validate_python(node)
|
||||
return nodes
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
|
||||
# We use a "plain" validator to validate the nodes in the graph. Pydantic is unable to create a JSON Schema for
|
||||
# fields that use "plain" validators, so we have to hack around this. Also, we need to add all invocations to
|
||||
# the generated schema as options for the `nodes` field.
|
||||
#
|
||||
# The workaround is to create a new BaseModel that has the same fields as `Graph` but without the validator and
|
||||
# with the invocation union as the type for the `nodes` field. Pydantic then generates the JSON Schema as
|
||||
# expected.
|
||||
#
|
||||
# You might be tempted to do something like this:
|
||||
#
|
||||
# ```py
|
||||
# cloned_model = create_model(cls.__name__, __base__=cls, nodes=...)
|
||||
# delattr(cloned_model, "validate_nodes")
|
||||
# cloned_model.model_rebuild(force=True)
|
||||
# json_schema = handler(cloned_model.__pydantic_core_schema__)
|
||||
# ```
|
||||
#
|
||||
# Unfortunately, this does not work. Calling `handler` here results in infinite recursion as pydantic attempts
|
||||
# to build the JSON Schema for the cloned model. Instead, we have to manually clone the model.
|
||||
#
|
||||
# This same pattern is used in `GraphExecutionState`.
|
||||
|
||||
class Graph(BaseModel):
|
||||
id: Optional[str] = Field(default=None, description="The id of this graph")
|
||||
nodes: dict[
|
||||
str, Annotated[Union[tuple(BaseInvocation._invocation_classes)], Field(discriminator="type")]
|
||||
] = Field(description="The nodes in this graph")
|
||||
edges: list[Edge] = Field(description="The connections between nodes and their fields in this graph")
|
||||
|
||||
json_schema = handler(Graph.__pydantic_core_schema__)
|
||||
json_schema = handler.resolve_ref_schema(json_schema)
|
||||
return json_schema
|
||||
|
||||
def add_node(self, node: BaseInvocation) -> None:
|
||||
"""Adds a node to a graph
|
||||
|
||||
@ -774,7 +748,7 @@ class GraphExecutionState(BaseModel):
|
||||
)
|
||||
|
||||
# The results of executed nodes
|
||||
results: dict[str, BaseInvocationOutput] = Field(description="The results of node executions", default_factory=dict)
|
||||
results: dict[str, AnyInvocationOutput] = Field(description="The results of node executions", default_factory=dict)
|
||||
|
||||
# Errors raised when executing nodes
|
||||
errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict)
|
||||
@ -791,52 +765,12 @@ class GraphExecutionState(BaseModel):
|
||||
default_factory=dict,
|
||||
)
|
||||
|
||||
@field_validator("results", mode="plain")
|
||||
@classmethod
|
||||
def validate_results(cls, v: dict[str, BaseInvocationOutput]):
|
||||
"""Validates the results in the GES by retrieving a union of all output types and validating each result."""
|
||||
|
||||
# See the comment in `Graph.validate_nodes` for an explanation of this logic.
|
||||
results: dict[str, BaseInvocationOutput] = {}
|
||||
typeadapter = BaseInvocationOutput.get_typeadapter()
|
||||
for result_id, result in v.items():
|
||||
results[result_id] = typeadapter.validate_python(result)
|
||||
return results
|
||||
|
||||
@field_validator("graph")
|
||||
def graph_is_valid(cls, v: Graph):
|
||||
"""Validates that the graph is valid"""
|
||||
v.validate_self()
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
|
||||
# See the comment in `Graph.__get_pydantic_json_schema__` for an explanation of this logic.
|
||||
class GraphExecutionState(BaseModel):
|
||||
"""Tracks the state of a graph execution"""
|
||||
|
||||
id: str = Field(description="The id of the execution state")
|
||||
graph: Graph = Field(description="The graph being executed")
|
||||
execution_graph: Graph = Field(description="The expanded graph of activated and executed nodes")
|
||||
executed: set[str] = Field(description="The set of node ids that have been executed")
|
||||
executed_history: list[str] = Field(
|
||||
description="The list of node ids that have been executed, in order of execution"
|
||||
)
|
||||
results: dict[
|
||||
str, Annotated[Union[tuple(BaseInvocationOutput._output_classes)], Field(discriminator="type")]
|
||||
] = Field(description="The results of node executions")
|
||||
errors: dict[str, str] = Field(description="Errors raised when executing nodes")
|
||||
prepared_source_mapping: dict[str, str] = Field(
|
||||
description="The map of prepared nodes to original graph nodes"
|
||||
)
|
||||
source_prepared_mapping: dict[str, set[str]] = Field(
|
||||
description="The map of original graph nodes to prepared nodes"
|
||||
)
|
||||
|
||||
json_schema = handler(GraphExecutionState.__pydantic_core_schema__)
|
||||
json_schema = handler.resolve_ref_schema(json_schema)
|
||||
return json_schema
|
||||
|
||||
def next(self) -> Optional[BaseInvocation]:
|
||||
"""Gets the next node ready to execute."""
|
||||
|
||||
|
Reference in New Issue
Block a user