Fix batch_manager test

This commit is contained in:
Brandon Rising 2023-08-16 15:33:15 -04:00
parent 796ff34c8a
commit 29fceb960d

View File

@ -12,7 +12,7 @@ from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats import InvocationStatsService
from invokeai.app.services.batch_manager_storage import SqliteBatchProcessStorage
from invokeai.app.services.batch_manager_storage import BatchData, SqliteBatchProcessStorage
from invokeai.app.services.batch_manager import (
Batch,
BatchManager,
@ -36,50 +36,37 @@ def simple_graph():
@pytest.fixture
def simple_batches():
batches = [
Batch(
node_id=1,
data=[
{
"prompt": "Tomato sushi",
},
{
"prompt": "Strawberry sushi",
},
{
"prompt": "Broccoli sushi",
},
{
"prompt": "Asparagus sushi",
},
{
"prompt": "Tea sushi",
},
def simple_batch():
return Batch(
data=[
[
BatchData(
node_id="1",
field_name="prompt",
items=[
"Tomato sushi",
"Strawberry sushi",
"Broccoli sushi",
"Asparagus sushi",
"Tea sushi",
],
)
],
),
Batch(
node_id="2",
data=[
{
"prompt2": "Ume sushi",
},
{
"prompt2": "Ichigo sushi",
},
{
"prompt2": "Momo sushi",
},
{
"prompt2": "Mikan sushi",
},
{
"prompt2": "Cha sushi",
},
[
BatchData(
node_id="2",
field_name="prompt",
items=[
"Ume sushi",
"Ichigo sushi",
"Momo sushi",
"Mikan sushi",
"Cha sushi",
],
)
],
),
]
return batches
]
)
# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
@ -89,9 +76,7 @@ def simple_batches():
def mock_services() -> InvocationServices:
# NOTE: none of these are actually called by the test invocations
db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
conn=db_conn, table_name="graph_executions"
)
graph_execution_manager = SqliteItemStorage[GraphExecutionState](conn=db_conn, table_name="graph_executions")
batch_manager_storage = SqliteBatchProcessStorage(conn=db_conn)
return InvocationServices(
model_manager=None, # type: ignore
@ -188,9 +173,9 @@ def test_handles_errors(mock_invoker: Invoker):
assert all((i in g.errors for i in g.source_prepared_mapping["1"]))
def test_can_create_batch(mock_invoker: Invoker, simple_graph, simple_batches):
def test_can_create_batch(mock_invoker: Invoker, simple_graph, simple_batch):
batch_process_res = mock_invoker.services.batch_manager.create_batch_process(
batches=simple_batches,
batch=simple_batch,
graph=simple_graph,
)
assert batch_process_res.batch_id