mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
120 lines
4.2 KiB
Python
120 lines
4.2 KiB
Python
import os
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import pytest
|
|
from fastapi import BackgroundTasks
|
|
from fastapi.testclient import TestClient
|
|
|
|
from invokeai.app.api.dependencies import ApiDependencies
|
|
from invokeai.app.api_app import app
|
|
from invokeai.app.services.board_records.board_records_common import BoardRecord
|
|
from invokeai.app.services.invoker import Invoker
|
|
|
|
|
|
@pytest.fixture(autouse=True, scope="module")
|
|
def client(invokeai_root_dir: Path) -> TestClient:
|
|
os.environ["INVOKEAI_ROOT"] = invokeai_root_dir.as_posix()
|
|
return TestClient(app)
|
|
|
|
|
|
class MockApiDependencies(ApiDependencies):
|
|
invoker: Invoker
|
|
|
|
def __init__(self, invoker) -> None:
|
|
self.invoker = invoker
|
|
|
|
|
|
def test_download_images_from_list(monkeypatch: Any, mock_invoker: Invoker, client: TestClient) -> None:
|
|
prepare_download_images_test(monkeypatch, mock_invoker)
|
|
|
|
response = client.post("/api/v1/images/download", json={"image_names": ["test.png"]})
|
|
json_response = response.json()
|
|
assert response.status_code == 202
|
|
assert json_response["bulk_download_item_name"] == "test.zip"
|
|
|
|
|
|
def test_download_images_from_board_id_empty_image_name_list(
|
|
monkeypatch: Any, mock_invoker: Invoker, client: TestClient
|
|
) -> None:
|
|
expected_board_name = "test"
|
|
|
|
def mock_get(*args, **kwargs):
|
|
return BoardRecord(board_id="12345", board_name=expected_board_name, created_at="None", updated_at="None")
|
|
|
|
monkeypatch.setattr(mock_invoker.services.board_records, "get", mock_get)
|
|
prepare_download_images_test(monkeypatch, mock_invoker)
|
|
|
|
response = client.post("/api/v1/images/download", json={"board_id": "test"})
|
|
json_response = response.json()
|
|
assert response.status_code == 202
|
|
assert json_response["bulk_download_item_name"] == "test.zip"
|
|
|
|
|
|
def prepare_download_images_test(monkeypatch: Any, mock_invoker: Invoker) -> None:
|
|
monkeypatch.setattr("invokeai.app.api.routers.images.ApiDependencies", MockApiDependencies(mock_invoker))
|
|
monkeypatch.setattr(
|
|
"invokeai.app.api.routers.images.ApiDependencies.invoker.services.bulk_download.generate_item_id",
|
|
lambda arg: "test",
|
|
)
|
|
|
|
def mock_add_task(*args, **kwargs):
|
|
return None
|
|
|
|
monkeypatch.setattr(BackgroundTasks, "add_task", mock_add_task)
|
|
|
|
|
|
def test_download_images_with_empty_image_list_and_no_board_id(
|
|
monkeypatch: Any, mock_invoker: Invoker, client: TestClient
|
|
) -> None:
|
|
prepare_download_images_test(monkeypatch, mock_invoker)
|
|
|
|
response = client.post("/api/v1/images/download", json={"image_names": []})
|
|
|
|
assert response.status_code == 400
|
|
|
|
|
|
def test_get_bulk_download_image(tmp_path: Path, monkeypatch: Any, mock_invoker: Invoker, client: TestClient) -> None:
|
|
mock_file: Path = tmp_path / "test.zip"
|
|
mock_file.write_text("contents")
|
|
|
|
monkeypatch.setattr(mock_invoker.services.bulk_download, "get_path", lambda x: str(mock_file))
|
|
monkeypatch.setattr("invokeai.app.api.routers.images.ApiDependencies", MockApiDependencies(mock_invoker))
|
|
|
|
def mock_add_task(*args, **kwargs):
|
|
return None
|
|
|
|
monkeypatch.setattr(BackgroundTasks, "add_task", mock_add_task)
|
|
|
|
response = client.get("/api/v1/images/download/test.zip")
|
|
|
|
assert response.status_code == 200
|
|
assert response.content == b"contents"
|
|
|
|
|
|
def test_get_bulk_download_image_not_found(monkeypatch: Any, mock_invoker: Invoker, client: TestClient) -> None:
|
|
monkeypatch.setattr("invokeai.app.api.routers.images.ApiDependencies", MockApiDependencies(mock_invoker))
|
|
|
|
def mock_add_task(*args, **kwargs):
|
|
return None
|
|
|
|
monkeypatch.setattr(BackgroundTasks, "add_task", mock_add_task)
|
|
|
|
response = client.get("/api/v1/images/download/test.zip")
|
|
|
|
assert response.status_code == 404
|
|
|
|
|
|
def test_get_bulk_download_image_image_deleted_after_response(
|
|
monkeypatch: Any, mock_invoker: Invoker, tmp_path: Path, client: TestClient
|
|
) -> None:
|
|
mock_file: Path = tmp_path / "test.zip"
|
|
mock_file.write_text("contents")
|
|
|
|
monkeypatch.setattr(mock_invoker.services.bulk_download, "get_path", lambda x: str(mock_file))
|
|
monkeypatch.setattr("invokeai.app.api.routers.images.ApiDependencies", MockApiDependencies(mock_invoker))
|
|
|
|
client.get("/api/v1/images/download/test.zip")
|
|
|
|
assert not (tmp_path / "test.zip").exists()
|