[nodes] Add subgraph library, subgraph usage in CLI, and fix subgraph execution (#3180)

* Add latent to latent (img2img equivalent)
Fix a CLI bug with multiple links per node

* Using "latents" instead of "latent"

* [nodes] In-progress implementation of graph library

* Add linking to CLI for graph nodes (still broken)

* Fix subgraph execution, fix subgraph linking in CLI

* Fix LatentsToLatents
This commit is contained in:
Kyle Schouviller 2023-04-13 23:41:06 -07:00 committed by GitHub
parent 024fd54d0b
commit 23d65e7162
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 471 additions and 103 deletions

View File

@ -3,12 +3,14 @@
import os
from argparse import Namespace
from ..services.default_graphs import create_system_graphs
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from ...backend import Globals
from ..services.model_manager_initializer import get_model_manager
from ..services.restoration_services import RestorationServices
from ..services.graph import GraphExecutionState
from ..services.graph import GraphExecutionState, LibraryGraph
from ..services.image_storage import DiskImageStorage
from ..services.invocation_queue import MemoryInvocationQueue
from ..services.invocation_services import InvocationServices
@ -69,6 +71,9 @@ class ApiDependencies:
latents=latents,
images=images,
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=db_location, table_name="graphs"
),
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions"
),
@ -76,6 +81,8 @@ class ApiDependencies:
restoration=RestorationServices(config),
)
create_system_graphs(services.graph_library)
ApiDependencies.invoker = Invoker(services)
@staticmethod

View File

