[nodes] Add subgraph library, subgraph usage in CLI, and fix subgraph execution (#3180)

* Add latent to latent (img2img equivalent)
Fix a CLI bug with multiple links per node

* Using "latents" instead of "latent"

* [nodes] In-progress implementation of graph library

* Add linking to CLI for graph nodes (still broken)

* Fix subgraph execution, fix subgraph linking in CLI

* Fix LatentsToLatents
This commit is contained in:
Kyle Schouviller 2023-04-13 23:41:06 -07:00 committed by GitHub
parent 024fd54d0b
commit 23d65e7162
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 471 additions and 103 deletions

View File

@ -3,12 +3,14 @@
import os import os
from argparse import Namespace from argparse import Namespace
from ..services.default_graphs import create_system_graphs
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from ...backend import Globals from ...backend import Globals
from ..services.model_manager_initializer import get_model_manager from ..services.model_manager_initializer import get_model_manager
from ..services.restoration_services import RestorationServices from ..services.restoration_services import RestorationServices
from ..services.graph import GraphExecutionState from ..services.graph import GraphExecutionState, LibraryGraph
from ..services.image_storage import DiskImageStorage from ..services.image_storage import DiskImageStorage
from ..services.invocation_queue import MemoryInvocationQueue from ..services.invocation_queue import MemoryInvocationQueue
from ..services.invocation_services import InvocationServices from ..services.invocation_services import InvocationServices
@ -69,6 +71,9 @@ class ApiDependencies:
latents=latents, latents=latents,
images=images, images=images,
queue=MemoryInvocationQueue(), queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=db_location, table_name="graphs"
),
graph_execution_manager=SqliteItemStorage[GraphExecutionState]( graph_execution_manager=SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions" filename=db_location, table_name="graph_executions"
), ),
@ -76,6 +81,8 @@ class ApiDependencies:
restoration=RestorationServices(config), restoration=RestorationServices(config),
) )
create_system_graphs(services.graph_library)
ApiDependencies.invoker = Invoker(services) ApiDependencies.invoker = Invoker(services)
@staticmethod @staticmethod

View File

