Playing with eventservice tests

This commit is contained in:
Brandon Rising 2023-09-08 12:33:12 -04:00
parent ed7deee8f1
commit 732780c376
3 changed files with 54 additions and 4 deletions

View File

@ -104,6 +104,7 @@ class BatchManager(BatchManagerBase):
batch_session.session_id, batch_session.session_id,
changes, changes,
) )
sessions = self.get_sessions(batch_session.batch_id)
batch_process = self.__batch_process_storage.get(batch_session.batch_id) batch_process = self.__batch_process_storage.get(batch_session.batch_id)
if not batch_process.canceled: if not batch_process.canceled:
self.run_batch_process(batch_process.batch_id) self.run_batch_process(batch_process.batch_id)

View File

@ -102,6 +102,7 @@ dependencies = [
"flake8", "flake8",
"Flake8-pyproject", "Flake8-pyproject",
"pytest>6.0.0", "pytest>6.0.0",
"pytest-asyncio",
"pytest-cov", "pytest-cov",
"pytest-datadir", "pytest-datadir",
] ]
@ -176,6 +177,7 @@ version = { attr = "invokeai.version.__version__" }
#=== Begin: PyTest and Coverage #=== Begin: PyTest and Coverage
[tool.pytest.ini_options] [tool.pytest.ini_options]
addopts = "--cov-report term --cov-report html --cov-report xml" addopts = "--cov-report term --cov-report html --cov-report xml"
asyncio_mode = "auto"
[tool.coverage.run] [tool.coverage.run]
branch = true branch = true
source = ["invokeai"] source = ["invokeai"]

View File

@ -6,12 +6,14 @@ from .test_nodes import (
create_edge, create_edge,
wait_until, wait_until,
) )
# from fastapi_events.handlers.local import
from invokeai.app.services.invocation_queue import MemoryInvocationQueue from invokeai.app.services.invocation_queue import MemoryInvocationQueue
from invokeai.app.services.processor import DefaultInvocationProcessor from invokeai.app.services.processor import DefaultInvocationProcessor
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats import InvocationStatsService from invokeai.app.services.invocation_stats import InvocationStatsService
from invokeai.app.api.events import FastAPIEventService
from invokeai.app.services.batch_manager_storage import BatchData, SqliteBatchProcessStorage from invokeai.app.services.batch_manager_storage import BatchData, SqliteBatchProcessStorage
from invokeai.app.services.batch_manager import ( from invokeai.app.services.batch_manager import (
Batch, Batch,
@ -24,7 +26,14 @@ from invokeai.app.services.graph import (
LibraryGraph, LibraryGraph,
) )
import pytest import pytest
import pytest_asyncio
import sqlite3 import sqlite3
import time
from httpx import AsyncClient
from fastapi import FastAPI
from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware
@pytest.fixture @pytest.fixture
@ -114,19 +123,42 @@ def batch_with_subgraph():
] ]
) )
# @pytest_asyncio.fixture(scope="module")
# def event_loop():
# import asyncio
# try:
# loop = asyncio.get_running_loop()
# except RuntimeError as e:
# loop = asyncio.new_event_loop()
# asyncio.set_event_loop(loop)
# # fastapi_events.event_store
# yield loop
# loop.close()
@pytest.fixture(scope="session")
def db_conn():
return sqlite3.connect(sqlite_memory, check_same_thread=False)
# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types # This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate # Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
# the test invocations. # the test invocations.
@pytest.fixture @pytest.fixture(autouse=True)
def mock_services() -> InvocationServices: async def mock_services(db_conn : sqlite3.Connection) -> InvocationServices:
app = FastAPI()
event_handler_id: int = id(app)
app.add_middleware(
EventHandlerASGIMiddleware,
handlers=[local_handler], # TODO: consider doing this in services to support different configurations
middleware_id=event_handler_id,
)
client = AsyncClient(app=app)
# NOTE: none of these are actually called by the test invocations # NOTE: none of these are actually called by the test invocations
db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
graph_execution_manager = SqliteItemStorage[GraphExecutionState](conn=db_conn, table_name="graph_executions") graph_execution_manager = SqliteItemStorage[GraphExecutionState](conn=db_conn, table_name="graph_executions")
batch_manager_storage = SqliteBatchProcessStorage(conn=db_conn) batch_manager_storage = SqliteBatchProcessStorage(conn=db_conn)
events = FastAPIEventService(event_handler_id)
return InvocationServices( return InvocationServices(
model_manager=None, # type: ignore model_manager=None, # type: ignore
events=TestEventService(), events=events,
logger=None, # type: ignore logger=None, # type: ignore
images=None, # type: ignore images=None, # type: ignore
latents=None, # type: ignore latents=None, # type: ignore
@ -238,6 +270,21 @@ def test_can_create_batch_with_subgraph(mock_invoker: Invoker, graph_with_subgra
# wait_until(lambda: has_executed_all_batches(batch_process_res.batch_id), timeout=10, interval=1) # wait_until(lambda: has_executed_all_batches(batch_process_res.batch_id), timeout=10, interval=1)
async def test_can_run_batch_with_subgraph(mock_invoker: Invoker, graph_with_subgraph, batch_with_subgraph):
batch_process_res = mock_invoker.services.batch_manager.create_batch_process(
batch=batch_with_subgraph,
graph=graph_with_subgraph,
)
mock_invoker.services.batch_manager.run_batch_process(batch_process_res.batch_id)
sessions = []
attempts = 0
import asyncio
while len(sessions) != 25 and attempts < 20:
batch = mock_invoker.services.batch_manager.get_batch(batch_process_res.batch_id)
sessions = mock_invoker.services.batch_manager.get_sessions(batch_process_res.batch_id)
await asyncio.sleep(1)
attempts += 1
assert len(sessions) == 25
def test_can_create_batch(mock_invoker: Invoker, simple_graph, simple_batch): def test_can_create_batch(mock_invoker: Invoker, simple_graph, simple_batch):
batch_process_res = mock_invoker.services.batch_manager.create_batch_process( batch_process_res = mock_invoker.services.batch_manager.create_batch_process(