@ -7,11 +7,40 @@ from pydantic import BaseModel, Field
import networkx as nx
import matplotlib.pyplot as plt
from ..models.image import ImageField
from ..services.graph import GraphExecutionState
from ..invocations.baseinvocation import BaseInvocation
from ..invocations.image import ImageField
from ..services.graph import GraphExecutionState, LibraryGraph, GraphInvocation, Edge
from ..services.invoker import Invoker
def add_field_argument(command_parser, name: str, field, default_override = None):
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory()
if get_origin(field.type_) == Literal:
allowed_values = get_args(field.type_)
allowed_types = set()
for val in allowed_values:
allowed_types.add(type(val))
allowed_types_list = list(allowed_types)
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
command_parser.add_argument(
f"--{name}",
dest=name,
type=field_type,
default=default,
choices=allowed_values,
help=field.field_info.description,
)
else:
command_parser.add_argument(
f"--{name}",
dest=name,
type=field.type_,
default=default,
help=field.field_info.description,
)
def add_parsers(
subparsers,
commands: list[type],
@ -36,30 +65,26 @@ def add_parsers(
if name in exclude_fields:
continue
if get_origin(field.type_) == Literal:
allowed_values = get_args(field.type_)
allowed_types = set()
for val in allowed_values:
allowed_types.add(type(val))
allowed_types_list = list(allowed_types)
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
add_field_argument(command_parser, name, field)
command_parser.add_argument(
f"--{name}",
dest=name,
type=field_type,
default=field.default if field.default_factory is None else field.default_factory(),
choices=allowed_values,
help=field.field_info.description,
)
else:
command_parser.add_argument(
f"--{name}",
dest=name,
type=field.type_,
default=field.default if field.default_factory is None else field.default_factory(),
help=field.field_info.description,
)
def add_graph_parsers(
subparsers,
graphs: list[LibraryGraph],
add_arguments: Callable[[argparse.ArgumentParser], None]|None = None
):
for graph in graphs:
command_parser = subparsers.add_parser(graph.name, help=graph.description)
if add_arguments is not None:
add_arguments(command_parser)
# Add arguments for inputs
for exposed_input in graph.exposed_inputs:
node = graph.graph.get_node(exposed_input.node_path)
field = node.__fields__[exposed_input.field]
default_override = getattr(node, exposed_input.field)
add_field_argument(command_parser, exposed_input.alias, field, default_override)
class CliContext:
@ -67,17 +92,38 @@ class CliContext:
session: GraphExecutionState
parser: argparse.ArgumentParser
defaults: dict[str, Any]
graph_nodes: dict[str, str]
nodes_added: list[str]
def __init__(self, invoker: Invoker, session: GraphExecutionState, parser: argparse.ArgumentParser):
self.invoker = invoker
self.session = session
self.parser = parser
self.defaults = dict()
self.graph_nodes = dict()
self.nodes_added = list()
def get_session(self):
self.session = self.invoker.services.graph_execution_manager.get(self.session.id)
return self.session
def reset(self):
self.session = self.invoker.create_execution_state()
self.graph_nodes = dict()
self.nodes_added = list()
# Leave defaults unchanged
def add_node(self, node: BaseInvocation):
self.get_session()
self.session.graph.add_node(node)
self.nodes_added.append(node.id)
self.invoker.services.graph_execution_manager.set(self.session)
def add_edge(self, edge: Edge):
self.get_session()
self.session.add_edge(edge)
self.invoker.services.graph_execution_manager.set(self.session)
class ExitCli(Exception):
"""Exception to exit the CLI"""

View File

@ -13,17 +13,20 @@ from typing import (
from pydantic import BaseModel
from pydantic.fields import Field
from .services.default_graphs import create_system_graphs
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from ..backend import Args
from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_graph_execution_history
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers, get_graph_execution_history
from .cli.completer import set_autocompleter
from .invocations import *
from .invocations.baseinvocation import BaseInvocation
from .services.events import EventServiceBase
from .services.model_manager_initializer import get_model_manager
from .services.restoration_services import RestorationServices
from .services.graph import Edge, EdgeConnection, GraphExecutionState, are_connection_types_compatible
from .services.graph import Edge, EdgeConnection, ExposedNodeInput, GraphExecutionState, GraphInvocation, LibraryGraph, are_connection_types_compatible
from .services.default_graphs import default_text_to_image_graph_id
from .services.image_storage import DiskImageStorage
from .services.invocation_queue import MemoryInvocationQueue
from .services.invocation_services import InvocationServices
@ -58,7 +61,7 @@ def add_invocation_args(command_parser):
)
def get_command_parser() -> argparse.ArgumentParser:
def get_command_parser(services: InvocationServices) -> argparse.ArgumentParser:
# Create invocation parser
parser = argparse.ArgumentParser()
@ -76,20 +79,72 @@ def get_command_parser() -> argparse.ArgumentParser:
commands = BaseCommand.get_all_subclasses()
add_parsers(subparsers, commands, exclude_fields=["type"])
# Create subparsers for exposed CLI graphs
# TODO: add a way to identify these graphs
text_to_image = services.graph_library.get(default_text_to_image_graph_id)
add_graph_parsers(subparsers, [text_to_image], add_arguments=add_invocation_args)
return parser
class NodeField():
alias: str
node_path: str
field: str
field_type: type
def __init__(self, alias: str, node_path: str, field: str, field_type: type):
self.alias = alias
self.node_path = node_path
self.field = field
self.field_type = field_type
def fields_from_type_hints(hints: dict[str, type], node_path: str) -> dict[str,NodeField]:
return {k:NodeField(alias=k, node_path=node_path, field=k, field_type=v) for k, v in hints.items()}
def get_node_input_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
"""Gets the node field for the specified field alias"""
exposed_input = next(e for e in graph.exposed_inputs if e.alias == field_alias)
node_type = type(graph.graph.get_node(exposed_input.node_path))
return NodeField(alias=exposed_input.alias, node_path=f'{node_id}.{exposed_input.node_path}', field=exposed_input.field, field_type=get_type_hints(node_type)[exposed_input.field])
def get_node_output_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
"""Gets the node field for the specified field alias"""
exposed_output = next(e for e in graph.exposed_outputs if e.alias == field_alias)
node_type = type(graph.graph.get_node(exposed_output.node_path))
node_output_type = node_type.get_output_type()
return NodeField(alias=exposed_output.alias, node_path=f'{node_id}.{exposed_output.node_path}', field=exposed_output.field, field_type=get_type_hints(node_output_type)[exposed_output.field])
def get_node_inputs(invocation: BaseInvocation, context: CliContext) -> dict[str, NodeField]:
"""Gets the inputs for the specified invocation from the context"""
node_type = type(invocation)
if node_type is not GraphInvocation:
return fields_from_type_hints(get_type_hints(node_type), invocation.id)
else:
graph: LibraryGraph = context.invoker.services.graph_library.get(context.graph_nodes[invocation.id])
return {e.alias: get_node_input_field(graph, e.alias, invocation.id) for e in graph.exposed_inputs}
def get_node_outputs(invocation: BaseInvocation, context: CliContext) -> dict[str, NodeField]:
"""Gets the outputs for the specified invocation from the context"""
node_type = type(invocation)
if node_type is not GraphInvocation:
return fields_from_type_hints(get_type_hints(node_type.get_output_type()), invocation.id)
else:
graph: LibraryGraph = context.invoker.services.graph_library.get(context.graph_nodes[invocation.id])
return {e.alias: get_node_output_field(graph, e.alias, invocation.id) for e in graph.exposed_outputs}
def generate_matching_edges(
a: BaseInvocation, b: BaseInvocation
a: BaseInvocation, b: BaseInvocation, context: CliContext
) -> list[Edge]:
"""Generates all possible edges between two invocations"""
atype = type(a)
btype = type(b)
aoutputtype = atype.get_output_type()
afields = get_type_hints(aoutputtype)
bfields = get_type_hints(btype)
afields = get_node_outputs(a, context)
bfields = get_node_inputs(b, context)
matching_fields = set(afields.keys()).intersection(bfields.keys())
@ -98,14 +153,14 @@ def generate_matching_edges(
matching_fields = matching_fields.difference(invalid_fields)
# Validate types
matching_fields = [f for f in matching_fields if are_connection_types_compatible(afields[f], bfields[f])]
matching_fields = [f for f in matching_fields if are_connection_types_compatible(afields[f].field_type, bfields[f].field_type)]
edges = [
Edge(
source=EdgeConnection(node_id=a.id, field=field),
destination=EdgeConnection(node_id=b.id, field=field)
source=EdgeConnection(node_id=afields[alias].node_path, field=afields[alias].field),
destination=EdgeConnection(node_id=bfields[alias].node_path, field=bfields[alias].field)
)
for field in matching_fields
for alias in matching_fields
]
return edges
@ -158,6 +213,9 @@ def invoke_cli():
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
images=DiskImageStorage(f'{output_folder}/images'),
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=db_location, table_name="graphs"
),
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions"
),
@ -165,9 +223,12 @@ def invoke_cli():
restoration=RestorationServices(config),
)
system_graphs = create_system_graphs(services.graph_library)
system_graph_names = set([g.name for g in system_graphs])
invoker = Invoker(services)
session: GraphExecutionState = invoker.create_execution_state()
parser = get_command_parser()
parser = get_command_parser(services)
re_negid = re.compile('^-[0-9]+$')
@ -185,11 +246,12 @@ def invoke_cli():
try:
# Refresh the state of the session
history = list(get_graph_execution_history(context.session))
#history = list(get_graph_execution_history(context.session))
history = list(reversed(context.nodes_added))
# Split the command for piping
cmds = cmd_input.split("|")
start_id = len(history)
start_id = len(context.nodes_added)
current_id = start_id
new_invocations = list()
for cmd in cmds:
@ -205,8 +267,24 @@ def invoke_cli():
args[field_name] = field_default
# Parse invocation
args["id"] = current_id
command = CliCommand(command=args)
command: CliCommand = None # type:ignore
system_graph: LibraryGraph|None = None
if args['type'] in system_graph_names:
system_graph = next(filter(lambda g: g.name == args['type'], system_graphs))
invocation = GraphInvocation(graph=system_graph.graph, id=str(current_id))
for exposed_input in system_graph.exposed_inputs:
if exposed_input.alias in args:
node = invocation.graph.get_node(exposed_input.node_path)
field = exposed_input.field
setattr(node, field, args[exposed_input.alias])
command = CliCommand(command = invocation)
context.graph_nodes[invocation.id] = system_graph.id
else:
args["id"] = current_id
command = CliCommand(command=args)
if command is None:
continue
# Run any CLI commands immediately
if isinstance(command.command, BaseCommand):
@ -217,6 +295,7 @@ def invoke_cli():
command.command.run(context)
continue
# TODO: handle linking with library graphs
# Pipe previous command output (if there was a previous command)
edges: list[Edge] = list()
if len(history) > 0 or current_id != start_id:
@ -229,7 +308,7 @@ def invoke_cli():
else context.session.graph.get_node(from_id)
)
matching_edges = generate_matching_edges(
from_node, command.command
from_node, command.command, context
)
edges.extend(matching_edges)
@ -242,7 +321,7 @@ def invoke_cli():
link_node = context.session.graph.get_node(node_id)
matching_edges = generate_matching_edges(
link_node, command.command
link_node, command.command, context
)
matching_destinations = [e.destination for e in matching_edges]
edges = [e for e in edges if e.destination not in matching_destinations]
@ -256,12 +335,14 @@ def invoke_cli():
if re_negid.match(node_id):
node_id = str(current_id + int(node_id))
# TODO: handle missing input/output
node_output = get_node_outputs(context.session.graph.get_node(node_id), context)[link[1]]
node_input = get_node_inputs(command.command, context)[link[2]]
edges.append(
Edge(
source=EdgeConnection(node_id=node_id, field=link[1]),
destination=EdgeConnection(
node_id=command.command.id, field=link[2]
)
source=EdgeConnection(node_id=node_output.node_path, field=node_output.field),
destination=EdgeConnection(node_id=node_input.node_path, field=node_input.field)
)
)
@ -270,10 +351,10 @@ def invoke_cli():
current_id = current_id + 1
# Add the node to the session
context.session.add_node(command.command)
context.add_node(command.command)
for edge in edges:
print(edge)
context.session.add_edge(edge)
context.add_edge(edge)
# Execute all remaining nodes
invoke_all(context)
@ -285,7 +366,7 @@ def invoke_cli():
except SessionError:
# Start a new session
print("Session error: creating a new session")
context.session = context.invoker.create_execution_state()
context.reset()
except ExitCli:
break

