mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
linted and styling
This commit is contained in:
parent
fc9a62dbf5
commit
98a01368b8
@ -10,7 +10,6 @@ from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
|
||||
from invokeai.app.services.board_records.board_records_common import BoardRecordNotFoundException
|
||||
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID, WorkflowWithoutIDValidator
|
||||
@ -373,13 +372,16 @@ async def unstar_images_in_list(
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to unstar images")
|
||||
|
||||
|
||||
class ImagesDownloaded(BaseModel):
|
||||
response: Optional[str] = Field(
|
||||
description="If defined, the message to display to the user when images begin downloading"
|
||||
)
|
||||
|
||||
|
||||
@images_router.post("/download", operation_id="download_images_from_list", response_model=ImagesDownloaded, status_code=202)
|
||||
@images_router.post(
|
||||
"/download", operation_id="download_images_from_list", response_model=ImagesDownloaded, status_code=202
|
||||
)
|
||||
async def download_images_from_list(
|
||||
background_tasks: BackgroundTasks,
|
||||
image_names: list[str] = Body(description="The list of names of images to download", embed=True),
|
||||
@ -389,6 +391,7 @@ async def download_images_from_list(
|
||||
) -> ImagesDownloaded:
|
||||
if (image_names is None or len(image_names) == 0) and board_id is None:
|
||||
raise HTTPException(status_code=400, detail="No images or board id specified.")
|
||||
background_tasks.add_task(ApiDependencies.invoker.services.bulk_download.handler, ApiDependencies.invoker, image_names, board_id)
|
||||
background_tasks.add_task(
|
||||
ApiDependencies.invoker.services.bulk_download.handler, ApiDependencies.invoker, image_names, board_id
|
||||
)
|
||||
return ImagesDownloaded(response="Your images are preparing to be downloaded")
|
||||
|
||||
|
@ -18,7 +18,6 @@ class SocketIO:
|
||||
__sub_bulk_download: str = "subscribe_bulk_download"
|
||||
__unsub_bulk_download: str = "unsubscribe_bulk_download"
|
||||
|
||||
|
||||
def __init__(self, app: FastAPI):
|
||||
self.__sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")
|
||||
self.__app = ASGIApp(socketio_server=self.__sio, socketio_path="/ws/socket.io")
|
||||
|
@ -1,13 +1,12 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
|
||||
class BulkDownloadBase(ABC):
|
||||
|
||||
class BulkDownloadBase(ABC):
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
@ -25,8 +24,8 @@ class BulkDownloadBase(ABC):
|
||||
def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[str]) -> None:
|
||||
"""
|
||||
Starts a a bulk download job.
|
||||
|
||||
|
||||
:param invoker: The Invoker that holds all the services, required to be passed as a parameter to avoid circular dependencies.
|
||||
:param image_names: A list of image names to include in the zip file.
|
||||
:param board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file.
|
||||
"""
|
||||
"""
|
||||
|
@ -1,4 +1,3 @@
|
||||
|
||||
class BulkDownloadException(Exception):
|
||||
"""Exception raised when a bulk download fails."""
|
||||
|
||||
@ -6,6 +5,7 @@ class BulkDownloadException(Exception):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
|
||||
class BulkDownloadTargetException(BulkDownloadException):
|
||||
"""Exception raised when a bulk download target is not found."""
|
||||
|
||||
@ -13,9 +13,13 @@ class BulkDownloadTargetException(BulkDownloadException):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
|
||||
class BulkDownloadParametersException(BulkDownloadException):
|
||||
"""Exception raised when a bulk download parameter is invalid."""
|
||||
|
||||
def __init__(self, message="The bulk download parameters are invalid, either an array of image names or a board id must be provided"):
|
||||
def __init__(
|
||||
self,
|
||||
message="The bulk download parameters are invalid, either an array of image names or a board id must be provided",
|
||||
):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.message = message
|
||||
|
@ -1,6 +1,6 @@
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
import uuid
|
||||
from zipfile import ZipFile
|
||||
|
||||
from invokeai.app.services.board_records.board_records_common import BoardRecordNotFoundException
|
||||
@ -11,15 +11,17 @@ from invokeai.app.services.invoker import Invoker
|
||||
|
||||
from .bulk_download_base import BulkDownloadBase
|
||||
|
||||
class BulkDownloadService(BulkDownloadBase):
|
||||
|
||||
class BulkDownloadService(BulkDownloadBase):
|
||||
__output_folder: Path
|
||||
__bulk_downloads_folder: Path
|
||||
__event_bus: Optional[EventServiceBase]
|
||||
|
||||
def __init__(self,
|
||||
output_folder: Union[str, Path],
|
||||
event_bus: Optional[EventServiceBase] = None,):
|
||||
def __init__(
|
||||
self,
|
||||
output_folder: Union[str, Path],
|
||||
event_bus: Optional[EventServiceBase] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the downloader object.
|
||||
|
||||
@ -30,8 +32,7 @@ class BulkDownloadService(BulkDownloadBase):
|
||||
self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True)
|
||||
self.__event_bus = event_bus
|
||||
|
||||
|
||||
def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[str]) -> None:
|
||||
def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[str]) -> None:
|
||||
"""
|
||||
Create a zip file containing the images specified by the given image names or board id.
|
||||
|
||||
@ -39,10 +40,8 @@ class BulkDownloadService(BulkDownloadBase):
|
||||
param: board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file.
|
||||
"""
|
||||
bulk_download_id = str(uuid.uuid4())
|
||||
|
||||
self._signal_job_started(bulk_download_id)
|
||||
|
||||
try:
|
||||
board_name: Union[str, None] = None
|
||||
if board_id:
|
||||
image_names = invoker.services.board_image_records.get_all_board_image_names_for_board(board_id)
|
||||
if board_id == "none":
|
||||
@ -64,24 +63,21 @@ class BulkDownloadService(BulkDownloadBase):
|
||||
for image_name in image_names:
|
||||
image_names_to_paths[image_name] = invoker.services.images.get_path(image_name)
|
||||
return image_names_to_paths
|
||||
|
||||
|
||||
|
||||
def _create_zip_file(self, image_names_to_paths: dict[str, str], bulk_download_id: str) -> str:
|
||||
"""
|
||||
Create a zip file containing the images specified by the given image names or board id.
|
||||
Create a zip file containing the images specified by the given image names or board id.
|
||||
If download with the same bulk_download_id already exists, it will be overwritten.
|
||||
"""
|
||||
|
||||
zip_file_path = self.__bulk_downloads_folder / (bulk_download_id + ".zip")
|
||||
|
||||
|
||||
with ZipFile(zip_file_path, "w") as zip_file:
|
||||
for image_name, image_path in image_names_to_paths.items():
|
||||
zip_file.write(image_path, arcname=image_name)
|
||||
|
||||
return str(zip_file_path)
|
||||
|
||||
|
||||
def _signal_job_started(self, bulk_download_id: str) -> None:
|
||||
"""Signal that a bulk download job has started."""
|
||||
if self.__event_bus:
|
||||
@ -90,7 +86,6 @@ class BulkDownloadService(BulkDownloadBase):
|
||||
bulk_download_id=bulk_download_id,
|
||||
)
|
||||
|
||||
|
||||
def _signal_job_completed(self, bulk_download_id: str, file_path: str) -> None:
|
||||
"""Signal that a bulk download job has completed."""
|
||||
if self.__event_bus:
|
||||
@ -100,7 +95,7 @@ class BulkDownloadService(BulkDownloadBase):
|
||||
bulk_download_id=bulk_download_id,
|
||||
file_path=file_path,
|
||||
)
|
||||
|
||||
|
||||
def _signal_job_failed(self, bulk_download_id: str, exception: Exception) -> None:
|
||||
"""Signal that a bulk download job has failed."""
|
||||
if self.__event_bus:
|
||||
@ -110,5 +105,3 @@ class BulkDownloadService(BulkDownloadBase):
|
||||
bulk_download_id=bulk_download_id,
|
||||
error=str(exception),
|
||||
)
|
||||
|
||||
|
@ -444,20 +444,19 @@ class EventServiceBase:
|
||||
"""Emitted when a bulk download starts"""
|
||||
self._emit_bulk_download_event(
|
||||
event_name="bulk_download_started",
|
||||
payload={"bulk_download_id": bulk_download_id, }
|
||||
payload={
|
||||
"bulk_download_id": bulk_download_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def emit_bulk_download_completed(self, bulk_download_id: str, file_path: str) -> None:
|
||||
"""Emitted when a bulk download completes"""
|
||||
self._emit_bulk_download_event(
|
||||
event_name="bulk_download_completed",
|
||||
payload={"bulk_download_id": bulk_download_id,
|
||||
"file_path": file_path}
|
||||
event_name="bulk_download_completed", payload={"bulk_download_id": bulk_download_id, "file_path": file_path}
|
||||
)
|
||||
|
||||
|
||||
def emit_bulk_download_failed(self, bulk_download_id: str, error: str) -> None:
|
||||
"""Emitted when a bulk download fails"""
|
||||
self._emit_bulk_download_event(
|
||||
event_name="bulk_download_failed",
|
||||
payload={"bulk_download_id": bulk_download_id, "error": error}
|
||||
event_name="bulk_download_failed", payload={"bulk_download_id": bulk_download_id, "error": error}
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user