feat(events): simplify event classes

- Remove ABCs, they do not work well with pydantic
- Remove the event type classvar - unused
- Remove clever logic to require an event name - we already get validation for this during schema registration.
- Rename event bases to all end in "Base"
This commit is contained in:
psychedelicious 2024-03-10 23:23:11 +11:00
parent 1b9bbaa5a4
commit 2dc752ea83
5 changed files with 51 additions and 88 deletions

View File

@ -10,14 +10,14 @@ from invokeai.app.services.events.events_common import (
BatchEnqueuedEvent, BatchEnqueuedEvent,
BulkDownloadCompleteEvent, BulkDownloadCompleteEvent,
BulkDownloadErrorEvent, BulkDownloadErrorEvent,
BulkDownloadEvent, BulkDownloadEventBase,
BulkDownloadStartedEvent, BulkDownloadStartedEvent,
FastAPIEvent, FastAPIEvent,
InvocationCompleteEvent, InvocationCompleteEvent,
InvocationDenoiseProgressEvent, InvocationDenoiseProgressEvent,
InvocationErrorEvent, InvocationErrorEvent,
InvocationStartedEvent, InvocationStartedEvent,
ModelEvent, ModelEventBase,
ModelInstallCancelledEvent, ModelInstallCancelledEvent,
ModelInstallCompleteEvent, ModelInstallCompleteEvent,
ModelInstallDownloadProgressEvent, ModelInstallDownloadProgressEvent,
@ -26,7 +26,7 @@ from invokeai.app.services.events.events_common import (
ModelLoadCompleteEvent, ModelLoadCompleteEvent,
ModelLoadStartedEvent, ModelLoadStartedEvent,
QueueClearedEvent, QueueClearedEvent,
QueueEvent, QueueEventBase,
QueueItemStatusChangedEvent, QueueItemStatusChangedEvent,
SessionCanceledEvent, SessionCanceledEvent,
SessionCompleteEvent, SessionCompleteEvent,
@ -106,14 +106,14 @@ class SocketIO:
async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None: async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None:
await self._sio.leave_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id) await self._sio.leave_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
async def _handle_queue_event(self, event: FastAPIEvent[QueueEvent]): async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]):
event_name, payload = event event_name, payload = event
await self._sio.emit(event=event_name, data=payload.model_dump(), room=payload.queue_id) await self._sio.emit(event=event_name, data=payload.model_dump(), room=payload.queue_id)
async def _handle_model_event(self, event: FastAPIEvent[ModelEvent]) -> None: async def _handle_model_event(self, event: FastAPIEvent[ModelEventBase]) -> None:
event_name, payload = event event_name, payload = event
await self._sio.emit(event=event_name, data=payload.model_dump()) await self._sio.emit(event=event_name, data=payload.model_dump())
async def _handle_bulk_image_download_event(self, event: FastAPIEvent[BulkDownloadEvent]) -> None: async def _handle_bulk_image_download_event(self, event: FastAPIEvent[BulkDownloadEventBase]) -> None:
event_name, payload = event event_name, payload = event
await self._sio.emit(event=event_name, data=payload.model_dump(), room=payload.bulk_download_id) await self._sio.emit(event=event_name, data=payload.model_dump(), room=payload.bulk_download_id)

View File

