mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add debugging statements and a timeout to download test
This commit is contained in:
parent
dd9daf8efb
commit
e18533e3b5
@ -221,22 +221,29 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
except Empty:
|
except Empty:
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
|
print(f'DEBUG: job [{job.id}] started', flush=True)
|
||||||
job.job_started = get_iso_timestamp()
|
job.job_started = get_iso_timestamp()
|
||||||
self._do_download(job)
|
self._do_download(job)
|
||||||
|
print(f'DEBUG: job [{job.id}] download completed', flush=True)
|
||||||
self._signal_job_complete(job)
|
self._signal_job_complete(job)
|
||||||
|
print(f'DEBUG: job [{job.id}] signaled completion', flush=True)
|
||||||
except (OSError, HTTPError) as excp:
|
except (OSError, HTTPError) as excp:
|
||||||
job.error_type = excp.__class__.__name__ + f"({str(excp)})"
|
job.error_type = excp.__class__.__name__ + f"({str(excp)})"
|
||||||
job.error = traceback.format_exc()
|
job.error = traceback.format_exc()
|
||||||
self._signal_job_error(job, excp)
|
self._signal_job_error(job, excp)
|
||||||
|
print(f'DEBUG: job [{job.id}] signaled error', flush=True)
|
||||||
except DownloadJobCancelledException:
|
except DownloadJobCancelledException:
|
||||||
self._signal_job_cancelled(job)
|
self._signal_job_cancelled(job)
|
||||||
self._cleanup_cancelled_job(job)
|
self._cleanup_cancelled_job(job)
|
||||||
|
print(f'DEBUG: job [{job.id}] signaled cancelled', flush=True)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
|
print(f'DEBUG: job [{job.id}] signalling completion', flush=True)
|
||||||
job.job_ended = get_iso_timestamp()
|
job.job_ended = get_iso_timestamp()
|
||||||
self._job_completed_event.set() # signal a change to terminal state
|
self._job_completed_event.set() # signal a change to terminal state
|
||||||
|
print(f'DEBUG: job [{job.id}] set job completion event', flush=True)
|
||||||
self._queue.task_done()
|
self._queue.task_done()
|
||||||
|
print(f'DEBUG: job [{job.id}] done', flush=True)
|
||||||
self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.")
|
self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.")
|
||||||
|
|
||||||
def _do_download(self, job: DownloadJob) -> None:
|
def _do_download(self, job: DownloadJob) -> None:
|
||||||
@ -244,7 +251,8 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
url = job.source
|
url = job.source
|
||||||
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
|
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
|
||||||
open_mode = "wb"
|
open_mode = "wb"
|
||||||
|
print(f'DEBUG: In _do_download [0]', flush=True)
|
||||||
|
|
||||||
# Make a streaming request. This will retrieve headers including
|
# Make a streaming request. This will retrieve headers including
|
||||||
# content-length and content-disposition, but not fetch any content itself
|
# content-length and content-disposition, but not fetch any content itself
|
||||||
resp = self._requests.get(str(url), headers=header, stream=True)
|
resp = self._requests.get(str(url), headers=header, stream=True)
|
||||||
@ -255,6 +263,8 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
content_length = int(resp.headers.get("content-length", 0))
|
content_length = int(resp.headers.get("content-length", 0))
|
||||||
job.total_bytes = content_length
|
job.total_bytes = content_length
|
||||||
|
|
||||||
|
print(f'DEBUG: In _do_download [1]')
|
||||||
|
|
||||||
if job.dest.is_dir():
|
if job.dest.is_dir():
|
||||||
file_name = os.path.basename(str(url.path)) # default is to use the last bit of the URL
|
file_name = os.path.basename(str(url.path)) # default is to use the last bit of the URL
|
||||||
|
|
||||||
@ -282,6 +292,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
# signal caller that the download is starting. At this point, key fields such as
|
# signal caller that the download is starting. At this point, key fields such as
|
||||||
# download_path and total_bytes will be populated. We call it here because the might
|
# download_path and total_bytes will be populated. We call it here because the might
|
||||||
# discover that the local file is already complete and generate a COMPLETED status.
|
# discover that the local file is already complete and generate a COMPLETED status.
|
||||||
|
print(f'DEBUG: In _do_download [2]', flush=True)
|
||||||
self._signal_job_started(job)
|
self._signal_job_started(job)
|
||||||
|
|
||||||
# "range not satisfiable" - local file is at least as large as the remote file
|
# "range not satisfiable" - local file is at least as large as the remote file
|
||||||
@ -297,12 +308,15 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
elif resp.status_code != 200:
|
elif resp.status_code != 200:
|
||||||
raise HTTPError(resp.reason)
|
raise HTTPError(resp.reason)
|
||||||
|
|
||||||
|
print(f'DEBUG: In _do_download [3]', flush=True)
|
||||||
|
|
||||||
self._logger.debug(f"{job.source}: Downloading {job.download_path}")
|
self._logger.debug(f"{job.source}: Downloading {job.download_path}")
|
||||||
report_delta = job.total_bytes / 100 # report every 1% change
|
report_delta = job.total_bytes / 100 # report every 1% change
|
||||||
last_report_bytes = 0
|
last_report_bytes = 0
|
||||||
|
|
||||||
# DOWNLOAD LOOP
|
# DOWNLOAD LOOP
|
||||||
with open(in_progress_path, open_mode) as file:
|
with open(in_progress_path, open_mode) as file:
|
||||||
|
print(f'DEBUG: In _do_download loop [4]', flush=True)
|
||||||
for data in resp.iter_content(chunk_size=DOWNLOAD_CHUNK_SIZE):
|
for data in resp.iter_content(chunk_size=DOWNLOAD_CHUNK_SIZE):
|
||||||
if job.cancelled:
|
if job.cancelled:
|
||||||
raise DownloadJobCancelledException("Job was cancelled at caller's request")
|
raise DownloadJobCancelledException("Job was cancelled at caller's request")
|
||||||
@ -315,6 +329,8 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
# if we get here we are done and can rename the file to the original dest
|
# if we get here we are done and can rename the file to the original dest
|
||||||
self._logger.debug(f"{job.source}: saved to {job.download_path} (bytes={job.bytes})")
|
self._logger.debug(f"{job.source}: saved to {job.download_path} (bytes={job.bytes})")
|
||||||
in_progress_path.rename(job.download_path)
|
in_progress_path.rename(job.download_path)
|
||||||
|
print(f'DEBUG: In _do_download [5]', flush=True)
|
||||||
|
|
||||||
|
|
||||||
def _validate_filename(self, directory: str, filename: str) -> bool:
|
def _validate_filename(self, directory: str, filename: str) -> bool:
|
||||||
pc_name_max = os.pathconf(directory, "PC_NAME_MAX") if hasattr(os, "pathconf") else 260 # hardcoded for windows
|
pc_name_max = os.pathconf(directory, "PC_NAME_MAX") if hasattr(os, "pathconf") else 260 # hardcoded for windows
|
||||||
|
@ -284,7 +284,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
|
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
|
||||||
self._cached_model_paths = {Path(x.path) for x in self.record_store.all_models()}
|
self._cached_model_paths = {Path(x.path) for x in self.record_store.all_models()}
|
||||||
callback = self._scan_install if install else self._scan_register
|
callback = self._scan_install if install else self._scan_register
|
||||||
search = ModelSearch(on_model_found=callback)
|
search = ModelSearch(on_model_found=callback, config=self._app_config)
|
||||||
self._models_installed.clear()
|
self._models_installed.clear()
|
||||||
search.search(scan_dir)
|
search.search(scan_dir)
|
||||||
return list(self._models_installed)
|
return list(self._models_installed)
|
||||||
|
@ -118,6 +118,7 @@ dependencies = [
|
|||||||
"pre-commit",
|
"pre-commit",
|
||||||
"pytest>6.0.0",
|
"pytest>6.0.0",
|
||||||
"pytest-cov",
|
"pytest-cov",
|
||||||
|
"pytest-timeout",
|
||||||
"pytest-datadir",
|
"pytest-datadir",
|
||||||
"requests_testadapter",
|
"requests_testadapter",
|
||||||
"httpx",
|
"httpx",
|
||||||
@ -189,6 +190,7 @@ version = { attr = "invokeai.version.__version__" }
|
|||||||
addopts = "--cov-report term --cov-report html --cov-report xml --strict-markers -m \"not slow\""
|
addopts = "--cov-report term --cov-report html --cov-report xml --strict-markers -m \"not slow\""
|
||||||
markers = [
|
markers = [
|
||||||
"slow: Marks tests as slow. Disabled by default. To run all tests, use -m \"\". To run only slow tests, use -m \"slow\".",
|
"slow: Marks tests as slow. Disabled by default. To run all tests, use -m \"\". To run only slow tests, use -m \"slow\".",
|
||||||
|
"timeout: 60"
|
||||||
]
|
]
|
||||||
[tool.coverage.run]
|
[tool.coverage.run]
|
||||||
branch = true
|
branch = true
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
from fastapi import BackgroundTasks
|
from fastapi import BackgroundTasks
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
@ -9,7 +11,11 @@ from invokeai.app.api_app import app
|
|||||||
from invokeai.app.services.board_records.board_records_common import BoardRecord
|
from invokeai.app.services.board_records.board_records_common import BoardRecord
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
|
|
||||||
client = TestClient(app)
|
|
||||||
|
@pytest.fixture(autouse=True, scope="module")
|
||||||
|
def client(invokeai_root_dir: Path) -> TestClient:
|
||||||
|
os.environ["INVOKEAI_ROOT"] = invokeai_root_dir.as_posix()
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
class MockApiDependencies(ApiDependencies):
|
class MockApiDependencies(ApiDependencies):
|
||||||
@ -19,7 +25,7 @@ class MockApiDependencies(ApiDependencies):
|
|||||||
self.invoker = invoker
|
self.invoker = invoker
|
||||||
|
|
||||||
|
|
||||||
def test_download_images_from_list(monkeypatch: Any, mock_invoker: Invoker) -> None:
|
def test_download_images_from_list(monkeypatch: Any, mock_invoker: Invoker, client: TestClient) -> None:
|
||||||
prepare_download_images_test(monkeypatch, mock_invoker)
|
prepare_download_images_test(monkeypatch, mock_invoker)
|
||||||
|
|
||||||
response = client.post("/api/v1/images/download", json={"image_names": ["test.png"]})
|
response = client.post("/api/v1/images/download", json={"image_names": ["test.png"]})
|
||||||
@ -28,7 +34,9 @@ def test_download_images_from_list(monkeypatch: Any, mock_invoker: Invoker) -> N
|
|||||||
assert json_response["bulk_download_item_name"] == "test.zip"
|
assert json_response["bulk_download_item_name"] == "test.zip"
|
||||||
|
|
||||||
|
|
||||||
def test_download_images_from_board_id_empty_image_name_list(monkeypatch: Any, mock_invoker: Invoker) -> None:
|
def test_download_images_from_board_id_empty_image_name_list(
|
||||||
|
monkeypatch: Any, mock_invoker: Invoker, client: TestClient
|
||||||
|
) -> None:
|
||||||
expected_board_name = "test"
|
expected_board_name = "test"
|
||||||
|
|
||||||
def mock_get(*args, **kwargs):
|
def mock_get(*args, **kwargs):
|
||||||
@ -56,7 +64,9 @@ def prepare_download_images_test(monkeypatch: Any, mock_invoker: Invoker) -> Non
|
|||||||
monkeypatch.setattr(BackgroundTasks, "add_task", mock_add_task)
|
monkeypatch.setattr(BackgroundTasks, "add_task", mock_add_task)
|
||||||
|
|
||||||
|
|
||||||
def test_download_images_with_empty_image_list_and_no_board_id(monkeypatch: Any, mock_invoker: Invoker) -> None:
|
def test_download_images_with_empty_image_list_and_no_board_id(
|
||||||
|
monkeypatch: Any, mock_invoker: Invoker, client: TestClient
|
||||||
|
) -> None:
|
||||||
prepare_download_images_test(monkeypatch, mock_invoker)
|
prepare_download_images_test(monkeypatch, mock_invoker)
|
||||||
|
|
||||||
response = client.post("/api/v1/images/download", json={"image_names": []})
|
response = client.post("/api/v1/images/download", json={"image_names": []})
|
||||||
@ -64,7 +74,7 @@ def test_download_images_with_empty_image_list_and_no_board_id(monkeypatch: Any,
|
|||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
def test_get_bulk_download_image(tmp_path: Path, monkeypatch: Any, mock_invoker: Invoker) -> None:
|
def test_get_bulk_download_image(tmp_path: Path, monkeypatch: Any, mock_invoker: Invoker, client: TestClient) -> None:
|
||||||
mock_file: Path = tmp_path / "test.zip"
|
mock_file: Path = tmp_path / "test.zip"
|
||||||
mock_file.write_text("contents")
|
mock_file.write_text("contents")
|
||||||
|
|
||||||
@ -82,7 +92,7 @@ def test_get_bulk_download_image(tmp_path: Path, monkeypatch: Any, mock_invoker:
|
|||||||
assert response.content == b"contents"
|
assert response.content == b"contents"
|
||||||
|
|
||||||
|
|
||||||
def test_get_bulk_download_image_not_found(monkeypatch: Any, mock_invoker: Invoker) -> None:
|
def test_get_bulk_download_image_not_found(monkeypatch: Any, mock_invoker: Invoker, client: TestClient) -> None:
|
||||||
monkeypatch.setattr("invokeai.app.api.routers.images.ApiDependencies", MockApiDependencies(mock_invoker))
|
monkeypatch.setattr("invokeai.app.api.routers.images.ApiDependencies", MockApiDependencies(mock_invoker))
|
||||||
|
|
||||||
def mock_add_task(*args, **kwargs):
|
def mock_add_task(*args, **kwargs):
|
||||||
@ -96,7 +106,7 @@ def test_get_bulk_download_image_not_found(monkeypatch: Any, mock_invoker: Invok
|
|||||||
|
|
||||||
|
|
||||||
def test_get_bulk_download_image_image_deleted_after_response(
|
def test_get_bulk_download_image_image_deleted_after_response(
|
||||||
monkeypatch: Any, mock_invoker: Invoker, tmp_path: Path
|
monkeypatch: Any, mock_invoker: Invoker, tmp_path: Path, client: TestClient
|
||||||
) -> None:
|
) -> None:
|
||||||
mock_file: Path = tmp_path / "test.zip"
|
mock_file: Path = tmp_path / "test.zip"
|
||||||
mock_file.write_text("contents")
|
mock_file.write_text("contents")
|
||||||
|
@ -166,7 +166,7 @@ def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None:
|
|||||||
# assert re.search("division by zero", captured.err)
|
# assert re.search("division by zero", captured.err)
|
||||||
queue.stop()
|
queue.stop()
|
||||||
|
|
||||||
|
@pytest.mark.timeout(timeout=15, method="thread")
|
||||||
def test_cancel(tmp_path: Path, session: Session) -> None:
|
def test_cancel(tmp_path: Path, session: Session) -> None:
|
||||||
event_bus = TestEventService()
|
event_bus = TestEventService()
|
||||||
|
|
||||||
@ -182,6 +182,9 @@ def test_cancel(tmp_path: Path, session: Session) -> None:
|
|||||||
nonlocal cancelled
|
nonlocal cancelled
|
||||||
cancelled = True
|
cancelled = True
|
||||||
|
|
||||||
|
def handler(signum, frame):
|
||||||
|
raise TimeoutError("Join took too long to return")
|
||||||
|
|
||||||
job = queue.download(
|
job = queue.download(
|
||||||
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
||||||
dest=tmp_path,
|
dest=tmp_path,
|
||||||
|
@ -195,7 +195,7 @@ def test_delete_register(
|
|||||||
with pytest.raises(UnknownModelException):
|
with pytest.raises(UnknownModelException):
|
||||||
store.get_model(key)
|
store.get_model(key)
|
||||||
|
|
||||||
|
@pytest.mark.timeout(timeout=20, method="thread")
|
||||||
def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||||
source = URLModelSource(url=Url("https://www.test.foo/download/test_embedding.safetensors"))
|
source = URLModelSource(url=Url("https://www.test.foo/download/test_embedding.safetensors"))
|
||||||
|
|
||||||
@ -220,7 +220,7 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config:
|
|||||||
event_names = [x.event_name for x in bus.events]
|
event_names = [x.event_name for x in bus.events]
|
||||||
assert event_names == ["model_install_downloading", "model_install_running", "model_install_completed"]
|
assert event_names == ["model_install_downloading", "model_install_running", "model_install_completed"]
|
||||||
|
|
||||||
|
@pytest.mark.timeout(timeout=20, method="thread")
|
||||||
def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||||
source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo"))
|
source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo"))
|
||||||
|
|
||||||
|
@ -5,6 +5,8 @@
|
|||||||
# We import the model_installer and torch_device fixtures here so that they can be used by all tests. Flake8 does not
|
# We import the model_installer and torch_device fixtures here so that they can be used by all tests. Flake8 does not
|
||||||
# play well with fixtures (F401 and F811), so this is cleaner than importing in all files that use these fixtures.
|
# play well with fixtures (F401 and F811), so this is cleaner than importing in all files that use these fixtures.
|
||||||
import logging
|
import logging
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -58,3 +60,11 @@ def mock_services() -> InvocationServices:
|
|||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def mock_invoker(mock_services: InvocationServices) -> Invoker:
|
def mock_invoker(mock_services: InvocationServices) -> Invoker:
|
||||||
return Invoker(services=mock_services)
|
return Invoker(services=mock_services)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def invokeai_root_dir(tmp_path_factory) -> Path:
|
||||||
|
root_template = Path(__file__).parent.resolve() / "backend/model_manager/data/invokeai_root"
|
||||||
|
temp_dir: Path = tmp_path_factory.mktemp("data") / "invokeai_root"
|
||||||
|
shutil.copytree(root_template, temp_dir)
|
||||||
|
return temp_dir
|
||||||
|
Loading…
Reference in New Issue
Block a user