Fix batch_manager test

This commit is contained in:
Brandon Rising 2023-08-16 15:33:15 -04:00
parent 796ff34c8a
commit 29fceb960d

View File

@ -12,7 +12,7 @@ from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats import InvocationStatsService from invokeai.app.services.invocation_stats import InvocationStatsService
from invokeai.app.services.batch_manager_storage import SqliteBatchProcessStorage from invokeai.app.services.batch_manager_storage import BatchData, SqliteBatchProcessStorage
from invokeai.app.services.batch_manager import ( from invokeai.app.services.batch_manager import (
Batch, Batch,
BatchManager, BatchManager,
@ -36,50 +36,37 @@ def simple_graph():
@pytest.fixture @pytest.fixture
def simple_batches(): def simple_batch():
batches = [ return Batch(
Batch( data=[
node_id=1, [
data=[ BatchData(
{ node_id="1",
"prompt": "Tomato sushi", field_name="prompt",
}, items=[
{ "Tomato sushi",
"prompt": "Strawberry sushi", "Strawberry sushi",
}, "Broccoli sushi",
{ "Asparagus sushi",
"prompt": "Broccoli sushi", "Tea sushi",
}, ],
{ )
"prompt": "Asparagus sushi",
},
{
"prompt": "Tea sushi",
},
], ],
), [
Batch( BatchData(
node_id="2", node_id="2",
data=[ field_name="prompt",
{ items=[
"prompt2": "Ume sushi", "Ume sushi",
}, "Ichigo sushi",
{ "Momo sushi",
"prompt2": "Ichigo sushi", "Mikan sushi",
}, "Cha sushi",
{ ],
"prompt2": "Momo sushi", )
},
{
"prompt2": "Mikan sushi",
},
{
"prompt2": "Cha sushi",
},
], ],
), ]
] )
return batches
# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types # This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
@ -89,9 +76,7 @@ def simple_batches():
def mock_services() -> InvocationServices: def mock_services() -> InvocationServices:
# NOTE: none of these are actually called by the test invocations # NOTE: none of these are actually called by the test invocations
db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False) db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
graph_execution_manager = SqliteItemStorage[GraphExecutionState]( graph_execution_manager = SqliteItemStorage[GraphExecutionState](conn=db_conn, table_name="graph_executions")
conn=db_conn, table_name="graph_executions"
)
batch_manager_storage = SqliteBatchProcessStorage(conn=db_conn) batch_manager_storage = SqliteBatchProcessStorage(conn=db_conn)
return InvocationServices( return InvocationServices(
model_manager=None, # type: ignore model_manager=None, # type: ignore
@ -188,9 +173,9 @@ def test_handles_errors(mock_invoker: Invoker):
assert all((i in g.errors for i in g.source_prepared_mapping["1"])) assert all((i in g.errors for i in g.source_prepared_mapping["1"]))
def test_can_create_batch(mock_invoker: Invoker, simple_graph, simple_batches): def test_can_create_batch(mock_invoker: Invoker, simple_graph, simple_batch):
batch_process_res = mock_invoker.services.batch_manager.create_batch_process( batch_process_res = mock_invoker.services.batch_manager.create_batch_process(
batches=simple_batches, batch=simple_batch,
graph=simple_graph, graph=simple_graph,
) )
assert batch_process_res.batch_id assert batch_process_res.batch_id