mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Apply black
This commit is contained in:
@ -14,8 +14,14 @@ from ..services.graph import GraphExecutionState, LibraryGraph, Edge
|
||||
from ..services.invoker import Invoker
|
||||
|
||||
|
||||
def add_field_argument(command_parser, name: str, field, default_override = None):
|
||||
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory()
|
||||
def add_field_argument(command_parser, name: str, field, default_override=None):
|
||||
default = (
|
||||
default_override
|
||||
if default_override is not None
|
||||
else field.default
|
||||
if field.default_factory is None
|
||||
else field.default_factory()
|
||||
)
|
||||
if get_origin(field.type_) == Literal:
|
||||
allowed_values = get_args(field.type_)
|
||||
allowed_types = set()
|
||||
@ -47,8 +53,8 @@ def add_parsers(
|
||||
commands: list[type],
|
||||
command_field: str = "type",
|
||||
exclude_fields: list[str] = ["id", "type"],
|
||||
add_arguments: Union[Callable[[argparse.ArgumentParser], None],None] = None
|
||||
):
|
||||
add_arguments: Union[Callable[[argparse.ArgumentParser], None], None] = None,
|
||||
):
|
||||
"""Adds parsers for each command to the subparsers"""
|
||||
|
||||
# Create subparsers for each command
|
||||
@ -61,7 +67,7 @@ def add_parsers(
|
||||
add_arguments(command_parser)
|
||||
|
||||
# Convert all fields to arguments
|
||||
fields = command.__fields__ # type: ignore
|
||||
fields = command.__fields__ # type: ignore
|
||||
for name, field in fields.items():
|
||||
if name in exclude_fields:
|
||||
continue
|
||||
@ -70,13 +76,11 @@ def add_parsers(
|
||||
|
||||
|
||||
def add_graph_parsers(
|
||||
subparsers,
|
||||
graphs: list[LibraryGraph],
|
||||
add_arguments: Union[Callable[[argparse.ArgumentParser], None], None] = None
|
||||
subparsers, graphs: list[LibraryGraph], add_arguments: Union[Callable[[argparse.ArgumentParser], None], None] = None
|
||||
):
|
||||
for graph in graphs:
|
||||
command_parser = subparsers.add_parser(graph.name, help=graph.description)
|
||||
|
||||
|
||||
if add_arguments is not None:
|
||||
add_arguments(command_parser)
|
||||
|
||||
@ -128,6 +132,7 @@ class CliContext:
|
||||
|
||||
class ExitCli(Exception):
|
||||
"""Exception to exit the CLI"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@ -155,7 +160,7 @@ class BaseCommand(ABC, BaseModel):
|
||||
@classmethod
|
||||
def get_commands_map(cls):
|
||||
# Get the type strings out of the literals and into a dictionary
|
||||
return dict(map(lambda t: (get_args(get_type_hints(t)['type'])[0], t),BaseCommand.get_all_subclasses()))
|
||||
return dict(map(lambda t: (get_args(get_type_hints(t)["type"])[0], t), BaseCommand.get_all_subclasses()))
|
||||
|
||||
@abstractmethod
|
||||
def run(self, context: CliContext) -> None:
|
||||
@ -165,7 +170,8 @@ class BaseCommand(ABC, BaseModel):
|
||||
|
||||
class ExitCommand(BaseCommand):
|
||||
"""Exits the CLI"""
|
||||
type: Literal['exit'] = 'exit'
|
||||
|
||||
type: Literal["exit"] = "exit"
|
||||
|
||||
def run(self, context: CliContext) -> None:
|
||||
raise ExitCli()
|
||||
@ -173,7 +179,8 @@ class ExitCommand(BaseCommand):
|
||||
|
||||
class HelpCommand(BaseCommand):
|
||||
"""Shows help"""
|
||||
type: Literal['help'] = 'help'
|
||||
|
||||
type: Literal["help"] = "help"
|
||||
|
||||
def run(self, context: CliContext) -> None:
|
||||
context.parser.print_help()
|
||||
@ -183,11 +190,7 @@ def get_graph_execution_history(
|
||||
graph_execution_state: GraphExecutionState,
|
||||
) -> Iterable[str]:
|
||||
"""Gets the history of fully-executed invocations for a graph execution"""
|
||||
return (
|
||||
n
|
||||
for n in reversed(graph_execution_state.executed_history)
|
||||
if n in graph_execution_state.graph.nodes
|
||||
)
|
||||
return (n for n in reversed(graph_execution_state.executed_history) if n in graph_execution_state.graph.nodes)
|
||||
|
||||
|
||||
def get_invocation_command(invocation) -> str:
|
||||
@ -218,7 +221,8 @@ def get_invocation_command(invocation) -> str:
|
||||
|
||||
class HistoryCommand(BaseCommand):
|
||||
"""Shows the invocation history"""
|
||||
type: Literal['history'] = 'history'
|
||||
|
||||
type: Literal["history"] = "history"
|
||||
|
||||
# Inputs
|
||||
# fmt: off
|
||||
@ -235,7 +239,8 @@ class HistoryCommand(BaseCommand):
|
||||
|
||||
class SetDefaultCommand(BaseCommand):
|
||||
"""Sets a default value for a field"""
|
||||
type: Literal['default'] = 'default'
|
||||
|
||||
type: Literal["default"] = "default"
|
||||
|
||||
# Inputs
|
||||
# fmt: off
|
||||
@ -253,7 +258,8 @@ class SetDefaultCommand(BaseCommand):
|
||||
|
||||
class DrawGraphCommand(BaseCommand):
|
||||
"""Debugs a graph"""
|
||||
type: Literal['draw_graph'] = 'draw_graph'
|
||||
|
||||
type: Literal["draw_graph"] = "draw_graph"
|
||||
|
||||
def run(self, context: CliContext) -> None:
|
||||
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
|
||||
@ -271,7 +277,8 @@ class DrawGraphCommand(BaseCommand):
|
||||
|
||||
class DrawExecutionGraphCommand(BaseCommand):
|
||||
"""Debugs an execution graph"""
|
||||
type: Literal['draw_xgraph'] = 'draw_xgraph'
|
||||
|
||||
type: Literal["draw_xgraph"] = "draw_xgraph"
|
||||
|
||||
def run(self, context: CliContext) -> None:
|
||||
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
|
||||
@ -286,6 +293,7 @@ class DrawExecutionGraphCommand(BaseCommand):
|
||||
plt.axis("off")
|
||||
plt.show()
|
||||
|
||||
|
||||
class SortedHelpFormatter(argparse.HelpFormatter):
|
||||
def _iter_indented_subactions(self, action):
|
||||
try:
|
||||
|
@ -19,8 +19,8 @@ from ..services.invocation_services import InvocationServices
|
||||
# singleton object, class variable
|
||||
completer = None
|
||||
|
||||
|
||||
class Completer(object):
|
||||
|
||||
def __init__(self, model_manager: ModelManager):
|
||||
self.commands = self.get_commands()
|
||||
self.matches = None
|
||||
@ -43,7 +43,7 @@ class Completer(object):
|
||||
except IndexError:
|
||||
pass
|
||||
options = options or list(self.parse_commands().keys())
|
||||
|
||||
|
||||
if not text: # first time
|
||||
self.matches = options
|
||||
else:
|
||||
@ -56,17 +56,17 @@ class Completer(object):
|
||||
return match
|
||||
|
||||
@classmethod
|
||||
def get_commands(self)->List[object]:
|
||||
def get_commands(self) -> List[object]:
|
||||
"""
|
||||
Return a list of all the client commands and invocations.
|
||||
"""
|
||||
return BaseCommand.get_commands() + BaseInvocation.get_invocations()
|
||||
|
||||
def get_current_command(self, buffer: str)->tuple[str, str]:
|
||||
def get_current_command(self, buffer: str) -> tuple[str, str]:
|
||||
"""
|
||||
Parse the readline buffer to find the most recent command and its switch.
|
||||
"""
|
||||
if len(buffer)==0:
|
||||
if len(buffer) == 0:
|
||||
return None, None
|
||||
tokens = shlex.split(buffer)
|
||||
command = None
|
||||
@ -78,11 +78,11 @@ class Completer(object):
|
||||
else:
|
||||
switch = t
|
||||
# don't try to autocomplete switches that are already complete
|
||||
if switch and buffer.endswith(' '):
|
||||
switch=None
|
||||
return command or '', switch or ''
|
||||
if switch and buffer.endswith(" "):
|
||||
switch = None
|
||||
return command or "", switch or ""
|
||||
|
||||
def parse_commands(self)->Dict[str, List[str]]:
|
||||
def parse_commands(self) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Return a dict in which the keys are the command name
|
||||
and the values are the parameters the command takes.
|
||||
@ -90,11 +90,11 @@ class Completer(object):
|
||||
result = dict()
|
||||
for command in self.commands:
|
||||
hints = get_type_hints(command)
|
||||
name = get_args(hints['type'])[0]
|
||||
result.update({name:hints})
|
||||
name = get_args(hints["type"])[0]
|
||||
result.update({name: hints})
|
||||
return result
|
||||
|
||||
def get_command_options(self, command: str, switch: str)->List[str]:
|
||||
def get_command_options(self, command: str, switch: str) -> List[str]:
|
||||
"""
|
||||
Return all the parameters that can be passed to the command as
|
||||
command-line switches. Returns None if the command is unrecognized.
|
||||
@ -102,42 +102,46 @@ class Completer(object):
|
||||
parsed_commands = self.parse_commands()
|
||||
if command not in parsed_commands:
|
||||
return None
|
||||
|
||||
|
||||
# handle switches in the format "-foo=bar"
|
||||
argument = None
|
||||
if switch and '=' in switch:
|
||||
switch, argument = switch.split('=')
|
||||
|
||||
parameter = switch.strip('-')
|
||||
if switch and "=" in switch:
|
||||
switch, argument = switch.split("=")
|
||||
|
||||
parameter = switch.strip("-")
|
||||
if parameter in parsed_commands[command]:
|
||||
if argument is None:
|
||||
return self.get_parameter_options(parameter, parsed_commands[command][parameter])
|
||||
else:
|
||||
return [f"--{parameter}={x}" for x in self.get_parameter_options(parameter, parsed_commands[command][parameter])]
|
||||
return [
|
||||
f"--{parameter}={x}"
|
||||
for x in self.get_parameter_options(parameter, parsed_commands[command][parameter])
|
||||
]
|
||||
else:
|
||||
return [f"--{x}" for x in parsed_commands[command].keys()]
|
||||
|
||||
def get_parameter_options(self, parameter: str, typehint)->List[str]:
|
||||
def get_parameter_options(self, parameter: str, typehint) -> List[str]:
|
||||
"""
|
||||
Given a parameter type (such as Literal), offers autocompletions.
|
||||
"""
|
||||
if get_origin(typehint) == Literal:
|
||||
return get_args(typehint)
|
||||
if parameter == 'model':
|
||||
if parameter == "model":
|
||||
return self.manager.model_names()
|
||||
|
||||
|
||||
def _pre_input_hook(self):
|
||||
if self.linebuffer:
|
||||
readline.insert_text(self.linebuffer)
|
||||
readline.redisplay()
|
||||
self.linebuffer = None
|
||||
|
||||
|
||||
|
||||
def set_autocompleter(services: InvocationServices) -> Completer:
|
||||
global completer
|
||||
|
||||
|
||||
if completer:
|
||||
return completer
|
||||
|
||||
|
||||
completer = Completer(services.model_manager)
|
||||
|
||||
readline.set_completer(completer.complete)
|
||||
@ -162,8 +166,6 @@ def set_autocompleter(services: InvocationServices) -> Completer:
|
||||
pass
|
||||
except OSError: # file likely corrupted
|
||||
newname = f"{histfile}.old"
|
||||
logger.error(
|
||||
f"Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}"
|
||||
)
|
||||
logger.error(f"Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}")
|
||||
histfile.replace(Path(newname))
|
||||
atexit.register(readline.write_history_file, histfile)
|
||||
|
Reference in New Issue
Block a user