InvokeAI/invokeai/app/services/events/events_base.py
psychedelicious 2dc752ea83 feat(events): simplify event classes
- Remove ABCs, they do not work well with pydantic
- Remove the event type classvar - unused
- Remove clever logic to require an event name - we already get validation for this during schema registration.
- Rename event bases to all end in "Base"
2024-05-27 09:06:02 +10:00

189 lines
6.9 KiB
Python

# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import TYPE_CHECKING, Optional
from invokeai.app.services.events.events_common import (
BatchEnqueuedEvent,
BulkDownloadCompleteEvent,
BulkDownloadErrorEvent,
BulkDownloadStartedEvent,
DownloadCancelledEvent,
DownloadCompleteEvent,
DownloadErrorEvent,
DownloadProgressEvent,
DownloadStartedEvent,
EventBase,
InvocationCompleteEvent,
InvocationDenoiseProgressEvent,
InvocationErrorEvent,
InvocationStartedEvent,
ModelInstallCancelledEvent,
ModelInstallCompleteEvent,
ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent,
ModelInstallErrorEvent,
ModelInstallStartedEvent,
ModelLoadCompleteEvent,
ModelLoadStartedEvent,
QueueClearedEvent,
QueueItemStatusChangedEvent,
SessionCanceledEvent,
SessionCompleteEvent,
SessionStartedEvent,
)
if TYPE_CHECKING:
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.events.events_common import EventBase
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import (
BatchStatus,
EnqueueBatchResult,
SessionQueueItem,
SessionQueueStatus,
)
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
class EventServiceBase:
"""Basic event bus, to have an empty stand-in when not needed"""
def dispatch(self, event: "EventBase") -> None:
pass
# region: Invocation
def emit_invocation_started(self, queue_item: "SessionQueueItem", invocation: "BaseInvocation") -> None:
self.dispatch(InvocationStartedEvent.build(queue_item, invocation))
def emit_invocation_denoise_progress(
self,
queue_item: "SessionQueueItem",
invocation: "BaseInvocation",
step: int,
total_steps: int,
progress_image: "ProgressImage",
) -> None:
self.dispatch(InvocationDenoiseProgressEvent.build(queue_item, invocation, step, total_steps, progress_image))
def emit_invocation_complete(
self, queue_item: "SessionQueueItem", invocation: "BaseInvocation", output: "BaseInvocationOutput"
) -> None:
self.dispatch(InvocationCompleteEvent.build(queue_item, invocation, output))
def emit_invocation_error(
self,
queue_item: "SessionQueueItem",
invocation: "BaseInvocation",
error_type: str,
error_message: str,
error_traceback: str,
) -> None:
self.dispatch(InvocationErrorEvent.build(queue_item, invocation, error_type, error_message, error_traceback))
# endregion
# region Session
def emit_session_started(self, queue_item: "SessionQueueItem") -> None:
self.dispatch(SessionStartedEvent.build(queue_item))
def emit_session_complete(self, queue_item: "SessionQueueItem") -> None:
self.dispatch(SessionCompleteEvent.build(queue_item))
def emit_session_canceled(self, queue_item: "SessionQueueItem") -> None:
self.dispatch(SessionCanceledEvent.build(queue_item))
# endregion
# region Queue
def emit_queue_item_status_changed(
self, queue_item: "SessionQueueItem", batch_status: "BatchStatus", queue_status: "SessionQueueStatus"
) -> None:
self.dispatch(QueueItemStatusChangedEvent.build(queue_item, batch_status, queue_status))
def emit_batch_enqueued(self, enqueue_result: "EnqueueBatchResult") -> None:
self.dispatch(BatchEnqueuedEvent.build(enqueue_result))
def emit_queue_cleared(self, queue_id: str) -> None:
self.dispatch(QueueClearedEvent.build(queue_id))
# endregion
# region Download
def emit_download_started(self, source: str, download_path: str) -> None:
self.dispatch(DownloadStartedEvent.build(source, download_path))
def emit_download_progress(self, source: str, download_path: str, current_bytes: int, total_bytes: int) -> None:
self.dispatch(DownloadProgressEvent.build(source, download_path, current_bytes, total_bytes))
def emit_download_complete(self, source: str, download_path: str, total_bytes: int) -> None:
self.dispatch(DownloadCompleteEvent.build(source, download_path, total_bytes))
def emit_download_cancelled(self, source: str) -> None:
self.dispatch(DownloadCancelledEvent.build(source))
def emit_download_error(self, source: str, error_type: str, error: str) -> None:
self.dispatch(DownloadErrorEvent.build(source, error_type, error))
# endregion
# region Model loading
def emit_model_load_started(self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None) -> None:
self.dispatch(ModelLoadStartedEvent.build(config, submodel_type))
def emit_model_load_complete(
self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None
) -> None:
self.dispatch(ModelLoadCompleteEvent.build(config, submodel_type))
# endregion
# region Model install
def emit_model_install_download_progress(self, job: "ModelInstallJob") -> None:
self.dispatch(ModelInstallDownloadProgressEvent.build(job))
def emit_model_install_downloads_complete(self, job: "ModelInstallJob") -> None:
self.dispatch(ModelInstallDownloadsCompleteEvent.build(job))
def emit_model_install_started(self, job: "ModelInstallJob") -> None:
self.dispatch(ModelInstallStartedEvent.build(job))
def emit_model_install_complete(self, job: "ModelInstallJob") -> None:
self.dispatch(ModelInstallCompleteEvent.build(job))
def emit_model_install_cancelled(self, job: "ModelInstallJob") -> None:
self.dispatch(ModelInstallCancelledEvent.build(job))
def emit_model_install_error(self, job: "ModelInstallJob") -> None:
self.dispatch(ModelInstallErrorEvent.build(job))
# endregion
# region Bulk image download
def emit_bulk_download_started(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> None:
self.dispatch(BulkDownloadStartedEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name))
def emit_bulk_download_complete(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> None:
self.dispatch(BulkDownloadCompleteEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name))
def emit_bulk_download_error(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
) -> None:
self.dispatch(
BulkDownloadErrorEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, error)
)
# endregion