From b1301e1cbc5c7e5994b8825990af4aa66f6121b4 Mon Sep 17 00:00:00 2001 From: Stefan Tobler Date: Sun, 28 Jan 2024 01:23:38 -0500 Subject: [PATCH] returning the bulk_download_item_name on response for possible polling --- invokeai/app/api/routers/images.py | 28 +++++-- .../bulk_download/bulk_download_base.py | 11 ++- .../bulk_download/bulk_download_default.py | 9 ++- tests/app/routers/test_images.py | 22 +++++- .../bulk_download/test_bulk_download.py | 79 ++++++++++++++----- 5 files changed, 116 insertions(+), 33 deletions(-) diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index d11c89c749..c12556aed6 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -1,6 +1,6 @@ import io import traceback -from typing import Optional +from typing import Optional, cast from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request, Response, UploadFile from fastapi.responses import FileResponse @@ -13,6 +13,7 @@ from invokeai.app.services.image_records.image_records_common import ImageCatego from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID, WorkflowWithoutIDValidator +from invokeai.app.util.misc import uuid_string from ..dependencies import ApiDependencies @@ -377,6 +378,7 @@ class ImagesDownloaded(BaseModel): response: Optional[str] = Field( description="If defined, the message to display to the user when images begin downloading" ) + bulk_download_item_name: str = Field(description="The bulk download item name of the bulk download item") @images_router.post( @@ -384,15 +386,31 @@ class ImagesDownloaded(BaseModel): ) async def download_images_from_list( background_tasks: BackgroundTasks, - image_names: list[str] = Body(description="The list of names of images to download", embed=True), + image_names: Optional[list[str]] = Body( + default=None, description="The list of names of images to download", embed=True + ), board_id: Optional[str] = Body( default=None, description="The board from which image should be downloaded from", embed=True ), ) -> ImagesDownloaded: if (image_names is None or len(image_names) == 0) and board_id is None: raise HTTPException(status_code=400, detail="No images or board id specified.") - background_tasks.add_task(ApiDependencies.invoker.services.bulk_download.handler, image_names, board_id) - return ImagesDownloaded(response="Your images are preparing to be downloaded") + bulk_download_item_id: str = uuid_string() if board_id is None else board_id + board_name: str = ( + "" if board_id is None else ApiDependencies.invoker.services.board_records.get(board_id).board_name + ) + + # Type narrowing handled above ^, we know that image_names is not None, trying to keep null checks at the boundaries + background_tasks.add_task( + ApiDependencies.invoker.services.bulk_download.handler, + cast(list[str], image_names), + board_id, + bulk_download_item_id, + ) + return ImagesDownloaded( + response="Your images are preparing to be downloaded", + bulk_download_item_name=bulk_download_item_id if board_id is None else board_name + ".zip", + ) @images_router.api_route( @@ -410,7 +428,7 @@ async def download_images_from_list( ) async def get_bulk_download_item( background_tasks: BackgroundTasks, - bulk_download_item_name: str = Path(description="The bulk_download_item_id of the bulk download item to get"), + bulk_download_item_name: str = Path(description="The bulk_download_item_name of the bulk download item to get"), ) -> FileResponse: """Gets a bulk download zip file""" try: diff --git a/invokeai/app/services/bulk_download/bulk_download_base.py b/invokeai/app/services/bulk_download/bulk_download_base.py index a1071f254a..d6b0e62211 100644 --- a/invokeai/app/services/bulk_download/bulk_download_base.py +++ b/invokeai/app/services/bulk_download/bulk_download_base.py @@ -26,7 +26,7 @@ class BulkDownloadBase(ABC): """ @abstractmethod - def handler(self, image_names: list[str], board_id: Optional[str]) -> None: + def handler(self, image_names: list[str], board_id: Optional[str], bulk_download_item_id: Optional[str]) -> None: """ Starts a a bulk download job. @@ -44,6 +44,15 @@ class BulkDownloadBase(ABC): :return: The path to the bulk download file. """ + @abstractmethod + def get_board_name(self, board_id: str) -> str: + """ + Get the name of the board. + + :param board_id: The ID of the board. + :return: The name of the board. + """ + @abstractmethod def stop(self, *args, **kwargs) -> None: """ diff --git a/invokeai/app/services/bulk_download/bulk_download_default.py b/invokeai/app/services/bulk_download/bulk_download_default.py index 87966ad622..be70dea2c1 100644 --- a/invokeai/app/services/bulk_download/bulk_download_default.py +++ b/invokeai/app/services/bulk_download/bulk_download_default.py @@ -38,7 +38,7 @@ class BulkDownloadService(BulkDownloadBase): self.__bulk_downloads_folder = self.__output_folder / "bulk_downloads" self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True) - def handler(self, image_names: list[str], board_id: Optional[str]) -> None: + def handler(self, image_names: list[str], board_id: Optional[str], bulk_download_item_id: Optional[str]) -> None: """ Create a zip file containing the images specified by the given image names or board id. @@ -47,7 +47,8 @@ class BulkDownloadService(BulkDownloadBase): """ bulk_download_id: str = DEFAULT_BULK_DOWNLOAD_ID - bulk_download_item_id: str = uuid_string() if board_id is None else board_id + if bulk_download_item_id is None: + bulk_download_item_id = uuid_string() if board_id is None else board_id self._signal_job_started(bulk_download_id, bulk_download_item_id) @@ -56,7 +57,7 @@ class BulkDownloadService(BulkDownloadBase): image_dtos: list[ImageDTO] = [] if board_id: - board_name = self._get_board_name(board_id) + board_name = self.get_board_name(board_id) board_name = self._clean_string_to_path_safe(board_name) # -1 is the default value for limit, which means no limit, is_intermediate only gives us completed images @@ -79,7 +80,7 @@ class BulkDownloadService(BulkDownloadBase): self.__invoker.services.logger.error("Problem bulk downloading images.") raise e - def _get_board_name(self, board_id: str) -> str: + def get_board_name(self, board_id: str) -> str: if board_id == "none": return "Uncategorized" diff --git a/tests/app/routers/test_images.py b/tests/app/routers/test_images.py index 040ae01914..a709daf24e 100644 --- a/tests/app/routers/test_images.py +++ b/tests/app/routers/test_images.py @@ -7,6 +7,7 @@ from fastapi.testclient import TestClient from invokeai.app.api.dependencies import ApiDependencies from invokeai.app.api_app import app +from invokeai.app.services.board_records.board_records_common import BoardRecord from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService from invokeai.app.services.config.config_default import InvokeAIAppConfig @@ -70,17 +71,32 @@ class MockApiDependencies(ApiDependencies): def test_download_images_from_list(monkeypatch: Any, mock_invoker: Invoker) -> None: prepare_download_images_test(monkeypatch, mock_invoker) - response = client.post("/api/v1/images/download", json={"image_names": ["test.png"]}) + def mock_uuid_string(): + return "test" + # You have to patch the function within the module it's being imported into. This is strange, but it works. + # See http://www.gregreda.com/2021/06/28/mocking-imported-module-function-python/ + monkeypatch.setattr("invokeai.app.api.routers.images.uuid_string", mock_uuid_string) + + response = client.post("/api/v1/images/download", json={"image_names": ["test.png"]}) + json_response = response.json() assert response.status_code == 202 + assert json_response["bulk_download_item_name"] == "test" def test_download_images_from_board_id_empty_image_name_list(monkeypatch: Any, mock_invoker: Invoker) -> None: + expected_board_name = "test" + + def mock_get(*args, **kwargs): + return BoardRecord(board_id="12345", board_name=expected_board_name, created_at="None", updated_at="None") + + monkeypatch.setattr(mock_invoker.services.board_records, "get", mock_get) prepare_download_images_test(monkeypatch, mock_invoker) - response = client.post("/api/v1/images/download", json={"image_names": [], "board_id": "test"}) - + response = client.post("/api/v1/images/download", json={"board_id": "test"}) + json_response = response.json() assert response.status_code == 202 + assert json_response["bulk_download_item_name"] == "test.zip" def prepare_download_images_test(monkeypatch: Any, mock_invoker: Invoker) -> None: diff --git a/tests/app/services/bulk_download/test_bulk_download.py b/tests/app/services/bulk_download/test_bulk_download.py index 184519866a..7909c44214 100644 --- a/tests/app/services/bulk_download/test_bulk_download.py +++ b/tests/app/services/bulk_download/test_bulk_download.py @@ -125,7 +125,7 @@ def test_handler_image_names(tmp_path: Path, monkeypatch: Any, mock_image_dto: I bulk_download_service = BulkDownloadService(tmp_path) bulk_download_service.start(mock_invoker) - bulk_download_service.handler([mock_image_dto.image_name], None) + bulk_download_service.handler([mock_image_dto.image_name], None, None) assert_handler_success( expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events @@ -151,7 +151,7 @@ def test_handler_board_id(tmp_path: Path, monkeypatch: Any, mock_image_dto: Imag bulk_download_service = BulkDownloadService(tmp_path) bulk_download_service.start(mock_invoker) - bulk_download_service.handler([], "test") + bulk_download_service.handler([], "test", None) assert_handler_success( expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events @@ -173,7 +173,31 @@ def test_handler_board_id_default(tmp_path: Path, monkeypatch: Any, mock_image_d bulk_download_service = BulkDownloadService(tmp_path) bulk_download_service.start(mock_invoker) - bulk_download_service.handler([], "none") + bulk_download_service.handler([], "none", None) + + assert_handler_success( + expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events + ) + + +def test_handler_bulk_download__item_id_given( + tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker +): + """Test that the handler creates the zip file correctly when given a pregenerated bulk download item id.""" + + _, expected_image_path, mock_image_contents = prepare_handler_test( + tmp_path, monkeypatch, mock_image_dto, mock_invoker + ) + expected_zip_path: Path = tmp_path / "bulk_downloads" / "test_id.zip" + + def mock_get_many(*args, **kwargs): + return OffsetPaginatedResults(limit=-1, total=1, offset=0, items=[mock_image_dto]) + + monkeypatch.setattr(mock_invoker.services.images, "get_many", mock_get_many) + + bulk_download_service = BulkDownloadService(tmp_path) + bulk_download_service.start(mock_invoker) + bulk_download_service.handler([mock_image_dto.image_name], None, "test_id") assert_handler_success( expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events @@ -242,20 +266,6 @@ def assert_handler_success( assert event_bus.events[1].payload["bulk_download_item_name"] == os.path.basename(expected_zip_path) -def test_stop(tmp_path: Path) -> None: - """Test that the stop method removes the bulk_downloads directory.""" - - bulk_download_service = BulkDownloadService(tmp_path) - - mock_file: Path = tmp_path / "bulk_downloads" / "test.zip" - mock_file.write_text("contents") - - bulk_download_service.stop() - - assert (tmp_path / "bulk_downloads").exists() - assert len(os.listdir(tmp_path / "bulk_downloads")) == 0 - - def test_handler_on_image_not_found(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker): """Test that the handler emits an error event when the image is not found.""" exception: Exception = ImageRecordNotFoundException("Image not found") @@ -309,7 +319,7 @@ def execute_handler_test_on_error( ): bulk_download_service = BulkDownloadService(tmp_path) bulk_download_service.start(mock_invoker) - bulk_download_service.handler([mock_image_dto.image_name], None) + bulk_download_service.handler([mock_image_dto.image_name], None, None) event_bus: DummyEventService = mock_invoker.services.events @@ -319,6 +329,35 @@ def execute_handler_test_on_error( assert event_bus.events[1].payload["error"] == error.__str__() +def test_get_board_name(tmp_path: Path, monkeypatch: Any, mock_invoker: Invoker): + """Test that the get_board_name function returns the correct board name.""" + + expected_board_name = "board1" + + def mock_get(*args, **kwargs): + return BoardRecord(board_id="12345", board_name=expected_board_name, created_at="None", updated_at="None") + + monkeypatch.setattr(mock_invoker.services.board_records, "get", mock_get) + + bulk_download_service = BulkDownloadService(tmp_path) + bulk_download_service.start(mock_invoker) + board_name = bulk_download_service.get_board_name("12345") + + assert board_name == expected_board_name + + +def test_get_board_name_default(tmp_path: Path, mock_invoker: Invoker): + """Test that the get_board_name function returns the correct board name.""" + + expected_board_name = "Uncategorized" + + bulk_download_service = BulkDownloadService(tmp_path) + bulk_download_service.start(mock_invoker) + board_name = bulk_download_service.get_board_name("none") + + assert board_name == expected_board_name + + def test_delete(tmp_path: Path): """Test that the delete method removes the bulk download file.""" @@ -332,8 +371,9 @@ def test_delete(tmp_path: Path): assert (tmp_path / "bulk_downloads").exists() assert len(os.listdir(tmp_path / "bulk_downloads")) == 0 + def test_stop(tmp_path: Path): - """Test that the delete method removes the bulk download file.""" + """Test that the stop method removes the bulk download file and not any directories.""" bulk_download_service = BulkDownloadService(tmp_path) @@ -343,7 +383,6 @@ def test_stop(tmp_path: Path): mock_dir: Path = tmp_path / "bulk_downloads" / "test" mock_dir.mkdir(parents=True, exist_ok=True) - bulk_download_service.stop() assert (tmp_path / "bulk_downloads").exists()