diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 5698d25758..cd5d8a61b2 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -3,12 +3,14 @@ import os from argparse import Namespace +from ..services.default_graphs import create_system_graphs + from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage from ...backend import Globals from ..services.model_manager_initializer import get_model_manager from ..services.restoration_services import RestorationServices -from ..services.graph import GraphExecutionState +from ..services.graph import GraphExecutionState, LibraryGraph from ..services.image_storage import DiskImageStorage from ..services.invocation_queue import MemoryInvocationQueue from ..services.invocation_services import InvocationServices @@ -69,6 +71,9 @@ class ApiDependencies: latents=latents, images=images, queue=MemoryInvocationQueue(), + graph_library=SqliteItemStorage[LibraryGraph]( + filename=db_location, table_name="graphs" + ), graph_execution_manager=SqliteItemStorage[GraphExecutionState]( filename=db_location, table_name="graph_executions" ), @@ -76,6 +81,8 @@ class ApiDependencies: restoration=RestorationServices(config), ) + create_system_graphs(services.graph_library) + ApiDependencies.invoker = Invoker(services) @staticmethod diff --git a/invokeai/app/cli/commands.py b/invokeai/app/cli/commands.py index 4e9c9aa581..5ad4827eb0 100644 --- a/invokeai/app/cli/commands.py +++ b/invokeai/app/cli/commands.py @@ -7,11 +7,40 @@ from pydantic import BaseModel, Field import networkx as nx import matplotlib.pyplot as plt -from ..models.image import ImageField -from ..services.graph import GraphExecutionState +from ..invocations.baseinvocation import BaseInvocation +from ..invocations.image import ImageField +from ..services.graph import GraphExecutionState, LibraryGraph, GraphInvocation, Edge from ..services.invoker import Invoker +def add_field_argument(command_parser, name: str, field, default_override = None): + default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory() + if get_origin(field.type_) == Literal: + allowed_values = get_args(field.type_) + allowed_types = set() + for val in allowed_values: + allowed_types.add(type(val)) + allowed_types_list = list(allowed_types) + field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore + + command_parser.add_argument( + f"--{name}", + dest=name, + type=field_type, + default=default, + choices=allowed_values, + help=field.field_info.description, + ) + else: + command_parser.add_argument( + f"--{name}", + dest=name, + type=field.type_, + default=default, + help=field.field_info.description, + ) + + def add_parsers( subparsers, commands: list[type], @@ -36,30 +65,26 @@ def add_parsers( if name in exclude_fields: continue - if get_origin(field.type_) == Literal: - allowed_values = get_args(field.type_) - allowed_types = set() - for val in allowed_values: - allowed_types.add(type(val)) - allowed_types_list = list(allowed_types) - field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore + add_field_argument(command_parser, name, field) - command_parser.add_argument( - f"--{name}", - dest=name, - type=field_type, - default=field.default if field.default_factory is None else field.default_factory(), - choices=allowed_values, - help=field.field_info.description, - ) - else: - command_parser.add_argument( - f"--{name}", - dest=name, - type=field.type_, - default=field.default if field.default_factory is None else field.default_factory(), - help=field.field_info.description, - ) + +def add_graph_parsers( + subparsers, + graphs: list[LibraryGraph], + add_arguments: Callable[[argparse.ArgumentParser], None]|None = None +): + for graph in graphs: + command_parser = subparsers.add_parser(graph.name, help=graph.description) + + if add_arguments is not None: + add_arguments(command_parser) + + # Add arguments for inputs + for exposed_input in graph.exposed_inputs: + node = graph.graph.get_node(exposed_input.node_path) + field = node.__fields__[exposed_input.field] + default_override = getattr(node, exposed_input.field) + add_field_argument(command_parser, exposed_input.alias, field, default_override) class CliContext: @@ -67,17 +92,38 @@ class CliContext: session: GraphExecutionState parser: argparse.ArgumentParser defaults: dict[str, Any] + graph_nodes: dict[str, str] + nodes_added: list[str] def __init__(self, invoker: Invoker, session: GraphExecutionState, parser: argparse.ArgumentParser): self.invoker = invoker self.session = session self.parser = parser self.defaults = dict() + self.graph_nodes = dict() + self.nodes_added = list() def get_session(self): self.session = self.invoker.services.graph_execution_manager.get(self.session.id) return self.session + def reset(self): + self.session = self.invoker.create_execution_state() + self.graph_nodes = dict() + self.nodes_added = list() + # Leave defaults unchanged + + def add_node(self, node: BaseInvocation): + self.get_session() + self.session.graph.add_node(node) + self.nodes_added.append(node.id) + self.invoker.services.graph_execution_manager.set(self.session) + + def add_edge(self, edge: Edge): + self.get_session() + self.session.add_edge(edge) + self.invoker.services.graph_execution_manager.set(self.session) + class ExitCli(Exception): """Exception to exit the CLI""" diff --git a/invokeai/app/cli_app.py b/invokeai/app/cli_app.py index a257825dcc..86fd18ca60 100644 --- a/invokeai/app/cli_app.py +++ b/invokeai/app/cli_app.py @@ -13,17 +13,20 @@ from typing import ( from pydantic import BaseModel from pydantic.fields import Field +from .services.default_graphs import create_system_graphs + from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage from ..backend import Args -from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_graph_execution_history +from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers, get_graph_execution_history from .cli.completer import set_autocompleter from .invocations import * from .invocations.baseinvocation import BaseInvocation from .services.events import EventServiceBase from .services.model_manager_initializer import get_model_manager from .services.restoration_services import RestorationServices -from .services.graph import Edge, EdgeConnection, GraphExecutionState, are_connection_types_compatible +from .services.graph import Edge, EdgeConnection, ExposedNodeInput, GraphExecutionState, GraphInvocation, LibraryGraph, are_connection_types_compatible +from .services.default_graphs import default_text_to_image_graph_id from .services.image_storage import DiskImageStorage from .services.invocation_queue import MemoryInvocationQueue from .services.invocation_services import InvocationServices @@ -58,7 +61,7 @@ def add_invocation_args(command_parser): ) -def get_command_parser() -> argparse.ArgumentParser: +def get_command_parser(services: InvocationServices) -> argparse.ArgumentParser: # Create invocation parser parser = argparse.ArgumentParser() @@ -76,20 +79,72 @@ def get_command_parser() -> argparse.ArgumentParser: commands = BaseCommand.get_all_subclasses() add_parsers(subparsers, commands, exclude_fields=["type"]) + # Create subparsers for exposed CLI graphs + # TODO: add a way to identify these graphs + text_to_image = services.graph_library.get(default_text_to_image_graph_id) + add_graph_parsers(subparsers, [text_to_image], add_arguments=add_invocation_args) + return parser +class NodeField(): + alias: str + node_path: str + field: str + field_type: type + + def __init__(self, alias: str, node_path: str, field: str, field_type: type): + self.alias = alias + self.node_path = node_path + self.field = field + self.field_type = field_type + + +def fields_from_type_hints(hints: dict[str, type], node_path: str) -> dict[str,NodeField]: + return {k:NodeField(alias=k, node_path=node_path, field=k, field_type=v) for k, v in hints.items()} + + +def get_node_input_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField: + """Gets the node field for the specified field alias""" + exposed_input = next(e for e in graph.exposed_inputs if e.alias == field_alias) + node_type = type(graph.graph.get_node(exposed_input.node_path)) + return NodeField(alias=exposed_input.alias, node_path=f'{node_id}.{exposed_input.node_path}', field=exposed_input.field, field_type=get_type_hints(node_type)[exposed_input.field]) + + +def get_node_output_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField: + """Gets the node field for the specified field alias""" + exposed_output = next(e for e in graph.exposed_outputs if e.alias == field_alias) + node_type = type(graph.graph.get_node(exposed_output.node_path)) + node_output_type = node_type.get_output_type() + return NodeField(alias=exposed_output.alias, node_path=f'{node_id}.{exposed_output.node_path}', field=exposed_output.field, field_type=get_type_hints(node_output_type)[exposed_output.field]) + + +def get_node_inputs(invocation: BaseInvocation, context: CliContext) -> dict[str, NodeField]: + """Gets the inputs for the specified invocation from the context""" + node_type = type(invocation) + if node_type is not GraphInvocation: + return fields_from_type_hints(get_type_hints(node_type), invocation.id) + else: + graph: LibraryGraph = context.invoker.services.graph_library.get(context.graph_nodes[invocation.id]) + return {e.alias: get_node_input_field(graph, e.alias, invocation.id) for e in graph.exposed_inputs} + + +def get_node_outputs(invocation: BaseInvocation, context: CliContext) -> dict[str, NodeField]: + """Gets the outputs for the specified invocation from the context""" + node_type = type(invocation) + if node_type is not GraphInvocation: + return fields_from_type_hints(get_type_hints(node_type.get_output_type()), invocation.id) + else: + graph: LibraryGraph = context.invoker.services.graph_library.get(context.graph_nodes[invocation.id]) + return {e.alias: get_node_output_field(graph, e.alias, invocation.id) for e in graph.exposed_outputs} + + def generate_matching_edges( - a: BaseInvocation, b: BaseInvocation + a: BaseInvocation, b: BaseInvocation, context: CliContext ) -> list[Edge]: """Generates all possible edges between two invocations""" - atype = type(a) - btype = type(b) - - aoutputtype = atype.get_output_type() - - afields = get_type_hints(aoutputtype) - bfields = get_type_hints(btype) + afields = get_node_outputs(a, context) + bfields = get_node_inputs(b, context) matching_fields = set(afields.keys()).intersection(bfields.keys()) @@ -98,14 +153,14 @@ def generate_matching_edges( matching_fields = matching_fields.difference(invalid_fields) # Validate types - matching_fields = [f for f in matching_fields if are_connection_types_compatible(afields[f], bfields[f])] + matching_fields = [f for f in matching_fields if are_connection_types_compatible(afields[f].field_type, bfields[f].field_type)] edges = [ Edge( - source=EdgeConnection(node_id=a.id, field=field), - destination=EdgeConnection(node_id=b.id, field=field) + source=EdgeConnection(node_id=afields[alias].node_path, field=afields[alias].field), + destination=EdgeConnection(node_id=bfields[alias].node_path, field=bfields[alias].field) ) - for field in matching_fields + for alias in matching_fields ] return edges @@ -158,6 +213,9 @@ def invoke_cli(): latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')), images=DiskImageStorage(f'{output_folder}/images'), queue=MemoryInvocationQueue(), + graph_library=SqliteItemStorage[LibraryGraph]( + filename=db_location, table_name="graphs" + ), graph_execution_manager=SqliteItemStorage[GraphExecutionState]( filename=db_location, table_name="graph_executions" ), @@ -165,9 +223,12 @@ def invoke_cli(): restoration=RestorationServices(config), ) + system_graphs = create_system_graphs(services.graph_library) + system_graph_names = set([g.name for g in system_graphs]) + invoker = Invoker(services) session: GraphExecutionState = invoker.create_execution_state() - parser = get_command_parser() + parser = get_command_parser(services) re_negid = re.compile('^-[0-9]+$') @@ -185,11 +246,12 @@ def invoke_cli(): try: # Refresh the state of the session - history = list(get_graph_execution_history(context.session)) + #history = list(get_graph_execution_history(context.session)) + history = list(reversed(context.nodes_added)) # Split the command for piping cmds = cmd_input.split("|") - start_id = len(history) + start_id = len(context.nodes_added) current_id = start_id new_invocations = list() for cmd in cmds: @@ -205,8 +267,24 @@ def invoke_cli(): args[field_name] = field_default # Parse invocation - args["id"] = current_id - command = CliCommand(command=args) + command: CliCommand = None # type:ignore + system_graph: LibraryGraph|None = None + if args['type'] in system_graph_names: + system_graph = next(filter(lambda g: g.name == args['type'], system_graphs)) + invocation = GraphInvocation(graph=system_graph.graph, id=str(current_id)) + for exposed_input in system_graph.exposed_inputs: + if exposed_input.alias in args: + node = invocation.graph.get_node(exposed_input.node_path) + field = exposed_input.field + setattr(node, field, args[exposed_input.alias]) + command = CliCommand(command = invocation) + context.graph_nodes[invocation.id] = system_graph.id + else: + args["id"] = current_id + command = CliCommand(command=args) + + if command is None: + continue # Run any CLI commands immediately if isinstance(command.command, BaseCommand): @@ -217,6 +295,7 @@ def invoke_cli(): command.command.run(context) continue + # TODO: handle linking with library graphs # Pipe previous command output (if there was a previous command) edges: list[Edge] = list() if len(history) > 0 or current_id != start_id: @@ -229,7 +308,7 @@ def invoke_cli(): else context.session.graph.get_node(from_id) ) matching_edges = generate_matching_edges( - from_node, command.command + from_node, command.command, context ) edges.extend(matching_edges) @@ -242,7 +321,7 @@ def invoke_cli(): link_node = context.session.graph.get_node(node_id) matching_edges = generate_matching_edges( - link_node, command.command + link_node, command.command, context ) matching_destinations = [e.destination for e in matching_edges] edges = [e for e in edges if e.destination not in matching_destinations] @@ -256,12 +335,14 @@ def invoke_cli(): if re_negid.match(node_id): node_id = str(current_id + int(node_id)) + # TODO: handle missing input/output + node_output = get_node_outputs(context.session.graph.get_node(node_id), context)[link[1]] + node_input = get_node_inputs(command.command, context)[link[2]] + edges.append( Edge( - source=EdgeConnection(node_id=node_id, field=link[1]), - destination=EdgeConnection( - node_id=command.command.id, field=link[2] - ) + source=EdgeConnection(node_id=node_output.node_path, field=node_output.field), + destination=EdgeConnection(node_id=node_input.node_path, field=node_input.field) ) ) @@ -270,10 +351,10 @@ def invoke_cli(): current_id = current_id + 1 # Add the node to the session - context.session.add_node(command.command) + context.add_node(command.command) for edge in edges: print(edge) - context.session.add_edge(edge) + context.add_edge(edge) # Execute all remaining nodes invoke_all(context) @@ -285,7 +366,7 @@ def invoke_cli(): except SessionError: # Start a new session print("Session error: creating a new session") - context.session = context.invoker.create_execution_state() + context.reset() except ExitCli: break diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 2da6e451a9..ef17962f89 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -1,5 +1,6 @@ # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) +import random from typing import Literal, Optional from pydantic import BaseModel, Field import torch @@ -99,13 +100,17 @@ def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_c return x +def random_seed(): + return random.randint(0, np.iinfo(np.uint32).max) + + class NoiseInvocation(BaseInvocation): """Generates latent noise.""" type: Literal["noise"] = "noise" # Inputs - seed: int = Field(default=0, ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", ) + seed: int = Field(ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", default_factory=random_seed) width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting noise", ) height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting noise", ) @@ -313,6 +318,56 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): ) +class LatentsToLatentsInvocation(TextToLatentsInvocation): + """Generates latents using latents as base image.""" + + type: Literal["l2l"] = "l2l" + + # Inputs + latents: Optional[LatentsField] = Field(description="The latents to use as a base image") + strength: float = Field(default=0.5, description="The strength of the latents to use") + + def invoke(self, context: InvocationContext) -> LatentsOutput: + noise = context.services.latents.get(self.noise.latents_name) + latent = context.services.latents.get(self.latents.latents_name) + + def step_callback(state: PipelineIntermediateState): + self.dispatch_progress(context, state) + + model = self.get_model(context.services.model_manager) + conditioning_data = self.get_conditioning_data(model) + + # TODO: Verify the noise is the right size + + initial_latents = latent if self.strength < 1.0 else torch.zeros_like( + latent, device=model.device, dtype=latent.dtype + ) + + timesteps, _ = model.get_img2img_timesteps( + self.steps, + self.strength, + device=model.device, + ) + + result_latents, result_attention_map_saver = model.latents_from_embeddings( + latents=initial_latents, + timesteps=timesteps, + noise=noise, + num_inference_steps=self.steps, + conditioning_data=conditioning_data, + callback=step_callback + ) + + # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 + torch.cuda.empty_cache() + + name = f'{context.graph_execution_state_id}__{self.id}' + context.services.latents.set(name, result_latents) + return LatentsOutput( + latents=LatentsField(latents_name=name) + ) + + # Latent to image class LatentsToImageInvocation(BaseInvocation): """Generates an image from latents.""" diff --git a/invokeai/app/invocations/params.py b/invokeai/app/invocations/params.py new file mode 100644 index 0000000000..fcc7f1737a --- /dev/null +++ b/invokeai/app/invocations/params.py @@ -0,0 +1,18 @@ +# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) + +from typing import Literal +from pydantic import Field +from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext +from .math import IntOutput + +# Pass-through parameter nodes - used by subgraphs + +class ParamIntInvocation(BaseInvocation): + """An integer parameter""" + #fmt: off + type: Literal["param_int"] = "param_int" + a: int = Field(default=0, description="The integer value") + #fmt: on + + def invoke(self, context: InvocationContext) -> IntOutput: + return IntOutput(a=self.a) diff --git a/invokeai/app/services/default_graphs.py b/invokeai/app/services/default_graphs.py new file mode 100644 index 0000000000..637d906e75 --- /dev/null +++ b/invokeai/app/services/default_graphs.py @@ -0,0 +1,56 @@ +from ..invocations.latent import LatentsToImageInvocation, NoiseInvocation, TextToLatentsInvocation +from ..invocations.params import ParamIntInvocation +from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph +from .item_storage import ItemStorageABC + + +default_text_to_image_graph_id = '539b2af5-2b4d-4d8c-8071-e54a3255fc74' + + +def create_text_to_image() -> LibraryGraph: + return LibraryGraph( + id=default_text_to_image_graph_id, + name='t2i', + description='Converts text to an image', + graph=Graph( + nodes={ + 'width': ParamIntInvocation(id='width', a=512), + 'height': ParamIntInvocation(id='height', a=512), + '3': NoiseInvocation(id='3'), + '4': TextToLatentsInvocation(id='4'), + '5': LatentsToImageInvocation(id='5') + }, + edges=[ + Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')), + Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='3', field='height')), + Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='4', field='width')), + Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='4', field='height')), + Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='4', field='noise')), + Edge(source=EdgeConnection(node_id='4', field='latents'), destination=EdgeConnection(node_id='5', field='latents')), + ] + ), + exposed_inputs=[ + ExposedNodeInput(node_path='4', field='prompt', alias='prompt'), + ExposedNodeInput(node_path='width', field='a', alias='width'), + ExposedNodeInput(node_path='height', field='a', alias='height') + ], + exposed_outputs=[ + ExposedNodeOutput(node_path='5', field='image', alias='image') + ]) + + +def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[LibraryGraph]: + """Creates the default system graphs, or adds new versions if the old ones don't match""" + + graphs: list[LibraryGraph] = list() + + text_to_image = graph_library.get(default_text_to_image_graph_id) + + # TODO: Check if the graph is the same as the default one, and if not, update it + #if text_to_image is None: + text_to_image = create_text_to_image() + graph_library.set(text_to_image) + + graphs.append(text_to_image) + + return graphs diff --git a/invokeai/app/services/graph.py b/invokeai/app/services/graph.py index e286569bcc..44f6a3d69e 100644 --- a/invokeai/app/services/graph.py +++ b/invokeai/app/services/graph.py @@ -17,7 +17,7 @@ from typing import ( ) import networkx as nx -from pydantic import BaseModel, validator +from pydantic import BaseModel, root_validator, validator from pydantic.fields import Field from ..invocations import * @@ -283,7 +283,8 @@ class Graph(BaseModel): :raises InvalidEdgeError: the provided edge is invalid. """ - if self._is_edge_valid(edge) and edge not in self.edges: + self._validate_edge(edge) + if edge not in self.edges: self.edges.append(edge) else: raise InvalidEdgeError() @@ -354,7 +355,7 @@ class Graph(BaseModel): return True - def _is_edge_valid(self, edge: Edge) -> bool: + def _validate_edge(self, edge: Edge): """Validates that a new edge doesn't create a cycle in the graph""" # Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly) @@ -362,54 +363,53 @@ class Graph(BaseModel): from_node = self.get_node(edge.source.node_id) to_node = self.get_node(edge.destination.node_id) except NodeNotFoundError: - return False + raise InvalidEdgeError("One or both nodes don't exist") # Validate that an edge to this node+field doesn't already exist input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field) if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation): - return False + raise InvalidEdgeError(f'Edge to node {edge.destination.node_id} field {edge.destination.field} already exists') # Validate that no cycles would be created g = self.nx_graph_flat() g.add_edge(edge.source.node_id, edge.destination.node_id) if not nx.is_directed_acyclic_graph(g): - return False + raise InvalidEdgeError(f'Edge creates a cycle in the graph') # Validate that the field types are compatible if not are_connections_compatible( from_node, edge.source.field, to_node, edge.destination.field ): - return False + raise InvalidEdgeError(f'Fields are incompatible') # Validate if iterator output type matches iterator input type (if this edge results in both being set) if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection": if not self._is_iterator_connection_valid( edge.destination.node_id, new_input=edge.source ): - return False + raise InvalidEdgeError(f'Iterator input type does not match iterator output type') # Validate if iterator input type matches output type (if this edge results in both being set) if isinstance(from_node, IterateInvocation) and edge.source.field == "item": if not self._is_iterator_connection_valid( edge.source.node_id, new_output=edge.destination ): - return False + raise InvalidEdgeError(f'Iterator output type does not match iterator input type') # Validate if collector input type matches output type (if this edge results in both being set) if isinstance(to_node, CollectInvocation) and edge.destination.field == "item": if not self._is_collector_connection_valid( edge.destination.node_id, new_input=edge.source ): - return False + raise InvalidEdgeError(f'Collector output type does not match collector input type') # Validate if collector output type matches input type (if this edge results in both being set) if isinstance(from_node, CollectInvocation) and edge.source.field == "collection": if not self._is_collector_connection_valid( edge.source.node_id, new_output=edge.destination ): - return False + raise InvalidEdgeError(f'Collector input type does not match collector output type') - return True def has_node(self, node_path: str) -> bool: """Determines whether or not a node exists in the graph.""" @@ -733,7 +733,7 @@ class Graph(BaseModel): for sgn in ( gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation) ): - sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix)) + g = sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix)) # TODO: figure out if iteration nodes need to be expanded @@ -858,7 +858,8 @@ class GraphExecutionState(BaseModel): def is_complete(self) -> bool: """Returns true if the graph is complete""" - return self.has_error() or all((k in self.executed for k in self.graph.nodes)) + node_ids = set(self.graph.nx_graph_flat().nodes) + return self.has_error() or all((k in self.executed for k in node_ids)) def has_error(self) -> bool: """Returns true if the graph has any errors""" @@ -946,11 +947,11 @@ class GraphExecutionState(BaseModel): def _iterator_graph(self) -> nx.DiGraph: """Gets a DiGraph with edges to collectors removed so an ancestor search produces all active iterators for any node""" - g = self.graph.nx_graph() + g = self.graph.nx_graph_flat() collectors = ( n for n in self.graph.nodes - if isinstance(self.graph.nodes[n], CollectInvocation) + if isinstance(self.graph.get_node(n), CollectInvocation) ) for c in collectors: g.remove_edges_from(list(g.in_edges(c))) @@ -962,7 +963,7 @@ class GraphExecutionState(BaseModel): iterators = [ n for n in nx.ancestors(g, node_id) - if isinstance(self.graph.nodes[n], IterateInvocation) + if isinstance(self.graph.get_node(n), IterateInvocation) ] return iterators @@ -1098,7 +1099,9 @@ class GraphExecutionState(BaseModel): # TODO: Add API for modifying underlying graph that checks if the change will be valid given the current execution state def _is_edge_valid(self, edge: Edge) -> bool: - if not self._is_edge_valid(edge): + try: + self.graph._validate_edge(edge) + except InvalidEdgeError: return False # Invalid if destination has already been prepared or executed @@ -1144,4 +1147,52 @@ class GraphExecutionState(BaseModel): self.graph.delete_edge(edge) +class ExposedNodeInput(BaseModel): + node_path: str = Field(description="The node path to the node with the input") + field: str = Field(description="The field name of the input") + alias: str = Field(description="The alias of the input") + + +class ExposedNodeOutput(BaseModel): + node_path: str = Field(description="The node path to the node with the output") + field: str = Field(description="The field name of the output") + alias: str = Field(description="The alias of the output") + +class LibraryGraph(BaseModel): + id: str = Field(description="The unique identifier for this library graph", default_factory=uuid.uuid4) + graph: Graph = Field(description="The graph") + name: str = Field(description="The name of the graph") + description: str = Field(description="The description of the graph") + exposed_inputs: list[ExposedNodeInput] = Field(description="The inputs exposed by this graph", default_factory=list) + exposed_outputs: list[ExposedNodeOutput] = Field(description="The outputs exposed by this graph", default_factory=list) + + @validator('exposed_inputs', 'exposed_outputs') + def validate_exposed_aliases(cls, v): + if len(v) != len(set(i.alias for i in v)): + raise ValueError("Duplicate exposed alias") + return v + + @root_validator + def validate_exposed_nodes(cls, values): + graph = values['graph'] + + # Validate exposed inputs + for exposed_input in values['exposed_inputs']: + if not graph.has_node(exposed_input.node_path): + raise ValueError(f"Exposed input node {exposed_input.node_path} does not exist") + node = graph.get_node(exposed_input.node_path) + if get_input_field(node, exposed_input.field) is None: + raise ValueError(f"Exposed input field {exposed_input.field} does not exist on node {exposed_input.node_path}") + + # Validate exposed outputs + for exposed_output in values['exposed_outputs']: + if not graph.has_node(exposed_output.node_path): + raise ValueError(f"Exposed output node {exposed_output.node_path} does not exist") + node = graph.get_node(exposed_output.node_path) + if get_output_field(node, exposed_output.field) is None: + raise ValueError(f"Exposed output field {exposed_output.field} does not exist on node {exposed_output.node_path}") + + return values + + GraphInvocation.update_forward_refs() diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 2cd0f55fd9..c3c6bbce7e 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -19,6 +19,7 @@ class InvocationServices: restoration: RestorationServices # NOTE: we must forward-declare any types that include invocations, since invocations can use services + graph_library: ItemStorageABC["LibraryGraph"] graph_execution_manager: ItemStorageABC["GraphExecutionState"] processor: "InvocationProcessorABC" @@ -29,6 +30,7 @@ class InvocationServices: latents: LatentsStorageBase, images: ImageStorageBase, queue: InvocationQueueABC, + graph_library: ItemStorageABC["LibraryGraph"], graph_execution_manager: ItemStorageABC["GraphExecutionState"], processor: "InvocationProcessorABC", restoration: RestorationServices, @@ -38,6 +40,7 @@ class InvocationServices: self.latents = latents self.images = images self.queue = queue + self.graph_library = graph_library self.graph_execution_manager = graph_execution_manager self.processor = processor self.restoration = restoration diff --git a/invokeai/app/services/sqlite.py b/invokeai/app/services/sqlite.py index fd089014bb..e06ca8c1ac 100644 --- a/invokeai/app/services/sqlite.py +++ b/invokeai/app/services/sqlite.py @@ -35,8 +35,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]): self._create_table() def _create_table(self): - try: - self._lock.acquire() + with self._lock: self._cursor.execute( f"""CREATE TABLE IF NOT EXISTS {self._table_name} ( item TEXT, @@ -45,34 +44,27 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]): self._cursor.execute( f"""CREATE UNIQUE INDEX IF NOT EXISTS {self._table_name}_id ON {self._table_name}(id);""" ) - finally: - self._lock.release() + self._conn.commit() def _parse_item(self, item: str) -> T: item_type = get_args(self.__orig_class__)[0] return parse_raw_as(item_type, item) def set(self, item: T): - try: - self._lock.acquire() + with self._lock: self._cursor.execute( f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""", (item.json(),), ) self._conn.commit() - finally: - self._lock.release() self._on_changed(item) def get(self, id: str) -> Union[T, None]: - try: - self._lock.acquire() + with self._lock: self._cursor.execute( f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),) ) result = self._cursor.fetchone() - finally: - self._lock.release() if not result: return None @@ -80,19 +72,15 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]): return self._parse_item(result[0]) def delete(self, id: str): - try: - self._lock.acquire() + with self._lock: self._cursor.execute( f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),) ) self._conn.commit() - finally: - self._lock.release() self._on_deleted(id) def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]: - try: - self._lock.acquire() + with self._lock: self._cursor.execute( f"""SELECT item FROM {self._table_name} LIMIT ? OFFSET ?;""", (per_page, page * per_page), @@ -103,8 +91,6 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]): self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""") count = self._cursor.fetchone()[0] - finally: - self._lock.release() pageCount = int(count / per_page) + 1 @@ -115,8 +101,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]): def search( self, query: str, page: int = 0, per_page: int = 10 ) -> PaginatedResults[T]: - try: - self._lock.acquire() + with self._lock: self._cursor.execute( f"""SELECT item FROM {self._table_name} WHERE item LIKE ? LIMIT ? OFFSET ?;""", (f"%{query}%", per_page, page * per_page), @@ -130,8 +115,6 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]): (f"%{query}%",), ) count = self._cursor.fetchone()[0] - finally: - self._lock.release() pageCount = int(count / per_page) + 1 diff --git a/pyproject.toml b/pyproject.toml index 3d72483237..ec6aabfb8b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ dependencies = [ "clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip", "compel==1.0.5", "datasets", - "diffusers[torch]~=0.14", + "diffusers[torch]==0.14", "dnspython==2.2.1", "einops", "eventlet", diff --git a/tests/nodes/test_graph_execution_state.py b/tests/nodes/test_graph_execution_state.py index 506b8653f8..f65129797e 100644 --- a/tests/nodes/test_graph_execution_state.py +++ b/tests/nodes/test_graph_execution_state.py @@ -7,7 +7,7 @@ from invokeai.app.services.processor import DefaultInvocationProcessor from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory from invokeai.app.services.invocation_queue import MemoryInvocationQueue from invokeai.app.services.invocation_services import InvocationServices -from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState +from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, LibraryGraph, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState import pytest @@ -28,6 +28,9 @@ def mock_services(): images = None, # type: ignore latents = None, # type: ignore queue = MemoryInvocationQueue(), + graph_library=SqliteItemStorage[LibraryGraph]( + filename=sqlite_memory, table_name="graphs" + ), graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), processor = DefaultInvocationProcessor(), restoration = None, # type: ignore diff --git a/tests/nodes/test_invoker.py b/tests/nodes/test_invoker.py index 68df708bdd..46d532b9f7 100644 --- a/tests/nodes/test_invoker.py +++ b/tests/nodes/test_invoker.py @@ -5,7 +5,7 @@ from invokeai.app.services.invocation_queue import MemoryInvocationQueue from invokeai.app.services.invoker import Invoker from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext from invokeai.app.services.invocation_services import InvocationServices -from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState +from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, LibraryGraph, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState import pytest @@ -26,6 +26,9 @@ def mock_services() -> InvocationServices: images = None, # type: ignore latents = None, # type: ignore queue = MemoryInvocationQueue(), + graph_library=SqliteItemStorage[LibraryGraph]( + filename=sqlite_memory, table_name="graphs" + ), graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), processor = DefaultInvocationProcessor(), restoration = None, # type: ignore diff --git a/tests/nodes/test_node_graph.py b/tests/nodes/test_node_graph.py index b864e1e47a..c7693b59c9 100644 --- a/tests/nodes/test_node_graph.py +++ b/tests/nodes/test_node_graph.py @@ -1,9 +1,11 @@ -from invokeai.app.invocations.image import * - from .test_nodes import ListPassThroughInvocation, PromptTestInvocation from invokeai.app.services.graph import Edge, Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation from invokeai.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation from invokeai.app.invocations.upscale import UpscaleInvocation +from invokeai.app.invocations.image import * +from invokeai.app.invocations.math import AddInvocation, SubtractInvocation +from invokeai.app.invocations.params import ParamIntInvocation +from invokeai.app.services.default_graphs import create_text_to_image import pytest @@ -417,6 +419,66 @@ def test_graph_gets_subgraph_node(): assert result.id == '1' assert result == n1_1 + +def test_graph_expands_subgraph(): + g = Graph() + n1 = GraphInvocation(id = "1") + n1.graph = Graph() + + n1_1 = AddInvocation(id = "1", a = 1, b = 2) + n1_2 = SubtractInvocation(id = "2", b = 3) + n1.graph.add_node(n1_1) + n1.graph.add_node(n1_2) + n1.graph.add_edge(create_edge("1","a","2","a")) + + g.add_node(n1) + + n2 = AddInvocation(id = "2", b = 5) + g.add_node(n2) + g.add_edge(create_edge("1.2","a","2","a")) + + dg = g.nx_graph_flat() + assert set(dg.nodes) == set(['1.1', '1.2', '2']) + assert set(dg.edges) == set([('1.1', '1.2'), ('1.2', '2')]) + + +def test_graph_subgraph_t2i(): + g = Graph() + n1 = GraphInvocation(id = "1") + + # Get text to image default graph + lg = create_text_to_image() + n1.graph = lg.graph + + g.add_node(n1) + + n2 = ParamIntInvocation(id = "2", a = 512) + n3 = ParamIntInvocation(id = "3", a = 256) + + g.add_node(n2) + g.add_node(n3) + + g.add_edge(create_edge("2","a","1.width","a")) + g.add_edge(create_edge("3","a","1.height","a")) + + n4 = ShowImageInvocation(id = "4") + g.add_node(n4) + g.add_edge(create_edge("1.5","image","4","image")) + + # Validate + dg = g.nx_graph_flat() + assert set(dg.nodes) == set(['1.width', '1.height', '1.3', '1.4', '1.5', '2', '3', '4']) + expected_edges = [(f'1.{e.source.node_id}',f'1.{e.destination.node_id}') for e in lg.graph.edges] + expected_edges.extend([ + ('2','1.width'), + ('3','1.height'), + ('1.5','4') + ]) + print(expected_edges) + print(list(dg.edges)) + assert set(dg.edges) == set(expected_edges) + + def test_graph_fails_to_get_missing_subgraph_node(): g = Graph() n1 = GraphInvocation(id = "1")