mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(app): add model_install_download_started
event
Previously, we used `model_install_download_progress` for both download starting and progressing. When handling this event, we don't know which actual thing it represents. Add `model_install_download_started` event to explicitly represent a model download started event.
This commit is contained in:
parent
1bc98abc76
commit
fb694b3e17
@ -22,6 +22,7 @@ from invokeai.app.services.events.events_common import (
|
|||||||
ModelInstallCompleteEvent,
|
ModelInstallCompleteEvent,
|
||||||
ModelInstallDownloadProgressEvent,
|
ModelInstallDownloadProgressEvent,
|
||||||
ModelInstallDownloadsCompleteEvent,
|
ModelInstallDownloadsCompleteEvent,
|
||||||
|
ModelInstallDownloadStartedEvent,
|
||||||
ModelInstallErrorEvent,
|
ModelInstallErrorEvent,
|
||||||
ModelInstallStartedEvent,
|
ModelInstallStartedEvent,
|
||||||
ModelLoadCompleteEvent,
|
ModelLoadCompleteEvent,
|
||||||
@ -144,6 +145,10 @@ class EventServiceBase:
|
|||||||
|
|
||||||
# region Model install
|
# region Model install
|
||||||
|
|
||||||
|
def emit_model_install_download_started(self, job: "ModelInstallJob") -> None:
|
||||||
|
"""Emitted at intervals while the install job is started (remote models only)."""
|
||||||
|
self.dispatch(ModelInstallDownloadStartedEvent.build(job))
|
||||||
|
|
||||||
def emit_model_install_download_progress(self, job: "ModelInstallJob") -> None:
|
def emit_model_install_download_progress(self, job: "ModelInstallJob") -> None:
|
||||||
"""Emitted at intervals while the install job is in progress (remote models only)."""
|
"""Emitted at intervals while the install job is in progress (remote models only)."""
|
||||||
self.dispatch(ModelInstallDownloadProgressEvent.build(job))
|
self.dispatch(ModelInstallDownloadProgressEvent.build(job))
|
||||||
|
@ -417,6 +417,42 @@ class ModelLoadCompleteEvent(ModelEventBase):
|
|||||||
return cls(config=config, submodel_type=submodel_type)
|
return cls(config=config, submodel_type=submodel_type)
|
||||||
|
|
||||||
|
|
||||||
|
@payload_schema.register
|
||||||
|
class ModelInstallDownloadStartedEvent(ModelEventBase):
|
||||||
|
"""Event model for model_install_download_started"""
|
||||||
|
|
||||||
|
__event_name__ = "model_install_download_started"
|
||||||
|
|
||||||
|
id: int = Field(description="The ID of the install job")
|
||||||
|
source: str = Field(description="Source of the model; local path, repo_id or url")
|
||||||
|
local_path: str = Field(description="Where model is downloading to")
|
||||||
|
bytes: int = Field(description="Number of bytes downloaded so far")
|
||||||
|
total_bytes: int = Field(description="Total size of download, including all files")
|
||||||
|
parts: list[dict[str, int | str]] = Field(
|
||||||
|
description="Progress of downloading URLs that comprise the model, if any"
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadStartedEvent":
|
||||||
|
parts: list[dict[str, str | int]] = [
|
||||||
|
{
|
||||||
|
"url": str(x.source),
|
||||||
|
"local_path": str(x.download_path),
|
||||||
|
"bytes": x.bytes,
|
||||||
|
"total_bytes": x.total_bytes,
|
||||||
|
}
|
||||||
|
for x in job.download_parts
|
||||||
|
]
|
||||||
|
return cls(
|
||||||
|
id=job.id,
|
||||||
|
source=str(job.source),
|
||||||
|
local_path=job.local_path.as_posix(),
|
||||||
|
parts=parts,
|
||||||
|
bytes=job.bytes,
|
||||||
|
total_bytes=job.total_bytes,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@payload_schema.register
|
@payload_schema.register
|
||||||
class ModelInstallDownloadProgressEvent(ModelEventBase):
|
class ModelInstallDownloadProgressEvent(ModelEventBase):
|
||||||
"""Event model for model_install_download_progress"""
|
"""Event model for model_install_download_progress"""
|
||||||
|
@ -822,7 +822,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
install_job.download_parts = download_job.download_parts
|
install_job.download_parts = download_job.download_parts
|
||||||
install_job.bytes = sum(x.bytes for x in download_job.download_parts)
|
install_job.bytes = sum(x.bytes for x in download_job.download_parts)
|
||||||
install_job.total_bytes = download_job.total_bytes
|
install_job.total_bytes = download_job.total_bytes
|
||||||
self._signal_job_downloading(install_job)
|
self._signal_job_download_started(install_job)
|
||||||
|
|
||||||
def _download_progress_callback(self, download_job: MultiFileDownloadJob) -> None:
|
def _download_progress_callback(self, download_job: MultiFileDownloadJob) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
@ -874,6 +874,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
if self._event_bus:
|
if self._event_bus:
|
||||||
self._event_bus.emit_model_install_started(job)
|
self._event_bus.emit_model_install_started(job)
|
||||||
|
|
||||||
|
def _signal_job_download_started(self, job: ModelInstallJob) -> None:
|
||||||
|
if self._event_bus:
|
||||||
|
assert job._multifile_job is not None
|
||||||
|
assert job.bytes is not None
|
||||||
|
assert job.total_bytes is not None
|
||||||
|
self._event_bus.emit_model_install_download_started(job)
|
||||||
|
|
||||||
def _signal_job_downloading(self, job: ModelInstallJob) -> None:
|
def _signal_job_downloading(self, job: ModelInstallJob) -> None:
|
||||||
if self._event_bus:
|
if self._event_bus:
|
||||||
assert job._multifile_job is not None
|
assert job._multifile_job is not None
|
||||||
|
Loading…
Reference in New Issue
Block a user