mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix batch_manager test
This commit is contained in:
parent
796ff34c8a
commit
29fceb960d
@ -12,7 +12,7 @@ from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.invocation_stats import InvocationStatsService
|
||||
from invokeai.app.services.batch_manager_storage import SqliteBatchProcessStorage
|
||||
from invokeai.app.services.batch_manager_storage import BatchData, SqliteBatchProcessStorage
|
||||
from invokeai.app.services.batch_manager import (
|
||||
Batch,
|
||||
BatchManager,
|
||||
@ -36,50 +36,37 @@ def simple_graph():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_batches():
|
||||
batches = [
|
||||
Batch(
|
||||
node_id=1,
|
||||
data=[
|
||||
{
|
||||
"prompt": "Tomato sushi",
|
||||
},
|
||||
{
|
||||
"prompt": "Strawberry sushi",
|
||||
},
|
||||
{
|
||||
"prompt": "Broccoli sushi",
|
||||
},
|
||||
{
|
||||
"prompt": "Asparagus sushi",
|
||||
},
|
||||
{
|
||||
"prompt": "Tea sushi",
|
||||
},
|
||||
def simple_batch():
|
||||
return Batch(
|
||||
data=[
|
||||
[
|
||||
BatchData(
|
||||
node_id="1",
|
||||
field_name="prompt",
|
||||
items=[
|
||||
"Tomato sushi",
|
||||
"Strawberry sushi",
|
||||
"Broccoli sushi",
|
||||
"Asparagus sushi",
|
||||
"Tea sushi",
|
||||
],
|
||||
)
|
||||
],
|
||||
),
|
||||
Batch(
|
||||
node_id="2",
|
||||
data=[
|
||||
{
|
||||
"prompt2": "Ume sushi",
|
||||
},
|
||||
{
|
||||
"prompt2": "Ichigo sushi",
|
||||
},
|
||||
{
|
||||
"prompt2": "Momo sushi",
|
||||
},
|
||||
{
|
||||
"prompt2": "Mikan sushi",
|
||||
},
|
||||
{
|
||||
"prompt2": "Cha sushi",
|
||||
},
|
||||
[
|
||||
BatchData(
|
||||
node_id="2",
|
||||
field_name="prompt",
|
||||
items=[
|
||||
"Ume sushi",
|
||||
"Ichigo sushi",
|
||||
"Momo sushi",
|
||||
"Mikan sushi",
|
||||
"Cha sushi",
|
||||
],
|
||||
)
|
||||
],
|
||||
),
|
||||
]
|
||||
return batches
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
|
||||
@ -89,9 +76,7 @@ def simple_batches():
|
||||
def mock_services() -> InvocationServices:
|
||||
# NOTE: none of these are actually called by the test invocations
|
||||
db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||
conn=db_conn, table_name="graph_executions"
|
||||
)
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](conn=db_conn, table_name="graph_executions")
|
||||
batch_manager_storage = SqliteBatchProcessStorage(conn=db_conn)
|
||||
return InvocationServices(
|
||||
model_manager=None, # type: ignore
|
||||
@ -188,9 +173,9 @@ def test_handles_errors(mock_invoker: Invoker):
|
||||
assert all((i in g.errors for i in g.source_prepared_mapping["1"]))
|
||||
|
||||
|
||||
def test_can_create_batch(mock_invoker: Invoker, simple_graph, simple_batches):
|
||||
def test_can_create_batch(mock_invoker: Invoker, simple_graph, simple_batch):
|
||||
batch_process_res = mock_invoker.services.batch_manager.create_batch_process(
|
||||
batches=simple_batches,
|
||||
batch=simple_batch,
|
||||
graph=simple_graph,
|
||||
)
|
||||
assert batch_process_res.batch_id
|
||||
|
Loading…
Reference in New Issue
Block a user