feat(nodes): JIT graph nodes validation

We use pydantic to validate a union of valid invocations when instantiating a graph.

Previously, we constructed the union while creating the `Graph` class. This introduces a dependency on the order of imports.

For example, consider a setup where we have 3 invocations in the app:

- Python executes the module where `FirstInvocation` is defined, registering `FirstInvocation`.
- Python executes the module where `SecondInvocation` is defined, registering `SecondInvocation`.
- Python executes the module where `Graph` is defined. A union of invocations is created and used to define the `Graph.nodes` field. The union contains `FirstInvocation` and `SecondInvocation`.
- Python executes the module where `ThirdInvocation` is defined, registering `ThirdInvocation`.
- A graph is created that includes `ThirdInvocation`. Pydantic validates the graph using the union, which does not know about `ThirdInvocation`, raising a `ValidationError` about an unknown invocation type.

This scenario has been particularly problematic in tests, where we may create invocations dynamically. The test files have to be structured in such a way that the imports happen in the right order. It's a major pain.

This PR refactors the validation of graph nodes to resolve this issue:

- `BaseInvocation` gets a new method `get_typeadapter`. This builds a pydantic `TypeAdapter` for the union of all registered invocations, caching it after the first call.
- `Graph.nodes`'s type is widened to `dict[str, BaseInvocation]`. This actually is a nice bonus, because we get better type hints whenever we reference `some_graph.nodes`.
- A "plain" field validator takes over the validation logic for `Graph.nodes`. "Plain" validators totally override pydantic's own validation logic. The validator grabs the `TypeAdapter` from `BaseInvocation`, then validates each node with it. The validation is identical to the previous implementation - we get the same errors.

`BaseInvocationOutput` gets the same treatment.
This commit is contained in:
psychedelicious 2024-02-17 11:22:08 +11:00
parent b06d63fb34
commit 81518ee1af
3 changed files with 63 additions and 25 deletions

View File

