mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(api): chore: pydantic & fastapi upgrade
Upgrade pydantic and fastapi to latest. - pydantic~=2.4.2 - fastapi~=103.2 - fastapi-events~=0.9.1 **Big Changes** There are a number of logic changes needed to support pydantic v2. Most changes are very simple, like using the new methods to serialized and deserialize models, but there are a few more complex changes. **Invocations** The biggest change relates to invocation creation, instantiation and validation. Because pydantic v2 moves all validation logic into the rust pydantic-core, we may no longer directly stick our fingers into the validation pie. Previously, we (ab)used models and fields to allow invocation fields to be optional at instantiation, but required when `invoke()` is called. We directly manipulated the fields and invocation models when calling `invoke()`. With pydantic v2, this is much more involved. Changes to the python wrapper do not propagate down to the rust validation logic - you have to rebuild the model. This causes problem with concurrent access to the invocation classes and is not a free operation. This logic has been totally refactored and we do not need to change the model any more. The details are in `baseinvocation.py`, in the `InputField` function and `BaseInvocation.invoke_internal()` method. In the end, this implementation is cleaner. **Invocation Fields** In pydantic v2, you can no longer directly add or remove fields from a model. Previously, we did this to add the `type` field to invocations. **Invocation Decorators** With pydantic v2, we instead use the imperative `create_model()` API to create a new model with the additional field. This is done in `baseinvocation.py` in the `invocation()` wrapper. A similar technique is used for `invocation_output()`. **Minor Changes** There are a number of minor changes around the pydantic v2 models API. **Protected `model_` Namespace** All models' pydantic-provided methods and attributes are prefixed with `model_` and this is considered a protected namespace. This causes some conflict, because "model" means something to us, and we have a ton of pydantic models with attributes starting with "model_". Forunately, there are no direct conflicts. However, in any pydantic model where we define an attribute or method that starts with "model_", we must tell set the protected namespaces to an empty tuple. ```py class IPAdapterModelField(BaseModel): model_name: str = Field(description="Name of the IP-Adapter model") base_model: BaseModelType = Field(description="Base model") model_config = ConfigDict(protected_namespaces=()) ``` **Model Serialization** Pydantic models no longer have `Model.dict()` or `Model.json()`. Instead, we use `Model.model_dump()` or `Model.model_dump_json()`. **Model Deserialization** Pydantic models no longer have `Model.parse_obj()` or `Model.parse_raw()`, and there are no `parse_raw_as()` or `parse_obj_as()` functions. Instead, you need to create a `TypeAdapter` object to parse python objects or JSON into a model. ```py adapter_graph = TypeAdapter(Graph) deserialized_graph_from_json = adapter_graph.validate_json(graph_json) deserialized_graph_from_dict = adapter_graph.validate_python(graph_dict) ``` **Field Customisation** Pydantic `Field`s no longer accept arbitrary args. Now, you must put all additional arbitrary args in a `json_schema_extra` arg on the field. **Schema Customisation** FastAPI and pydantic schema generation now follows the OpenAPI version 3.1 spec. This necessitates two changes: - Our schema customization logic has been revised - Schema parsing to build node templates has been revised The specific aren't important, but this does present additional surface area for bugs. **Performance Improvements** Pydantic v2 is a full rewrite with a rust backend. This offers a substantial performance improvement (pydantic claims 5x to 50x depending on the task). We'll notice this the most during serialization and deserialization of sessions/graphs, which happens very very often - a couple times per node. I haven't done any benchmarks, but anecdotally, graph execution is much faster. Also, very larges graphs - like with massive iterators - are much, much faster.
This commit is contained in:
@ -7,28 +7,16 @@ import re
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from inspect import signature
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
AbstractSet,
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
get_args,
|
||||
get_type_hints,
|
||||
)
|
||||
from types import UnionType
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union
|
||||
|
||||
import semver
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from pydantic.fields import ModelField, Undefined
|
||||
from pydantic.typing import NoArgAnyCallable
|
||||
from pydantic import BaseModel, ConfigDict, Field, create_model, field_validator
|
||||
from pydantic.fields import _Unset
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..services.invocation_services import InvocationServices
|
||||
@ -211,6 +199,11 @@ class _InputField(BaseModel):
|
||||
ui_choice_labels: Optional[dict[str, str]]
|
||||
item_default: Optional[Any]
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
json_schema_serialization_defaults_required=True,
|
||||
)
|
||||
|
||||
|
||||
class _OutputField(BaseModel):
|
||||
"""
|
||||
@ -224,34 +217,36 @@ class _OutputField(BaseModel):
|
||||
ui_type: Optional[UIType]
|
||||
ui_order: Optional[int]
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
json_schema_serialization_defaults_required=True,
|
||||
)
|
||||
|
||||
|
||||
def get_type(klass: BaseModel) -> str:
|
||||
"""Helper function to get an invocation or invocation output's type. This is the default value of the `type` field."""
|
||||
return klass.model_fields["type"].default
|
||||
|
||||
|
||||
def InputField(
|
||||
*args: Any,
|
||||
default: Any = Undefined,
|
||||
default_factory: Optional[NoArgAnyCallable] = None,
|
||||
alias: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
exclude: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
|
||||
include: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
|
||||
const: Optional[bool] = None,
|
||||
gt: Optional[float] = None,
|
||||
ge: Optional[float] = None,
|
||||
lt: Optional[float] = None,
|
||||
le: Optional[float] = None,
|
||||
multiple_of: Optional[float] = None,
|
||||
allow_inf_nan: Optional[bool] = None,
|
||||
max_digits: Optional[int] = None,
|
||||
decimal_places: Optional[int] = None,
|
||||
min_items: Optional[int] = None,
|
||||
max_items: Optional[int] = None,
|
||||
unique_items: Optional[bool] = None,
|
||||
min_length: Optional[int] = None,
|
||||
max_length: Optional[int] = None,
|
||||
allow_mutation: bool = True,
|
||||
regex: Optional[str] = None,
|
||||
discriminator: Optional[str] = None,
|
||||
repr: bool = True,
|
||||
# copied from pydantic's Field
|
||||
default: Any = _Unset,
|
||||
default_factory: Callable[[], Any] | None = _Unset,
|
||||
title: str | None = _Unset,
|
||||
description: str | None = _Unset,
|
||||
pattern: str | None = _Unset,
|
||||
strict: bool | None = _Unset,
|
||||
gt: float | None = _Unset,
|
||||
ge: float | None = _Unset,
|
||||
lt: float | None = _Unset,
|
||||
le: float | None = _Unset,
|
||||
multiple_of: float | None = _Unset,
|
||||
allow_inf_nan: bool | None = _Unset,
|
||||
max_digits: int | None = _Unset,
|
||||
decimal_places: int | None = _Unset,
|
||||
min_length: int | None = _Unset,
|
||||
max_length: int | None = _Unset,
|
||||
# custom
|
||||
input: Input = Input.Any,
|
||||
ui_type: Optional[UIType] = None,
|
||||
ui_component: Optional[UIComponent] = None,
|
||||
@ -259,7 +254,6 @@ def InputField(
|
||||
ui_order: Optional[int] = None,
|
||||
ui_choice_labels: Optional[dict[str, str]] = None,
|
||||
item_default: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""
|
||||
Creates an input field for an invocation.
|
||||
@ -289,18 +283,26 @@ def InputField(
|
||||
: param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \
|
||||
|
||||
: param bool item_default: [None] Specifies the default item value, if this is a collection input. \
|
||||
Ignored for non-collection fields..
|
||||
Ignored for non-collection fields.
|
||||
"""
|
||||
return Field(
|
||||
*args,
|
||||
|
||||
json_schema_extra_: dict[str, Any] = dict(
|
||||
input=input,
|
||||
ui_type=ui_type,
|
||||
ui_component=ui_component,
|
||||
ui_hidden=ui_hidden,
|
||||
ui_order=ui_order,
|
||||
item_default=item_default,
|
||||
ui_choice_labels=ui_choice_labels,
|
||||
)
|
||||
|
||||
field_args = dict(
|
||||
default=default,
|
||||
default_factory=default_factory,
|
||||
alias=alias,
|
||||
title=title,
|
||||
description=description,
|
||||
exclude=exclude,
|
||||
include=include,
|
||||
const=const,
|
||||
pattern=pattern,
|
||||
strict=strict,
|
||||
gt=gt,
|
||||
ge=ge,
|
||||
lt=lt,
|
||||
@ -309,57 +311,92 @@ def InputField(
|
||||
allow_inf_nan=allow_inf_nan,
|
||||
max_digits=max_digits,
|
||||
decimal_places=decimal_places,
|
||||
min_items=min_items,
|
||||
max_items=max_items,
|
||||
unique_items=unique_items,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
allow_mutation=allow_mutation,
|
||||
regex=regex,
|
||||
discriminator=discriminator,
|
||||
repr=repr,
|
||||
input=input,
|
||||
ui_type=ui_type,
|
||||
ui_component=ui_component,
|
||||
ui_hidden=ui_hidden,
|
||||
ui_order=ui_order,
|
||||
item_default=item_default,
|
||||
ui_choice_labels=ui_choice_labels,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
"""
|
||||
Invocation definitions have their fields typed correctly for their `invoke()` functions.
|
||||
This typing is often more specific than the actual invocation definition requires, because
|
||||
fields may have values provided only by connections.
|
||||
|
||||
For example, consider an ResizeImageInvocation with an `image: ImageField` field.
|
||||
|
||||
`image` is required during the call to `invoke()`, but when the python class is instantiated,
|
||||
the field may not be present. This is fine, because that image field will be provided by a
|
||||
an ancestor node that outputs the image.
|
||||
|
||||
So we'd like to type that `image` field as `Optional[ImageField]`. If we do that, however, then
|
||||
we need to handle a lot of extra logic in the `invoke()` function to check if the field has a
|
||||
value or not. This is very tedious.
|
||||
|
||||
Ideally, the invocation definition would be able to specify that the field is required during
|
||||
invocation, but optional during instantiation. So the field would be typed as `image: ImageField`,
|
||||
but when calling the `invoke()` function, we raise an error if the field is not present.
|
||||
|
||||
To do this, we need to do a bit of fanagling to make the pydantic field optional, and then do
|
||||
extra validation when calling `invoke()`.
|
||||
|
||||
There is some additional logic here to cleaning create the pydantic field via the wrapper.
|
||||
"""
|
||||
|
||||
# Filter out field args not provided
|
||||
provided_args = {k: v for (k, v) in field_args.items() if v is not PydanticUndefined}
|
||||
|
||||
if (default is not PydanticUndefined) and (default_factory is not PydanticUndefined):
|
||||
raise ValueError("Cannot specify both default and default_factory")
|
||||
|
||||
# because we are manually making fields optional, we need to store the original required bool for reference later
|
||||
if default is PydanticUndefined and default_factory is PydanticUndefined:
|
||||
json_schema_extra_.update(dict(orig_required=True))
|
||||
else:
|
||||
json_schema_extra_.update(dict(orig_required=False))
|
||||
|
||||
# make Input.Any and Input.Connection fields optional, providing None as a default if the field doesn't already have one
|
||||
if (input is Input.Any or input is Input.Connection) and default_factory is PydanticUndefined:
|
||||
default_ = None if default is PydanticUndefined else default
|
||||
provided_args.update(dict(default=default_))
|
||||
if default is not PydanticUndefined:
|
||||
# before invoking, we'll grab the original default value and set it on the field if the field wasn't provided a value
|
||||
json_schema_extra_.update(dict(default=default))
|
||||
json_schema_extra_.update(dict(orig_default=default))
|
||||
elif default is not PydanticUndefined and default_factory is PydanticUndefined:
|
||||
default_ = default
|
||||
provided_args.update(dict(default=default_))
|
||||
json_schema_extra_.update(dict(orig_default=default_))
|
||||
elif default_factory is not PydanticUndefined:
|
||||
provided_args.update(dict(default_factory=default_factory))
|
||||
# TODO: cannot serialize default_factory...
|
||||
# json_schema_extra_.update(dict(orig_default_factory=default_factory))
|
||||
|
||||
return Field(
|
||||
**provided_args,
|
||||
json_schema_extra=json_schema_extra_,
|
||||
)
|
||||
|
||||
|
||||
def OutputField(
|
||||
*args: Any,
|
||||
default: Any = Undefined,
|
||||
default_factory: Optional[NoArgAnyCallable] = None,
|
||||
alias: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
exclude: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
|
||||
include: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
|
||||
const: Optional[bool] = None,
|
||||
gt: Optional[float] = None,
|
||||
ge: Optional[float] = None,
|
||||
lt: Optional[float] = None,
|
||||
le: Optional[float] = None,
|
||||
multiple_of: Optional[float] = None,
|
||||
allow_inf_nan: Optional[bool] = None,
|
||||
max_digits: Optional[int] = None,
|
||||
decimal_places: Optional[int] = None,
|
||||
min_items: Optional[int] = None,
|
||||
max_items: Optional[int] = None,
|
||||
unique_items: Optional[bool] = None,
|
||||
min_length: Optional[int] = None,
|
||||
max_length: Optional[int] = None,
|
||||
allow_mutation: bool = True,
|
||||
regex: Optional[str] = None,
|
||||
discriminator: Optional[str] = None,
|
||||
repr: bool = True,
|
||||
# copied from pydantic's Field
|
||||
default: Any = _Unset,
|
||||
default_factory: Callable[[], Any] | None = _Unset,
|
||||
title: str | None = _Unset,
|
||||
description: str | None = _Unset,
|
||||
pattern: str | None = _Unset,
|
||||
strict: bool | None = _Unset,
|
||||
gt: float | None = _Unset,
|
||||
ge: float | None = _Unset,
|
||||
lt: float | None = _Unset,
|
||||
le: float | None = _Unset,
|
||||
multiple_of: float | None = _Unset,
|
||||
allow_inf_nan: bool | None = _Unset,
|
||||
max_digits: int | None = _Unset,
|
||||
decimal_places: int | None = _Unset,
|
||||
min_length: int | None = _Unset,
|
||||
max_length: int | None = _Unset,
|
||||
# custom
|
||||
ui_type: Optional[UIType] = None,
|
||||
ui_hidden: bool = False,
|
||||
ui_order: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""
|
||||
Creates an output field for an invocation output.
|
||||
@ -379,15 +416,12 @@ def OutputField(
|
||||
: param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \
|
||||
"""
|
||||
return Field(
|
||||
*args,
|
||||
default=default,
|
||||
default_factory=default_factory,
|
||||
alias=alias,
|
||||
title=title,
|
||||
description=description,
|
||||
exclude=exclude,
|
||||
include=include,
|
||||
const=const,
|
||||
pattern=pattern,
|
||||
strict=strict,
|
||||
gt=gt,
|
||||
ge=ge,
|
||||
lt=lt,
|
||||
@ -396,19 +430,13 @@ def OutputField(
|
||||
allow_inf_nan=allow_inf_nan,
|
||||
max_digits=max_digits,
|
||||
decimal_places=decimal_places,
|
||||
min_items=min_items,
|
||||
max_items=max_items,
|
||||
unique_items=unique_items,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
allow_mutation=allow_mutation,
|
||||
regex=regex,
|
||||
discriminator=discriminator,
|
||||
repr=repr,
|
||||
ui_type=ui_type,
|
||||
ui_hidden=ui_hidden,
|
||||
ui_order=ui_order,
|
||||
**kwargs,
|
||||
json_schema_extra=dict(
|
||||
ui_type=ui_type,
|
||||
ui_hidden=ui_hidden,
|
||||
ui_order=ui_order,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -422,7 +450,13 @@ class UIConfigBase(BaseModel):
|
||||
title: Optional[str] = Field(default=None, description="The node's display name")
|
||||
category: Optional[str] = Field(default=None, description="The node's category")
|
||||
version: Optional[str] = Field(
|
||||
default=None, description='The node\'s version. Should be a valid semver string e.g. "1.0.0" or "3.8.13".'
|
||||
default=None,
|
||||
description='The node\'s version. Should be a valid semver string e.g. "1.0.0" or "3.8.13".',
|
||||
)
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
json_schema_serialization_defaults_required=True,
|
||||
)
|
||||
|
||||
|
||||
@ -457,23 +491,38 @@ class BaseInvocationOutput(BaseModel):
|
||||
All invocation outputs must use the `@invocation_output` decorator to provide their unique type.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_all_subclasses_tuple(cls):
|
||||
subclasses = []
|
||||
toprocess = [cls]
|
||||
while len(toprocess) > 0:
|
||||
next = toprocess.pop(0)
|
||||
next_subclasses = next.__subclasses__()
|
||||
subclasses.extend(next_subclasses)
|
||||
toprocess.extend(next_subclasses)
|
||||
return tuple(subclasses)
|
||||
_output_classes: ClassVar[set[BaseInvocationOutput]] = set()
|
||||
|
||||
class Config:
|
||||
@staticmethod
|
||||
def schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
if "required" not in schema or not isinstance(schema["required"], list):
|
||||
schema["required"] = list()
|
||||
schema["required"].extend(["type"])
|
||||
@classmethod
|
||||
def register_output(cls, output: BaseInvocationOutput) -> None:
|
||||
cls._output_classes.add(output)
|
||||
|
||||
@classmethod
|
||||
def get_outputs(cls) -> Iterable[BaseInvocationOutput]:
|
||||
return cls._output_classes
|
||||
|
||||
@classmethod
|
||||
def get_outputs_union(cls) -> UnionType:
|
||||
outputs_union = Union[tuple(cls._output_classes)] # type: ignore [valid-type]
|
||||
return outputs_union # type: ignore [return-value]
|
||||
|
||||
@classmethod
|
||||
def get_output_types(cls) -> Iterable[str]:
|
||||
return map(lambda i: get_type(i), BaseInvocationOutput.get_outputs())
|
||||
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
# Because we use a pydantic Literal field with default value for the invocation type,
|
||||
# it will be typed as optional in the OpenAPI schema. Make it required manually.
|
||||
if "required" not in schema or not isinstance(schema["required"], list):
|
||||
schema["required"] = list()
|
||||
schema["required"].extend(["type"])
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
json_schema_serialization_defaults_required=True,
|
||||
json_schema_extra=json_schema_extra,
|
||||
)
|
||||
|
||||
|
||||
class RequiredConnectionException(Exception):
|
||||
@ -498,104 +547,91 @@ class BaseInvocation(ABC, BaseModel):
|
||||
All invocations must use the `@invocation` decorator to provide their unique type.
|
||||
"""
|
||||
|
||||
_invocation_classes: ClassVar[set[BaseInvocation]] = set()
|
||||
|
||||
@classmethod
|
||||
def get_all_subclasses(cls):
|
||||
def register_invocation(cls, invocation: BaseInvocation) -> None:
|
||||
cls._invocation_classes.add(invocation)
|
||||
|
||||
@classmethod
|
||||
def get_invocations_union(cls) -> UnionType:
|
||||
invocations_union = Union[tuple(cls._invocation_classes)] # type: ignore [valid-type]
|
||||
return invocations_union # type: ignore [return-value]
|
||||
|
||||
@classmethod
|
||||
def get_invocations(cls) -> Iterable[BaseInvocation]:
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
subclasses = []
|
||||
toprocess = [cls]
|
||||
while len(toprocess) > 0:
|
||||
next = toprocess.pop(0)
|
||||
next_subclasses = next.__subclasses__()
|
||||
subclasses.extend(next_subclasses)
|
||||
toprocess.extend(next_subclasses)
|
||||
allowed_invocations = []
|
||||
for sc in subclasses:
|
||||
allowed_invocations: set[BaseInvocation] = set()
|
||||
for sc in cls._invocation_classes:
|
||||
invocation_type = get_type(sc)
|
||||
is_in_allowlist = (
|
||||
sc.__fields__.get("type").default in app_config.allow_nodes
|
||||
if isinstance(app_config.allow_nodes, list)
|
||||
else True
|
||||
invocation_type in app_config.allow_nodes if isinstance(app_config.allow_nodes, list) else True
|
||||
)
|
||||
|
||||
is_in_denylist = (
|
||||
sc.__fields__.get("type").default in app_config.deny_nodes
|
||||
if isinstance(app_config.deny_nodes, list)
|
||||
else False
|
||||
invocation_type in app_config.deny_nodes if isinstance(app_config.deny_nodes, list) else False
|
||||
)
|
||||
|
||||
if is_in_allowlist and not is_in_denylist:
|
||||
allowed_invocations.append(sc)
|
||||
allowed_invocations.add(sc)
|
||||
return allowed_invocations
|
||||
|
||||
@classmethod
|
||||
def get_invocations(cls):
|
||||
return tuple(BaseInvocation.get_all_subclasses())
|
||||
|
||||
@classmethod
|
||||
def get_invocations_map(cls):
|
||||
def get_invocations_map(cls) -> dict[str, BaseInvocation]:
|
||||
# Get the type strings out of the literals and into a dictionary
|
||||
return dict(
|
||||
map(
|
||||
lambda t: (get_args(get_type_hints(t)["type"])[0], t),
|
||||
BaseInvocation.get_all_subclasses(),
|
||||
lambda i: (get_type(i), i),
|
||||
BaseInvocation.get_invocations(),
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_output_type(cls):
|
||||
def get_invocation_types(cls) -> Iterable[str]:
|
||||
return map(lambda i: get_type(i), BaseInvocation.get_invocations())
|
||||
|
||||
@classmethod
|
||||
def get_output_type(cls) -> BaseInvocationOutput:
|
||||
return signature(cls.invoke).return_annotation
|
||||
|
||||
class Config:
|
||||
validate_assignment = True
|
||||
validate_all = True
|
||||
|
||||
@staticmethod
|
||||
def schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
uiconfig = getattr(model_class, "UIConfig", None)
|
||||
if uiconfig and hasattr(uiconfig, "title"):
|
||||
schema["title"] = uiconfig.title
|
||||
if uiconfig and hasattr(uiconfig, "tags"):
|
||||
schema["tags"] = uiconfig.tags
|
||||
if uiconfig and hasattr(uiconfig, "category"):
|
||||
schema["category"] = uiconfig.category
|
||||
if uiconfig and hasattr(uiconfig, "version"):
|
||||
schema["version"] = uiconfig.version
|
||||
if "required" not in schema or not isinstance(schema["required"], list):
|
||||
schema["required"] = list()
|
||||
schema["required"].extend(["type", "id"])
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
# Add the various UI-facing attributes to the schema. These are used to build the invocation templates.
|
||||
uiconfig = getattr(model_class, "UIConfig", None)
|
||||
if uiconfig and hasattr(uiconfig, "title"):
|
||||
schema["title"] = uiconfig.title
|
||||
if uiconfig and hasattr(uiconfig, "tags"):
|
||||
schema["tags"] = uiconfig.tags
|
||||
if uiconfig and hasattr(uiconfig, "category"):
|
||||
schema["category"] = uiconfig.category
|
||||
if uiconfig and hasattr(uiconfig, "version"):
|
||||
schema["version"] = uiconfig.version
|
||||
if "required" not in schema or not isinstance(schema["required"], list):
|
||||
schema["required"] = list()
|
||||
schema["required"].extend(["type", "id"])
|
||||
|
||||
@abstractmethod
|
||||
def invoke(self, context: InvocationContext) -> BaseInvocationOutput:
|
||||
"""Invoke with provided context and return outputs."""
|
||||
pass
|
||||
|
||||
def __init__(self, **data):
|
||||
# nodes may have required fields, that can accept input from connections
|
||||
# on instantiation of the model, we need to exclude these from validation
|
||||
restore = dict()
|
||||
try:
|
||||
field_names = list(self.__fields__.keys())
|
||||
for field_name in field_names:
|
||||
# if the field is required and may get its value from a connection, exclude it from validation
|
||||
field = self.__fields__[field_name]
|
||||
_input = field.field_info.extra.get("input", None)
|
||||
if _input in [Input.Connection, Input.Any] and field.required:
|
||||
if field_name not in data:
|
||||
restore[field_name] = self.__fields__.pop(field_name)
|
||||
# instantiate the node, which will validate the data
|
||||
super().__init__(**data)
|
||||
finally:
|
||||
# restore the removed fields
|
||||
for field_name, field in restore.items():
|
||||
self.__fields__[field_name] = field
|
||||
|
||||
def invoke_internal(self, context: InvocationContext) -> BaseInvocationOutput:
|
||||
for field_name, field in self.__fields__.items():
|
||||
_input = field.field_info.extra.get("input", None)
|
||||
if field.required and not hasattr(self, field_name):
|
||||
if _input == Input.Connection:
|
||||
raise RequiredConnectionException(self.__fields__["type"].default, field_name)
|
||||
elif _input == Input.Any:
|
||||
raise MissingInputException(self.__fields__["type"].default, field_name)
|
||||
for field_name, field in self.model_fields.items():
|
||||
if not field.json_schema_extra or callable(field.json_schema_extra):
|
||||
# something has gone terribly awry, we should always have this and it should be a dict
|
||||
continue
|
||||
|
||||
# Here we handle the case where the field is optional in the pydantic class, but required
|
||||
# in the `invoke()` method.
|
||||
|
||||
orig_default = field.json_schema_extra.get("orig_default", PydanticUndefined)
|
||||
orig_required = field.json_schema_extra.get("orig_required", True)
|
||||
input_ = field.json_schema_extra.get("input", None)
|
||||
if orig_default is not PydanticUndefined and not hasattr(self, field_name):
|
||||
setattr(self, field_name, orig_default)
|
||||
if orig_required and orig_default is PydanticUndefined and getattr(self, field_name) is None:
|
||||
if input_ == Input.Connection:
|
||||
raise RequiredConnectionException(self.model_fields["type"].default, field_name)
|
||||
elif input_ == Input.Any:
|
||||
raise MissingInputException(self.model_fields["type"].default, field_name)
|
||||
|
||||
# skip node cache codepath if it's disabled
|
||||
if context.services.configuration.node_cache_size == 0:
|
||||
@ -618,23 +654,31 @@ class BaseInvocation(ABC, BaseModel):
|
||||
return self.invoke(context)
|
||||
|
||||
def get_type(self) -> str:
|
||||
return self.__fields__["type"].default
|
||||
return self.model_fields["type"].default
|
||||
|
||||
id: str = Field(
|
||||
description="The id of this instance of an invocation. Must be unique among all instances of invocations."
|
||||
default_factory=uuid_string,
|
||||
description="The id of this instance of an invocation. Must be unique among all instances of invocations.",
|
||||
)
|
||||
is_intermediate: bool = InputField(
|
||||
default=False, description="Whether or not this is an intermediate invocation.", ui_type=UIType.IsIntermediate
|
||||
is_intermediate: Optional[bool] = Field(
|
||||
default=False,
|
||||
description="Whether or not this is an intermediate invocation.",
|
||||
json_schema_extra=dict(ui_type=UIType.IsIntermediate),
|
||||
)
|
||||
workflow: Optional[str] = InputField(
|
||||
workflow: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The workflow to save with the image",
|
||||
ui_type=UIType.WorkflowField,
|
||||
json_schema_extra=dict(ui_type=UIType.WorkflowField),
|
||||
)
|
||||
use_cache: Optional[bool] = Field(
|
||||
default=True,
|
||||
description="Whether or not to use the cache",
|
||||
)
|
||||
use_cache: bool = InputField(default=True, description="Whether or not to use the cache")
|
||||
|
||||
@validator("workflow", pre=True)
|
||||
@field_validator("workflow", mode="before")
|
||||
@classmethod
|
||||
def validate_workflow_is_json(cls, v):
|
||||
"""We don't have a workflow schema in the backend, so we just check that it's valid JSON"""
|
||||
if v is None:
|
||||
return None
|
||||
try:
|
||||
@ -645,8 +689,14 @@ class BaseInvocation(ABC, BaseModel):
|
||||
|
||||
UIConfig: ClassVar[Type[UIConfigBase]]
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
json_schema_extra=json_schema_extra,
|
||||
json_schema_serialization_defaults_required=True,
|
||||
)
|
||||
|
||||
GenericBaseInvocation = TypeVar("GenericBaseInvocation", bound=BaseInvocation)
|
||||
|
||||
TBaseInvocation = TypeVar("TBaseInvocation", bound=BaseInvocation)
|
||||
|
||||
|
||||
def invocation(
|
||||
@ -656,7 +706,7 @@ def invocation(
|
||||
category: Optional[str] = None,
|
||||
version: Optional[str] = None,
|
||||
use_cache: Optional[bool] = True,
|
||||
) -> Callable[[Type[GenericBaseInvocation]], Type[GenericBaseInvocation]]:
|
||||
) -> Callable[[Type[TBaseInvocation]], Type[TBaseInvocation]]:
|
||||
"""
|
||||
Adds metadata to an invocation.
|
||||
|
||||
@ -668,12 +718,15 @@ def invocation(
|
||||
:param Optional[bool] use_cache: Whether or not to use the invocation cache. Defaults to True. The user may override this in the workflow editor.
|
||||
"""
|
||||
|
||||
def wrapper(cls: Type[GenericBaseInvocation]) -> Type[GenericBaseInvocation]:
|
||||
def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]:
|
||||
# Validate invocation types on creation of invocation classes
|
||||
# TODO: ensure unique?
|
||||
if re.compile(r"^\S+$").match(invocation_type) is None:
|
||||
raise ValueError(f'"invocation_type" must consist of non-whitespace characters, got "{invocation_type}"')
|
||||
|
||||
if invocation_type in BaseInvocation.get_invocation_types():
|
||||
raise ValueError(f'Invocation type "{invocation_type}" already exists')
|
||||
|
||||
# Add OpenAPI schema extras
|
||||
uiconf_name = cls.__qualname__ + ".UIConfig"
|
||||
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
|
||||
@ -691,59 +744,83 @@ def invocation(
|
||||
raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e
|
||||
cls.UIConfig.version = version
|
||||
if use_cache is not None:
|
||||
cls.__fields__["use_cache"].default = use_cache
|
||||
cls.model_fields["use_cache"].default = use_cache
|
||||
|
||||
# Add the invocation type to the model.
|
||||
|
||||
# You'd be tempted to just add the type field and rebuild the model, like this:
|
||||
# cls.model_fields.update(type=FieldInfo.from_annotated_attribute(Literal[invocation_type], invocation_type))
|
||||
# cls.model_rebuild() or cls.model_rebuild(force=True)
|
||||
|
||||
# Unfortunately, because the `GraphInvocation` uses a forward ref in its `graph` field's annotation, this does
|
||||
# not work. Instead, we have to create a new class with the type field and patch the original class with it.
|
||||
|
||||
# Add the invocation type to the pydantic model of the invocation
|
||||
invocation_type_annotation = Literal[invocation_type] # type: ignore
|
||||
invocation_type_field = ModelField.infer(
|
||||
name="type",
|
||||
value=invocation_type,
|
||||
annotation=invocation_type_annotation,
|
||||
class_validators=None,
|
||||
config=cls.__config__,
|
||||
invocation_type_field = Field(
|
||||
title="type",
|
||||
default=invocation_type,
|
||||
)
|
||||
cls.__fields__.update({"type": invocation_type_field})
|
||||
# to support 3.9, 3.10 and 3.11, as described in https://docs.python.org/3/howto/annotations.html
|
||||
if annotations := cls.__dict__.get("__annotations__", None):
|
||||
annotations.update({"type": invocation_type_annotation})
|
||||
|
||||
docstring = cls.__doc__
|
||||
cls = create_model(
|
||||
cls.__qualname__,
|
||||
__base__=cls,
|
||||
__module__=cls.__module__,
|
||||
type=(invocation_type_annotation, invocation_type_field),
|
||||
)
|
||||
cls.__doc__ = docstring
|
||||
|
||||
# TODO: how to type this correctly? it's typed as ModelMetaclass, a private class in pydantic
|
||||
BaseInvocation.register_invocation(cls) # type: ignore
|
||||
|
||||
return cls
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
GenericBaseInvocationOutput = TypeVar("GenericBaseInvocationOutput", bound=BaseInvocationOutput)
|
||||
TBaseInvocationOutput = TypeVar("TBaseInvocationOutput", bound=BaseInvocationOutput)
|
||||
|
||||
|
||||
def invocation_output(
|
||||
output_type: str,
|
||||
) -> Callable[[Type[GenericBaseInvocationOutput]], Type[GenericBaseInvocationOutput]]:
|
||||
) -> Callable[[Type[TBaseInvocationOutput]], Type[TBaseInvocationOutput]]:
|
||||
"""
|
||||
Adds metadata to an invocation output.
|
||||
|
||||
:param str output_type: The type of the invocation output. Must be unique among all invocation outputs.
|
||||
"""
|
||||
|
||||
def wrapper(cls: Type[GenericBaseInvocationOutput]) -> Type[GenericBaseInvocationOutput]:
|
||||
def wrapper(cls: Type[TBaseInvocationOutput]) -> Type[TBaseInvocationOutput]:
|
||||
# Validate output types on creation of invocation output classes
|
||||
# TODO: ensure unique?
|
||||
if re.compile(r"^\S+$").match(output_type) is None:
|
||||
raise ValueError(f'"output_type" must consist of non-whitespace characters, got "{output_type}"')
|
||||
|
||||
# Add the output type to the pydantic model of the invocation output
|
||||
output_type_annotation = Literal[output_type] # type: ignore
|
||||
output_type_field = ModelField.infer(
|
||||
name="type",
|
||||
value=output_type,
|
||||
annotation=output_type_annotation,
|
||||
class_validators=None,
|
||||
config=cls.__config__,
|
||||
)
|
||||
cls.__fields__.update({"type": output_type_field})
|
||||
if output_type in BaseInvocationOutput.get_output_types():
|
||||
raise ValueError(f'Invocation type "{output_type}" already exists')
|
||||
|
||||
# to support 3.9, 3.10 and 3.11, as described in https://docs.python.org/3/howto/annotations.html
|
||||
if annotations := cls.__dict__.get("__annotations__", None):
|
||||
annotations.update({"type": output_type_annotation})
|
||||
# Add the output type to the model.
|
||||
|
||||
output_type_annotation = Literal[output_type] # type: ignore
|
||||
output_type_field = Field(
|
||||
title="type",
|
||||
default=output_type,
|
||||
)
|
||||
|
||||
docstring = cls.__doc__
|
||||
cls = create_model(
|
||||
cls.__qualname__,
|
||||
__base__=cls,
|
||||
__module__=cls.__module__,
|
||||
type=(output_type_annotation, output_type_field),
|
||||
)
|
||||
cls.__doc__ = docstring
|
||||
|
||||
BaseInvocationOutput.register_output(cls) # type: ignore # TODO: how to type this correctly?
|
||||
|
||||
return cls
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
GenericBaseModel = TypeVar("GenericBaseModel", bound=BaseModel)
|
||||
|
Reference in New Issue
Block a user