from pathlib import Path from typing import Any import pytest from fastapi import BackgroundTasks from fastapi.testclient import TestClient from invokeai.app.api.dependencies import ApiDependencies from invokeai.app.api_app import app from invokeai.app.services.board_records.board_records_common import BoardRecord from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.images.images_default import ImageService from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invoker import Invoker from invokeai.backend.util.logging import InvokeAILogger from tests.fixtures.sqlite_database import create_mock_sqlite_database client = TestClient(app) @pytest.fixture def mock_services(tmp_path: Path) -> InvocationServices: configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0) logger = InvokeAILogger.get_logger() db = create_mock_sqlite_database(configuration, logger) return InvocationServices( board_image_records=None, # type: ignore board_images=None, # type: ignore board_records=SqliteBoardRecordStorage(db=db), boards=None, # type: ignore bulk_download=BulkDownloadService(tmp_path), configuration=None, # type: ignore events=None, # type: ignore graph_execution_manager=None, # type: ignore image_files=None, # type: ignore image_records=None, # type: ignore images=ImageService(), invocation_cache=None, # type: ignore latents=None, # type: ignore logger=logger, model_manager=None, # type: ignore model_records=None, # type: ignore download_queue=None, # type: ignore model_install=None, # type: ignore names=None, # type: ignore performance_statistics=None, # type: ignore processor=None, # type: ignore queue=None, # type: ignore session_processor=None, # type: ignore session_queue=None, # type: ignore urls=None, # type: ignore workflow_records=None, # type: ignore ) @pytest.fixture() def mock_invoker(mock_services: InvocationServices) -> Invoker: return Invoker(services=mock_services) class MockApiDependencies(ApiDependencies): invoker: Invoker def __init__(self, invoker) -> None: self.invoker = invoker def test_download_images_from_list(monkeypatch: Any, mock_invoker: Invoker) -> None: prepare_download_images_test(monkeypatch, mock_invoker) def mock_uuid_string(): return "test" # You have to patch the function within the module it's being imported into. This is strange, but it works. # See http://www.gregreda.com/2021/06/28/mocking-imported-module-function-python/ monkeypatch.setattr("invokeai.app.api.routers.images.uuid_string", mock_uuid_string) response = client.post("/api/v1/images/download", json={"image_names": ["test.png"]}) json_response = response.json() assert response.status_code == 202 assert json_response["bulk_download_item_name"] == "test" def test_download_images_from_board_id_empty_image_name_list(monkeypatch: Any, mock_invoker: Invoker) -> None: expected_board_name = "test" def mock_get(*args, **kwargs): return BoardRecord(board_id="12345", board_name=expected_board_name, created_at="None", updated_at="None") monkeypatch.setattr(mock_invoker.services.board_records, "get", mock_get) prepare_download_images_test(monkeypatch, mock_invoker) response = client.post("/api/v1/images/download", json={"board_id": "test"}) json_response = response.json() assert response.status_code == 202 assert json_response["bulk_download_item_name"] == "test.zip" def prepare_download_images_test(monkeypatch: Any, mock_invoker: Invoker) -> None: monkeypatch.setattr("invokeai.app.api.routers.images.ApiDependencies", MockApiDependencies(mock_invoker)) def mock_add_task(*args, **kwargs): return None monkeypatch.setattr(BackgroundTasks, "add_task", mock_add_task) def test_download_images_with_empty_image_list_and_no_board_id(monkeypatch: Any, mock_invoker: Invoker) -> None: prepare_download_images_test(monkeypatch, mock_invoker) response = client.post("/api/v1/images/download", json={"image_names": []}) assert response.status_code == 400 def test_get_bulk_download_image(tmp_path: Path, monkeypatch: Any, mock_invoker: Invoker) -> None: mock_file: Path = tmp_path / "test.zip" mock_file.write_text("contents") monkeypatch.setattr(mock_invoker.services.bulk_download, "get_path", lambda x: str(mock_file)) monkeypatch.setattr("invokeai.app.api.routers.images.ApiDependencies", MockApiDependencies(mock_invoker)) def mock_add_task(*args, **kwargs): return None monkeypatch.setattr(BackgroundTasks, "add_task", mock_add_task) response = client.get("/api/v1/images/download/test.zip") assert response.status_code == 200 assert response.content == b"contents" def test_get_bulk_download_image_not_found(monkeypatch: Any, mock_invoker: Invoker) -> None: monkeypatch.setattr("invokeai.app.api.routers.images.ApiDependencies", MockApiDependencies(mock_invoker)) def mock_add_task(*args, **kwargs): return None monkeypatch.setattr(BackgroundTasks, "add_task", mock_add_task) response = client.get("/api/v1/images/download/test.zip") assert response.status_code == 404 def test_get_bulk_download_image_image_deleted_after_response( monkeypatch: Any, mock_invoker: Invoker, tmp_path: Path ) -> None: mock_file: Path = tmp_path / "test.zip" mock_file.write_text("contents") monkeypatch.setattr(mock_invoker.services.bulk_download, "get_path", lambda x: str(mock_file)) monkeypatch.setattr("invokeai.app.api.routers.images.ApiDependencies", MockApiDependencies(mock_invoker)) client.get("/api/v1/images/download/test.zip") assert not (tmp_path / "test.zip").exists()