add tests for model install file size reporting

This commit is contained in:
Lincoln Stein 2024-05-16 07:18:33 -04:00
parent f29c406fed
commit 911a24479b
2 changed files with 11 additions and 3 deletions

View File

@ -855,8 +855,8 @@ class ModelInstallService(ModelInstallServiceBase):
str(job.source), str(job.source),
local_path=job.local_path.as_posix(), local_path=job.local_path.as_posix(),
parts=parts, parts=parts,
bytes=job.bytes, bytes=sum(x["bytes"] for x in parts),
total_bytes=job.total_bytes, total_bytes=sum(x["total_bytes"] for x in parts),
id=job.id, id=job.id,
) )
@ -875,7 +875,10 @@ class ModelInstallService(ModelInstallServiceBase):
assert job.local_path is not None assert job.local_path is not None
assert job.config_out is not None assert job.config_out is not None
key = job.config_out.key key = job.config_out.key
self._event_bus.emit_model_install_completed(str(job.source), key, id=job.id) self._event_bus.emit_model_install_completed(source=str(job.source),
key=key,
id=job.id,
total_bytes=job.bytes)
def _signal_job_errored(self, job: ModelInstallJob) -> None: def _signal_job_errored(self, job: ModelInstallJob) -> None:
self._logger.error(f"Model install error: {job.source}\n{job.error_type}: {job.error}") self._logger.error(f"Model install error: {job.source}\n{job.error_type}: {job.error}")

View File

@ -317,6 +317,11 @@ def test_huggingface_repo_id(mm2_installer: ModelInstallServiceBase, mm2_app_con
"model_install_completed", "model_install_completed",
} }
completed_events = [x for x in bus.events if x.event_name == "model_install_completed"]
downloading_events = [x for x in bus.events if x.event_name == "model_install_downloading"]
assert completed_events[0].payload["total_bytes"] == downloading_events[-1].payload["bytes"]
assert job.total_bytes == completed_events[0].payload["total_bytes"]
assert job.total_bytes == sum(x["total_bytes"] for x in downloading_events[-1].payload["parts"])
def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
source = URLModelSource(url=Url("https://test.com/missing_model.safetensors")) source = URLModelSource(url=Url("https://test.com/missing_model.safetensors"))