From 50d3030471dd26f0d5a761f9ae4511092aaffe6e Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 29 May 2024 21:02:29 +1000 Subject: [PATCH] feat(app): dynamic type adapters for invocations & outputs Keep track of whether or not the typeadapter needs to be updated. Allows for dynamic invocation and output unions. --- invokeai/app/invocations/baseinvocation.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 9545179e21..1d169f0a82 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -98,11 +98,13 @@ class BaseInvocationOutput(BaseModel): _output_classes: ClassVar[set[BaseInvocationOutput]] = set() _typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None + _typeadapter_needs_update: ClassVar[bool] = False @classmethod def register_output(cls, output: BaseInvocationOutput) -> None: """Registers an invocation output.""" cls._output_classes.add(output) + cls._typeadapter_needs_update = True @classmethod def get_outputs(cls) -> Iterable[BaseInvocationOutput]: @@ -112,11 +114,12 @@ class BaseInvocationOutput(BaseModel): @classmethod def get_typeadapter(cls) -> TypeAdapter[Any]: """Gets a pydantc TypeAdapter for the union of all invocation output types.""" - if not cls._typeadapter: + if not cls._typeadapter or cls._typeadapter_needs_update: AnyInvocationOutput = TypeAliasType( "AnyInvocationOutput", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")] ) cls._typeadapter = TypeAdapter(AnyInvocationOutput) + cls._typeadapter_needs_update = False return cls._typeadapter @classmethod @@ -168,6 +171,7 @@ class BaseInvocation(ABC, BaseModel): _invocation_classes: ClassVar[set[BaseInvocation]] = set() _typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None + _typeadapter_needs_update: ClassVar[bool] = False @classmethod def get_type(cls) -> str: @@ -178,15 +182,17 @@ class BaseInvocation(ABC, BaseModel): def register_invocation(cls, invocation: BaseInvocation) -> None: """Registers an invocation.""" cls._invocation_classes.add(invocation) + cls._typeadapter_needs_update = True @classmethod def get_typeadapter(cls) -> TypeAdapter[Any]: """Gets a pydantc TypeAdapter for the union of all invocation types.""" - if not cls._typeadapter: + if not cls._typeadapter or cls._typeadapter_needs_update: AnyInvocation = TypeAliasType( "AnyInvocation", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")] ) cls._typeadapter = TypeAdapter(AnyInvocation) + cls._typeadapter_needs_update = False return cls._typeadapter @classmethod