implementation of bulkdownload background task

This commit is contained in:
Stefan Tobler 2024-01-07 21:29:42 -05:00 committed by psychedelicious
parent 4d8bec1605
commit fc9a62dbf5
5 changed files with 130 additions and 6 deletions

View File

@ -15,6 +15,7 @@ from ..services.board_image_records.board_image_records_sqlite import SqliteBoar
from ..services.board_images.board_images_default import BoardImagesService
from ..services.board_records.board_records_sqlite import SqliteBoardRecordStorage
from ..services.boards.boards_default import BoardService
from ..services.bulk_download.bulk_download_defauilt import BulkDownloadService
from ..services.config import InvokeAIAppConfig
from ..services.download import DownloadQueueService
from ..services.image_files.image_files_disk import DiskImageFileStorage
@ -81,6 +82,7 @@ class ApiDependencies:
board_records = SqliteBoardRecordStorage(db=db)
boards = BoardService()
events = FastAPIEventService(event_handler_id)
bulk_download = BulkDownloadService(output_folder=f"{output_folder}", event_bus=events)
image_records = SqliteImageRecordStorage(db=db)
images = ImageService()
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
@ -110,6 +112,7 @@ class ApiDependencies:
board_images=board_images,
board_records=board_records,
boards=boards,
bulk_download=bulk_download,
configuration=configuration,
events=events,
image_files=image_files,

View File

