From 4a14ee0e011f96cf52fc7843166b5370060d1dcf Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 2 Dec 2023 20:00:07 +1100 Subject: [PATCH] fix(tests): fix tests --- tests/nodes/test_graph_execution_state.py | 4 ++-- tests/nodes/test_invoker.py | 1 - tests/test_config.py | 6 +++--- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/nodes/test_graph_execution_state.py b/tests/nodes/test_graph_execution_state.py index cc40970ace..203c470469 100644 --- a/tests/nodes/test_graph_execution_state.py +++ b/tests/nodes/test_graph_execution_state.py @@ -28,7 +28,7 @@ from invokeai.app.services.shared.graph import ( IterateInvocation, LibraryGraph, ) -from invokeai.app.services.shared.sqlite import SqliteDatabase +from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase from invokeai.backend.util.logging import InvokeAILogger from .test_invoker import create_edge @@ -77,7 +77,6 @@ def mock_services() -> InvocationServices: session_queue=None, # type: ignore urls=None, # type: ignore workflow_records=None, # type: ignore - workflow_image_records=None, # type: ignore ) @@ -94,6 +93,7 @@ def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[B queue_id=DEFAULT_QUEUE_ID, services=services, graph_execution_state_id="1", + workflow=None, ) ) g.complete(n.id, o) diff --git a/tests/nodes/test_invoker.py b/tests/nodes/test_invoker.py index f79c205232..88186d5448 100644 --- a/tests/nodes/test_invoker.py +++ b/tests/nodes/test_invoker.py @@ -82,7 +82,6 @@ def mock_services() -> InvocationServices: session_queue=None, # type: ignore urls=None, # type: ignore workflow_records=None, # type: ignore - workflow_image_records=None, # type: ignore ) diff --git a/tests/test_config.py b/tests/test_config.py index 6d76872a0d..740a4866dd 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -206,9 +206,9 @@ def test_deny_nodes(patch_rootdir): # confirm invocations union will not have denied nodes all_invocations = BaseInvocation.get_invocations() - has_integer = len([i for i in all_invocations if i.__fields__.get("type").default == "integer"]) == 1 - has_string = len([i for i in all_invocations if i.__fields__.get("type").default == "string"]) == 1 - has_float = len([i for i in all_invocations if i.__fields__.get("type").default == "float"]) == 1 + has_integer = len([i for i in all_invocations if i.model_fields.get("type").default == "integer"]) == 1 + has_string = len([i for i in all_invocations if i.model_fields.get("type").default == "string"]) == 1 + has_float = len([i for i in all_invocations if i.model_fields.get("type").default == "float"]) == 1 assert has_integer assert has_string