mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
9bd78823a3
Our events handling and implementation has a couple pain points: - Adding or removing data from event payloads requires changes wherever the events are dispatched from. - We have no type safety for events and need to rely on string matching and dict access when interacting with events. - Frontend types for socket events must be manually typed. This has caused several bugs. `fastapi-events` has a neat feature where you can create a pydantic model as an event payload, give it an `__event_name__` attr, and then dispatch the model directly. This allows us to eliminate a layer of indirection and some unpleasant complexity: - Event handler callbacks get type hints for their event payloads, and can use `isinstance` on them if needed. - Event payload construction is now the responsibility of the event itself (a pydantic model), not the service. Every event model has a `build` class method, encapsulating this logic. The build methods are provided as few args as possible. For example, `InvocationStartedEvent.build()` gets the invocation instance and queue item, and can choose the data it wants to include in the event payload. - Frontend event types may be autogenerated from the OpenAPI schema. We use the payload registry feature of `fastapi-events` to collect all payload models into one place, making it trivial to keep our schema and frontend types in sync. This commit moves the backend over to this improved event handling setup.
189 lines
6.9 KiB
Python
189 lines
6.9 KiB
Python
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
|
|
|
|
|
from typing import TYPE_CHECKING, Optional
|
|
|
|
from invokeai.app.services.events.events_common import (
|
|
BaseEvent,
|
|
BatchEnqueuedEvent,
|
|
BulkDownloadCompleteEvent,
|
|
BulkDownloadErrorEvent,
|
|
BulkDownloadStartedEvent,
|
|
DownloadCancelledEvent,
|
|
DownloadCompleteEvent,
|
|
DownloadErrorEvent,
|
|
DownloadProgressEvent,
|
|
DownloadStartedEvent,
|
|
InvocationCompleteEvent,
|
|
InvocationDenoiseProgressEvent,
|
|
InvocationErrorEvent,
|
|
InvocationStartedEvent,
|
|
ModelInstallCancelledEvent,
|
|
ModelInstallCompleteEvent,
|
|
ModelInstallDownloadProgressEvent,
|
|
ModelInstallDownloadsCompleteEvent,
|
|
ModelInstallErrorEvent,
|
|
ModelInstallStartedEvent,
|
|
ModelLoadCompleteEvent,
|
|
ModelLoadStartedEvent,
|
|
QueueClearedEvent,
|
|
QueueItemStatusChangedEvent,
|
|
SessionCanceledEvent,
|
|
SessionCompleteEvent,
|
|
SessionStartedEvent,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
|
from invokeai.app.services.events.events_common import BaseEvent
|
|
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
|
|
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
|
from invokeai.app.services.session_queue.session_queue_common import (
|
|
BatchStatus,
|
|
EnqueueBatchResult,
|
|
SessionQueueItem,
|
|
SessionQueueStatus,
|
|
)
|
|
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
|
|
|
|
|
|
class EventServiceBase:
|
|
"""Basic event bus, to have an empty stand-in when not needed"""
|
|
|
|
def dispatch(self, event: "BaseEvent") -> None:
|
|
pass
|
|
|
|
# region: Invocation
|
|
|
|
def emit_invocation_started(self, queue_item: "SessionQueueItem", invocation: "BaseInvocation") -> None:
|
|
self.dispatch(InvocationStartedEvent.build(queue_item, invocation))
|
|
|
|
def emit_invocation_denoise_progress(
|
|
self,
|
|
queue_item: "SessionQueueItem",
|
|
invocation: "BaseInvocation",
|
|
step: int,
|
|
total_steps: int,
|
|
progress_image: "ProgressImage",
|
|
) -> None:
|
|
self.dispatch(InvocationDenoiseProgressEvent.build(queue_item, invocation, step, total_steps, progress_image))
|
|
|
|
def emit_invocation_complete(
|
|
self, queue_item: "SessionQueueItem", invocation: "BaseInvocation", output: "BaseInvocationOutput"
|
|
) -> None:
|
|
self.dispatch(InvocationCompleteEvent.build(queue_item, invocation, output))
|
|
|
|
def emit_invocation_error(
|
|
self,
|
|
queue_item: "SessionQueueItem",
|
|
invocation: "BaseInvocation",
|
|
error_type: str,
|
|
error_message: str,
|
|
error_traceback: str,
|
|
) -> None:
|
|
self.dispatch(InvocationErrorEvent.build(queue_item, invocation, error_type, error_message, error_traceback))
|
|
|
|
# endregion
|
|
|
|
# region Session
|
|
|
|
def emit_session_started(self, queue_item: "SessionQueueItem") -> None:
|
|
self.dispatch(SessionStartedEvent.build(queue_item))
|
|
|
|
def emit_session_complete(self, queue_item: "SessionQueueItem") -> None:
|
|
self.dispatch(SessionCompleteEvent.build(queue_item))
|
|
|
|
def emit_session_canceled(self, queue_item: "SessionQueueItem") -> None:
|
|
self.dispatch(SessionCanceledEvent.build(queue_item))
|
|
|
|
# endregion
|
|
|
|
# region Queue
|
|
|
|
def emit_queue_item_status_changed(
|
|
self, queue_item: "SessionQueueItem", batch_status: "BatchStatus", queue_status: "SessionQueueStatus"
|
|
) -> None:
|
|
self.dispatch(QueueItemStatusChangedEvent.build(queue_item, batch_status, queue_status))
|
|
|
|
def emit_batch_enqueued(self, enqueue_result: "EnqueueBatchResult") -> None:
|
|
self.dispatch(BatchEnqueuedEvent.build(enqueue_result))
|
|
|
|
def emit_queue_cleared(self, queue_id: str) -> None:
|
|
self.dispatch(QueueClearedEvent.build(queue_id))
|
|
|
|
# endregion
|
|
|
|
# region Download
|
|
|
|
def emit_download_started(self, source: str, download_path: str) -> None:
|
|
self.dispatch(DownloadStartedEvent.build(source, download_path))
|
|
|
|
def emit_download_progress(self, source: str, download_path: str, current_bytes: int, total_bytes: int) -> None:
|
|
self.dispatch(DownloadProgressEvent.build(source, download_path, current_bytes, total_bytes))
|
|
|
|
def emit_download_complete(self, source: str, download_path: str, total_bytes: int) -> None:
|
|
self.dispatch(DownloadCompleteEvent.build(source, download_path, total_bytes))
|
|
|
|
def emit_download_cancelled(self, source: str) -> None:
|
|
self.dispatch(DownloadCancelledEvent.build(source))
|
|
|
|
def emit_download_error(self, source: str, error_type: str, error: str) -> None:
|
|
self.dispatch(DownloadErrorEvent.build(source, error_type, error))
|
|
|
|
# endregion
|
|
|
|
# region Model loading
|
|
|
|
def emit_model_load_started(self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None) -> None:
|
|
self.dispatch(ModelLoadStartedEvent.build(config, submodel_type))
|
|
|
|
def emit_model_load_complete(
|
|
self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None
|
|
) -> None:
|
|
self.dispatch(ModelLoadCompleteEvent.build(config, submodel_type))
|
|
|
|
# endregion
|
|
|
|
# region Model install
|
|
|
|
def emit_model_install_download_progress(self, job: "ModelInstallJob") -> None:
|
|
self.dispatch(ModelInstallDownloadProgressEvent.build(job))
|
|
|
|
def emit_model_install_downloads_complete(self, job: "ModelInstallJob") -> None:
|
|
self.dispatch(ModelInstallDownloadsCompleteEvent.build(job))
|
|
|
|
def emit_model_install_started(self, job: "ModelInstallJob") -> None:
|
|
self.dispatch(ModelInstallStartedEvent.build(job))
|
|
|
|
def emit_model_install_complete(self, job: "ModelInstallJob") -> None:
|
|
self.dispatch(ModelInstallCompleteEvent.build(job))
|
|
|
|
def emit_model_install_cancelled(self, job: "ModelInstallJob") -> None:
|
|
self.dispatch(ModelInstallCancelledEvent.build(job))
|
|
|
|
def emit_model_install_error(self, job: "ModelInstallJob") -> None:
|
|
self.dispatch(ModelInstallErrorEvent.build(job))
|
|
|
|
# endregion
|
|
|
|
# region Bulk image download
|
|
|
|
def emit_bulk_download_started(
|
|
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
|
|
) -> None:
|
|
self.dispatch(BulkDownloadStartedEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name))
|
|
|
|
def emit_bulk_download_complete(
|
|
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
|
|
) -> None:
|
|
self.dispatch(BulkDownloadCompleteEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name))
|
|
|
|
def emit_bulk_download_error(
|
|
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
|
|
) -> None:
|
|
self.dispatch(
|
|
BulkDownloadErrorEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, error)
|
|
)
|
|
|
|
# endregion
|