returning the bulk_download_item_name on response for possible polling

This commit is contained in:
Stefan Tobler
2024-01-28 01:23:38 -05:00
committed by psychedelicious
parent b5ca1643a6
commit d0f3571e59
5 changed files with 116 additions and 33 deletions

View File

@ -7,6 +7,7 @@ from fastapi.testclient import TestClient
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.api_app import app
from invokeai.app.services.board_records.board_records_common import BoardRecord
from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage
from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService
from invokeai.app.services.config.config_default import InvokeAIAppConfig
@ -70,17 +71,32 @@ class MockApiDependencies(ApiDependencies):
def test_download_images_from_list(monkeypatch: Any, mock_invoker: Invoker) -> None:
prepare_download_images_test(monkeypatch, mock_invoker)
response = client.post("/api/v1/images/download", json={"image_names": ["test.png"]})
def mock_uuid_string():
return "test"
# You have to patch the function within the module it's being imported into. This is strange, but it works.
# See http://www.gregreda.com/2021/06/28/mocking-imported-module-function-python/
monkeypatch.setattr("invokeai.app.api.routers.images.uuid_string", mock_uuid_string)
response = client.post("/api/v1/images/download", json={"image_names": ["test.png"]})
json_response = response.json()
assert response.status_code == 202
assert json_response["bulk_download_item_name"] == "test"
def test_download_images_from_board_id_empty_image_name_list(monkeypatch: Any, mock_invoker: Invoker) -> None:
expected_board_name = "test"
def mock_get(*args, **kwargs):
return BoardRecord(board_id="12345", board_name=expected_board_name, created_at="None", updated_at="None")
monkeypatch.setattr(mock_invoker.services.board_records, "get", mock_get)
prepare_download_images_test(monkeypatch, mock_invoker)
response = client.post("/api/v1/images/download", json={"image_names": [], "board_id": "test"})
response = client.post("/api/v1/images/download", json={"board_id": "test"})
json_response = response.json()
assert response.status_code == 202
assert json_response["bulk_download_item_name"] == "test.zip"
def prepare_download_images_test(monkeypatch: Any, mock_invoker: Invoker) -> None:

View File

@ -125,7 +125,7 @@ def test_handler_image_names(tmp_path: Path, monkeypatch: Any, mock_image_dto: I
bulk_download_service = BulkDownloadService(tmp_path)
bulk_download_service.start(mock_invoker)
bulk_download_service.handler([mock_image_dto.image_name], None)
bulk_download_service.handler([mock_image_dto.image_name], None, None)
assert_handler_success(
expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events
@ -151,7 +151,7 @@ def test_handler_board_id(tmp_path: Path, monkeypatch: Any, mock_image_dto: Imag
bulk_download_service = BulkDownloadService(tmp_path)
bulk_download_service.start(mock_invoker)
bulk_download_service.handler([], "test")
bulk_download_service.handler([], "test", None)
assert_handler_success(
expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events
@ -173,7 +173,31 @@ def test_handler_board_id_default(tmp_path: Path, monkeypatch: Any, mock_image_d
bulk_download_service = BulkDownloadService(tmp_path)
bulk_download_service.start(mock_invoker)
bulk_download_service.handler([], "none")
bulk_download_service.handler([], "none", None)
assert_handler_success(
expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events
)
def test_handler_bulk_download__item_id_given(
tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker
):
"""Test that the handler creates the zip file correctly when given a pregenerated bulk download item id."""
_, expected_image_path, mock_image_contents = prepare_handler_test(
tmp_path, monkeypatch, mock_image_dto, mock_invoker
)
expected_zip_path: Path = tmp_path / "bulk_downloads" / "test_id.zip"
def mock_get_many(*args, **kwargs):
return OffsetPaginatedResults(limit=-1, total=1, offset=0, items=[mock_image_dto])
monkeypatch.setattr(mock_invoker.services.images, "get_many", mock_get_many)
bulk_download_service = BulkDownloadService(tmp_path)
bulk_download_service.start(mock_invoker)
bulk_download_service.handler([mock_image_dto.image_name], None, "test_id")
assert_handler_success(
expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events
@ -242,20 +266,6 @@ def assert_handler_success(
assert event_bus.events[1].payload["bulk_download_item_name"] == os.path.basename(expected_zip_path)
def test_stop(tmp_path: Path) -> None:
"""Test that the stop method removes the bulk_downloads directory."""
bulk_download_service = BulkDownloadService(tmp_path)
mock_file: Path = tmp_path / "bulk_downloads" / "test.zip"
mock_file.write_text("contents")
bulk_download_service.stop()
assert (tmp_path / "bulk_downloads").exists()
assert len(os.listdir(tmp_path / "bulk_downloads")) == 0
def test_handler_on_image_not_found(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker):
"""Test that the handler emits an error event when the image is not found."""
exception: Exception = ImageRecordNotFoundException("Image not found")
@ -309,7 +319,7 @@ def execute_handler_test_on_error(
):
bulk_download_service = BulkDownloadService(tmp_path)
bulk_download_service.start(mock_invoker)
bulk_download_service.handler([mock_image_dto.image_name], None)
bulk_download_service.handler([mock_image_dto.image_name], None, None)
event_bus: DummyEventService = mock_invoker.services.events
@ -319,6 +329,35 @@ def execute_handler_test_on_error(
assert event_bus.events[1].payload["error"] == error.__str__()
def test_get_board_name(tmp_path: Path, monkeypatch: Any, mock_invoker: Invoker):
"""Test that the get_board_name function returns the correct board name."""
expected_board_name = "board1"
def mock_get(*args, **kwargs):
return BoardRecord(board_id="12345", board_name=expected_board_name, created_at="None", updated_at="None")
monkeypatch.setattr(mock_invoker.services.board_records, "get", mock_get)
bulk_download_service = BulkDownloadService(tmp_path)
bulk_download_service.start(mock_invoker)
board_name = bulk_download_service.get_board_name("12345")
assert board_name == expected_board_name
def test_get_board_name_default(tmp_path: Path, mock_invoker: Invoker):
"""Test that the get_board_name function returns the correct board name."""
expected_board_name = "Uncategorized"
bulk_download_service = BulkDownloadService(tmp_path)
bulk_download_service.start(mock_invoker)
board_name = bulk_download_service.get_board_name("none")
assert board_name == expected_board_name
def test_delete(tmp_path: Path):
"""Test that the delete method removes the bulk download file."""
@ -332,8 +371,9 @@ def test_delete(tmp_path: Path):
assert (tmp_path / "bulk_downloads").exists()
assert len(os.listdir(tmp_path / "bulk_downloads")) == 0
def test_stop(tmp_path: Path):
"""Test that the delete method removes the bulk download file."""
"""Test that the stop method removes the bulk download file and not any directories."""
bulk_download_service = BulkDownloadService(tmp_path)
@ -343,7 +383,6 @@ def test_stop(tmp_path: Path):
mock_dir: Path = tmp_path / "bulk_downloads" / "test"
mock_dir.mkdir(parents=True, exist_ok=True)
bulk_download_service.stop()
assert (tmp_path / "bulk_downloads").exists()