mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Playing with eventservice tests
This commit is contained in:
parent
ed7deee8f1
commit
732780c376
@ -104,6 +104,7 @@ class BatchManager(BatchManagerBase):
|
||||
batch_session.session_id,
|
||||
changes,
|
||||
)
|
||||
sessions = self.get_sessions(batch_session.batch_id)
|
||||
batch_process = self.__batch_process_storage.get(batch_session.batch_id)
|
||||
if not batch_process.canceled:
|
||||
self.run_batch_process(batch_process.batch_id)
|
||||
|
@ -102,6 +102,7 @@ dependencies = [
|
||||
"flake8",
|
||||
"Flake8-pyproject",
|
||||
"pytest>6.0.0",
|
||||
"pytest-asyncio",
|
||||
"pytest-cov",
|
||||
"pytest-datadir",
|
||||
]
|
||||
@ -176,6 +177,7 @@ version = { attr = "invokeai.version.__version__" }
|
||||
#=== Begin: PyTest and Coverage
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "--cov-report term --cov-report html --cov-report xml"
|
||||
asyncio_mode = "auto"
|
||||
[tool.coverage.run]
|
||||
branch = true
|
||||
source = ["invokeai"]
|
||||
|
@ -6,12 +6,14 @@ from .test_nodes import (
|
||||
create_edge,
|
||||
wait_until,
|
||||
)
|
||||
# from fastapi_events.handlers.local import
|
||||
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
|
||||
from invokeai.app.services.processor import DefaultInvocationProcessor
|
||||
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.invocation_stats import InvocationStatsService
|
||||
from invokeai.app.api.events import FastAPIEventService
|
||||
from invokeai.app.services.batch_manager_storage import BatchData, SqliteBatchProcessStorage
|
||||
from invokeai.app.services.batch_manager import (
|
||||
Batch,
|
||||
@ -24,7 +26,14 @@ from invokeai.app.services.graph import (
|
||||
LibraryGraph,
|
||||
)
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import sqlite3
|
||||
import time
|
||||
|
||||
from httpx import AsyncClient
|
||||
from fastapi import FastAPI
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -114,19 +123,42 @@ def batch_with_subgraph():
|
||||
]
|
||||
)
|
||||
|
||||
# @pytest_asyncio.fixture(scope="module")
|
||||
# def event_loop():
|
||||
# import asyncio
|
||||
# try:
|
||||
# loop = asyncio.get_running_loop()
|
||||
# except RuntimeError as e:
|
||||
# loop = asyncio.new_event_loop()
|
||||
# asyncio.set_event_loop(loop)
|
||||
# # fastapi_events.event_store
|
||||
# yield loop
|
||||
# loop.close()
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def db_conn():
|
||||
return sqlite3.connect(sqlite_memory, check_same_thread=False)
|
||||
|
||||
# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
|
||||
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
|
||||
# the test invocations.
|
||||
@pytest.fixture
|
||||
def mock_services() -> InvocationServices:
|
||||
@pytest.fixture(autouse=True)
|
||||
async def mock_services(db_conn : sqlite3.Connection) -> InvocationServices:
|
||||
app = FastAPI()
|
||||
event_handler_id: int = id(app)
|
||||
app.add_middleware(
|
||||
EventHandlerASGIMiddleware,
|
||||
handlers=[local_handler], # TODO: consider doing this in services to support different configurations
|
||||
middleware_id=event_handler_id,
|
||||
)
|
||||
client = AsyncClient(app=app)
|
||||
# NOTE: none of these are actually called by the test invocations
|
||||
db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](conn=db_conn, table_name="graph_executions")
|
||||
batch_manager_storage = SqliteBatchProcessStorage(conn=db_conn)
|
||||
events = FastAPIEventService(event_handler_id)
|
||||
return InvocationServices(
|
||||
model_manager=None, # type: ignore
|
||||
events=TestEventService(),
|
||||
events=events,
|
||||
logger=None, # type: ignore
|
||||
images=None, # type: ignore
|
||||
latents=None, # type: ignore
|
||||
@ -238,6 +270,21 @@ def test_can_create_batch_with_subgraph(mock_invoker: Invoker, graph_with_subgra
|
||||
|
||||
# wait_until(lambda: has_executed_all_batches(batch_process_res.batch_id), timeout=10, interval=1)
|
||||
|
||||
async def test_can_run_batch_with_subgraph(mock_invoker: Invoker, graph_with_subgraph, batch_with_subgraph):
|
||||
batch_process_res = mock_invoker.services.batch_manager.create_batch_process(
|
||||
batch=batch_with_subgraph,
|
||||
graph=graph_with_subgraph,
|
||||
)
|
||||
mock_invoker.services.batch_manager.run_batch_process(batch_process_res.batch_id)
|
||||
sessions = []
|
||||
attempts = 0
|
||||
import asyncio
|
||||
while len(sessions) != 25 and attempts < 20:
|
||||
batch = mock_invoker.services.batch_manager.get_batch(batch_process_res.batch_id)
|
||||
sessions = mock_invoker.services.batch_manager.get_sessions(batch_process_res.batch_id)
|
||||
await asyncio.sleep(1)
|
||||
attempts += 1
|
||||
assert len(sessions) == 25
|
||||
|
||||
def test_can_create_batch(mock_invoker: Invoker, simple_graph, simple_batch):
|
||||
batch_process_res = mock_invoker.services.batch_manager.create_batch_process(
|
||||
|
Loading…
Reference in New Issue
Block a user