@ -4,7 +4,6 @@
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from invokeai.app.services.events.events_common import ( from invokeai.app.services.events.events_common import (
BaseEvent,
BatchEnqueuedEvent, BatchEnqueuedEvent,
BulkDownloadCompleteEvent, BulkDownloadCompleteEvent,
BulkDownloadErrorEvent, BulkDownloadErrorEvent,
@ -14,6 +13,7 @@ from invokeai.app.services.events.events_common import (
DownloadErrorEvent, DownloadErrorEvent,
DownloadProgressEvent, DownloadProgressEvent,
DownloadStartedEvent, DownloadStartedEvent,
EventBase,
InvocationCompleteEvent, InvocationCompleteEvent,
InvocationDenoiseProgressEvent, InvocationDenoiseProgressEvent,
InvocationErrorEvent, InvocationErrorEvent,
@ -35,7 +35,7 @@ from invokeai.app.services.events.events_common import (
if TYPE_CHECKING: if TYPE_CHECKING:
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.events.events_common import BaseEvent from invokeai.app.services.events.events_common import EventBase
from invokeai.app.services.model_install.model_install_common import ModelInstallJob from invokeai.app.services.model_install.model_install_common import ModelInstallJob
from invokeai.app.services.session_processor.session_processor_common import ProgressImage from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import ( from invokeai.app.services.session_queue.session_queue_common import (
@ -50,7 +50,7 @@ if TYPE_CHECKING:
class EventServiceBase: class EventServiceBase:
"""Basic event bus, to have an empty stand-in when not needed""" """Basic event bus, to have an empty stand-in when not needed"""
def dispatch(self, event: "BaseEvent") -> None: def dispatch(self, event: "EventBase") -> None:
pass pass
# region: Invocation # region: Invocation

View File

@ -1,6 +1,4 @@
from abc import ABC from typing import TYPE_CHECKING, Any, Coroutine, Optional, Protocol, TypeAlias, TypeVar
from enum import Enum
from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Optional, Protocol, TypeAlias, TypeVar
from fastapi_events.handlers.local import local_handler from fastapi_events.handlers.local import local_handler
from fastapi_events.registry.payload_schema import registry as payload_schema from fastapi_events.registry.payload_schema import registry as payload_schema
@ -22,39 +20,22 @@ if TYPE_CHECKING:
from invokeai.app.services.model_install.model_install_common import ModelInstallJob from invokeai.app.services.model_install.model_install_common import ModelInstallJob
class EventType(str, Enum): class EventBase(BaseModel):
QUEUE = "queue"
MODEL = "model"
DOWNLOAD = "download"
BULK_IMAGE_DOWNLOAD = "bulk_image_download"
class BaseEvent(BaseModel, ABC):
"""Base class for all events. All events must inherit from this class. """Base class for all events. All events must inherit from this class.
Events must define the following class attributes: Events must define a class attribute `__event_name__` to identify the event.
- `__event_name__: str`: The name of the event
- `__event_type__: EventType`: The type of the event
All other attributes should be defined as normal for a pydantic model. All other attributes should be defined as normal for a pydantic model.
A timestamp is automatically added to the event when it is created. A timestamp is automatically added to the event when it is created.
""" """
__event_name__: ClassVar[str] = ... # pyright: ignore [reportAssignmentType]
__event_type__: ClassVar[EventType] = ... # pyright: ignore [reportAssignmentType]
timestamp: int = Field(description="The timestamp of the event", default_factory=get_timestamp) timestamp: int = Field(description="The timestamp of the event", default_factory=get_timestamp)
def __init_subclass__(cls, **kwargs: ConfigDict):
for required_attr in ("__event_name__", "__event_type__"):
if getattr(cls, required_attr) is ...:
raise TypeError(f"{cls.__name__} must define {required_attr}")
model_config = ConfigDict(json_schema_serialization_defaults_required=True) model_config = ConfigDict(json_schema_serialization_defaults_required=True)
TEvent = TypeVar("TEvent", bound=BaseEvent) TEvent = TypeVar("TEvent", bound=EventBase)
FastAPIEvent: TypeAlias = tuple[str, TEvent] FastAPIEvent: TypeAlias = tuple[str, TEvent]
""" """
@ -77,37 +58,28 @@ def register_events(events: set[type[TEvent]], func: FastAPIEventFunc) -> None:
local_handler.register(event_name=event.__event_name__, _func=func) local_handler.register(event_name=event.__event_name__, _func=func)
class QueueEvent(BaseEvent, ABC): class QueueEventBase(EventBase):
"""Base class for queue events""" """Base class for queue events"""
__event_type__ = EventType.QUEUE
__event_name__ = "queue_event"
queue_id: str = Field(description="The ID of the queue") queue_id: str = Field(description="The ID of the queue")
class QueueItemEvent(QueueEvent, ABC): class QueueItemEventBase(QueueEventBase):
"""Base class for queue item events""" """Base class for queue item events"""
__event_name__ = "queue_item_event"
item_id: int = Field(description="The ID of the queue item") item_id: int = Field(description="The ID of the queue item")
batch_id: str = Field(description="The ID of the queue batch") batch_id: str = Field(description="The ID of the queue batch")
class SessionEvent(QueueItemEvent, ABC): class SessionEventBase(QueueItemEventBase):
"""Base class for session (aka graph execution state) events""" """Base class for session (aka graph execution state) events"""
__event_name__ = "session_event"
session_id: str = Field(description="The ID of the session (aka graph execution state)") session_id: str = Field(description="The ID of the session (aka graph execution state)")
class InvocationEvent(SessionEvent, ABC): class InvocationEventBase(SessionEventBase):
"""Base class for invocation events""" """Base class for invocation events"""
__event_name__ = "invocation_event"
queue_id: str = Field(description="The ID of the queue") queue_id: str = Field(description="The ID of the queue")
item_id: int = Field(description="The ID of the queue item") item_id: int = Field(description="The ID of the queue item")
batch_id: str = Field(description="The ID of the queue batch") batch_id: str = Field(description="The ID of the queue batch")
@ -118,7 +90,7 @@ class InvocationEvent(SessionEvent, ABC):
@payload_schema.register # pyright: ignore [reportUnknownMemberType] @payload_schema.register # pyright: ignore [reportUnknownMemberType]
class InvocationStartedEvent(InvocationEvent): class InvocationStartedEvent(InvocationEventBase):
"""Emitted when an invocation is started""" """Emitted when an invocation is started"""
__event_name__ = "invocation_started" __event_name__ = "invocation_started"
@ -137,7 +109,7 @@ class InvocationStartedEvent(InvocationEvent):
@payload_schema.register # pyright: ignore [reportUnknownMemberType] @payload_schema.register # pyright: ignore [reportUnknownMemberType]
class InvocationDenoiseProgressEvent(InvocationEvent): class InvocationDenoiseProgressEvent(InvocationEventBase):
"""Emitted at each step during denoising of an invocation.""" """Emitted at each step during denoising of an invocation."""
__event_name__ = "invocation_denoise_progress" __event_name__ = "invocation_denoise_progress"
@ -170,7 +142,7 @@ class InvocationDenoiseProgressEvent(InvocationEvent):
@payload_schema.register # pyright: ignore [reportUnknownMemberType] @payload_schema.register # pyright: ignore [reportUnknownMemberType]
class InvocationCompleteEvent(InvocationEvent): class InvocationCompleteEvent(InvocationEventBase):
"""Emitted when an invocation is complete""" """Emitted when an invocation is complete"""
__event_name__ = "invocation_complete" __event_name__ = "invocation_complete"
@ -194,7 +166,7 @@ class InvocationCompleteEvent(InvocationEvent):
@payload_schema.register # pyright: ignore [reportUnknownMemberType] @payload_schema.register # pyright: ignore [reportUnknownMemberType]
class InvocationErrorEvent(InvocationEvent): class InvocationErrorEvent(InvocationEventBase):
"""Emitted when an invocation encounters an error""" """Emitted when an invocation encounters an error"""
__event_name__ = "invocation_error" __event_name__ = "invocation_error"
@ -227,7 +199,7 @@ class InvocationErrorEvent(InvocationEvent):
@payload_schema.register # pyright: ignore [reportUnknownMemberType] @payload_schema.register # pyright: ignore [reportUnknownMemberType]
class SessionStartedEvent(SessionEvent): class SessionStartedEvent(SessionEventBase):
"""Emitted when a session has started""" """Emitted when a session has started"""
__event_name__ = "session_started" __event_name__ = "session_started"
@ -243,7 +215,7 @@ class SessionStartedEvent(SessionEvent):
@payload_schema.register # pyright: ignore [reportUnknownMemberType] @payload_schema.register # pyright: ignore [reportUnknownMemberType]
class SessionCompleteEvent(SessionEvent): class SessionCompleteEvent(SessionEventBase):
"""Emitted when a session has completed all invocations""" """Emitted when a session has completed all invocations"""
__event_name__ = "session_complete" __event_name__ = "session_complete"
@ -259,7 +231,7 @@ class SessionCompleteEvent(SessionEvent):
@payload_schema.register # pyright: ignore [reportUnknownMemberType] @payload_schema.register # pyright: ignore [reportUnknownMemberType]
class SessionCanceledEvent(SessionEvent): class SessionCanceledEvent(SessionEventBase):
"""Emitted when a session is canceled""" """Emitted when a session is canceled"""
__event_name__ = "session_canceled" __event_name__ = "session_canceled"
@ -275,7 +247,7 @@ class SessionCanceledEvent(SessionEvent):
@payload_schema.register # pyright: ignore [reportUnknownMemberType] @payload_schema.register # pyright: ignore [reportUnknownMemberType]
class QueueItemStatusChangedEvent(SessionEvent): class QueueItemStatusChangedEvent(SessionEventBase):
"""Emitted when a queue item's status changes""" """Emitted when a queue item's status changes"""
__event_name__ = "queue_item_status_changed" __event_name__ = "queue_item_status_changed"
@ -314,7 +286,7 @@ class QueueItemStatusChangedEvent(SessionEvent):
@payload_schema.register # pyright: ignore [reportUnknownMemberType] @payload_schema.register # pyright: ignore [reportUnknownMemberType]
class BatchEnqueuedEvent(QueueEvent): class BatchEnqueuedEvent(QueueEventBase):
"""Emitted when a batch is enqueued""" """Emitted when a batch is enqueued"""
__event_name__ = "batch_enqueued" __event_name__ = "batch_enqueued"
@ -338,7 +310,7 @@ class BatchEnqueuedEvent(QueueEvent):
@payload_schema.register # pyright: ignore [reportUnknownMemberType] @payload_schema.register # pyright: ignore [reportUnknownMemberType]
class QueueClearedEvent(QueueEvent): class QueueClearedEvent(QueueEventBase):
"""Emitted when a queue is cleared""" """Emitted when a queue is cleared"""
__event_name__ = "queue_cleared" __event_name__ = "queue_cleared"
@ -348,17 +320,14 @@ class QueueClearedEvent(QueueEvent):
return cls(queue_id=queue_id) return cls(queue_id=queue_id)
class DownloadEvent(BaseEvent, ABC): class DownloadEventBase(EventBase):
"""Base class for events associated with a download""" """Base class for events associated with a download"""
__event_type__ = EventType.DOWNLOAD
__event_name__ = "download_event"
source: str = Field(description="The source of the download") source: str = Field(description="The source of the download")
@payload_schema.register # pyright: ignore [reportUnknownMemberType] @payload_schema.register # pyright: ignore [reportUnknownMemberType]
class DownloadStartedEvent(DownloadEvent): class DownloadStartedEvent(DownloadEventBase):
"""Emitted when a download is started""" """Emitted when a download is started"""
__event_name__ = "download_started" __event_name__ = "download_started"
@ -371,7 +340,7 @@ class DownloadStartedEvent(DownloadEvent):
@payload_schema.register # pyright: ignore [reportUnknownMemberType] @payload_schema.register # pyright: ignore [reportUnknownMemberType]
class DownloadProgressEvent(DownloadEvent): class DownloadProgressEvent(DownloadEventBase):
"""Emitted at intervals during a download""" """Emitted at intervals during a download"""
__event_name__ = "download_progress" __event_name__ = "download_progress"
@ -386,7 +355,7 @@ class DownloadProgressEvent(DownloadEvent):
@payload_schema.register # pyright: ignore [reportUnknownMemberType] @payload_schema.register # pyright: ignore [reportUnknownMemberType]
class DownloadCompleteEvent(DownloadEvent): class DownloadCompleteEvent(DownloadEventBase):
"""Emitted when a download is completed""" """Emitted when a download is completed"""
__event_name__ = "download_complete" __event_name__ = "download_complete"
@ -400,7 +369,7 @@ class DownloadCompleteEvent(DownloadEvent):
@payload_schema.register # pyright: ignore [reportUnknownMemberType] @payload_schema.register # pyright: ignore [reportUnknownMemberType]
class DownloadCancelledEvent(DownloadEvent): class DownloadCancelledEvent(DownloadEventBase):
"""Emitted when a download is cancelled""" """Emitted when a download is cancelled"""
__event_name__ = "download_cancelled" __event_name__ = "download_cancelled"
@ -411,7 +380,7 @@ class DownloadCancelledEvent(DownloadEvent):
@payload_schema.register # pyright: ignore [reportUnknownMemberType] @payload_schema.register # pyright: ignore [reportUnknownMemberType]
class DownloadErrorEvent(DownloadEvent): class DownloadErrorEvent(DownloadEventBase):
"""Emitted when a download encounters an error""" """Emitted when a download encounters an error"""
__event_name__ = "download_error" __event_name__ = "download_error"
@ -424,15 +393,12 @@ class DownloadErrorEvent(DownloadEvent):
return cls(source=source, error_type=error_type, error=error) return cls(source=source, error_type=error_type, error=error)
class ModelEvent(BaseEvent, ABC): class ModelEventBase(EventBase):
"""Base class for events associated with a model""" """Base class for events associated with a model"""
__event_type__ = EventType.MODEL
__event_name__ = "model_event"
@payload_schema.register # pyright: ignore [reportUnknownMemberType] @payload_schema.register # pyright: ignore [reportUnknownMemberType]
class ModelLoadStartedEvent(ModelEvent): class ModelLoadStartedEvent(ModelEventBase):
"""Emitted when a model is requested""" """Emitted when a model is requested"""
__event_name__ = "model_load_started" __event_name__ = "model_load_started"
@ -446,7 +412,7 @@ class ModelLoadStartedEvent(ModelEvent):
@payload_schema.register # pyright: ignore [reportUnknownMemberType] @payload_schema.register # pyright: ignore [reportUnknownMemberType]
class ModelLoadCompleteEvent(ModelEvent): class ModelLoadCompleteEvent(ModelEventBase):
"""Emitted when a model is requested""" """Emitted when a model is requested"""
__event_name__ = "model_load_complete" __event_name__ = "model_load_complete"
@ -460,7 +426,7 @@ class ModelLoadCompleteEvent(ModelEvent):
@payload_schema.register # pyright: ignore [reportUnknownMemberType] @payload_schema.register # pyright: ignore [reportUnknownMemberType]
class ModelInstallDownloadProgressEvent(ModelEvent): class ModelInstallDownloadProgressEvent(ModelEventBase):
"""Emitted at intervals while the install job is in progress (remote models only).""" """Emitted at intervals while the install job is in progress (remote models only)."""
__event_name__ = "model_install_download_progress" __event_name__ = "model_install_download_progress"
@ -496,7 +462,7 @@ class ModelInstallDownloadProgressEvent(ModelEvent):
@payload_schema.register # pyright: ignore [reportUnknownMemberType] @payload_schema.register # pyright: ignore [reportUnknownMemberType]
class ModelInstallDownloadsCompleteEvent(ModelEvent): class ModelInstallDownloadsCompleteEvent(ModelEventBase):
"""Emitted once when an install job becomes active.""" """Emitted once when an install job becomes active."""
__event_name__ = "model_install_downloads_complete" __event_name__ = "model_install_downloads_complete"
@ -510,7 +476,7 @@ class ModelInstallDownloadsCompleteEvent(ModelEvent):
@payload_schema.register # pyright: ignore [reportUnknownMemberType] @payload_schema.register # pyright: ignore [reportUnknownMemberType]
class ModelInstallStartedEvent(ModelEvent): class ModelInstallStartedEvent(ModelEventBase):
"""Emitted once when an install job becomes active.""" """Emitted once when an install job becomes active."""
__event_name__ = "model_install_started" __event_name__ = "model_install_started"
@ -524,7 +490,7 @@ class ModelInstallStartedEvent(ModelEvent):
@payload_schema.register # pyright: ignore [reportUnknownMemberType] @payload_schema.register # pyright: ignore [reportUnknownMemberType]
class ModelInstallCompleteEvent(ModelEvent): class ModelInstallCompleteEvent(ModelEventBase):
"""Emitted when an install job is completed successfully.""" """Emitted when an install job is completed successfully."""
__event_name__ = "model_install_complete" __event_name__ = "model_install_complete"
@ -541,7 +507,7 @@ class ModelInstallCompleteEvent(ModelEvent):
@payload_schema.register # pyright: ignore [reportUnknownMemberType] @payload_schema.register # pyright: ignore [reportUnknownMemberType]
class ModelInstallCancelledEvent(ModelEvent): class ModelInstallCancelledEvent(ModelEventBase):
"""Emitted when an install job is cancelled.""" """Emitted when an install job is cancelled."""
__event_name__ = "model_install_cancelled" __event_name__ = "model_install_cancelled"
@ -555,7 +521,7 @@ class ModelInstallCancelledEvent(ModelEvent):
@payload_schema.register # pyright: ignore [reportUnknownMemberType] @payload_schema.register # pyright: ignore [reportUnknownMemberType]
class ModelInstallErrorEvent(ModelEvent): class ModelInstallErrorEvent(ModelEventBase):
"""Emitted when an install job encounters an exception.""" """Emitted when an install job encounters an exception."""
__event_name__ = "model_install_error" __event_name__ = "model_install_error"
@ -572,19 +538,16 @@ class ModelInstallErrorEvent(ModelEvent):
return cls(id=job.id, source=str(job.source), error_type=job.error_type, error=job.error) return cls(id=job.id, source=str(job.source), error_type=job.error_type, error=job.error)
class BulkDownloadEvent(BaseEvent, ABC): class BulkDownloadEventBase(EventBase):
"""Base class for events associated with a bulk image download""" """Base class for events associated with a bulk image download"""
__event_type__ = EventType.BULK_IMAGE_DOWNLOAD
__event_name__ = "bulk_image_download_event"
bulk_download_id: str = Field(description="The ID of the bulk image download") bulk_download_id: str = Field(description="The ID of the bulk image download")
bulk_download_item_id: str = Field(description="The ID of the bulk image download item") bulk_download_item_id: str = Field(description="The ID of the bulk image download item")
bulk_download_item_name: str = Field(description="The name of the bulk image download item") bulk_download_item_name: str = Field(description="The name of the bulk image download item")
@payload_schema.register # pyright: ignore [reportUnknownMemberType] @payload_schema.register # pyright: ignore [reportUnknownMemberType]
class BulkDownloadStartedEvent(BulkDownloadEvent): class BulkDownloadStartedEvent(BulkDownloadEventBase):
"""Emitted when a bulk image download is started""" """Emitted when a bulk image download is started"""
__event_name__ = "bulk_download_started" __event_name__ = "bulk_download_started"
@ -601,7 +564,7 @@ class BulkDownloadStartedEvent(BulkDownloadEvent):
@payload_schema.register # pyright: ignore [reportUnknownMemberType] @payload_schema.register # pyright: ignore [reportUnknownMemberType]
class BulkDownloadCompleteEvent(BulkDownloadEvent): class BulkDownloadCompleteEvent(BulkDownloadEventBase):
"""Emitted when a bulk image download is started""" """Emitted when a bulk image download is started"""
__event_name__ = "bulk_download_complete" __event_name__ = "bulk_download_complete"
@ -618,7 +581,7 @@ class BulkDownloadCompleteEvent(BulkDownloadEvent):
@payload_schema.register # pyright: ignore [reportUnknownMemberType] @payload_schema.register # pyright: ignore [reportUnknownMemberType]
class BulkDownloadErrorEvent(BulkDownloadEvent): class BulkDownloadErrorEvent(BulkDownloadEventBase):
"""Emitted when a bulk image download is started""" """Emitted when a bulk image download is started"""
__event_name__ = "bulk_download_error" __event_name__ = "bulk_download_error"

View File

@ -7,7 +7,7 @@ from queue import Empty, Queue
from fastapi_events.dispatcher import dispatch from fastapi_events.dispatcher import dispatch
from invokeai.app.services.events.events_common import ( from invokeai.app.services.events.events_common import (
BaseEvent, EventBase,
) )
from .events_base import EventServiceBase from .events_base import EventServiceBase
@ -16,7 +16,7 @@ from .events_base import EventServiceBase
class FastAPIEventService(EventServiceBase): class FastAPIEventService(EventServiceBase):
def __init__(self, event_handler_id: int) -> None: def __init__(self, event_handler_id: int) -> None:
self.event_handler_id = event_handler_id self.event_handler_id = event_handler_id
self._queue = Queue[BaseEvent | None]() self._queue = Queue[EventBase | None]()
self._stop_event = threading.Event() self._stop_event = threading.Event()
asyncio.create_task(self._dispatch_from_queue(stop_event=self._stop_event)) asyncio.create_task(self._dispatch_from_queue(stop_event=self._stop_event))
@ -26,7 +26,7 @@ class FastAPIEventService(EventServiceBase):
self._stop_event.set() self._stop_event.set()
self._queue.put(None) self._queue.put(None)
def dispatch(self, event: BaseEvent) -> None: def dispatch(self, event: EventBase) -> None:
self._queue.put(event) self._queue.put(event)
async def _dispatch_from_queue(self, stop_event: threading.Event): async def _dispatch_from_queue(self, stop_event: threading.Event):

View File

@ -9,7 +9,7 @@ from invokeai.app.services.events.events_common import (
BatchEnqueuedEvent, BatchEnqueuedEvent,
FastAPIEvent, FastAPIEvent,
QueueClearedEvent, QueueClearedEvent,
QueueEvent, QueueEventBase,
QueueItemStatusChangedEvent, QueueItemStatusChangedEvent,
SessionCanceledEvent, SessionCanceledEvent,
register_events, register_events,
@ -332,7 +332,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
def _poll_now(self) -> None: def _poll_now(self) -> None:
self._poll_now_event.set() self._poll_now_event.set()
async def _on_queue_event(self, event: FastAPIEvent[QueueEvent]) -> None: async def _on_queue_event(self, event: FastAPIEvent[QueueEventBase]) -> None:
_event_name, payload = event _event_name, payload = event
if ( if (
isinstance(payload, SessionCanceledEvent) isinstance(payload, SessionCanceledEvent)