mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor: remove all session events
There's no longer any need for session-scoped events now that we have the session queue. Session started/completed/canceled map 1-to-1 to queue item status events, but queue item status events also have an event for failed state. We can simplify queue and processor handling substantially by removing session events and instead using queue item events. - Remove the session-scoped events entirely. - Remove all event handling from session queue. The processor still needs to respond to some events from the queue: `QueueClearedEvent`, `BatchEnqueuedEvent` and `QueueItemStatusChangedEvent`. - Pass an `is_canceled` callback to the invocation context instead of the cancel event - Update processor logic to ensure the local instance of the current queue item is synced with the instance in the database. This prevents race conditions and ensures lifecycle callback do not get stale callbacks. - Update docstrings and comments - Add `complete_queue_item` method to session queue service as an explicit way to mark a queue item as successfully completed. Previously, the queue listened for session complete events to do this. Closes #6442
This commit is contained in:
parent
8592f5c6e1
commit
084cf26ed6
@ -34,9 +34,6 @@ from invokeai.app.services.events.events_common import (
|
||||
QueueClearedEvent,
|
||||
QueueEventBase,
|
||||
QueueItemStatusChangedEvent,
|
||||
SessionCanceledEvent,
|
||||
SessionCompleteEvent,
|
||||
SessionStartedEvent,
|
||||
register_events,
|
||||
)
|
||||
|
||||
@ -54,9 +51,6 @@ QUEUE_EVENTS = {
|
||||
InvocationDenoiseProgressEvent,
|
||||
InvocationCompleteEvent,
|
||||
InvocationErrorEvent,
|
||||
SessionStartedEvent,
|
||||
SessionCompleteEvent,
|
||||
SessionCanceledEvent,
|
||||
QueueItemStatusChangedEvent,
|
||||
BatchEnqueuedEvent,
|
||||
QueueClearedEvent,
|
||||
|
@ -28,9 +28,6 @@ from invokeai.app.services.events.events_common import (
|
||||
ModelLoadStartedEvent,
|
||||
QueueClearedEvent,
|
||||
QueueItemStatusChangedEvent,
|
||||
SessionCanceledEvent,
|
||||
SessionCompleteEvent,
|
||||
SessionStartedEvent,
|
||||
)
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
|
||||
@ -90,22 +87,6 @@ class EventServiceBase:
|
||||
|
||||
# endregion
|
||||
|
||||
# region Session
|
||||
|
||||
def emit_session_started(self, queue_item: "SessionQueueItem") -> None:
|
||||
"""Emitted when a session has started"""
|
||||
self.dispatch(SessionStartedEvent.build(queue_item))
|
||||
|
||||
def emit_session_complete(self, queue_item: "SessionQueueItem") -> None:
|
||||
"""Emitted when a session has completed all invocations"""
|
||||
self.dispatch(SessionCompleteEvent.build(queue_item))
|
||||
|
||||
def emit_session_canceled(self, queue_item: "SessionQueueItem") -> None:
|
||||
"""Emitted when a session is canceled"""
|
||||
self.dispatch(SessionCanceledEvent.build(queue_item))
|
||||
|
||||
# endregion
|
||||
|
||||
# region Queue
|
||||
|
||||
def emit_queue_item_status_changed(
|
||||
|
@ -88,15 +88,10 @@ class QueueItemEventBase(QueueEventBase):
|
||||
batch_id: str = Field(description="The ID of the queue batch")
|
||||
|
||||
|
||||
class SessionEventBase(QueueItemEventBase):
|
||||
"""Base class for session (aka graph execution state) events"""
|
||||
|
||||
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
||||
|
||||
|
||||
class InvocationEventBase(SessionEventBase):
|
||||
class InvocationEventBase(QueueItemEventBase):
|
||||
"""Base class for invocation events"""
|
||||
|
||||
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
||||
queue_id: str = Field(description="The ID of the queue")
|
||||
item_id: int = Field(description="The ID of the queue item")
|
||||
batch_id: str = Field(description="The ID of the queue batch")
|
||||
@ -231,51 +226,6 @@ class InvocationErrorEvent(InvocationEventBase):
|
||||
)
|
||||
|
||||
|
||||
class SessionStartedEvent(SessionEventBase):
|
||||
"""Event model for session_started"""
|
||||
|
||||
__event_name__ = "session_started"
|
||||
|
||||
@classmethod
|
||||
def build(cls, queue_item: SessionQueueItem) -> "SessionStartedEvent":
|
||||
return cls(
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
session_id=queue_item.session_id,
|
||||
)
|
||||
|
||||
|
||||
class SessionCompleteEvent(SessionEventBase):
|
||||
"""Event model for session_complete"""
|
||||
|
||||
__event_name__ = "session_complete"
|
||||
|
||||
@classmethod
|
||||
def build(cls, queue_item: SessionQueueItem) -> "SessionCompleteEvent":
|
||||
return cls(
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
session_id=queue_item.session_id,
|
||||
)
|
||||
|
||||
|
||||
class SessionCanceledEvent(SessionEventBase):
|
||||
"""Event model for session_canceled"""
|
||||
|
||||
__event_name__ = "session_canceled"
|
||||
|
||||
@classmethod
|
||||
def build(cls, queue_item: SessionQueueItem) -> "SessionCanceledEvent":
|
||||
return cls(
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
session_id=queue_item.session_id,
|
||||
)
|
||||
|
||||
|
||||
class QueueItemStatusChangedEvent(QueueItemEventBase):
|
||||
"""Event model for queue_item_status_changed"""
|
||||
|
||||
|
@ -10,7 +10,6 @@ from invokeai.app.services.events.events_common import (
|
||||
FastAPIEvent,
|
||||
QueueClearedEvent,
|
||||
QueueItemStatusChangedEvent,
|
||||
SessionCanceledEvent,
|
||||
register_events,
|
||||
)
|
||||
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
|
||||
@ -64,6 +63,11 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
self._cancel_event = cancel_event
|
||||
self._profiler = profiler
|
||||
|
||||
def _is_canceled(self) -> bool:
|
||||
"""Check if the cancel event is set. This is also passed to the invocation context builder and called during
|
||||
denoising to check if the session has been canceled."""
|
||||
return self._cancel_event.is_set()
|
||||
|
||||
def run(self, queue_item: SessionQueueItem):
|
||||
# Exceptions raised outside `run_node` are handled by the processor. There is no need to catch them here.
|
||||
|
||||
@ -87,13 +91,19 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
)
|
||||
break
|
||||
|
||||
if invocation is None or self._cancel_event.is_set():
|
||||
if invocation is None or self._is_canceled():
|
||||
break
|
||||
|
||||
self.run_node(invocation, queue_item)
|
||||
|
||||
# The session is complete if all invocations have been run or there is an error on the session.
|
||||
if queue_item.session.is_complete() or self._cancel_event.is_set():
|
||||
# At this time, the queue item may be canceled, but the object itself here won't be updated yet. We must
|
||||
# use the cancel event to check if the session is canceled.
|
||||
if (
|
||||
queue_item.session.is_complete()
|
||||
or self._is_canceled()
|
||||
or queue_item.status in ["failed", "canceled", "completed"]
|
||||
):
|
||||
break
|
||||
|
||||
self._on_after_run_session(queue_item=queue_item)
|
||||
@ -112,7 +122,7 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
context = build_invocation_context(
|
||||
data=data,
|
||||
services=self._services,
|
||||
cancel_event=self._cancel_event,
|
||||
is_canceled=self._is_canceled,
|
||||
)
|
||||
|
||||
# Invoke the node
|
||||
@ -126,16 +136,12 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
# TODO(psyche): This is expected to be caught in the main thread. Do we need to catch this here?
|
||||
pass
|
||||
except CanceledException:
|
||||
# When the user cancels the graph, we first set the cancel event. The event is checked
|
||||
# between invocations, in this loop. Some invocations are long-running, and we need to
|
||||
# be able to cancel them mid-execution.
|
||||
# A CanceledException is raised during the denoising step callback if the cancel event is set. We don't need
|
||||
# to do any handling here, and no error should be set - just pass and the cancellation will be handled
|
||||
# correctly in the next iteration of the session runner loop.
|
||||
#
|
||||
# For example, denoising is a long-running invocation with many steps. A step callback
|
||||
# is executed after each step. This step callback checks if the canceled event is set,
|
||||
# then raises a CanceledException to stop execution immediately.
|
||||
#
|
||||
# When we get a CanceledException, we don't need to do anything - just pass and let the
|
||||
# loop go to its next iteration, and the cancel event will be handled correctly.
|
||||
# See the comment in the processor's `_on_queue_item_status_changed()` method for more details on how we
|
||||
# handle cancellation.
|
||||
pass
|
||||
except Exception as e:
|
||||
error_type = e.__class__.__name__
|
||||
@ -156,8 +162,6 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
f"On before run session: queue item {queue_item.item_id}, session {queue_item.session_id}"
|
||||
)
|
||||
|
||||
self._services.events.emit_session_started(queue_item)
|
||||
|
||||
# If profiling is enabled, start the profiler
|
||||
if self._profiler is not None:
|
||||
self._profiler.start(profile_id=queue_item.session_id)
|
||||
@ -186,9 +190,10 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
# while the session is running.
|
||||
queue_item = self._services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
|
||||
|
||||
# TODO(psyche): This feels jumbled - we should review separation of concerns here.
|
||||
# Send complete event. The events service will receive this and update the queue item's status.
|
||||
self._services.events.emit_session_complete(queue_item=queue_item)
|
||||
# The queue item may have been canceled or failed while the session was running. We should only complete it
|
||||
# if it is not already canceled or failed.
|
||||
if queue_item.status not in ["canceled", "failed"]:
|
||||
queue_item = self._services.session_queue.complete_queue_item(queue_item.item_id)
|
||||
|
||||
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
|
||||
# we don't care about that - suppress the error.
|
||||
@ -251,6 +256,12 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
)
|
||||
self._services.logger.error(error_traceback)
|
||||
|
||||
# Fail the queue item
|
||||
queue_item = self._services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
|
||||
queue_item = self._services.session_queue.fail_queue_item(
|
||||
queue_item.item_id, error_type, error_message, error_traceback
|
||||
)
|
||||
|
||||
# Send error event
|
||||
self._services.events.emit_invocation_error(
|
||||
queue_item=queue_item,
|
||||
@ -295,7 +306,6 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
self._poll_now_event = ThreadEvent()
|
||||
self._cancel_event = ThreadEvent()
|
||||
|
||||
register_events(SessionCanceledEvent, self._on_session_canceled)
|
||||
register_events(QueueClearedEvent, self._on_queue_cleared)
|
||||
register_events(BatchEnqueuedEvent, self._on_batch_enqueued)
|
||||
register_events(QueueItemStatusChangedEvent, self._on_queue_item_status_changed)
|
||||
@ -333,11 +343,6 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
def _poll_now(self) -> None:
|
||||
self._poll_now_event.set()
|
||||
|
||||
async def _on_session_canceled(self, event: FastAPIEvent[SessionCanceledEvent]) -> None:
|
||||
if self._queue_item and self._queue_item.item_id == event[1].item_id:
|
||||
self._cancel_event.set()
|
||||
self._poll_now()
|
||||
|
||||
async def _on_queue_cleared(self, event: FastAPIEvent[QueueClearedEvent]) -> None:
|
||||
if self._queue_item and self._queue_item.queue_id == event[1].queue_id:
|
||||
self._cancel_event.set()
|
||||
@ -348,6 +353,15 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
|
||||
async def _on_queue_item_status_changed(self, event: FastAPIEvent[QueueItemStatusChangedEvent]) -> None:
|
||||
if self._queue_item and event[1].status in ["completed", "failed", "canceled"]:
|
||||
# When the queue item is canceled via HTTP, the queue item status is set to `"canceled"` and this event is
|
||||
# emitted. We need to respond to this event and stop graph execution. This is done by setting the cancel
|
||||
# event, which the session runner checks between invocations. If set, the session runner loop is broken.
|
||||
#
|
||||
# Long-running nodes that cannot be interrupted easily present a challenge. `denoise_latents` is one such
|
||||
# node, but it gets a step callback, called on each step of denoising. This callback checks if the queue item
|
||||
# is canceled, and if it is, raises a `CanceledException` to stop execution immediately.
|
||||
if event[1].status == "canceled":
|
||||
self._cancel_event.set()
|
||||
self._poll_now()
|
||||
|
||||
def resume(self) -> SessionProcessorStatus:
|
||||
@ -441,10 +455,11 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
self._invoker.services.logger.error(error_traceback)
|
||||
|
||||
if queue_item is not None:
|
||||
# Update the queue item with the completed session
|
||||
self._invoker.services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
|
||||
# Fail the queue item
|
||||
self._invoker.services.session_queue.fail_queue_item(
|
||||
# Update the queue item with the completed session & fail it
|
||||
queue_item = self._invoker.services.session_queue.set_queue_item_session(
|
||||
queue_item.item_id, queue_item.session
|
||||
)
|
||||
queue_item = self._invoker.services.session_queue.fail_queue_item(
|
||||
item_id=queue_item.item_id,
|
||||
error_type=error_type,
|
||||
error_message=error_message,
|
||||
|
@ -73,6 +73,11 @@ class SessionQueueBase(ABC):
|
||||
"""Gets the status of a batch"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def complete_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
"""Completes a session queue item"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
"""Cancels a session queue item"""
|
||||
|
@ -2,13 +2,6 @@ import sqlite3
|
||||
import threading
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from invokeai.app.services.events.events_common import (
|
||||
FastAPIEvent,
|
||||
InvocationErrorEvent,
|
||||
SessionCanceledEvent,
|
||||
SessionCompleteEvent,
|
||||
register_events,
|
||||
)
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
@ -46,10 +39,6 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
self._set_in_progress_to_canceled()
|
||||
prune_result = self.prune(DEFAULT_QUEUE_ID)
|
||||
|
||||
register_events(InvocationErrorEvent, self._handle_error_event)
|
||||
register_events(SessionCompleteEvent, self._handle_complete_event)
|
||||
register_events(SessionCanceledEvent, self._handle_cancel_event)
|
||||
|
||||
if prune_result.deleted > 0:
|
||||
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
|
||||
|
||||
@ -59,38 +48,6 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
self.__conn = db.conn
|
||||
self.__cursor = self.__conn.cursor()
|
||||
|
||||
async def _handle_complete_event(self, event: FastAPIEvent[SessionCompleteEvent]) -> None:
|
||||
try:
|
||||
# When a queue item has an error, we get an error event, then a completed event.
|
||||
# Mark the queue item completed only if it isn't already marked completed, e.g.
|
||||
# by a previously-handled error event.
|
||||
queue_item = self.get_queue_item(event[1].item_id)
|
||||
if queue_item.status not in ["completed", "failed", "canceled"]:
|
||||
self._set_queue_item_status(item_id=event[1].item_id, status="completed")
|
||||
except SessionQueueItemNotFoundError:
|
||||
pass
|
||||
|
||||
async def _handle_error_event(self, event: FastAPIEvent[InvocationErrorEvent]) -> None:
|
||||
try:
|
||||
# always set to failed if have an error, even if previously the item was marked completed or canceled
|
||||
self._set_queue_item_status(
|
||||
item_id=event[1].item_id,
|
||||
status="failed",
|
||||
error_type=event[1].error_type,
|
||||
error_message=event[1].error_message,
|
||||
error_traceback=event[1].error_traceback,
|
||||
)
|
||||
except SessionQueueItemNotFoundError:
|
||||
pass
|
||||
|
||||
async def _handle_cancel_event(self, event: FastAPIEvent[SessionCanceledEvent]) -> None:
|
||||
try:
|
||||
queue_item = self.get_queue_item(event[1].item_id)
|
||||
if queue_item.status not in ["completed", "failed", "canceled"]:
|
||||
self._set_queue_item_status(item_id=event[1].item_id, status="canceled")
|
||||
except SessionQueueItemNotFoundError:
|
||||
pass
|
||||
|
||||
def _set_in_progress_to_canceled(self) -> None:
|
||||
"""
|
||||
Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue.
|
||||
@ -400,10 +357,11 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
return PruneResult(deleted=count)
|
||||
|
||||
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
if queue_item.status not in ["canceled", "failed", "completed"]:
|
||||
queue_item = self._set_queue_item_status(item_id=item_id, status="canceled")
|
||||
self.__invoker.services.events.emit_session_canceled(queue_item)
|
||||
queue_item = self._set_queue_item_status(item_id=item_id, status="canceled")
|
||||
return queue_item
|
||||
|
||||
def complete_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
queue_item = self._set_queue_item_status(item_id=item_id, status="completed")
|
||||
return queue_item
|
||||
|
||||
def fail_queue_item(
|
||||
@ -413,16 +371,13 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
error_message: str,
|
||||
error_traceback: str,
|
||||
) -> SessionQueueItem:
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
if queue_item.status not in ["canceled", "failed", "completed"]:
|
||||
queue_item = self._set_queue_item_status(
|
||||
item_id=item_id,
|
||||
status="failed",
|
||||
error_type=error_type,
|
||||
error_message=error_message,
|
||||
error_traceback=error_traceback,
|
||||
)
|
||||
self.__invoker.services.events.emit_session_canceled(queue_item)
|
||||
queue_item = self._set_queue_item_status(
|
||||
item_id=item_id,
|
||||
status="failed",
|
||||
error_type=error_type,
|
||||
error_message=error_message,
|
||||
error_traceback=error_traceback,
|
||||
)
|
||||
return queue_item
|
||||
|
||||
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
|
||||
@ -458,7 +413,6 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
)
|
||||
self.__conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
|
||||
self.__invoker.services.events.emit_session_canceled(current_queue_item)
|
||||
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
|
||||
queue_status = self.get_queue_status(queue_id=queue_id)
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(
|
||||
@ -502,7 +456,6 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
)
|
||||
self.__conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
|
||||
self.__invoker.services.events.emit_session_canceled(current_queue_item)
|
||||
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
|
||||
queue_status = self.get_queue_status(queue_id=queue_id)
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(
|
||||
|
@ -1,7 +1,6 @@
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
from typing import TYPE_CHECKING, Callable, Optional, Union
|
||||
|
||||
from PIL.Image import Image
|
||||
from torch import Tensor
|
||||
@ -449,10 +448,10 @@ class ConfigInterface(InvocationContextInterface):
|
||||
|
||||
class UtilInterface(InvocationContextInterface):
|
||||
def __init__(
|
||||
self, services: InvocationServices, data: InvocationContextData, cancel_event: threading.Event
|
||||
self, services: InvocationServices, data: InvocationContextData, is_canceled: Callable[[], bool]
|
||||
) -> None:
|
||||
super().__init__(services, data)
|
||||
self._cancel_event = cancel_event
|
||||
self._is_canceled = is_canceled
|
||||
|
||||
def is_canceled(self) -> bool:
|
||||
"""Checks if the current session has been canceled.
|
||||
@ -460,7 +459,7 @@ class UtilInterface(InvocationContextInterface):
|
||||
Returns:
|
||||
True if the current session has been canceled, False if not.
|
||||
"""
|
||||
return self._cancel_event.is_set()
|
||||
return self._is_canceled()
|
||||
|
||||
def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_model: BaseModelType) -> None:
|
||||
"""
|
||||
@ -535,7 +534,7 @@ class InvocationContext:
|
||||
def build_invocation_context(
|
||||
services: InvocationServices,
|
||||
data: InvocationContextData,
|
||||
cancel_event: threading.Event,
|
||||
is_canceled: Callable[[], bool],
|
||||
) -> InvocationContext:
|
||||
"""Builds the invocation context for a specific invocation execution.
|
||||
|
||||
@ -552,7 +551,7 @@ def build_invocation_context(
|
||||
tensors = TensorsInterface(services=services, data=data)
|
||||
models = ModelsInterface(services=services, data=data)
|
||||
config = ConfigInterface(services=services, data=data)
|
||||
util = UtilInterface(services=services, data=data, cancel_event=cancel_event)
|
||||
util = UtilInterface(services=services, data=data, is_canceled=is_canceled)
|
||||
conditioning = ConditioningInterface(services=services, data=data)
|
||||
boards = BoardsInterface(services=services, data=data)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user