removing dependency on an output folder, embrace python temp folder for bulk download

This commit is contained in:
Stefan Tobler 2024-02-17 00:29:05 -05:00 committed by psychedelicious
parent 0ab9fe6987
commit 037cac8154
5 changed files with 19 additions and 27 deletions

View File

@ -82,7 +82,7 @@ class ApiDependencies:
board_records = SqliteBoardRecordStorage(db=db)
boards = BoardService()
events = FastAPIEventService(event_handler_id)
bulk_download = BulkDownloadService(output_folder=f"{output_folder}")
bulk_download = BulkDownloadService()
image_records = SqliteImageRecordStorage(db=db)
images = ImageService()
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)

View File

@ -1,6 +1,5 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Optional, Union
from typing import Optional
from invokeai.app.services.invoker import Invoker
@ -18,11 +17,9 @@ class BulkDownloadBase(ABC):
"""
@abstractmethod
def __init__(self, output_folder: Union[str, Path]):
def __init__(self):
"""
Create BulkDownloadBase object.
:param output_folder: The path to the output folder where the bulk download files can be temporarily stored.
"""
@abstractmethod

View File

@ -19,7 +19,6 @@ from .bulk_download_base import BulkDownloadBase
class BulkDownloadService(BulkDownloadBase):
__output_folder: Path
__temp_directory: TemporaryDirectory
__bulk_downloads_folder: Path
__event_bus: EventServiceBase
@ -29,15 +28,11 @@ class BulkDownloadService(BulkDownloadBase):
self.__invoker = invoker
self.__event_bus = invoker.services.events
def __init__(
self,
output_folder: Union[str, Path],
):
def __init__(self):
"""
Initialize the downloader object.
"""
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
self.__temp_directory = TemporaryDirectory(dir=self.__output_folder)
self.__temp_directory = TemporaryDirectory()
self.__bulk_downloads_folder = Path(self.__temp_directory.name) / "bulk_downloads"
self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True)

View File

@ -31,7 +31,7 @@ def mock_services(tmp_path: Path) -> InvocationServices:
board_images=None, # type: ignore
board_records=SqliteBoardRecordStorage(db=db),
boards=None, # type: ignore
bulk_download=BulkDownloadService(tmp_path),
bulk_download=BulkDownloadService(),
configuration=None, # type: ignore
events=None, # type: ignore
graph_execution_manager=None, # type: ignore

View File

@ -113,7 +113,7 @@ def mock_temporary_directory(monkeypatch: Any, tmp_path: Path):
def test_get_path_when_file_exists(tmp_path: Path) -> None:
"""Test get_path when the file exists."""
bulk_download_service = BulkDownloadService(tmp_path)
bulk_download_service = BulkDownloadService()
# Create a directory at tmp_path/bulk_downloads
test_bulk_downloads_dir: Path = tmp_path / "bulk_downloads"
@ -129,7 +129,7 @@ def test_get_path_when_file_exists(tmp_path: Path) -> None:
def test_get_path_when_file_does_not_exist(tmp_path: Path) -> None:
"""Test get_path when the file does not exist."""
bulk_download_service = BulkDownloadService(tmp_path)
bulk_download_service = BulkDownloadService()
with pytest.raises(BulkDownloadTargetException):
bulk_download_service.get_path("test")
@ -137,7 +137,7 @@ def test_get_path_when_file_does_not_exist(tmp_path: Path) -> None:
def test_bulk_downloads_dir_created_at_start(tmp_path: Path) -> None:
"""Test that the bulk_downloads directory is created at start."""
BulkDownloadService(tmp_path)
BulkDownloadService()
assert (tmp_path / "bulk_downloads").exists()
@ -148,7 +148,7 @@ def test_handler_image_names(tmp_path: Path, monkeypatch: Any, mock_image_dto: I
tmp_path, monkeypatch, mock_image_dto, mock_invoker
)
bulk_download_service = BulkDownloadService(tmp_path)
bulk_download_service = BulkDownloadService()
bulk_download_service.start(mock_invoker)
bulk_download_service.handler([mock_image_dto.image_name], None, None)
@ -160,7 +160,7 @@ def test_handler_image_names(tmp_path: Path, monkeypatch: Any, mock_image_dto: I
def test_generate_id(monkeypatch: Any):
"""Test that the generate_id method generates a unique id."""
bulk_download_service = BulkDownloadService("test")
bulk_download_service = BulkDownloadService()
monkeypatch.setattr("invokeai.app.services.bulk_download.bulk_download_default.uuid_string", lambda: "test")
@ -170,7 +170,7 @@ def test_generate_id(monkeypatch: Any):
def test_generate_id_with_board_id(monkeypatch: Any, mock_invoker: Invoker):
"""Test that the generate_id method generates a unique id with a board id."""
bulk_download_service = BulkDownloadService("test")
bulk_download_service = BulkDownloadService()
bulk_download_service.start(mock_invoker)
def mock_board_get(*args, **kwargs):
@ -186,7 +186,7 @@ def test_generate_id_with_board_id(monkeypatch: Any, mock_invoker: Invoker):
def test_generate_id_with_default_board_id(monkeypatch: Any):
"""Test that the generate_id method generates a unique id with a board id."""
bulk_download_service = BulkDownloadService("test")
bulk_download_service = BulkDownloadService()
monkeypatch.setattr("invokeai.app.services.bulk_download.bulk_download_default.uuid_string", lambda: "test")
@ -210,7 +210,7 @@ def test_handler_board_id(tmp_path: Path, monkeypatch: Any, mock_image_dto: Imag
monkeypatch.setattr(mock_invoker.services.images, "get_many", mock_get_many)
bulk_download_service = BulkDownloadService(tmp_path)
bulk_download_service = BulkDownloadService()
bulk_download_service.start(mock_invoker)
bulk_download_service.handler([], "test", None)
@ -231,7 +231,7 @@ def test_handler_board_id_default(tmp_path: Path, monkeypatch: Any, mock_image_d
monkeypatch.setattr(mock_invoker.services.images, "get_many", mock_get_many)
bulk_download_service = BulkDownloadService(tmp_path)
bulk_download_service = BulkDownloadService()
bulk_download_service.start(mock_invoker)
bulk_download_service.handler([], "none", None)
@ -256,7 +256,7 @@ def test_handler_bulk_download_item_id_given(
monkeypatch.setattr(mock_invoker.services.images, "get_many", mock_get_many)
bulk_download_service = BulkDownloadService(tmp_path)
bulk_download_service = BulkDownloadService()
bulk_download_service.start(mock_invoker)
bulk_download_service.handler([mock_image_dto.image_name], None, "test_id")
@ -380,7 +380,7 @@ def test_handler_on_generic_exception(
def execute_handler_test_on_error(
tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker, error: Exception
):
bulk_download_service = BulkDownloadService(tmp_path)
bulk_download_service = BulkDownloadService()
bulk_download_service.start(mock_invoker)
bulk_download_service.handler([mock_image_dto.image_name], None, None)
@ -395,7 +395,7 @@ def execute_handler_test_on_error(
def test_delete(tmp_path: Path):
"""Test that the delete method removes the bulk download file."""
bulk_download_service = BulkDownloadService(tmp_path)
bulk_download_service = BulkDownloadService()
mock_file: Path = tmp_path / "bulk_downloads" / "test.zip"
mock_file.write_text("contents")
@ -409,7 +409,7 @@ def test_delete(tmp_path: Path):
def test_stop(tmp_path: Path):
"""Test that the stop method removes the bulk download file and not any directories."""
bulk_download_service = BulkDownloadService(tmp_path)
bulk_download_service = BulkDownloadService()
mock_file: Path = tmp_path / "bulk_downloads" / "test.zip"
mock_file.write_text("contents")