linted and styling

This commit is contained in:
Stefan Tobler 2024-01-07 22:17:03 -05:00 committed by psychedelicious
parent fc9a62dbf5
commit 98a01368b8
6 changed files with 37 additions and 40 deletions

View File

@ -10,7 +10,6 @@ from pydantic import BaseModel, Field, ValidationError
from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
from invokeai.app.services.board_records.board_records_common import BoardRecordNotFoundException
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID, WorkflowWithoutIDValidator
@ -373,13 +372,16 @@ async def unstar_images_in_list(
except Exception:
raise HTTPException(status_code=500, detail="Failed to unstar images")
class ImagesDownloaded(BaseModel):
response: Optional[str] = Field(
description="If defined, the message to display to the user when images begin downloading"
)
@images_router.post("/download", operation_id="download_images_from_list", response_model=ImagesDownloaded, status_code=202)
@images_router.post(
"/download", operation_id="download_images_from_list", response_model=ImagesDownloaded, status_code=202
)
async def download_images_from_list(
background_tasks: BackgroundTasks,
image_names: list[str] = Body(description="The list of names of images to download", embed=True),
@ -389,6 +391,7 @@ async def download_images_from_list(
) -> ImagesDownloaded:
if (image_names is None or len(image_names) == 0) and board_id is None:
raise HTTPException(status_code=400, detail="No images or board id specified.")
background_tasks.add_task(ApiDependencies.invoker.services.bulk_download.handler, ApiDependencies.invoker, image_names, board_id)
background_tasks.add_task(
ApiDependencies.invoker.services.bulk_download.handler, ApiDependencies.invoker, image_names, board_id
)
return ImagesDownloaded(response="Your images are preparing to be downloaded")

View File

@ -18,7 +18,6 @@ class SocketIO:
__sub_bulk_download: str = "subscribe_bulk_download"
__unsub_bulk_download: str = "unsubscribe_bulk_download"
def __init__(self, app: FastAPI):
self.__sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")
self.__app = ASGIApp(socketio_server=self.__sio, socketio_path="/ws/socket.io")

View File

@ -1,13 +1,12 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Optional, Union
from abc import ABC, abstractmethod
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker
class BulkDownloadBase(ABC):
class BulkDownloadBase(ABC):
@abstractmethod
def __init__(
self,
@ -25,8 +24,8 @@ class BulkDownloadBase(ABC):
def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[str]) -> None:
"""
Starts a a bulk download job.
:param invoker: The Invoker that holds all the services, required to be passed as a parameter to avoid circular dependencies.
:param image_names: A list of image names to include in the zip file.
:param board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file.
"""
"""

View File

@ -1,4 +1,3 @@
class BulkDownloadException(Exception):
"""Exception raised when a bulk download fails."""
@ -6,6 +5,7 @@ class BulkDownloadException(Exception):
super().__init__(message)
self.message = message
class BulkDownloadTargetException(BulkDownloadException):
"""Exception raised when a bulk download target is not found."""
@ -13,9 +13,13 @@ class BulkDownloadTargetException(BulkDownloadException):
super().__init__(message)
self.message = message
class BulkDownloadParametersException(BulkDownloadException):
"""Exception raised when a bulk download parameter is invalid."""
def __init__(self, message="The bulk download parameters are invalid, either an array of image names or a board id must be provided"):
def __init__(
self,
message="The bulk download parameters are invalid, either an array of image names or a board id must be provided",
):
super().__init__(message)
self.message = message
self.message = message

View File

@ -1,6 +1,6 @@
import uuid
from pathlib import Path
from typing import Optional, Union
import uuid
from zipfile import ZipFile
from invokeai.app.services.board_records.board_records_common import BoardRecordNotFoundException
@ -11,15 +11,17 @@ from invokeai.app.services.invoker import Invoker
from .bulk_download_base import BulkDownloadBase
class BulkDownloadService(BulkDownloadBase):
class BulkDownloadService(BulkDownloadBase):
__output_folder: Path
__bulk_downloads_folder: Path
__event_bus: Optional[EventServiceBase]
def __init__(self,
output_folder: Union[str, Path],
event_bus: Optional[EventServiceBase] = None,):
def __init__(
self,
output_folder: Union[str, Path],
event_bus: Optional[EventServiceBase] = None,
):
"""
Initialize the downloader object.
@ -30,8 +32,7 @@ class BulkDownloadService(BulkDownloadBase):
self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True)
self.__event_bus = event_bus
def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[str]) -> None:
def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[str]) -> None:
"""
Create a zip file containing the images specified by the given image names or board id.
@ -39,10 +40,8 @@ class BulkDownloadService(BulkDownloadBase):
param: board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file.
"""
bulk_download_id = str(uuid.uuid4())
self._signal_job_started(bulk_download_id)
try:
board_name: Union[str, None] = None
if board_id:
image_names = invoker.services.board_image_records.get_all_board_image_names_for_board(board_id)
if board_id == "none":
@ -64,24 +63,21 @@ class BulkDownloadService(BulkDownloadBase):
for image_name in image_names:
image_names_to_paths[image_name] = invoker.services.images.get_path(image_name)
return image_names_to_paths
def _create_zip_file(self, image_names_to_paths: dict[str, str], bulk_download_id: str) -> str:
"""
Create a zip file containing the images specified by the given image names or board id.
Create a zip file containing the images specified by the given image names or board id.
If download with the same bulk_download_id already exists, it will be overwritten.
"""
zip_file_path = self.__bulk_downloads_folder / (bulk_download_id + ".zip")
with ZipFile(zip_file_path, "w") as zip_file:
for image_name, image_path in image_names_to_paths.items():
zip_file.write(image_path, arcname=image_name)
return str(zip_file_path)
def _signal_job_started(self, bulk_download_id: str) -> None:
"""Signal that a bulk download job has started."""
if self.__event_bus:
@ -90,7 +86,6 @@ class BulkDownloadService(BulkDownloadBase):
bulk_download_id=bulk_download_id,
)
def _signal_job_completed(self, bulk_download_id: str, file_path: str) -> None:
"""Signal that a bulk download job has completed."""
if self.__event_bus:
@ -100,7 +95,7 @@ class BulkDownloadService(BulkDownloadBase):
bulk_download_id=bulk_download_id,
file_path=file_path,
)
def _signal_job_failed(self, bulk_download_id: str, exception: Exception) -> None:
"""Signal that a bulk download job has failed."""
if self.__event_bus:
@ -110,5 +105,3 @@ class BulkDownloadService(BulkDownloadBase):
bulk_download_id=bulk_download_id,
error=str(exception),
)

View File

@ -444,20 +444,19 @@ class EventServiceBase:
"""Emitted when a bulk download starts"""
self._emit_bulk_download_event(
event_name="bulk_download_started",
payload={"bulk_download_id": bulk_download_id, }
payload={
"bulk_download_id": bulk_download_id,
},
)
def emit_bulk_download_completed(self, bulk_download_id: str, file_path: str) -> None:
"""Emitted when a bulk download completes"""
self._emit_bulk_download_event(
event_name="bulk_download_completed",
payload={"bulk_download_id": bulk_download_id,
"file_path": file_path}
event_name="bulk_download_completed", payload={"bulk_download_id": bulk_download_id, "file_path": file_path}
)
def emit_bulk_download_failed(self, bulk_download_id: str, error: str) -> None:
"""Emitted when a bulk download fails"""
self._emit_bulk_download_event(
event_name="bulk_download_failed",
payload={"bulk_download_id": bulk_download_id, "error": error}
event_name="bulk_download_failed", payload={"bulk_download_id": bulk_download_id, "error": error}
)