mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into enhance/invokeai-logs
- resolve conflicts - remove unused code identified by pyflakes
This commit is contained in:
@ -13,18 +13,21 @@ from typing import (
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import Field
|
||||
|
||||
|
||||
import invokeai.backend.util.logging as log
|
||||
from invokeai.app.services.metadata import PngMetadataService
|
||||
from .services.default_graphs import create_system_graphs
|
||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
|
||||
from ..backend import Args
|
||||
from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_graph_execution_history
|
||||
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers
|
||||
from .cli.completer import set_autocompleter
|
||||
from .invocations import *
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
from .services.events import EventServiceBase
|
||||
from .services.model_manager_initializer import get_model_manager
|
||||
from .services.restoration_services import RestorationServices
|
||||
from .services.graph import Edge, EdgeConnection, GraphExecutionState, are_connection_types_compatible
|
||||
from .services.graph import Edge, EdgeConnection, GraphExecutionState, GraphInvocation, LibraryGraph, are_connection_types_compatible
|
||||
from .services.default_graphs import default_text_to_image_graph_id
|
||||
from .services.image_storage import DiskImageStorage
|
||||
from .services.invocation_queue import MemoryInvocationQueue
|
||||
from .services.invocation_services import InvocationServices
|
||||
@ -59,7 +62,7 @@ def add_invocation_args(command_parser):
|
||||
)
|
||||
|
||||
|
||||
def get_command_parser() -> argparse.ArgumentParser:
|
||||
def get_command_parser(services: InvocationServices) -> argparse.ArgumentParser:
|
||||
# Create invocation parser
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
@ -77,20 +80,72 @@ def get_command_parser() -> argparse.ArgumentParser:
|
||||
commands = BaseCommand.get_all_subclasses()
|
||||
add_parsers(subparsers, commands, exclude_fields=["type"])
|
||||
|
||||
# Create subparsers for exposed CLI graphs
|
||||
# TODO: add a way to identify these graphs
|
||||
text_to_image = services.graph_library.get(default_text_to_image_graph_id)
|
||||
add_graph_parsers(subparsers, [text_to_image], add_arguments=add_invocation_args)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
class NodeField():
|
||||
alias: str
|
||||
node_path: str
|
||||
field: str
|
||||
field_type: type
|
||||
|
||||
def __init__(self, alias: str, node_path: str, field: str, field_type: type):
|
||||
self.alias = alias
|
||||
self.node_path = node_path
|
||||
self.field = field
|
||||
self.field_type = field_type
|
||||
|
||||
|
||||
def fields_from_type_hints(hints: dict[str, type], node_path: str) -> dict[str,NodeField]:
|
||||
return {k:NodeField(alias=k, node_path=node_path, field=k, field_type=v) for k, v in hints.items()}
|
||||
|
||||
|
||||
def get_node_input_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
|
||||
"""Gets the node field for the specified field alias"""
|
||||
exposed_input = next(e for e in graph.exposed_inputs if e.alias == field_alias)
|
||||
node_type = type(graph.graph.get_node(exposed_input.node_path))
|
||||
return NodeField(alias=exposed_input.alias, node_path=f'{node_id}.{exposed_input.node_path}', field=exposed_input.field, field_type=get_type_hints(node_type)[exposed_input.field])
|
||||
|
||||
|
||||
def get_node_output_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
|
||||
"""Gets the node field for the specified field alias"""
|
||||
exposed_output = next(e for e in graph.exposed_outputs if e.alias == field_alias)
|
||||
node_type = type(graph.graph.get_node(exposed_output.node_path))
|
||||
node_output_type = node_type.get_output_type()
|
||||
return NodeField(alias=exposed_output.alias, node_path=f'{node_id}.{exposed_output.node_path}', field=exposed_output.field, field_type=get_type_hints(node_output_type)[exposed_output.field])
|
||||
|
||||
|
||||
def get_node_inputs(invocation: BaseInvocation, context: CliContext) -> dict[str, NodeField]:
|
||||
"""Gets the inputs for the specified invocation from the context"""
|
||||
node_type = type(invocation)
|
||||
if node_type is not GraphInvocation:
|
||||
return fields_from_type_hints(get_type_hints(node_type), invocation.id)
|
||||
else:
|
||||
graph: LibraryGraph = context.invoker.services.graph_library.get(context.graph_nodes[invocation.id])
|
||||
return {e.alias: get_node_input_field(graph, e.alias, invocation.id) for e in graph.exposed_inputs}
|
||||
|
||||
|
||||
def get_node_outputs(invocation: BaseInvocation, context: CliContext) -> dict[str, NodeField]:
|
||||
"""Gets the outputs for the specified invocation from the context"""
|
||||
node_type = type(invocation)
|
||||
if node_type is not GraphInvocation:
|
||||
return fields_from_type_hints(get_type_hints(node_type.get_output_type()), invocation.id)
|
||||
else:
|
||||
graph: LibraryGraph = context.invoker.services.graph_library.get(context.graph_nodes[invocation.id])
|
||||
return {e.alias: get_node_output_field(graph, e.alias, invocation.id) for e in graph.exposed_outputs}
|
||||
|
||||
|
||||
def generate_matching_edges(
|
||||
a: BaseInvocation, b: BaseInvocation
|
||||
a: BaseInvocation, b: BaseInvocation, context: CliContext
|
||||
) -> list[Edge]:
|
||||
"""Generates all possible edges between two invocations"""
|
||||
atype = type(a)
|
||||
btype = type(b)
|
||||
|
||||
aoutputtype = atype.get_output_type()
|
||||
|
||||
afields = get_type_hints(aoutputtype)
|
||||
bfields = get_type_hints(btype)
|
||||
afields = get_node_outputs(a, context)
|
||||
bfields = get_node_inputs(b, context)
|
||||
|
||||
matching_fields = set(afields.keys()).intersection(bfields.keys())
|
||||
|
||||
@ -99,14 +154,14 @@ def generate_matching_edges(
|
||||
matching_fields = matching_fields.difference(invalid_fields)
|
||||
|
||||
# Validate types
|
||||
matching_fields = [f for f in matching_fields if are_connection_types_compatible(afields[f], bfields[f])]
|
||||
matching_fields = [f for f in matching_fields if are_connection_types_compatible(afields[f].field_type, bfields[f].field_type)]
|
||||
|
||||
edges = [
|
||||
Edge(
|
||||
source=EdgeConnection(node_id=a.id, field=field),
|
||||
destination=EdgeConnection(node_id=b.id, field=field)
|
||||
source=EdgeConnection(node_id=afields[alias].node_path, field=afields[alias].field),
|
||||
destination=EdgeConnection(node_id=bfields[alias].node_path, field=bfields[alias].field)
|
||||
)
|
||||
for field in matching_fields
|
||||
for alias in matching_fields
|
||||
]
|
||||
return edges
|
||||
|
||||
@ -142,10 +197,12 @@ def invoke_cli():
|
||||
# Currently nothing is done with the returned Completer
|
||||
# object, but the object can be used to change autocompletion
|
||||
# behavior on the fly, if desired.
|
||||
completer = set_autocompleter(model_manager)
|
||||
set_autocompleter(model_manager)
|
||||
|
||||
events = EventServiceBase()
|
||||
|
||||
metadata = PngMetadataService()
|
||||
|
||||
output_folder = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "../../../outputs")
|
||||
)
|
||||
@ -157,8 +214,12 @@ def invoke_cli():
|
||||
model_manager=model_manager,
|
||||
events=events,
|
||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
|
||||
images=DiskImageStorage(f'{output_folder}/images'),
|
||||
images=DiskImageStorage(f'{output_folder}/images', metadata_service=metadata),
|
||||
metadata=metadata,
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](
|
||||
filename=db_location, table_name="graphs"
|
||||
),
|
||||
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
),
|
||||
@ -166,9 +227,12 @@ def invoke_cli():
|
||||
restoration=RestorationServices(config),
|
||||
)
|
||||
|
||||
system_graphs = create_system_graphs(services.graph_library)
|
||||
system_graph_names = set([g.name for g in system_graphs])
|
||||
|
||||
invoker = Invoker(services)
|
||||
session: GraphExecutionState = invoker.create_execution_state()
|
||||
parser = get_command_parser()
|
||||
parser = get_command_parser(services)
|
||||
|
||||
re_negid = re.compile('^-[0-9]+$')
|
||||
|
||||
@ -186,11 +250,12 @@ def invoke_cli():
|
||||
|
||||
try:
|
||||
# Refresh the state of the session
|
||||
history = list(get_graph_execution_history(context.session))
|
||||
#history = list(get_graph_execution_history(context.session))
|
||||
history = list(reversed(context.nodes_added))
|
||||
|
||||
# Split the command for piping
|
||||
cmds = cmd_input.split("|")
|
||||
start_id = len(history)
|
||||
start_id = len(context.nodes_added)
|
||||
current_id = start_id
|
||||
new_invocations = list()
|
||||
for cmd in cmds:
|
||||
@ -206,8 +271,24 @@ def invoke_cli():
|
||||
args[field_name] = field_default
|
||||
|
||||
# Parse invocation
|
||||
args["id"] = current_id
|
||||
command = CliCommand(command=args)
|
||||
command: CliCommand = None # type:ignore
|
||||
system_graph: LibraryGraph|None = None
|
||||
if args['type'] in system_graph_names:
|
||||
system_graph = next(filter(lambda g: g.name == args['type'], system_graphs))
|
||||
invocation = GraphInvocation(graph=system_graph.graph, id=str(current_id))
|
||||
for exposed_input in system_graph.exposed_inputs:
|
||||
if exposed_input.alias in args:
|
||||
node = invocation.graph.get_node(exposed_input.node_path)
|
||||
field = exposed_input.field
|
||||
setattr(node, field, args[exposed_input.alias])
|
||||
command = CliCommand(command = invocation)
|
||||
context.graph_nodes[invocation.id] = system_graph.id
|
||||
else:
|
||||
args["id"] = current_id
|
||||
command = CliCommand(command=args)
|
||||
|
||||
if command is None:
|
||||
continue
|
||||
|
||||
# Run any CLI commands immediately
|
||||
if isinstance(command.command, BaseCommand):
|
||||
@ -218,6 +299,7 @@ def invoke_cli():
|
||||
command.command.run(context)
|
||||
continue
|
||||
|
||||
# TODO: handle linking with library graphs
|
||||
# Pipe previous command output (if there was a previous command)
|
||||
edges: list[Edge] = list()
|
||||
if len(history) > 0 or current_id != start_id:
|
||||
@ -230,7 +312,7 @@ def invoke_cli():
|
||||
else context.session.graph.get_node(from_id)
|
||||
)
|
||||
matching_edges = generate_matching_edges(
|
||||
from_node, command.command
|
||||
from_node, command.command, context
|
||||
)
|
||||
edges.extend(matching_edges)
|
||||
|
||||
@ -243,7 +325,7 @@ def invoke_cli():
|
||||
|
||||
link_node = context.session.graph.get_node(node_id)
|
||||
matching_edges = generate_matching_edges(
|
||||
link_node, command.command
|
||||
link_node, command.command, context
|
||||
)
|
||||
matching_destinations = [e.destination for e in matching_edges]
|
||||
edges = [e for e in edges if e.destination not in matching_destinations]
|
||||
@ -257,12 +339,14 @@ def invoke_cli():
|
||||
if re_negid.match(node_id):
|
||||
node_id = str(current_id + int(node_id))
|
||||
|
||||
# TODO: handle missing input/output
|
||||
node_output = get_node_outputs(context.session.graph.get_node(node_id), context)[link[1]]
|
||||
node_input = get_node_inputs(command.command, context)[link[2]]
|
||||
|
||||
edges.append(
|
||||
Edge(
|
||||
source=EdgeConnection(node_id=node_id, field=link[1]),
|
||||
destination=EdgeConnection(
|
||||
node_id=command.command.id, field=link[2]
|
||||
)
|
||||
source=EdgeConnection(node_id=node_output.node_path, field=node_output.field),
|
||||
destination=EdgeConnection(node_id=node_input.node_path, field=node_input.field)
|
||||
)
|
||||
)
|
||||
|
||||
@ -271,10 +355,10 @@ def invoke_cli():
|
||||
current_id = current_id + 1
|
||||
|
||||
# Add the node to the session
|
||||
context.session.add_node(command.command)
|
||||
context.add_node(command.command)
|
||||
for edge in edges:
|
||||
print(edge)
|
||||
context.session.add_edge(edge)
|
||||
context.add_edge(edge)
|
||||
|
||||
# Execute all remaining nodes
|
||||
invoke_all(context)
|
||||
@ -286,7 +370,7 @@ def invoke_cli():
|
||||
except SessionError:
|
||||
# Start a new session
|
||||
log.warning("Session error: creating a new session")
|
||||
context.session = context.invoker.create_execution_state()
|
||||
context.reset()
|
||||
|
||||
except ExitCli:
|
||||
break
|
||||
|
Reference in New Issue
Block a user