mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
linted and styling
This commit is contained in:
parent
fc9a62dbf5
commit
98a01368b8
@ -10,7 +10,6 @@ from pydantic import BaseModel, Field, ValidationError
|
|||||||
|
|
||||||
from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator
|
from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator
|
||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
|
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
|
||||||
from invokeai.app.services.board_records.board_records_common import BoardRecordNotFoundException
|
|
||||||
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
|
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
|
||||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID, WorkflowWithoutIDValidator
|
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID, WorkflowWithoutIDValidator
|
||||||
@ -373,13 +372,16 @@ async def unstar_images_in_list(
|
|||||||
except Exception:
|
except Exception:
|
||||||
raise HTTPException(status_code=500, detail="Failed to unstar images")
|
raise HTTPException(status_code=500, detail="Failed to unstar images")
|
||||||
|
|
||||||
|
|
||||||
class ImagesDownloaded(BaseModel):
|
class ImagesDownloaded(BaseModel):
|
||||||
response: Optional[str] = Field(
|
response: Optional[str] = Field(
|
||||||
description="If defined, the message to display to the user when images begin downloading"
|
description="If defined, the message to display to the user when images begin downloading"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@images_router.post("/download", operation_id="download_images_from_list", response_model=ImagesDownloaded, status_code=202)
|
@images_router.post(
|
||||||
|
"/download", operation_id="download_images_from_list", response_model=ImagesDownloaded, status_code=202
|
||||||
|
)
|
||||||
async def download_images_from_list(
|
async def download_images_from_list(
|
||||||
background_tasks: BackgroundTasks,
|
background_tasks: BackgroundTasks,
|
||||||
image_names: list[str] = Body(description="The list of names of images to download", embed=True),
|
image_names: list[str] = Body(description="The list of names of images to download", embed=True),
|
||||||
@ -389,6 +391,7 @@ async def download_images_from_list(
|
|||||||
) -> ImagesDownloaded:
|
) -> ImagesDownloaded:
|
||||||
if (image_names is None or len(image_names) == 0) and board_id is None:
|
if (image_names is None or len(image_names) == 0) and board_id is None:
|
||||||
raise HTTPException(status_code=400, detail="No images or board id specified.")
|
raise HTTPException(status_code=400, detail="No images or board id specified.")
|
||||||
background_tasks.add_task(ApiDependencies.invoker.services.bulk_download.handler, ApiDependencies.invoker, image_names, board_id)
|
background_tasks.add_task(
|
||||||
|
ApiDependencies.invoker.services.bulk_download.handler, ApiDependencies.invoker, image_names, board_id
|
||||||
|
)
|
||||||
return ImagesDownloaded(response="Your images are preparing to be downloaded")
|
return ImagesDownloaded(response="Your images are preparing to be downloaded")
|
||||||
|
|
||||||
|
@ -18,7 +18,6 @@ class SocketIO:
|
|||||||
__sub_bulk_download: str = "subscribe_bulk_download"
|
__sub_bulk_download: str = "subscribe_bulk_download"
|
||||||
__unsub_bulk_download: str = "unsubscribe_bulk_download"
|
__unsub_bulk_download: str = "unsubscribe_bulk_download"
|
||||||
|
|
||||||
|
|
||||||
def __init__(self, app: FastAPI):
|
def __init__(self, app: FastAPI):
|
||||||
self.__sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")
|
self.__sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")
|
||||||
self.__app = ASGIApp(socketio_server=self.__sio, socketio_path="/ws/socket.io")
|
self.__app = ASGIApp(socketio_server=self.__sio, socketio_path="/ws/socket.io")
|
||||||
|
@ -1,13 +1,12 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
|
|
||||||
from invokeai.app.services.events.events_base import EventServiceBase
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
|
|
||||||
class BulkDownloadBase(ABC):
|
|
||||||
|
|
||||||
|
class BulkDownloadBase(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -25,8 +24,8 @@ class BulkDownloadBase(ABC):
|
|||||||
def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[str]) -> None:
|
def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[str]) -> None:
|
||||||
"""
|
"""
|
||||||
Starts a a bulk download job.
|
Starts a a bulk download job.
|
||||||
|
|
||||||
:param invoker: The Invoker that holds all the services, required to be passed as a parameter to avoid circular dependencies.
|
:param invoker: The Invoker that holds all the services, required to be passed as a parameter to avoid circular dependencies.
|
||||||
:param image_names: A list of image names to include in the zip file.
|
:param image_names: A list of image names to include in the zip file.
|
||||||
:param board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file.
|
:param board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file.
|
||||||
"""
|
"""
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
class BulkDownloadException(Exception):
|
class BulkDownloadException(Exception):
|
||||||
"""Exception raised when a bulk download fails."""
|
"""Exception raised when a bulk download fails."""
|
||||||
|
|
||||||
@ -6,6 +5,7 @@ class BulkDownloadException(Exception):
|
|||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
self.message = message
|
self.message = message
|
||||||
|
|
||||||
|
|
||||||
class BulkDownloadTargetException(BulkDownloadException):
|
class BulkDownloadTargetException(BulkDownloadException):
|
||||||
"""Exception raised when a bulk download target is not found."""
|
"""Exception raised when a bulk download target is not found."""
|
||||||
|
|
||||||
@ -13,9 +13,13 @@ class BulkDownloadTargetException(BulkDownloadException):
|
|||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
self.message = message
|
self.message = message
|
||||||
|
|
||||||
|
|
||||||
class BulkDownloadParametersException(BulkDownloadException):
|
class BulkDownloadParametersException(BulkDownloadException):
|
||||||
"""Exception raised when a bulk download parameter is invalid."""
|
"""Exception raised when a bulk download parameter is invalid."""
|
||||||
|
|
||||||
def __init__(self, message="The bulk download parameters are invalid, either an array of image names or a board id must be provided"):
|
def __init__(
|
||||||
|
self,
|
||||||
|
message="The bulk download parameters are invalid, either an array of image names or a board id must be provided",
|
||||||
|
):
|
||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
self.message = message
|
self.message = message
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
import uuid
|
|
||||||
from zipfile import ZipFile
|
from zipfile import ZipFile
|
||||||
|
|
||||||
from invokeai.app.services.board_records.board_records_common import BoardRecordNotFoundException
|
from invokeai.app.services.board_records.board_records_common import BoardRecordNotFoundException
|
||||||
@ -11,15 +11,17 @@ from invokeai.app.services.invoker import Invoker
|
|||||||
|
|
||||||
from .bulk_download_base import BulkDownloadBase
|
from .bulk_download_base import BulkDownloadBase
|
||||||
|
|
||||||
class BulkDownloadService(BulkDownloadBase):
|
|
||||||
|
|
||||||
|
class BulkDownloadService(BulkDownloadBase):
|
||||||
__output_folder: Path
|
__output_folder: Path
|
||||||
__bulk_downloads_folder: Path
|
__bulk_downloads_folder: Path
|
||||||
__event_bus: Optional[EventServiceBase]
|
__event_bus: Optional[EventServiceBase]
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
output_folder: Union[str, Path],
|
self,
|
||||||
event_bus: Optional[EventServiceBase] = None,):
|
output_folder: Union[str, Path],
|
||||||
|
event_bus: Optional[EventServiceBase] = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the downloader object.
|
Initialize the downloader object.
|
||||||
|
|
||||||
@ -30,8 +32,7 @@ class BulkDownloadService(BulkDownloadBase):
|
|||||||
self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True)
|
self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True)
|
||||||
self.__event_bus = event_bus
|
self.__event_bus = event_bus
|
||||||
|
|
||||||
|
def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[str]) -> None:
|
||||||
def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[str]) -> None:
|
|
||||||
"""
|
"""
|
||||||
Create a zip file containing the images specified by the given image names or board id.
|
Create a zip file containing the images specified by the given image names or board id.
|
||||||
|
|
||||||
@ -39,10 +40,8 @@ class BulkDownloadService(BulkDownloadBase):
|
|||||||
param: board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file.
|
param: board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file.
|
||||||
"""
|
"""
|
||||||
bulk_download_id = str(uuid.uuid4())
|
bulk_download_id = str(uuid.uuid4())
|
||||||
|
|
||||||
self._signal_job_started(bulk_download_id)
|
|
||||||
try:
|
try:
|
||||||
board_name: Union[str, None] = None
|
|
||||||
if board_id:
|
if board_id:
|
||||||
image_names = invoker.services.board_image_records.get_all_board_image_names_for_board(board_id)
|
image_names = invoker.services.board_image_records.get_all_board_image_names_for_board(board_id)
|
||||||
if board_id == "none":
|
if board_id == "none":
|
||||||
@ -64,24 +63,21 @@ class BulkDownloadService(BulkDownloadBase):
|
|||||||
for image_name in image_names:
|
for image_name in image_names:
|
||||||
image_names_to_paths[image_name] = invoker.services.images.get_path(image_name)
|
image_names_to_paths[image_name] = invoker.services.images.get_path(image_name)
|
||||||
return image_names_to_paths
|
return image_names_to_paths
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _create_zip_file(self, image_names_to_paths: dict[str, str], bulk_download_id: str) -> str:
|
def _create_zip_file(self, image_names_to_paths: dict[str, str], bulk_download_id: str) -> str:
|
||||||
"""
|
"""
|
||||||
Create a zip file containing the images specified by the given image names or board id.
|
Create a zip file containing the images specified by the given image names or board id.
|
||||||
If download with the same bulk_download_id already exists, it will be overwritten.
|
If download with the same bulk_download_id already exists, it will be overwritten.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
zip_file_path = self.__bulk_downloads_folder / (bulk_download_id + ".zip")
|
zip_file_path = self.__bulk_downloads_folder / (bulk_download_id + ".zip")
|
||||||
|
|
||||||
with ZipFile(zip_file_path, "w") as zip_file:
|
with ZipFile(zip_file_path, "w") as zip_file:
|
||||||
for image_name, image_path in image_names_to_paths.items():
|
for image_name, image_path in image_names_to_paths.items():
|
||||||
zip_file.write(image_path, arcname=image_name)
|
zip_file.write(image_path, arcname=image_name)
|
||||||
|
|
||||||
return str(zip_file_path)
|
return str(zip_file_path)
|
||||||
|
|
||||||
|
|
||||||
def _signal_job_started(self, bulk_download_id: str) -> None:
|
def _signal_job_started(self, bulk_download_id: str) -> None:
|
||||||
"""Signal that a bulk download job has started."""
|
"""Signal that a bulk download job has started."""
|
||||||
if self.__event_bus:
|
if self.__event_bus:
|
||||||
@ -90,7 +86,6 @@ class BulkDownloadService(BulkDownloadBase):
|
|||||||
bulk_download_id=bulk_download_id,
|
bulk_download_id=bulk_download_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _signal_job_completed(self, bulk_download_id: str, file_path: str) -> None:
|
def _signal_job_completed(self, bulk_download_id: str, file_path: str) -> None:
|
||||||
"""Signal that a bulk download job has completed."""
|
"""Signal that a bulk download job has completed."""
|
||||||
if self.__event_bus:
|
if self.__event_bus:
|
||||||
@ -100,7 +95,7 @@ class BulkDownloadService(BulkDownloadBase):
|
|||||||
bulk_download_id=bulk_download_id,
|
bulk_download_id=bulk_download_id,
|
||||||
file_path=file_path,
|
file_path=file_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _signal_job_failed(self, bulk_download_id: str, exception: Exception) -> None:
|
def _signal_job_failed(self, bulk_download_id: str, exception: Exception) -> None:
|
||||||
"""Signal that a bulk download job has failed."""
|
"""Signal that a bulk download job has failed."""
|
||||||
if self.__event_bus:
|
if self.__event_bus:
|
||||||
@ -110,5 +105,3 @@ class BulkDownloadService(BulkDownloadBase):
|
|||||||
bulk_download_id=bulk_download_id,
|
bulk_download_id=bulk_download_id,
|
||||||
error=str(exception),
|
error=str(exception),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -444,20 +444,19 @@ class EventServiceBase:
|
|||||||
"""Emitted when a bulk download starts"""
|
"""Emitted when a bulk download starts"""
|
||||||
self._emit_bulk_download_event(
|
self._emit_bulk_download_event(
|
||||||
event_name="bulk_download_started",
|
event_name="bulk_download_started",
|
||||||
payload={"bulk_download_id": bulk_download_id, }
|
payload={
|
||||||
|
"bulk_download_id": bulk_download_id,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_bulk_download_completed(self, bulk_download_id: str, file_path: str) -> None:
|
def emit_bulk_download_completed(self, bulk_download_id: str, file_path: str) -> None:
|
||||||
"""Emitted when a bulk download completes"""
|
"""Emitted when a bulk download completes"""
|
||||||
self._emit_bulk_download_event(
|
self._emit_bulk_download_event(
|
||||||
event_name="bulk_download_completed",
|
event_name="bulk_download_completed", payload={"bulk_download_id": bulk_download_id, "file_path": file_path}
|
||||||
payload={"bulk_download_id": bulk_download_id,
|
|
||||||
"file_path": file_path}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_bulk_download_failed(self, bulk_download_id: str, error: str) -> None:
|
def emit_bulk_download_failed(self, bulk_download_id: str, error: str) -> None:
|
||||||
"""Emitted when a bulk download fails"""
|
"""Emitted when a bulk download fails"""
|
||||||
self._emit_bulk_download_event(
|
self._emit_bulk_download_event(
|
||||||
event_name="bulk_download_failed",
|
event_name="bulk_download_failed", payload={"bulk_download_id": bulk_download_id, "error": error}
|
||||||
payload={"bulk_download_id": bulk_download_id, "error": error}
|
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user