@ -7,11 +7,40 @@ from pydantic import BaseModel, Field
import networkx as nx import networkx as nx
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from ..models.image import ImageField from ..invocations.baseinvocation import BaseInvocation
from ..services.graph import GraphExecutionState from ..invocations.image import ImageField
from ..services.graph import GraphExecutionState, LibraryGraph, GraphInvocation, Edge
from ..services.invoker import Invoker from ..services.invoker import Invoker
def add_field_argument(command_parser, name: str, field, default_override = None):
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory()
if get_origin(field.type_) == Literal:
allowed_values = get_args(field.type_)
allowed_types = set()
for val in allowed_values:
allowed_types.add(type(val))
allowed_types_list = list(allowed_types)
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
command_parser.add_argument(
f"--{name}",
dest=name,
type=field_type,
default=default,
choices=allowed_values,
help=field.field_info.description,
)
else:
command_parser.add_argument(
f"--{name}",
dest=name,
type=field.type_,
default=default,
help=field.field_info.description,
)
def add_parsers( def add_parsers(
subparsers, subparsers,
commands: list[type], commands: list[type],
@ -36,30 +65,26 @@ def add_parsers(
if name in exclude_fields: if name in exclude_fields:
continue continue
if get_origin(field.type_) == Literal: add_field_argument(command_parser, name, field)
allowed_values = get_args(field.type_)
allowed_types = set()
for val in allowed_values:
allowed_types.add(type(val))
allowed_types_list = list(allowed_types)
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
command_parser.add_argument(
f"--{name}", def add_graph_parsers(
dest=name, subparsers,
type=field_type, graphs: list[LibraryGraph],
default=field.default if field.default_factory is None else field.default_factory(), add_arguments: Callable[[argparse.ArgumentParser], None]|None = None
choices=allowed_values, ):
help=field.field_info.description, for graph in graphs:
) command_parser = subparsers.add_parser(graph.name, help=graph.description)
else:
command_parser.add_argument( if add_arguments is not None:
f"--{name}", add_arguments(command_parser)
dest=name,
type=field.type_, # Add arguments for inputs
default=field.default if field.default_factory is None else field.default_factory(), for exposed_input in graph.exposed_inputs:
help=field.field_info.description, node = graph.graph.get_node(exposed_input.node_path)
) field = node.__fields__[exposed_input.field]
default_override = getattr(node, exposed_input.field)
add_field_argument(command_parser, exposed_input.alias, field, default_override)
class CliContext: class CliContext:
@ -67,17 +92,38 @@ class CliContext:
session: GraphExecutionState session: GraphExecutionState
parser: argparse.ArgumentParser parser: argparse.ArgumentParser
defaults: dict[str, Any] defaults: dict[str, Any]
graph_nodes: dict[str, str]
nodes_added: list[str]
def __init__(self, invoker: Invoker, session: GraphExecutionState, parser: argparse.ArgumentParser): def __init__(self, invoker: Invoker, session: GraphExecutionState, parser: argparse.ArgumentParser):
self.invoker = invoker self.invoker = invoker
self.session = session self.session = session
self.parser = parser self.parser = parser
self.defaults = dict() self.defaults = dict()
self.graph_nodes = dict()
self.nodes_added = list()
def get_session(self): def get_session(self):
self.session = self.invoker.services.graph_execution_manager.get(self.session.id) self.session = self.invoker.services.graph_execution_manager.get(self.session.id)
return self.session return self.session
def reset(self):
self.session = self.invoker.create_execution_state()
self.graph_nodes = dict()
self.nodes_added = list()
# Leave defaults unchanged
def add_node(self, node: BaseInvocation):
self.get_session()
self.session.graph.add_node(node)
self.nodes_added.append(node.id)
self.invoker.services.graph_execution_manager.set(self.session)
def add_edge(self, edge: Edge):
self.get_session()
self.session.add_edge(edge)
self.invoker.services.graph_execution_manager.set(self.session)
class ExitCli(Exception): class ExitCli(Exception):
"""Exception to exit the CLI""" """Exception to exit the CLI"""

View File

@ -13,17 +13,20 @@ from typing import (
from pydantic import BaseModel from pydantic import BaseModel
from pydantic.fields import Field from pydantic.fields import Field
from .services.default_graphs import create_system_graphs
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from ..backend import Args from ..backend import Args
from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_graph_execution_history from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers, get_graph_execution_history
from .cli.completer import set_autocompleter from .cli.completer import set_autocompleter
from .invocations import * from .invocations import *
from .invocations.baseinvocation import BaseInvocation from .invocations.baseinvocation import BaseInvocation
from .services.events import EventServiceBase from .services.events import EventServiceBase
from .services.model_manager_initializer import get_model_manager from .services.model_manager_initializer import get_model_manager
from .services.restoration_services import RestorationServices from .services.restoration_services import RestorationServices
from .services.graph import Edge, EdgeConnection, GraphExecutionState, are_connection_types_compatible from .services.graph import Edge, EdgeConnection, ExposedNodeInput, GraphExecutionState, GraphInvocation, LibraryGraph, are_connection_types_compatible
from .services.default_graphs import default_text_to_image_graph_id
from .services.image_storage import DiskImageStorage from .services.image_storage import DiskImageStorage
from .services.invocation_queue import MemoryInvocationQueue from .services.invocation_queue import MemoryInvocationQueue
from .services.invocation_services import InvocationServices from .services.invocation_services import InvocationServices
@ -58,7 +61,7 @@ def add_invocation_args(command_parser):
) )
def get_command_parser() -> argparse.ArgumentParser: def get_command_parser(services: InvocationServices) -> argparse.ArgumentParser:
# Create invocation parser # Create invocation parser
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -76,20 +79,72 @@ def get_command_parser() -> argparse.ArgumentParser:
commands = BaseCommand.get_all_subclasses() commands = BaseCommand.get_all_subclasses()
add_parsers(subparsers, commands, exclude_fields=["type"]) add_parsers(subparsers, commands, exclude_fields=["type"])
# Create subparsers for exposed CLI graphs
# TODO: add a way to identify these graphs
text_to_image = services.graph_library.get(default_text_to_image_graph_id)
add_graph_parsers(subparsers, [text_to_image], add_arguments=add_invocation_args)
return parser return parser
class NodeField():
alias: str
node_path: str
field: str
field_type: type
def __init__(self, alias: str, node_path: str, field: str, field_type: type):
self.alias = alias
self.node_path = node_path
self.field = field
self.field_type = field_type
def fields_from_type_hints(hints: dict[str, type], node_path: str) -> dict[str,NodeField]:
return {k:NodeField(alias=k, node_path=node_path, field=k, field_type=v) for k, v in hints.items()}
def get_node_input_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
"""Gets the node field for the specified field alias"""
exposed_input = next(e for e in graph.exposed_inputs if e.alias == field_alias)
node_type = type(graph.graph.get_node(exposed_input.node_path))
return NodeField(alias=exposed_input.alias, node_path=f'{node_id}.{exposed_input.node_path}', field=exposed_input.field, field_type=get_type_hints(node_type)[exposed_input.field])
def get_node_output_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
"""Gets the node field for the specified field alias"""
exposed_output = next(e for e in graph.exposed_outputs if e.alias == field_alias)
node_type = type(graph.graph.get_node(exposed_output.node_path))
node_output_type = node_type.get_output_type()
return NodeField(alias=exposed_output.alias, node_path=f'{node_id}.{exposed_output.node_path}', field=exposed_output.field, field_type=get_type_hints(node_output_type)[exposed_output.field])
def get_node_inputs(invocation: BaseInvocation, context: CliContext) -> dict[str, NodeField]:
"""Gets the inputs for the specified invocation from the context"""
node_type = type(invocation)
if node_type is not GraphInvocation:
return fields_from_type_hints(get_type_hints(node_type), invocation.id)
else:
graph: LibraryGraph = context.invoker.services.graph_library.get(context.graph_nodes[invocation.id])
return {e.alias: get_node_input_field(graph, e.alias, invocation.id) for e in graph.exposed_inputs}
def get_node_outputs(invocation: BaseInvocation, context: CliContext) -> dict[str, NodeField]:
"""Gets the outputs for the specified invocation from the context"""
node_type = type(invocation)
if node_type is not GraphInvocation:
return fields_from_type_hints(get_type_hints(node_type.get_output_type()), invocation.id)
else:
graph: LibraryGraph = context.invoker.services.graph_library.get(context.graph_nodes[invocation.id])
return {e.alias: get_node_output_field(graph, e.alias, invocation.id) for e in graph.exposed_outputs}
def generate_matching_edges( def generate_matching_edges(
a: BaseInvocation, b: BaseInvocation a: BaseInvocation, b: BaseInvocation, context: CliContext
) -> list[Edge]: ) -> list[Edge]:
"""Generates all possible edges between two invocations""" """Generates all possible edges between two invocations"""
atype = type(a) afields = get_node_outputs(a, context)
btype = type(b) bfields = get_node_inputs(b, context)
aoutputtype = atype.get_output_type()
afields = get_type_hints(aoutputtype)
bfields = get_type_hints(btype)
matching_fields = set(afields.keys()).intersection(bfields.keys()) matching_fields = set(afields.keys()).intersection(bfields.keys())
@ -98,14 +153,14 @@ def generate_matching_edges(
matching_fields = matching_fields.difference(invalid_fields) matching_fields = matching_fields.difference(invalid_fields)
# Validate types # Validate types
matching_fields = [f for f in matching_fields if are_connection_types_compatible(afields[f], bfields[f])] matching_fields = [f for f in matching_fields if are_connection_types_compatible(afields[f].field_type, bfields[f].field_type)]
edges = [ edges = [
Edge( Edge(
source=EdgeConnection(node_id=a.id, field=field), source=EdgeConnection(node_id=afields[alias].node_path, field=afields[alias].field),
destination=EdgeConnection(node_id=b.id, field=field) destination=EdgeConnection(node_id=bfields[alias].node_path, field=bfields[alias].field)
) )
for field in matching_fields for alias in matching_fields
] ]
return edges return edges
@ -158,6 +213,9 @@ def invoke_cli():
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')), latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
images=DiskImageStorage(f'{output_folder}/images'), images=DiskImageStorage(f'{output_folder}/images'),
queue=MemoryInvocationQueue(), queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=db_location, table_name="graphs"
),
graph_execution_manager=SqliteItemStorage[GraphExecutionState]( graph_execution_manager=SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions" filename=db_location, table_name="graph_executions"
), ),
@ -165,9 +223,12 @@ def invoke_cli():
restoration=RestorationServices(config), restoration=RestorationServices(config),
) )
system_graphs = create_system_graphs(services.graph_library)
system_graph_names = set([g.name for g in system_graphs])
invoker = Invoker(services) invoker = Invoker(services)
session: GraphExecutionState = invoker.create_execution_state() session: GraphExecutionState = invoker.create_execution_state()
parser = get_command_parser() parser = get_command_parser(services)
re_negid = re.compile('^-[0-9]+$') re_negid = re.compile('^-[0-9]+$')
@ -185,11 +246,12 @@ def invoke_cli():
try: try:
# Refresh the state of the session # Refresh the state of the session
history = list(get_graph_execution_history(context.session)) #history = list(get_graph_execution_history(context.session))
history = list(reversed(context.nodes_added))
# Split the command for piping # Split the command for piping
cmds = cmd_input.split("|") cmds = cmd_input.split("|")
start_id = len(history) start_id = len(context.nodes_added)
current_id = start_id current_id = start_id
new_invocations = list() new_invocations = list()
for cmd in cmds: for cmd in cmds:
@ -205,8 +267,24 @@ def invoke_cli():
args[field_name] = field_default args[field_name] = field_default
# Parse invocation # Parse invocation
args["id"] = current_id command: CliCommand = None # type:ignore
command = CliCommand(command=args) system_graph: LibraryGraph|None = None
if args['type'] in system_graph_names:
system_graph = next(filter(lambda g: g.name == args['type'], system_graphs))
invocation = GraphInvocation(graph=system_graph.graph, id=str(current_id))
for exposed_input in system_graph.exposed_inputs:
if exposed_input.alias in args:
node = invocation.graph.get_node(exposed_input.node_path)
field = exposed_input.field
setattr(node, field, args[exposed_input.alias])
command = CliCommand(command = invocation)
context.graph_nodes[invocation.id] = system_graph.id
else:
args["id"] = current_id
command = CliCommand(command=args)
if command is None:
continue
# Run any CLI commands immediately # Run any CLI commands immediately
if isinstance(command.command, BaseCommand): if isinstance(command.command, BaseCommand):
@ -217,6 +295,7 @@ def invoke_cli():
command.command.run(context) command.command.run(context)
continue continue
# TODO: handle linking with library graphs
# Pipe previous command output (if there was a previous command) # Pipe previous command output (if there was a previous command)
edges: list[Edge] = list() edges: list[Edge] = list()
if len(history) > 0 or current_id != start_id: if len(history) > 0 or current_id != start_id:
@ -229,7 +308,7 @@ def invoke_cli():
else context.session.graph.get_node(from_id) else context.session.graph.get_node(from_id)
) )
matching_edges = generate_matching_edges( matching_edges = generate_matching_edges(
from_node, command.command from_node, command.command, context
) )
edges.extend(matching_edges) edges.extend(matching_edges)
@ -242,7 +321,7 @@ def invoke_cli():
link_node = context.session.graph.get_node(node_id) link_node = context.session.graph.get_node(node_id)
matching_edges = generate_matching_edges( matching_edges = generate_matching_edges(
link_node, command.command link_node, command.command, context
) )
matching_destinations = [e.destination for e in matching_edges] matching_destinations = [e.destination for e in matching_edges]
edges = [e for e in edges if e.destination not in matching_destinations] edges = [e for e in edges if e.destination not in matching_destinations]
@ -256,12 +335,14 @@ def invoke_cli():
if re_negid.match(node_id): if re_negid.match(node_id):
node_id = str(current_id + int(node_id)) node_id = str(current_id + int(node_id))
# TODO: handle missing input/output
node_output = get_node_outputs(context.session.graph.get_node(node_id), context)[link[1]]
node_input = get_node_inputs(command.command, context)[link[2]]
edges.append( edges.append(
Edge( Edge(
source=EdgeConnection(node_id=node_id, field=link[1]), source=EdgeConnection(node_id=node_output.node_path, field=node_output.field),
destination=EdgeConnection( destination=EdgeConnection(node_id=node_input.node_path, field=node_input.field)
node_id=command.command.id, field=link[2]
)
) )
) )
@ -270,10 +351,10 @@ def invoke_cli():
current_id = current_id + 1 current_id = current_id + 1
# Add the node to the session # Add the node to the session
context.session.add_node(command.command) context.add_node(command.command)
for edge in edges: for edge in edges:
print(edge) print(edge)
context.session.add_edge(edge) context.add_edge(edge)
# Execute all remaining nodes # Execute all remaining nodes
invoke_all(context) invoke_all(context)
@ -285,7 +366,7 @@ def invoke_cli():
except SessionError: except SessionError:
# Start a new session # Start a new session
print("Session error: creating a new session") print("Session error: creating a new session")
context.session = context.invoker.create_execution_state() context.reset()
except ExitCli: except ExitCli:
break break

View File

@ -1,5 +1,6 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import random
from typing import Literal, Optional from typing import Literal, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
import torch import torch
@ -99,13 +100,17 @@ def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_c
return x return x
def random_seed():
return random.randint(0, np.iinfo(np.uint32).max)
class NoiseInvocation(BaseInvocation): class NoiseInvocation(BaseInvocation):
"""Generates latent noise.""" """Generates latent noise."""
type: Literal["noise"] = "noise" type: Literal["noise"] = "noise"
# Inputs # Inputs
seed: int = Field(default=0, ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", ) seed: int = Field(ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", default_factory=random_seed)
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting noise", ) width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting noise", )
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting noise", ) height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting noise", )
@ -313,6 +318,56 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
) )
class LatentsToLatentsInvocation(TextToLatentsInvocation):
"""Generates latents using latents as base image."""
type: Literal["l2l"] = "l2l"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
strength: float = Field(default=0.5, description="The strength of the latents to use")
def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name)
latent = context.services.latents.get(self.latents.latents_name)
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, state)
model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(model)
# TODO: Verify the noise is the right size
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
latent, device=model.device, dtype=latent.dtype
)
timesteps, _ = model.get_img2img_timesteps(
self.steps,
self.strength,
device=model.device,
)
result_latents, result_attention_map_saver = model.latents_from_embeddings(
latents=initial_latents,
timesteps=timesteps,
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
callback=step_callback
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, result_latents)
return LatentsOutput(
latents=LatentsField(latents_name=name)
)
# Latent to image # Latent to image
class LatentsToImageInvocation(BaseInvocation): class LatentsToImageInvocation(BaseInvocation):
"""Generates an image from latents.""" """Generates an image from latents."""

