mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fb694b3e17
Previously, we used `model_install_download_progress` for both download starting and progressing. When handling this event, we don't know which actual thing it represents. Add `model_install_download_started` event to explicitly represent a model download started event.
200 lines
7.8 KiB
Python
200 lines
7.8 KiB
Python
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
|
|
|
|
|
from typing import TYPE_CHECKING, Optional
|
|
|
|
from invokeai.app.services.events.events_common import (
|
|
BatchEnqueuedEvent,
|
|
BulkDownloadCompleteEvent,
|
|
BulkDownloadErrorEvent,
|
|
BulkDownloadStartedEvent,
|
|
DownloadCancelledEvent,
|
|
DownloadCompleteEvent,
|
|
DownloadErrorEvent,
|
|
DownloadProgressEvent,
|
|
DownloadStartedEvent,
|
|
EventBase,
|
|
InvocationCompleteEvent,
|
|
InvocationDenoiseProgressEvent,
|
|
InvocationErrorEvent,
|
|
InvocationStartedEvent,
|
|
ModelInstallCancelledEvent,
|
|
ModelInstallCompleteEvent,
|
|
ModelInstallDownloadProgressEvent,
|
|
ModelInstallDownloadsCompleteEvent,
|
|
ModelInstallDownloadStartedEvent,
|
|
ModelInstallErrorEvent,
|
|
ModelInstallStartedEvent,
|
|
ModelLoadCompleteEvent,
|
|
ModelLoadStartedEvent,
|
|
QueueClearedEvent,
|
|
QueueItemStatusChangedEvent,
|
|
)
|
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
|
|
|
if TYPE_CHECKING:
|
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
|
from invokeai.app.services.download.download_base import DownloadJob
|
|
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
|
|
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
|
from invokeai.app.services.session_queue.session_queue_common import (
|
|
BatchStatus,
|
|
EnqueueBatchResult,
|
|
SessionQueueItem,
|
|
SessionQueueStatus,
|
|
)
|
|
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
|
|
|
|
|
|
class EventServiceBase:
|
|
"""Basic event bus, to have an empty stand-in when not needed"""
|
|
|
|
def dispatch(self, event: "EventBase") -> None:
|
|
pass
|
|
|
|
# region: Invocation
|
|
|
|
def emit_invocation_started(self, queue_item: "SessionQueueItem", invocation: "BaseInvocation") -> None:
|
|
"""Emitted when an invocation is started"""
|
|
self.dispatch(InvocationStartedEvent.build(queue_item, invocation))
|
|
|
|
def emit_invocation_denoise_progress(
|
|
self,
|
|
queue_item: "SessionQueueItem",
|
|
invocation: "BaseInvocation",
|
|
intermediate_state: PipelineIntermediateState,
|
|
progress_image: "ProgressImage",
|
|
) -> None:
|
|
"""Emitted at each step during denoising of an invocation."""
|
|
self.dispatch(InvocationDenoiseProgressEvent.build(queue_item, invocation, intermediate_state, progress_image))
|
|
|
|
def emit_invocation_complete(
|
|
self, queue_item: "SessionQueueItem", invocation: "BaseInvocation", output: "BaseInvocationOutput"
|
|
) -> None:
|
|
"""Emitted when an invocation is complete"""
|
|
self.dispatch(InvocationCompleteEvent.build(queue_item, invocation, output))
|
|
|
|
def emit_invocation_error(
|
|
self,
|
|
queue_item: "SessionQueueItem",
|
|
invocation: "BaseInvocation",
|
|
error_type: str,
|
|
error_message: str,
|
|
error_traceback: str,
|
|
) -> None:
|
|
"""Emitted when an invocation encounters an error"""
|
|
self.dispatch(InvocationErrorEvent.build(queue_item, invocation, error_type, error_message, error_traceback))
|
|
|
|
# endregion
|
|
|
|
# region Queue
|
|
|
|
def emit_queue_item_status_changed(
|
|
self, queue_item: "SessionQueueItem", batch_status: "BatchStatus", queue_status: "SessionQueueStatus"
|
|
) -> None:
|
|
"""Emitted when a queue item's status changes"""
|
|
self.dispatch(QueueItemStatusChangedEvent.build(queue_item, batch_status, queue_status))
|
|
|
|
def emit_batch_enqueued(self, enqueue_result: "EnqueueBatchResult") -> None:
|
|
"""Emitted when a batch is enqueued"""
|
|
self.dispatch(BatchEnqueuedEvent.build(enqueue_result))
|
|
|
|
def emit_queue_cleared(self, queue_id: str) -> None:
|
|
"""Emitted when a queue is cleared"""
|
|
self.dispatch(QueueClearedEvent.build(queue_id))
|
|
|
|
# endregion
|
|
|
|
# region Download
|
|
|
|
def emit_download_started(self, job: "DownloadJob") -> None:
|
|
"""Emitted when a download is started"""
|
|
self.dispatch(DownloadStartedEvent.build(job))
|
|
|
|
def emit_download_progress(self, job: "DownloadJob") -> None:
|
|
"""Emitted at intervals during a download"""
|
|
self.dispatch(DownloadProgressEvent.build(job))
|
|
|
|
def emit_download_complete(self, job: "DownloadJob") -> None:
|
|
"""Emitted when a download is completed"""
|
|
self.dispatch(DownloadCompleteEvent.build(job))
|
|
|
|
def emit_download_cancelled(self, job: "DownloadJob") -> None:
|
|
"""Emitted when a download is cancelled"""
|
|
self.dispatch(DownloadCancelledEvent.build(job))
|
|
|
|
def emit_download_error(self, job: "DownloadJob") -> None:
|
|
"""Emitted when a download encounters an error"""
|
|
self.dispatch(DownloadErrorEvent.build(job))
|
|
|
|
# endregion
|
|
|
|
# region Model loading
|
|
|
|
def emit_model_load_started(self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None) -> None:
|
|
"""Emitted when a model load is started."""
|
|
self.dispatch(ModelLoadStartedEvent.build(config, submodel_type))
|
|
|
|
def emit_model_load_complete(
|
|
self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None
|
|
) -> None:
|
|
"""Emitted when a model load is complete."""
|
|
self.dispatch(ModelLoadCompleteEvent.build(config, submodel_type))
|
|
|
|
# endregion
|
|
|
|
# region Model install
|
|
|
|
def emit_model_install_download_started(self, job: "ModelInstallJob") -> None:
|
|
"""Emitted at intervals while the install job is started (remote models only)."""
|
|
self.dispatch(ModelInstallDownloadStartedEvent.build(job))
|
|
|
|
def emit_model_install_download_progress(self, job: "ModelInstallJob") -> None:
|
|
"""Emitted at intervals while the install job is in progress (remote models only)."""
|
|
self.dispatch(ModelInstallDownloadProgressEvent.build(job))
|
|
|
|
def emit_model_install_downloads_complete(self, job: "ModelInstallJob") -> None:
|
|
self.dispatch(ModelInstallDownloadsCompleteEvent.build(job))
|
|
|
|
def emit_model_install_started(self, job: "ModelInstallJob") -> None:
|
|
"""Emitted once when an install job is started (after any download)."""
|
|
self.dispatch(ModelInstallStartedEvent.build(job))
|
|
|
|
def emit_model_install_complete(self, job: "ModelInstallJob") -> None:
|
|
"""Emitted when an install job is completed successfully."""
|
|
self.dispatch(ModelInstallCompleteEvent.build(job))
|
|
|
|
def emit_model_install_cancelled(self, job: "ModelInstallJob") -> None:
|
|
"""Emitted when an install job is cancelled."""
|
|
self.dispatch(ModelInstallCancelledEvent.build(job))
|
|
|
|
def emit_model_install_error(self, job: "ModelInstallJob") -> None:
|
|
"""Emitted when an install job encounters an exception."""
|
|
self.dispatch(ModelInstallErrorEvent.build(job))
|
|
|
|
# endregion
|
|
|
|
# region Bulk image download
|
|
|
|
def emit_bulk_download_started(
|
|
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
|
|
) -> None:
|
|
"""Emitted when a bulk image download is started"""
|
|
self.dispatch(BulkDownloadStartedEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name))
|
|
|
|
def emit_bulk_download_complete(
|
|
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
|
|
) -> None:
|
|
"""Emitted when a bulk image download is complete"""
|
|
self.dispatch(BulkDownloadCompleteEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name))
|
|
|
|
def emit_bulk_download_error(
|
|
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
|
|
) -> None:
|
|
"""Emitted when a bulk image download has an error"""
|
|
self.dispatch(
|
|
BulkDownloadErrorEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, error)
|
|
)
|
|
|
|
# endregion
|