mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
removing dependency on an output folder, embrace python temp folder for bulk download
This commit is contained in:
parent
0ab9fe6987
commit
037cac8154
@ -82,7 +82,7 @@ class ApiDependencies:
|
|||||||
board_records = SqliteBoardRecordStorage(db=db)
|
board_records = SqliteBoardRecordStorage(db=db)
|
||||||
boards = BoardService()
|
boards = BoardService()
|
||||||
events = FastAPIEventService(event_handler_id)
|
events = FastAPIEventService(event_handler_id)
|
||||||
bulk_download = BulkDownloadService(output_folder=f"{output_folder}")
|
bulk_download = BulkDownloadService()
|
||||||
image_records = SqliteImageRecordStorage(db=db)
|
image_records = SqliteImageRecordStorage(db=db)
|
||||||
images = ImageService()
|
images = ImageService()
|
||||||
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
|
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from typing import Optional
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
|
|
||||||
@ -18,11 +17,9 @@ class BulkDownloadBase(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __init__(self, output_folder: Union[str, Path]):
|
def __init__(self):
|
||||||
"""
|
"""
|
||||||
Create BulkDownloadBase object.
|
Create BulkDownloadBase object.
|
||||||
|
|
||||||
:param output_folder: The path to the output folder where the bulk download files can be temporarily stored.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -19,7 +19,6 @@ from .bulk_download_base import BulkDownloadBase
|
|||||||
|
|
||||||
|
|
||||||
class BulkDownloadService(BulkDownloadBase):
|
class BulkDownloadService(BulkDownloadBase):
|
||||||
__output_folder: Path
|
|
||||||
__temp_directory: TemporaryDirectory
|
__temp_directory: TemporaryDirectory
|
||||||
__bulk_downloads_folder: Path
|
__bulk_downloads_folder: Path
|
||||||
__event_bus: EventServiceBase
|
__event_bus: EventServiceBase
|
||||||
@ -29,15 +28,11 @@ class BulkDownloadService(BulkDownloadBase):
|
|||||||
self.__invoker = invoker
|
self.__invoker = invoker
|
||||||
self.__event_bus = invoker.services.events
|
self.__event_bus = invoker.services.events
|
||||||
|
|
||||||
def __init__(
|
def __init__(self):
|
||||||
self,
|
|
||||||
output_folder: Union[str, Path],
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Initialize the downloader object.
|
Initialize the downloader object.
|
||||||
"""
|
"""
|
||||||
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
self.__temp_directory = TemporaryDirectory()
|
||||||
self.__temp_directory = TemporaryDirectory(dir=self.__output_folder)
|
|
||||||
self.__bulk_downloads_folder = Path(self.__temp_directory.name) / "bulk_downloads"
|
self.__bulk_downloads_folder = Path(self.__temp_directory.name) / "bulk_downloads"
|
||||||
self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True)
|
self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
@ -31,7 +31,7 @@ def mock_services(tmp_path: Path) -> InvocationServices:
|
|||||||
board_images=None, # type: ignore
|
board_images=None, # type: ignore
|
||||||
board_records=SqliteBoardRecordStorage(db=db),
|
board_records=SqliteBoardRecordStorage(db=db),
|
||||||
boards=None, # type: ignore
|
boards=None, # type: ignore
|
||||||
bulk_download=BulkDownloadService(tmp_path),
|
bulk_download=BulkDownloadService(),
|
||||||
configuration=None, # type: ignore
|
configuration=None, # type: ignore
|
||||||
events=None, # type: ignore
|
events=None, # type: ignore
|
||||||
graph_execution_manager=None, # type: ignore
|
graph_execution_manager=None, # type: ignore
|
||||||
|
@ -113,7 +113,7 @@ def mock_temporary_directory(monkeypatch: Any, tmp_path: Path):
|
|||||||
def test_get_path_when_file_exists(tmp_path: Path) -> None:
|
def test_get_path_when_file_exists(tmp_path: Path) -> None:
|
||||||
"""Test get_path when the file exists."""
|
"""Test get_path when the file exists."""
|
||||||
|
|
||||||
bulk_download_service = BulkDownloadService(tmp_path)
|
bulk_download_service = BulkDownloadService()
|
||||||
|
|
||||||
# Create a directory at tmp_path/bulk_downloads
|
# Create a directory at tmp_path/bulk_downloads
|
||||||
test_bulk_downloads_dir: Path = tmp_path / "bulk_downloads"
|
test_bulk_downloads_dir: Path = tmp_path / "bulk_downloads"
|
||||||
@ -129,7 +129,7 @@ def test_get_path_when_file_exists(tmp_path: Path) -> None:
|
|||||||
def test_get_path_when_file_does_not_exist(tmp_path: Path) -> None:
|
def test_get_path_when_file_does_not_exist(tmp_path: Path) -> None:
|
||||||
"""Test get_path when the file does not exist."""
|
"""Test get_path when the file does not exist."""
|
||||||
|
|
||||||
bulk_download_service = BulkDownloadService(tmp_path)
|
bulk_download_service = BulkDownloadService()
|
||||||
with pytest.raises(BulkDownloadTargetException):
|
with pytest.raises(BulkDownloadTargetException):
|
||||||
bulk_download_service.get_path("test")
|
bulk_download_service.get_path("test")
|
||||||
|
|
||||||
@ -137,7 +137,7 @@ def test_get_path_when_file_does_not_exist(tmp_path: Path) -> None:
|
|||||||
def test_bulk_downloads_dir_created_at_start(tmp_path: Path) -> None:
|
def test_bulk_downloads_dir_created_at_start(tmp_path: Path) -> None:
|
||||||
"""Test that the bulk_downloads directory is created at start."""
|
"""Test that the bulk_downloads directory is created at start."""
|
||||||
|
|
||||||
BulkDownloadService(tmp_path)
|
BulkDownloadService()
|
||||||
assert (tmp_path / "bulk_downloads").exists()
|
assert (tmp_path / "bulk_downloads").exists()
|
||||||
|
|
||||||
|
|
||||||
@ -148,7 +148,7 @@ def test_handler_image_names(tmp_path: Path, monkeypatch: Any, mock_image_dto: I
|
|||||||
tmp_path, monkeypatch, mock_image_dto, mock_invoker
|
tmp_path, monkeypatch, mock_image_dto, mock_invoker
|
||||||
)
|
)
|
||||||
|
|
||||||
bulk_download_service = BulkDownloadService(tmp_path)
|
bulk_download_service = BulkDownloadService()
|
||||||
bulk_download_service.start(mock_invoker)
|
bulk_download_service.start(mock_invoker)
|
||||||
bulk_download_service.handler([mock_image_dto.image_name], None, None)
|
bulk_download_service.handler([mock_image_dto.image_name], None, None)
|
||||||
|
|
||||||
@ -160,7 +160,7 @@ def test_handler_image_names(tmp_path: Path, monkeypatch: Any, mock_image_dto: I
|
|||||||
def test_generate_id(monkeypatch: Any):
|
def test_generate_id(monkeypatch: Any):
|
||||||
"""Test that the generate_id method generates a unique id."""
|
"""Test that the generate_id method generates a unique id."""
|
||||||
|
|
||||||
bulk_download_service = BulkDownloadService("test")
|
bulk_download_service = BulkDownloadService()
|
||||||
|
|
||||||
monkeypatch.setattr("invokeai.app.services.bulk_download.bulk_download_default.uuid_string", lambda: "test")
|
monkeypatch.setattr("invokeai.app.services.bulk_download.bulk_download_default.uuid_string", lambda: "test")
|
||||||
|
|
||||||
@ -170,7 +170,7 @@ def test_generate_id(monkeypatch: Any):
|
|||||||
def test_generate_id_with_board_id(monkeypatch: Any, mock_invoker: Invoker):
|
def test_generate_id_with_board_id(monkeypatch: Any, mock_invoker: Invoker):
|
||||||
"""Test that the generate_id method generates a unique id with a board id."""
|
"""Test that the generate_id method generates a unique id with a board id."""
|
||||||
|
|
||||||
bulk_download_service = BulkDownloadService("test")
|
bulk_download_service = BulkDownloadService()
|
||||||
bulk_download_service.start(mock_invoker)
|
bulk_download_service.start(mock_invoker)
|
||||||
|
|
||||||
def mock_board_get(*args, **kwargs):
|
def mock_board_get(*args, **kwargs):
|
||||||
@ -186,7 +186,7 @@ def test_generate_id_with_board_id(monkeypatch: Any, mock_invoker: Invoker):
|
|||||||
def test_generate_id_with_default_board_id(monkeypatch: Any):
|
def test_generate_id_with_default_board_id(monkeypatch: Any):
|
||||||
"""Test that the generate_id method generates a unique id with a board id."""
|
"""Test that the generate_id method generates a unique id with a board id."""
|
||||||
|
|
||||||
bulk_download_service = BulkDownloadService("test")
|
bulk_download_service = BulkDownloadService()
|
||||||
|
|
||||||
monkeypatch.setattr("invokeai.app.services.bulk_download.bulk_download_default.uuid_string", lambda: "test")
|
monkeypatch.setattr("invokeai.app.services.bulk_download.bulk_download_default.uuid_string", lambda: "test")
|
||||||
|
|
||||||
@ -210,7 +210,7 @@ def test_handler_board_id(tmp_path: Path, monkeypatch: Any, mock_image_dto: Imag
|
|||||||
|
|
||||||
monkeypatch.setattr(mock_invoker.services.images, "get_many", mock_get_many)
|
monkeypatch.setattr(mock_invoker.services.images, "get_many", mock_get_many)
|
||||||
|
|
||||||
bulk_download_service = BulkDownloadService(tmp_path)
|
bulk_download_service = BulkDownloadService()
|
||||||
bulk_download_service.start(mock_invoker)
|
bulk_download_service.start(mock_invoker)
|
||||||
bulk_download_service.handler([], "test", None)
|
bulk_download_service.handler([], "test", None)
|
||||||
|
|
||||||
@ -231,7 +231,7 @@ def test_handler_board_id_default(tmp_path: Path, monkeypatch: Any, mock_image_d
|
|||||||
|
|
||||||
monkeypatch.setattr(mock_invoker.services.images, "get_many", mock_get_many)
|
monkeypatch.setattr(mock_invoker.services.images, "get_many", mock_get_many)
|
||||||
|
|
||||||
bulk_download_service = BulkDownloadService(tmp_path)
|
bulk_download_service = BulkDownloadService()
|
||||||
bulk_download_service.start(mock_invoker)
|
bulk_download_service.start(mock_invoker)
|
||||||
bulk_download_service.handler([], "none", None)
|
bulk_download_service.handler([], "none", None)
|
||||||
|
|
||||||
@ -256,7 +256,7 @@ def test_handler_bulk_download_item_id_given(
|
|||||||
|
|
||||||
monkeypatch.setattr(mock_invoker.services.images, "get_many", mock_get_many)
|
monkeypatch.setattr(mock_invoker.services.images, "get_many", mock_get_many)
|
||||||
|
|
||||||
bulk_download_service = BulkDownloadService(tmp_path)
|
bulk_download_service = BulkDownloadService()
|
||||||
bulk_download_service.start(mock_invoker)
|
bulk_download_service.start(mock_invoker)
|
||||||
bulk_download_service.handler([mock_image_dto.image_name], None, "test_id")
|
bulk_download_service.handler([mock_image_dto.image_name], None, "test_id")
|
||||||
|
|
||||||
@ -380,7 +380,7 @@ def test_handler_on_generic_exception(
|
|||||||
def execute_handler_test_on_error(
|
def execute_handler_test_on_error(
|
||||||
tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker, error: Exception
|
tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker, error: Exception
|
||||||
):
|
):
|
||||||
bulk_download_service = BulkDownloadService(tmp_path)
|
bulk_download_service = BulkDownloadService()
|
||||||
bulk_download_service.start(mock_invoker)
|
bulk_download_service.start(mock_invoker)
|
||||||
bulk_download_service.handler([mock_image_dto.image_name], None, None)
|
bulk_download_service.handler([mock_image_dto.image_name], None, None)
|
||||||
|
|
||||||
@ -395,7 +395,7 @@ def execute_handler_test_on_error(
|
|||||||
def test_delete(tmp_path: Path):
|
def test_delete(tmp_path: Path):
|
||||||
"""Test that the delete method removes the bulk download file."""
|
"""Test that the delete method removes the bulk download file."""
|
||||||
|
|
||||||
bulk_download_service = BulkDownloadService(tmp_path)
|
bulk_download_service = BulkDownloadService()
|
||||||
|
|
||||||
mock_file: Path = tmp_path / "bulk_downloads" / "test.zip"
|
mock_file: Path = tmp_path / "bulk_downloads" / "test.zip"
|
||||||
mock_file.write_text("contents")
|
mock_file.write_text("contents")
|
||||||
@ -409,7 +409,7 @@ def test_delete(tmp_path: Path):
|
|||||||
def test_stop(tmp_path: Path):
|
def test_stop(tmp_path: Path):
|
||||||
"""Test that the stop method removes the bulk download file and not any directories."""
|
"""Test that the stop method removes the bulk download file and not any directories."""
|
||||||
|
|
||||||
bulk_download_service = BulkDownloadService(tmp_path)
|
bulk_download_service = BulkDownloadService()
|
||||||
|
|
||||||
mock_file: Path = tmp_path / "bulk_downloads" / "test.zip"
|
mock_file: Path = tmp_path / "bulk_downloads" / "test.zip"
|
||||||
mock_file.write_text("contents")
|
mock_file.write_text("contents")
|
||||||
|
Loading…
Reference in New Issue
Block a user