mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
304 lines
11 KiB
Python
304 lines
11 KiB
Python
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
|
|
|
import argparse
|
|
import shlex
|
|
import os
|
|
import time
|
|
from typing import Any, Dict, Iterable, Literal, Union, get_args, get_origin, get_type_hints
|
|
from pydantic import BaseModel
|
|
from pydantic.fields import Field
|
|
|
|
from .services.processor import DefaultInvocationProcessor
|
|
|
|
from .services.graph import EdgeConnection, GraphExecutionState
|
|
|
|
from .services.sqlite import SqliteItemStorage
|
|
|
|
from .invocations.image import ImageField
|
|
from .services.generate_initializer import get_generate
|
|
from .services.image_storage import DiskImageStorage
|
|
from .services.invocation_queue import MemoryInvocationQueue
|
|
from .invocations.baseinvocation import BaseInvocation
|
|
from .services.invocation_services import InvocationServices
|
|
from .services.invoker import Invoker
|
|
from .invocations import *
|
|
from ..args import Args
|
|
from .services.events import EventServiceBase
|
|
|
|
|
|
class InvocationCommand(BaseModel):
|
|
invocation: Union[BaseInvocation.get_invocations()] = Field(discriminator="type")
|
|
|
|
|
|
class InvalidArgs(Exception):
|
|
pass
|
|
|
|
|
|
def get_invocation_parser() -> argparse.ArgumentParser:
|
|
|
|
# Create invocation parser
|
|
parser = argparse.ArgumentParser()
|
|
def exit(*args, **kwargs):
|
|
raise InvalidArgs
|
|
parser.exit = exit
|
|
|
|
subparsers = parser.add_subparsers(dest='type')
|
|
invocation_parsers = dict()
|
|
|
|
# Add history parser
|
|
history_parser = subparsers.add_parser('history', help="Shows the invocation history")
|
|
history_parser.add_argument('count', nargs='?', default=5, type=int, help="The number of history entries to show")
|
|
|
|
# Add default parser
|
|
default_parser = subparsers.add_parser('default', help="Define a default value for all inputs with a specified name")
|
|
default_parser.add_argument('input', type=str, help="The input field")
|
|
default_parser.add_argument('value', help="The default value")
|
|
|
|
default_parser = subparsers.add_parser('reset_default', help="Resets a default value")
|
|
default_parser.add_argument('input', type=str, help="The input field")
|
|
|
|
# Create subparsers for each invocation
|
|
invocations = BaseInvocation.get_all_subclasses()
|
|
for invocation in invocations:
|
|
hints = get_type_hints(invocation)
|
|
cmd_name = get_args(hints['type'])[0]
|
|
command_parser = subparsers.add_parser(cmd_name, help=invocation.__doc__)
|
|
invocation_parsers[cmd_name] = command_parser
|
|
|
|
# Add linking capability
|
|
command_parser.add_argument('--link', '-l', action='append', nargs=3,
|
|
help="A link in the format 'dest_field source_node source_field'. source_node can be relative to history (e.g. -1)")
|
|
|
|
command_parser.add_argument('--link_node', '-ln', action='append',
|
|
help="A link from all fields in the specified node. Node can be relative to history (e.g. -1)")
|
|
|
|
# Convert all fields to arguments
|
|
fields = invocation.__fields__
|
|
for name, field in fields.items():
|
|
if name in ['id', 'type']:
|
|
continue
|
|
|
|
if get_origin(field.type_) == Literal:
|
|
allowed_values = get_args(field.type_)
|
|
allowed_types = set()
|
|
for val in allowed_values:
|
|
allowed_types.add(type(val))
|
|
allowed_types_list = list(allowed_types)
|
|
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list]
|
|
|
|
command_parser.add_argument(
|
|
f"--{name}",
|
|
dest=name,
|
|
type=field_type,
|
|
default=field.default,
|
|
choices = allowed_values,
|
|
help=field.field_info.description
|
|
)
|
|
else:
|
|
command_parser.add_argument(
|
|
f"--{name}",
|
|
dest=name,
|
|
type=field.type_,
|
|
default=field.default,
|
|
help=field.field_info.description
|
|
)
|
|
|
|
return parser
|
|
|
|
|
|
def get_invocation_command(invocation) -> str:
|
|
fields = invocation.__fields__.items()
|
|
type_hints = get_type_hints(type(invocation))
|
|
command = [invocation.type]
|
|
for name,field in fields:
|
|
if name in ['id', 'type']:
|
|
continue
|
|
|
|
# TODO: add links
|
|
|
|
# Skip image fields when serializing command
|
|
type_hint = type_hints.get(name) or None
|
|
if type_hint is ImageField or ImageField in get_args(type_hint):
|
|
continue
|
|
|
|
field_value = getattr(invocation, name)
|
|
field_default = field.default
|
|
if field_value != field_default:
|
|
if type_hint is str or str in get_args(type_hint):
|
|
command.append(f'--{name} "{field_value}"')
|
|
else:
|
|
command.append(f'--{name} {field_value}')
|
|
|
|
return ' '.join(command)
|
|
|
|
|
|
def get_graph_execution_history(graph_execution_state: GraphExecutionState) -> Iterable[str]:
|
|
"""Gets the history of fully-executed invocations for a graph execution"""
|
|
return (n for n in reversed(graph_execution_state.executed_history) if n in graph_execution_state.graph.nodes)
|
|
|
|
|
|
def generate_matching_edges(a: BaseInvocation, b: BaseInvocation) -> list[tuple[EdgeConnection, EdgeConnection]]:
|
|
"""Generates all possible edges between two invocations"""
|
|
atype = type(a)
|
|
btype = type(b)
|
|
|
|
aoutputtype = atype.get_output_type()
|
|
|
|
afields = get_type_hints(aoutputtype)
|
|
bfields = get_type_hints(btype)
|
|
|
|
matching_fields = set(afields.keys()).intersection(bfields.keys())
|
|
|
|
# Remove invalid fields
|
|
invalid_fields = set(['type', 'id'])
|
|
matching_fields = matching_fields.difference(invalid_fields)
|
|
|
|
edges = [(EdgeConnection(node_id = a.id, field = field), EdgeConnection(node_id = b.id, field = field)) for field in matching_fields]
|
|
return edges
|
|
|
|
|
|
def invoke_cli():
|
|
args = Args()
|
|
config = args.parse_args()
|
|
|
|
generate = get_generate(args, config)
|
|
|
|
# NOTE: load model on first use, uncomment to load at startup
|
|
# TODO: Make this a config option?
|
|
#generate.load_model()
|
|
|
|
events = EventServiceBase()
|
|
|
|
output_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../outputs'))
|
|
|
|
# TODO: build a file/path manager?
|
|
db_location = os.path.join(output_folder, 'invokeai.db')
|
|
|
|
services = InvocationServices(
|
|
generate = generate,
|
|
events = events,
|
|
images = DiskImageStorage(output_folder),
|
|
queue = MemoryInvocationQueue(),
|
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = db_location, table_name = 'graph_executions'),
|
|
processor = DefaultInvocationProcessor()
|
|
)
|
|
|
|
invoker = Invoker(services)
|
|
session = invoker.create_execution_state()
|
|
|
|
parser = get_invocation_parser()
|
|
|
|
# Uncomment to print out previous sessions at startup
|
|
# print(services.session_manager.list())
|
|
|
|
# Defaults storage
|
|
defaults: Dict[str, Any] = dict()
|
|
|
|
while True:
|
|
try:
|
|
cmd_input = input("> ")
|
|
except KeyboardInterrupt:
|
|
# Ctrl-c exits
|
|
break
|
|
|
|
if cmd_input in ['exit','q']:
|
|
break;
|
|
|
|
if cmd_input in ['--help','help','h','?']:
|
|
parser.print_help()
|
|
continue
|
|
|
|
try:
|
|
# Refresh the state of the session
|
|
session = invoker.services.graph_execution_manager.get(session.id)
|
|
history = list(get_graph_execution_history(session))
|
|
|
|
# Split the command for piping
|
|
cmds = cmd_input.split('|')
|
|
start_id = len(history)
|
|
current_id = start_id
|
|
new_invocations = list()
|
|
for cmd in cmds:
|
|
# Parse args to create invocation
|
|
args = vars(parser.parse_args(shlex.split(cmd.strip())))
|
|
|
|
# Check for special commands
|
|
# TODO: These might be better as Pydantic models, similar to the invocations
|
|
if args['type'] == 'history':
|
|
history_count = args['count'] or 5
|
|
for i in range(min(history_count, len(history))):
|
|
entry_id = history[-1 - i]
|
|
entry = session.graph.get_node(entry_id)
|
|
print(f'{entry_id}: {get_invocation_command(entry.invocation)}')
|
|
continue
|
|
|
|
if args['type'] == 'reset_default':
|
|
if args['input'] in defaults:
|
|
del defaults[args['input']]
|
|
continue
|
|
|
|
if args['type'] == 'default':
|
|
field = args['input']
|
|
field_value = args['value']
|
|
defaults[field] = field_value
|
|
continue
|
|
|
|
# Override defaults
|
|
for field_name,field_default in defaults.items():
|
|
if field_name in args:
|
|
args[field_name] = field_default
|
|
|
|
# Parse invocation
|
|
args['id'] = current_id
|
|
command = InvocationCommand(invocation = args)
|
|
|
|
# Pipe previous command output (if there was a previous command)
|
|
edges = []
|
|
if len(history) > 0 or current_id != start_id:
|
|
from_id = history[0] if current_id == start_id else str(current_id - 1)
|
|
from_node = next(filter(lambda n: n[0].id == from_id, new_invocations))[0] if current_id != start_id else session.graph.get_node(from_id)
|
|
matching_edges = generate_matching_edges(from_node, command.invocation)
|
|
edges.extend(matching_edges)
|
|
|
|
# Parse provided links
|
|
if 'link_node' in args and args['link_node']:
|
|
for link in args['link_node']:
|
|
link_node = session.graph.get_node(link)
|
|
matching_edges = generate_matching_edges(link_node, command.invocation)
|
|
edges.extend(matching_edges)
|
|
|
|
if 'link' in args and args['link']:
|
|
for link in args['link']:
|
|
edges.append((EdgeConnection(node_id = link[1], field = link[0]), EdgeConnection(node_id = command.invocation.id, field = link[2])))
|
|
|
|
new_invocations.append((command.invocation, edges))
|
|
|
|
current_id = current_id + 1
|
|
|
|
# Command line was parsed successfully
|
|
# Add the invocations to the session
|
|
for invocation in new_invocations:
|
|
session.add_node(invocation[0])
|
|
for edge in invocation[1]:
|
|
session.add_edge(edge)
|
|
|
|
# Execute all available invocations
|
|
invoker.invoke(session, invoke_all = True)
|
|
while not session.is_complete():
|
|
# Wait some time
|
|
session = invoker.services.graph_execution_manager.get(session.id)
|
|
time.sleep(0.1)
|
|
|
|
except InvalidArgs:
|
|
print('Invalid command, use "help" to list commands')
|
|
continue
|
|
|
|
except SystemExit:
|
|
continue
|
|
|
|
invoker.stop()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
invoke_cli()
|