mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(tests): fix tests
This commit is contained in:
parent
f268ea4e39
commit
4a14ee0e01
@ -28,7 +28,7 @@ from invokeai.app.services.shared.graph import (
|
||||
IterateInvocation,
|
||||
LibraryGraph,
|
||||
)
|
||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from .test_invoker import create_edge
|
||||
@ -77,7 +77,6 @@ def mock_services() -> InvocationServices:
|
||||
session_queue=None, # type: ignore
|
||||
urls=None, # type: ignore
|
||||
workflow_records=None, # type: ignore
|
||||
workflow_image_records=None, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
@ -94,6 +93,7 @@ def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[B
|
||||
queue_id=DEFAULT_QUEUE_ID,
|
||||
services=services,
|
||||
graph_execution_state_id="1",
|
||||
workflow=None,
|
||||
)
|
||||
)
|
||||
g.complete(n.id, o)
|
||||
|
@ -82,7 +82,6 @@ def mock_services() -> InvocationServices:
|
||||
session_queue=None, # type: ignore
|
||||
urls=None, # type: ignore
|
||||
workflow_records=None, # type: ignore
|
||||
workflow_image_records=None, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
|
@ -206,9 +206,9 @@ def test_deny_nodes(patch_rootdir):
|
||||
# confirm invocations union will not have denied nodes
|
||||
all_invocations = BaseInvocation.get_invocations()
|
||||
|
||||
has_integer = len([i for i in all_invocations if i.__fields__.get("type").default == "integer"]) == 1
|
||||
has_string = len([i for i in all_invocations if i.__fields__.get("type").default == "string"]) == 1
|
||||
has_float = len([i for i in all_invocations if i.__fields__.get("type").default == "float"]) == 1
|
||||
has_integer = len([i for i in all_invocations if i.model_fields.get("type").default == "integer"]) == 1
|
||||
has_string = len([i for i in all_invocations if i.model_fields.get("type").default == "string"]) == 1
|
||||
has_float = len([i for i in all_invocations if i.model_fields.get("type").default == "float"]) == 1
|
||||
|
||||
assert has_integer
|
||||
assert has_string
|
||||
|
Loading…
Reference in New Issue
Block a user