mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
084cf26ed6
There's no longer any need for session-scoped events now that we have the session queue. Session started/completed/canceled map 1-to-1 to queue item status events, but queue item status events also have an event for failed state. We can simplify queue and processor handling substantially by removing session events and instead using queue item events. - Remove the session-scoped events entirely. - Remove all event handling from session queue. The processor still needs to respond to some events from the queue: `QueueClearedEvent`, `BatchEnqueuedEvent` and `QueueItemStatusChangedEvent`. - Pass an `is_canceled` callback to the invocation context instead of the cancel event - Update processor logic to ensure the local instance of the current queue item is synced with the instance in the database. This prevents race conditions and ensures lifecycle callback do not get stale callbacks. - Update docstrings and comments - Add `complete_queue_item` method to session queue service as an explicit way to mark a queue item as successfully completed. Previously, the queue listened for session complete events to do this. Closes #6442
196 lines
7.6 KiB
Python
196 lines
7.6 KiB
Python
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
|
|
|
|
|
from typing import TYPE_CHECKING, Optional
|
|
|
|
from invokeai.app.services.events.events_common import (
|
|
BatchEnqueuedEvent,
|
|
BulkDownloadCompleteEvent,
|
|
BulkDownloadErrorEvent,
|
|
BulkDownloadStartedEvent,
|
|
DownloadCancelledEvent,
|
|
DownloadCompleteEvent,
|
|
DownloadErrorEvent,
|
|
DownloadProgressEvent,
|
|
DownloadStartedEvent,
|
|
EventBase,
|
|
InvocationCompleteEvent,
|
|
InvocationDenoiseProgressEvent,
|
|
InvocationErrorEvent,
|
|
InvocationStartedEvent,
|
|
ModelInstallCancelledEvent,
|
|
ModelInstallCompleteEvent,
|
|
ModelInstallDownloadProgressEvent,
|
|
ModelInstallDownloadsCompleteEvent,
|
|
ModelInstallErrorEvent,
|
|
ModelInstallStartedEvent,
|
|
ModelLoadCompleteEvent,
|
|
ModelLoadStartedEvent,
|
|
QueueClearedEvent,
|
|
QueueItemStatusChangedEvent,
|
|
)
|
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
|
|
|
if TYPE_CHECKING:
|
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
|
from invokeai.app.services.download.download_base import DownloadJob
|
|
from invokeai.app.services.events.events_common import EventBase
|
|
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
|
|
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
|
from invokeai.app.services.session_queue.session_queue_common import (
|
|
BatchStatus,
|
|
EnqueueBatchResult,
|
|
SessionQueueItem,
|
|
SessionQueueStatus,
|
|
)
|
|
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
|
|
|
|
|
|
class EventServiceBase:
|
|
"""Basic event bus, to have an empty stand-in when not needed"""
|
|
|
|
def dispatch(self, event: "EventBase") -> None:
|
|
pass
|
|
|
|
# region: Invocation
|
|
|
|
def emit_invocation_started(self, queue_item: "SessionQueueItem", invocation: "BaseInvocation") -> None:
|
|
"""Emitted when an invocation is started"""
|
|
self.dispatch(InvocationStartedEvent.build(queue_item, invocation))
|
|
|
|
def emit_invocation_denoise_progress(
|
|
self,
|
|
queue_item: "SessionQueueItem",
|
|
invocation: "BaseInvocation",
|
|
intermediate_state: PipelineIntermediateState,
|
|
progress_image: "ProgressImage",
|
|
) -> None:
|
|
"""Emitted at each step during denoising of an invocation."""
|
|
self.dispatch(InvocationDenoiseProgressEvent.build(queue_item, invocation, intermediate_state, progress_image))
|
|
|
|
def emit_invocation_complete(
|
|
self, queue_item: "SessionQueueItem", invocation: "BaseInvocation", output: "BaseInvocationOutput"
|
|
) -> None:
|
|
"""Emitted when an invocation is complete"""
|
|
self.dispatch(InvocationCompleteEvent.build(queue_item, invocation, output))
|
|
|
|
def emit_invocation_error(
|
|
self,
|
|
queue_item: "SessionQueueItem",
|
|
invocation: "BaseInvocation",
|
|
error_type: str,
|
|
error_message: str,
|
|
error_traceback: str,
|
|
) -> None:
|
|
"""Emitted when an invocation encounters an error"""
|
|
self.dispatch(InvocationErrorEvent.build(queue_item, invocation, error_type, error_message, error_traceback))
|
|
|
|
# endregion
|
|
|
|
# region Queue
|
|
|
|
def emit_queue_item_status_changed(
|
|
self, queue_item: "SessionQueueItem", batch_status: "BatchStatus", queue_status: "SessionQueueStatus"
|
|
) -> None:
|
|
"""Emitted when a queue item's status changes"""
|
|
self.dispatch(QueueItemStatusChangedEvent.build(queue_item, batch_status, queue_status))
|
|
|
|
def emit_batch_enqueued(self, enqueue_result: "EnqueueBatchResult") -> None:
|
|
"""Emitted when a batch is enqueued"""
|
|
self.dispatch(BatchEnqueuedEvent.build(enqueue_result))
|
|
|
|
def emit_queue_cleared(self, queue_id: str) -> None:
|
|
"""Emitted when a queue is cleared"""
|
|
self.dispatch(QueueClearedEvent.build(queue_id))
|
|
|
|
# endregion
|
|
|
|
# region Download
|
|
|
|
def emit_download_started(self, job: "DownloadJob") -> None:
|
|
"""Emitted when a download is started"""
|
|
self.dispatch(DownloadStartedEvent.build(job))
|
|
|
|
def emit_download_progress(self, job: "DownloadJob") -> None:
|
|
"""Emitted at intervals during a download"""
|
|
self.dispatch(DownloadProgressEvent.build(job))
|
|
|
|
def emit_download_complete(self, job: "DownloadJob") -> None:
|
|
"""Emitted when a download is completed"""
|
|
self.dispatch(DownloadCompleteEvent.build(job))
|
|
|
|
def emit_download_cancelled(self, job: "DownloadJob") -> None:
|
|
"""Emitted when a download is cancelled"""
|
|
self.dispatch(DownloadCancelledEvent.build(job))
|
|
|
|
def emit_download_error(self, job: "DownloadJob") -> None:
|
|
"""Emitted when a download encounters an error"""
|
|
self.dispatch(DownloadErrorEvent.build(job))
|
|
|
|
# endregion
|
|
|
|
# region Model loading
|
|
|
|
def emit_model_load_started(self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None) -> None:
|
|
"""Emitted when a model load is started."""
|
|
self.dispatch(ModelLoadStartedEvent.build(config, submodel_type))
|
|
|
|
def emit_model_load_complete(
|
|
self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None
|
|
) -> None:
|
|
"""Emitted when a model load is complete."""
|
|
self.dispatch(ModelLoadCompleteEvent.build(config, submodel_type))
|
|
|
|
# endregion
|
|
|
|
# region Model install
|
|
|
|
def emit_model_install_download_progress(self, job: "ModelInstallJob") -> None:
|
|
"""Emitted at intervals while the install job is in progress (remote models only)."""
|
|
self.dispatch(ModelInstallDownloadProgressEvent.build(job))
|
|
|
|
def emit_model_install_downloads_complete(self, job: "ModelInstallJob") -> None:
|
|
self.dispatch(ModelInstallDownloadsCompleteEvent.build(job))
|
|
|
|
def emit_model_install_started(self, job: "ModelInstallJob") -> None:
|
|
"""Emitted once when an install job is started (after any download)."""
|
|
self.dispatch(ModelInstallStartedEvent.build(job))
|
|
|
|
def emit_model_install_complete(self, job: "ModelInstallJob") -> None:
|
|
"""Emitted when an install job is completed successfully."""
|
|
self.dispatch(ModelInstallCompleteEvent.build(job))
|
|
|
|
def emit_model_install_cancelled(self, job: "ModelInstallJob") -> None:
|
|
"""Emitted when an install job is cancelled."""
|
|
self.dispatch(ModelInstallCancelledEvent.build(job))
|
|
|
|
def emit_model_install_error(self, job: "ModelInstallJob") -> None:
|
|
"""Emitted when an install job encounters an exception."""
|
|
self.dispatch(ModelInstallErrorEvent.build(job))
|
|
|
|
# endregion
|
|
|
|
# region Bulk image download
|
|
|
|
def emit_bulk_download_started(
|
|
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
|
|
) -> None:
|
|
"""Emitted when a bulk image download is started"""
|
|
self.dispatch(BulkDownloadStartedEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name))
|
|
|
|
def emit_bulk_download_complete(
|
|
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
|
|
) -> None:
|
|
"""Emitted when a bulk image download is complete"""
|
|
self.dispatch(BulkDownloadCompleteEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name))
|
|
|
|
def emit_bulk_download_error(
|
|
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
|
|
) -> None:
|
|
"""Emitted when a bulk image download has an error"""
|
|
self.dispatch(
|
|
BulkDownloadErrorEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, error)
|
|
)
|
|
|
|
# endregion
|