2022-12-01 05:33:20 +00:00
|
|
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
import os
|
2023-03-03 06:02:00 +00:00
|
|
|
import shlex
|
2022-12-01 05:33:20 +00:00
|
|
|
import time
|
2023-03-03 06:02:00 +00:00
|
|
|
from typing import (
|
|
|
|
Union,
|
|
|
|
get_type_hints,
|
|
|
|
)
|
|
|
|
|
2022-12-01 05:33:20 +00:00
|
|
|
from pydantic import BaseModel
|
|
|
|
from pydantic.fields import Field
|
|
|
|
|
2023-03-04 01:19:37 +00:00
|
|
|
from ..backend import Args
|
2023-03-04 22:46:02 +00:00
|
|
|
from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_graph_execution_history
|
2023-03-03 06:02:00 +00:00
|
|
|
from .invocations import *
|
|
|
|
from .invocations.baseinvocation import BaseInvocation
|
|
|
|
from .services.events import EventServiceBase
|
2023-03-11 16:32:57 +00:00
|
|
|
from .services.model_manager_initializer import get_model_manager
|
2023-03-11 22:00:00 +00:00
|
|
|
from .services.restoration_services import RestorationServices
|
2023-03-15 06:09:30 +00:00
|
|
|
from .services.graph import Edge, EdgeConnection, GraphExecutionState
|
2022-12-01 05:33:20 +00:00
|
|
|
from .services.image_storage import DiskImageStorage
|
|
|
|
from .services.invocation_queue import MemoryInvocationQueue
|
|
|
|
from .services.invocation_services import InvocationServices
|
2023-03-03 06:02:00 +00:00
|
|
|
from .services.invoker import Invoker
|
|
|
|
from .services.processor import DefaultInvocationProcessor
|
|
|
|
from .services.sqlite import SqliteItemStorage
|
2022-12-01 05:33:20 +00:00
|
|
|
|
|
|
|
|
2023-03-04 22:46:02 +00:00
|
|
|
class CliCommand(BaseModel):
|
|
|
|
command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
|
2022-12-01 05:33:20 +00:00
|
|
|
|
|
|
|
|
|
|
|
class InvalidArgs(Exception):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
2023-03-04 22:46:02 +00:00
|
|
|
def add_invocation_args(command_parser):
|
|
|
|
# Add linking capability
|
|
|
|
command_parser.add_argument(
|
|
|
|
"--link",
|
|
|
|
"-l",
|
|
|
|
action="append",
|
|
|
|
nargs=3,
|
|
|
|
help="A link in the format 'dest_field source_node source_field'. source_node can be relative to history (e.g. -1)",
|
|
|
|
)
|
|
|
|
|
|
|
|
command_parser.add_argument(
|
|
|
|
"--link_node",
|
|
|
|
"-ln",
|
|
|
|
action="append",
|
|
|
|
help="A link from all fields in the specified node. Node can be relative to history (e.g. -1)",
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def get_command_parser() -> argparse.ArgumentParser:
|
2022-12-01 05:33:20 +00:00
|
|
|
# Create invocation parser
|
|
|
|
parser = argparse.ArgumentParser()
|
2023-03-03 06:02:00 +00:00
|
|
|
|
2022-12-01 05:33:20 +00:00
|
|
|
def exit(*args, **kwargs):
|
|
|
|
raise InvalidArgs
|
2023-03-03 06:02:00 +00:00
|
|
|
|
2022-12-01 05:33:20 +00:00
|
|
|
parser.exit = exit
|
2023-03-03 06:02:00 +00:00
|
|
|
subparsers = parser.add_subparsers(dest="type")
|
2022-12-01 05:33:20 +00:00
|
|
|
|
|
|
|
# Create subparsers for each invocation
|
|
|
|
invocations = BaseInvocation.get_all_subclasses()
|
2023-03-04 22:46:02 +00:00
|
|
|
add_parsers(subparsers, invocations, add_arguments=add_invocation_args)
|
2023-03-03 06:02:00 +00:00
|
|
|
|
2023-03-04 22:46:02 +00:00
|
|
|
# Create subparsers for each command
|
|
|
|
commands = BaseCommand.get_all_subclasses()
|
|
|
|
add_parsers(subparsers, commands, exclude_fields=["type"])
|
2022-12-01 05:33:20 +00:00
|
|
|
|
|
|
|
return parser
|
|
|
|
|
|
|
|
|
2023-03-03 06:02:00 +00:00
|
|
|
def generate_matching_edges(
|
|
|
|
a: BaseInvocation, b: BaseInvocation
|
2023-03-15 06:09:30 +00:00
|
|
|
) -> list[Edge]:
|
2022-12-01 05:33:20 +00:00
|
|
|
"""Generates all possible edges between two invocations"""
|
|
|
|
atype = type(a)
|
|
|
|
btype = type(b)
|
|
|
|
|
|
|
|
aoutputtype = atype.get_output_type()
|
|
|
|
|
|
|
|
afields = get_type_hints(aoutputtype)
|
|
|
|
bfields = get_type_hints(btype)
|
|
|
|
|
|
|
|
matching_fields = set(afields.keys()).intersection(bfields.keys())
|
2023-03-03 06:02:00 +00:00
|
|
|
|
2022-12-01 05:33:20 +00:00
|
|
|
# Remove invalid fields
|
2023-03-03 06:02:00 +00:00
|
|
|
invalid_fields = set(["type", "id"])
|
2022-12-01 05:33:20 +00:00
|
|
|
matching_fields = matching_fields.difference(invalid_fields)
|
|
|
|
|
2023-03-03 06:02:00 +00:00
|
|
|
edges = [
|
2023-03-15 06:09:30 +00:00
|
|
|
Edge(
|
|
|
|
source=EdgeConnection(node_id=a.id, field=field),
|
|
|
|
destination=EdgeConnection(node_id=b.id, field=field)
|
2023-03-03 06:02:00 +00:00
|
|
|
)
|
|
|
|
for field in matching_fields
|
|
|
|
]
|
2022-12-01 05:33:20 +00:00
|
|
|
return edges
|
|
|
|
|
|
|
|
|
2023-03-09 03:25:03 +00:00
|
|
|
class SessionError(Exception):
|
|
|
|
"""Raised when a session error has occurred"""
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
def invoke_all(context: CliContext):
|
|
|
|
"""Runs all invocations in the specified session"""
|
|
|
|
context.invoker.invoke(context.session, invoke_all=True)
|
2023-03-15 06:09:30 +00:00
|
|
|
while not context.get_session().is_complete():
|
2023-03-09 03:25:03 +00:00
|
|
|
# Wait some time
|
|
|
|
time.sleep(0.1)
|
|
|
|
|
|
|
|
# Print any errors
|
|
|
|
if context.session.has_error():
|
|
|
|
for n in context.session.errors:
|
|
|
|
print(
|
2023-03-15 06:09:30 +00:00
|
|
|
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {context.session.errors[n]}"
|
2023-03-09 03:25:03 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
raise SessionError()
|
|
|
|
|
|
|
|
|
2022-12-01 05:33:20 +00:00
|
|
|
def invoke_cli():
|
2023-03-11 16:32:57 +00:00
|
|
|
config = Args()
|
|
|
|
config.parse_args()
|
|
|
|
model_manager = get_model_manager(config)
|
2022-12-01 05:33:20 +00:00
|
|
|
|
|
|
|
events = EventServiceBase()
|
|
|
|
|
2023-03-03 06:02:00 +00:00
|
|
|
output_folder = os.path.abspath(
|
|
|
|
os.path.join(os.path.dirname(__file__), "../../../outputs")
|
|
|
|
)
|
2022-12-01 05:33:20 +00:00
|
|
|
|
|
|
|
# TODO: build a file/path manager?
|
2023-03-03 06:02:00 +00:00
|
|
|
db_location = os.path.join(output_folder, "invokeai.db")
|
2022-12-01 05:33:20 +00:00
|
|
|
|
2023-02-25 04:11:28 +00:00
|
|
|
services = InvocationServices(
|
remove factory pattern
Factory pattern is now removed. Typical usage of the InvokeAIGenerator is now:
```
from invokeai.backend.generator import (
InvokeAIGeneratorBasicParams,
Txt2Img,
Img2Img,
Inpaint,
)
params = InvokeAIGeneratorBasicParams(
model_name = 'stable-diffusion-1.5',
steps = 30,
scheduler = 'k_lms',
cfg_scale = 8.0,
height = 640,
width = 640
)
print ('=== TXT2IMG TEST ===')
txt2img = Txt2Img(manager, params)
outputs = txt2img.generate(prompt='banana sushi', iterations=2)
for i in outputs:
print(f'image={output.image}, seed={output.seed}, model={output.params.model_name}, hash={output.model_hash}, steps={output.params.steps}')
```
The `params` argument is optional, so if you wish to accept default
parameters and selectively override them, just do this:
```
outputs = Txt2Img(manager).generate(prompt='banana sushi',
steps=50,
scheduler='k_heun',
model_name='stable-diffusion-2.1'
)
```
2023-03-11 00:33:04 +00:00
|
|
|
model_manager=model_manager,
|
2023-03-03 06:02:00 +00:00
|
|
|
events=events,
|
|
|
|
images=DiskImageStorage(output_folder),
|
|
|
|
queue=MemoryInvocationQueue(),
|
|
|
|
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
|
|
|
filename=db_location, table_name="graph_executions"
|
|
|
|
),
|
|
|
|
processor=DefaultInvocationProcessor(),
|
2023-03-11 22:00:00 +00:00
|
|
|
restoration=RestorationServices(config),
|
2022-12-01 05:33:20 +00:00
|
|
|
)
|
|
|
|
|
2023-02-25 04:11:28 +00:00
|
|
|
invoker = Invoker(services)
|
2023-02-27 18:01:07 +00:00
|
|
|
session: GraphExecutionState = invoker.create_execution_state()
|
2023-03-04 22:46:02 +00:00
|
|
|
parser = get_command_parser()
|
2022-12-01 05:33:20 +00:00
|
|
|
|
|
|
|
# Uncomment to print out previous sessions at startup
|
2023-02-25 04:11:28 +00:00
|
|
|
# print(services.session_manager.list())
|
2022-12-01 05:33:20 +00:00
|
|
|
|
2023-03-04 22:46:02 +00:00
|
|
|
context = CliContext(invoker, session, parser)
|
2022-12-01 05:33:20 +00:00
|
|
|
|
|
|
|
while True:
|
|
|
|
try:
|
|
|
|
cmd_input = input("> ")
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
# Ctrl-c exits
|
|
|
|
break
|
|
|
|
|
|
|
|
try:
|
|
|
|
# Refresh the state of the session
|
2023-03-09 03:25:03 +00:00
|
|
|
history = list(get_graph_execution_history(context.session))
|
2022-12-01 05:33:20 +00:00
|
|
|
|
|
|
|
# Split the command for piping
|
2023-03-03 06:02:00 +00:00
|
|
|
cmds = cmd_input.split("|")
|
2022-12-01 05:33:20 +00:00
|
|
|
start_id = len(history)
|
|
|
|
current_id = start_id
|
|
|
|
new_invocations = list()
|
|
|
|
for cmd in cmds:
|
2023-03-03 06:02:00 +00:00
|
|
|
if cmd is None or cmd.strip() == "":
|
|
|
|
raise InvalidArgs("Empty command")
|
2023-02-27 18:01:07 +00:00
|
|
|
|
2022-12-01 05:33:20 +00:00
|
|
|
# Parse args to create invocation
|
2023-03-09 03:25:03 +00:00
|
|
|
args = vars(context.parser.parse_args(shlex.split(cmd.strip())))
|
2022-12-01 05:33:20 +00:00
|
|
|
|
|
|
|
# Override defaults
|
2023-03-04 22:46:02 +00:00
|
|
|
for field_name, field_default in context.defaults.items():
|
2022-12-01 05:33:20 +00:00
|
|
|
if field_name in args:
|
|
|
|
args[field_name] = field_default
|
|
|
|
|
|
|
|
# Parse invocation
|
2023-03-03 06:02:00 +00:00
|
|
|
args["id"] = current_id
|
2023-03-04 22:46:02 +00:00
|
|
|
command = CliCommand(command=args)
|
|
|
|
|
|
|
|
# Run any CLI commands immediately
|
|
|
|
if isinstance(command.command, BaseCommand):
|
2023-03-09 03:25:03 +00:00
|
|
|
# Invoke all current nodes to preserve operation order
|
|
|
|
invoke_all(context)
|
|
|
|
|
|
|
|
# Run the command
|
2023-03-04 22:46:02 +00:00
|
|
|
command.command.run(context)
|
|
|
|
continue
|
2022-12-01 05:33:20 +00:00
|
|
|
|
|
|
|
# Pipe previous command output (if there was a previous command)
|
2023-03-15 06:09:30 +00:00
|
|
|
edges: list[Edge] = list()
|
2022-12-01 05:33:20 +00:00
|
|
|
if len(history) > 0 or current_id != start_id:
|
2023-03-03 06:02:00 +00:00
|
|
|
from_id = (
|
|
|
|
history[0] if current_id == start_id else str(current_id - 1)
|
|
|
|
)
|
|
|
|
from_node = (
|
|
|
|
next(filter(lambda n: n[0].id == from_id, new_invocations))[0]
|
|
|
|
if current_id != start_id
|
2023-03-09 03:25:03 +00:00
|
|
|
else context.session.graph.get_node(from_id)
|
2023-03-03 06:02:00 +00:00
|
|
|
)
|
|
|
|
matching_edges = generate_matching_edges(
|
2023-03-04 22:46:02 +00:00
|
|
|
from_node, command.command
|
2023-03-03 06:02:00 +00:00
|
|
|
)
|
2022-12-01 05:33:20 +00:00
|
|
|
edges.extend(matching_edges)
|
2023-03-03 06:02:00 +00:00
|
|
|
|
2022-12-01 05:33:20 +00:00
|
|
|
# Parse provided links
|
2023-03-03 06:02:00 +00:00
|
|
|
if "link_node" in args and args["link_node"]:
|
|
|
|
for link in args["link_node"]:
|
2023-03-09 03:25:03 +00:00
|
|
|
link_node = context.session.graph.get_node(link)
|
2023-03-03 06:02:00 +00:00
|
|
|
matching_edges = generate_matching_edges(
|
2023-03-04 22:46:02 +00:00
|
|
|
link_node, command.command
|
2023-03-03 06:02:00 +00:00
|
|
|
)
|
2023-03-15 06:09:30 +00:00
|
|
|
matching_destinations = [e.destination for e in matching_edges]
|
|
|
|
edges = [e for e in edges if e.destination not in matching_destinations]
|
2022-12-01 05:33:20 +00:00
|
|
|
edges.extend(matching_edges)
|
2023-03-03 06:02:00 +00:00
|
|
|
|
|
|
|
if "link" in args and args["link"]:
|
|
|
|
for link in args["link"]:
|
2023-03-15 06:09:30 +00:00
|
|
|
edges = [e for e in edges if e.destination.node_id != command.command.id and e.destination.field != link[2]]
|
2023-03-03 06:02:00 +00:00
|
|
|
edges.append(
|
2023-03-15 06:09:30 +00:00
|
|
|
Edge(
|
|
|
|
source=EdgeConnection(node_id=link[1], field=link[0]),
|
|
|
|
destination=EdgeConnection(
|
2023-03-04 22:46:02 +00:00
|
|
|
node_id=command.command.id, field=link[2]
|
2023-03-15 06:09:30 +00:00
|
|
|
)
|
2023-03-03 06:02:00 +00:00
|
|
|
)
|
|
|
|
)
|
2022-12-01 05:33:20 +00:00
|
|
|
|
2023-03-04 22:46:02 +00:00
|
|
|
new_invocations.append((command.command, edges))
|
2022-12-01 05:33:20 +00:00
|
|
|
|
|
|
|
current_id = current_id + 1
|
|
|
|
|
2023-03-09 03:25:03 +00:00
|
|
|
# Add the node to the session
|
|
|
|
context.session.add_node(command.command)
|
|
|
|
for edge in edges:
|
2023-03-04 22:46:02 +00:00
|
|
|
print(edge)
|
2023-03-09 03:25:03 +00:00
|
|
|
context.session.add_edge(edge)
|
2023-03-03 06:02:00 +00:00
|
|
|
|
2023-03-09 03:25:03 +00:00
|
|
|
# Execute all remaining nodes
|
|
|
|
invoke_all(context)
|
2022-12-01 05:33:20 +00:00
|
|
|
|
|
|
|
except InvalidArgs:
|
|
|
|
print('Invalid command, use "help" to list commands')
|
|
|
|
continue
|
|
|
|
|
2023-03-09 03:25:03 +00:00
|
|
|
except SessionError:
|
|
|
|
# Start a new session
|
|
|
|
print("Session error: creating a new session")
|
|
|
|
context.session = context.invoker.create_execution_state()
|
|
|
|
|
2023-03-04 22:46:02 +00:00
|
|
|
except ExitCli:
|
|
|
|
break
|
|
|
|
|
2022-12-01 05:33:20 +00:00
|
|
|
except SystemExit:
|
|
|
|
continue
|
2023-03-03 06:02:00 +00:00
|
|
|
|
2022-12-01 05:33:20 +00:00
|
|
|
invoker.stop()
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
invoke_cli()
|