mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactoring bulk_download to be better managed
This commit is contained in:
parent
db812133e7
commit
7d91426d8f
@ -82,7 +82,7 @@ class ApiDependencies:
|
||||
board_records = SqliteBoardRecordStorage(db=db)
|
||||
boards = BoardService()
|
||||
events = FastAPIEventService(event_handler_id)
|
||||
bulk_download = BulkDownloadService(output_folder=f"{output_folder}", event_bus=events)
|
||||
bulk_download = BulkDownloadService(output_folder=f"{output_folder}")
|
||||
image_records = SqliteImageRecordStorage(db=db)
|
||||
images = ImageService()
|
||||
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
|
||||
|
@ -2,7 +2,6 @@ from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
|
||||
|
||||
@ -19,16 +18,11 @@ class BulkDownloadBase(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
output_folder: Union[str, Path],
|
||||
event_bus: Optional["EventServiceBase"] = None,
|
||||
):
|
||||
def __init__(self, output_folder: Union[str, Path]):
|
||||
"""
|
||||
Create BulkDownloadBase object.
|
||||
|
||||
:param output_folder: The path to the output folder where the bulk download files can be temporarily stored.
|
||||
:param event_bus: InvokeAI event bus for reporting events to.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
|
@ -1,4 +1,3 @@
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from zipfile import ZipFile
|
||||
@ -13,6 +12,7 @@ from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.image_records.image_records_common import ImageRecordNotFoundException
|
||||
from invokeai.app.services.images.images_common import ImageDTO
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
from .bulk_download_base import BulkDownloadBase
|
||||
|
||||
@ -20,26 +20,23 @@ from .bulk_download_base import BulkDownloadBase
|
||||
class BulkDownloadService(BulkDownloadBase):
|
||||
__output_folder: Path
|
||||
__bulk_downloads_folder: Path
|
||||
__event_bus: Optional[EventServiceBase]
|
||||
__event_bus: EventServiceBase
|
||||
__invoker: Invoker
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self.__invoker = invoker
|
||||
self.__event_bus = invoker.services.events
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
output_folder: Union[str, Path],
|
||||
event_bus: Optional[EventServiceBase] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the downloader object.
|
||||
|
||||
:param event_bus: Optional EventService object
|
||||
"""
|
||||
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
||||
self.__bulk_downloads_folder = self.__output_folder / "bulk_downloads"
|
||||
self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True)
|
||||
self.__event_bus = event_bus
|
||||
|
||||
def get_path(self, bulk_download_item_name: str) -> str:
|
||||
"""
|
||||
@ -67,7 +64,7 @@ class BulkDownloadService(BulkDownloadBase):
|
||||
"""
|
||||
|
||||
bulk_download_id: str = DEFAULT_BULK_DOWNLOAD_ID
|
||||
bulk_download_item_id: str = str(uuid.uuid4()) if board_id is None else board_id
|
||||
bulk_download_item_id: str = uuid_string() if board_id is None else board_id
|
||||
|
||||
self._signal_job_started(bulk_download_id, bulk_download_item_id)
|
||||
|
||||
@ -76,10 +73,7 @@ class BulkDownloadService(BulkDownloadBase):
|
||||
image_dtos: list[ImageDTO] = []
|
||||
|
||||
if board_id:
|
||||
if board_id == "none":
|
||||
board_name = "Uncategorized"
|
||||
else:
|
||||
board_name = self.__invoker.services.board_records.get(board_id).board_name
|
||||
board_name = self._get_board_name(board_id)
|
||||
board_name = self._clean_string_to_path_safe(board_name)
|
||||
|
||||
# -1 is the default value for limit, which means no limit, is_intermediate only gives us completed images
|
||||
@ -102,6 +96,12 @@ class BulkDownloadService(BulkDownloadBase):
|
||||
self.__invoker.services.logger.error("Problem bulk downloading images.")
|
||||
raise e
|
||||
|
||||
def _get_board_name(self, board_id: str) -> str:
|
||||
if board_id == "none":
|
||||
return "Uncategorized"
|
||||
|
||||
return self.__invoker.services.board_records.get(board_id).board_name
|
||||
|
||||
def _create_zip_file(self, image_dtos: list[ImageDTO], bulk_download_item_id: str) -> str:
|
||||
"""
|
||||
Create a zip file containing the images specified by the given image names or board id.
|
||||
@ -115,8 +115,8 @@ class BulkDownloadService(BulkDownloadBase):
|
||||
with ZipFile(zip_file_path, "w") as zip_file:
|
||||
for image_dto in image_dtos:
|
||||
image_zip_path = Path(image_dto.image_category.value) / image_dto.image_name
|
||||
image_path = self.__invoker.services.images.get_path(image_dto.image_name)
|
||||
zip_file.write(image_path, arcname=image_zip_path)
|
||||
image_disk_path = self.__invoker.services.images.get_path(image_dto.image_name)
|
||||
zip_file.write(image_disk_path, arcname=image_zip_path)
|
||||
|
||||
return str(zip_file_name)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user