mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
[nodes] Add subgraph library, subgraph usage in CLI, and fix subgraph execution (#3180)
* Add latent to latent (img2img equivalent) Fix a CLI bug with multiple links per node * Using "latents" instead of "latent" * [nodes] In-progress implementation of graph library * Add linking to CLI for graph nodes (still broken) * Fix subgraph execution, fix subgraph linking in CLI * Fix LatentsToLatents
This commit is contained in:
parent
024fd54d0b
commit
23d65e7162
@ -3,12 +3,14 @@
|
|||||||
import os
|
import os
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
|
||||||
|
from ..services.default_graphs import create_system_graphs
|
||||||
|
|
||||||
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||||
|
|
||||||
from ...backend import Globals
|
from ...backend import Globals
|
||||||
from ..services.model_manager_initializer import get_model_manager
|
from ..services.model_manager_initializer import get_model_manager
|
||||||
from ..services.restoration_services import RestorationServices
|
from ..services.restoration_services import RestorationServices
|
||||||
from ..services.graph import GraphExecutionState
|
from ..services.graph import GraphExecutionState, LibraryGraph
|
||||||
from ..services.image_storage import DiskImageStorage
|
from ..services.image_storage import DiskImageStorage
|
||||||
from ..services.invocation_queue import MemoryInvocationQueue
|
from ..services.invocation_queue import MemoryInvocationQueue
|
||||||
from ..services.invocation_services import InvocationServices
|
from ..services.invocation_services import InvocationServices
|
||||||
@ -69,6 +71,9 @@ class ApiDependencies:
|
|||||||
latents=latents,
|
latents=latents,
|
||||||
images=images,
|
images=images,
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
|
graph_library=SqliteItemStorage[LibraryGraph](
|
||||||
|
filename=db_location, table_name="graphs"
|
||||||
|
),
|
||||||
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
||||||
filename=db_location, table_name="graph_executions"
|
filename=db_location, table_name="graph_executions"
|
||||||
),
|
),
|
||||||
@ -76,6 +81,8 @@ class ApiDependencies:
|
|||||||
restoration=RestorationServices(config),
|
restoration=RestorationServices(config),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
create_system_graphs(services.graph_library)
|
||||||
|
|
||||||
ApiDependencies.invoker = Invoker(services)
|
ApiDependencies.invoker = Invoker(services)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -7,11 +7,40 @@ from pydantic import BaseModel, Field
|
|||||||
import networkx as nx
|
import networkx as nx
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
from ..models.image import ImageField
|
from ..invocations.baseinvocation import BaseInvocation
|
||||||
from ..services.graph import GraphExecutionState
|
from ..invocations.image import ImageField
|
||||||
|
from ..services.graph import GraphExecutionState, LibraryGraph, GraphInvocation, Edge
|
||||||
from ..services.invoker import Invoker
|
from ..services.invoker import Invoker
|
||||||
|
|
||||||
|
|
||||||
|
def add_field_argument(command_parser, name: str, field, default_override = None):
|
||||||
|
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory()
|
||||||
|
if get_origin(field.type_) == Literal:
|
||||||
|
allowed_values = get_args(field.type_)
|
||||||
|
allowed_types = set()
|
||||||
|
for val in allowed_values:
|
||||||
|
allowed_types.add(type(val))
|
||||||
|
allowed_types_list = list(allowed_types)
|
||||||
|
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
|
||||||
|
|
||||||
|
command_parser.add_argument(
|
||||||
|
f"--{name}",
|
||||||
|
dest=name,
|
||||||
|
type=field_type,
|
||||||
|
default=default,
|
||||||
|
choices=allowed_values,
|
||||||
|
help=field.field_info.description,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
command_parser.add_argument(
|
||||||
|
f"--{name}",
|
||||||
|
dest=name,
|
||||||
|
type=field.type_,
|
||||||
|
default=default,
|
||||||
|
help=field.field_info.description,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def add_parsers(
|
def add_parsers(
|
||||||
subparsers,
|
subparsers,
|
||||||
commands: list[type],
|
commands: list[type],
|
||||||
@ -36,30 +65,26 @@ def add_parsers(
|
|||||||
if name in exclude_fields:
|
if name in exclude_fields:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if get_origin(field.type_) == Literal:
|
add_field_argument(command_parser, name, field)
|
||||||
allowed_values = get_args(field.type_)
|
|
||||||
allowed_types = set()
|
|
||||||
for val in allowed_values:
|
|
||||||
allowed_types.add(type(val))
|
|
||||||
allowed_types_list = list(allowed_types)
|
|
||||||
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
|
|
||||||
|
|
||||||
command_parser.add_argument(
|
|
||||||
f"--{name}",
|
def add_graph_parsers(
|
||||||
dest=name,
|
subparsers,
|
||||||
type=field_type,
|
graphs: list[LibraryGraph],
|
||||||
default=field.default if field.default_factory is None else field.default_factory(),
|
add_arguments: Callable[[argparse.ArgumentParser], None]|None = None
|
||||||
choices=allowed_values,
|
):
|
||||||
help=field.field_info.description,
|
for graph in graphs:
|
||||||
)
|
command_parser = subparsers.add_parser(graph.name, help=graph.description)
|
||||||
else:
|
|
||||||
command_parser.add_argument(
|
if add_arguments is not None:
|
||||||
f"--{name}",
|
add_arguments(command_parser)
|
||||||
dest=name,
|
|
||||||
type=field.type_,
|
# Add arguments for inputs
|
||||||
default=field.default if field.default_factory is None else field.default_factory(),
|
for exposed_input in graph.exposed_inputs:
|
||||||
help=field.field_info.description,
|
node = graph.graph.get_node(exposed_input.node_path)
|
||||||
)
|
field = node.__fields__[exposed_input.field]
|
||||||
|
default_override = getattr(node, exposed_input.field)
|
||||||
|
add_field_argument(command_parser, exposed_input.alias, field, default_override)
|
||||||
|
|
||||||
|
|
||||||
class CliContext:
|
class CliContext:
|
||||||
@ -67,17 +92,38 @@ class CliContext:
|
|||||||
session: GraphExecutionState
|
session: GraphExecutionState
|
||||||
parser: argparse.ArgumentParser
|
parser: argparse.ArgumentParser
|
||||||
defaults: dict[str, Any]
|
defaults: dict[str, Any]
|
||||||
|
graph_nodes: dict[str, str]
|
||||||
|
nodes_added: list[str]
|
||||||
|
|
||||||
def __init__(self, invoker: Invoker, session: GraphExecutionState, parser: argparse.ArgumentParser):
|
def __init__(self, invoker: Invoker, session: GraphExecutionState, parser: argparse.ArgumentParser):
|
||||||
self.invoker = invoker
|
self.invoker = invoker
|
||||||
self.session = session
|
self.session = session
|
||||||
self.parser = parser
|
self.parser = parser
|
||||||
self.defaults = dict()
|
self.defaults = dict()
|
||||||
|
self.graph_nodes = dict()
|
||||||
|
self.nodes_added = list()
|
||||||
|
|
||||||
def get_session(self):
|
def get_session(self):
|
||||||
self.session = self.invoker.services.graph_execution_manager.get(self.session.id)
|
self.session = self.invoker.services.graph_execution_manager.get(self.session.id)
|
||||||
return self.session
|
return self.session
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.session = self.invoker.create_execution_state()
|
||||||
|
self.graph_nodes = dict()
|
||||||
|
self.nodes_added = list()
|
||||||
|
# Leave defaults unchanged
|
||||||
|
|
||||||
|
def add_node(self, node: BaseInvocation):
|
||||||
|
self.get_session()
|
||||||
|
self.session.graph.add_node(node)
|
||||||
|
self.nodes_added.append(node.id)
|
||||||
|
self.invoker.services.graph_execution_manager.set(self.session)
|
||||||
|
|
||||||
|
def add_edge(self, edge: Edge):
|
||||||
|
self.get_session()
|
||||||
|
self.session.add_edge(edge)
|
||||||
|
self.invoker.services.graph_execution_manager.set(self.session)
|
||||||
|
|
||||||
|
|
||||||
class ExitCli(Exception):
|
class ExitCli(Exception):
|
||||||
"""Exception to exit the CLI"""
|
"""Exception to exit the CLI"""
|
||||||
|
@ -13,17 +13,20 @@ from typing import (
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
|
|
||||||
|
from .services.default_graphs import create_system_graphs
|
||||||
|
|
||||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||||
|
|
||||||
from ..backend import Args
|
from ..backend import Args
|
||||||
from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_graph_execution_history
|
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers, get_graph_execution_history
|
||||||
from .cli.completer import set_autocompleter
|
from .cli.completer import set_autocompleter
|
||||||
from .invocations import *
|
from .invocations import *
|
||||||
from .invocations.baseinvocation import BaseInvocation
|
from .invocations.baseinvocation import BaseInvocation
|
||||||
from .services.events import EventServiceBase
|
from .services.events import EventServiceBase
|
||||||
from .services.model_manager_initializer import get_model_manager
|
from .services.model_manager_initializer import get_model_manager
|
||||||
from .services.restoration_services import RestorationServices
|
from .services.restoration_services import RestorationServices
|
||||||
from .services.graph import Edge, EdgeConnection, GraphExecutionState, are_connection_types_compatible
|
from .services.graph import Edge, EdgeConnection, ExposedNodeInput, GraphExecutionState, GraphInvocation, LibraryGraph, are_connection_types_compatible
|
||||||
|
from .services.default_graphs import default_text_to_image_graph_id
|
||||||
from .services.image_storage import DiskImageStorage
|
from .services.image_storage import DiskImageStorage
|
||||||
from .services.invocation_queue import MemoryInvocationQueue
|
from .services.invocation_queue import MemoryInvocationQueue
|
||||||
from .services.invocation_services import InvocationServices
|
from .services.invocation_services import InvocationServices
|
||||||
@ -58,7 +61,7 @@ def add_invocation_args(command_parser):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_command_parser() -> argparse.ArgumentParser:
|
def get_command_parser(services: InvocationServices) -> argparse.ArgumentParser:
|
||||||
# Create invocation parser
|
# Create invocation parser
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
@ -76,20 +79,72 @@ def get_command_parser() -> argparse.ArgumentParser:
|
|||||||
commands = BaseCommand.get_all_subclasses()
|
commands = BaseCommand.get_all_subclasses()
|
||||||
add_parsers(subparsers, commands, exclude_fields=["type"])
|
add_parsers(subparsers, commands, exclude_fields=["type"])
|
||||||
|
|
||||||
|
# Create subparsers for exposed CLI graphs
|
||||||
|
# TODO: add a way to identify these graphs
|
||||||
|
text_to_image = services.graph_library.get(default_text_to_image_graph_id)
|
||||||
|
add_graph_parsers(subparsers, [text_to_image], add_arguments=add_invocation_args)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
class NodeField():
|
||||||
|
alias: str
|
||||||
|
node_path: str
|
||||||
|
field: str
|
||||||
|
field_type: type
|
||||||
|
|
||||||
|
def __init__(self, alias: str, node_path: str, field: str, field_type: type):
|
||||||
|
self.alias = alias
|
||||||
|
self.node_path = node_path
|
||||||
|
self.field = field
|
||||||
|
self.field_type = field_type
|
||||||
|
|
||||||
|
|
||||||
|
def fields_from_type_hints(hints: dict[str, type], node_path: str) -> dict[str,NodeField]:
|
||||||
|
return {k:NodeField(alias=k, node_path=node_path, field=k, field_type=v) for k, v in hints.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def get_node_input_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
|
||||||
|
"""Gets the node field for the specified field alias"""
|
||||||
|
exposed_input = next(e for e in graph.exposed_inputs if e.alias == field_alias)
|
||||||
|
node_type = type(graph.graph.get_node(exposed_input.node_path))
|
||||||
|
return NodeField(alias=exposed_input.alias, node_path=f'{node_id}.{exposed_input.node_path}', field=exposed_input.field, field_type=get_type_hints(node_type)[exposed_input.field])
|
||||||
|
|
||||||
|
|
||||||
|
def get_node_output_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
|
||||||
|
"""Gets the node field for the specified field alias"""
|
||||||
|
exposed_output = next(e for e in graph.exposed_outputs if e.alias == field_alias)
|
||||||
|
node_type = type(graph.graph.get_node(exposed_output.node_path))
|
||||||
|
node_output_type = node_type.get_output_type()
|
||||||
|
return NodeField(alias=exposed_output.alias, node_path=f'{node_id}.{exposed_output.node_path}', field=exposed_output.field, field_type=get_type_hints(node_output_type)[exposed_output.field])
|
||||||
|
|
||||||
|
|
||||||
|
def get_node_inputs(invocation: BaseInvocation, context: CliContext) -> dict[str, NodeField]:
|
||||||
|
"""Gets the inputs for the specified invocation from the context"""
|
||||||
|
node_type = type(invocation)
|
||||||
|
if node_type is not GraphInvocation:
|
||||||
|
return fields_from_type_hints(get_type_hints(node_type), invocation.id)
|
||||||
|
else:
|
||||||
|
graph: LibraryGraph = context.invoker.services.graph_library.get(context.graph_nodes[invocation.id])
|
||||||
|
return {e.alias: get_node_input_field(graph, e.alias, invocation.id) for e in graph.exposed_inputs}
|
||||||
|
|
||||||
|
|
||||||
|
def get_node_outputs(invocation: BaseInvocation, context: CliContext) -> dict[str, NodeField]:
|
||||||
|
"""Gets the outputs for the specified invocation from the context"""
|
||||||
|
node_type = type(invocation)
|
||||||
|
if node_type is not GraphInvocation:
|
||||||
|
return fields_from_type_hints(get_type_hints(node_type.get_output_type()), invocation.id)
|
||||||
|
else:
|
||||||
|
graph: LibraryGraph = context.invoker.services.graph_library.get(context.graph_nodes[invocation.id])
|
||||||
|
return {e.alias: get_node_output_field(graph, e.alias, invocation.id) for e in graph.exposed_outputs}
|
||||||
|
|
||||||
|
|
||||||
def generate_matching_edges(
|
def generate_matching_edges(
|
||||||
a: BaseInvocation, b: BaseInvocation
|
a: BaseInvocation, b: BaseInvocation, context: CliContext
|
||||||
) -> list[Edge]:
|
) -> list[Edge]:
|
||||||
"""Generates all possible edges between two invocations"""
|
"""Generates all possible edges between two invocations"""
|
||||||
atype = type(a)
|
afields = get_node_outputs(a, context)
|
||||||
btype = type(b)
|
bfields = get_node_inputs(b, context)
|
||||||
|
|
||||||
aoutputtype = atype.get_output_type()
|
|
||||||
|
|
||||||
afields = get_type_hints(aoutputtype)
|
|
||||||
bfields = get_type_hints(btype)
|
|
||||||
|
|
||||||
matching_fields = set(afields.keys()).intersection(bfields.keys())
|
matching_fields = set(afields.keys()).intersection(bfields.keys())
|
||||||
|
|
||||||
@ -98,14 +153,14 @@ def generate_matching_edges(
|
|||||||
matching_fields = matching_fields.difference(invalid_fields)
|
matching_fields = matching_fields.difference(invalid_fields)
|
||||||
|
|
||||||
# Validate types
|
# Validate types
|
||||||
matching_fields = [f for f in matching_fields if are_connection_types_compatible(afields[f], bfields[f])]
|
matching_fields = [f for f in matching_fields if are_connection_types_compatible(afields[f].field_type, bfields[f].field_type)]
|
||||||
|
|
||||||
edges = [
|
edges = [
|
||||||
Edge(
|
Edge(
|
||||||
source=EdgeConnection(node_id=a.id, field=field),
|
source=EdgeConnection(node_id=afields[alias].node_path, field=afields[alias].field),
|
||||||
destination=EdgeConnection(node_id=b.id, field=field)
|
destination=EdgeConnection(node_id=bfields[alias].node_path, field=bfields[alias].field)
|
||||||
)
|
)
|
||||||
for field in matching_fields
|
for alias in matching_fields
|
||||||
]
|
]
|
||||||
return edges
|
return edges
|
||||||
|
|
||||||
@ -158,6 +213,9 @@ def invoke_cli():
|
|||||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
|
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
|
||||||
images=DiskImageStorage(f'{output_folder}/images'),
|
images=DiskImageStorage(f'{output_folder}/images'),
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
|
graph_library=SqliteItemStorage[LibraryGraph](
|
||||||
|
filename=db_location, table_name="graphs"
|
||||||
|
),
|
||||||
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
||||||
filename=db_location, table_name="graph_executions"
|
filename=db_location, table_name="graph_executions"
|
||||||
),
|
),
|
||||||
@ -165,9 +223,12 @@ def invoke_cli():
|
|||||||
restoration=RestorationServices(config),
|
restoration=RestorationServices(config),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
system_graphs = create_system_graphs(services.graph_library)
|
||||||
|
system_graph_names = set([g.name for g in system_graphs])
|
||||||
|
|
||||||
invoker = Invoker(services)
|
invoker = Invoker(services)
|
||||||
session: GraphExecutionState = invoker.create_execution_state()
|
session: GraphExecutionState = invoker.create_execution_state()
|
||||||
parser = get_command_parser()
|
parser = get_command_parser(services)
|
||||||
|
|
||||||
re_negid = re.compile('^-[0-9]+$')
|
re_negid = re.compile('^-[0-9]+$')
|
||||||
|
|
||||||
@ -185,11 +246,12 @@ def invoke_cli():
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Refresh the state of the session
|
# Refresh the state of the session
|
||||||
history = list(get_graph_execution_history(context.session))
|
#history = list(get_graph_execution_history(context.session))
|
||||||
|
history = list(reversed(context.nodes_added))
|
||||||
|
|
||||||
# Split the command for piping
|
# Split the command for piping
|
||||||
cmds = cmd_input.split("|")
|
cmds = cmd_input.split("|")
|
||||||
start_id = len(history)
|
start_id = len(context.nodes_added)
|
||||||
current_id = start_id
|
current_id = start_id
|
||||||
new_invocations = list()
|
new_invocations = list()
|
||||||
for cmd in cmds:
|
for cmd in cmds:
|
||||||
@ -205,9 +267,25 @@ def invoke_cli():
|
|||||||
args[field_name] = field_default
|
args[field_name] = field_default
|
||||||
|
|
||||||
# Parse invocation
|
# Parse invocation
|
||||||
|
command: CliCommand = None # type:ignore
|
||||||
|
system_graph: LibraryGraph|None = None
|
||||||
|
if args['type'] in system_graph_names:
|
||||||
|
system_graph = next(filter(lambda g: g.name == args['type'], system_graphs))
|
||||||
|
invocation = GraphInvocation(graph=system_graph.graph, id=str(current_id))
|
||||||
|
for exposed_input in system_graph.exposed_inputs:
|
||||||
|
if exposed_input.alias in args:
|
||||||
|
node = invocation.graph.get_node(exposed_input.node_path)
|
||||||
|
field = exposed_input.field
|
||||||
|
setattr(node, field, args[exposed_input.alias])
|
||||||
|
command = CliCommand(command = invocation)
|
||||||
|
context.graph_nodes[invocation.id] = system_graph.id
|
||||||
|
else:
|
||||||
args["id"] = current_id
|
args["id"] = current_id
|
||||||
command = CliCommand(command=args)
|
command = CliCommand(command=args)
|
||||||
|
|
||||||
|
if command is None:
|
||||||
|
continue
|
||||||
|
|
||||||
# Run any CLI commands immediately
|
# Run any CLI commands immediately
|
||||||
if isinstance(command.command, BaseCommand):
|
if isinstance(command.command, BaseCommand):
|
||||||
# Invoke all current nodes to preserve operation order
|
# Invoke all current nodes to preserve operation order
|
||||||
@ -217,6 +295,7 @@ def invoke_cli():
|
|||||||
command.command.run(context)
|
command.command.run(context)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# TODO: handle linking with library graphs
|
||||||
# Pipe previous command output (if there was a previous command)
|
# Pipe previous command output (if there was a previous command)
|
||||||
edges: list[Edge] = list()
|
edges: list[Edge] = list()
|
||||||
if len(history) > 0 or current_id != start_id:
|
if len(history) > 0 or current_id != start_id:
|
||||||
@ -229,7 +308,7 @@ def invoke_cli():
|
|||||||
else context.session.graph.get_node(from_id)
|
else context.session.graph.get_node(from_id)
|
||||||
)
|
)
|
||||||
matching_edges = generate_matching_edges(
|
matching_edges = generate_matching_edges(
|
||||||
from_node, command.command
|
from_node, command.command, context
|
||||||
)
|
)
|
||||||
edges.extend(matching_edges)
|
edges.extend(matching_edges)
|
||||||
|
|
||||||
@ -242,7 +321,7 @@ def invoke_cli():
|
|||||||
|
|
||||||
link_node = context.session.graph.get_node(node_id)
|
link_node = context.session.graph.get_node(node_id)
|
||||||
matching_edges = generate_matching_edges(
|
matching_edges = generate_matching_edges(
|
||||||
link_node, command.command
|
link_node, command.command, context
|
||||||
)
|
)
|
||||||
matching_destinations = [e.destination for e in matching_edges]
|
matching_destinations = [e.destination for e in matching_edges]
|
||||||
edges = [e for e in edges if e.destination not in matching_destinations]
|
edges = [e for e in edges if e.destination not in matching_destinations]
|
||||||
@ -256,12 +335,14 @@ def invoke_cli():
|
|||||||
if re_negid.match(node_id):
|
if re_negid.match(node_id):
|
||||||
node_id = str(current_id + int(node_id))
|
node_id = str(current_id + int(node_id))
|
||||||
|
|
||||||
|
# TODO: handle missing input/output
|
||||||
|
node_output = get_node_outputs(context.session.graph.get_node(node_id), context)[link[1]]
|
||||||
|
node_input = get_node_inputs(command.command, context)[link[2]]
|
||||||
|
|
||||||
edges.append(
|
edges.append(
|
||||||
Edge(
|
Edge(
|
||||||
source=EdgeConnection(node_id=node_id, field=link[1]),
|
source=EdgeConnection(node_id=node_output.node_path, field=node_output.field),
|
||||||
destination=EdgeConnection(
|
destination=EdgeConnection(node_id=node_input.node_path, field=node_input.field)
|
||||||
node_id=command.command.id, field=link[2]
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -270,10 +351,10 @@ def invoke_cli():
|
|||||||
current_id = current_id + 1
|
current_id = current_id + 1
|
||||||
|
|
||||||
# Add the node to the session
|
# Add the node to the session
|
||||||
context.session.add_node(command.command)
|
context.add_node(command.command)
|
||||||
for edge in edges:
|
for edge in edges:
|
||||||
print(edge)
|
print(edge)
|
||||||
context.session.add_edge(edge)
|
context.add_edge(edge)
|
||||||
|
|
||||||
# Execute all remaining nodes
|
# Execute all remaining nodes
|
||||||
invoke_all(context)
|
invoke_all(context)
|
||||||
@ -285,7 +366,7 @@ def invoke_cli():
|
|||||||
except SessionError:
|
except SessionError:
|
||||||
# Start a new session
|
# Start a new session
|
||||||
print("Session error: creating a new session")
|
print("Session error: creating a new session")
|
||||||
context.session = context.invoker.create_execution_state()
|
context.reset()
|
||||||
|
|
||||||
except ExitCli:
|
except ExitCli:
|
||||||
break
|
break
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
import random
|
||||||
from typing import Literal, Optional
|
from typing import Literal, Optional
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
import torch
|
import torch
|
||||||
@ -99,13 +100,17 @@ def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_c
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def random_seed():
|
||||||
|
return random.randint(0, np.iinfo(np.uint32).max)
|
||||||
|
|
||||||
|
|
||||||
class NoiseInvocation(BaseInvocation):
|
class NoiseInvocation(BaseInvocation):
|
||||||
"""Generates latent noise."""
|
"""Generates latent noise."""
|
||||||
|
|
||||||
type: Literal["noise"] = "noise"
|
type: Literal["noise"] = "noise"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
seed: int = Field(default=0, ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", )
|
seed: int = Field(ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", default_factory=random_seed)
|
||||||
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting noise", )
|
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting noise", )
|
||||||
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting noise", )
|
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting noise", )
|
||||||
|
|
||||||
@ -313,6 +318,56 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||||
|
"""Generates latents using latents as base image."""
|
||||||
|
|
||||||
|
type: Literal["l2l"] = "l2l"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
|
||||||
|
strength: float = Field(default=0.5, description="The strength of the latents to use")
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
|
noise = context.services.latents.get(self.noise.latents_name)
|
||||||
|
latent = context.services.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
|
def step_callback(state: PipelineIntermediateState):
|
||||||
|
self.dispatch_progress(context, state)
|
||||||
|
|
||||||
|
model = self.get_model(context.services.model_manager)
|
||||||
|
conditioning_data = self.get_conditioning_data(model)
|
||||||
|
|
||||||
|
# TODO: Verify the noise is the right size
|
||||||
|
|
||||||
|
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
|
||||||
|
latent, device=model.device, dtype=latent.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
timesteps, _ = model.get_img2img_timesteps(
|
||||||
|
self.steps,
|
||||||
|
self.strength,
|
||||||
|
device=model.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
result_latents, result_attention_map_saver = model.latents_from_embeddings(
|
||||||
|
latents=initial_latents,
|
||||||
|
timesteps=timesteps,
|
||||||
|
noise=noise,
|
||||||
|
num_inference_steps=self.steps,
|
||||||
|
conditioning_data=conditioning_data,
|
||||||
|
callback=step_callback
|
||||||
|
)
|
||||||
|
|
||||||
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||||
|
context.services.latents.set(name, result_latents)
|
||||||
|
return LatentsOutput(
|
||||||
|
latents=LatentsField(latents_name=name)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Latent to image
|
# Latent to image
|
||||||
class LatentsToImageInvocation(BaseInvocation):
|
class LatentsToImageInvocation(BaseInvocation):
|
||||||
"""Generates an image from latents."""
|
"""Generates an image from latents."""
|
||||||
|
18
invokeai/app/invocations/params.py
Normal file
18
invokeai/app/invocations/params.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
from pydantic import Field
|
||||||
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||||
|
from .math import IntOutput
|
||||||
|
|
||||||
|
# Pass-through parameter nodes - used by subgraphs
|
||||||
|
|
||||||
|
class ParamIntInvocation(BaseInvocation):
|
||||||
|
"""An integer parameter"""
|
||||||
|
#fmt: off
|
||||||
|
type: Literal["param_int"] = "param_int"
|
||||||
|
a: int = Field(default=0, description="The integer value")
|
||||||
|
#fmt: on
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||||
|
return IntOutput(a=self.a)
|
56
invokeai/app/services/default_graphs.py
Normal file
56
invokeai/app/services/default_graphs.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
from ..invocations.latent import LatentsToImageInvocation, NoiseInvocation, TextToLatentsInvocation
|
||||||
|
from ..invocations.params import ParamIntInvocation
|
||||||
|
from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph
|
||||||
|
from .item_storage import ItemStorageABC
|
||||||
|
|
||||||
|
|
||||||
|
default_text_to_image_graph_id = '539b2af5-2b4d-4d8c-8071-e54a3255fc74'
|
||||||
|
|
||||||
|
|
||||||
|
def create_text_to_image() -> LibraryGraph:
|
||||||
|
return LibraryGraph(
|
||||||
|
id=default_text_to_image_graph_id,
|
||||||
|
name='t2i',
|
||||||
|
description='Converts text to an image',
|
||||||
|
graph=Graph(
|
||||||
|
nodes={
|
||||||
|
'width': ParamIntInvocation(id='width', a=512),
|
||||||
|
'height': ParamIntInvocation(id='height', a=512),
|
||||||
|
'3': NoiseInvocation(id='3'),
|
||||||
|
'4': TextToLatentsInvocation(id='4'),
|
||||||
|
'5': LatentsToImageInvocation(id='5')
|
||||||
|
},
|
||||||
|
edges=[
|
||||||
|
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')),
|
||||||
|
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='3', field='height')),
|
||||||
|
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='4', field='width')),
|
||||||
|
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='4', field='height')),
|
||||||
|
Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='4', field='noise')),
|
||||||
|
Edge(source=EdgeConnection(node_id='4', field='latents'), destination=EdgeConnection(node_id='5', field='latents')),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
exposed_inputs=[
|
||||||
|
ExposedNodeInput(node_path='4', field='prompt', alias='prompt'),
|
||||||
|
ExposedNodeInput(node_path='width', field='a', alias='width'),
|
||||||
|
ExposedNodeInput(node_path='height', field='a', alias='height')
|
||||||
|
],
|
||||||
|
exposed_outputs=[
|
||||||
|
ExposedNodeOutput(node_path='5', field='image', alias='image')
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[LibraryGraph]:
|
||||||
|
"""Creates the default system graphs, or adds new versions if the old ones don't match"""
|
||||||
|
|
||||||
|
graphs: list[LibraryGraph] = list()
|
||||||
|
|
||||||
|
text_to_image = graph_library.get(default_text_to_image_graph_id)
|
||||||
|
|
||||||
|
# TODO: Check if the graph is the same as the default one, and if not, update it
|
||||||
|
#if text_to_image is None:
|
||||||
|
text_to_image = create_text_to_image()
|
||||||
|
graph_library.set(text_to_image)
|
||||||
|
|
||||||
|
graphs.append(text_to_image)
|
||||||
|
|
||||||
|
return graphs
|
@ -17,7 +17,7 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
from pydantic import BaseModel, validator
|
from pydantic import BaseModel, root_validator, validator
|
||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
|
|
||||||
from ..invocations import *
|
from ..invocations import *
|
||||||
@ -283,7 +283,8 @@ class Graph(BaseModel):
|
|||||||
:raises InvalidEdgeError: the provided edge is invalid.
|
:raises InvalidEdgeError: the provided edge is invalid.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self._is_edge_valid(edge) and edge not in self.edges:
|
self._validate_edge(edge)
|
||||||
|
if edge not in self.edges:
|
||||||
self.edges.append(edge)
|
self.edges.append(edge)
|
||||||
else:
|
else:
|
||||||
raise InvalidEdgeError()
|
raise InvalidEdgeError()
|
||||||
@ -354,7 +355,7 @@ class Graph(BaseModel):
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _is_edge_valid(self, edge: Edge) -> bool:
|
def _validate_edge(self, edge: Edge):
|
||||||
"""Validates that a new edge doesn't create a cycle in the graph"""
|
"""Validates that a new edge doesn't create a cycle in the graph"""
|
||||||
|
|
||||||
# Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly)
|
# Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly)
|
||||||
@ -362,54 +363,53 @@ class Graph(BaseModel):
|
|||||||
from_node = self.get_node(edge.source.node_id)
|
from_node = self.get_node(edge.source.node_id)
|
||||||
to_node = self.get_node(edge.destination.node_id)
|
to_node = self.get_node(edge.destination.node_id)
|
||||||
except NodeNotFoundError:
|
except NodeNotFoundError:
|
||||||
return False
|
raise InvalidEdgeError("One or both nodes don't exist")
|
||||||
|
|
||||||
# Validate that an edge to this node+field doesn't already exist
|
# Validate that an edge to this node+field doesn't already exist
|
||||||
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
|
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
|
||||||
if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation):
|
if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation):
|
||||||
return False
|
raise InvalidEdgeError(f'Edge to node {edge.destination.node_id} field {edge.destination.field} already exists')
|
||||||
|
|
||||||
# Validate that no cycles would be created
|
# Validate that no cycles would be created
|
||||||
g = self.nx_graph_flat()
|
g = self.nx_graph_flat()
|
||||||
g.add_edge(edge.source.node_id, edge.destination.node_id)
|
g.add_edge(edge.source.node_id, edge.destination.node_id)
|
||||||
if not nx.is_directed_acyclic_graph(g):
|
if not nx.is_directed_acyclic_graph(g):
|
||||||
return False
|
raise InvalidEdgeError(f'Edge creates a cycle in the graph')
|
||||||
|
|
||||||
# Validate that the field types are compatible
|
# Validate that the field types are compatible
|
||||||
if not are_connections_compatible(
|
if not are_connections_compatible(
|
||||||
from_node, edge.source.field, to_node, edge.destination.field
|
from_node, edge.source.field, to_node, edge.destination.field
|
||||||
):
|
):
|
||||||
return False
|
raise InvalidEdgeError(f'Fields are incompatible')
|
||||||
|
|
||||||
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
|
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
|
||||||
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
|
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
|
||||||
if not self._is_iterator_connection_valid(
|
if not self._is_iterator_connection_valid(
|
||||||
edge.destination.node_id, new_input=edge.source
|
edge.destination.node_id, new_input=edge.source
|
||||||
):
|
):
|
||||||
return False
|
raise InvalidEdgeError(f'Iterator input type does not match iterator output type')
|
||||||
|
|
||||||
# Validate if iterator input type matches output type (if this edge results in both being set)
|
# Validate if iterator input type matches output type (if this edge results in both being set)
|
||||||
if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
|
if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
|
||||||
if not self._is_iterator_connection_valid(
|
if not self._is_iterator_connection_valid(
|
||||||
edge.source.node_id, new_output=edge.destination
|
edge.source.node_id, new_output=edge.destination
|
||||||
):
|
):
|
||||||
return False
|
raise InvalidEdgeError(f'Iterator output type does not match iterator input type')
|
||||||
|
|
||||||
# Validate if collector input type matches output type (if this edge results in both being set)
|
# Validate if collector input type matches output type (if this edge results in both being set)
|
||||||
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
|
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
|
||||||
if not self._is_collector_connection_valid(
|
if not self._is_collector_connection_valid(
|
||||||
edge.destination.node_id, new_input=edge.source
|
edge.destination.node_id, new_input=edge.source
|
||||||
):
|
):
|
||||||
return False
|
raise InvalidEdgeError(f'Collector output type does not match collector input type')
|
||||||
|
|
||||||
# Validate if collector output type matches input type (if this edge results in both being set)
|
# Validate if collector output type matches input type (if this edge results in both being set)
|
||||||
if isinstance(from_node, CollectInvocation) and edge.source.field == "collection":
|
if isinstance(from_node, CollectInvocation) and edge.source.field == "collection":
|
||||||
if not self._is_collector_connection_valid(
|
if not self._is_collector_connection_valid(
|
||||||
edge.source.node_id, new_output=edge.destination
|
edge.source.node_id, new_output=edge.destination
|
||||||
):
|
):
|
||||||
return False
|
raise InvalidEdgeError(f'Collector input type does not match collector output type')
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def has_node(self, node_path: str) -> bool:
|
def has_node(self, node_path: str) -> bool:
|
||||||
"""Determines whether or not a node exists in the graph."""
|
"""Determines whether or not a node exists in the graph."""
|
||||||
@ -733,7 +733,7 @@ class Graph(BaseModel):
|
|||||||
for sgn in (
|
for sgn in (
|
||||||
gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation)
|
gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation)
|
||||||
):
|
):
|
||||||
sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix))
|
g = sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix))
|
||||||
|
|
||||||
# TODO: figure out if iteration nodes need to be expanded
|
# TODO: figure out if iteration nodes need to be expanded
|
||||||
|
|
||||||
@ -858,7 +858,8 @@ class GraphExecutionState(BaseModel):
|
|||||||
|
|
||||||
def is_complete(self) -> bool:
|
def is_complete(self) -> bool:
|
||||||
"""Returns true if the graph is complete"""
|
"""Returns true if the graph is complete"""
|
||||||
return self.has_error() or all((k in self.executed for k in self.graph.nodes))
|
node_ids = set(self.graph.nx_graph_flat().nodes)
|
||||||
|
return self.has_error() or all((k in self.executed for k in node_ids))
|
||||||
|
|
||||||
def has_error(self) -> bool:
|
def has_error(self) -> bool:
|
||||||
"""Returns true if the graph has any errors"""
|
"""Returns true if the graph has any errors"""
|
||||||
@ -946,11 +947,11 @@ class GraphExecutionState(BaseModel):
|
|||||||
|
|
||||||
def _iterator_graph(self) -> nx.DiGraph:
|
def _iterator_graph(self) -> nx.DiGraph:
|
||||||
"""Gets a DiGraph with edges to collectors removed so an ancestor search produces all active iterators for any node"""
|
"""Gets a DiGraph with edges to collectors removed so an ancestor search produces all active iterators for any node"""
|
||||||
g = self.graph.nx_graph()
|
g = self.graph.nx_graph_flat()
|
||||||
collectors = (
|
collectors = (
|
||||||
n
|
n
|
||||||
for n in self.graph.nodes
|
for n in self.graph.nodes
|
||||||
if isinstance(self.graph.nodes[n], CollectInvocation)
|
if isinstance(self.graph.get_node(n), CollectInvocation)
|
||||||
)
|
)
|
||||||
for c in collectors:
|
for c in collectors:
|
||||||
g.remove_edges_from(list(g.in_edges(c)))
|
g.remove_edges_from(list(g.in_edges(c)))
|
||||||
@ -962,7 +963,7 @@ class GraphExecutionState(BaseModel):
|
|||||||
iterators = [
|
iterators = [
|
||||||
n
|
n
|
||||||
for n in nx.ancestors(g, node_id)
|
for n in nx.ancestors(g, node_id)
|
||||||
if isinstance(self.graph.nodes[n], IterateInvocation)
|
if isinstance(self.graph.get_node(n), IterateInvocation)
|
||||||
]
|
]
|
||||||
return iterators
|
return iterators
|
||||||
|
|
||||||
@ -1098,7 +1099,9 @@ class GraphExecutionState(BaseModel):
|
|||||||
|
|
||||||
# TODO: Add API for modifying underlying graph that checks if the change will be valid given the current execution state
|
# TODO: Add API for modifying underlying graph that checks if the change will be valid given the current execution state
|
||||||
def _is_edge_valid(self, edge: Edge) -> bool:
|
def _is_edge_valid(self, edge: Edge) -> bool:
|
||||||
if not self._is_edge_valid(edge):
|
try:
|
||||||
|
self.graph._validate_edge(edge)
|
||||||
|
except InvalidEdgeError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Invalid if destination has already been prepared or executed
|
# Invalid if destination has already been prepared or executed
|
||||||
@ -1144,4 +1147,52 @@ class GraphExecutionState(BaseModel):
|
|||||||
self.graph.delete_edge(edge)
|
self.graph.delete_edge(edge)
|
||||||
|
|
||||||
|
|
||||||
|
class ExposedNodeInput(BaseModel):
|
||||||
|
node_path: str = Field(description="The node path to the node with the input")
|
||||||
|
field: str = Field(description="The field name of the input")
|
||||||
|
alias: str = Field(description="The alias of the input")
|
||||||
|
|
||||||
|
|
||||||
|
class ExposedNodeOutput(BaseModel):
|
||||||
|
node_path: str = Field(description="The node path to the node with the output")
|
||||||
|
field: str = Field(description="The field name of the output")
|
||||||
|
alias: str = Field(description="The alias of the output")
|
||||||
|
|
||||||
|
class LibraryGraph(BaseModel):
|
||||||
|
id: str = Field(description="The unique identifier for this library graph", default_factory=uuid.uuid4)
|
||||||
|
graph: Graph = Field(description="The graph")
|
||||||
|
name: str = Field(description="The name of the graph")
|
||||||
|
description: str = Field(description="The description of the graph")
|
||||||
|
exposed_inputs: list[ExposedNodeInput] = Field(description="The inputs exposed by this graph", default_factory=list)
|
||||||
|
exposed_outputs: list[ExposedNodeOutput] = Field(description="The outputs exposed by this graph", default_factory=list)
|
||||||
|
|
||||||
|
@validator('exposed_inputs', 'exposed_outputs')
|
||||||
|
def validate_exposed_aliases(cls, v):
|
||||||
|
if len(v) != len(set(i.alias for i in v)):
|
||||||
|
raise ValueError("Duplicate exposed alias")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@root_validator
|
||||||
|
def validate_exposed_nodes(cls, values):
|
||||||
|
graph = values['graph']
|
||||||
|
|
||||||
|
# Validate exposed inputs
|
||||||
|
for exposed_input in values['exposed_inputs']:
|
||||||
|
if not graph.has_node(exposed_input.node_path):
|
||||||
|
raise ValueError(f"Exposed input node {exposed_input.node_path} does not exist")
|
||||||
|
node = graph.get_node(exposed_input.node_path)
|
||||||
|
if get_input_field(node, exposed_input.field) is None:
|
||||||
|
raise ValueError(f"Exposed input field {exposed_input.field} does not exist on node {exposed_input.node_path}")
|
||||||
|
|
||||||
|
# Validate exposed outputs
|
||||||
|
for exposed_output in values['exposed_outputs']:
|
||||||
|
if not graph.has_node(exposed_output.node_path):
|
||||||
|
raise ValueError(f"Exposed output node {exposed_output.node_path} does not exist")
|
||||||
|
node = graph.get_node(exposed_output.node_path)
|
||||||
|
if get_output_field(node, exposed_output.field) is None:
|
||||||
|
raise ValueError(f"Exposed output field {exposed_output.field} does not exist on node {exposed_output.node_path}")
|
||||||
|
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
GraphInvocation.update_forward_refs()
|
GraphInvocation.update_forward_refs()
|
||||||
|
@ -19,6 +19,7 @@ class InvocationServices:
|
|||||||
restoration: RestorationServices
|
restoration: RestorationServices
|
||||||
|
|
||||||
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
|
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
|
||||||
|
graph_library: ItemStorageABC["LibraryGraph"]
|
||||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
||||||
processor: "InvocationProcessorABC"
|
processor: "InvocationProcessorABC"
|
||||||
|
|
||||||
@ -29,6 +30,7 @@ class InvocationServices:
|
|||||||
latents: LatentsStorageBase,
|
latents: LatentsStorageBase,
|
||||||
images: ImageStorageBase,
|
images: ImageStorageBase,
|
||||||
queue: InvocationQueueABC,
|
queue: InvocationQueueABC,
|
||||||
|
graph_library: ItemStorageABC["LibraryGraph"],
|
||||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||||
processor: "InvocationProcessorABC",
|
processor: "InvocationProcessorABC",
|
||||||
restoration: RestorationServices,
|
restoration: RestorationServices,
|
||||||
@ -38,6 +40,7 @@ class InvocationServices:
|
|||||||
self.latents = latents
|
self.latents = latents
|
||||||
self.images = images
|
self.images = images
|
||||||
self.queue = queue
|
self.queue = queue
|
||||||
|
self.graph_library = graph_library
|
||||||
self.graph_execution_manager = graph_execution_manager
|
self.graph_execution_manager = graph_execution_manager
|
||||||
self.processor = processor
|
self.processor = processor
|
||||||
self.restoration = restoration
|
self.restoration = restoration
|
||||||
|
@ -35,8 +35,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
self._create_table()
|
self._create_table()
|
||||||
|
|
||||||
def _create_table(self):
|
def _create_table(self):
|
||||||
try:
|
with self._lock:
|
||||||
self._lock.acquire()
|
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
f"""CREATE TABLE IF NOT EXISTS {self._table_name} (
|
f"""CREATE TABLE IF NOT EXISTS {self._table_name} (
|
||||||
item TEXT,
|
item TEXT,
|
||||||
@ -45,34 +44,27 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
f"""CREATE UNIQUE INDEX IF NOT EXISTS {self._table_name}_id ON {self._table_name}(id);"""
|
f"""CREATE UNIQUE INDEX IF NOT EXISTS {self._table_name}_id ON {self._table_name}(id);"""
|
||||||
)
|
)
|
||||||
finally:
|
self._conn.commit()
|
||||||
self._lock.release()
|
|
||||||
|
|
||||||
def _parse_item(self, item: str) -> T:
|
def _parse_item(self, item: str) -> T:
|
||||||
item_type = get_args(self.__orig_class__)[0]
|
item_type = get_args(self.__orig_class__)[0]
|
||||||
return parse_raw_as(item_type, item)
|
return parse_raw_as(item_type, item)
|
||||||
|
|
||||||
def set(self, item: T):
|
def set(self, item: T):
|
||||||
try:
|
with self._lock:
|
||||||
self._lock.acquire()
|
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""",
|
f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""",
|
||||||
(item.json(),),
|
(item.json(),),
|
||||||
)
|
)
|
||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
self._on_changed(item)
|
self._on_changed(item)
|
||||||
|
|
||||||
def get(self, id: str) -> Union[T, None]:
|
def get(self, id: str) -> Union[T, None]:
|
||||||
try:
|
with self._lock:
|
||||||
self._lock.acquire()
|
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)
|
f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)
|
||||||
)
|
)
|
||||||
result = self._cursor.fetchone()
|
result = self._cursor.fetchone()
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
|
|
||||||
if not result:
|
if not result:
|
||||||
return None
|
return None
|
||||||
@ -80,19 +72,15 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
return self._parse_item(result[0])
|
return self._parse_item(result[0])
|
||||||
|
|
||||||
def delete(self, id: str):
|
def delete(self, id: str):
|
||||||
try:
|
with self._lock:
|
||||||
self._lock.acquire()
|
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),)
|
f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),)
|
||||||
)
|
)
|
||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
self._on_deleted(id)
|
self._on_deleted(id)
|
||||||
|
|
||||||
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||||||
try:
|
with self._lock:
|
||||||
self._lock.acquire()
|
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
f"""SELECT item FROM {self._table_name} LIMIT ? OFFSET ?;""",
|
f"""SELECT item FROM {self._table_name} LIMIT ? OFFSET ?;""",
|
||||||
(per_page, page * per_page),
|
(per_page, page * per_page),
|
||||||
@ -103,8 +91,6 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
|
|
||||||
self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""")
|
self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""")
|
||||||
count = self._cursor.fetchone()[0]
|
count = self._cursor.fetchone()[0]
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
|
|
||||||
pageCount = int(count / per_page) + 1
|
pageCount = int(count / per_page) + 1
|
||||||
|
|
||||||
@ -115,8 +101,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
def search(
|
def search(
|
||||||
self, query: str, page: int = 0, per_page: int = 10
|
self, query: str, page: int = 0, per_page: int = 10
|
||||||
) -> PaginatedResults[T]:
|
) -> PaginatedResults[T]:
|
||||||
try:
|
with self._lock:
|
||||||
self._lock.acquire()
|
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
f"""SELECT item FROM {self._table_name} WHERE item LIKE ? LIMIT ? OFFSET ?;""",
|
f"""SELECT item FROM {self._table_name} WHERE item LIKE ? LIMIT ? OFFSET ?;""",
|
||||||
(f"%{query}%", per_page, page * per_page),
|
(f"%{query}%", per_page, page * per_page),
|
||||||
@ -130,8 +115,6 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
(f"%{query}%",),
|
(f"%{query}%",),
|
||||||
)
|
)
|
||||||
count = self._cursor.fetchone()[0]
|
count = self._cursor.fetchone()[0]
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
|
|
||||||
pageCount = int(count / per_page) + 1
|
pageCount = int(count / per_page) + 1
|
||||||
|
|
||||||
|
@ -40,7 +40,7 @@ dependencies = [
|
|||||||
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||||
"compel==1.0.5",
|
"compel==1.0.5",
|
||||||
"datasets",
|
"datasets",
|
||||||
"diffusers[torch]~=0.14",
|
"diffusers[torch]==0.14",
|
||||||
"dnspython==2.2.1",
|
"dnspython==2.2.1",
|
||||||
"einops",
|
"einops",
|
||||||
"eventlet",
|
"eventlet",
|
||||||
|
@ -7,7 +7,7 @@ from invokeai.app.services.processor import DefaultInvocationProcessor
|
|||||||
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
||||||
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
|
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
|
||||||
from invokeai.app.services.invocation_services import InvocationServices
|
from invokeai.app.services.invocation_services import InvocationServices
|
||||||
from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
|
from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, LibraryGraph, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
@ -28,6 +28,9 @@ def mock_services():
|
|||||||
images = None, # type: ignore
|
images = None, # type: ignore
|
||||||
latents = None, # type: ignore
|
latents = None, # type: ignore
|
||||||
queue = MemoryInvocationQueue(),
|
queue = MemoryInvocationQueue(),
|
||||||
|
graph_library=SqliteItemStorage[LibraryGraph](
|
||||||
|
filename=sqlite_memory, table_name="graphs"
|
||||||
|
),
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
||||||
processor = DefaultInvocationProcessor(),
|
processor = DefaultInvocationProcessor(),
|
||||||
restoration = None, # type: ignore
|
restoration = None, # type: ignore
|
||||||
|
@ -5,7 +5,7 @@ from invokeai.app.services.invocation_queue import MemoryInvocationQueue
|
|||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||||
from invokeai.app.services.invocation_services import InvocationServices
|
from invokeai.app.services.invocation_services import InvocationServices
|
||||||
from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
|
from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, LibraryGraph, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
@ -26,6 +26,9 @@ def mock_services() -> InvocationServices:
|
|||||||
images = None, # type: ignore
|
images = None, # type: ignore
|
||||||
latents = None, # type: ignore
|
latents = None, # type: ignore
|
||||||
queue = MemoryInvocationQueue(),
|
queue = MemoryInvocationQueue(),
|
||||||
|
graph_library=SqliteItemStorage[LibraryGraph](
|
||||||
|
filename=sqlite_memory, table_name="graphs"
|
||||||
|
),
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
||||||
processor = DefaultInvocationProcessor(),
|
processor = DefaultInvocationProcessor(),
|
||||||
restoration = None, # type: ignore
|
restoration = None, # type: ignore
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
from invokeai.app.invocations.image import *
|
|
||||||
|
|
||||||
from .test_nodes import ListPassThroughInvocation, PromptTestInvocation
|
from .test_nodes import ListPassThroughInvocation, PromptTestInvocation
|
||||||
from invokeai.app.services.graph import Edge, Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation
|
from invokeai.app.services.graph import Edge, Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation
|
||||||
from invokeai.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation
|
from invokeai.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation
|
||||||
from invokeai.app.invocations.upscale import UpscaleInvocation
|
from invokeai.app.invocations.upscale import UpscaleInvocation
|
||||||
|
from invokeai.app.invocations.image import *
|
||||||
|
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
|
||||||
|
from invokeai.app.invocations.params import ParamIntInvocation
|
||||||
|
from invokeai.app.services.default_graphs import create_text_to_image
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
@ -417,6 +419,66 @@ def test_graph_gets_subgraph_node():
|
|||||||
assert result.id == '1'
|
assert result.id == '1'
|
||||||
assert result == n1_1
|
assert result == n1_1
|
||||||
|
|
||||||
|
|
||||||
|
def test_graph_expands_subgraph():
|
||||||
|
g = Graph()
|
||||||
|
n1 = GraphInvocation(id = "1")
|
||||||
|
n1.graph = Graph()
|
||||||
|
|
||||||
|
n1_1 = AddInvocation(id = "1", a = 1, b = 2)
|
||||||
|
n1_2 = SubtractInvocation(id = "2", b = 3)
|
||||||
|
n1.graph.add_node(n1_1)
|
||||||
|
n1.graph.add_node(n1_2)
|
||||||
|
n1.graph.add_edge(create_edge("1","a","2","a"))
|
||||||
|
|
||||||
|
g.add_node(n1)
|
||||||
|
|
||||||
|
n2 = AddInvocation(id = "2", b = 5)
|
||||||
|
g.add_node(n2)
|
||||||
|
g.add_edge(create_edge("1.2","a","2","a"))
|
||||||
|
|
||||||
|
dg = g.nx_graph_flat()
|
||||||
|
assert set(dg.nodes) == set(['1.1', '1.2', '2'])
|
||||||
|
assert set(dg.edges) == set([('1.1', '1.2'), ('1.2', '2')])
|
||||||
|
|
||||||
|
|
||||||
|
def test_graph_subgraph_t2i():
|
||||||
|
g = Graph()
|
||||||
|
n1 = GraphInvocation(id = "1")
|
||||||
|
|
||||||
|
# Get text to image default graph
|
||||||
|
lg = create_text_to_image()
|
||||||
|
n1.graph = lg.graph
|
||||||
|
|
||||||
|
g.add_node(n1)
|
||||||
|
|
||||||
|
n2 = ParamIntInvocation(id = "2", a = 512)
|
||||||
|
n3 = ParamIntInvocation(id = "3", a = 256)
|
||||||
|
|
||||||
|
g.add_node(n2)
|
||||||
|
g.add_node(n3)
|
||||||
|
|
||||||
|
g.add_edge(create_edge("2","a","1.width","a"))
|
||||||
|
g.add_edge(create_edge("3","a","1.height","a"))
|
||||||
|
|
||||||
|
n4 = ShowImageInvocation(id = "4")
|
||||||
|
g.add_node(n4)
|
||||||
|
g.add_edge(create_edge("1.5","image","4","image"))
|
||||||
|
|
||||||
|
# Validate
|
||||||
|
dg = g.nx_graph_flat()
|
||||||
|
assert set(dg.nodes) == set(['1.width', '1.height', '1.3', '1.4', '1.5', '2', '3', '4'])
|
||||||
|
expected_edges = [(f'1.{e.source.node_id}',f'1.{e.destination.node_id}') for e in lg.graph.edges]
|
||||||
|
expected_edges.extend([
|
||||||
|
('2','1.width'),
|
||||||
|
('3','1.height'),
|
||||||
|
('1.5','4')
|
||||||
|
])
|
||||||
|
print(expected_edges)
|
||||||
|
print(list(dg.edges))
|
||||||
|
assert set(dg.edges) == set(expected_edges)
|
||||||
|
|
||||||
|
|
||||||
def test_graph_fails_to_get_missing_subgraph_node():
|
def test_graph_fails_to_get_missing_subgraph_node():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = GraphInvocation(id = "1")
|
n1 = GraphInvocation(id = "1")
|
||||||
|
Loading…
Reference in New Issue
Block a user