diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 149d47fb96..65607c436a 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -151,6 +151,8 @@ def custom_openapi() -> dict[str, Any]: # TODO: note that we assume the schema_key here is the TYPE.__name__ # This could break in some cases, figure out a better way to do it output_type_titles[schema_key] = output_schema["title"] + openapi_schema["components"]["schemas"][schema_key] = output_schema + openapi_schema["components"]["schemas"][schema_key]["class"] = "output" # Add Node Editor UI helper schemas ui_config_schemas = models_json_schema( @@ -173,7 +175,6 @@ def custom_openapi() -> dict[str, Any]: outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"} invoker_schema["output"] = outputs_ref invoker_schema["class"] = "invocation" - openapi_schema["components"]["schemas"][f"{output_type_title}"]["class"] = "output" # This code no longer seems to be necessary? # Leave it here just in case diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index 3066af0e50..1b53f64222 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -2,16 +2,19 @@ import copy import itertools -from typing import Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints +from typing import Annotated, Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints import networkx as nx from pydantic import ( BaseModel, ConfigDict, + GetJsonSchemaHandler, field_validator, model_validator, ) from pydantic.fields import Field +from pydantic.json_schema import JsonSchemaValue +from pydantic_core import CoreSchema # Importing * is bad karma but needed here for node detection from invokeai.app.invocations import * # noqa: F401 F403 @@ -277,12 +280,61 @@ class Graph(BaseModel): @field_validator("nodes", mode="plain") @classmethod def validate_nodes(cls, v: dict[str, Any]): + """Validates the nodes in the graph by retrieving a union of all node types and validating each node.""" + + # Invocations register themselves as their python modules are executed. The union of all invocations is + # constructed at runtime. We use pydantic to validate `Graph.nodes` using that union. + # + # It's possible that when `graph.py` is executed, not all invocation-containing modules will have executed. If + # we construct the invocation union as `graph.py` is executed, we may miss some invocations. Those missing + # invocations will cause a graph to fail if they are used. + # + # We can get around this by validating the nodes in the graph using a "plain" validator, which overrides the + # pydantic validation entirely. This allows us to validate the nodes using the union of invocations at runtime. + # + # This same pattern is used in `GraphExecutionState`. + nodes: dict[str, BaseInvocation] = {} typeadapter = BaseInvocation.get_typeadapter() for node_id, node in v.items(): nodes[node_id] = typeadapter.validate_python(node) return nodes + @classmethod + def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue: + # We use a "plain" validator to validate the nodes in the graph. Pydantic is unable to create a JSON Schema for + # fields that use "plain" validators, so we have to hack around this. Also, we need to add all invocations to + # the generated schema as options for the `nodes` field. + # + # The workaround is to create a new BaseModel that has the same fields as `Graph` but without the validator and + # with the invocation union as the type for the `nodes` field. Pydantic then generates the JSON Schema as + # expected. + # + # You might be tempted to do something like this: + # + # ```py + # cloned_model = create_model(cls.__name__, __base__=cls, nodes=...) + # delattr(cloned_model, "validate_nodes") + # cloned_model.model_rebuild(force=True) + # json_schema = handler(cloned_model.__pydantic_core_schema__) + # ``` + # + # Unfortunately, this does not work. Calling `handler` here results in infinite recursion as pydantic attempts + # to build the JSON Schema for the cloned model. Instead, we have to manually clone the model. + # + # This same pattern is used in `GraphExecutionState`. + + class Graph(BaseModel): + id: Optional[str] = Field(default=None, description="The id of this graph") + nodes: dict[ + str, Annotated[Union[tuple(BaseInvocation._invocation_classes)], Field(discriminator="type")] + ] = Field(description="The nodes in this graph") + edges: list[Edge] = Field(description="The connections between nodes and their fields in this graph") + + json_schema = handler(Graph.__pydantic_core_schema__) + json_schema = handler.resolve_ref_schema(json_schema) + return json_schema + def add_node(self, node: BaseInvocation) -> None: """Adds a node to a graph @@ -852,6 +904,9 @@ class GraphExecutionState(BaseModel): @field_validator("results", mode="plain") @classmethod def validate_results(cls, v: dict[str, BaseInvocationOutput]): + """Validates the results in the GES by retrieving a union of all output types and validating each result.""" + + # See the comment in `Graph.validate_nodes` for an explanation of this logic. results: dict[str, BaseInvocationOutput] = {} typeadapter = BaseInvocationOutput.get_typeadapter() for result_id, result in v.items(): @@ -864,6 +919,34 @@ class GraphExecutionState(BaseModel): v.validate_self() return v + @classmethod + def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue: + # See the comment in `Graph.__get_pydantic_json_schema__` for an explanation of this logic. + class GraphExecutionState(BaseModel): + """Tracks the state of a graph execution""" + + id: str = Field(description="The id of the execution state") + graph: Graph = Field(description="The graph being executed") + execution_graph: Graph = Field(description="The expanded graph of activated and executed nodes") + executed: set[str] = Field(description="The set of node ids that have been executed") + executed_history: list[str] = Field( + description="The list of node ids that have been executed, in order of execution" + ) + results: dict[ + str, Annotated[Union[tuple(BaseInvocationOutput._output_classes)], Field(discriminator="type")] + ] = Field(description="The results of node executions") + errors: dict[str, str] = Field(description="Errors raised when executing nodes") + prepared_source_mapping: dict[str, str] = Field( + description="The map of prepared nodes to original graph nodes" + ) + source_prepared_mapping: dict[str, set[str]] = Field( + description="The map of original graph nodes to prepared nodes" + ) + + json_schema = handler(GraphExecutionState.__pydantic_core_schema__) + json_schema = handler.resolve_ref_schema(json_schema) + return json_schema + model_config = ConfigDict( json_schema_extra={ "required": [ @@ -1260,8 +1343,3 @@ class LibraryGraph(BaseModel): ) return values - - -Graph.model_rebuild(force=True) -GraphInvocation.model_rebuild(force=True) -GraphExecutionState.model_rebuild(force=True)