View File

@ -1,5 +1,6 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import random
from typing import Literal, Optional
from pydantic import BaseModel, Field
import torch
@ -99,13 +100,17 @@ def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_c
return x
def random_seed():
return random.randint(0, np.iinfo(np.uint32).max)
class NoiseInvocation(BaseInvocation):
"""Generates latent noise."""
type: Literal["noise"] = "noise"
# Inputs
seed: int = Field(default=0, ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", )
seed: int = Field(ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", default_factory=random_seed)
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting noise", )
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting noise", )
@ -313,6 +318,56 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
)
class LatentsToLatentsInvocation(TextToLatentsInvocation):
"""Generates latents using latents as base image."""
type: Literal["l2l"] = "l2l"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
strength: float = Field(default=0.5, description="The strength of the latents to use")
def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name)
latent = context.services.latents.get(self.latents.latents_name)
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, state)
model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(model)
# TODO: Verify the noise is the right size
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
latent, device=model.device, dtype=latent.dtype
)
timesteps, _ = model.get_img2img_timesteps(
self.steps,
self.strength,
device=model.device,
)
result_latents, result_attention_map_saver = model.latents_from_embeddings(
latents=initial_latents,
timesteps=timesteps,
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
callback=step_callback
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, result_latents)
return LatentsOutput(
latents=LatentsField(latents_name=name)
)
# Latent to image
class LatentsToImageInvocation(BaseInvocation):
"""Generates an image from latents."""

