From ba28709f2d2216aeb34e40561bdd05324b70ad02 Mon Sep 17 00:00:00 2001 From: Stefan Tobler Date: Fri, 16 Feb 2024 15:50:48 -0500 Subject: [PATCH] moving the responsibility of cleaning up board names to the service not the route --- invokeai/app/api/routers/images.py | 8 +-- .../bulk_download/bulk_download_base.py | 18 ++--- .../bulk_download/bulk_download_default.py | 23 +++--- tests/app/routers/test_images.py | 13 ++-- .../bulk_download/test_bulk_download.py | 71 ++++++++++--------- 5 files changed, 63 insertions(+), 70 deletions(-) diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index 69a76e4062..d1c64648de 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -13,7 +13,6 @@ from invokeai.app.services.image_records.image_records_common import ImageCatego from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID, WorkflowWithoutIDValidator -from invokeai.app.util.misc import uuid_string from ..dependencies import ApiDependencies @@ -395,10 +394,7 @@ async def download_images_from_list( ) -> ImagesDownloaded: if (image_names is None or len(image_names) == 0) and board_id is None: raise HTTPException(status_code=400, detail="No images or board id specified.") - bulk_download_item_id: str = uuid_string() if board_id is None else board_id - board_name: str = ( - "" if board_id is None else ApiDependencies.invoker.services.bulk_download.get_clean_board_name(board_id) - ) + bulk_download_item_id: str = ApiDependencies.invoker.services.bulk_download.generate_item_id(board_id) # Type narrowing handled above ^, we know that image_names is not None, trying to keep null checks at the boundaries background_tasks.add_task( @@ -409,7 +405,7 @@ async def download_images_from_list( ) return ImagesDownloaded( response="Your images are preparing to be downloaded", - bulk_download_item_name=bulk_download_item_id if board_id is None else board_name + ".zip", + bulk_download_item_name=bulk_download_item_id + ".zip", ) diff --git a/invokeai/app/services/bulk_download/bulk_download_base.py b/invokeai/app/services/bulk_download/bulk_download_base.py index 89b2e73772..5199652ad4 100644 --- a/invokeai/app/services/bulk_download/bulk_download_base.py +++ b/invokeai/app/services/bulk_download/bulk_download_base.py @@ -30,9 +30,9 @@ class BulkDownloadBase(ABC): """ Starts a a bulk download job. - :param invoker: The Invoker that holds all the services, required to be passed as a parameter to avoid circular dependencies. :param image_names: A list of image names to include in the zip file. :param board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file. + :param bulk_download_item_id: The bulk_download_item_id that will be used to retrieve the bulk download item when it is prepared, if none is provided a uuid will be generated. """ @abstractmethod @@ -45,12 +45,12 @@ class BulkDownloadBase(ABC): """ @abstractmethod - def get_clean_board_name(self, board_id: str) -> str: + def generate_item_id(self, board_id: Optional[str]) -> str: """ - Get the name of the board. + Generate an item ID for a bulk download item. - :param board_id: The ID of the board. - :return: The name of the board. + :param board_id: The ID of the board whose name is to be included in the item id. + :return: The generated item ID. """ @abstractmethod @@ -61,12 +61,8 @@ class BulkDownloadBase(ABC): This method is responsible for stopping the BulkDownloadService and performing any necessary cleanup operations to remove any remnants or resources associated with the service. - Args: - *args: Variable length argument list. - **kwargs: Arbitrary keyword arguments. - - Returns: - None + :param *args: Variable length argument list. + :param **kwargs: Arbitrary keyword arguments. """ @abstractmethod diff --git a/invokeai/app/services/bulk_download/bulk_download_default.py b/invokeai/app/services/bulk_download/bulk_download_default.py index fe76a12333..406bd7d997 100644 --- a/invokeai/app/services/bulk_download/bulk_download_default.py +++ b/invokeai/app/services/bulk_download/bulk_download_default.py @@ -50,19 +50,15 @@ class BulkDownloadService(BulkDownloadBase): """ bulk_download_id: str = DEFAULT_BULK_DOWNLOAD_ID - if bulk_download_item_id is None: - bulk_download_item_id = uuid_string() if board_id is None else board_id + bulk_download_item_id = uuid_string() if bulk_download_item_id is None else bulk_download_item_id self._signal_job_started(bulk_download_id, bulk_download_item_id) try: - board_name: str = "" image_dtos: list[ImageDTO] = [] if board_id: - board_name = self.get_clean_board_name(board_id) - - # -1 is the default value for limit, which means no limit, is_intermediate only gives us completed images + # -1 is the default value for limit, which means no limit, is_intermediate False only gives us completed images image_dtos = self.__invoker.services.images.get_many( offset=0, limit=-1, @@ -71,9 +67,7 @@ class BulkDownloadService(BulkDownloadBase): ).items else: image_dtos = [self.__invoker.services.images.get_dto(image_name) for image_name in image_names] - bulk_download_item_name: str = self._create_zip_file( - image_dtos, bulk_download_item_id if board_id is None else board_name - ) + bulk_download_item_name: str = self._create_zip_file(image_dtos, bulk_download_item_id) self._signal_job_completed(bulk_download_id, bulk_download_item_id, bulk_download_item_name) except (ImageRecordNotFoundException, BoardRecordNotFoundException, BulkDownloadException) as e: self._signal_job_failed(bulk_download_id, bulk_download_item_id, e) @@ -82,7 +76,10 @@ class BulkDownloadService(BulkDownloadBase): self.__invoker.services.logger.error("Problem bulk downloading images.") raise e - def get_clean_board_name(self, board_id: str) -> str: + def generate_item_id(self, board_id: Optional[str]) -> str: + return uuid_string() if board_id is None else self._get_clean_board_name(board_id) + "_" + uuid_string() + + def _get_clean_board_name(self, board_id: str) -> str: if board_id == "none": return "Uncategorized" @@ -109,7 +106,7 @@ class BulkDownloadService(BulkDownloadBase): # from https://stackoverflow.com/questions/7406102/create-sane-safe-filename-from-any-unsafe-string def _clean_string_to_path_safe(self, s: str) -> str: """Clean a string to be path safe.""" - return "".join([c for c in s if c.isalpha() or c.isdigit() or c == " "]).rstrip() + return "".join([c for c in s if c.isalpha() or c.isdigit() or c == " " or c == "_" or c == "-"]).rstrip() def _signal_job_started(self, bulk_download_id: str, bulk_download_item_id: str) -> None: """Signal that a bulk download job has started.""" @@ -166,11 +163,11 @@ class BulkDownloadService(BulkDownloadBase): :return: The path to the bulk download file. """ path = str(self.__bulk_downloads_folder / bulk_download_item_name) - if not self.validate_path(path): + if not self._is_valid_path(path): raise BulkDownloadTargetException() return path - def validate_path(self, path: Union[str, Path]) -> bool: + def _is_valid_path(self, path: Union[str, Path]) -> bool: """Validates the path given for a bulk download.""" path = path if isinstance(path, Path) else Path(path) return path.exists() diff --git a/tests/app/routers/test_images.py b/tests/app/routers/test_images.py index a709daf24e..e8521bf132 100644 --- a/tests/app/routers/test_images.py +++ b/tests/app/routers/test_images.py @@ -71,17 +71,10 @@ class MockApiDependencies(ApiDependencies): def test_download_images_from_list(monkeypatch: Any, mock_invoker: Invoker) -> None: prepare_download_images_test(monkeypatch, mock_invoker) - def mock_uuid_string(): - return "test" - - # You have to patch the function within the module it's being imported into. This is strange, but it works. - # See http://www.gregreda.com/2021/06/28/mocking-imported-module-function-python/ - monkeypatch.setattr("invokeai.app.api.routers.images.uuid_string", mock_uuid_string) - response = client.post("/api/v1/images/download", json={"image_names": ["test.png"]}) json_response = response.json() assert response.status_code == 202 - assert json_response["bulk_download_item_name"] == "test" + assert json_response["bulk_download_item_name"] == "test.zip" def test_download_images_from_board_id_empty_image_name_list(monkeypatch: Any, mock_invoker: Invoker) -> None: @@ -101,6 +94,10 @@ def test_download_images_from_board_id_empty_image_name_list(monkeypatch: Any, m def prepare_download_images_test(monkeypatch: Any, mock_invoker: Invoker) -> None: monkeypatch.setattr("invokeai.app.api.routers.images.ApiDependencies", MockApiDependencies(mock_invoker)) + monkeypatch.setattr( + "invokeai.app.api.routers.images.ApiDependencies.invoker.services.bulk_download.generate_item_id", + lambda arg: "test", + ) def mock_add_task(*args, **kwargs): return None diff --git a/tests/app/services/bulk_download/test_bulk_download.py b/tests/app/services/bulk_download/test_bulk_download.py index 924385f7e1..d70510cd91 100644 --- a/tests/app/services/bulk_download/test_bulk_download.py +++ b/tests/app/services/bulk_download/test_bulk_download.py @@ -151,6 +151,42 @@ def test_handler_image_names(tmp_path: Path, monkeypatch: Any, mock_image_dto: I ) +def test_generate_id(monkeypatch: Any): + """Test that the generate_id method generates a unique id.""" + + bulk_download_service = BulkDownloadService("test") + + monkeypatch.setattr("invokeai.app.services.bulk_download.bulk_download_default.uuid_string", lambda: "test") + + assert bulk_download_service.generate_item_id(None) == "test" + + +def test_generate_id_with_board_id(monkeypatch: Any, mock_invoker: Invoker): + """Test that the generate_id method generates a unique id with a board id.""" + + bulk_download_service = BulkDownloadService("test") + bulk_download_service.start(mock_invoker) + + def mock_board_get(*args, **kwargs): + return BoardRecord(board_id="12345", board_name="test_board_name", created_at="None", updated_at="None") + + monkeypatch.setattr(mock_invoker.services.board_records, "get", mock_board_get) + + monkeypatch.setattr("invokeai.app.services.bulk_download.bulk_download_default.uuid_string", lambda: "test") + + assert bulk_download_service.generate_item_id("12345") == "test_board_name_test" + + +def test_generate_id_with_default_board_id(monkeypatch: Any): + """Test that the generate_id method generates a unique id with a board id.""" + + bulk_download_service = BulkDownloadService("test") + + monkeypatch.setattr("invokeai.app.services.bulk_download.bulk_download_default.uuid_string", lambda: "test") + + assert bulk_download_service.generate_item_id("none") == "Uncategorized_test" + + def test_handler_board_id(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker): """Test that the handler creates the zip file correctly when given a board id.""" @@ -159,7 +195,7 @@ def test_handler_board_id(tmp_path: Path, monkeypatch: Any, mock_image_dto: Imag ) def mock_board_get(*args, **kwargs): - return BoardRecord(board_id="12345", board_name="test", created_at="None", updated_at="None") + return BoardRecord(board_id="12345", board_name="test_board_name", created_at="None", updated_at="None") monkeypatch.setattr(mock_invoker.services.board_records, "get", mock_board_get) @@ -193,14 +229,14 @@ def test_handler_board_id_default(tmp_path: Path, monkeypatch: Any, mock_image_d bulk_download_service.start(mock_invoker) bulk_download_service.handler([], "none", None) - expected_zip_path: Path = tmp_path / "bulk_downloads" / "Uncategorized.zip" + expected_zip_path: Path = tmp_path / "bulk_downloads" / "test.zip" assert_handler_success( expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events ) -def test_handler_bulk_download__item_id_given( +def test_handler_bulk_download_item_id_given( tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker ): """Test that the handler creates the zip file correctly when given a pregenerated bulk download item id.""" @@ -350,35 +386,6 @@ def execute_handler_test_on_error( assert event_bus.events[1].payload["error"] == error.__str__() -def test_get_board_name(tmp_path: Path, monkeypatch: Any, mock_invoker: Invoker): - """Test that the get_board_name function returns the correct board name.""" - - expected_board_name = "board1" - - def mock_get(*args, **kwargs): - return BoardRecord(board_id="12345", board_name=expected_board_name, created_at="None", updated_at="None") - - monkeypatch.setattr(mock_invoker.services.board_records, "get", mock_get) - - bulk_download_service = BulkDownloadService(tmp_path) - bulk_download_service.start(mock_invoker) - board_name = bulk_download_service.get_clean_board_name("12345") - - assert board_name == expected_board_name - - -def test_get_board_name_default(tmp_path: Path, mock_invoker: Invoker): - """Test that the get_board_name function returns the correct board name.""" - - expected_board_name = "Uncategorized" - - bulk_download_service = BulkDownloadService(tmp_path) - bulk_download_service.start(mock_invoker) - board_name = bulk_download_service.get_clean_board_name("none") - - assert board_name == expected_board_name - - def test_delete(tmp_path: Path): """Test that the delete method removes the bulk download file."""