resolve merge conflicts

This commit is contained in:
Lincoln Stein
2023-08-20 15:26:52 -04:00
105 changed files with 473 additions and 453 deletions

View File

@ -90,17 +90,17 @@ def test_graph_state_executes_in_order(simple_graph, mock_services):
def test_graph_is_complete(simple_graph, mock_services):
g = GraphExecutionState(graph=simple_graph)
n1 = invoke_next(g, mock_services)
n2 = invoke_next(g, mock_services)
n3 = g.next()
_ = invoke_next(g, mock_services)
_ = invoke_next(g, mock_services)
_ = g.next()
assert g.is_complete()
def test_graph_is_not_complete(simple_graph, mock_services):
g = GraphExecutionState(graph=simple_graph)
n1 = invoke_next(g, mock_services)
n2 = g.next()
_ = invoke_next(g, mock_services)
_ = g.next()
assert not g.is_complete()
@ -140,11 +140,11 @@ def test_graph_state_collects(mock_services):
graph.add_edge(create_edge("3", "prompt", "4", "item"))
g = GraphExecutionState(graph=graph)
n1 = invoke_next(g, mock_services)
n2 = invoke_next(g, mock_services)
n3 = invoke_next(g, mock_services)
n4 = invoke_next(g, mock_services)
n5 = invoke_next(g, mock_services)
_ = invoke_next(g, mock_services)
_ = invoke_next(g, mock_services)
_ = invoke_next(g, mock_services)
_ = invoke_next(g, mock_services)
_ = invoke_next(g, mock_services)
n6 = invoke_next(g, mock_services)
assert isinstance(n6[0], CollectInvocation)
@ -195,10 +195,10 @@ def test_graph_executes_depth_first(mock_services):
graph.add_edge(create_edge("prompt_iterated", "prompt", "prompt_successor", "prompt"))
g = GraphExecutionState(graph=graph)
n1 = invoke_next(g, mock_services)
n2 = invoke_next(g, mock_services)
n3 = invoke_next(g, mock_services)
n4 = invoke_next(g, mock_services)
_ = invoke_next(g, mock_services)
_ = invoke_next(g, mock_services)
_ = invoke_next(g, mock_services)
_ = invoke_next(g, mock_services)
# Because ordering is not guaranteed, we cannot compare results directly.
# Instead, we must count the number of results.
@ -211,17 +211,17 @@ def test_graph_executes_depth_first(mock_services):
assert get_completed_count(g, "prompt_iterated") == 1
assert get_completed_count(g, "prompt_successor") == 0
n5 = invoke_next(g, mock_services)
_ = invoke_next(g, mock_services)
assert get_completed_count(g, "prompt_iterated") == 1
assert get_completed_count(g, "prompt_successor") == 1
n6 = invoke_next(g, mock_services)
_ = invoke_next(g, mock_services)
assert get_completed_count(g, "prompt_iterated") == 2
assert get_completed_count(g, "prompt_successor") == 1
n7 = invoke_next(g, mock_services)
_ = invoke_next(g, mock_services)
assert get_completed_count(g, "prompt_iterated") == 2
assert get_completed_count(g, "prompt_successor") == 2

View File

@ -17,7 +17,8 @@ from invokeai.app.services.graph import (
IterateInvocation,
)
from invokeai.app.invocations.upscale import ESRGANInvocation
from invokeai.app.invocations.image import *
from invokeai.app.invocations.image import ShowImageInvocation
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
from invokeai.app.invocations.primitives import IntegerInvocation
from invokeai.app.services.default_graphs import create_text_to_image
@ -41,7 +42,7 @@ def test_connections_are_compatible():
result = are_connections_compatible(from_node, from_field, to_node, to_field)
assert result == True
assert result is True
def test_connections_are_incompatible():
@ -52,7 +53,7 @@ def test_connections_are_incompatible():
result = are_connections_compatible(from_node, from_field, to_node, to_field)
assert result == False
assert result is False
def test_connections_incompatible_with_invalid_fields():
@ -63,14 +64,14 @@ def test_connections_incompatible_with_invalid_fields():
# From field is invalid
result = are_connections_compatible(from_node, from_field, to_node, to_field)
assert result == False
assert result is False
# To field is invalid
from_field = "image"
to_field = "invalid_field"
result = are_connections_compatible(from_node, from_field, to_node, to_field)
assert result == False
assert result is False
def test_graph_can_add_node():
@ -394,7 +395,7 @@ def test_graph_validates():
e1 = create_edge("1", "image", "2", "image")
g.add_edge(e1)
assert g.is_valid() == True
assert g.is_valid() is True
def test_graph_invalid_if_edges_reference_missing_nodes():
@ -404,7 +405,7 @@ def test_graph_invalid_if_edges_reference_missing_nodes():
e1 = create_edge("1", "image", "2", "image")
g.edges.append(e1)
assert g.is_valid() == False
assert g.is_valid() is False
def test_graph_invalid_if_subgraph_invalid():
@ -419,7 +420,7 @@ def test_graph_invalid_if_subgraph_invalid():
g.nodes[n1.id] = n1
assert g.is_valid() == False
assert g.is_valid() is False
def test_graph_invalid_if_has_cycle():
@ -433,7 +434,7 @@ def test_graph_invalid_if_has_cycle():
g.edges.append(e1)
g.edges.append(e2)
assert g.is_valid() == False
assert g.is_valid() is False
def test_graph_invalid_with_invalid_connection():
@ -445,7 +446,7 @@ def test_graph_invalid_with_invalid_connection():
e1 = create_edge("1", "image", "2", "strength")
g.edges.append(e1)
assert g.is_valid() == False
assert g.is_valid() is False
# TODO: Subgraph operations
@ -536,7 +537,7 @@ def test_graph_fails_to_get_missing_subgraph_node():
g.add_node(n1)
with pytest.raises(NodeNotFoundError):
result = g.get_node("1.2")
_ = g.get_node("1.2")
def test_graph_fails_to_enumerate_non_subgraph_node():
@ -554,7 +555,7 @@ def test_graph_fails_to_enumerate_non_subgraph_node():
g.add_node(n2)
with pytest.raises(NodeNotFoundError):
result = g.get_node("2.1")
_ = g.get_node("2.1")
def test_graph_gets_networkx_graph():
@ -584,7 +585,7 @@ def test_graph_can_serialize():
g.add_edge(e)
# Not throwing on this line is sufficient
json = g.json()
_ = g.json()
def test_graph_can_deserialize():
@ -612,4 +613,4 @@ def test_graph_can_deserialize():
def test_graph_can_generate_schema():
# Not throwing on this line is sufficient
# NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation
schema = Graph.schema_json(indent=2)
_ = Graph.schema_json(indent=2)