View File

@ -0,0 +1,18 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal
from pydantic import Field
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from .math import IntOutput
# Pass-through parameter nodes - used by subgraphs
class ParamIntInvocation(BaseInvocation):
"""An integer parameter"""
#fmt: off
type: Literal["param_int"] = "param_int"
a: int = Field(default=0, description="The integer value")
#fmt: on
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a)

View File

@ -0,0 +1,56 @@
from ..invocations.latent import LatentsToImageInvocation, NoiseInvocation, TextToLatentsInvocation
from ..invocations.params import ParamIntInvocation
from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph
from .item_storage import ItemStorageABC
default_text_to_image_graph_id = '539b2af5-2b4d-4d8c-8071-e54a3255fc74'
def create_text_to_image() -> LibraryGraph:
return LibraryGraph(
id=default_text_to_image_graph_id,
name='t2i',
description='Converts text to an image',
graph=Graph(
nodes={
'width': ParamIntInvocation(id='width', a=512),
'height': ParamIntInvocation(id='height', a=512),
'3': NoiseInvocation(id='3'),
'4': TextToLatentsInvocation(id='4'),
'5': LatentsToImageInvocation(id='5')
},
edges=[
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')),
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='3', field='height')),
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='4', field='width')),
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='4', field='height')),
Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='4', field='noise')),
Edge(source=EdgeConnection(node_id='4', field='latents'), destination=EdgeConnection(node_id='5', field='latents')),
]
),
exposed_inputs=[
ExposedNodeInput(node_path='4', field='prompt', alias='prompt'),
ExposedNodeInput(node_path='width', field='a', alias='width'),
ExposedNodeInput(node_path='height', field='a', alias='height')
],
exposed_outputs=[
ExposedNodeOutput(node_path='5', field='image', alias='image')
])
def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[LibraryGraph]:
"""Creates the default system graphs, or adds new versions if the old ones don't match"""
graphs: list[LibraryGraph] = list()
text_to_image = graph_library.get(default_text_to_image_graph_id)
# TODO: Check if the graph is the same as the default one, and if not, update it
#if text_to_image is None:
text_to_image = create_text_to_image()
graph_library.set(text_to_image)
graphs.append(text_to_image)
return graphs

