feat(events): simplify event classes

- Remove ABCs, they do not work well with pydantic
- Remove the event type classvar - unused
- Remove clever logic to require an event name - we already get validation for this during schema registration.
- Rename event bases to all end in "Base"
This commit is contained in:
psychedelicious 2024-03-10 23:23:11 +11:00
parent e9043ff060
commit 63e4b224b2
5 changed files with 51 additions and 88 deletions

View File

@ -10,14 +10,14 @@ from invokeai.app.services.events.events_common import (
BatchEnqueuedEvent,
BulkDownloadCompleteEvent,
BulkDownloadErrorEvent,
BulkDownloadEvent,
BulkDownloadEventBase,
BulkDownloadStartedEvent,
FastAPIEvent,
InvocationCompleteEvent,
InvocationDenoiseProgressEvent,
InvocationErrorEvent,
InvocationStartedEvent,
ModelEvent,
ModelEventBase,
ModelInstallCancelledEvent,
ModelInstallCompleteEvent,
ModelInstallDownloadProgressEvent,
@ -26,7 +26,7 @@ from invokeai.app.services.events.events_common import (
ModelLoadCompleteEvent,
ModelLoadStartedEvent,
QueueClearedEvent,
QueueEvent,
QueueEventBase,
QueueItemStatusChangedEvent,
SessionCanceledEvent,
SessionCompleteEvent,
@ -106,14 +106,14 @@ class SocketIO:
async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None:
await self._sio.leave_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
async def _handle_queue_event(self, event: FastAPIEvent[QueueEvent]):
async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]):
event_name, payload = event
await self._sio.emit(event=event_name, data=payload.model_dump(), room=payload.queue_id)
async def _handle_model_event(self, event: FastAPIEvent[ModelEvent]) -> None:
async def _handle_model_event(self, event: FastAPIEvent[ModelEventBase]) -> None:
event_name, payload = event
await self._sio.emit(event=event_name, data=payload.model_dump())
async def _handle_bulk_image_download_event(self, event: FastAPIEvent[BulkDownloadEvent]) -> None:
async def _handle_bulk_image_download_event(self, event: FastAPIEvent[BulkDownloadEventBase]) -> None:
event_name, payload = event
await self._sio.emit(event=event_name, data=payload.model_dump(), room=payload.bulk_download_id)

View File

