mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactoring bulkdownload to consider image category
This commit is contained in:
parent
1e00b9760a
commit
c2b12f8849
@ -391,9 +391,7 @@ async def download_images_from_list(
|
|||||||
) -> ImagesDownloaded:
|
) -> ImagesDownloaded:
|
||||||
if (image_names is None or len(image_names) == 0) and board_id is None:
|
if (image_names is None or len(image_names) == 0) and board_id is None:
|
||||||
raise HTTPException(status_code=400, detail="No images or board id specified.")
|
raise HTTPException(status_code=400, detail="No images or board id specified.")
|
||||||
background_tasks.add_task(
|
background_tasks.add_task(ApiDependencies.invoker.services.bulk_download.handler, image_names, board_id)
|
||||||
ApiDependencies.invoker.services.bulk_download.handler, ApiDependencies.invoker, image_names, board_id
|
|
||||||
)
|
|
||||||
return ImagesDownloaded(response="Your images are preparing to be downloaded")
|
return ImagesDownloaded(response="Your images are preparing to be downloaded")
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,6 +7,17 @@ from invokeai.app.services.invoker import Invoker
|
|||||||
|
|
||||||
|
|
||||||
class BulkDownloadBase(ABC):
|
class BulkDownloadBase(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def start(self, invoker: Invoker) -> None:
|
||||||
|
"""
|
||||||
|
Starts the BulkDownloadService.
|
||||||
|
|
||||||
|
This method is responsible for starting the BulkDownloadService and performing any necessary initialization
|
||||||
|
operations to prepare the service for use.
|
||||||
|
|
||||||
|
param: invoker: The Invoker that holds all the services, required to be passed as a parameter to avoid circular dependencies.
|
||||||
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -21,7 +32,7 @@ class BulkDownloadBase(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[str]) -> None:
|
def handler(self, image_names: list[str], board_id: Optional[str]) -> None:
|
||||||
"""
|
"""
|
||||||
Starts a a bulk download job.
|
Starts a a bulk download job.
|
||||||
|
|
||||||
|
@ -10,7 +10,8 @@ from invokeai.app.services.bulk_download.bulk_download_common import (
|
|||||||
BulkDownloadTargetException,
|
BulkDownloadTargetException,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.events.events_base import EventServiceBase
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordNotFoundException
|
from invokeai.app.services.image_records.image_records_common import ImageRecordNotFoundException
|
||||||
|
from invokeai.app.services.images.images_common import ImageDTO
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
|
|
||||||
from .bulk_download_base import BulkDownloadBase
|
from .bulk_download_base import BulkDownloadBase
|
||||||
@ -20,6 +21,10 @@ class BulkDownloadService(BulkDownloadBase):
|
|||||||
__output_folder: Path
|
__output_folder: Path
|
||||||
__bulk_downloads_folder: Path
|
__bulk_downloads_folder: Path
|
||||||
__event_bus: Optional[EventServiceBase]
|
__event_bus: Optional[EventServiceBase]
|
||||||
|
__invoker: Invoker
|
||||||
|
|
||||||
|
def start(self, invoker: Invoker) -> None:
|
||||||
|
self.__invoker = invoker
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -53,7 +58,7 @@ class BulkDownloadService(BulkDownloadBase):
|
|||||||
path = path if isinstance(path, Path) else Path(path)
|
path = path if isinstance(path, Path) else Path(path)
|
||||||
return path.exists()
|
return path.exists()
|
||||||
|
|
||||||
def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[str]) -> None:
|
def handler(self, image_names: list[str], board_id: Optional[str]) -> None:
|
||||||
"""
|
"""
|
||||||
Create a zip file containing the images specified by the given image names or board id.
|
Create a zip file containing the images specified by the given image names or board id.
|
||||||
|
|
||||||
@ -64,50 +69,40 @@ class BulkDownloadService(BulkDownloadBase):
|
|||||||
bulk_download_id: str = DEFAULT_BULK_DOWNLOAD_ID
|
bulk_download_id: str = DEFAULT_BULK_DOWNLOAD_ID
|
||||||
bulk_download_item_id: str = str(uuid.uuid4()) if board_id is None else board_id
|
bulk_download_item_id: str = str(uuid.uuid4()) if board_id is None else board_id
|
||||||
|
|
||||||
|
self._signal_job_started(bulk_download_id, bulk_download_item_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
board_name: str = ""
|
board_name: str = ""
|
||||||
|
image_dtos: list[ImageDTO] = []
|
||||||
|
|
||||||
if board_id:
|
if board_id:
|
||||||
# -1 is the default value for limit, which means no limit, is_intermediate only gives us completed images
|
|
||||||
image_names = [
|
|
||||||
img.image_name
|
|
||||||
for img in invoker.services.images.get_many(
|
|
||||||
offset=0,
|
|
||||||
limit=-1,
|
|
||||||
board_id=board_id,
|
|
||||||
is_intermediate=False,
|
|
||||||
categories=[ImageCategory.GENERAL],
|
|
||||||
).items
|
|
||||||
]
|
|
||||||
if board_id == "none":
|
if board_id == "none":
|
||||||
board_id = "Uncategorized"
|
|
||||||
board_name = "Uncategorized"
|
board_name = "Uncategorized"
|
||||||
else:
|
else:
|
||||||
board_name = invoker.services.board_records.get(board_id).board_name
|
board_name = self.__invoker.services.board_records.get(board_id).board_name
|
||||||
board_name = self._clean_string_to_path_safe(board_name)
|
board_name = self._clean_string_to_path_safe(board_name)
|
||||||
|
|
||||||
self._signal_job_started(bulk_download_id, bulk_download_item_id)
|
# -1 is the default value for limit, which means no limit, is_intermediate only gives us completed images
|
||||||
|
image_dtos = self.__invoker.services.images.get_many(
|
||||||
image_names_to_paths: dict[str, str] = self._get_image_name_to_path_map(invoker, image_names)
|
offset=0,
|
||||||
|
limit=-1,
|
||||||
|
board_id=board_id,
|
||||||
|
is_intermediate=False,
|
||||||
|
).items
|
||||||
|
else:
|
||||||
|
image_dtos = [self.__invoker.services.images.get_dto(image_name) for image_name in image_names]
|
||||||
bulk_download_item_name: str = self._create_zip_file(
|
bulk_download_item_name: str = self._create_zip_file(
|
||||||
image_names_to_paths, bulk_download_item_id if board_id is None else board_name
|
image_dtos, bulk_download_item_id if board_id is None else board_name
|
||||||
)
|
)
|
||||||
self._signal_job_completed(bulk_download_id, bulk_download_item_id, bulk_download_item_name)
|
self._signal_job_completed(bulk_download_id, bulk_download_item_id, bulk_download_item_name)
|
||||||
except (ImageRecordNotFoundException, BoardRecordNotFoundException, BulkDownloadException) as e:
|
except (ImageRecordNotFoundException, BoardRecordNotFoundException, BulkDownloadException) as e:
|
||||||
self._signal_job_failed(bulk_download_id, bulk_download_item_id, e)
|
self._signal_job_failed(bulk_download_id, bulk_download_item_id, e)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._signal_job_failed(bulk_download_id, bulk_download_item_id, e)
|
self._signal_job_failed(bulk_download_id, bulk_download_item_id, e)
|
||||||
|
self.__invoker.services.logger.error("Problem bulk downloading images.")
|
||||||
|
raise e
|
||||||
|
|
||||||
def _get_image_name_to_path_map(self, invoker: Invoker, image_names: list[str]) -> dict[str, str]:
|
def _create_zip_file(self, image_dtos: list[ImageDTO], bulk_download_item_id: str) -> str:
|
||||||
"""
|
|
||||||
Create a map of image names to their paths.
|
|
||||||
:param image_names: A list of image names.
|
|
||||||
"""
|
|
||||||
image_names_to_paths: dict[str, str] = {}
|
|
||||||
for image_name in image_names:
|
|
||||||
image_names_to_paths[image_name] = invoker.services.images.get_path(image_name)
|
|
||||||
return image_names_to_paths
|
|
||||||
|
|
||||||
def _create_zip_file(self, image_names_to_paths: dict[str, str], bulk_download_item_id: str) -> str:
|
|
||||||
"""
|
"""
|
||||||
Create a zip file containing the images specified by the given image names or board id.
|
Create a zip file containing the images specified by the given image names or board id.
|
||||||
If download with the same bulk_download_id already exists, it will be overwritten.
|
If download with the same bulk_download_id already exists, it will be overwritten.
|
||||||
@ -118,11 +113,14 @@ class BulkDownloadService(BulkDownloadBase):
|
|||||||
zip_file_path = self.__bulk_downloads_folder / (zip_file_name)
|
zip_file_path = self.__bulk_downloads_folder / (zip_file_name)
|
||||||
|
|
||||||
with ZipFile(zip_file_path, "w") as zip_file:
|
with ZipFile(zip_file_path, "w") as zip_file:
|
||||||
for image_name, image_path in image_names_to_paths.items():
|
for image_dto in image_dtos:
|
||||||
zip_file.write(image_path, arcname=image_name)
|
image_zip_path = Path(image_dto.image_category.value) / image_dto.image_name
|
||||||
|
image_path = self.__invoker.services.images.get_path(image_dto.image_name)
|
||||||
|
zip_file.write(image_path, arcname=image_zip_path)
|
||||||
|
|
||||||
return str(zip_file_name)
|
return str(zip_file_name)
|
||||||
|
|
||||||
|
# from https://stackoverflow.com/questions/7406102/create-sane-safe-filename-from-any-unsafe-string
|
||||||
def _clean_string_to_path_safe(self, s: str) -> str:
|
def _clean_string_to_path_safe(self, s: str) -> str:
|
||||||
"""Clean a string to be path safe."""
|
"""Clean a string to be path safe."""
|
||||||
return "".join([c for c in s if c.isalpha() or c.isdigit() or c == " "]).rstrip()
|
return "".join([c for c in s if c.isalpha() or c.isdigit() or c == " "]).rstrip()
|
||||||
|
Loading…
Reference in New Issue
Block a user