author Kyle Schouviller <kyle0654@hotmail.com> 1669872800 -0800
committer Kyle Schouviller <kyle0654@hotmail.com> 1676240900 -0800

Adding base node architecture

Fix type annotation errors

Runs and generates, but breaks in saving session

Fix default model value setting. Fix deprecation warning.

Fixed node api

Adding markdown docs

Simplifying Generate construction in apps

[nodes] A few minor changes (#2510)

* Pin api-related requirements

* Remove confusing extra CORS origins list

* Adds response models for HTTP 200

[nodes] Adding graph_execution_state to soon replace session. Adding tests with pytest.

Minor typing fixes

[nodes] Fix some small output query hookups

[node] Fixing some additional typing issues

[nodes] Move and expand graph code. Add base item storage and sqlite implementation.

Update startup to match new code

[nodes] Add callbacks to item storage

[nodes] Adding an InvocationContext object to use for invocations to provide easier extensibility

[nodes] New execution model that handles iteration

[nodes] Fixing the CLI

[nodes] Adding a note to the CLI

[nodes] Split processing thread into separate service

[node] Add error message on node processing failure

Removing old files and duplicated packages

Adding python-multipart
This commit is contained in:
Kyle Schouviller
2022-11-30 21:33:20 -08:00
parent 49ffb64ef3
commit 34e3aa1f88
42 changed files with 4510 additions and 0 deletions

0
tests/__init__.py Normal file
View File

0
tests/nodes/__init__.py Normal file
View File

View File

@ -0,0 +1,114 @@
from .test_invoker import create_edge
from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation
from ldm.invoke.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from ldm.invoke.app.services.invocation_services import InvocationServices
from ldm.invoke.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
from ldm.invoke.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation
from ldm.invoke.app.invocations.upscale import UpscaleInvocation
import pytest
@pytest.fixture
def simple_graph():
g = Graph()
g.add_node(PromptTestInvocation(id = "1", prompt = "Banana sushi"))
g.add_node(ImageTestInvocation(id = "2"))
g.add_edge(create_edge("1", "prompt", "2", "prompt"))
return g
@pytest.fixture
def mock_services():
# NOTE: none of these are actually called by the test invocations
return InvocationServices(generate = None, events = None, images = None)
def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]:
n = g.next()
if n is None:
return (None, None)
print(f'invoking {n.id}: {type(n)}')
o = n.invoke(InvocationContext(services, "1"))
g.complete(n.id, o)
return (n, o)
def test_graph_state_executes_in_order(simple_graph, mock_services):
g = GraphExecutionState(graph = simple_graph)
n1 = invoke_next(g, mock_services)
n2 = invoke_next(g, mock_services)
n3 = g.next()
assert g.prepared_source_mapping[n1[0].id] == "1"
assert g.prepared_source_mapping[n2[0].id] == "2"
assert n3 is None
assert g.results[n1[0].id].prompt == n1[0].prompt
assert n2[0].prompt == n1[0].prompt
def test_graph_is_complete(simple_graph, mock_services):
g = GraphExecutionState(graph = simple_graph)
n1 = invoke_next(g, mock_services)
n2 = invoke_next(g, mock_services)
n3 = g.next()
assert g.is_complete()
def test_graph_is_not_complete(simple_graph, mock_services):
g = GraphExecutionState(graph = simple_graph)
n1 = invoke_next(g, mock_services)
n2 = g.next()
assert not g.is_complete()
# TODO: test completion with iterators/subgraphs
def test_graph_state_expands_iterator(mock_services):
graph = Graph()
test_prompts = ["Banana sushi", "Cat sushi"]
graph.add_node(PromptCollectionTestInvocation(id = "1", collection = list(test_prompts)))
graph.add_node(IterateInvocation(id = "2"))
graph.add_node(ImageTestInvocation(id = "3"))
graph.add_edge(create_edge("1", "collection", "2", "collection"))
graph.add_edge(create_edge("2", "item", "3", "prompt"))
g = GraphExecutionState(graph = graph)
n1 = invoke_next(g, mock_services)
n2 = invoke_next(g, mock_services)
n3 = invoke_next(g, mock_services)
n4 = invoke_next(g, mock_services)
n5 = invoke_next(g, mock_services)
assert g.prepared_source_mapping[n1[0].id] == "1"
assert g.prepared_source_mapping[n2[0].id] == "2"
assert g.prepared_source_mapping[n3[0].id] == "2"
assert g.prepared_source_mapping[n4[0].id] == "3"
assert g.prepared_source_mapping[n5[0].id] == "3"
assert isinstance(n4[0], ImageTestInvocation)
assert isinstance(n5[0], ImageTestInvocation)
prompts = [n4[0].prompt, n5[0].prompt]
assert sorted(prompts) == sorted(test_prompts)
def test_graph_state_collects(mock_services):
graph = Graph()
test_prompts = ["Banana sushi", "Cat sushi"]
graph.add_node(PromptCollectionTestInvocation(id = "1", collection = list(test_prompts)))
graph.add_node(IterateInvocation(id = "2"))
graph.add_node(PromptTestInvocation(id = "3"))
graph.add_node(CollectInvocation(id = "4"))
graph.add_edge(create_edge("1", "collection", "2", "collection"))
graph.add_edge(create_edge("2", "item", "3", "prompt"))
graph.add_edge(create_edge("3", "prompt", "4", "item"))
g = GraphExecutionState(graph = graph)
n1 = invoke_next(g, mock_services)
n2 = invoke_next(g, mock_services)
n3 = invoke_next(g, mock_services)
n4 = invoke_next(g, mock_services)
n5 = invoke_next(g, mock_services)
n6 = invoke_next(g, mock_services)
assert isinstance(n6[0], CollectInvocation)
assert sorted(g.results[n6[0].id].collection) == sorted(test_prompts)

