linted and styling

This commit is contained in:
Stefan Tobler 2024-01-07 22:17:03 -05:00 committed by psychedelicious
parent fc9a62dbf5
commit 98a01368b8
6 changed files with 37 additions and 40 deletions

View File

@ -10,7 +10,6 @@ from pydantic import BaseModel, Field, ValidationError
from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
from invokeai.app.services.board_records.board_records_common import BoardRecordNotFoundException
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID, WorkflowWithoutIDValidator from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID, WorkflowWithoutIDValidator
@ -373,13 +372,16 @@ async def unstar_images_in_list(
except Exception: except Exception:
raise HTTPException(status_code=500, detail="Failed to unstar images") raise HTTPException(status_code=500, detail="Failed to unstar images")
class ImagesDownloaded(BaseModel): class ImagesDownloaded(BaseModel):
response: Optional[str] = Field( response: Optional[str] = Field(
description="If defined, the message to display to the user when images begin downloading" description="If defined, the message to display to the user when images begin downloading"
) )
@images_router.post("/download", operation_id="download_images_from_list", response_model=ImagesDownloaded, status_code=202) @images_router.post(
"/download", operation_id="download_images_from_list", response_model=ImagesDownloaded, status_code=202
)
async def download_images_from_list( async def download_images_from_list(
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
image_names: list[str] = Body(description="The list of names of images to download", embed=True), image_names: list[str] = Body(description="The list of names of images to download", embed=True),
@ -389,6 +391,7 @@ async def download_images_from_list(
) -> ImagesDownloaded: ) -> ImagesDownloaded:
if (image_names is None or len(image_names) == 0) and board_id is None: if (image_names is None or len(image_names) == 0) and board_id is None:
raise HTTPException(status_code=400, detail="No images or board id specified.") raise HTTPException(status_code=400, detail="No images or board id specified.")
background_tasks.add_task(ApiDependencies.invoker.services.bulk_download.handler, ApiDependencies.invoker, image_names, board_id) background_tasks.add_task(
ApiDependencies.invoker.services.bulk_download.handler, ApiDependencies.invoker, image_names, board_id
)
return ImagesDownloaded(response="Your images are preparing to be downloaded") return ImagesDownloaded(response="Your images are preparing to be downloaded")

View File

@ -18,7 +18,6 @@ class SocketIO:
__sub_bulk_download: str = "subscribe_bulk_download" __sub_bulk_download: str = "subscribe_bulk_download"
__unsub_bulk_download: str = "unsubscribe_bulk_download" __unsub_bulk_download: str = "unsubscribe_bulk_download"
def __init__(self, app: FastAPI): def __init__(self, app: FastAPI):
self.__sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*") self.__sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")
self.__app = ASGIApp(socketio_server=self.__sio, socketio_path="/ws/socket.io") self.__app = ASGIApp(socketio_server=self.__sio, socketio_path="/ws/socket.io")

View File

@ -1,13 +1,12 @@
from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Optional, Union from typing import Optional, Union
from abc import ABC, abstractmethod
from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
class BulkDownloadBase(ABC):
class BulkDownloadBase(ABC):
@abstractmethod @abstractmethod
def __init__( def __init__(
self, self,

View File

@ -1,4 +1,3 @@
class BulkDownloadException(Exception): class BulkDownloadException(Exception):
"""Exception raised when a bulk download fails.""" """Exception raised when a bulk download fails."""
@ -6,6 +5,7 @@ class BulkDownloadException(Exception):
super().__init__(message) super().__init__(message)
self.message = message self.message = message
class BulkDownloadTargetException(BulkDownloadException): class BulkDownloadTargetException(BulkDownloadException):
"""Exception raised when a bulk download target is not found.""" """Exception raised when a bulk download target is not found."""
@ -13,9 +13,13 @@ class BulkDownloadTargetException(BulkDownloadException):
super().__init__(message) super().__init__(message)
self.message = message self.message = message
class BulkDownloadParametersException(BulkDownloadException): class BulkDownloadParametersException(BulkDownloadException):
"""Exception raised when a bulk download parameter is invalid.""" """Exception raised when a bulk download parameter is invalid."""
def __init__(self, message="The bulk download parameters are invalid, either an array of image names or a board id must be provided"): def __init__(
self,
message="The bulk download parameters are invalid, either an array of image names or a board id must be provided",
):
super().__init__(message) super().__init__(message)
self.message = message self.message = message

View File

@ -1,6 +1,6 @@
import uuid
from pathlib import Path from pathlib import Path
from typing import Optional, Union from typing import Optional, Union
import uuid
from zipfile import ZipFile from zipfile import ZipFile
from invokeai.app.services.board_records.board_records_common import BoardRecordNotFoundException from invokeai.app.services.board_records.board_records_common import BoardRecordNotFoundException
@ -11,15 +11,17 @@ from invokeai.app.services.invoker import Invoker
from .bulk_download_base import BulkDownloadBase from .bulk_download_base import BulkDownloadBase
class BulkDownloadService(BulkDownloadBase):
class BulkDownloadService(BulkDownloadBase):
__output_folder: Path __output_folder: Path
__bulk_downloads_folder: Path __bulk_downloads_folder: Path
__event_bus: Optional[EventServiceBase] __event_bus: Optional[EventServiceBase]
def __init__(self, def __init__(
output_folder: Union[str, Path], self,
event_bus: Optional[EventServiceBase] = None,): output_folder: Union[str, Path],
event_bus: Optional[EventServiceBase] = None,
):
""" """
Initialize the downloader object. Initialize the downloader object.
@ -30,8 +32,7 @@ class BulkDownloadService(BulkDownloadBase):
self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True) self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True)
self.__event_bus = event_bus self.__event_bus = event_bus
def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[str]) -> None:
def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[str]) -> None:
""" """
Create a zip file containing the images specified by the given image names or board id. Create a zip file containing the images specified by the given image names or board id.
@ -40,9 +41,7 @@ class BulkDownloadService(BulkDownloadBase):
""" """
bulk_download_id = str(uuid.uuid4()) bulk_download_id = str(uuid.uuid4())
self._signal_job_started(bulk_download_id)
try: try:
board_name: Union[str, None] = None
if board_id: if board_id:
image_names = invoker.services.board_image_records.get_all_board_image_names_for_board(board_id) image_names = invoker.services.board_image_records.get_all_board_image_names_for_board(board_id)
if board_id == "none": if board_id == "none":
@ -65,8 +64,6 @@ class BulkDownloadService(BulkDownloadBase):
image_names_to_paths[image_name] = invoker.services.images.get_path(image_name) image_names_to_paths[image_name] = invoker.services.images.get_path(image_name)
return image_names_to_paths return image_names_to_paths
def _create_zip_file(self, image_names_to_paths: dict[str, str], bulk_download_id: str) -> str: def _create_zip_file(self, image_names_to_paths: dict[str, str], bulk_download_id: str) -> str:
""" """
Create a zip file containing the images specified by the given image names or board id. Create a zip file containing the images specified by the given image names or board id.
@ -81,7 +78,6 @@ class BulkDownloadService(BulkDownloadBase):
return str(zip_file_path) return str(zip_file_path)
def _signal_job_started(self, bulk_download_id: str) -> None: def _signal_job_started(self, bulk_download_id: str) -> None:
"""Signal that a bulk download job has started.""" """Signal that a bulk download job has started."""
if self.__event_bus: if self.__event_bus:
@ -90,7 +86,6 @@ class BulkDownloadService(BulkDownloadBase):
bulk_download_id=bulk_download_id, bulk_download_id=bulk_download_id,
) )
def _signal_job_completed(self, bulk_download_id: str, file_path: str) -> None: def _signal_job_completed(self, bulk_download_id: str, file_path: str) -> None:
"""Signal that a bulk download job has completed.""" """Signal that a bulk download job has completed."""
if self.__event_bus: if self.__event_bus:
@ -110,5 +105,3 @@ class BulkDownloadService(BulkDownloadBase):
bulk_download_id=bulk_download_id, bulk_download_id=bulk_download_id,
error=str(exception), error=str(exception),
) )

