returning the bulk_download_item_name on response for possible polling

This commit is contained in:
Stefan Tobler 2024-01-28 01:23:38 -05:00 committed by psychedelicious
parent ff53563152
commit fc5c5b6bdd
5 changed files with 116 additions and 33 deletions

View File

@ -1,6 +1,6 @@
import io import io
import traceback import traceback
from typing import Optional from typing import Optional, cast
from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request, Response, UploadFile from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request, Response, UploadFile
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
@ -13,6 +13,7 @@ from invokeai.app.services.image_records.image_records_common import ImageCatego
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID, WorkflowWithoutIDValidator from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID, WorkflowWithoutIDValidator
from invokeai.app.util.misc import uuid_string
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
@ -377,6 +378,7 @@ class ImagesDownloaded(BaseModel):
response: Optional[str] = Field( response: Optional[str] = Field(
description="If defined, the message to display to the user when images begin downloading" description="If defined, the message to display to the user when images begin downloading"
) )
bulk_download_item_name: str = Field(description="The bulk download item name of the bulk download item")
@images_router.post( @images_router.post(
@ -384,15 +386,31 @@ class ImagesDownloaded(BaseModel):
) )
async def download_images_from_list( async def download_images_from_list(
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
image_names: list[str] = Body(description="The list of names of images to download", embed=True), image_names: Optional[list[str]] = Body(
default=None, description="The list of names of images to download", embed=True
),
board_id: Optional[str] = Body( board_id: Optional[str] = Body(
default=None, description="The board from which image should be downloaded from", embed=True default=None, description="The board from which image should be downloaded from", embed=True
), ),
) -> ImagesDownloaded: ) -> ImagesDownloaded:
if (image_names is None or len(image_names) == 0) and board_id is None: if (image_names is None or len(image_names) == 0) and board_id is None:
raise HTTPException(status_code=400, detail="No images or board id specified.") raise HTTPException(status_code=400, detail="No images or board id specified.")
background_tasks.add_task(ApiDependencies.invoker.services.bulk_download.handler, image_names, board_id) bulk_download_item_id: str = uuid_string() if board_id is None else board_id
return ImagesDownloaded(response="Your images are preparing to be downloaded") board_name: str = (
"" if board_id is None else ApiDependencies.invoker.services.board_records.get(board_id).board_name
)
# Type narrowing handled above ^, we know that image_names is not None, trying to keep null checks at the boundaries
background_tasks.add_task(
ApiDependencies.invoker.services.bulk_download.handler,
cast(list[str], image_names),
board_id,
bulk_download_item_id,
)
return ImagesDownloaded(
response="Your images are preparing to be downloaded",
bulk_download_item_name=bulk_download_item_id if board_id is None else board_name + ".zip",
)
@images_router.api_route( @images_router.api_route(
@ -410,7 +428,7 @@ async def download_images_from_list(
) )
async def get_bulk_download_item( async def get_bulk_download_item(
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
bulk_download_item_name: str = Path(description="The bulk_download_item_id of the bulk download item to get"), bulk_download_item_name: str = Path(description="The bulk_download_item_name of the bulk download item to get"),
) -> FileResponse: ) -> FileResponse:
"""Gets a bulk download zip file""" """Gets a bulk download zip file"""
try: try:

View File

@ -26,7 +26,7 @@ class BulkDownloadBase(ABC):
""" """
@abstractmethod @abstractmethod
def handler(self, image_names: list[str], board_id: Optional[str]) -> None: def handler(self, image_names: list[str], board_id: Optional[str], bulk_download_item_id: Optional[str]) -> None:
""" """
Starts a a bulk download job. Starts a a bulk download job.
@ -44,6 +44,15 @@ class BulkDownloadBase(ABC):
:return: The path to the bulk download file. :return: The path to the bulk download file.
""" """
@abstractmethod
def get_board_name(self, board_id: str) -> str:
"""
Get the name of the board.
:param board_id: The ID of the board.
:return: The name of the board.
"""
@abstractmethod @abstractmethod
def stop(self, *args, **kwargs) -> None: def stop(self, *args, **kwargs) -> None:
""" """

View File

@ -38,7 +38,7 @@ class BulkDownloadService(BulkDownloadBase):
self.__bulk_downloads_folder = self.__output_folder / "bulk_downloads" self.__bulk_downloads_folder = self.__output_folder / "bulk_downloads"
self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True) self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True)
def handler(self, image_names: list[str], board_id: Optional[str]) -> None: def handler(self, image_names: list[str], board_id: Optional[str], bulk_download_item_id: Optional[str]) -> None:
""" """
Create a zip file containing the images specified by the given image names or board id. Create a zip file containing the images specified by the given image names or board id.
@ -47,7 +47,8 @@ class BulkDownloadService(BulkDownloadBase):
""" """
bulk_download_id: str = DEFAULT_BULK_DOWNLOAD_ID bulk_download_id: str = DEFAULT_BULK_DOWNLOAD_ID
bulk_download_item_id: str = uuid_string() if board_id is None else board_id if bulk_download_item_id is None:
bulk_download_item_id = uuid_string() if board_id is None else board_id
self._signal_job_started(bulk_download_id, bulk_download_item_id) self._signal_job_started(bulk_download_id, bulk_download_item_id)
@ -56,7 +57,7 @@ class BulkDownloadService(BulkDownloadBase):
image_dtos: list[ImageDTO] = [] image_dtos: list[ImageDTO] = []
if board_id: if board_id:
board_name = self._get_board_name(board_id) board_name = self.get_board_name(board_id)
board_name = self._clean_string_to_path_safe(board_name) board_name = self._clean_string_to_path_safe(board_name)
# -1 is the default value for limit, which means no limit, is_intermediate only gives us completed images # -1 is the default value for limit, which means no limit, is_intermediate only gives us completed images
@ -79,7 +80,7 @@ class BulkDownloadService(BulkDownloadBase):
self.__invoker.services.logger.error("Problem bulk downloading images.") self.__invoker.services.logger.error("Problem bulk downloading images.")
raise e raise e
def _get_board_name(self, board_id: str) -> str: def get_board_name(self, board_id: str) -> str:
if board_id == "none": if board_id == "none":
return "Uncategorized" return "Uncategorized"

