mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
34e3aa1f88
author Kyle Schouviller <kyle0654@hotmail.com> 1669872800 -0800 committer Kyle Schouviller <kyle0654@hotmail.com> 1676240900 -0800 Adding base node architecture Fix type annotation errors Runs and generates, but breaks in saving session Fix default model value setting. Fix deprecation warning. Fixed node api Adding markdown docs Simplifying Generate construction in apps [nodes] A few minor changes (#2510) * Pin api-related requirements * Remove confusing extra CORS origins list * Adds response models for HTTP 200 [nodes] Adding graph_execution_state to soon replace session. Adding tests with pytest. Minor typing fixes [nodes] Fix some small output query hookups [node] Fixing some additional typing issues [nodes] Move and expand graph code. Add base item storage and sqlite implementation. Update startup to match new code [nodes] Add callbacks to item storage [nodes] Adding an InvocationContext object to use for invocations to provide easier extensibility [nodes] New execution model that handles iteration [nodes] Fixing the CLI [nodes] Adding a note to the CLI [nodes] Split processing thread into separate service [node] Add error message on node processing failure Removing old files and duplicated packages Adding python-multipart
115 lines
4.4 KiB
Python
115 lines
4.4 KiB
Python
from .test_invoker import create_edge
|
|
from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation
|
|
from ldm.invoke.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
|
from ldm.invoke.app.services.invocation_services import InvocationServices
|
|
from ldm.invoke.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
|
|
from ldm.invoke.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation
|
|
from ldm.invoke.app.invocations.upscale import UpscaleInvocation
|
|
import pytest
|
|
|
|
|
|
@pytest.fixture
|
|
def simple_graph():
|
|
g = Graph()
|
|
g.add_node(PromptTestInvocation(id = "1", prompt = "Banana sushi"))
|
|
g.add_node(ImageTestInvocation(id = "2"))
|
|
g.add_edge(create_edge("1", "prompt", "2", "prompt"))
|
|
return g
|
|
|
|
@pytest.fixture
|
|
def mock_services():
|
|
# NOTE: none of these are actually called by the test invocations
|
|
return InvocationServices(generate = None, events = None, images = None)
|
|
|
|
def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]:
|
|
n = g.next()
|
|
if n is None:
|
|
return (None, None)
|
|
|
|
print(f'invoking {n.id}: {type(n)}')
|
|
o = n.invoke(InvocationContext(services, "1"))
|
|
g.complete(n.id, o)
|
|
|
|
return (n, o)
|
|
|
|
def test_graph_state_executes_in_order(simple_graph, mock_services):
|
|
g = GraphExecutionState(graph = simple_graph)
|
|
|
|
n1 = invoke_next(g, mock_services)
|
|
n2 = invoke_next(g, mock_services)
|
|
n3 = g.next()
|
|
|
|
assert g.prepared_source_mapping[n1[0].id] == "1"
|
|
assert g.prepared_source_mapping[n2[0].id] == "2"
|
|
assert n3 is None
|
|
assert g.results[n1[0].id].prompt == n1[0].prompt
|
|
assert n2[0].prompt == n1[0].prompt
|
|
|
|
def test_graph_is_complete(simple_graph, mock_services):
|
|
g = GraphExecutionState(graph = simple_graph)
|
|
n1 = invoke_next(g, mock_services)
|
|
n2 = invoke_next(g, mock_services)
|
|
n3 = g.next()
|
|
|
|
assert g.is_complete()
|
|
|
|
def test_graph_is_not_complete(simple_graph, mock_services):
|
|
g = GraphExecutionState(graph = simple_graph)
|
|
n1 = invoke_next(g, mock_services)
|
|
n2 = g.next()
|
|
|
|
assert not g.is_complete()
|
|
|
|
# TODO: test completion with iterators/subgraphs
|
|
|
|
def test_graph_state_expands_iterator(mock_services):
|
|
graph = Graph()
|
|
test_prompts = ["Banana sushi", "Cat sushi"]
|
|
graph.add_node(PromptCollectionTestInvocation(id = "1", collection = list(test_prompts)))
|
|
graph.add_node(IterateInvocation(id = "2"))
|
|
graph.add_node(ImageTestInvocation(id = "3"))
|
|
graph.add_edge(create_edge("1", "collection", "2", "collection"))
|
|
graph.add_edge(create_edge("2", "item", "3", "prompt"))
|
|
|
|
g = GraphExecutionState(graph = graph)
|
|
n1 = invoke_next(g, mock_services)
|
|
n2 = invoke_next(g, mock_services)
|
|
n3 = invoke_next(g, mock_services)
|
|
n4 = invoke_next(g, mock_services)
|
|
n5 = invoke_next(g, mock_services)
|
|
|
|
assert g.prepared_source_mapping[n1[0].id] == "1"
|
|
assert g.prepared_source_mapping[n2[0].id] == "2"
|
|
assert g.prepared_source_mapping[n3[0].id] == "2"
|
|
assert g.prepared_source_mapping[n4[0].id] == "3"
|
|
assert g.prepared_source_mapping[n5[0].id] == "3"
|
|
|
|
assert isinstance(n4[0], ImageTestInvocation)
|
|
assert isinstance(n5[0], ImageTestInvocation)
|
|
|
|
prompts = [n4[0].prompt, n5[0].prompt]
|
|
assert sorted(prompts) == sorted(test_prompts)
|
|
|
|
def test_graph_state_collects(mock_services):
|
|
graph = Graph()
|
|
test_prompts = ["Banana sushi", "Cat sushi"]
|
|
graph.add_node(PromptCollectionTestInvocation(id = "1", collection = list(test_prompts)))
|
|
graph.add_node(IterateInvocation(id = "2"))
|
|
graph.add_node(PromptTestInvocation(id = "3"))
|
|
graph.add_node(CollectInvocation(id = "4"))
|
|
graph.add_edge(create_edge("1", "collection", "2", "collection"))
|
|
graph.add_edge(create_edge("2", "item", "3", "prompt"))
|
|
graph.add_edge(create_edge("3", "prompt", "4", "item"))
|
|
|
|
g = GraphExecutionState(graph = graph)
|
|
n1 = invoke_next(g, mock_services)
|
|
n2 = invoke_next(g, mock_services)
|
|
n3 = invoke_next(g, mock_services)
|
|
n4 = invoke_next(g, mock_services)
|
|
n5 = invoke_next(g, mock_services)
|
|
n6 = invoke_next(g, mock_services)
|
|
|
|
assert isinstance(n6[0], CollectInvocation)
|
|
|
|
assert sorted(g.results[n6[0].id].collection) == sorted(test_prompts)
|