mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
returning the bulk_download_item_name on response for possible polling
This commit is contained in:
parent
ff53563152
commit
fc5c5b6bdd
@ -1,6 +1,6 @@
|
||||
import io
|
||||
import traceback
|
||||
from typing import Optional
|
||||
from typing import Optional, cast
|
||||
|
||||
from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request, Response, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
@ -13,6 +13,7 @@ from invokeai.app.services.image_records.image_records_common import ImageCatego
|
||||
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID, WorkflowWithoutIDValidator
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
@ -377,6 +378,7 @@ class ImagesDownloaded(BaseModel):
|
||||
response: Optional[str] = Field(
|
||||
description="If defined, the message to display to the user when images begin downloading"
|
||||
)
|
||||
bulk_download_item_name: str = Field(description="The bulk download item name of the bulk download item")
|
||||
|
||||
|
||||
@images_router.post(
|
||||
@ -384,15 +386,31 @@ class ImagesDownloaded(BaseModel):
|
||||
)
|
||||
async def download_images_from_list(
|
||||
background_tasks: BackgroundTasks,
|
||||
image_names: list[str] = Body(description="The list of names of images to download", embed=True),
|
||||
image_names: Optional[list[str]] = Body(
|
||||
default=None, description="The list of names of images to download", embed=True
|
||||
),
|
||||
board_id: Optional[str] = Body(
|
||||
default=None, description="The board from which image should be downloaded from", embed=True
|
||||
),
|
||||
) -> ImagesDownloaded:
|
||||
if (image_names is None or len(image_names) == 0) and board_id is None:
|
||||
raise HTTPException(status_code=400, detail="No images or board id specified.")
|
||||
background_tasks.add_task(ApiDependencies.invoker.services.bulk_download.handler, image_names, board_id)
|
||||
return ImagesDownloaded(response="Your images are preparing to be downloaded")
|
||||
bulk_download_item_id: str = uuid_string() if board_id is None else board_id
|
||||
board_name: str = (
|
||||
"" if board_id is None else ApiDependencies.invoker.services.board_records.get(board_id).board_name
|
||||
)
|
||||
|
||||
# Type narrowing handled above ^, we know that image_names is not None, trying to keep null checks at the boundaries
|
||||
background_tasks.add_task(
|
||||
ApiDependencies.invoker.services.bulk_download.handler,
|
||||
cast(list[str], image_names),
|
||||
board_id,
|
||||
bulk_download_item_id,
|
||||
)
|
||||
return ImagesDownloaded(
|
||||
response="Your images are preparing to be downloaded",
|
||||
bulk_download_item_name=bulk_download_item_id if board_id is None else board_name + ".zip",
|
||||
)
|
||||
|
||||
|
||||
@images_router.api_route(
|
||||
@ -410,7 +428,7 @@ async def download_images_from_list(
|
||||
)
|
||||
async def get_bulk_download_item(
|
||||
background_tasks: BackgroundTasks,
|
||||
bulk_download_item_name: str = Path(description="The bulk_download_item_id of the bulk download item to get"),
|
||||
bulk_download_item_name: str = Path(description="The bulk_download_item_name of the bulk download item to get"),
|
||||
) -> FileResponse:
|
||||
"""Gets a bulk download zip file"""
|
||||
try:
|
||||
|
@ -26,7 +26,7 @@ class BulkDownloadBase(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def handler(self, image_names: list[str], board_id: Optional[str]) -> None:
|
||||
def handler(self, image_names: list[str], board_id: Optional[str], bulk_download_item_id: Optional[str]) -> None:
|
||||
"""
|
||||
Starts a a bulk download job.
|
||||
|
||||
@ -44,6 +44,15 @@ class BulkDownloadBase(ABC):
|
||||
:return: The path to the bulk download file.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_board_name(self, board_id: str) -> str:
|
||||
"""
|
||||
Get the name of the board.
|
||||
|
||||
:param board_id: The ID of the board.
|
||||
:return: The name of the board.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def stop(self, *args, **kwargs) -> None:
|
||||
"""
|
||||
|
@ -38,7 +38,7 @@ class BulkDownloadService(BulkDownloadBase):
|
||||
self.__bulk_downloads_folder = self.__output_folder / "bulk_downloads"
|
||||
self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def handler(self, image_names: list[str], board_id: Optional[str]) -> None:
|
||||
def handler(self, image_names: list[str], board_id: Optional[str], bulk_download_item_id: Optional[str]) -> None:
|
||||
"""
|
||||
Create a zip file containing the images specified by the given image names or board id.
|
||||
|
||||
@ -47,7 +47,8 @@ class BulkDownloadService(BulkDownloadBase):
|
||||
"""
|
||||
|
||||
bulk_download_id: str = DEFAULT_BULK_DOWNLOAD_ID
|
||||
bulk_download_item_id: str = uuid_string() if board_id is None else board_id
|
||||
if bulk_download_item_id is None:
|
||||
bulk_download_item_id = uuid_string() if board_id is None else board_id
|
||||
|
||||
self._signal_job_started(bulk_download_id, bulk_download_item_id)
|
||||
|
||||
@ -56,7 +57,7 @@ class BulkDownloadService(BulkDownloadBase):
|
||||
image_dtos: list[ImageDTO] = []
|
||||
|
||||
if board_id:
|
||||
board_name = self._get_board_name(board_id)
|
||||
board_name = self.get_board_name(board_id)
|
||||
board_name = self._clean_string_to_path_safe(board_name)
|
||||
|
||||
# -1 is the default value for limit, which means no limit, is_intermediate only gives us completed images
|
||||
@ -79,7 +80,7 @@ class BulkDownloadService(BulkDownloadBase):
|
||||
self.__invoker.services.logger.error("Problem bulk downloading images.")
|
||||
raise e
|
||||
|
||||
def _get_board_name(self, board_id: str) -> str:
|
||||
def get_board_name(self, board_id: str) -> str:
|
||||
if board_id == "none":
|
||||
return "Uncategorized"
|
||||
|
||||
|
@ -7,6 +7,7 @@ from fastapi.testclient import TestClient
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.api_app import app
|
||||
from invokeai.app.services.board_records.board_records_common import BoardRecord
|
||||
from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage
|
||||
from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
@ -70,17 +71,32 @@ class MockApiDependencies(ApiDependencies):
|
||||
def test_download_images_from_list(monkeypatch: Any, mock_invoker: Invoker) -> None:
|
||||
prepare_download_images_test(monkeypatch, mock_invoker)
|
||||
|
||||
response = client.post("/api/v1/images/download", json={"image_names": ["test.png"]})
|
||||
def mock_uuid_string():
|
||||
return "test"
|
||||
|
||||
# You have to patch the function within the module it's being imported into. This is strange, but it works.
|
||||
# See http://www.gregreda.com/2021/06/28/mocking-imported-module-function-python/
|
||||
monkeypatch.setattr("invokeai.app.api.routers.images.uuid_string", mock_uuid_string)
|
||||
|
||||
response = client.post("/api/v1/images/download", json={"image_names": ["test.png"]})
|
||||
json_response = response.json()
|
||||
assert response.status_code == 202
|
||||
assert json_response["bulk_download_item_name"] == "test"
|
||||
|
||||
|
||||
def test_download_images_from_board_id_empty_image_name_list(monkeypatch: Any, mock_invoker: Invoker) -> None:
|
||||
expected_board_name = "test"
|
||||
|
||||
def mock_get(*args, **kwargs):
|
||||
return BoardRecord(board_id="12345", board_name=expected_board_name, created_at="None", updated_at="None")
|
||||
|
||||
monkeypatch.setattr(mock_invoker.services.board_records, "get", mock_get)
|
||||
prepare_download_images_test(monkeypatch, mock_invoker)
|
||||
|
||||
response = client.post("/api/v1/images/download", json={"image_names": [], "board_id": "test"})
|
||||
|
||||
response = client.post("/api/v1/images/download", json={"board_id": "test"})
|
||||
json_response = response.json()
|
||||
assert response.status_code == 202
|
||||
assert json_response["bulk_download_item_name"] == "test.zip"
|
||||
|
||||
|
||||
def prepare_download_images_test(monkeypatch: Any, mock_invoker: Invoker) -> None:
|
||||
|
@ -125,7 +125,7 @@ def test_handler_image_names(tmp_path: Path, monkeypatch: Any, mock_image_dto: I
|
||||
|
||||
bulk_download_service = BulkDownloadService(tmp_path)
|
||||
bulk_download_service.start(mock_invoker)
|
||||
bulk_download_service.handler([mock_image_dto.image_name], None)
|
||||
bulk_download_service.handler([mock_image_dto.image_name], None, None)
|
||||
|
||||
assert_handler_success(
|
||||
expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events
|
||||
@ -151,7 +151,7 @@ def test_handler_board_id(tmp_path: Path, monkeypatch: Any, mock_image_dto: Imag
|
||||
|
||||
bulk_download_service = BulkDownloadService(tmp_path)
|
||||
bulk_download_service.start(mock_invoker)
|
||||
bulk_download_service.handler([], "test")
|
||||
bulk_download_service.handler([], "test", None)
|
||||
|
||||
assert_handler_success(
|
||||
expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events
|
||||
@ -173,7 +173,31 @@ def test_handler_board_id_default(tmp_path: Path, monkeypatch: Any, mock_image_d
|
||||
|
||||
bulk_download_service = BulkDownloadService(tmp_path)
|
||||
bulk_download_service.start(mock_invoker)
|
||||
bulk_download_service.handler([], "none")
|
||||
bulk_download_service.handler([], "none", None)
|
||||
|
||||
assert_handler_success(
|
||||
expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events
|
||||
)
|
||||
|
||||
|
||||
def test_handler_bulk_download__item_id_given(
|
||||
tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker
|
||||
):
|
||||
"""Test that the handler creates the zip file correctly when given a pregenerated bulk download item id."""
|
||||
|
||||
_, expected_image_path, mock_image_contents = prepare_handler_test(
|
||||
tmp_path, monkeypatch, mock_image_dto, mock_invoker
|
||||
)
|
||||
expected_zip_path: Path = tmp_path / "bulk_downloads" / "test_id.zip"
|
||||
|
||||
def mock_get_many(*args, **kwargs):
|
||||
return OffsetPaginatedResults(limit=-1, total=1, offset=0, items=[mock_image_dto])
|
||||
|
||||
monkeypatch.setattr(mock_invoker.services.images, "get_many", mock_get_many)
|
||||
|
||||
bulk_download_service = BulkDownloadService(tmp_path)
|
||||
bulk_download_service.start(mock_invoker)
|
||||
bulk_download_service.handler([mock_image_dto.image_name], None, "test_id")
|
||||
|
||||
assert_handler_success(
|
||||
expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events
|
||||
@ -242,20 +266,6 @@ def assert_handler_success(
|
||||
assert event_bus.events[1].payload["bulk_download_item_name"] == os.path.basename(expected_zip_path)
|
||||
|
||||
|
||||
def test_stop(tmp_path: Path) -> None:
|
||||
"""Test that the stop method removes the bulk_downloads directory."""
|
||||
|
||||
bulk_download_service = BulkDownloadService(tmp_path)
|
||||
|
||||
mock_file: Path = tmp_path / "bulk_downloads" / "test.zip"
|
||||
mock_file.write_text("contents")
|
||||
|
||||
bulk_download_service.stop()
|
||||
|
||||
assert (tmp_path / "bulk_downloads").exists()
|
||||
assert len(os.listdir(tmp_path / "bulk_downloads")) == 0
|
||||
|
||||
|
||||
def test_handler_on_image_not_found(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker):
|
||||
"""Test that the handler emits an error event when the image is not found."""
|
||||
exception: Exception = ImageRecordNotFoundException("Image not found")
|
||||
@ -309,7 +319,7 @@ def execute_handler_test_on_error(
|
||||
):
|
||||
bulk_download_service = BulkDownloadService(tmp_path)
|
||||
bulk_download_service.start(mock_invoker)
|
||||
bulk_download_service.handler([mock_image_dto.image_name], None)
|
||||
bulk_download_service.handler([mock_image_dto.image_name], None, None)
|
||||
|
||||
event_bus: DummyEventService = mock_invoker.services.events
|
||||
|
||||
@ -319,6 +329,35 @@ def execute_handler_test_on_error(
|
||||
assert event_bus.events[1].payload["error"] == error.__str__()
|
||||
|
||||
|
||||
def test_get_board_name(tmp_path: Path, monkeypatch: Any, mock_invoker: Invoker):
|
||||
"""Test that the get_board_name function returns the correct board name."""
|
||||
|
||||
expected_board_name = "board1"
|
||||
|
||||
def mock_get(*args, **kwargs):
|
||||
return BoardRecord(board_id="12345", board_name=expected_board_name, created_at="None", updated_at="None")
|
||||
|
||||
monkeypatch.setattr(mock_invoker.services.board_records, "get", mock_get)
|
||||
|
||||
bulk_download_service = BulkDownloadService(tmp_path)
|
||||
bulk_download_service.start(mock_invoker)
|
||||
board_name = bulk_download_service.get_board_name("12345")
|
||||
|
||||
assert board_name == expected_board_name
|
||||
|
||||
|
||||
def test_get_board_name_default(tmp_path: Path, mock_invoker: Invoker):
|
||||
"""Test that the get_board_name function returns the correct board name."""
|
||||
|
||||
expected_board_name = "Uncategorized"
|
||||
|
||||
bulk_download_service = BulkDownloadService(tmp_path)
|
||||
bulk_download_service.start(mock_invoker)
|
||||
board_name = bulk_download_service.get_board_name("none")
|
||||
|
||||
assert board_name == expected_board_name
|
||||
|
||||
|
||||
def test_delete(tmp_path: Path):
|
||||
"""Test that the delete method removes the bulk download file."""
|
||||
|
||||
@ -332,8 +371,9 @@ def test_delete(tmp_path: Path):
|
||||
assert (tmp_path / "bulk_downloads").exists()
|
||||
assert len(os.listdir(tmp_path / "bulk_downloads")) == 0
|
||||
|
||||
|
||||
def test_stop(tmp_path: Path):
|
||||
"""Test that the delete method removes the bulk download file."""
|
||||
"""Test that the stop method removes the bulk download file and not any directories."""
|
||||
|
||||
bulk_download_service = BulkDownloadService(tmp_path)
|
||||
|
||||
@ -343,7 +383,6 @@ def test_stop(tmp_path: Path):
|
||||
mock_dir: Path = tmp_path / "bulk_downloads" / "test"
|
||||
mock_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
bulk_download_service.stop()
|
||||
|
||||
assert (tmp_path / "bulk_downloads").exists()
|
||||
|
Loading…
Reference in New Issue
Block a user