mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
after stopping install and download services, wait for thread exit
This commit is contained in:
parent
e452c6171b
commit
ce687a2869
@ -87,6 +87,8 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
self._queue.queue.clear()
|
||||
self.join() # wait for all active jobs to finish
|
||||
self._stop_event.set()
|
||||
for thread in self._worker_pool:
|
||||
thread.join()
|
||||
self._worker_pool.clear()
|
||||
|
||||
def submit_download_job(
|
||||
|
@ -93,6 +93,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {}
|
||||
self._running = False
|
||||
self._session = session
|
||||
self._install_thread: Optional[threading.Thread] = None
|
||||
self._next_job_id = 0
|
||||
|
||||
@property
|
||||
@ -127,6 +128,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._stop_event.set()
|
||||
self._clear_pending_jobs()
|
||||
self._download_cache.clear()
|
||||
assert self._install_thread is not None
|
||||
self._install_thread.join()
|
||||
self._running = False
|
||||
|
||||
def _clear_pending_jobs(self) -> None:
|
||||
@ -269,19 +272,14 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
|
||||
def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: # noqa D102
|
||||
"""Block until all installation jobs are done."""
|
||||
self.printf("wait_for_installs(): ENTERING")
|
||||
start = time.time()
|
||||
while len(self._download_cache) > 0:
|
||||
if self._downloads_changed_event.wait(timeout=0.25): # in case we miss an event
|
||||
self._downloads_changed_event.clear()
|
||||
if timeout > 0 and time.time() - start > timeout:
|
||||
raise TimeoutError("Timeout exceeded")
|
||||
self.printf(
|
||||
f"wait_for_installs(): install_queue size={self._install_queue.qsize()}, download_cache={self._download_cache}"
|
||||
)
|
||||
self._install_queue.join()
|
||||
|
||||
self.printf("wait_for_installs(): EXITING")
|
||||
return self._install_jobs
|
||||
|
||||
def cancel_job(self, job: ModelInstallJob) -> None:
|
||||
@ -418,21 +416,21 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
# Internal functions that manage the installer threads
|
||||
# --------------------------------------------------------------------------------------------
|
||||
def _start_installer_thread(self) -> None:
|
||||
threading.Thread(target=self._install_next_item, daemon=True).start()
|
||||
self._install_thread = threading.Thread(target=self._install_next_item, daemon=True)
|
||||
self._install_thread.start()
|
||||
self._running = True
|
||||
|
||||
def _install_next_item(self) -> None:
|
||||
done = False
|
||||
while not done:
|
||||
self._logger.info(f"Installer thread {threading.get_ident()} starting")
|
||||
while True:
|
||||
if self._stop_event.is_set():
|
||||
done = True
|
||||
continue
|
||||
break
|
||||
self._logger.info(f"Installer thread {threading.get_ident()} running")
|
||||
try:
|
||||
job = self._install_queue.get(timeout=1)
|
||||
except Empty:
|
||||
continue
|
||||
assert job.local_path is not None
|
||||
self.printf(f"_install_next_item(source={job.source}, id={job.id}")
|
||||
try:
|
||||
if job.cancelled:
|
||||
self._signal_job_cancelled(job)
|
||||
@ -454,8 +452,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
if job._install_tmpdir is not None:
|
||||
rmtree(job._install_tmpdir)
|
||||
self._install_completed_event.set()
|
||||
self.printf("Signaling task done")
|
||||
self._install_queue.task_done()
|
||||
self._logger.info(f"Installer thread {threading.get_ident()} exiting")
|
||||
|
||||
def _register_or_install(self, job: ModelInstallJob) -> None:
|
||||
# local jobs will be in waiting state, remote jobs will be downloading state
|
||||
@ -800,25 +798,16 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
def _download_complete_callback(self, download_job: DownloadJob) -> None:
|
||||
self._logger.info(f"{download_job.source}: model download complete")
|
||||
with self._lock:
|
||||
self.printf("_LOCK")
|
||||
install_job = self._download_cache[download_job.source]
|
||||
self.printf(
|
||||
f"_download_complete_callback(source={download_job.source}, job={install_job.source}, install_job.id={install_job.id})"
|
||||
)
|
||||
|
||||
# are there any more active jobs left in this task?
|
||||
if install_job.downloading and all(x.complete for x in install_job.download_parts):
|
||||
self.printf(f"_enqueuing job {install_job.id}")
|
||||
self._signal_job_downloads_done(install_job)
|
||||
self._put_in_queue(install_job)
|
||||
self.printf(f"_enqueued job {install_job.id}")
|
||||
|
||||
# Let other threads know that the number of downloads has changed
|
||||
self.printf(f"popping {download_job.source}")
|
||||
self._download_cache.pop(download_job.source, None)
|
||||
self._downloads_changed_event.set()
|
||||
self.printf("downloads_changed_event is set")
|
||||
self.printf("_UNLOCK")
|
||||
|
||||
def _download_error_callback(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
||||
with self._lock:
|
||||
@ -927,7 +916,3 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
|
||||
return HuggingFaceMetadataFetch
|
||||
raise ValueError(f"Unsupported model source: '{url}'")
|
||||
|
||||
@staticmethod
|
||||
def printf(message: str) -> None:
|
||||
print(f"[{time.time():18}] [{threading.get_ident():16}] {message}", flush=True)
|
||||
|
@ -5,6 +5,7 @@ Test the model installer
|
||||
import platform
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
@ -276,48 +277,48 @@ def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: In
|
||||
|
||||
|
||||
# TODO: Fix bug in model install causing jobs to get installed multiple times then uncomment this test
|
||||
# @pytest.mark.parametrize(
|
||||
# "model_params",
|
||||
# [
|
||||
# # SDXL, Lora
|
||||
# {
|
||||
# "repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors",
|
||||
# "name": "test_lora",
|
||||
# "type": "embedding",
|
||||
# },
|
||||
# # SDXL, Lora - incorrect type
|
||||
# {
|
||||
# "repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors",
|
||||
# "name": "test_lora",
|
||||
# "type": "lora",
|
||||
# },
|
||||
# ],
|
||||
# )
|
||||
# @pytest.mark.timeout(timeout=40, method="thread")
|
||||
# def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]):
|
||||
# """Test whether or not type is respected on configs when passed to heuristic import."""
|
||||
# assert "name" in model_params and "type" in model_params
|
||||
# config1: Dict[str, Any] = {
|
||||
# "name": f"{model_params['name']}_1",
|
||||
# "type": model_params["type"],
|
||||
# "hash": "placeholder1",
|
||||
# }
|
||||
# config2: Dict[str, Any] = {
|
||||
# "name": f"{model_params['name']}_2",
|
||||
# "type": ModelType(model_params["type"]),
|
||||
# "hash": "placeholder2",
|
||||
# }
|
||||
# assert "repo_id" in model_params
|
||||
# install_job1 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config1)
|
||||
# mm2_installer.wait_for_job(install_job1, timeout=20)
|
||||
# if model_params["type"] != "embedding":
|
||||
# assert install_job1.errored
|
||||
# assert install_job1.error_type == "InvalidModelConfigException"
|
||||
# return
|
||||
# assert install_job1.complete
|
||||
# assert install_job1.config_out if model_params["type"] == "embedding" else not install_job1.config_out
|
||||
@pytest.mark.parametrize(
|
||||
"model_params",
|
||||
[
|
||||
# SDXL, Lora
|
||||
{
|
||||
"repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors",
|
||||
"name": "test_lora",
|
||||
"type": "embedding",
|
||||
},
|
||||
# SDXL, Lora - incorrect type
|
||||
{
|
||||
"repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors",
|
||||
"name": "test_lora",
|
||||
"type": "lora",
|
||||
},
|
||||
],
|
||||
)
|
||||
@pytest.mark.timeout(timeout=40, method="thread")
|
||||
def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]):
|
||||
"""Test whether or not type is respected on configs when passed to heuristic import."""
|
||||
assert "name" in model_params and "type" in model_params
|
||||
config1: Dict[str, Any] = {
|
||||
"name": f"{model_params['name']}_1",
|
||||
"type": model_params["type"],
|
||||
"hash": "placeholder1",
|
||||
}
|
||||
config2: Dict[str, Any] = {
|
||||
"name": f"{model_params['name']}_2",
|
||||
"type": ModelType(model_params["type"]),
|
||||
"hash": "placeholder2",
|
||||
}
|
||||
assert "repo_id" in model_params
|
||||
install_job1 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config1)
|
||||
mm2_installer.wait_for_job(install_job1, timeout=20)
|
||||
if model_params["type"] != "embedding":
|
||||
assert install_job1.errored
|
||||
assert install_job1.error_type == "InvalidModelConfigException"
|
||||
return
|
||||
assert install_job1.complete
|
||||
assert install_job1.config_out if model_params["type"] == "embedding" else not install_job1.config_out
|
||||
|
||||
# install_job2 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config2)
|
||||
# mm2_installer.wait_for_job(install_job2, timeout=20)
|
||||
# assert install_job2.complete
|
||||
# assert install_job2.config_out if model_params["type"] == "embedding" else not install_job2.config_out
|
||||
install_job2 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config2)
|
||||
mm2_installer.wait_for_job(install_job2, timeout=20)
|
||||
assert install_job2.complete
|
||||
assert install_job2.config_out if model_params["type"] == "embedding" else not install_job2.config_out
|
||||
|
@ -2,13 +2,11 @@
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from pytest import FixtureRequest
|
||||
from requests.sessions import Session
|
||||
from requests_testadapter import TestAdapter, TestSession
|
||||
|
||||
@ -99,16 +97,12 @@ def mm2_app_config(mm2_root_dir: Path) -> InvokeAIAppConfig:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_download_queue(mm2_session: Session, request: FixtureRequest) -> DownloadQueueServiceBase:
|
||||
def mm2_download_queue(mm2_session: Session) -> DownloadQueueServiceBase:
|
||||
download_queue = DownloadQueueService(requests_session=mm2_session)
|
||||
download_queue.start()
|
||||
|
||||
def stop_queue() -> None:
|
||||
yield download_queue
|
||||
download_queue.stop()
|
||||
|
||||
request.addfinalizer(stop_queue)
|
||||
return download_queue
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase:
|
||||
@ -130,7 +124,6 @@ def mm2_installer(
|
||||
mm2_app_config: InvokeAIAppConfig,
|
||||
mm2_download_queue: DownloadQueueServiceBase,
|
||||
mm2_session: Session,
|
||||
request: FixtureRequest,
|
||||
) -> ModelInstallServiceBase:
|
||||
logger = InvokeAILogger.get_logger()
|
||||
db = create_mock_sqlite_database(mm2_app_config, logger)
|
||||
@ -145,13 +138,8 @@ def mm2_installer(
|
||||
session=mm2_session,
|
||||
)
|
||||
installer.start()
|
||||
|
||||
def stop_installer() -> None:
|
||||
yield installer
|
||||
installer.stop()
|
||||
time.sleep(0.1) # avoid error message from the logger when it is closed before thread prints final message
|
||||
|
||||
request.addfinalizer(stop_installer)
|
||||
return installer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
Loading…
Reference in New Issue
Block a user