mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
removing dependency on an output folder, embrace python temp folder for bulk download
This commit is contained in:
parent
0ab9fe6987
commit
037cac8154
@ -82,7 +82,7 @@ class ApiDependencies:
|
||||
board_records = SqliteBoardRecordStorage(db=db)
|
||||
boards = BoardService()
|
||||
events = FastAPIEventService(event_handler_id)
|
||||
bulk_download = BulkDownloadService(output_folder=f"{output_folder}")
|
||||
bulk_download = BulkDownloadService()
|
||||
image_records = SqliteImageRecordStorage(db=db)
|
||||
images = ImageService()
|
||||
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
|
||||
|
@ -1,6 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
|
||||
@ -18,11 +17,9 @@ class BulkDownloadBase(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, output_folder: Union[str, Path]):
|
||||
def __init__(self):
|
||||
"""
|
||||
Create BulkDownloadBase object.
|
||||
|
||||
:param output_folder: The path to the output folder where the bulk download files can be temporarily stored.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
|
@ -19,7 +19,6 @@ from .bulk_download_base import BulkDownloadBase
|
||||
|
||||
|
||||
class BulkDownloadService(BulkDownloadBase):
|
||||
__output_folder: Path
|
||||
__temp_directory: TemporaryDirectory
|
||||
__bulk_downloads_folder: Path
|
||||
__event_bus: EventServiceBase
|
||||
@ -29,15 +28,11 @@ class BulkDownloadService(BulkDownloadBase):
|
||||
self.__invoker = invoker
|
||||
self.__event_bus = invoker.services.events
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
output_folder: Union[str, Path],
|
||||
):
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize the downloader object.
|
||||
"""
|
||||
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
||||
self.__temp_directory = TemporaryDirectory(dir=self.__output_folder)
|
||||
self.__temp_directory = TemporaryDirectory()
|
||||
self.__bulk_downloads_folder = Path(self.__temp_directory.name) / "bulk_downloads"
|
||||
self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
@ -31,7 +31,7 @@ def mock_services(tmp_path: Path) -> InvocationServices:
|
||||
board_images=None, # type: ignore
|
||||
board_records=SqliteBoardRecordStorage(db=db),
|
||||
boards=None, # type: ignore
|
||||
bulk_download=BulkDownloadService(tmp_path),
|
||||
bulk_download=BulkDownloadService(),
|
||||
configuration=None, # type: ignore
|
||||
events=None, # type: ignore
|
||||
graph_execution_manager=None, # type: ignore
|
||||
|
@ -113,7 +113,7 @@ def mock_temporary_directory(monkeypatch: Any, tmp_path: Path):
|
||||
def test_get_path_when_file_exists(tmp_path: Path) -> None:
|
||||
"""Test get_path when the file exists."""
|
||||
|
||||
bulk_download_service = BulkDownloadService(tmp_path)
|
||||
bulk_download_service = BulkDownloadService()
|
||||
|
||||
# Create a directory at tmp_path/bulk_downloads
|
||||
test_bulk_downloads_dir: Path = tmp_path / "bulk_downloads"
|
||||
@ -129,7 +129,7 @@ def test_get_path_when_file_exists(tmp_path: Path) -> None:
|
||||
def test_get_path_when_file_does_not_exist(tmp_path: Path) -> None:
|
||||
"""Test get_path when the file does not exist."""
|
||||
|
||||
bulk_download_service = BulkDownloadService(tmp_path)
|
||||
bulk_download_service = BulkDownloadService()
|
||||
with pytest.raises(BulkDownloadTargetException):
|
||||
bulk_download_service.get_path("test")
|
||||
|
||||
@ -137,7 +137,7 @@ def test_get_path_when_file_does_not_exist(tmp_path: Path) -> None:
|
||||
def test_bulk_downloads_dir_created_at_start(tmp_path: Path) -> None:
|
||||
"""Test that the bulk_downloads directory is created at start."""
|
||||
|
||||
BulkDownloadService(tmp_path)
|
||||
BulkDownloadService()
|
||||
assert (tmp_path / "bulk_downloads").exists()
|
||||
|
||||
|
||||
@ -148,7 +148,7 @@ def test_handler_image_names(tmp_path: Path, monkeypatch: Any, mock_image_dto: I
|
||||
tmp_path, monkeypatch, mock_image_dto, mock_invoker
|
||||
)
|
||||
|
||||
bulk_download_service = BulkDownloadService(tmp_path)
|
||||
bulk_download_service = BulkDownloadService()
|
||||
bulk_download_service.start(mock_invoker)
|
||||
bulk_download_service.handler([mock_image_dto.image_name], None, None)
|
||||
|
||||
@ -160,7 +160,7 @@ def test_handler_image_names(tmp_path: Path, monkeypatch: Any, mock_image_dto: I
|
||||
def test_generate_id(monkeypatch: Any):
|
||||
"""Test that the generate_id method generates a unique id."""
|
||||
|
||||
bulk_download_service = BulkDownloadService("test")
|
||||
bulk_download_service = BulkDownloadService()
|
||||
|
||||
monkeypatch.setattr("invokeai.app.services.bulk_download.bulk_download_default.uuid_string", lambda: "test")
|
||||
|
||||
@ -170,7 +170,7 @@ def test_generate_id(monkeypatch: Any):
|
||||
def test_generate_id_with_board_id(monkeypatch: Any, mock_invoker: Invoker):
|
||||
"""Test that the generate_id method generates a unique id with a board id."""
|
||||
|
||||
bulk_download_service = BulkDownloadService("test")
|
||||
bulk_download_service = BulkDownloadService()
|
||||
bulk_download_service.start(mock_invoker)
|
||||
|
||||
def mock_board_get(*args, **kwargs):
|
||||
@ -186,7 +186,7 @@ def test_generate_id_with_board_id(monkeypatch: Any, mock_invoker: Invoker):
|
||||
def test_generate_id_with_default_board_id(monkeypatch: Any):
|
||||
"""Test that the generate_id method generates a unique id with a board id."""
|
||||
|
||||
bulk_download_service = BulkDownloadService("test")
|
||||
bulk_download_service = BulkDownloadService()
|
||||
|
||||
monkeypatch.setattr("invokeai.app.services.bulk_download.bulk_download_default.uuid_string", lambda: "test")
|
||||
|
||||
@ -210,7 +210,7 @@ def test_handler_board_id(tmp_path: Path, monkeypatch: Any, mock_image_dto: Imag
|
||||
|
||||
monkeypatch.setattr(mock_invoker.services.images, "get_many", mock_get_many)
|
||||
|
||||
bulk_download_service = BulkDownloadService(tmp_path)
|
||||
bulk_download_service = BulkDownloadService()
|
||||
bulk_download_service.start(mock_invoker)
|
||||
bulk_download_service.handler([], "test", None)
|
||||
|
||||
@ -231,7 +231,7 @@ def test_handler_board_id_default(tmp_path: Path, monkeypatch: Any, mock_image_d
|
||||
|
||||
monkeypatch.setattr(mock_invoker.services.images, "get_many", mock_get_many)
|
||||
|
||||
bulk_download_service = BulkDownloadService(tmp_path)
|
||||
bulk_download_service = BulkDownloadService()
|
||||
bulk_download_service.start(mock_invoker)
|
||||
bulk_download_service.handler([], "none", None)
|
||||
|
||||
@ -256,7 +256,7 @@ def test_handler_bulk_download_item_id_given(
|
||||
|
||||
monkeypatch.setattr(mock_invoker.services.images, "get_many", mock_get_many)
|
||||
|
||||
bulk_download_service = BulkDownloadService(tmp_path)
|
||||
bulk_download_service = BulkDownloadService()
|
||||
bulk_download_service.start(mock_invoker)
|
||||
bulk_download_service.handler([mock_image_dto.image_name], None, "test_id")
|
||||
|
||||
@ -380,7 +380,7 @@ def test_handler_on_generic_exception(
|
||||
def execute_handler_test_on_error(
|
||||
tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker, error: Exception
|
||||
):
|
||||
bulk_download_service = BulkDownloadService(tmp_path)
|
||||
bulk_download_service = BulkDownloadService()
|
||||
bulk_download_service.start(mock_invoker)
|
||||
bulk_download_service.handler([mock_image_dto.image_name], None, None)
|
||||
|
||||
@ -395,7 +395,7 @@ def execute_handler_test_on_error(
|
||||
def test_delete(tmp_path: Path):
|
||||
"""Test that the delete method removes the bulk download file."""
|
||||
|
||||
bulk_download_service = BulkDownloadService(tmp_path)
|
||||
bulk_download_service = BulkDownloadService()
|
||||
|
||||
mock_file: Path = tmp_path / "bulk_downloads" / "test.zip"
|
||||
mock_file.write_text("contents")
|
||||
@ -409,7 +409,7 @@ def test_delete(tmp_path: Path):
|
||||
def test_stop(tmp_path: Path):
|
||||
"""Test that the stop method removes the bulk download file and not any directories."""
|
||||
|
||||
bulk_download_service = BulkDownloadService(tmp_path)
|
||||
bulk_download_service = BulkDownloadService()
|
||||
|
||||
mock_file: Path = tmp_path / "bulk_downloads" / "test.zip"
|
||||
mock_file.write_text("contents")
|
||||
|
Loading…
Reference in New Issue
Block a user