removing dependency on an output folder, embrace python temp folder for bulk download

This commit is contained in:
Stefan Tobler 2024-02-17 00:29:05 -05:00 committed by psychedelicious
parent 0ab9fe6987
commit 037cac8154
5 changed files with 19 additions and 27 deletions

View File

@ -82,7 +82,7 @@ class ApiDependencies:
board_records = SqliteBoardRecordStorage(db=db) board_records = SqliteBoardRecordStorage(db=db)
boards = BoardService() boards = BoardService()
events = FastAPIEventService(event_handler_id) events = FastAPIEventService(event_handler_id)
bulk_download = BulkDownloadService(output_folder=f"{output_folder}") bulk_download = BulkDownloadService()
image_records = SqliteImageRecordStorage(db=db) image_records = SqliteImageRecordStorage(db=db)
images = ImageService() images = ImageService()
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size) invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)

View File

@ -1,6 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from typing import Optional
from typing import Optional, Union
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
@ -18,11 +17,9 @@ class BulkDownloadBase(ABC):
""" """
@abstractmethod @abstractmethod
def __init__(self, output_folder: Union[str, Path]): def __init__(self):
""" """
Create BulkDownloadBase object. Create BulkDownloadBase object.
:param output_folder: The path to the output folder where the bulk download files can be temporarily stored.
""" """
@abstractmethod @abstractmethod

View File

@ -19,7 +19,6 @@ from .bulk_download_base import BulkDownloadBase
class BulkDownloadService(BulkDownloadBase): class BulkDownloadService(BulkDownloadBase):
__output_folder: Path
__temp_directory: TemporaryDirectory __temp_directory: TemporaryDirectory
__bulk_downloads_folder: Path __bulk_downloads_folder: Path
__event_bus: EventServiceBase __event_bus: EventServiceBase
@ -29,15 +28,11 @@ class BulkDownloadService(BulkDownloadBase):
self.__invoker = invoker self.__invoker = invoker
self.__event_bus = invoker.services.events self.__event_bus = invoker.services.events
def __init__( def __init__(self):
self,
output_folder: Union[str, Path],
):
""" """
Initialize the downloader object. Initialize the downloader object.
""" """
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder) self.__temp_directory = TemporaryDirectory()
self.__temp_directory = TemporaryDirectory(dir=self.__output_folder)
self.__bulk_downloads_folder = Path(self.__temp_directory.name) / "bulk_downloads" self.__bulk_downloads_folder = Path(self.__temp_directory.name) / "bulk_downloads"
self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True) self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True)

View File

@ -31,7 +31,7 @@ def mock_services(tmp_path: Path) -> InvocationServices:
board_images=None, # type: ignore board_images=None, # type: ignore
board_records=SqliteBoardRecordStorage(db=db), board_records=SqliteBoardRecordStorage(db=db),
boards=None, # type: ignore boards=None, # type: ignore
bulk_download=BulkDownloadService(tmp_path), bulk_download=BulkDownloadService(),
configuration=None, # type: ignore configuration=None, # type: ignore
events=None, # type: ignore events=None, # type: ignore
graph_execution_manager=None, # type: ignore graph_execution_manager=None, # type: ignore

View File

