mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
using temp directory for downloads
This commit is contained in:
parent
b1301e1cbc
commit
bb40196a17
@ -397,7 +397,7 @@ async def download_images_from_list(
|
|||||||
raise HTTPException(status_code=400, detail="No images or board id specified.")
|
raise HTTPException(status_code=400, detail="No images or board id specified.")
|
||||||
bulk_download_item_id: str = uuid_string() if board_id is None else board_id
|
bulk_download_item_id: str = uuid_string() if board_id is None else board_id
|
||||||
board_name: str = (
|
board_name: str = (
|
||||||
"" if board_id is None else ApiDependencies.invoker.services.board_records.get(board_id).board_name
|
"" if board_id is None else ApiDependencies.invoker.services.bulk_download.get_clean_board_name(board_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Type narrowing handled above ^, we know that image_names is not None, trying to keep null checks at the boundaries
|
# Type narrowing handled above ^, we know that image_names is not None, trying to keep null checks at the boundaries
|
||||||
|
@ -45,7 +45,7 @@ class BulkDownloadBase(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_board_name(self, board_id: str) -> str:
|
def get_clean_board_name(self, board_id: str) -> str:
|
||||||
"""
|
"""
|
||||||
Get the name of the board.
|
Get the name of the board.
|
||||||
|
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from tempfile import TemporaryDirectory
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
from zipfile import ZipFile
|
from zipfile import ZipFile
|
||||||
|
|
||||||
@ -19,6 +20,7 @@ from .bulk_download_base import BulkDownloadBase
|
|||||||
|
|
||||||
class BulkDownloadService(BulkDownloadBase):
|
class BulkDownloadService(BulkDownloadBase):
|
||||||
__output_folder: Path
|
__output_folder: Path
|
||||||
|
__temp_directory: TemporaryDirectory
|
||||||
__bulk_downloads_folder: Path
|
__bulk_downloads_folder: Path
|
||||||
__event_bus: EventServiceBase
|
__event_bus: EventServiceBase
|
||||||
__invoker: Invoker
|
__invoker: Invoker
|
||||||
@ -35,7 +37,8 @@ class BulkDownloadService(BulkDownloadBase):
|
|||||||
Initialize the downloader object.
|
Initialize the downloader object.
|
||||||
"""
|
"""
|
||||||
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
||||||
self.__bulk_downloads_folder = self.__output_folder / "bulk_downloads"
|
self.__temp_directory = TemporaryDirectory(dir=self.__output_folder)
|
||||||
|
self.__bulk_downloads_folder = Path(self.__temp_directory.name) / "bulk_downloads"
|
||||||
self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True)
|
self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
def handler(self, image_names: list[str], board_id: Optional[str], bulk_download_item_id: Optional[str]) -> None:
|
def handler(self, image_names: list[str], board_id: Optional[str], bulk_download_item_id: Optional[str]) -> None:
|
||||||
@ -57,8 +60,7 @@ class BulkDownloadService(BulkDownloadBase):
|
|||||||
image_dtos: list[ImageDTO] = []
|
image_dtos: list[ImageDTO] = []
|
||||||
|
|
||||||
if board_id:
|
if board_id:
|
||||||
board_name = self.get_board_name(board_id)
|
board_name = self.get_clean_board_name(board_id)
|
||||||
board_name = self._clean_string_to_path_safe(board_name)
|
|
||||||
|
|
||||||
# -1 is the default value for limit, which means no limit, is_intermediate only gives us completed images
|
# -1 is the default value for limit, which means no limit, is_intermediate only gives us completed images
|
||||||
image_dtos = self.__invoker.services.images.get_many(
|
image_dtos = self.__invoker.services.images.get_many(
|
||||||
@ -80,11 +82,11 @@ class BulkDownloadService(BulkDownloadBase):
|
|||||||
self.__invoker.services.logger.error("Problem bulk downloading images.")
|
self.__invoker.services.logger.error("Problem bulk downloading images.")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def get_board_name(self, board_id: str) -> str:
|
def get_clean_board_name(self, board_id: str) -> str:
|
||||||
if board_id == "none":
|
if board_id == "none":
|
||||||
return "Uncategorized"
|
return "Uncategorized"
|
||||||
|
|
||||||
return self.__invoker.services.board_records.get(board_id).board_name
|
return self._clean_string_to_path_safe(self.__invoker.services.board_records.get(board_id).board_name)
|
||||||
|
|
||||||
def _create_zip_file(self, image_dtos: list[ImageDTO], bulk_download_item_id: str) -> str:
|
def _create_zip_file(self, image_dtos: list[ImageDTO], bulk_download_item_id: str) -> str:
|
||||||
"""
|
"""
|
||||||
@ -145,11 +147,7 @@ class BulkDownloadService(BulkDownloadBase):
|
|||||||
def stop(self, *args, **kwargs):
|
def stop(self, *args, **kwargs):
|
||||||
"""Stop the bulk download service and delete the files in the bulk download folder."""
|
"""Stop the bulk download service and delete the files in the bulk download folder."""
|
||||||
# Get all the files in the bulk downloads folder, only .zip files
|
# Get all the files in the bulk downloads folder, only .zip files
|
||||||
files = self.__bulk_downloads_folder.glob("*.zip")
|
self.__temp_directory.cleanup()
|
||||||
|
|
||||||
# Delete all the files
|
|
||||||
for file in files:
|
|
||||||
file.unlink()
|
|
||||||
|
|
||||||
def delete(self, bulk_download_item_name: str) -> None:
|
def delete(self, bulk_download_item_name: str) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from tempfile import TemporaryDirectory
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from zipfile import ZipFile
|
from zipfile import ZipFile
|
||||||
|
|
||||||
@ -86,9 +87,28 @@ def mock_invoker(mock_services: InvocationServices) -> Invoker:
|
|||||||
return Invoker(services=mock_services)
|
return Invoker(services=mock_services)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_temporary_directory(monkeypatch: Any, tmp_path: Path):
|
||||||
|
"""Mock the TemporaryDirectory class so that it uses the tmp_path fixture."""
|
||||||
|
|
||||||
|
class MockTemporaryDirectory(TemporaryDirectory):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(dir=tmp_path)
|
||||||
|
self.name = tmp_path
|
||||||
|
|
||||||
|
def mock_TemporaryDirectory(*args, **kwargs):
|
||||||
|
return MockTemporaryDirectory()
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"invokeai.app.services.bulk_download.bulk_download_default.TemporaryDirectory", mock_TemporaryDirectory
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_get_path_when_file_exists(tmp_path: Path) -> None:
|
def test_get_path_when_file_exists(tmp_path: Path) -> None:
|
||||||
"""Test get_path when the file exists."""
|
"""Test get_path when the file exists."""
|
||||||
|
|
||||||
|
bulk_download_service = BulkDownloadService(tmp_path)
|
||||||
|
|
||||||
# Create a directory at tmp_path/bulk_downloads
|
# Create a directory at tmp_path/bulk_downloads
|
||||||
test_bulk_downloads_dir: Path = tmp_path / "bulk_downloads"
|
test_bulk_downloads_dir: Path = tmp_path / "bulk_downloads"
|
||||||
test_bulk_downloads_dir.mkdir(parents=True, exist_ok=True)
|
test_bulk_downloads_dir.mkdir(parents=True, exist_ok=True)
|
||||||
@ -97,7 +117,6 @@ def test_get_path_when_file_exists(tmp_path: Path) -> None:
|
|||||||
test_file_path: Path = test_bulk_downloads_dir / "test.zip"
|
test_file_path: Path = test_bulk_downloads_dir / "test.zip"
|
||||||
test_file_path.touch()
|
test_file_path.touch()
|
||||||
|
|
||||||
bulk_download_service = BulkDownloadService(tmp_path)
|
|
||||||
assert bulk_download_service.get_path("test.zip") == str(test_file_path)
|
assert bulk_download_service.get_path("test.zip") == str(test_file_path)
|
||||||
|
|
||||||
|
|
||||||
@ -164,7 +183,6 @@ def test_handler_board_id_default(tmp_path: Path, monkeypatch: Any, mock_image_d
|
|||||||
_, expected_image_path, mock_image_contents = prepare_handler_test(
|
_, expected_image_path, mock_image_contents = prepare_handler_test(
|
||||||
tmp_path, monkeypatch, mock_image_dto, mock_invoker
|
tmp_path, monkeypatch, mock_image_dto, mock_invoker
|
||||||
)
|
)
|
||||||
expected_zip_path: Path = tmp_path / "bulk_downloads" / "Uncategorized.zip"
|
|
||||||
|
|
||||||
def mock_get_many(*args, **kwargs):
|
def mock_get_many(*args, **kwargs):
|
||||||
return OffsetPaginatedResults(limit=-1, total=1, offset=0, items=[mock_image_dto])
|
return OffsetPaginatedResults(limit=-1, total=1, offset=0, items=[mock_image_dto])
|
||||||
@ -175,6 +193,8 @@ def test_handler_board_id_default(tmp_path: Path, monkeypatch: Any, mock_image_d
|
|||||||
bulk_download_service.start(mock_invoker)
|
bulk_download_service.start(mock_invoker)
|
||||||
bulk_download_service.handler([], "none", None)
|
bulk_download_service.handler([], "none", None)
|
||||||
|
|
||||||
|
expected_zip_path: Path = tmp_path / "bulk_downloads" / "Uncategorized.zip"
|
||||||
|
|
||||||
assert_handler_success(
|
assert_handler_success(
|
||||||
expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events
|
expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events
|
||||||
)
|
)
|
||||||
@ -188,7 +208,6 @@ def test_handler_bulk_download__item_id_given(
|
|||||||
_, expected_image_path, mock_image_contents = prepare_handler_test(
|
_, expected_image_path, mock_image_contents = prepare_handler_test(
|
||||||
tmp_path, monkeypatch, mock_image_dto, mock_invoker
|
tmp_path, monkeypatch, mock_image_dto, mock_invoker
|
||||||
)
|
)
|
||||||
expected_zip_path: Path = tmp_path / "bulk_downloads" / "test_id.zip"
|
|
||||||
|
|
||||||
def mock_get_many(*args, **kwargs):
|
def mock_get_many(*args, **kwargs):
|
||||||
return OffsetPaginatedResults(limit=-1, total=1, offset=0, items=[mock_image_dto])
|
return OffsetPaginatedResults(limit=-1, total=1, offset=0, items=[mock_image_dto])
|
||||||
@ -199,6 +218,8 @@ def test_handler_bulk_download__item_id_given(
|
|||||||
bulk_download_service.start(mock_invoker)
|
bulk_download_service.start(mock_invoker)
|
||||||
bulk_download_service.handler([mock_image_dto.image_name], None, "test_id")
|
bulk_download_service.handler([mock_image_dto.image_name], None, "test_id")
|
||||||
|
|
||||||
|
expected_zip_path: Path = tmp_path / "bulk_downloads" / "test_id.zip"
|
||||||
|
|
||||||
assert_handler_success(
|
assert_handler_success(
|
||||||
expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events
|
expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events
|
||||||
)
|
)
|
||||||
@ -341,7 +362,7 @@ def test_get_board_name(tmp_path: Path, monkeypatch: Any, mock_invoker: Invoker)
|
|||||||
|
|
||||||
bulk_download_service = BulkDownloadService(tmp_path)
|
bulk_download_service = BulkDownloadService(tmp_path)
|
||||||
bulk_download_service.start(mock_invoker)
|
bulk_download_service.start(mock_invoker)
|
||||||
board_name = bulk_download_service.get_board_name("12345")
|
board_name = bulk_download_service.get_clean_board_name("12345")
|
||||||
|
|
||||||
assert board_name == expected_board_name
|
assert board_name == expected_board_name
|
||||||
|
|
||||||
@ -353,7 +374,7 @@ def test_get_board_name_default(tmp_path: Path, mock_invoker: Invoker):
|
|||||||
|
|
||||||
bulk_download_service = BulkDownloadService(tmp_path)
|
bulk_download_service = BulkDownloadService(tmp_path)
|
||||||
bulk_download_service.start(mock_invoker)
|
bulk_download_service.start(mock_invoker)
|
||||||
board_name = bulk_download_service.get_board_name("none")
|
board_name = bulk_download_service.get_clean_board_name("none")
|
||||||
|
|
||||||
assert board_name == expected_board_name
|
assert board_name == expected_board_name
|
||||||
|
|
||||||
@ -385,6 +406,4 @@ def test_stop(tmp_path: Path):
|
|||||||
|
|
||||||
bulk_download_service.stop()
|
bulk_download_service.stop()
|
||||||
|
|
||||||
assert (tmp_path / "bulk_downloads").exists()
|
assert not (tmp_path / "bulk_downloads").exists()
|
||||||
assert mock_dir.exists()
|
|
||||||
assert len(os.listdir(tmp_path / "bulk_downloads")) == 1
|
|
||||||
|
Loading…
Reference in New Issue
Block a user