diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index aaa08a2498..984fd8e267 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -82,7 +82,7 @@ class ApiDependencies: board_records = SqliteBoardRecordStorage(db=db) boards = BoardService() events = FastAPIEventService(event_handler_id) - bulk_download = BulkDownloadService(output_folder=f"{output_folder}", event_bus=events) + bulk_download = BulkDownloadService(output_folder=f"{output_folder}") image_records = SqliteImageRecordStorage(db=db) images = ImageService() invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size) diff --git a/invokeai/app/services/bulk_download/bulk_download_base.py b/invokeai/app/services/bulk_download/bulk_download_base.py index 880345fe98..7a4aa0661c 100644 --- a/invokeai/app/services/bulk_download/bulk_download_base.py +++ b/invokeai/app/services/bulk_download/bulk_download_base.py @@ -2,7 +2,6 @@ from abc import ABC, abstractmethod from pathlib import Path from typing import Optional, Union -from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invoker import Invoker @@ -19,16 +18,11 @@ class BulkDownloadBase(ABC): """ @abstractmethod - def __init__( - self, - output_folder: Union[str, Path], - event_bus: Optional["EventServiceBase"] = None, - ): + def __init__(self, output_folder: Union[str, Path]): """ Create BulkDownloadBase object. :param output_folder: The path to the output folder where the bulk download files can be temporarily stored. - :param event_bus: InvokeAI event bus for reporting events to. """ @abstractmethod diff --git a/invokeai/app/services/bulk_download/bulk_download_default.py b/invokeai/app/services/bulk_download/bulk_download_default.py index ffc26dfa54..a9ea12bfd6 100644 --- a/invokeai/app/services/bulk_download/bulk_download_default.py +++ b/invokeai/app/services/bulk_download/bulk_download_default.py @@ -1,4 +1,3 @@ -import uuid from pathlib import Path from typing import Optional, Union from zipfile import ZipFile @@ -13,6 +12,7 @@ from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.image_records.image_records_common import ImageRecordNotFoundException from invokeai.app.services.images.images_common import ImageDTO from invokeai.app.services.invoker import Invoker +from invokeai.app.util.misc import uuid_string from .bulk_download_base import BulkDownloadBase @@ -20,26 +20,23 @@ from .bulk_download_base import BulkDownloadBase class BulkDownloadService(BulkDownloadBase): __output_folder: Path __bulk_downloads_folder: Path - __event_bus: Optional[EventServiceBase] + __event_bus: EventServiceBase __invoker: Invoker def start(self, invoker: Invoker) -> None: self.__invoker = invoker + self.__event_bus = invoker.services.events def __init__( self, output_folder: Union[str, Path], - event_bus: Optional[EventServiceBase] = None, ): """ Initialize the downloader object. - - :param event_bus: Optional EventService object """ self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder) self.__bulk_downloads_folder = self.__output_folder / "bulk_downloads" self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True) - self.__event_bus = event_bus def get_path(self, bulk_download_item_name: str) -> str: """ @@ -67,7 +64,7 @@ class BulkDownloadService(BulkDownloadBase): """ bulk_download_id: str = DEFAULT_BULK_DOWNLOAD_ID - bulk_download_item_id: str = str(uuid.uuid4()) if board_id is None else board_id + bulk_download_item_id: str = uuid_string() if board_id is None else board_id self._signal_job_started(bulk_download_id, bulk_download_item_id) @@ -76,10 +73,7 @@ class BulkDownloadService(BulkDownloadBase): image_dtos: list[ImageDTO] = [] if board_id: - if board_id == "none": - board_name = "Uncategorized" - else: - board_name = self.__invoker.services.board_records.get(board_id).board_name + board_name = self._get_board_name(board_id) board_name = self._clean_string_to_path_safe(board_name) # -1 is the default value for limit, which means no limit, is_intermediate only gives us completed images @@ -102,6 +96,12 @@ class BulkDownloadService(BulkDownloadBase): self.__invoker.services.logger.error("Problem bulk downloading images.") raise e + def _get_board_name(self, board_id: str) -> str: + if board_id == "none": + return "Uncategorized" + + return self.__invoker.services.board_records.get(board_id).board_name + def _create_zip_file(self, image_dtos: list[ImageDTO], bulk_download_item_id: str) -> str: """ Create a zip file containing the images specified by the given image names or board id. @@ -115,8 +115,8 @@ class BulkDownloadService(BulkDownloadBase): with ZipFile(zip_file_path, "w") as zip_file: for image_dto in image_dtos: image_zip_path = Path(image_dto.image_category.value) / image_dto.image_name - image_path = self.__invoker.services.images.get_path(image_dto.image_name) - zip_file.write(image_path, arcname=image_zip_path) + image_disk_path = self.__invoker.services.images.get_path(image_dto.image_name) + zip_file.write(image_disk_path, arcname=image_zip_path) return str(zip_file_name)