View File

@ -0,0 +1,85 @@
from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation, TestEventService, create_edge, wait_until
from ldm.invoke.app.services.processor import DefaultInvocationProcessor
from ldm.invoke.app.services.sqlite import SqliteItemStorage, sqlite_memory
from ldm.invoke.app.services.invocation_queue import MemoryInvocationQueue
from ldm.invoke.app.services.invoker import Invoker, InvokerServices
from ldm.invoke.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from ldm.invoke.app.services.invocation_services import InvocationServices
from ldm.invoke.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
from ldm.invoke.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation
from ldm.invoke.app.invocations.upscale import UpscaleInvocation
import pytest
@pytest.fixture
def simple_graph():
g = Graph()
g.add_node(PromptTestInvocation(id = "1", prompt = "Banana sushi"))
g.add_node(ImageTestInvocation(id = "2"))
g.add_edge(create_edge("1", "prompt", "2", "prompt"))
return g
@pytest.fixture
def mock_services() -> InvocationServices:
# NOTE: none of these are actually called by the test invocations
return InvocationServices(generate = None, events = TestEventService(), images = None)
@pytest.fixture()
def mock_invoker_services() -> InvokerServices:
return InvokerServices(
queue = MemoryInvocationQueue(),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor()
)
@pytest.fixture()
def mock_invoker(mock_services: InvocationServices, mock_invoker_services: InvokerServices) -> Invoker:
return Invoker(
services = mock_services,
invoker_services = mock_invoker_services
)
def test_can_create_graph_state(mock_invoker: Invoker):
g = mock_invoker.create_execution_state()
mock_invoker.stop()
assert g is not None
assert isinstance(g, GraphExecutionState)
def test_can_create_graph_state_from_graph(mock_invoker: Invoker, simple_graph):
g = mock_invoker.create_execution_state(graph = simple_graph)
mock_invoker.stop()
assert g is not None
assert isinstance(g, GraphExecutionState)
assert g.graph == simple_graph
def test_can_invoke(mock_invoker: Invoker, simple_graph):
g = mock_invoker.create_execution_state(graph = simple_graph)
invocation_id = mock_invoker.invoke(g)
assert invocation_id is not None
def has_executed_any(g: GraphExecutionState):
g = mock_invoker.invoker_services.graph_execution_manager.get(g.id)
return len(g.executed) > 0
wait_until(lambda: has_executed_any(g), timeout = 5, interval = 1)
mock_invoker.stop()
g = mock_invoker.invoker_services.graph_execution_manager.get(g.id)
assert len(g.executed) > 0
def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
g = mock_invoker.create_execution_state(graph = simple_graph)
invocation_id = mock_invoker.invoke(g, invoke_all = True)
assert invocation_id is not None
def has_executed_all(g: GraphExecutionState):
g = mock_invoker.invoker_services.graph_execution_manager.get(g.id)
return g.is_complete()
wait_until(lambda: has_executed_all(g), timeout = 5, interval = 1)
mock_invoker.stop()
g = mock_invoker.invoker_services.graph_execution_manager.get(g.id)
assert g.is_complete()

