mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
[nodes] Add subgraph library, subgraph usage in CLI, and fix subgraph execution (#3180)
* Add latent to latent (img2img equivalent) Fix a CLI bug with multiple links per node * Using "latents" instead of "latent" * [nodes] In-progress implementation of graph library * Add linking to CLI for graph nodes (still broken) * Fix subgraph execution, fix subgraph linking in CLI * Fix LatentsToLatents
This commit is contained in:
parent
024fd54d0b
commit
23d65e7162
@ -3,12 +3,14 @@
|
||||
import os
|
||||
from argparse import Namespace
|
||||
|
||||
from ..services.default_graphs import create_system_graphs
|
||||
|
||||
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
|
||||
from ...backend import Globals
|
||||
from ..services.model_manager_initializer import get_model_manager
|
||||
from ..services.restoration_services import RestorationServices
|
||||
from ..services.graph import GraphExecutionState
|
||||
from ..services.graph import GraphExecutionState, LibraryGraph
|
||||
from ..services.image_storage import DiskImageStorage
|
||||
from ..services.invocation_queue import MemoryInvocationQueue
|
||||
from ..services.invocation_services import InvocationServices
|
||||
@ -69,6 +71,9 @@ class ApiDependencies:
|
||||
latents=latents,
|
||||
images=images,
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](
|
||||
filename=db_location, table_name="graphs"
|
||||
),
|
||||
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
),
|
||||
@ -76,6 +81,8 @@ class ApiDependencies:
|
||||
restoration=RestorationServices(config),
|
||||
)
|
||||
|
||||
create_system_graphs(services.graph_library)
|
||||
|
||||
ApiDependencies.invoker = Invoker(services)
|
||||
|
||||
@staticmethod
|
||||
|
@ -7,11 +7,40 @@ from pydantic import BaseModel, Field
|
||||
import networkx as nx
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from ..models.image import ImageField
|
||||
from ..services.graph import GraphExecutionState
|
||||
from ..invocations.baseinvocation import BaseInvocation
|
||||
from ..invocations.image import ImageField
|
||||
from ..services.graph import GraphExecutionState, LibraryGraph, GraphInvocation, Edge
|
||||
from ..services.invoker import Invoker
|
||||
|
||||
|
||||
def add_field_argument(command_parser, name: str, field, default_override = None):
|
||||
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory()
|
||||
if get_origin(field.type_) == Literal:
|
||||
allowed_values = get_args(field.type_)
|
||||
allowed_types = set()
|
||||
for val in allowed_values:
|
||||
allowed_types.add(type(val))
|
||||
allowed_types_list = list(allowed_types)
|
||||
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
|
||||
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field_type,
|
||||
default=default,
|
||||
choices=allowed_values,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
else:
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field.type_,
|
||||
default=default,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
|
||||
|
||||
def add_parsers(
|
||||
subparsers,
|
||||
commands: list[type],
|
||||
@ -36,30 +65,26 @@ def add_parsers(
|
||||
if name in exclude_fields:
|
||||
continue
|
||||
|
||||
if get_origin(field.type_) == Literal:
|
||||
allowed_values = get_args(field.type_)
|
||||
allowed_types = set()
|
||||
for val in allowed_values:
|
||||
allowed_types.add(type(val))
|
||||
allowed_types_list = list(allowed_types)
|
||||
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
|
||||
add_field_argument(command_parser, name, field)
|
||||
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field_type,
|
||||
default=field.default if field.default_factory is None else field.default_factory(),
|
||||
choices=allowed_values,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
else:
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field.type_,
|
||||
default=field.default if field.default_factory is None else field.default_factory(),
|
||||
help=field.field_info.description,
|
||||
)
|
||||
|
||||
def add_graph_parsers(
|
||||
subparsers,
|
||||
graphs: list[LibraryGraph],
|
||||
add_arguments: Callable[[argparse.ArgumentParser], None]|None = None
|
||||
):
|
||||
for graph in graphs:
|
||||
command_parser = subparsers.add_parser(graph.name, help=graph.description)
|
||||
|
||||
if add_arguments is not None:
|
||||
add_arguments(command_parser)
|
||||
|
||||
# Add arguments for inputs
|
||||
for exposed_input in graph.exposed_inputs:
|
||||
node = graph.graph.get_node(exposed_input.node_path)
|
||||
field = node.__fields__[exposed_input.field]
|
||||
default_override = getattr(node, exposed_input.field)
|
||||
add_field_argument(command_parser, exposed_input.alias, field, default_override)
|
||||
|
||||
|
||||
class CliContext:
|
||||
@ -67,17 +92,38 @@ class CliContext:
|
||||
session: GraphExecutionState
|
||||
parser: argparse.ArgumentParser
|
||||
defaults: dict[str, Any]
|
||||
graph_nodes: dict[str, str]
|
||||
nodes_added: list[str]
|
||||
|
||||
def __init__(self, invoker: Invoker, session: GraphExecutionState, parser: argparse.ArgumentParser):
|
||||
self.invoker = invoker
|
||||
self.session = session
|
||||
self.parser = parser
|
||||
self.defaults = dict()
|
||||
self.graph_nodes = dict()
|
||||
self.nodes_added = list()
|
||||
|
||||
def get_session(self):
|
||||
self.session = self.invoker.services.graph_execution_manager.get(self.session.id)
|
||||
return self.session
|
||||
|
||||
def reset(self):
|
||||
self.session = self.invoker.create_execution_state()
|
||||
self.graph_nodes = dict()
|
||||
self.nodes_added = list()
|
||||
# Leave defaults unchanged
|
||||
|
||||
def add_node(self, node: BaseInvocation):
|
||||
self.get_session()
|
||||
self.session.graph.add_node(node)
|
||||
self.nodes_added.append(node.id)
|
||||
self.invoker.services.graph_execution_manager.set(self.session)
|
||||
|
||||
def add_edge(self, edge: Edge):
|
||||
self.get_session()
|
||||
self.session.add_edge(edge)
|
||||
self.invoker.services.graph_execution_manager.set(self.session)
|
||||
|
||||
|
||||
class ExitCli(Exception):
|
||||
"""Exception to exit the CLI"""
|
||||
|
@ -13,17 +13,20 @@ from typing import (
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import Field
|
||||
|
||||
from .services.default_graphs import create_system_graphs
|
||||
|
||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
|
||||
from ..backend import Args
|
||||
from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_graph_execution_history
|
||||
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers, get_graph_execution_history
|
||||
from .cli.completer import set_autocompleter
|
||||
from .invocations import *
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
from .services.events import EventServiceBase
|
||||
from .services.model_manager_initializer import get_model_manager
|
||||
from .services.restoration_services import RestorationServices
|
||||
from .services.graph import Edge, EdgeConnection, GraphExecutionState, are_connection_types_compatible
|
||||
from .services.graph import Edge, EdgeConnection, ExposedNodeInput, GraphExecutionState, GraphInvocation, LibraryGraph, are_connection_types_compatible
|
||||
from .services.default_graphs import default_text_to_image_graph_id
|
||||
from .services.image_storage import DiskImageStorage
|
||||
from .services.invocation_queue import MemoryInvocationQueue
|
||||
from .services.invocation_services import InvocationServices
|
||||
@ -58,7 +61,7 @@ def add_invocation_args(command_parser):
|
||||
)
|
||||
|
||||
|
||||
def get_command_parser() -> argparse.ArgumentParser:
|
||||
def get_command_parser(services: InvocationServices) -> argparse.ArgumentParser:
|
||||
# Create invocation parser
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
@ -76,20 +79,72 @@ def get_command_parser() -> argparse.ArgumentParser:
|
||||
commands = BaseCommand.get_all_subclasses()
|
||||
add_parsers(subparsers, commands, exclude_fields=["type"])
|
||||
|
||||
# Create subparsers for exposed CLI graphs
|
||||
# TODO: add a way to identify these graphs
|
||||
text_to_image = services.graph_library.get(default_text_to_image_graph_id)
|
||||
add_graph_parsers(subparsers, [text_to_image], add_arguments=add_invocation_args)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
class NodeField():
|
||||
alias: str
|
||||
node_path: str
|
||||
field: str
|
||||
field_type: type
|
||||
|
||||
def __init__(self, alias: str, node_path: str, field: str, field_type: type):
|
||||
self.alias = alias
|
||||
self.node_path = node_path
|
||||
self.field = field
|
||||
self.field_type = field_type
|
||||
|
||||
|
||||
def fields_from_type_hints(hints: dict[str, type], node_path: str) -> dict[str,NodeField]:
|
||||
return {k:NodeField(alias=k, node_path=node_path, field=k, field_type=v) for k, v in hints.items()}
|
||||
|
||||
|
||||
def get_node_input_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
|
||||
"""Gets the node field for the specified field alias"""
|
||||
exposed_input = next(e for e in graph.exposed_inputs if e.alias == field_alias)
|
||||
node_type = type(graph.graph.get_node(exposed_input.node_path))
|
||||
return NodeField(alias=exposed_input.alias, node_path=f'{node_id}.{exposed_input.node_path}', field=exposed_input.field, field_type=get_type_hints(node_type)[exposed_input.field])
|
||||
|
||||
|
||||
def get_node_output_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
|
||||
"""Gets the node field for the specified field alias"""
|
||||
exposed_output = next(e for e in graph.exposed_outputs if e.alias == field_alias)
|
||||
node_type = type(graph.graph.get_node(exposed_output.node_path))
|
||||
node_output_type = node_type.get_output_type()
|
||||
return NodeField(alias=exposed_output.alias, node_path=f'{node_id}.{exposed_output.node_path}', field=exposed_output.field, field_type=get_type_hints(node_output_type)[exposed_output.field])
|
||||
|
||||
|
||||
def get_node_inputs(invocation: BaseInvocation, context: CliContext) -> dict[str, NodeField]:
|
||||
"""Gets the inputs for the specified invocation from the context"""
|
||||
node_type = type(invocation)
|
||||
if node_type is not GraphInvocation:
|
||||
return fields_from_type_hints(get_type_hints(node_type), invocation.id)
|
||||
else:
|
||||
graph: LibraryGraph = context.invoker.services.graph_library.get(context.graph_nodes[invocation.id])
|
||||
return {e.alias: get_node_input_field(graph, e.alias, invocation.id) for e in graph.exposed_inputs}
|
||||
|
||||
|
||||
def get_node_outputs(invocation: BaseInvocation, context: CliContext) -> dict[str, NodeField]:
|
||||
"""Gets the outputs for the specified invocation from the context"""
|
||||
node_type = type(invocation)
|
||||
if node_type is not GraphInvocation:
|
||||
return fields_from_type_hints(get_type_hints(node_type.get_output_type()), invocation.id)
|
||||
else:
|
||||
graph: LibraryGraph = context.invoker.services.graph_library.get(context.graph_nodes[invocation.id])
|
||||
return {e.alias: get_node_output_field(graph, e.alias, invocation.id) for e in graph.exposed_outputs}
|
||||
|
||||
|
||||
def generate_matching_edges(
|
||||
a: BaseInvocation, b: BaseInvocation
|
||||
a: BaseInvocation, b: BaseInvocation, context: CliContext
|
||||
) -> list[Edge]:
|
||||
"""Generates all possible edges between two invocations"""
|
||||
atype = type(a)
|
||||
btype = type(b)
|
||||
|
||||
aoutputtype = atype.get_output_type()
|
||||
|
||||
afields = get_type_hints(aoutputtype)
|
||||
bfields = get_type_hints(btype)
|
||||
afields = get_node_outputs(a, context)
|
||||
bfields = get_node_inputs(b, context)
|
||||
|
||||
matching_fields = set(afields.keys()).intersection(bfields.keys())
|
||||
|
||||
@ -98,14 +153,14 @@ def generate_matching_edges(
|
||||
matching_fields = matching_fields.difference(invalid_fields)
|
||||
|
||||
# Validate types
|
||||
matching_fields = [f for f in matching_fields if are_connection_types_compatible(afields[f], bfields[f])]
|
||||
matching_fields = [f for f in matching_fields if are_connection_types_compatible(afields[f].field_type, bfields[f].field_type)]
|
||||
|
||||
edges = [
|
||||
Edge(
|
||||
source=EdgeConnection(node_id=a.id, field=field),
|
||||
destination=EdgeConnection(node_id=b.id, field=field)
|
||||
source=EdgeConnection(node_id=afields[alias].node_path, field=afields[alias].field),
|
||||
destination=EdgeConnection(node_id=bfields[alias].node_path, field=bfields[alias].field)
|
||||
)
|
||||
for field in matching_fields
|
||||
for alias in matching_fields
|
||||
]
|
||||
return edges
|
||||
|
||||
@ -158,6 +213,9 @@ def invoke_cli():
|
||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
|
||||
images=DiskImageStorage(f'{output_folder}/images'),
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](
|
||||
filename=db_location, table_name="graphs"
|
||||
),
|
||||
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
),
|
||||
@ -165,9 +223,12 @@ def invoke_cli():
|
||||
restoration=RestorationServices(config),
|
||||
)
|
||||
|
||||
system_graphs = create_system_graphs(services.graph_library)
|
||||
system_graph_names = set([g.name for g in system_graphs])
|
||||
|
||||
invoker = Invoker(services)
|
||||
session: GraphExecutionState = invoker.create_execution_state()
|
||||
parser = get_command_parser()
|
||||
parser = get_command_parser(services)
|
||||
|
||||
re_negid = re.compile('^-[0-9]+$')
|
||||
|
||||
@ -185,11 +246,12 @@ def invoke_cli():
|
||||
|
||||
try:
|
||||
# Refresh the state of the session
|
||||
history = list(get_graph_execution_history(context.session))
|
||||
#history = list(get_graph_execution_history(context.session))
|
||||
history = list(reversed(context.nodes_added))
|
||||
|
||||
# Split the command for piping
|
||||
cmds = cmd_input.split("|")
|
||||
start_id = len(history)
|
||||
start_id = len(context.nodes_added)
|
||||
current_id = start_id
|
||||
new_invocations = list()
|
||||
for cmd in cmds:
|
||||
@ -205,8 +267,24 @@ def invoke_cli():
|
||||
args[field_name] = field_default
|
||||
|
||||
# Parse invocation
|
||||
args["id"] = current_id
|
||||
command = CliCommand(command=args)
|
||||
command: CliCommand = None # type:ignore
|
||||
system_graph: LibraryGraph|None = None
|
||||
if args['type'] in system_graph_names:
|
||||
system_graph = next(filter(lambda g: g.name == args['type'], system_graphs))
|
||||
invocation = GraphInvocation(graph=system_graph.graph, id=str(current_id))
|
||||
for exposed_input in system_graph.exposed_inputs:
|
||||
if exposed_input.alias in args:
|
||||
node = invocation.graph.get_node(exposed_input.node_path)
|
||||
field = exposed_input.field
|
||||
setattr(node, field, args[exposed_input.alias])
|
||||
command = CliCommand(command = invocation)
|
||||
context.graph_nodes[invocation.id] = system_graph.id
|
||||
else:
|
||||
args["id"] = current_id
|
||||
command = CliCommand(command=args)
|
||||
|
||||
if command is None:
|
||||
continue
|
||||
|
||||
# Run any CLI commands immediately
|
||||
if isinstance(command.command, BaseCommand):
|
||||
@ -217,6 +295,7 @@ def invoke_cli():
|
||||
command.command.run(context)
|
||||
continue
|
||||
|
||||
# TODO: handle linking with library graphs
|
||||
# Pipe previous command output (if there was a previous command)
|
||||
edges: list[Edge] = list()
|
||||
if len(history) > 0 or current_id != start_id:
|
||||
@ -229,7 +308,7 @@ def invoke_cli():
|
||||
else context.session.graph.get_node(from_id)
|
||||
)
|
||||
matching_edges = generate_matching_edges(
|
||||
from_node, command.command
|
||||
from_node, command.command, context
|
||||
)
|
||||
edges.extend(matching_edges)
|
||||
|
||||
@ -242,7 +321,7 @@ def invoke_cli():
|
||||
|
||||
link_node = context.session.graph.get_node(node_id)
|
||||
matching_edges = generate_matching_edges(
|
||||
link_node, command.command
|
||||
link_node, command.command, context
|
||||
)
|
||||
matching_destinations = [e.destination for e in matching_edges]
|
||||
edges = [e for e in edges if e.destination not in matching_destinations]
|
||||
@ -256,12 +335,14 @@ def invoke_cli():
|
||||
if re_negid.match(node_id):
|
||||
node_id = str(current_id + int(node_id))
|
||||
|
||||
# TODO: handle missing input/output
|
||||
node_output = get_node_outputs(context.session.graph.get_node(node_id), context)[link[1]]
|
||||
node_input = get_node_inputs(command.command, context)[link[2]]
|
||||
|
||||
edges.append(
|
||||
Edge(
|
||||
source=EdgeConnection(node_id=node_id, field=link[1]),
|
||||
destination=EdgeConnection(
|
||||
node_id=command.command.id, field=link[2]
|
||||
)
|
||||
source=EdgeConnection(node_id=node_output.node_path, field=node_output.field),
|
||||
destination=EdgeConnection(node_id=node_input.node_path, field=node_input.field)
|
||||
)
|
||||
)
|
||||
|
||||
@ -270,10 +351,10 @@ def invoke_cli():
|
||||
current_id = current_id + 1
|
||||
|
||||
# Add the node to the session
|
||||
context.session.add_node(command.command)
|
||||
context.add_node(command.command)
|
||||
for edge in edges:
|
||||
print(edge)
|
||||
context.session.add_edge(edge)
|
||||
context.add_edge(edge)
|
||||
|
||||
# Execute all remaining nodes
|
||||
invoke_all(context)
|
||||
@ -285,7 +366,7 @@ def invoke_cli():
|
||||
except SessionError:
|
||||
# Start a new session
|
||||
print("Session error: creating a new session")
|
||||
context.session = context.invoker.create_execution_state()
|
||||
context.reset()
|
||||
|
||||
except ExitCli:
|
||||
break
|
||||
|
@ -1,5 +1,6 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import random
|
||||
from typing import Literal, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
import torch
|
||||
@ -99,13 +100,17 @@ def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_c
|
||||
return x
|
||||
|
||||
|
||||
def random_seed():
|
||||
return random.randint(0, np.iinfo(np.uint32).max)
|
||||
|
||||
|
||||
class NoiseInvocation(BaseInvocation):
|
||||
"""Generates latent noise."""
|
||||
|
||||
type: Literal["noise"] = "noise"
|
||||
|
||||
# Inputs
|
||||
seed: int = Field(default=0, ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", )
|
||||
seed: int = Field(ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", default_factory=random_seed)
|
||||
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting noise", )
|
||||
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting noise", )
|
||||
|
||||
@ -313,6 +318,56 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
)
|
||||
|
||||
|
||||
class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
"""Generates latents using latents as base image."""
|
||||
|
||||
type: Literal["l2l"] = "l2l"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
|
||||
strength: float = Field(default=0.5, description="The strength of the latents to use")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
noise = context.services.latents.get(self.noise.latents_name)
|
||||
latent = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
self.dispatch_progress(context, state)
|
||||
|
||||
model = self.get_model(context.services.model_manager)
|
||||
conditioning_data = self.get_conditioning_data(model)
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
|
||||
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
|
||||
latent, device=model.device, dtype=latent.dtype
|
||||
)
|
||||
|
||||
timesteps, _ = model.get_img2img_timesteps(
|
||||
self.steps,
|
||||
self.strength,
|
||||
device=model.device,
|
||||
)
|
||||
|
||||
result_latents, result_attention_map_saver = model.latents_from_embeddings(
|
||||
latents=initial_latents,
|
||||
timesteps=timesteps,
|
||||
noise=noise,
|
||||
num_inference_steps=self.steps,
|
||||
conditioning_data=conditioning_data,
|
||||
callback=step_callback
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
context.services.latents.set(name, result_latents)
|
||||
return LatentsOutput(
|
||||
latents=LatentsField(latents_name=name)
|
||||
)
|
||||
|
||||
|
||||
# Latent to image
|
||||
class LatentsToImageInvocation(BaseInvocation):
|
||||
"""Generates an image from latents."""
|
||||
|
18
invokeai/app/invocations/params.py
Normal file
18
invokeai/app/invocations/params.py
Normal file
@ -0,0 +1,18 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Literal
|
||||
from pydantic import Field
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||
from .math import IntOutput
|
||||
|
||||
# Pass-through parameter nodes - used by subgraphs
|
||||
|
||||
class ParamIntInvocation(BaseInvocation):
|
||||
"""An integer parameter"""
|
||||
#fmt: off
|
||||
type: Literal["param_int"] = "param_int"
|
||||
a: int = Field(default=0, description="The integer value")
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a)
|
56
invokeai/app/services/default_graphs.py
Normal file
56
invokeai/app/services/default_graphs.py
Normal file
@ -0,0 +1,56 @@
|
||||
from ..invocations.latent import LatentsToImageInvocation, NoiseInvocation, TextToLatentsInvocation
|
||||
from ..invocations.params import ParamIntInvocation
|
||||
from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph
|
||||
from .item_storage import ItemStorageABC
|
||||
|
||||
|
||||
default_text_to_image_graph_id = '539b2af5-2b4d-4d8c-8071-e54a3255fc74'
|
||||
|
||||
|
||||
def create_text_to_image() -> LibraryGraph:
|
||||
return LibraryGraph(
|
||||
id=default_text_to_image_graph_id,
|
||||
name='t2i',
|
||||
description='Converts text to an image',
|
||||
graph=Graph(
|
||||
nodes={
|
||||
'width': ParamIntInvocation(id='width', a=512),
|
||||
'height': ParamIntInvocation(id='height', a=512),
|
||||
'3': NoiseInvocation(id='3'),
|
||||
'4': TextToLatentsInvocation(id='4'),
|
||||
'5': LatentsToImageInvocation(id='5')
|
||||
},
|
||||
edges=[
|
||||
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')),
|
||||
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='3', field='height')),
|
||||
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='4', field='width')),
|
||||
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='4', field='height')),
|
||||
Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='4', field='noise')),
|
||||
Edge(source=EdgeConnection(node_id='4', field='latents'), destination=EdgeConnection(node_id='5', field='latents')),
|
||||
]
|
||||
),
|
||||
exposed_inputs=[
|
||||
ExposedNodeInput(node_path='4', field='prompt', alias='prompt'),
|
||||
ExposedNodeInput(node_path='width', field='a', alias='width'),
|
||||
ExposedNodeInput(node_path='height', field='a', alias='height')
|
||||
],
|
||||
exposed_outputs=[
|
||||
ExposedNodeOutput(node_path='5', field='image', alias='image')
|
||||
])
|
||||
|
||||
|
||||
def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[LibraryGraph]:
|
||||
"""Creates the default system graphs, or adds new versions if the old ones don't match"""
|
||||
|
||||
graphs: list[LibraryGraph] = list()
|
||||
|
||||
text_to_image = graph_library.get(default_text_to_image_graph_id)
|
||||
|
||||
# TODO: Check if the graph is the same as the default one, and if not, update it
|
||||
#if text_to_image is None:
|
||||
text_to_image = create_text_to_image()
|
||||
graph_library.set(text_to_image)
|
||||
|
||||
graphs.append(text_to_image)
|
||||
|
||||
return graphs
|
@ -17,7 +17,7 @@ from typing import (
|
||||
)
|
||||
|
||||
import networkx as nx
|
||||
from pydantic import BaseModel, validator
|
||||
from pydantic import BaseModel, root_validator, validator
|
||||
from pydantic.fields import Field
|
||||
|
||||
from ..invocations import *
|
||||
@ -283,7 +283,8 @@ class Graph(BaseModel):
|
||||
:raises InvalidEdgeError: the provided edge is invalid.
|
||||
"""
|
||||
|
||||
if self._is_edge_valid(edge) and edge not in self.edges:
|
||||
self._validate_edge(edge)
|
||||
if edge not in self.edges:
|
||||
self.edges.append(edge)
|
||||
else:
|
||||
raise InvalidEdgeError()
|
||||
@ -354,7 +355,7 @@ class Graph(BaseModel):
|
||||
|
||||
return True
|
||||
|
||||
def _is_edge_valid(self, edge: Edge) -> bool:
|
||||
def _validate_edge(self, edge: Edge):
|
||||
"""Validates that a new edge doesn't create a cycle in the graph"""
|
||||
|
||||
# Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly)
|
||||
@ -362,54 +363,53 @@ class Graph(BaseModel):
|
||||
from_node = self.get_node(edge.source.node_id)
|
||||
to_node = self.get_node(edge.destination.node_id)
|
||||
except NodeNotFoundError:
|
||||
return False
|
||||
raise InvalidEdgeError("One or both nodes don't exist")
|
||||
|
||||
# Validate that an edge to this node+field doesn't already exist
|
||||
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
|
||||
if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation):
|
||||
return False
|
||||
raise InvalidEdgeError(f'Edge to node {edge.destination.node_id} field {edge.destination.field} already exists')
|
||||
|
||||
# Validate that no cycles would be created
|
||||
g = self.nx_graph_flat()
|
||||
g.add_edge(edge.source.node_id, edge.destination.node_id)
|
||||
if not nx.is_directed_acyclic_graph(g):
|
||||
return False
|
||||
raise InvalidEdgeError(f'Edge creates a cycle in the graph')
|
||||
|
||||
# Validate that the field types are compatible
|
||||
if not are_connections_compatible(
|
||||
from_node, edge.source.field, to_node, edge.destination.field
|
||||
):
|
||||
return False
|
||||
raise InvalidEdgeError(f'Fields are incompatible')
|
||||
|
||||
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
|
||||
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
|
||||
if not self._is_iterator_connection_valid(
|
||||
edge.destination.node_id, new_input=edge.source
|
||||
):
|
||||
return False
|
||||
raise InvalidEdgeError(f'Iterator input type does not match iterator output type')
|
||||
|
||||
# Validate if iterator input type matches output type (if this edge results in both being set)
|
||||
if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
|
||||
if not self._is_iterator_connection_valid(
|
||||
edge.source.node_id, new_output=edge.destination
|
||||
):
|
||||
return False
|
||||
raise InvalidEdgeError(f'Iterator output type does not match iterator input type')
|
||||
|
||||
# Validate if collector input type matches output type (if this edge results in both being set)
|
||||
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
|
||||
if not self._is_collector_connection_valid(
|
||||
edge.destination.node_id, new_input=edge.source
|
||||
):
|
||||
return False
|
||||
raise InvalidEdgeError(f'Collector output type does not match collector input type')
|
||||
|
||||
# Validate if collector output type matches input type (if this edge results in both being set)
|
||||
if isinstance(from_node, CollectInvocation) and edge.source.field == "collection":
|
||||
if not self._is_collector_connection_valid(
|
||||
edge.source.node_id, new_output=edge.destination
|
||||
):
|
||||
return False
|
||||
raise InvalidEdgeError(f'Collector input type does not match collector output type')
|
||||
|
||||
return True
|
||||
|
||||
def has_node(self, node_path: str) -> bool:
|
||||
"""Determines whether or not a node exists in the graph."""
|
||||
@ -733,7 +733,7 @@ class Graph(BaseModel):
|
||||
for sgn in (
|
||||
gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation)
|
||||
):
|
||||
sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix))
|
||||
g = sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix))
|
||||
|
||||
# TODO: figure out if iteration nodes need to be expanded
|
||||
|
||||
@ -858,7 +858,8 @@ class GraphExecutionState(BaseModel):
|
||||
|
||||
def is_complete(self) -> bool:
|
||||
"""Returns true if the graph is complete"""
|
||||
return self.has_error() or all((k in self.executed for k in self.graph.nodes))
|
||||
node_ids = set(self.graph.nx_graph_flat().nodes)
|
||||
return self.has_error() or all((k in self.executed for k in node_ids))
|
||||
|
||||
def has_error(self) -> bool:
|
||||
"""Returns true if the graph has any errors"""
|
||||
@ -946,11 +947,11 @@ class GraphExecutionState(BaseModel):
|
||||
|
||||
def _iterator_graph(self) -> nx.DiGraph:
|
||||
"""Gets a DiGraph with edges to collectors removed so an ancestor search produces all active iterators for any node"""
|
||||
g = self.graph.nx_graph()
|
||||
g = self.graph.nx_graph_flat()
|
||||
collectors = (
|
||||
n
|
||||
for n in self.graph.nodes
|
||||
if isinstance(self.graph.nodes[n], CollectInvocation)
|
||||
if isinstance(self.graph.get_node(n), CollectInvocation)
|
||||
)
|
||||
for c in collectors:
|
||||
g.remove_edges_from(list(g.in_edges(c)))
|
||||
@ -962,7 +963,7 @@ class GraphExecutionState(BaseModel):
|
||||
iterators = [
|
||||
n
|
||||
for n in nx.ancestors(g, node_id)
|
||||
if isinstance(self.graph.nodes[n], IterateInvocation)
|
||||
if isinstance(self.graph.get_node(n), IterateInvocation)
|
||||
]
|
||||
return iterators
|
||||
|
||||
@ -1098,7 +1099,9 @@ class GraphExecutionState(BaseModel):
|
||||
|
||||
# TODO: Add API for modifying underlying graph that checks if the change will be valid given the current execution state
|
||||
def _is_edge_valid(self, edge: Edge) -> bool:
|
||||
if not self._is_edge_valid(edge):
|
||||
try:
|
||||
self.graph._validate_edge(edge)
|
||||
except InvalidEdgeError:
|
||||
return False
|
||||
|
||||
# Invalid if destination has already been prepared or executed
|
||||
@ -1144,4 +1147,52 @@ class GraphExecutionState(BaseModel):
|
||||
self.graph.delete_edge(edge)
|
||||
|
||||
|
||||
class ExposedNodeInput(BaseModel):
|
||||
node_path: str = Field(description="The node path to the node with the input")
|
||||
field: str = Field(description="The field name of the input")
|
||||
alias: str = Field(description="The alias of the input")
|
||||
|
||||
|
||||
class ExposedNodeOutput(BaseModel):
|
||||
node_path: str = Field(description="The node path to the node with the output")
|
||||
field: str = Field(description="The field name of the output")
|
||||
alias: str = Field(description="The alias of the output")
|
||||
|
||||
class LibraryGraph(BaseModel):
|
||||
id: str = Field(description="The unique identifier for this library graph", default_factory=uuid.uuid4)
|
||||
graph: Graph = Field(description="The graph")
|
||||
name: str = Field(description="The name of the graph")
|
||||
description: str = Field(description="The description of the graph")
|
||||
exposed_inputs: list[ExposedNodeInput] = Field(description="The inputs exposed by this graph", default_factory=list)
|
||||
exposed_outputs: list[ExposedNodeOutput] = Field(description="The outputs exposed by this graph", default_factory=list)
|
||||
|
||||
@validator('exposed_inputs', 'exposed_outputs')
|
||||
def validate_exposed_aliases(cls, v):
|
||||
if len(v) != len(set(i.alias for i in v)):
|
||||
raise ValueError("Duplicate exposed alias")
|
||||
return v
|
||||
|
||||
@root_validator
|
||||
def validate_exposed_nodes(cls, values):
|
||||
graph = values['graph']
|
||||
|
||||
# Validate exposed inputs
|
||||
for exposed_input in values['exposed_inputs']:
|
||||
if not graph.has_node(exposed_input.node_path):
|
||||
raise ValueError(f"Exposed input node {exposed_input.node_path} does not exist")
|
||||
node = graph.get_node(exposed_input.node_path)
|
||||
if get_input_field(node, exposed_input.field) is None:
|
||||
raise ValueError(f"Exposed input field {exposed_input.field} does not exist on node {exposed_input.node_path}")
|
||||
|
||||
# Validate exposed outputs
|
||||
for exposed_output in values['exposed_outputs']:
|
||||
if not graph.has_node(exposed_output.node_path):
|
||||
raise ValueError(f"Exposed output node {exposed_output.node_path} does not exist")
|
||||
node = graph.get_node(exposed_output.node_path)
|
||||
if get_output_field(node, exposed_output.field) is None:
|
||||
raise ValueError(f"Exposed output field {exposed_output.field} does not exist on node {exposed_output.node_path}")
|
||||
|
||||
return values
|
||||
|
||||
|
||||
GraphInvocation.update_forward_refs()
|
||||
|
@ -19,6 +19,7 @@ class InvocationServices:
|
||||
restoration: RestorationServices
|
||||
|
||||
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
|
||||
graph_library: ItemStorageABC["LibraryGraph"]
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
||||
processor: "InvocationProcessorABC"
|
||||
|
||||
@ -29,6 +30,7 @@ class InvocationServices:
|
||||
latents: LatentsStorageBase,
|
||||
images: ImageStorageBase,
|
||||
queue: InvocationQueueABC,
|
||||
graph_library: ItemStorageABC["LibraryGraph"],
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||
processor: "InvocationProcessorABC",
|
||||
restoration: RestorationServices,
|
||||
@ -38,6 +40,7 @@ class InvocationServices:
|
||||
self.latents = latents
|
||||
self.images = images
|
||||
self.queue = queue
|
||||
self.graph_library = graph_library
|
||||
self.graph_execution_manager = graph_execution_manager
|
||||
self.processor = processor
|
||||
self.restoration = restoration
|
||||
|
@ -35,8 +35,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
self._create_table()
|
||||
|
||||
def _create_table(self):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
with self._lock:
|
||||
self._cursor.execute(
|
||||
f"""CREATE TABLE IF NOT EXISTS {self._table_name} (
|
||||
item TEXT,
|
||||
@ -45,34 +44,27 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
self._cursor.execute(
|
||||
f"""CREATE UNIQUE INDEX IF NOT EXISTS {self._table_name}_id ON {self._table_name}(id);"""
|
||||
)
|
||||
finally:
|
||||
self._lock.release()
|
||||
self._conn.commit()
|
||||
|
||||
def _parse_item(self, item: str) -> T:
|
||||
item_type = get_args(self.__orig_class__)[0]
|
||||
return parse_raw_as(item_type, item)
|
||||
|
||||
def set(self, item: T):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
with self._lock:
|
||||
self._cursor.execute(
|
||||
f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""",
|
||||
(item.json(),),
|
||||
)
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
self._on_changed(item)
|
||||
|
||||
def get(self, id: str) -> Union[T, None]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
with self._lock:
|
||||
self._cursor.execute(
|
||||
f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)
|
||||
)
|
||||
result = self._cursor.fetchone()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
if not result:
|
||||
return None
|
||||
@ -80,19 +72,15 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
return self._parse_item(result[0])
|
||||
|
||||
def delete(self, id: str):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
with self._lock:
|
||||
self._cursor.execute(
|
||||
f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),)
|
||||
)
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
self._on_deleted(id)
|
||||
|
||||
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
with self._lock:
|
||||
self._cursor.execute(
|
||||
f"""SELECT item FROM {self._table_name} LIMIT ? OFFSET ?;""",
|
||||
(per_page, page * per_page),
|
||||
@ -103,8 +91,6 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
|
||||
self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""")
|
||||
count = self._cursor.fetchone()[0]
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
pageCount = int(count / per_page) + 1
|
||||
|
||||
@ -115,8 +101,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
def search(
|
||||
self, query: str, page: int = 0, per_page: int = 10
|
||||
) -> PaginatedResults[T]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
with self._lock:
|
||||
self._cursor.execute(
|
||||
f"""SELECT item FROM {self._table_name} WHERE item LIKE ? LIMIT ? OFFSET ?;""",
|
||||
(f"%{query}%", per_page, page * per_page),
|
||||
@ -130,8 +115,6 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
(f"%{query}%",),
|
||||
)
|
||||
count = self._cursor.fetchone()[0]
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
pageCount = int(count / per_page) + 1
|
||||
|
||||
|
@ -40,7 +40,7 @@ dependencies = [
|
||||
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||
"compel==1.0.5",
|
||||
"datasets",
|
||||
"diffusers[torch]~=0.14",
|
||||
"diffusers[torch]==0.14",
|
||||
"dnspython==2.2.1",
|
||||
"einops",
|
||||
"eventlet",
|
||||
|
@ -7,7 +7,7 @@ from invokeai.app.services.processor import DefaultInvocationProcessor
|
||||
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
||||
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
|
||||
from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, LibraryGraph, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
|
||||
import pytest
|
||||
|
||||
|
||||
@ -28,6 +28,9 @@ def mock_services():
|
||||
images = None, # type: ignore
|
||||
latents = None, # type: ignore
|
||||
queue = MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](
|
||||
filename=sqlite_memory, table_name="graphs"
|
||||
),
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
||||
processor = DefaultInvocationProcessor(),
|
||||
restoration = None, # type: ignore
|
||||
|
@ -5,7 +5,7 @@ from invokeai.app.services.invocation_queue import MemoryInvocationQueue
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
|
||||
from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, LibraryGraph, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
|
||||
import pytest
|
||||
|
||||
|
||||
@ -26,6 +26,9 @@ def mock_services() -> InvocationServices:
|
||||
images = None, # type: ignore
|
||||
latents = None, # type: ignore
|
||||
queue = MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](
|
||||
filename=sqlite_memory, table_name="graphs"
|
||||
),
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
||||
processor = DefaultInvocationProcessor(),
|
||||
restoration = None, # type: ignore
|
||||
|
@ -1,9 +1,11 @@
|
||||
from invokeai.app.invocations.image import *
|
||||
|
||||
from .test_nodes import ListPassThroughInvocation, PromptTestInvocation
|
||||
from invokeai.app.services.graph import Edge, Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation
|
||||
from invokeai.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation
|
||||
from invokeai.app.invocations.upscale import UpscaleInvocation
|
||||
from invokeai.app.invocations.image import *
|
||||
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
|
||||
from invokeai.app.invocations.params import ParamIntInvocation
|
||||
from invokeai.app.services.default_graphs import create_text_to_image
|
||||
import pytest
|
||||
|
||||
|
||||
@ -417,6 +419,66 @@ def test_graph_gets_subgraph_node():
|
||||
assert result.id == '1'
|
||||
assert result == n1_1
|
||||
|
||||
|
||||
def test_graph_expands_subgraph():
|
||||
g = Graph()
|
||||
n1 = GraphInvocation(id = "1")
|
||||
n1.graph = Graph()
|
||||
|
||||
n1_1 = AddInvocation(id = "1", a = 1, b = 2)
|
||||
n1_2 = SubtractInvocation(id = "2", b = 3)
|
||||
n1.graph.add_node(n1_1)
|
||||
n1.graph.add_node(n1_2)
|
||||
n1.graph.add_edge(create_edge("1","a","2","a"))
|
||||
|
||||
g.add_node(n1)
|
||||
|
||||
n2 = AddInvocation(id = "2", b = 5)
|
||||
g.add_node(n2)
|
||||
g.add_edge(create_edge("1.2","a","2","a"))
|
||||
|
||||
dg = g.nx_graph_flat()
|
||||
assert set(dg.nodes) == set(['1.1', '1.2', '2'])
|
||||
assert set(dg.edges) == set([('1.1', '1.2'), ('1.2', '2')])
|
||||
|
||||
|
||||
def test_graph_subgraph_t2i():
|
||||
g = Graph()
|
||||
n1 = GraphInvocation(id = "1")
|
||||
|
||||
# Get text to image default graph
|
||||
lg = create_text_to_image()
|
||||
n1.graph = lg.graph
|
||||
|
||||
g.add_node(n1)
|
||||
|
||||
n2 = ParamIntInvocation(id = "2", a = 512)
|
||||
n3 = ParamIntInvocation(id = "3", a = 256)
|
||||
|
||||
g.add_node(n2)
|
||||
g.add_node(n3)
|
||||
|
||||
g.add_edge(create_edge("2","a","1.width","a"))
|
||||
g.add_edge(create_edge("3","a","1.height","a"))
|
||||
|
||||
n4 = ShowImageInvocation(id = "4")
|
||||
g.add_node(n4)
|
||||
g.add_edge(create_edge("1.5","image","4","image"))
|
||||
|
||||
# Validate
|
||||
dg = g.nx_graph_flat()
|
||||
assert set(dg.nodes) == set(['1.width', '1.height', '1.3', '1.4', '1.5', '2', '3', '4'])
|
||||
expected_edges = [(f'1.{e.source.node_id}',f'1.{e.destination.node_id}') for e in lg.graph.edges]
|
||||
expected_edges.extend([
|
||||
('2','1.width'),
|
||||
('3','1.height'),
|
||||
('1.5','4')
|
||||
])
|
||||
print(expected_edges)
|
||||
print(list(dg.edges))
|
||||
assert set(dg.edges) == set(expected_edges)
|
||||
|
||||
|
||||
def test_graph_fails_to_get_missing_subgraph_node():
|
||||
g = Graph()
|
||||
n1 = GraphInvocation(id = "1")
|
||||
|
Loading…
Reference in New Issue
Block a user