View File

@ -0,0 +1,18 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal
from pydantic import Field
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from .math import IntOutput
# Pass-through parameter nodes - used by subgraphs
class ParamIntInvocation(BaseInvocation):
"""An integer parameter"""
#fmt: off
type: Literal["param_int"] = "param_int"
a: int = Field(default=0, description="The integer value")
#fmt: on
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a)

View File

@ -0,0 +1,56 @@
from ..invocations.latent import LatentsToImageInvocation, NoiseInvocation, TextToLatentsInvocation
from ..invocations.params import ParamIntInvocation
from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph
from .item_storage import ItemStorageABC
default_text_to_image_graph_id = '539b2af5-2b4d-4d8c-8071-e54a3255fc74'
def create_text_to_image() -> LibraryGraph:
return LibraryGraph(
id=default_text_to_image_graph_id,
name='t2i',
description='Converts text to an image',
graph=Graph(
nodes={
'width': ParamIntInvocation(id='width', a=512),
'height': ParamIntInvocation(id='height', a=512),
'3': NoiseInvocation(id='3'),
'4': TextToLatentsInvocation(id='4'),
'5': LatentsToImageInvocation(id='5')
},
edges=[
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')),
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='3', field='height')),
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='4', field='width')),
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='4', field='height')),
Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='4', field='noise')),
Edge(source=EdgeConnection(node_id='4', field='latents'), destination=EdgeConnection(node_id='5', field='latents')),
]
),
exposed_inputs=[
ExposedNodeInput(node_path='4', field='prompt', alias='prompt'),
ExposedNodeInput(node_path='width', field='a', alias='width'),
ExposedNodeInput(node_path='height', field='a', alias='height')
],
exposed_outputs=[
ExposedNodeOutput(node_path='5', field='image', alias='image')
])
def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[LibraryGraph]:
"""Creates the default system graphs, or adds new versions if the old ones don't match"""
graphs: list[LibraryGraph] = list()
text_to_image = graph_library.get(default_text_to_image_graph_id)
# TODO: Check if the graph is the same as the default one, and if not, update it
#if text_to_image is None:
text_to_image = create_text_to_image()
graph_library.set(text_to_image)
graphs.append(text_to_image)
return graphs

