refactoring bulk_download to be better managed

This commit is contained in:
Stefan Tobler 2024-01-15 12:59:45 -05:00 committed by psychedelicious
parent db812133e7
commit 7d91426d8f
3 changed files with 15 additions and 21 deletions

View File

@ -82,7 +82,7 @@ class ApiDependencies:
board_records = SqliteBoardRecordStorage(db=db)
boards = BoardService()
events = FastAPIEventService(event_handler_id)
bulk_download = BulkDownloadService(output_folder=f"{output_folder}", event_bus=events)
bulk_download = BulkDownloadService(output_folder=f"{output_folder}")
image_records = SqliteImageRecordStorage(db=db)
images = ImageService()
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)

View File

@ -2,7 +2,6 @@ from abc import ABC, abstractmethod
from pathlib import Path
from typing import Optional, Union
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker
@ -19,16 +18,11 @@ class BulkDownloadBase(ABC):
"""
@abstractmethod
def __init__(
self,
output_folder: Union[str, Path],
event_bus: Optional["EventServiceBase"] = None,
):
def __init__(self, output_folder: Union[str, Path]):
"""
Create BulkDownloadBase object.
:param output_folder: The path to the output folder where the bulk download files can be temporarily stored.
:param event_bus: InvokeAI event bus for reporting events to.
"""
@abstractmethod

View File

@ -1,4 +1,3 @@
import uuid
from pathlib import Path
from typing import Optional, Union
from zipfile import ZipFile
@ -13,6 +12,7 @@ from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.image_records.image_records_common import ImageRecordNotFoundException
from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.invoker import Invoker
from invokeai.app.util.misc import uuid_string
from .bulk_download_base import BulkDownloadBase
@ -20,26 +20,23 @@ from .bulk_download_base import BulkDownloadBase
class BulkDownloadService(BulkDownloadBase):
__output_folder: Path
__bulk_downloads_folder: Path
__event_bus: Optional[EventServiceBase]
__event_bus: EventServiceBase
__invoker: Invoker
def start(self, invoker: Invoker) -> None:
self.__invoker = invoker
self.__event_bus = invoker.services.events
def __init__(
self,
output_folder: Union[str, Path],
event_bus: Optional[EventServiceBase] = None,
):
"""
Initialize the downloader object.
:param event_bus: Optional EventService object
"""
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
self.__bulk_downloads_folder = self.__output_folder / "bulk_downloads"
self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True)
self.__event_bus = event_bus
def get_path(self, bulk_download_item_name: str) -> str:
"""
@ -67,7 +64,7 @@ class BulkDownloadService(BulkDownloadBase):
"""
bulk_download_id: str = DEFAULT_BULK_DOWNLOAD_ID
bulk_download_item_id: str = str(uuid.uuid4()) if board_id is None else board_id
bulk_download_item_id: str = uuid_string() if board_id is None else board_id
self._signal_job_started(bulk_download_id, bulk_download_item_id)
@ -76,10 +73,7 @@ class BulkDownloadService(BulkDownloadBase):
image_dtos: list[ImageDTO] = []
if board_id:
if board_id == "none":
board_name = "Uncategorized"
else:
board_name = self.__invoker.services.board_records.get(board_id).board_name
board_name = self._get_board_name(board_id)
board_name = self._clean_string_to_path_safe(board_name)
# -1 is the default value for limit, which means no limit, is_intermediate only gives us completed images
@ -102,6 +96,12 @@ class BulkDownloadService(BulkDownloadBase):
self.__invoker.services.logger.error("Problem bulk downloading images.")
raise e
def _get_board_name(self, board_id: str) -> str:
if board_id == "none":
return "Uncategorized"
return self.__invoker.services.board_records.get(board_id).board_name
def _create_zip_file(self, image_dtos: list[ImageDTO], bulk_download_item_id: str) -> str:
"""
Create a zip file containing the images specified by the given image names or board id.
@ -115,8 +115,8 @@ class BulkDownloadService(BulkDownloadBase):
with ZipFile(zip_file_path, "w") as zip_file:
for image_dto in image_dtos:
image_zip_path = Path(image_dto.image_category.value) / image_dto.image_name
image_path = self.__invoker.services.images.get_path(image_dto.image_name)
zip_file.write(image_path, arcname=image_zip_path)
image_disk_path = self.__invoker.services.images.get_path(image_dto.image_name)
zip_file.write(image_disk_path, arcname=image_zip_path)
return str(zip_file_name)