mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
c238a7f18b
Upgrade pydantic and fastapi to latest. - pydantic~=2.4.2 - fastapi~=103.2 - fastapi-events~=0.9.1 **Big Changes** There are a number of logic changes needed to support pydantic v2. Most changes are very simple, like using the new methods to serialized and deserialize models, but there are a few more complex changes. **Invocations** The biggest change relates to invocation creation, instantiation and validation. Because pydantic v2 moves all validation logic into the rust pydantic-core, we may no longer directly stick our fingers into the validation pie. Previously, we (ab)used models and fields to allow invocation fields to be optional at instantiation, but required when `invoke()` is called. We directly manipulated the fields and invocation models when calling `invoke()`. With pydantic v2, this is much more involved. Changes to the python wrapper do not propagate down to the rust validation logic - you have to rebuild the model. This causes problem with concurrent access to the invocation classes and is not a free operation. This logic has been totally refactored and we do not need to change the model any more. The details are in `baseinvocation.py`, in the `InputField` function and `BaseInvocation.invoke_internal()` method. In the end, this implementation is cleaner. **Invocation Fields** In pydantic v2, you can no longer directly add or remove fields from a model. Previously, we did this to add the `type` field to invocations. **Invocation Decorators** With pydantic v2, we instead use the imperative `create_model()` API to create a new model with the additional field. This is done in `baseinvocation.py` in the `invocation()` wrapper. A similar technique is used for `invocation_output()`. **Minor Changes** There are a number of minor changes around the pydantic v2 models API. **Protected `model_` Namespace** All models' pydantic-provided methods and attributes are prefixed with `model_` and this is considered a protected namespace. This causes some conflict, because "model" means something to us, and we have a ton of pydantic models with attributes starting with "model_". Forunately, there are no direct conflicts. However, in any pydantic model where we define an attribute or method that starts with "model_", we must tell set the protected namespaces to an empty tuple. ```py class IPAdapterModelField(BaseModel): model_name: str = Field(description="Name of the IP-Adapter model") base_model: BaseModelType = Field(description="Base model") model_config = ConfigDict(protected_namespaces=()) ``` **Model Serialization** Pydantic models no longer have `Model.dict()` or `Model.json()`. Instead, we use `Model.model_dump()` or `Model.model_dump_json()`. **Model Deserialization** Pydantic models no longer have `Model.parse_obj()` or `Model.parse_raw()`, and there are no `parse_raw_as()` or `parse_obj_as()` functions. Instead, you need to create a `TypeAdapter` object to parse python objects or JSON into a model. ```py adapter_graph = TypeAdapter(Graph) deserialized_graph_from_json = adapter_graph.validate_json(graph_json) deserialized_graph_from_dict = adapter_graph.validate_python(graph_dict) ``` **Field Customisation** Pydantic `Field`s no longer accept arbitrary args. Now, you must put all additional arbitrary args in a `json_schema_extra` arg on the field. **Schema Customisation** FastAPI and pydantic schema generation now follows the OpenAPI version 3.1 spec. This necessitates two changes: - Our schema customization logic has been revised - Schema parsing to build node templates has been revised The specific aren't important, but this does present additional surface area for bugs. **Performance Improvements** Pydantic v2 is a full rewrite with a rust backend. This offers a substantial performance improvement (pydantic claims 5x to 50x depending on the task). We'll notice this the most during serialization and deserialization of sessions/graphs, which happens very very often - a couple times per node. I haven't done any benchmarks, but anecdotally, graph execution is much faster. Also, very larges graphs - like with massive iterators - are much, much faster.
313 lines
9.9 KiB
Python
313 lines
9.9 KiB
Python
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
|
|
|
import argparse
|
|
from abc import ABC, abstractmethod
|
|
from typing import Any, Callable, Iterable, Literal, Union, get_args, get_origin, get_type_hints
|
|
|
|
import matplotlib.pyplot as plt
|
|
import networkx as nx
|
|
from pydantic import BaseModel, Field
|
|
|
|
import invokeai.backend.util.logging as logger
|
|
|
|
from ..invocations.baseinvocation import BaseInvocation
|
|
from ..invocations.image import ImageField
|
|
from ..services.graph import Edge, GraphExecutionState, LibraryGraph
|
|
from ..services.invoker import Invoker
|
|
|
|
|
|
def add_field_argument(command_parser, name: str, field, default_override=None):
|
|
default = (
|
|
default_override
|
|
if default_override is not None
|
|
else field.default
|
|
if field.default_factory is None
|
|
else field.default_factory()
|
|
)
|
|
if get_origin(field.annotation) == Literal:
|
|
allowed_values = get_args(field.annotation)
|
|
allowed_types = set()
|
|
for val in allowed_values:
|
|
allowed_types.add(type(val))
|
|
allowed_types_list = list(allowed_types)
|
|
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
|
|
|
|
command_parser.add_argument(
|
|
f"--{name}",
|
|
dest=name,
|
|
type=field_type,
|
|
default=default,
|
|
choices=allowed_values,
|
|
help=field.description,
|
|
)
|
|
else:
|
|
command_parser.add_argument(
|
|
f"--{name}",
|
|
dest=name,
|
|
type=field.annotation,
|
|
default=default,
|
|
help=field.description,
|
|
)
|
|
|
|
|
|
def add_parsers(
|
|
subparsers,
|
|
commands: list[type],
|
|
command_field: str = "type",
|
|
exclude_fields: list[str] = ["id", "type"],
|
|
add_arguments: Union[Callable[[argparse.ArgumentParser], None], None] = None,
|
|
):
|
|
"""Adds parsers for each command to the subparsers"""
|
|
|
|
# Create subparsers for each command
|
|
for command in commands:
|
|
hints = get_type_hints(command)
|
|
cmd_name = get_args(hints[command_field])[0]
|
|
command_parser = subparsers.add_parser(cmd_name, help=command.__doc__)
|
|
|
|
if add_arguments is not None:
|
|
add_arguments(command_parser)
|
|
|
|
# Convert all fields to arguments
|
|
fields = command.__fields__ # type: ignore
|
|
for name, field in fields.items():
|
|
if name in exclude_fields:
|
|
continue
|
|
|
|
add_field_argument(command_parser, name, field)
|
|
|
|
|
|
def add_graph_parsers(
|
|
subparsers, graphs: list[LibraryGraph], add_arguments: Union[Callable[[argparse.ArgumentParser], None], None] = None
|
|
):
|
|
for graph in graphs:
|
|
command_parser = subparsers.add_parser(graph.name, help=graph.description)
|
|
|
|
if add_arguments is not None:
|
|
add_arguments(command_parser)
|
|
|
|
# Add arguments for inputs
|
|
for exposed_input in graph.exposed_inputs:
|
|
node = graph.graph.get_node(exposed_input.node_path)
|
|
field = node.__fields__[exposed_input.field]
|
|
default_override = getattr(node, exposed_input.field)
|
|
add_field_argument(command_parser, exposed_input.alias, field, default_override)
|
|
|
|
|
|
class CliContext:
|
|
invoker: Invoker
|
|
session: GraphExecutionState
|
|
parser: argparse.ArgumentParser
|
|
defaults: dict[str, Any]
|
|
graph_nodes: dict[str, str]
|
|
nodes_added: list[str]
|
|
|
|
def __init__(self, invoker: Invoker, session: GraphExecutionState, parser: argparse.ArgumentParser):
|
|
self.invoker = invoker
|
|
self.session = session
|
|
self.parser = parser
|
|
self.defaults = dict()
|
|
self.graph_nodes = dict()
|
|
self.nodes_added = list()
|
|
|
|
def get_session(self):
|
|
self.session = self.invoker.services.graph_execution_manager.get(self.session.id)
|
|
return self.session
|
|
|
|
def reset(self):
|
|
self.session = self.invoker.create_execution_state()
|
|
self.graph_nodes = dict()
|
|
self.nodes_added = list()
|
|
# Leave defaults unchanged
|
|
|
|
def add_node(self, node: BaseInvocation):
|
|
self.get_session()
|
|
self.session.graph.add_node(node)
|
|
self.nodes_added.append(node.id)
|
|
self.invoker.services.graph_execution_manager.set(self.session)
|
|
|
|
def add_edge(self, edge: Edge):
|
|
self.get_session()
|
|
self.session.add_edge(edge)
|
|
self.invoker.services.graph_execution_manager.set(self.session)
|
|
|
|
|
|
class ExitCli(Exception):
|
|
"""Exception to exit the CLI"""
|
|
|
|
pass
|
|
|
|
|
|
class BaseCommand(ABC, BaseModel):
|
|
"""A CLI command"""
|
|
|
|
# All commands must include a type name like this:
|
|
|
|
@classmethod
|
|
def get_all_subclasses(cls):
|
|
subclasses = []
|
|
toprocess = [cls]
|
|
while len(toprocess) > 0:
|
|
next = toprocess.pop(0)
|
|
next_subclasses = next.__subclasses__()
|
|
subclasses.extend(next_subclasses)
|
|
toprocess.extend(next_subclasses)
|
|
return subclasses
|
|
|
|
@classmethod
|
|
def get_commands(cls):
|
|
return tuple(BaseCommand.get_all_subclasses())
|
|
|
|
@classmethod
|
|
def get_commands_map(cls):
|
|
# Get the type strings out of the literals and into a dictionary
|
|
return dict(map(lambda t: (get_args(get_type_hints(t)["type"])[0], t), BaseCommand.get_all_subclasses()))
|
|
|
|
@abstractmethod
|
|
def run(self, context: CliContext) -> None:
|
|
"""Run the command. Raise ExitCli to exit."""
|
|
pass
|
|
|
|
|
|
class ExitCommand(BaseCommand):
|
|
"""Exits the CLI"""
|
|
|
|
type: Literal["exit"] = "exit"
|
|
|
|
def run(self, context: CliContext) -> None:
|
|
raise ExitCli()
|
|
|
|
|
|
class HelpCommand(BaseCommand):
|
|
"""Shows help"""
|
|
|
|
type: Literal["help"] = "help"
|
|
|
|
def run(self, context: CliContext) -> None:
|
|
context.parser.print_help()
|
|
|
|
|
|
def get_graph_execution_history(
|
|
graph_execution_state: GraphExecutionState,
|
|
) -> Iterable[str]:
|
|
"""Gets the history of fully-executed invocations for a graph execution"""
|
|
return (n for n in reversed(graph_execution_state.executed_history) if n in graph_execution_state.graph.nodes)
|
|
|
|
|
|
def get_invocation_command(invocation) -> str:
|
|
fields = invocation.__fields__.items()
|
|
type_hints = get_type_hints(type(invocation))
|
|
command = [invocation.type]
|
|
for name, field in fields:
|
|
if name in ["id", "type"]:
|
|
continue
|
|
|
|
# TODO: add links
|
|
|
|
# Skip image fields when serializing command
|
|
type_hint = type_hints.get(name) or None
|
|
if type_hint is ImageField or ImageField in get_args(type_hint):
|
|
continue
|
|
|
|
field_value = getattr(invocation, name)
|
|
field_default = field.default
|
|
if field_value != field_default:
|
|
if type_hint is str or str in get_args(type_hint):
|
|
command.append(f'--{name} "{field_value}"')
|
|
else:
|
|
command.append(f"--{name} {field_value}")
|
|
|
|
return " ".join(command)
|
|
|
|
|
|
class HistoryCommand(BaseCommand):
|
|
"""Shows the invocation history"""
|
|
|
|
type: Literal["history"] = "history"
|
|
|
|
# Inputs
|
|
# fmt: off
|
|
count: int = Field(default=5, gt=0, description="The number of history entries to show")
|
|
# fmt: on
|
|
|
|
def run(self, context: CliContext) -> None:
|
|
history = list(get_graph_execution_history(context.get_session()))
|
|
for i in range(min(self.count, len(history))):
|
|
entry_id = history[-1 - i]
|
|
entry = context.get_session().graph.get_node(entry_id)
|
|
logger.info(f"{entry_id}: {get_invocation_command(entry)}")
|
|
|
|
|
|
class SetDefaultCommand(BaseCommand):
|
|
"""Sets a default value for a field"""
|
|
|
|
type: Literal["default"] = "default"
|
|
|
|
# Inputs
|
|
# fmt: off
|
|
field: str = Field(description="The field to set the default for")
|
|
value: str = Field(description="The value to set the default to, or None to clear the default")
|
|
# fmt: on
|
|
|
|
def run(self, context: CliContext) -> None:
|
|
if self.value is None:
|
|
if self.field in context.defaults:
|
|
del context.defaults[self.field]
|
|
else:
|
|
context.defaults[self.field] = self.value
|
|
|
|
|
|
class DrawGraphCommand(BaseCommand):
|
|
"""Debugs a graph"""
|
|
|
|
type: Literal["draw_graph"] = "draw_graph"
|
|
|
|
def run(self, context: CliContext) -> None:
|
|
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
|
|
nxgraph = session.graph.nx_graph_flat()
|
|
|
|
# Draw the networkx graph
|
|
plt.figure(figsize=(20, 20))
|
|
pos = nx.spectral_layout(nxgraph)
|
|
nx.draw_networkx_nodes(nxgraph, pos, node_size=1000)
|
|
nx.draw_networkx_edges(nxgraph, pos, width=2)
|
|
nx.draw_networkx_labels(nxgraph, pos, font_size=20, font_family="sans-serif")
|
|
plt.axis("off")
|
|
plt.show()
|
|
|
|
|
|
class DrawExecutionGraphCommand(BaseCommand):
|
|
"""Debugs an execution graph"""
|
|
|
|
type: Literal["draw_xgraph"] = "draw_xgraph"
|
|
|
|
def run(self, context: CliContext) -> None:
|
|
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
|
|
nxgraph = session.execution_graph.nx_graph_flat()
|
|
|
|
# Draw the networkx graph
|
|
plt.figure(figsize=(20, 20))
|
|
pos = nx.spectral_layout(nxgraph)
|
|
nx.draw_networkx_nodes(nxgraph, pos, node_size=1000)
|
|
nx.draw_networkx_edges(nxgraph, pos, width=2)
|
|
nx.draw_networkx_labels(nxgraph, pos, font_size=20, font_family="sans-serif")
|
|
plt.axis("off")
|
|
plt.show()
|
|
|
|
|
|
class SortedHelpFormatter(argparse.HelpFormatter):
|
|
def _iter_indented_subactions(self, action):
|
|
try:
|
|
get_subactions = action._get_subactions
|
|
except AttributeError:
|
|
pass
|
|
else:
|
|
self._indent()
|
|
if isinstance(action, argparse._SubParsersAction):
|
|
for subaction in sorted(get_subactions(), key=lambda x: x.dest):
|
|
yield subaction
|
|
else:
|
|
for subaction in get_subactions():
|
|
yield subaction
|
|
self._dedent()
|