diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index c12556aed6..69a76e4062 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -397,7 +397,7 @@ async def download_images_from_list( raise HTTPException(status_code=400, detail="No images or board id specified.") bulk_download_item_id: str = uuid_string() if board_id is None else board_id board_name: str = ( - "" if board_id is None else ApiDependencies.invoker.services.board_records.get(board_id).board_name + "" if board_id is None else ApiDependencies.invoker.services.bulk_download.get_clean_board_name(board_id) ) # Type narrowing handled above ^, we know that image_names is not None, trying to keep null checks at the boundaries diff --git a/invokeai/app/services/bulk_download/bulk_download_base.py b/invokeai/app/services/bulk_download/bulk_download_base.py index d6b0e62211..89b2e73772 100644 --- a/invokeai/app/services/bulk_download/bulk_download_base.py +++ b/invokeai/app/services/bulk_download/bulk_download_base.py @@ -45,7 +45,7 @@ class BulkDownloadBase(ABC): """ @abstractmethod - def get_board_name(self, board_id: str) -> str: + def get_clean_board_name(self, board_id: str) -> str: """ Get the name of the board. diff --git a/invokeai/app/services/bulk_download/bulk_download_default.py b/invokeai/app/services/bulk_download/bulk_download_default.py index be70dea2c1..fe76a12333 100644 --- a/invokeai/app/services/bulk_download/bulk_download_default.py +++ b/invokeai/app/services/bulk_download/bulk_download_default.py @@ -1,4 +1,5 @@ from pathlib import Path +from tempfile import TemporaryDirectory from typing import Optional, Union from zipfile import ZipFile @@ -19,6 +20,7 @@ from .bulk_download_base import BulkDownloadBase class BulkDownloadService(BulkDownloadBase): __output_folder: Path + __temp_directory: TemporaryDirectory __bulk_downloads_folder: Path __event_bus: EventServiceBase __invoker: Invoker @@ -35,7 +37,8 @@ class BulkDownloadService(BulkDownloadBase): Initialize the downloader object. """ self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder) - self.__bulk_downloads_folder = self.__output_folder / "bulk_downloads" + self.__temp_directory = TemporaryDirectory(dir=self.__output_folder) + self.__bulk_downloads_folder = Path(self.__temp_directory.name) / "bulk_downloads" self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True) def handler(self, image_names: list[str], board_id: Optional[str], bulk_download_item_id: Optional[str]) -> None: @@ -57,8 +60,7 @@ class BulkDownloadService(BulkDownloadBase): image_dtos: list[ImageDTO] = [] if board_id: - board_name = self.get_board_name(board_id) - board_name = self._clean_string_to_path_safe(board_name) + board_name = self.get_clean_board_name(board_id) # -1 is the default value for limit, which means no limit, is_intermediate only gives us completed images image_dtos = self.__invoker.services.images.get_many( @@ -80,11 +82,11 @@ class BulkDownloadService(BulkDownloadBase): self.__invoker.services.logger.error("Problem bulk downloading images.") raise e - def get_board_name(self, board_id: str) -> str: + def get_clean_board_name(self, board_id: str) -> str: if board_id == "none": return "Uncategorized" - return self.__invoker.services.board_records.get(board_id).board_name + return self._clean_string_to_path_safe(self.__invoker.services.board_records.get(board_id).board_name) def _create_zip_file(self, image_dtos: list[ImageDTO], bulk_download_item_id: str) -> str: """ @@ -145,11 +147,7 @@ class BulkDownloadService(BulkDownloadBase): def stop(self, *args, **kwargs): """Stop the bulk download service and delete the files in the bulk download folder.""" # Get all the files in the bulk downloads folder, only .zip files - files = self.__bulk_downloads_folder.glob("*.zip") - - # Delete all the files - for file in files: - file.unlink() + self.__temp_directory.cleanup() def delete(self, bulk_download_item_name: str) -> None: """ diff --git a/tests/app/services/bulk_download/test_bulk_download.py b/tests/app/services/bulk_download/test_bulk_download.py index 7909c44214..3cd2123232 100644 --- a/tests/app/services/bulk_download/test_bulk_download.py +++ b/tests/app/services/bulk_download/test_bulk_download.py @@ -1,5 +1,6 @@ import os from pathlib import Path +from tempfile import TemporaryDirectory from typing import Any from zipfile import ZipFile @@ -86,9 +87,28 @@ def mock_invoker(mock_services: InvocationServices) -> Invoker: return Invoker(services=mock_services) +@pytest.fixture(autouse=True) +def mock_temporary_directory(monkeypatch: Any, tmp_path: Path): + """Mock the TemporaryDirectory class so that it uses the tmp_path fixture.""" + + class MockTemporaryDirectory(TemporaryDirectory): + def __init__(self): + super().__init__(dir=tmp_path) + self.name = tmp_path + + def mock_TemporaryDirectory(*args, **kwargs): + return MockTemporaryDirectory() + + monkeypatch.setattr( + "invokeai.app.services.bulk_download.bulk_download_default.TemporaryDirectory", mock_TemporaryDirectory + ) + + def test_get_path_when_file_exists(tmp_path: Path) -> None: """Test get_path when the file exists.""" + bulk_download_service = BulkDownloadService(tmp_path) + # Create a directory at tmp_path/bulk_downloads test_bulk_downloads_dir: Path = tmp_path / "bulk_downloads" test_bulk_downloads_dir.mkdir(parents=True, exist_ok=True) @@ -97,7 +117,6 @@ def test_get_path_when_file_exists(tmp_path: Path) -> None: test_file_path: Path = test_bulk_downloads_dir / "test.zip" test_file_path.touch() - bulk_download_service = BulkDownloadService(tmp_path) assert bulk_download_service.get_path("test.zip") == str(test_file_path) @@ -164,7 +183,6 @@ def test_handler_board_id_default(tmp_path: Path, monkeypatch: Any, mock_image_d _, expected_image_path, mock_image_contents = prepare_handler_test( tmp_path, monkeypatch, mock_image_dto, mock_invoker ) - expected_zip_path: Path = tmp_path / "bulk_downloads" / "Uncategorized.zip" def mock_get_many(*args, **kwargs): return OffsetPaginatedResults(limit=-1, total=1, offset=0, items=[mock_image_dto]) @@ -175,6 +193,8 @@ def test_handler_board_id_default(tmp_path: Path, monkeypatch: Any, mock_image_d bulk_download_service.start(mock_invoker) bulk_download_service.handler([], "none", None) + expected_zip_path: Path = tmp_path / "bulk_downloads" / "Uncategorized.zip" + assert_handler_success( expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events ) @@ -188,7 +208,6 @@ def test_handler_bulk_download__item_id_given( _, expected_image_path, mock_image_contents = prepare_handler_test( tmp_path, monkeypatch, mock_image_dto, mock_invoker ) - expected_zip_path: Path = tmp_path / "bulk_downloads" / "test_id.zip" def mock_get_many(*args, **kwargs): return OffsetPaginatedResults(limit=-1, total=1, offset=0, items=[mock_image_dto]) @@ -199,6 +218,8 @@ def test_handler_bulk_download__item_id_given( bulk_download_service.start(mock_invoker) bulk_download_service.handler([mock_image_dto.image_name], None, "test_id") + expected_zip_path: Path = tmp_path / "bulk_downloads" / "test_id.zip" + assert_handler_success( expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events ) @@ -341,7 +362,7 @@ def test_get_board_name(tmp_path: Path, monkeypatch: Any, mock_invoker: Invoker) bulk_download_service = BulkDownloadService(tmp_path) bulk_download_service.start(mock_invoker) - board_name = bulk_download_service.get_board_name("12345") + board_name = bulk_download_service.get_clean_board_name("12345") assert board_name == expected_board_name @@ -353,7 +374,7 @@ def test_get_board_name_default(tmp_path: Path, mock_invoker: Invoker): bulk_download_service = BulkDownloadService(tmp_path) bulk_download_service.start(mock_invoker) - board_name = bulk_download_service.get_board_name("none") + board_name = bulk_download_service.get_clean_board_name("none") assert board_name == expected_board_name @@ -385,6 +406,4 @@ def test_stop(tmp_path: Path): bulk_download_service.stop() - assert (tmp_path / "bulk_downloads").exists() - assert mock_dir.exists() - assert len(os.listdir(tmp_path / "bulk_downloads")) == 1 + assert not (tmp_path / "bulk_downloads").exists()