From a8d7cf4e97b15a371abe5b2441160dd50659a0a3 Mon Sep 17 00:00:00 2001 From: Stefan Tobler Date: Sat, 17 Feb 2024 23:53:38 -0500 Subject: [PATCH] refactoring handlers to do null check --- invokeai/app/api/routers/images.py | 5 ++- .../bulk_download/bulk_download_base.py | 4 ++- .../bulk_download/bulk_download_default.py | 31 +++++++++++++------ 3 files changed, 27 insertions(+), 13 deletions(-) diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index d1c64648de..c3504b104d 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -1,6 +1,6 @@ import io import traceback -from typing import Optional, cast +from typing import Optional from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request, Response, UploadFile from fastapi.responses import FileResponse @@ -396,10 +396,9 @@ async def download_images_from_list( raise HTTPException(status_code=400, detail="No images or board id specified.") bulk_download_item_id: str = ApiDependencies.invoker.services.bulk_download.generate_item_id(board_id) - # Type narrowing handled above ^, we know that image_names is not None, trying to keep null checks at the boundaries background_tasks.add_task( ApiDependencies.invoker.services.bulk_download.handler, - cast(list[str], image_names), + image_names, board_id, bulk_download_item_id, ) diff --git a/invokeai/app/services/bulk_download/bulk_download_base.py b/invokeai/app/services/bulk_download/bulk_download_base.py index d889e2ed0e..80a2ddfb25 100644 --- a/invokeai/app/services/bulk_download/bulk_download_base.py +++ b/invokeai/app/services/bulk_download/bulk_download_base.py @@ -23,7 +23,9 @@ class BulkDownloadBase(ABC): """ @abstractmethod - def handler(self, image_names: list[str], board_id: Optional[str], bulk_download_item_id: Optional[str]) -> None: + def handler( + self, image_names: Optional[list[str]], board_id: Optional[str], bulk_download_item_id: Optional[str] + ) -> None: """ Starts a a bulk download job. diff --git a/invokeai/app/services/bulk_download/bulk_download_default.py b/invokeai/app/services/bulk_download/bulk_download_default.py index 4f5bfb087f..72bb5a5d52 100644 --- a/invokeai/app/services/bulk_download/bulk_download_default.py +++ b/invokeai/app/services/bulk_download/bulk_download_default.py @@ -7,6 +7,7 @@ from invokeai.app.services.board_records.board_records_common import BoardRecord from invokeai.app.services.bulk_download.bulk_download_common import ( DEFAULT_BULK_DOWNLOAD_ID, BulkDownloadException, + BulkDownloadParametersException, BulkDownloadTargetException, ) from invokeai.app.services.events.events_base import EventServiceBase @@ -36,7 +37,9 @@ class BulkDownloadService(BulkDownloadBase): self.__bulk_downloads_folder = Path(self.__temp_directory.name) / "bulk_downloads" self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True) - def handler(self, image_names: list[str], board_id: Optional[str], bulk_download_item_id: Optional[str]) -> None: + def handler( + self, image_names: Optional[list[str]], board_id: Optional[str], bulk_download_item_id: Optional[str] + ) -> None: """ Create a zip file containing the images specified by the given image names or board id. @@ -53,15 +56,12 @@ class BulkDownloadService(BulkDownloadBase): image_dtos: list[ImageDTO] = [] if board_id: - # -1 is the default value for limit, which means no limit, is_intermediate False only gives us completed images - image_dtos = self.__invoker.services.images.get_many( - offset=0, - limit=-1, - board_id=board_id, - is_intermediate=False, - ).items + image_dtos = self._board_handler(board_id) + elif image_names: + image_dtos = self._image_handler(image_names) else: - image_dtos = [self.__invoker.services.images.get_dto(image_name) for image_name in image_names] + raise BulkDownloadParametersException() + bulk_download_item_name: str = self._create_zip_file(image_dtos, bulk_download_item_id) self._signal_job_completed(bulk_download_id, bulk_download_item_id, bulk_download_item_name) except (ImageRecordNotFoundException, BoardRecordNotFoundException, BulkDownloadException) as e: @@ -71,6 +71,19 @@ class BulkDownloadService(BulkDownloadBase): self.__invoker.services.logger.error("Problem bulk downloading images.") raise e + def _image_handler(self, image_names: list[str]) -> list[ImageDTO]: + return [self.__invoker.services.images.get_dto(image_name) for image_name in image_names] + + def _board_handler(self, board_id: str) -> list[ImageDTO]: + # -1 is the default value for limit, which means no limit, is_intermediate False only gives us completed images + image_dtos = self.__invoker.services.images.get_many( + offset=0, + limit=-1, + board_id=board_id, + is_intermediate=False, + ).items + return image_dtos + def generate_item_id(self, board_id: Optional[str]) -> str: return uuid_string() if board_id is None else self._get_clean_board_name(board_id) + "_" + uuid_string()