Fix race condition causing hangs during model install unit tests (#5994)

* fix race condition causing hangs during model install unit tests

* remove extraneous sanity checks

---------

Co-authored-by: Lincoln Stein <lstein@gmail.com>
This commit is contained in:
Lincoln Stein 2024-03-19 16:54:49 -04:00 committed by GitHub
parent c87497fd54
commit 74a51571a0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 54 additions and 18 deletions

View File

@ -386,6 +386,17 @@ class EventServiceBase:
},
)
def emit_model_install_downloads_done(self, source: str) -> None:
"""
Emit once when all parts are downloaded, but before the probing and registration start.
:param source: Source of the model; local path, repo_id or url
"""
self.__emit_model_event(
event_name="model_install_downloads_done",
payload={"source": source},
)
def emit_model_install_running(self, source: str) -> None:
"""
Emit once when an install job becomes active.

View File

@ -124,15 +124,28 @@ class ModelInstallService(ModelInstallServiceBase):
if not self._running:
raise Exception("Attempt to stop the install service before it was started")
self._stop_event.set()
with self._install_queue.mutex:
self._install_queue.queue.clear() # get rid of pending jobs
active_jobs = [x for x in self.list_jobs() if x.running]
if active_jobs:
self._logger.warning("Waiting for active install job to complete")
self.wait_for_installs()
self._clear_pending_jobs()
self._download_cache.clear()
self._running = False
def _clear_pending_jobs(self) -> None:
for job in self.list_jobs():
if not job.in_terminal_state:
self._logger.warning("Cancelling job {job.id}")
self.cancel_job(job)
while True:
try:
job = self._install_queue.get(block=False)
self._install_queue.task_done()
except Empty:
break
def _put_in_queue(self, job: ModelInstallJob) -> None:
if self._stop_event.is_set():
self.cancel_job(job)
else:
self._install_queue.put(job)
def register_path(
self,
model_path: Union[Path, str],
@ -218,7 +231,7 @@ class ModelInstallService(ModelInstallServiceBase):
if isinstance(source, LocalModelSource):
install_job = self._import_local_model(source, config)
self._install_queue.put(install_job) # synchronously install
self._put_in_queue(install_job) # synchronously install
elif isinstance(source, HFModelSource):
install_job = self._import_from_hf(source, config)
elif isinstance(source, URLModelSource):
@ -253,7 +266,6 @@ class ModelInstallService(ModelInstallServiceBase):
raise TimeoutError("Timeout exceeded")
return job
# TODO: Better name? Maybe wait_for_jobs()? Maybe too easily confused with above
def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: # noqa D102
"""Block until all installation jobs are done."""
start = time.time()
@ -412,7 +424,6 @@ class ModelInstallService(ModelInstallServiceBase):
job = self._install_queue.get(timeout=1)
except Empty:
continue
assert job.local_path is not None
try:
if job.cancelled:
@ -462,8 +473,6 @@ class ModelInstallService(ModelInstallServiceBase):
self._install_completed_event.set()
self._install_queue.task_done()
self._logger.info("Install thread exiting")
# --------------------------------------------------------------------------------------------
# Internal functions that manage the models directory
# --------------------------------------------------------------------------------------------
@ -779,14 +788,14 @@ class ModelInstallService(ModelInstallServiceBase):
self._logger.info(f"{download_job.source}: model download complete")
with self._lock:
install_job = self._download_cache[download_job.source]
self._download_cache.pop(download_job.source, None)
# are there any more active jobs left in this task?
if install_job.downloading and all(x.complete for x in install_job.download_parts):
install_job.status = InstallStatus.DOWNLOADS_DONE
self._install_queue.put(install_job)
self._signal_job_downloads_done(install_job)
self._put_in_queue(install_job)
# Let other threads know that the number of downloads has changed
self._download_cache.pop(download_job.source, None)
self._downloads_changed_event.set()
def _download_error_callback(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None:
@ -826,7 +835,7 @@ class ModelInstallService(ModelInstallServiceBase):
if all(x.in_terminal_state for x in install_job.download_parts):
# When all parts have reached their terminal state, we finalize the job to clean up the temporary directory and other resources
self._install_queue.put(install_job)
self._put_in_queue(install_job)
# ------------------------------------------------------------------------------------------------
# Internal methods that put events on the event bus
@ -859,6 +868,12 @@ class ModelInstallService(ModelInstallServiceBase):
id=job.id,
)
def _signal_job_downloads_done(self, job: ModelInstallJob) -> None:
job.status = InstallStatus.DOWNLOADS_DONE
self._logger.info(f"{job.source}: all parts of this model are downloaded")
if self._event_bus:
self._event_bus.emit_model_install_downloads_done(str(job.source))
def _signal_job_completed(self, job: ModelInstallJob) -> None:
job.status = InstallStatus.COMPLETED
assert job.config_out

View File

@ -221,9 +221,14 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config:
model_record = store.get_model(key)
assert Path(model_record.path).exists()
assert len(bus.events) == 3
assert len(bus.events) == 4
event_names = [x.event_name for x in bus.events]
assert event_names == ["model_install_downloading", "model_install_running", "model_install_completed"]
assert event_names == [
"model_install_downloading",
"model_install_downloads_done",
"model_install_running",
"model_install_completed",
]
@pytest.mark.timeout(timeout=20, method="thread")
@ -250,7 +255,12 @@ def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_co
assert hasattr(bus, "events") # the dummyeventservice has this
assert len(bus.events) >= 3
event_names = {x.event_name for x in bus.events}
assert event_names == {"model_install_downloading", "model_install_running", "model_install_completed"}
assert event_names == {
"model_install_downloading",
"model_install_downloads_done",
"model_install_running",
"model_install_completed",
}
def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: