mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Playing with eventservice tests
This commit is contained in:
parent
ed7deee8f1
commit
732780c376
@ -104,6 +104,7 @@ class BatchManager(BatchManagerBase):
|
|||||||
batch_session.session_id,
|
batch_session.session_id,
|
||||||
changes,
|
changes,
|
||||||
)
|
)
|
||||||
|
sessions = self.get_sessions(batch_session.batch_id)
|
||||||
batch_process = self.__batch_process_storage.get(batch_session.batch_id)
|
batch_process = self.__batch_process_storage.get(batch_session.batch_id)
|
||||||
if not batch_process.canceled:
|
if not batch_process.canceled:
|
||||||
self.run_batch_process(batch_process.batch_id)
|
self.run_batch_process(batch_process.batch_id)
|
||||||
|
@ -102,6 +102,7 @@ dependencies = [
|
|||||||
"flake8",
|
"flake8",
|
||||||
"Flake8-pyproject",
|
"Flake8-pyproject",
|
||||||
"pytest>6.0.0",
|
"pytest>6.0.0",
|
||||||
|
"pytest-asyncio",
|
||||||
"pytest-cov",
|
"pytest-cov",
|
||||||
"pytest-datadir",
|
"pytest-datadir",
|
||||||
]
|
]
|
||||||
@ -176,6 +177,7 @@ version = { attr = "invokeai.version.__version__" }
|
|||||||
#=== Begin: PyTest and Coverage
|
#=== Begin: PyTest and Coverage
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
addopts = "--cov-report term --cov-report html --cov-report xml"
|
addopts = "--cov-report term --cov-report html --cov-report xml"
|
||||||
|
asyncio_mode = "auto"
|
||||||
[tool.coverage.run]
|
[tool.coverage.run]
|
||||||
branch = true
|
branch = true
|
||||||
source = ["invokeai"]
|
source = ["invokeai"]
|
||||||
|
@ -6,12 +6,14 @@ from .test_nodes import (
|
|||||||
create_edge,
|
create_edge,
|
||||||
wait_until,
|
wait_until,
|
||||||
)
|
)
|
||||||
|
# from fastapi_events.handlers.local import
|
||||||
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
|
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
|
||||||
from invokeai.app.services.processor import DefaultInvocationProcessor
|
from invokeai.app.services.processor import DefaultInvocationProcessor
|
||||||
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.services.invocation_services import InvocationServices
|
from invokeai.app.services.invocation_services import InvocationServices
|
||||||
from invokeai.app.services.invocation_stats import InvocationStatsService
|
from invokeai.app.services.invocation_stats import InvocationStatsService
|
||||||
|
from invokeai.app.api.events import FastAPIEventService
|
||||||
from invokeai.app.services.batch_manager_storage import BatchData, SqliteBatchProcessStorage
|
from invokeai.app.services.batch_manager_storage import BatchData, SqliteBatchProcessStorage
|
||||||
from invokeai.app.services.batch_manager import (
|
from invokeai.app.services.batch_manager import (
|
||||||
Batch,
|
Batch,
|
||||||
@ -24,7 +26,14 @@ from invokeai.app.services.graph import (
|
|||||||
LibraryGraph,
|
LibraryGraph,
|
||||||
)
|
)
|
||||||
import pytest
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
import time
|
||||||
|
|
||||||
|
from httpx import AsyncClient
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi_events.handlers.local import local_handler
|
||||||
|
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -114,19 +123,42 @@ def batch_with_subgraph():
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# @pytest_asyncio.fixture(scope="module")
|
||||||
|
# def event_loop():
|
||||||
|
# import asyncio
|
||||||
|
# try:
|
||||||
|
# loop = asyncio.get_running_loop()
|
||||||
|
# except RuntimeError as e:
|
||||||
|
# loop = asyncio.new_event_loop()
|
||||||
|
# asyncio.set_event_loop(loop)
|
||||||
|
# # fastapi_events.event_store
|
||||||
|
# yield loop
|
||||||
|
# loop.close()
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def db_conn():
|
||||||
|
return sqlite3.connect(sqlite_memory, check_same_thread=False)
|
||||||
|
|
||||||
# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
|
# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
|
||||||
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
|
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
|
||||||
# the test invocations.
|
# the test invocations.
|
||||||
@pytest.fixture
|
@pytest.fixture(autouse=True)
|
||||||
def mock_services() -> InvocationServices:
|
async def mock_services(db_conn : sqlite3.Connection) -> InvocationServices:
|
||||||
|
app = FastAPI()
|
||||||
|
event_handler_id: int = id(app)
|
||||||
|
app.add_middleware(
|
||||||
|
EventHandlerASGIMiddleware,
|
||||||
|
handlers=[local_handler], # TODO: consider doing this in services to support different configurations
|
||||||
|
middleware_id=event_handler_id,
|
||||||
|
)
|
||||||
|
client = AsyncClient(app=app)
|
||||||
# NOTE: none of these are actually called by the test invocations
|
# NOTE: none of these are actually called by the test invocations
|
||||||
db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
|
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](conn=db_conn, table_name="graph_executions")
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](conn=db_conn, table_name="graph_executions")
|
||||||
batch_manager_storage = SqliteBatchProcessStorage(conn=db_conn)
|
batch_manager_storage = SqliteBatchProcessStorage(conn=db_conn)
|
||||||
|
events = FastAPIEventService(event_handler_id)
|
||||||
return InvocationServices(
|
return InvocationServices(
|
||||||
model_manager=None, # type: ignore
|
model_manager=None, # type: ignore
|
||||||
events=TestEventService(),
|
events=events,
|
||||||
logger=None, # type: ignore
|
logger=None, # type: ignore
|
||||||
images=None, # type: ignore
|
images=None, # type: ignore
|
||||||
latents=None, # type: ignore
|
latents=None, # type: ignore
|
||||||
@ -238,6 +270,21 @@ def test_can_create_batch_with_subgraph(mock_invoker: Invoker, graph_with_subgra
|
|||||||
|
|
||||||
# wait_until(lambda: has_executed_all_batches(batch_process_res.batch_id), timeout=10, interval=1)
|
# wait_until(lambda: has_executed_all_batches(batch_process_res.batch_id), timeout=10, interval=1)
|
||||||
|
|
||||||
|
async def test_can_run_batch_with_subgraph(mock_invoker: Invoker, graph_with_subgraph, batch_with_subgraph):
|
||||||
|
batch_process_res = mock_invoker.services.batch_manager.create_batch_process(
|
||||||
|
batch=batch_with_subgraph,
|
||||||
|
graph=graph_with_subgraph,
|
||||||
|
)
|
||||||
|
mock_invoker.services.batch_manager.run_batch_process(batch_process_res.batch_id)
|
||||||
|
sessions = []
|
||||||
|
attempts = 0
|
||||||
|
import asyncio
|
||||||
|
while len(sessions) != 25 and attempts < 20:
|
||||||
|
batch = mock_invoker.services.batch_manager.get_batch(batch_process_res.batch_id)
|
||||||
|
sessions = mock_invoker.services.batch_manager.get_sessions(batch_process_res.batch_id)
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
attempts += 1
|
||||||
|
assert len(sessions) == 25
|
||||||
|
|
||||||
def test_can_create_batch(mock_invoker: Invoker, simple_graph, simple_batch):
|
def test_can_create_batch(mock_invoker: Invoker, simple_graph, simple_batch):
|
||||||
batch_process_res = mock_invoker.services.batch_manager.create_batch_process(
|
batch_process_res = mock_invoker.services.batch_manager.create_batch_process(
|
||||||
|
Loading…
Reference in New Issue
Block a user