diff --git a/invokeai/app/services/batch_manager.py b/invokeai/app/services/batch_manager.py index d6432d5fc7..671a323e01 100644 --- a/invokeai/app/services/batch_manager.py +++ b/invokeai/app/services/batch_manager.py @@ -104,6 +104,7 @@ class BatchManager(BatchManagerBase): batch_session.session_id, changes, ) + sessions = self.get_sessions(batch_session.batch_id) batch_process = self.__batch_process_storage.get(batch_session.batch_id) if not batch_process.canceled: self.run_batch_process(batch_process.batch_id) diff --git a/pyproject.toml b/pyproject.toml index 4b06944b33..b66ae6fbef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -102,6 +102,7 @@ dependencies = [ "flake8", "Flake8-pyproject", "pytest>6.0.0", + "pytest-asyncio", "pytest-cov", "pytest-datadir", ] @@ -176,6 +177,7 @@ version = { attr = "invokeai.version.__version__" } #=== Begin: PyTest and Coverage [tool.pytest.ini_options] addopts = "--cov-report term --cov-report html --cov-report xml" +asyncio_mode = "auto" [tool.coverage.run] branch = true source = ["invokeai"] diff --git a/tests/nodes/test_invoker.py b/tests/nodes/test_invoker.py index 905675115a..e519d5a51d 100644 --- a/tests/nodes/test_invoker.py +++ b/tests/nodes/test_invoker.py @@ -6,12 +6,14 @@ from .test_nodes import ( create_edge, wait_until, ) +# from fastapi_events.handlers.local import from invokeai.app.services.invocation_queue import MemoryInvocationQueue from invokeai.app.services.processor import DefaultInvocationProcessor from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory from invokeai.app.services.invoker import Invoker from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invocation_stats import InvocationStatsService +from invokeai.app.api.events import FastAPIEventService from invokeai.app.services.batch_manager_storage import BatchData, SqliteBatchProcessStorage from invokeai.app.services.batch_manager import ( Batch, @@ -24,7 +26,14 @@ from invokeai.app.services.graph import ( LibraryGraph, ) import pytest +import pytest_asyncio import sqlite3 +import time + +from httpx import AsyncClient +from fastapi import FastAPI +from fastapi_events.handlers.local import local_handler +from fastapi_events.middleware import EventHandlerASGIMiddleware @pytest.fixture @@ -114,19 +123,42 @@ def batch_with_subgraph(): ] ) +# @pytest_asyncio.fixture(scope="module") +# def event_loop(): +# import asyncio +# try: +# loop = asyncio.get_running_loop() +# except RuntimeError as e: +# loop = asyncio.new_event_loop() +# asyncio.set_event_loop(loop) +# # fastapi_events.event_store +# yield loop +# loop.close() + +@pytest.fixture(scope="session") +def db_conn(): + return sqlite3.connect(sqlite_memory, check_same_thread=False) # This must be defined here to avoid issues with the dynamic creation of the union of all invocation types # Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate # the test invocations. -@pytest.fixture -def mock_services() -> InvocationServices: +@pytest.fixture(autouse=True) +async def mock_services(db_conn : sqlite3.Connection) -> InvocationServices: + app = FastAPI() + event_handler_id: int = id(app) + app.add_middleware( + EventHandlerASGIMiddleware, + handlers=[local_handler], # TODO: consider doing this in services to support different configurations + middleware_id=event_handler_id, + ) + client = AsyncClient(app=app) # NOTE: none of these are actually called by the test invocations - db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False) graph_execution_manager = SqliteItemStorage[GraphExecutionState](conn=db_conn, table_name="graph_executions") batch_manager_storage = SqliteBatchProcessStorage(conn=db_conn) + events = FastAPIEventService(event_handler_id) return InvocationServices( model_manager=None, # type: ignore - events=TestEventService(), + events=events, logger=None, # type: ignore images=None, # type: ignore latents=None, # type: ignore @@ -238,6 +270,21 @@ def test_can_create_batch_with_subgraph(mock_invoker: Invoker, graph_with_subgra # wait_until(lambda: has_executed_all_batches(batch_process_res.batch_id), timeout=10, interval=1) +async def test_can_run_batch_with_subgraph(mock_invoker: Invoker, graph_with_subgraph, batch_with_subgraph): + batch_process_res = mock_invoker.services.batch_manager.create_batch_process( + batch=batch_with_subgraph, + graph=graph_with_subgraph, + ) + mock_invoker.services.batch_manager.run_batch_process(batch_process_res.batch_id) + sessions = [] + attempts = 0 + import asyncio + while len(sessions) != 25 and attempts < 20: + batch = mock_invoker.services.batch_manager.get_batch(batch_process_res.batch_id) + sessions = mock_invoker.services.batch_manager.get_sessions(batch_process_res.batch_id) + await asyncio.sleep(1) + attempts += 1 + assert len(sessions) == 25 def test_can_create_batch(mock_invoker: Invoker, simple_graph, simple_batch): batch_process_res = mock_invoker.services.batch_manager.create_batch_process(