mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix race condition causing hangs during model install unit tests (#5994)
* fix race condition causing hangs during model install unit tests * remove extraneous sanity checks --------- Co-authored-by: Lincoln Stein <lstein@gmail.com>
This commit is contained in:
parent
c87497fd54
commit
74a51571a0
@ -386,6 +386,17 @@ class EventServiceBase:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def emit_model_install_downloads_done(self, source: str) -> None:
|
||||||
|
"""
|
||||||
|
Emit once when all parts are downloaded, but before the probing and registration start.
|
||||||
|
|
||||||
|
:param source: Source of the model; local path, repo_id or url
|
||||||
|
"""
|
||||||
|
self.__emit_model_event(
|
||||||
|
event_name="model_install_downloads_done",
|
||||||
|
payload={"source": source},
|
||||||
|
)
|
||||||
|
|
||||||
def emit_model_install_running(self, source: str) -> None:
|
def emit_model_install_running(self, source: str) -> None:
|
||||||
"""
|
"""
|
||||||
Emit once when an install job becomes active.
|
Emit once when an install job becomes active.
|
||||||
|
@ -124,15 +124,28 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
if not self._running:
|
if not self._running:
|
||||||
raise Exception("Attempt to stop the install service before it was started")
|
raise Exception("Attempt to stop the install service before it was started")
|
||||||
self._stop_event.set()
|
self._stop_event.set()
|
||||||
with self._install_queue.mutex:
|
self._clear_pending_jobs()
|
||||||
self._install_queue.queue.clear() # get rid of pending jobs
|
|
||||||
active_jobs = [x for x in self.list_jobs() if x.running]
|
|
||||||
if active_jobs:
|
|
||||||
self._logger.warning("Waiting for active install job to complete")
|
|
||||||
self.wait_for_installs()
|
|
||||||
self._download_cache.clear()
|
self._download_cache.clear()
|
||||||
self._running = False
|
self._running = False
|
||||||
|
|
||||||
|
def _clear_pending_jobs(self) -> None:
|
||||||
|
for job in self.list_jobs():
|
||||||
|
if not job.in_terminal_state:
|
||||||
|
self._logger.warning("Cancelling job {job.id}")
|
||||||
|
self.cancel_job(job)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
job = self._install_queue.get(block=False)
|
||||||
|
self._install_queue.task_done()
|
||||||
|
except Empty:
|
||||||
|
break
|
||||||
|
|
||||||
|
def _put_in_queue(self, job: ModelInstallJob) -> None:
|
||||||
|
if self._stop_event.is_set():
|
||||||
|
self.cancel_job(job)
|
||||||
|
else:
|
||||||
|
self._install_queue.put(job)
|
||||||
|
|
||||||
def register_path(
|
def register_path(
|
||||||
self,
|
self,
|
||||||
model_path: Union[Path, str],
|
model_path: Union[Path, str],
|
||||||
@ -218,7 +231,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
|
|
||||||
if isinstance(source, LocalModelSource):
|
if isinstance(source, LocalModelSource):
|
||||||
install_job = self._import_local_model(source, config)
|
install_job = self._import_local_model(source, config)
|
||||||
self._install_queue.put(install_job) # synchronously install
|
self._put_in_queue(install_job) # synchronously install
|
||||||
elif isinstance(source, HFModelSource):
|
elif isinstance(source, HFModelSource):
|
||||||
install_job = self._import_from_hf(source, config)
|
install_job = self._import_from_hf(source, config)
|
||||||
elif isinstance(source, URLModelSource):
|
elif isinstance(source, URLModelSource):
|
||||||
@ -253,7 +266,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
raise TimeoutError("Timeout exceeded")
|
raise TimeoutError("Timeout exceeded")
|
||||||
return job
|
return job
|
||||||
|
|
||||||
# TODO: Better name? Maybe wait_for_jobs()? Maybe too easily confused with above
|
|
||||||
def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: # noqa D102
|
def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: # noqa D102
|
||||||
"""Block until all installation jobs are done."""
|
"""Block until all installation jobs are done."""
|
||||||
start = time.time()
|
start = time.time()
|
||||||
@ -412,7 +424,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
job = self._install_queue.get(timeout=1)
|
job = self._install_queue.get(timeout=1)
|
||||||
except Empty:
|
except Empty:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
assert job.local_path is not None
|
assert job.local_path is not None
|
||||||
try:
|
try:
|
||||||
if job.cancelled:
|
if job.cancelled:
|
||||||
@ -462,8 +473,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._install_completed_event.set()
|
self._install_completed_event.set()
|
||||||
self._install_queue.task_done()
|
self._install_queue.task_done()
|
||||||
|
|
||||||
self._logger.info("Install thread exiting")
|
|
||||||
|
|
||||||
# --------------------------------------------------------------------------------------------
|
# --------------------------------------------------------------------------------------------
|
||||||
# Internal functions that manage the models directory
|
# Internal functions that manage the models directory
|
||||||
# --------------------------------------------------------------------------------------------
|
# --------------------------------------------------------------------------------------------
|
||||||
@ -779,14 +788,14 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._logger.info(f"{download_job.source}: model download complete")
|
self._logger.info(f"{download_job.source}: model download complete")
|
||||||
with self._lock:
|
with self._lock:
|
||||||
install_job = self._download_cache[download_job.source]
|
install_job = self._download_cache[download_job.source]
|
||||||
self._download_cache.pop(download_job.source, None)
|
|
||||||
|
|
||||||
# are there any more active jobs left in this task?
|
# are there any more active jobs left in this task?
|
||||||
if install_job.downloading and all(x.complete for x in install_job.download_parts):
|
if install_job.downloading and all(x.complete for x in install_job.download_parts):
|
||||||
install_job.status = InstallStatus.DOWNLOADS_DONE
|
self._signal_job_downloads_done(install_job)
|
||||||
self._install_queue.put(install_job)
|
self._put_in_queue(install_job)
|
||||||
|
|
||||||
# Let other threads know that the number of downloads has changed
|
# Let other threads know that the number of downloads has changed
|
||||||
|
self._download_cache.pop(download_job.source, None)
|
||||||
self._downloads_changed_event.set()
|
self._downloads_changed_event.set()
|
||||||
|
|
||||||
def _download_error_callback(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
def _download_error_callback(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
||||||
@ -826,7 +835,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
|
|
||||||
if all(x.in_terminal_state for x in install_job.download_parts):
|
if all(x.in_terminal_state for x in install_job.download_parts):
|
||||||
# When all parts have reached their terminal state, we finalize the job to clean up the temporary directory and other resources
|
# When all parts have reached their terminal state, we finalize the job to clean up the temporary directory and other resources
|
||||||
self._install_queue.put(install_job)
|
self._put_in_queue(install_job)
|
||||||
|
|
||||||
# ------------------------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------------------------
|
||||||
# Internal methods that put events on the event bus
|
# Internal methods that put events on the event bus
|
||||||
@ -859,6 +868,12 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
id=job.id,
|
id=job.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _signal_job_downloads_done(self, job: ModelInstallJob) -> None:
|
||||||
|
job.status = InstallStatus.DOWNLOADS_DONE
|
||||||
|
self._logger.info(f"{job.source}: all parts of this model are downloaded")
|
||||||
|
if self._event_bus:
|
||||||
|
self._event_bus.emit_model_install_downloads_done(str(job.source))
|
||||||
|
|
||||||
def _signal_job_completed(self, job: ModelInstallJob) -> None:
|
def _signal_job_completed(self, job: ModelInstallJob) -> None:
|
||||||
job.status = InstallStatus.COMPLETED
|
job.status = InstallStatus.COMPLETED
|
||||||
assert job.config_out
|
assert job.config_out
|
||||||
|
@ -221,9 +221,14 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config:
|
|||||||
model_record = store.get_model(key)
|
model_record = store.get_model(key)
|
||||||
assert Path(model_record.path).exists()
|
assert Path(model_record.path).exists()
|
||||||
|
|
||||||
assert len(bus.events) == 3
|
assert len(bus.events) == 4
|
||||||
event_names = [x.event_name for x in bus.events]
|
event_names = [x.event_name for x in bus.events]
|
||||||
assert event_names == ["model_install_downloading", "model_install_running", "model_install_completed"]
|
assert event_names == [
|
||||||
|
"model_install_downloading",
|
||||||
|
"model_install_downloads_done",
|
||||||
|
"model_install_running",
|
||||||
|
"model_install_completed",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(timeout=20, method="thread")
|
@pytest.mark.timeout(timeout=20, method="thread")
|
||||||
@ -250,7 +255,12 @@ def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_co
|
|||||||
assert hasattr(bus, "events") # the dummyeventservice has this
|
assert hasattr(bus, "events") # the dummyeventservice has this
|
||||||
assert len(bus.events) >= 3
|
assert len(bus.events) >= 3
|
||||||
event_names = {x.event_name for x in bus.events}
|
event_names = {x.event_name for x in bus.events}
|
||||||
assert event_names == {"model_install_downloading", "model_install_running", "model_install_completed"}
|
assert event_names == {
|
||||||
|
"model_install_downloading",
|
||||||
|
"model_install_downloads_done",
|
||||||
|
"model_install_running",
|
||||||
|
"model_install_completed",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user