View File

@ -0,0 +1,501 @@
from ldm.invoke.app.invocations.image import *
from .test_nodes import ListPassThroughInvocation, PromptTestInvocation
from ldm.invoke.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation
from ldm.invoke.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation
from ldm.invoke.app.invocations.upscale import UpscaleInvocation
import pytest
# Helpers
def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> tuple[EdgeConnection, EdgeConnection]:
return (EdgeConnection(node_id = from_id, field = from_field), EdgeConnection(node_id = to_id, field = to_field))
# Tests
def test_connections_are_compatible():
from_node = TextToImageInvocation(id = "1", prompt = "Banana sushi")
from_field = "image"
to_node = UpscaleInvocation(id = "2")
to_field = "image"
result = are_connections_compatible(from_node, from_field, to_node, to_field)
assert result == True
def test_connections_are_incompatible():
from_node = TextToImageInvocation(id = "1", prompt = "Banana sushi")
from_field = "image"
to_node = UpscaleInvocation(id = "2")
to_field = "strength"
result = are_connections_compatible(from_node, from_field, to_node, to_field)
assert result == False
def test_connections_incompatible_with_invalid_fields():
from_node = TextToImageInvocation(id = "1", prompt = "Banana sushi")
from_field = "invalid_field"
to_node = UpscaleInvocation(id = "2")
to_field = "image"
# From field is invalid
result = are_connections_compatible(from_node, from_field, to_node, to_field)
assert result == False
# To field is invalid
from_field = "image"
to_field = "invalid_field"
result = are_connections_compatible(from_node, from_field, to_node, to_field)
assert result == False
def test_graph_can_add_node():
g = Graph()
n = TextToImageInvocation(id = "1", prompt = "Banana sushi")
g.add_node(n)
assert n.id in g.nodes
def test_graph_fails_to_add_node_with_duplicate_id():
g = Graph()
n = TextToImageInvocation(id = "1", prompt = "Banana sushi")
g.add_node(n)
n2 = TextToImageInvocation(id = "1", prompt = "Banana sushi the second")
with pytest.raises(NodeAlreadyInGraphError):
g.add_node(n2)
def test_graph_updates_node():
g = Graph()
n = TextToImageInvocation(id = "1", prompt = "Banana sushi")
g.add_node(n)
n2 = TextToImageInvocation(id = "2", prompt = "Banana sushi the second")
g.add_node(n2)
nu = TextToImageInvocation(id = "1", prompt = "Banana sushi updated")
g.update_node("1", nu)
assert g.nodes["1"].prompt == "Banana sushi updated"
def test_graph_fails_to_update_node_if_type_changes():
g = Graph()
n = TextToImageInvocation(id = "1", prompt = "Banana sushi")
g.add_node(n)
n2 = UpscaleInvocation(id = "2")
g.add_node(n2)
nu = UpscaleInvocation(id = "1")
with pytest.raises(TypeError):
g.update_node("1", nu)
def test_graph_allows_non_conflicting_id_change():
g = Graph()
n = TextToImageInvocation(id = "1", prompt = "Banana sushi")
g.add_node(n)
n2 = UpscaleInvocation(id = "2")
g.add_node(n2)
e1 = create_edge(n.id,"image",n2.id,"image")
g.add_edge(e1)
nu = TextToImageInvocation(id = "3", prompt = "Banana sushi")
g.update_node("1", nu)
with pytest.raises(NodeNotFoundError):
g.get_node("1")
assert g.get_node("3").prompt == "Banana sushi"
assert len(g.edges) == 1
assert (EdgeConnection(node_id = "3", field = "image"), EdgeConnection(node_id = "2", field = "image")) in g.edges
def test_graph_fails_to_update_node_id_if_conflict():
g = Graph()
n = TextToImageInvocation(id = "1", prompt = "Banana sushi")
g.add_node(n)
n2 = TextToImageInvocation(id = "2", prompt = "Banana sushi the second")
g.add_node(n2)
nu = TextToImageInvocation(id = "2", prompt = "Banana sushi")
with pytest.raises(NodeAlreadyInGraphError):
g.update_node("1", nu)
def test_graph_adds_edge():
g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2")
g.add_node(n1)
g.add_node(n2)
e = create_edge(n1.id,"image",n2.id,"image")
g.add_edge(e)
assert e in g.edges
def test_graph_fails_to_add_edge_with_cycle():
g = Graph()
n1 = UpscaleInvocation(id = "1")
g.add_node(n1)
e = create_edge(n1.id,"image",n1.id,"image")
with pytest.raises(InvalidEdgeError):
g.add_edge(e)
def test_graph_fails_to_add_edge_with_long_cycle():
g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2")
n3 = UpscaleInvocation(id = "3")
g.add_node(n1)
g.add_node(n2)
g.add_node(n3)
e1 = create_edge(n1.id,"image",n2.id,"image")
e2 = create_edge(n2.id,"image",n3.id,"image")
e3 = create_edge(n3.id,"image",n2.id,"image")
g.add_edge(e1)
g.add_edge(e2)
with pytest.raises(InvalidEdgeError):
g.add_edge(e3)
def test_graph_fails_to_add_edge_with_missing_node_id():
g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2")
g.add_node(n1)
g.add_node(n2)
e1 = create_edge("1","image","3","image")
e2 = create_edge("3","image","1","image")
with pytest.raises(InvalidEdgeError):
g.add_edge(e1)
with pytest.raises(InvalidEdgeError):
g.add_edge(e2)
def test_graph_fails_to_add_edge_when_destination_exists():
g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2")
n3 = UpscaleInvocation(id = "3")
g.add_node(n1)
g.add_node(n2)
g.add_node(n3)
e1 = create_edge(n1.id,"image",n2.id,"image")
e2 = create_edge(n1.id,"image",n3.id,"image")
e3 = create_edge(n2.id,"image",n3.id,"image")
g.add_edge(e1)
g.add_edge(e2)
with pytest.raises(InvalidEdgeError):
g.add_edge(e3)
def test_graph_fails_to_add_edge_with_mismatched_types():
g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2")
g.add_node(n1)
g.add_node(n2)
e1 = create_edge("1","image","2","strength")
with pytest.raises(InvalidEdgeError):
g.add_edge(e1)
def test_graph_connects_collector():
g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
n2 = TextToImageInvocation(id = "2", prompt = "Banana sushi 2")
n3 = CollectInvocation(id = "3")
n4 = ListPassThroughInvocation(id = "4")
g.add_node(n1)
g.add_node(n2)
g.add_node(n3)
g.add_node(n4)
e1 = create_edge("1","image","3","item")
e2 = create_edge("2","image","3","item")
e3 = create_edge("3","collection","4","collection")
g.add_edge(e1)
g.add_edge(e2)
g.add_edge(e3)
# TODO: test that derived types mixed with base types are compatible
def test_graph_collector_invalid_with_varying_input_types():
g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
n2 = PromptTestInvocation(id = "2", prompt = "banana sushi 2")
n3 = CollectInvocation(id = "3")
g.add_node(n1)
g.add_node(n2)
g.add_node(n3)
e1 = create_edge("1","image","3","item")
e2 = create_edge("2","prompt","3","item")
g.add_edge(e1)
with pytest.raises(InvalidEdgeError):
g.add_edge(e2)
def test_graph_collector_invalid_with_varying_input_output():
g = Graph()
n1 = PromptTestInvocation(id = "1", prompt = "Banana sushi")
n2 = PromptTestInvocation(id = "2", prompt = "Banana sushi 2")
n3 = CollectInvocation(id = "3")
n4 = ListPassThroughInvocation(id = "4")
g.add_node(n1)
g.add_node(n2)
g.add_node(n3)
g.add_node(n4)
e1 = create_edge("1","prompt","3","item")
e2 = create_edge("2","prompt","3","item")
e3 = create_edge("3","collection","4","collection")
g.add_edge(e1)
g.add_edge(e2)
with pytest.raises(InvalidEdgeError):
g.add_edge(e3)
def test_graph_collector_invalid_with_non_list_output():
g = Graph()
n1 = PromptTestInvocation(id = "1", prompt = "Banana sushi")
n2 = PromptTestInvocation(id = "2", prompt = "Banana sushi 2")
n3 = CollectInvocation(id = "3")
n4 = PromptTestInvocation(id = "4")
g.add_node(n1)
g.add_node(n2)
g.add_node(n3)
g.add_node(n4)
e1 = create_edge("1","prompt","3","item")
e2 = create_edge("2","prompt","3","item")
e3 = create_edge("3","collection","4","prompt")
g.add_edge(e1)
g.add_edge(e2)
with pytest.raises(InvalidEdgeError):
g.add_edge(e3)
def test_graph_connects_iterator():
g = Graph()
n1 = ListPassThroughInvocation(id = "1")
n2 = IterateInvocation(id = "2")
n3 = ImageToImageInvocation(id = "3", prompt = "Banana sushi")
g.add_node(n1)
g.add_node(n2)
g.add_node(n3)
e1 = create_edge("1","collection","2","collection")
e2 = create_edge("2","item","3","image")
g.add_edge(e1)
g.add_edge(e2)
# TODO: TEST INVALID ITERATOR SCENARIOS
def test_graph_iterator_invalid_if_multiple_inputs():
g = Graph()
n1 = ListPassThroughInvocation(id = "1")
n2 = IterateInvocation(id = "2")
n3 = ImageToImageInvocation(id = "3", prompt = "Banana sushi")
n4 = ListPassThroughInvocation(id = "4")
g.add_node(n1)
g.add_node(n2)
g.add_node(n3)
g.add_node(n4)
e1 = create_edge("1","collection","2","collection")
e2 = create_edge("2","item","3","image")
e3 = create_edge("4","collection","2","collection")
g.add_edge(e1)
g.add_edge(e2)
with pytest.raises(InvalidEdgeError):
g.add_edge(e3)
def test_graph_iterator_invalid_if_input_not_list():
g = Graph()
n1 = TextToImageInvocation(id = "1", promopt = "Banana sushi")
n2 = IterateInvocation(id = "2")
g.add_node(n1)
g.add_node(n2)
e1 = create_edge("1","collection","2","collection")
with pytest.raises(InvalidEdgeError):
g.add_edge(e1)
def test_graph_iterator_invalid_if_output_and_input_types_different():
g = Graph()
n1 = ListPassThroughInvocation(id = "1")
n2 = IterateInvocation(id = "2")
n3 = PromptTestInvocation(id = "3", prompt = "Banana sushi")
g.add_node(n1)
g.add_node(n2)
g.add_node(n3)
e1 = create_edge("1","collection","2","collection")
e2 = create_edge("2","item","3","prompt")
g.add_edge(e1)
with pytest.raises(InvalidEdgeError):
g.add_edge(e2)
def test_graph_validates():
g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2")
g.add_node(n1)
g.add_node(n2)
e1 = create_edge("1","image","2","image")
g.add_edge(e1)
assert g.is_valid() == True
def test_graph_invalid_if_edges_reference_missing_nodes():
g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
g.nodes[n1.id] = n1
e1 = create_edge("1","image","2","image")
g.edges.append(e1)
assert g.is_valid() == False
def test_graph_invalid_if_subgraph_invalid():
g = Graph()
n1 = GraphInvocation(id = "1")
n1.graph = Graph()
n1_1 = TextToImageInvocation(id = "2", prompt = "Banana sushi")
n1.graph.nodes[n1_1.id] = n1_1
e1 = create_edge("1","image","2","image")
n1.graph.edges.append(e1)
g.nodes[n1.id] = n1
assert g.is_valid() == False
def test_graph_invalid_if_has_cycle():
g = Graph()
n1 = UpscaleInvocation(id = "1")
n2 = UpscaleInvocation(id = "2")
g.nodes[n1.id] = n1
g.nodes[n2.id] = n2
e1 = create_edge("1","image","2","image")
e2 = create_edge("2","image","1","image")
g.edges.append(e1)
g.edges.append(e2)
assert g.is_valid() == False
def test_graph_invalid_with_invalid_connection():
g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2")
g.nodes[n1.id] = n1
g.nodes[n2.id] = n2
e1 = create_edge("1","image","2","strength")
g.edges.append(e1)
assert g.is_valid() == False
# TODO: Subgraph operations
def test_graph_gets_subgraph_node():
g = Graph()
n1 = GraphInvocation(id = "1")
n1.graph = Graph()
n1.graph.add_node
n1_1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
n1.graph.add_node(n1_1)
g.add_node(n1)
result = g.get_node('1.1')
assert result is not None
assert result.id == '1'
assert result == n1_1
def test_graph_fails_to_get_missing_subgraph_node():
g = Graph()
n1 = GraphInvocation(id = "1")
n1.graph = Graph()
n1.graph.add_node
n1_1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
n1.graph.add_node(n1_1)
g.add_node(n1)
with pytest.raises(NodeNotFoundError):
result = g.get_node('1.2')
def test_graph_fails_to_enumerate_non_subgraph_node():
g = Graph()
n1 = GraphInvocation(id = "1")
n1.graph = Graph()
n1.graph.add_node
n1_1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
n1.graph.add_node(n1_1)
g.add_node(n1)
n2 = UpscaleInvocation(id = "2")
g.add_node(n2)
with pytest.raises(NodeNotFoundError):
result = g.get_node('2.1')
def test_graph_gets_networkx_graph():
g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2")
g.add_node(n1)
g.add_node(n2)
e = create_edge(n1.id,"image",n2.id,"image")
g.add_edge(e)
nxg = g.nx_graph()
assert '1' in nxg.nodes
assert '2' in nxg.nodes
assert ('1','2') in nxg.edges
# TODO: Graph serializes and deserializes
def test_graph_can_serialize():
g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2")
g.add_node(n1)
g.add_node(n2)
e = create_edge(n1.id,"image",n2.id,"image")
g.add_edge(e)
# Not throwing on this line is sufficient
json = g.json()
def test_graph_can_deserialize():
g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2")
g.add_node(n1)
g.add_node(n2)
e = create_edge(n1.id,"image",n2.id,"image")
g.add_edge(e)
json = g.json()
g2 = Graph.parse_raw(json)
assert g2 is not None
assert g2.nodes['1'] is not None
assert g2.nodes['2'] is not None
assert len(g2.edges) == 1
assert g2.edges[0][0].node_id == '1'
assert g2.edges[0][0].field == 'image'
assert g2.edges[0][1].node_id == '2'
assert g2.edges[0][1].field == 'image'
def test_graph_can_generate_schema():
# Not throwing on this line is sufficient
# NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation
schema = Graph.schema_json(indent = 2)

