feat(events): add dynamic invocation & result validators

This is required to get these event fields to deserialize correctly. If omitted, pydantic uses `BaseInvocation`/`BaseInvocationOutput`, which is not correct.

This is similar to the workaround in the `Graph` and `GraphExecutionState` classes where we need to fanagle pydantic with manual validation handling.
This commit is contained in:
psychedelicious 2024-05-28 15:35:57 +10:00 committed by Kent Keirsey
parent a4f88ff834
commit 21aa42627b

View File

@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Generic, Optional, P
from fastapi_events.handlers.local import local_handler
from fastapi_events.registry.payload_schema import registry as payload_schema
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, field_validator
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
@ -101,6 +101,14 @@ class InvocationEventBase(QueueItemEventBase):
invocation: SerializeAsAny[BaseInvocation] = Field(description="The ID of the invocation")
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
@field_validator("invocation", mode="plain")
@classmethod
def validate_invocation(cls, v: Any):
"""Validates the invocation using the dynamic type adapter."""
invocation = BaseInvocation.get_typeadapter().validate_python(v)
return invocation
@payload_schema.register
class InvocationStartedEvent(InvocationEventBase):
@ -176,6 +184,14 @@ class InvocationCompleteEvent(InvocationEventBase):
result: SerializeAsAny[BaseInvocationOutput] = Field(description="The result of the invocation")
@field_validator("result", mode="plain")
@classmethod
def validate_results(cls, v: Any):
"""Validates the invocation result using the dynamic type adapter."""
result = BaseInvocationOutput.get_typeadapter().validate_python(v)
return result
@classmethod
def build(
cls, queue_item: SessionQueueItem, invocation: BaseInvocation, result: BaseInvocationOutput