add debugging statements and a timeout to download test

This commit is contained in:
Lincoln Stein 2024-02-29 06:34:59 -05:00 committed by psychedelicious
parent dd9daf8efb
commit e18533e3b5
7 changed files with 55 additions and 14 deletions

View File

@ -221,22 +221,29 @@ class DownloadQueueService(DownloadQueueServiceBase):
except Empty: except Empty:
continue continue
try: try:
print(f'DEBUG: job [{job.id}] started', flush=True)
job.job_started = get_iso_timestamp() job.job_started = get_iso_timestamp()
self._do_download(job) self._do_download(job)
print(f'DEBUG: job [{job.id}] download completed', flush=True)
self._signal_job_complete(job) self._signal_job_complete(job)
print(f'DEBUG: job [{job.id}] signaled completion', flush=True)
except (OSError, HTTPError) as excp: except (OSError, HTTPError) as excp:
job.error_type = excp.__class__.__name__ + f"({str(excp)})" job.error_type = excp.__class__.__name__ + f"({str(excp)})"
job.error = traceback.format_exc() job.error = traceback.format_exc()
self._signal_job_error(job, excp) self._signal_job_error(job, excp)
print(f'DEBUG: job [{job.id}] signaled error', flush=True)
except DownloadJobCancelledException: except DownloadJobCancelledException:
self._signal_job_cancelled(job) self._signal_job_cancelled(job)
self._cleanup_cancelled_job(job) self._cleanup_cancelled_job(job)
print(f'DEBUG: job [{job.id}] signaled cancelled', flush=True)
finally: finally:
print(f'DEBUG: job [{job.id}] signalling completion', flush=True)
job.job_ended = get_iso_timestamp() job.job_ended = get_iso_timestamp()
self._job_completed_event.set() # signal a change to terminal state self._job_completed_event.set() # signal a change to terminal state
print(f'DEBUG: job [{job.id}] set job completion event', flush=True)
self._queue.task_done() self._queue.task_done()
print(f'DEBUG: job [{job.id}] done', flush=True)
self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.") self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.")
def _do_download(self, job: DownloadJob) -> None: def _do_download(self, job: DownloadJob) -> None:
@ -244,7 +251,8 @@ class DownloadQueueService(DownloadQueueServiceBase):
url = job.source url = job.source
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {} header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
open_mode = "wb" open_mode = "wb"
print(f'DEBUG: In _do_download [0]', flush=True)
# Make a streaming request. This will retrieve headers including # Make a streaming request. This will retrieve headers including
# content-length and content-disposition, but not fetch any content itself # content-length and content-disposition, but not fetch any content itself
resp = self._requests.get(str(url), headers=header, stream=True) resp = self._requests.get(str(url), headers=header, stream=True)
@ -255,6 +263,8 @@ class DownloadQueueService(DownloadQueueServiceBase):
content_length = int(resp.headers.get("content-length", 0)) content_length = int(resp.headers.get("content-length", 0))
job.total_bytes = content_length job.total_bytes = content_length
print(f'DEBUG: In _do_download [1]')
if job.dest.is_dir(): if job.dest.is_dir():
file_name = os.path.basename(str(url.path)) # default is to use the last bit of the URL file_name = os.path.basename(str(url.path)) # default is to use the last bit of the URL
@ -282,6 +292,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
# signal caller that the download is starting. At this point, key fields such as # signal caller that the download is starting. At this point, key fields such as
# download_path and total_bytes will be populated. We call it here because the might # download_path and total_bytes will be populated. We call it here because the might
# discover that the local file is already complete and generate a COMPLETED status. # discover that the local file is already complete and generate a COMPLETED status.
print(f'DEBUG: In _do_download [2]', flush=True)
self._signal_job_started(job) self._signal_job_started(job)
# "range not satisfiable" - local file is at least as large as the remote file # "range not satisfiable" - local file is at least as large as the remote file
@ -297,12 +308,15 @@ class DownloadQueueService(DownloadQueueServiceBase):
elif resp.status_code != 200: elif resp.status_code != 200:
raise HTTPError(resp.reason) raise HTTPError(resp.reason)
print(f'DEBUG: In _do_download [3]', flush=True)
self._logger.debug(f"{job.source}: Downloading {job.download_path}") self._logger.debug(f"{job.source}: Downloading {job.download_path}")
report_delta = job.total_bytes / 100 # report every 1% change report_delta = job.total_bytes / 100 # report every 1% change
last_report_bytes = 0 last_report_bytes = 0
# DOWNLOAD LOOP # DOWNLOAD LOOP
with open(in_progress_path, open_mode) as file: with open(in_progress_path, open_mode) as file:
print(f'DEBUG: In _do_download loop [4]', flush=True)
for data in resp.iter_content(chunk_size=DOWNLOAD_CHUNK_SIZE): for data in resp.iter_content(chunk_size=DOWNLOAD_CHUNK_SIZE):
if job.cancelled: if job.cancelled:
raise DownloadJobCancelledException("Job was cancelled at caller's request") raise DownloadJobCancelledException("Job was cancelled at caller's request")
@ -315,6 +329,8 @@ class DownloadQueueService(DownloadQueueServiceBase):
# if we get here we are done and can rename the file to the original dest # if we get here we are done and can rename the file to the original dest
self._logger.debug(f"{job.source}: saved to {job.download_path} (bytes={job.bytes})") self._logger.debug(f"{job.source}: saved to {job.download_path} (bytes={job.bytes})")
in_progress_path.rename(job.download_path) in_progress_path.rename(job.download_path)
print(f'DEBUG: In _do_download [5]', flush=True)
def _validate_filename(self, directory: str, filename: str) -> bool: def _validate_filename(self, directory: str, filename: str) -> bool:
pc_name_max = os.pathconf(directory, "PC_NAME_MAX") if hasattr(os, "pathconf") else 260 # hardcoded for windows pc_name_max = os.pathconf(directory, "PC_NAME_MAX") if hasattr(os, "pathconf") else 260 # hardcoded for windows

