mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
tests(nodes): fix tests following removal of services
This commit is contained in:
parent
d53a2a2d4e
commit
0788a27a80
@ -1,9 +1,9 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
|
||||
|
||||
# This import must happen before other invoke imports or test in other files(!!) break
|
||||
from .test_nodes import ( # isort: split
|
||||
PromptCollectionTestInvocation,
|
||||
@ -17,8 +17,6 @@ from invokeai.app.invocations.collections import RangeInvocation
|
||||
from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
|
||||
from invokeai.app.services.invocation_processor.invocation_processor_default import DefaultInvocationProcessor
|
||||
from invokeai.app.services.invocation_queue.invocation_queue_memory import MemoryInvocationQueue
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||
from invokeai.app.services.shared.graph import (
|
||||
@ -28,11 +26,11 @@ from invokeai.app.services.shared.graph import (
|
||||
IterateInvocation,
|
||||
)
|
||||
|
||||
from .test_invoker import create_edge
|
||||
from .test_nodes import create_edge
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_graph():
|
||||
def simple_graph() -> Graph:
|
||||
g = Graph()
|
||||
g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi"))
|
||||
g.add_node(TextToImageTestInvocation(id="2"))
|
||||
@ -47,7 +45,6 @@ def simple_graph():
|
||||
def mock_services() -> InvocationServices:
|
||||
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
|
||||
# NOTE: none of these are actually called by the test invocations
|
||||
graph_execution_manager = ItemStorageMemory[GraphExecutionState]()
|
||||
return InvocationServices(
|
||||
board_image_records=None, # type: ignore
|
||||
board_images=None, # type: ignore
|
||||
@ -55,7 +52,6 @@ def mock_services() -> InvocationServices:
|
||||
boards=None, # type: ignore
|
||||
configuration=configuration,
|
||||
events=TestEventService(),
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
image_files=None, # type: ignore
|
||||
image_records=None, # type: ignore
|
||||
images=None, # type: ignore
|
||||
@ -65,47 +61,32 @@ def mock_services() -> InvocationServices:
|
||||
download_queue=None, # type: ignore
|
||||
names=None, # type: ignore
|
||||
performance_statistics=InvocationStatsService(),
|
||||
processor=DefaultInvocationProcessor(),
|
||||
queue=MemoryInvocationQueue(),
|
||||
session_processor=None, # type: ignore
|
||||
session_queue=None, # type: ignore
|
||||
urls=None, # type: ignore
|
||||
workflow_records=None, # type: ignore
|
||||
tensors=None,
|
||||
conditioning=None,
|
||||
tensors=None, # type: ignore
|
||||
conditioning=None, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]:
|
||||
def invoke_next(g: GraphExecutionState) -> tuple[Optional[BaseInvocation], Optional[BaseInvocationOutput]]:
|
||||
n = g.next()
|
||||
if n is None:
|
||||
return (None, None)
|
||||
|
||||
print(f"invoking {n.id}: {type(n)}")
|
||||
o = n.invoke(
|
||||
InvocationContext(
|
||||
conditioning=None,
|
||||
config=None,
|
||||
data=None,
|
||||
images=None,
|
||||
tensors=None,
|
||||
logger=None,
|
||||
models=None,
|
||||
util=None,
|
||||
boards=None,
|
||||
services=None,
|
||||
)
|
||||
)
|
||||
o = n.invoke(Mock(InvocationContext))
|
||||
g.complete(n.id, o)
|
||||
|
||||
return (n, o)
|
||||
|
||||
|
||||
def test_graph_state_executes_in_order(simple_graph, mock_services):
|
||||
def test_graph_state_executes_in_order(simple_graph: Graph):
|
||||
g = GraphExecutionState(graph=simple_graph)
|
||||
|
||||
n1 = invoke_next(g, mock_services)
|
||||
n2 = invoke_next(g, mock_services)
|
||||
n1 = invoke_next(g)
|
||||
n2 = invoke_next(g)
|
||||
n3 = g.next()
|
||||
|
||||
assert g.prepared_source_mapping[n1[0].id] == "1"
|
||||
@ -115,18 +96,18 @@ def test_graph_state_executes_in_order(simple_graph, mock_services):
|
||||
assert n2[0].prompt == n1[0].prompt
|
||||
|
||||
|
||||
def test_graph_is_complete(simple_graph, mock_services):
|
||||
def test_graph_is_complete(simple_graph: Graph):
|
||||
g = GraphExecutionState(graph=simple_graph)
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g)
|
||||
_ = invoke_next(g)
|
||||
_ = g.next()
|
||||
|
||||
assert g.is_complete()
|
||||
|
||||
|
||||
def test_graph_is_not_complete(simple_graph, mock_services):
|
||||
def test_graph_is_not_complete(simple_graph: Graph):
|
||||
g = GraphExecutionState(graph=simple_graph)
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g)
|
||||
_ = g.next()
|
||||
|
||||
assert not g.is_complete()
|
||||
@ -135,7 +116,7 @@ def test_graph_is_not_complete(simple_graph, mock_services):
|
||||
# TODO: test completion with iterators/subgraphs
|
||||
|
||||
|
||||
def test_graph_state_expands_iterator(mock_services):
|
||||
def test_graph_state_expands_iterator():
|
||||
graph = Graph()
|
||||
graph.add_node(RangeInvocation(id="0", start=0, stop=3, step=1))
|
||||
graph.add_node(IterateInvocation(id="1"))
|
||||
@ -147,7 +128,7 @@ def test_graph_state_expands_iterator(mock_services):
|
||||
|
||||
g = GraphExecutionState(graph=graph)
|
||||
while not g.is_complete():
|
||||
invoke_next(g, mock_services)
|
||||
invoke_next(g)
|
||||
|
||||
prepared_add_nodes = g.source_prepared_mapping["3"]
|
||||
results = {g.results[n].value for n in prepared_add_nodes}
|
||||
@ -155,7 +136,7 @@ def test_graph_state_expands_iterator(mock_services):
|
||||
assert results == expected
|
||||
|
||||
|
||||
def test_graph_state_collects(mock_services):
|
||||
def test_graph_state_collects():
|
||||
graph = Graph()
|
||||
test_prompts = ["Banana sushi", "Cat sushi"]
|
||||
graph.add_node(PromptCollectionTestInvocation(id="1", collection=list(test_prompts)))
|
||||
@ -167,19 +148,19 @@ def test_graph_state_collects(mock_services):
|
||||
graph.add_edge(create_edge("3", "prompt", "4", "item"))
|
||||
|
||||
g = GraphExecutionState(graph=graph)
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g, mock_services)
|
||||
n6 = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g)
|
||||
_ = invoke_next(g)
|
||||
_ = invoke_next(g)
|
||||
_ = invoke_next(g)
|
||||
_ = invoke_next(g)
|
||||
n6 = invoke_next(g)
|
||||
|
||||
assert isinstance(n6[0], CollectInvocation)
|
||||
|
||||
assert sorted(g.results[n6[0].id].collection) == sorted(test_prompts)
|
||||
|
||||
|
||||
def test_graph_state_prepares_eagerly(mock_services):
|
||||
def test_graph_state_prepares_eagerly():
|
||||
"""Tests that all prepareable nodes are prepared"""
|
||||
graph = Graph()
|
||||
|
||||
@ -208,7 +189,7 @@ def test_graph_state_prepares_eagerly(mock_services):
|
||||
assert "prompt_iterated" not in g.source_prepared_mapping
|
||||
|
||||
|
||||
def test_graph_executes_depth_first(mock_services):
|
||||
def test_graph_executes_depth_first():
|
||||
"""Tests that the graph executes depth-first, executing a branch as far as possible before moving to the next branch"""
|
||||
graph = Graph()
|
||||
|
||||
@ -222,14 +203,14 @@ def test_graph_executes_depth_first(mock_services):
|
||||
graph.add_edge(create_edge("prompt_iterated", "prompt", "prompt_successor", "prompt"))
|
||||
|
||||
g = GraphExecutionState(graph=graph)
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g)
|
||||
_ = invoke_next(g)
|
||||
_ = invoke_next(g)
|
||||
_ = invoke_next(g)
|
||||
|
||||
# Because ordering is not guaranteed, we cannot compare results directly.
|
||||
# Instead, we must count the number of results.
|
||||
def get_completed_count(g, id):
|
||||
def get_completed_count(g: GraphExecutionState, id: str):
|
||||
ids = list(g.source_prepared_mapping[id])
|
||||
completed_ids = [i for i in g.executed if i in ids]
|
||||
return len(completed_ids)
|
||||
@ -238,17 +219,17 @@ def test_graph_executes_depth_first(mock_services):
|
||||
assert get_completed_count(g, "prompt_iterated") == 1
|
||||
assert get_completed_count(g, "prompt_successor") == 0
|
||||
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g)
|
||||
|
||||
assert get_completed_count(g, "prompt_iterated") == 1
|
||||
assert get_completed_count(g, "prompt_successor") == 1
|
||||
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g)
|
||||
|
||||
assert get_completed_count(g, "prompt_iterated") == 2
|
||||
assert get_completed_count(g, "prompt_successor") == 1
|
||||
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g)
|
||||
|
||||
assert get_completed_count(g, "prompt_iterated") == 2
|
||||
assert get_completed_count(g, "prompt_successor") == 2
|
||||
|
@ -1,163 +0,0 @@
|
||||
import logging
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
|
||||
|
||||
# This import must happen before other invoke imports or test in other files(!!) break
|
||||
from .test_nodes import ( # isort: split
|
||||
ErrorInvocation,
|
||||
PromptTestInvocation,
|
||||
TestEventService,
|
||||
TextToImageTestInvocation,
|
||||
create_edge,
|
||||
wait_until,
|
||||
)
|
||||
|
||||
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
|
||||
from invokeai.app.services.invocation_processor.invocation_processor_default import DefaultInvocationProcessor
|
||||
from invokeai.app.services.invocation_queue.invocation_queue_memory import MemoryInvocationQueue
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID
|
||||
from invokeai.app.services.shared.graph import Graph, GraphExecutionState
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_graph():
|
||||
g = Graph()
|
||||
g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi"))
|
||||
g.add_node(TextToImageTestInvocation(id="2"))
|
||||
g.add_edge(create_edge("1", "prompt", "2", "prompt"))
|
||||
return g
|
||||
|
||||
|
||||
# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
|
||||
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
|
||||
# the test invocations.
|
||||
@pytest.fixture
|
||||
def mock_services() -> InvocationServices:
|
||||
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
|
||||
return InvocationServices(
|
||||
board_image_records=None, # type: ignore
|
||||
board_images=None, # type: ignore
|
||||
board_records=None, # type: ignore
|
||||
boards=None, # type: ignore
|
||||
configuration=configuration,
|
||||
events=TestEventService(),
|
||||
graph_execution_manager=ItemStorageMemory[GraphExecutionState](),
|
||||
image_files=None, # type: ignore
|
||||
image_records=None, # type: ignore
|
||||
images=None, # type: ignore
|
||||
invocation_cache=MemoryInvocationCache(max_cache_size=0),
|
||||
logger=logging, # type: ignore
|
||||
model_manager=Mock(), # type: ignore
|
||||
download_queue=None, # type: ignore
|
||||
names=None, # type: ignore
|
||||
performance_statistics=InvocationStatsService(),
|
||||
processor=DefaultInvocationProcessor(),
|
||||
queue=MemoryInvocationQueue(),
|
||||
session_processor=None, # type: ignore
|
||||
session_queue=None, # type: ignore
|
||||
urls=None, # type: ignore
|
||||
workflow_records=None, # type: ignore
|
||||
tensors=None,
|
||||
conditioning=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_invoker(mock_services: InvocationServices) -> Invoker:
|
||||
return Invoker(services=mock_services)
|
||||
|
||||
|
||||
def test_can_create_graph_state(mock_invoker: Invoker):
|
||||
g = mock_invoker.create_execution_state()
|
||||
mock_invoker.stop()
|
||||
|
||||
assert g is not None
|
||||
assert isinstance(g, GraphExecutionState)
|
||||
|
||||
|
||||
def test_can_create_graph_state_from_graph(mock_invoker: Invoker, simple_graph):
|
||||
g = mock_invoker.create_execution_state(graph=simple_graph)
|
||||
mock_invoker.stop()
|
||||
|
||||
assert g is not None
|
||||
assert isinstance(g, GraphExecutionState)
|
||||
assert g.graph == simple_graph
|
||||
|
||||
|
||||
# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
|
||||
def test_can_invoke(mock_invoker: Invoker, simple_graph):
|
||||
g = mock_invoker.create_execution_state(graph=simple_graph)
|
||||
invocation_id = mock_invoker.invoke(
|
||||
session_queue_batch_id="1",
|
||||
session_queue_item_id=1,
|
||||
session_queue_id=DEFAULT_QUEUE_ID,
|
||||
graph_execution_state=g,
|
||||
)
|
||||
assert invocation_id is not None
|
||||
|
||||
def has_executed_any(g: GraphExecutionState):
|
||||
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
||||
return len(g.executed) > 0
|
||||
|
||||
wait_until(lambda: has_executed_any(g), timeout=5, interval=1)
|
||||
mock_invoker.stop()
|
||||
|
||||
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
||||
assert len(g.executed) > 0
|
||||
|
||||
|
||||
# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
|
||||
def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
|
||||
g = mock_invoker.create_execution_state(graph=simple_graph)
|
||||
invocation_id = mock_invoker.invoke(
|
||||
session_queue_batch_id="1",
|
||||
session_queue_item_id=1,
|
||||
session_queue_id=DEFAULT_QUEUE_ID,
|
||||
graph_execution_state=g,
|
||||
invoke_all=True,
|
||||
)
|
||||
assert invocation_id is not None
|
||||
|
||||
def has_executed_all(g: GraphExecutionState):
|
||||
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
||||
return g.is_complete()
|
||||
|
||||
wait_until(lambda: has_executed_all(g), timeout=5, interval=1)
|
||||
mock_invoker.stop()
|
||||
|
||||
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
||||
assert g.is_complete()
|
||||
|
||||
|
||||
# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
|
||||
def test_handles_errors(mock_invoker: Invoker):
|
||||
g = mock_invoker.create_execution_state()
|
||||
g.graph.add_node(ErrorInvocation(id="1"))
|
||||
|
||||
mock_invoker.invoke(
|
||||
session_queue_batch_id="1",
|
||||
session_queue_item_id=1,
|
||||
session_queue_id=DEFAULT_QUEUE_ID,
|
||||
graph_execution_state=g,
|
||||
invoke_all=True,
|
||||
)
|
||||
|
||||
def has_executed_all(g: GraphExecutionState):
|
||||
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
||||
return g.is_complete()
|
||||
|
||||
wait_until(lambda: has_executed_all(g), timeout=5, interval=1)
|
||||
mock_invoker.stop()
|
||||
|
||||
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
||||
assert g.has_error()
|
||||
assert g.is_complete()
|
||||
|
||||
assert all((i in g.errors for i in g.source_prepared_mapping["1"]))
|
Loading…
Reference in New Issue
Block a user