92
tests/nodes/test_nodes.py Normal file
View File

@ -0,0 +1,92 @@
from typing import Any, Callable, Literal
from ldm.invoke.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from ldm.invoke.app.invocations.image import ImageField
from ldm.invoke.app.services.invocation_services import InvocationServices
from pydantic import Field
import pytest
# Define test invocations before importing anything that uses invocations
class ListPassThroughInvocationOutput(BaseInvocationOutput):
type: Literal['test_list_output'] = 'test_list_output'
collection: list[ImageField] = Field(default_factory=list)
class ListPassThroughInvocation(BaseInvocation):
type: Literal['test_list'] = 'test_list'
collection: list[ImageField] = Field(default_factory=list)
def invoke(self, context: InvocationContext) -> ListPassThroughInvocationOutput:
return ListPassThroughInvocationOutput(collection = self.collection)
class PromptTestInvocationOutput(BaseInvocationOutput):
type: Literal['test_prompt_output'] = 'test_prompt_output'
prompt: str = Field(default = "")
class PromptTestInvocation(BaseInvocation):
type: Literal['test_prompt'] = 'test_prompt'
prompt: str = Field(default = "")
def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput:
return PromptTestInvocationOutput(prompt = self.prompt)
class ImageTestInvocationOutput(BaseInvocationOutput):
type: Literal['test_image_output'] = 'test_image_output'
image: ImageField = Field()
class ImageTestInvocation(BaseInvocation):
type: Literal['test_image'] = 'test_image'
prompt: str = Field(default = "")
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
class PromptCollectionTestInvocationOutput(BaseInvocationOutput):
type: Literal['test_prompt_collection_output'] = 'test_prompt_collection_output'
collection: list[str] = Field(default_factory=list)
class PromptCollectionTestInvocation(BaseInvocation):
type: Literal['test_prompt_collection'] = 'test_prompt_collection'
collection: list[str] = Field()
def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput:
return PromptCollectionTestInvocationOutput(collection=self.collection.copy())
from ldm.invoke.app.services.events import EventServiceBase
from ldm.invoke.app.services.graph import EdgeConnection
def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> tuple[EdgeConnection, EdgeConnection]:
return (EdgeConnection(node_id = from_id, field = from_field), EdgeConnection(node_id = to_id, field = to_field))
class TestEvent:
event_name: str
payload: Any
def __init__(self, event_name: str, payload: Any):
self.event_name = event_name
self.payload = payload
class TestEventService(EventServiceBase):
events: list
def __init__(self):
super().__init__()
self.events = list()
def dispatch(self, event_name: str, payload: Any) -> None:
pass
def wait_until(condition: Callable[[], bool], timeout: int = 10, interval: float = 0.1) -> None:
import time
start_time = time.time()
while time.time() - start_time < timeout:
if condition():
return
time.sleep(interval)
raise TimeoutError("Condition not met")

