mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
implementation of bulkdownload background task
This commit is contained in:
parent
f1967c3393
commit
56d2d220a8
@ -15,6 +15,7 @@ from ..services.board_image_records.board_image_records_sqlite import SqliteBoar
|
||||
from ..services.board_images.board_images_default import BoardImagesService
|
||||
from ..services.board_records.board_records_sqlite import SqliteBoardRecordStorage
|
||||
from ..services.boards.boards_default import BoardService
|
||||
from ..services.bulk_download.bulk_download_defauilt import BulkDownloadService
|
||||
from ..services.config import InvokeAIAppConfig
|
||||
from ..services.download import DownloadQueueService
|
||||
from ..services.image_files.image_files_disk import DiskImageFileStorage
|
||||
@ -81,6 +82,7 @@ class ApiDependencies:
|
||||
board_records = SqliteBoardRecordStorage(db=db)
|
||||
boards = BoardService()
|
||||
events = FastAPIEventService(event_handler_id)
|
||||
bulk_download = BulkDownloadService(output_folder=f"{output_folder}", event_bus=events)
|
||||
image_records = SqliteImageRecordStorage(db=db)
|
||||
images = ImageService()
|
||||
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
|
||||
@ -110,6 +112,7 @@ class ApiDependencies:
|
||||
board_images=board_images,
|
||||
board_records=board_records,
|
||||
boards=boards,
|
||||
bulk_download=bulk_download,
|
||||
configuration=configuration,
|
||||
events=events,
|
||||
image_files=image_files,
|
||||
|
@ -2,7 +2,7 @@ import io
|
||||
import traceback
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile
|
||||
from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request, Response, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.routing import APIRouter
|
||||
from PIL import Image
|
||||
@ -10,6 +10,7 @@ from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
|
||||
from invokeai.app.services.board_records.board_records_common import BoardRecordNotFoundException
|
||||
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID, WorkflowWithoutIDValidator
|
||||
@ -372,19 +373,22 @@ async def unstar_images_in_list(
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to unstar images")
|
||||
|
||||
|
||||
class ImagesDownloaded(BaseModel):
|
||||
response: Optional[str] = Field(
|
||||
description="If defined, the message to display to the user when images begin downloading"
|
||||
)
|
||||
|
||||
|
||||
@images_router.post("/download", operation_id="download_images_from_list", response_model=ImagesDownloaded)
|
||||
@images_router.post("/download", operation_id="download_images_from_list", response_model=ImagesDownloaded, status_code=202)
|
||||
async def download_images_from_list(
|
||||
background_tasks: BackgroundTasks,
|
||||
image_names: list[str] = Body(description="The list of names of images to download", embed=True),
|
||||
board_id: Optional[str] = Body(
|
||||
default=None, description="The board from which image should be downloaded from", embed=True
|
||||
),
|
||||
) -> ImagesDownloaded:
|
||||
# return ImagesDownloaded(response="Your images are downloading")
|
||||
raise HTTPException(status_code=501, detail="Endpoint is not yet implemented")
|
||||
if (image_names is None or len(image_names) == 0) and board_id is None:
|
||||
raise HTTPException(status_code=400, detail="No images or board id specified.")
|
||||
background_tasks.add_task(ApiDependencies.invoker.services.bulk_download.handler, ApiDependencies.invoker, image_names, board_id)
|
||||
return ImagesDownloaded(response="Your images are preparing to be downloaded")
|
||||
|
||||
|
@ -22,7 +22,7 @@ class BulkDownloadBase(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def start(self, invoker: Invoker, image_names: list[str], board_id: Optional[str]) -> str:
|
||||
def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[str]) -> None:
|
||||
"""
|
||||
Starts a a bulk download job.
|
||||
|
||||
|
114
invokeai/app/services/bulk_download/bulk_download_defauilt.py
Normal file
114
invokeai/app/services/bulk_download/bulk_download_defauilt.py
Normal file
@ -0,0 +1,114 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
import uuid
|
||||
from zipfile import ZipFile
|
||||
|
||||
from invokeai.app.services.board_records.board_records_common import BoardRecordNotFoundException
|
||||
from invokeai.app.services.bulk_download.bulk_download_common import BulkDownloadException
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.image_records.image_records_common import ImageRecordNotFoundException
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
|
||||
from .bulk_download_base import BulkDownloadBase
|
||||
|
||||
class BulkDownloadService(BulkDownloadBase):
|
||||
|
||||
__output_folder: Path
|
||||
__bulk_downloads_folder: Path
|
||||
__event_bus: Optional[EventServiceBase]
|
||||
|
||||
def __init__(self,
|
||||
output_folder: Union[str, Path],
|
||||
event_bus: Optional[EventServiceBase] = None,):
|
||||
"""
|
||||
Initialize the downloader object.
|
||||
|
||||
:param event_bus: Optional EventService object
|
||||
"""
|
||||
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
||||
self.__bulk_downloads_folder = self.__output_folder / "bulk_downloads"
|
||||
self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True)
|
||||
self.__event_bus = event_bus
|
||||
|
||||
|
||||
def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[str]) -> None:
|
||||
"""
|
||||
Create a zip file containing the images specified by the given image names or board id.
|
||||
|
||||
param: image_names: A list of image names to include in the zip file.
|
||||
param: board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file.
|
||||
"""
|
||||
bulk_download_id = str(uuid.uuid4())
|
||||
|
||||
self._signal_job_started(bulk_download_id)
|
||||
try:
|
||||
board_name: Union[str, None] = None
|
||||
if board_id:
|
||||
image_names = invoker.services.board_image_records.get_all_board_image_names_for_board(board_id)
|
||||
if board_id == "none":
|
||||
board_id = "Uncategorized"
|
||||
image_names_to_paths: dict[str, str] = self._get_image_name_to_path_map(invoker, image_names)
|
||||
file_path: str = self._create_zip_file(image_names_to_paths, bulk_download_id)
|
||||
self._signal_job_completed(bulk_download_id, file_path)
|
||||
except (ImageRecordNotFoundException, BoardRecordNotFoundException, BulkDownloadException) as e:
|
||||
self._signal_job_failed(bulk_download_id, e)
|
||||
except Exception as e:
|
||||
self._signal_job_failed(bulk_download_id, e)
|
||||
|
||||
def _get_image_name_to_path_map(self, invoker: Invoker, image_names: list[str]) -> dict[str, str]:
|
||||
"""
|
||||
Create a map of image names to their paths.
|
||||
:param image_names: A list of image names.
|
||||
"""
|
||||
image_names_to_paths: dict[str, str] = {}
|
||||
for image_name in image_names:
|
||||
image_names_to_paths[image_name] = invoker.services.images.get_path(image_name)
|
||||
return image_names_to_paths
|
||||
|
||||
|
||||
|
||||
def _create_zip_file(self, image_names_to_paths: dict[str, str], bulk_download_id: str) -> str:
|
||||
"""
|
||||
Create a zip file containing the images specified by the given image names or board id.
|
||||
If download with the same bulk_download_id already exists, it will be overwritten.
|
||||
"""
|
||||
|
||||
zip_file_path = self.__bulk_downloads_folder / (bulk_download_id + ".zip")
|
||||
|
||||
with ZipFile(zip_file_path, "w") as zip_file:
|
||||
for image_name, image_path in image_names_to_paths.items():
|
||||
zip_file.write(image_path, arcname=image_name)
|
||||
|
||||
return str(zip_file_path)
|
||||
|
||||
|
||||
def _signal_job_started(self, bulk_download_id: str) -> None:
|
||||
"""Signal that a bulk download job has started."""
|
||||
if self.__event_bus:
|
||||
assert bulk_download_id is not None
|
||||
self.__event_bus.emit_bulk_download_started(
|
||||
bulk_download_id=bulk_download_id,
|
||||
)
|
||||
|
||||
|
||||
def _signal_job_completed(self, bulk_download_id: str, file_path: str) -> None:
|
||||
"""Signal that a bulk download job has completed."""
|
||||
if self.__event_bus:
|
||||
assert bulk_download_id is not None
|
||||
assert file_path is not None
|
||||
self.__event_bus.emit_bulk_download_completed(
|
||||
bulk_download_id=bulk_download_id,
|
||||
file_path=file_path,
|
||||
)
|
||||
|
||||
def _signal_job_failed(self, bulk_download_id: str, exception: Exception) -> None:
|
||||
"""Signal that a bulk download job has failed."""
|
||||
if self.__event_bus:
|
||||
assert bulk_download_id is not None
|
||||
assert exception is not None
|
||||
self.__event_bus.emit_bulk_download_failed(
|
||||
bulk_download_id=bulk_download_id,
|
||||
error=str(exception),
|
||||
)
|
||||
|
||||
|
@ -16,6 +16,7 @@ if TYPE_CHECKING:
|
||||
from .board_images.board_images_base import BoardImagesServiceABC
|
||||
from .board_records.board_records_base import BoardRecordStorageBase
|
||||
from .boards.boards_base import BoardServiceABC
|
||||
from .bulk_download.bulk_download_base import BulkDownloadBase
|
||||
from .config import InvokeAIAppConfig
|
||||
from .download import DownloadQueueServiceBase
|
||||
from .events.events_base import EventServiceBase
|
||||
@ -41,6 +42,7 @@ class InvocationServices:
|
||||
board_image_records: "BoardImageRecordStorageBase",
|
||||
boards: "BoardServiceABC",
|
||||
board_records: "BoardRecordStorageBase",
|
||||
bulk_download: "BulkDownloadBase",
|
||||
configuration: "InvokeAIAppConfig",
|
||||
events: "EventServiceBase",
|
||||
images: "ImageServiceABC",
|
||||
@ -63,6 +65,7 @@ class InvocationServices:
|
||||
self.board_image_records = board_image_records
|
||||
self.boards = boards
|
||||
self.board_records = board_records
|
||||
self.bulk_download = bulk_download
|
||||
self.configuration = configuration
|
||||
self.events = events
|
||||
self.images = images
|
||||
|
Loading…
Reference in New Issue
Block a user