Merge branch 'ryan/remove-attention-map-saving' into ryan/regional-conditioning

This commit is contained in:
Ryan Dick
2024-03-01 11:03:04 -05:00
740 changed files with 24428 additions and 31726 deletions

View File

@ -1,174 +0,0 @@
import logging
import pytest
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
# This import must happen before other invoke imports or test in other files(!!) break
from .test_nodes import ( # isort: split
ErrorInvocation,
PromptTestInvocation,
TestEventService,
TextToImageTestInvocation,
create_edge,
wait_until,
)
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
from invokeai.app.services.invocation_processor.invocation_processor_default import DefaultInvocationProcessor
from invokeai.app.services.invocation_queue.invocation_queue_memory import MemoryInvocationQueue
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation
@pytest.fixture
def simple_graph():
g = Graph()
g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi"))
g.add_node(TextToImageTestInvocation(id="2"))
g.add_edge(create_edge("1", "prompt", "2", "prompt"))
return g
@pytest.fixture
def graph_with_subgraph():
sub_g = Graph()
sub_g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi"))
sub_g.add_node(TextToImageTestInvocation(id="2"))
sub_g.add_edge(create_edge("1", "prompt", "2", "prompt"))
g = Graph()
g.add_node(GraphInvocation(id="1", graph=sub_g))
return g
# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
# the test invocations.
@pytest.fixture
def mock_services() -> InvocationServices:
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
return InvocationServices(
board_image_records=None, # type: ignore
board_images=None, # type: ignore
board_records=None, # type: ignore
boards=None, # type: ignore
configuration=configuration,
events=TestEventService(),
graph_execution_manager=ItemStorageMemory[GraphExecutionState](),
image_files=None, # type: ignore
image_records=None, # type: ignore
images=None, # type: ignore
invocation_cache=MemoryInvocationCache(max_cache_size=0),
latents=None, # type: ignore
logger=logging, # type: ignore
model_manager=None, # type: ignore
model_records=None, # type: ignore
download_queue=None, # type: ignore
model_install=None, # type: ignore
names=None, # type: ignore
performance_statistics=InvocationStatsService(),
processor=DefaultInvocationProcessor(),
queue=MemoryInvocationQueue(),
session_processor=None, # type: ignore
session_queue=None, # type: ignore
urls=None, # type: ignore
workflow_records=None, # type: ignore
)
@pytest.fixture()
def mock_invoker(mock_services: InvocationServices) -> Invoker:
return Invoker(services=mock_services)
def test_can_create_graph_state(mock_invoker: Invoker):
g = mock_invoker.create_execution_state()
mock_invoker.stop()
assert g is not None
assert isinstance(g, GraphExecutionState)
def test_can_create_graph_state_from_graph(mock_invoker: Invoker, simple_graph):
g = mock_invoker.create_execution_state(graph=simple_graph)
mock_invoker.stop()
assert g is not None
assert isinstance(g, GraphExecutionState)
assert g.graph == simple_graph
# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
def test_can_invoke(mock_invoker: Invoker, simple_graph):
g = mock_invoker.create_execution_state(graph=simple_graph)
invocation_id = mock_invoker.invoke(
session_queue_batch_id="1",
session_queue_item_id=1,
session_queue_id=DEFAULT_QUEUE_ID,
graph_execution_state=g,
)
assert invocation_id is not None
def has_executed_any(g: GraphExecutionState):
g = mock_invoker.services.graph_execution_manager.get(g.id)
return len(g.executed) > 0
wait_until(lambda: has_executed_any(g), timeout=5, interval=1)
mock_invoker.stop()
g = mock_invoker.services.graph_execution_manager.get(g.id)
assert len(g.executed) > 0
# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
g = mock_invoker.create_execution_state(graph=simple_graph)
invocation_id = mock_invoker.invoke(
session_queue_batch_id="1",
session_queue_item_id=1,
session_queue_id=DEFAULT_QUEUE_ID,
graph_execution_state=g,
invoke_all=True,
)
assert invocation_id is not None
def has_executed_all(g: GraphExecutionState):
g = mock_invoker.services.graph_execution_manager.get(g.id)
return g.is_complete()
wait_until(lambda: has_executed_all(g), timeout=5, interval=1)
mock_invoker.stop()
g = mock_invoker.services.graph_execution_manager.get(g.id)
assert g.is_complete()
# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
def test_handles_errors(mock_invoker: Invoker):
g = mock_invoker.create_execution_state()
g.graph.add_node(ErrorInvocation(id="1"))
mock_invoker.invoke(
session_queue_batch_id="1",
session_queue_item_id=1,
session_queue_id=DEFAULT_QUEUE_ID,
graph_execution_state=g,
invoke_all=True,
)
def has_executed_all(g: GraphExecutionState):
g = mock_invoker.services.graph_execution_manager.get(g.id)
return g.is_complete()
wait_until(lambda: has_executed_all(g), timeout=5, interval=1)
mock_invoker.stop()
g = mock_invoker.services.graph_execution_manager.get(g.id)
assert g.has_error()
assert g.is_complete()
assert all((i in g.errors for i in g.source_prepared_mapping["1"]))

View File

@ -0,0 +1,119 @@
import os
from pathlib import Path
from typing import Any
import pytest
from fastapi import BackgroundTasks
from fastapi.testclient import TestClient
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.api_app import app
from invokeai.app.services.board_records.board_records_common import BoardRecord
from invokeai.app.services.invoker import Invoker
@pytest.fixture(autouse=True, scope="module")
def client(invokeai_root_dir: Path) -> TestClient:
os.environ["INVOKEAI_ROOT"] = invokeai_root_dir.as_posix()
return TestClient(app)
class MockApiDependencies(ApiDependencies):
invoker: Invoker
def __init__(self, invoker) -> None:
self.invoker = invoker
def test_download_images_from_list(monkeypatch: Any, mock_invoker: Invoker, client: TestClient) -> None:
prepare_download_images_test(monkeypatch, mock_invoker)
response = client.post("/api/v1/images/download", json={"image_names": ["test.png"]})
json_response = response.json()
assert response.status_code == 202
assert json_response["bulk_download_item_name"] == "test.zip"
def test_download_images_from_board_id_empty_image_name_list(
monkeypatch: Any, mock_invoker: Invoker, client: TestClient
) -> None:
expected_board_name = "test"
def mock_get(*args, **kwargs):
return BoardRecord(board_id="12345", board_name=expected_board_name, created_at="None", updated_at="None")
monkeypatch.setattr(mock_invoker.services.board_records, "get", mock_get)
prepare_download_images_test(monkeypatch, mock_invoker)
response = client.post("/api/v1/images/download", json={"board_id": "test"})
json_response = response.json()
assert response.status_code == 202
assert json_response["bulk_download_item_name"] == "test.zip"
def prepare_download_images_test(monkeypatch: Any, mock_invoker: Invoker) -> None:
monkeypatch.setattr("invokeai.app.api.routers.images.ApiDependencies", MockApiDependencies(mock_invoker))
monkeypatch.setattr(
"invokeai.app.api.routers.images.ApiDependencies.invoker.services.bulk_download.generate_item_id",
lambda arg: "test",
)
def mock_add_task(*args, **kwargs):
return None
monkeypatch.setattr(BackgroundTasks, "add_task", mock_add_task)
def test_download_images_with_empty_image_list_and_no_board_id(
monkeypatch: Any, mock_invoker: Invoker, client: TestClient
) -> None:
prepare_download_images_test(monkeypatch, mock_invoker)
response = client.post("/api/v1/images/download", json={"image_names": []})
assert response.status_code == 400
def test_get_bulk_download_image(tmp_path: Path, monkeypatch: Any, mock_invoker: Invoker, client: TestClient) -> None:
mock_file: Path = tmp_path / "test.zip"
mock_file.write_text("contents")
monkeypatch.setattr(mock_invoker.services.bulk_download, "get_path", lambda x: str(mock_file))
monkeypatch.setattr("invokeai.app.api.routers.images.ApiDependencies", MockApiDependencies(mock_invoker))
def mock_add_task(*args, **kwargs):
return None
monkeypatch.setattr(BackgroundTasks, "add_task", mock_add_task)
response = client.get("/api/v1/images/download/test.zip")
assert response.status_code == 200
assert response.content == b"contents"
def test_get_bulk_download_image_not_found(monkeypatch: Any, mock_invoker: Invoker, client: TestClient) -> None:
monkeypatch.setattr("invokeai.app.api.routers.images.ApiDependencies", MockApiDependencies(mock_invoker))
def mock_add_task(*args, **kwargs):
return None
monkeypatch.setattr(BackgroundTasks, "add_task", mock_add_task)
response = client.get("/api/v1/images/download/test.zip")
assert response.status_code == 404
def test_get_bulk_download_image_image_deleted_after_response(
monkeypatch: Any, mock_invoker: Invoker, tmp_path: Path, client: TestClient
) -> None:
mock_file: Path = tmp_path / "test.zip"
mock_file.write_text("contents")
monkeypatch.setattr(mock_invoker.services.bulk_download, "get_path", lambda x: str(mock_file))
monkeypatch.setattr("invokeai.app.api.routers.images.ApiDependencies", MockApiDependencies(mock_invoker))
client.get("/api/v1/images/download/test.zip")
assert not (tmp_path / "test.zip").exists()

