# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) from typing import TYPE_CHECKING, Optional from invokeai.app.services.events.events_common import ( BatchEnqueuedEvent, BulkDownloadCompleteEvent, BulkDownloadErrorEvent, BulkDownloadStartedEvent, DownloadCancelledEvent, DownloadCompleteEvent, DownloadErrorEvent, DownloadProgressEvent, DownloadStartedEvent, EventBase, InvocationCompleteEvent, InvocationDenoiseProgressEvent, InvocationErrorEvent, InvocationStartedEvent, ModelInstallCancelledEvent, ModelInstallCompleteEvent, ModelInstallDownloadProgressEvent, ModelInstallDownloadsCompleteEvent, ModelInstallErrorEvent, ModelInstallStartedEvent, ModelLoadCompleteEvent, ModelLoadStartedEvent, QueueClearedEvent, QueueItemStatusChangedEvent, ) from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState if TYPE_CHECKING: from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput from invokeai.app.services.download.download_base import DownloadJob from invokeai.app.services.events.events_common import EventBase from invokeai.app.services.model_install.model_install_common import ModelInstallJob from invokeai.app.services.session_processor.session_processor_common import ProgressImage from invokeai.app.services.session_queue.session_queue_common import ( BatchStatus, EnqueueBatchResult, SessionQueueItem, SessionQueueStatus, ) from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType class EventServiceBase: """Basic event bus, to have an empty stand-in when not needed""" def dispatch(self, event: "EventBase") -> None: pass # region: Invocation def emit_invocation_started(self, queue_item: "SessionQueueItem", invocation: "BaseInvocation") -> None: """Emitted when an invocation is started""" self.dispatch(InvocationStartedEvent.build(queue_item, invocation)) def emit_invocation_denoise_progress( self, queue_item: "SessionQueueItem", invocation: "BaseInvocation", intermediate_state: PipelineIntermediateState, progress_image: "ProgressImage", ) -> None: """Emitted at each step during denoising of an invocation.""" self.dispatch(InvocationDenoiseProgressEvent.build(queue_item, invocation, intermediate_state, progress_image)) def emit_invocation_complete( self, queue_item: "SessionQueueItem", invocation: "BaseInvocation", output: "BaseInvocationOutput" ) -> None: """Emitted when an invocation is complete""" self.dispatch(InvocationCompleteEvent.build(queue_item, invocation, output)) def emit_invocation_error( self, queue_item: "SessionQueueItem", invocation: "BaseInvocation", error_type: str, error_message: str, error_traceback: str, ) -> None: """Emitted when an invocation encounters an error""" self.dispatch(InvocationErrorEvent.build(queue_item, invocation, error_type, error_message, error_traceback)) # endregion # region Queue def emit_queue_item_status_changed( self, queue_item: "SessionQueueItem", batch_status: "BatchStatus", queue_status: "SessionQueueStatus" ) -> None: """Emitted when a queue item's status changes""" self.dispatch(QueueItemStatusChangedEvent.build(queue_item, batch_status, queue_status)) def emit_batch_enqueued(self, enqueue_result: "EnqueueBatchResult") -> None: """Emitted when a batch is enqueued""" self.dispatch(BatchEnqueuedEvent.build(enqueue_result)) def emit_queue_cleared(self, queue_id: str) -> None: """Emitted when a queue is cleared""" self.dispatch(QueueClearedEvent.build(queue_id)) # endregion # region Download def emit_download_started(self, job: "DownloadJob") -> None: """Emitted when a download is started""" self.dispatch(DownloadStartedEvent.build(job)) def emit_download_progress(self, job: "DownloadJob") -> None: """Emitted at intervals during a download""" self.dispatch(DownloadProgressEvent.build(job)) def emit_download_complete(self, job: "DownloadJob") -> None: """Emitted when a download is completed""" self.dispatch(DownloadCompleteEvent.build(job)) def emit_download_cancelled(self, job: "DownloadJob") -> None: """Emitted when a download is cancelled""" self.dispatch(DownloadCancelledEvent.build(job)) def emit_download_error(self, job: "DownloadJob") -> None: """Emitted when a download encounters an error""" self.dispatch(DownloadErrorEvent.build(job)) # endregion # region Model loading def emit_model_load_started(self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None) -> None: """Emitted when a model load is started.""" self.dispatch(ModelLoadStartedEvent.build(config, submodel_type)) def emit_model_load_complete( self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None ) -> None: """Emitted when a model load is complete.""" self.dispatch(ModelLoadCompleteEvent.build(config, submodel_type)) # endregion # region Model install def emit_model_install_download_progress(self, job: "ModelInstallJob") -> None: """Emitted at intervals while the install job is in progress (remote models only).""" self.dispatch(ModelInstallDownloadProgressEvent.build(job)) def emit_model_install_downloads_complete(self, job: "ModelInstallJob") -> None: self.dispatch(ModelInstallDownloadsCompleteEvent.build(job)) def emit_model_install_started(self, job: "ModelInstallJob") -> None: """Emitted once when an install job is started (after any download).""" self.dispatch(ModelInstallStartedEvent.build(job)) def emit_model_install_complete(self, job: "ModelInstallJob") -> None: """Emitted when an install job is completed successfully.""" self.dispatch(ModelInstallCompleteEvent.build(job)) def emit_model_install_cancelled(self, job: "ModelInstallJob") -> None: """Emitted when an install job is cancelled.""" self.dispatch(ModelInstallCancelledEvent.build(job)) def emit_model_install_error(self, job: "ModelInstallJob") -> None: """Emitted when an install job encounters an exception.""" self.dispatch(ModelInstallErrorEvent.build(job)) # endregion # region Bulk image download def emit_bulk_download_started( self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str ) -> None: """Emitted when a bulk image download is started""" self.dispatch(BulkDownloadStartedEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name)) def emit_bulk_download_complete( self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str ) -> None: """Emitted when a bulk image download is complete""" self.dispatch(BulkDownloadCompleteEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name)) def emit_bulk_download_error( self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str ) -> None: """Emitted when a bulk image download has an error""" self.dispatch( BulkDownloadErrorEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, error) ) # endregion