mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
reworking some of the logic to use a default room, adding endpoint to download file on complete
This commit is contained in:
parent
98a01368b8
commit
aa132fb9e3
@ -15,7 +15,7 @@ from ..services.board_image_records.board_image_records_sqlite import SqliteBoar
|
|||||||
from ..services.board_images.board_images_default import BoardImagesService
|
from ..services.board_images.board_images_default import BoardImagesService
|
||||||
from ..services.board_records.board_records_sqlite import SqliteBoardRecordStorage
|
from ..services.board_records.board_records_sqlite import SqliteBoardRecordStorage
|
||||||
from ..services.boards.boards_default import BoardService
|
from ..services.boards.boards_default import BoardService
|
||||||
from ..services.bulk_download.bulk_download_defauilt import BulkDownloadService
|
from ..services.bulk_download.bulk_download_default import BulkDownloadService
|
||||||
from ..services.config import InvokeAIAppConfig
|
from ..services.config import InvokeAIAppConfig
|
||||||
from ..services.download import DownloadQueueService
|
from ..services.download import DownloadQueueService
|
||||||
from ..services.image_files.image_files_disk import DiskImageFileStorage
|
from ..services.image_files.image_files_disk import DiskImageFileStorage
|
||||||
|
@ -395,3 +395,36 @@ async def download_images_from_list(
|
|||||||
ApiDependencies.invoker.services.bulk_download.handler, ApiDependencies.invoker, image_names, board_id
|
ApiDependencies.invoker.services.bulk_download.handler, ApiDependencies.invoker, image_names, board_id
|
||||||
)
|
)
|
||||||
return ImagesDownloaded(response="Your images are preparing to be downloaded")
|
return ImagesDownloaded(response="Your images are preparing to be downloaded")
|
||||||
|
|
||||||
|
|
||||||
|
@images_router.api_route(
|
||||||
|
"/download/{bulk_download_item_name}",
|
||||||
|
methods=["GET"],
|
||||||
|
operation_id="get_bulk_download_item",
|
||||||
|
response_class=Response,
|
||||||
|
responses={
|
||||||
|
200: {
|
||||||
|
"description": "Return the complete bulk download item",
|
||||||
|
"content": {"application/zip": {}},
|
||||||
|
},
|
||||||
|
404: {"description": "Image not found"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def get_bulk_download_item(
|
||||||
|
bulk_download_item_name: str = Path(description="The bulk_download_item_id of the bulk download item to get"),
|
||||||
|
) -> FileResponse:
|
||||||
|
"""Gets a bulk download zip file"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
path = ApiDependencies.invoker.services.bulk_download.get_path(bulk_download_item_name)
|
||||||
|
|
||||||
|
response = FileResponse(
|
||||||
|
path,
|
||||||
|
media_type="application/zip",
|
||||||
|
filename=bulk_download_item_name,
|
||||||
|
content_disposition_type="inline",
|
||||||
|
)
|
||||||
|
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
||||||
|
return response
|
||||||
|
except Exception:
|
||||||
|
raise HTTPException(status_code=404)
|
||||||
|
@ -29,3 +29,28 @@ class BulkDownloadBase(ABC):
|
|||||||
:param image_names: A list of image names to include in the zip file.
|
:param image_names: A list of image names to include in the zip file.
|
||||||
:param board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file.
|
:param board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_path(self, bulk_download_item_id: str) -> str:
|
||||||
|
"""
|
||||||
|
Get the path to the bulk download file.
|
||||||
|
|
||||||
|
:param bulk_download_item_id: The ID of the bulk download item.
|
||||||
|
:return: The path to the bulk download file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def stop(self, *args, **kwargs) -> None:
|
||||||
|
"""
|
||||||
|
Stops the BulkDownloadService and cleans up all the remnants.
|
||||||
|
|
||||||
|
This method is responsible for stopping the BulkDownloadService and performing any necessary cleanup
|
||||||
|
operations to remove any remnants or resources associated with the service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*args: Variable length argument list.
|
||||||
|
**kwargs: Arbitrary keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
@ -1,3 +1,6 @@
|
|||||||
|
DEFAULT_BULK_DOWNLOAD_ID = "default"
|
||||||
|
|
||||||
|
|
||||||
class BulkDownloadException(Exception):
|
class BulkDownloadException(Exception):
|
||||||
"""Exception raised when a bulk download fails."""
|
"""Exception raised when a bulk download fails."""
|
||||||
|
|
||||||
|
@ -4,7 +4,11 @@ from typing import Optional, Union
|
|||||||
from zipfile import ZipFile
|
from zipfile import ZipFile
|
||||||
|
|
||||||
from invokeai.app.services.board_records.board_records_common import BoardRecordNotFoundException
|
from invokeai.app.services.board_records.board_records_common import BoardRecordNotFoundException
|
||||||
from invokeai.app.services.bulk_download.bulk_download_common import BulkDownloadException
|
from invokeai.app.services.bulk_download.bulk_download_common import (
|
||||||
|
DEFAULT_BULK_DOWNLOAD_ID,
|
||||||
|
BulkDownloadException,
|
||||||
|
BulkDownloadTargetException,
|
||||||
|
)
|
||||||
from invokeai.app.services.events.events_base import EventServiceBase
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||||||
from invokeai.app.services.image_records.image_records_common import ImageRecordNotFoundException
|
from invokeai.app.services.image_records.image_records_common import ImageRecordNotFoundException
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
@ -32,6 +36,32 @@ class BulkDownloadService(BulkDownloadBase):
|
|||||||
self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True)
|
self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True)
|
||||||
self.__event_bus = event_bus
|
self.__event_bus = event_bus
|
||||||
|
|
||||||
|
def get_path(self, bulk_download_item_name: str) -> str:
|
||||||
|
"""
|
||||||
|
Get the path to the bulk download file.
|
||||||
|
|
||||||
|
:param bulk_download_item_name: The name of the bulk download item.
|
||||||
|
:return: The path to the bulk download file.
|
||||||
|
"""
|
||||||
|
path = str(self.__bulk_downloads_folder / bulk_download_item_name)
|
||||||
|
if not self.validate_path(path):
|
||||||
|
raise BulkDownloadTargetException()
|
||||||
|
return path
|
||||||
|
|
||||||
|
def get_bulk_download_item_name(self, bulk_download_item_id: str) -> str:
|
||||||
|
"""
|
||||||
|
Get the name of the bulk download item.
|
||||||
|
|
||||||
|
:param bulk_download_item_id: The ID of the bulk download item.
|
||||||
|
:return: The name of the bulk download item.
|
||||||
|
"""
|
||||||
|
return bulk_download_item_id + ".zip"
|
||||||
|
|
||||||
|
def validate_path(self, path: Union[str, Path]) -> bool:
|
||||||
|
"""Validates the path given for a bulk download."""
|
||||||
|
path = path if isinstance(path, Path) else Path(path)
|
||||||
|
return path.exists()
|
||||||
|
|
||||||
def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[str]) -> None:
|
def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[str]) -> None:
|
||||||
"""
|
"""
|
||||||
Create a zip file containing the images specified by the given image names or board id.
|
Create a zip file containing the images specified by the given image names or board id.
|
||||||
@ -39,7 +69,9 @@ class BulkDownloadService(BulkDownloadBase):
|
|||||||
param: image_names: A list of image names to include in the zip file.
|
param: image_names: A list of image names to include in the zip file.
|
||||||
param: board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file.
|
param: board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file.
|
||||||
"""
|
"""
|
||||||
bulk_download_id = str(uuid.uuid4())
|
bulk_download_id = DEFAULT_BULK_DOWNLOAD_ID
|
||||||
|
bulk_download_item_id = str(uuid.uuid4())
|
||||||
|
self._signal_job_started(bulk_download_id, bulk_download_item_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if board_id:
|
if board_id:
|
||||||
@ -47,12 +79,12 @@ class BulkDownloadService(BulkDownloadBase):
|
|||||||
if board_id == "none":
|
if board_id == "none":
|
||||||
board_id = "Uncategorized"
|
board_id = "Uncategorized"
|
||||||
image_names_to_paths: dict[str, str] = self._get_image_name_to_path_map(invoker, image_names)
|
image_names_to_paths: dict[str, str] = self._get_image_name_to_path_map(invoker, image_names)
|
||||||
file_path: str = self._create_zip_file(image_names_to_paths, bulk_download_id)
|
bulk_download_item_name: str = self._create_zip_file(image_names_to_paths, bulk_download_item_id)
|
||||||
self._signal_job_completed(bulk_download_id, file_path)
|
self._signal_job_completed(bulk_download_id, bulk_download_item_id, bulk_download_item_name)
|
||||||
except (ImageRecordNotFoundException, BoardRecordNotFoundException, BulkDownloadException) as e:
|
except (ImageRecordNotFoundException, BoardRecordNotFoundException, BulkDownloadException) as e:
|
||||||
self._signal_job_failed(bulk_download_id, e)
|
self._signal_job_failed(bulk_download_id, bulk_download_item_id, e)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._signal_job_failed(bulk_download_id, e)
|
self._signal_job_failed(bulk_download_id, bulk_download_item_id, e)
|
||||||
|
|
||||||
def _get_image_name_to_path_map(self, invoker: Invoker, image_names: list[str]) -> dict[str, str]:
|
def _get_image_name_to_path_map(self, invoker: Invoker, image_names: list[str]) -> dict[str, str]:
|
||||||
"""
|
"""
|
||||||
@ -64,44 +96,54 @@ class BulkDownloadService(BulkDownloadBase):
|
|||||||
image_names_to_paths[image_name] = invoker.services.images.get_path(image_name)
|
image_names_to_paths[image_name] = invoker.services.images.get_path(image_name)
|
||||||
return image_names_to_paths
|
return image_names_to_paths
|
||||||
|
|
||||||
def _create_zip_file(self, image_names_to_paths: dict[str, str], bulk_download_id: str) -> str:
|
def _create_zip_file(self, image_names_to_paths: dict[str, str], bulk_download_item_id: str) -> str:
|
||||||
"""
|
"""
|
||||||
Create a zip file containing the images specified by the given image names or board id.
|
Create a zip file containing the images specified by the given image names or board id.
|
||||||
If download with the same bulk_download_id already exists, it will be overwritten.
|
If download with the same bulk_download_id already exists, it will be overwritten.
|
||||||
"""
|
|
||||||
|
|
||||||
zip_file_path = self.__bulk_downloads_folder / (bulk_download_id + ".zip")
|
:return: The name of the zip file.
|
||||||
|
"""
|
||||||
|
zip_file_name = bulk_download_item_id + ".zip"
|
||||||
|
zip_file_path = self.__bulk_downloads_folder / (zip_file_name)
|
||||||
|
|
||||||
with ZipFile(zip_file_path, "w") as zip_file:
|
with ZipFile(zip_file_path, "w") as zip_file:
|
||||||
for image_name, image_path in image_names_to_paths.items():
|
for image_name, image_path in image_names_to_paths.items():
|
||||||
zip_file.write(image_path, arcname=image_name)
|
zip_file.write(image_path, arcname=image_name)
|
||||||
|
|
||||||
return str(zip_file_path)
|
return str(zip_file_name)
|
||||||
|
|
||||||
def _signal_job_started(self, bulk_download_id: str) -> None:
|
def _signal_job_started(self, bulk_download_id: str, bulk_download_item_id: str) -> None:
|
||||||
"""Signal that a bulk download job has started."""
|
"""Signal that a bulk download job has started."""
|
||||||
if self.__event_bus:
|
if self.__event_bus:
|
||||||
assert bulk_download_id is not None
|
assert bulk_download_id is not None
|
||||||
self.__event_bus.emit_bulk_download_started(
|
self.__event_bus.emit_bulk_download_started(
|
||||||
bulk_download_id=bulk_download_id,
|
bulk_download_id=bulk_download_id,
|
||||||
|
bulk_download_item_id=bulk_download_item_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _signal_job_completed(self, bulk_download_id: str, file_path: str) -> None:
|
def _signal_job_completed(
|
||||||
|
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
|
||||||
|
) -> None:
|
||||||
"""Signal that a bulk download job has completed."""
|
"""Signal that a bulk download job has completed."""
|
||||||
if self.__event_bus:
|
if self.__event_bus:
|
||||||
assert bulk_download_id is not None
|
assert bulk_download_id is not None
|
||||||
assert file_path is not None
|
assert bulk_download_item_name is not None
|
||||||
self.__event_bus.emit_bulk_download_completed(
|
self.__event_bus.emit_bulk_download_completed(
|
||||||
bulk_download_id=bulk_download_id,
|
bulk_download_id=bulk_download_id,
|
||||||
file_path=file_path,
|
bulk_download_item_id=bulk_download_item_id,
|
||||||
|
bulk_download_item_name=bulk_download_item_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _signal_job_failed(self, bulk_download_id: str, exception: Exception) -> None:
|
def _signal_job_failed(self, bulk_download_id: str, bulk_download_item_id: str, exception: Exception) -> None:
|
||||||
"""Signal that a bulk download job has failed."""
|
"""Signal that a bulk download job has failed."""
|
||||||
if self.__event_bus:
|
if self.__event_bus:
|
||||||
assert bulk_download_id is not None
|
assert bulk_download_id is not None
|
||||||
assert exception is not None
|
assert exception is not None
|
||||||
self.__event_bus.emit_bulk_download_failed(
|
self.__event_bus.emit_bulk_download_failed(
|
||||||
bulk_download_id=bulk_download_id,
|
bulk_download_id=bulk_download_id,
|
||||||
|
bulk_download_item_id=bulk_download_item_id,
|
||||||
error=str(exception),
|
error=str(exception),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def stop(self, *args, **kwargs):
|
||||||
|
pass
|
@ -440,23 +440,36 @@ class EventServiceBase:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_bulk_download_started(self, bulk_download_id: str) -> None:
|
def emit_bulk_download_started(self, bulk_download_id: str, bulk_download_item_id: str) -> None:
|
||||||
"""Emitted when a bulk download starts"""
|
"""Emitted when a bulk download starts"""
|
||||||
self._emit_bulk_download_event(
|
self._emit_bulk_download_event(
|
||||||
event_name="bulk_download_started",
|
event_name="bulk_download_started",
|
||||||
payload={
|
payload={
|
||||||
"bulk_download_id": bulk_download_id,
|
"bulk_download_id": bulk_download_id,
|
||||||
|
"bulk_download_item_id": bulk_download_item_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_bulk_download_completed(self, bulk_download_id: str, file_path: str) -> None:
|
def emit_bulk_download_completed(
|
||||||
|
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
|
||||||
|
) -> None:
|
||||||
"""Emitted when a bulk download completes"""
|
"""Emitted when a bulk download completes"""
|
||||||
self._emit_bulk_download_event(
|
self._emit_bulk_download_event(
|
||||||
event_name="bulk_download_completed", payload={"bulk_download_id": bulk_download_id, "file_path": file_path}
|
event_name="bulk_download_completed",
|
||||||
|
payload={
|
||||||
|
"bulk_download_id": bulk_download_id,
|
||||||
|
"bulk_download_item_id": bulk_download_item_id,
|
||||||
|
"bulk_download_item_name": bulk_download_item_name,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_bulk_download_failed(self, bulk_download_id: str, error: str) -> None:
|
def emit_bulk_download_failed(self, bulk_download_id: str, bulk_download_item_id: str, error: str) -> None:
|
||||||
"""Emitted when a bulk download fails"""
|
"""Emitted when a bulk download fails"""
|
||||||
self._emit_bulk_download_event(
|
self._emit_bulk_download_event(
|
||||||
event_name="bulk_download_failed", payload={"bulk_download_id": bulk_download_id, "error": error}
|
event_name="bulk_download_failed",
|
||||||
|
payload={
|
||||||
|
"bulk_download_id": bulk_download_id,
|
||||||
|
"bulk_download_item_id": bulk_download_item_id,
|
||||||
|
"error": error,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user