diff --git a/tests/nodes/test_invoker.py b/tests/nodes/test_invoker.py index 677b0a372d..9051fcb403 100644 --- a/tests/nodes/test_invoker.py +++ b/tests/nodes/test_invoker.py @@ -12,6 +12,11 @@ from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory from invokeai.app.services.invoker import Invoker from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invocation_stats import InvocationStatsService +from invokeai.app.services.batch_manager_storage import SqliteBatchProcessStorage +from invokeai.app.services.batch_manager import ( + Batch, + BatchManager, +) from invokeai.app.services.graph import ( Graph, GraphExecutionState, @@ -29,6 +34,53 @@ def simple_graph(): return g +@pytest.fixture +def simple_batches(): + batches = [ + Batch( + node_id=1, + data= [ + { + "prompt": "Tomato sushi", + }, + { + "prompt": "Strawberry sushi", + }, + { + "prompt": "Broccoli sushi", + }, + { + "prompt": "Asparagus sushi", + }, + { + "prompt": "Tea sushi", + }, + ] + ), + Batch( + node_id="2", + data= [ + { + "prompt2": "Ume sushi", + }, + { + "prompt2": "Ichigo sushi", + }, + { + "prompt2": "Momo sushi", + }, + { + "prompt2": "Mikan sushi", + }, + { + "prompt2": "Cha sushi", + }, + ] + ) + ] + return batches + + # This must be defined here to avoid issues with the dynamic creation of the union of all invocation types # Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate # the test invocations. @@ -38,13 +90,14 @@ def mock_services() -> InvocationServices: graph_execution_manager = SqliteItemStorage[GraphExecutionState]( filename=sqlite_memory, table_name="graph_executions" ) + batch_manager_storage = SqliteBatchProcessStorage(sqlite_memory) return InvocationServices( model_manager=None, # type: ignore events=TestEventService(), logger=None, # type: ignore images=None, # type: ignore latents=None, # type: ignore - batch_manager=None, # type: ignore + batch_manager=BatchManager(batch_manager_storage), boards=None, # type: ignore board_images=None, # type: ignore queue=MemoryInvocationQueue(), @@ -131,3 +184,13 @@ def test_handles_errors(mock_invoker: Invoker): assert g.is_complete() assert all((i in g.errors for i in g.source_prepared_mapping["1"])) + +def test_can_create_batch(mock_invoker: Invoker, simple_graph, simple_batches): + batch_process_res = mock_invoker.services.batch_manager.create_batch_process( + batches=simple_batches, + graph=simple_graph, + ) + assert batch_process_res.batch_id + assert len(batch_process_res.session_ids) == 25 + for session in batch_process_res.session_ids: + assert mock_invoker.services.graph_execution_manager.get(session) diff --git a/tests/nodes/test_nodes.py b/tests/nodes/test_nodes.py index 13338e9261..adfb0a556b 100644 --- a/tests/nodes/test_nodes.py +++ b/tests/nodes/test_nodes.py @@ -54,6 +54,7 @@ class TextToImageTestInvocation(BaseInvocation): type: Literal["test_text_to_image"] = "test_text_to_image" prompt: str = Field(default="") + prompt2: str = Field(default="") def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput: return ImageTestInvocationOutput(image=ImageField(image_name=self.id))