diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 984fd8e267..95407291ec 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -82,7 +82,7 @@ class ApiDependencies: board_records = SqliteBoardRecordStorage(db=db) boards = BoardService() events = FastAPIEventService(event_handler_id) - bulk_download = BulkDownloadService(output_folder=f"{output_folder}") + bulk_download = BulkDownloadService() image_records = SqliteImageRecordStorage(db=db) images = ImageService() invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size) diff --git a/invokeai/app/services/bulk_download/bulk_download_base.py b/invokeai/app/services/bulk_download/bulk_download_base.py index 5199652ad4..d889e2ed0e 100644 --- a/invokeai/app/services/bulk_download/bulk_download_base.py +++ b/invokeai/app/services/bulk_download/bulk_download_base.py @@ -1,6 +1,5 @@ from abc import ABC, abstractmethod -from pathlib import Path -from typing import Optional, Union +from typing import Optional from invokeai.app.services.invoker import Invoker @@ -18,11 +17,9 @@ class BulkDownloadBase(ABC): """ @abstractmethod - def __init__(self, output_folder: Union[str, Path]): + def __init__(self): """ Create BulkDownloadBase object. - - :param output_folder: The path to the output folder where the bulk download files can be temporarily stored. """ @abstractmethod diff --git a/invokeai/app/services/bulk_download/bulk_download_default.py b/invokeai/app/services/bulk_download/bulk_download_default.py index 406bd7d997..4f5bfb087f 100644 --- a/invokeai/app/services/bulk_download/bulk_download_default.py +++ b/invokeai/app/services/bulk_download/bulk_download_default.py @@ -19,7 +19,6 @@ from .bulk_download_base import BulkDownloadBase class BulkDownloadService(BulkDownloadBase): - __output_folder: Path __temp_directory: TemporaryDirectory __bulk_downloads_folder: Path __event_bus: EventServiceBase @@ -29,15 +28,11 @@ class BulkDownloadService(BulkDownloadBase): self.__invoker = invoker self.__event_bus = invoker.services.events - def __init__( - self, - output_folder: Union[str, Path], - ): + def __init__(self): """ Initialize the downloader object. """ - self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder) - self.__temp_directory = TemporaryDirectory(dir=self.__output_folder) + self.__temp_directory = TemporaryDirectory() self.__bulk_downloads_folder = Path(self.__temp_directory.name) / "bulk_downloads" self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True) diff --git a/tests/app/routers/test_images.py b/tests/app/routers/test_images.py index e8521bf132..67297a116f 100644 --- a/tests/app/routers/test_images.py +++ b/tests/app/routers/test_images.py @@ -31,7 +31,7 @@ def mock_services(tmp_path: Path) -> InvocationServices: board_images=None, # type: ignore board_records=SqliteBoardRecordStorage(db=db), boards=None, # type: ignore - bulk_download=BulkDownloadService(tmp_path), + bulk_download=BulkDownloadService(), configuration=None, # type: ignore events=None, # type: ignore graph_execution_manager=None, # type: ignore diff --git a/tests/app/services/bulk_download/test_bulk_download.py b/tests/app/services/bulk_download/test_bulk_download.py index b7480091d9..3e8b7fd2eb 100644 --- a/tests/app/services/bulk_download/test_bulk_download.py +++ b/tests/app/services/bulk_download/test_bulk_download.py @@ -113,7 +113,7 @@ def mock_temporary_directory(monkeypatch: Any, tmp_path: Path): def test_get_path_when_file_exists(tmp_path: Path) -> None: """Test get_path when the file exists.""" - bulk_download_service = BulkDownloadService(tmp_path) + bulk_download_service = BulkDownloadService() # Create a directory at tmp_path/bulk_downloads test_bulk_downloads_dir: Path = tmp_path / "bulk_downloads" @@ -129,7 +129,7 @@ def test_get_path_when_file_exists(tmp_path: Path) -> None: def test_get_path_when_file_does_not_exist(tmp_path: Path) -> None: """Test get_path when the file does not exist.""" - bulk_download_service = BulkDownloadService(tmp_path) + bulk_download_service = BulkDownloadService() with pytest.raises(BulkDownloadTargetException): bulk_download_service.get_path("test") @@ -137,7 +137,7 @@ def test_get_path_when_file_does_not_exist(tmp_path: Path) -> None: def test_bulk_downloads_dir_created_at_start(tmp_path: Path) -> None: """Test that the bulk_downloads directory is created at start.""" - BulkDownloadService(tmp_path) + BulkDownloadService() assert (tmp_path / "bulk_downloads").exists() @@ -148,7 +148,7 @@ def test_handler_image_names(tmp_path: Path, monkeypatch: Any, mock_image_dto: I tmp_path, monkeypatch, mock_image_dto, mock_invoker ) - bulk_download_service = BulkDownloadService(tmp_path) + bulk_download_service = BulkDownloadService() bulk_download_service.start(mock_invoker) bulk_download_service.handler([mock_image_dto.image_name], None, None) @@ -160,7 +160,7 @@ def test_handler_image_names(tmp_path: Path, monkeypatch: Any, mock_image_dto: I def test_generate_id(monkeypatch: Any): """Test that the generate_id method generates a unique id.""" - bulk_download_service = BulkDownloadService("test") + bulk_download_service = BulkDownloadService() monkeypatch.setattr("invokeai.app.services.bulk_download.bulk_download_default.uuid_string", lambda: "test") @@ -170,7 +170,7 @@ def test_generate_id(monkeypatch: Any): def test_generate_id_with_board_id(monkeypatch: Any, mock_invoker: Invoker): """Test that the generate_id method generates a unique id with a board id.""" - bulk_download_service = BulkDownloadService("test") + bulk_download_service = BulkDownloadService() bulk_download_service.start(mock_invoker) def mock_board_get(*args, **kwargs): @@ -186,7 +186,7 @@ def test_generate_id_with_board_id(monkeypatch: Any, mock_invoker: Invoker): def test_generate_id_with_default_board_id(monkeypatch: Any): """Test that the generate_id method generates a unique id with a board id.""" - bulk_download_service = BulkDownloadService("test") + bulk_download_service = BulkDownloadService() monkeypatch.setattr("invokeai.app.services.bulk_download.bulk_download_default.uuid_string", lambda: "test") @@ -210,7 +210,7 @@ def test_handler_board_id(tmp_path: Path, monkeypatch: Any, mock_image_dto: Imag monkeypatch.setattr(mock_invoker.services.images, "get_many", mock_get_many) - bulk_download_service = BulkDownloadService(tmp_path) + bulk_download_service = BulkDownloadService() bulk_download_service.start(mock_invoker) bulk_download_service.handler([], "test", None) @@ -231,7 +231,7 @@ def test_handler_board_id_default(tmp_path: Path, monkeypatch: Any, mock_image_d monkeypatch.setattr(mock_invoker.services.images, "get_many", mock_get_many) - bulk_download_service = BulkDownloadService(tmp_path) + bulk_download_service = BulkDownloadService() bulk_download_service.start(mock_invoker) bulk_download_service.handler([], "none", None) @@ -256,7 +256,7 @@ def test_handler_bulk_download_item_id_given( monkeypatch.setattr(mock_invoker.services.images, "get_many", mock_get_many) - bulk_download_service = BulkDownloadService(tmp_path) + bulk_download_service = BulkDownloadService() bulk_download_service.start(mock_invoker) bulk_download_service.handler([mock_image_dto.image_name], None, "test_id") @@ -380,7 +380,7 @@ def test_handler_on_generic_exception( def execute_handler_test_on_error( tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker, error: Exception ): - bulk_download_service = BulkDownloadService(tmp_path) + bulk_download_service = BulkDownloadService() bulk_download_service.start(mock_invoker) bulk_download_service.handler([mock_image_dto.image_name], None, None) @@ -395,7 +395,7 @@ def execute_handler_test_on_error( def test_delete(tmp_path: Path): """Test that the delete method removes the bulk download file.""" - bulk_download_service = BulkDownloadService(tmp_path) + bulk_download_service = BulkDownloadService() mock_file: Path = tmp_path / "bulk_downloads" / "test.zip" mock_file.write_text("contents") @@ -409,7 +409,7 @@ def test_delete(tmp_path: Path): def test_stop(tmp_path: Path): """Test that the stop method removes the bulk download file and not any directories.""" - bulk_download_service = BulkDownloadService(tmp_path) + bulk_download_service = BulkDownloadService() mock_file: Path = tmp_path / "bulk_downloads" / "test.zip" mock_file.write_text("contents")