From 98a01368b81890d62680ceb720b715edbd0038f9 Mon Sep 17 00:00:00 2001 From: Stefan Tobler Date: Sun, 7 Jan 2024 22:17:03 -0500 Subject: [PATCH] linted and styling --- invokeai/app/api/routers/images.py | 11 ++++--- invokeai/app/api/sockets.py | 1 - .../bulk_download/bulk_download_base.py | 9 +++--- .../bulk_download/bulk_download_common.py | 10 ++++-- .../bulk_download/bulk_download_defauilt.py | 31 +++++++------------ invokeai/app/services/events/events_base.py | 15 +++++---- 6 files changed, 37 insertions(+), 40 deletions(-) diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index 2a8e1e7ec7..e32f7fb9ee 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -10,7 +10,6 @@ from pydantic import BaseModel, Field, ValidationError from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin -from invokeai.app.services.board_records.board_records_common import BoardRecordNotFoundException from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID, WorkflowWithoutIDValidator @@ -373,13 +372,16 @@ async def unstar_images_in_list( except Exception: raise HTTPException(status_code=500, detail="Failed to unstar images") + class ImagesDownloaded(BaseModel): response: Optional[str] = Field( description="If defined, the message to display to the user when images begin downloading" ) -@images_router.post("/download", operation_id="download_images_from_list", response_model=ImagesDownloaded, status_code=202) +@images_router.post( + "/download", operation_id="download_images_from_list", response_model=ImagesDownloaded, status_code=202 +) async def download_images_from_list( background_tasks: BackgroundTasks, image_names: list[str] = Body(description="The list of names of images to download", embed=True), @@ -389,6 +391,7 @@ async def download_images_from_list( ) -> ImagesDownloaded: if (image_names is None or len(image_names) == 0) and board_id is None: raise HTTPException(status_code=400, detail="No images or board id specified.") - background_tasks.add_task(ApiDependencies.invoker.services.bulk_download.handler, ApiDependencies.invoker, image_names, board_id) + background_tasks.add_task( + ApiDependencies.invoker.services.bulk_download.handler, ApiDependencies.invoker, image_names, board_id + ) return ImagesDownloaded(response="Your images are preparing to be downloaded") - diff --git a/invokeai/app/api/sockets.py b/invokeai/app/api/sockets.py index c5d9ace8d2..463545d9bc 100644 --- a/invokeai/app/api/sockets.py +++ b/invokeai/app/api/sockets.py @@ -18,7 +18,6 @@ class SocketIO: __sub_bulk_download: str = "subscribe_bulk_download" __unsub_bulk_download: str = "unsubscribe_bulk_download" - def __init__(self, app: FastAPI): self.__sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*") self.__app = ASGIApp(socketio_server=self.__sio, socketio_path="/ws/socket.io") diff --git a/invokeai/app/services/bulk_download/bulk_download_base.py b/invokeai/app/services/bulk_download/bulk_download_base.py index b788020bba..fc45aff280 100644 --- a/invokeai/app/services/bulk_download/bulk_download_base.py +++ b/invokeai/app/services/bulk_download/bulk_download_base.py @@ -1,13 +1,12 @@ +from abc import ABC, abstractmethod from pathlib import Path from typing import Optional, Union -from abc import ABC, abstractmethod - from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invoker import Invoker -class BulkDownloadBase(ABC): +class BulkDownloadBase(ABC): @abstractmethod def __init__( self, @@ -25,8 +24,8 @@ class BulkDownloadBase(ABC): def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[str]) -> None: """ Starts a a bulk download job. - + :param invoker: The Invoker that holds all the services, required to be passed as a parameter to avoid circular dependencies. :param image_names: A list of image names to include in the zip file. :param board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file. - """ \ No newline at end of file + """ diff --git a/invokeai/app/services/bulk_download/bulk_download_common.py b/invokeai/app/services/bulk_download/bulk_download_common.py index 3ac1f5bba8..23a0589daf 100644 --- a/invokeai/app/services/bulk_download/bulk_download_common.py +++ b/invokeai/app/services/bulk_download/bulk_download_common.py @@ -1,4 +1,3 @@ - class BulkDownloadException(Exception): """Exception raised when a bulk download fails.""" @@ -6,6 +5,7 @@ class BulkDownloadException(Exception): super().__init__(message) self.message = message + class BulkDownloadTargetException(BulkDownloadException): """Exception raised when a bulk download target is not found.""" @@ -13,9 +13,13 @@ class BulkDownloadTargetException(BulkDownloadException): super().__init__(message) self.message = message + class BulkDownloadParametersException(BulkDownloadException): """Exception raised when a bulk download parameter is invalid.""" - def __init__(self, message="The bulk download parameters are invalid, either an array of image names or a board id must be provided"): + def __init__( + self, + message="The bulk download parameters are invalid, either an array of image names or a board id must be provided", + ): super().__init__(message) - self.message = message \ No newline at end of file + self.message = message diff --git a/invokeai/app/services/bulk_download/bulk_download_defauilt.py b/invokeai/app/services/bulk_download/bulk_download_defauilt.py index ebeaa4be5a..8321f5069d 100644 --- a/invokeai/app/services/bulk_download/bulk_download_defauilt.py +++ b/invokeai/app/services/bulk_download/bulk_download_defauilt.py @@ -1,6 +1,6 @@ +import uuid from pathlib import Path from typing import Optional, Union -import uuid from zipfile import ZipFile from invokeai.app.services.board_records.board_records_common import BoardRecordNotFoundException @@ -11,15 +11,17 @@ from invokeai.app.services.invoker import Invoker from .bulk_download_base import BulkDownloadBase -class BulkDownloadService(BulkDownloadBase): +class BulkDownloadService(BulkDownloadBase): __output_folder: Path __bulk_downloads_folder: Path __event_bus: Optional[EventServiceBase] - def __init__(self, - output_folder: Union[str, Path], - event_bus: Optional[EventServiceBase] = None,): + def __init__( + self, + output_folder: Union[str, Path], + event_bus: Optional[EventServiceBase] = None, + ): """ Initialize the downloader object. @@ -30,8 +32,7 @@ class BulkDownloadService(BulkDownloadBase): self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True) self.__event_bus = event_bus - - def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[str]) -> None: + def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[str]) -> None: """ Create a zip file containing the images specified by the given image names or board id. @@ -39,10 +40,8 @@ class BulkDownloadService(BulkDownloadBase): param: board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file. """ bulk_download_id = str(uuid.uuid4()) - - self._signal_job_started(bulk_download_id) + try: - board_name: Union[str, None] = None if board_id: image_names = invoker.services.board_image_records.get_all_board_image_names_for_board(board_id) if board_id == "none": @@ -64,24 +63,21 @@ class BulkDownloadService(BulkDownloadBase): for image_name in image_names: image_names_to_paths[image_name] = invoker.services.images.get_path(image_name) return image_names_to_paths - - def _create_zip_file(self, image_names_to_paths: dict[str, str], bulk_download_id: str) -> str: """ - Create a zip file containing the images specified by the given image names or board id. + Create a zip file containing the images specified by the given image names or board id. If download with the same bulk_download_id already exists, it will be overwritten. """ zip_file_path = self.__bulk_downloads_folder / (bulk_download_id + ".zip") - + with ZipFile(zip_file_path, "w") as zip_file: for image_name, image_path in image_names_to_paths.items(): zip_file.write(image_path, arcname=image_name) return str(zip_file_path) - def _signal_job_started(self, bulk_download_id: str) -> None: """Signal that a bulk download job has started.""" if self.__event_bus: @@ -90,7 +86,6 @@ class BulkDownloadService(BulkDownloadBase): bulk_download_id=bulk_download_id, ) - def _signal_job_completed(self, bulk_download_id: str, file_path: str) -> None: """Signal that a bulk download job has completed.""" if self.__event_bus: @@ -100,7 +95,7 @@ class BulkDownloadService(BulkDownloadBase): bulk_download_id=bulk_download_id, file_path=file_path, ) - + def _signal_job_failed(self, bulk_download_id: str, exception: Exception) -> None: """Signal that a bulk download job has failed.""" if self.__event_bus: @@ -110,5 +105,3 @@ class BulkDownloadService(BulkDownloadBase): bulk_download_id=bulk_download_id, error=str(exception), ) - - \ No newline at end of file diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index 0a0668b274..597a56d944 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -444,20 +444,19 @@ class EventServiceBase: """Emitted when a bulk download starts""" self._emit_bulk_download_event( event_name="bulk_download_started", - payload={"bulk_download_id": bulk_download_id, } + payload={ + "bulk_download_id": bulk_download_id, + }, ) - + def emit_bulk_download_completed(self, bulk_download_id: str, file_path: str) -> None: """Emitted when a bulk download completes""" self._emit_bulk_download_event( - event_name="bulk_download_completed", - payload={"bulk_download_id": bulk_download_id, - "file_path": file_path} + event_name="bulk_download_completed", payload={"bulk_download_id": bulk_download_id, "file_path": file_path} ) - + def emit_bulk_download_failed(self, bulk_download_id: str, error: str) -> None: """Emitted when a bulk download fails""" self._emit_bulk_download_event( - event_name="bulk_download_failed", - payload={"bulk_download_id": bulk_download_id, "error": error} + event_name="bulk_download_failed", payload={"bulk_download_id": bulk_download_id, "error": error} )