112
tests/nodes/test_sqlite.py Normal file
View File

@ -0,0 +1,112 @@
from ldm.invoke.app.services.sqlite import SqliteItemStorage, sqlite_memory
from pydantic import BaseModel, Field
class TestModel(BaseModel):
id: str = Field(description = "ID")
name: str = Field(description = "Name")
def test_sqlite_service_can_create_and_get():
db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id')
db.set(TestModel(id = '1', name = 'Test'))
assert db.get('1') == TestModel(id = '1', name = 'Test')
def test_sqlite_service_can_list():
db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id')
db.set(TestModel(id = '1', name = 'Test'))
db.set(TestModel(id = '2', name = 'Test'))
db.set(TestModel(id = '3', name = 'Test'))
results = db.list()
assert results.page == 0
assert results.pages == 1
assert results.per_page == 10
assert results.total == 3
assert results.items == [TestModel(id = '1', name = 'Test'), TestModel(id = '2', name = 'Test'), TestModel(id = '3', name = 'Test')]
def test_sqlite_service_can_delete():
db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id')
db.set(TestModel(id = '1', name = 'Test'))
db.delete('1')
assert db.get('1') is None
def test_sqlite_service_calls_set_callback():
db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id')
called = False
def on_changed(item: TestModel):
nonlocal called
called = True
db.on_changed(on_changed)
db.set(TestModel(id = '1', name = 'Test'))
assert called
def test_sqlite_service_calls_delete_callback():
db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id')
called = False
def on_deleted(item_id: str):
nonlocal called
called = True
db.on_deleted(on_deleted)
db.set(TestModel(id = '1', name = 'Test'))
db.delete('1')
assert called
def test_sqlite_service_can_list_with_pagination():
db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id')
db.set(TestModel(id = '1', name = 'Test'))
db.set(TestModel(id = '2', name = 'Test'))
db.set(TestModel(id = '3', name = 'Test'))
results = db.list(page = 0, per_page = 2)
assert results.page == 0
assert results.pages == 2
assert results.per_page == 2
assert results.total == 3
assert results.items == [TestModel(id = '1', name = 'Test'), TestModel(id = '2', name = 'Test')]
def test_sqlite_service_can_list_with_pagination_and_offset():
db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id')
db.set(TestModel(id = '1', name = 'Test'))
db.set(TestModel(id = '2', name = 'Test'))
db.set(TestModel(id = '3', name = 'Test'))
results = db.list(page = 1, per_page = 2)
assert results.page == 1
assert results.pages == 2
assert results.per_page == 2
assert results.total == 3
assert results.items == [TestModel(id = '3', name = 'Test')]
def test_sqlite_service_can_search():
db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id')
db.set(TestModel(id = '1', name = 'Test'))
db.set(TestModel(id = '2', name = 'Test'))
db.set(TestModel(id = '3', name = 'Test'))
results = db.search(query = 'Test')
assert results.page == 0
assert results.pages == 1
assert results.per_page == 10
assert results.total == 3
assert results.items == [TestModel(id = '1', name = 'Test'), TestModel(id = '2', name = 'Test'), TestModel(id = '3', name = 'Test')]
def test_sqlite_service_can_search_with_pagination():
db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id')
db.set(TestModel(id = '1', name = 'Test'))
db.set(TestModel(id = '2', name = 'Test'))
db.set(TestModel(id = '3', name = 'Test'))
results = db.search(query = 'Test', page = 0, per_page = 2)
assert results.page == 0
assert results.pages == 2
assert results.per_page == 2
assert results.total == 3
assert results.items == [TestModel(id = '1', name = 'Test'), TestModel(id = '2', name = 'Test')]
def test_sqlite_service_can_search_with_pagination_and_offset():
db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id')
db.set(TestModel(id = '1', name = 'Test'))
db.set(TestModel(id = '2', name = 'Test'))
db.set(TestModel(id = '3', name = 'Test'))
results = db.search(query = 'Test', page = 1, per_page = 2)
assert results.page == 1
assert results.pages == 2
assert results.per_page == 2
assert results.total == 3
assert results.items == [TestModel(id = '3', name = 'Test')]