View File

@ -7,6 +7,7 @@ from fastapi.testclient import TestClient
from invokeai.app.api.dependencies import ApiDependencies from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.api_app import app from invokeai.app.api_app import app
from invokeai.app.services.board_records.board_records_common import BoardRecord
from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage
from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService
from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.config.config_default import InvokeAIAppConfig
@ -70,17 +71,32 @@ class MockApiDependencies(ApiDependencies):
def test_download_images_from_list(monkeypatch: Any, mock_invoker: Invoker) -> None: def test_download_images_from_list(monkeypatch: Any, mock_invoker: Invoker) -> None:
prepare_download_images_test(monkeypatch, mock_invoker) prepare_download_images_test(monkeypatch, mock_invoker)
response = client.post("/api/v1/images/download", json={"image_names": ["test.png"]}) def mock_uuid_string():
return "test"
# You have to patch the function within the module it's being imported into. This is strange, but it works.
# See http://www.gregreda.com/2021/06/28/mocking-imported-module-function-python/
monkeypatch.setattr("invokeai.app.api.routers.images.uuid_string", mock_uuid_string)
response = client.post("/api/v1/images/download", json={"image_names": ["test.png"]})
json_response = response.json()
assert response.status_code == 202 assert response.status_code == 202
assert json_response["bulk_download_item_name"] == "test"
def test_download_images_from_board_id_empty_image_name_list(monkeypatch: Any, mock_invoker: Invoker) -> None: def test_download_images_from_board_id_empty_image_name_list(monkeypatch: Any, mock_invoker: Invoker) -> None:
expected_board_name = "test"
def mock_get(*args, **kwargs):
return BoardRecord(board_id="12345", board_name=expected_board_name, created_at="None", updated_at="None")
monkeypatch.setattr(mock_invoker.services.board_records, "get", mock_get)
prepare_download_images_test(monkeypatch, mock_invoker) prepare_download_images_test(monkeypatch, mock_invoker)
response = client.post("/api/v1/images/download", json={"image_names": [], "board_id": "test"}) response = client.post("/api/v1/images/download", json={"board_id": "test"})
json_response = response.json()
assert response.status_code == 202 assert response.status_code == 202
assert json_response["bulk_download_item_name"] == "test.zip"
def prepare_download_images_test(monkeypatch: Any, mock_invoker: Invoker) -> None: def prepare_download_images_test(monkeypatch: Any, mock_invoker: Invoker) -> None:

View File

