2022-12-01 05:33:20 +00:00
|
|
|
import pytest
|
|
|
|
|
2023-06-26 15:55:24 +00:00
|
|
|
from invokeai.app.services.graph import Graph, GraphExecutionState
|
|
|
|
from invokeai.app.services.invocation_services import InvocationServices
|
|
|
|
from invokeai.app.services.invoker import Invoker
|
|
|
|
|
|
|
|
from .test_nodes import (ErrorInvocation, ImageTestInvocation,
|
|
|
|
PromptTestInvocation, create_edge, wait_until)
|
|
|
|
|
2022-12-01 05:33:20 +00:00
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
def simple_graph():
|
|
|
|
g = Graph()
|
|
|
|
g.add_node(PromptTestInvocation(id = "1", prompt = "Banana sushi"))
|
|
|
|
g.add_node(ImageTestInvocation(id = "2"))
|
|
|
|
g.add_edge(create_edge("1", "prompt", "2", "prompt"))
|
|
|
|
return g
|
|
|
|
|
|
|
|
@pytest.fixture()
|
2023-02-25 04:11:28 +00:00
|
|
|
def mock_invoker(mock_services: InvocationServices) -> Invoker:
|
2022-12-01 05:33:20 +00:00
|
|
|
return Invoker(
|
2023-02-25 04:11:28 +00:00
|
|
|
services = mock_services
|
2022-12-01 05:33:20 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
def test_can_create_graph_state(mock_invoker: Invoker):
|
|
|
|
g = mock_invoker.create_execution_state()
|
|
|
|
mock_invoker.stop()
|
|
|
|
|
|
|
|
assert g is not None
|
|
|
|
assert isinstance(g, GraphExecutionState)
|
|
|
|
|
|
|
|
def test_can_create_graph_state_from_graph(mock_invoker: Invoker, simple_graph):
|
|
|
|
g = mock_invoker.create_execution_state(graph = simple_graph)
|
|
|
|
mock_invoker.stop()
|
|
|
|
|
|
|
|
assert g is not None
|
|
|
|
assert isinstance(g, GraphExecutionState)
|
|
|
|
assert g.graph == simple_graph
|
|
|
|
|
2023-06-26 15:55:24 +00:00
|
|
|
@pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
|
2022-12-01 05:33:20 +00:00
|
|
|
def test_can_invoke(mock_invoker: Invoker, simple_graph):
|
|
|
|
g = mock_invoker.create_execution_state(graph = simple_graph)
|
|
|
|
invocation_id = mock_invoker.invoke(g)
|
|
|
|
assert invocation_id is not None
|
|
|
|
|
|
|
|
def has_executed_any(g: GraphExecutionState):
|
2023-02-25 04:11:28 +00:00
|
|
|
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
2022-12-01 05:33:20 +00:00
|
|
|
return len(g.executed) > 0
|
|
|
|
|
|
|
|
wait_until(lambda: has_executed_any(g), timeout = 5, interval = 1)
|
|
|
|
mock_invoker.stop()
|
|
|
|
|
2023-02-25 04:11:28 +00:00
|
|
|
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
2022-12-01 05:33:20 +00:00
|
|
|
assert len(g.executed) > 0
|
|
|
|
|
2023-06-26 15:55:24 +00:00
|
|
|
@pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
|
2022-12-01 05:33:20 +00:00
|
|
|
def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
|
|
|
|
g = mock_invoker.create_execution_state(graph = simple_graph)
|
|
|
|
invocation_id = mock_invoker.invoke(g, invoke_all = True)
|
|
|
|
assert invocation_id is not None
|
|
|
|
|
|
|
|
def has_executed_all(g: GraphExecutionState):
|
2023-02-25 04:11:28 +00:00
|
|
|
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
2022-12-01 05:33:20 +00:00
|
|
|
return g.is_complete()
|
|
|
|
|
|
|
|
wait_until(lambda: has_executed_all(g), timeout = 5, interval = 1)
|
|
|
|
mock_invoker.stop()
|
|
|
|
|
2023-02-25 04:11:28 +00:00
|
|
|
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
2022-12-01 05:33:20 +00:00
|
|
|
assert g.is_complete()
|
2023-02-27 18:01:07 +00:00
|
|
|
|
2023-06-26 15:55:24 +00:00
|
|
|
@pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
|
2023-02-27 18:01:07 +00:00
|
|
|
def test_handles_errors(mock_invoker: Invoker):
|
|
|
|
g = mock_invoker.create_execution_state()
|
|
|
|
g.graph.add_node(ErrorInvocation(id = "1"))
|
|
|
|
|
|
|
|
mock_invoker.invoke(g, invoke_all=True)
|
|
|
|
|
|
|
|
def has_executed_all(g: GraphExecutionState):
|
|
|
|
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
|
|
|
return g.is_complete()
|
|
|
|
|
|
|
|
wait_until(lambda: has_executed_all(g), timeout = 5, interval = 1)
|
|
|
|
mock_invoker.stop()
|
|
|
|
|
|
|
|
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
|
|
|
assert g.has_error()
|
|
|
|
assert g.is_complete()
|
|
|
|
|
|
|
|
assert all((i in g.errors for i in g.source_prepared_mapping['1']))
|