diff --git a/invokeai/app/services/download/download_default.py b/invokeai/app/services/download/download_default.py index 4c9d9bda13..180f0f1a8c 100644 --- a/invokeai/app/services/download/download_default.py +++ b/invokeai/app/services/download/download_default.py @@ -345,8 +345,7 @@ class DownloadQueueService(DownloadQueueServiceBase): f"An error occurred while processing the on_start callback: {traceback.format_exception(e)}" ) if self._event_bus: - assert job.download_path - self._event_bus.emit_download_started(str(job.source), job.download_path.as_posix()) + self._event_bus.emit_download_started(job) def _signal_job_progress(self, job: DownloadJob) -> None: if job.on_progress: @@ -357,13 +356,7 @@ class DownloadQueueService(DownloadQueueServiceBase): f"An error occurred while processing the on_progress callback: {traceback.format_exception(e)}" ) if self._event_bus: - assert job.download_path - self._event_bus.emit_download_progress( - str(job.source), - download_path=job.download_path.as_posix(), - current_bytes=job.bytes, - total_bytes=job.total_bytes, - ) + self._event_bus.emit_download_progress(job) def _signal_job_complete(self, job: DownloadJob) -> None: job.status = DownloadJobStatus.COMPLETED @@ -375,10 +368,7 @@ class DownloadQueueService(DownloadQueueServiceBase): f"An error occurred while processing the on_complete callback: {traceback.format_exception(e)}" ) if self._event_bus: - assert job.download_path - self._event_bus.emit_download_complete( - str(job.source), download_path=job.download_path.as_posix(), total_bytes=job.total_bytes - ) + self._event_bus.emit_download_complete(job) def _signal_job_cancelled(self, job: DownloadJob) -> None: if job.status not in [DownloadJobStatus.RUNNING, DownloadJobStatus.WAITING]: @@ -392,7 +382,7 @@ class DownloadQueueService(DownloadQueueServiceBase): f"An error occurred while processing the on_cancelled callback: {traceback.format_exception(e)}" ) if self._event_bus: - self._event_bus.emit_download_cancelled(str(job.source)) + self._event_bus.emit_download_cancelled(job) def _signal_job_error(self, job: DownloadJob, excp: Optional[Exception] = None) -> None: job.status = DownloadJobStatus.ERROR @@ -405,9 +395,7 @@ class DownloadQueueService(DownloadQueueServiceBase): f"An error occurred while processing the on_error callback: {traceback.format_exception(e)}" ) if self._event_bus: - assert job.error_type - assert job.error - self._event_bus.emit_download_error(str(job.source), error_type=job.error_type, error=job.error) + self._event_bus.emit_download_error(job) def _cleanup_cancelled_job(self, job: DownloadJob) -> None: self._logger.debug(f"Cleaning up leftover files from cancelled download job {job.download_path}") diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index eb64856bb5..b055d4249f 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -35,6 +35,7 @@ from invokeai.app.services.events.events_common import ( if TYPE_CHECKING: from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput + from invokeai.app.services.download.download_base import DownloadJob from invokeai.app.services.events.events_common import EventBase from invokeai.app.services.model_install.model_install_common import ModelInstallJob from invokeai.app.services.session_processor.session_processor_common import ProgressImage @@ -120,25 +121,25 @@ class EventServiceBase: # region Download - def emit_download_started(self, source: str, download_path: str) -> None: + def emit_download_started(self, job: "DownloadJob") -> None: """Emitted when a download is started""" - self.dispatch(DownloadStartedEvent.build(source, download_path)) + self.dispatch(DownloadStartedEvent.build(job)) - def emit_download_progress(self, source: str, download_path: str, current_bytes: int, total_bytes: int) -> None: + def emit_download_progress(self, job: "DownloadJob") -> None: """Emitted at intervals during a download""" - self.dispatch(DownloadProgressEvent.build(source, download_path, current_bytes, total_bytes)) + self.dispatch(DownloadProgressEvent.build(job)) - def emit_download_complete(self, source: str, download_path: str, total_bytes: int) -> None: + def emit_download_complete(self, job: "DownloadJob") -> None: """Emitted when a download is completed""" - self.dispatch(DownloadCompleteEvent.build(source, download_path, total_bytes)) + self.dispatch(DownloadCompleteEvent.build(job)) - def emit_download_cancelled(self, source: str) -> None: + def emit_download_cancelled(self, job: "DownloadJob") -> None: """Emitted when a download is cancelled""" - self.dispatch(DownloadCancelledEvent.build(source)) + self.dispatch(DownloadCancelledEvent.build(job)) - def emit_download_error(self, source: str, error_type: str, error: str) -> None: + def emit_download_error(self, job: "DownloadJob") -> None: """Emitted when a download encounters an error""" - self.dispatch(DownloadErrorEvent.build(source, error_type, error)) + self.dispatch(DownloadErrorEvent.build(job)) # endregion diff --git a/invokeai/app/services/events/events_common.py b/invokeai/app/services/events/events_common.py index 2d75f0a2a4..e007b5a0d3 100644 --- a/invokeai/app/services/events/events_common.py +++ b/invokeai/app/services/events/events_common.py @@ -17,6 +17,7 @@ from invokeai.app.util.misc import get_timestamp from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType if TYPE_CHECKING: + from invokeai.app.services.download.download_base import DownloadJob from invokeai.app.services.model_install.model_install_common import ModelInstallJob @@ -325,8 +326,9 @@ class DownloadStartedEvent(DownloadEventBase): download_path: str = Field(description="The local path where the download is saved") @classmethod - def build(cls, source: str, download_path: str) -> "DownloadStartedEvent": - return cls(source=source, download_path=download_path) + def build(cls, job: "DownloadJob") -> "DownloadStartedEvent": + assert job.download_path + return cls(source=str(job.source), download_path=job.download_path.as_posix()) @payload_schema.register # pyright: ignore [reportUnknownMemberType] @@ -340,8 +342,14 @@ class DownloadProgressEvent(DownloadEventBase): total_bytes: int = Field(description="The total number of bytes to be downloaded") @classmethod - def build(cls, source: str, download_path: str, current_bytes: int, total_bytes: int) -> "DownloadProgressEvent": - return cls(source=source, download_path=download_path, current_bytes=current_bytes, total_bytes=total_bytes) + def build(cls, job: "DownloadJob") -> "DownloadProgressEvent": + assert job.download_path + return cls( + source=str(job.source), + download_path=job.download_path.as_posix(), + current_bytes=job.bytes, + total_bytes=job.total_bytes, + ) @payload_schema.register # pyright: ignore [reportUnknownMemberType] @@ -354,8 +362,9 @@ class DownloadCompleteEvent(DownloadEventBase): total_bytes: int = Field(description="The total number of bytes downloaded") @classmethod - def build(cls, source: str, download_path: str, total_bytes: int) -> "DownloadCompleteEvent": - return cls(source=source, download_path=download_path, total_bytes=total_bytes) + def build(cls, job: "DownloadJob") -> "DownloadCompleteEvent": + assert job.download_path + return cls(source=str(job.source), download_path=job.download_path.as_posix(), total_bytes=job.total_bytes) @payload_schema.register # pyright: ignore [reportUnknownMemberType] @@ -365,8 +374,8 @@ class DownloadCancelledEvent(DownloadEventBase): __event_name__ = "download_cancelled" @classmethod - def build(cls, source: str) -> "DownloadCancelledEvent": - return cls(source=source) + def build(cls, job: "DownloadJob") -> "DownloadCancelledEvent": + return cls(source=str(job.source)) @payload_schema.register # pyright: ignore [reportUnknownMemberType] @@ -379,8 +388,10 @@ class DownloadErrorEvent(DownloadEventBase): error: str = Field(description="The error message") @classmethod - def build(cls, source: str, error_type: str, error: str) -> "DownloadErrorEvent": - return cls(source=source, error_type=error_type, error=error) + def build(cls, job: "DownloadJob") -> "DownloadErrorEvent": + assert job.error_type + assert job.error + return cls(source=str(job.source), error_type=job.error_type, error=job.error) class ModelEventBase(EventBase):