moving the responsibility of cleaning up board names to the service not the route

This commit is contained in:
Stefan Tobler
2024-02-16 15:50:48 -05:00
committed by psychedelicious
parent 124075ae7a
commit 8033589629
5 changed files with 63 additions and 70 deletions

View File

@ -13,7 +13,6 @@ from invokeai.app.services.image_records.image_records_common import ImageCatego
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID, WorkflowWithoutIDValidator
from invokeai.app.util.misc import uuid_string
from ..dependencies import ApiDependencies
@ -395,10 +394,7 @@ async def download_images_from_list(
) -> ImagesDownloaded:
if (image_names is None or len(image_names) == 0) and board_id is None:
raise HTTPException(status_code=400, detail="No images or board id specified.")
bulk_download_item_id: str = uuid_string() if board_id is None else board_id
board_name: str = (
"" if board_id is None else ApiDependencies.invoker.services.bulk_download.get_clean_board_name(board_id)
)
bulk_download_item_id: str = ApiDependencies.invoker.services.bulk_download.generate_item_id(board_id)
# Type narrowing handled above ^, we know that image_names is not None, trying to keep null checks at the boundaries
background_tasks.add_task(
@ -409,7 +405,7 @@ async def download_images_from_list(
)
return ImagesDownloaded(
response="Your images are preparing to be downloaded",
bulk_download_item_name=bulk_download_item_id if board_id is None else board_name + ".zip",
bulk_download_item_name=bulk_download_item_id + ".zip",
)

View File

@ -30,9 +30,9 @@ class BulkDownloadBase(ABC):
"""
Starts a a bulk download job.
:param invoker: The Invoker that holds all the services, required to be passed as a parameter to avoid circular dependencies.
:param image_names: A list of image names to include in the zip file.
:param board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file.
:param bulk_download_item_id: The bulk_download_item_id that will be used to retrieve the bulk download item when it is prepared, if none is provided a uuid will be generated.
"""
@abstractmethod
@ -45,12 +45,12 @@ class BulkDownloadBase(ABC):
"""
@abstractmethod
def get_clean_board_name(self, board_id: str) -> str:
def generate_item_id(self, board_id: Optional[str]) -> str:
"""
Get the name of the board.
Generate an item ID for a bulk download item.
:param board_id: The ID of the board.
:return: The name of the board.
:param board_id: The ID of the board whose name is to be included in the item id.
:return: The generated item ID.
"""
@abstractmethod
@ -61,12 +61,8 @@ class BulkDownloadBase(ABC):
This method is responsible for stopping the BulkDownloadService and performing any necessary cleanup
operations to remove any remnants or resources associated with the service.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
None
:param *args: Variable length argument list.
:param **kwargs: Arbitrary keyword arguments.
"""
@abstractmethod

View File

@ -50,19 +50,15 @@ class BulkDownloadService(BulkDownloadBase):
"""
bulk_download_id: str = DEFAULT_BULK_DOWNLOAD_ID
if bulk_download_item_id is None:
bulk_download_item_id = uuid_string() if board_id is None else board_id
bulk_download_item_id = uuid_string() if bulk_download_item_id is None else bulk_download_item_id
self._signal_job_started(bulk_download_id, bulk_download_item_id)
try:
board_name: str = ""
image_dtos: list[ImageDTO] = []
if board_id:
board_name = self.get_clean_board_name(board_id)
# -1 is the default value for limit, which means no limit, is_intermediate only gives us completed images
# -1 is the default value for limit, which means no limit, is_intermediate False only gives us completed images
image_dtos = self.__invoker.services.images.get_many(
offset=0,
limit=-1,
@ -71,9 +67,7 @@ class BulkDownloadService(BulkDownloadBase):
).items
else:
image_dtos = [self.__invoker.services.images.get_dto(image_name) for image_name in image_names]
bulk_download_item_name: str = self._create_zip_file(
image_dtos, bulk_download_item_id if board_id is None else board_name
)
bulk_download_item_name: str = self._create_zip_file(image_dtos, bulk_download_item_id)
self._signal_job_completed(bulk_download_id, bulk_download_item_id, bulk_download_item_name)
except (ImageRecordNotFoundException, BoardRecordNotFoundException, BulkDownloadException) as e:
self._signal_job_failed(bulk_download_id, bulk_download_item_id, e)
@ -82,7 +76,10 @@ class BulkDownloadService(BulkDownloadBase):
self.__invoker.services.logger.error("Problem bulk downloading images.")
raise e
def get_clean_board_name(self, board_id: str) -> str:
def generate_item_id(self, board_id: Optional[str]) -> str:
return uuid_string() if board_id is None else self._get_clean_board_name(board_id) + "_" + uuid_string()
def _get_clean_board_name(self, board_id: str) -> str:
if board_id == "none":
return "Uncategorized"
@ -109,7 +106,7 @@ class BulkDownloadService(BulkDownloadBase):
# from https://stackoverflow.com/questions/7406102/create-sane-safe-filename-from-any-unsafe-string
def _clean_string_to_path_safe(self, s: str) -> str:
"""Clean a string to be path safe."""
return "".join([c for c in s if c.isalpha() or c.isdigit() or c == " "]).rstrip()
return "".join([c for c in s if c.isalpha() or c.isdigit() or c == " " or c == "_" or c == "-"]).rstrip()
def _signal_job_started(self, bulk_download_id: str, bulk_download_item_id: str) -> None:
"""Signal that a bulk download job has started."""
@ -166,11 +163,11 @@ class BulkDownloadService(BulkDownloadBase):
:return: The path to the bulk download file.
"""
path = str(self.__bulk_downloads_folder / bulk_download_item_name)
if not self.validate_path(path):
if not self._is_valid_path(path):
raise BulkDownloadTargetException()
return path
def validate_path(self, path: Union[str, Path]) -> bool:
def _is_valid_path(self, path: Union[str, Path]) -> bool:
"""Validates the path given for a bulk download."""
path = path if isinstance(path, Path) else Path(path)
return path.exists()