InvokeAI/ldm/invoke/app/cli_app.py
Kyle Schouviller 357601e2d6
parent 9eed1919c2
author Kyle Schouviller <kyle0654@hotmail.com> 1669872800 -0800
committer Kyle Schouviller <kyle0654@hotmail.com> 1676240900 -0800

Adding base node architecture

Fix type annotation errors

Runs and generates, but breaks in saving session

Fix default model value setting. Fix deprecation warning.

Fixed node api

Adding markdown docs

Simplifying Generate construction in apps

[nodes] A few minor changes (#2510)

* Pin api-related requirements

* Remove confusing extra CORS origins list

* Adds response models for HTTP 200

[nodes] Adding graph_execution_state to soon replace session. Adding tests with pytest.

Minor typing fixes

[nodes] Fix some small output query hookups

[node] Fixing some additional typing issues

[nodes] Move and expand graph code. Add base item storage and sqlite implementation.

Update startup to match new code

[nodes] Add callbacks to item storage

[nodes] Adding an InvocationContext object to use for invocations to provide easier extensibility

[nodes] New execution model that handles iteration

[nodes] Fixing the CLI

[nodes] Adding a note to the CLI

[nodes] Split processing thread into separate service

[node] Add error message on node processing failure

Removing old files and duplicated packages

Adding python-multipart
2023-02-26 21:28:00 +01:00

307 lines
11 KiB
Python

# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import argparse
import shlex
import os
import time
from typing import Any, Dict, Iterable, Literal, Union, get_args, get_origin, get_type_hints
from pydantic import BaseModel
from pydantic.fields import Field
from .services.processor import DefaultInvocationProcessor
from .services.graph import EdgeConnection, GraphExecutionState
from .services.sqlite import SqliteItemStorage
from .invocations.image import ImageField
from .services.generate_initializer import get_generate
from .services.image_storage import DiskImageStorage
from .services.invocation_queue import MemoryInvocationQueue
from .invocations.baseinvocation import BaseInvocation
from .services.invocation_services import InvocationServices
from .services.invoker import Invoker, InvokerServices
from .invocations import *
from ..args import Args
from .services.events import EventServiceBase
class InvocationCommand(BaseModel):
invocation: Union[BaseInvocation.get_invocations()] = Field(discriminator="type")
class InvalidArgs(Exception):
pass
def get_invocation_parser() -> argparse.ArgumentParser:
# Create invocation parser
parser = argparse.ArgumentParser()
def exit(*args, **kwargs):
raise InvalidArgs
parser.exit = exit
subparsers = parser.add_subparsers(dest='type')
invocation_parsers = dict()
# Add history parser
history_parser = subparsers.add_parser('history', help="Shows the invocation history")
history_parser.add_argument('count', nargs='?', default=5, type=int, help="The number of history entries to show")
# Add default parser
default_parser = subparsers.add_parser('default', help="Define a default value for all inputs with a specified name")
default_parser.add_argument('input', type=str, help="The input field")
default_parser.add_argument('value', help="The default value")
default_parser = subparsers.add_parser('reset_default', help="Resets a default value")
default_parser.add_argument('input', type=str, help="The input field")
# Create subparsers for each invocation
invocations = BaseInvocation.get_all_subclasses()
for invocation in invocations:
hints = get_type_hints(invocation)
cmd_name = get_args(hints['type'])[0]
command_parser = subparsers.add_parser(cmd_name, help=invocation.__doc__)
invocation_parsers[cmd_name] = command_parser
# Add linking capability
command_parser.add_argument('--link', '-l', action='append', nargs=3,
help="A link in the format 'dest_field source_node source_field'. source_node can be relative to history (e.g. -1)")
command_parser.add_argument('--link_node', '-ln', action='append',
help="A link from all fields in the specified node. Node can be relative to history (e.g. -1)")
# Convert all fields to arguments
fields = invocation.__fields__
for name, field in fields.items():
if name in ['id', 'type']:
continue
if get_origin(field.type_) == Literal:
allowed_values = get_args(field.type_)
allowed_types = set()
for val in allowed_values:
allowed_types.add(type(val))
allowed_types_list = list(allowed_types)
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list]
command_parser.add_argument(
f"--{name}",
dest=name,
type=field_type,
default=field.default,
choices = allowed_values,
help=field.field_info.description
)
else:
command_parser.add_argument(
f"--{name}",
dest=name,
type=field.type_,
default=field.default,
help=field.field_info.description
)
return parser
def get_invocation_command(invocation) -> str:
fields = invocation.__fields__.items()
type_hints = get_type_hints(type(invocation))
command = [invocation.type]
for name,field in fields:
if name in ['id', 'type']:
continue
# TODO: add links
# Skip image fields when serializing command
type_hint = type_hints.get(name) or None
if type_hint is ImageField or ImageField in get_args(type_hint):
continue
field_value = getattr(invocation, name)
field_default = field.default
if field_value != field_default:
if type_hint is str or str in get_args(type_hint):
command.append(f'--{name} "{field_value}"')
else:
command.append(f'--{name} {field_value}')
return ' '.join(command)
def get_graph_execution_history(graph_execution_state: GraphExecutionState) -> Iterable[str]:
"""Gets the history of fully-executed invocations for a graph execution"""
return (n for n in reversed(graph_execution_state.executed_history) if n in graph_execution_state.graph.nodes)
def generate_matching_edges(a: BaseInvocation, b: BaseInvocation) -> list[tuple[EdgeConnection, EdgeConnection]]:
"""Generates all possible edges between two invocations"""
atype = type(a)
btype = type(b)
aoutputtype = atype.get_output_type()
afields = get_type_hints(aoutputtype)
bfields = get_type_hints(btype)
matching_fields = set(afields.keys()).intersection(bfields.keys())
# Remove invalid fields
invalid_fields = set(['type', 'id'])
matching_fields = matching_fields.difference(invalid_fields)
edges = [(EdgeConnection(node_id = a.id, field = field), EdgeConnection(node_id = b.id, field = field)) for field in matching_fields]
return edges
def invoke_cli():
args = Args()
config = args.parse_args()
generate = get_generate(args, config)
# NOTE: load model on first use, uncomment to load at startup
# TODO: Make this a config option?
#generate.load_model()
events = EventServiceBase()
output_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../outputs'))
services = InvocationServices(
generate = generate,
events = events,
images = DiskImageStorage(output_folder)
)
# TODO: build a file/path manager?
db_location = os.path.join(output_folder, 'invokeai.db')
invoker_services = InvokerServices(
queue = MemoryInvocationQueue(),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = db_location, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor()
)
invoker = Invoker(services, invoker_services)
session = invoker.create_execution_state()
parser = get_invocation_parser()
# Uncomment to print out previous sessions at startup
# print(invoker_services.session_manager.list())
# Defaults storage
defaults: Dict[str, Any] = dict()
while True:
try:
cmd_input = input("> ")
except KeyboardInterrupt:
# Ctrl-c exits
break
if cmd_input in ['exit','q']:
break;
if cmd_input in ['--help','help','h','?']:
parser.print_help()
continue
try:
# Refresh the state of the session
session = invoker.invoker_services.graph_execution_manager.get(session.id)
history = list(get_graph_execution_history(session))
# Split the command for piping
cmds = cmd_input.split('|')
start_id = len(history)
current_id = start_id
new_invocations = list()
for cmd in cmds:
# Parse args to create invocation
args = vars(parser.parse_args(shlex.split(cmd.strip())))
# Check for special commands
# TODO: These might be better as Pydantic models, similar to the invocations
if args['type'] == 'history':
history_count = args['count'] or 5
for i in range(min(history_count, len(history))):
entry_id = history[-1 - i]
entry = session.graph.get_node(entry_id)
print(f'{entry_id}: {get_invocation_command(entry.invocation)}')
continue
if args['type'] == 'reset_default':
if args['input'] in defaults:
del defaults[args['input']]
continue
if args['type'] == 'default':
field = args['input']
field_value = args['value']
defaults[field] = field_value
continue
# Override defaults
for field_name,field_default in defaults.items():
if field_name in args:
args[field_name] = field_default
# Parse invocation
args['id'] = current_id
command = InvocationCommand(invocation = args)
# Pipe previous command output (if there was a previous command)
edges = []
if len(history) > 0 or current_id != start_id:
from_id = history[0] if current_id == start_id else str(current_id - 1)
from_node = next(filter(lambda n: n[0].id == from_id, new_invocations))[0] if current_id != start_id else session.graph.get_node(from_id)
matching_edges = generate_matching_edges(from_node, command.invocation)
edges.extend(matching_edges)
# Parse provided links
if 'link_node' in args and args['link_node']:
for link in args['link_node']:
link_node = session.graph.get_node(link)
matching_edges = generate_matching_edges(link_node, command.invocation)
edges.extend(matching_edges)
if 'link' in args and args['link']:
for link in args['link']:
edges.append((EdgeConnection(node_id = link[1], field = link[0]), EdgeConnection(node_id = command.invocation.id, field = link[2])))
new_invocations.append((command.invocation, edges))
current_id = current_id + 1
# Command line was parsed successfully
# Add the invocations to the session
for invocation in new_invocations:
session.add_node(invocation[0])
for edge in invocation[1]:
session.add_edge(edge)
# Execute all available invocations
invoker.invoke(session, invoke_all = True)
while not session.is_complete():
# Wait some time
session = invoker.invoker_services.graph_execution_manager.get(session.id)
time.sleep(0.1)
except InvalidArgs:
print('Invalid command, use "help" to list commands')
continue
except SystemExit:
continue
invoker.stop()
if __name__ == "__main__":
invoke_cli()