mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
resolve merge conflicts
This commit is contained in:
@ -90,17 +90,17 @@ def test_graph_state_executes_in_order(simple_graph, mock_services):
|
||||
|
||||
def test_graph_is_complete(simple_graph, mock_services):
|
||||
g = GraphExecutionState(graph=simple_graph)
|
||||
n1 = invoke_next(g, mock_services)
|
||||
n2 = invoke_next(g, mock_services)
|
||||
n3 = g.next()
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = g.next()
|
||||
|
||||
assert g.is_complete()
|
||||
|
||||
|
||||
def test_graph_is_not_complete(simple_graph, mock_services):
|
||||
g = GraphExecutionState(graph=simple_graph)
|
||||
n1 = invoke_next(g, mock_services)
|
||||
n2 = g.next()
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = g.next()
|
||||
|
||||
assert not g.is_complete()
|
||||
|
||||
@ -140,11 +140,11 @@ def test_graph_state_collects(mock_services):
|
||||
graph.add_edge(create_edge("3", "prompt", "4", "item"))
|
||||
|
||||
g = GraphExecutionState(graph=graph)
|
||||
n1 = invoke_next(g, mock_services)
|
||||
n2 = invoke_next(g, mock_services)
|
||||
n3 = invoke_next(g, mock_services)
|
||||
n4 = invoke_next(g, mock_services)
|
||||
n5 = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g, mock_services)
|
||||
n6 = invoke_next(g, mock_services)
|
||||
|
||||
assert isinstance(n6[0], CollectInvocation)
|
||||
@ -195,10 +195,10 @@ def test_graph_executes_depth_first(mock_services):
|
||||
graph.add_edge(create_edge("prompt_iterated", "prompt", "prompt_successor", "prompt"))
|
||||
|
||||
g = GraphExecutionState(graph=graph)
|
||||
n1 = invoke_next(g, mock_services)
|
||||
n2 = invoke_next(g, mock_services)
|
||||
n3 = invoke_next(g, mock_services)
|
||||
n4 = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g, mock_services)
|
||||
|
||||
# Because ordering is not guaranteed, we cannot compare results directly.
|
||||
# Instead, we must count the number of results.
|
||||
@ -211,17 +211,17 @@ def test_graph_executes_depth_first(mock_services):
|
||||
assert get_completed_count(g, "prompt_iterated") == 1
|
||||
assert get_completed_count(g, "prompt_successor") == 0
|
||||
|
||||
n5 = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g, mock_services)
|
||||
|
||||
assert get_completed_count(g, "prompt_iterated") == 1
|
||||
assert get_completed_count(g, "prompt_successor") == 1
|
||||
|
||||
n6 = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g, mock_services)
|
||||
|
||||
assert get_completed_count(g, "prompt_iterated") == 2
|
||||
assert get_completed_count(g, "prompt_successor") == 1
|
||||
|
||||
n7 = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g, mock_services)
|
||||
|
||||
assert get_completed_count(g, "prompt_iterated") == 2
|
||||
assert get_completed_count(g, "prompt_successor") == 2
|
||||
|
@ -17,7 +17,8 @@ from invokeai.app.services.graph import (
|
||||
IterateInvocation,
|
||||
)
|
||||
from invokeai.app.invocations.upscale import ESRGANInvocation
|
||||
from invokeai.app.invocations.image import *
|
||||
|
||||
from invokeai.app.invocations.image import ShowImageInvocation
|
||||
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
|
||||
from invokeai.app.invocations.primitives import IntegerInvocation
|
||||
from invokeai.app.services.default_graphs import create_text_to_image
|
||||
@ -41,7 +42,7 @@ def test_connections_are_compatible():
|
||||
|
||||
result = are_connections_compatible(from_node, from_field, to_node, to_field)
|
||||
|
||||
assert result == True
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_connections_are_incompatible():
|
||||
@ -52,7 +53,7 @@ def test_connections_are_incompatible():
|
||||
|
||||
result = are_connections_compatible(from_node, from_field, to_node, to_field)
|
||||
|
||||
assert result == False
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_connections_incompatible_with_invalid_fields():
|
||||
@ -63,14 +64,14 @@ def test_connections_incompatible_with_invalid_fields():
|
||||
|
||||
# From field is invalid
|
||||
result = are_connections_compatible(from_node, from_field, to_node, to_field)
|
||||
assert result == False
|
||||
assert result is False
|
||||
|
||||
# To field is invalid
|
||||
from_field = "image"
|
||||
to_field = "invalid_field"
|
||||
|
||||
result = are_connections_compatible(from_node, from_field, to_node, to_field)
|
||||
assert result == False
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_graph_can_add_node():
|
||||
@ -394,7 +395,7 @@ def test_graph_validates():
|
||||
e1 = create_edge("1", "image", "2", "image")
|
||||
g.add_edge(e1)
|
||||
|
||||
assert g.is_valid() == True
|
||||
assert g.is_valid() is True
|
||||
|
||||
|
||||
def test_graph_invalid_if_edges_reference_missing_nodes():
|
||||
@ -404,7 +405,7 @@ def test_graph_invalid_if_edges_reference_missing_nodes():
|
||||
e1 = create_edge("1", "image", "2", "image")
|
||||
g.edges.append(e1)
|
||||
|
||||
assert g.is_valid() == False
|
||||
assert g.is_valid() is False
|
||||
|
||||
|
||||
def test_graph_invalid_if_subgraph_invalid():
|
||||
@ -419,7 +420,7 @@ def test_graph_invalid_if_subgraph_invalid():
|
||||
|
||||
g.nodes[n1.id] = n1
|
||||
|
||||
assert g.is_valid() == False
|
||||
assert g.is_valid() is False
|
||||
|
||||
|
||||
def test_graph_invalid_if_has_cycle():
|
||||
@ -433,7 +434,7 @@ def test_graph_invalid_if_has_cycle():
|
||||
g.edges.append(e1)
|
||||
g.edges.append(e2)
|
||||
|
||||
assert g.is_valid() == False
|
||||
assert g.is_valid() is False
|
||||
|
||||
|
||||
def test_graph_invalid_with_invalid_connection():
|
||||
@ -445,7 +446,7 @@ def test_graph_invalid_with_invalid_connection():
|
||||
e1 = create_edge("1", "image", "2", "strength")
|
||||
g.edges.append(e1)
|
||||
|
||||
assert g.is_valid() == False
|
||||
assert g.is_valid() is False
|
||||
|
||||
|
||||
# TODO: Subgraph operations
|
||||
@ -536,7 +537,7 @@ def test_graph_fails_to_get_missing_subgraph_node():
|
||||
g.add_node(n1)
|
||||
|
||||
with pytest.raises(NodeNotFoundError):
|
||||
result = g.get_node("1.2")
|
||||
_ = g.get_node("1.2")
|
||||
|
||||
|
||||
def test_graph_fails_to_enumerate_non_subgraph_node():
|
||||
@ -554,7 +555,7 @@ def test_graph_fails_to_enumerate_non_subgraph_node():
|
||||
g.add_node(n2)
|
||||
|
||||
with pytest.raises(NodeNotFoundError):
|
||||
result = g.get_node("2.1")
|
||||
_ = g.get_node("2.1")
|
||||
|
||||
|
||||
def test_graph_gets_networkx_graph():
|
||||
@ -584,7 +585,7 @@ def test_graph_can_serialize():
|
||||
g.add_edge(e)
|
||||
|
||||
# Not throwing on this line is sufficient
|
||||
json = g.json()
|
||||
_ = g.json()
|
||||
|
||||
|
||||
def test_graph_can_deserialize():
|
||||
@ -612,4 +613,4 @@ def test_graph_can_deserialize():
|
||||
def test_graph_can_generate_schema():
|
||||
# Not throwing on this line is sufficient
|
||||
# NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation
|
||||
schema = Graph.schema_json(indent=2)
|
||||
_ = Graph.schema_json(indent=2)
|
||||
|
@ -1,9 +1,7 @@
|
||||
from typing import Any, Callable, Literal, Union
|
||||
from pydantic import Field
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||
from invokeai.app.invocations.image import ImageField
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from pydantic import Field
|
||||
import pytest
|
||||
|
||||
|
||||
# Define test invocations before importing anything that uses invocations
|
||||
@ -82,8 +80,9 @@ class PromptCollectionTestInvocation(BaseInvocation):
|
||||
return PromptCollectionTestInvocationOutput(collection=self.collection.copy())
|
||||
|
||||
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.graph import Edge, EdgeConnection
|
||||
# Importing these at the top breaks previous tests
|
||||
from invokeai.app.services.events import EventServiceBase # noqa: E402
|
||||
from invokeai.app.services.graph import Edge, EdgeConnection # noqa: E402
|
||||
|
||||
|
||||
def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edge:
|
||||
|
@ -1,14 +1,18 @@
|
||||
import os
|
||||
import pytest
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from omegaconf import OmegaConf
|
||||
from pathlib import Path
|
||||
|
||||
os.environ["INVOKEAI_ROOT"] = "/tmp"
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
@pytest.fixture
|
||||
def patch_rootdir(tmp_path: Path, monkeypatch: Any) -> None:
|
||||
"""This may be overkill since the current tests don't need the root dir to exist"""
|
||||
monkeypatch.setenv("INVOKEAI_ROOT", str(tmp_path))
|
||||
|
||||
|
||||
init1 = OmegaConf.create(
|
||||
"""
|
||||
InvokeAI:
|
||||
@ -47,10 +51,12 @@ InvokeAI:
|
||||
)
|
||||
|
||||
|
||||
def test_use_init():
|
||||
def test_use_init(patch_rootdir):
|
||||
# note that we explicitly set omegaconf dict and argv here
|
||||
# so that the values aren't read from ~invokeai/invokeai.yaml and
|
||||
# sys.argv respectively.
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
conf1 = InvokeAIAppConfig.get_config()
|
||||
assert conf1
|
||||
conf1.parse_args(conf=init1, argv=[])
|
||||
@ -85,14 +91,16 @@ def test_argv_override():
|
||||
assert conf.outdir == Path("outputs") # this is the default
|
||||
|
||||
|
||||
def test_env_override():
|
||||
def test_env_override(patch_rootdir):
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
# argv overrides
|
||||
conf = InvokeAIAppConfig()
|
||||
conf.parse_args(conf=init1, argv=["--max_cache=10"])
|
||||
assert conf.always_use_cpu == False
|
||||
assert conf.always_use_cpu is False
|
||||
os.environ["INVOKEAI_always_use_cpu"] = "True"
|
||||
conf.parse_args(conf=init1, argv=["--max_cache=10"])
|
||||
assert conf.always_use_cpu == True
|
||||
assert conf.always_use_cpu is True
|
||||
|
||||
# environment variables should be case insensitive
|
||||
os.environ["InvokeAI_Max_Cache_Size"] = "15"
|
||||
@ -102,7 +110,7 @@ def test_env_override():
|
||||
|
||||
conf = InvokeAIAppConfig()
|
||||
conf.parse_args(conf=init1, argv=["--no-always_use_cpu", "--max_cache=10"])
|
||||
assert conf.always_use_cpu == False
|
||||
assert conf.always_use_cpu is False
|
||||
assert conf.max_cache_size == 10
|
||||
|
||||
conf = InvokeAIAppConfig.get_config(max_cache_size=20)
|
||||
@ -110,7 +118,9 @@ def test_env_override():
|
||||
assert conf.max_cache_size == 20
|
||||
|
||||
|
||||
def test_root_resists_cwd():
|
||||
def test_root_resists_cwd(patch_rootdir):
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
previous = os.environ["INVOKEAI_ROOT"]
|
||||
cwd = Path(os.getcwd()).resolve()
|
||||
|
||||
@ -125,7 +135,9 @@ def test_root_resists_cwd():
|
||||
os.chdir(cwd)
|
||||
|
||||
|
||||
def test_type_coercion():
|
||||
def test_type_coercion(patch_rootdir):
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
conf = InvokeAIAppConfig().get_config()
|
||||
conf.parse_args(argv=["--root=/tmp/foobar"])
|
||||
assert conf.root == Path("/tmp/foobar")
|
||||
|
Reference in New Issue
Block a user