fix(test): fix tests

This commit is contained in:
psychedelicious 2023-09-20 18:40:40 +10:00
parent c1aa2b82eb
commit bfed08673a
2 changed files with 10 additions and 9 deletions

View File

@ -3,8 +3,6 @@ import threading
import pytest
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
# This import must happen before other invoke imports or test in other files(!!) break
from .test_nodes import ( # isort: split
PromptCollectionTestInvocation,
@ -17,7 +15,9 @@ import sqlite3
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from invokeai.app.invocations.collections import RangeInvocation
from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation
from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig
from invokeai.app.services.graph import CollectInvocation, Graph, GraphExecutionState, IterateInvocation, LibraryGraph
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats import InvocationStatsService
@ -61,7 +61,7 @@ def mock_services() -> InvocationServices:
graph_execution_manager=graph_execution_manager,
performance_statistics=InvocationStatsService(graph_execution_manager),
processor=DefaultInvocationProcessor(),
configuration=None, # type: ignore
configuration=InvokeAIAppConfig(node_cache_size=0), # type: ignore
session_queue=None, # type: ignore
session_processor=None, # type: ignore
invocation_cache=MemoryInvocationCache(), # type: ignore

View File

@ -4,6 +4,8 @@ import threading
import pytest
from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig
# This import must happen before other invoke imports or test in other files(!!) break
from .test_nodes import ( # isort: split
ErrorInvocation,
@ -14,7 +16,6 @@ from .test_nodes import ( # isort: split
wait_until,
)
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.graph import Graph, GraphExecutionState, GraphInvocation, LibraryGraph
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
@ -70,10 +71,10 @@ def mock_services() -> InvocationServices:
graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(),
performance_statistics=InvocationStatsService(graph_execution_manager),
configuration=InvokeAIAppConfig(),
configuration=InvokeAIAppConfig(node_cache_size=0), # type: ignore
session_queue=None, # type: ignore
session_processor=None, # type: ignore
invocation_cache=MemoryInvocationCache(),
invocation_cache=MemoryInvocationCache(max_cache_size=0),
)
@ -102,7 +103,7 @@ def test_can_create_graph_state_from_graph(mock_invoker: Invoker, simple_graph):
# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
def test_can_invoke(mock_invoker: Invoker, simple_graph):
g = mock_invoker.create_execution_state(graph=simple_graph)
invocation_id = mock_invoker.invoke(queue_item_id="1", queue_id=DEFAULT_QUEUE_ID, graph_execution_state=g)
invocation_id = mock_invoker.invoke(queue_item_id=1, queue_id=DEFAULT_QUEUE_ID, graph_execution_state=g)
assert invocation_id is not None
def has_executed_any(g: GraphExecutionState):
@ -120,7 +121,7 @@ def test_can_invoke(mock_invoker: Invoker, simple_graph):
def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
g = mock_invoker.create_execution_state(graph=simple_graph)
invocation_id = mock_invoker.invoke(
queue_item_id="1", queue_id=DEFAULT_QUEUE_ID, graph_execution_state=g, invoke_all=True
queue_item_id=1, queue_id=DEFAULT_QUEUE_ID, graph_execution_state=g, invoke_all=True
)
assert invocation_id is not None
@ -140,7 +141,7 @@ def test_handles_errors(mock_invoker: Invoker):
g = mock_invoker.create_execution_state()
g.graph.add_node(ErrorInvocation(id="1"))
mock_invoker.invoke(queue_item_id="1", queue_id=DEFAULT_QUEUE_ID, graph_execution_state=g, invoke_all=True)
mock_invoker.invoke(queue_item_id=1, queue_id=DEFAULT_QUEUE_ID, graph_execution_state=g, invoke_all=True)
def has_executed_all(g: GraphExecutionState):
g = mock_invoker.services.graph_execution_manager.get(g.id)