View File

@ -17,7 +17,7 @@ from typing import (
) )
import networkx as nx import networkx as nx
from pydantic import BaseModel, validator from pydantic import BaseModel, root_validator, validator
from pydantic.fields import Field from pydantic.fields import Field
from ..invocations import * from ..invocations import *
@ -283,7 +283,8 @@ class Graph(BaseModel):
:raises InvalidEdgeError: the provided edge is invalid. :raises InvalidEdgeError: the provided edge is invalid.
""" """
if self._is_edge_valid(edge) and edge not in self.edges: self._validate_edge(edge)
if edge not in self.edges:
self.edges.append(edge) self.edges.append(edge)
else: else:
raise InvalidEdgeError() raise InvalidEdgeError()
@ -354,7 +355,7 @@ class Graph(BaseModel):
return True return True
def _is_edge_valid(self, edge: Edge) -> bool: def _validate_edge(self, edge: Edge):
"""Validates that a new edge doesn't create a cycle in the graph""" """Validates that a new edge doesn't create a cycle in the graph"""
# Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly) # Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly)
@ -362,54 +363,53 @@ class Graph(BaseModel):
from_node = self.get_node(edge.source.node_id) from_node = self.get_node(edge.source.node_id)
to_node = self.get_node(edge.destination.node_id) to_node = self.get_node(edge.destination.node_id)
except NodeNotFoundError: except NodeNotFoundError:
return False raise InvalidEdgeError("One or both nodes don't exist")
# Validate that an edge to this node+field doesn't already exist # Validate that an edge to this node+field doesn't already exist
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field) input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation): if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation):
return False raise InvalidEdgeError(f'Edge to node {edge.destination.node_id} field {edge.destination.field} already exists')
# Validate that no cycles would be created # Validate that no cycles would be created
g = self.nx_graph_flat() g = self.nx_graph_flat()
g.add_edge(edge.source.node_id, edge.destination.node_id) g.add_edge(edge.source.node_id, edge.destination.node_id)
if not nx.is_directed_acyclic_graph(g): if not nx.is_directed_acyclic_graph(g):
return False raise InvalidEdgeError(f'Edge creates a cycle in the graph')
# Validate that the field types are compatible # Validate that the field types are compatible
if not are_connections_compatible( if not are_connections_compatible(
from_node, edge.source.field, to_node, edge.destination.field from_node, edge.source.field, to_node, edge.destination.field
): ):
return False raise InvalidEdgeError(f'Fields are incompatible')
# Validate if iterator output type matches iterator input type (if this edge results in both being set) # Validate if iterator output type matches iterator input type (if this edge results in both being set)
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection": if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
if not self._is_iterator_connection_valid( if not self._is_iterator_connection_valid(
edge.destination.node_id, new_input=edge.source edge.destination.node_id, new_input=edge.source
): ):
return False raise InvalidEdgeError(f'Iterator input type does not match iterator output type')
# Validate if iterator input type matches output type (if this edge results in both being set) # Validate if iterator input type matches output type (if this edge results in both being set)
if isinstance(from_node, IterateInvocation) and edge.source.field == "item": if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
if not self._is_iterator_connection_valid( if not self._is_iterator_connection_valid(
edge.source.node_id, new_output=edge.destination edge.source.node_id, new_output=edge.destination
): ):
return False raise InvalidEdgeError(f'Iterator output type does not match iterator input type')
# Validate if collector input type matches output type (if this edge results in both being set) # Validate if collector input type matches output type (if this edge results in both being set)
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item": if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
if not self._is_collector_connection_valid( if not self._is_collector_connection_valid(
edge.destination.node_id, new_input=edge.source edge.destination.node_id, new_input=edge.source
): ):
return False raise InvalidEdgeError(f'Collector output type does not match collector input type')
# Validate if collector output type matches input type (if this edge results in both being set) # Validate if collector output type matches input type (if this edge results in both being set)
if isinstance(from_node, CollectInvocation) and edge.source.field == "collection": if isinstance(from_node, CollectInvocation) and edge.source.field == "collection":
if not self._is_collector_connection_valid( if not self._is_collector_connection_valid(
edge.source.node_id, new_output=edge.destination edge.source.node_id, new_output=edge.destination
): ):
return False raise InvalidEdgeError(f'Collector input type does not match collector output type')
return True
def has_node(self, node_path: str) -> bool: def has_node(self, node_path: str) -> bool:
"""Determines whether or not a node exists in the graph.""" """Determines whether or not a node exists in the graph."""
@ -733,7 +733,7 @@ class Graph(BaseModel):
for sgn in ( for sgn in (
gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation) gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation)
): ):
sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix)) g = sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix))
# TODO: figure out if iteration nodes need to be expanded # TODO: figure out if iteration nodes need to be expanded
@ -858,7 +858,8 @@ class GraphExecutionState(BaseModel):
def is_complete(self) -> bool: def is_complete(self) -> bool:
"""Returns true if the graph is complete""" """Returns true if the graph is complete"""
return self.has_error() or all((k in self.executed for k in self.graph.nodes)) node_ids = set(self.graph.nx_graph_flat().nodes)
return self.has_error() or all((k in self.executed for k in node_ids))
def has_error(self) -> bool: def has_error(self) -> bool:
"""Returns true if the graph has any errors""" """Returns true if the graph has any errors"""
@ -946,11 +947,11 @@ class GraphExecutionState(BaseModel):
def _iterator_graph(self) -> nx.DiGraph: def _iterator_graph(self) -> nx.DiGraph:
"""Gets a DiGraph with edges to collectors removed so an ancestor search produces all active iterators for any node""" """Gets a DiGraph with edges to collectors removed so an ancestor search produces all active iterators for any node"""
g = self.graph.nx_graph() g = self.graph.nx_graph_flat()
collectors = ( collectors = (
n n
for n in self.graph.nodes for n in self.graph.nodes
if isinstance(self.graph.nodes[n], CollectInvocation) if isinstance(self.graph.get_node(n), CollectInvocation)
) )
for c in collectors: for c in collectors:
g.remove_edges_from(list(g.in_edges(c))) g.remove_edges_from(list(g.in_edges(c)))
@ -962,7 +963,7 @@ class GraphExecutionState(BaseModel):
iterators = [ iterators = [
n n
for n in nx.ancestors(g, node_id) for n in nx.ancestors(g, node_id)
if isinstance(self.graph.nodes[n], IterateInvocation) if isinstance(self.graph.get_node(n), IterateInvocation)
] ]
return iterators return iterators
@ -1098,7 +1099,9 @@ class GraphExecutionState(BaseModel):
# TODO: Add API for modifying underlying graph that checks if the change will be valid given the current execution state # TODO: Add API for modifying underlying graph that checks if the change will be valid given the current execution state
def _is_edge_valid(self, edge: Edge) -> bool: def _is_edge_valid(self, edge: Edge) -> bool:
if not self._is_edge_valid(edge): try:
self.graph._validate_edge(edge)
except InvalidEdgeError:
return False return False
# Invalid if destination has already been prepared or executed # Invalid if destination has already been prepared or executed
@ -1144,4 +1147,52 @@ class GraphExecutionState(BaseModel):
self.graph.delete_edge(edge) self.graph.delete_edge(edge)
class ExposedNodeInput(BaseModel):
node_path: str = Field(description="The node path to the node with the input")
field: str = Field(description="The field name of the input")
alias: str = Field(description="The alias of the input")
class ExposedNodeOutput(BaseModel):
node_path: str = Field(description="The node path to the node with the output")
field: str = Field(description="The field name of the output")
alias: str = Field(description="The alias of the output")
class LibraryGraph(BaseModel):
id: str = Field(description="The unique identifier for this library graph", default_factory=uuid.uuid4)
graph: Graph = Field(description="The graph")
name: str = Field(description="The name of the graph")
description: str = Field(description="The description of the graph")
exposed_inputs: list[ExposedNodeInput] = Field(description="The inputs exposed by this graph", default_factory=list)
exposed_outputs: list[ExposedNodeOutput] = Field(description="The outputs exposed by this graph", default_factory=list)
@validator('exposed_inputs', 'exposed_outputs')
def validate_exposed_aliases(cls, v):
if len(v) != len(set(i.alias for i in v)):
raise ValueError("Duplicate exposed alias")
return v
@root_validator
def validate_exposed_nodes(cls, values):
graph = values['graph']
# Validate exposed inputs
for exposed_input in values['exposed_inputs']:
if not graph.has_node(exposed_input.node_path):
raise ValueError(f"Exposed input node {exposed_input.node_path} does not exist")
node = graph.get_node(exposed_input.node_path)
if get_input_field(node, exposed_input.field) is None:
raise ValueError(f"Exposed input field {exposed_input.field} does not exist on node {exposed_input.node_path}")
# Validate exposed outputs
for exposed_output in values['exposed_outputs']:
if not graph.has_node(exposed_output.node_path):
raise ValueError(f"Exposed output node {exposed_output.node_path} does not exist")
node = graph.get_node(exposed_output.node_path)
if get_output_field(node, exposed_output.field) is None:
raise ValueError(f"Exposed output field {exposed_output.field} does not exist on node {exposed_output.node_path}")
return values
GraphInvocation.update_forward_refs() GraphInvocation.update_forward_refs()

