after stopping install and download services, wait for thread exit

This commit is contained in:
Lincoln Stein 2024-03-20 22:11:52 -04:00
parent e452c6171b
commit ce687a2869
4 changed files with 62 additions and 86 deletions

View File

@ -87,6 +87,8 @@ class DownloadQueueService(DownloadQueueServiceBase):
self._queue.queue.clear() self._queue.queue.clear()
self.join() # wait for all active jobs to finish self.join() # wait for all active jobs to finish
self._stop_event.set() self._stop_event.set()
for thread in self._worker_pool:
thread.join()
self._worker_pool.clear() self._worker_pool.clear()
def submit_download_job( def submit_download_job(

View File

@ -93,6 +93,7 @@ class ModelInstallService(ModelInstallServiceBase):
self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {} self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {}
self._running = False self._running = False
self._session = session self._session = session
self._install_thread: Optional[threading.Thread] = None
self._next_job_id = 0 self._next_job_id = 0
@property @property
@ -127,6 +128,8 @@ class ModelInstallService(ModelInstallServiceBase):
self._stop_event.set() self._stop_event.set()
self._clear_pending_jobs() self._clear_pending_jobs()
self._download_cache.clear() self._download_cache.clear()
assert self._install_thread is not None
self._install_thread.join()
self._running = False self._running = False
def _clear_pending_jobs(self) -> None: def _clear_pending_jobs(self) -> None:
@ -269,19 +272,14 @@ class ModelInstallService(ModelInstallServiceBase):
def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: # noqa D102 def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: # noqa D102
"""Block until all installation jobs are done.""" """Block until all installation jobs are done."""
self.printf("wait_for_installs(): ENTERING")
start = time.time() start = time.time()
while len(self._download_cache) > 0: while len(self._download_cache) > 0:
if self._downloads_changed_event.wait(timeout=0.25): # in case we miss an event if self._downloads_changed_event.wait(timeout=0.25): # in case we miss an event
self._downloads_changed_event.clear() self._downloads_changed_event.clear()
if timeout > 0 and time.time() - start > timeout: if timeout > 0 and time.time() - start > timeout:
raise TimeoutError("Timeout exceeded") raise TimeoutError("Timeout exceeded")
self.printf(
f"wait_for_installs(): install_queue size={self._install_queue.qsize()}, download_cache={self._download_cache}"
)
self._install_queue.join() self._install_queue.join()
self.printf("wait_for_installs(): EXITING")
return self._install_jobs return self._install_jobs
def cancel_job(self, job: ModelInstallJob) -> None: def cancel_job(self, job: ModelInstallJob) -> None:
@ -418,21 +416,21 @@ class ModelInstallService(ModelInstallServiceBase):
# Internal functions that manage the installer threads # Internal functions that manage the installer threads
# -------------------------------------------------------------------------------------------- # --------------------------------------------------------------------------------------------
def _start_installer_thread(self) -> None: def _start_installer_thread(self) -> None:
threading.Thread(target=self._install_next_item, daemon=True).start() self._install_thread = threading.Thread(target=self._install_next_item, daemon=True)
self._install_thread.start()
self._running = True self._running = True
def _install_next_item(self) -> None: def _install_next_item(self) -> None:
done = False self._logger.info(f"Installer thread {threading.get_ident()} starting")
while not done: while True:
if self._stop_event.is_set(): if self._stop_event.is_set():
done = True break
continue self._logger.info(f"Installer thread {threading.get_ident()} running")
try: try:
job = self._install_queue.get(timeout=1) job = self._install_queue.get(timeout=1)
except Empty: except Empty:
continue continue
assert job.local_path is not None assert job.local_path is not None
self.printf(f"_install_next_item(source={job.source}, id={job.id}")
try: try:
if job.cancelled: if job.cancelled:
self._signal_job_cancelled(job) self._signal_job_cancelled(job)
@ -454,8 +452,8 @@ class ModelInstallService(ModelInstallServiceBase):
if job._install_tmpdir is not None: if job._install_tmpdir is not None:
rmtree(job._install_tmpdir) rmtree(job._install_tmpdir)
self._install_completed_event.set() self._install_completed_event.set()
self.printf("Signaling task done")
self._install_queue.task_done() self._install_queue.task_done()
self._logger.info(f"Installer thread {threading.get_ident()} exiting")
def _register_or_install(self, job: ModelInstallJob) -> None: def _register_or_install(self, job: ModelInstallJob) -> None:
# local jobs will be in waiting state, remote jobs will be downloading state # local jobs will be in waiting state, remote jobs will be downloading state
@ -800,25 +798,16 @@ class ModelInstallService(ModelInstallServiceBase):
def _download_complete_callback(self, download_job: DownloadJob) -> None: def _download_complete_callback(self, download_job: DownloadJob) -> None:
self._logger.info(f"{download_job.source}: model download complete") self._logger.info(f"{download_job.source}: model download complete")
with self._lock: with self._lock:
self.printf("_LOCK")
install_job = self._download_cache[download_job.source] install_job = self._download_cache[download_job.source]
self.printf(
f"_download_complete_callback(source={download_job.source}, job={install_job.source}, install_job.id={install_job.id})"
)
# are there any more active jobs left in this task? # are there any more active jobs left in this task?
if install_job.downloading and all(x.complete for x in install_job.download_parts): if install_job.downloading and all(x.complete for x in install_job.download_parts):
self.printf(f"_enqueuing job {install_job.id}")
self._signal_job_downloads_done(install_job) self._signal_job_downloads_done(install_job)
self._put_in_queue(install_job) self._put_in_queue(install_job)
self.printf(f"_enqueued job {install_job.id}")
# Let other threads know that the number of downloads has changed # Let other threads know that the number of downloads has changed
self.printf(f"popping {download_job.source}")
self._download_cache.pop(download_job.source, None) self._download_cache.pop(download_job.source, None)
self._downloads_changed_event.set() self._downloads_changed_event.set()
self.printf("downloads_changed_event is set")
self.printf("_UNLOCK")
def _download_error_callback(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None: def _download_error_callback(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None:
with self._lock: with self._lock:
@ -927,7 +916,3 @@ class ModelInstallService(ModelInstallServiceBase):
if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()): if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
return HuggingFaceMetadataFetch return HuggingFaceMetadataFetch
raise ValueError(f"Unsupported model source: '{url}'") raise ValueError(f"Unsupported model source: '{url}'")
@staticmethod
def printf(message: str) -> None:
print(f"[{time.time():18}] [{threading.get_ident():16}] {message}", flush=True)

View File

@ -5,6 +5,7 @@ Test the model installer
import platform import platform
import uuid import uuid
from pathlib import Path from pathlib import Path
from typing import Any, Dict
import pytest import pytest
from pydantic import ValidationError from pydantic import ValidationError
@ -276,48 +277,48 @@ def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: In
# TODO: Fix bug in model install causing jobs to get installed multiple times then uncomment this test # TODO: Fix bug in model install causing jobs to get installed multiple times then uncomment this test
# @pytest.mark.parametrize( @pytest.mark.parametrize(
# "model_params", "model_params",
# [ [
# # SDXL, Lora # SDXL, Lora
# { {
# "repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors", "repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors",
# "name": "test_lora", "name": "test_lora",
# "type": "embedding", "type": "embedding",
# }, },
# # SDXL, Lora - incorrect type # SDXL, Lora - incorrect type
# { {
# "repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors", "repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors",
# "name": "test_lora", "name": "test_lora",
# "type": "lora", "type": "lora",
# }, },
# ], ],
# ) )
# @pytest.mark.timeout(timeout=40, method="thread") @pytest.mark.timeout(timeout=40, method="thread")
# def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]): def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]):
# """Test whether or not type is respected on configs when passed to heuristic import.""" """Test whether or not type is respected on configs when passed to heuristic import."""
# assert "name" in model_params and "type" in model_params assert "name" in model_params and "type" in model_params
# config1: Dict[str, Any] = { config1: Dict[str, Any] = {
# "name": f"{model_params['name']}_1", "name": f"{model_params['name']}_1",
# "type": model_params["type"], "type": model_params["type"],
# "hash": "placeholder1", "hash": "placeholder1",
# } }
# config2: Dict[str, Any] = { config2: Dict[str, Any] = {
# "name": f"{model_params['name']}_2", "name": f"{model_params['name']}_2",
# "type": ModelType(model_params["type"]), "type": ModelType(model_params["type"]),
# "hash": "placeholder2", "hash": "placeholder2",
# } }
# assert "repo_id" in model_params assert "repo_id" in model_params
# install_job1 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config1) install_job1 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config1)
# mm2_installer.wait_for_job(install_job1, timeout=20) mm2_installer.wait_for_job(install_job1, timeout=20)
# if model_params["type"] != "embedding": if model_params["type"] != "embedding":
# assert install_job1.errored assert install_job1.errored
# assert install_job1.error_type == "InvalidModelConfigException" assert install_job1.error_type == "InvalidModelConfigException"
# return return
# assert install_job1.complete assert install_job1.complete
# assert install_job1.config_out if model_params["type"] == "embedding" else not install_job1.config_out assert install_job1.config_out if model_params["type"] == "embedding" else not install_job1.config_out
# install_job2 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config2) install_job2 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config2)
# mm2_installer.wait_for_job(install_job2, timeout=20) mm2_installer.wait_for_job(install_job2, timeout=20)
# assert install_job2.complete assert install_job2.complete
# assert install_job2.config_out if model_params["type"] == "embedding" else not install_job2.config_out assert install_job2.config_out if model_params["type"] == "embedding" else not install_job2.config_out

View File

@ -2,13 +2,11 @@
import os import os
import shutil import shutil
import time
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List from typing import Any, Dict, List
import pytest import pytest
from pydantic import BaseModel from pydantic import BaseModel
from pytest import FixtureRequest
from requests.sessions import Session from requests.sessions import Session
from requests_testadapter import TestAdapter, TestSession from requests_testadapter import TestAdapter, TestSession
@ -99,15 +97,11 @@ def mm2_app_config(mm2_root_dir: Path) -> InvokeAIAppConfig:
@pytest.fixture @pytest.fixture
def mm2_download_queue(mm2_session: Session, request: FixtureRequest) -> DownloadQueueServiceBase: def mm2_download_queue(mm2_session: Session) -> DownloadQueueServiceBase:
download_queue = DownloadQueueService(requests_session=mm2_session) download_queue = DownloadQueueService(requests_session=mm2_session)
download_queue.start() download_queue.start()
yield download_queue
def stop_queue() -> None: download_queue.stop()
download_queue.stop()
request.addfinalizer(stop_queue)
return download_queue
@pytest.fixture @pytest.fixture
@ -130,7 +124,6 @@ def mm2_installer(
mm2_app_config: InvokeAIAppConfig, mm2_app_config: InvokeAIAppConfig,
mm2_download_queue: DownloadQueueServiceBase, mm2_download_queue: DownloadQueueServiceBase,
mm2_session: Session, mm2_session: Session,
request: FixtureRequest,
) -> ModelInstallServiceBase: ) -> ModelInstallServiceBase:
logger = InvokeAILogger.get_logger() logger = InvokeAILogger.get_logger()
db = create_mock_sqlite_database(mm2_app_config, logger) db = create_mock_sqlite_database(mm2_app_config, logger)
@ -145,13 +138,8 @@ def mm2_installer(
session=mm2_session, session=mm2_session,
) )
installer.start() installer.start()
yield installer
def stop_installer() -> None: installer.stop()
installer.stop()
time.sleep(0.1) # avoid error message from the logger when it is closed before thread prints final message
request.addfinalizer(stop_installer)
return installer
@pytest.fixture @pytest.fixture