From aa132fb9e3d89874d9d69be8327e9b38390fabbe Mon Sep 17 00:00:00 2001 From: Stefan Tobler Date: Sat, 13 Jan 2024 23:35:33 -0500 Subject: [PATCH] reworking some of the logic to use a default room, adding endpoint to download file on complete --- invokeai/app/api/dependencies.py | 2 +- invokeai/app/api/routers/images.py | 33 +++++++++ .../bulk_download/bulk_download_base.py | 25 +++++++ .../bulk_download/bulk_download_common.py | 3 + ...d_defauilt.py => bulk_download_default.py} | 72 +++++++++++++++---- invokeai/app/services/events/events_base.py | 23 ++++-- 6 files changed, 137 insertions(+), 21 deletions(-) rename invokeai/app/services/bulk_download/{bulk_download_defauilt.py => bulk_download_default.py} (60%) diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index ab09d1e5d7..aaa08a2498 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -15,7 +15,7 @@ from ..services.board_image_records.board_image_records_sqlite import SqliteBoar from ..services.board_images.board_images_default import BoardImagesService from ..services.board_records.board_records_sqlite import SqliteBoardRecordStorage from ..services.boards.boards_default import BoardService -from ..services.bulk_download.bulk_download_defauilt import BulkDownloadService +from ..services.bulk_download.bulk_download_default import BulkDownloadService from ..services.config import InvokeAIAppConfig from ..services.download import DownloadQueueService from ..services.image_files.image_files_disk import DiskImageFileStorage diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index e32f7fb9ee..43392dd471 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -395,3 +395,36 @@ async def download_images_from_list( ApiDependencies.invoker.services.bulk_download.handler, ApiDependencies.invoker, image_names, board_id ) return ImagesDownloaded(response="Your images are preparing to be downloaded") + + +@images_router.api_route( + "/download/{bulk_download_item_name}", + methods=["GET"], + operation_id="get_bulk_download_item", + response_class=Response, + responses={ + 200: { + "description": "Return the complete bulk download item", + "content": {"application/zip": {}}, + }, + 404: {"description": "Image not found"}, + }, +) +async def get_bulk_download_item( + bulk_download_item_name: str = Path(description="The bulk_download_item_id of the bulk download item to get"), +) -> FileResponse: + """Gets a bulk download zip file""" + + try: + path = ApiDependencies.invoker.services.bulk_download.get_path(bulk_download_item_name) + + response = FileResponse( + path, + media_type="application/zip", + filename=bulk_download_item_name, + content_disposition_type="inline", + ) + response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}" + return response + except Exception: + raise HTTPException(status_code=404) diff --git a/invokeai/app/services/bulk_download/bulk_download_base.py b/invokeai/app/services/bulk_download/bulk_download_base.py index fc45aff280..8a9ea1f3f2 100644 --- a/invokeai/app/services/bulk_download/bulk_download_base.py +++ b/invokeai/app/services/bulk_download/bulk_download_base.py @@ -29,3 +29,28 @@ class BulkDownloadBase(ABC): :param image_names: A list of image names to include in the zip file. :param board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file. """ + + @abstractmethod + def get_path(self, bulk_download_item_id: str) -> str: + """ + Get the path to the bulk download file. + + :param bulk_download_item_id: The ID of the bulk download item. + :return: The path to the bulk download file. + """ + + @abstractmethod + def stop(self, *args, **kwargs) -> None: + """ + Stops the BulkDownloadService and cleans up all the remnants. + + This method is responsible for stopping the BulkDownloadService and performing any necessary cleanup + operations to remove any remnants or resources associated with the service. + + Args: + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + None + """ diff --git a/invokeai/app/services/bulk_download/bulk_download_common.py b/invokeai/app/services/bulk_download/bulk_download_common.py index 23a0589daf..37b80073be 100644 --- a/invokeai/app/services/bulk_download/bulk_download_common.py +++ b/invokeai/app/services/bulk_download/bulk_download_common.py @@ -1,3 +1,6 @@ +DEFAULT_BULK_DOWNLOAD_ID = "default" + + class BulkDownloadException(Exception): """Exception raised when a bulk download fails.""" diff --git a/invokeai/app/services/bulk_download/bulk_download_defauilt.py b/invokeai/app/services/bulk_download/bulk_download_default.py similarity index 60% rename from invokeai/app/services/bulk_download/bulk_download_defauilt.py rename to invokeai/app/services/bulk_download/bulk_download_default.py index 8321f5069d..561fd173a8 100644 --- a/invokeai/app/services/bulk_download/bulk_download_defauilt.py +++ b/invokeai/app/services/bulk_download/bulk_download_default.py @@ -4,7 +4,11 @@ from typing import Optional, Union from zipfile import ZipFile from invokeai.app.services.board_records.board_records_common import BoardRecordNotFoundException -from invokeai.app.services.bulk_download.bulk_download_common import BulkDownloadException +from invokeai.app.services.bulk_download.bulk_download_common import ( + DEFAULT_BULK_DOWNLOAD_ID, + BulkDownloadException, + BulkDownloadTargetException, +) from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.image_records.image_records_common import ImageRecordNotFoundException from invokeai.app.services.invoker import Invoker @@ -32,6 +36,32 @@ class BulkDownloadService(BulkDownloadBase): self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True) self.__event_bus = event_bus + def get_path(self, bulk_download_item_name: str) -> str: + """ + Get the path to the bulk download file. + + :param bulk_download_item_name: The name of the bulk download item. + :return: The path to the bulk download file. + """ + path = str(self.__bulk_downloads_folder / bulk_download_item_name) + if not self.validate_path(path): + raise BulkDownloadTargetException() + return path + + def get_bulk_download_item_name(self, bulk_download_item_id: str) -> str: + """ + Get the name of the bulk download item. + + :param bulk_download_item_id: The ID of the bulk download item. + :return: The name of the bulk download item. + """ + return bulk_download_item_id + ".zip" + + def validate_path(self, path: Union[str, Path]) -> bool: + """Validates the path given for a bulk download.""" + path = path if isinstance(path, Path) else Path(path) + return path.exists() + def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[str]) -> None: """ Create a zip file containing the images specified by the given image names or board id. @@ -39,7 +69,9 @@ class BulkDownloadService(BulkDownloadBase): param: image_names: A list of image names to include in the zip file. param: board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file. """ - bulk_download_id = str(uuid.uuid4()) + bulk_download_id = DEFAULT_BULK_DOWNLOAD_ID + bulk_download_item_id = str(uuid.uuid4()) + self._signal_job_started(bulk_download_id, bulk_download_item_id) try: if board_id: @@ -47,12 +79,12 @@ class BulkDownloadService(BulkDownloadBase): if board_id == "none": board_id = "Uncategorized" image_names_to_paths: dict[str, str] = self._get_image_name_to_path_map(invoker, image_names) - file_path: str = self._create_zip_file(image_names_to_paths, bulk_download_id) - self._signal_job_completed(bulk_download_id, file_path) + bulk_download_item_name: str = self._create_zip_file(image_names_to_paths, bulk_download_item_id) + self._signal_job_completed(bulk_download_id, bulk_download_item_id, bulk_download_item_name) except (ImageRecordNotFoundException, BoardRecordNotFoundException, BulkDownloadException) as e: - self._signal_job_failed(bulk_download_id, e) + self._signal_job_failed(bulk_download_id, bulk_download_item_id, e) except Exception as e: - self._signal_job_failed(bulk_download_id, e) + self._signal_job_failed(bulk_download_id, bulk_download_item_id, e) def _get_image_name_to_path_map(self, invoker: Invoker, image_names: list[str]) -> dict[str, str]: """ @@ -64,44 +96,54 @@ class BulkDownloadService(BulkDownloadBase): image_names_to_paths[image_name] = invoker.services.images.get_path(image_name) return image_names_to_paths - def _create_zip_file(self, image_names_to_paths: dict[str, str], bulk_download_id: str) -> str: + def _create_zip_file(self, image_names_to_paths: dict[str, str], bulk_download_item_id: str) -> str: """ Create a zip file containing the images specified by the given image names or board id. If download with the same bulk_download_id already exists, it will be overwritten. - """ - zip_file_path = self.__bulk_downloads_folder / (bulk_download_id + ".zip") + :return: The name of the zip file. + """ + zip_file_name = bulk_download_item_id + ".zip" + zip_file_path = self.__bulk_downloads_folder / (zip_file_name) with ZipFile(zip_file_path, "w") as zip_file: for image_name, image_path in image_names_to_paths.items(): zip_file.write(image_path, arcname=image_name) - return str(zip_file_path) + return str(zip_file_name) - def _signal_job_started(self, bulk_download_id: str) -> None: + def _signal_job_started(self, bulk_download_id: str, bulk_download_item_id: str) -> None: """Signal that a bulk download job has started.""" if self.__event_bus: assert bulk_download_id is not None self.__event_bus.emit_bulk_download_started( bulk_download_id=bulk_download_id, + bulk_download_item_id=bulk_download_item_id, ) - def _signal_job_completed(self, bulk_download_id: str, file_path: str) -> None: + def _signal_job_completed( + self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str + ) -> None: """Signal that a bulk download job has completed.""" if self.__event_bus: assert bulk_download_id is not None - assert file_path is not None + assert bulk_download_item_name is not None self.__event_bus.emit_bulk_download_completed( bulk_download_id=bulk_download_id, - file_path=file_path, + bulk_download_item_id=bulk_download_item_id, + bulk_download_item_name=bulk_download_item_name, ) - def _signal_job_failed(self, bulk_download_id: str, exception: Exception) -> None: + def _signal_job_failed(self, bulk_download_id: str, bulk_download_item_id: str, exception: Exception) -> None: """Signal that a bulk download job has failed.""" if self.__event_bus: assert bulk_download_id is not None assert exception is not None self.__event_bus.emit_bulk_download_failed( bulk_download_id=bulk_download_id, + bulk_download_item_id=bulk_download_item_id, error=str(exception), ) + + def stop(self, *args, **kwargs): + pass diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index 597a56d944..3cc3ba2f28 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -440,23 +440,36 @@ class EventServiceBase: }, ) - def emit_bulk_download_started(self, bulk_download_id: str) -> None: + def emit_bulk_download_started(self, bulk_download_id: str, bulk_download_item_id: str) -> None: """Emitted when a bulk download starts""" self._emit_bulk_download_event( event_name="bulk_download_started", payload={ "bulk_download_id": bulk_download_id, + "bulk_download_item_id": bulk_download_item_id, }, ) - def emit_bulk_download_completed(self, bulk_download_id: str, file_path: str) -> None: + def emit_bulk_download_completed( + self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str + ) -> None: """Emitted when a bulk download completes""" self._emit_bulk_download_event( - event_name="bulk_download_completed", payload={"bulk_download_id": bulk_download_id, "file_path": file_path} + event_name="bulk_download_completed", + payload={ + "bulk_download_id": bulk_download_id, + "bulk_download_item_id": bulk_download_item_id, + "bulk_download_item_name": bulk_download_item_name, + }, ) - def emit_bulk_download_failed(self, bulk_download_id: str, error: str) -> None: + def emit_bulk_download_failed(self, bulk_download_id: str, bulk_download_item_id: str, error: str) -> None: """Emitted when a bulk download fails""" self._emit_bulk_download_event( - event_name="bulk_download_failed", payload={"bulk_download_id": bulk_download_id, "error": error} + event_name="bulk_download_failed", + payload={ + "bulk_download_id": bulk_download_id, + "bulk_download_item_id": bulk_download_item_id, + "error": error, + }, )