@ -2,7 +2,7 @@ import io
import traceback
from typing import Optional
from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile
from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request, Response, UploadFile
from fastapi.responses import FileResponse
from fastapi.routing import APIRouter
from PIL import Image
@ -10,6 +10,7 @@ from pydantic import BaseModel, Field, ValidationError
from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
from invokeai.app.services.board_records.board_records_common import BoardRecordNotFoundException
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID, WorkflowWithoutIDValidator
@ -372,19 +373,22 @@ async def unstar_images_in_list(
except Exception:
raise HTTPException(status_code=500, detail="Failed to unstar images")
class ImagesDownloaded(BaseModel):
response: Optional[str] = Field(
description="If defined, the message to display to the user when images begin downloading"
)
@images_router.post("/download", operation_id="download_images_from_list", response_model=ImagesDownloaded)
@images_router.post("/download", operation_id="download_images_from_list", response_model=ImagesDownloaded, status_code=202)
async def download_images_from_list(
background_tasks: BackgroundTasks,
image_names: list[str] = Body(description="The list of names of images to download", embed=True),
board_id: Optional[str] = Body(
default=None, description="The board from which image should be downloaded from", embed=True
),
) -> ImagesDownloaded:
# return ImagesDownloaded(response="Your images are downloading")
raise HTTPException(status_code=501, detail="Endpoint is not yet implemented")
if (image_names is None or len(image_names) == 0) and board_id is None:
raise HTTPException(status_code=400, detail="No images or board id specified.")
background_tasks.add_task(ApiDependencies.invoker.services.bulk_download.handler, ApiDependencies.invoker, image_names, board_id)
return ImagesDownloaded(response="Your images are preparing to be downloaded")

View File

@ -22,7 +22,7 @@ class BulkDownloadBase(ABC):
"""
@abstractmethod
def start(self, invoker: Invoker, image_names: list[str], board_id: Optional[str]) -> str:
def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[str]) -> None:
"""
Starts a a bulk download job.

View File

@ -0,0 +1,114 @@
from pathlib import Path
from typing import Optional, Union
import uuid
from zipfile import ZipFile
from invokeai.app.services.board_records.board_records_common import BoardRecordNotFoundException
from invokeai.app.services.bulk_download.bulk_download_common import BulkDownloadException
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.image_records.image_records_common import ImageRecordNotFoundException
from invokeai.app.services.invoker import Invoker
from .bulk_download_base import BulkDownloadBase
class BulkDownloadService(BulkDownloadBase):
__output_folder: Path
__bulk_downloads_folder: Path
__event_bus: Optional[EventServiceBase]
def __init__(self,
output_folder: Union[str, Path],
event_bus: Optional[EventServiceBase] = None,):
"""
Initialize the downloader object.
:param event_bus: Optional EventService object
"""
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
self.__bulk_downloads_folder = self.__output_folder / "bulk_downloads"
self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True)
self.__event_bus = event_bus
def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[str]) -> None:
"""
Create a zip file containing the images specified by the given image names or board id.
param: image_names: A list of image names to include in the zip file.
param: board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file.
"""
bulk_download_id = str(uuid.uuid4())
self._signal_job_started(bulk_download_id)
try:
board_name: Union[str, None] = None
if board_id:
image_names = invoker.services.board_image_records.get_all_board_image_names_for_board(board_id)
if board_id == "none":
board_id = "Uncategorized"
image_names_to_paths: dict[str, str] = self._get_image_name_to_path_map(invoker, image_names)
file_path: str = self._create_zip_file(image_names_to_paths, bulk_download_id)
self._signal_job_completed(bulk_download_id, file_path)
except (ImageRecordNotFoundException, BoardRecordNotFoundException, BulkDownloadException) as e:
self._signal_job_failed(bulk_download_id, e)
except Exception as e:
self._signal_job_failed(bulk_download_id, e)
def _get_image_name_to_path_map(self, invoker: Invoker, image_names: list[str]) -> dict[str, str]:
"""
Create a map of image names to their paths.
:param image_names: A list of image names.
"""
image_names_to_paths: dict[str, str] = {}
for image_name in image_names:
image_names_to_paths[image_name] = invoker.services.images.get_path(image_name)
return image_names_to_paths
def _create_zip_file(self, image_names_to_paths: dict[str, str], bulk_download_id: str) -> str:
"""
Create a zip file containing the images specified by the given image names or board id.
If download with the same bulk_download_id already exists, it will be overwritten.
"""
zip_file_path = self.__bulk_downloads_folder / (bulk_download_id + ".zip")
with ZipFile(zip_file_path, "w") as zip_file:
for image_name, image_path in image_names_to_paths.items():
zip_file.write(image_path, arcname=image_name)
return str(zip_file_path)
def _signal_job_started(self, bulk_download_id: str) -> None:
"""Signal that a bulk download job has started."""
if self.__event_bus:
assert bulk_download_id is not None
self.__event_bus.emit_bulk_download_started(
bulk_download_id=bulk_download_id,
)
def _signal_job_completed(self, bulk_download_id: str, file_path: str) -> None:
"""Signal that a bulk download job has completed."""
if self.__event_bus:
assert bulk_download_id is not None
assert file_path is not None
self.__event_bus.emit_bulk_download_completed(
bulk_download_id=bulk_download_id,
file_path=file_path,
)
def _signal_job_failed(self, bulk_download_id: str, exception: Exception) -> None:
"""Signal that a bulk download job has failed."""
if self.__event_bus:
assert bulk_download_id is not None
assert exception is not None
self.__event_bus.emit_bulk_download_failed(
bulk_download_id=bulk_download_id,
error=str(exception),
)

View File

@ -16,6 +16,7 @@ if TYPE_CHECKING:
from .board_images.board_images_base import BoardImagesServiceABC
from .board_records.board_records_base import BoardRecordStorageBase
from .boards.boards_base import BoardServiceABC
from .bulk_download.bulk_download_base import BulkDownloadBase
from .config import InvokeAIAppConfig
from .download import DownloadQueueServiceBase
from .events.events_base import EventServiceBase
@ -41,6 +42,7 @@ class InvocationServices:
board_image_records: "BoardImageRecordStorageBase",
boards: "BoardServiceABC",
board_records: "BoardRecordStorageBase",
bulk_download: "BulkDownloadBase",
configuration: "InvokeAIAppConfig",
events: "EventServiceBase",
images: "ImageServiceABC",
@ -63,6 +65,7 @@ class InvocationServices:
self.board_image_records = board_image_records
self.boards = boards
self.board_records = board_records
self.bulk_download = bulk_download
self.configuration = configuration
self.events = events
self.images = images