diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 3243714937..5edae5342d 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -8,13 +8,26 @@ import warnings from abc import ABC, abstractmethod from enum import Enum from inspect import signature -from types import UnionType -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union, cast +from typing import ( + TYPE_CHECKING, + Annotated, + Any, + Callable, + ClassVar, + Iterable, + Literal, + Optional, + Type, + TypeVar, + Union, + cast, +) import semver -from pydantic import BaseModel, ConfigDict, Field, create_model +from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, create_model from pydantic.fields import FieldInfo from pydantic_core import PydanticUndefined +from typing_extensions import TypeAliasType from invokeai.app.invocations.fields import ( FieldKind, @@ -84,6 +97,7 @@ class BaseInvocationOutput(BaseModel): """ _output_classes: ClassVar[set[BaseInvocationOutput]] = set() + _typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None @classmethod def register_output(cls, output: BaseInvocationOutput) -> None: @@ -96,10 +110,14 @@ class BaseInvocationOutput(BaseModel): return cls._output_classes @classmethod - def get_outputs_union(cls) -> UnionType: - """Gets a union of all invocation outputs.""" - outputs_union = Union[tuple(cls._output_classes)] # type: ignore [valid-type] - return outputs_union # type: ignore [return-value] + def get_typeadapter(cls) -> TypeAdapter[Any]: + """Gets a pydantc TypeAdapter for the union of all invocation output types.""" + if not cls._typeadapter: + InvocationOutputsUnion = TypeAliasType( + "InvocationOutputsUnion", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")] + ) + cls._typeadapter = TypeAdapter(InvocationOutputsUnion) + return cls._typeadapter @classmethod def get_output_types(cls) -> Iterable[str]: @@ -148,6 +166,7 @@ class BaseInvocation(ABC, BaseModel): """ _invocation_classes: ClassVar[set[BaseInvocation]] = set() + _typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None @classmethod def get_type(cls) -> str: @@ -160,10 +179,14 @@ class BaseInvocation(ABC, BaseModel): cls._invocation_classes.add(invocation) @classmethod - def get_invocations_union(cls) -> UnionType: - """Gets a union of all invocation types.""" - invocations_union = Union[tuple(cls._invocation_classes)] # type: ignore [valid-type] - return invocations_union # type: ignore [return-value] + def get_typeadapter(cls) -> TypeAdapter[Any]: + """Gets a pydantc TypeAdapter for the union of all invocation types.""" + if not cls._typeadapter: + InvocationsUnion = TypeAliasType( + "InvocationsUnion", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")] + ) + cls._typeadapter = TypeAdapter(InvocationsUnion) + return cls._typeadapter @classmethod def get_invocations(cls) -> Iterable[BaseInvocation]: diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 517da4375e..47be380626 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -417,7 +417,7 @@ class ClipSkipInvocation(BaseInvocation): """Skip layers in clip text_encoder model.""" clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP") - skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers) + skipped_layers: int = InputField(default=0, ge=0, description=FieldDescriptions.skipped_layers) def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput: self.clip.skipped_layers += self.skipped_layers diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index 3df230f5ee..3066af0e50 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -2,10 +2,15 @@ import copy import itertools -from typing import Annotated, Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints +from typing import Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints import networkx as nx -from pydantic import BaseModel, ConfigDict, field_validator, model_validator +from pydantic import ( + BaseModel, + ConfigDict, + field_validator, + model_validator, +) from pydantic.fields import Field # Importing * is bad karma but needed here for node detection @@ -260,21 +265,24 @@ class CollectInvocation(BaseInvocation): return CollectInvocationOutput(collection=copy.copy(self.collection)) -InvocationsUnion: Any = BaseInvocation.get_invocations_union() -InvocationOutputsUnion: Any = BaseInvocationOutput.get_outputs_union() - - class Graph(BaseModel): id: str = Field(description="The id of this graph", default_factory=uuid_string) # TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me - nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field( - description="The nodes in this graph", default_factory=dict - ) + nodes: dict[str, BaseInvocation] = Field(description="The nodes in this graph", default_factory=dict) edges: list[Edge] = Field( description="The connections between nodes and their fields in this graph", default_factory=list, ) + @field_validator("nodes", mode="plain") + @classmethod + def validate_nodes(cls, v: dict[str, Any]): + nodes: dict[str, BaseInvocation] = {} + typeadapter = BaseInvocation.get_typeadapter() + for node_id, node in v.items(): + nodes[node_id] = typeadapter.validate_python(node) + return nodes + def add_node(self, node: BaseInvocation) -> None: """Adds a node to a graph @@ -824,9 +832,7 @@ class GraphExecutionState(BaseModel): ) # The results of executed nodes - results: dict[str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]] = Field( - description="The results of node executions", default_factory=dict - ) + results: dict[str, BaseInvocationOutput] = Field(description="The results of node executions", default_factory=dict) # Errors raised when executing nodes errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict) @@ -843,6 +849,15 @@ class GraphExecutionState(BaseModel): default_factory=dict, ) + @field_validator("results", mode="plain") + @classmethod + def validate_results(cls, v: dict[str, BaseInvocationOutput]): + results: dict[str, BaseInvocationOutput] = {} + typeadapter = BaseInvocationOutput.get_typeadapter() + for result_id, result in v.items(): + results[result_id] = typeadapter.validate_python(result) + return results + @field_validator("graph") def graph_is_valid(cls, v: Graph): """Validates that the graph is valid""" @@ -1247,6 +1262,6 @@ class LibraryGraph(BaseModel): return values -GraphInvocation.model_rebuild(force=True) Graph.model_rebuild(force=True) +GraphInvocation.model_rebuild(force=True) GraphExecutionState.model_rebuild(force=True)