mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add test for batch manager
This commit is contained in:
parent
69f541075c
commit
846e52f2ea
@ -12,6 +12,11 @@ from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.invocation_stats import InvocationStatsService
|
||||
from invokeai.app.services.batch_manager_storage import SqliteBatchProcessStorage
|
||||
from invokeai.app.services.batch_manager import (
|
||||
Batch,
|
||||
BatchManager,
|
||||
)
|
||||
from invokeai.app.services.graph import (
|
||||
Graph,
|
||||
GraphExecutionState,
|
||||
@ -29,6 +34,53 @@ def simple_graph():
|
||||
return g
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_batches():
|
||||
batches = [
|
||||
Batch(
|
||||
node_id=1,
|
||||
data= [
|
||||
{
|
||||
"prompt": "Tomato sushi",
|
||||
},
|
||||
{
|
||||
"prompt": "Strawberry sushi",
|
||||
},
|
||||
{
|
||||
"prompt": "Broccoli sushi",
|
||||
},
|
||||
{
|
||||
"prompt": "Asparagus sushi",
|
||||
},
|
||||
{
|
||||
"prompt": "Tea sushi",
|
||||
},
|
||||
]
|
||||
),
|
||||
Batch(
|
||||
node_id="2",
|
||||
data= [
|
||||
{
|
||||
"prompt2": "Ume sushi",
|
||||
},
|
||||
{
|
||||
"prompt2": "Ichigo sushi",
|
||||
},
|
||||
{
|
||||
"prompt2": "Momo sushi",
|
||||
},
|
||||
{
|
||||
"prompt2": "Mikan sushi",
|
||||
},
|
||||
{
|
||||
"prompt2": "Cha sushi",
|
||||
},
|
||||
]
|
||||
)
|
||||
]
|
||||
return batches
|
||||
|
||||
|
||||
# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
|
||||
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
|
||||
# the test invocations.
|
||||
@ -38,13 +90,14 @@ def mock_services() -> InvocationServices:
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||
filename=sqlite_memory, table_name="graph_executions"
|
||||
)
|
||||
batch_manager_storage = SqliteBatchProcessStorage(sqlite_memory)
|
||||
return InvocationServices(
|
||||
model_manager=None, # type: ignore
|
||||
events=TestEventService(),
|
||||
logger=None, # type: ignore
|
||||
images=None, # type: ignore
|
||||
latents=None, # type: ignore
|
||||
batch_manager=None, # type: ignore
|
||||
batch_manager=BatchManager(batch_manager_storage),
|
||||
boards=None, # type: ignore
|
||||
board_images=None, # type: ignore
|
||||
queue=MemoryInvocationQueue(),
|
||||
@ -131,3 +184,13 @@ def test_handles_errors(mock_invoker: Invoker):
|
||||
assert g.is_complete()
|
||||
|
||||
assert all((i in g.errors for i in g.source_prepared_mapping["1"]))
|
||||
|
||||
def test_can_create_batch(mock_invoker: Invoker, simple_graph, simple_batches):
|
||||
batch_process_res = mock_invoker.services.batch_manager.create_batch_process(
|
||||
batches=simple_batches,
|
||||
graph=simple_graph,
|
||||
)
|
||||
assert batch_process_res.batch_id
|
||||
assert len(batch_process_res.session_ids) == 25
|
||||
for session in batch_process_res.session_ids:
|
||||
assert mock_invoker.services.graph_execution_manager.get(session)
|
||||
|
@ -54,6 +54,7 @@ class TextToImageTestInvocation(BaseInvocation):
|
||||
type: Literal["test_text_to_image"] = "test_text_to_image"
|
||||
|
||||
prompt: str = Field(default="")
|
||||
prompt2: str = Field(default="")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
|
||||
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
|
||||
|
Loading…
Reference in New Issue
Block a user