fix(tests): fix tests

This commit is contained in:
psychedelicious 2023-12-02 20:00:07 +11:00
parent f268ea4e39
commit 4a14ee0e01
3 changed files with 5 additions and 6 deletions

View File

@ -28,7 +28,7 @@ from invokeai.app.services.shared.graph import (
IterateInvocation, IterateInvocation,
LibraryGraph, LibraryGraph,
) )
from invokeai.app.services.shared.sqlite import SqliteDatabase from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from .test_invoker import create_edge from .test_invoker import create_edge
@ -77,7 +77,6 @@ def mock_services() -> InvocationServices:
session_queue=None, # type: ignore session_queue=None, # type: ignore
urls=None, # type: ignore urls=None, # type: ignore
workflow_records=None, # type: ignore workflow_records=None, # type: ignore
workflow_image_records=None, # type: ignore
) )
@ -94,6 +93,7 @@ def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[B
queue_id=DEFAULT_QUEUE_ID, queue_id=DEFAULT_QUEUE_ID,
services=services, services=services,
graph_execution_state_id="1", graph_execution_state_id="1",
workflow=None,
) )
) )
g.complete(n.id, o) g.complete(n.id, o)

View File

@ -82,7 +82,6 @@ def mock_services() -> InvocationServices:
session_queue=None, # type: ignore session_queue=None, # type: ignore
urls=None, # type: ignore urls=None, # type: ignore
workflow_records=None, # type: ignore workflow_records=None, # type: ignore
workflow_image_records=None, # type: ignore
) )

View File

@ -206,9 +206,9 @@ def test_deny_nodes(patch_rootdir):
# confirm invocations union will not have denied nodes # confirm invocations union will not have denied nodes
all_invocations = BaseInvocation.get_invocations() all_invocations = BaseInvocation.get_invocations()
has_integer = len([i for i in all_invocations if i.__fields__.get("type").default == "integer"]) == 1 has_integer = len([i for i in all_invocations if i.model_fields.get("type").default == "integer"]) == 1
has_string = len([i for i in all_invocations if i.__fields__.get("type").default == "string"]) == 1 has_string = len([i for i in all_invocations if i.model_fields.get("type").default == "string"]) == 1
has_float = len([i for i in all_invocations if i.__fields__.get("type").default == "float"]) == 1 has_float = len([i for i in all_invocations if i.model_fields.get("type").default == "float"]) == 1
assert has_integer assert has_integer
assert has_string assert has_string