after stopping install and download services, wait for thread exit

This commit is contained in:
Lincoln Stein 2024-03-20 22:11:52 -04:00
parent e452c6171b
commit ce687a2869
4 changed files with 62 additions and 86 deletions

View File

@ -87,6 +87,8 @@ class DownloadQueueService(DownloadQueueServiceBase):
self._queue.queue.clear()
self.join() # wait for all active jobs to finish
self._stop_event.set()
for thread in self._worker_pool:
thread.join()
self._worker_pool.clear()
def submit_download_job(

View File

@ -93,6 +93,7 @@ class ModelInstallService(ModelInstallServiceBase):
self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {}
self._running = False
self._session = session
self._install_thread: Optional[threading.Thread] = None
self._next_job_id = 0
@property
@ -127,6 +128,8 @@ class ModelInstallService(ModelInstallServiceBase):
self._stop_event.set()
self._clear_pending_jobs()
self._download_cache.clear()
assert self._install_thread is not None
self._install_thread.join()
self._running = False
def _clear_pending_jobs(self) -> None:
@ -269,19 +272,14 @@ class ModelInstallService(ModelInstallServiceBase):
def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: # noqa D102
"""Block until all installation jobs are done."""
self.printf("wait_for_installs(): ENTERING")
start = time.time()
while len(self._download_cache) > 0:
if self._downloads_changed_event.wait(timeout=0.25): # in case we miss an event
self._downloads_changed_event.clear()
if timeout > 0 and time.time() - start > timeout:
raise TimeoutError("Timeout exceeded")
self.printf(
f"wait_for_installs(): install_queue size={self._install_queue.qsize()}, download_cache={self._download_cache}"
)
self._install_queue.join()
self.printf("wait_for_installs(): EXITING")
return self._install_jobs
def cancel_job(self, job: ModelInstallJob) -> None:
@ -418,21 +416,21 @@ class ModelInstallService(ModelInstallServiceBase):
# Internal functions that manage the installer threads
# --------------------------------------------------------------------------------------------
def _start_installer_thread(self) -> None:
threading.Thread(target=self._install_next_item, daemon=True).start()
self._install_thread = threading.Thread(target=self._install_next_item, daemon=True)
self._install_thread.start()
self._running = True
def _install_next_item(self) -> None:
done = False
while not done:
self._logger.info(f"Installer thread {threading.get_ident()} starting")
while True:
if self._stop_event.is_set():
done = True
continue
break
self._logger.info(f"Installer thread {threading.get_ident()} running")
try:
job = self._install_queue.get(timeout=1)
except Empty:
continue
assert job.local_path is not None
self.printf(f"_install_next_item(source={job.source}, id={job.id}")
try:
if job.cancelled:
self._signal_job_cancelled(job)
@ -454,8 +452,8 @@ class ModelInstallService(ModelInstallServiceBase):
if job._install_tmpdir is not None:
rmtree(job._install_tmpdir)
self._install_completed_event.set()
self.printf("Signaling task done")
self._install_queue.task_done()
self._logger.info(f"Installer thread {threading.get_ident()} exiting")
def _register_or_install(self, job: ModelInstallJob) -> None:
# local jobs will be in waiting state, remote jobs will be downloading state
@ -800,25 +798,16 @@ class ModelInstallService(ModelInstallServiceBase):
def _download_complete_callback(self, download_job: DownloadJob) -> None:
self._logger.info(f"{download_job.source}: model download complete")
with self._lock:
self.printf("_LOCK")
install_job = self._download_cache[download_job.source]
self.printf(
f"_download_complete_callback(source={download_job.source}, job={install_job.source}, install_job.id={install_job.id})"
)
# are there any more active jobs left in this task?
if install_job.downloading and all(x.complete for x in install_job.download_parts):
self.printf(f"_enqueuing job {install_job.id}")
self._signal_job_downloads_done(install_job)
self._put_in_queue(install_job)
self.printf(f"_enqueued job {install_job.id}")
# Let other threads know that the number of downloads has changed
self.printf(f"popping {download_job.source}")
self._download_cache.pop(download_job.source, None)
self._downloads_changed_event.set()
self.printf("downloads_changed_event is set")
self.printf("_UNLOCK")
def _download_error_callback(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None:
with self._lock:
@ -927,7 +916,3 @@ class ModelInstallService(ModelInstallServiceBase):
if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
return HuggingFaceMetadataFetch
raise ValueError(f"Unsupported model source: '{url}'")
@staticmethod
def printf(message: str) -> None:
print(f"[{time.time():18}] [{threading.get_ident():16}] {message}", flush=True)

View File

@ -5,6 +5,7 @@ Test the model installer
import platform
import uuid
from pathlib import Path
from typing import Any, Dict
import pytest
from pydantic import ValidationError
@ -276,48 +277,48 @@ def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: In
# TODO: Fix bug in model install causing jobs to get installed multiple times then uncomment this test
# @pytest.mark.parametrize(
# "model_params",
# [
# # SDXL, Lora
# {
# "repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors",
# "name": "test_lora",
# "type": "embedding",
# },
# # SDXL, Lora - incorrect type
# {
# "repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors",
# "name": "test_lora",
# "type": "lora",
# },
# ],
# )
# @pytest.mark.timeout(timeout=40, method="thread")
# def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]):
# """Test whether or not type is respected on configs when passed to heuristic import."""
# assert "name" in model_params and "type" in model_params
# config1: Dict[str, Any] = {
# "name": f"{model_params['name']}_1",
# "type": model_params["type"],
# "hash": "placeholder1",
# }
# config2: Dict[str, Any] = {
# "name": f"{model_params['name']}_2",
# "type": ModelType(model_params["type"]),
# "hash": "placeholder2",
# }
# assert "repo_id" in model_params
# install_job1 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config1)
# mm2_installer.wait_for_job(install_job1, timeout=20)
# if model_params["type"] != "embedding":
# assert install_job1.errored
# assert install_job1.error_type == "InvalidModelConfigException"
# return
# assert install_job1.complete
# assert install_job1.config_out if model_params["type"] == "embedding" else not install_job1.config_out
@pytest.mark.parametrize(
"model_params",
[
# SDXL, Lora
{
"repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors",
"name": "test_lora",
"type": "embedding",
},
# SDXL, Lora - incorrect type
{
"repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors",
"name": "test_lora",
"type": "lora",
},
],
)
@pytest.mark.timeout(timeout=40, method="thread")
def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]):
"""Test whether or not type is respected on configs when passed to heuristic import."""
assert "name" in model_params and "type" in model_params
config1: Dict[str, Any] = {
"name": f"{model_params['name']}_1",
"type": model_params["type"],
"hash": "placeholder1",
}
config2: Dict[str, Any] = {
"name": f"{model_params['name']}_2",
"type": ModelType(model_params["type"]),
"hash": "placeholder2",
}
assert "repo_id" in model_params
install_job1 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config1)
mm2_installer.wait_for_job(install_job1, timeout=20)
if model_params["type"] != "embedding":
assert install_job1.errored
assert install_job1.error_type == "InvalidModelConfigException"
return
assert install_job1.complete
assert install_job1.config_out if model_params["type"] == "embedding" else not install_job1.config_out
# install_job2 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config2)
# mm2_installer.wait_for_job(install_job2, timeout=20)
# assert install_job2.complete
# assert install_job2.config_out if model_params["type"] == "embedding" else not install_job2.config_out
install_job2 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config2)
mm2_installer.wait_for_job(install_job2, timeout=20)
assert install_job2.complete
assert install_job2.config_out if model_params["type"] == "embedding" else not install_job2.config_out