View File

@ -284,7 +284,7 @@ class ModelInstallService(ModelInstallServiceBase):
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102 def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
self._cached_model_paths = {Path(x.path) for x in self.record_store.all_models()} self._cached_model_paths = {Path(x.path) for x in self.record_store.all_models()}
callback = self._scan_install if install else self._scan_register callback = self._scan_install if install else self._scan_register
search = ModelSearch(on_model_found=callback) search = ModelSearch(on_model_found=callback, config=self._app_config)
self._models_installed.clear() self._models_installed.clear()
search.search(scan_dir) search.search(scan_dir)
return list(self._models_installed) return list(self._models_installed)

View File

@ -118,6 +118,7 @@ dependencies = [
"pre-commit", "pre-commit",
"pytest>6.0.0", "pytest>6.0.0",
"pytest-cov", "pytest-cov",
"pytest-timeout",
"pytest-datadir", "pytest-datadir",
"requests_testadapter", "requests_testadapter",
"httpx", "httpx",
@ -189,6 +190,7 @@ version = { attr = "invokeai.version.__version__" }
addopts = "--cov-report term --cov-report html --cov-report xml --strict-markers -m \"not slow\"" addopts = "--cov-report term --cov-report html --cov-report xml --strict-markers -m \"not slow\""
markers = [ markers = [
"slow: Marks tests as slow. Disabled by default. To run all tests, use -m \"\". To run only slow tests, use -m \"slow\".", "slow: Marks tests as slow. Disabled by default. To run all tests, use -m \"\". To run only slow tests, use -m \"slow\".",
"timeout: 60"
] ]
[tool.coverage.run] [tool.coverage.run]
branch = true branch = true

View File

@ -1,6 +1,8 @@
import os
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
import pytest
from fastapi import BackgroundTasks from fastapi import BackgroundTasks
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
@ -9,7 +11,11 @@ from invokeai.app.api_app import app
from invokeai.app.services.board_records.board_records_common import BoardRecord from invokeai.app.services.board_records.board_records_common import BoardRecord
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
client = TestClient(app)
@pytest.fixture(autouse=True, scope="module")
def client(invokeai_root_dir: Path) -> TestClient:
os.environ["INVOKEAI_ROOT"] = invokeai_root_dir.as_posix()
return TestClient(app)
class MockApiDependencies(ApiDependencies): class MockApiDependencies(ApiDependencies):
@ -19,7 +25,7 @@ class MockApiDependencies(ApiDependencies):
self.invoker = invoker self.invoker = invoker
def test_download_images_from_list(monkeypatch: Any, mock_invoker: Invoker) -> None: def test_download_images_from_list(monkeypatch: Any, mock_invoker: Invoker, client: TestClient) -> None:
prepare_download_images_test(monkeypatch, mock_invoker) prepare_download_images_test(monkeypatch, mock_invoker)
response = client.post("/api/v1/images/download", json={"image_names": ["test.png"]}) response = client.post("/api/v1/images/download", json={"image_names": ["test.png"]})
@ -28,7 +34,9 @@ def test_download_images_from_list(monkeypatch: Any, mock_invoker: Invoker) -> N
assert json_response["bulk_download_item_name"] == "test.zip" assert json_response["bulk_download_item_name"] == "test.zip"
def test_download_images_from_board_id_empty_image_name_list(monkeypatch: Any, mock_invoker: Invoker) -> None: def test_download_images_from_board_id_empty_image_name_list(
monkeypatch: Any, mock_invoker: Invoker, client: TestClient
) -> None:
expected_board_name = "test" expected_board_name = "test"
def mock_get(*args, **kwargs): def mock_get(*args, **kwargs):
@ -56,7 +64,9 @@ def prepare_download_images_test(monkeypatch: Any, mock_invoker: Invoker) -> Non
monkeypatch.setattr(BackgroundTasks, "add_task", mock_add_task) monkeypatch.setattr(BackgroundTasks, "add_task", mock_add_task)
def test_download_images_with_empty_image_list_and_no_board_id(monkeypatch: Any, mock_invoker: Invoker) -> None: def test_download_images_with_empty_image_list_and_no_board_id(
monkeypatch: Any, mock_invoker: Invoker, client: TestClient
) -> None:
prepare_download_images_test(monkeypatch, mock_invoker) prepare_download_images_test(monkeypatch, mock_invoker)
response = client.post("/api/v1/images/download", json={"image_names": []}) response = client.post("/api/v1/images/download", json={"image_names": []})
@ -64,7 +74,7 @@ def test_download_images_with_empty_image_list_and_no_board_id(monkeypatch: Any,
assert response.status_code == 400 assert response.status_code == 400
def test_get_bulk_download_image(tmp_path: Path, monkeypatch: Any, mock_invoker: Invoker) -> None: def test_get_bulk_download_image(tmp_path: Path, monkeypatch: Any, mock_invoker: Invoker, client: TestClient) -> None:
mock_file: Path = tmp_path / "test.zip" mock_file: Path = tmp_path / "test.zip"
mock_file.write_text("contents") mock_file.write_text("contents")
@ -82,7 +92,7 @@ def test_get_bulk_download_image(tmp_path: Path, monkeypatch: Any, mock_invoker:
assert response.content == b"contents" assert response.content == b"contents"
def test_get_bulk_download_image_not_found(monkeypatch: Any, mock_invoker: Invoker) -> None: def test_get_bulk_download_image_not_found(monkeypatch: Any, mock_invoker: Invoker, client: TestClient) -> None:
monkeypatch.setattr("invokeai.app.api.routers.images.ApiDependencies", MockApiDependencies(mock_invoker)) monkeypatch.setattr("invokeai.app.api.routers.images.ApiDependencies", MockApiDependencies(mock_invoker))
def mock_add_task(*args, **kwargs): def mock_add_task(*args, **kwargs):
@ -96,7 +106,7 @@ def test_get_bulk_download_image_not_found(monkeypatch: Any, mock_invoker: Invok
def test_get_bulk_download_image_image_deleted_after_response( def test_get_bulk_download_image_image_deleted_after_response(
monkeypatch: Any, mock_invoker: Invoker, tmp_path: Path monkeypatch: Any, mock_invoker: Invoker, tmp_path: Path, client: TestClient
) -> None: ) -> None:
mock_file: Path = tmp_path / "test.zip" mock_file: Path = tmp_path / "test.zip"
mock_file.write_text("contents") mock_file.write_text("contents")

View File

@ -166,7 +166,7 @@ def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None:
# assert re.search("division by zero", captured.err) # assert re.search("division by zero", captured.err)
queue.stop() queue.stop()
@pytest.mark.timeout(timeout=15, method="thread")
def test_cancel(tmp_path: Path, session: Session) -> None: def test_cancel(tmp_path: Path, session: Session) -> None:
event_bus = TestEventService() event_bus = TestEventService()
@ -182,6 +182,9 @@ def test_cancel(tmp_path: Path, session: Session) -> None:
nonlocal cancelled nonlocal cancelled
cancelled = True cancelled = True
def handler(signum, frame):
raise TimeoutError("Join took too long to return")
job = queue.download( job = queue.download(
source=AnyHttpUrl("http://www.civitai.com/models/12345"), source=AnyHttpUrl("http://www.civitai.com/models/12345"),
dest=tmp_path, dest=tmp_path,

View File

@ -195,7 +195,7 @@ def test_delete_register(
with pytest.raises(UnknownModelException): with pytest.raises(UnknownModelException):
store.get_model(key) store.get_model(key)
@pytest.mark.timeout(timeout=20, method="thread")
def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
source = URLModelSource(url=Url("https://www.test.foo/download/test_embedding.safetensors")) source = URLModelSource(url=Url("https://www.test.foo/download/test_embedding.safetensors"))
@ -220,7 +220,7 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config:
event_names = [x.event_name for x in bus.events] event_names = [x.event_name for x in bus.events]
assert event_names == ["model_install_downloading", "model_install_running", "model_install_completed"] assert event_names == ["model_install_downloading", "model_install_running", "model_install_completed"]
@pytest.mark.timeout(timeout=20, method="thread")
def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo")) source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo"))

View File

@ -5,6 +5,8 @@
# We import the model_installer and torch_device fixtures here so that they can be used by all tests. Flake8 does not # We import the model_installer and torch_device fixtures here so that they can be used by all tests. Flake8 does not
# play well with fixtures (F401 and F811), so this is cleaner than importing in all files that use these fixtures. # play well with fixtures (F401 and F811), so this is cleaner than importing in all files that use these fixtures.
import logging import logging
import shutil
from pathlib import Path
import pytest import pytest
@ -58,3 +60,11 @@ def mock_services() -> InvocationServices:
@pytest.fixture() @pytest.fixture()
def mock_invoker(mock_services: InvocationServices) -> Invoker: def mock_invoker(mock_services: InvocationServices) -> Invoker:
return Invoker(services=mock_services) return Invoker(services=mock_services)
@pytest.fixture(scope="module")
def invokeai_root_dir(tmp_path_factory) -> Path:
root_template = Path(__file__).parent.resolve() / "backend/model_manager/data/invokeai_root"
temp_dir: Path = tmp_path_factory.mktemp("data") / "invokeai_root"
shutil.copytree(root_template, temp_dir)
return temp_dir