Merge branch 'main' into refactor/model-manager-2

This commit is contained in:
psychedelicious
2023-11-14 07:51:57 +11:00
committed by GitHub
253 changed files with 1712 additions and 3981 deletions

View File

@ -150,8 +150,8 @@ def test_graph_state_expands_iterator(mock_services):
invoke_next(g, mock_services)
prepared_add_nodes = g.source_prepared_mapping["3"]
results = set([g.results[n].value for n in prepared_add_nodes])
expected = set([1, 11, 21])
results = {g.results[n].value for n in prepared_add_nodes}
expected = {1, 11, 21}
assert results == expected
@ -230,7 +230,7 @@ def test_graph_executes_depth_first(mock_services):
# Because ordering is not guaranteed, we cannot compare results directly.
# Instead, we must count the number of results.
def get_completed_count(g, id):
ids = [i for i in g.source_prepared_mapping[id]]
ids = list(g.source_prepared_mapping[id])
completed_ids = [i for i in g.executed if i in ids]
return len(completed_ids)

View File

@ -471,7 +471,6 @@ def test_graph_gets_subgraph_node():
g = Graph()
n1 = GraphInvocation(id="1")
n1.graph = Graph()
n1.graph.add_node
n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
n1.graph.add_node(n1_1)
@ -503,8 +502,8 @@ def test_graph_expands_subgraph():
g.add_edge(create_edge("1.2", "value", "2", "a"))
dg = g.nx_graph_flat()
assert set(dg.nodes) == set(["1.1", "1.2", "2"])
assert set(dg.edges) == set([("1.1", "1.2"), ("1.2", "2")])
assert set(dg.nodes) == {"1.1", "1.2", "2"}
assert set(dg.edges) == {("1.1", "1.2"), ("1.2", "2")}
def test_graph_subgraph_t2i():
@ -532,9 +531,7 @@ def test_graph_subgraph_t2i():
# Validate
dg = g.nx_graph_flat()
assert set(dg.nodes) == set(
["1.width", "1.height", "1.seed", "1.3", "1.4", "1.5", "1.6", "1.7", "1.8", "2", "3", "4"]
)
assert set(dg.nodes) == {"1.width", "1.height", "1.seed", "1.3", "1.4", "1.5", "1.6", "1.7", "1.8", "2", "3", "4"}
expected_edges = [(f"1.{e.source.node_id}", f"1.{e.destination.node_id}") for e in lg.graph.edges]
expected_edges.extend([("2", "1.width"), ("3", "1.height"), ("1.8", "4")])
print(expected_edges)
@ -546,7 +543,6 @@ def test_graph_fails_to_get_missing_subgraph_node():
g = Graph()
n1 = GraphInvocation(id="1")
n1.graph = Graph()
n1.graph.add_node
n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
n1.graph.add_node(n1_1)
@ -561,7 +557,6 @@ def test_graph_fails_to_enumerate_non_subgraph_node():
g = Graph()
n1 = GraphInvocation(id="1")
n1.graph = Graph()
n1.graph.add_node
n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
n1.graph.add_node(n1_1)

View File

@ -130,7 +130,7 @@ class TestEventService(EventServiceBase):
def __init__(self):
super().__init__()
self.events = list()
self.events = []
def dispatch(self, event_name: str, payload: Any) -> None:
pass

View File

@ -169,7 +169,7 @@ def test_prepare_values_to_insert(batch_data_collection, batch_graph):
NodeFieldValueValidator = TypeAdapter(list[NodeFieldValue])
# should have 3 node field values
assert type(values[0].field_values) is str
assert isinstance(values[0].field_values, str)
assert len(NodeFieldValueValidator.validate_json(values[0].field_values)) == 3
# should have batch id and priority