From 4dfa1e3d03e7031e7f8c7d578b989a577ff22334 Mon Sep 17 00:00:00 2001 From: Stefan Tobler Date: Mon, 15 Jan 2024 15:58:43 -0500 Subject: [PATCH] cleaning up bulk download zip after the response is complete --- invokeai/app/api/routers/images.py | 3 +- .../bulk_download/bulk_download_base.py | 8 ++++ .../bulk_download/bulk_download_default.py | 43 +++++++++++-------- .../bulk_download/test_bulk_download.py | 16 ++++++- 4 files changed, 51 insertions(+), 19 deletions(-) diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index 236961fa9e..d11c89c749 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -409,10 +409,10 @@ async def download_images_from_list( }, ) async def get_bulk_download_item( + background_tasks: BackgroundTasks, bulk_download_item_name: str = Path(description="The bulk_download_item_id of the bulk download item to get"), ) -> FileResponse: """Gets a bulk download zip file""" - try: path = ApiDependencies.invoker.services.bulk_download.get_path(bulk_download_item_name) @@ -423,6 +423,7 @@ async def get_bulk_download_item( content_disposition_type="inline", ) response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}" + background_tasks.add_task(ApiDependencies.invoker.services.bulk_download.delete, bulk_download_item_name) return response except Exception: raise HTTPException(status_code=404) diff --git a/invokeai/app/services/bulk_download/bulk_download_base.py b/invokeai/app/services/bulk_download/bulk_download_base.py index 7a4aa0661c..a1071f254a 100644 --- a/invokeai/app/services/bulk_download/bulk_download_base.py +++ b/invokeai/app/services/bulk_download/bulk_download_base.py @@ -59,3 +59,11 @@ class BulkDownloadBase(ABC): Returns: None """ + + @abstractmethod + def delete(self, bulk_download_item_name: str) -> None: + """ + Delete the bulk download file. + + :param bulk_download_item_name: The name of the bulk download item. + """ diff --git a/invokeai/app/services/bulk_download/bulk_download_default.py b/invokeai/app/services/bulk_download/bulk_download_default.py index a9ea12bfd6..a0abb6743a 100644 --- a/invokeai/app/services/bulk_download/bulk_download_default.py +++ b/invokeai/app/services/bulk_download/bulk_download_default.py @@ -38,23 +38,6 @@ class BulkDownloadService(BulkDownloadBase): self.__bulk_downloads_folder = self.__output_folder / "bulk_downloads" self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True) - def get_path(self, bulk_download_item_name: str) -> str: - """ - Get the path to the bulk download file. - - :param bulk_download_item_name: The name of the bulk download item. - :return: The path to the bulk download file. - """ - path = str(self.__bulk_downloads_folder / bulk_download_item_name) - if not self.validate_path(path): - raise BulkDownloadTargetException() - return path - - def validate_path(self, path: Union[str, Path]) -> bool: - """Validates the path given for a bulk download.""" - path = path if isinstance(path, Path) else Path(path) - return path.exists() - def handler(self, image_names: list[str], board_id: Optional[str]) -> None: """ Create a zip file containing the images specified by the given image names or board id. @@ -166,3 +149,29 @@ class BulkDownloadService(BulkDownloadBase): # Delete all the files for file in files: file.unlink() + + def delete(self, bulk_download_item_name: str) -> None: + """ + Delete the bulk download file. + + :param bulk_download_item_name: The name of the bulk download item. + """ + path = self.get_path(bulk_download_item_name) + Path(path).unlink() + + def get_path(self, bulk_download_item_name: str) -> str: + """ + Get the path to the bulk download file. + + :param bulk_download_item_name: The name of the bulk download item. + :return: The path to the bulk download file. + """ + path = str(self.__bulk_downloads_folder / bulk_download_item_name) + if not self.validate_path(path): + raise BulkDownloadTargetException() + return path + + def validate_path(self, path: Union[str, Path]) -> bool: + """Validates the path given for a bulk download.""" + path = path if isinstance(path, Path) else Path(path) + return path.exists() diff --git a/tests/app/services/bulk_download/test_bulk_download.py b/tests/app/services/bulk_download/test_bulk_download.py index 4f476c21be..4c9dc42612 100644 --- a/tests/app/services/bulk_download/test_bulk_download.py +++ b/tests/app/services/bulk_download/test_bulk_download.py @@ -293,7 +293,7 @@ def test_handler_on_generic_exception( monkeypatch.setattr(mock_invoker.services.images, "get_dto", mock_get_board_name) - with pytest.raises(Exception): + with pytest.raises(Exception): # noqa: B017 execute_handler_test_on_error(tmp_path, monkeypatch, mock_image_dto, mock_invoker, exception) event_bus: DummyEventService = mock_invoker.services.events @@ -317,3 +317,17 @@ def execute_handler_test_on_error( assert event_bus.events[0].event_name == "bulk_download_started" assert event_bus.events[1].event_name == "bulk_download_failed" assert event_bus.events[1].payload["error"] == error.__str__() + + +def test_delete(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker): + """Test that the delete method removes the bulk download file.""" + + bulk_download_service = BulkDownloadService(tmp_path) + + mock_file: Path = tmp_path / "bulk_downloads" / "test.zip" + mock_file.write_text("contents") + + bulk_download_service.delete("test.zip") + + assert (tmp_path / "bulk_downloads").exists() + assert len(os.listdir(tmp_path / "bulk_downloads")) == 0