View File

@ -19,6 +19,7 @@ class InvocationServices:
restoration: RestorationServices restoration: RestorationServices
# NOTE: we must forward-declare any types that include invocations, since invocations can use services # NOTE: we must forward-declare any types that include invocations, since invocations can use services
graph_library: ItemStorageABC["LibraryGraph"]
graph_execution_manager: ItemStorageABC["GraphExecutionState"] graph_execution_manager: ItemStorageABC["GraphExecutionState"]
processor: "InvocationProcessorABC" processor: "InvocationProcessorABC"
@ -29,6 +30,7 @@ class InvocationServices:
latents: LatentsStorageBase, latents: LatentsStorageBase,
images: ImageStorageBase, images: ImageStorageBase,
queue: InvocationQueueABC, queue: InvocationQueueABC,
graph_library: ItemStorageABC["LibraryGraph"],
graph_execution_manager: ItemStorageABC["GraphExecutionState"], graph_execution_manager: ItemStorageABC["GraphExecutionState"],
processor: "InvocationProcessorABC", processor: "InvocationProcessorABC",
restoration: RestorationServices, restoration: RestorationServices,
@ -38,6 +40,7 @@ class InvocationServices:
self.latents = latents self.latents = latents
self.images = images self.images = images
self.queue = queue self.queue = queue
self.graph_library = graph_library
self.graph_execution_manager = graph_execution_manager self.graph_execution_manager = graph_execution_manager
self.processor = processor self.processor = processor
self.restoration = restoration self.restoration = restoration

