mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): JIT graph nodes validation
We use pydantic to validate a union of valid invocations when instantiating a graph. Previously, we constructed the union while creating the `Graph` class. This introduces a dependency on the order of imports. For example, consider a setup where we have 3 invocations in the app: - Python executes the module where `FirstInvocation` is defined, registering `FirstInvocation`. - Python executes the module where `SecondInvocation` is defined, registering `SecondInvocation`. - Python executes the module where `Graph` is defined. A union of invocations is created and used to define the `Graph.nodes` field. The union contains `FirstInvocation` and `SecondInvocation`. - Python executes the module where `ThirdInvocation` is defined, registering `ThirdInvocation`. - A graph is created that includes `ThirdInvocation`. Pydantic validates the graph using the union, which does not know about `ThirdInvocation`, raising a `ValidationError` about an unknown invocation type. This scenario has been particularly problematic in tests, where we may create invocations dynamically. The test files have to be structured in such a way that the imports happen in the right order. It's a major pain. This PR refactors the validation of graph nodes to resolve this issue: - `BaseInvocation` gets a new method `get_typeadapter`. This builds a pydantic `TypeAdapter` for the union of all registered invocations, caching it after the first call. - `Graph.nodes`'s type is widened to `dict[str, BaseInvocation]`. This actually is a nice bonus, because we get better type hints whenever we reference `some_graph.nodes`. - A "plain" field validator takes over the validation logic for `Graph.nodes`. "Plain" validators totally override pydantic's own validation logic. The validator grabs the `TypeAdapter` from `BaseInvocation`, then validates each node with it. The validation is identical to the previous implementation - we get the same errors. `BaseInvocationOutput` gets the same treatment.
This commit is contained in:
parent
b06d63fb34
commit
81518ee1af
@ -8,13 +8,26 @@ import warnings
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from types import UnionType
|
from typing import (
|
||||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union, cast
|
TYPE_CHECKING,
|
||||||
|
Annotated,
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
ClassVar,
|
||||||
|
Iterable,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
import semver
|
import semver
|
||||||
from pydantic import BaseModel, ConfigDict, Field, create_model
|
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, create_model
|
||||||
from pydantic.fields import FieldInfo
|
from pydantic.fields import FieldInfo
|
||||||
from pydantic_core import PydanticUndefined
|
from pydantic_core import PydanticUndefined
|
||||||
|
from typing_extensions import TypeAliasType
|
||||||
|
|
||||||
from invokeai.app.invocations.fields import (
|
from invokeai.app.invocations.fields import (
|
||||||
FieldKind,
|
FieldKind,
|
||||||
@ -84,6 +97,7 @@ class BaseInvocationOutput(BaseModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
_output_classes: ClassVar[set[BaseInvocationOutput]] = set()
|
_output_classes: ClassVar[set[BaseInvocationOutput]] = set()
|
||||||
|
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register_output(cls, output: BaseInvocationOutput) -> None:
|
def register_output(cls, output: BaseInvocationOutput) -> None:
|
||||||
@ -96,10 +110,14 @@ class BaseInvocationOutput(BaseModel):
|
|||||||
return cls._output_classes
|
return cls._output_classes
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_outputs_union(cls) -> UnionType:
|
def get_typeadapter(cls) -> TypeAdapter[Any]:
|
||||||
"""Gets a union of all invocation outputs."""
|
"""Gets a pydantc TypeAdapter for the union of all invocation output types."""
|
||||||
outputs_union = Union[tuple(cls._output_classes)] # type: ignore [valid-type]
|
if not cls._typeadapter:
|
||||||
return outputs_union # type: ignore [return-value]
|
InvocationOutputsUnion = TypeAliasType(
|
||||||
|
"InvocationOutputsUnion", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
|
||||||
|
)
|
||||||
|
cls._typeadapter = TypeAdapter(InvocationOutputsUnion)
|
||||||
|
return cls._typeadapter
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_output_types(cls) -> Iterable[str]:
|
def get_output_types(cls) -> Iterable[str]:
|
||||||
@ -148,6 +166,7 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
_invocation_classes: ClassVar[set[BaseInvocation]] = set()
|
_invocation_classes: ClassVar[set[BaseInvocation]] = set()
|
||||||
|
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_type(cls) -> str:
|
def get_type(cls) -> str:
|
||||||
@ -160,10 +179,14 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
cls._invocation_classes.add(invocation)
|
cls._invocation_classes.add(invocation)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_invocations_union(cls) -> UnionType:
|
def get_typeadapter(cls) -> TypeAdapter[Any]:
|
||||||
"""Gets a union of all invocation types."""
|
"""Gets a pydantc TypeAdapter for the union of all invocation types."""
|
||||||
invocations_union = Union[tuple(cls._invocation_classes)] # type: ignore [valid-type]
|
if not cls._typeadapter:
|
||||||
return invocations_union # type: ignore [return-value]
|
InvocationsUnion = TypeAliasType(
|
||||||
|
"InvocationsUnion", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
|
||||||
|
)
|
||||||
|
cls._typeadapter = TypeAdapter(InvocationsUnion)
|
||||||
|
return cls._typeadapter
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_invocations(cls) -> Iterable[BaseInvocation]:
|
def get_invocations(cls) -> Iterable[BaseInvocation]:
|
||||||
|
@ -417,7 +417,7 @@ class ClipSkipInvocation(BaseInvocation):
|
|||||||
"""Skip layers in clip text_encoder model."""
|
"""Skip layers in clip text_encoder model."""
|
||||||
|
|
||||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
|
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
|
||||||
skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers)
|
skipped_layers: int = InputField(default=0, ge=0, description=FieldDescriptions.skipped_layers)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
|
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
|
||||||
self.clip.skipped_layers += self.skipped_layers
|
self.clip.skipped_layers += self.skipped_layers
|
||||||
|
@ -2,10 +2,15 @@
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import itertools
|
import itertools
|
||||||
from typing import Annotated, Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
|
from typing import Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
from pydantic import BaseModel, ConfigDict, field_validator, model_validator
|
from pydantic import (
|
||||||
|
BaseModel,
|
||||||
|
ConfigDict,
|
||||||
|
field_validator,
|
||||||
|
model_validator,
|
||||||
|
)
|
||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
|
|
||||||
# Importing * is bad karma but needed here for node detection
|
# Importing * is bad karma but needed here for node detection
|
||||||
@ -260,21 +265,24 @@ class CollectInvocation(BaseInvocation):
|
|||||||
return CollectInvocationOutput(collection=copy.copy(self.collection))
|
return CollectInvocationOutput(collection=copy.copy(self.collection))
|
||||||
|
|
||||||
|
|
||||||
InvocationsUnion: Any = BaseInvocation.get_invocations_union()
|
|
||||||
InvocationOutputsUnion: Any = BaseInvocationOutput.get_outputs_union()
|
|
||||||
|
|
||||||
|
|
||||||
class Graph(BaseModel):
|
class Graph(BaseModel):
|
||||||
id: str = Field(description="The id of this graph", default_factory=uuid_string)
|
id: str = Field(description="The id of this graph", default_factory=uuid_string)
|
||||||
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
|
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
|
||||||
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(
|
nodes: dict[str, BaseInvocation] = Field(description="The nodes in this graph", default_factory=dict)
|
||||||
description="The nodes in this graph", default_factory=dict
|
|
||||||
)
|
|
||||||
edges: list[Edge] = Field(
|
edges: list[Edge] = Field(
|
||||||
description="The connections between nodes and their fields in this graph",
|
description="The connections between nodes and their fields in this graph",
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@field_validator("nodes", mode="plain")
|
||||||
|
@classmethod
|
||||||
|
def validate_nodes(cls, v: dict[str, Any]):
|
||||||
|
nodes: dict[str, BaseInvocation] = {}
|
||||||
|
typeadapter = BaseInvocation.get_typeadapter()
|
||||||
|
for node_id, node in v.items():
|
||||||
|
nodes[node_id] = typeadapter.validate_python(node)
|
||||||
|
return nodes
|
||||||
|
|
||||||
def add_node(self, node: BaseInvocation) -> None:
|
def add_node(self, node: BaseInvocation) -> None:
|
||||||
"""Adds a node to a graph
|
"""Adds a node to a graph
|
||||||
|
|
||||||
@ -824,9 +832,7 @@ class GraphExecutionState(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# The results of executed nodes
|
# The results of executed nodes
|
||||||
results: dict[str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]] = Field(
|
results: dict[str, BaseInvocationOutput] = Field(description="The results of node executions", default_factory=dict)
|
||||||
description="The results of node executions", default_factory=dict
|
|
||||||
)
|
|
||||||
|
|
||||||
# Errors raised when executing nodes
|
# Errors raised when executing nodes
|
||||||
errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict)
|
errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict)
|
||||||
@ -843,6 +849,15 @@ class GraphExecutionState(BaseModel):
|
|||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@field_validator("results", mode="plain")
|
||||||
|
@classmethod
|
||||||
|
def validate_results(cls, v: dict[str, BaseInvocationOutput]):
|
||||||
|
results: dict[str, BaseInvocationOutput] = {}
|
||||||
|
typeadapter = BaseInvocationOutput.get_typeadapter()
|
||||||
|
for result_id, result in v.items():
|
||||||
|
results[result_id] = typeadapter.validate_python(result)
|
||||||
|
return results
|
||||||
|
|
||||||
@field_validator("graph")
|
@field_validator("graph")
|
||||||
def graph_is_valid(cls, v: Graph):
|
def graph_is_valid(cls, v: Graph):
|
||||||
"""Validates that the graph is valid"""
|
"""Validates that the graph is valid"""
|
||||||
@ -1247,6 +1262,6 @@ class LibraryGraph(BaseModel):
|
|||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
GraphInvocation.model_rebuild(force=True)
|
|
||||||
Graph.model_rebuild(force=True)
|
Graph.model_rebuild(force=True)
|
||||||
|
GraphInvocation.model_rebuild(force=True)
|
||||||
GraphExecutionState.model_rebuild(force=True)
|
GraphExecutionState.model_rebuild(force=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user