reworking some of the logic to use a default room, adding endpoint to download file on complete

This commit is contained in:
Stefan Tobler 2024-01-13 23:35:33 -05:00 committed by psychedelicious
parent 98a01368b8
commit aa132fb9e3
6 changed files with 137 additions and 21 deletions

View File

@ -15,7 +15,7 @@ from ..services.board_image_records.board_image_records_sqlite import SqliteBoar
from ..services.board_images.board_images_default import BoardImagesService
from ..services.board_records.board_records_sqlite import SqliteBoardRecordStorage
from ..services.boards.boards_default import BoardService
from ..services.bulk_download.bulk_download_defauilt import BulkDownloadService
from ..services.bulk_download.bulk_download_default import BulkDownloadService
from ..services.config import InvokeAIAppConfig
from ..services.download import DownloadQueueService
from ..services.image_files.image_files_disk import DiskImageFileStorage

View File

@ -395,3 +395,36 @@ async def download_images_from_list(
ApiDependencies.invoker.services.bulk_download.handler, ApiDependencies.invoker, image_names, board_id
)
return ImagesDownloaded(response="Your images are preparing to be downloaded")
@images_router.api_route(
"/download/{bulk_download_item_name}",
methods=["GET"],
operation_id="get_bulk_download_item",
response_class=Response,
responses={
200: {
"description": "Return the complete bulk download item",
"content": {"application/zip": {}},
},
404: {"description": "Image not found"},
},
)
async def get_bulk_download_item(
bulk_download_item_name: str = Path(description="The bulk_download_item_id of the bulk download item to get"),
) -> FileResponse:
"""Gets a bulk download zip file"""
try:
path = ApiDependencies.invoker.services.bulk_download.get_path(bulk_download_item_name)
response = FileResponse(
path,
media_type="application/zip",
filename=bulk_download_item_name,
content_disposition_type="inline",
)
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
return response
except Exception:
raise HTTPException(status_code=404)

View File

