mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(events): use builder pattern for download events
This commit is contained in:
parent
1f92e9eec2
commit
b1e2dd222e
@ -345,8 +345,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
f"An error occurred while processing the on_start callback: {traceback.format_exception(e)}"
|
||||
)
|
||||
if self._event_bus:
|
||||
assert job.download_path
|
||||
self._event_bus.emit_download_started(str(job.source), job.download_path.as_posix())
|
||||
self._event_bus.emit_download_started(job)
|
||||
|
||||
def _signal_job_progress(self, job: DownloadJob) -> None:
|
||||
if job.on_progress:
|
||||
@ -357,13 +356,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
f"An error occurred while processing the on_progress callback: {traceback.format_exception(e)}"
|
||||
)
|
||||
if self._event_bus:
|
||||
assert job.download_path
|
||||
self._event_bus.emit_download_progress(
|
||||
str(job.source),
|
||||
download_path=job.download_path.as_posix(),
|
||||
current_bytes=job.bytes,
|
||||
total_bytes=job.total_bytes,
|
||||
)
|
||||
self._event_bus.emit_download_progress(job)
|
||||
|
||||
def _signal_job_complete(self, job: DownloadJob) -> None:
|
||||
job.status = DownloadJobStatus.COMPLETED
|
||||
@ -375,10 +368,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
f"An error occurred while processing the on_complete callback: {traceback.format_exception(e)}"
|
||||
)
|
||||
if self._event_bus:
|
||||
assert job.download_path
|
||||
self._event_bus.emit_download_complete(
|
||||
str(job.source), download_path=job.download_path.as_posix(), total_bytes=job.total_bytes
|
||||
)
|
||||
self._event_bus.emit_download_complete(job)
|
||||
|
||||
def _signal_job_cancelled(self, job: DownloadJob) -> None:
|
||||
if job.status not in [DownloadJobStatus.RUNNING, DownloadJobStatus.WAITING]:
|
||||
@ -392,7 +382,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
f"An error occurred while processing the on_cancelled callback: {traceback.format_exception(e)}"
|
||||
)
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_download_cancelled(str(job.source))
|
||||
self._event_bus.emit_download_cancelled(job)
|
||||
|
||||
def _signal_job_error(self, job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
||||
job.status = DownloadJobStatus.ERROR
|
||||
@ -405,9 +395,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
f"An error occurred while processing the on_error callback: {traceback.format_exception(e)}"
|
||||
)
|
||||
if self._event_bus:
|
||||
assert job.error_type
|
||||
assert job.error
|
||||
self._event_bus.emit_download_error(str(job.source), error_type=job.error_type, error=job.error)
|
||||
self._event_bus.emit_download_error(job)
|
||||
|
||||
def _cleanup_cancelled_job(self, job: DownloadJob) -> None:
|
||||
self._logger.debug(f"Cleaning up leftover files from cancelled download job {job.download_path}")
|
||||
|
@ -35,6 +35,7 @@ from invokeai.app.services.events.events_common import (
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||
from invokeai.app.services.download.download_base import DownloadJob
|
||||
from invokeai.app.services.events.events_common import EventBase
|
||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
|
||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||
@ -120,25 +121,25 @@ class EventServiceBase:
|
||||
|
||||
# region Download
|
||||
|
||||
def emit_download_started(self, source: str, download_path: str) -> None:
|
||||
def emit_download_started(self, job: "DownloadJob") -> None:
|
||||
"""Emitted when a download is started"""
|
||||
self.dispatch(DownloadStartedEvent.build(source, download_path))
|
||||
self.dispatch(DownloadStartedEvent.build(job))
|
||||
|
||||
def emit_download_progress(self, source: str, download_path: str, current_bytes: int, total_bytes: int) -> None:
|
||||
def emit_download_progress(self, job: "DownloadJob") -> None:
|
||||
"""Emitted at intervals during a download"""
|
||||
self.dispatch(DownloadProgressEvent.build(source, download_path, current_bytes, total_bytes))
|
||||
self.dispatch(DownloadProgressEvent.build(job))
|
||||
|
||||
def emit_download_complete(self, source: str, download_path: str, total_bytes: int) -> None:
|
||||
def emit_download_complete(self, job: "DownloadJob") -> None:
|
||||
"""Emitted when a download is completed"""
|
||||
self.dispatch(DownloadCompleteEvent.build(source, download_path, total_bytes))
|
||||
self.dispatch(DownloadCompleteEvent.build(job))
|
||||
|
||||
def emit_download_cancelled(self, source: str) -> None:
|
||||
def emit_download_cancelled(self, job: "DownloadJob") -> None:
|
||||
"""Emitted when a download is cancelled"""
|
||||
self.dispatch(DownloadCancelledEvent.build(source))
|
||||
self.dispatch(DownloadCancelledEvent.build(job))
|
||||
|
||||
def emit_download_error(self, source: str, error_type: str, error: str) -> None:
|
||||
def emit_download_error(self, job: "DownloadJob") -> None:
|
||||
"""Emitted when a download encounters an error"""
|
||||
self.dispatch(DownloadErrorEvent.build(source, error_type, error))
|
||||
self.dispatch(DownloadErrorEvent.build(job))
|
||||
|
||||
# endregion
|
||||
|
||||
|
@ -17,6 +17,7 @@ from invokeai.app.util.misc import get_timestamp
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.download.download_base import DownloadJob
|
||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
|
||||
|
||||
|
||||
@ -325,8 +326,9 @@ class DownloadStartedEvent(DownloadEventBase):
|
||||
download_path: str = Field(description="The local path where the download is saved")
|
||||
|
||||
@classmethod
|
||||
def build(cls, source: str, download_path: str) -> "DownloadStartedEvent":
|
||||
return cls(source=source, download_path=download_path)
|
||||
def build(cls, job: "DownloadJob") -> "DownloadStartedEvent":
|
||||
assert job.download_path
|
||||
return cls(source=str(job.source), download_path=job.download_path.as_posix())
|
||||
|
||||
|
||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
||||
@ -340,8 +342,14 @@ class DownloadProgressEvent(DownloadEventBase):
|
||||
total_bytes: int = Field(description="The total number of bytes to be downloaded")
|
||||
|
||||
@classmethod
|
||||
def build(cls, source: str, download_path: str, current_bytes: int, total_bytes: int) -> "DownloadProgressEvent":
|
||||
return cls(source=source, download_path=download_path, current_bytes=current_bytes, total_bytes=total_bytes)
|
||||
def build(cls, job: "DownloadJob") -> "DownloadProgressEvent":
|
||||
assert job.download_path
|
||||
return cls(
|
||||
source=str(job.source),
|
||||
download_path=job.download_path.as_posix(),
|
||||
current_bytes=job.bytes,
|
||||
total_bytes=job.total_bytes,
|
||||
)
|
||||
|
||||
|
||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
||||
@ -354,8 +362,9 @@ class DownloadCompleteEvent(DownloadEventBase):
|
||||
total_bytes: int = Field(description="The total number of bytes downloaded")
|
||||
|
||||
@classmethod
|
||||
def build(cls, source: str, download_path: str, total_bytes: int) -> "DownloadCompleteEvent":
|
||||
return cls(source=source, download_path=download_path, total_bytes=total_bytes)
|
||||
def build(cls, job: "DownloadJob") -> "DownloadCompleteEvent":
|
||||
assert job.download_path
|
||||
return cls(source=str(job.source), download_path=job.download_path.as_posix(), total_bytes=job.total_bytes)
|
||||
|
||||
|
||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
||||
@ -365,8 +374,8 @@ class DownloadCancelledEvent(DownloadEventBase):
|
||||
__event_name__ = "download_cancelled"
|
||||
|
||||
@classmethod
|
||||
def build(cls, source: str) -> "DownloadCancelledEvent":
|
||||
return cls(source=source)
|
||||
def build(cls, job: "DownloadJob") -> "DownloadCancelledEvent":
|
||||
return cls(source=str(job.source))
|
||||
|
||||
|
||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
||||
@ -379,8 +388,10 @@ class DownloadErrorEvent(DownloadEventBase):
|
||||
error: str = Field(description="The error message")
|
||||
|
||||
@classmethod
|
||||
def build(cls, source: str, error_type: str, error: str) -> "DownloadErrorEvent":
|
||||
return cls(source=source, error_type=error_type, error=error)
|
||||
def build(cls, job: "DownloadJob") -> "DownloadErrorEvent":
|
||||
assert job.error_type
|
||||
assert job.error
|
||||
return cls(source=str(job.source), error_type=job.error_type, error=job.error)
|
||||
|
||||
|
||||
class ModelEventBase(EventBase):
|
||||
|
Loading…
Reference in New Issue
Block a user