moving the responsibility of cleaning up board names to the service not the route

This commit is contained in:
Stefan Tobler 2024-02-16 15:50:48 -05:00 committed by Brandon Rising
parent 3c881d5b1a
commit ba28709f2d
5 changed files with 63 additions and 70 deletions

View File

@ -13,7 +13,6 @@ from invokeai.app.services.image_records.image_records_common import ImageCatego
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID, WorkflowWithoutIDValidator from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID, WorkflowWithoutIDValidator
from invokeai.app.util.misc import uuid_string
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
@ -395,10 +394,7 @@ async def download_images_from_list(
) -> ImagesDownloaded: ) -> ImagesDownloaded:
if (image_names is None or len(image_names) == 0) and board_id is None: if (image_names is None or len(image_names) == 0) and board_id is None:
raise HTTPException(status_code=400, detail="No images or board id specified.") raise HTTPException(status_code=400, detail="No images or board id specified.")
bulk_download_item_id: str = uuid_string() if board_id is None else board_id bulk_download_item_id: str = ApiDependencies.invoker.services.bulk_download.generate_item_id(board_id)
board_name: str = (
"" if board_id is None else ApiDependencies.invoker.services.bulk_download.get_clean_board_name(board_id)
)
# Type narrowing handled above ^, we know that image_names is not None, trying to keep null checks at the boundaries # Type narrowing handled above ^, we know that image_names is not None, trying to keep null checks at the boundaries
background_tasks.add_task( background_tasks.add_task(
@ -409,7 +405,7 @@ async def download_images_from_list(
) )
return ImagesDownloaded( return ImagesDownloaded(
response="Your images are preparing to be downloaded", response="Your images are preparing to be downloaded",
bulk_download_item_name=bulk_download_item_id if board_id is None else board_name + ".zip", bulk_download_item_name=bulk_download_item_id + ".zip",
) )

View File

@ -30,9 +30,9 @@ class BulkDownloadBase(ABC):
""" """
Starts a a bulk download job. Starts a a bulk download job.
:param invoker: The Invoker that holds all the services, required to be passed as a parameter to avoid circular dependencies.
:param image_names: A list of image names to include in the zip file. :param image_names: A list of image names to include in the zip file.
:param board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file. :param board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file.
:param bulk_download_item_id: The bulk_download_item_id that will be used to retrieve the bulk download item when it is prepared, if none is provided a uuid will be generated.
""" """
@abstractmethod @abstractmethod
@ -45,12 +45,12 @@ class BulkDownloadBase(ABC):
""" """
@abstractmethod @abstractmethod
def get_clean_board_name(self, board_id: str) -> str: def generate_item_id(self, board_id: Optional[str]) -> str:
""" """
Get the name of the board. Generate an item ID for a bulk download item.
:param board_id: The ID of the board. :param board_id: The ID of the board whose name is to be included in the item id.
:return: The name of the board. :return: The generated item ID.
""" """
@abstractmethod @abstractmethod
@ -61,12 +61,8 @@ class BulkDownloadBase(ABC):
This method is responsible for stopping the BulkDownloadService and performing any necessary cleanup This method is responsible for stopping the BulkDownloadService and performing any necessary cleanup
operations to remove any remnants or resources associated with the service. operations to remove any remnants or resources associated with the service.
Args: :param *args: Variable length argument list.
*args: Variable length argument list. :param **kwargs: Arbitrary keyword arguments.
**kwargs: Arbitrary keyword arguments.
Returns:
None
""" """
@abstractmethod @abstractmethod

View File

@ -50,19 +50,15 @@ class BulkDownloadService(BulkDownloadBase):
""" """
bulk_download_id: str = DEFAULT_BULK_DOWNLOAD_ID bulk_download_id: str = DEFAULT_BULK_DOWNLOAD_ID
if bulk_download_item_id is None: bulk_download_item_id = uuid_string() if bulk_download_item_id is None else bulk_download_item_id
bulk_download_item_id = uuid_string() if board_id is None else board_id
self._signal_job_started(bulk_download_id, bulk_download_item_id) self._signal_job_started(bulk_download_id, bulk_download_item_id)
try: try:
board_name: str = ""
image_dtos: list[ImageDTO] = [] image_dtos: list[ImageDTO] = []
if board_id: if board_id:
board_name = self.get_clean_board_name(board_id) # -1 is the default value for limit, which means no limit, is_intermediate False only gives us completed images
# -1 is the default value for limit, which means no limit, is_intermediate only gives us completed images
image_dtos = self.__invoker.services.images.get_many( image_dtos = self.__invoker.services.images.get_many(
offset=0, offset=0,
limit=-1, limit=-1,
@ -71,9 +67,7 @@ class BulkDownloadService(BulkDownloadBase):
).items ).items
else: else:
image_dtos = [self.__invoker.services.images.get_dto(image_name) for image_name in image_names] image_dtos = [self.__invoker.services.images.get_dto(image_name) for image_name in image_names]
bulk_download_item_name: str = self._create_zip_file( bulk_download_item_name: str = self._create_zip_file(image_dtos, bulk_download_item_id)
image_dtos, bulk_download_item_id if board_id is None else board_name
)
self._signal_job_completed(bulk_download_id, bulk_download_item_id, bulk_download_item_name) self._signal_job_completed(bulk_download_id, bulk_download_item_id, bulk_download_item_name)
except (ImageRecordNotFoundException, BoardRecordNotFoundException, BulkDownloadException) as e: except (ImageRecordNotFoundException, BoardRecordNotFoundException, BulkDownloadException) as e:
self._signal_job_failed(bulk_download_id, bulk_download_item_id, e) self._signal_job_failed(bulk_download_id, bulk_download_item_id, e)
@ -82,7 +76,10 @@ class BulkDownloadService(BulkDownloadBase):
self.__invoker.services.logger.error("Problem bulk downloading images.") self.__invoker.services.logger.error("Problem bulk downloading images.")
raise e raise e
def get_clean_board_name(self, board_id: str) -> str: def generate_item_id(self, board_id: Optional[str]) -> str:
return uuid_string() if board_id is None else self._get_clean_board_name(board_id) + "_" + uuid_string()
def _get_clean_board_name(self, board_id: str) -> str:
if board_id == "none": if board_id == "none":
return "Uncategorized" return "Uncategorized"
@ -109,7 +106,7 @@ class BulkDownloadService(BulkDownloadBase):
# from https://stackoverflow.com/questions/7406102/create-sane-safe-filename-from-any-unsafe-string # from https://stackoverflow.com/questions/7406102/create-sane-safe-filename-from-any-unsafe-string
def _clean_string_to_path_safe(self, s: str) -> str: def _clean_string_to_path_safe(self, s: str) -> str:
"""Clean a string to be path safe.""" """Clean a string to be path safe."""
return "".join([c for c in s if c.isalpha() or c.isdigit() or c == " "]).rstrip() return "".join([c for c in s if c.isalpha() or c.isdigit() or c == " " or c == "_" or c == "-"]).rstrip()
def _signal_job_started(self, bulk_download_id: str, bulk_download_item_id: str) -> None: def _signal_job_started(self, bulk_download_id: str, bulk_download_item_id: str) -> None:
"""Signal that a bulk download job has started.""" """Signal that a bulk download job has started."""
@ -166,11 +163,11 @@ class BulkDownloadService(BulkDownloadBase):
:return: The path to the bulk download file. :return: The path to the bulk download file.
""" """
path = str(self.__bulk_downloads_folder / bulk_download_item_name) path = str(self.__bulk_downloads_folder / bulk_download_item_name)
if not self.validate_path(path): if not self._is_valid_path(path):
raise BulkDownloadTargetException() raise BulkDownloadTargetException()
return path return path
def validate_path(self, path: Union[str, Path]) -> bool: def _is_valid_path(self, path: Union[str, Path]) -> bool:
"""Validates the path given for a bulk download.""" """Validates the path given for a bulk download."""
path = path if isinstance(path, Path) else Path(path) path = path if isinstance(path, Path) else Path(path)
return path.exists() return path.exists()

View File

@ -71,17 +71,10 @@ class MockApiDependencies(ApiDependencies):
def test_download_images_from_list(monkeypatch: Any, mock_invoker: Invoker) -> None: def test_download_images_from_list(monkeypatch: Any, mock_invoker: Invoker) -> None:
prepare_download_images_test(monkeypatch, mock_invoker) prepare_download_images_test(monkeypatch, mock_invoker)
def mock_uuid_string():
return "test"
# You have to patch the function within the module it's being imported into. This is strange, but it works.
# See http://www.gregreda.com/2021/06/28/mocking-imported-module-function-python/
monkeypatch.setattr("invokeai.app.api.routers.images.uuid_string", mock_uuid_string)
response = client.post("/api/v1/images/download", json={"image_names": ["test.png"]}) response = client.post("/api/v1/images/download", json={"image_names": ["test.png"]})
json_response = response.json() json_response = response.json()
assert response.status_code == 202 assert response.status_code == 202
assert json_response["bulk_download_item_name"] == "test" assert json_response["bulk_download_item_name"] == "test.zip"
def test_download_images_from_board_id_empty_image_name_list(monkeypatch: Any, mock_invoker: Invoker) -> None: def test_download_images_from_board_id_empty_image_name_list(monkeypatch: Any, mock_invoker: Invoker) -> None:
@ -101,6 +94,10 @@ def test_download_images_from_board_id_empty_image_name_list(monkeypatch: Any, m
def prepare_download_images_test(monkeypatch: Any, mock_invoker: Invoker) -> None: def prepare_download_images_test(monkeypatch: Any, mock_invoker: Invoker) -> None:
monkeypatch.setattr("invokeai.app.api.routers.images.ApiDependencies", MockApiDependencies(mock_invoker)) monkeypatch.setattr("invokeai.app.api.routers.images.ApiDependencies", MockApiDependencies(mock_invoker))
monkeypatch.setattr(
"invokeai.app.api.routers.images.ApiDependencies.invoker.services.bulk_download.generate_item_id",
lambda arg: "test",
)
def mock_add_task(*args, **kwargs): def mock_add_task(*args, **kwargs):
return None return None

View File

@ -151,6 +151,42 @@ def test_handler_image_names(tmp_path: Path, monkeypatch: Any, mock_image_dto: I
) )
def test_generate_id(monkeypatch: Any):
"""Test that the generate_id method generates a unique id."""
bulk_download_service = BulkDownloadService("test")
monkeypatch.setattr("invokeai.app.services.bulk_download.bulk_download_default.uuid_string", lambda: "test")
assert bulk_download_service.generate_item_id(None) == "test"
def test_generate_id_with_board_id(monkeypatch: Any, mock_invoker: Invoker):
"""Test that the generate_id method generates a unique id with a board id."""
bulk_download_service = BulkDownloadService("test")
bulk_download_service.start(mock_invoker)
def mock_board_get(*args, **kwargs):
return BoardRecord(board_id="12345", board_name="test_board_name", created_at="None", updated_at="None")
monkeypatch.setattr(mock_invoker.services.board_records, "get", mock_board_get)
monkeypatch.setattr("invokeai.app.services.bulk_download.bulk_download_default.uuid_string", lambda: "test")
assert bulk_download_service.generate_item_id("12345") == "test_board_name_test"
def test_generate_id_with_default_board_id(monkeypatch: Any):
"""Test that the generate_id method generates a unique id with a board id."""
bulk_download_service = BulkDownloadService("test")
monkeypatch.setattr("invokeai.app.services.bulk_download.bulk_download_default.uuid_string", lambda: "test")
assert bulk_download_service.generate_item_id("none") == "Uncategorized_test"
def test_handler_board_id(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker): def test_handler_board_id(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker):
"""Test that the handler creates the zip file correctly when given a board id.""" """Test that the handler creates the zip file correctly when given a board id."""
@ -159,7 +195,7 @@ def test_handler_board_id(tmp_path: Path, monkeypatch: Any, mock_image_dto: Imag
) )
def mock_board_get(*args, **kwargs): def mock_board_get(*args, **kwargs):
return BoardRecord(board_id="12345", board_name="test", created_at="None", updated_at="None") return BoardRecord(board_id="12345", board_name="test_board_name", created_at="None", updated_at="None")
monkeypatch.setattr(mock_invoker.services.board_records, "get", mock_board_get) monkeypatch.setattr(mock_invoker.services.board_records, "get", mock_board_get)
@ -193,14 +229,14 @@ def test_handler_board_id_default(tmp_path: Path, monkeypatch: Any, mock_image_d
bulk_download_service.start(mock_invoker) bulk_download_service.start(mock_invoker)
bulk_download_service.handler([], "none", None) bulk_download_service.handler([], "none", None)
expected_zip_path: Path = tmp_path / "bulk_downloads" / "Uncategorized.zip" expected_zip_path: Path = tmp_path / "bulk_downloads" / "test.zip"
assert_handler_success( assert_handler_success(
expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events
) )
def test_handler_bulk_download__item_id_given( def test_handler_bulk_download_item_id_given(
tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker
): ):
"""Test that the handler creates the zip file correctly when given a pregenerated bulk download item id.""" """Test that the handler creates the zip file correctly when given a pregenerated bulk download item id."""
@ -350,35 +386,6 @@ def execute_handler_test_on_error(
assert event_bus.events[1].payload["error"] == error.__str__() assert event_bus.events[1].payload["error"] == error.__str__()
def test_get_board_name(tmp_path: Path, monkeypatch: Any, mock_invoker: Invoker):
"""Test that the get_board_name function returns the correct board name."""
expected_board_name = "board1"
def mock_get(*args, **kwargs):
return BoardRecord(board_id="12345", board_name=expected_board_name, created_at="None", updated_at="None")
monkeypatch.setattr(mock_invoker.services.board_records, "get", mock_get)
bulk_download_service = BulkDownloadService(tmp_path)
bulk_download_service.start(mock_invoker)
board_name = bulk_download_service.get_clean_board_name("12345")
assert board_name == expected_board_name
def test_get_board_name_default(tmp_path: Path, mock_invoker: Invoker):
"""Test that the get_board_name function returns the correct board name."""
expected_board_name = "Uncategorized"
bulk_download_service = BulkDownloadService(tmp_path)
bulk_download_service.start(mock_invoker)
board_name = bulk_download_service.get_clean_board_name("none")
assert board_name == expected_board_name
def test_delete(tmp_path: Path): def test_delete(tmp_path: Path):
"""Test that the delete method removes the bulk download file.""" """Test that the delete method removes the bulk download file."""