From 796ee1246b532d89b66e2555f5275017e5f67ae1 Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Wed, 16 Aug 2023 15:42:45 -0400 Subject: [PATCH] Add a batch validation test --- tests/nodes/test_invoker.py | 77 +++++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/tests/nodes/test_invoker.py b/tests/nodes/test_invoker.py index a0b4de5943..4ef35e6aaf 100644 --- a/tests/nodes/test_invoker.py +++ b/tests/nodes/test_invoker.py @@ -182,3 +182,80 @@ def test_can_create_batch(mock_invoker: Invoker, simple_graph, simple_batch): assert len(batch_process_res.session_ids) == 25 for session in batch_process_res.session_ids: assert mock_invoker.services.graph_execution_manager.get(session) + + +def test_can_create_bad_batches(): + batch = None + try: + batch = Batch( # This batch has a duplicate node_id|fieldname combo + data=[ + [ + BatchData( + node_id="1", + field_name="prompt", + items=[ + "Tomato sushi", + ], + ) + ], + [ + BatchData( + node_id="1", + field_name="prompt", + items=[ + "Ume sushi", + ], + ) + ], + ] + ) + except Exception as e: + assert e + try: + batch = Batch( # This batch has different item list lengths in the same group + data=[ + [ + BatchData( + node_id="1", + field_name="prompt", + items=[ + "Tomato sushi", + ], + ), + BatchData( + node_id="1", + field_name="prompt", + items=[ + "Tomato sushi", + "Courgette sushi", + ], + ), + ], + [ + BatchData( + node_id="1", + field_name="prompt", + items=[ + "Ume sushi", + ], + ) + ], + ] + ) + except Exception as e: + assert e + try: + batch = Batch( # This batch has a type mismatch in single items list + data=[ + [ + BatchData( + node_id="1", + field_name="prompt", + items=["Tomato sushi", 5], + ), + ], + ] + ) + except Exception as e: + assert e + assert not batch