View File

@ -0,0 +1,379 @@
import os
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any
from zipfile import ZipFile
import pytest
from invokeai.app.services.board_records.board_records_common import BoardRecord, BoardRecordNotFoundException
from invokeai.app.services.bulk_download.bulk_download_common import BulkDownloadTargetException
from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageRecordNotFoundException,
ResourceOrigin,
)
from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from tests.test_nodes import TestEventService
@pytest.fixture
def mock_image_dto() -> ImageDTO:
"""Create a mock ImageDTO."""
return ImageDTO(
image_name="mock_image.png",
board_id="12345",
image_url="None",
width=100,
height=100,
thumbnail_url="None",
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
created_at="None",
updated_at="None",
starred=False,
has_workflow=False,
is_intermediate=False,
)
@pytest.fixture(autouse=True)
def mock_temporary_directory(monkeypatch: Any, tmp_path: Path):
"""Mock the TemporaryDirectory class so that it uses the tmp_path fixture."""
class MockTemporaryDirectory(TemporaryDirectory):
def __init__(self):
super().__init__(dir=tmp_path)
self.name = tmp_path
def mock_TemporaryDirectory(*args, **kwargs):
return MockTemporaryDirectory()
monkeypatch.setattr(
"invokeai.app.services.bulk_download.bulk_download_default.TemporaryDirectory", mock_TemporaryDirectory
)
def test_get_path_when_file_exists(tmp_path: Path) -> None:
"""Test get_path when the file exists."""
bulk_download_service = BulkDownloadService()
# Create a directory at tmp_path/bulk_downloads
test_bulk_downloads_dir: Path = tmp_path / "bulk_downloads"
test_bulk_downloads_dir.mkdir(parents=True, exist_ok=True)
# Create a file at tmp_path/bulk_downloads/test.zip
test_file_path: Path = test_bulk_downloads_dir / "test.zip"
test_file_path.touch()
assert bulk_download_service.get_path("test.zip") == str(test_file_path)
def test_get_path_when_file_does_not_exist(tmp_path: Path) -> None:
"""Test get_path when the file does not exist."""
bulk_download_service = BulkDownloadService()
with pytest.raises(BulkDownloadTargetException):
bulk_download_service.get_path("test")
def test_bulk_downloads_dir_created_at_start(tmp_path: Path) -> None:
"""Test that the bulk_downloads directory is created at start."""
BulkDownloadService()
assert (tmp_path / "bulk_downloads").exists()
def test_handler_image_names(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker):
"""Test that the handler creates the zip file correctly when given a list of image names."""
expected_zip_path, expected_image_path, mock_image_contents = prepare_handler_test(
tmp_path, monkeypatch, mock_image_dto, mock_invoker
)
bulk_download_service = BulkDownloadService()
bulk_download_service.start(mock_invoker)
bulk_download_service.handler([mock_image_dto.image_name], None, None)
assert_handler_success(
expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events
)
def test_generate_id(monkeypatch: Any):
"""Test that the generate_id method generates a unique id."""
bulk_download_service = BulkDownloadService()
monkeypatch.setattr("invokeai.app.services.bulk_download.bulk_download_default.uuid_string", lambda: "test")
assert bulk_download_service.generate_item_id(None) == "test"
def test_generate_id_with_board_id(monkeypatch: Any, mock_invoker: Invoker):
"""Test that the generate_id method generates a unique id with a board id."""
bulk_download_service = BulkDownloadService()
bulk_download_service.start(mock_invoker)
def mock_board_get(*args, **kwargs):
return BoardRecord(board_id="12345", board_name="test_board_name", created_at="None", updated_at="None")
monkeypatch.setattr(mock_invoker.services.board_records, "get", mock_board_get)
monkeypatch.setattr("invokeai.app.services.bulk_download.bulk_download_default.uuid_string", lambda: "test")
assert bulk_download_service.generate_item_id("12345") == "test_board_name_test"
def test_generate_id_with_default_board_id(monkeypatch: Any):
"""Test that the generate_id method generates a unique id with a board id."""
bulk_download_service = BulkDownloadService()
monkeypatch.setattr("invokeai.app.services.bulk_download.bulk_download_default.uuid_string", lambda: "test")
assert bulk_download_service.generate_item_id("none") == "Uncategorized_test"
def test_handler_board_id(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker):
"""Test that the handler creates the zip file correctly when given a board id."""
expected_zip_path, expected_image_path, mock_image_contents = prepare_handler_test(
tmp_path, monkeypatch, mock_image_dto, mock_invoker
)
def mock_board_get(*args, **kwargs):
return BoardRecord(board_id="12345", board_name="test_board_name", created_at="None", updated_at="None")
monkeypatch.setattr(mock_invoker.services.board_records, "get", mock_board_get)
def mock_get_many(*args, **kwargs):
return OffsetPaginatedResults(limit=-1, total=1, offset=0, items=[mock_image_dto])
monkeypatch.setattr(mock_invoker.services.images, "get_many", mock_get_many)
bulk_download_service = BulkDownloadService()
bulk_download_service.start(mock_invoker)
bulk_download_service.handler([], "test", None)
assert_handler_success(
expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events
)
def test_handler_board_id_default(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker):
"""Test that the handler creates the zip file correctly when given a board id."""
_, expected_image_path, mock_image_contents = prepare_handler_test(
tmp_path, monkeypatch, mock_image_dto, mock_invoker
)
def mock_get_many(*args, **kwargs):
return OffsetPaginatedResults(limit=-1, total=1, offset=0, items=[mock_image_dto])
monkeypatch.setattr(mock_invoker.services.images, "get_many", mock_get_many)
bulk_download_service = BulkDownloadService()
bulk_download_service.start(mock_invoker)
bulk_download_service.handler([], "none", None)
expected_zip_path: Path = tmp_path / "bulk_downloads" / "test.zip"
assert_handler_success(
expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events
)
def test_handler_bulk_download_item_id_given(
tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker
):
"""Test that the handler creates the zip file correctly when given a pregenerated bulk download item id."""
_, expected_image_path, mock_image_contents = prepare_handler_test(
tmp_path, monkeypatch, mock_image_dto, mock_invoker
)
def mock_get_many(*args, **kwargs):
return OffsetPaginatedResults(limit=-1, total=1, offset=0, items=[mock_image_dto])
monkeypatch.setattr(mock_invoker.services.images, "get_many", mock_get_many)
bulk_download_service = BulkDownloadService()
bulk_download_service.start(mock_invoker)
bulk_download_service.handler([mock_image_dto.image_name], None, "test_id")
expected_zip_path: Path = tmp_path / "bulk_downloads" / "test_id.zip"
assert_handler_success(
expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events
)
def prepare_handler_test(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker):
"""Prepare the test for the handler tests."""
def mock_uuid_string():
return "test"
# You have to patch the function within the module it's being imported into. This is strange, but it works.
# See http://www.gregreda.com/2021/06/28/mocking-imported-module-function-python/
monkeypatch.setattr("invokeai.app.services.bulk_download.bulk_download_default.uuid_string", mock_uuid_string)
expected_zip_path: Path = tmp_path / "bulk_downloads" / "test.zip"
expected_image_path: Path = (
tmp_path / "bulk_downloads" / mock_image_dto.image_category.value / mock_image_dto.image_name
)
# Mock the get_dto method so that when the image dto needs to be retrieved it is returned
def mock_get_dto(*args, **kwargs):
return mock_image_dto
monkeypatch.setattr(mock_invoker.services.images, "get_dto", mock_get_dto)
# This is used when preparing all images for a given board
def mock_get_all_board_image_names_for_board(*args, **kwargs):
return [mock_image_dto.image_name]
monkeypatch.setattr(
mock_invoker.services.board_image_records,
"get_all_board_image_names_for_board",
mock_get_all_board_image_names_for_board,
)
# Create a mock image file so that the contents of the zip file are not empty
mock_image_path: Path = tmp_path / mock_image_dto.image_name
mock_image_contents: str = "Totally an image"
mock_image_path.write_text(mock_image_contents)
def mock_get_path(*args, **kwargs):
return str(mock_image_path)
monkeypatch.setattr(mock_invoker.services.images, "get_path", mock_get_path)
return expected_zip_path, expected_image_path, mock_image_contents
def assert_handler_success(
expected_zip_path: Path,
expected_image_path: Path,
mock_image_contents: str,
tmp_path: Path,
event_bus: TestEventService,
):
"""Assert that the handler was successful."""
# Check that the zip file was created
assert expected_zip_path.exists()
assert expected_zip_path.is_file()
assert expected_zip_path.stat().st_size > 0
# Check that the zip contents are expected
with ZipFile(expected_zip_path, "r") as zip_file:
zip_file.extractall(tmp_path / "bulk_downloads")
assert expected_image_path.exists()
assert expected_image_path.is_file()
assert expected_image_path.stat().st_size > 0
assert expected_image_path.read_text() == mock_image_contents
# Check that the correct events were emitted
assert len(event_bus.events) == 2
assert event_bus.events[0].event_name == "bulk_download_started"
assert event_bus.events[1].event_name == "bulk_download_completed"
assert event_bus.events[1].payload["bulk_download_item_name"] == os.path.basename(expected_zip_path)
def test_handler_on_image_not_found(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker):
"""Test that the handler emits an error event when the image is not found."""
exception: Exception = ImageRecordNotFoundException("Image not found")
def mock_get_dto(*args, **kwargs):
raise exception
monkeypatch.setattr(mock_invoker.services.images, "get_dto", mock_get_dto)
execute_handler_test_on_error(tmp_path, monkeypatch, mock_image_dto, mock_invoker, exception)
def test_handler_on_board_not_found(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker):
"""Test that the handler emits an error event when the image is not found."""
exception: Exception = BoardRecordNotFoundException("Image not found")
def mock_get_board_name(*args, **kwargs):
raise exception
monkeypatch.setattr(mock_invoker.services.images, "get_dto", mock_get_board_name)
execute_handler_test_on_error(tmp_path, monkeypatch, mock_image_dto, mock_invoker, exception)
def test_handler_on_generic_exception(
tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker
):
"""Test that the handler emits an error event when the image is not found."""
exception: Exception = Exception("Generic exception")
def mock_get_board_name(*args, **kwargs):
raise exception
monkeypatch.setattr(mock_invoker.services.images, "get_dto", mock_get_board_name)
with pytest.raises(Exception): # noqa: B017
execute_handler_test_on_error(tmp_path, monkeypatch, mock_image_dto, mock_invoker, exception)
event_bus: TestEventService = mock_invoker.services.events
assert len(event_bus.events) == 2
assert event_bus.events[0].event_name == "bulk_download_started"
assert event_bus.events[1].event_name == "bulk_download_failed"
assert event_bus.events[1].payload["error"] == exception.__str__()
def execute_handler_test_on_error(
tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker, error: Exception
):
bulk_download_service = BulkDownloadService()
bulk_download_service.start(mock_invoker)
bulk_download_service.handler([mock_image_dto.image_name], None, None)
event_bus: TestEventService = mock_invoker.services.events
assert len(event_bus.events) == 2
assert event_bus.events[0].event_name == "bulk_download_started"
assert event_bus.events[1].event_name == "bulk_download_failed"
assert event_bus.events[1].payload["error"] == error.__str__()
def test_delete(tmp_path: Path):
"""Test that the delete method removes the bulk download file."""
bulk_download_service = BulkDownloadService()
mock_file: Path = tmp_path / "bulk_downloads" / "test.zip"
mock_file.write_text("contents")
bulk_download_service.delete("test.zip")
assert (tmp_path / "bulk_downloads").exists()
assert len(os.listdir(tmp_path / "bulk_downloads")) == 0
def test_stop(tmp_path: Path):
"""Test that the stop method removes the bulk download file and not any directories."""
bulk_download_service = BulkDownloadService()
mock_file: Path = tmp_path / "bulk_downloads" / "test.zip"
mock_file.write_text("contents")
mock_dir: Path = tmp_path / "bulk_downloads" / "test"
mock_dir.mkdir(parents=True, exist_ok=True)
bulk_download_service.stop()
assert not (tmp_path / "bulk_downloads").exists()

