mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
tidy(nodes): move all field things to fields.py
Unfortunately, this is necessary to prevent circular imports at runtime.
This commit is contained in:
@ -12,10 +12,11 @@ from types import UnionType
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union, cast
|
||||
|
||||
import semver
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, create_model
|
||||
from pydantic.fields import FieldInfo, _Unset
|
||||
from pydantic import BaseModel, ConfigDict, Field, create_model
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from invokeai.app.invocations.fields import FieldKind, Input
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
@ -52,393 +53,6 @@ class Classification(str, Enum, metaclass=MetaEnum):
|
||||
Prototype = "prototype"
|
||||
|
||||
|
||||
class Input(str, Enum, metaclass=MetaEnum):
|
||||
"""
|
||||
The type of input a field accepts.
|
||||
- `Input.Direct`: The field must have its value provided directly, when the invocation and field \
|
||||
are instantiated.
|
||||
- `Input.Connection`: The field must have its value provided by a connection.
|
||||
- `Input.Any`: The field may have its value provided either directly or by a connection.
|
||||
"""
|
||||
|
||||
Connection = "connection"
|
||||
Direct = "direct"
|
||||
Any = "any"
|
||||
|
||||
|
||||
class FieldKind(str, Enum, metaclass=MetaEnum):
|
||||
"""
|
||||
The kind of field.
|
||||
- `Input`: An input field on a node.
|
||||
- `Output`: An output field on a node.
|
||||
- `Internal`: A field which is treated as an input, but cannot be used in node definitions. Metadata is
|
||||
one example. It is provided to nodes via the WithMetadata class, and we want to reserve the field name
|
||||
"metadata" for this on all nodes. `FieldKind` is used to short-circuit the field name validation logic,
|
||||
allowing "metadata" for that field.
|
||||
- `NodeAttribute`: The field is a node attribute. These are fields which are not inputs or outputs,
|
||||
but which are used to store information about the node. For example, the `id` and `type` fields are node
|
||||
attributes.
|
||||
|
||||
The presence of this in `json_schema_extra["field_kind"]` is used when initializing node schemas on app
|
||||
startup, and when generating the OpenAPI schema for the workflow editor.
|
||||
"""
|
||||
|
||||
Input = "input"
|
||||
Output = "output"
|
||||
Internal = "internal"
|
||||
NodeAttribute = "node_attribute"
|
||||
|
||||
|
||||
class UIType(str, Enum, metaclass=MetaEnum):
|
||||
"""
|
||||
Type hints for the UI for situations in which the field type is not enough to infer the correct UI type.
|
||||
|
||||
- Model Fields
|
||||
The most common node-author-facing use will be for model fields. Internally, there is no difference
|
||||
between SD-1, SD-2 and SDXL model fields - they all use the class `MainModelField`. To ensure the
|
||||
base-model-specific UI is rendered, use e.g. `ui_type=UIType.SDXLMainModelField` to indicate that
|
||||
the field is an SDXL main model field.
|
||||
|
||||
- Any Field
|
||||
We cannot infer the usage of `typing.Any` via schema parsing, so you *must* use `ui_type=UIType.Any` to
|
||||
indicate that the field accepts any type. Use with caution. This cannot be used on outputs.
|
||||
|
||||
- Scheduler Field
|
||||
Special handling in the UI is needed for this field, which otherwise would be parsed as a plain enum field.
|
||||
|
||||
- Internal Fields
|
||||
Similar to the Any Field, the `collect` and `iterate` nodes make use of `typing.Any`. To facilitate
|
||||
handling these types in the client, we use `UIType._Collection` and `UIType._CollectionItem`. These
|
||||
should not be used by node authors.
|
||||
|
||||
- DEPRECATED Fields
|
||||
These types are deprecated and should not be used by node authors. A warning will be logged if one is
|
||||
used, and the type will be ignored. They are included here for backwards compatibility.
|
||||
"""
|
||||
|
||||
# region Model Field Types
|
||||
SDXLMainModel = "SDXLMainModelField"
|
||||
SDXLRefinerModel = "SDXLRefinerModelField"
|
||||
ONNXModel = "ONNXModelField"
|
||||
VaeModel = "VAEModelField"
|
||||
LoRAModel = "LoRAModelField"
|
||||
ControlNetModel = "ControlNetModelField"
|
||||
IPAdapterModel = "IPAdapterModelField"
|
||||
# endregion
|
||||
|
||||
# region Misc Field Types
|
||||
Scheduler = "SchedulerField"
|
||||
Any = "AnyField"
|
||||
# endregion
|
||||
|
||||
# region Internal Field Types
|
||||
_Collection = "CollectionField"
|
||||
_CollectionItem = "CollectionItemField"
|
||||
# endregion
|
||||
|
||||
# region DEPRECATED
|
||||
Boolean = "DEPRECATED_Boolean"
|
||||
Color = "DEPRECATED_Color"
|
||||
Conditioning = "DEPRECATED_Conditioning"
|
||||
Control = "DEPRECATED_Control"
|
||||
Float = "DEPRECATED_Float"
|
||||
Image = "DEPRECATED_Image"
|
||||
Integer = "DEPRECATED_Integer"
|
||||
Latents = "DEPRECATED_Latents"
|
||||
String = "DEPRECATED_String"
|
||||
BooleanCollection = "DEPRECATED_BooleanCollection"
|
||||
ColorCollection = "DEPRECATED_ColorCollection"
|
||||
ConditioningCollection = "DEPRECATED_ConditioningCollection"
|
||||
ControlCollection = "DEPRECATED_ControlCollection"
|
||||
FloatCollection = "DEPRECATED_FloatCollection"
|
||||
ImageCollection = "DEPRECATED_ImageCollection"
|
||||
IntegerCollection = "DEPRECATED_IntegerCollection"
|
||||
LatentsCollection = "DEPRECATED_LatentsCollection"
|
||||
StringCollection = "DEPRECATED_StringCollection"
|
||||
BooleanPolymorphic = "DEPRECATED_BooleanPolymorphic"
|
||||
ColorPolymorphic = "DEPRECATED_ColorPolymorphic"
|
||||
ConditioningPolymorphic = "DEPRECATED_ConditioningPolymorphic"
|
||||
ControlPolymorphic = "DEPRECATED_ControlPolymorphic"
|
||||
FloatPolymorphic = "DEPRECATED_FloatPolymorphic"
|
||||
ImagePolymorphic = "DEPRECATED_ImagePolymorphic"
|
||||
IntegerPolymorphic = "DEPRECATED_IntegerPolymorphic"
|
||||
LatentsPolymorphic = "DEPRECATED_LatentsPolymorphic"
|
||||
StringPolymorphic = "DEPRECATED_StringPolymorphic"
|
||||
MainModel = "DEPRECATED_MainModel"
|
||||
UNet = "DEPRECATED_UNet"
|
||||
Vae = "DEPRECATED_Vae"
|
||||
CLIP = "DEPRECATED_CLIP"
|
||||
Collection = "DEPRECATED_Collection"
|
||||
CollectionItem = "DEPRECATED_CollectionItem"
|
||||
Enum = "DEPRECATED_Enum"
|
||||
WorkflowField = "DEPRECATED_WorkflowField"
|
||||
IsIntermediate = "DEPRECATED_IsIntermediate"
|
||||
BoardField = "DEPRECATED_BoardField"
|
||||
MetadataItem = "DEPRECATED_MetadataItem"
|
||||
MetadataItemCollection = "DEPRECATED_MetadataItemCollection"
|
||||
MetadataItemPolymorphic = "DEPRECATED_MetadataItemPolymorphic"
|
||||
MetadataDict = "DEPRECATED_MetadataDict"
|
||||
# endregion
|
||||
|
||||
|
||||
class UIComponent(str, Enum, metaclass=MetaEnum):
|
||||
"""
|
||||
The type of UI component to use for a field, used to override the default components, which are
|
||||
inferred from the field type.
|
||||
"""
|
||||
|
||||
None_ = "none"
|
||||
Textarea = "textarea"
|
||||
Slider = "slider"
|
||||
|
||||
|
||||
class InputFieldJSONSchemaExtra(BaseModel):
|
||||
"""
|
||||
Extra attributes to be added to input fields and their OpenAPI schema. Used during graph execution,
|
||||
and by the workflow editor during schema parsing and UI rendering.
|
||||
"""
|
||||
|
||||
input: Input
|
||||
orig_required: bool
|
||||
field_kind: FieldKind
|
||||
default: Optional[Any] = None
|
||||
orig_default: Optional[Any] = None
|
||||
ui_hidden: bool = False
|
||||
ui_type: Optional[UIType] = None
|
||||
ui_component: Optional[UIComponent] = None
|
||||
ui_order: Optional[int] = None
|
||||
ui_choice_labels: Optional[dict[str, str]] = None
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
json_schema_serialization_defaults_required=True,
|
||||
)
|
||||
|
||||
|
||||
class OutputFieldJSONSchemaExtra(BaseModel):
|
||||
"""
|
||||
Extra attributes to be added to input fields and their OpenAPI schema. Used by the workflow editor
|
||||
during schema parsing and UI rendering.
|
||||
"""
|
||||
|
||||
field_kind: FieldKind
|
||||
ui_hidden: bool
|
||||
ui_type: Optional[UIType]
|
||||
ui_order: Optional[int]
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
json_schema_serialization_defaults_required=True,
|
||||
)
|
||||
|
||||
|
||||
def InputField(
|
||||
# copied from pydantic's Field
|
||||
# TODO: Can we support default_factory?
|
||||
default: Any = _Unset,
|
||||
default_factory: Callable[[], Any] | None = _Unset,
|
||||
title: str | None = _Unset,
|
||||
description: str | None = _Unset,
|
||||
pattern: str | None = _Unset,
|
||||
strict: bool | None = _Unset,
|
||||
gt: float | None = _Unset,
|
||||
ge: float | None = _Unset,
|
||||
lt: float | None = _Unset,
|
||||
le: float | None = _Unset,
|
||||
multiple_of: float | None = _Unset,
|
||||
allow_inf_nan: bool | None = _Unset,
|
||||
max_digits: int | None = _Unset,
|
||||
decimal_places: int | None = _Unset,
|
||||
min_length: int | None = _Unset,
|
||||
max_length: int | None = _Unset,
|
||||
# custom
|
||||
input: Input = Input.Any,
|
||||
ui_type: Optional[UIType] = None,
|
||||
ui_component: Optional[UIComponent] = None,
|
||||
ui_hidden: bool = False,
|
||||
ui_order: Optional[int] = None,
|
||||
ui_choice_labels: Optional[dict[str, str]] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Creates an input field for an invocation.
|
||||
|
||||
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/latest/api/fields/#pydantic.fields.Field) \
|
||||
that adds a few extra parameters to support graph execution and the node editor UI.
|
||||
|
||||
:param Input input: [Input.Any] The kind of input this field requires. \
|
||||
`Input.Direct` means a value must be provided on instantiation. \
|
||||
`Input.Connection` means the value must be provided by a connection. \
|
||||
`Input.Any` means either will do.
|
||||
|
||||
:param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \
|
||||
In some situations, the field's type is not enough to infer the correct UI type. \
|
||||
For example, model selection fields should render a dropdown UI component to select a model. \
|
||||
Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \
|
||||
`MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \
|
||||
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
|
||||
|
||||
:param UIComponent ui_component: [None] Optionally specifies a specific component to use in the UI. \
|
||||
The UI will always render a suitable component, but sometimes you want something different than the default. \
|
||||
For example, a `string` field will default to a single-line input, but you may want a multi-line textarea instead. \
|
||||
For this case, you could provide `UIComponent.Textarea`.
|
||||
|
||||
:param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI.
|
||||
|
||||
:param int ui_order: [None] Specifies the order in which this field should be rendered in the UI.
|
||||
|
||||
:param dict[str, str] ui_choice_labels: [None] Specifies the labels to use for the choices in an enum field.
|
||||
"""
|
||||
|
||||
json_schema_extra_ = InputFieldJSONSchemaExtra(
|
||||
input=input,
|
||||
ui_type=ui_type,
|
||||
ui_component=ui_component,
|
||||
ui_hidden=ui_hidden,
|
||||
ui_order=ui_order,
|
||||
ui_choice_labels=ui_choice_labels,
|
||||
field_kind=FieldKind.Input,
|
||||
orig_required=True,
|
||||
)
|
||||
|
||||
"""
|
||||
There is a conflict between the typing of invocation definitions and the typing of an invocation's
|
||||
`invoke()` function.
|
||||
|
||||
On instantiation of a node, the invocation definition is used to create the python class. At this time,
|
||||
any number of fields may be optional, because they may be provided by connections.
|
||||
|
||||
On calling of `invoke()`, however, those fields may be required.
|
||||
|
||||
For example, consider an ResizeImageInvocation with an `image: ImageField` field.
|
||||
|
||||
`image` is required during the call to `invoke()`, but when the python class is instantiated,
|
||||
the field may not be present. This is fine, because that image field will be provided by a
|
||||
connection from an ancestor node, which outputs an image.
|
||||
|
||||
This means we want to type the `image` field as optional for the node class definition, but required
|
||||
for the `invoke()` function.
|
||||
|
||||
If we use `typing.Optional` in the node class definition, the field will be typed as optional in the
|
||||
`invoke()` method, and we'll have to do a lot of runtime checks to ensure the field is present - or
|
||||
any static type analysis tools will complain.
|
||||
|
||||
To get around this, in node class definitions, we type all fields correctly for the `invoke()` function,
|
||||
but secretly make them optional in `InputField()`. We also store the original required bool and/or default
|
||||
value. When we call `invoke()`, we use this stored information to do an additional check on the class.
|
||||
"""
|
||||
|
||||
if default_factory is not _Unset and default_factory is not None:
|
||||
default = default_factory()
|
||||
logger.warn('"default_factory" is not supported, calling it now to set "default"')
|
||||
|
||||
# These are the args we may wish pass to the pydantic `Field()` function
|
||||
field_args = {
|
||||
"default": default,
|
||||
"title": title,
|
||||
"description": description,
|
||||
"pattern": pattern,
|
||||
"strict": strict,
|
||||
"gt": gt,
|
||||
"ge": ge,
|
||||
"lt": lt,
|
||||
"le": le,
|
||||
"multiple_of": multiple_of,
|
||||
"allow_inf_nan": allow_inf_nan,
|
||||
"max_digits": max_digits,
|
||||
"decimal_places": decimal_places,
|
||||
"min_length": min_length,
|
||||
"max_length": max_length,
|
||||
}
|
||||
|
||||
# We only want to pass the args that were provided, otherwise the `Field()`` function won't work as expected
|
||||
provided_args = {k: v for (k, v) in field_args.items() if v is not PydanticUndefined}
|
||||
|
||||
# Because we are manually making fields optional, we need to store the original required bool for reference later
|
||||
json_schema_extra_.orig_required = default is PydanticUndefined
|
||||
|
||||
# Make Input.Any and Input.Connection fields optional, providing None as a default if the field doesn't already have one
|
||||
if input is Input.Any or input is Input.Connection:
|
||||
default_ = None if default is PydanticUndefined else default
|
||||
provided_args.update({"default": default_})
|
||||
if default is not PydanticUndefined:
|
||||
# Before invoking, we'll check for the original default value and set it on the field if the field has no value
|
||||
json_schema_extra_.default = default
|
||||
json_schema_extra_.orig_default = default
|
||||
elif default is not PydanticUndefined:
|
||||
default_ = default
|
||||
provided_args.update({"default": default_})
|
||||
json_schema_extra_.orig_default = default_
|
||||
|
||||
return Field(
|
||||
**provided_args,
|
||||
json_schema_extra=json_schema_extra_.model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
|
||||
def OutputField(
|
||||
# copied from pydantic's Field
|
||||
default: Any = _Unset,
|
||||
title: str | None = _Unset,
|
||||
description: str | None = _Unset,
|
||||
pattern: str | None = _Unset,
|
||||
strict: bool | None = _Unset,
|
||||
gt: float | None = _Unset,
|
||||
ge: float | None = _Unset,
|
||||
lt: float | None = _Unset,
|
||||
le: float | None = _Unset,
|
||||
multiple_of: float | None = _Unset,
|
||||
allow_inf_nan: bool | None = _Unset,
|
||||
max_digits: int | None = _Unset,
|
||||
decimal_places: int | None = _Unset,
|
||||
min_length: int | None = _Unset,
|
||||
max_length: int | None = _Unset,
|
||||
# custom
|
||||
ui_type: Optional[UIType] = None,
|
||||
ui_hidden: bool = False,
|
||||
ui_order: Optional[int] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Creates an output field for an invocation output.
|
||||
|
||||
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \
|
||||
that adds a few extra parameters to support graph execution and the node editor UI.
|
||||
|
||||
:param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \
|
||||
In some situations, the field's type is not enough to infer the correct UI type. \
|
||||
For example, model selection fields should render a dropdown UI component to select a model. \
|
||||
Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \
|
||||
`MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \
|
||||
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
|
||||
|
||||
:param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \
|
||||
|
||||
:param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \
|
||||
"""
|
||||
return Field(
|
||||
default=default,
|
||||
title=title,
|
||||
description=description,
|
||||
pattern=pattern,
|
||||
strict=strict,
|
||||
gt=gt,
|
||||
ge=ge,
|
||||
lt=lt,
|
||||
le=le,
|
||||
multiple_of=multiple_of,
|
||||
allow_inf_nan=allow_inf_nan,
|
||||
max_digits=max_digits,
|
||||
decimal_places=decimal_places,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
json_schema_extra=OutputFieldJSONSchemaExtra(
|
||||
ui_type=ui_type,
|
||||
ui_hidden=ui_hidden,
|
||||
ui_order=ui_order,
|
||||
field_kind=FieldKind.Output,
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
|
||||
class UIConfigBase(BaseModel):
|
||||
"""
|
||||
Provides additional node configuration to the UI.
|
||||
@ -460,33 +74,6 @@ class UIConfigBase(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class InvocationContext:
|
||||
"""Initialized and provided to on execution of invocations."""
|
||||
|
||||
services: InvocationServices
|
||||
graph_execution_state_id: str
|
||||
queue_id: str
|
||||
queue_item_id: int
|
||||
queue_batch_id: str
|
||||
workflow: Optional[WorkflowWithoutID]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
services: InvocationServices,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
workflow: Optional[WorkflowWithoutID],
|
||||
):
|
||||
self.services = services
|
||||
self.graph_execution_state_id = graph_execution_state_id
|
||||
self.queue_id = queue_id
|
||||
self.queue_item_id = queue_item_id
|
||||
self.queue_batch_id = queue_batch_id
|
||||
self.workflow = workflow
|
||||
|
||||
|
||||
class BaseInvocationOutput(BaseModel):
|
||||
"""
|
||||
Base class for all invocation outputs.
|
||||
@ -926,37 +513,3 @@ def invocation_output(
|
||||
return cls
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class MetadataField(RootModel):
|
||||
"""
|
||||
Pydantic model for metadata with custom root of type dict[str, Any].
|
||||
Metadata is stored without a strict schema.
|
||||
"""
|
||||
|
||||
root: dict[str, Any] = Field(description="The metadata")
|
||||
|
||||
|
||||
MetadataFieldValidator = TypeAdapter(MetadataField)
|
||||
|
||||
|
||||
class WithMetadata(BaseModel):
|
||||
metadata: Optional[MetadataField] = Field(
|
||||
default=None,
|
||||
description=FieldDescriptions.metadata,
|
||||
json_schema_extra=InputFieldJSONSchemaExtra(
|
||||
field_kind=FieldKind.Internal,
|
||||
input=Input.Connection,
|
||||
orig_required=False,
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
|
||||
class WithWorkflow:
|
||||
workflow = None
|
||||
|
||||
def __init_subclass__(cls) -> None:
|
||||
logger.warn(
|
||||
f"{cls.__module__.split('.')[0]}.{cls.__name__}: WithWorkflow is deprecated. Use `context.workflow` to access the workflow."
|
||||
)
|
||||
super().__init_subclass__()
|
||||
|
Reference in New Issue
Block a user