using temp directory for downloads

This commit is contained in:
Stefan Tobler 2024-01-28 18:59:56 -05:00 committed by Brandon Rising
parent b1301e1cbc
commit bb40196a17
4 changed files with 37 additions and 20 deletions

View File

@ -397,7 +397,7 @@ async def download_images_from_list(
raise HTTPException(status_code=400, detail="No images or board id specified.") raise HTTPException(status_code=400, detail="No images or board id specified.")
bulk_download_item_id: str = uuid_string() if board_id is None else board_id bulk_download_item_id: str = uuid_string() if board_id is None else board_id
board_name: str = ( board_name: str = (
"" if board_id is None else ApiDependencies.invoker.services.board_records.get(board_id).board_name "" if board_id is None else ApiDependencies.invoker.services.bulk_download.get_clean_board_name(board_id)
) )
# Type narrowing handled above ^, we know that image_names is not None, trying to keep null checks at the boundaries # Type narrowing handled above ^, we know that image_names is not None, trying to keep null checks at the boundaries

View File

@ -45,7 +45,7 @@ class BulkDownloadBase(ABC):
""" """
@abstractmethod @abstractmethod
def get_board_name(self, board_id: str) -> str: def get_clean_board_name(self, board_id: str) -> str:
""" """
Get the name of the board. Get the name of the board.

View File

@ -1,4 +1,5 @@
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Optional, Union from typing import Optional, Union
from zipfile import ZipFile from zipfile import ZipFile
@ -19,6 +20,7 @@ from .bulk_download_base import BulkDownloadBase
class BulkDownloadService(BulkDownloadBase): class BulkDownloadService(BulkDownloadBase):
__output_folder: Path __output_folder: Path
__temp_directory: TemporaryDirectory
__bulk_downloads_folder: Path __bulk_downloads_folder: Path
__event_bus: EventServiceBase __event_bus: EventServiceBase
__invoker: Invoker __invoker: Invoker
@ -35,7 +37,8 @@ class BulkDownloadService(BulkDownloadBase):
Initialize the downloader object. Initialize the downloader object.
""" """
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder) self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
self.__bulk_downloads_folder = self.__output_folder / "bulk_downloads" self.__temp_directory = TemporaryDirectory(dir=self.__output_folder)
self.__bulk_downloads_folder = Path(self.__temp_directory.name) / "bulk_downloads"
self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True) self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True)
def handler(self, image_names: list[str], board_id: Optional[str], bulk_download_item_id: Optional[str]) -> None: def handler(self, image_names: list[str], board_id: Optional[str], bulk_download_item_id: Optional[str]) -> None:
@ -57,8 +60,7 @@ class BulkDownloadService(BulkDownloadBase):
image_dtos: list[ImageDTO] = [] image_dtos: list[ImageDTO] = []
if board_id: if board_id:
board_name = self.get_board_name(board_id) board_name = self.get_clean_board_name(board_id)
board_name = self._clean_string_to_path_safe(board_name)
# -1 is the default value for limit, which means no limit, is_intermediate only gives us completed images # -1 is the default value for limit, which means no limit, is_intermediate only gives us completed images
image_dtos = self.__invoker.services.images.get_many( image_dtos = self.__invoker.services.images.get_many(
@ -80,11 +82,11 @@ class BulkDownloadService(BulkDownloadBase):
self.__invoker.services.logger.error("Problem bulk downloading images.") self.__invoker.services.logger.error("Problem bulk downloading images.")
raise e raise e
def get_board_name(self, board_id: str) -> str: def get_clean_board_name(self, board_id: str) -> str:
if board_id == "none": if board_id == "none":
return "Uncategorized" return "Uncategorized"
return self.__invoker.services.board_records.get(board_id).board_name return self._clean_string_to_path_safe(self.__invoker.services.board_records.get(board_id).board_name)
def _create_zip_file(self, image_dtos: list[ImageDTO], bulk_download_item_id: str) -> str: def _create_zip_file(self, image_dtos: list[ImageDTO], bulk_download_item_id: str) -> str:
""" """
@ -145,11 +147,7 @@ class BulkDownloadService(BulkDownloadBase):
def stop(self, *args, **kwargs): def stop(self, *args, **kwargs):
"""Stop the bulk download service and delete the files in the bulk download folder.""" """Stop the bulk download service and delete the files in the bulk download folder."""
# Get all the files in the bulk downloads folder, only .zip files # Get all the files in the bulk downloads folder, only .zip files
files = self.__bulk_downloads_folder.glob("*.zip") self.__temp_directory.cleanup()
# Delete all the files
for file in files:
file.unlink()
def delete(self, bulk_download_item_name: str) -> None: def delete(self, bulk_download_item_name: str) -> None:
""" """

View File

@ -1,5 +1,6 @@
import os import os
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any from typing import Any
from zipfile import ZipFile from zipfile import ZipFile
@ -86,9 +87,28 @@ def mock_invoker(mock_services: InvocationServices) -> Invoker:
return Invoker(services=mock_services) return Invoker(services=mock_services)
@pytest.fixture(autouse=True)
def mock_temporary_directory(monkeypatch: Any, tmp_path: Path):
"""Mock the TemporaryDirectory class so that it uses the tmp_path fixture."""
class MockTemporaryDirectory(TemporaryDirectory):
def __init__(self):
super().__init__(dir=tmp_path)
self.name = tmp_path
def mock_TemporaryDirectory(*args, **kwargs):
return MockTemporaryDirectory()
monkeypatch.setattr(
"invokeai.app.services.bulk_download.bulk_download_default.TemporaryDirectory", mock_TemporaryDirectory
)
def test_get_path_when_file_exists(tmp_path: Path) -> None: def test_get_path_when_file_exists(tmp_path: Path) -> None:
"""Test get_path when the file exists.""" """Test get_path when the file exists."""
bulk_download_service = BulkDownloadService(tmp_path)
# Create a directory at tmp_path/bulk_downloads # Create a directory at tmp_path/bulk_downloads
test_bulk_downloads_dir: Path = tmp_path / "bulk_downloads" test_bulk_downloads_dir: Path = tmp_path / "bulk_downloads"
test_bulk_downloads_dir.mkdir(parents=True, exist_ok=True) test_bulk_downloads_dir.mkdir(parents=True, exist_ok=True)
@ -97,7 +117,6 @@ def test_get_path_when_file_exists(tmp_path: Path) -> None:
test_file_path: Path = test_bulk_downloads_dir / "test.zip" test_file_path: Path = test_bulk_downloads_dir / "test.zip"
test_file_path.touch() test_file_path.touch()
bulk_download_service = BulkDownloadService(tmp_path)
assert bulk_download_service.get_path("test.zip") == str(test_file_path) assert bulk_download_service.get_path("test.zip") == str(test_file_path)
@ -164,7 +183,6 @@ def test_handler_board_id_default(tmp_path: Path, monkeypatch: Any, mock_image_d
_, expected_image_path, mock_image_contents = prepare_handler_test( _, expected_image_path, mock_image_contents = prepare_handler_test(
tmp_path, monkeypatch, mock_image_dto, mock_invoker tmp_path, monkeypatch, mock_image_dto, mock_invoker
) )
expected_zip_path: Path = tmp_path / "bulk_downloads" / "Uncategorized.zip"
def mock_get_many(*args, **kwargs): def mock_get_many(*args, **kwargs):
return OffsetPaginatedResults(limit=-1, total=1, offset=0, items=[mock_image_dto]) return OffsetPaginatedResults(limit=-1, total=1, offset=0, items=[mock_image_dto])
@ -175,6 +193,8 @@ def test_handler_board_id_default(tmp_path: Path, monkeypatch: Any, mock_image_d
bulk_download_service.start(mock_invoker) bulk_download_service.start(mock_invoker)
bulk_download_service.handler([], "none", None) bulk_download_service.handler([], "none", None)
expected_zip_path: Path = tmp_path / "bulk_downloads" / "Uncategorized.zip"
assert_handler_success( assert_handler_success(
expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events
) )
@ -188,7 +208,6 @@ def test_handler_bulk_download__item_id_given(
_, expected_image_path, mock_image_contents = prepare_handler_test( _, expected_image_path, mock_image_contents = prepare_handler_test(
tmp_path, monkeypatch, mock_image_dto, mock_invoker tmp_path, monkeypatch, mock_image_dto, mock_invoker
) )
expected_zip_path: Path = tmp_path / "bulk_downloads" / "test_id.zip"
def mock_get_many(*args, **kwargs): def mock_get_many(*args, **kwargs):
return OffsetPaginatedResults(limit=-1, total=1, offset=0, items=[mock_image_dto]) return OffsetPaginatedResults(limit=-1, total=1, offset=0, items=[mock_image_dto])
@ -199,6 +218,8 @@ def test_handler_bulk_download__item_id_given(
bulk_download_service.start(mock_invoker) bulk_download_service.start(mock_invoker)
bulk_download_service.handler([mock_image_dto.image_name], None, "test_id") bulk_download_service.handler([mock_image_dto.image_name], None, "test_id")
expected_zip_path: Path = tmp_path / "bulk_downloads" / "test_id.zip"
assert_handler_success( assert_handler_success(
expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events
) )
@ -341,7 +362,7 @@ def test_get_board_name(tmp_path: Path, monkeypatch: Any, mock_invoker: Invoker)
bulk_download_service = BulkDownloadService(tmp_path) bulk_download_service = BulkDownloadService(tmp_path)
bulk_download_service.start(mock_invoker) bulk_download_service.start(mock_invoker)
board_name = bulk_download_service.get_board_name("12345") board_name = bulk_download_service.get_clean_board_name("12345")
assert board_name == expected_board_name assert board_name == expected_board_name
@ -353,7 +374,7 @@ def test_get_board_name_default(tmp_path: Path, mock_invoker: Invoker):
bulk_download_service = BulkDownloadService(tmp_path) bulk_download_service = BulkDownloadService(tmp_path)
bulk_download_service.start(mock_invoker) bulk_download_service.start(mock_invoker)
board_name = bulk_download_service.get_board_name("none") board_name = bulk_download_service.get_clean_board_name("none")
assert board_name == expected_board_name assert board_name == expected_board_name
@ -385,6 +406,4 @@ def test_stop(tmp_path: Path):
bulk_download_service.stop() bulk_download_service.stop()
assert (tmp_path / "bulk_downloads").exists() assert not (tmp_path / "bulk_downloads").exists()
assert mock_dir.exists()
assert len(os.listdir(tmp_path / "bulk_downloads")) == 1