mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactoring handlers to do null check
This commit is contained in:
parent
ec129662a6
commit
6468b044d8
@ -1,6 +1,6 @@
|
|||||||
import io
|
import io
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Optional, cast
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request, Response, UploadFile
|
from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request, Response, UploadFile
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
@ -396,10 +396,9 @@ async def download_images_from_list(
|
|||||||
raise HTTPException(status_code=400, detail="No images or board id specified.")
|
raise HTTPException(status_code=400, detail="No images or board id specified.")
|
||||||
bulk_download_item_id: str = ApiDependencies.invoker.services.bulk_download.generate_item_id(board_id)
|
bulk_download_item_id: str = ApiDependencies.invoker.services.bulk_download.generate_item_id(board_id)
|
||||||
|
|
||||||
# Type narrowing handled above ^, we know that image_names is not None, trying to keep null checks at the boundaries
|
|
||||||
background_tasks.add_task(
|
background_tasks.add_task(
|
||||||
ApiDependencies.invoker.services.bulk_download.handler,
|
ApiDependencies.invoker.services.bulk_download.handler,
|
||||||
cast(list[str], image_names),
|
image_names,
|
||||||
board_id,
|
board_id,
|
||||||
bulk_download_item_id,
|
bulk_download_item_id,
|
||||||
)
|
)
|
||||||
|
@ -23,7 +23,9 @@ class BulkDownloadBase(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def handler(self, image_names: list[str], board_id: Optional[str], bulk_download_item_id: Optional[str]) -> None:
|
def handler(
|
||||||
|
self, image_names: Optional[list[str]], board_id: Optional[str], bulk_download_item_id: Optional[str]
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Starts a a bulk download job.
|
Starts a a bulk download job.
|
||||||
|
|
||||||
|
@ -7,6 +7,7 @@ from invokeai.app.services.board_records.board_records_common import BoardRecord
|
|||||||
from invokeai.app.services.bulk_download.bulk_download_common import (
|
from invokeai.app.services.bulk_download.bulk_download_common import (
|
||||||
DEFAULT_BULK_DOWNLOAD_ID,
|
DEFAULT_BULK_DOWNLOAD_ID,
|
||||||
BulkDownloadException,
|
BulkDownloadException,
|
||||||
|
BulkDownloadParametersException,
|
||||||
BulkDownloadTargetException,
|
BulkDownloadTargetException,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.events.events_base import EventServiceBase
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||||||
@ -36,7 +37,9 @@ class BulkDownloadService(BulkDownloadBase):
|
|||||||
self.__bulk_downloads_folder = Path(self.__temp_directory.name) / "bulk_downloads"
|
self.__bulk_downloads_folder = Path(self.__temp_directory.name) / "bulk_downloads"
|
||||||
self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True)
|
self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
def handler(self, image_names: list[str], board_id: Optional[str], bulk_download_item_id: Optional[str]) -> None:
|
def handler(
|
||||||
|
self, image_names: Optional[list[str]], board_id: Optional[str], bulk_download_item_id: Optional[str]
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Create a zip file containing the images specified by the given image names or board id.
|
Create a zip file containing the images specified by the given image names or board id.
|
||||||
|
|
||||||
@ -53,15 +56,12 @@ class BulkDownloadService(BulkDownloadBase):
|
|||||||
image_dtos: list[ImageDTO] = []
|
image_dtos: list[ImageDTO] = []
|
||||||
|
|
||||||
if board_id:
|
if board_id:
|
||||||
# -1 is the default value for limit, which means no limit, is_intermediate False only gives us completed images
|
image_dtos = self._board_handler(board_id)
|
||||||
image_dtos = self.__invoker.services.images.get_many(
|
elif image_names:
|
||||||
offset=0,
|
image_dtos = self._image_handler(image_names)
|
||||||
limit=-1,
|
|
||||||
board_id=board_id,
|
|
||||||
is_intermediate=False,
|
|
||||||
).items
|
|
||||||
else:
|
else:
|
||||||
image_dtos = [self.__invoker.services.images.get_dto(image_name) for image_name in image_names]
|
raise BulkDownloadParametersException()
|
||||||
|
|
||||||
bulk_download_item_name: str = self._create_zip_file(image_dtos, bulk_download_item_id)
|
bulk_download_item_name: str = self._create_zip_file(image_dtos, bulk_download_item_id)
|
||||||
self._signal_job_completed(bulk_download_id, bulk_download_item_id, bulk_download_item_name)
|
self._signal_job_completed(bulk_download_id, bulk_download_item_id, bulk_download_item_name)
|
||||||
except (ImageRecordNotFoundException, BoardRecordNotFoundException, BulkDownloadException) as e:
|
except (ImageRecordNotFoundException, BoardRecordNotFoundException, BulkDownloadException) as e:
|
||||||
@ -71,6 +71,19 @@ class BulkDownloadService(BulkDownloadBase):
|
|||||||
self.__invoker.services.logger.error("Problem bulk downloading images.")
|
self.__invoker.services.logger.error("Problem bulk downloading images.")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
def _image_handler(self, image_names: list[str]) -> list[ImageDTO]:
|
||||||
|
return [self.__invoker.services.images.get_dto(image_name) for image_name in image_names]
|
||||||
|
|
||||||
|
def _board_handler(self, board_id: str) -> list[ImageDTO]:
|
||||||
|
# -1 is the default value for limit, which means no limit, is_intermediate False only gives us completed images
|
||||||
|
image_dtos = self.__invoker.services.images.get_many(
|
||||||
|
offset=0,
|
||||||
|
limit=-1,
|
||||||
|
board_id=board_id,
|
||||||
|
is_intermediate=False,
|
||||||
|
).items
|
||||||
|
return image_dtos
|
||||||
|
|
||||||
def generate_item_id(self, board_id: Optional[str]) -> str:
|
def generate_item_id(self, board_id: Optional[str]) -> str:
|
||||||
return uuid_string() if board_id is None else self._get_clean_board_name(board_id) + "_" + uuid_string()
|
return uuid_string() if board_id is None else self._get_clean_board_name(board_id) + "_" + uuid_string()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user