View File

@ -35,8 +35,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
self._create_table() self._create_table()
def _create_table(self): def _create_table(self):
try: with self._lock:
self._lock.acquire()
self._cursor.execute( self._cursor.execute(
f"""CREATE TABLE IF NOT EXISTS {self._table_name} ( f"""CREATE TABLE IF NOT EXISTS {self._table_name} (
item TEXT, item TEXT,
@ -45,34 +44,27 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
self._cursor.execute( self._cursor.execute(
f"""CREATE UNIQUE INDEX IF NOT EXISTS {self._table_name}_id ON {self._table_name}(id);""" f"""CREATE UNIQUE INDEX IF NOT EXISTS {self._table_name}_id ON {self._table_name}(id);"""
) )
finally: self._conn.commit()
self._lock.release()
def _parse_item(self, item: str) -> T: def _parse_item(self, item: str) -> T:
item_type = get_args(self.__orig_class__)[0] item_type = get_args(self.__orig_class__)[0]
return parse_raw_as(item_type, item) return parse_raw_as(item_type, item)
def set(self, item: T): def set(self, item: T):
try: with self._lock:
self._lock.acquire()
self._cursor.execute( self._cursor.execute(
f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""", f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""",
(item.json(),), (item.json(),),
) )
self._conn.commit() self._conn.commit()
finally:
self._lock.release()
self._on_changed(item) self._on_changed(item)
def get(self, id: str) -> Union[T, None]: def get(self, id: str) -> Union[T, None]:
try: with self._lock:
self._lock.acquire()
self._cursor.execute( self._cursor.execute(
f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),) f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)
) )
result = self._cursor.fetchone() result = self._cursor.fetchone()
finally:
self._lock.release()
if not result: if not result:
return None return None
@ -80,19 +72,15 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
return self._parse_item(result[0]) return self._parse_item(result[0])
def delete(self, id: str): def delete(self, id: str):
try: with self._lock:
self._lock.acquire()
self._cursor.execute( self._cursor.execute(
f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),) f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),)
) )
self._conn.commit() self._conn.commit()
finally:
self._lock.release()
self._on_deleted(id) self._on_deleted(id)
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]: def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
try: with self._lock:
self._lock.acquire()
self._cursor.execute( self._cursor.execute(
f"""SELECT item FROM {self._table_name} LIMIT ? OFFSET ?;""", f"""SELECT item FROM {self._table_name} LIMIT ? OFFSET ?;""",
(per_page, page * per_page), (per_page, page * per_page),
@ -103,8 +91,6 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""") self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""")
count = self._cursor.fetchone()[0] count = self._cursor.fetchone()[0]
finally:
self._lock.release()
pageCount = int(count / per_page) + 1 pageCount = int(count / per_page) + 1
@ -115,8 +101,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
def search( def search(
self, query: str, page: int = 0, per_page: int = 10 self, query: str, page: int = 0, per_page: int = 10
) -> PaginatedResults[T]: ) -> PaginatedResults[T]:
try: with self._lock:
self._lock.acquire()
self._cursor.execute( self._cursor.execute(
f"""SELECT item FROM {self._table_name} WHERE item LIKE ? LIMIT ? OFFSET ?;""", f"""SELECT item FROM {self._table_name} WHERE item LIKE ? LIMIT ? OFFSET ?;""",
(f"%{query}%", per_page, page * per_page), (f"%{query}%", per_page, page * per_page),
@ -130,8 +115,6 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
(f"%{query}%",), (f"%{query}%",),
) )
count = self._cursor.fetchone()[0] count = self._cursor.fetchone()[0]
finally:
self._lock.release()
pageCount = int(count / per_page) + 1 pageCount = int(count / per_page) + 1

View File

@ -40,7 +40,7 @@ dependencies = [
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip", "clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
"compel==1.0.5", "compel==1.0.5",
"datasets", "datasets",
"diffusers[torch]~=0.14", "diffusers[torch]==0.14",
"dnspython==2.2.1", "dnspython==2.2.1",
"einops", "einops",
"eventlet", "eventlet",

View File

@ -7,7 +7,7 @@ from invokeai.app.services.processor import DefaultInvocationProcessor
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
from invokeai.app.services.invocation_queue import MemoryInvocationQueue from invokeai.app.services.invocation_queue import MemoryInvocationQueue
from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, LibraryGraph, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
import pytest import pytest
@ -28,6 +28,9 @@ def mock_services():
images = None, # type: ignore images = None, # type: ignore
latents = None, # type: ignore latents = None, # type: ignore
queue = MemoryInvocationQueue(), queue = MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=sqlite_memory, table_name="graphs"
),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor(), processor = DefaultInvocationProcessor(),
restoration = None, # type: ignore restoration = None, # type: ignore

View File

@ -5,7 +5,7 @@ from invokeai.app.services.invocation_queue import MemoryInvocationQueue
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, LibraryGraph, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
import pytest import pytest
@ -26,6 +26,9 @@ def mock_services() -> InvocationServices:
images = None, # type: ignore images = None, # type: ignore
latents = None, # type: ignore latents = None, # type: ignore
queue = MemoryInvocationQueue(), queue = MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=sqlite_memory, table_name="graphs"
),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor(), processor = DefaultInvocationProcessor(),
restoration = None, # type: ignore restoration = None, # type: ignore

