diff --git a/invokeai/app/services/bulk_download/bulk_download_base.py b/invokeai/app/services/bulk_download/bulk_download_base.py index 8a9ea1f3f2..366a5fec5f 100644 --- a/invokeai/app/services/bulk_download/bulk_download_base.py +++ b/invokeai/app/services/bulk_download/bulk_download_base.py @@ -31,11 +31,11 @@ class BulkDownloadBase(ABC): """ @abstractmethod - def get_path(self, bulk_download_item_id: str) -> str: + def get_path(self, bulk_download_item_name: str) -> str: """ Get the path to the bulk download file. - :param bulk_download_item_id: The ID of the bulk download item. + :param bulk_download_item_name: The name of the bulk download item. :return: The path to the bulk download file. """ diff --git a/invokeai/app/services/bulk_download/bulk_download_default.py b/invokeai/app/services/bulk_download/bulk_download_default.py index 561fd173a8..b80b8cc2f5 100644 --- a/invokeai/app/services/bulk_download/bulk_download_default.py +++ b/invokeai/app/services/bulk_download/bulk_download_default.py @@ -48,15 +48,6 @@ class BulkDownloadService(BulkDownloadBase): raise BulkDownloadTargetException() return path - def get_bulk_download_item_name(self, bulk_download_item_id: str) -> str: - """ - Get the name of the bulk download item. - - :param bulk_download_item_id: The ID of the bulk download item. - :return: The name of the bulk download item. - """ - return bulk_download_item_id + ".zip" - def validate_path(self, path: Union[str, Path]) -> bool: """Validates the path given for a bulk download.""" path = path if isinstance(path, Path) else Path(path) @@ -69,17 +60,27 @@ class BulkDownloadService(BulkDownloadBase): param: image_names: A list of image names to include in the zip file. param: board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file. """ - bulk_download_id = DEFAULT_BULK_DOWNLOAD_ID - bulk_download_item_id = str(uuid.uuid4()) - self._signal_job_started(bulk_download_id, bulk_download_item_id) + + bulk_download_id: str = DEFAULT_BULK_DOWNLOAD_ID + bulk_download_item_id: str = str(uuid.uuid4()) if board_id is None else board_id try: + board_name: str = "" if board_id: image_names = invoker.services.board_image_records.get_all_board_image_names_for_board(board_id) if board_id == "none": board_id = "Uncategorized" + board_name = "Uncategorized" + else: + board_name = invoker.services.board_records.get(board_id).board_name + board_name = self._clean_string_to_path_safe(board_name) + + self._signal_job_started(bulk_download_id, bulk_download_item_id) + image_names_to_paths: dict[str, str] = self._get_image_name_to_path_map(invoker, image_names) - bulk_download_item_name: str = self._create_zip_file(image_names_to_paths, bulk_download_item_id) + bulk_download_item_name: str = self._create_zip_file( + image_names_to_paths, bulk_download_item_id if board_id is None else board_name + ) self._signal_job_completed(bulk_download_id, bulk_download_item_id, bulk_download_item_name) except (ImageRecordNotFoundException, BoardRecordNotFoundException, BulkDownloadException) as e: self._signal_job_failed(bulk_download_id, bulk_download_item_id, e) @@ -112,6 +113,10 @@ class BulkDownloadService(BulkDownloadBase): return str(zip_file_name) + def _clean_string_to_path_safe(self, s: str) -> str: + """Clean a string to be path safe.""" + return "".join([c for c in s if c.isalpha() or c.isdigit() or c == " "]).rstrip() + def _signal_job_started(self, bulk_download_id: str, bulk_download_item_id: str) -> None: """Signal that a bulk download job has started.""" if self.__event_bus: @@ -146,4 +151,10 @@ class BulkDownloadService(BulkDownloadBase): ) def stop(self, *args, **kwargs): - pass + """Stop the bulk download service and delete the files in the bulk download folder.""" + # Get all the files in the bulk downloads folder + files = self.__bulk_downloads_folder.glob("*") + + # Delete all the files + for file in files: + file.unlink()