|
|
|
@ -6,12 +6,14 @@ from .test_nodes import (
|
|
|
|
|
create_edge,
|
|
|
|
|
wait_until,
|
|
|
|
|
)
|
|
|
|
|
# from fastapi_events.handlers.local import
|
|
|
|
|
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
|
|
|
|
|
from invokeai.app.services.processor import DefaultInvocationProcessor
|
|
|
|
|
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
|
|
|
|
from invokeai.app.services.invoker import Invoker
|
|
|
|
|
from invokeai.app.services.invocation_services import InvocationServices
|
|
|
|
|
from invokeai.app.services.invocation_stats import InvocationStatsService
|
|
|
|
|
from invokeai.app.api.events import FastAPIEventService
|
|
|
|
|
from invokeai.app.services.batch_manager_storage import BatchData, SqliteBatchProcessStorage
|
|
|
|
|
from invokeai.app.services.batch_manager import (
|
|
|
|
|
Batch,
|
|
|
|
@ -24,7 +26,14 @@ from invokeai.app.services.graph import (
|
|
|
|
|
LibraryGraph,
|
|
|
|
|
)
|
|
|
|
|
import pytest
|
|
|
|
|
import pytest_asyncio
|
|
|
|
|
import sqlite3
|
|
|
|
|
import time
|
|
|
|
|
|
|
|
|
|
from httpx import AsyncClient
|
|
|
|
|
from fastapi import FastAPI
|
|
|
|
|
from fastapi_events.handlers.local import local_handler
|
|
|
|
|
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
@ -114,19 +123,42 @@ def batch_with_subgraph():
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# @pytest_asyncio.fixture(scope="module")
|
|
|
|
|
# def event_loop():
|
|
|
|
|
# import asyncio
|
|
|
|
|
# try:
|
|
|
|
|
# loop = asyncio.get_running_loop()
|
|
|
|
|
# except RuntimeError as e:
|
|
|
|
|
# loop = asyncio.new_event_loop()
|
|
|
|
|
# asyncio.set_event_loop(loop)
|
|
|
|
|
# # fastapi_events.event_store
|
|
|
|
|
# yield loop
|
|
|
|
|
# loop.close()
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
|
|
|
def db_conn():
|
|
|
|
|
return sqlite3.connect(sqlite_memory, check_same_thread=False)
|
|
|
|
|
|
|
|
|
|
# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
|
|
|
|
|
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
|
|
|
|
|
# the test invocations.
|
|
|
|
|
@pytest.fixture
|
|
|
|
|
def mock_services() -> InvocationServices:
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
|
|
|
async def mock_services(db_conn : sqlite3.Connection) -> InvocationServices:
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
event_handler_id: int = id(app)
|
|
|
|
|
app.add_middleware(
|
|
|
|
|
EventHandlerASGIMiddleware,
|
|
|
|
|
handlers=[local_handler], # TODO: consider doing this in services to support different configurations
|
|
|
|
|
middleware_id=event_handler_id,
|
|
|
|
|
)
|
|
|
|
|
client = AsyncClient(app=app)
|
|
|
|
|
# NOTE: none of these are actually called by the test invocations
|
|
|
|
|
db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
|
|
|
|
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](conn=db_conn, table_name="graph_executions")
|
|
|
|
|
batch_manager_storage = SqliteBatchProcessStorage(conn=db_conn)
|
|
|
|
|
events = FastAPIEventService(event_handler_id)
|
|
|
|
|
return InvocationServices(
|
|
|
|
|
model_manager=None, # type: ignore
|
|
|
|
|
events=TestEventService(),
|
|
|
|
|
events=events,
|
|
|
|
|
logger=None, # type: ignore
|
|
|
|
|
images=None, # type: ignore
|
|
|
|
|
latents=None, # type: ignore
|
|
|
|
@ -238,6 +270,21 @@ def test_can_create_batch_with_subgraph(mock_invoker: Invoker, graph_with_subgra
|
|
|
|
|
|
|
|
|
|
# wait_until(lambda: has_executed_all_batches(batch_process_res.batch_id), timeout=10, interval=1)
|
|
|
|
|
|
|
|
|
|
async def test_can_run_batch_with_subgraph(mock_invoker: Invoker, graph_with_subgraph, batch_with_subgraph):
|
|
|
|
|
batch_process_res = mock_invoker.services.batch_manager.create_batch_process(
|
|
|
|
|
batch=batch_with_subgraph,
|
|
|
|
|
graph=graph_with_subgraph,
|
|
|
|
|
)
|
|
|
|
|
mock_invoker.services.batch_manager.run_batch_process(batch_process_res.batch_id)
|
|
|
|
|
sessions = []
|
|
|
|
|
attempts = 0
|
|
|
|
|
import asyncio
|
|
|
|
|
while len(sessions) != 25 and attempts < 20:
|
|
|
|
|
batch = mock_invoker.services.batch_manager.get_batch(batch_process_res.batch_id)
|
|
|
|
|
sessions = mock_invoker.services.batch_manager.get_sessions(batch_process_res.batch_id)
|
|
|
|
|
await asyncio.sleep(1)
|
|
|
|
|
attempts += 1
|
|
|
|
|
assert len(sessions) == 25
|
|
|
|
|
|
|
|
|
|
def test_can_create_batch(mock_invoker: Invoker, simple_graph, simple_batch):
|
|
|
|
|
batch_process_res = mock_invoker.services.batch_manager.create_batch_process(
|
|
|
|
|