View File

@ -2,13 +2,11 @@
import os
import shutil
import time
from pathlib import Path
from typing import Any, Dict, List
import pytest
from pydantic import BaseModel
from pytest import FixtureRequest
from requests.sessions import Session
from requests_testadapter import TestAdapter, TestSession
@ -99,16 +97,12 @@ def mm2_app_config(mm2_root_dir: Path) -> InvokeAIAppConfig:
@pytest.fixture
def mm2_download_queue(mm2_session: Session, request: FixtureRequest) -> DownloadQueueServiceBase:
def mm2_download_queue(mm2_session: Session) -> DownloadQueueServiceBase:
download_queue = DownloadQueueService(requests_session=mm2_session)
download_queue.start()
def stop_queue() -> None:
yield download_queue
download_queue.stop()
request.addfinalizer(stop_queue)
return download_queue
@pytest.fixture
def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase:
@ -130,7 +124,6 @@ def mm2_installer(
mm2_app_config: InvokeAIAppConfig,
mm2_download_queue: DownloadQueueServiceBase,
mm2_session: Session,
request: FixtureRequest,
) -> ModelInstallServiceBase:
logger = InvokeAILogger.get_logger()
db = create_mock_sqlite_database(mm2_app_config, logger)
@ -145,13 +138,8 @@ def mm2_installer(
session=mm2_session,
)
installer.start()
def stop_installer() -> None:
yield installer
installer.stop()
time.sleep(0.1) # avoid error message from the logger when it is closed before thread prints final message
request.addfinalizer(stop_installer)
return installer
@pytest.fixture