@ -29,3 +29,28 @@ class BulkDownloadBase(ABC):
:param image_names: A list of image names to include in the zip file.
:param board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file.
"""
@abstractmethod
def get_path(self, bulk_download_item_id: str) -> str:
"""
Get the path to the bulk download file.
:param bulk_download_item_id: The ID of the bulk download item.
:return: The path to the bulk download file.
"""
@abstractmethod
def stop(self, *args, **kwargs) -> None:
"""
Stops the BulkDownloadService and cleans up all the remnants.
This method is responsible for stopping the BulkDownloadService and performing any necessary cleanup
operations to remove any remnants or resources associated with the service.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
None
"""

View File

@ -1,3 +1,6 @@
DEFAULT_BULK_DOWNLOAD_ID = "default"
class BulkDownloadException(Exception):
"""Exception raised when a bulk download fails."""

View File

@ -4,7 +4,11 @@ from typing import Optional, Union
from zipfile import ZipFile
from invokeai.app.services.board_records.board_records_common import BoardRecordNotFoundException
from invokeai.app.services.bulk_download.bulk_download_common import BulkDownloadException
from invokeai.app.services.bulk_download.bulk_download_common import (
DEFAULT_BULK_DOWNLOAD_ID,
BulkDownloadException,
BulkDownloadTargetException,
)
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.image_records.image_records_common import ImageRecordNotFoundException
from invokeai.app.services.invoker import Invoker
@ -32,6 +36,32 @@ class BulkDownloadService(BulkDownloadBase):
self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True)
self.__event_bus = event_bus
def get_path(self, bulk_download_item_name: str) -> str:
"""
Get the path to the bulk download file.
:param bulk_download_item_name: The name of the bulk download item.
:return: The path to the bulk download file.
"""
path = str(self.__bulk_downloads_folder / bulk_download_item_name)
if not self.validate_path(path):
raise BulkDownloadTargetException()
return path
def get_bulk_download_item_name(self, bulk_download_item_id: str) -> str:
"""
Get the name of the bulk download item.
:param bulk_download_item_id: The ID of the bulk download item.
:return: The name of the bulk download item.
"""
return bulk_download_item_id + ".zip"
def validate_path(self, path: Union[str, Path]) -> bool:
"""Validates the path given for a bulk download."""
path = path if isinstance(path, Path) else Path(path)
return path.exists()
def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[str]) -> None:
"""
Create a zip file containing the images specified by the given image names or board id.
@ -39,7 +69,9 @@ class BulkDownloadService(BulkDownloadBase):
param: image_names: A list of image names to include in the zip file.
param: board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file.
"""
bulk_download_id = str(uuid.uuid4())
bulk_download_id = DEFAULT_BULK_DOWNLOAD_ID
bulk_download_item_id = str(uuid.uuid4())
self._signal_job_started(bulk_download_id, bulk_download_item_id)
try:
if board_id:
@ -47,12 +79,12 @@ class BulkDownloadService(BulkDownloadBase):
if board_id == "none":
board_id = "Uncategorized"
image_names_to_paths: dict[str, str] = self._get_image_name_to_path_map(invoker, image_names)
file_path: str = self._create_zip_file(image_names_to_paths, bulk_download_id)
self._signal_job_completed(bulk_download_id, file_path)
bulk_download_item_name: str = self._create_zip_file(image_names_to_paths, bulk_download_item_id)
self._signal_job_completed(bulk_download_id, bulk_download_item_id, bulk_download_item_name)
except (ImageRecordNotFoundException, BoardRecordNotFoundException, BulkDownloadException) as e:
self._signal_job_failed(bulk_download_id, e)
self._signal_job_failed(bulk_download_id, bulk_download_item_id, e)
except Exception as e:
self._signal_job_failed(bulk_download_id, e)
self._signal_job_failed(bulk_download_id, bulk_download_item_id, e)
def _get_image_name_to_path_map(self, invoker: Invoker, image_names: list[str]) -> dict[str, str]:
"""
@ -64,44 +96,54 @@ class BulkDownloadService(BulkDownloadBase):
image_names_to_paths[image_name] = invoker.services.images.get_path(image_name)
return image_names_to_paths
def _create_zip_file(self, image_names_to_paths: dict[str, str], bulk_download_id: str) -> str:
def _create_zip_file(self, image_names_to_paths: dict[str, str], bulk_download_item_id: str) -> str:
"""
Create a zip file containing the images specified by the given image names or board id.
If download with the same bulk_download_id already exists, it will be overwritten.
"""
zip_file_path = self.__bulk_downloads_folder / (bulk_download_id + ".zip")
:return: The name of the zip file.
"""
zip_file_name = bulk_download_item_id + ".zip"
zip_file_path = self.__bulk_downloads_folder / (zip_file_name)
with ZipFile(zip_file_path, "w") as zip_file:
for image_name, image_path in image_names_to_paths.items():
zip_file.write(image_path, arcname=image_name)
return str(zip_file_path)
return str(zip_file_name)
def _signal_job_started(self, bulk_download_id: str) -> None:
def _signal_job_started(self, bulk_download_id: str, bulk_download_item_id: str) -> None:
"""Signal that a bulk download job has started."""
if self.__event_bus:
assert bulk_download_id is not None
self.__event_bus.emit_bulk_download_started(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
)
def _signal_job_completed(self, bulk_download_id: str, file_path: str) -> None:
def _signal_job_completed(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> None:
"""Signal that a bulk download job has completed."""
if self.__event_bus:
assert bulk_download_id is not None
assert file_path is not None
assert bulk_download_item_name is not None
self.__event_bus.emit_bulk_download_completed(
bulk_download_id=bulk_download_id,
file_path=file_path,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
)
def _signal_job_failed(self, bulk_download_id: str, exception: Exception) -> None:
def _signal_job_failed(self, bulk_download_id: str, bulk_download_item_id: str, exception: Exception) -> None:
"""Signal that a bulk download job has failed."""
if self.__event_bus:
assert bulk_download_id is not None
assert exception is not None
self.__event_bus.emit_bulk_download_failed(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
error=str(exception),
)
def stop(self, *args, **kwargs):
pass

View File

@ -440,23 +440,36 @@ class EventServiceBase:
},
)
def emit_bulk_download_started(self, bulk_download_id: str) -> None:
def emit_bulk_download_started(self, bulk_download_id: str, bulk_download_item_id: str) -> None:
"""Emitted when a bulk download starts"""
self._emit_bulk_download_event(
event_name="bulk_download_started",
payload={
"bulk_download_id": bulk_download_id,
"bulk_download_item_id": bulk_download_item_id,
},
)
def emit_bulk_download_completed(self, bulk_download_id: str, file_path: str) -> None:
def emit_bulk_download_completed(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> None:
"""Emitted when a bulk download completes"""
self._emit_bulk_download_event(
event_name="bulk_download_completed", payload={"bulk_download_id": bulk_download_id, "file_path": file_path}
event_name="bulk_download_completed",
payload={
"bulk_download_id": bulk_download_id,
"bulk_download_item_id": bulk_download_item_id,
"bulk_download_item_name": bulk_download_item_name,
},
)
def emit_bulk_download_failed(self, bulk_download_id: str, error: str) -> None:
def emit_bulk_download_failed(self, bulk_download_id: str, bulk_download_item_id: str, error: str) -> None:
"""Emitted when a bulk download fails"""
self._emit_bulk_download_event(
event_name="bulk_download_failed", payload={"bulk_download_id": bulk_download_id, "error": error}
event_name="bulk_download_failed",
payload={
"bulk_download_id": bulk_download_id,
"bulk_download_item_id": bulk_download_item_id,
"error": error,
},
)