mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix race condition causing hangs during model install unit tests (#5994)
* fix race condition causing hangs during model install unit tests * remove extraneous sanity checks --------- Co-authored-by: Lincoln Stein <lstein@gmail.com>
This commit is contained in:
parent
c87497fd54
commit
74a51571a0
@ -386,6 +386,17 @@ class EventServiceBase:
|
||||
},
|
||||
)
|
||||
|
||||
def emit_model_install_downloads_done(self, source: str) -> None:
|
||||
"""
|
||||
Emit once when all parts are downloaded, but before the probing and registration start.
|
||||
|
||||
:param source: Source of the model; local path, repo_id or url
|
||||
"""
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_downloads_done",
|
||||
payload={"source": source},
|
||||
)
|
||||
|
||||
def emit_model_install_running(self, source: str) -> None:
|
||||
"""
|
||||
Emit once when an install job becomes active.
|
||||
|
@ -124,15 +124,28 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
if not self._running:
|
||||
raise Exception("Attempt to stop the install service before it was started")
|
||||
self._stop_event.set()
|
||||
with self._install_queue.mutex:
|
||||
self._install_queue.queue.clear() # get rid of pending jobs
|
||||
active_jobs = [x for x in self.list_jobs() if x.running]
|
||||
if active_jobs:
|
||||
self._logger.warning("Waiting for active install job to complete")
|
||||
self.wait_for_installs()
|
||||
self._clear_pending_jobs()
|
||||
self._download_cache.clear()
|
||||
self._running = False
|
||||
|
||||
def _clear_pending_jobs(self) -> None:
|
||||
for job in self.list_jobs():
|
||||
if not job.in_terminal_state:
|
||||
self._logger.warning("Cancelling job {job.id}")
|
||||
self.cancel_job(job)
|
||||
while True:
|
||||
try:
|
||||
job = self._install_queue.get(block=False)
|
||||
self._install_queue.task_done()
|
||||
except Empty:
|
||||
break
|
||||
|
||||
def _put_in_queue(self, job: ModelInstallJob) -> None:
|
||||
if self._stop_event.is_set():
|
||||
self.cancel_job(job)
|
||||
else:
|
||||
self._install_queue.put(job)
|
||||
|
||||
def register_path(
|
||||
self,
|
||||
model_path: Union[Path, str],
|
||||
@ -218,7 +231,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
|
||||
if isinstance(source, LocalModelSource):
|
||||
install_job = self._import_local_model(source, config)
|
||||
self._install_queue.put(install_job) # synchronously install
|
||||
self._put_in_queue(install_job) # synchronously install
|
||||
elif isinstance(source, HFModelSource):
|
||||
install_job = self._import_from_hf(source, config)
|
||||
elif isinstance(source, URLModelSource):
|
||||
@ -253,7 +266,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
raise TimeoutError("Timeout exceeded")
|
||||
return job
|
||||
|
||||
# TODO: Better name? Maybe wait_for_jobs()? Maybe too easily confused with above
|
||||
def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: # noqa D102
|
||||
"""Block until all installation jobs are done."""
|
||||
start = time.time()
|
||||
@ -412,7 +424,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
job = self._install_queue.get(timeout=1)
|
||||
except Empty:
|
||||
continue
|
||||
|
||||
assert job.local_path is not None
|
||||
try:
|
||||
if job.cancelled:
|
||||
@ -462,8 +473,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._install_completed_event.set()
|
||||
self._install_queue.task_done()
|
||||
|
||||
self._logger.info("Install thread exiting")
|
||||
|
||||
# --------------------------------------------------------------------------------------------
|
||||
# Internal functions that manage the models directory
|
||||
# --------------------------------------------------------------------------------------------
|
||||
@ -779,14 +788,14 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._logger.info(f"{download_job.source}: model download complete")
|
||||
with self._lock:
|
||||
install_job = self._download_cache[download_job.source]
|
||||
self._download_cache.pop(download_job.source, None)
|
||||
|
||||
# are there any more active jobs left in this task?
|
||||
if install_job.downloading and all(x.complete for x in install_job.download_parts):
|
||||
install_job.status = InstallStatus.DOWNLOADS_DONE
|
||||
self._install_queue.put(install_job)
|
||||
self._signal_job_downloads_done(install_job)
|
||||
self._put_in_queue(install_job)
|
||||
|
||||
# Let other threads know that the number of downloads has changed
|
||||
self._download_cache.pop(download_job.source, None)
|
||||
self._downloads_changed_event.set()
|
||||
|
||||
def _download_error_callback(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
||||
@ -826,7 +835,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
|
||||
if all(x.in_terminal_state for x in install_job.download_parts):
|
||||
# When all parts have reached their terminal state, we finalize the job to clean up the temporary directory and other resources
|
||||
self._install_queue.put(install_job)
|
||||
self._put_in_queue(install_job)
|
||||
|
||||
# ------------------------------------------------------------------------------------------------
|
||||
# Internal methods that put events on the event bus
|
||||
@ -859,6 +868,12 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
id=job.id,
|
||||
)
|
||||
|
||||
def _signal_job_downloads_done(self, job: ModelInstallJob) -> None:
|
||||
job.status = InstallStatus.DOWNLOADS_DONE
|
||||
self._logger.info(f"{job.source}: all parts of this model are downloaded")
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_model_install_downloads_done(str(job.source))
|
||||
|
||||
def _signal_job_completed(self, job: ModelInstallJob) -> None:
|
||||
job.status = InstallStatus.COMPLETED
|
||||
assert job.config_out
|
||||
|
@ -221,9 +221,14 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config:
|
||||
model_record = store.get_model(key)
|
||||
assert Path(model_record.path).exists()
|
||||
|
||||
assert len(bus.events) == 3
|
||||
assert len(bus.events) == 4
|
||||
event_names = [x.event_name for x in bus.events]
|
||||
assert event_names == ["model_install_downloading", "model_install_running", "model_install_completed"]
|
||||
assert event_names == [
|
||||
"model_install_downloading",
|
||||
"model_install_downloads_done",
|
||||
"model_install_running",
|
||||
"model_install_completed",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=20, method="thread")
|
||||
@ -250,7 +255,12 @@ def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_co
|
||||
assert hasattr(bus, "events") # the dummyeventservice has this
|
||||
assert len(bus.events) >= 3
|
||||
event_names = {x.event_name for x in bus.events}
|
||||
assert event_names == {"model_install_downloading", "model_install_running", "model_install_completed"}
|
||||
assert event_names == {
|
||||
"model_install_downloading",
|
||||
"model_install_downloads_done",
|
||||
"model_install_running",
|
||||
"model_install_completed",
|
||||
}
|
||||
|
||||
|
||||
def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user