From 56d2d220a812990a0b28305e2df3153331f913db Mon Sep 17 00:00:00 2001 From: Stefan Tobler Date: Sun, 7 Jan 2024 21:29:42 -0500 Subject: [PATCH] implementation of bulkdownload background task --- invokeai/app/api/dependencies.py | 3 + invokeai/app/api/routers/images.py | 14 ++- .../bulk_download/bulk_download_base.py | 2 +- .../bulk_download/bulk_download_defauilt.py | 114 ++++++++++++++++++ invokeai/app/services/invocation_services.py | 3 + 5 files changed, 130 insertions(+), 6 deletions(-) create mode 100644 invokeai/app/services/bulk_download/bulk_download_defauilt.py diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index a9132516a8..ab09d1e5d7 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -15,6 +15,7 @@ from ..services.board_image_records.board_image_records_sqlite import SqliteBoar from ..services.board_images.board_images_default import BoardImagesService from ..services.board_records.board_records_sqlite import SqliteBoardRecordStorage from ..services.boards.boards_default import BoardService +from ..services.bulk_download.bulk_download_defauilt import BulkDownloadService from ..services.config import InvokeAIAppConfig from ..services.download import DownloadQueueService from ..services.image_files.image_files_disk import DiskImageFileStorage @@ -81,6 +82,7 @@ class ApiDependencies: board_records = SqliteBoardRecordStorage(db=db) boards = BoardService() events = FastAPIEventService(event_handler_id) + bulk_download = BulkDownloadService(output_folder=f"{output_folder}", event_bus=events) image_records = SqliteImageRecordStorage(db=db) images = ImageService() invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size) @@ -110,6 +112,7 @@ class ApiDependencies: board_images=board_images, board_records=board_records, boards=boards, + bulk_download=bulk_download, configuration=configuration, events=events, image_files=image_files, diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index cc60ad1be8..2a8e1e7ec7 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -2,7 +2,7 @@ import io import traceback from typing import Optional -from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile +from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request, Response, UploadFile from fastapi.responses import FileResponse from fastapi.routing import APIRouter from PIL import Image @@ -10,6 +10,7 @@ from pydantic import BaseModel, Field, ValidationError from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin +from invokeai.app.services.board_records.board_records_common import BoardRecordNotFoundException from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID, WorkflowWithoutIDValidator @@ -372,19 +373,22 @@ async def unstar_images_in_list( except Exception: raise HTTPException(status_code=500, detail="Failed to unstar images") - class ImagesDownloaded(BaseModel): response: Optional[str] = Field( description="If defined, the message to display to the user when images begin downloading" ) -@images_router.post("/download", operation_id="download_images_from_list", response_model=ImagesDownloaded) +@images_router.post("/download", operation_id="download_images_from_list", response_model=ImagesDownloaded, status_code=202) async def download_images_from_list( + background_tasks: BackgroundTasks, image_names: list[str] = Body(description="The list of names of images to download", embed=True), board_id: Optional[str] = Body( default=None, description="The board from which image should be downloaded from", embed=True ), ) -> ImagesDownloaded: - # return ImagesDownloaded(response="Your images are downloading") - raise HTTPException(status_code=501, detail="Endpoint is not yet implemented") + if (image_names is None or len(image_names) == 0) and board_id is None: + raise HTTPException(status_code=400, detail="No images or board id specified.") + background_tasks.add_task(ApiDependencies.invoker.services.bulk_download.handler, ApiDependencies.invoker, image_names, board_id) + return ImagesDownloaded(response="Your images are preparing to be downloaded") + diff --git a/invokeai/app/services/bulk_download/bulk_download_base.py b/invokeai/app/services/bulk_download/bulk_download_base.py index 54c8771437..b788020bba 100644 --- a/invokeai/app/services/bulk_download/bulk_download_base.py +++ b/invokeai/app/services/bulk_download/bulk_download_base.py @@ -22,7 +22,7 @@ class BulkDownloadBase(ABC): """ @abstractmethod - def start(self, invoker: Invoker, image_names: list[str], board_id: Optional[str]) -> str: + def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[str]) -> None: """ Starts a a bulk download job. diff --git a/invokeai/app/services/bulk_download/bulk_download_defauilt.py b/invokeai/app/services/bulk_download/bulk_download_defauilt.py new file mode 100644 index 0000000000..ebeaa4be5a --- /dev/null +++ b/invokeai/app/services/bulk_download/bulk_download_defauilt.py @@ -0,0 +1,114 @@ +from pathlib import Path +from typing import Optional, Union +import uuid +from zipfile import ZipFile + +from invokeai.app.services.board_records.board_records_common import BoardRecordNotFoundException +from invokeai.app.services.bulk_download.bulk_download_common import BulkDownloadException +from invokeai.app.services.events.events_base import EventServiceBase +from invokeai.app.services.image_records.image_records_common import ImageRecordNotFoundException +from invokeai.app.services.invoker import Invoker + +from .bulk_download_base import BulkDownloadBase + +class BulkDownloadService(BulkDownloadBase): + + __output_folder: Path + __bulk_downloads_folder: Path + __event_bus: Optional[EventServiceBase] + + def __init__(self, + output_folder: Union[str, Path], + event_bus: Optional[EventServiceBase] = None,): + """ + Initialize the downloader object. + + :param event_bus: Optional EventService object + """ + self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder) + self.__bulk_downloads_folder = self.__output_folder / "bulk_downloads" + self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True) + self.__event_bus = event_bus + + + def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[str]) -> None: + """ + Create a zip file containing the images specified by the given image names or board id. + + param: image_names: A list of image names to include in the zip file. + param: board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file. + """ + bulk_download_id = str(uuid.uuid4()) + + self._signal_job_started(bulk_download_id) + try: + board_name: Union[str, None] = None + if board_id: + image_names = invoker.services.board_image_records.get_all_board_image_names_for_board(board_id) + if board_id == "none": + board_id = "Uncategorized" + image_names_to_paths: dict[str, str] = self._get_image_name_to_path_map(invoker, image_names) + file_path: str = self._create_zip_file(image_names_to_paths, bulk_download_id) + self._signal_job_completed(bulk_download_id, file_path) + except (ImageRecordNotFoundException, BoardRecordNotFoundException, BulkDownloadException) as e: + self._signal_job_failed(bulk_download_id, e) + except Exception as e: + self._signal_job_failed(bulk_download_id, e) + + def _get_image_name_to_path_map(self, invoker: Invoker, image_names: list[str]) -> dict[str, str]: + """ + Create a map of image names to their paths. + :param image_names: A list of image names. + """ + image_names_to_paths: dict[str, str] = {} + for image_name in image_names: + image_names_to_paths[image_name] = invoker.services.images.get_path(image_name) + return image_names_to_paths + + + + def _create_zip_file(self, image_names_to_paths: dict[str, str], bulk_download_id: str) -> str: + """ + Create a zip file containing the images specified by the given image names or board id. + If download with the same bulk_download_id already exists, it will be overwritten. + """ + + zip_file_path = self.__bulk_downloads_folder / (bulk_download_id + ".zip") + + with ZipFile(zip_file_path, "w") as zip_file: + for image_name, image_path in image_names_to_paths.items(): + zip_file.write(image_path, arcname=image_name) + + return str(zip_file_path) + + + def _signal_job_started(self, bulk_download_id: str) -> None: + """Signal that a bulk download job has started.""" + if self.__event_bus: + assert bulk_download_id is not None + self.__event_bus.emit_bulk_download_started( + bulk_download_id=bulk_download_id, + ) + + + def _signal_job_completed(self, bulk_download_id: str, file_path: str) -> None: + """Signal that a bulk download job has completed.""" + if self.__event_bus: + assert bulk_download_id is not None + assert file_path is not None + self.__event_bus.emit_bulk_download_completed( + bulk_download_id=bulk_download_id, + file_path=file_path, + ) + + def _signal_job_failed(self, bulk_download_id: str, exception: Exception) -> None: + """Signal that a bulk download job has failed.""" + if self.__event_bus: + assert bulk_download_id is not None + assert exception is not None + self.__event_bus.emit_bulk_download_failed( + bulk_download_id=bulk_download_id, + error=str(exception), + ) + + \ No newline at end of file diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 04fe71a3eb..a560696692 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -16,6 +16,7 @@ if TYPE_CHECKING: from .board_images.board_images_base import BoardImagesServiceABC from .board_records.board_records_base import BoardRecordStorageBase from .boards.boards_base import BoardServiceABC + from .bulk_download.bulk_download_base import BulkDownloadBase from .config import InvokeAIAppConfig from .download import DownloadQueueServiceBase from .events.events_base import EventServiceBase @@ -41,6 +42,7 @@ class InvocationServices: board_image_records: "BoardImageRecordStorageBase", boards: "BoardServiceABC", board_records: "BoardRecordStorageBase", + bulk_download: "BulkDownloadBase", configuration: "InvokeAIAppConfig", events: "EventServiceBase", images: "ImageServiceABC", @@ -63,6 +65,7 @@ class InvocationServices: self.board_image_records = board_image_records self.boards = boards self.board_records = board_records + self.bulk_download = bulk_download self.configuration = configuration self.events = events self.images = images