mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add test for batch manager
This commit is contained in:
parent
69f541075c
commit
846e52f2ea
@ -12,6 +12,11 @@ from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
|||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.services.invocation_services import InvocationServices
|
from invokeai.app.services.invocation_services import InvocationServices
|
||||||
from invokeai.app.services.invocation_stats import InvocationStatsService
|
from invokeai.app.services.invocation_stats import InvocationStatsService
|
||||||
|
from invokeai.app.services.batch_manager_storage import SqliteBatchProcessStorage
|
||||||
|
from invokeai.app.services.batch_manager import (
|
||||||
|
Batch,
|
||||||
|
BatchManager,
|
||||||
|
)
|
||||||
from invokeai.app.services.graph import (
|
from invokeai.app.services.graph import (
|
||||||
Graph,
|
Graph,
|
||||||
GraphExecutionState,
|
GraphExecutionState,
|
||||||
@ -29,6 +34,53 @@ def simple_graph():
|
|||||||
return g
|
return g
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def simple_batches():
|
||||||
|
batches = [
|
||||||
|
Batch(
|
||||||
|
node_id=1,
|
||||||
|
data= [
|
||||||
|
{
|
||||||
|
"prompt": "Tomato sushi",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"prompt": "Strawberry sushi",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"prompt": "Broccoli sushi",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"prompt": "Asparagus sushi",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"prompt": "Tea sushi",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
),
|
||||||
|
Batch(
|
||||||
|
node_id="2",
|
||||||
|
data= [
|
||||||
|
{
|
||||||
|
"prompt2": "Ume sushi",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"prompt2": "Ichigo sushi",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"prompt2": "Momo sushi",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"prompt2": "Mikan sushi",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"prompt2": "Cha sushi",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
return batches
|
||||||
|
|
||||||
|
|
||||||
# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
|
# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
|
||||||
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
|
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
|
||||||
# the test invocations.
|
# the test invocations.
|
||||||
@ -38,13 +90,14 @@ def mock_services() -> InvocationServices:
|
|||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||||
filename=sqlite_memory, table_name="graph_executions"
|
filename=sqlite_memory, table_name="graph_executions"
|
||||||
)
|
)
|
||||||
|
batch_manager_storage = SqliteBatchProcessStorage(sqlite_memory)
|
||||||
return InvocationServices(
|
return InvocationServices(
|
||||||
model_manager=None, # type: ignore
|
model_manager=None, # type: ignore
|
||||||
events=TestEventService(),
|
events=TestEventService(),
|
||||||
logger=None, # type: ignore
|
logger=None, # type: ignore
|
||||||
images=None, # type: ignore
|
images=None, # type: ignore
|
||||||
latents=None, # type: ignore
|
latents=None, # type: ignore
|
||||||
batch_manager=None, # type: ignore
|
batch_manager=BatchManager(batch_manager_storage),
|
||||||
boards=None, # type: ignore
|
boards=None, # type: ignore
|
||||||
board_images=None, # type: ignore
|
board_images=None, # type: ignore
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
@ -131,3 +184,13 @@ def test_handles_errors(mock_invoker: Invoker):
|
|||||||
assert g.is_complete()
|
assert g.is_complete()
|
||||||
|
|
||||||
assert all((i in g.errors for i in g.source_prepared_mapping["1"]))
|
assert all((i in g.errors for i in g.source_prepared_mapping["1"]))
|
||||||
|
|
||||||
|
def test_can_create_batch(mock_invoker: Invoker, simple_graph, simple_batches):
|
||||||
|
batch_process_res = mock_invoker.services.batch_manager.create_batch_process(
|
||||||
|
batches=simple_batches,
|
||||||
|
graph=simple_graph,
|
||||||
|
)
|
||||||
|
assert batch_process_res.batch_id
|
||||||
|
assert len(batch_process_res.session_ids) == 25
|
||||||
|
for session in batch_process_res.session_ids:
|
||||||
|
assert mock_invoker.services.graph_execution_manager.get(session)
|
||||||
|
@ -54,6 +54,7 @@ class TextToImageTestInvocation(BaseInvocation):
|
|||||||
type: Literal["test_text_to_image"] = "test_text_to_image"
|
type: Literal["test_text_to_image"] = "test_text_to_image"
|
||||||
|
|
||||||
prompt: str = Field(default="")
|
prompt: str = Field(default="")
|
||||||
|
prompt2: str = Field(default="")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
|
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
|
||||||
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
|
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
|
||||||
|
Loading…
Reference in New Issue
Block a user