View File

@ -1,9 +1,11 @@
from invokeai.app.invocations.image import *
from .test_nodes import ListPassThroughInvocation, PromptTestInvocation from .test_nodes import ListPassThroughInvocation, PromptTestInvocation
from invokeai.app.services.graph import Edge, Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation from invokeai.app.services.graph import Edge, Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation
from invokeai.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation from invokeai.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation
from invokeai.app.invocations.upscale import UpscaleInvocation from invokeai.app.invocations.upscale import UpscaleInvocation
from invokeai.app.invocations.image import *
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
from invokeai.app.invocations.params import ParamIntInvocation
from invokeai.app.services.default_graphs import create_text_to_image
import pytest import pytest
@ -417,6 +419,66 @@ def test_graph_gets_subgraph_node():
assert result.id == '1' assert result.id == '1'
assert result == n1_1 assert result == n1_1
def test_graph_expands_subgraph():
g = Graph()
n1 = GraphInvocation(id = "1")
n1.graph = Graph()
n1_1 = AddInvocation(id = "1", a = 1, b = 2)
n1_2 = SubtractInvocation(id = "2", b = 3)
n1.graph.add_node(n1_1)
n1.graph.add_node(n1_2)
n1.graph.add_edge(create_edge("1","a","2","a"))
g.add_node(n1)
n2 = AddInvocation(id = "2", b = 5)
g.add_node(n2)
g.add_edge(create_edge("1.2","a","2","a"))
dg = g.nx_graph_flat()
assert set(dg.nodes) == set(['1.1', '1.2', '2'])
assert set(dg.edges) == set([('1.1', '1.2'), ('1.2', '2')])
def test_graph_subgraph_t2i():
g = Graph()
n1 = GraphInvocation(id = "1")
# Get text to image default graph
lg = create_text_to_image()
n1.graph = lg.graph
g.add_node(n1)
n2 = ParamIntInvocation(id = "2", a = 512)
n3 = ParamIntInvocation(id = "3", a = 256)
g.add_node(n2)
g.add_node(n3)
g.add_edge(create_edge("2","a","1.width","a"))
g.add_edge(create_edge("3","a","1.height","a"))
n4 = ShowImageInvocation(id = "4")
g.add_node(n4)
g.add_edge(create_edge("1.5","image","4","image"))
# Validate
dg = g.nx_graph_flat()
assert set(dg.nodes) == set(['1.width', '1.height', '1.3', '1.4', '1.5', '2', '3', '4'])
expected_edges = [(f'1.{e.source.node_id}',f'1.{e.destination.node_id}') for e in lg.graph.edges]
expected_edges.extend([
('2','1.width'),
('3','1.height'),
('1.5','4')
])
print(expected_edges)
print(list(dg.edges))
assert set(dg.edges) == set(expected_edges)
def test_graph_fails_to_get_missing_subgraph_node(): def test_graph_fails_to_get_missing_subgraph_node():
g = Graph() g = Graph()
n1 = GraphInvocation(id = "1") n1 = GraphInvocation(id = "1")