mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
moving the responsibility of cleaning up board names to the service not the route
This commit is contained in:
parent
3c881d5b1a
commit
ba28709f2d
@ -13,7 +13,6 @@ from invokeai.app.services.image_records.image_records_common import ImageCatego
|
|||||||
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
|
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
|
||||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID, WorkflowWithoutIDValidator
|
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID, WorkflowWithoutIDValidator
|
||||||
from invokeai.app.util.misc import uuid_string
|
|
||||||
|
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
@ -395,10 +394,7 @@ async def download_images_from_list(
|
|||||||
) -> ImagesDownloaded:
|
) -> ImagesDownloaded:
|
||||||
if (image_names is None or len(image_names) == 0) and board_id is None:
|
if (image_names is None or len(image_names) == 0) and board_id is None:
|
||||||
raise HTTPException(status_code=400, detail="No images or board id specified.")
|
raise HTTPException(status_code=400, detail="No images or board id specified.")
|
||||||
bulk_download_item_id: str = uuid_string() if board_id is None else board_id
|
bulk_download_item_id: str = ApiDependencies.invoker.services.bulk_download.generate_item_id(board_id)
|
||||||
board_name: str = (
|
|
||||||
"" if board_id is None else ApiDependencies.invoker.services.bulk_download.get_clean_board_name(board_id)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Type narrowing handled above ^, we know that image_names is not None, trying to keep null checks at the boundaries
|
# Type narrowing handled above ^, we know that image_names is not None, trying to keep null checks at the boundaries
|
||||||
background_tasks.add_task(
|
background_tasks.add_task(
|
||||||
@ -409,7 +405,7 @@ async def download_images_from_list(
|
|||||||
)
|
)
|
||||||
return ImagesDownloaded(
|
return ImagesDownloaded(
|
||||||
response="Your images are preparing to be downloaded",
|
response="Your images are preparing to be downloaded",
|
||||||
bulk_download_item_name=bulk_download_item_id if board_id is None else board_name + ".zip",
|
bulk_download_item_name=bulk_download_item_id + ".zip",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -30,9 +30,9 @@ class BulkDownloadBase(ABC):
|
|||||||
"""
|
"""
|
||||||
Starts a a bulk download job.
|
Starts a a bulk download job.
|
||||||
|
|
||||||
:param invoker: The Invoker that holds all the services, required to be passed as a parameter to avoid circular dependencies.
|
|
||||||
:param image_names: A list of image names to include in the zip file.
|
:param image_names: A list of image names to include in the zip file.
|
||||||
:param board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file.
|
:param board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file.
|
||||||
|
:param bulk_download_item_id: The bulk_download_item_id that will be used to retrieve the bulk download item when it is prepared, if none is provided a uuid will be generated.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -45,12 +45,12 @@ class BulkDownloadBase(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_clean_board_name(self, board_id: str) -> str:
|
def generate_item_id(self, board_id: Optional[str]) -> str:
|
||||||
"""
|
"""
|
||||||
Get the name of the board.
|
Generate an item ID for a bulk download item.
|
||||||
|
|
||||||
:param board_id: The ID of the board.
|
:param board_id: The ID of the board whose name is to be included in the item id.
|
||||||
:return: The name of the board.
|
:return: The generated item ID.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -61,12 +61,8 @@ class BulkDownloadBase(ABC):
|
|||||||
This method is responsible for stopping the BulkDownloadService and performing any necessary cleanup
|
This method is responsible for stopping the BulkDownloadService and performing any necessary cleanup
|
||||||
operations to remove any remnants or resources associated with the service.
|
operations to remove any remnants or resources associated with the service.
|
||||||
|
|
||||||
Args:
|
:param *args: Variable length argument list.
|
||||||
*args: Variable length argument list.
|
:param **kwargs: Arbitrary keyword arguments.
|
||||||
**kwargs: Arbitrary keyword arguments.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
None
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -50,19 +50,15 @@ class BulkDownloadService(BulkDownloadBase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
bulk_download_id: str = DEFAULT_BULK_DOWNLOAD_ID
|
bulk_download_id: str = DEFAULT_BULK_DOWNLOAD_ID
|
||||||
if bulk_download_item_id is None:
|
bulk_download_item_id = uuid_string() if bulk_download_item_id is None else bulk_download_item_id
|
||||||
bulk_download_item_id = uuid_string() if board_id is None else board_id
|
|
||||||
|
|
||||||
self._signal_job_started(bulk_download_id, bulk_download_item_id)
|
self._signal_job_started(bulk_download_id, bulk_download_item_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
board_name: str = ""
|
|
||||||
image_dtos: list[ImageDTO] = []
|
image_dtos: list[ImageDTO] = []
|
||||||
|
|
||||||
if board_id:
|
if board_id:
|
||||||
board_name = self.get_clean_board_name(board_id)
|
# -1 is the default value for limit, which means no limit, is_intermediate False only gives us completed images
|
||||||
|
|
||||||
# -1 is the default value for limit, which means no limit, is_intermediate only gives us completed images
|
|
||||||
image_dtos = self.__invoker.services.images.get_many(
|
image_dtos = self.__invoker.services.images.get_many(
|
||||||
offset=0,
|
offset=0,
|
||||||
limit=-1,
|
limit=-1,
|
||||||
@ -71,9 +67,7 @@ class BulkDownloadService(BulkDownloadBase):
|
|||||||
).items
|
).items
|
||||||
else:
|
else:
|
||||||
image_dtos = [self.__invoker.services.images.get_dto(image_name) for image_name in image_names]
|
image_dtos = [self.__invoker.services.images.get_dto(image_name) for image_name in image_names]
|
||||||
bulk_download_item_name: str = self._create_zip_file(
|
bulk_download_item_name: str = self._create_zip_file(image_dtos, bulk_download_item_id)
|
||||||
image_dtos, bulk_download_item_id if board_id is None else board_name
|
|
||||||
)
|
|
||||||
self._signal_job_completed(bulk_download_id, bulk_download_item_id, bulk_download_item_name)
|
self._signal_job_completed(bulk_download_id, bulk_download_item_id, bulk_download_item_name)
|
||||||
except (ImageRecordNotFoundException, BoardRecordNotFoundException, BulkDownloadException) as e:
|
except (ImageRecordNotFoundException, BoardRecordNotFoundException, BulkDownloadException) as e:
|
||||||
self._signal_job_failed(bulk_download_id, bulk_download_item_id, e)
|
self._signal_job_failed(bulk_download_id, bulk_download_item_id, e)
|
||||||
@ -82,7 +76,10 @@ class BulkDownloadService(BulkDownloadBase):
|
|||||||
self.__invoker.services.logger.error("Problem bulk downloading images.")
|
self.__invoker.services.logger.error("Problem bulk downloading images.")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def get_clean_board_name(self, board_id: str) -> str:
|
def generate_item_id(self, board_id: Optional[str]) -> str:
|
||||||
|
return uuid_string() if board_id is None else self._get_clean_board_name(board_id) + "_" + uuid_string()
|
||||||
|
|
||||||
|
def _get_clean_board_name(self, board_id: str) -> str:
|
||||||
if board_id == "none":
|
if board_id == "none":
|
||||||
return "Uncategorized"
|
return "Uncategorized"
|
||||||
|
|
||||||
@ -109,7 +106,7 @@ class BulkDownloadService(BulkDownloadBase):
|
|||||||
# from https://stackoverflow.com/questions/7406102/create-sane-safe-filename-from-any-unsafe-string
|
# from https://stackoverflow.com/questions/7406102/create-sane-safe-filename-from-any-unsafe-string
|
||||||
def _clean_string_to_path_safe(self, s: str) -> str:
|
def _clean_string_to_path_safe(self, s: str) -> str:
|
||||||
"""Clean a string to be path safe."""
|
"""Clean a string to be path safe."""
|
||||||
return "".join([c for c in s if c.isalpha() or c.isdigit() or c == " "]).rstrip()
|
return "".join([c for c in s if c.isalpha() or c.isdigit() or c == " " or c == "_" or c == "-"]).rstrip()
|
||||||
|
|
||||||
def _signal_job_started(self, bulk_download_id: str, bulk_download_item_id: str) -> None:
|
def _signal_job_started(self, bulk_download_id: str, bulk_download_item_id: str) -> None:
|
||||||
"""Signal that a bulk download job has started."""
|
"""Signal that a bulk download job has started."""
|
||||||
@ -166,11 +163,11 @@ class BulkDownloadService(BulkDownloadBase):
|
|||||||
:return: The path to the bulk download file.
|
:return: The path to the bulk download file.
|
||||||
"""
|
"""
|
||||||
path = str(self.__bulk_downloads_folder / bulk_download_item_name)
|
path = str(self.__bulk_downloads_folder / bulk_download_item_name)
|
||||||
if not self.validate_path(path):
|
if not self._is_valid_path(path):
|
||||||
raise BulkDownloadTargetException()
|
raise BulkDownloadTargetException()
|
||||||
return path
|
return path
|
||||||
|
|
||||||
def validate_path(self, path: Union[str, Path]) -> bool:
|
def _is_valid_path(self, path: Union[str, Path]) -> bool:
|
||||||
"""Validates the path given for a bulk download."""
|
"""Validates the path given for a bulk download."""
|
||||||
path = path if isinstance(path, Path) else Path(path)
|
path = path if isinstance(path, Path) else Path(path)
|
||||||
return path.exists()
|
return path.exists()
|
||||||
|
@ -71,17 +71,10 @@ class MockApiDependencies(ApiDependencies):
|
|||||||
def test_download_images_from_list(monkeypatch: Any, mock_invoker: Invoker) -> None:
|
def test_download_images_from_list(monkeypatch: Any, mock_invoker: Invoker) -> None:
|
||||||
prepare_download_images_test(monkeypatch, mock_invoker)
|
prepare_download_images_test(monkeypatch, mock_invoker)
|
||||||
|
|
||||||
def mock_uuid_string():
|
|
||||||
return "test"
|
|
||||||
|
|
||||||
# You have to patch the function within the module it's being imported into. This is strange, but it works.
|
|
||||||
# See http://www.gregreda.com/2021/06/28/mocking-imported-module-function-python/
|
|
||||||
monkeypatch.setattr("invokeai.app.api.routers.images.uuid_string", mock_uuid_string)
|
|
||||||
|
|
||||||
response = client.post("/api/v1/images/download", json={"image_names": ["test.png"]})
|
response = client.post("/api/v1/images/download", json={"image_names": ["test.png"]})
|
||||||
json_response = response.json()
|
json_response = response.json()
|
||||||
assert response.status_code == 202
|
assert response.status_code == 202
|
||||||
assert json_response["bulk_download_item_name"] == "test"
|
assert json_response["bulk_download_item_name"] == "test.zip"
|
||||||
|
|
||||||
|
|
||||||
def test_download_images_from_board_id_empty_image_name_list(monkeypatch: Any, mock_invoker: Invoker) -> None:
|
def test_download_images_from_board_id_empty_image_name_list(monkeypatch: Any, mock_invoker: Invoker) -> None:
|
||||||
@ -101,6 +94,10 @@ def test_download_images_from_board_id_empty_image_name_list(monkeypatch: Any, m
|
|||||||
|
|
||||||
def prepare_download_images_test(monkeypatch: Any, mock_invoker: Invoker) -> None:
|
def prepare_download_images_test(monkeypatch: Any, mock_invoker: Invoker) -> None:
|
||||||
monkeypatch.setattr("invokeai.app.api.routers.images.ApiDependencies", MockApiDependencies(mock_invoker))
|
monkeypatch.setattr("invokeai.app.api.routers.images.ApiDependencies", MockApiDependencies(mock_invoker))
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"invokeai.app.api.routers.images.ApiDependencies.invoker.services.bulk_download.generate_item_id",
|
||||||
|
lambda arg: "test",
|
||||||
|
)
|
||||||
|
|
||||||
def mock_add_task(*args, **kwargs):
|
def mock_add_task(*args, **kwargs):
|
||||||
return None
|
return None
|
||||||
|
@ -151,6 +151,42 @@ def test_handler_image_names(tmp_path: Path, monkeypatch: Any, mock_image_dto: I
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_id(monkeypatch: Any):
|
||||||
|
"""Test that the generate_id method generates a unique id."""
|
||||||
|
|
||||||
|
bulk_download_service = BulkDownloadService("test")
|
||||||
|
|
||||||
|
monkeypatch.setattr("invokeai.app.services.bulk_download.bulk_download_default.uuid_string", lambda: "test")
|
||||||
|
|
||||||
|
assert bulk_download_service.generate_item_id(None) == "test"
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_id_with_board_id(monkeypatch: Any, mock_invoker: Invoker):
|
||||||
|
"""Test that the generate_id method generates a unique id with a board id."""
|
||||||
|
|
||||||
|
bulk_download_service = BulkDownloadService("test")
|
||||||
|
bulk_download_service.start(mock_invoker)
|
||||||
|
|
||||||
|
def mock_board_get(*args, **kwargs):
|
||||||
|
return BoardRecord(board_id="12345", board_name="test_board_name", created_at="None", updated_at="None")
|
||||||
|
|
||||||
|
monkeypatch.setattr(mock_invoker.services.board_records, "get", mock_board_get)
|
||||||
|
|
||||||
|
monkeypatch.setattr("invokeai.app.services.bulk_download.bulk_download_default.uuid_string", lambda: "test")
|
||||||
|
|
||||||
|
assert bulk_download_service.generate_item_id("12345") == "test_board_name_test"
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_id_with_default_board_id(monkeypatch: Any):
|
||||||
|
"""Test that the generate_id method generates a unique id with a board id."""
|
||||||
|
|
||||||
|
bulk_download_service = BulkDownloadService("test")
|
||||||
|
|
||||||
|
monkeypatch.setattr("invokeai.app.services.bulk_download.bulk_download_default.uuid_string", lambda: "test")
|
||||||
|
|
||||||
|
assert bulk_download_service.generate_item_id("none") == "Uncategorized_test"
|
||||||
|
|
||||||
|
|
||||||
def test_handler_board_id(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker):
|
def test_handler_board_id(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker):
|
||||||
"""Test that the handler creates the zip file correctly when given a board id."""
|
"""Test that the handler creates the zip file correctly when given a board id."""
|
||||||
|
|
||||||
@ -159,7 +195,7 @@ def test_handler_board_id(tmp_path: Path, monkeypatch: Any, mock_image_dto: Imag
|
|||||||
)
|
)
|
||||||
|
|
||||||
def mock_board_get(*args, **kwargs):
|
def mock_board_get(*args, **kwargs):
|
||||||
return BoardRecord(board_id="12345", board_name="test", created_at="None", updated_at="None")
|
return BoardRecord(board_id="12345", board_name="test_board_name", created_at="None", updated_at="None")
|
||||||
|
|
||||||
monkeypatch.setattr(mock_invoker.services.board_records, "get", mock_board_get)
|
monkeypatch.setattr(mock_invoker.services.board_records, "get", mock_board_get)
|
||||||
|
|
||||||
@ -193,14 +229,14 @@ def test_handler_board_id_default(tmp_path: Path, monkeypatch: Any, mock_image_d
|
|||||||
bulk_download_service.start(mock_invoker)
|
bulk_download_service.start(mock_invoker)
|
||||||
bulk_download_service.handler([], "none", None)
|
bulk_download_service.handler([], "none", None)
|
||||||
|
|
||||||
expected_zip_path: Path = tmp_path / "bulk_downloads" / "Uncategorized.zip"
|
expected_zip_path: Path = tmp_path / "bulk_downloads" / "test.zip"
|
||||||
|
|
||||||
assert_handler_success(
|
assert_handler_success(
|
||||||
expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events
|
expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_handler_bulk_download__item_id_given(
|
def test_handler_bulk_download_item_id_given(
|
||||||
tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker
|
tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker
|
||||||
):
|
):
|
||||||
"""Test that the handler creates the zip file correctly when given a pregenerated bulk download item id."""
|
"""Test that the handler creates the zip file correctly when given a pregenerated bulk download item id."""
|
||||||
@ -350,35 +386,6 @@ def execute_handler_test_on_error(
|
|||||||
assert event_bus.events[1].payload["error"] == error.__str__()
|
assert event_bus.events[1].payload["error"] == error.__str__()
|
||||||
|
|
||||||
|
|
||||||
def test_get_board_name(tmp_path: Path, monkeypatch: Any, mock_invoker: Invoker):
|
|
||||||
"""Test that the get_board_name function returns the correct board name."""
|
|
||||||
|
|
||||||
expected_board_name = "board1"
|
|
||||||
|
|
||||||
def mock_get(*args, **kwargs):
|
|
||||||
return BoardRecord(board_id="12345", board_name=expected_board_name, created_at="None", updated_at="None")
|
|
||||||
|
|
||||||
monkeypatch.setattr(mock_invoker.services.board_records, "get", mock_get)
|
|
||||||
|
|
||||||
bulk_download_service = BulkDownloadService(tmp_path)
|
|
||||||
bulk_download_service.start(mock_invoker)
|
|
||||||
board_name = bulk_download_service.get_clean_board_name("12345")
|
|
||||||
|
|
||||||
assert board_name == expected_board_name
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_board_name_default(tmp_path: Path, mock_invoker: Invoker):
|
|
||||||
"""Test that the get_board_name function returns the correct board name."""
|
|
||||||
|
|
||||||
expected_board_name = "Uncategorized"
|
|
||||||
|
|
||||||
bulk_download_service = BulkDownloadService(tmp_path)
|
|
||||||
bulk_download_service.start(mock_invoker)
|
|
||||||
board_name = bulk_download_service.get_clean_board_name("none")
|
|
||||||
|
|
||||||
assert board_name == expected_board_name
|
|
||||||
|
|
||||||
|
|
||||||
def test_delete(tmp_path: Path):
|
def test_delete(tmp_path: Path):
|
||||||
"""Test that the delete method removes the bulk download file."""
|
"""Test that the delete method removes the bulk download file."""
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user