returning the bulk_download_item_name on response for possible polling

This commit is contained in:
Stefan Tobler
2024-01-28 01:23:38 -05:00
committed by psychedelicious
parent ff53563152
commit fc5c5b6bdd
5 changed files with 116 additions and 33 deletions

View File

@ -1,6 +1,6 @@
import io
import traceback
from typing import Optional
from typing import Optional, cast
from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request, Response, UploadFile
from fastapi.responses import FileResponse
@ -13,6 +13,7 @@ from invokeai.app.services.image_records.image_records_common import ImageCatego
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID, WorkflowWithoutIDValidator
from invokeai.app.util.misc import uuid_string
from ..dependencies import ApiDependencies
@ -377,6 +378,7 @@ class ImagesDownloaded(BaseModel):
response: Optional[str] = Field(
description="If defined, the message to display to the user when images begin downloading"
)
bulk_download_item_name: str = Field(description="The bulk download item name of the bulk download item")
@images_router.post(
@ -384,15 +386,31 @@ class ImagesDownloaded(BaseModel):
)
async def download_images_from_list(
background_tasks: BackgroundTasks,
image_names: list[str] = Body(description="The list of names of images to download", embed=True),
image_names: Optional[list[str]] = Body(
default=None, description="The list of names of images to download", embed=True
),
board_id: Optional[str] = Body(
default=None, description="The board from which image should be downloaded from", embed=True
),
) -> ImagesDownloaded:
if (image_names is None or len(image_names) == 0) and board_id is None:
raise HTTPException(status_code=400, detail="No images or board id specified.")
background_tasks.add_task(ApiDependencies.invoker.services.bulk_download.handler, image_names, board_id)
return ImagesDownloaded(response="Your images are preparing to be downloaded")
bulk_download_item_id: str = uuid_string() if board_id is None else board_id
board_name: str = (
"" if board_id is None else ApiDependencies.invoker.services.board_records.get(board_id).board_name
)
# Type narrowing handled above ^, we know that image_names is not None, trying to keep null checks at the boundaries
background_tasks.add_task(
ApiDependencies.invoker.services.bulk_download.handler,
cast(list[str], image_names),
board_id,
bulk_download_item_id,
)
return ImagesDownloaded(
response="Your images are preparing to be downloaded",
bulk_download_item_name=bulk_download_item_id if board_id is None else board_name + ".zip",
)
@images_router.api_route(
@ -410,7 +428,7 @@ async def download_images_from_list(
)
async def get_bulk_download_item(
background_tasks: BackgroundTasks,
bulk_download_item_name: str = Path(description="The bulk_download_item_id of the bulk download item to get"),
bulk_download_item_name: str = Path(description="The bulk_download_item_name of the bulk download item to get"),
) -> FileResponse:
"""Gets a bulk download zip file"""
try:

View File

@ -26,7 +26,7 @@ class BulkDownloadBase(ABC):
"""
@abstractmethod
def handler(self, image_names: list[str], board_id: Optional[str]) -> None:
def handler(self, image_names: list[str], board_id: Optional[str], bulk_download_item_id: Optional[str]) -> None:
"""
Starts a a bulk download job.
@ -44,6 +44,15 @@ class BulkDownloadBase(ABC):
:return: The path to the bulk download file.
"""
@abstractmethod
def get_board_name(self, board_id: str) -> str:
"""
Get the name of the board.
:param board_id: The ID of the board.
:return: The name of the board.
"""
@abstractmethod
def stop(self, *args, **kwargs) -> None:
"""

View File

@ -38,7 +38,7 @@ class BulkDownloadService(BulkDownloadBase):
self.__bulk_downloads_folder = self.__output_folder / "bulk_downloads"
self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True)
def handler(self, image_names: list[str], board_id: Optional[str]) -> None:
def handler(self, image_names: list[str], board_id: Optional[str], bulk_download_item_id: Optional[str]) -> None:
"""
Create a zip file containing the images specified by the given image names or board id.
@ -47,7 +47,8 @@ class BulkDownloadService(BulkDownloadBase):
"""
bulk_download_id: str = DEFAULT_BULK_DOWNLOAD_ID
bulk_download_item_id: str = uuid_string() if board_id is None else board_id
if bulk_download_item_id is None:
bulk_download_item_id = uuid_string() if board_id is None else board_id
self._signal_job_started(bulk_download_id, bulk_download_item_id)
@ -56,7 +57,7 @@ class BulkDownloadService(BulkDownloadBase):
image_dtos: list[ImageDTO] = []
if board_id:
board_name = self._get_board_name(board_id)
board_name = self.get_board_name(board_id)
board_name = self._clean_string_to_path_safe(board_name)
# -1 is the default value for limit, which means no limit, is_intermediate only gives us completed images
@ -79,7 +80,7 @@ class BulkDownloadService(BulkDownloadBase):
self.__invoker.services.logger.error("Problem bulk downloading images.")
raise e
def _get_board_name(self, board_id: str) -> str:
def get_board_name(self, board_id: str) -> str:
if board_id == "none":
return "Uncategorized"