mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
using temp directory for downloads
This commit is contained in:
committed by
psychedelicious
parent
d0f3571e59
commit
f15aa562c2
@ -1,5 +1,6 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Any
|
||||
from zipfile import ZipFile
|
||||
|
||||
@ -86,9 +87,28 @@ def mock_invoker(mock_services: InvocationServices) -> Invoker:
|
||||
return Invoker(services=mock_services)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_temporary_directory(monkeypatch: Any, tmp_path: Path):
|
||||
"""Mock the TemporaryDirectory class so that it uses the tmp_path fixture."""
|
||||
|
||||
class MockTemporaryDirectory(TemporaryDirectory):
|
||||
def __init__(self):
|
||||
super().__init__(dir=tmp_path)
|
||||
self.name = tmp_path
|
||||
|
||||
def mock_TemporaryDirectory(*args, **kwargs):
|
||||
return MockTemporaryDirectory()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"invokeai.app.services.bulk_download.bulk_download_default.TemporaryDirectory", mock_TemporaryDirectory
|
||||
)
|
||||
|
||||
|
||||
def test_get_path_when_file_exists(tmp_path: Path) -> None:
|
||||
"""Test get_path when the file exists."""
|
||||
|
||||
bulk_download_service = BulkDownloadService(tmp_path)
|
||||
|
||||
# Create a directory at tmp_path/bulk_downloads
|
||||
test_bulk_downloads_dir: Path = tmp_path / "bulk_downloads"
|
||||
test_bulk_downloads_dir.mkdir(parents=True, exist_ok=True)
|
||||
@ -97,7 +117,6 @@ def test_get_path_when_file_exists(tmp_path: Path) -> None:
|
||||
test_file_path: Path = test_bulk_downloads_dir / "test.zip"
|
||||
test_file_path.touch()
|
||||
|
||||
bulk_download_service = BulkDownloadService(tmp_path)
|
||||
assert bulk_download_service.get_path("test.zip") == str(test_file_path)
|
||||
|
||||
|
||||
@ -164,7 +183,6 @@ def test_handler_board_id_default(tmp_path: Path, monkeypatch: Any, mock_image_d
|
||||
_, expected_image_path, mock_image_contents = prepare_handler_test(
|
||||
tmp_path, monkeypatch, mock_image_dto, mock_invoker
|
||||
)
|
||||
expected_zip_path: Path = tmp_path / "bulk_downloads" / "Uncategorized.zip"
|
||||
|
||||
def mock_get_many(*args, **kwargs):
|
||||
return OffsetPaginatedResults(limit=-1, total=1, offset=0, items=[mock_image_dto])
|
||||
@ -175,6 +193,8 @@ def test_handler_board_id_default(tmp_path: Path, monkeypatch: Any, mock_image_d
|
||||
bulk_download_service.start(mock_invoker)
|
||||
bulk_download_service.handler([], "none", None)
|
||||
|
||||
expected_zip_path: Path = tmp_path / "bulk_downloads" / "Uncategorized.zip"
|
||||
|
||||
assert_handler_success(
|
||||
expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events
|
||||
)
|
||||
@ -188,7 +208,6 @@ def test_handler_bulk_download__item_id_given(
|
||||
_, expected_image_path, mock_image_contents = prepare_handler_test(
|
||||
tmp_path, monkeypatch, mock_image_dto, mock_invoker
|
||||
)
|
||||
expected_zip_path: Path = tmp_path / "bulk_downloads" / "test_id.zip"
|
||||
|
||||
def mock_get_many(*args, **kwargs):
|
||||
return OffsetPaginatedResults(limit=-1, total=1, offset=0, items=[mock_image_dto])
|
||||
@ -199,6 +218,8 @@ def test_handler_bulk_download__item_id_given(
|
||||
bulk_download_service.start(mock_invoker)
|
||||
bulk_download_service.handler([mock_image_dto.image_name], None, "test_id")
|
||||
|
||||
expected_zip_path: Path = tmp_path / "bulk_downloads" / "test_id.zip"
|
||||
|
||||
assert_handler_success(
|
||||
expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events
|
||||
)
|
||||
@ -341,7 +362,7 @@ def test_get_board_name(tmp_path: Path, monkeypatch: Any, mock_invoker: Invoker)
|
||||
|
||||
bulk_download_service = BulkDownloadService(tmp_path)
|
||||
bulk_download_service.start(mock_invoker)
|
||||
board_name = bulk_download_service.get_board_name("12345")
|
||||
board_name = bulk_download_service.get_clean_board_name("12345")
|
||||
|
||||
assert board_name == expected_board_name
|
||||
|
||||
@ -353,7 +374,7 @@ def test_get_board_name_default(tmp_path: Path, mock_invoker: Invoker):
|
||||
|
||||
bulk_download_service = BulkDownloadService(tmp_path)
|
||||
bulk_download_service.start(mock_invoker)
|
||||
board_name = bulk_download_service.get_board_name("none")
|
||||
board_name = bulk_download_service.get_clean_board_name("none")
|
||||
|
||||
assert board_name == expected_board_name
|
||||
|
||||
@ -385,6 +406,4 @@ def test_stop(tmp_path: Path):
|
||||
|
||||
bulk_download_service.stop()
|
||||
|
||||
assert (tmp_path / "bulk_downloads").exists()
|
||||
assert mock_dir.exists()
|
||||
assert len(os.listdir(tmp_path / "bulk_downloads")) == 1
|
||||
assert not (tmp_path / "bulk_downloads").exists()
|
||||
|
Reference in New Issue
Block a user