mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix batch_manager test
This commit is contained in:
parent
796ff34c8a
commit
29fceb960d
@ -12,7 +12,7 @@ from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
|||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.services.invocation_services import InvocationServices
|
from invokeai.app.services.invocation_services import InvocationServices
|
||||||
from invokeai.app.services.invocation_stats import InvocationStatsService
|
from invokeai.app.services.invocation_stats import InvocationStatsService
|
||||||
from invokeai.app.services.batch_manager_storage import SqliteBatchProcessStorage
|
from invokeai.app.services.batch_manager_storage import BatchData, SqliteBatchProcessStorage
|
||||||
from invokeai.app.services.batch_manager import (
|
from invokeai.app.services.batch_manager import (
|
||||||
Batch,
|
Batch,
|
||||||
BatchManager,
|
BatchManager,
|
||||||
@ -36,50 +36,37 @@ def simple_graph():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def simple_batches():
|
def simple_batch():
|
||||||
batches = [
|
return Batch(
|
||||||
Batch(
|
|
||||||
node_id=1,
|
|
||||||
data=[
|
data=[
|
||||||
{
|
[
|
||||||
"prompt": "Tomato sushi",
|
BatchData(
|
||||||
},
|
node_id="1",
|
||||||
{
|
field_name="prompt",
|
||||||
"prompt": "Strawberry sushi",
|
items=[
|
||||||
},
|
"Tomato sushi",
|
||||||
{
|
"Strawberry sushi",
|
||||||
"prompt": "Broccoli sushi",
|
"Broccoli sushi",
|
||||||
},
|
"Asparagus sushi",
|
||||||
{
|
"Tea sushi",
|
||||||
"prompt": "Asparagus sushi",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"prompt": "Tea sushi",
|
|
||||||
},
|
|
||||||
],
|
],
|
||||||
),
|
)
|
||||||
Batch(
|
],
|
||||||
|
[
|
||||||
|
BatchData(
|
||||||
node_id="2",
|
node_id="2",
|
||||||
data=[
|
field_name="prompt",
|
||||||
{
|
items=[
|
||||||
"prompt2": "Ume sushi",
|
"Ume sushi",
|
||||||
},
|
"Ichigo sushi",
|
||||||
{
|
"Momo sushi",
|
||||||
"prompt2": "Ichigo sushi",
|
"Mikan sushi",
|
||||||
},
|
"Cha sushi",
|
||||||
{
|
],
|
||||||
"prompt2": "Momo sushi",
|
)
|
||||||
},
|
|
||||||
{
|
|
||||||
"prompt2": "Mikan sushi",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"prompt2": "Cha sushi",
|
|
||||||
},
|
|
||||||
],
|
],
|
||||||
),
|
|
||||||
]
|
]
|
||||||
return batches
|
)
|
||||||
|
|
||||||
|
|
||||||
# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
|
# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
|
||||||
@ -89,9 +76,7 @@ def simple_batches():
|
|||||||
def mock_services() -> InvocationServices:
|
def mock_services() -> InvocationServices:
|
||||||
# NOTE: none of these are actually called by the test invocations
|
# NOTE: none of these are actually called by the test invocations
|
||||||
db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
|
db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](conn=db_conn, table_name="graph_executions")
|
||||||
conn=db_conn, table_name="graph_executions"
|
|
||||||
)
|
|
||||||
batch_manager_storage = SqliteBatchProcessStorage(conn=db_conn)
|
batch_manager_storage = SqliteBatchProcessStorage(conn=db_conn)
|
||||||
return InvocationServices(
|
return InvocationServices(
|
||||||
model_manager=None, # type: ignore
|
model_manager=None, # type: ignore
|
||||||
@ -188,9 +173,9 @@ def test_handles_errors(mock_invoker: Invoker):
|
|||||||
assert all((i in g.errors for i in g.source_prepared_mapping["1"]))
|
assert all((i in g.errors for i in g.source_prepared_mapping["1"]))
|
||||||
|
|
||||||
|
|
||||||
def test_can_create_batch(mock_invoker: Invoker, simple_graph, simple_batches):
|
def test_can_create_batch(mock_invoker: Invoker, simple_graph, simple_batch):
|
||||||
batch_process_res = mock_invoker.services.batch_manager.create_batch_process(
|
batch_process_res = mock_invoker.services.batch_manager.create_batch_process(
|
||||||
batches=simple_batches,
|
batch=simple_batch,
|
||||||
graph=simple_graph,
|
graph=simple_graph,
|
||||||
)
|
)
|
||||||
assert batch_process_res.batch_id
|
assert batch_process_res.batch_id
|
||||||
|
Loading…
Reference in New Issue
Block a user