Add test for batch manager

This commit is contained in:
Brandon Rising 2023-08-14 10:57:18 -04:00
parent 69f541075c
commit 846e52f2ea
2 changed files with 65 additions and 1 deletions

View File

@ -12,6 +12,11 @@ from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats import InvocationStatsService
from invokeai.app.services.batch_manager_storage import SqliteBatchProcessStorage
from invokeai.app.services.batch_manager import (
Batch,
BatchManager,
)
from invokeai.app.services.graph import (
Graph,
GraphExecutionState,
@ -29,6 +34,53 @@ def simple_graph():
return g
@pytest.fixture
def simple_batches():
batches = [
Batch(
node_id=1,
data= [
{
"prompt": "Tomato sushi",
},
{
"prompt": "Strawberry sushi",
},
{
"prompt": "Broccoli sushi",
},
{
"prompt": "Asparagus sushi",
},
{
"prompt": "Tea sushi",
},
]
),
Batch(
node_id="2",
data= [
{
"prompt2": "Ume sushi",
},
{
"prompt2": "Ichigo sushi",
},
{
"prompt2": "Momo sushi",
},
{
"prompt2": "Mikan sushi",
},
{
"prompt2": "Cha sushi",
},
]
)
]
return batches
# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
# the test invocations.
@ -38,13 +90,14 @@ def mock_services() -> InvocationServices:
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
filename=sqlite_memory, table_name="graph_executions"
)
batch_manager_storage = SqliteBatchProcessStorage(sqlite_memory)
return InvocationServices(
model_manager=None, # type: ignore
events=TestEventService(),
logger=None, # type: ignore
images=None, # type: ignore
latents=None, # type: ignore
batch_manager=None, # type: ignore
batch_manager=BatchManager(batch_manager_storage),
boards=None, # type: ignore
board_images=None, # type: ignore
queue=MemoryInvocationQueue(),
@ -131,3 +184,13 @@ def test_handles_errors(mock_invoker: Invoker):
assert g.is_complete()
assert all((i in g.errors for i in g.source_prepared_mapping["1"]))
def test_can_create_batch(mock_invoker: Invoker, simple_graph, simple_batches):
batch_process_res = mock_invoker.services.batch_manager.create_batch_process(
batches=simple_batches,
graph=simple_graph,
)
assert batch_process_res.batch_id
assert len(batch_process_res.session_ids) == 25
for session in batch_process_res.session_ids:
assert mock_invoker.services.graph_execution_manager.get(session)

View File

@ -54,6 +54,7 @@ class TextToImageTestInvocation(BaseInvocation):
type: Literal["test_text_to_image"] = "test_text_to_image"
prompt: str = Field(default="")
prompt2: str = Field(default="")
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))