mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add tests for model install file size reporting
This commit is contained in:
parent
f29c406fed
commit
911a24479b
@ -855,8 +855,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
str(job.source),
|
str(job.source),
|
||||||
local_path=job.local_path.as_posix(),
|
local_path=job.local_path.as_posix(),
|
||||||
parts=parts,
|
parts=parts,
|
||||||
bytes=job.bytes,
|
bytes=sum(x["bytes"] for x in parts),
|
||||||
total_bytes=job.total_bytes,
|
total_bytes=sum(x["total_bytes"] for x in parts),
|
||||||
id=job.id,
|
id=job.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -875,7 +875,10 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
assert job.local_path is not None
|
assert job.local_path is not None
|
||||||
assert job.config_out is not None
|
assert job.config_out is not None
|
||||||
key = job.config_out.key
|
key = job.config_out.key
|
||||||
self._event_bus.emit_model_install_completed(str(job.source), key, id=job.id)
|
self._event_bus.emit_model_install_completed(source=str(job.source),
|
||||||
|
key=key,
|
||||||
|
id=job.id,
|
||||||
|
total_bytes=job.bytes)
|
||||||
|
|
||||||
def _signal_job_errored(self, job: ModelInstallJob) -> None:
|
def _signal_job_errored(self, job: ModelInstallJob) -> None:
|
||||||
self._logger.error(f"Model install error: {job.source}\n{job.error_type}: {job.error}")
|
self._logger.error(f"Model install error: {job.source}\n{job.error_type}: {job.error}")
|
||||||
|
@ -317,6 +317,11 @@ def test_huggingface_repo_id(mm2_installer: ModelInstallServiceBase, mm2_app_con
|
|||||||
"model_install_completed",
|
"model_install_completed",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
completed_events = [x for x in bus.events if x.event_name == "model_install_completed"]
|
||||||
|
downloading_events = [x for x in bus.events if x.event_name == "model_install_downloading"]
|
||||||
|
assert completed_events[0].payload["total_bytes"] == downloading_events[-1].payload["bytes"]
|
||||||
|
assert job.total_bytes == completed_events[0].payload["total_bytes"]
|
||||||
|
assert job.total_bytes == sum(x["total_bytes"] for x in downloading_events[-1].payload["parts"])
|
||||||
|
|
||||||
def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||||
source = URLModelSource(url=Url("https://test.com/missing_model.safetensors"))
|
source = URLModelSource(url=Url("https://test.com/missing_model.safetensors"))
|
||||||
|
Loading…
Reference in New Issue
Block a user