diff --git a/tests/test_graph_execution_state.py b/tests/test_graph_execution_state.py index 9cff502acf..2e88178424 100644 --- a/tests/test_graph_execution_state.py +++ b/tests/test_graph_execution_state.py @@ -1,9 +1,9 @@ import logging +from typing import Optional +from unittest.mock import Mock import pytest -from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory - # This import must happen before other invoke imports or test in other files(!!) break from .test_nodes import ( # isort: split PromptCollectionTestInvocation, @@ -17,8 +17,6 @@ from invokeai.app.invocations.collections import RangeInvocation from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache -from invokeai.app.services.invocation_processor.invocation_processor_default import DefaultInvocationProcessor -from invokeai.app.services.invocation_queue.invocation_queue_memory import MemoryInvocationQueue from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService from invokeai.app.services.shared.graph import ( @@ -28,11 +26,11 @@ from invokeai.app.services.shared.graph import ( IterateInvocation, ) -from .test_invoker import create_edge +from .test_nodes import create_edge @pytest.fixture -def simple_graph(): +def simple_graph() -> Graph: g = Graph() g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi")) g.add_node(TextToImageTestInvocation(id="2")) @@ -47,7 +45,6 @@ def simple_graph(): def mock_services() -> InvocationServices: configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0) # NOTE: none of these are actually called by the test invocations - graph_execution_manager = ItemStorageMemory[GraphExecutionState]() return InvocationServices( board_image_records=None, # type: ignore board_images=None, # type: ignore @@ -55,7 +52,6 @@ def mock_services() -> InvocationServices: boards=None, # type: ignore configuration=configuration, events=TestEventService(), - graph_execution_manager=graph_execution_manager, image_files=None, # type: ignore image_records=None, # type: ignore images=None, # type: ignore @@ -65,47 +61,32 @@ def mock_services() -> InvocationServices: download_queue=None, # type: ignore names=None, # type: ignore performance_statistics=InvocationStatsService(), - processor=DefaultInvocationProcessor(), - queue=MemoryInvocationQueue(), session_processor=None, # type: ignore session_queue=None, # type: ignore urls=None, # type: ignore workflow_records=None, # type: ignore - tensors=None, - conditioning=None, + tensors=None, # type: ignore + conditioning=None, # type: ignore ) -def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]: +def invoke_next(g: GraphExecutionState) -> tuple[Optional[BaseInvocation], Optional[BaseInvocationOutput]]: n = g.next() if n is None: return (None, None) print(f"invoking {n.id}: {type(n)}") - o = n.invoke( - InvocationContext( - conditioning=None, - config=None, - data=None, - images=None, - tensors=None, - logger=None, - models=None, - util=None, - boards=None, - services=None, - ) - ) + o = n.invoke(Mock(InvocationContext)) g.complete(n.id, o) return (n, o) -def test_graph_state_executes_in_order(simple_graph, mock_services): +def test_graph_state_executes_in_order(simple_graph: Graph): g = GraphExecutionState(graph=simple_graph) - n1 = invoke_next(g, mock_services) - n2 = invoke_next(g, mock_services) + n1 = invoke_next(g) + n2 = invoke_next(g) n3 = g.next() assert g.prepared_source_mapping[n1[0].id] == "1" @@ -115,18 +96,18 @@ def test_graph_state_executes_in_order(simple_graph, mock_services): assert n2[0].prompt == n1[0].prompt -def test_graph_is_complete(simple_graph, mock_services): +def test_graph_is_complete(simple_graph: Graph): g = GraphExecutionState(graph=simple_graph) - _ = invoke_next(g, mock_services) - _ = invoke_next(g, mock_services) + _ = invoke_next(g) + _ = invoke_next(g) _ = g.next() assert g.is_complete() -def test_graph_is_not_complete(simple_graph, mock_services): +def test_graph_is_not_complete(simple_graph: Graph): g = GraphExecutionState(graph=simple_graph) - _ = invoke_next(g, mock_services) + _ = invoke_next(g) _ = g.next() assert not g.is_complete() @@ -135,7 +116,7 @@ def test_graph_is_not_complete(simple_graph, mock_services): # TODO: test completion with iterators/subgraphs -def test_graph_state_expands_iterator(mock_services): +def test_graph_state_expands_iterator(): graph = Graph() graph.add_node(RangeInvocation(id="0", start=0, stop=3, step=1)) graph.add_node(IterateInvocation(id="1")) @@ -147,7 +128,7 @@ def test_graph_state_expands_iterator(mock_services): g = GraphExecutionState(graph=graph) while not g.is_complete(): - invoke_next(g, mock_services) + invoke_next(g) prepared_add_nodes = g.source_prepared_mapping["3"] results = {g.results[n].value for n in prepared_add_nodes} @@ -155,7 +136,7 @@ def test_graph_state_expands_iterator(mock_services): assert results == expected -def test_graph_state_collects(mock_services): +def test_graph_state_collects(): graph = Graph() test_prompts = ["Banana sushi", "Cat sushi"] graph.add_node(PromptCollectionTestInvocation(id="1", collection=list(test_prompts))) @@ -167,19 +148,19 @@ def test_graph_state_collects(mock_services): graph.add_edge(create_edge("3", "prompt", "4", "item")) g = GraphExecutionState(graph=graph) - _ = invoke_next(g, mock_services) - _ = invoke_next(g, mock_services) - _ = invoke_next(g, mock_services) - _ = invoke_next(g, mock_services) - _ = invoke_next(g, mock_services) - n6 = invoke_next(g, mock_services) + _ = invoke_next(g) + _ = invoke_next(g) + _ = invoke_next(g) + _ = invoke_next(g) + _ = invoke_next(g) + n6 = invoke_next(g) assert isinstance(n6[0], CollectInvocation) assert sorted(g.results[n6[0].id].collection) == sorted(test_prompts) -def test_graph_state_prepares_eagerly(mock_services): +def test_graph_state_prepares_eagerly(): """Tests that all prepareable nodes are prepared""" graph = Graph() @@ -208,7 +189,7 @@ def test_graph_state_prepares_eagerly(mock_services): assert "prompt_iterated" not in g.source_prepared_mapping -def test_graph_executes_depth_first(mock_services): +def test_graph_executes_depth_first(): """Tests that the graph executes depth-first, executing a branch as far as possible before moving to the next branch""" graph = Graph() @@ -222,14 +203,14 @@ def test_graph_executes_depth_first(mock_services): graph.add_edge(create_edge("prompt_iterated", "prompt", "prompt_successor", "prompt")) g = GraphExecutionState(graph=graph) - _ = invoke_next(g, mock_services) - _ = invoke_next(g, mock_services) - _ = invoke_next(g, mock_services) - _ = invoke_next(g, mock_services) + _ = invoke_next(g) + _ = invoke_next(g) + _ = invoke_next(g) + _ = invoke_next(g) # Because ordering is not guaranteed, we cannot compare results directly. # Instead, we must count the number of results. - def get_completed_count(g, id): + def get_completed_count(g: GraphExecutionState, id: str): ids = list(g.source_prepared_mapping[id]) completed_ids = [i for i in g.executed if i in ids] return len(completed_ids) @@ -238,17 +219,17 @@ def test_graph_executes_depth_first(mock_services): assert get_completed_count(g, "prompt_iterated") == 1 assert get_completed_count(g, "prompt_successor") == 0 - _ = invoke_next(g, mock_services) + _ = invoke_next(g) assert get_completed_count(g, "prompt_iterated") == 1 assert get_completed_count(g, "prompt_successor") == 1 - _ = invoke_next(g, mock_services) + _ = invoke_next(g) assert get_completed_count(g, "prompt_iterated") == 2 assert get_completed_count(g, "prompt_successor") == 1 - _ = invoke_next(g, mock_services) + _ = invoke_next(g) assert get_completed_count(g, "prompt_iterated") == 2 assert get_completed_count(g, "prompt_successor") == 2 diff --git a/tests/test_invoker.py b/tests/test_invoker.py deleted file mode 100644 index 38fcf859a5..0000000000 --- a/tests/test_invoker.py +++ /dev/null @@ -1,163 +0,0 @@ -import logging -from unittest.mock import Mock - -import pytest - -from invokeai.app.services.config.config_default import InvokeAIAppConfig -from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory - -# This import must happen before other invoke imports or test in other files(!!) break -from .test_nodes import ( # isort: split - ErrorInvocation, - PromptTestInvocation, - TestEventService, - TextToImageTestInvocation, - create_edge, - wait_until, -) - -from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache -from invokeai.app.services.invocation_processor.invocation_processor_default import DefaultInvocationProcessor -from invokeai.app.services.invocation_queue.invocation_queue_memory import MemoryInvocationQueue -from invokeai.app.services.invocation_services import InvocationServices -from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService -from invokeai.app.services.invoker import Invoker -from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID -from invokeai.app.services.shared.graph import Graph, GraphExecutionState - - -@pytest.fixture -def simple_graph(): - g = Graph() - g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi")) - g.add_node(TextToImageTestInvocation(id="2")) - g.add_edge(create_edge("1", "prompt", "2", "prompt")) - return g - - -# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types -# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate -# the test invocations. -@pytest.fixture -def mock_services() -> InvocationServices: - configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0) - return InvocationServices( - board_image_records=None, # type: ignore - board_images=None, # type: ignore - board_records=None, # type: ignore - boards=None, # type: ignore - configuration=configuration, - events=TestEventService(), - graph_execution_manager=ItemStorageMemory[GraphExecutionState](), - image_files=None, # type: ignore - image_records=None, # type: ignore - images=None, # type: ignore - invocation_cache=MemoryInvocationCache(max_cache_size=0), - logger=logging, # type: ignore - model_manager=Mock(), # type: ignore - download_queue=None, # type: ignore - names=None, # type: ignore - performance_statistics=InvocationStatsService(), - processor=DefaultInvocationProcessor(), - queue=MemoryInvocationQueue(), - session_processor=None, # type: ignore - session_queue=None, # type: ignore - urls=None, # type: ignore - workflow_records=None, # type: ignore - tensors=None, - conditioning=None, - ) - - -@pytest.fixture() -def mock_invoker(mock_services: InvocationServices) -> Invoker: - return Invoker(services=mock_services) - - -def test_can_create_graph_state(mock_invoker: Invoker): - g = mock_invoker.create_execution_state() - mock_invoker.stop() - - assert g is not None - assert isinstance(g, GraphExecutionState) - - -def test_can_create_graph_state_from_graph(mock_invoker: Invoker, simple_graph): - g = mock_invoker.create_execution_state(graph=simple_graph) - mock_invoker.stop() - - assert g is not None - assert isinstance(g, GraphExecutionState) - assert g.graph == simple_graph - - -# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor") -def test_can_invoke(mock_invoker: Invoker, simple_graph): - g = mock_invoker.create_execution_state(graph=simple_graph) - invocation_id = mock_invoker.invoke( - session_queue_batch_id="1", - session_queue_item_id=1, - session_queue_id=DEFAULT_QUEUE_ID, - graph_execution_state=g, - ) - assert invocation_id is not None - - def has_executed_any(g: GraphExecutionState): - g = mock_invoker.services.graph_execution_manager.get(g.id) - return len(g.executed) > 0 - - wait_until(lambda: has_executed_any(g), timeout=5, interval=1) - mock_invoker.stop() - - g = mock_invoker.services.graph_execution_manager.get(g.id) - assert len(g.executed) > 0 - - -# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor") -def test_can_invoke_all(mock_invoker: Invoker, simple_graph): - g = mock_invoker.create_execution_state(graph=simple_graph) - invocation_id = mock_invoker.invoke( - session_queue_batch_id="1", - session_queue_item_id=1, - session_queue_id=DEFAULT_QUEUE_ID, - graph_execution_state=g, - invoke_all=True, - ) - assert invocation_id is not None - - def has_executed_all(g: GraphExecutionState): - g = mock_invoker.services.graph_execution_manager.get(g.id) - return g.is_complete() - - wait_until(lambda: has_executed_all(g), timeout=5, interval=1) - mock_invoker.stop() - - g = mock_invoker.services.graph_execution_manager.get(g.id) - assert g.is_complete() - - -# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor") -def test_handles_errors(mock_invoker: Invoker): - g = mock_invoker.create_execution_state() - g.graph.add_node(ErrorInvocation(id="1")) - - mock_invoker.invoke( - session_queue_batch_id="1", - session_queue_item_id=1, - session_queue_id=DEFAULT_QUEUE_ID, - graph_execution_state=g, - invoke_all=True, - ) - - def has_executed_all(g: GraphExecutionState): - g = mock_invoker.services.graph_execution_manager.get(g.id) - return g.is_complete() - - wait_until(lambda: has_executed_all(g), timeout=5, interval=1) - mock_invoker.stop() - - g = mock_invoker.services.graph_execution_manager.get(g.id) - assert g.has_error() - assert g.is_complete() - - assert all((i in g.errors for i in g.source_prepared_mapping["1"]))