refactoring handlers to do null check

This commit is contained in:
Stefan Tobler 2024-02-17 23:53:38 -05:00 committed by psychedelicious
parent 037cac8154
commit a8d7cf4e97
3 changed files with 27 additions and 13 deletions

View File

@ -1,6 +1,6 @@
import io import io
import traceback import traceback
from typing import Optional, cast from typing import Optional
from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request, Response, UploadFile from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request, Response, UploadFile
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
@ -396,10 +396,9 @@ async def download_images_from_list(
raise HTTPException(status_code=400, detail="No images or board id specified.") raise HTTPException(status_code=400, detail="No images or board id specified.")
bulk_download_item_id: str = ApiDependencies.invoker.services.bulk_download.generate_item_id(board_id) bulk_download_item_id: str = ApiDependencies.invoker.services.bulk_download.generate_item_id(board_id)
# Type narrowing handled above ^, we know that image_names is not None, trying to keep null checks at the boundaries
background_tasks.add_task( background_tasks.add_task(
ApiDependencies.invoker.services.bulk_download.handler, ApiDependencies.invoker.services.bulk_download.handler,
cast(list[str], image_names), image_names,
board_id, board_id,
bulk_download_item_id, bulk_download_item_id,
) )

View File

@ -23,7 +23,9 @@ class BulkDownloadBase(ABC):
""" """
@abstractmethod @abstractmethod
def handler(self, image_names: list[str], board_id: Optional[str], bulk_download_item_id: Optional[str]) -> None: def handler(
self, image_names: Optional[list[str]], board_id: Optional[str], bulk_download_item_id: Optional[str]
) -> None:
""" """
Starts a a bulk download job. Starts a a bulk download job.

View File

@ -7,6 +7,7 @@ from invokeai.app.services.board_records.board_records_common import BoardRecord
from invokeai.app.services.bulk_download.bulk_download_common import ( from invokeai.app.services.bulk_download.bulk_download_common import (
DEFAULT_BULK_DOWNLOAD_ID, DEFAULT_BULK_DOWNLOAD_ID,
BulkDownloadException, BulkDownloadException,
BulkDownloadParametersException,
BulkDownloadTargetException, BulkDownloadTargetException,
) )
from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.events.events_base import EventServiceBase
@ -36,7 +37,9 @@ class BulkDownloadService(BulkDownloadBase):
self.__bulk_downloads_folder = Path(self.__temp_directory.name) / "bulk_downloads" self.__bulk_downloads_folder = Path(self.__temp_directory.name) / "bulk_downloads"
self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True) self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True)
def handler(self, image_names: list[str], board_id: Optional[str], bulk_download_item_id: Optional[str]) -> None: def handler(
self, image_names: Optional[list[str]], board_id: Optional[str], bulk_download_item_id: Optional[str]
) -> None:
""" """
Create a zip file containing the images specified by the given image names or board id. Create a zip file containing the images specified by the given image names or board id.
@ -53,15 +56,12 @@ class BulkDownloadService(BulkDownloadBase):
image_dtos: list[ImageDTO] = [] image_dtos: list[ImageDTO] = []
if board_id: if board_id:
# -1 is the default value for limit, which means no limit, is_intermediate False only gives us completed images image_dtos = self._board_handler(board_id)
image_dtos = self.__invoker.services.images.get_many( elif image_names:
offset=0, image_dtos = self._image_handler(image_names)
limit=-1,
board_id=board_id,
is_intermediate=False,
).items
else: else:
image_dtos = [self.__invoker.services.images.get_dto(image_name) for image_name in image_names] raise BulkDownloadParametersException()
bulk_download_item_name: str = self._create_zip_file(image_dtos, bulk_download_item_id) bulk_download_item_name: str = self._create_zip_file(image_dtos, bulk_download_item_id)
self._signal_job_completed(bulk_download_id, bulk_download_item_id, bulk_download_item_name) self._signal_job_completed(bulk_download_id, bulk_download_item_id, bulk_download_item_name)
except (ImageRecordNotFoundException, BoardRecordNotFoundException, BulkDownloadException) as e: except (ImageRecordNotFoundException, BoardRecordNotFoundException, BulkDownloadException) as e:
@ -71,6 +71,19 @@ class BulkDownloadService(BulkDownloadBase):
self.__invoker.services.logger.error("Problem bulk downloading images.") self.__invoker.services.logger.error("Problem bulk downloading images.")
raise e raise e
def _image_handler(self, image_names: list[str]) -> list[ImageDTO]:
return [self.__invoker.services.images.get_dto(image_name) for image_name in image_names]
def _board_handler(self, board_id: str) -> list[ImageDTO]:
# -1 is the default value for limit, which means no limit, is_intermediate False only gives us completed images
image_dtos = self.__invoker.services.images.get_many(
offset=0,
limit=-1,
board_id=board_id,
is_intermediate=False,
).items
return image_dtos
def generate_item_id(self, board_id: Optional[str]) -> str: def generate_item_id(self, board_id: Optional[str]) -> str:
return uuid_string() if board_id is None else self._get_clean_board_name(board_id) + "_" + uuid_string() return uuid_string() if board_id is None else self._get_clean_board_name(board_id) + "_" + uuid_string()