From 2dc752ea833381bc36d8c2e838b13e1342b3d462 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 10 Mar 2024 23:23:11 +1100 Subject: [PATCH] feat(events): simplify event classes - Remove ABCs, they do not work well with pydantic - Remove the event type classvar - unused - Remove clever logic to require an event name - we already get validation for this during schema registration. - Rename event bases to all end in "Base" --- invokeai/app/api/sockets.py | 12 +- invokeai/app/services/events/events_base.py | 6 +- invokeai/app/services/events/events_common.py | 111 ++++++------------ .../services/events/events_fastapievents.py | 6 +- .../session_processor_default.py | 4 +- 5 files changed, 51 insertions(+), 88 deletions(-) diff --git a/invokeai/app/api/sockets.py b/invokeai/app/api/sockets.py index 3245fa0921..697eda9217 100644 --- a/invokeai/app/api/sockets.py +++ b/invokeai/app/api/sockets.py @@ -10,14 +10,14 @@ from invokeai.app.services.events.events_common import ( BatchEnqueuedEvent, BulkDownloadCompleteEvent, BulkDownloadErrorEvent, - BulkDownloadEvent, + BulkDownloadEventBase, BulkDownloadStartedEvent, FastAPIEvent, InvocationCompleteEvent, InvocationDenoiseProgressEvent, InvocationErrorEvent, InvocationStartedEvent, - ModelEvent, + ModelEventBase, ModelInstallCancelledEvent, ModelInstallCompleteEvent, ModelInstallDownloadProgressEvent, @@ -26,7 +26,7 @@ from invokeai.app.services.events.events_common import ( ModelLoadCompleteEvent, ModelLoadStartedEvent, QueueClearedEvent, - QueueEvent, + QueueEventBase, QueueItemStatusChangedEvent, SessionCanceledEvent, SessionCompleteEvent, @@ -106,14 +106,14 @@ class SocketIO: async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None: await self._sio.leave_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id) - async def _handle_queue_event(self, event: FastAPIEvent[QueueEvent]): + async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]): event_name, payload = event await self._sio.emit(event=event_name, data=payload.model_dump(), room=payload.queue_id) - async def _handle_model_event(self, event: FastAPIEvent[ModelEvent]) -> None: + async def _handle_model_event(self, event: FastAPIEvent[ModelEventBase]) -> None: event_name, payload = event await self._sio.emit(event=event_name, data=payload.model_dump()) - async def _handle_bulk_image_download_event(self, event: FastAPIEvent[BulkDownloadEvent]) -> None: + async def _handle_bulk_image_download_event(self, event: FastAPIEvent[BulkDownloadEventBase]) -> None: event_name, payload = event await self._sio.emit(event=event_name, data=payload.model_dump(), room=payload.bulk_download_id) diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index b489f29fd7..6287753d93 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -4,7 +4,6 @@ from typing import TYPE_CHECKING, Optional from invokeai.app.services.events.events_common import ( - BaseEvent, BatchEnqueuedEvent, BulkDownloadCompleteEvent, BulkDownloadErrorEvent, @@ -14,6 +13,7 @@ from invokeai.app.services.events.events_common import ( DownloadErrorEvent, DownloadProgressEvent, DownloadStartedEvent, + EventBase, InvocationCompleteEvent, InvocationDenoiseProgressEvent, InvocationErrorEvent, @@ -35,7 +35,7 @@ from invokeai.app.services.events.events_common import ( if TYPE_CHECKING: from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput - from invokeai.app.services.events.events_common import BaseEvent + from invokeai.app.services.events.events_common import EventBase from invokeai.app.services.model_install.model_install_common import ModelInstallJob from invokeai.app.services.session_processor.session_processor_common import ProgressImage from invokeai.app.services.session_queue.session_queue_common import ( @@ -50,7 +50,7 @@ if TYPE_CHECKING: class EventServiceBase: """Basic event bus, to have an empty stand-in when not needed""" - def dispatch(self, event: "BaseEvent") -> None: + def dispatch(self, event: "EventBase") -> None: pass # region: Invocation diff --git a/invokeai/app/services/events/events_common.py b/invokeai/app/services/events/events_common.py index f7264e2341..66152cd98f 100644 --- a/invokeai/app/services/events/events_common.py +++ b/invokeai/app/services/events/events_common.py @@ -1,6 +1,4 @@ -from abc import ABC -from enum import Enum -from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Optional, Protocol, TypeAlias, TypeVar +from typing import TYPE_CHECKING, Any, Coroutine, Optional, Protocol, TypeAlias, TypeVar from fastapi_events.handlers.local import local_handler from fastapi_events.registry.payload_schema import registry as payload_schema @@ -22,39 +20,22 @@ if TYPE_CHECKING: from invokeai.app.services.model_install.model_install_common import ModelInstallJob -class EventType(str, Enum): - QUEUE = "queue" - MODEL = "model" - DOWNLOAD = "download" - BULK_IMAGE_DOWNLOAD = "bulk_image_download" - - -class BaseEvent(BaseModel, ABC): +class EventBase(BaseModel): """Base class for all events. All events must inherit from this class. - Events must define the following class attributes: - - `__event_name__: str`: The name of the event - - `__event_type__: EventType`: The type of the event + Events must define a class attribute `__event_name__` to identify the event. All other attributes should be defined as normal for a pydantic model. A timestamp is automatically added to the event when it is created. """ - __event_name__: ClassVar[str] = ... # pyright: ignore [reportAssignmentType] - __event_type__: ClassVar[EventType] = ... # pyright: ignore [reportAssignmentType] - timestamp: int = Field(description="The timestamp of the event", default_factory=get_timestamp) - def __init_subclass__(cls, **kwargs: ConfigDict): - for required_attr in ("__event_name__", "__event_type__"): - if getattr(cls, required_attr) is ...: - raise TypeError(f"{cls.__name__} must define {required_attr}") - model_config = ConfigDict(json_schema_serialization_defaults_required=True) -TEvent = TypeVar("TEvent", bound=BaseEvent) +TEvent = TypeVar("TEvent", bound=EventBase) FastAPIEvent: TypeAlias = tuple[str, TEvent] """ @@ -77,37 +58,28 @@ def register_events(events: set[type[TEvent]], func: FastAPIEventFunc) -> None: local_handler.register(event_name=event.__event_name__, _func=func) -class QueueEvent(BaseEvent, ABC): +class QueueEventBase(EventBase): """Base class for queue events""" - __event_type__ = EventType.QUEUE - __event_name__ = "queue_event" - queue_id: str = Field(description="The ID of the queue") -class QueueItemEvent(QueueEvent, ABC): +class QueueItemEventBase(QueueEventBase): """Base class for queue item events""" - __event_name__ = "queue_item_event" - item_id: int = Field(description="The ID of the queue item") batch_id: str = Field(description="The ID of the queue batch") -class SessionEvent(QueueItemEvent, ABC): +class SessionEventBase(QueueItemEventBase): """Base class for session (aka graph execution state) events""" - __event_name__ = "session_event" - session_id: str = Field(description="The ID of the session (aka graph execution state)") -class InvocationEvent(SessionEvent, ABC): +class InvocationEventBase(SessionEventBase): """Base class for invocation events""" - __event_name__ = "invocation_event" - queue_id: str = Field(description="The ID of the queue") item_id: int = Field(description="The ID of the queue item") batch_id: str = Field(description="The ID of the queue batch") @@ -118,7 +90,7 @@ class InvocationEvent(SessionEvent, ABC): @payload_schema.register # pyright: ignore [reportUnknownMemberType] -class InvocationStartedEvent(InvocationEvent): +class InvocationStartedEvent(InvocationEventBase): """Emitted when an invocation is started""" __event_name__ = "invocation_started" @@ -137,7 +109,7 @@ class InvocationStartedEvent(InvocationEvent): @payload_schema.register # pyright: ignore [reportUnknownMemberType] -class InvocationDenoiseProgressEvent(InvocationEvent): +class InvocationDenoiseProgressEvent(InvocationEventBase): """Emitted at each step during denoising of an invocation.""" __event_name__ = "invocation_denoise_progress" @@ -170,7 +142,7 @@ class InvocationDenoiseProgressEvent(InvocationEvent): @payload_schema.register # pyright: ignore [reportUnknownMemberType] -class InvocationCompleteEvent(InvocationEvent): +class InvocationCompleteEvent(InvocationEventBase): """Emitted when an invocation is complete""" __event_name__ = "invocation_complete" @@ -194,7 +166,7 @@ class InvocationCompleteEvent(InvocationEvent): @payload_schema.register # pyright: ignore [reportUnknownMemberType] -class InvocationErrorEvent(InvocationEvent): +class InvocationErrorEvent(InvocationEventBase): """Emitted when an invocation encounters an error""" __event_name__ = "invocation_error" @@ -227,7 +199,7 @@ class InvocationErrorEvent(InvocationEvent): @payload_schema.register # pyright: ignore [reportUnknownMemberType] -class SessionStartedEvent(SessionEvent): +class SessionStartedEvent(SessionEventBase): """Emitted when a session has started""" __event_name__ = "session_started" @@ -243,7 +215,7 @@ class SessionStartedEvent(SessionEvent): @payload_schema.register # pyright: ignore [reportUnknownMemberType] -class SessionCompleteEvent(SessionEvent): +class SessionCompleteEvent(SessionEventBase): """Emitted when a session has completed all invocations""" __event_name__ = "session_complete" @@ -259,7 +231,7 @@ class SessionCompleteEvent(SessionEvent): @payload_schema.register # pyright: ignore [reportUnknownMemberType] -class SessionCanceledEvent(SessionEvent): +class SessionCanceledEvent(SessionEventBase): """Emitted when a session is canceled""" __event_name__ = "session_canceled" @@ -275,7 +247,7 @@ class SessionCanceledEvent(SessionEvent): @payload_schema.register # pyright: ignore [reportUnknownMemberType] -class QueueItemStatusChangedEvent(SessionEvent): +class QueueItemStatusChangedEvent(SessionEventBase): """Emitted when a queue item's status changes""" __event_name__ = "queue_item_status_changed" @@ -314,7 +286,7 @@ class QueueItemStatusChangedEvent(SessionEvent): @payload_schema.register # pyright: ignore [reportUnknownMemberType] -class BatchEnqueuedEvent(QueueEvent): +class BatchEnqueuedEvent(QueueEventBase): """Emitted when a batch is enqueued""" __event_name__ = "batch_enqueued" @@ -338,7 +310,7 @@ class BatchEnqueuedEvent(QueueEvent): @payload_schema.register # pyright: ignore [reportUnknownMemberType] -class QueueClearedEvent(QueueEvent): +class QueueClearedEvent(QueueEventBase): """Emitted when a queue is cleared""" __event_name__ = "queue_cleared" @@ -348,17 +320,14 @@ class QueueClearedEvent(QueueEvent): return cls(queue_id=queue_id) -class DownloadEvent(BaseEvent, ABC): +class DownloadEventBase(EventBase): """Base class for events associated with a download""" - __event_type__ = EventType.DOWNLOAD - __event_name__ = "download_event" - source: str = Field(description="The source of the download") @payload_schema.register # pyright: ignore [reportUnknownMemberType] -class DownloadStartedEvent(DownloadEvent): +class DownloadStartedEvent(DownloadEventBase): """Emitted when a download is started""" __event_name__ = "download_started" @@ -371,7 +340,7 @@ class DownloadStartedEvent(DownloadEvent): @payload_schema.register # pyright: ignore [reportUnknownMemberType] -class DownloadProgressEvent(DownloadEvent): +class DownloadProgressEvent(DownloadEventBase): """Emitted at intervals during a download""" __event_name__ = "download_progress" @@ -386,7 +355,7 @@ class DownloadProgressEvent(DownloadEvent): @payload_schema.register # pyright: ignore [reportUnknownMemberType] -class DownloadCompleteEvent(DownloadEvent): +class DownloadCompleteEvent(DownloadEventBase): """Emitted when a download is completed""" __event_name__ = "download_complete" @@ -400,7 +369,7 @@ class DownloadCompleteEvent(DownloadEvent): @payload_schema.register # pyright: ignore [reportUnknownMemberType] -class DownloadCancelledEvent(DownloadEvent): +class DownloadCancelledEvent(DownloadEventBase): """Emitted when a download is cancelled""" __event_name__ = "download_cancelled" @@ -411,7 +380,7 @@ class DownloadCancelledEvent(DownloadEvent): @payload_schema.register # pyright: ignore [reportUnknownMemberType] -class DownloadErrorEvent(DownloadEvent): +class DownloadErrorEvent(DownloadEventBase): """Emitted when a download encounters an error""" __event_name__ = "download_error" @@ -424,15 +393,12 @@ class DownloadErrorEvent(DownloadEvent): return cls(source=source, error_type=error_type, error=error) -class ModelEvent(BaseEvent, ABC): +class ModelEventBase(EventBase): """Base class for events associated with a model""" - __event_type__ = EventType.MODEL - __event_name__ = "model_event" - @payload_schema.register # pyright: ignore [reportUnknownMemberType] -class ModelLoadStartedEvent(ModelEvent): +class ModelLoadStartedEvent(ModelEventBase): """Emitted when a model is requested""" __event_name__ = "model_load_started" @@ -446,7 +412,7 @@ class ModelLoadStartedEvent(ModelEvent): @payload_schema.register # pyright: ignore [reportUnknownMemberType] -class ModelLoadCompleteEvent(ModelEvent): +class ModelLoadCompleteEvent(ModelEventBase): """Emitted when a model is requested""" __event_name__ = "model_load_complete" @@ -460,7 +426,7 @@ class ModelLoadCompleteEvent(ModelEvent): @payload_schema.register # pyright: ignore [reportUnknownMemberType] -class ModelInstallDownloadProgressEvent(ModelEvent): +class ModelInstallDownloadProgressEvent(ModelEventBase): """Emitted at intervals while the install job is in progress (remote models only).""" __event_name__ = "model_install_download_progress" @@ -496,7 +462,7 @@ class ModelInstallDownloadProgressEvent(ModelEvent): @payload_schema.register # pyright: ignore [reportUnknownMemberType] -class ModelInstallDownloadsCompleteEvent(ModelEvent): +class ModelInstallDownloadsCompleteEvent(ModelEventBase): """Emitted once when an install job becomes active.""" __event_name__ = "model_install_downloads_complete" @@ -510,7 +476,7 @@ class ModelInstallDownloadsCompleteEvent(ModelEvent): @payload_schema.register # pyright: ignore [reportUnknownMemberType] -class ModelInstallStartedEvent(ModelEvent): +class ModelInstallStartedEvent(ModelEventBase): """Emitted once when an install job becomes active.""" __event_name__ = "model_install_started" @@ -524,7 +490,7 @@ class ModelInstallStartedEvent(ModelEvent): @payload_schema.register # pyright: ignore [reportUnknownMemberType] -class ModelInstallCompleteEvent(ModelEvent): +class ModelInstallCompleteEvent(ModelEventBase): """Emitted when an install job is completed successfully.""" __event_name__ = "model_install_complete" @@ -541,7 +507,7 @@ class ModelInstallCompleteEvent(ModelEvent): @payload_schema.register # pyright: ignore [reportUnknownMemberType] -class ModelInstallCancelledEvent(ModelEvent): +class ModelInstallCancelledEvent(ModelEventBase): """Emitted when an install job is cancelled.""" __event_name__ = "model_install_cancelled" @@ -555,7 +521,7 @@ class ModelInstallCancelledEvent(ModelEvent): @payload_schema.register # pyright: ignore [reportUnknownMemberType] -class ModelInstallErrorEvent(ModelEvent): +class ModelInstallErrorEvent(ModelEventBase): """Emitted when an install job encounters an exception.""" __event_name__ = "model_install_error" @@ -572,19 +538,16 @@ class ModelInstallErrorEvent(ModelEvent): return cls(id=job.id, source=str(job.source), error_type=job.error_type, error=job.error) -class BulkDownloadEvent(BaseEvent, ABC): +class BulkDownloadEventBase(EventBase): """Base class for events associated with a bulk image download""" - __event_type__ = EventType.BULK_IMAGE_DOWNLOAD - __event_name__ = "bulk_image_download_event" - bulk_download_id: str = Field(description="The ID of the bulk image download") bulk_download_item_id: str = Field(description="The ID of the bulk image download item") bulk_download_item_name: str = Field(description="The name of the bulk image download item") @payload_schema.register # pyright: ignore [reportUnknownMemberType] -class BulkDownloadStartedEvent(BulkDownloadEvent): +class BulkDownloadStartedEvent(BulkDownloadEventBase): """Emitted when a bulk image download is started""" __event_name__ = "bulk_download_started" @@ -601,7 +564,7 @@ class BulkDownloadStartedEvent(BulkDownloadEvent): @payload_schema.register # pyright: ignore [reportUnknownMemberType] -class BulkDownloadCompleteEvent(BulkDownloadEvent): +class BulkDownloadCompleteEvent(BulkDownloadEventBase): """Emitted when a bulk image download is started""" __event_name__ = "bulk_download_complete" @@ -618,7 +581,7 @@ class BulkDownloadCompleteEvent(BulkDownloadEvent): @payload_schema.register # pyright: ignore [reportUnknownMemberType] -class BulkDownloadErrorEvent(BulkDownloadEvent): +class BulkDownloadErrorEvent(BulkDownloadEventBase): """Emitted when a bulk image download is started""" __event_name__ = "bulk_download_error" diff --git a/invokeai/app/services/events/events_fastapievents.py b/invokeai/app/services/events/events_fastapievents.py index a8317911cf..09f2a2f8b1 100644 --- a/invokeai/app/services/events/events_fastapievents.py +++ b/invokeai/app/services/events/events_fastapievents.py @@ -7,7 +7,7 @@ from queue import Empty, Queue from fastapi_events.dispatcher import dispatch from invokeai.app.services.events.events_common import ( - BaseEvent, + EventBase, ) from .events_base import EventServiceBase @@ -16,7 +16,7 @@ from .events_base import EventServiceBase class FastAPIEventService(EventServiceBase): def __init__(self, event_handler_id: int) -> None: self.event_handler_id = event_handler_id - self._queue = Queue[BaseEvent | None]() + self._queue = Queue[EventBase | None]() self._stop_event = threading.Event() asyncio.create_task(self._dispatch_from_queue(stop_event=self._stop_event)) @@ -26,7 +26,7 @@ class FastAPIEventService(EventServiceBase): self._stop_event.set() self._queue.put(None) - def dispatch(self, event: BaseEvent) -> None: + def dispatch(self, event: EventBase) -> None: self._queue.put(event) async def _dispatch_from_queue(self, stop_event: threading.Event): diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index ef455d00b0..7f9ce0b41a 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -9,7 +9,7 @@ from invokeai.app.services.events.events_common import ( BatchEnqueuedEvent, FastAPIEvent, QueueClearedEvent, - QueueEvent, + QueueEventBase, QueueItemStatusChangedEvent, SessionCanceledEvent, register_events, @@ -332,7 +332,7 @@ class DefaultSessionProcessor(SessionProcessorBase): def _poll_now(self) -> None: self._poll_now_event.set() - async def _on_queue_event(self, event: FastAPIEvent[QueueEvent]) -> None: + async def _on_queue_event(self, event: FastAPIEvent[QueueEventBase]) -> None: _event_name, payload = event if ( isinstance(payload, SessionCanceledEvent)