Apply black

This commit is contained in:
Martin Kristiansen
2023-07-27 10:54:01 -04:00
parent 2183dba5c5
commit 218b6d0546
148 changed files with 5486 additions and 6296 deletions

View File

@ -13,6 +13,7 @@ from pydantic.fields import Field
# This should come early so that the logger can pick up its configuration options
from .services.config import InvokeAIAppConfig
from invokeai.backend.util.logging import InvokeAILogger
config = InvokeAIAppConfig.get_config()
config.parse_args()
logger = InvokeAILogger().getLogger(config=config)
@ -20,7 +21,7 @@ from invokeai.version.invokeai_version import __version__
# we call this early so that the message appears before other invokeai initialization messages
if config.version:
print(f'InvokeAI version {__version__}')
print(f"InvokeAI version {__version__}")
sys.exit(0)
from invokeai.app.services.board_image_record_storage import (
@ -36,18 +37,21 @@ from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService, ImageServiceDependencies
from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService
from .services.default_graphs import (default_text_to_image_graph_id,
create_system_graphs)
from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from .cli.commands import (BaseCommand, CliContext, ExitCli,
SortedHelpFormatter, add_graph_parsers, add_parsers)
from .cli.commands import BaseCommand, CliContext, ExitCli, SortedHelpFormatter, add_graph_parsers, add_parsers
from .cli.completer import set_autocompleter
from .invocations.baseinvocation import BaseInvocation
from .services.events import EventServiceBase
from .services.graph import (Edge, EdgeConnection, GraphExecutionState,
GraphInvocation, LibraryGraph,
are_connection_types_compatible)
from .services.graph import (
Edge,
EdgeConnection,
GraphExecutionState,
GraphInvocation,
LibraryGraph,
are_connection_types_compatible,
)
from .services.image_file_storage import DiskImageFileStorage
from .services.invocation_queue import MemoryInvocationQueue
from .services.invocation_services import InvocationServices
@ -58,6 +62,7 @@ from .services.sqlite import SqliteItemStorage
import torch
import invokeai.backend.util.hotfixes
if torch.backends.mps.is_available():
import invokeai.backend.util.mps_fixes
@ -69,6 +74,7 @@ class CliCommand(BaseModel):
class InvalidArgs(Exception):
pass
def add_invocation_args(command_parser):
# Add linking capability
command_parser.add_argument(
@ -113,7 +119,7 @@ def get_command_parser(services: InvocationServices) -> argparse.ArgumentParser:
return parser
class NodeField():
class NodeField:
alias: str
node_path: str
field: str
@ -126,15 +132,20 @@ class NodeField():
self.field_type = field_type
def fields_from_type_hints(hints: dict[str, type], node_path: str) -> dict[str,NodeField]:
return {k:NodeField(alias=k, node_path=node_path, field=k, field_type=v) for k, v in hints.items()}
def fields_from_type_hints(hints: dict[str, type], node_path: str) -> dict[str, NodeField]:
return {k: NodeField(alias=k, node_path=node_path, field=k, field_type=v) for k, v in hints.items()}
def get_node_input_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
"""Gets the node field for the specified field alias"""
exposed_input = next(e for e in graph.exposed_inputs if e.alias == field_alias)
node_type = type(graph.graph.get_node(exposed_input.node_path))
return NodeField(alias=exposed_input.alias, node_path=f'{node_id}.{exposed_input.node_path}', field=exposed_input.field, field_type=get_type_hints(node_type)[exposed_input.field])
return NodeField(
alias=exposed_input.alias,
node_path=f"{node_id}.{exposed_input.node_path}",
field=exposed_input.field,
field_type=get_type_hints(node_type)[exposed_input.field],
)
def get_node_output_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
@ -142,7 +153,12 @@ def get_node_output_field(graph: LibraryGraph, field_alias: str, node_id: str) -
exposed_output = next(e for e in graph.exposed_outputs if e.alias == field_alias)
node_type = type(graph.graph.get_node(exposed_output.node_path))
node_output_type = node_type.get_output_type()
return NodeField(alias=exposed_output.alias, node_path=f'{node_id}.{exposed_output.node_path}', field=exposed_output.field, field_type=get_type_hints(node_output_type)[exposed_output.field])
return NodeField(
alias=exposed_output.alias,
node_path=f"{node_id}.{exposed_output.node_path}",
field=exposed_output.field,
field_type=get_type_hints(node_output_type)[exposed_output.field],
)
def get_node_inputs(invocation: BaseInvocation, context: CliContext) -> dict[str, NodeField]:
@ -165,9 +181,7 @@ def get_node_outputs(invocation: BaseInvocation, context: CliContext) -> dict[st
return {e.alias: get_node_output_field(graph, e.alias, invocation.id) for e in graph.exposed_outputs}
def generate_matching_edges(
a: BaseInvocation, b: BaseInvocation, context: CliContext
) -> list[Edge]:
def generate_matching_edges(a: BaseInvocation, b: BaseInvocation, context: CliContext) -> list[Edge]:
"""Generates all possible edges between two invocations"""
afields = get_node_outputs(a, context)
bfields = get_node_inputs(b, context)
@ -179,12 +193,14 @@ def generate_matching_edges(
matching_fields = matching_fields.difference(invalid_fields)
# Validate types
matching_fields = [f for f in matching_fields if are_connection_types_compatible(afields[f].field_type, bfields[f].field_type)]
matching_fields = [
f for f in matching_fields if are_connection_types_compatible(afields[f].field_type, bfields[f].field_type)
]
edges = [
Edge(
source=EdgeConnection(node_id=afields[alias].node_path, field=afields[alias].field),
destination=EdgeConnection(node_id=bfields[alias].node_path, field=bfields[alias].field)
destination=EdgeConnection(node_id=bfields[alias].node_path, field=bfields[alias].field),
)
for alias in matching_fields
]
@ -193,6 +209,7 @@ def generate_matching_edges(
class SessionError(Exception):
"""Raised when a session error has occurred"""
pass
@ -209,22 +226,23 @@ def invoke_all(context: CliContext):
context.invoker.services.logger.error(
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {context.session.errors[n]}"
)
raise SessionError()
def invoke_cli():
logger.info(f'InvokeAI version {__version__}')
logger.info(f"InvokeAI version {__version__}")
# get the optional list of invocations to execute on the command line
parser = config.get_parser()
parser.add_argument('commands',nargs='*')
parser.add_argument("commands", nargs="*")
invocation_commands = parser.parse_args().commands
# get the optional file to read commands from.
# Simplest is to use it for STDIN
if infile := config.from_file:
sys.stdin = open(infile,"r")
model_manager = ModelManagerService(config,logger)
sys.stdin = open(infile, "r")
model_manager = ModelManagerService(config, logger)
events = EventServiceBase()
output_folder = config.output_path
@ -234,13 +252,13 @@ def invoke_cli():
db_location = ":memory:"
else:
db_location = config.db_path
db_location.parent.mkdir(parents=True,exist_ok=True)
db_location.parent.mkdir(parents=True, exist_ok=True)
logger.info(f'InvokeAI database location is "{db_location}"')
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions"
)
filename=db_location, table_name="graph_executions"
)
urls = LocalUrlService()
image_record_storage = SqliteImageRecordStorage(db_location)
@ -281,24 +299,21 @@ def invoke_cli():
graph_execution_manager=graph_execution_manager,
)
)
services = InvocationServices(
model_manager=model_manager,
events=events,
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
latents=ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")),
images=images,
boards=boards,
board_images=board_images,
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=db_location, table_name="graphs"
),
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(),
logger=logger,
configuration=config,
)
system_graphs = create_system_graphs(services.graph_library)
system_graph_names = set([g.name for g in system_graphs])
@ -308,7 +323,7 @@ def invoke_cli():
session: GraphExecutionState = invoker.create_execution_state()
parser = get_command_parser(services)
re_negid = re.compile('^-[0-9]+$')
re_negid = re.compile("^-[0-9]+$")
# Uncomment to print out previous sessions at startup
# print(services.session_manager.list())
@ -318,7 +333,7 @@ def invoke_cli():
command_line_args_exist = len(invocation_commands) > 0
done = False
while not done:
try:
if command_line_args_exist:
@ -332,7 +347,7 @@ def invoke_cli():
try:
# Refresh the state of the session
#history = list(get_graph_execution_history(context.session))
# history = list(get_graph_execution_history(context.session))
history = list(reversed(context.nodes_added))
# Split the command for piping
@ -353,17 +368,17 @@ def invoke_cli():
args[field_name] = field_default
# Parse invocation
command: CliCommand = None # type:ignore
command: CliCommand = None # type:ignore
system_graph: Optional[LibraryGraph] = None
if args['type'] in system_graph_names:
system_graph = next(filter(lambda g: g.name == args['type'], system_graphs))
if args["type"] in system_graph_names:
system_graph = next(filter(lambda g: g.name == args["type"], system_graphs))
invocation = GraphInvocation(graph=system_graph.graph, id=str(current_id))
for exposed_input in system_graph.exposed_inputs:
if exposed_input.alias in args:
node = invocation.graph.get_node(exposed_input.node_path)
field = exposed_input.field
setattr(node, field, args[exposed_input.alias])
command = CliCommand(command = invocation)
command = CliCommand(command=invocation)
context.graph_nodes[invocation.id] = system_graph.id
else:
args["id"] = current_id
@ -385,17 +400,13 @@ def invoke_cli():
# Pipe previous command output (if there was a previous command)
edges: list[Edge] = list()
if len(history) > 0 or current_id != start_id:
from_id = (
history[0] if current_id == start_id else str(current_id - 1)
)
from_id = history[0] if current_id == start_id else str(current_id - 1)
from_node = (
next(filter(lambda n: n[0].id == from_id, new_invocations))[0]
if current_id != start_id
else context.session.graph.get_node(from_id)
)
matching_edges = generate_matching_edges(
from_node, command.command, context
)
matching_edges = generate_matching_edges(from_node, command.command, context)
edges.extend(matching_edges)
# Parse provided links
@ -406,16 +417,18 @@ def invoke_cli():
node_id = str(current_id + int(node_id))
link_node = context.session.graph.get_node(node_id)
matching_edges = generate_matching_edges(
link_node, command.command, context
)
matching_edges = generate_matching_edges(link_node, command.command, context)
matching_destinations = [e.destination for e in matching_edges]
edges = [e for e in edges if e.destination not in matching_destinations]
edges.extend(matching_edges)
if "link" in args and args["link"]:
for link in args["link"]:
edges = [e for e in edges if e.destination.node_id != command.command.id or e.destination.field != link[2]]
edges = [
e
for e in edges
if e.destination.node_id != command.command.id or e.destination.field != link[2]
]
node_id = link[0]
if re_negid.match(node_id):
@ -428,7 +441,7 @@ def invoke_cli():
edges.append(
Edge(
source=EdgeConnection(node_id=node_output.node_path, field=node_output.field),
destination=EdgeConnection(node_id=node_input.node_path, field=node_input.field)
destination=EdgeConnection(node_id=node_input.node_path, field=node_input.field),
)
)