View File

@ -444,20 +444,19 @@ class EventServiceBase:
"""Emitted when a bulk download starts""" """Emitted when a bulk download starts"""
self._emit_bulk_download_event( self._emit_bulk_download_event(
event_name="bulk_download_started", event_name="bulk_download_started",
payload={"bulk_download_id": bulk_download_id, } payload={
"bulk_download_id": bulk_download_id,
},
) )
def emit_bulk_download_completed(self, bulk_download_id: str, file_path: str) -> None: def emit_bulk_download_completed(self, bulk_download_id: str, file_path: str) -> None:
"""Emitted when a bulk download completes""" """Emitted when a bulk download completes"""
self._emit_bulk_download_event( self._emit_bulk_download_event(
event_name="bulk_download_completed", event_name="bulk_download_completed", payload={"bulk_download_id": bulk_download_id, "file_path": file_path}
payload={"bulk_download_id": bulk_download_id,
"file_path": file_path}
) )
def emit_bulk_download_failed(self, bulk_download_id: str, error: str) -> None: def emit_bulk_download_failed(self, bulk_download_id: str, error: str) -> None:
"""Emitted when a bulk download fails""" """Emitted when a bulk download fails"""
self._emit_bulk_download_event( self._emit_bulk_download_event(
event_name="bulk_download_failed", event_name="bulk_download_failed", payload={"bulk_download_id": bulk_download_id, "error": error}
payload={"bulk_download_id": bulk_download_id, "error": error}
) )