mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(events): add dynamic invocation & result validators
This is required to get these event fields to deserialize correctly. If omitted, pydantic uses `BaseInvocation`/`BaseInvocationOutput`, which is not correct. This is similar to the workaround in the `Graph` and `GraphExecutionState` classes where we need to fanagle pydantic with manual validation handling.
This commit is contained in:
parent
a4f88ff834
commit
21aa42627b
@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Generic, Optional, P
|
|||||||
|
|
||||||
from fastapi_events.handlers.local import local_handler
|
from fastapi_events.handlers.local import local_handler
|
||||||
from fastapi_events.registry.payload_schema import registry as payload_schema
|
from fastapi_events.registry.payload_schema import registry as payload_schema
|
||||||
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny
|
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, field_validator
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||||
@ -101,6 +101,14 @@ class InvocationEventBase(QueueItemEventBase):
|
|||||||
invocation: SerializeAsAny[BaseInvocation] = Field(description="The ID of the invocation")
|
invocation: SerializeAsAny[BaseInvocation] = Field(description="The ID of the invocation")
|
||||||
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
|
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
|
||||||
|
|
||||||
|
@field_validator("invocation", mode="plain")
|
||||||
|
@classmethod
|
||||||
|
def validate_invocation(cls, v: Any):
|
||||||
|
"""Validates the invocation using the dynamic type adapter."""
|
||||||
|
|
||||||
|
invocation = BaseInvocation.get_typeadapter().validate_python(v)
|
||||||
|
return invocation
|
||||||
|
|
||||||
|
|
||||||
@payload_schema.register
|
@payload_schema.register
|
||||||
class InvocationStartedEvent(InvocationEventBase):
|
class InvocationStartedEvent(InvocationEventBase):
|
||||||
@ -176,6 +184,14 @@ class InvocationCompleteEvent(InvocationEventBase):
|
|||||||
|
|
||||||
result: SerializeAsAny[BaseInvocationOutput] = Field(description="The result of the invocation")
|
result: SerializeAsAny[BaseInvocationOutput] = Field(description="The result of the invocation")
|
||||||
|
|
||||||
|
@field_validator("result", mode="plain")
|
||||||
|
@classmethod
|
||||||
|
def validate_results(cls, v: Any):
|
||||||
|
"""Validates the invocation result using the dynamic type adapter."""
|
||||||
|
|
||||||
|
result = BaseInvocationOutput.get_typeadapter().validate_python(v)
|
||||||
|
return result
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build(
|
def build(
|
||||||
cls, queue_item: SessionQueueItem, invocation: BaseInvocation, result: BaseInvocationOutput
|
cls, queue_item: SessionQueueItem, invocation: BaseInvocation, result: BaseInvocationOutput
|
||||||
|
Loading…
Reference in New Issue
Block a user