mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
357601e2d6
author Kyle Schouviller <kyle0654@hotmail.com> 1669872800 -0800 committer Kyle Schouviller <kyle0654@hotmail.com> 1676240900 -0800 Adding base node architecture Fix type annotation errors Runs and generates, but breaks in saving session Fix default model value setting. Fix deprecation warning. Fixed node api Adding markdown docs Simplifying Generate construction in apps [nodes] A few minor changes (#2510) * Pin api-related requirements * Remove confusing extra CORS origins list * Adds response models for HTTP 200 [nodes] Adding graph_execution_state to soon replace session. Adding tests with pytest. Minor typing fixes [nodes] Fix some small output query hookups [node] Fixing some additional typing issues [nodes] Move and expand graph code. Add base item storage and sqlite implementation. Update startup to match new code [nodes] Add callbacks to item storage [nodes] Adding an InvocationContext object to use for invocations to provide easier extensibility [nodes] New execution model that handles iteration [nodes] Fixing the CLI [nodes] Adding a note to the CLI [nodes] Split processing thread into separate service [node] Add error message on node processing failure Removing old files and duplicated packages Adding python-multipart
307 lines
11 KiB
Python
307 lines
11 KiB
Python
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
|
|
|
import argparse
|
|
import shlex
|
|
import os
|
|
import time
|
|
from typing import Any, Dict, Iterable, Literal, Union, get_args, get_origin, get_type_hints
|
|
from pydantic import BaseModel
|
|
from pydantic.fields import Field
|
|
|
|
from .services.processor import DefaultInvocationProcessor
|
|
|
|
from .services.graph import EdgeConnection, GraphExecutionState
|
|
|
|
from .services.sqlite import SqliteItemStorage
|
|
|
|
from .invocations.image import ImageField
|
|
from .services.generate_initializer import get_generate
|
|
from .services.image_storage import DiskImageStorage
|
|
from .services.invocation_queue import MemoryInvocationQueue
|
|
from .invocations.baseinvocation import BaseInvocation
|
|
from .services.invocation_services import InvocationServices
|
|
from .services.invoker import Invoker, InvokerServices
|
|
from .invocations import *
|
|
from ..args import Args
|
|
from .services.events import EventServiceBase
|
|
|
|
|
|
class InvocationCommand(BaseModel):
|
|
invocation: Union[BaseInvocation.get_invocations()] = Field(discriminator="type")
|
|
|
|
|
|
class InvalidArgs(Exception):
|
|
pass
|
|
|
|
|
|
def get_invocation_parser() -> argparse.ArgumentParser:
|
|
|
|
# Create invocation parser
|
|
parser = argparse.ArgumentParser()
|
|
def exit(*args, **kwargs):
|
|
raise InvalidArgs
|
|
parser.exit = exit
|
|
|
|
subparsers = parser.add_subparsers(dest='type')
|
|
invocation_parsers = dict()
|
|
|
|
# Add history parser
|
|
history_parser = subparsers.add_parser('history', help="Shows the invocation history")
|
|
history_parser.add_argument('count', nargs='?', default=5, type=int, help="The number of history entries to show")
|
|
|
|
# Add default parser
|
|
default_parser = subparsers.add_parser('default', help="Define a default value for all inputs with a specified name")
|
|
default_parser.add_argument('input', type=str, help="The input field")
|
|
default_parser.add_argument('value', help="The default value")
|
|
|
|
default_parser = subparsers.add_parser('reset_default', help="Resets a default value")
|
|
default_parser.add_argument('input', type=str, help="The input field")
|
|
|
|
# Create subparsers for each invocation
|
|
invocations = BaseInvocation.get_all_subclasses()
|
|
for invocation in invocations:
|
|
hints = get_type_hints(invocation)
|
|
cmd_name = get_args(hints['type'])[0]
|
|
command_parser = subparsers.add_parser(cmd_name, help=invocation.__doc__)
|
|
invocation_parsers[cmd_name] = command_parser
|
|
|
|
# Add linking capability
|
|
command_parser.add_argument('--link', '-l', action='append', nargs=3,
|
|
help="A link in the format 'dest_field source_node source_field'. source_node can be relative to history (e.g. -1)")
|
|
|
|
command_parser.add_argument('--link_node', '-ln', action='append',
|
|
help="A link from all fields in the specified node. Node can be relative to history (e.g. -1)")
|
|
|
|
# Convert all fields to arguments
|
|
fields = invocation.__fields__
|
|
for name, field in fields.items():
|
|
if name in ['id', 'type']:
|
|
continue
|
|
|
|
if get_origin(field.type_) == Literal:
|
|
allowed_values = get_args(field.type_)
|
|
allowed_types = set()
|
|
for val in allowed_values:
|
|
allowed_types.add(type(val))
|
|
allowed_types_list = list(allowed_types)
|
|
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list]
|
|
|
|
command_parser.add_argument(
|
|
f"--{name}",
|
|
dest=name,
|
|
type=field_type,
|
|
default=field.default,
|
|
choices = allowed_values,
|
|
help=field.field_info.description
|
|
)
|
|
else:
|
|
command_parser.add_argument(
|
|
f"--{name}",
|
|
dest=name,
|
|
type=field.type_,
|
|
default=field.default,
|
|
help=field.field_info.description
|
|
)
|
|
|
|
return parser
|
|
|
|
|
|
def get_invocation_command(invocation) -> str:
|
|
fields = invocation.__fields__.items()
|
|
type_hints = get_type_hints(type(invocation))
|
|
command = [invocation.type]
|
|
for name,field in fields:
|
|
if name in ['id', 'type']:
|
|
continue
|
|
|
|
# TODO: add links
|
|
|
|
# Skip image fields when serializing command
|
|
type_hint = type_hints.get(name) or None
|
|
if type_hint is ImageField or ImageField in get_args(type_hint):
|
|
continue
|
|
|
|
field_value = getattr(invocation, name)
|
|
field_default = field.default
|
|
if field_value != field_default:
|
|
if type_hint is str or str in get_args(type_hint):
|
|
command.append(f'--{name} "{field_value}"')
|
|
else:
|
|
command.append(f'--{name} {field_value}')
|
|
|
|
return ' '.join(command)
|
|
|
|
|
|
def get_graph_execution_history(graph_execution_state: GraphExecutionState) -> Iterable[str]:
|
|
"""Gets the history of fully-executed invocations for a graph execution"""
|
|
return (n for n in reversed(graph_execution_state.executed_history) if n in graph_execution_state.graph.nodes)
|
|
|
|
|
|
def generate_matching_edges(a: BaseInvocation, b: BaseInvocation) -> list[tuple[EdgeConnection, EdgeConnection]]:
|
|
"""Generates all possible edges between two invocations"""
|
|
atype = type(a)
|
|
btype = type(b)
|
|
|
|
aoutputtype = atype.get_output_type()
|
|
|
|
afields = get_type_hints(aoutputtype)
|
|
bfields = get_type_hints(btype)
|
|
|
|
matching_fields = set(afields.keys()).intersection(bfields.keys())
|
|
|
|
# Remove invalid fields
|
|
invalid_fields = set(['type', 'id'])
|
|
matching_fields = matching_fields.difference(invalid_fields)
|
|
|
|
edges = [(EdgeConnection(node_id = a.id, field = field), EdgeConnection(node_id = b.id, field = field)) for field in matching_fields]
|
|
return edges
|
|
|
|
|
|
def invoke_cli():
|
|
args = Args()
|
|
config = args.parse_args()
|
|
|
|
generate = get_generate(args, config)
|
|
|
|
# NOTE: load model on first use, uncomment to load at startup
|
|
# TODO: Make this a config option?
|
|
#generate.load_model()
|
|
|
|
events = EventServiceBase()
|
|
|
|
output_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../outputs'))
|
|
|
|
services = InvocationServices(
|
|
generate = generate,
|
|
events = events,
|
|
images = DiskImageStorage(output_folder)
|
|
)
|
|
|
|
# TODO: build a file/path manager?
|
|
db_location = os.path.join(output_folder, 'invokeai.db')
|
|
|
|
invoker_services = InvokerServices(
|
|
queue = MemoryInvocationQueue(),
|
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = db_location, table_name = 'graph_executions'),
|
|
processor = DefaultInvocationProcessor()
|
|
)
|
|
|
|
invoker = Invoker(services, invoker_services)
|
|
session = invoker.create_execution_state()
|
|
|
|
parser = get_invocation_parser()
|
|
|
|
# Uncomment to print out previous sessions at startup
|
|
# print(invoker_services.session_manager.list())
|
|
|
|
# Defaults storage
|
|
defaults: Dict[str, Any] = dict()
|
|
|
|
while True:
|
|
try:
|
|
cmd_input = input("> ")
|
|
except KeyboardInterrupt:
|
|
# Ctrl-c exits
|
|
break
|
|
|
|
if cmd_input in ['exit','q']:
|
|
break;
|
|
|
|
if cmd_input in ['--help','help','h','?']:
|
|
parser.print_help()
|
|
continue
|
|
|
|
try:
|
|
# Refresh the state of the session
|
|
session = invoker.invoker_services.graph_execution_manager.get(session.id)
|
|
history = list(get_graph_execution_history(session))
|
|
|
|
# Split the command for piping
|
|
cmds = cmd_input.split('|')
|
|
start_id = len(history)
|
|
current_id = start_id
|
|
new_invocations = list()
|
|
for cmd in cmds:
|
|
# Parse args to create invocation
|
|
args = vars(parser.parse_args(shlex.split(cmd.strip())))
|
|
|
|
# Check for special commands
|
|
# TODO: These might be better as Pydantic models, similar to the invocations
|
|
if args['type'] == 'history':
|
|
history_count = args['count'] or 5
|
|
for i in range(min(history_count, len(history))):
|
|
entry_id = history[-1 - i]
|
|
entry = session.graph.get_node(entry_id)
|
|
print(f'{entry_id}: {get_invocation_command(entry.invocation)}')
|
|
continue
|
|
|
|
if args['type'] == 'reset_default':
|
|
if args['input'] in defaults:
|
|
del defaults[args['input']]
|
|
continue
|
|
|
|
if args['type'] == 'default':
|
|
field = args['input']
|
|
field_value = args['value']
|
|
defaults[field] = field_value
|
|
continue
|
|
|
|
# Override defaults
|
|
for field_name,field_default in defaults.items():
|
|
if field_name in args:
|
|
args[field_name] = field_default
|
|
|
|
# Parse invocation
|
|
args['id'] = current_id
|
|
command = InvocationCommand(invocation = args)
|
|
|
|
# Pipe previous command output (if there was a previous command)
|
|
edges = []
|
|
if len(history) > 0 or current_id != start_id:
|
|
from_id = history[0] if current_id == start_id else str(current_id - 1)
|
|
from_node = next(filter(lambda n: n[0].id == from_id, new_invocations))[0] if current_id != start_id else session.graph.get_node(from_id)
|
|
matching_edges = generate_matching_edges(from_node, command.invocation)
|
|
edges.extend(matching_edges)
|
|
|
|
# Parse provided links
|
|
if 'link_node' in args and args['link_node']:
|
|
for link in args['link_node']:
|
|
link_node = session.graph.get_node(link)
|
|
matching_edges = generate_matching_edges(link_node, command.invocation)
|
|
edges.extend(matching_edges)
|
|
|
|
if 'link' in args and args['link']:
|
|
for link in args['link']:
|
|
edges.append((EdgeConnection(node_id = link[1], field = link[0]), EdgeConnection(node_id = command.invocation.id, field = link[2])))
|
|
|
|
new_invocations.append((command.invocation, edges))
|
|
|
|
current_id = current_id + 1
|
|
|
|
# Command line was parsed successfully
|
|
# Add the invocations to the session
|
|
for invocation in new_invocations:
|
|
session.add_node(invocation[0])
|
|
for edge in invocation[1]:
|
|
session.add_edge(edge)
|
|
|
|
# Execute all available invocations
|
|
invoker.invoke(session, invoke_all = True)
|
|
while not session.is_complete():
|
|
# Wait some time
|
|
session = invoker.invoker_services.graph_execution_manager.get(session.id)
|
|
time.sleep(0.1)
|
|
|
|
except InvalidArgs:
|
|
print('Invalid command, use "help" to list commands')
|
|
continue
|
|
|
|
except SystemExit:
|
|
continue
|
|
|
|
invoker.stop()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
invoke_cli()
|