From e18533e3b51c39366b5a2f27843b9231d391cdab Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 29 Feb 2024 06:34:59 -0500 Subject: [PATCH] add debugging statements and a timeout to download test --- .../app/services/download/download_default.py | 22 ++++++++++++++--- .../model_install/model_install_default.py | 2 +- pyproject.toml | 2 ++ tests/app/routers/test_images.py | 24 +++++++++++++------ .../services/download/test_download_queue.py | 5 +++- .../model_install/test_model_install.py | 4 ++-- tests/conftest.py | 10 ++++++++ 7 files changed, 55 insertions(+), 14 deletions(-) diff --git a/invokeai/app/services/download/download_default.py b/invokeai/app/services/download/download_default.py index 50cac80d09..c336ff4c8c 100644 --- a/invokeai/app/services/download/download_default.py +++ b/invokeai/app/services/download/download_default.py @@ -221,22 +221,29 @@ class DownloadQueueService(DownloadQueueServiceBase): except Empty: continue try: + print(f'DEBUG: job [{job.id}] started', flush=True) job.job_started = get_iso_timestamp() self._do_download(job) + print(f'DEBUG: job [{job.id}] download completed', flush=True) self._signal_job_complete(job) - + print(f'DEBUG: job [{job.id}] signaled completion', flush=True) except (OSError, HTTPError) as excp: job.error_type = excp.__class__.__name__ + f"({str(excp)})" job.error = traceback.format_exc() self._signal_job_error(job, excp) + print(f'DEBUG: job [{job.id}] signaled error', flush=True) except DownloadJobCancelledException: self._signal_job_cancelled(job) self._cleanup_cancelled_job(job) - + print(f'DEBUG: job [{job.id}] signaled cancelled', flush=True) + finally: + print(f'DEBUG: job [{job.id}] signalling completion', flush=True) job.job_ended = get_iso_timestamp() self._job_completed_event.set() # signal a change to terminal state + print(f'DEBUG: job [{job.id}] set job completion event', flush=True) self._queue.task_done() + print(f'DEBUG: job [{job.id}] done', flush=True) self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.") def _do_download(self, job: DownloadJob) -> None: @@ -244,7 +251,8 @@ class DownloadQueueService(DownloadQueueServiceBase): url = job.source header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {} open_mode = "wb" - + print(f'DEBUG: In _do_download [0]', flush=True) + # Make a streaming request. This will retrieve headers including # content-length and content-disposition, but not fetch any content itself resp = self._requests.get(str(url), headers=header, stream=True) @@ -255,6 +263,8 @@ class DownloadQueueService(DownloadQueueServiceBase): content_length = int(resp.headers.get("content-length", 0)) job.total_bytes = content_length + print(f'DEBUG: In _do_download [1]') + if job.dest.is_dir(): file_name = os.path.basename(str(url.path)) # default is to use the last bit of the URL @@ -282,6 +292,7 @@ class DownloadQueueService(DownloadQueueServiceBase): # signal caller that the download is starting. At this point, key fields such as # download_path and total_bytes will be populated. We call it here because the might # discover that the local file is already complete and generate a COMPLETED status. + print(f'DEBUG: In _do_download [2]', flush=True) self._signal_job_started(job) # "range not satisfiable" - local file is at least as large as the remote file @@ -297,12 +308,15 @@ class DownloadQueueService(DownloadQueueServiceBase): elif resp.status_code != 200: raise HTTPError(resp.reason) + print(f'DEBUG: In _do_download [3]', flush=True) + self._logger.debug(f"{job.source}: Downloading {job.download_path}") report_delta = job.total_bytes / 100 # report every 1% change last_report_bytes = 0 # DOWNLOAD LOOP with open(in_progress_path, open_mode) as file: + print(f'DEBUG: In _do_download loop [4]', flush=True) for data in resp.iter_content(chunk_size=DOWNLOAD_CHUNK_SIZE): if job.cancelled: raise DownloadJobCancelledException("Job was cancelled at caller's request") @@ -315,6 +329,8 @@ class DownloadQueueService(DownloadQueueServiceBase): # if we get here we are done and can rename the file to the original dest self._logger.debug(f"{job.source}: saved to {job.download_path} (bytes={job.bytes})") in_progress_path.rename(job.download_path) + print(f'DEBUG: In _do_download [5]', flush=True) + def _validate_filename(self, directory: str, filename: str) -> bool: pc_name_max = os.pathconf(directory, "PC_NAME_MAX") if hasattr(os, "pathconf") else 260 # hardcoded for windows diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index a48cf92b99..fe8124923f 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -284,7 +284,7 @@ class ModelInstallService(ModelInstallServiceBase): def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102 self._cached_model_paths = {Path(x.path) for x in self.record_store.all_models()} callback = self._scan_install if install else self._scan_register - search = ModelSearch(on_model_found=callback) + search = ModelSearch(on_model_found=callback, config=self._app_config) self._models_installed.clear() search.search(scan_dir) return list(self._models_installed) diff --git a/pyproject.toml b/pyproject.toml index f460806354..661adfe4c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -118,6 +118,7 @@ dependencies = [ "pre-commit", "pytest>6.0.0", "pytest-cov", + "pytest-timeout", "pytest-datadir", "requests_testadapter", "httpx", @@ -189,6 +190,7 @@ version = { attr = "invokeai.version.__version__" } addopts = "--cov-report term --cov-report html --cov-report xml --strict-markers -m \"not slow\"" markers = [ "slow: Marks tests as slow. Disabled by default. To run all tests, use -m \"\". To run only slow tests, use -m \"slow\".", + "timeout: 60" ] [tool.coverage.run] branch = true diff --git a/tests/app/routers/test_images.py b/tests/app/routers/test_images.py index 5cb8cf1c37..c0da3ec51c 100644 --- a/tests/app/routers/test_images.py +++ b/tests/app/routers/test_images.py @@ -1,6 +1,8 @@ +import os from pathlib import Path from typing import Any +import pytest from fastapi import BackgroundTasks from fastapi.testclient import TestClient @@ -9,7 +11,11 @@ from invokeai.app.api_app import app from invokeai.app.services.board_records.board_records_common import BoardRecord from invokeai.app.services.invoker import Invoker -client = TestClient(app) + +@pytest.fixture(autouse=True, scope="module") +def client(invokeai_root_dir: Path) -> TestClient: + os.environ["INVOKEAI_ROOT"] = invokeai_root_dir.as_posix() + return TestClient(app) class MockApiDependencies(ApiDependencies): @@ -19,7 +25,7 @@ class MockApiDependencies(ApiDependencies): self.invoker = invoker -def test_download_images_from_list(monkeypatch: Any, mock_invoker: Invoker) -> None: +def test_download_images_from_list(monkeypatch: Any, mock_invoker: Invoker, client: TestClient) -> None: prepare_download_images_test(monkeypatch, mock_invoker) response = client.post("/api/v1/images/download", json={"image_names": ["test.png"]}) @@ -28,7 +34,9 @@ def test_download_images_from_list(monkeypatch: Any, mock_invoker: Invoker) -> N assert json_response["bulk_download_item_name"] == "test.zip" -def test_download_images_from_board_id_empty_image_name_list(monkeypatch: Any, mock_invoker: Invoker) -> None: +def test_download_images_from_board_id_empty_image_name_list( + monkeypatch: Any, mock_invoker: Invoker, client: TestClient +) -> None: expected_board_name = "test" def mock_get(*args, **kwargs): @@ -56,7 +64,9 @@ def prepare_download_images_test(monkeypatch: Any, mock_invoker: Invoker) -> Non monkeypatch.setattr(BackgroundTasks, "add_task", mock_add_task) -def test_download_images_with_empty_image_list_and_no_board_id(monkeypatch: Any, mock_invoker: Invoker) -> None: +def test_download_images_with_empty_image_list_and_no_board_id( + monkeypatch: Any, mock_invoker: Invoker, client: TestClient +) -> None: prepare_download_images_test(monkeypatch, mock_invoker) response = client.post("/api/v1/images/download", json={"image_names": []}) @@ -64,7 +74,7 @@ def test_download_images_with_empty_image_list_and_no_board_id(monkeypatch: Any, assert response.status_code == 400 -def test_get_bulk_download_image(tmp_path: Path, monkeypatch: Any, mock_invoker: Invoker) -> None: +def test_get_bulk_download_image(tmp_path: Path, monkeypatch: Any, mock_invoker: Invoker, client: TestClient) -> None: mock_file: Path = tmp_path / "test.zip" mock_file.write_text("contents") @@ -82,7 +92,7 @@ def test_get_bulk_download_image(tmp_path: Path, monkeypatch: Any, mock_invoker: assert response.content == b"contents" -def test_get_bulk_download_image_not_found(monkeypatch: Any, mock_invoker: Invoker) -> None: +def test_get_bulk_download_image_not_found(monkeypatch: Any, mock_invoker: Invoker, client: TestClient) -> None: monkeypatch.setattr("invokeai.app.api.routers.images.ApiDependencies", MockApiDependencies(mock_invoker)) def mock_add_task(*args, **kwargs): @@ -96,7 +106,7 @@ def test_get_bulk_download_image_not_found(monkeypatch: Any, mock_invoker: Invok def test_get_bulk_download_image_image_deleted_after_response( - monkeypatch: Any, mock_invoker: Invoker, tmp_path: Path + monkeypatch: Any, mock_invoker: Invoker, tmp_path: Path, client: TestClient ) -> None: mock_file: Path = tmp_path / "test.zip" mock_file.write_text("contents") diff --git a/tests/app/services/download/test_download_queue.py b/tests/app/services/download/test_download_queue.py index 9c1826170e..77d77836ec 100644 --- a/tests/app/services/download/test_download_queue.py +++ b/tests/app/services/download/test_download_queue.py @@ -166,7 +166,7 @@ def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None: # assert re.search("division by zero", captured.err) queue.stop() - +@pytest.mark.timeout(timeout=15, method="thread") def test_cancel(tmp_path: Path, session: Session) -> None: event_bus = TestEventService() @@ -182,6 +182,9 @@ def test_cancel(tmp_path: Path, session: Session) -> None: nonlocal cancelled cancelled = True + def handler(signum, frame): + raise TimeoutError("Join took too long to return") + job = queue.download( source=AnyHttpUrl("http://www.civitai.com/models/12345"), dest=tmp_path, diff --git a/tests/app/services/model_install/test_model_install.py b/tests/app/services/model_install/test_model_install.py index 80b106c5cb..c0a6e7b2b6 100644 --- a/tests/app/services/model_install/test_model_install.py +++ b/tests/app/services/model_install/test_model_install.py @@ -195,7 +195,7 @@ def test_delete_register( with pytest.raises(UnknownModelException): store.get_model(key) - +@pytest.mark.timeout(timeout=20, method="thread") def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: source = URLModelSource(url=Url("https://www.test.foo/download/test_embedding.safetensors")) @@ -220,7 +220,7 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: event_names = [x.event_name for x in bus.events] assert event_names == ["model_install_downloading", "model_install_running", "model_install_completed"] - +@pytest.mark.timeout(timeout=20, method="thread") def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo")) diff --git a/tests/conftest.py b/tests/conftest.py index a483b7529a..06d29b05be 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,8 @@ # We import the model_installer and torch_device fixtures here so that they can be used by all tests. Flake8 does not # play well with fixtures (F401 and F811), so this is cleaner than importing in all files that use these fixtures. import logging +import shutil +from pathlib import Path import pytest @@ -58,3 +60,11 @@ def mock_services() -> InvocationServices: @pytest.fixture() def mock_invoker(mock_services: InvocationServices) -> Invoker: return Invoker(services=mock_services) + + +@pytest.fixture(scope="module") +def invokeai_root_dir(tmp_path_factory) -> Path: + root_template = Path(__file__).parent.resolve() / "backend/model_manager/data/invokeai_root" + temp_dir: Path = tmp_path_factory.mktemp("data") / "invokeai_root" + shutil.copytree(root_template, temp_dir) + return temp_dir