@ -113,7 +113,7 @@ def mock_temporary_directory(monkeypatch: Any, tmp_path: Path):
def test_get_path_when_file_exists(tmp_path: Path) -> None: def test_get_path_when_file_exists(tmp_path: Path) -> None:
"""Test get_path when the file exists.""" """Test get_path when the file exists."""
bulk_download_service = BulkDownloadService(tmp_path) bulk_download_service = BulkDownloadService()
# Create a directory at tmp_path/bulk_downloads # Create a directory at tmp_path/bulk_downloads
test_bulk_downloads_dir: Path = tmp_path / "bulk_downloads" test_bulk_downloads_dir: Path = tmp_path / "bulk_downloads"
@ -129,7 +129,7 @@ def test_get_path_when_file_exists(tmp_path: Path) -> None:
def test_get_path_when_file_does_not_exist(tmp_path: Path) -> None: def test_get_path_when_file_does_not_exist(tmp_path: Path) -> None:
"""Test get_path when the file does not exist.""" """Test get_path when the file does not exist."""
bulk_download_service = BulkDownloadService(tmp_path) bulk_download_service = BulkDownloadService()
with pytest.raises(BulkDownloadTargetException): with pytest.raises(BulkDownloadTargetException):
bulk_download_service.get_path("test") bulk_download_service.get_path("test")
@ -137,7 +137,7 @@ def test_get_path_when_file_does_not_exist(tmp_path: Path) -> None:
def test_bulk_downloads_dir_created_at_start(tmp_path: Path) -> None: def test_bulk_downloads_dir_created_at_start(tmp_path: Path) -> None:
"""Test that the bulk_downloads directory is created at start.""" """Test that the bulk_downloads directory is created at start."""
BulkDownloadService(tmp_path) BulkDownloadService()
assert (tmp_path / "bulk_downloads").exists() assert (tmp_path / "bulk_downloads").exists()
@ -148,7 +148,7 @@ def test_handler_image_names(tmp_path: Path, monkeypatch: Any, mock_image_dto: I
tmp_path, monkeypatch, mock_image_dto, mock_invoker tmp_path, monkeypatch, mock_image_dto, mock_invoker
) )
bulk_download_service = BulkDownloadService(tmp_path) bulk_download_service = BulkDownloadService()
bulk_download_service.start(mock_invoker) bulk_download_service.start(mock_invoker)
bulk_download_service.handler([mock_image_dto.image_name], None, None) bulk_download_service.handler([mock_image_dto.image_name], None, None)
@ -160,7 +160,7 @@ def test_handler_image_names(tmp_path: Path, monkeypatch: Any, mock_image_dto: I
def test_generate_id(monkeypatch: Any): def test_generate_id(monkeypatch: Any):
"""Test that the generate_id method generates a unique id.""" """Test that the generate_id method generates a unique id."""
bulk_download_service = BulkDownloadService("test") bulk_download_service = BulkDownloadService()
monkeypatch.setattr("invokeai.app.services.bulk_download.bulk_download_default.uuid_string", lambda: "test") monkeypatch.setattr("invokeai.app.services.bulk_download.bulk_download_default.uuid_string", lambda: "test")
@ -170,7 +170,7 @@ def test_generate_id(monkeypatch: Any):
def test_generate_id_with_board_id(monkeypatch: Any, mock_invoker: Invoker): def test_generate_id_with_board_id(monkeypatch: Any, mock_invoker: Invoker):
"""Test that the generate_id method generates a unique id with a board id.""" """Test that the generate_id method generates a unique id with a board id."""
bulk_download_service = BulkDownloadService("test") bulk_download_service = BulkDownloadService()
bulk_download_service.start(mock_invoker) bulk_download_service.start(mock_invoker)
def mock_board_get(*args, **kwargs): def mock_board_get(*args, **kwargs):
@ -186,7 +186,7 @@ def test_generate_id_with_board_id(monkeypatch: Any, mock_invoker: Invoker):
def test_generate_id_with_default_board_id(monkeypatch: Any): def test_generate_id_with_default_board_id(monkeypatch: Any):
"""Test that the generate_id method generates a unique id with a board id.""" """Test that the generate_id method generates a unique id with a board id."""
bulk_download_service = BulkDownloadService("test") bulk_download_service = BulkDownloadService()
monkeypatch.setattr("invokeai.app.services.bulk_download.bulk_download_default.uuid_string", lambda: "test") monkeypatch.setattr("invokeai.app.services.bulk_download.bulk_download_default.uuid_string", lambda: "test")
@ -210,7 +210,7 @@ def test_handler_board_id(tmp_path: Path, monkeypatch: Any, mock_image_dto: Imag
monkeypatch.setattr(mock_invoker.services.images, "get_many", mock_get_many) monkeypatch.setattr(mock_invoker.services.images, "get_many", mock_get_many)
bulk_download_service = BulkDownloadService(tmp_path) bulk_download_service = BulkDownloadService()
bulk_download_service.start(mock_invoker) bulk_download_service.start(mock_invoker)
bulk_download_service.handler([], "test", None) bulk_download_service.handler([], "test", None)
@ -231,7 +231,7 @@ def test_handler_board_id_default(tmp_path: Path, monkeypatch: Any, mock_image_d
monkeypatch.setattr(mock_invoker.services.images, "get_many", mock_get_many) monkeypatch.setattr(mock_invoker.services.images, "get_many", mock_get_many)
bulk_download_service = BulkDownloadService(tmp_path) bulk_download_service = BulkDownloadService()
bulk_download_service.start(mock_invoker) bulk_download_service.start(mock_invoker)
bulk_download_service.handler([], "none", None) bulk_download_service.handler([], "none", None)
@ -256,7 +256,7 @@ def test_handler_bulk_download_item_id_given(
monkeypatch.setattr(mock_invoker.services.images, "get_many", mock_get_many) monkeypatch.setattr(mock_invoker.services.images, "get_many", mock_get_many)
bulk_download_service = BulkDownloadService(tmp_path) bulk_download_service = BulkDownloadService()
bulk_download_service.start(mock_invoker) bulk_download_service.start(mock_invoker)
bulk_download_service.handler([mock_image_dto.image_name], None, "test_id") bulk_download_service.handler([mock_image_dto.image_name], None, "test_id")
@ -380,7 +380,7 @@ def test_handler_on_generic_exception(
def execute_handler_test_on_error( def execute_handler_test_on_error(
tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker, error: Exception tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker, error: Exception
): ):
bulk_download_service = BulkDownloadService(tmp_path) bulk_download_service = BulkDownloadService()
bulk_download_service.start(mock_invoker) bulk_download_service.start(mock_invoker)
bulk_download_service.handler([mock_image_dto.image_name], None, None) bulk_download_service.handler([mock_image_dto.image_name], None, None)
@ -395,7 +395,7 @@ def execute_handler_test_on_error(
def test_delete(tmp_path: Path): def test_delete(tmp_path: Path):
"""Test that the delete method removes the bulk download file.""" """Test that the delete method removes the bulk download file."""
bulk_download_service = BulkDownloadService(tmp_path) bulk_download_service = BulkDownloadService()
mock_file: Path = tmp_path / "bulk_downloads" / "test.zip" mock_file: Path = tmp_path / "bulk_downloads" / "test.zip"
mock_file.write_text("contents") mock_file.write_text("contents")
@ -409,7 +409,7 @@ def test_delete(tmp_path: Path):
def test_stop(tmp_path: Path): def test_stop(tmp_path: Path):
"""Test that the stop method removes the bulk download file and not any directories.""" """Test that the stop method removes the bulk download file and not any directories."""
bulk_download_service = BulkDownloadService(tmp_path) bulk_download_service = BulkDownloadService()
mock_file: Path = tmp_path / "bulk_downloads" / "test.zip" mock_file: Path = tmp_path / "bulk_downloads" / "test.zip"
mock_file.write_text("contents") mock_file.write_text("contents")