fix unit tests

This commit is contained in:
Lincoln Stein 2023-08-02 18:31:10 -04:00
parent 8fc75a71ee
commit 3fc789a7ee
2 changed files with 12 additions and 6 deletions

View File

@ -16,6 +16,7 @@ from invokeai.app.invocations.baseinvocation import (
from invokeai.app.invocations.collections import RangeInvocation from invokeai.app.invocations.collections import RangeInvocation
from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation
from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats import InvocationStatsService
from invokeai.app.services.graph import ( from invokeai.app.services.graph import (
Graph, Graph,
CollectInvocation, CollectInvocation,
@ -41,6 +42,9 @@ def simple_graph():
@pytest.fixture @pytest.fixture
def mock_services() -> InvocationServices: def mock_services() -> InvocationServices:
# NOTE: none of these are actually called by the test invocations # NOTE: none of these are actually called by the test invocations
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
filename=sqlite_memory, table_name="graph_executions"
)
return InvocationServices( return InvocationServices(
model_manager=None, # type: ignore model_manager=None, # type: ignore
events=TestEventService(), events=TestEventService(),
@ -51,9 +55,8 @@ def mock_services() -> InvocationServices:
board_images=None, # type: ignore board_images=None, # type: ignore
queue=MemoryInvocationQueue(), queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](filename=sqlite_memory, table_name="graphs"), graph_library=SqliteItemStorage[LibraryGraph](filename=sqlite_memory, table_name="graphs"),
graph_execution_manager=SqliteItemStorage[GraphExecutionState]( graph_execution_manager=graph_execution_manager,
filename=sqlite_memory, table_name="graph_executions" performance_statistics=InvocationStatsService(graph_execution_manager),
),
processor=DefaultInvocationProcessor(), processor=DefaultInvocationProcessor(),
configuration=None, # type: ignore configuration=None, # type: ignore
) )

View File

@ -11,6 +11,7 @@ from invokeai.app.services.processor import DefaultInvocationProcessor
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats import InvocationStatsService
from invokeai.app.services.graph import ( from invokeai.app.services.graph import (
Graph, Graph,
GraphExecutionState, GraphExecutionState,
@ -34,6 +35,9 @@ def simple_graph():
@pytest.fixture @pytest.fixture
def mock_services() -> InvocationServices: def mock_services() -> InvocationServices:
# NOTE: none of these are actually called by the test invocations # NOTE: none of these are actually called by the test invocations
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
filename=sqlite_memory, table_name="graph_executions"
)
return InvocationServices( return InvocationServices(
model_manager=None, # type: ignore model_manager=None, # type: ignore
events=TestEventService(), events=TestEventService(),
@ -44,10 +48,9 @@ def mock_services() -> InvocationServices:
board_images=None, # type: ignore board_images=None, # type: ignore
queue=MemoryInvocationQueue(), queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](filename=sqlite_memory, table_name="graphs"), graph_library=SqliteItemStorage[LibraryGraph](filename=sqlite_memory, table_name="graphs"),
graph_execution_manager=SqliteItemStorage[GraphExecutionState]( graph_execution_manager=graph_execution_manager,
filename=sqlite_memory, table_name="graph_executions"
),
processor=DefaultInvocationProcessor(), processor=DefaultInvocationProcessor(),
performance_statistics=InvocationStatsService(graph_execution_manager),
configuration=None, # type: ignore configuration=None, # type: ignore
) )