@ -125,7 +125,7 @@ def test_handler_image_names(tmp_path: Path, monkeypatch: Any, mock_image_dto: I
bulk_download_service = BulkDownloadService(tmp_path) bulk_download_service = BulkDownloadService(tmp_path)
bulk_download_service.start(mock_invoker) bulk_download_service.start(mock_invoker)
bulk_download_service.handler([mock_image_dto.image_name], None) bulk_download_service.handler([mock_image_dto.image_name], None, None)
assert_handler_success( assert_handler_success(
expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events
@ -151,7 +151,7 @@ def test_handler_board_id(tmp_path: Path, monkeypatch: Any, mock_image_dto: Imag
bulk_download_service = BulkDownloadService(tmp_path) bulk_download_service = BulkDownloadService(tmp_path)
bulk_download_service.start(mock_invoker) bulk_download_service.start(mock_invoker)
bulk_download_service.handler([], "test") bulk_download_service.handler([], "test", None)
assert_handler_success( assert_handler_success(
expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events
@ -173,7 +173,31 @@ def test_handler_board_id_default(tmp_path: Path, monkeypatch: Any, mock_image_d
bulk_download_service = BulkDownloadService(tmp_path) bulk_download_service = BulkDownloadService(tmp_path)
bulk_download_service.start(mock_invoker) bulk_download_service.start(mock_invoker)
bulk_download_service.handler([], "none") bulk_download_service.handler([], "none", None)
assert_handler_success(
expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events
)
def test_handler_bulk_download__item_id_given(
tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker
):
"""Test that the handler creates the zip file correctly when given a pregenerated bulk download item id."""
_, expected_image_path, mock_image_contents = prepare_handler_test(
tmp_path, monkeypatch, mock_image_dto, mock_invoker
)
expected_zip_path: Path = tmp_path / "bulk_downloads" / "test_id.zip"
def mock_get_many(*args, **kwargs):
return OffsetPaginatedResults(limit=-1, total=1, offset=0, items=[mock_image_dto])
monkeypatch.setattr(mock_invoker.services.images, "get_many", mock_get_many)
bulk_download_service = BulkDownloadService(tmp_path)
bulk_download_service.start(mock_invoker)
bulk_download_service.handler([mock_image_dto.image_name], None, "test_id")
assert_handler_success( assert_handler_success(
expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events
@ -242,20 +266,6 @@ def assert_handler_success(
assert event_bus.events[1].payload["bulk_download_item_name"] == os.path.basename(expected_zip_path) assert event_bus.events[1].payload["bulk_download_item_name"] == os.path.basename(expected_zip_path)
def test_stop(tmp_path: Path) -> None:
"""Test that the stop method removes the bulk_downloads directory."""
bulk_download_service = BulkDownloadService(tmp_path)
mock_file: Path = tmp_path / "bulk_downloads" / "test.zip"
mock_file.write_text("contents")
bulk_download_service.stop()
assert (tmp_path / "bulk_downloads").exists()
assert len(os.listdir(tmp_path / "bulk_downloads")) == 0
def test_handler_on_image_not_found(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker): def test_handler_on_image_not_found(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker):
"""Test that the handler emits an error event when the image is not found.""" """Test that the handler emits an error event when the image is not found."""
exception: Exception = ImageRecordNotFoundException("Image not found") exception: Exception = ImageRecordNotFoundException("Image not found")
@ -309,7 +319,7 @@ def execute_handler_test_on_error(
): ):
bulk_download_service = BulkDownloadService(tmp_path) bulk_download_service = BulkDownloadService(tmp_path)
bulk_download_service.start(mock_invoker) bulk_download_service.start(mock_invoker)
bulk_download_service.handler([mock_image_dto.image_name], None) bulk_download_service.handler([mock_image_dto.image_name], None, None)
event_bus: DummyEventService = mock_invoker.services.events event_bus: DummyEventService = mock_invoker.services.events
@ -319,6 +329,35 @@ def execute_handler_test_on_error(
assert event_bus.events[1].payload["error"] == error.__str__() assert event_bus.events[1].payload["error"] == error.__str__()
def test_get_board_name(tmp_path: Path, monkeypatch: Any, mock_invoker: Invoker):
"""Test that the get_board_name function returns the correct board name."""
expected_board_name = "board1"
def mock_get(*args, **kwargs):
return BoardRecord(board_id="12345", board_name=expected_board_name, created_at="None", updated_at="None")
monkeypatch.setattr(mock_invoker.services.board_records, "get", mock_get)
bulk_download_service = BulkDownloadService(tmp_path)
bulk_download_service.start(mock_invoker)
board_name = bulk_download_service.get_board_name("12345")
assert board_name == expected_board_name
def test_get_board_name_default(tmp_path: Path, mock_invoker: Invoker):
"""Test that the get_board_name function returns the correct board name."""
expected_board_name = "Uncategorized"
bulk_download_service = BulkDownloadService(tmp_path)
bulk_download_service.start(mock_invoker)
board_name = bulk_download_service.get_board_name("none")
assert board_name == expected_board_name
def test_delete(tmp_path: Path): def test_delete(tmp_path: Path):
"""Test that the delete method removes the bulk download file.""" """Test that the delete method removes the bulk download file."""
@ -332,8 +371,9 @@ def test_delete(tmp_path: Path):
assert (tmp_path / "bulk_downloads").exists() assert (tmp_path / "bulk_downloads").exists()
assert len(os.listdir(tmp_path / "bulk_downloads")) == 0 assert len(os.listdir(tmp_path / "bulk_downloads")) == 0
def test_stop(tmp_path: Path): def test_stop(tmp_path: Path):
"""Test that the delete method removes the bulk download file.""" """Test that the stop method removes the bulk download file and not any directories."""
bulk_download_service = BulkDownloadService(tmp_path) bulk_download_service = BulkDownloadService(tmp_path)
@ -343,7 +383,6 @@ def test_stop(tmp_path: Path):
mock_dir: Path = tmp_path / "bulk_downloads" / "test" mock_dir: Path = tmp_path / "bulk_downloads" / "test"
mock_dir.mkdir(parents=True, exist_ok=True) mock_dir.mkdir(parents=True, exist_ok=True)
bulk_download_service.stop() bulk_download_service.stop()
assert (tmp_path / "bulk_downloads").exists() assert (tmp_path / "bulk_downloads").exists()