fix(test): fix tests

This commit is contained in:
psychedelicious 2023-09-20 18:40:40 +10:00
parent c1aa2b82eb
commit bfed08673a
2 changed files with 10 additions and 9 deletions

View File

@ -3,8 +3,6 @@ import threading
import pytest import pytest
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
# This import must happen before other invoke imports or test in other files(!!) break # This import must happen before other invoke imports or test in other files(!!) break
from .test_nodes import ( # isort: split from .test_nodes import ( # isort: split
PromptCollectionTestInvocation, PromptCollectionTestInvocation,
@ -17,7 +15,9 @@ import sqlite3
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from invokeai.app.invocations.collections import RangeInvocation from invokeai.app.invocations.collections import RangeInvocation
from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation
from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig
from invokeai.app.services.graph import CollectInvocation, Graph, GraphExecutionState, IterateInvocation, LibraryGraph from invokeai.app.services.graph import CollectInvocation, Graph, GraphExecutionState, IterateInvocation, LibraryGraph
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
from invokeai.app.services.invocation_queue import MemoryInvocationQueue from invokeai.app.services.invocation_queue import MemoryInvocationQueue
from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats import InvocationStatsService from invokeai.app.services.invocation_stats import InvocationStatsService
@ -61,7 +61,7 @@ def mock_services() -> InvocationServices:
graph_execution_manager=graph_execution_manager, graph_execution_manager=graph_execution_manager,
performance_statistics=InvocationStatsService(graph_execution_manager), performance_statistics=InvocationStatsService(graph_execution_manager),
processor=DefaultInvocationProcessor(), processor=DefaultInvocationProcessor(),
configuration=None, # type: ignore configuration=InvokeAIAppConfig(node_cache_size=0), # type: ignore
session_queue=None, # type: ignore session_queue=None, # type: ignore
session_processor=None, # type: ignore session_processor=None, # type: ignore
invocation_cache=MemoryInvocationCache(), # type: ignore invocation_cache=MemoryInvocationCache(), # type: ignore

View File

@ -4,6 +4,8 @@ import threading
import pytest import pytest
from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig
# This import must happen before other invoke imports or test in other files(!!) break # This import must happen before other invoke imports or test in other files(!!) break
from .test_nodes import ( # isort: split from .test_nodes import ( # isort: split
ErrorInvocation, ErrorInvocation,
@ -14,7 +16,6 @@ from .test_nodes import ( # isort: split
wait_until, wait_until,
) )
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.graph import Graph, GraphExecutionState, GraphInvocation, LibraryGraph from invokeai.app.services.graph import Graph, GraphExecutionState, GraphInvocation, LibraryGraph
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
from invokeai.app.services.invocation_queue import MemoryInvocationQueue from invokeai.app.services.invocation_queue import MemoryInvocationQueue
@ -70,10 +71,10 @@ def mock_services() -> InvocationServices:
graph_execution_manager=graph_execution_manager, graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(), processor=DefaultInvocationProcessor(),
performance_statistics=InvocationStatsService(graph_execution_manager), performance_statistics=InvocationStatsService(graph_execution_manager),
configuration=InvokeAIAppConfig(), configuration=InvokeAIAppConfig(node_cache_size=0), # type: ignore
session_queue=None, # type: ignore session_queue=None, # type: ignore
session_processor=None, # type: ignore session_processor=None, # type: ignore
invocation_cache=MemoryInvocationCache(), invocation_cache=MemoryInvocationCache(max_cache_size=0),
) )
@ -102,7 +103,7 @@ def test_can_create_graph_state_from_graph(mock_invoker: Invoker, simple_graph):
# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor") # @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
def test_can_invoke(mock_invoker: Invoker, simple_graph): def test_can_invoke(mock_invoker: Invoker, simple_graph):
g = mock_invoker.create_execution_state(graph=simple_graph) g = mock_invoker.create_execution_state(graph=simple_graph)
invocation_id = mock_invoker.invoke(queue_item_id="1", queue_id=DEFAULT_QUEUE_ID, graph_execution_state=g) invocation_id = mock_invoker.invoke(queue_item_id=1, queue_id=DEFAULT_QUEUE_ID, graph_execution_state=g)
assert invocation_id is not None assert invocation_id is not None
def has_executed_any(g: GraphExecutionState): def has_executed_any(g: GraphExecutionState):
@ -120,7 +121,7 @@ def test_can_invoke(mock_invoker: Invoker, simple_graph):
def test_can_invoke_all(mock_invoker: Invoker, simple_graph): def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
g = mock_invoker.create_execution_state(graph=simple_graph) g = mock_invoker.create_execution_state(graph=simple_graph)
invocation_id = mock_invoker.invoke( invocation_id = mock_invoker.invoke(
queue_item_id="1", queue_id=DEFAULT_QUEUE_ID, graph_execution_state=g, invoke_all=True queue_item_id=1, queue_id=DEFAULT_QUEUE_ID, graph_execution_state=g, invoke_all=True
) )
assert invocation_id is not None assert invocation_id is not None
@ -140,7 +141,7 @@ def test_handles_errors(mock_invoker: Invoker):
g = mock_invoker.create_execution_state() g = mock_invoker.create_execution_state()
g.graph.add_node(ErrorInvocation(id="1")) g.graph.add_node(ErrorInvocation(id="1"))
mock_invoker.invoke(queue_item_id="1", queue_id=DEFAULT_QUEUE_ID, graph_execution_state=g, invoke_all=True) mock_invoker.invoke(queue_item_id=1, queue_id=DEFAULT_QUEUE_ID, graph_execution_state=g, invoke_all=True)
def has_executed_all(g: GraphExecutionState): def has_executed_all(g: GraphExecutionState):
g = mock_invoker.services.graph_execution_manager.get(g.id) g = mock_invoker.services.graph_execution_manager.get(g.id)