feat(nodes): JIT graph nodes validation

We use pydantic to validate a union of valid invocations when instantiating a graph.

Previously, we constructed the union while creating the `Graph` class. This introduces a dependency on the order of imports.

For example, consider a setup where we have 3 invocations in the app:

- Python executes the module where `FirstInvocation` is defined, registering `FirstInvocation`.
- Python executes the module where `SecondInvocation` is defined, registering `SecondInvocation`.
- Python executes the module where `Graph` is defined. A union of invocations is created and used to define the `Graph.nodes` field. The union contains `FirstInvocation` and `SecondInvocation`.
- Python executes the module where `ThirdInvocation` is defined, registering `ThirdInvocation`.
- A graph is created that includes `ThirdInvocation`. Pydantic validates the graph using the union, which does not know about `ThirdInvocation`, raising a `ValidationError` about an unknown invocation type.

This scenario has been particularly problematic in tests, where we may create invocations dynamically. The test files have to be structured in such a way that the imports happen in the right order. It's a major pain.

This PR refactors the validation of graph nodes to resolve this issue:

- `BaseInvocation` gets a new method `get_typeadapter`. This builds a pydantic `TypeAdapter` for the union of all registered invocations, caching it after the first call.
- `Graph.nodes`'s type is widened to `dict[str, BaseInvocation]`. This actually is a nice bonus, because we get better type hints whenever we reference `some_graph.nodes`.
- A "plain" field validator takes over the validation logic for `Graph.nodes`. "Plain" validators totally override pydantic's own validation logic. The validator grabs the `TypeAdapter` from `BaseInvocation`, then validates each node with it. The validation is identical to the previous implementation - we get the same errors.

`BaseInvocationOutput` gets the same treatment.
This commit is contained in:
psychedelicious
2024-02-17 11:22:08 +11:00
parent af2117dc0c
commit 731860c332
3 changed files with 63 additions and 25 deletions

View File

@ -8,13 +8,26 @@ import warnings
from abc import ABC, abstractmethod
from enum import Enum
from inspect import signature
from types import UnionType
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union, cast
from typing import (
TYPE_CHECKING,
Annotated,
Any,
Callable,
ClassVar,
Iterable,
Literal,
Optional,
Type,
TypeVar,
Union,
cast,
)
import semver
from pydantic import BaseModel, ConfigDict, Field, create_model
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, create_model
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefined
from typing_extensions import TypeAliasType
from invokeai.app.invocations.fields import (
FieldKind,
@ -84,6 +97,7 @@ class BaseInvocationOutput(BaseModel):
"""
_output_classes: ClassVar[set[BaseInvocationOutput]] = set()
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
@classmethod
def register_output(cls, output: BaseInvocationOutput) -> None:
@ -96,10 +110,14 @@ class BaseInvocationOutput(BaseModel):
return cls._output_classes
@classmethod
def get_outputs_union(cls) -> UnionType:
"""Gets a union of all invocation outputs."""
outputs_union = Union[tuple(cls._output_classes)] # type: ignore [valid-type]
return outputs_union # type: ignore [return-value]
def get_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantc TypeAdapter for the union of all invocation output types."""
if not cls._typeadapter:
InvocationOutputsUnion = TypeAliasType(
"InvocationOutputsUnion", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
)
cls._typeadapter = TypeAdapter(InvocationOutputsUnion)
return cls._typeadapter
@classmethod
def get_output_types(cls) -> Iterable[str]:
@ -148,6 +166,7 @@ class BaseInvocation(ABC, BaseModel):
"""
_invocation_classes: ClassVar[set[BaseInvocation]] = set()
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
@classmethod
def get_type(cls) -> str:
@ -160,10 +179,14 @@ class BaseInvocation(ABC, BaseModel):
cls._invocation_classes.add(invocation)
@classmethod
def get_invocations_union(cls) -> UnionType:
"""Gets a union of all invocation types."""
invocations_union = Union[tuple(cls._invocation_classes)] # type: ignore [valid-type]
return invocations_union # type: ignore [return-value]
def get_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantc TypeAdapter for the union of all invocation types."""
if not cls._typeadapter:
InvocationsUnion = TypeAliasType(
"InvocationsUnion", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
)
cls._typeadapter = TypeAdapter(InvocationsUnion)
return cls._typeadapter
@classmethod
def get_invocations(cls) -> Iterable[BaseInvocation]: