diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index cf49cc0626..bb578c23e8 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -22,6 +22,7 @@ from invokeai.app.services.events.events_common import ( ModelInstallCompleteEvent, ModelInstallDownloadProgressEvent, ModelInstallDownloadsCompleteEvent, + ModelInstallDownloadStartedEvent, ModelInstallErrorEvent, ModelInstallStartedEvent, ModelLoadCompleteEvent, @@ -144,6 +145,10 @@ class EventServiceBase: # region Model install + def emit_model_install_download_started(self, job: "ModelInstallJob") -> None: + """Emitted at intervals while the install job is started (remote models only).""" + self.dispatch(ModelInstallDownloadStartedEvent.build(job)) + def emit_model_install_download_progress(self, job: "ModelInstallJob") -> None: """Emitted at intervals while the install job is in progress (remote models only).""" self.dispatch(ModelInstallDownloadProgressEvent.build(job)) diff --git a/invokeai/app/services/events/events_common.py b/invokeai/app/services/events/events_common.py index 0adcaa2ab1..c6a867fb08 100644 --- a/invokeai/app/services/events/events_common.py +++ b/invokeai/app/services/events/events_common.py @@ -417,6 +417,42 @@ class ModelLoadCompleteEvent(ModelEventBase): return cls(config=config, submodel_type=submodel_type) +@payload_schema.register +class ModelInstallDownloadStartedEvent(ModelEventBase): + """Event model for model_install_download_started""" + + __event_name__ = "model_install_download_started" + + id: int = Field(description="The ID of the install job") + source: str = Field(description="Source of the model; local path, repo_id or url") + local_path: str = Field(description="Where model is downloading to") + bytes: int = Field(description="Number of bytes downloaded so far") + total_bytes: int = Field(description="Total size of download, including all files") + parts: list[dict[str, int | str]] = Field( + description="Progress of downloading URLs that comprise the model, if any" + ) + + @classmethod + def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadStartedEvent": + parts: list[dict[str, str | int]] = [ + { + "url": str(x.source), + "local_path": str(x.download_path), + "bytes": x.bytes, + "total_bytes": x.total_bytes, + } + for x in job.download_parts + ] + return cls( + id=job.id, + source=str(job.source), + local_path=job.local_path.as_posix(), + parts=parts, + bytes=job.bytes, + total_bytes=job.total_bytes, + ) + + @payload_schema.register class ModelInstallDownloadProgressEvent(ModelEventBase): """Event model for model_install_download_progress""" diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 0a2e2d798a..dd1b44d899 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -822,7 +822,7 @@ class ModelInstallService(ModelInstallServiceBase): install_job.download_parts = download_job.download_parts install_job.bytes = sum(x.bytes for x in download_job.download_parts) install_job.total_bytes = download_job.total_bytes - self._signal_job_downloading(install_job) + self._signal_job_download_started(install_job) def _download_progress_callback(self, download_job: MultiFileDownloadJob) -> None: with self._lock: @@ -874,6 +874,13 @@ class ModelInstallService(ModelInstallServiceBase): if self._event_bus: self._event_bus.emit_model_install_started(job) + def _signal_job_download_started(self, job: ModelInstallJob) -> None: + if self._event_bus: + assert job._multifile_job is not None + assert job.bytes is not None + assert job.total_bytes is not None + self._event_bus.emit_model_install_download_started(job) + def _signal_job_downloading(self, job: ModelInstallJob) -> None: if self._event_bus: assert job._multifile_job is not None