feat(events): register_events supports single event

This commit is contained in:
psychedelicious 2024-05-26 09:41:03 +10:00
parent c0aabcd8ea
commit 368127bd25
3 changed files with 11 additions and 10 deletions

View File

@ -63,12 +63,13 @@ class FastAPIEventFunc(Protocol):
def __call__(self, event: FastAPIEvent[Any]) -> Optional[Coroutine[Any, Any, None]]: ...
def register_events(events: set[type[TEvent]], func: FastAPIEventFunc) -> None:
"""Register a function to handle a list of events.
def register_events(events: set[type[TEvent]] | type[TEvent], func: FastAPIEventFunc) -> None:
"""Register a function to handle specific events.
:param events: A list of event classes to handle
:param events: An event or set of events to handle
:param func: The function to handle the events
"""
events = events if isinstance(events, set) else {events}
for event in events:
assert hasattr(event, "__event_name__")
local_handler.register(event_name=event.__event_name__, _func=func) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]

View File

@ -295,10 +295,10 @@ class DefaultSessionProcessor(SessionProcessorBase):
self._poll_now_event = ThreadEvent()
self._cancel_event = ThreadEvent()
register_events(events={SessionCanceledEvent}, func=self._on_session_canceled)
register_events(events={QueueClearedEvent}, func=self._on_queue_cleared)
register_events(events={BatchEnqueuedEvent}, func=self._on_batch_enqueued)
register_events(events={QueueItemStatusChangedEvent}, func=self._on_queue_item_status_changed)
register_events(SessionCanceledEvent, self._on_session_canceled)
register_events(QueueClearedEvent, self._on_queue_cleared)
register_events(BatchEnqueuedEvent, self._on_batch_enqueued)
register_events(QueueItemStatusChangedEvent, self._on_queue_item_status_changed)
self._thread_semaphore = BoundedSemaphore(self._thread_limit)

View File

@ -46,9 +46,9 @@ class SqliteSessionQueue(SessionQueueBase):
self._set_in_progress_to_canceled()
prune_result = self.prune(DEFAULT_QUEUE_ID)
register_events(events={InvocationErrorEvent}, func=self._handle_error_event)
register_events(events={SessionCompleteEvent}, func=self._handle_complete_event)
register_events(events={SessionCanceledEvent}, func=self._handle_cancel_event)
register_events(InvocationErrorEvent, self._handle_error_event)
register_events(SessionCompleteEvent, self._handle_complete_event)
register_events(SessionCanceledEvent, self._handle_cancel_event)
if prune_result.deleted > 0:
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")