From ce687a28692ccdcb8ebdcbc427ab0e3f2338d810 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 20 Mar 2024 22:11:52 -0400 Subject: [PATCH] after stopping install and download services, wait for thread exit --- .../app/services/download/download_default.py | 2 + .../model_install/model_install_default.py | 35 +++----- .../model_install/test_model_install.py | 89 ++++++++++--------- .../model_manager/model_manager_fixtures.py | 22 ++--- 4 files changed, 62 insertions(+), 86 deletions(-) diff --git a/invokeai/app/services/download/download_default.py b/invokeai/app/services/download/download_default.py index 843351a259..bc32422a5b 100644 --- a/invokeai/app/services/download/download_default.py +++ b/invokeai/app/services/download/download_default.py @@ -87,6 +87,8 @@ class DownloadQueueService(DownloadQueueServiceBase): self._queue.queue.clear() self.join() # wait for all active jobs to finish self._stop_event.set() + for thread in self._worker_pool: + thread.join() self._worker_pool.clear() def submit_download_job( diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index e4926ce3bd..c344e8f18e 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -93,6 +93,7 @@ class ModelInstallService(ModelInstallServiceBase): self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {} self._running = False self._session = session + self._install_thread: Optional[threading.Thread] = None self._next_job_id = 0 @property @@ -127,6 +128,8 @@ class ModelInstallService(ModelInstallServiceBase): self._stop_event.set() self._clear_pending_jobs() self._download_cache.clear() + assert self._install_thread is not None + self._install_thread.join() self._running = False def _clear_pending_jobs(self) -> None: @@ -269,19 +272,14 @@ class ModelInstallService(ModelInstallServiceBase): def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: # noqa D102 """Block until all installation jobs are done.""" - self.printf("wait_for_installs(): ENTERING") start = time.time() while len(self._download_cache) > 0: if self._downloads_changed_event.wait(timeout=0.25): # in case we miss an event self._downloads_changed_event.clear() if timeout > 0 and time.time() - start > timeout: raise TimeoutError("Timeout exceeded") - self.printf( - f"wait_for_installs(): install_queue size={self._install_queue.qsize()}, download_cache={self._download_cache}" - ) self._install_queue.join() - self.printf("wait_for_installs(): EXITING") return self._install_jobs def cancel_job(self, job: ModelInstallJob) -> None: @@ -418,21 +416,21 @@ class ModelInstallService(ModelInstallServiceBase): # Internal functions that manage the installer threads # -------------------------------------------------------------------------------------------- def _start_installer_thread(self) -> None: - threading.Thread(target=self._install_next_item, daemon=True).start() + self._install_thread = threading.Thread(target=self._install_next_item, daemon=True) + self._install_thread.start() self._running = True def _install_next_item(self) -> None: - done = False - while not done: + self._logger.info(f"Installer thread {threading.get_ident()} starting") + while True: if self._stop_event.is_set(): - done = True - continue + break + self._logger.info(f"Installer thread {threading.get_ident()} running") try: job = self._install_queue.get(timeout=1) except Empty: continue assert job.local_path is not None - self.printf(f"_install_next_item(source={job.source}, id={job.id}") try: if job.cancelled: self._signal_job_cancelled(job) @@ -454,8 +452,8 @@ class ModelInstallService(ModelInstallServiceBase): if job._install_tmpdir is not None: rmtree(job._install_tmpdir) self._install_completed_event.set() - self.printf("Signaling task done") self._install_queue.task_done() + self._logger.info(f"Installer thread {threading.get_ident()} exiting") def _register_or_install(self, job: ModelInstallJob) -> None: # local jobs will be in waiting state, remote jobs will be downloading state @@ -800,25 +798,16 @@ class ModelInstallService(ModelInstallServiceBase): def _download_complete_callback(self, download_job: DownloadJob) -> None: self._logger.info(f"{download_job.source}: model download complete") with self._lock: - self.printf("_LOCK") install_job = self._download_cache[download_job.source] - self.printf( - f"_download_complete_callback(source={download_job.source}, job={install_job.source}, install_job.id={install_job.id})" - ) # are there any more active jobs left in this task? if install_job.downloading and all(x.complete for x in install_job.download_parts): - self.printf(f"_enqueuing job {install_job.id}") self._signal_job_downloads_done(install_job) self._put_in_queue(install_job) - self.printf(f"_enqueued job {install_job.id}") # Let other threads know that the number of downloads has changed - self.printf(f"popping {download_job.source}") self._download_cache.pop(download_job.source, None) self._downloads_changed_event.set() - self.printf("downloads_changed_event is set") - self.printf("_UNLOCK") def _download_error_callback(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None: with self._lock: @@ -927,7 +916,3 @@ class ModelInstallService(ModelInstallServiceBase): if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()): return HuggingFaceMetadataFetch raise ValueError(f"Unsupported model source: '{url}'") - - @staticmethod - def printf(message: str) -> None: - print(f"[{time.time():18}] [{threading.get_ident():16}] {message}", flush=True) diff --git a/tests/app/services/model_install/test_model_install.py b/tests/app/services/model_install/test_model_install.py index 79895ed380..ad9b2bb7a8 100644 --- a/tests/app/services/model_install/test_model_install.py +++ b/tests/app/services/model_install/test_model_install.py @@ -5,6 +5,7 @@ Test the model installer import platform import uuid from pathlib import Path +from typing import Any, Dict import pytest from pydantic import ValidationError @@ -276,48 +277,48 @@ def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: In # TODO: Fix bug in model install causing jobs to get installed multiple times then uncomment this test -# @pytest.mark.parametrize( -# "model_params", -# [ -# # SDXL, Lora -# { -# "repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors", -# "name": "test_lora", -# "type": "embedding", -# }, -# # SDXL, Lora - incorrect type -# { -# "repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors", -# "name": "test_lora", -# "type": "lora", -# }, -# ], -# ) -# @pytest.mark.timeout(timeout=40, method="thread") -# def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]): -# """Test whether or not type is respected on configs when passed to heuristic import.""" -# assert "name" in model_params and "type" in model_params -# config1: Dict[str, Any] = { -# "name": f"{model_params['name']}_1", -# "type": model_params["type"], -# "hash": "placeholder1", -# } -# config2: Dict[str, Any] = { -# "name": f"{model_params['name']}_2", -# "type": ModelType(model_params["type"]), -# "hash": "placeholder2", -# } -# assert "repo_id" in model_params -# install_job1 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config1) -# mm2_installer.wait_for_job(install_job1, timeout=20) -# if model_params["type"] != "embedding": -# assert install_job1.errored -# assert install_job1.error_type == "InvalidModelConfigException" -# return -# assert install_job1.complete -# assert install_job1.config_out if model_params["type"] == "embedding" else not install_job1.config_out +@pytest.mark.parametrize( + "model_params", + [ + # SDXL, Lora + { + "repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors", + "name": "test_lora", + "type": "embedding", + }, + # SDXL, Lora - incorrect type + { + "repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors", + "name": "test_lora", + "type": "lora", + }, + ], +) +@pytest.mark.timeout(timeout=40, method="thread") +def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]): + """Test whether or not type is respected on configs when passed to heuristic import.""" + assert "name" in model_params and "type" in model_params + config1: Dict[str, Any] = { + "name": f"{model_params['name']}_1", + "type": model_params["type"], + "hash": "placeholder1", + } + config2: Dict[str, Any] = { + "name": f"{model_params['name']}_2", + "type": ModelType(model_params["type"]), + "hash": "placeholder2", + } + assert "repo_id" in model_params + install_job1 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config1) + mm2_installer.wait_for_job(install_job1, timeout=20) + if model_params["type"] != "embedding": + assert install_job1.errored + assert install_job1.error_type == "InvalidModelConfigException" + return + assert install_job1.complete + assert install_job1.config_out if model_params["type"] == "embedding" else not install_job1.config_out -# install_job2 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config2) -# mm2_installer.wait_for_job(install_job2, timeout=20) -# assert install_job2.complete -# assert install_job2.config_out if model_params["type"] == "embedding" else not install_job2.config_out + install_job2 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config2) + mm2_installer.wait_for_job(install_job2, timeout=20) + assert install_job2.complete + assert install_job2.config_out if model_params["type"] == "embedding" else not install_job2.config_out diff --git a/tests/backend/model_manager/model_manager_fixtures.py b/tests/backend/model_manager/model_manager_fixtures.py index 8d4ccf196c..9d13838a04 100644 --- a/tests/backend/model_manager/model_manager_fixtures.py +++ b/tests/backend/model_manager/model_manager_fixtures.py @@ -2,13 +2,11 @@ import os import shutil -import time from pathlib import Path from typing import Any, Dict, List import pytest from pydantic import BaseModel -from pytest import FixtureRequest from requests.sessions import Session from requests_testadapter import TestAdapter, TestSession @@ -99,15 +97,11 @@ def mm2_app_config(mm2_root_dir: Path) -> InvokeAIAppConfig: @pytest.fixture -def mm2_download_queue(mm2_session: Session, request: FixtureRequest) -> DownloadQueueServiceBase: +def mm2_download_queue(mm2_session: Session) -> DownloadQueueServiceBase: download_queue = DownloadQueueService(requests_session=mm2_session) download_queue.start() - - def stop_queue() -> None: - download_queue.stop() - - request.addfinalizer(stop_queue) - return download_queue + yield download_queue + download_queue.stop() @pytest.fixture @@ -130,7 +124,6 @@ def mm2_installer( mm2_app_config: InvokeAIAppConfig, mm2_download_queue: DownloadQueueServiceBase, mm2_session: Session, - request: FixtureRequest, ) -> ModelInstallServiceBase: logger = InvokeAILogger.get_logger() db = create_mock_sqlite_database(mm2_app_config, logger) @@ -145,13 +138,8 @@ def mm2_installer( session=mm2_session, ) installer.start() - - def stop_installer() -> None: - installer.stop() - time.sleep(0.1) # avoid error message from the logger when it is closed before thread prints final message - - request.addfinalizer(stop_installer) - return installer + yield installer + installer.stop() @pytest.fixture