From ff53563152c2db7293371431fb406a2d7b20f424 Mon Sep 17 00:00:00 2001 From: Stefan Tobler Date: Sun, 28 Jan 2024 00:38:01 -0500 Subject: [PATCH] narrowing bulk_download stop service scope --- .../bulk_download/bulk_download_default.py | 4 ++-- .../bulk_download/test_bulk_download.py | 20 ++++++++++++++++++- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/invokeai/app/services/bulk_download/bulk_download_default.py b/invokeai/app/services/bulk_download/bulk_download_default.py index a0abb6743a..87966ad622 100644 --- a/invokeai/app/services/bulk_download/bulk_download_default.py +++ b/invokeai/app/services/bulk_download/bulk_download_default.py @@ -143,8 +143,8 @@ class BulkDownloadService(BulkDownloadBase): def stop(self, *args, **kwargs): """Stop the bulk download service and delete the files in the bulk download folder.""" - # Get all the files in the bulk downloads folder - files = self.__bulk_downloads_folder.glob("*") + # Get all the files in the bulk downloads folder, only .zip files + files = self.__bulk_downloads_folder.glob("*.zip") # Delete all the files for file in files: diff --git a/tests/app/services/bulk_download/test_bulk_download.py b/tests/app/services/bulk_download/test_bulk_download.py index bc6eb8d41c..184519866a 100644 --- a/tests/app/services/bulk_download/test_bulk_download.py +++ b/tests/app/services/bulk_download/test_bulk_download.py @@ -319,7 +319,7 @@ def execute_handler_test_on_error( assert event_bus.events[1].payload["error"] == error.__str__() -def test_delete(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker): +def test_delete(tmp_path: Path): """Test that the delete method removes the bulk download file.""" bulk_download_service = BulkDownloadService(tmp_path) @@ -331,3 +331,21 @@ def test_delete(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock assert (tmp_path / "bulk_downloads").exists() assert len(os.listdir(tmp_path / "bulk_downloads")) == 0 + +def test_stop(tmp_path: Path): + """Test that the delete method removes the bulk download file.""" + + bulk_download_service = BulkDownloadService(tmp_path) + + mock_file: Path = tmp_path / "bulk_downloads" / "test.zip" + mock_file.write_text("contents") + + mock_dir: Path = tmp_path / "bulk_downloads" / "test" + mock_dir.mkdir(parents=True, exist_ok=True) + + + bulk_download_service.stop() + + assert (tmp_path / "bulk_downloads").exists() + assert mock_dir.exists() + assert len(os.listdir(tmp_path / "bulk_downloads")) == 1