mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(nodes): fix OpenAPI schema generation
The change to `Graph.nodes` and `GraphExecutionState.results` validation requires some fanagling to get the OpenAPI schema generation to work. See new comments for a details.
This commit is contained in:
parent
731860c332
commit
b79ae3a101
@ -151,6 +151,8 @@ def custom_openapi() -> dict[str, Any]:
|
|||||||
# TODO: note that we assume the schema_key here is the TYPE.__name__
|
# TODO: note that we assume the schema_key here is the TYPE.__name__
|
||||||
# This could break in some cases, figure out a better way to do it
|
# This could break in some cases, figure out a better way to do it
|
||||||
output_type_titles[schema_key] = output_schema["title"]
|
output_type_titles[schema_key] = output_schema["title"]
|
||||||
|
openapi_schema["components"]["schemas"][schema_key] = output_schema
|
||||||
|
openapi_schema["components"]["schemas"][schema_key]["class"] = "output"
|
||||||
|
|
||||||
# Add Node Editor UI helper schemas
|
# Add Node Editor UI helper schemas
|
||||||
ui_config_schemas = models_json_schema(
|
ui_config_schemas = models_json_schema(
|
||||||
@ -173,7 +175,6 @@ def custom_openapi() -> dict[str, Any]:
|
|||||||
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
|
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
|
||||||
invoker_schema["output"] = outputs_ref
|
invoker_schema["output"] = outputs_ref
|
||||||
invoker_schema["class"] = "invocation"
|
invoker_schema["class"] = "invocation"
|
||||||
openapi_schema["components"]["schemas"][f"{output_type_title}"]["class"] = "output"
|
|
||||||
|
|
||||||
# This code no longer seems to be necessary?
|
# This code no longer seems to be necessary?
|
||||||
# Leave it here just in case
|
# Leave it here just in case
|
||||||
|
@ -2,16 +2,19 @@
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import itertools
|
import itertools
|
||||||
from typing import Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
|
from typing import Annotated, Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
BaseModel,
|
BaseModel,
|
||||||
ConfigDict,
|
ConfigDict,
|
||||||
|
GetJsonSchemaHandler,
|
||||||
field_validator,
|
field_validator,
|
||||||
model_validator,
|
model_validator,
|
||||||
)
|
)
|
||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
|
from pydantic.json_schema import JsonSchemaValue
|
||||||
|
from pydantic_core import CoreSchema
|
||||||
|
|
||||||
# Importing * is bad karma but needed here for node detection
|
# Importing * is bad karma but needed here for node detection
|
||||||
from invokeai.app.invocations import * # noqa: F401 F403
|
from invokeai.app.invocations import * # noqa: F401 F403
|
||||||
@ -277,12 +280,61 @@ class Graph(BaseModel):
|
|||||||
@field_validator("nodes", mode="plain")
|
@field_validator("nodes", mode="plain")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_nodes(cls, v: dict[str, Any]):
|
def validate_nodes(cls, v: dict[str, Any]):
|
||||||
|
"""Validates the nodes in the graph by retrieving a union of all node types and validating each node."""
|
||||||
|
|
||||||
|
# Invocations register themselves as their python modules are executed. The union of all invocations is
|
||||||
|
# constructed at runtime. We use pydantic to validate `Graph.nodes` using that union.
|
||||||
|
#
|
||||||
|
# It's possible that when `graph.py` is executed, not all invocation-containing modules will have executed. If
|
||||||
|
# we construct the invocation union as `graph.py` is executed, we may miss some invocations. Those missing
|
||||||
|
# invocations will cause a graph to fail if they are used.
|
||||||
|
#
|
||||||
|
# We can get around this by validating the nodes in the graph using a "plain" validator, which overrides the
|
||||||
|
# pydantic validation entirely. This allows us to validate the nodes using the union of invocations at runtime.
|
||||||
|
#
|
||||||
|
# This same pattern is used in `GraphExecutionState`.
|
||||||
|
|
||||||
nodes: dict[str, BaseInvocation] = {}
|
nodes: dict[str, BaseInvocation] = {}
|
||||||
typeadapter = BaseInvocation.get_typeadapter()
|
typeadapter = BaseInvocation.get_typeadapter()
|
||||||
for node_id, node in v.items():
|
for node_id, node in v.items():
|
||||||
nodes[node_id] = typeadapter.validate_python(node)
|
nodes[node_id] = typeadapter.validate_python(node)
|
||||||
return nodes
|
return nodes
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
|
||||||
|
# We use a "plain" validator to validate the nodes in the graph. Pydantic is unable to create a JSON Schema for
|
||||||
|
# fields that use "plain" validators, so we have to hack around this. Also, we need to add all invocations to
|
||||||
|
# the generated schema as options for the `nodes` field.
|
||||||
|
#
|
||||||
|
# The workaround is to create a new BaseModel that has the same fields as `Graph` but without the validator and
|
||||||
|
# with the invocation union as the type for the `nodes` field. Pydantic then generates the JSON Schema as
|
||||||
|
# expected.
|
||||||
|
#
|
||||||
|
# You might be tempted to do something like this:
|
||||||
|
#
|
||||||
|
# ```py
|
||||||
|
# cloned_model = create_model(cls.__name__, __base__=cls, nodes=...)
|
||||||
|
# delattr(cloned_model, "validate_nodes")
|
||||||
|
# cloned_model.model_rebuild(force=True)
|
||||||
|
# json_schema = handler(cloned_model.__pydantic_core_schema__)
|
||||||
|
# ```
|
||||||
|
#
|
||||||
|
# Unfortunately, this does not work. Calling `handler` here results in infinite recursion as pydantic attempts
|
||||||
|
# to build the JSON Schema for the cloned model. Instead, we have to manually clone the model.
|
||||||
|
#
|
||||||
|
# This same pattern is used in `GraphExecutionState`.
|
||||||
|
|
||||||
|
class Graph(BaseModel):
|
||||||
|
id: Optional[str] = Field(default=None, description="The id of this graph")
|
||||||
|
nodes: dict[
|
||||||
|
str, Annotated[Union[tuple(BaseInvocation._invocation_classes)], Field(discriminator="type")]
|
||||||
|
] = Field(description="The nodes in this graph")
|
||||||
|
edges: list[Edge] = Field(description="The connections between nodes and their fields in this graph")
|
||||||
|
|
||||||
|
json_schema = handler(Graph.__pydantic_core_schema__)
|
||||||
|
json_schema = handler.resolve_ref_schema(json_schema)
|
||||||
|
return json_schema
|
||||||
|
|
||||||
def add_node(self, node: BaseInvocation) -> None:
|
def add_node(self, node: BaseInvocation) -> None:
|
||||||
"""Adds a node to a graph
|
"""Adds a node to a graph
|
||||||
|
|
||||||
@ -852,6 +904,9 @@ class GraphExecutionState(BaseModel):
|
|||||||
@field_validator("results", mode="plain")
|
@field_validator("results", mode="plain")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_results(cls, v: dict[str, BaseInvocationOutput]):
|
def validate_results(cls, v: dict[str, BaseInvocationOutput]):
|
||||||
|
"""Validates the results in the GES by retrieving a union of all output types and validating each result."""
|
||||||
|
|
||||||
|
# See the comment in `Graph.validate_nodes` for an explanation of this logic.
|
||||||
results: dict[str, BaseInvocationOutput] = {}
|
results: dict[str, BaseInvocationOutput] = {}
|
||||||
typeadapter = BaseInvocationOutput.get_typeadapter()
|
typeadapter = BaseInvocationOutput.get_typeadapter()
|
||||||
for result_id, result in v.items():
|
for result_id, result in v.items():
|
||||||
@ -864,6 +919,34 @@ class GraphExecutionState(BaseModel):
|
|||||||
v.validate_self()
|
v.validate_self()
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
|
||||||
|
# See the comment in `Graph.__get_pydantic_json_schema__` for an explanation of this logic.
|
||||||
|
class GraphExecutionState(BaseModel):
|
||||||
|
"""Tracks the state of a graph execution"""
|
||||||
|
|
||||||
|
id: str = Field(description="The id of the execution state")
|
||||||
|
graph: Graph = Field(description="The graph being executed")
|
||||||
|
execution_graph: Graph = Field(description="The expanded graph of activated and executed nodes")
|
||||||
|
executed: set[str] = Field(description="The set of node ids that have been executed")
|
||||||
|
executed_history: list[str] = Field(
|
||||||
|
description="The list of node ids that have been executed, in order of execution"
|
||||||
|
)
|
||||||
|
results: dict[
|
||||||
|
str, Annotated[Union[tuple(BaseInvocationOutput._output_classes)], Field(discriminator="type")]
|
||||||
|
] = Field(description="The results of node executions")
|
||||||
|
errors: dict[str, str] = Field(description="Errors raised when executing nodes")
|
||||||
|
prepared_source_mapping: dict[str, str] = Field(
|
||||||
|
description="The map of prepared nodes to original graph nodes"
|
||||||
|
)
|
||||||
|
source_prepared_mapping: dict[str, set[str]] = Field(
|
||||||
|
description="The map of original graph nodes to prepared nodes"
|
||||||
|
)
|
||||||
|
|
||||||
|
json_schema = handler(GraphExecutionState.__pydantic_core_schema__)
|
||||||
|
json_schema = handler.resolve_ref_schema(json_schema)
|
||||||
|
return json_schema
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"required": [
|
"required": [
|
||||||
@ -1260,8 +1343,3 @@ class LibraryGraph(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
Graph.model_rebuild(force=True)
|
|
||||||
GraphInvocation.model_rebuild(force=True)
|
|
||||||
GraphExecutionState.model_rebuild(force=True)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user