View File

@ -17,7 +17,7 @@ from typing import (
)
import networkx as nx
from pydantic import BaseModel, validator
from pydantic import BaseModel, root_validator, validator
from pydantic.fields import Field
from ..invocations import *
@ -283,7 +283,8 @@ class Graph(BaseModel):
:raises InvalidEdgeError: the provided edge is invalid.
"""
if self._is_edge_valid(edge) and edge not in self.edges:
self._validate_edge(edge)
if edge not in self.edges:
self.edges.append(edge)
else:
raise InvalidEdgeError()
@ -354,7 +355,7 @@ class Graph(BaseModel):
return True
def _is_edge_valid(self, edge: Edge) -> bool:
def _validate_edge(self, edge: Edge):
"""Validates that a new edge doesn't create a cycle in the graph"""
# Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly)
@ -362,54 +363,53 @@ class Graph(BaseModel):
from_node = self.get_node(edge.source.node_id)
to_node = self.get_node(edge.destination.node_id)
except NodeNotFoundError:
return False
raise InvalidEdgeError("One or both nodes don't exist")
# Validate that an edge to this node+field doesn't already exist
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation):
return False
raise InvalidEdgeError(f'Edge to node {edge.destination.node_id} field {edge.destination.field} already exists')
# Validate that no cycles would be created
g = self.nx_graph_flat()
g.add_edge(edge.source.node_id, edge.destination.node_id)
if not nx.is_directed_acyclic_graph(g):
return False
raise InvalidEdgeError(f'Edge creates a cycle in the graph')
# Validate that the field types are compatible
if not are_connections_compatible(
from_node, edge.source.field, to_node, edge.destination.field
):
return False
raise InvalidEdgeError(f'Fields are incompatible')
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
if not self._is_iterator_connection_valid(
edge.destination.node_id, new_input=edge.source
):
return False
raise InvalidEdgeError(f'Iterator input type does not match iterator output type')
# Validate if iterator input type matches output type (if this edge results in both being set)
if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
if not self._is_iterator_connection_valid(
edge.source.node_id, new_output=edge.destination
):
return False
raise InvalidEdgeError(f'Iterator output type does not match iterator input type')
# Validate if collector input type matches output type (if this edge results in both being set)
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
if not self._is_collector_connection_valid(
edge.destination.node_id, new_input=edge.source
):
return False
raise InvalidEdgeError(f'Collector output type does not match collector input type')
# Validate if collector output type matches input type (if this edge results in both being set)
if isinstance(from_node, CollectInvocation) and edge.source.field == "collection":
if not self._is_collector_connection_valid(
edge.source.node_id, new_output=edge.destination
):
return False
raise InvalidEdgeError(f'Collector input type does not match collector output type')
return True
def has_node(self, node_path: str) -> bool:
"""Determines whether or not a node exists in the graph."""
@ -733,7 +733,7 @@ class Graph(BaseModel):
for sgn in (
gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation)
):
sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix))
g = sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix))
# TODO: figure out if iteration nodes need to be expanded
@ -858,7 +858,8 @@ class GraphExecutionState(BaseModel):
def is_complete(self) -> bool:
"""Returns true if the graph is complete"""
return self.has_error() or all((k in self.executed for k in self.graph.nodes))
node_ids = set(self.graph.nx_graph_flat().nodes)
return self.has_error() or all((k in self.executed for k in node_ids))
def has_error(self) -> bool:
"""Returns true if the graph has any errors"""
@ -946,11 +947,11 @@ class GraphExecutionState(BaseModel):
def _iterator_graph(self) -> nx.DiGraph:
"""Gets a DiGraph with edges to collectors removed so an ancestor search produces all active iterators for any node"""
g = self.graph.nx_graph()
g = self.graph.nx_graph_flat()
collectors = (
n
for n in self.graph.nodes
if isinstance(self.graph.nodes[n], CollectInvocation)
if isinstance(self.graph.get_node(n), CollectInvocation)
)
for c in collectors:
g.remove_edges_from(list(g.in_edges(c)))
@ -962,7 +963,7 @@ class GraphExecutionState(BaseModel):
iterators = [
n
for n in nx.ancestors(g, node_id)
if isinstance(self.graph.nodes[n], IterateInvocation)
if isinstance(self.graph.get_node(n), IterateInvocation)
]
return iterators
@ -1098,7 +1099,9 @@ class GraphExecutionState(BaseModel):
# TODO: Add API for modifying underlying graph that checks if the change will be valid given the current execution state
def _is_edge_valid(self, edge: Edge) -> bool:
if not self._is_edge_valid(edge):
try:
self.graph._validate_edge(edge)
except InvalidEdgeError:
return False
# Invalid if destination has already been prepared or executed
@ -1144,4 +1147,52 @@ class GraphExecutionState(BaseModel):
self.graph.delete_edge(edge)
class ExposedNodeInput(BaseModel):
node_path: str = Field(description="The node path to the node with the input")
field: str = Field(description="The field name of the input")
alias: str = Field(description="The alias of the input")
class ExposedNodeOutput(BaseModel):
node_path: str = Field(description="The node path to the node with the output")
field: str = Field(description="The field name of the output")
alias: str = Field(description="The alias of the output")
class LibraryGraph(BaseModel):
id: str = Field(description="The unique identifier for this library graph", default_factory=uuid.uuid4)
graph: Graph = Field(description="The graph")
name: str = Field(description="The name of the graph")
description: str = Field(description="The description of the graph")
exposed_inputs: list[ExposedNodeInput] = Field(description="The inputs exposed by this graph", default_factory=list)
exposed_outputs: list[ExposedNodeOutput] = Field(description="The outputs exposed by this graph", default_factory=list)
@validator('exposed_inputs', 'exposed_outputs')
def validate_exposed_aliases(cls, v):
if len(v) != len(set(i.alias for i in v)):
raise ValueError("Duplicate exposed alias")
return v
@root_validator
def validate_exposed_nodes(cls, values):
graph = values['graph']
# Validate exposed inputs
for exposed_input in values['exposed_inputs']:
if not graph.has_node(exposed_input.node_path):
raise ValueError(f"Exposed input node {exposed_input.node_path} does not exist")
node = graph.get_node(exposed_input.node_path)
if get_input_field(node, exposed_input.field) is None:
raise ValueError(f"Exposed input field {exposed_input.field} does not exist on node {exposed_input.node_path}")
# Validate exposed outputs
for exposed_output in values['exposed_outputs']:
if not graph.has_node(exposed_output.node_path):
raise ValueError(f"Exposed output node {exposed_output.node_path} does not exist")
node = graph.get_node(exposed_output.node_path)
if get_output_field(node, exposed_output.field) is None:
raise ValueError(f"Exposed output field {exposed_output.field} does not exist on node {exposed_output.node_path}")
return values
GraphInvocation.update_forward_refs()

