linted and styling

This commit is contained in:
Stefan Tobler 2024-01-07 22:17:03 -05:00 committed by psychedelicious
parent fc9a62dbf5
commit 98a01368b8
6 changed files with 37 additions and 40 deletions

View File

@ -10,7 +10,6 @@ from pydantic import BaseModel, Field, ValidationError
from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
from invokeai.app.services.board_records.board_records_common import BoardRecordNotFoundException
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID, WorkflowWithoutIDValidator from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID, WorkflowWithoutIDValidator
@ -373,13 +372,16 @@ async def unstar_images_in_list(
except Exception: except Exception:
raise HTTPException(status_code=500, detail="Failed to unstar images") raise HTTPException(status_code=500, detail="Failed to unstar images")
class ImagesDownloaded(BaseModel): class ImagesDownloaded(BaseModel):
response: Optional[str] = Field( response: Optional[str] = Field(
description="If defined, the message to display to the user when images begin downloading" description="If defined, the message to display to the user when images begin downloading"
) )
@images_router.post("/download", operation_id="download_images_from_list", response_model=ImagesDownloaded, status_code=202) @images_router.post(
"/download", operation_id="download_images_from_list", response_model=ImagesDownloaded, status_code=202
)
async def download_images_from_list( async def download_images_from_list(
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
image_names: list[str] = Body(description="The list of names of images to download", embed=True), image_names: list[str] = Body(description="The list of names of images to download", embed=True),
@ -389,6 +391,7 @@ async def download_images_from_list(
) -> ImagesDownloaded: ) -> ImagesDownloaded:
if (image_names is None or len(image_names) == 0) and board_id is None: if (image_names is None or len(image_names) == 0) and board_id is None:
raise HTTPException(status_code=400, detail="No images or board id specified.") raise HTTPException(status_code=400, detail="No images or board id specified.")
background_tasks.add_task(ApiDependencies.invoker.services.bulk_download.handler, ApiDependencies.invoker, image_names, board_id) background_tasks.add_task(
ApiDependencies.invoker.services.bulk_download.handler, ApiDependencies.invoker, image_names, board_id
)
return ImagesDownloaded(response="Your images are preparing to be downloaded") return ImagesDownloaded(response="Your images are preparing to be downloaded")

View File

@ -18,7 +18,6 @@ class SocketIO:
__sub_bulk_download: str = "subscribe_bulk_download" __sub_bulk_download: str = "subscribe_bulk_download"
__unsub_bulk_download: str = "unsubscribe_bulk_download" __unsub_bulk_download: str = "unsubscribe_bulk_download"
def __init__(self, app: FastAPI): def __init__(self, app: FastAPI):
self.__sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*") self.__sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")
self.__app = ASGIApp(socketio_server=self.__sio, socketio_path="/ws/socket.io") self.__app = ASGIApp(socketio_server=self.__sio, socketio_path="/ws/socket.io")

View File

@ -1,13 +1,12 @@
from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Optional, Union from typing import Optional, Union
from abc import ABC, abstractmethod
from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
class BulkDownloadBase(ABC):
class BulkDownloadBase(ABC):
@abstractmethod @abstractmethod
def __init__( def __init__(
self, self,
@ -25,8 +24,8 @@ class BulkDownloadBase(ABC):
def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[str]) -> None: def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[str]) -> None:
""" """
Starts a a bulk download job. Starts a a bulk download job.
:param invoker: The Invoker that holds all the services, required to be passed as a parameter to avoid circular dependencies. :param invoker: The Invoker that holds all the services, required to be passed as a parameter to avoid circular dependencies.
:param image_names: A list of image names to include in the zip file. :param image_names: A list of image names to include in the zip file.
:param board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file. :param board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file.
""" """

View File

@ -1,4 +1,3 @@
class BulkDownloadException(Exception): class BulkDownloadException(Exception):
"""Exception raised when a bulk download fails.""" """Exception raised when a bulk download fails."""
@ -6,6 +5,7 @@ class BulkDownloadException(Exception):
super().__init__(message) super().__init__(message)
self.message = message self.message = message
class BulkDownloadTargetException(BulkDownloadException): class BulkDownloadTargetException(BulkDownloadException):
"""Exception raised when a bulk download target is not found.""" """Exception raised when a bulk download target is not found."""
@ -13,9 +13,13 @@ class BulkDownloadTargetException(BulkDownloadException):
super().__init__(message) super().__init__(message)
self.message = message self.message = message
class BulkDownloadParametersException(BulkDownloadException): class BulkDownloadParametersException(BulkDownloadException):
"""Exception raised when a bulk download parameter is invalid.""" """Exception raised when a bulk download parameter is invalid."""
def __init__(self, message="The bulk download parameters are invalid, either an array of image names or a board id must be provided"): def __init__(
self,
message="The bulk download parameters are invalid, either an array of image names or a board id must be provided",
):
super().__init__(message) super().__init__(message)
self.message = message self.message = message

View File

@ -1,6 +1,6 @@
import uuid
from pathlib import Path from pathlib import Path
from typing import Optional, Union from typing import Optional, Union
import uuid
from zipfile import ZipFile from zipfile import ZipFile
from invokeai.app.services.board_records.board_records_common import BoardRecordNotFoundException from invokeai.app.services.board_records.board_records_common import BoardRecordNotFoundException
@ -11,15 +11,17 @@ from invokeai.app.services.invoker import Invoker
from .bulk_download_base import BulkDownloadBase from .bulk_download_base import BulkDownloadBase
class BulkDownloadService(BulkDownloadBase):
class BulkDownloadService(BulkDownloadBase):
__output_folder: Path __output_folder: Path
__bulk_downloads_folder: Path __bulk_downloads_folder: Path
__event_bus: Optional[EventServiceBase] __event_bus: Optional[EventServiceBase]
def __init__(self, def __init__(
output_folder: Union[str, Path], self,
event_bus: Optional[EventServiceBase] = None,): output_folder: Union[str, Path],
event_bus: Optional[EventServiceBase] = None,
):
""" """
Initialize the downloader object. Initialize the downloader object.
@ -30,8 +32,7 @@ class BulkDownloadService(BulkDownloadBase):
self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True) self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True)
self.__event_bus = event_bus self.__event_bus = event_bus
def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[str]) -> None:
def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[str]) -> None:
""" """
Create a zip file containing the images specified by the given image names or board id. Create a zip file containing the images specified by the given image names or board id.
@ -39,10 +40,8 @@ class BulkDownloadService(BulkDownloadBase):
param: board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file. param: board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file.
""" """
bulk_download_id = str(uuid.uuid4()) bulk_download_id = str(uuid.uuid4())
self._signal_job_started(bulk_download_id)
try: try:
board_name: Union[str, None] = None
if board_id: if board_id:
image_names = invoker.services.board_image_records.get_all_board_image_names_for_board(board_id) image_names = invoker.services.board_image_records.get_all_board_image_names_for_board(board_id)
if board_id == "none": if board_id == "none":
@ -64,24 +63,21 @@ class BulkDownloadService(BulkDownloadBase):
for image_name in image_names: for image_name in image_names:
image_names_to_paths[image_name] = invoker.services.images.get_path(image_name) image_names_to_paths[image_name] = invoker.services.images.get_path(image_name)
return image_names_to_paths return image_names_to_paths
def _create_zip_file(self, image_names_to_paths: dict[str, str], bulk_download_id: str) -> str: def _create_zip_file(self, image_names_to_paths: dict[str, str], bulk_download_id: str) -> str:
""" """
Create a zip file containing the images specified by the given image names or board id. Create a zip file containing the images specified by the given image names or board id.
If download with the same bulk_download_id already exists, it will be overwritten. If download with the same bulk_download_id already exists, it will be overwritten.
""" """
zip_file_path = self.__bulk_downloads_folder / (bulk_download_id + ".zip") zip_file_path = self.__bulk_downloads_folder / (bulk_download_id + ".zip")
with ZipFile(zip_file_path, "w") as zip_file: with ZipFile(zip_file_path, "w") as zip_file:
for image_name, image_path in image_names_to_paths.items(): for image_name, image_path in image_names_to_paths.items():
zip_file.write(image_path, arcname=image_name) zip_file.write(image_path, arcname=image_name)
return str(zip_file_path) return str(zip_file_path)
def _signal_job_started(self, bulk_download_id: str) -> None: def _signal_job_started(self, bulk_download_id: str) -> None:
"""Signal that a bulk download job has started.""" """Signal that a bulk download job has started."""
if self.__event_bus: if self.__event_bus:
@ -90,7 +86,6 @@ class BulkDownloadService(BulkDownloadBase):
bulk_download_id=bulk_download_id, bulk_download_id=bulk_download_id,
) )
def _signal_job_completed(self, bulk_download_id: str, file_path: str) -> None: def _signal_job_completed(self, bulk_download_id: str, file_path: str) -> None:
"""Signal that a bulk download job has completed.""" """Signal that a bulk download job has completed."""
if self.__event_bus: if self.__event_bus:
@ -100,7 +95,7 @@ class BulkDownloadService(BulkDownloadBase):
bulk_download_id=bulk_download_id, bulk_download_id=bulk_download_id,
file_path=file_path, file_path=file_path,
) )
def _signal_job_failed(self, bulk_download_id: str, exception: Exception) -> None: def _signal_job_failed(self, bulk_download_id: str, exception: Exception) -> None:
"""Signal that a bulk download job has failed.""" """Signal that a bulk download job has failed."""
if self.__event_bus: if self.__event_bus:
@ -110,5 +105,3 @@ class BulkDownloadService(BulkDownloadBase):
bulk_download_id=bulk_download_id, bulk_download_id=bulk_download_id,
error=str(exception), error=str(exception),
) )

