mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): JIT graph nodes validation
We use pydantic to validate a union of valid invocations when instantiating a graph. Previously, we constructed the union while creating the `Graph` class. This introduces a dependency on the order of imports. For example, consider a setup where we have 3 invocations in the app: - Python executes the module where `FirstInvocation` is defined, registering `FirstInvocation`. - Python executes the module where `SecondInvocation` is defined, registering `SecondInvocation`. - Python executes the module where `Graph` is defined. A union of invocations is created and used to define the `Graph.nodes` field. The union contains `FirstInvocation` and `SecondInvocation`. - Python executes the module where `ThirdInvocation` is defined, registering `ThirdInvocation`. - A graph is created that includes `ThirdInvocation`. Pydantic validates the graph using the union, which does not know about `ThirdInvocation`, raising a `ValidationError` about an unknown invocation type. This scenario has been particularly problematic in tests, where we may create invocations dynamically. The test files have to be structured in such a way that the imports happen in the right order. It's a major pain. This PR refactors the validation of graph nodes to resolve this issue: - `BaseInvocation` gets a new method `get_typeadapter`. This builds a pydantic `TypeAdapter` for the union of all registered invocations, caching it after the first call. - `Graph.nodes`'s type is widened to `dict[str, BaseInvocation]`. This actually is a nice bonus, because we get better type hints whenever we reference `some_graph.nodes`. - A "plain" field validator takes over the validation logic for `Graph.nodes`. "Plain" validators totally override pydantic's own validation logic. The validator grabs the `TypeAdapter` from `BaseInvocation`, then validates each node with it. The validation is identical to the previous implementation - we get the same errors. `BaseInvocationOutput` gets the same treatment.
This commit is contained in:
@ -8,13 +8,26 @@ import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from inspect import signature
|
||||
from types import UnionType
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union, cast
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Iterable,
|
||||
Literal,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import semver
|
||||
from pydantic import BaseModel, ConfigDict, Field, create_model
|
||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, create_model
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_core import PydanticUndefined
|
||||
from typing_extensions import TypeAliasType
|
||||
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldKind,
|
||||
@ -84,6 +97,7 @@ class BaseInvocationOutput(BaseModel):
|
||||
"""
|
||||
|
||||
_output_classes: ClassVar[set[BaseInvocationOutput]] = set()
|
||||
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
|
||||
|
||||
@classmethod
|
||||
def register_output(cls, output: BaseInvocationOutput) -> None:
|
||||
@ -96,10 +110,14 @@ class BaseInvocationOutput(BaseModel):
|
||||
return cls._output_classes
|
||||
|
||||
@classmethod
|
||||
def get_outputs_union(cls) -> UnionType:
|
||||
"""Gets a union of all invocation outputs."""
|
||||
outputs_union = Union[tuple(cls._output_classes)] # type: ignore [valid-type]
|
||||
return outputs_union # type: ignore [return-value]
|
||||
def get_typeadapter(cls) -> TypeAdapter[Any]:
|
||||
"""Gets a pydantc TypeAdapter for the union of all invocation output types."""
|
||||
if not cls._typeadapter:
|
||||
InvocationOutputsUnion = TypeAliasType(
|
||||
"InvocationOutputsUnion", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
|
||||
)
|
||||
cls._typeadapter = TypeAdapter(InvocationOutputsUnion)
|
||||
return cls._typeadapter
|
||||
|
||||
@classmethod
|
||||
def get_output_types(cls) -> Iterable[str]:
|
||||
@ -148,6 +166,7 @@ class BaseInvocation(ABC, BaseModel):
|
||||
"""
|
||||
|
||||
_invocation_classes: ClassVar[set[BaseInvocation]] = set()
|
||||
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> str:
|
||||
@ -160,10 +179,14 @@ class BaseInvocation(ABC, BaseModel):
|
||||
cls._invocation_classes.add(invocation)
|
||||
|
||||
@classmethod
|
||||
def get_invocations_union(cls) -> UnionType:
|
||||
"""Gets a union of all invocation types."""
|
||||
invocations_union = Union[tuple(cls._invocation_classes)] # type: ignore [valid-type]
|
||||
return invocations_union # type: ignore [return-value]
|
||||
def get_typeadapter(cls) -> TypeAdapter[Any]:
|
||||
"""Gets a pydantc TypeAdapter for the union of all invocation types."""
|
||||
if not cls._typeadapter:
|
||||
InvocationsUnion = TypeAliasType(
|
||||
"InvocationsUnion", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
|
||||
)
|
||||
cls._typeadapter = TypeAdapter(InvocationsUnion)
|
||||
return cls._typeadapter
|
||||
|
||||
@classmethod
|
||||
def get_invocations(cls) -> Iterable[BaseInvocation]:
|
||||
|
Reference in New Issue
Block a user