@ -8,13 +8,26 @@ import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
from inspect import signature from inspect import signature
from types import UnionType from typing import (
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union, cast TYPE_CHECKING,
Annotated,
Any,
Callable,
ClassVar,
Iterable,
Literal,
Optional,
Type,
TypeVar,
Union,
cast,
)
import semver import semver
from pydantic import BaseModel, ConfigDict, Field, create_model from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, create_model
from pydantic.fields import FieldInfo from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefined from pydantic_core import PydanticUndefined
from typing_extensions import TypeAliasType
from invokeai.app.invocations.fields import ( from invokeai.app.invocations.fields import (
FieldKind, FieldKind,
@ -84,6 +97,7 @@ class BaseInvocationOutput(BaseModel):
""" """
_output_classes: ClassVar[set[BaseInvocationOutput]] = set() _output_classes: ClassVar[set[BaseInvocationOutput]] = set()
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
@classmethod @classmethod
def register_output(cls, output: BaseInvocationOutput) -> None: def register_output(cls, output: BaseInvocationOutput) -> None:
@ -96,10 +110,14 @@ class BaseInvocationOutput(BaseModel):
return cls._output_classes return cls._output_classes
@classmethod @classmethod
def get_outputs_union(cls) -> UnionType: def get_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a union of all invocation outputs.""" """Gets a pydantc TypeAdapter for the union of all invocation output types."""
outputs_union = Union[tuple(cls._output_classes)] # type: ignore [valid-type] if not cls._typeadapter:
return outputs_union # type: ignore [return-value] InvocationOutputsUnion = TypeAliasType(
"InvocationOutputsUnion", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
)
cls._typeadapter = TypeAdapter(InvocationOutputsUnion)
return cls._typeadapter
@classmethod @classmethod
def get_output_types(cls) -> Iterable[str]: def get_output_types(cls) -> Iterable[str]:
@ -148,6 +166,7 @@ class BaseInvocation(ABC, BaseModel):
""" """
_invocation_classes: ClassVar[set[BaseInvocation]] = set() _invocation_classes: ClassVar[set[BaseInvocation]] = set()
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
@classmethod @classmethod
def get_type(cls) -> str: def get_type(cls) -> str:
@ -160,10 +179,14 @@ class BaseInvocation(ABC, BaseModel):
cls._invocation_classes.add(invocation) cls._invocation_classes.add(invocation)
@classmethod @classmethod
def get_invocations_union(cls) -> UnionType: def get_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a union of all invocation types.""" """Gets a pydantc TypeAdapter for the union of all invocation types."""
invocations_union = Union[tuple(cls._invocation_classes)] # type: ignore [valid-type] if not cls._typeadapter:
return invocations_union # type: ignore [return-value] InvocationsUnion = TypeAliasType(
"InvocationsUnion", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
)
cls._typeadapter = TypeAdapter(InvocationsUnion)
return cls._typeadapter
@classmethod @classmethod
def get_invocations(cls) -> Iterable[BaseInvocation]: def get_invocations(cls) -> Iterable[BaseInvocation]:

View File

@ -417,7 +417,7 @@ class ClipSkipInvocation(BaseInvocation):
"""Skip layers in clip text_encoder model.""" """Skip layers in clip text_encoder model."""
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP") clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers) skipped_layers: int = InputField(default=0, ge=0, description=FieldDescriptions.skipped_layers)
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput: def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
self.clip.skipped_layers += self.skipped_layers self.clip.skipped_layers += self.skipped_layers

View File

@ -2,10 +2,15 @@
import copy import copy
import itertools import itertools
from typing import Annotated, Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints from typing import Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
import networkx as nx import networkx as nx
from pydantic import BaseModel, ConfigDict, field_validator, model_validator from pydantic import (
BaseModel,
ConfigDict,
field_validator,
model_validator,
)
from pydantic.fields import Field from pydantic.fields import Field
# Importing * is bad karma but needed here for node detection # Importing * is bad karma but needed here for node detection
@ -260,21 +265,24 @@ class CollectInvocation(BaseInvocation):
return CollectInvocationOutput(collection=copy.copy(self.collection)) return CollectInvocationOutput(collection=copy.copy(self.collection))
InvocationsUnion: Any = BaseInvocation.get_invocations_union()
InvocationOutputsUnion: Any = BaseInvocationOutput.get_outputs_union()
class Graph(BaseModel): class Graph(BaseModel):
id: str = Field(description="The id of this graph", default_factory=uuid_string) id: str = Field(description="The id of this graph", default_factory=uuid_string)
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me # TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field( nodes: dict[str, BaseInvocation] = Field(description="The nodes in this graph", default_factory=dict)
description="The nodes in this graph", default_factory=dict
)
edges: list[Edge] = Field( edges: list[Edge] = Field(
description="The connections between nodes and their fields in this graph", description="The connections between nodes and their fields in this graph",
default_factory=list, default_factory=list,
) )
@field_validator("nodes", mode="plain")
@classmethod
def validate_nodes(cls, v: dict[str, Any]):
nodes: dict[str, BaseInvocation] = {}
typeadapter = BaseInvocation.get_typeadapter()
for node_id, node in v.items():
nodes[node_id] = typeadapter.validate_python(node)
return nodes
def add_node(self, node: BaseInvocation) -> None: def add_node(self, node: BaseInvocation) -> None:
"""Adds a node to a graph """Adds a node to a graph
@ -824,9 +832,7 @@ class GraphExecutionState(BaseModel):
) )
# The results of executed nodes # The results of executed nodes
results: dict[str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]] = Field( results: dict[str, BaseInvocationOutput] = Field(description="The results of node executions", default_factory=dict)
description="The results of node executions", default_factory=dict
)
# Errors raised when executing nodes # Errors raised when executing nodes
errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict) errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict)
@ -843,6 +849,15 @@ class GraphExecutionState(BaseModel):
default_factory=dict, default_factory=dict,
) )
@field_validator("results", mode="plain")
@classmethod
def validate_results(cls, v: dict[str, BaseInvocationOutput]):
results: dict[str, BaseInvocationOutput] = {}
typeadapter = BaseInvocationOutput.get_typeadapter()
for result_id, result in v.items():
results[result_id] = typeadapter.validate_python(result)
return results
@field_validator("graph") @field_validator("graph")
def graph_is_valid(cls, v: Graph): def graph_is_valid(cls, v: Graph):
"""Validates that the graph is valid""" """Validates that the graph is valid"""
@ -1247,6 +1262,6 @@ class LibraryGraph(BaseModel):
return values return values
GraphInvocation.model_rebuild(force=True)
Graph.model_rebuild(force=True) Graph.model_rebuild(force=True)
GraphInvocation.model_rebuild(force=True)
GraphExecutionState.model_rebuild(force=True) GraphExecutionState.model_rebuild(force=True)