View File

@ -1,20 +1,19 @@
"""Test the queued download facility"""
import re
import time
from pathlib import Path
from typing import Any, Dict, List
import pytest
from pydantic import BaseModel
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
from requests_testadapter import TestAdapter, TestSession
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService
from invokeai.app.services.events.events_base import EventServiceBase
from tests.test_nodes import TestEventService
# Prevent pytest deprecation warnings
TestAdapter.__test__ = False
TestAdapter.__test__ = False # type: ignore
@pytest.fixture
@ -52,28 +51,6 @@ def session() -> Session:
return sess
class DummyEvent(BaseModel):
"""Dummy Event to use with Dummy Event service."""
event_name: str
payload: Dict[str, Any]
# A dummy event service for testing event issuing
class DummyEventService(EventServiceBase):
"""Dummy event service for testing."""
events: List[DummyEvent]
def __init__(self) -> None:
super().__init__()
self.events = []
def dispatch(self, event_name: str, payload: Any) -> None:
"""Dispatch an event by appending it to self.events."""
self.events.append(DummyEvent(event_name=payload["event"], payload=payload["data"]))
def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
events = set()
@ -125,7 +102,7 @@ def test_errors(tmp_path: Path, session: Session) -> None:
def test_event_bus(tmp_path: Path, session: Session) -> None:
event_bus = DummyEventService()
event_bus = TestEventService()
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
queue.start()
@ -190,8 +167,9 @@ def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None:
queue.stop()
@pytest.mark.timeout(timeout=15, method="thread")
def test_cancel(tmp_path: Path, session: Session) -> None:
event_bus = DummyEventService()
event_bus = TestEventService()
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
queue.start()
@ -205,6 +183,9 @@ def test_cancel(tmp_path: Path, session: Session) -> None:
nonlocal cancelled
cancelled = True
def handler(signum, frame):
raise TimeoutError("Join took too long to return")
job = queue.download(
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
dest=tmp_path,

View File

@ -20,7 +20,7 @@ from invokeai.app.services.model_install import (
)
from invokeai.app.services.model_records import UnknownModelException
from invokeai.backend.model_manager.config import BaseModelType, ModelFormat, ModelType
from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
OS = platform.uname().system
@ -31,6 +31,7 @@ def test_registration(mm2_installer: ModelInstallServiceBase, embedding_file: Pa
assert len(matches) == 0
key = mm2_installer.register_path(embedding_file)
assert key is not None
assert key != "<NOKEY>"
assert len(key) == 32
@ -58,12 +59,13 @@ def test_registration_meta_override_fail(mm2_installer: ModelInstallServiceBase,
def test_registration_meta_override_succeed(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
store = mm2_installer.record_store
key = mm2_installer.register_path(
embedding_file, {"name": "banana_sushi", "source": "fake/repo_id", "current_hash": "New Hash"}
embedding_file, {"name": "banana_sushi", "source": "fake/repo_id", "current_hash": "New Hash", "key": "xyzzy"}
)
model_record = store.get_model(key)
assert model_record.name == "banana_sushi"
assert model_record.source == "fake/repo_id"
assert model_record.current_hash == "New Hash"
assert model_record.key == "xyzzy"
def test_install(
@ -129,6 +131,7 @@ def test_background_install(
model_record = mm2_installer.record_store.get_model(key)
assert model_record is not None
assert model_record.path == destination
assert model_record.key != "<NOKEY>"
assert Path(mm2_app_config.models_dir / model_record.path).exists()
# see if metadata was properly passed through
@ -196,6 +199,7 @@ def test_delete_register(
store.get_model(key)
@pytest.mark.timeout(timeout=20, method="thread")
def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
source = URLModelSource(url=Url("https://www.test.foo/download/test_embedding.safetensors"))
@ -221,6 +225,7 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config:
assert event_names == ["model_install_downloading", "model_install_running", "model_install_completed"]
@pytest.mark.timeout(timeout=20, method="thread")
def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo"))
@ -256,4 +261,4 @@ def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: In
assert job.error_type == "HTTPError"
assert job.error
assert "NOT FOUND" in job.error
assert "Traceback" in job.error
assert job.error_traceback.startswith("Traceback")

View File

@ -8,6 +8,7 @@ from typing import Any
import pytest
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_metadata import ModelMetadataStoreSQL
from invokeai.app.services.model_records import (
DuplicateModelException,
ModelRecordOrderBy,
@ -25,7 +26,7 @@ from invokeai.backend.model_manager.config import (
)
from invokeai.backend.model_manager.metadata import BaseMetadata
from invokeai.backend.util.logging import InvokeAILogger
from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
from tests.fixtures.sqlite_database import create_mock_sqlite_database
@ -36,7 +37,7 @@ def store(
config = InvokeAIAppConfig(root=datadir)
logger = InvokeAILogger.get_logger(config=config)
db = create_mock_sqlite_database(config, logger)
return ModelRecordServiceSQL(db)
return ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
def example_config() -> TextualInversionConfig:

View File

@ -1,7 +1,7 @@
import pytest
import torch
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager import BaseModelType, ModelType, SubModelType
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher
from invokeai.backend.util.test_utils import install_and_load_model

View File

@ -0,0 +1,30 @@
"""
Test model loading
"""
from pathlib import Path
from invokeai.app.services.model_manager import ModelManagerServiceBase
from invokeai.backend.textual_inversion import TextualInversionModelRaw
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
def test_loading(mm2_model_manager: ModelManagerServiceBase, embedding_file: Path):
store = mm2_model_manager.store
matches = store.search_by_attr(model_name="test_embedding")
assert len(matches) == 0
key = mm2_model_manager.install.register_path(embedding_file)
loaded_model = mm2_model_manager.load_model_by_config(store.get_model(key))
assert loaded_model is not None
assert loaded_model.config.key == key
with loaded_model as model:
assert isinstance(model, TextualInversionModelRaw)
loaded_model_2 = mm2_model_manager.load_model_by_key(key)
assert loaded_model.config.key == loaded_model_2.config.key
loaded_model_3 = mm2_model_manager.load_model_by_attr(
model_name=loaded_model.config.name,
model_type=loaded_model.config.type,
base_model=loaded_model.config.base,
)
assert loaded_model.config.key == loaded_model_3.config.key

View File

@ -2,27 +2,32 @@
import os
import shutil
import time
from pathlib import Path
from typing import Any, Dict, List
import pytest
from pydantic import BaseModel
from pytest import FixtureRequest
from requests.sessions import Session
from requests_testadapter import TestAdapter, TestSession
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadQueueService
from invokeai.app.services.download import DownloadQueueService, DownloadQueueServiceBase
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase
from invokeai.app.services.model_records import ModelRecordServiceSQL
from invokeai.app.services.model_load import ModelLoadService, ModelLoadServiceBase
from invokeai.app.services.model_manager import ModelManagerService, ModelManagerServiceBase
from invokeai.app.services.model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
from invokeai.backend.model_manager.config import (
BaseModelType,
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.metadata import ModelMetadataStore
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache
from invokeai.backend.util.logging import InvokeAILogger
from tests.backend.model_manager_2.model_metadata.metadata_examples import (
from tests.backend.model_manager.model_metadata.metadata_examples import (
RepoCivitaiModelMetadata1,
RepoCivitaiVersionMetadata1,
RepoHFMetadata1,
@ -85,15 +90,77 @@ def mm2_app_config(mm2_root_dir: Path) -> InvokeAIAppConfig:
app_config = InvokeAIAppConfig(
root=mm2_root_dir,
models_dir=mm2_root_dir / "models",
log_level="info",
)
return app_config
@pytest.fixture
def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL:
def mm2_download_queue(mm2_session: Session, request: FixtureRequest) -> DownloadQueueServiceBase:
download_queue = DownloadQueueService(requests_session=mm2_session)
download_queue.start()
def stop_queue() -> None:
download_queue.stop()
request.addfinalizer(stop_queue)
return download_queue
@pytest.fixture
def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStoreBase:
return mm2_record_store.metadata_store
@pytest.fixture
def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase:
ram_cache = ModelCache(
logger=InvokeAILogger.get_logger(),
max_cache_size=mm2_app_config.ram_cache_size,
max_vram_cache_size=mm2_app_config.vram_cache_size,
)
convert_cache = ModelConvertCache(mm2_app_config.models_convert_cache_path)
return ModelLoadService(
app_config=mm2_app_config,
ram_cache=ram_cache,
convert_cache=convert_cache,
)
@pytest.fixture
def mm2_installer(
mm2_app_config: InvokeAIAppConfig,
mm2_download_queue: DownloadQueueServiceBase,
mm2_session: Session,
request: FixtureRequest,
) -> ModelInstallServiceBase:
logger = InvokeAILogger.get_logger()
db = create_mock_sqlite_database(mm2_app_config, logger)
events = DummyEventService()
store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
installer = ModelInstallService(
app_config=mm2_app_config,
record_store=store,
download_queue=mm2_download_queue,
event_bus=events,
session=mm2_session,
)
installer.start()
def stop_installer() -> None:
installer.stop()
time.sleep(0.1) # avoid error message from the logger when it is closed before thread prints final message
request.addfinalizer(stop_installer)
return installer
@pytest.fixture
def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
logger = InvokeAILogger.get_logger(config=mm2_app_config)
db = create_mock_sqlite_database(mm2_app_config, logger)
store = ModelRecordServiceSQL(db)
store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
# add five simple config records to the database
raw1 = {
"path": "/tmp/foo1",
@ -152,15 +219,16 @@ def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL
@pytest.fixture
def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStore:
db = mm2_record_store._db # to ensure we are sharing the same database
return ModelMetadataStore(db)
def mm2_model_manager(
mm2_record_store: ModelRecordServiceBase, mm2_installer: ModelInstallServiceBase, mm2_loader: ModelLoadServiceBase
) -> ModelManagerServiceBase:
return ModelManagerService(store=mm2_record_store, install=mm2_installer, load=mm2_loader)
@pytest.fixture
def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
"""This fixtures defines a series of mock URLs for testing download and installation."""
sess = TestSession()
sess: Session = TestSession()
sess.mount(
"https://test.com/missing_model.safetensors",
TestAdapter(
@ -240,26 +308,3 @@ def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
),
)
return sess
@pytest.fixture
def mm2_installer(mm2_app_config: InvokeAIAppConfig, mm2_session: Session) -> ModelInstallServiceBase:
logger = InvokeAILogger.get_logger()
db = create_mock_sqlite_database(mm2_app_config, logger)
events = DummyEventService()
store = ModelRecordServiceSQL(db)
metadata_store = ModelMetadataStore(db)
download_queue = DownloadQueueService(requests_session=mm2_session)
download_queue.start()
installer = ModelInstallService(
app_config=mm2_app_config,
record_store=store,
download_queue=download_queue,
metadata_store=metadata_store,
event_bus=events,
session=mm2_session,
)
installer.start()
return installer

View File

@ -1,6 +1,7 @@
"""
Test model metadata fetching and storage.
"""
import datetime
from pathlib import Path
@ -8,6 +9,7 @@ import pytest
from pydantic.networks import HttpUrl
from requests.sessions import Session
from invokeai.app.services.model_metadata import ModelMetadataStoreBase
from invokeai.backend.model_manager.config import ModelRepoVariant
from invokeai.backend.model_manager.metadata import (
CivitaiMetadata,
@ -15,14 +17,13 @@ from invokeai.backend.model_manager.metadata import (
CommercialUsage,
HuggingFaceMetadata,
HuggingFaceMetadataFetch,
ModelMetadataStore,
UnknownMetadataException,
)
from invokeai.backend.model_manager.util import select_hf_files
from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStore) -> None:
def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStoreBase) -> None:
tags = {"text-to-image", "diffusers"}
input_metadata = HuggingFaceMetadata(
name="sdxl-vae",
@ -40,7 +41,7 @@ def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStore) -> None:
assert mm2_metadata_store.list_tags() == tags
def test_metadata_store_update(mm2_metadata_store: ModelMetadataStore) -> None:
def test_metadata_store_update(mm2_metadata_store: ModelMetadataStoreBase) -> None:
input_metadata = HuggingFaceMetadata(
name="sdxl-vae",
author="stabilityai",
@ -57,7 +58,7 @@ def test_metadata_store_update(mm2_metadata_store: ModelMetadataStore) -> None:
assert input_metadata == output_metadata
def test_metadata_search(mm2_metadata_store: ModelMetadataStore) -> None:
def test_metadata_search(mm2_metadata_store: ModelMetadataStoreBase) -> None:
metadata1 = HuggingFaceMetadata(
name="sdxl-vae",
author="stabilityai",
@ -133,7 +134,7 @@ def test_metadata_civitai_fetch(mm2_session: Session) -> None:
assert metadata.id == 215485
assert metadata.author == "test_author" # note that this is not the same as the original from Civitai
assert metadata.allow_commercial_use # changed to make sure we are reading locally not remotely
assert metadata.restrictions.AllowCommercialUse == CommercialUsage("RentCivit")
assert CommercialUsage("RentCivit") in metadata.restrictions.AllowCommercialUse
assert metadata.version_id == 242807
assert metadata.tags == {"tool", "turbo", "sdxl turbo"}

View File

@ -1,6 +1,6 @@
import pytest
from invokeai.backend.model_management.libc_util import LibcUtil, Struct_mallinfo2
from invokeai.backend.model_manager.util.libc_util import LibcUtil, Struct_mallinfo2
def test_libc_util_mallinfo2():

View File

@ -5,8 +5,8 @@
import pytest
import torch
from invokeai.backend.model_management.lora import ModelPatcher
from invokeai.backend.model_management.models.lora import LoRALayer, LoRAModelRaw
from invokeai.backend.lora import LoRALayer, LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher
@pytest.mark.parametrize(

View File

@ -1,7 +1,7 @@
import pytest
from invokeai.backend.model_management.libc_util import Struct_mallinfo2
from invokeai.backend.model_management.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.model_manager.util.libc_util import Struct_mallinfo2
def test_memory_snapshot_capture():
@ -26,6 +26,7 @@ snapshots = [
def test_get_pretty_snapshot_diff(snapshot_1, snapshot_2):
"""Test that get_pretty_snapshot_diff() works with various combinations of missing MemorySnapshot fields."""
msg = get_pretty_snapshot_diff(snapshot_1, snapshot_2)
print(msg)
expected_lines = 0
if snapshot_1 is not None and snapshot_2 is not None:

View File

@ -1,7 +1,7 @@
import pytest
import torch
from invokeai.backend.model_management.model_load_optimizations import _no_op, skip_torch_weight_init
from invokeai.backend.model_manager.load.optimizations import _no_op, skip_torch_weight_init
@pytest.mark.parametrize(

View File

@ -192,6 +192,7 @@ def sdxl_base_files() -> List[Path]:
"text_encoder/model.onnx",
"text_encoder_2/config.json",
"text_encoder_2/model.onnx",
"text_encoder_2/model.onnx_data",
"tokenizer/merges.txt",
"tokenizer/special_tokens_map.json",
"tokenizer/tokenizer_config.json",
@ -202,6 +203,7 @@ def sdxl_base_files() -> List[Path]:
"tokenizer_2/vocab.json",
"unet/config.json",
"unet/model.onnx",
"unet/model.onnx_data",
"vae_decoder/config.json",
"vae_decoder/model.onnx",
"vae_encoder/config.json",

View File

@ -1,6 +1,7 @@
"""
Test interaction of logging with configuration system.
"""
import io
import logging
import re

View File

@ -4,4 +4,67 @@
# We import the model_installer and torch_device fixtures here so that they can be used by all tests. Flake8 does not
# play well with fixtures (F401 and F811), so this is cleaner than importing in all files that use these fixtures.
from invokeai.backend.util.test_utils import model_installer, torch_device # noqa: F401
import logging
import shutil
from pathlib import Path
import pytest
from invokeai.app.services.board_image_records.board_image_records_sqlite import SqliteBoardImageRecordStorage
from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage
from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.images.images_default import ImageService
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
from invokeai.app.services.invoker import Invoker
from invokeai.backend.util.logging import InvokeAILogger
from tests.fixtures.sqlite_database import create_mock_sqlite_database # noqa: F401
from tests.test_nodes import TestEventService
@pytest.fixture
def mock_services() -> InvocationServices:
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
logger = InvokeAILogger.get_logger()
db = create_mock_sqlite_database(configuration, logger)
# NOTE: none of these are actually called by the test invocations
return InvocationServices(
board_image_records=SqliteBoardImageRecordStorage(db=db),
board_images=None, # type: ignore
board_records=SqliteBoardRecordStorage(db=db),
boards=None, # type: ignore
bulk_download=BulkDownloadService(),
configuration=configuration,
events=TestEventService(),
image_files=None, # type: ignore
image_records=None, # type: ignore
images=ImageService(),
invocation_cache=MemoryInvocationCache(max_cache_size=0),
logger=logging, # type: ignore
model_manager=None, # type: ignore
download_queue=None, # type: ignore
names=None, # type: ignore
performance_statistics=InvocationStatsService(),
session_processor=None, # type: ignore
session_queue=None, # type: ignore
urls=None, # type: ignore
workflow_records=None, # type: ignore
tensors=None, # type: ignore
conditioning=None, # type: ignore
)
@pytest.fixture()
def mock_invoker(mock_services: InvocationServices) -> Invoker:
return Invoker(services=mock_services)
@pytest.fixture(scope="module")
def invokeai_root_dir(tmp_path_factory) -> Path:
root_template = Path(__file__).parent.resolve() / "backend/model_manager/data/invokeai_root"
temp_dir: Path = tmp_path_factory.mktemp("data") / "invokeai_root"
shutil.copytree(root_template, temp_dir)
return temp_dir

View File

@ -1,27 +1,18 @@
import logging
from typing import Optional
from unittest.mock import Mock
import pytest
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
# This import must happen before other invoke imports or test in other files(!!) break
from .test_nodes import ( # isort: split
PromptCollectionTestInvocation,
PromptTestInvocation,
TestEventService,
TextToImageTestInvocation,
)
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from invokeai.app.invocations.collections import RangeInvocation
from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
from invokeai.app.services.invocation_processor.invocation_processor_default import DefaultInvocationProcessor
from invokeai.app.services.invocation_queue.invocation_queue_memory import MemoryInvocationQueue
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID
from invokeai.app.services.shared.graph import (
CollectInvocation,
Graph,
@ -29,11 +20,11 @@ from invokeai.app.services.shared.graph import (
IterateInvocation,
)
from .test_invoker import create_edge
from .test_nodes import create_edge
@pytest.fixture
def simple_graph():
def simple_graph() -> Graph:
g = Graph()
g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi"))
g.add_node(TextToImageTestInvocation(id="2"))
@ -41,69 +32,23 @@ def simple_graph():
return g
# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
# the test invocations.
@pytest.fixture
def mock_services() -> InvocationServices:
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
# NOTE: none of these are actually called by the test invocations
graph_execution_manager = ItemStorageMemory[GraphExecutionState]()
return InvocationServices(
board_image_records=None, # type: ignore
board_images=None, # type: ignore
board_records=None, # type: ignore
boards=None, # type: ignore
configuration=configuration,
events=TestEventService(),
graph_execution_manager=graph_execution_manager,
image_files=None, # type: ignore
image_records=None, # type: ignore
images=None, # type: ignore
invocation_cache=MemoryInvocationCache(max_cache_size=0),
latents=None, # type: ignore
logger=logging, # type: ignore
model_manager=None, # type: ignore
model_records=None, # type: ignore
download_queue=None, # type: ignore
model_install=None, # type: ignore
names=None, # type: ignore
performance_statistics=InvocationStatsService(),
processor=DefaultInvocationProcessor(),
queue=MemoryInvocationQueue(),
session_processor=None, # type: ignore
session_queue=None, # type: ignore
urls=None, # type: ignore
workflow_records=None, # type: ignore
)
def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]:
def invoke_next(g: GraphExecutionState) -> tuple[Optional[BaseInvocation], Optional[BaseInvocationOutput]]:
n = g.next()
if n is None:
return (None, None)
print(f"invoking {n.id}: {type(n)}")
o = n.invoke(
InvocationContext(
queue_batch_id="1",
queue_item_id=1,
queue_id=DEFAULT_QUEUE_ID,
services=services,
graph_execution_state_id="1",
workflow=None,
)
)
o = n.invoke(Mock(InvocationContext))
g.complete(n.id, o)
return (n, o)
def test_graph_state_executes_in_order(simple_graph, mock_services):
def test_graph_state_executes_in_order(simple_graph: Graph):
g = GraphExecutionState(graph=simple_graph)
n1 = invoke_next(g, mock_services)
n2 = invoke_next(g, mock_services)
n1 = invoke_next(g)
n2 = invoke_next(g)
n3 = g.next()
assert g.prepared_source_mapping[n1[0].id] == "1"
@ -113,18 +58,18 @@ def test_graph_state_executes_in_order(simple_graph, mock_services):
assert n2[0].prompt == n1[0].prompt
def test_graph_is_complete(simple_graph, mock_services):
def test_graph_is_complete(simple_graph: Graph):
g = GraphExecutionState(graph=simple_graph)
_ = invoke_next(g, mock_services)
_ = invoke_next(g, mock_services)
_ = invoke_next(g)
_ = invoke_next(g)
_ = g.next()
assert g.is_complete()
def test_graph_is_not_complete(simple_graph, mock_services):
def test_graph_is_not_complete(simple_graph: Graph):
g = GraphExecutionState(graph=simple_graph)
_ = invoke_next(g, mock_services)
_ = invoke_next(g)
_ = g.next()
assert not g.is_complete()
@ -133,7 +78,7 @@ def test_graph_is_not_complete(simple_graph, mock_services):
# TODO: test completion with iterators/subgraphs
def test_graph_state_expands_iterator(mock_services):
def test_graph_state_expands_iterator():
graph = Graph()
graph.add_node(RangeInvocation(id="0", start=0, stop=3, step=1))
graph.add_node(IterateInvocation(id="1"))
@ -145,7 +90,7 @@ def test_graph_state_expands_iterator(mock_services):
g = GraphExecutionState(graph=graph)
while not g.is_complete():
invoke_next(g, mock_services)
invoke_next(g)
prepared_add_nodes = g.source_prepared_mapping["3"]
results = {g.results[n].value for n in prepared_add_nodes}
@ -153,7 +98,7 @@ def test_graph_state_expands_iterator(mock_services):
assert results == expected
def test_graph_state_collects(mock_services):
def test_graph_state_collects():
graph = Graph()
test_prompts = ["Banana sushi", "Cat sushi"]
graph.add_node(PromptCollectionTestInvocation(id="1", collection=list(test_prompts)))
@ -165,19 +110,19 @@ def test_graph_state_collects(mock_services):
graph.add_edge(create_edge("3", "prompt", "4", "item"))
g = GraphExecutionState(graph=graph)
_ = invoke_next(g, mock_services)
_ = invoke_next(g, mock_services)
_ = invoke_next(g, mock_services)
_ = invoke_next(g, mock_services)
_ = invoke_next(g, mock_services)
n6 = invoke_next(g, mock_services)
_ = invoke_next(g)
_ = invoke_next(g)
_ = invoke_next(g)
_ = invoke_next(g)
_ = invoke_next(g)
n6 = invoke_next(g)
assert isinstance(n6[0], CollectInvocation)
assert sorted(g.results[n6[0].id].collection) == sorted(test_prompts)
def test_graph_state_prepares_eagerly(mock_services):
def test_graph_state_prepares_eagerly():
"""Tests that all prepareable nodes are prepared"""
graph = Graph()
@ -206,7 +151,7 @@ def test_graph_state_prepares_eagerly(mock_services):
assert "prompt_iterated" not in g.source_prepared_mapping
def test_graph_executes_depth_first(mock_services):
def test_graph_executes_depth_first():
"""Tests that the graph executes depth-first, executing a branch as far as possible before moving to the next branch"""
graph = Graph()
@ -220,14 +165,14 @@ def test_graph_executes_depth_first(mock_services):
graph.add_edge(create_edge("prompt_iterated", "prompt", "prompt_successor", "prompt"))
g = GraphExecutionState(graph=graph)
_ = invoke_next(g, mock_services)
_ = invoke_next(g, mock_services)
_ = invoke_next(g, mock_services)
_ = invoke_next(g, mock_services)
_ = invoke_next(g)
_ = invoke_next(g)
_ = invoke_next(g)
_ = invoke_next(g)
# Because ordering is not guaranteed, we cannot compare results directly.
# Instead, we must count the number of results.
def get_completed_count(g, id):
def get_completed_count(g: GraphExecutionState, id: str):
ids = list(g.source_prepared_mapping[id])
completed_ids = [i for i in g.executed if i in ids]
return len(completed_ids)
@ -236,17 +181,17 @@ def test_graph_executes_depth_first(mock_services):
assert get_completed_count(g, "prompt_iterated") == 1
assert get_completed_count(g, "prompt_successor") == 0
_ = invoke_next(g, mock_services)
_ = invoke_next(g)
assert get_completed_count(g, "prompt_iterated") == 1
assert get_completed_count(g, "prompt_successor") == 1
_ = invoke_next(g, mock_services)
_ = invoke_next(g)
assert get_completed_count(g, "prompt_iterated") == 2
assert get_completed_count(g, "prompt_successor") == 1
_ = invoke_next(g, mock_services)
_ = invoke_next(g)
assert get_completed_count(g, "prompt_iterated") == 2
assert get_completed_count(g, "prompt_successor") == 2

View File

@ -1,47 +0,0 @@
from pathlib import Path
import pytest
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.backend import BaseModelType, ModelManager, ModelType, SubModelType
BASIC_MODEL_NAME = ("SDXL base", BaseModelType.StableDiffusionXL, ModelType.Main)
VAE_OVERRIDE_MODEL_NAME = ("SDXL with VAE", BaseModelType.StableDiffusionXL, ModelType.Main)
VAE_NULL_OVERRIDE_MODEL_NAME = ("SDXL with empty VAE", BaseModelType.StableDiffusionXL, ModelType.Main)
@pytest.fixture
def model_manager(datadir) -> ModelManager:
InvokeAIAppConfig.get_config(root=datadir)
return ModelManager(datadir / "configs" / "relative_sub.models.yaml")
def test_get_model_names(model_manager: ModelManager):
names = model_manager.model_names()
assert names[:2] == [BASIC_MODEL_NAME, VAE_OVERRIDE_MODEL_NAME]
def test_get_model_path_for_diffusers(model_manager: ModelManager, datadir: Path):
model_config = model_manager._get_model_config(BASIC_MODEL_NAME[1], BASIC_MODEL_NAME[0], BASIC_MODEL_NAME[2])
top_model_path, is_override = model_manager._get_model_path(model_config)
expected_model_path = datadir / "models" / "sdxl" / "main" / "SDXL base 1_0"
assert top_model_path == expected_model_path
assert not is_override
def test_get_model_path_for_overridden_vae(model_manager: ModelManager, datadir: Path):
model_config = model_manager._get_model_config(
VAE_OVERRIDE_MODEL_NAME[1], VAE_OVERRIDE_MODEL_NAME[0], VAE_OVERRIDE_MODEL_NAME[2]
)
vae_model_path, is_override = model_manager._get_model_path(model_config, SubModelType.Vae)
expected_vae_path = datadir / "models" / "sdxl" / "vae" / "sdxl-vae-fp16-fix"
assert vae_model_path == expected_vae_path
assert is_override
def test_get_model_path_for_null_overridden_vae(model_manager: ModelManager, datadir: Path):
model_config = model_manager._get_model_config(
VAE_NULL_OVERRIDE_MODEL_NAME[1], VAE_NULL_OVERRIDE_MODEL_NAME[0], VAE_NULL_OVERRIDE_MODEL_NAME[2]
)
vae_model_path, is_override = model_manager._get_model_path(model_config, SubModelType.Vae)
assert not is_override

View File

@ -2,8 +2,8 @@ from pathlib import Path
import pytest
from invokeai.backend import BaseModelType
from invokeai.backend.model_management.model_probe import VaeFolderProbe
from invokeai.backend.model_manager import BaseModelType, ModelRepoVariant
from invokeai.backend.model_manager.probe import VaeFolderProbe
@pytest.mark.parametrize(
@ -20,3 +20,11 @@ def test_get_base_type(vae_path: str, expected_type: BaseModelType, datadir: Pat
probe = VaeFolderProbe(sd1_vae_path)
base_type = probe.get_base_type()
assert base_type == expected_type
repo_variant = probe.get_repo_variant()
assert repo_variant == ModelRepoVariant.DEFAULT
def test_repo_variant(datadir: Path):
probe = VaeFolderProbe(datadir / "vae" / "taesdxl-fp16")
repo_variant = probe.get_repo_variant()
assert repo_variant == ModelRepoVariant.FP16

View File

@ -0,0 +1,37 @@
{
"_class_name": "AutoencoderTiny",
"_diffusers_version": "0.20.0.dev0",
"act_fn": "relu",
"decoder_block_out_channels": [
64,
64,
64,
64
],
"encoder_block_out_channels": [
64,
64,
64,
64
],
"force_upcast": false,
"in_channels": 3,
"latent_channels": 4,
"latent_magnitude": 3,
"latent_shift": 0.5,
"num_decoder_blocks": [
3,
3,
3,
1
],
"num_encoder_blocks": [
1,
3,
3,
3
],
"out_channels": 3,
"scaling_factor": 1.0,
"upsampling_scaling_factor": 2
}

View File

@ -8,8 +8,6 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
invocation_output,
)
from invokeai.app.invocations.image import ShowImageInvocation
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
from invokeai.app.invocations.primitives import (
FloatCollectionInvocation,
FloatInvocation,
@ -17,13 +15,11 @@ from invokeai.app.invocations.primitives import (
StringInvocation,
)
from invokeai.app.invocations.upscale import ESRGANInvocation
from invokeai.app.services.shared.default_graphs import create_text_to_image
from invokeai.app.services.shared.graph import (
CollectInvocation,
Edge,
EdgeConnection,
Graph,
GraphInvocation,
InvalidEdgeError,
IterateInvocation,
NodeAlreadyInGraphError,
@ -425,21 +421,6 @@ def test_graph_invalid_if_edges_reference_missing_nodes():
assert g.is_valid() is False
def test_graph_invalid_if_subgraph_invalid():
g = Graph()
n1 = GraphInvocation(id="1")
n1.graph = Graph()
n1_1 = TextToImageTestInvocation(id="2", prompt="Banana sushi")
n1.graph.nodes[n1_1.id] = n1_1
e1 = create_edge("1", "image", "2", "image")
n1.graph.edges.append(e1)
g.nodes[n1.id] = n1
assert g.is_valid() is False
def test_graph_invalid_if_has_cycle():
g = Graph()
n1 = ESRGANInvocation(id="1")
@ -466,110 +447,6 @@ def test_graph_invalid_with_invalid_connection():
assert g.is_valid() is False
# TODO: Subgraph operations
def test_graph_gets_subgraph_node():
g = Graph()
n1 = GraphInvocation(id="1")
n1.graph = Graph()
n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
n1.graph.add_node(n1_1)
g.add_node(n1)
result = g.get_node("1.1")
assert result is not None
assert result.id == "1"
assert result == n1_1
def test_graph_expands_subgraph():
g = Graph()
n1 = GraphInvocation(id="1")
n1.graph = Graph()
n1_1 = AddInvocation(id="1", a=1, b=2)
n1_2 = SubtractInvocation(id="2", b=3)
n1.graph.add_node(n1_1)
n1.graph.add_node(n1_2)
n1.graph.add_edge(create_edge("1", "value", "2", "a"))
g.add_node(n1)
n2 = AddInvocation(id="2", b=5)
g.add_node(n2)
g.add_edge(create_edge("1.2", "value", "2", "a"))
dg = g.nx_graph_flat()
assert set(dg.nodes) == {"1.1", "1.2", "2"}
assert set(dg.edges) == {("1.1", "1.2"), ("1.2", "2")}
def test_graph_subgraph_t2i():
g = Graph()
n1 = GraphInvocation(id="1")
# Get text to image default graph
lg = create_text_to_image()
n1.graph = lg.graph
g.add_node(n1)
n2 = IntegerInvocation(id="2", value=512)
n3 = IntegerInvocation(id="3", value=256)
g.add_node(n2)
g.add_node(n3)
g.add_edge(create_edge("2", "value", "1.width", "value"))
g.add_edge(create_edge("3", "value", "1.height", "value"))
n4 = ShowImageInvocation(id="4")
g.add_node(n4)
g.add_edge(create_edge("1.8", "image", "4", "image"))
# Validate
dg = g.nx_graph_flat()
assert set(dg.nodes) == {"1.width", "1.height", "1.seed", "1.3", "1.4", "1.5", "1.6", "1.7", "1.8", "2", "3", "4"}
expected_edges = [(f"1.{e.source.node_id}", f"1.{e.destination.node_id}") for e in lg.graph.edges]
expected_edges.extend([("2", "1.width"), ("3", "1.height"), ("1.8", "4")])
print(expected_edges)
print(list(dg.edges))
assert set(dg.edges) == set(expected_edges)
def test_graph_fails_to_get_missing_subgraph_node():
g = Graph()
n1 = GraphInvocation(id="1")
n1.graph = Graph()
n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
n1.graph.add_node(n1_1)
g.add_node(n1)
with pytest.raises(NodeNotFoundError):
_ = g.get_node("1.2")
def test_graph_fails_to_enumerate_non_subgraph_node():
g = Graph()
n1 = GraphInvocation(id="1")
n1.graph = Graph()
n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
n1.graph.add_node(n1_1)
g.add_node(n1)
n2 = ESRGANInvocation(id="2")
g.add_node(n2)
with pytest.raises(NodeNotFoundError):
_ = g.get_node("2.1")
def test_graph_gets_networkx_graph():
g = Graph()
n1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")

View File

@ -1,15 +1,16 @@
from typing import Any, Callable, Union
from pydantic import BaseModel
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InputField,
InvocationContext,
OutputField,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import InputField, OutputField
from invokeai.app.invocations.image import ImageField
from invokeai.app.services.shared.invocation_context import InvocationContext
# Define test invocations before importing anything that uses invocations
@ -116,25 +117,22 @@ def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edg
)
class TestEvent:
event_name: str
payload: Any
class TestEvent(BaseModel):
__test__ = False # not a pytest test case
def __init__(self, event_name: str, payload: Any):
self.event_name = event_name
self.payload = payload
event_name: str
payload: Any
class TestEventService(EventServiceBase):
events: list
__test__ = False # not a pytest test case
def __init__(self):
super().__init__()
self.events = []
self.events: list[TestEvent] = []
def dispatch(self, event_name: str, payload: Any) -> None:
self.events.append(TestEvent(event_name=payload["event"], payload=payload["data"]))
pass

View File

@ -0,0 +1,172 @@
import tempfile
from dataclasses import dataclass
from pathlib import Path
import pytest
import torch
from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError
from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
@dataclass
class MockDataclass:
foo: str
def count_files(path: Path):
return len(list(path.iterdir()))
@pytest.fixture
def obj_serializer(tmp_path: Path):
return ObjectSerializerDisk[MockDataclass](tmp_path)
@pytest.fixture
def fwd_cache(tmp_path: Path):
return ObjectSerializerForwardCache(ObjectSerializerDisk[MockDataclass](tmp_path), max_cache_size=2)
def test_obj_serializer_disk_initializes(tmp_path: Path):
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path)
assert obj_serializer._output_dir == tmp_path
def test_obj_serializer_disk_saves(obj_serializer: ObjectSerializerDisk[MockDataclass]):
obj_1 = MockDataclass(foo="bar")
obj_1_name = obj_serializer.save(obj_1)
assert Path(obj_serializer._output_dir, obj_1_name).exists()
obj_2 = MockDataclass(foo="baz")
obj_2_name = obj_serializer.save(obj_2)
assert Path(obj_serializer._output_dir, obj_2_name).exists()
def test_obj_serializer_disk_loads(obj_serializer: ObjectSerializerDisk[MockDataclass]):
obj_1 = MockDataclass(foo="bar")
obj_1_name = obj_serializer.save(obj_1)
assert obj_serializer.load(obj_1_name).foo == "bar"
obj_2 = MockDataclass(foo="baz")
obj_2_name = obj_serializer.save(obj_2)
assert obj_serializer.load(obj_2_name).foo == "baz"
with pytest.raises(ObjectNotFoundError):
obj_serializer.load("nonexistent_object_name")
def test_obj_serializer_disk_deletes(obj_serializer: ObjectSerializerDisk[MockDataclass]):
obj_1 = MockDataclass(foo="bar")
obj_1_name = obj_serializer.save(obj_1)
obj_2 = MockDataclass(foo="bar")
obj_2_name = obj_serializer.save(obj_2)
obj_serializer.delete(obj_1_name)
assert not Path(obj_serializer._output_dir, obj_1_name).exists()
assert Path(obj_serializer._output_dir, obj_2_name).exists()
def test_obj_serializer_ephemeral_creates_tempdir(tmp_path: Path):
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True)
assert isinstance(obj_serializer._tempdir, tempfile.TemporaryDirectory)
assert obj_serializer._base_output_dir == tmp_path
assert obj_serializer._output_dir != tmp_path
assert obj_serializer._output_dir == Path(obj_serializer._tempdir.name)
def test_obj_serializer_ephemeral_deletes_tempdir(tmp_path: Path):
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True)
tempdir_path = obj_serializer._output_dir
del obj_serializer
assert not tempdir_path.exists()
def test_obj_serializer_ephemeral_deletes_tempdir_on_stop(tmp_path: Path):
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True)
tempdir_path = obj_serializer._output_dir
obj_serializer.stop(None) # pyright: ignore [reportArgumentType]
assert not tempdir_path.exists()
def test_obj_serializer_ephemeral_writes_to_tempdir(tmp_path: Path):
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True)
obj_1 = MockDataclass(foo="bar")
obj_1_name = obj_serializer.save(obj_1)
assert Path(obj_serializer._output_dir, obj_1_name).exists()
assert not Path(tmp_path, obj_1_name).exists()
def test_obj_serializer_disk_different_types(tmp_path: Path):
obj_serializer_1 = ObjectSerializerDisk[MockDataclass](tmp_path)
obj_1 = MockDataclass(foo="bar")
obj_1_name = obj_serializer_1.save(obj_1)
obj_1_loaded = obj_serializer_1.load(obj_1_name)
assert obj_serializer_1._obj_class_name == "MockDataclass"
assert isinstance(obj_1_loaded, MockDataclass)
assert obj_1_loaded.foo == "bar"
assert obj_1_name.startswith("MockDataclass_")
obj_serializer_2 = ObjectSerializerDisk[int](tmp_path)
obj_2_name = obj_serializer_2.save(9001)
assert obj_serializer_2._obj_class_name == "int"
assert obj_serializer_2.load(obj_2_name) == 9001
assert obj_2_name.startswith("int_")
obj_serializer_3 = ObjectSerializerDisk[str](tmp_path)
obj_3_name = obj_serializer_3.save("foo")
assert obj_serializer_3._obj_class_name == "str"
assert obj_serializer_3.load(obj_3_name) == "foo"
assert obj_3_name.startswith("str_")
obj_serializer_4 = ObjectSerializerDisk[torch.Tensor](tmp_path)
obj_4_name = obj_serializer_4.save(torch.tensor([1, 2, 3]))
obj_4_loaded = obj_serializer_4.load(obj_4_name)
assert obj_serializer_4._obj_class_name == "Tensor"
assert isinstance(obj_4_loaded, torch.Tensor)
assert torch.equal(obj_4_loaded, torch.tensor([1, 2, 3]))
assert obj_4_name.startswith("Tensor_")
def test_obj_serializer_fwd_cache_initializes(obj_serializer: ObjectSerializerDisk[MockDataclass]):
fwd_cache = ObjectSerializerForwardCache(obj_serializer)
assert fwd_cache._underlying_storage == obj_serializer
def test_obj_serializer_fwd_cache_saves_and_loads(fwd_cache: ObjectSerializerForwardCache[MockDataclass]):
obj = MockDataclass(foo="bar")
obj_name = fwd_cache.save(obj)
obj_loaded = fwd_cache.load(obj_name)
obj_underlying = fwd_cache._underlying_storage.load(obj_name)
assert obj_loaded == obj_underlying
assert obj_loaded.foo == "bar"
def test_obj_serializer_fwd_cache_respects_cache_size(fwd_cache: ObjectSerializerForwardCache[MockDataclass]):
obj_1 = MockDataclass(foo="bar")
obj_1_name = fwd_cache.save(obj_1)
obj_2 = MockDataclass(foo="baz")
obj_2_name = fwd_cache.save(obj_2)
obj_3 = MockDataclass(foo="qux")
obj_3_name = fwd_cache.save(obj_3)
assert obj_1_name not in fwd_cache._cache
assert obj_2_name in fwd_cache._cache
assert obj_3_name in fwd_cache._cache
# apparently qsize is "not reliable"?
assert fwd_cache._cache_ids.qsize() == 2
def test_obj_serializer_fwd_cache_calls_delete_callback(fwd_cache: ObjectSerializerForwardCache[MockDataclass]):
called_name = None
obj_1 = MockDataclass(foo="bar")
def on_deleted(name: str):
nonlocal called_name
called_name = name
fwd_cache.on_deleted(on_deleted)
obj_1_name = fwd_cache.save(obj_1)
fwd_cache.delete(obj_1_name)
assert called_name == obj_1_name

View File

@ -2,6 +2,7 @@
Not really a test, but a way to verify that the paths are existing
and fail early if they are not.
"""
import pathlib
import unittest
from os import path as osp

View File

@ -8,11 +8,11 @@ from invokeai.app.services.session_queue.session_queue_common import (
NodeFieldValue,
calc_session_count,
create_session_nfv_tuples,
populate_graph,
prepare_values_to_insert,
)
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation
from tests.aa_nodes.test_nodes import PromptTestInvocation
from invokeai.app.services.shared.graph import Graph, GraphExecutionState
from .test_nodes import PromptTestInvocation
@pytest.fixture
@ -39,30 +39,6 @@ def batch_graph() -> Graph:
return g
def test_populate_graph_with_subgraph():
g1 = Graph()
g1.add_node(PromptTestInvocation(id="1", prompt="Banana sushi"))
g1.add_node(PromptTestInvocation(id="2", prompt="Banana sushi"))
n1 = PromptTestInvocation(id="1", prompt="Banana snake")
subgraph = Graph()
subgraph.add_node(n1)
g1.add_node(GraphInvocation(id="3", graph=subgraph))
nfvs = [
NodeFieldValue(node_path="1", field_name="prompt", value="Strawberry sushi"),
NodeFieldValue(node_path="2", field_name="prompt", value="Strawberry sunday"),
NodeFieldValue(node_path="3.1", field_name="prompt", value="Strawberry snake"),
]
g2 = populate_graph(g1, nfvs)
# do not mutate g1
assert g1 is not g2
assert g2.get_node("1").prompt == "Strawberry sushi"
assert g2.get_node("2").prompt == "Strawberry sunday"
assert g2.get_node("3.1").prompt == "Strawberry snake"
def test_create_sessions_from_batch_with_runs(batch_data_collection, batch_graph):
b = Batch(graph=batch_graph, data=batch_data_collection, runs=2)
t = list(create_session_nfv_tuples(batch=b, maximum=1000))