mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): JIT graph nodes validation
We use pydantic to validate a union of valid invocations when instantiating a graph. Previously, we constructed the union while creating the `Graph` class. This introduces a dependency on the order of imports. For example, consider a setup where we have 3 invocations in the app: - Python executes the module where `FirstInvocation` is defined, registering `FirstInvocation`. - Python executes the module where `SecondInvocation` is defined, registering `SecondInvocation`. - Python executes the module where `Graph` is defined. A union of invocations is created and used to define the `Graph.nodes` field. The union contains `FirstInvocation` and `SecondInvocation`. - Python executes the module where `ThirdInvocation` is defined, registering `ThirdInvocation`. - A graph is created that includes `ThirdInvocation`. Pydantic validates the graph using the union, which does not know about `ThirdInvocation`, raising a `ValidationError` about an unknown invocation type. This scenario has been particularly problematic in tests, where we may create invocations dynamically. The test files have to be structured in such a way that the imports happen in the right order. It's a major pain. This PR refactors the validation of graph nodes to resolve this issue: - `BaseInvocation` gets a new method `get_typeadapter`. This builds a pydantic `TypeAdapter` for the union of all registered invocations, caching it after the first call. - `Graph.nodes`'s type is widened to `dict[str, BaseInvocation]`. This actually is a nice bonus, because we get better type hints whenever we reference `some_graph.nodes`. - A "plain" field validator takes over the validation logic for `Graph.nodes`. "Plain" validators totally override pydantic's own validation logic. The validator grabs the `TypeAdapter` from `BaseInvocation`, then validates each node with it. The validation is identical to the previous implementation - we get the same errors. `BaseInvocationOutput` gets the same treatment.
This commit is contained in:
@ -2,10 +2,15 @@
|
||||
|
||||
import copy
|
||||
import itertools
|
||||
from typing import Annotated, Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
|
||||
from typing import Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
|
||||
|
||||
import networkx as nx
|
||||
from pydantic import BaseModel, ConfigDict, field_validator, model_validator
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
from pydantic.fields import Field
|
||||
|
||||
# Importing * is bad karma but needed here for node detection
|
||||
@ -260,21 +265,24 @@ class CollectInvocation(BaseInvocation):
|
||||
return CollectInvocationOutput(collection=copy.copy(self.collection))
|
||||
|
||||
|
||||
InvocationsUnion: Any = BaseInvocation.get_invocations_union()
|
||||
InvocationOutputsUnion: Any = BaseInvocationOutput.get_outputs_union()
|
||||
|
||||
|
||||
class Graph(BaseModel):
|
||||
id: str = Field(description="The id of this graph", default_factory=uuid_string)
|
||||
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
|
||||
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(
|
||||
description="The nodes in this graph", default_factory=dict
|
||||
)
|
||||
nodes: dict[str, BaseInvocation] = Field(description="The nodes in this graph", default_factory=dict)
|
||||
edges: list[Edge] = Field(
|
||||
description="The connections between nodes and their fields in this graph",
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
@field_validator("nodes", mode="plain")
|
||||
@classmethod
|
||||
def validate_nodes(cls, v: dict[str, Any]):
|
||||
nodes: dict[str, BaseInvocation] = {}
|
||||
typeadapter = BaseInvocation.get_typeadapter()
|
||||
for node_id, node in v.items():
|
||||
nodes[node_id] = typeadapter.validate_python(node)
|
||||
return nodes
|
||||
|
||||
def add_node(self, node: BaseInvocation) -> None:
|
||||
"""Adds a node to a graph
|
||||
|
||||
@ -824,9 +832,7 @@ class GraphExecutionState(BaseModel):
|
||||
)
|
||||
|
||||
# The results of executed nodes
|
||||
results: dict[str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]] = Field(
|
||||
description="The results of node executions", default_factory=dict
|
||||
)
|
||||
results: dict[str, BaseInvocationOutput] = Field(description="The results of node executions", default_factory=dict)
|
||||
|
||||
# Errors raised when executing nodes
|
||||
errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict)
|
||||
@ -843,6 +849,15 @@ class GraphExecutionState(BaseModel):
|
||||
default_factory=dict,
|
||||
)
|
||||
|
||||
@field_validator("results", mode="plain")
|
||||
@classmethod
|
||||
def validate_results(cls, v: dict[str, BaseInvocationOutput]):
|
||||
results: dict[str, BaseInvocationOutput] = {}
|
||||
typeadapter = BaseInvocationOutput.get_typeadapter()
|
||||
for result_id, result in v.items():
|
||||
results[result_id] = typeadapter.validate_python(result)
|
||||
return results
|
||||
|
||||
@field_validator("graph")
|
||||
def graph_is_valid(cls, v: Graph):
|
||||
"""Validates that the graph is valid"""
|
||||
@ -1247,6 +1262,6 @@ class LibraryGraph(BaseModel):
|
||||
return values
|
||||
|
||||
|
||||
GraphInvocation.model_rebuild(force=True)
|
||||
Graph.model_rebuild(force=True)
|
||||
GraphInvocation.model_rebuild(force=True)
|
||||
GraphExecutionState.model_rebuild(force=True)
|
||||
|
Reference in New Issue
Block a user