View File

@ -1,9 +1,7 @@
from typing import Any, Callable, Literal, Union
from pydantic import Field
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from invokeai.app.invocations.image import ImageField
from invokeai.app.services.invocation_services import InvocationServices
from pydantic import Field
import pytest
# Define test invocations before importing anything that uses invocations
@ -82,8 +80,9 @@ class PromptCollectionTestInvocation(BaseInvocation):
return PromptCollectionTestInvocationOutput(collection=self.collection.copy())
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.graph import Edge, EdgeConnection
# Importing these at the top breaks previous tests
from invokeai.app.services.events import EventServiceBase # noqa: E402
from invokeai.app.services.graph import Edge, EdgeConnection # noqa: E402
def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edge:

View File

@ -1,14 +1,18 @@
import os
import pytest
import sys
from typing import Any
import pytest
from omegaconf import OmegaConf
from pathlib import Path
os.environ["INVOKEAI_ROOT"] = "/tmp"
from invokeai.app.services.config import InvokeAIAppConfig
@pytest.fixture
def patch_rootdir(tmp_path: Path, monkeypatch: Any) -> None:
"""This may be overkill since the current tests don't need the root dir to exist"""
monkeypatch.setenv("INVOKEAI_ROOT", str(tmp_path))
init1 = OmegaConf.create(
"""
InvokeAI:
@ -47,10 +51,12 @@ InvokeAI:
)
def test_use_init():
def test_use_init(patch_rootdir):
# note that we explicitly set omegaconf dict and argv here
# so that the values aren't read from ~invokeai/invokeai.yaml and
# sys.argv respectively.
from invokeai.app.services.config import InvokeAIAppConfig
conf1 = InvokeAIAppConfig.get_config()
assert conf1
conf1.parse_args(conf=init1, argv=[])
@ -85,14 +91,16 @@ def test_argv_override():
assert conf.outdir == Path("outputs") # this is the default
def test_env_override():
def test_env_override(patch_rootdir):
from invokeai.app.services.config import InvokeAIAppConfig
# argv overrides
conf = InvokeAIAppConfig()
conf.parse_args(conf=init1, argv=["--max_cache=10"])
assert conf.always_use_cpu == False
assert conf.always_use_cpu is False
os.environ["INVOKEAI_always_use_cpu"] = "True"
conf.parse_args(conf=init1, argv=["--max_cache=10"])
assert conf.always_use_cpu == True
assert conf.always_use_cpu is True
# environment variables should be case insensitive
os.environ["InvokeAI_Max_Cache_Size"] = "15"
@ -102,7 +110,7 @@ def test_env_override():
conf = InvokeAIAppConfig()
conf.parse_args(conf=init1, argv=["--no-always_use_cpu", "--max_cache=10"])
assert conf.always_use_cpu == False
assert conf.always_use_cpu is False
assert conf.max_cache_size == 10
conf = InvokeAIAppConfig.get_config(max_cache_size=20)
@ -110,7 +118,9 @@ def test_env_override():
assert conf.max_cache_size == 20
def test_root_resists_cwd():
def test_root_resists_cwd(patch_rootdir):
from invokeai.app.services.config import InvokeAIAppConfig
previous = os.environ["INVOKEAI_ROOT"]
cwd = Path(os.getcwd()).resolve()
@ -125,7 +135,9 @@ def test_root_resists_cwd():
os.chdir(cwd)
def test_type_coercion():
def test_type_coercion(patch_rootdir):
from invokeai.app.services.config import InvokeAIAppConfig
conf = InvokeAIAppConfig().get_config()
conf.parse_args(argv=["--root=/tmp/foobar"])
assert conf.root == Path("/tmp/foobar")