@ -4,7 +4,6 @@
from typing import TYPE_CHECKING, Optional
from invokeai.app.services.events.events_common import (
BaseEvent,
BatchEnqueuedEvent,
BulkDownloadCompleteEvent,
BulkDownloadErrorEvent,
@ -14,6 +13,7 @@ from invokeai.app.services.events.events_common import (
DownloadErrorEvent,
DownloadProgressEvent,
DownloadStartedEvent,
EventBase,
InvocationCompleteEvent,
InvocationDenoiseProgressEvent,
InvocationErrorEvent,
@ -35,7 +35,7 @@ from invokeai.app.services.events.events_common import (
if TYPE_CHECKING:
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.events.events_common import BaseEvent
from invokeai.app.services.events.events_common import EventBase
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import (
@ -50,7 +50,7 @@ if TYPE_CHECKING:
class EventServiceBase:
"""Basic event bus, to have an empty stand-in when not needed"""
def dispatch(self, event: "BaseEvent") -> None:
def dispatch(self, event: "EventBase") -> None:
pass
# region: Invocation

View File

@ -1,6 +1,4 @@
from abc import ABC
from enum import Enum
from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Optional, Protocol, TypeAlias, TypeVar
from typing import TYPE_CHECKING, Any, Coroutine, Optional, Protocol, TypeAlias, TypeVar
from fastapi_events.handlers.local import local_handler
from fastapi_events.registry.payload_schema import registry as payload_schema
@ -22,39 +20,22 @@ if TYPE_CHECKING:
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
class EventType(str, Enum):
QUEUE = "queue"
MODEL = "model"
DOWNLOAD = "download"
BULK_IMAGE_DOWNLOAD = "bulk_image_download"
class BaseEvent(BaseModel, ABC):
class EventBase(BaseModel):
"""Base class for all events. All events must inherit from this class.
Events must define the following class attributes:
- `__event_name__: str`: The name of the event
- `__event_type__: EventType`: The type of the event
Events must define a class attribute `__event_name__` to identify the event.
All other attributes should be defined as normal for a pydantic model.
A timestamp is automatically added to the event when it is created.
"""
__event_name__: ClassVar[str] = ... # pyright: ignore [reportAssignmentType]
__event_type__: ClassVar[EventType] = ... # pyright: ignore [reportAssignmentType]
timestamp: int = Field(description="The timestamp of the event", default_factory=get_timestamp)
def __init_subclass__(cls, **kwargs: ConfigDict):
for required_attr in ("__event_name__", "__event_type__"):
if getattr(cls, required_attr) is ...:
raise TypeError(f"{cls.__name__} must define {required_attr}")
model_config = ConfigDict(json_schema_serialization_defaults_required=True)
TEvent = TypeVar("TEvent", bound=BaseEvent)
TEvent = TypeVar("TEvent", bound=EventBase)
FastAPIEvent: TypeAlias = tuple[str, TEvent]
"""
@ -77,37 +58,28 @@ def register_events(events: set[type[TEvent]], func: FastAPIEventFunc) -> None:
local_handler.register(event_name=event.__event_name__, _func=func)
class QueueEvent(BaseEvent, ABC):
class QueueEventBase(EventBase):
"""Base class for queue events"""
__event_type__ = EventType.QUEUE
__event_name__ = "queue_event"
queue_id: str = Field(description="The ID of the queue")
class QueueItemEvent(QueueEvent, ABC):
class QueueItemEventBase(QueueEventBase):
"""Base class for queue item events"""
__event_name__ = "queue_item_event"
item_id: int = Field(description="The ID of the queue item")
batch_id: str = Field(description="The ID of the queue batch")
class SessionEvent(QueueItemEvent, ABC):
class SessionEventBase(QueueItemEventBase):
"""Base class for session (aka graph execution state) events"""
__event_name__ = "session_event"
session_id: str = Field(description="The ID of the session (aka graph execution state)")
class InvocationEvent(SessionEvent, ABC):
class InvocationEventBase(SessionEventBase):
"""Base class for invocation events"""
__event_name__ = "invocation_event"
queue_id: str = Field(description="The ID of the queue")
item_id: int = Field(description="The ID of the queue item")
batch_id: str = Field(description="The ID of the queue batch")
@ -118,7 +90,7 @@ class InvocationEvent(SessionEvent, ABC):
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class InvocationStartedEvent(InvocationEvent):
class InvocationStartedEvent(InvocationEventBase):
"""Emitted when an invocation is started"""
__event_name__ = "invocation_started"
@ -137,7 +109,7 @@ class InvocationStartedEvent(InvocationEvent):
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class InvocationDenoiseProgressEvent(InvocationEvent):
class InvocationDenoiseProgressEvent(InvocationEventBase):
"""Emitted at each step during denoising of an invocation."""
__event_name__ = "invocation_denoise_progress"
@ -170,7 +142,7 @@ class InvocationDenoiseProgressEvent(InvocationEvent):
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class InvocationCompleteEvent(InvocationEvent):
class InvocationCompleteEvent(InvocationEventBase):
"""Emitted when an invocation is complete"""
__event_name__ = "invocation_complete"
@ -194,7 +166,7 @@ class InvocationCompleteEvent(InvocationEvent):
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class InvocationErrorEvent(InvocationEvent):
class InvocationErrorEvent(InvocationEventBase):
"""Emitted when an invocation encounters an error"""
__event_name__ = "invocation_error"
@ -220,7 +192,7 @@ class InvocationErrorEvent(InvocationEvent):
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class SessionStartedEvent(SessionEvent):
class SessionStartedEvent(SessionEventBase):
"""Emitted when a session has started"""
__event_name__ = "session_started"
@ -236,7 +208,7 @@ class SessionStartedEvent(SessionEvent):
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class SessionCompleteEvent(SessionEvent):
class SessionCompleteEvent(SessionEventBase):
"""Emitted when a session has completed all invocations"""
__event_name__ = "session_complete"
@ -252,7 +224,7 @@ class SessionCompleteEvent(SessionEvent):
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class SessionCanceledEvent(SessionEvent):
class SessionCanceledEvent(SessionEventBase):
"""Emitted when a session is canceled"""
__event_name__ = "session_canceled"
@ -268,7 +240,7 @@ class SessionCanceledEvent(SessionEvent):
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class QueueItemStatusChangedEvent(QueueItemEvent):
class QueueItemStatusChangedEvent(QueueItemEventBase):
"""Emitted when a queue item's status changes"""
__event_name__ = "queue_item_status_changed"
@ -302,7 +274,7 @@ class QueueItemStatusChangedEvent(QueueItemEvent):
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class BatchEnqueuedEvent(QueueEvent):
class BatchEnqueuedEvent(QueueEventBase):
"""Emitted when a batch is enqueued"""
__event_name__ = "batch_enqueued"
@ -326,7 +298,7 @@ class BatchEnqueuedEvent(QueueEvent):
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class QueueClearedEvent(QueueEvent):
class QueueClearedEvent(QueueEventBase):
"""Emitted when a queue is cleared"""
__event_name__ = "queue_cleared"
@ -336,17 +308,14 @@ class QueueClearedEvent(QueueEvent):
return cls(queue_id=queue_id)
class DownloadEvent(BaseEvent, ABC):
class DownloadEventBase(EventBase):
"""Base class for events associated with a download"""
__event_type__ = EventType.DOWNLOAD
__event_name__ = "download_event"
source: str = Field(description="The source of the download")
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class DownloadStartedEvent(DownloadEvent):
class DownloadStartedEvent(DownloadEventBase):
"""Emitted when a download is started"""
__event_name__ = "download_started"
@ -359,7 +328,7 @@ class DownloadStartedEvent(DownloadEvent):
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class DownloadProgressEvent(DownloadEvent):
class DownloadProgressEvent(DownloadEventBase):
"""Emitted at intervals during a download"""
__event_name__ = "download_progress"
@ -374,7 +343,7 @@ class DownloadProgressEvent(DownloadEvent):
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class DownloadCompleteEvent(DownloadEvent):
class DownloadCompleteEvent(DownloadEventBase):
"""Emitted when a download is completed"""
__event_name__ = "download_complete"
@ -388,7 +357,7 @@ class DownloadCompleteEvent(DownloadEvent):
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class DownloadCancelledEvent(DownloadEvent):
class DownloadCancelledEvent(DownloadEventBase):
"""Emitted when a download is cancelled"""
__event_name__ = "download_cancelled"
@ -399,7 +368,7 @@ class DownloadCancelledEvent(DownloadEvent):
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class DownloadErrorEvent(DownloadEvent):
class DownloadErrorEvent(DownloadEventBase):
"""Emitted when a download encounters an error"""
__event_name__ = "download_error"
@ -412,15 +381,12 @@ class DownloadErrorEvent(DownloadEvent):
return cls(source=source, error_type=error_type, error=error)
class ModelEvent(BaseEvent, ABC):
class ModelEventBase(EventBase):
"""Base class for events associated with a model"""
__event_type__ = EventType.MODEL
__event_name__ = "model_event"
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class ModelLoadStartedEvent(ModelEvent):
class ModelLoadStartedEvent(ModelEventBase):
"""Emitted when a model is requested"""
__event_name__ = "model_load_started"
@ -434,7 +400,7 @@ class ModelLoadStartedEvent(ModelEvent):
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class ModelLoadCompleteEvent(ModelEvent):
class ModelLoadCompleteEvent(ModelEventBase):
"""Emitted when a model is requested"""
__event_name__ = "model_load_complete"
@ -448,7 +414,7 @@ class ModelLoadCompleteEvent(ModelEvent):
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class ModelInstallDownloadProgressEvent(ModelEvent):
class ModelInstallDownloadProgressEvent(ModelEventBase):
"""Emitted at intervals while the install job is in progress (remote models only)."""
__event_name__ = "model_install_download_progress"
@ -484,7 +450,7 @@ class ModelInstallDownloadProgressEvent(ModelEvent):
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class ModelInstallDownloadsCompleteEvent(ModelEvent):
class ModelInstallDownloadsCompleteEvent(ModelEventBase):
"""Emitted once when an install job becomes active."""
__event_name__ = "model_install_downloads_complete"
@ -498,7 +464,7 @@ class ModelInstallDownloadsCompleteEvent(ModelEvent):
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class ModelInstallStartedEvent(ModelEvent):
class ModelInstallStartedEvent(ModelEventBase):
"""Emitted once when an install job becomes active."""
__event_name__ = "model_install_started"
@ -512,7 +478,7 @@ class ModelInstallStartedEvent(ModelEvent):
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class ModelInstallCompleteEvent(ModelEvent):
class ModelInstallCompleteEvent(ModelEventBase):
"""Emitted when an install job is completed successfully."""
__event_name__ = "model_install_complete"
@ -529,7 +495,7 @@ class ModelInstallCompleteEvent(ModelEvent):
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class ModelInstallCancelledEvent(ModelEvent):
class ModelInstallCancelledEvent(ModelEventBase):
"""Emitted when an install job is cancelled."""
__event_name__ = "model_install_cancelled"
@ -543,7 +509,7 @@ class ModelInstallCancelledEvent(ModelEvent):
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class ModelInstallErrorEvent(ModelEvent):
class ModelInstallErrorEvent(ModelEventBase):
"""Emitted when an install job encounters an exception."""
__event_name__ = "model_install_error"
@ -560,19 +526,16 @@ class ModelInstallErrorEvent(ModelEvent):
return cls(id=job.id, source=str(job.source), error_type=job.error_type, error=job.error)
class BulkDownloadEvent(BaseEvent, ABC):
class BulkDownloadEventBase(EventBase):
"""Base class for events associated with a bulk image download"""
__event_type__ = EventType.BULK_IMAGE_DOWNLOAD
__event_name__ = "bulk_image_download_event"
bulk_download_id: str = Field(description="The ID of the bulk image download")
bulk_download_item_id: str = Field(description="The ID of the bulk image download item")
bulk_download_item_name: str = Field(description="The name of the bulk image download item")
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class BulkDownloadStartedEvent(BulkDownloadEvent):
class BulkDownloadStartedEvent(BulkDownloadEventBase):
"""Emitted when a bulk image download is started"""
__event_name__ = "bulk_download_started"
@ -589,7 +552,7 @@ class BulkDownloadStartedEvent(BulkDownloadEvent):
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class BulkDownloadCompleteEvent(BulkDownloadEvent):
class BulkDownloadCompleteEvent(BulkDownloadEventBase):
"""Emitted when a bulk image download is started"""
__event_name__ = "bulk_download_complete"
@ -606,7 +569,7 @@ class BulkDownloadCompleteEvent(BulkDownloadEvent):
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class BulkDownloadErrorEvent(BulkDownloadEvent):
class BulkDownloadErrorEvent(BulkDownloadEventBase):
"""Emitted when a bulk image download is started"""
__event_name__ = "bulk_download_error"

View File

@ -7,7 +7,7 @@ from queue import Empty, Queue
from fastapi_events.dispatcher import dispatch
from invokeai.app.services.events.events_common import (
BaseEvent,
EventBase,
)
from .events_base import EventServiceBase
@ -16,7 +16,7 @@ from .events_base import EventServiceBase
class FastAPIEventService(EventServiceBase):
def __init__(self, event_handler_id: int) -> None:
self.event_handler_id = event_handler_id
self._queue = Queue[BaseEvent | None]()
self._queue = Queue[EventBase | None]()
self._stop_event = threading.Event()
asyncio.create_task(self._dispatch_from_queue(stop_event=self._stop_event))
@ -26,7 +26,7 @@ class FastAPIEventService(EventServiceBase):
self._stop_event.set()
self._queue.put(None)
def dispatch(self, event: BaseEvent) -> None:
def dispatch(self, event: EventBase) -> None:
self._queue.put(event)
async def _dispatch_from_queue(self, stop_event: threading.Event):

View File

@ -9,7 +9,7 @@ from invokeai.app.services.events.events_common import (
BatchEnqueuedEvent,
FastAPIEvent,
QueueClearedEvent,
QueueEvent,
QueueEventBase,
SessionCanceledEvent,
register_events,
)
@ -71,7 +71,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
def _poll_now(self) -> None:
self._poll_now_event.set()
async def _on_queue_event(self, event: FastAPIEvent[QueueEvent]) -> None:
async def _on_queue_event(self, event: FastAPIEvent[QueueEventBase]) -> None:
_event_name, payload = event
if (
isinstance(payload, SessionCanceledEvent)