View File

@ -444,20 +444,19 @@ class EventServiceBase:
"""Emitted when a bulk download starts""" """Emitted when a bulk download starts"""
self._emit_bulk_download_event( self._emit_bulk_download_event(
event_name="bulk_download_started", event_name="bulk_download_started",
payload={"bulk_download_id": bulk_download_id, } payload={
"bulk_download_id": bulk_download_id,
},
) )
def emit_bulk_download_completed(self, bulk_download_id: str, file_path: str) -> None: def emit_bulk_download_completed(self, bulk_download_id: str, file_path: str) -> None:
"""Emitted when a bulk download completes""" """Emitted when a bulk download completes"""
self._emit_bulk_download_event( self._emit_bulk_download_event(
event_name="bulk_download_completed", event_name="bulk_download_completed", payload={"bulk_download_id": bulk_download_id, "file_path": file_path}
payload={"bulk_download_id": bulk_download_id,
"file_path": file_path}
) )
def emit_bulk_download_failed(self, bulk_download_id: str, error: str) -> None: def emit_bulk_download_failed(self, bulk_download_id: str, error: str) -> None:
"""Emitted when a bulk download fails""" """Emitted when a bulk download fails"""
self._emit_bulk_download_event( self._emit_bulk_download_event(
event_name="bulk_download_failed", event_name="bulk_download_failed", payload={"bulk_download_id": bulk_download_id, "error": error}
payload={"bulk_download_id": bulk_download_id, "error": error}
) )