feat(events): use builder pattern for download events

This commit is contained in:
psychedelicious 2024-03-31 12:02:49 +11:00
parent 1f92e9eec2
commit b1e2dd222e
3 changed files with 37 additions and 37 deletions

View File

@ -345,8 +345,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
f"An error occurred while processing the on_start callback: {traceback.format_exception(e)}"
)
if self._event_bus:
assert job.download_path
self._event_bus.emit_download_started(str(job.source), job.download_path.as_posix())
self._event_bus.emit_download_started(job)
def _signal_job_progress(self, job: DownloadJob) -> None:
if job.on_progress:
@ -357,13 +356,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
f"An error occurred while processing the on_progress callback: {traceback.format_exception(e)}"
)
if self._event_bus:
assert job.download_path
self._event_bus.emit_download_progress(
str(job.source),
download_path=job.download_path.as_posix(),
current_bytes=job.bytes,
total_bytes=job.total_bytes,
)
self._event_bus.emit_download_progress(job)
def _signal_job_complete(self, job: DownloadJob) -> None:
job.status = DownloadJobStatus.COMPLETED
@ -375,10 +368,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
f"An error occurred while processing the on_complete callback: {traceback.format_exception(e)}"
)
if self._event_bus:
assert job.download_path
self._event_bus.emit_download_complete(
str(job.source), download_path=job.download_path.as_posix(), total_bytes=job.total_bytes
)
self._event_bus.emit_download_complete(job)
def _signal_job_cancelled(self, job: DownloadJob) -> None:
if job.status not in [DownloadJobStatus.RUNNING, DownloadJobStatus.WAITING]:
@ -392,7 +382,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
f"An error occurred while processing the on_cancelled callback: {traceback.format_exception(e)}"
)
if self._event_bus:
self._event_bus.emit_download_cancelled(str(job.source))
self._event_bus.emit_download_cancelled(job)
def _signal_job_error(self, job: DownloadJob, excp: Optional[Exception] = None) -> None:
job.status = DownloadJobStatus.ERROR
@ -405,9 +395,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
f"An error occurred while processing the on_error callback: {traceback.format_exception(e)}"
)
if self._event_bus:
assert job.error_type
assert job.error
self._event_bus.emit_download_error(str(job.source), error_type=job.error_type, error=job.error)
self._event_bus.emit_download_error(job)
def _cleanup_cancelled_job(self, job: DownloadJob) -> None:
self._logger.debug(f"Cleaning up leftover files from cancelled download job {job.download_path}")

View File

@ -35,6 +35,7 @@ from invokeai.app.services.events.events_common import (
if TYPE_CHECKING:
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.download.download_base import DownloadJob
from invokeai.app.services.events.events_common import EventBase
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
@ -120,25 +121,25 @@ class EventServiceBase:
# region Download
def emit_download_started(self, source: str, download_path: str) -> None:
def emit_download_started(self, job: "DownloadJob") -> None:
"""Emitted when a download is started"""
self.dispatch(DownloadStartedEvent.build(source, download_path))
self.dispatch(DownloadStartedEvent.build(job))
def emit_download_progress(self, source: str, download_path: str, current_bytes: int, total_bytes: int) -> None:
def emit_download_progress(self, job: "DownloadJob") -> None:
"""Emitted at intervals during a download"""
self.dispatch(DownloadProgressEvent.build(source, download_path, current_bytes, total_bytes))
self.dispatch(DownloadProgressEvent.build(job))
def emit_download_complete(self, source: str, download_path: str, total_bytes: int) -> None:
def emit_download_complete(self, job: "DownloadJob") -> None:
"""Emitted when a download is completed"""
self.dispatch(DownloadCompleteEvent.build(source, download_path, total_bytes))
self.dispatch(DownloadCompleteEvent.build(job))
def emit_download_cancelled(self, source: str) -> None:
def emit_download_cancelled(self, job: "DownloadJob") -> None:
"""Emitted when a download is cancelled"""
self.dispatch(DownloadCancelledEvent.build(source))
self.dispatch(DownloadCancelledEvent.build(job))
def emit_download_error(self, source: str, error_type: str, error: str) -> None:
def emit_download_error(self, job: "DownloadJob") -> None:
"""Emitted when a download encounters an error"""
self.dispatch(DownloadErrorEvent.build(source, error_type, error))
self.dispatch(DownloadErrorEvent.build(job))
# endregion

View File

@ -17,6 +17,7 @@ from invokeai.app.util.misc import get_timestamp
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
if TYPE_CHECKING:
from invokeai.app.services.download.download_base import DownloadJob
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
@ -325,8 +326,9 @@ class DownloadStartedEvent(DownloadEventBase):
download_path: str = Field(description="The local path where the download is saved")
@classmethod
def build(cls, source: str, download_path: str) -> "DownloadStartedEvent":
return cls(source=source, download_path=download_path)
def build(cls, job: "DownloadJob") -> "DownloadStartedEvent":
assert job.download_path
return cls(source=str(job.source), download_path=job.download_path.as_posix())
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
@ -340,8 +342,14 @@ class DownloadProgressEvent(DownloadEventBase):
total_bytes: int = Field(description="The total number of bytes to be downloaded")
@classmethod
def build(cls, source: str, download_path: str, current_bytes: int, total_bytes: int) -> "DownloadProgressEvent":
return cls(source=source, download_path=download_path, current_bytes=current_bytes, total_bytes=total_bytes)
def build(cls, job: "DownloadJob") -> "DownloadProgressEvent":
assert job.download_path
return cls(
source=str(job.source),
download_path=job.download_path.as_posix(),
current_bytes=job.bytes,
total_bytes=job.total_bytes,
)
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
@ -354,8 +362,9 @@ class DownloadCompleteEvent(DownloadEventBase):
total_bytes: int = Field(description="The total number of bytes downloaded")
@classmethod
def build(cls, source: str, download_path: str, total_bytes: int) -> "DownloadCompleteEvent":
return cls(source=source, download_path=download_path, total_bytes=total_bytes)
def build(cls, job: "DownloadJob") -> "DownloadCompleteEvent":
assert job.download_path
return cls(source=str(job.source), download_path=job.download_path.as_posix(), total_bytes=job.total_bytes)
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
@ -365,8 +374,8 @@ class DownloadCancelledEvent(DownloadEventBase):
__event_name__ = "download_cancelled"
@classmethod
def build(cls, source: str) -> "DownloadCancelledEvent":
return cls(source=source)
def build(cls, job: "DownloadJob") -> "DownloadCancelledEvent":
return cls(source=str(job.source))
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
@ -379,8 +388,10 @@ class DownloadErrorEvent(DownloadEventBase):
error: str = Field(description="The error message")
@classmethod
def build(cls, source: str, error_type: str, error: str) -> "DownloadErrorEvent":
return cls(source=source, error_type=error_type, error=error)
def build(cls, job: "DownloadJob") -> "DownloadErrorEvent":
assert job.error_type
assert job.error
return cls(source=str(job.source), error_type=job.error_type, error=job.error)
class ModelEventBase(EventBase):