View File

@ -19,6 +19,7 @@ class InvocationServices:
restoration: RestorationServices
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
graph_library: ItemStorageABC["LibraryGraph"]
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
processor: "InvocationProcessorABC"
@ -29,6 +30,7 @@ class InvocationServices:
latents: LatentsStorageBase,
images: ImageStorageBase,
queue: InvocationQueueABC,
graph_library: ItemStorageABC["LibraryGraph"],
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
processor: "InvocationProcessorABC",
restoration: RestorationServices,
@ -38,6 +40,7 @@ class InvocationServices:
self.latents = latents
self.images = images
self.queue = queue
self.graph_library = graph_library
self.graph_execution_manager = graph_execution_manager
self.processor = processor
self.restoration = restoration

View File

@ -35,8 +35,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
self._create_table()
def _create_table(self):
try:
self._lock.acquire()
with self._lock:
self._cursor.execute(
f"""CREATE TABLE IF NOT EXISTS {self._table_name} (
item TEXT,
@ -45,34 +44,27 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
self._cursor.execute(
f"""CREATE UNIQUE INDEX IF NOT EXISTS {self._table_name}_id ON {self._table_name}(id);"""
)
finally:
self._lock.release()
self._conn.commit()
def _parse_item(self, item: str) -> T:
item_type = get_args(self.__orig_class__)[0]
return parse_raw_as(item_type, item)
def set(self, item: T):
try:
self._lock.acquire()
with self._lock:
self._cursor.execute(
f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""",
(item.json(),),
)
self._conn.commit()
finally:
self._lock.release()
self._on_changed(item)
def get(self, id: str) -> Union[T, None]:
try:
self._lock.acquire()
with self._lock:
self._cursor.execute(
f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)
)
result = self._cursor.fetchone()
finally:
self._lock.release()
if not result:
return None
@ -80,19 +72,15 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
return self._parse_item(result[0])
def delete(self, id: str):
try:
self._lock.acquire()
with self._lock:
self._cursor.execute(
f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),)
)
self._conn.commit()
finally:
self._lock.release()
self._on_deleted(id)
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
try:
self._lock.acquire()
with self._lock:
self._cursor.execute(
f"""SELECT item FROM {self._table_name} LIMIT ? OFFSET ?;""",
(per_page, page * per_page),
@ -103,8 +91,6 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""")
count = self._cursor.fetchone()[0]
finally:
self._lock.release()
pageCount = int(count / per_page) + 1
@ -115,8 +101,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
def search(
self, query: str, page: int = 0, per_page: int = 10
) -> PaginatedResults[T]:
try:
self._lock.acquire()
with self._lock:
self._cursor.execute(
f"""SELECT item FROM {self._table_name} WHERE item LIKE ? LIMIT ? OFFSET ?;""",
(f"%{query}%", per_page, page * per_page),
@ -130,8 +115,6 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
(f"%{query}%",),
)
count = self._cursor.fetchone()[0]
finally:
self._lock.release()
pageCount = int(count / per_page) + 1

View File

@ -40,7 +40,7 @@ dependencies = [
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
"compel==1.0.5",
"datasets",
"diffusers[torch]~=0.14",
"diffusers[torch]==0.14",
"dnspython==2.2.1",
"einops",
"eventlet",

View File

@ -7,7 +7,7 @@ from invokeai.app.services.processor import DefaultInvocationProcessor
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, LibraryGraph, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
import pytest
@ -28,6 +28,9 @@ def mock_services():
images = None, # type: ignore
latents = None, # type: ignore
queue = MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=sqlite_memory, table_name="graphs"
),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor(),
restoration = None, # type: ignore

View File

@ -5,7 +5,7 @@ from invokeai.app.services.invocation_queue import MemoryInvocationQueue
from invokeai.app.services.invoker import Invoker
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, LibraryGraph, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
import pytest
@ -26,6 +26,9 @@ def mock_services() -> InvocationServices:
images = None, # type: ignore
latents = None, # type: ignore
queue = MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=sqlite_memory, table_name="graphs"
),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor(),
restoration = None, # type: ignore

View File

@ -1,9 +1,11 @@
from invokeai.app.invocations.image import *
from .test_nodes import ListPassThroughInvocation, PromptTestInvocation
from invokeai.app.services.graph import Edge, Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation
from invokeai.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation
from invokeai.app.invocations.upscale import UpscaleInvocation
from invokeai.app.invocations.image import *
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
from invokeai.app.invocations.params import ParamIntInvocation
from invokeai.app.services.default_graphs import create_text_to_image
import pytest
@ -417,6 +419,66 @@ def test_graph_gets_subgraph_node():
assert result.id == '1'
assert result == n1_1
def test_graph_expands_subgraph():
g = Graph()
n1 = GraphInvocation(id = "1")
n1.graph = Graph()
n1_1 = AddInvocation(id = "1", a = 1, b = 2)
n1_2 = SubtractInvocation(id = "2", b = 3)
n1.graph.add_node(n1_1)
n1.graph.add_node(n1_2)
n1.graph.add_edge(create_edge("1","a","2","a"))
g.add_node(n1)
n2 = AddInvocation(id = "2", b = 5)
g.add_node(n2)
g.add_edge(create_edge("1.2","a","2","a"))
dg = g.nx_graph_flat()
assert set(dg.nodes) == set(['1.1', '1.2', '2'])
assert set(dg.edges) == set([('1.1', '1.2'), ('1.2', '2')])
def test_graph_subgraph_t2i():
g = Graph()
n1 = GraphInvocation(id = "1")
# Get text to image default graph
lg = create_text_to_image()
n1.graph = lg.graph
g.add_node(n1)
n2 = ParamIntInvocation(id = "2", a = 512)
n3 = ParamIntInvocation(id = "3", a = 256)
g.add_node(n2)
g.add_node(n3)
g.add_edge(create_edge("2","a","1.width","a"))
g.add_edge(create_edge("3","a","1.height","a"))
n4 = ShowImageInvocation(id = "4")
g.add_node(n4)
g.add_edge(create_edge("1.5","image","4","image"))
# Validate
dg = g.nx_graph_flat()
assert set(dg.nodes) == set(['1.width', '1.height', '1.3', '1.4', '1.5', '2', '3', '4'])
expected_edges = [(f'1.{e.source.node_id}',f'1.{e.destination.node_id}') for e in lg.graph.edges]
expected_edges.extend([
('2','1.width'),
('3','1.height'),
('1.5','4')
])
print(expected_edges)
print(list(dg.edges))
assert set(dg.edges) == set(expected_edges)
def test_graph_fails_to_get_missing_subgraph_node():
g = Graph()
n1 = GraphInvocation(id = "1")