feat(ui): add support for custom field types

Node authors may now create their own arbitrary/custom field types. Any pydantic model is supported.

Two notes:
1. Your field type's class name must be unique.

Suggest prefixing fields with something related to the node pack as a kind of namespace.

2. Custom field types function as connection-only fields.

For example, if your custom field has string attributes, you will not get a text input for that attribute when you give a node a field with your custom type.

This is the same behaviour as other complex fields that don't have custom UIs in the workflow editor - like, say, a string collection.

feat(ui): fix tooltips for custom types

We need to hold onto the original type of the field so they don't all just show up as "Unknown".

fix(ui): fix ts error with custom fields

feat(ui): custom field types connection validation

In the initial commit, a custom field's original type was added to the *field templates* only as `originalType`. Custom fields' `type` property was `"Custom"`*. This allowed for type safety throughout the UI logic.

*Actually, it was `"Unknown"`, but I changed it to custom for clarity.

Connection validation logic, however, uses the *field instance* of the node/field. Like the templates, *field instances* with custom types have their `type` set to `"Custom"`, but they didn't have an `originalType` property. As a result, all custom fields could be connected to all other custom fields.

To resolve this, we need to add `originalType` to the *field instances*, then switch the validation logic to use this instead of `type`.

This ended up needing a bit of fanagling:

- If we make `originalType` a required property on field instances, existing workflows will break during connection validation, because they won't have this property. We'd need a new layer of logic to migrate the workflows, adding the new `originalType` property.

While this layer is probably needed anyways, typing `originalType` as optional is much simpler. Workflow migration logic can come layer.

(Technically, we could remove all references to field types from the workflow files, and let the templates hold all this information. This feels like a significant change and I'm reluctant to do it now.)

- Because `originalType` is optional, anywhere we care about the type of a field, we need to use it over `type`. So there are a number of `field.originalType ?? field.type` expressions. This is a bit of a gotcha, we'll need to remember this in the future.

- We use `Array.prototype.includes()` often in the workflow editor, e.g. `COLLECTION_TYPES.includes(type)`. In these cases, the const array is of type `FieldType[]`, and `type` is is `FieldType`.

Because we now support custom types, the arg `type` is now widened from `FieldType` to `string`.

This causes a TS error. This behaviour is somewhat controversial (see https://github.com/microsoft/TypeScript/issues/14520). These expressions are now rewritten as `COLLECTION_TYPES.some((t) => t === type)` to satisfy TS. It's logically equivalent.

fix(ui): typo

feat(ui): add CustomCollection and CustomPolymorphic field types

feat(ui): add validation for CustomCollection & CustomPolymorphic types

- Update connection validation for custom types
- Use simple string parsing to determine if a field is a collection or polymorphic type.
- No longer need to keep a list of collection and polymorphic types.
- Added runtime checks in `baseinvocation.py` to ensure no fields are named in such a way that it could mess up the new parsing

chore(ui): remove errant console.log

fix(ui): rename 'nodes.currentConnectionFieldType' -> 'nodes.connectionStartFieldType'

This was confusingly named and kept tripping me up. Renamed to be consistent with the `reactflow` `ConnectionStartParams` type.

fix(ui): fix ts error

feat(nodes): add runtime check for custom field names

"Custom", "CustomCollection" and "CustomPolymorphic" are reserved field names.

chore(ui): add TODO for revising field type names

wip refactor fieldtype structured

wip refactor field types

wip refactor types

wip refactor types

fix node layout

refactor field types

chore: mypy

organisation

organisation

organisation

fix(nodes): fix field orig_required, field_kind and input statuses

feat(nodes): remove broken implementation of default_factory on InputField

Use of this could break connection validation due to the difference in node schemas required fields and invoke() required args.

Removed entirely for now. It wasn't ever actually used by the system, because all graphs always had values provided for fields where default_factory was used.

Also, pydantic is smart enough to not reuse the same object when specifying a default value - it clones the object first. So, the common pattern of `default_factory=list` is extraneous. It can just be `default=[]`.

fix(nodes): fix InputField name validation

workflow validation

validation

chore: ruff

feat(nodes): fix up baseinvocation comments

fix(ui): improve typing & logic of buildFieldInputTemplate

improved error handling in parseFieldType

fix: back compat for deprecated default_factory and UIType

feat(nodes): do not show node packs loaded log if none loaded

chore(ui): typegen
This commit is contained in:
psychedelicious 2023-11-17 11:32:35 +11:00
parent 0d52430481
commit 86a74e929a
186 changed files with 5713 additions and 5704 deletions

View File

@ -1,11 +1,8 @@
import sys
from typing import Any
from fastapi.responses import HTMLResponse
# parse_args() must be called before any other imports. if it is not called first, consumers of the config
# which are imported/used before parse_args() is called will get the default config values instead of the
# values from the command line or config file.
import sys
from invokeai.version.invokeai_version import __version__
from .services.config import InvokeAIAppConfig
@ -22,6 +19,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
import socket
from inspect import signature
from pathlib import Path
from typing import Any
import uvicorn
from fastapi import FastAPI
@ -29,7 +27,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
from fastapi.openapi.utils import get_openapi
from fastapi.responses import FileResponse
from fastapi.responses import FileResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware
@ -58,9 +56,9 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
from .api.sockets import SocketIO
from .invocations.baseinvocation import (
BaseInvocation,
InputFieldJSONSchemaExtra,
OutputFieldJSONSchemaExtra,
UIConfigBase,
_InputField,
_OutputField,
)
if is_mps_available():
@ -157,7 +155,11 @@ def custom_openapi() -> dict[str, Any]:
# Add Node Editor UI helper schemas
ui_config_schemas = models_json_schema(
[(UIConfigBase, "serialization"), (_InputField, "serialization"), (_OutputField, "serialization")],
[
(UIConfigBase, "serialization"),
(InputFieldJSONSchemaExtra, "serialization"),
(OutputFieldJSONSchemaExtra, "serialization"),
],
ref_template="#/components/schemas/{model}",
)
for schema_key, ui_config_schema in ui_config_schemas[1]["$defs"].items():
@ -165,7 +167,7 @@ def custom_openapi() -> dict[str, Any]:
# Add a reference to the output type to additionalProperties of the invoker schema
for invoker in all_invocations:
invoker_name = invoker.__name__
invoker_name = invoker.__name__ # type: ignore [attr-defined] # this is a valid attribute
output_type = signature(obj=invoker.invoke).return_annotation
output_type_title = output_type_titles[output_type.__name__]
invoker_schema = openapi_schema["components"]["schemas"][f"{invoker_name}"]

View File

@ -17,11 +17,15 @@ from pydantic_core import PydanticUndefined
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.shared.fields import FieldDescriptions
from invokeai.app.util.metaenum import MetaEnum
from invokeai.app.util.misc import uuid_string
from invokeai.backend.util.logging import InvokeAILogger
if TYPE_CHECKING:
from ..services.invocation_services import InvocationServices
logger = InvokeAILogger.get_logger()
class InvalidVersionError(ValueError):
pass
@ -31,7 +35,7 @@ class InvalidFieldError(TypeError):
pass
class Input(str, Enum):
class Input(str, Enum, metaclass=MetaEnum):
"""
The type of input a field accepts.
- `Input.Direct`: The field must have its value provided directly, when the invocation and field \
@ -45,86 +49,120 @@ class Input(str, Enum):
Any = "any"
class UIType(str, Enum):
class FieldKind(str, Enum, metaclass=MetaEnum):
"""
Type hints for the UI.
If a field should be provided a data type that does not exactly match the python type of the field, \
use this to provide the type that should be used instead. See the node development docs for detail \
on adding a new field type, which involves client-side changes.
The kind of field.
- `Input`: An input field on a node.
- `Output`: An output field on a node.
- `Internal`: A field which is treated as an input, but cannot be used in node definitions. Metadata is
one example. It is provided to nodes via the WithMetadata class, and we want to reserve the field name
"metadata" for this on all nodes. `FieldKind` is used to short-circuit the field name validation logic,
allowing "metadata" for that field.
- `NodeAttribute`: The field is a node attribute. These are fields which are not inputs or outputs,
but which are used to store information about the node. For example, the `id` and `type` fields are node
attributes.
The presence of this in `json_schema_extra["field_kind"]` is used when initializing node schemas on app
startup, and when generating the OpenAPI schema for the workflow editor.
"""
# region Primitives
Boolean = "boolean"
Color = "ColorField"
Conditioning = "ConditioningField"
Control = "ControlField"
Float = "float"
Image = "ImageField"
Integer = "integer"
Latents = "LatentsField"
String = "string"
# endregion
Input = "input"
Output = "output"
Internal = "internal"
NodeAttribute = "node_attribute"
# region Collection Primitives
BooleanCollection = "BooleanCollection"
ColorCollection = "ColorCollection"
ConditioningCollection = "ConditioningCollection"
ControlCollection = "ControlCollection"
FloatCollection = "FloatCollection"
ImageCollection = "ImageCollection"
IntegerCollection = "IntegerCollection"
LatentsCollection = "LatentsCollection"
StringCollection = "StringCollection"
# endregion
# region Polymorphic Primitives
BooleanPolymorphic = "BooleanPolymorphic"
ColorPolymorphic = "ColorPolymorphic"
ConditioningPolymorphic = "ConditioningPolymorphic"
ControlPolymorphic = "ControlPolymorphic"
FloatPolymorphic = "FloatPolymorphic"
ImagePolymorphic = "ImagePolymorphic"
IntegerPolymorphic = "IntegerPolymorphic"
LatentsPolymorphic = "LatentsPolymorphic"
StringPolymorphic = "StringPolymorphic"
# endregion
class UIType(str, Enum, metaclass=MetaEnum):
"""
Type hints for the UI for situations in which the field type is not enough to infer the correct UI type.
# region Models
MainModel = "MainModelField"
- Model Fields
The most common node-author-facing use will be for model fields. Internally, there is no difference
between SD-1, SD-2 and SDXL model fields - they all use the class `MainModelField`. To ensure the
base-model-specific UI is rendered, use e.g. `ui_type=UIType.SDXLMainModelField` to indicate that
the field is an SDXL main model field.
- Any Field
We cannot infer the usage of `typing.Any` via schema parsing, so you *must* use `ui_type=UIType.Any` to
indicate that the field accepts any type. Use with caution. This cannot be used on outputs.
- Scheduler Field
Special handling in the UI is needed for this field, which otherwise would be parsed as a plain enum field.
- Internal Fields
Similar to the Any Field, the `collect` and `iterate` nodes make use of `typing.Any`. To facilitate
handling these types in the client, we use `UIType._Collection` and `UIType._CollectionItem`. These
should not be used by node authors.
"""
# region Model Field Types
SDXLMainModel = "SDXLMainModelField"
SDXLRefinerModel = "SDXLRefinerModelField"
ONNXModel = "ONNXModelField"
VaeModel = "VaeModelField"
VaeModel = "VAEModelField"
LoRAModel = "LoRAModelField"
ControlNetModel = "ControlNetModelField"
IPAdapterModel = "IPAdapterModelField"
UNet = "UNetField"
Vae = "VaeField"
CLIP = "ClipField"
# endregion
# region Iterate/Collect
Collection = "Collection"
CollectionItem = "CollectionItem"
# region Misc Field Types
Scheduler = "SchedulerField"
Any = "AnyField"
# endregion
# region Misc
Enum = "enum"
Scheduler = "Scheduler"
WorkflowField = "WorkflowField"
IsIntermediate = "IsIntermediate"
BoardField = "BoardField"
Any = "Any"
MetadataItem = "MetadataItem"
MetadataItemCollection = "MetadataItemCollection"
MetadataItemPolymorphic = "MetadataItemPolymorphic"
MetadataDict = "MetadataDict"
# region Internal Field Types
_Collection = "CollectionField"
_CollectionItem = "CollectionItemField"
# endregion
# region DEPRECATED
Boolean = "DEPRECATED_Boolean"
Color = "DEPRECATED_Color"
Conditioning = "DEPRECATED_Conditioning"
Control = "DEPRECATED_Control"
Float = "DEPRECATED_Float"
Image = "DEPRECATED_Image"
Integer = "DEPRECATED_Integer"
Latents = "DEPRECATED_Latents"
String = "DEPRECATED_String"
BooleanCollection = "DEPRECATED_BooleanCollection"
ColorCollection = "DEPRECATED_ColorCollection"
ConditioningCollection = "DEPRECATED_ConditioningCollection"
ControlCollection = "DEPRECATED_ControlCollection"
FloatCollection = "DEPRECATED_FloatCollection"
ImageCollection = "DEPRECATED_ImageCollection"
IntegerCollection = "DEPRECATED_IntegerCollection"
LatentsCollection = "DEPRECATED_LatentsCollection"
StringCollection = "DEPRECATED_StringCollection"
BooleanPolymorphic = "DEPRECATED_BooleanPolymorphic"
ColorPolymorphic = "DEPRECATED_ColorPolymorphic"
ConditioningPolymorphic = "DEPRECATED_ConditioningPolymorphic"
ControlPolymorphic = "DEPRECATED_ControlPolymorphic"
FloatPolymorphic = "DEPRECATED_FloatPolymorphic"
ImagePolymorphic = "DEPRECATED_ImagePolymorphic"
IntegerPolymorphic = "DEPRECATED_IntegerPolymorphic"
LatentsPolymorphic = "DEPRECATED_LatentsPolymorphic"
StringPolymorphic = "DEPRECATED_StringPolymorphic"
MainModel = "DEPRECATED_MainModel"
UNet = "DEPRECATED_UNet"
Vae = "DEPRECATED_Vae"
CLIP = "DEPRECATED_CLIP"
Collection = "DEPRECATED_Collection"
CollectionItem = "DEPRECATED_CollectionItem"
Enum = "DEPRECATED_Enum"
WorkflowField = "DEPRECATED_WorkflowField"
IsIntermediate = "DEPRECATED_IsIntermediate"
BoardField = "DEPRECATED_BoardField"
MetadataItem = "DEPRECATED_MetadataItem"
MetadataItemCollection = "DEPRECATED_MetadataItemCollection"
MetadataItemPolymorphic = "DEPRECATED_MetadataItemPolymorphic"
MetadataDict = "DEPRECATED_MetadataDict"
# endregion
class UIComponent(str, Enum):
class UIComponent(str, Enum, metaclass=MetaEnum):
"""
The type of UI component to use for a field, used to override the default components, which are \
The type of UI component to use for a field, used to override the default components, which are
inferred from the field type.
"""
@ -133,7 +171,7 @@ class UIComponent(str, Enum):
Slider = "slider"
class _InputField(BaseModel):
class InputFieldJSONSchemaExtra(BaseModel):
"""
*DO NOT USE*
This helper class is used to tell the client about our custom field attributes via OpenAPI
@ -142,12 +180,15 @@ class _InputField(BaseModel):
"""
input: Input
ui_hidden: bool
ui_type: Optional[UIType]
ui_component: Optional[UIComponent]
ui_order: Optional[int]
ui_choice_labels: Optional[dict[str, str]]
item_default: Optional[Any]
orig_required: bool
field_kind: FieldKind
default: Optional[Any] = None
orig_default: Optional[Any] = None
ui_hidden: bool = False
ui_type: Optional[UIType] = None
ui_component: Optional[UIComponent] = None
ui_order: Optional[int] = None
ui_choice_labels: Optional[dict[str, str]] = None
model_config = ConfigDict(
validate_assignment=True,
@ -155,7 +196,7 @@ class _InputField(BaseModel):
)
class _OutputField(BaseModel):
class OutputFieldJSONSchemaExtra(BaseModel):
"""
*DO NOT USE*
This helper class is used to tell the client about our custom field attributes via OpenAPI
@ -163,6 +204,7 @@ class _OutputField(BaseModel):
purpose in the backend.
"""
field_kind: FieldKind
ui_hidden: bool
ui_type: Optional[UIType]
ui_order: Optional[int]
@ -180,6 +222,7 @@ def get_type(klass: BaseModel) -> str:
def InputField(
# copied from pydantic's Field
# TODO: Can we support default_factory?
default: Any = _Unset,
default_factory: Callable[[], Any] | None = _Unset,
title: str | None = _Unset,
@ -203,12 +246,11 @@ def InputField(
ui_hidden: bool = False,
ui_order: Optional[int] = None,
ui_choice_labels: Optional[dict[str, str]] = None,
item_default: Optional[Any] = None,
) -> Any:
"""
Creates an input field for an invocation.
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/latest/api/fields/#pydantic.fields.Field) \
that adds a few extra parameters to support graph execution and the node editor UI.
:param Input input: [Input.Any] The kind of input this field requires. \
@ -230,26 +272,57 @@ def InputField(
:param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI.
: param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \
:param int ui_order: [None] Specifies the order in which this field should be rendered in the UI.
: param bool item_default: [None] Specifies the default item value, if this is a collection input. \
Ignored for non-collection fields.
:param dict[str, str] ui_choice_labels: [None] Specifies the labels to use for the choices in an enum field.
"""
json_schema_extra_: dict[str, Any] = {
"input": input,
"ui_type": ui_type,
"ui_component": ui_component,
"ui_hidden": ui_hidden,
"ui_order": ui_order,
"item_default": item_default,
"ui_choice_labels": ui_choice_labels,
"_field_kind": "input",
}
json_schema_extra_ = InputFieldJSONSchemaExtra(
input=input,
ui_type=ui_type,
ui_component=ui_component,
ui_hidden=ui_hidden,
ui_order=ui_order,
ui_choice_labels=ui_choice_labels,
field_kind=FieldKind.Input,
orig_required=True,
)
"""
There is a conflict between the typing of invocation definitions and the typing of an invocation's
`invoke()` function.
On instantiation of a node, the invocation definition is used to create the python class. At this time,
any number of fields may be optional, because they may be provided by connections.
On calling of `invoke()`, however, those fields may be required.
For example, consider an ResizeImageInvocation with an `image: ImageField` field.
`image` is required during the call to `invoke()`, but when the python class is instantiated,
the field may not be present. This is fine, because that image field will be provided by a
connection from an ancestor node, which outputs an image.
This means we want to type the `image` field as optional for the node class definition, but required
for the `invoke()` function.
If we use `typing.Optional` in the node class definition, the field will be typed as optional in the
`invoke()` method, and we'll have to do a lot of runtime checks to ensure the field is present - or
any static type analysis tools will complain.
To get around this, in node class definitions, we type all fields correctly for the `invoke()` function,
but secretly make them optional in `InputField()`. We also store the original required bool and/or default
value. When we call `invoke()`, we use this stored information to do an additional check on the class.
"""
if default_factory is not _Unset and default_factory is not None:
default = default_factory()
del default_factory
logger.warn('"default_factory" is not supported, calling it now to set "default"')
# These are the args we may wish pass to the pydantic `Field()` function
field_args = {
"default": default,
"default_factory": default_factory,
"title": title,
"description": description,
"pattern": pattern,
@ -266,70 +339,34 @@ def InputField(
"max_length": max_length,
}
"""
Invocation definitions have their fields typed correctly for their `invoke()` functions.
This typing is often more specific than the actual invocation definition requires, because
fields may have values provided only by connections.
For example, consider an ResizeImageInvocation with an `image: ImageField` field.
`image` is required during the call to `invoke()`, but when the python class is instantiated,
the field may not be present. This is fine, because that image field will be provided by a
an ancestor node that outputs the image.
So we'd like to type that `image` field as `Optional[ImageField]`. If we do that, however, then
we need to handle a lot of extra logic in the `invoke()` function to check if the field has a
value or not. This is very tedious.
Ideally, the invocation definition would be able to specify that the field is required during
invocation, but optional during instantiation. So the field would be typed as `image: ImageField`,
but when calling the `invoke()` function, we raise an error if the field is not present.
To do this, we need to do a bit of fanagling to make the pydantic field optional, and then do
extra validation when calling `invoke()`.
There is some additional logic here to cleaning create the pydantic field via the wrapper.
"""
# Filter out field args not provided
# We only want to pass the args that were provided, otherwise the `Field()`` function won't work as expected
provided_args = {k: v for (k, v) in field_args.items() if v is not PydanticUndefined}
if (default is not PydanticUndefined) and (default_factory is not PydanticUndefined):
raise ValueError("Cannot specify both default and default_factory")
# Because we are manually making fields optional, we need to store the original required bool for reference later
json_schema_extra_.orig_required = default is PydanticUndefined
# because we are manually making fields optional, we need to store the original required bool for reference later
if default is PydanticUndefined and default_factory is PydanticUndefined:
json_schema_extra_.update({"orig_required": True})
else:
json_schema_extra_.update({"orig_required": False})
# make Input.Any and Input.Connection fields optional, providing None as a default if the field doesn't already have one
if (input is Input.Any or input is Input.Connection) and default_factory is PydanticUndefined:
# Make Input.Any and Input.Connection fields optional, providing None as a default if the field doesn't already have one
if input is Input.Any or input is Input.Connection:
default_ = None if default is PydanticUndefined else default
provided_args.update({"default": default_})
if default is not PydanticUndefined:
# before invoking, we'll grab the original default value and set it on the field if the field wasn't provided a value
json_schema_extra_.update({"default": default})
json_schema_extra_.update({"orig_default": default})
elif default is not PydanticUndefined and default_factory is PydanticUndefined:
# Before invoking, we'll check for the original default value and set it on the field if the field has no value
json_schema_extra_.default = default
json_schema_extra_.orig_default = default
elif default is not PydanticUndefined:
default_ = default
provided_args.update({"default": default_})
json_schema_extra_.update({"orig_default": default_})
elif default_factory is not PydanticUndefined:
provided_args.update({"default_factory": default_factory})
# TODO: cannot serialize default_factory...
# json_schema_extra_.update(dict(orig_default_factory=default_factory))
json_schema_extra_.orig_default = default_
return Field(
**provided_args,
json_schema_extra=json_schema_extra_,
json_schema_extra=json_schema_extra_.model_dump(exclude_none=True),
)
def OutputField(
# copied from pydantic's Field
default: Any = _Unset,
default_factory: Callable[[], Any] | None = _Unset,
title: str | None = _Unset,
description: str | None = _Unset,
pattern: str | None = _Unset,
@ -368,7 +405,6 @@ def OutputField(
"""
return Field(
default=default,
default_factory=default_factory,
title=title,
description=description,
pattern=pattern,
@ -383,12 +419,12 @@ def OutputField(
decimal_places=decimal_places,
min_length=min_length,
max_length=max_length,
json_schema_extra={
"ui_type": ui_type,
"ui_hidden": ui_hidden,
"ui_order": ui_order,
"_field_kind": "output",
},
json_schema_extra=OutputFieldJSONSchemaExtra(
ui_type=ui_type,
ui_hidden=ui_hidden,
ui_order=ui_order,
field_kind=FieldKind.Output,
).model_dump(exclude_none=True),
)
@ -538,7 +574,7 @@ class BaseInvocation(ABC, BaseModel):
return signature(cls.invoke).return_annotation
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel], *args, **kwargs) -> None:
# Add the various UI-facing attributes to the schema. These are used to build the invocation templates.
uiconfig = getattr(model_class, "UIConfig", None)
if uiconfig and hasattr(uiconfig, "title"):
@ -604,15 +640,17 @@ class BaseInvocation(ABC, BaseModel):
id: str = Field(
default_factory=uuid_string,
description="The id of this instance of an invocation. Must be unique among all instances of invocations.",
json_schema_extra={"_field_kind": "internal"},
json_schema_extra={"field_kind": FieldKind.NodeAttribute},
)
is_intermediate: bool = Field(
default=False,
description="Whether or not this is an intermediate invocation.",
json_schema_extra={"ui_type": UIType.IsIntermediate, "_field_kind": "internal"},
json_schema_extra={"ui_type": "IsIntermediate", "field_kind": FieldKind.NodeAttribute},
)
use_cache: bool = Field(
default=True, description="Whether or not to use the cache", json_schema_extra={"_field_kind": "internal"}
default=True,
description="Whether or not to use the cache",
json_schema_extra={"field_kind": FieldKind.NodeAttribute},
)
UIConfig: ClassVar[Type[UIConfigBase]]
@ -629,12 +667,15 @@ class BaseInvocation(ABC, BaseModel):
TBaseInvocation = TypeVar("TBaseInvocation", bound=BaseInvocation)
RESERVED_INPUT_FIELD_NAMES = {
RESERVED_NODE_ATTRIBUTE_FIELD_NAMES = {
"id",
"is_intermediate",
"use_cache",
"type",
"workflow",
}
RESERVED_INPUT_FIELD_NAMES = {
"metadata",
}
@ -653,39 +694,56 @@ def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None
"""
Validates the fields of an invocation or invocation output:
- must not override any pydantic reserved fields
- must not end with "Collection" or "Polymorphic" as these are reserved for internal use
- must be created via `InputField`, `OutputField`, or be an internal field defined in this file
"""
for name, field in model_fields.items():
if name in RESERVED_PYDANTIC_FIELD_NAMES:
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved by pydantic)')
field_kind = (
# _field_kind is defined via InputField(), OutputField() or by one of the internal fields defined in this file
field.json_schema_extra.get("_field_kind", None) if field.json_schema_extra else None
if not field.annotation:
raise InvalidFieldError(f'Invalid field type "{name}" on "{model_type}" (missing annotation)')
if not isinstance(field.json_schema_extra, dict):
raise InvalidFieldError(
f'Invalid field definition for "{name}" on "{model_type}" (missing json_schema_extra dict)'
)
field_kind = field.json_schema_extra.get("field_kind", None)
# must have a field_kind
if field_kind is None or field_kind not in {"input", "output", "internal"}:
if not isinstance(field_kind, FieldKind):
raise InvalidFieldError(
f'Invalid field definition for "{name}" on "{model_type}" (maybe it\'s not an InputField or OutputField?)'
)
if field_kind == "input" and name in RESERVED_INPUT_FIELD_NAMES:
if field_kind is FieldKind.Input and (
name in RESERVED_NODE_ATTRIBUTE_FIELD_NAMES or name in RESERVED_INPUT_FIELD_NAMES
):
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved input field name)')
if field_kind == "output" and name in RESERVED_OUTPUT_FIELD_NAMES:
if field_kind is FieldKind.Output and name in RESERVED_OUTPUT_FIELD_NAMES:
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved output field name)')
# internal fields *must* be in the reserved list
if (
field_kind == "internal"
and name not in RESERVED_INPUT_FIELD_NAMES
and name not in RESERVED_OUTPUT_FIELD_NAMES
):
if (field_kind is FieldKind.Internal) and name not in RESERVED_INPUT_FIELD_NAMES:
raise InvalidFieldError(
f'Invalid field name "{name}" on "{model_type}" (internal field without reserved name)'
)
# node attribute fields *must* be in the reserved list
if (
field_kind is FieldKind.NodeAttribute
and name not in RESERVED_NODE_ATTRIBUTE_FIELD_NAMES
and name not in RESERVED_OUTPUT_FIELD_NAMES
):
raise InvalidFieldError(
f'Invalid field name "{name}" on "{model_type}" (node attribute field without reserved name)'
)
ui_type = field.json_schema_extra.get("ui_type", None)
if isinstance(ui_type, str) and ui_type.startswith("DEPRECATED_"):
logger.warn(f"\"UIType.{ui_type.split('_')[-1]}\" is deprecated, ignoring")
field.json_schema_extra.pop("ui_type")
return None
@ -749,7 +807,7 @@ def invocation(
invocation_type_annotation = Literal[invocation_type] # type: ignore
invocation_type_field = Field(
title="type", default=invocation_type, json_schema_extra={"_field_kind": "internal"}
title="type", default=invocation_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute}
)
docstring = cls.__doc__
@ -795,7 +853,9 @@ def invocation_output(
# Add the output type to the model.
output_type_annotation = Literal[output_type] # type: ignore
output_type_field = Field(title="type", default=output_type, json_schema_extra={"_field_kind": "internal"})
output_type_field = Field(
title="type", default=output_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute}
)
docstring = cls.__doc__
cls = create_model(
@ -827,7 +887,7 @@ WorkflowFieldValidator = TypeAdapter(WorkflowField)
class WithWorkflow(BaseModel):
workflow: Optional[WorkflowField] = Field(
default=None, description=FieldDescriptions.workflow, json_schema_extra={"_field_kind": "internal"}
default=None, description=FieldDescriptions.workflow, json_schema_extra={"field_kind": FieldKind.NodeAttribute}
)
@ -845,5 +905,11 @@ MetadataFieldValidator = TypeAdapter(MetadataField)
class WithMetadata(BaseModel):
metadata: Optional[MetadataField] = Field(
default=None, description=FieldDescriptions.metadata, json_schema_extra={"_field_kind": "internal"}
default=None,
description=FieldDescriptions.metadata,
json_schema_extra=InputFieldJSONSchemaExtra(
field_kind=FieldKind.Internal,
input=Input.Connection,
orig_required=False,
).model_dump(exclude_none=True),
)

View File

@ -5,7 +5,7 @@ import numpy as np
from pydantic import ValidationInfo, field_validator
from invokeai.app.invocations.primitives import IntegerCollectionOutput
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.app.util.misc import SEED_MAX
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
@ -55,7 +55,7 @@ class RangeOfSizeInvocation(BaseInvocation):
title="Random Range",
tags=["range", "integer", "random", "collection"],
category="collections",
version="1.0.0",
version="1.0.1",
use_cache=False,
)
class RandomRangeInvocation(BaseInvocation):
@ -65,10 +65,10 @@ class RandomRangeInvocation(BaseInvocation):
high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value")
size: int = InputField(default=1, description="The number of values to generate")
seed: int = InputField(
default=0,
ge=0,
le=SEED_MAX,
description="The seed for the RNG (omit for random)",
default_factory=get_random_seed,
)
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:

View File

@ -39,6 +39,8 @@ for d in Path(__file__).parent.iterdir():
logger.warn(f"Could not load {init}")
continue
logger.info(f"Loading node pack {spec.name}")
module = module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module)
@ -47,5 +49,5 @@ for d in Path(__file__).parent.iterdir():
del init, module_name
logger.info(f"Loaded {loaded_count} modules from {Path(__file__).parent}")
if loaded_count > 0:
logger.info(f"Loaded {loaded_count} node packs from {Path(__file__).parent}")

View File

@ -8,7 +8,7 @@ from PIL import Image, ImageOps
from invokeai.app.invocations.primitives import ColorField, ImageField, ImageOutput
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.app.util.misc import SEED_MAX
from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint
from invokeai.backend.image_util.lama import LaMA
from invokeai.backend.image_util.patchmatch import PatchMatch
@ -154,17 +154,17 @@ class InfillColorInvocation(BaseInvocation, WithWorkflow, WithMetadata):
)
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.1.0")
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.1.1")
class InfillTileInvocation(BaseInvocation, WithWorkflow, WithMetadata):
"""Infills transparent areas of an image with tiles of the image"""
image: ImageField = InputField(description="The image to infill")
tile_size: int = InputField(default=32, ge=1, description="The tile size (px)")
seed: int = InputField(
default=0,
ge=0,
le=SEED_MAX,
description="The seed to use for tile generation (omit for random)",
default_factory=get_random_seed,
)
def invoke(self, context: InvocationContext) -> ImageOutput:

View File

@ -11,7 +11,6 @@ from invokeai.app.invocations.baseinvocation import (
InputField,
InvocationContext,
OutputField,
UIType,
invocation,
invocation_output,
)
@ -67,7 +66,7 @@ class IPAdapterInvocation(BaseInvocation):
# weight: float = InputField(default=1.0, description="The weight of the IP-Adapter.", ui_type=UIType.Float)
weight: Union[float, List[float]] = InputField(
default=1, ge=-1, description="The weight given to the IP-Adapter", ui_type=UIType.Float, title="Weight"
default=1, ge=-1, description="The weight given to the IP-Adapter", title="Weight"
)
begin_step_percent: float = InputField(

View File

@ -274,7 +274,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
ui_order=7,
)
latents: Optional[LatentsField] = InputField(
default=None, description=FieldDescriptions.latents, input=Input.Connection
default=None,
description=FieldDescriptions.latents,
input=Input.Connection,
ui_order=4,
)
denoise_mask: Optional[DenoiseMaskField] = InputField(
default=None,

View File

@ -14,7 +14,6 @@ from .baseinvocation import (
InputField,
InvocationContext,
OutputField,
UIType,
invocation,
invocation_output,
)
@ -395,7 +394,6 @@ class VaeLoaderInvocation(BaseInvocation):
vae_model: VAEModelField = InputField(
description=FieldDescriptions.vae_model,
input=Input.Direct,
ui_type=UIType.VaeModel,
title="VAE",
)

View File

@ -6,7 +6,7 @@ from pydantic import field_validator
from invokeai.app.invocations.latent import LatentsField
from invokeai.app.shared.fields import FieldDescriptions
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.app.util.misc import SEED_MAX
from ...backend.util.devices import choose_torch_device, torch_dtype
from .baseinvocation import (
@ -83,16 +83,16 @@ def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int):
title="Noise",
tags=["latents", "noise"],
category="latents",
version="1.0.0",
version="1.0.1",
)
class NoiseInvocation(BaseInvocation):
"""Generates latent noise."""
seed: int = InputField(
default=0,
ge=0,
le=SEED_MAX,
description=FieldDescriptions.seed,
default_factory=get_random_seed,
)
width: int = InputField(
default=512,

View File

@ -62,12 +62,12 @@ class BooleanInvocation(BaseInvocation):
title="Boolean Collection Primitive",
tags=["primitives", "boolean", "collection"],
category="primitives",
version="1.0.0",
version="1.0.1",
)
class BooleanCollectionInvocation(BaseInvocation):
"""A collection of boolean primitive values"""
collection: list[bool] = InputField(default_factory=list, description="The collection of boolean values")
collection: list[bool] = InputField(default=[], description="The collection of boolean values")
def invoke(self, context: InvocationContext) -> BooleanCollectionOutput:
return BooleanCollectionOutput(collection=self.collection)
@ -111,12 +111,12 @@ class IntegerInvocation(BaseInvocation):
title="Integer Collection Primitive",
tags=["primitives", "integer", "collection"],
category="primitives",
version="1.0.0",
version="1.0.1",
)
class IntegerCollectionInvocation(BaseInvocation):
"""A collection of integer primitive values"""
collection: list[int] = InputField(default_factory=list, description="The collection of integer values")
collection: list[int] = InputField(default=[], description="The collection of integer values")
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
return IntegerCollectionOutput(collection=self.collection)
@ -158,12 +158,12 @@ class FloatInvocation(BaseInvocation):
title="Float Collection Primitive",
tags=["primitives", "float", "collection"],
category="primitives",
version="1.0.0",
version="1.0.1",
)
class FloatCollectionInvocation(BaseInvocation):
"""A collection of float primitive values"""
collection: list[float] = InputField(default_factory=list, description="The collection of float values")
collection: list[float] = InputField(default=[], description="The collection of float values")
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
return FloatCollectionOutput(collection=self.collection)
@ -205,12 +205,12 @@ class StringInvocation(BaseInvocation):
title="String Collection Primitive",
tags=["primitives", "string", "collection"],
category="primitives",
version="1.0.0",
version="1.0.1",
)
class StringCollectionInvocation(BaseInvocation):
"""A collection of string primitive values"""
collection: list[str] = InputField(default_factory=list, description="The collection of string values")
collection: list[str] = InputField(default=[], description="The collection of string values")
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
return StringCollectionOutput(collection=self.collection)
@ -467,13 +467,13 @@ class ConditioningInvocation(BaseInvocation):
title="Conditioning Collection Primitive",
tags=["primitives", "conditioning", "collection"],
category="primitives",
version="1.0.0",
version="1.0.1",
)
class ConditioningCollectionInvocation(BaseInvocation):
"""A collection of conditioning tensor primitive values"""
collection: list[ConditioningField] = InputField(
default_factory=list,
default=[],
description="The collection of conditioning tensors",
)

View File

@ -9,7 +9,6 @@ from invokeai.app.invocations.baseinvocation import (
InputField,
InvocationContext,
OutputField,
UIType,
invocation,
invocation_output,
)
@ -59,7 +58,7 @@ class T2IAdapterInvocation(BaseInvocation):
ui_order=-1,
)
weight: Union[float, list[float]] = InputField(
default=1, ge=0, description="The weight given to the T2I-Adapter", ui_type=UIType.Float, title="Weight"
default=1, ge=0, description="The weight given to the T2I-Adapter", title="Weight"
)
begin_step_percent: float = InputField(
default=0, ge=-1, le=2, description="When the T2I-Adapter is first applied (% of total steps)"

View File

@ -205,7 +205,7 @@ class IterateInvocationOutput(BaseInvocationOutput):
"""Used to connect iteration outputs. Will be expanded to a specific output."""
item: Any = OutputField(
description="The item being iterated over", title="Collection Item", ui_type=UIType.CollectionItem
description="The item being iterated over", title="Collection Item", ui_type=UIType._CollectionItem
)
@ -215,7 +215,7 @@ class IterateInvocation(BaseInvocation):
"""Iterates over a list of items"""
collection: list[Any] = InputField(
description="The list of items to iterate over", default_factory=list, ui_type=UIType.Collection
description="The list of items to iterate over", default=[], ui_type=UIType._Collection
)
index: int = InputField(description="The index, will be provided on executed iterators", default=0, ui_hidden=True)
@ -227,7 +227,7 @@ class IterateInvocation(BaseInvocation):
@invocation_output("collect_output")
class CollectInvocationOutput(BaseInvocationOutput):
collection: list[Any] = OutputField(
description="The collection of input items", title="Collection", ui_type=UIType.Collection
description="The collection of input items", title="Collection", ui_type=UIType._Collection
)
@ -238,12 +238,12 @@ class CollectInvocation(BaseInvocation):
item: Optional[Any] = InputField(
default=None,
description="The item to collect (all inputs must be of the same type)",
ui_type=UIType.CollectionItem,
ui_type=UIType._CollectionItem,
title="Collection Item",
input=Input.Connection,
)
collection: list[Any] = InputField(
description="The collection, will be provided on execution", default_factory=list, ui_hidden=True
description="The collection, will be provided on execution", default=[], ui_hidden=True
)
def invoke(self, context: InvocationContext) -> CollectInvocationOutput:

View File

@ -805,6 +805,8 @@
"clipField": "Clip",
"clipFieldDescription": "Tokenizer and text_encoder submodels.",
"collection": "Collection",
"collectionFieldType": "{{name}} Collection",
"polymorphicFieldType": "{{name}} Polymorphic",
"collectionDescription": "TODO",
"collectionItem": "Collection Item",
"collectionItemDescription": "TODO",
@ -891,10 +893,15 @@
"mainModelField": "Model",
"mainModelFieldDescription": "TODO",
"maybeIncompatible": "May be Incompatible With Installed",
"mismatchedVersion": "Has Mismatched Version",
"mismatchedVersion": "Invalid node: node {{node}} of type {{type}} has mismatched version (try updating?)",
"missingCanvaInitImage": "Missing canvas init image",
"missingCanvaInitMaskImages": "Missing canvas init and mask images",
"missingTemplate": "Missing Template",
"missingTemplate": "Invalid node: node {{node}} of type {{type}} missing template (not installed?)",
"sourceNodeDoesNotExist": "Invalid edge: source/output node {{node}} does not exist",
"targetNodeDoesNotExist": "Invalid edge: target/input node {{node}} does not exist",
"sourceNodeFieldDoesNotExist": "Invalid edge: source/output field {{node}}.{{field}} does not exist",
"targetNodeFieldDoesNotExist": "Invalid edge: target/input field {{node}}.{{field}} does not exist",
"deletedInvalidEdge": "Deleted invalid edge {{source}} -> {{target}}",
"noConnectionData": "No connection data",
"noConnectionInProgress": "No connection in progress",
"node": "Node",
@ -954,10 +961,17 @@
"stringDescription": "Strings are text.",
"stringPolymorphic": "String Polymorphic",
"stringPolymorphicDescription": "A collection of strings.",
"unableToLoadWorkflow": "Unable to Validate Workflow",
"unableToLoadWorkflow": "Unable to Load Workflow",
"unableToParseEdge": "Unable to parse edge",
"unableToParseNode": "Unable to parse node",
"unableToUpdateNode": "Unable to update node",
"unableToValidateWorkflow": "Unable to Validate Workflow",
"unknownErrorValidatingWorkflow": "Unknown error validating workflow",
"inputFieldTypeParseError": "Unable to parse type of input field {{node}}.{{field}} ({{message}})",
"outputFieldTypeParseError": "Unable to parse type of output field {{node}}.{{field}} ({{message}})",
"unableToExtractSchemaNameFromRef": "unable to extract schema name from ref",
"unsupportedArrayItemType": "unsupported array item type \"{{type}}\"",
"unableToParseFieldType": "unable to parse field type",
"uNetField": "UNet",
"uNetFieldDescription": "UNet submodel.",
"unhandledInputProperty": "Unhandled input property",
@ -971,8 +985,9 @@
"unkownInvocation": "Unknown Invocation type",
"unknownOutput": "Unknown output",
"updateNode": "Update Node",
"updateAllNodes": "Update All Nodes",
"updateApp": "Update App",
"updateAllNodes": "Update All Nodes",
"allNodesUpdated": "All Nodes Updated",
"unableToUpdateNodes_one": "Unable to update {{count}} node",
"unableToUpdateNodes_other": "Unable to update {{count}} nodes",
"vaeField": "Vae",
@ -981,6 +996,8 @@
"vaeModelFieldDescription": "TODO",
"validateConnections": "Validate Connections and Graph",
"validateConnectionsHelp": "Prevent invalid connections from being made, and invalid graphs from being invoked",
"unableToGetWorkflowVersion": "Unable to get workflow schema version",
"unrecognizedWorkflowVersion": "Unrecognized workflow schema version {{version}}",
"version": "Version",
"versionUnknown": " Version Unknown",
"workflow": "Workflow",

View File

@ -71,7 +71,7 @@ import { addSocketUnsubscribedEventListener as addSocketUnsubscribedListener } f
import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSaved';
import { addTabChangedListener } from './listeners/tabChanged';
import { addUpscaleRequestedListener } from './listeners/upscaleRequested';
import { addWorkflowLoadedListener } from './listeners/workflowLoaded';
import { addWorkflowLoadRequestedListener } from './listeners/workflowLoadRequested';
import { addUpdateAllNodesRequestedListener } from './listeners/updateAllNodesRequested';
export const listenerMiddleware = createListenerMiddleware();
@ -178,7 +178,7 @@ addBoardIdSelectedListener();
addReceivedOpenAPISchemaListener();
// Workflows
addWorkflowLoadedListener();
addWorkflowLoadRequestedListener();
addUpdateAllNodesRequestedListener();
// DND

View File

@ -12,10 +12,10 @@ import { addToast } from 'features/system/store/systemSlice';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
import { queueApi } from 'services/api/endpoints/queue';
import { isImageOutput } from 'services/api/guards';
import { BatchConfig, ImageDTO } from 'services/api/types';
import { socketInvocationComplete } from 'services/events/actions';
import { startAppListening } from '..';
import { isImageOutput } from 'features/nodes/types/common';
export const addControlNetImageProcessedListener = () => {
startAppListening({

View File

@ -5,19 +5,20 @@ import {
controlAdapterProcessedImageChanged,
selectControlAdapterAll,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions';
import { isModalOpenChanged } from 'features/deleteImageModal/store/slice';
import { selectListImagesBaseQueryArgs } from 'features/gallery/store/gallerySelectors';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { isInvocationNode } from 'features/nodes/types/types';
import { isImageFieldInputInstance } from 'features/nodes/types/field';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { clearInitialImage } from 'features/parameters/store/generationSlice';
import { clamp, forEach } from 'lodash-es';
import { api } from 'services/api';
import { imagesApi } from 'services/api/endpoints/images';
import { imagesAdapter } from 'services/api/util';
import { startAppListening } from '..';
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
export const addRequestedSingleImageDeletionListener = () => {
startAppListening({
@ -121,7 +122,7 @@ export const addRequestedSingleImageDeletionListener = () => {
forEach(node.data.inputs, (input) => {
if (
input.type === 'ImageField' &&
isImageFieldInputInstance(input) &&
input.value?.image_name === imageDTO.image_name
) {
dispatch(
@ -241,7 +242,7 @@ export const addRequestedMultipleImageDeletionListener = () => {
forEach(node.data.inputs, (input) => {
if (
input.type === 'ImageField' &&
isImageFieldInputInstance(input) &&
input.value?.image_name === imageDTO.image_name
) {
dispatch(

View File

@ -12,12 +12,12 @@ import {
setWidth,
vaeSelected,
} from 'features/parameters/store/generationSlice';
import { zMainOrOnnxModel } from 'features/parameters/types/parameterSchemas';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next';
import { forEach } from 'lodash-es';
import { startAppListening } from '..';
import { zParameterModel } from 'features/parameters/types/parameterSchemas';
export const addModelSelectedListener = () => {
startAppListening({
@ -26,7 +26,7 @@ export const addModelSelectedListener = () => {
const log = logger('models');
const state = getState();
const result = zMainOrOnnxModel.safeParse(action.payload);
const result = zParameterModel.safeParse(action.payload);
if (!result.success) {
log.error(

View File

@ -11,9 +11,9 @@ import {
vaeSelected,
} from 'features/parameters/store/generationSlice';
import {
zMainOrOnnxModel,
zSDXLRefinerModel,
zVaeModel,
zParameterModel,
zParameterSDXLRefinerModel,
zParameterVAEModel,
} from 'features/parameters/types/parameterSchemas';
import {
refinerModelChanged,
@ -67,7 +67,7 @@ export const addModelsLoadedListener = () => {
return;
}
const result = zMainOrOnnxModel.safeParse(models[0]);
const result = zParameterModel.safeParse(models[0]);
if (!result.success) {
log.error(
@ -119,7 +119,7 @@ export const addModelsLoadedListener = () => {
return;
}
const result = zSDXLRefinerModel.safeParse(models[0]);
const result = zParameterSDXLRefinerModel.safeParse(models[0]);
if (!result.success) {
log.error(
@ -170,7 +170,7 @@ export const addModelsLoadedListener = () => {
return;
}
const result = zVaeModel.safeParse(firstModel);
const result = zParameterVAEModel.safeParse(firstModel);
if (!result.success) {
log.error(

View File

@ -15,6 +15,7 @@ export const addReceivedOpenAPISchemaListener = () => {
log.debug({ schemaJSON }, 'Received OpenAPI schema');
const { nodesAllowlist, nodesDenylist } = getState().config;
const nodeTemplates = parseSchema(
schemaJSON,
nodesAllowlist,

View File

@ -13,13 +13,13 @@ import {
} from 'features/nodes/util/graphBuilders/constants';
import { boardsApi } from 'services/api/endpoints/boards';
import { imagesApi } from 'services/api/endpoints/images';
import { isImageOutput } from 'services/api/guards';
import { imagesAdapter } from 'services/api/util';
import {
appSocketInvocationComplete,
socketInvocationComplete,
} from 'services/events/actions';
import { startAppListening } from '../..';
import { isImageOutput } from 'features/nodes/types/common';
// These nodes output an image, but do not actually *save* an image, so we don't want to handle the gallery logic on them
const nodeTypeDenylist = ['load_image', 'image'];

View File

@ -1,14 +1,16 @@
import { logger } from 'app/logging/logger';
import { updateAllNodesRequested } from 'features/nodes/store/actions';
import { nodeReplaced } from 'features/nodes/store/nodesSlice';
import {
getNeedsUpdate,
updateNode,
} from 'features/nodes/hooks/useNodeVersion';
import { updateAllNodesRequested } from 'features/nodes/store/actions';
import { nodeReplaced } from 'features/nodes/store/nodesSlice';
import { startAppListening } from '..';
import { logger } from 'app/logging/logger';
} from 'features/nodes/store/util/nodeUpdate';
import { NodeUpdateError } from 'features/nodes/types/error';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next';
import { startAppListening } from '..';
export const addUpdateAllNodesRequestedListener = () => {
startAppListening({
@ -20,22 +22,31 @@ export const addUpdateAllNodesRequestedListener = () => {
let unableToUpdateCount = 0;
nodes.forEach((node) => {
nodes.filter(isInvocationNode).forEach((node) => {
const template = templates[node.data.type];
const needsUpdate = getNeedsUpdate(node, template);
const updatedNode = updateNode(node, template);
if (!updatedNode) {
if (needsUpdate) {
if (!template) {
unableToUpdateCount++;
}
return;
}
if (!getNeedsUpdate(node, template)) {
// No need to increment the count here, since we're not actually updating
return;
}
try {
const updatedNode = updateNode(node, template);
dispatch(nodeReplaced({ nodeId: updatedNode.id, node: updatedNode }));
} catch (e) {
if (e instanceof NodeUpdateError) {
unableToUpdateCount++;
}
}
});
if (unableToUpdateCount) {
log.warn(
`Unable to update ${unableToUpdateCount} nodes. Please report this issue.`
t('nodes.unableToUpdateNodes', {
count: unableToUpdateCount,
})
);
dispatch(
addToast(
@ -46,6 +57,15 @@ export const addUpdateAllNodesRequestedListener = () => {
})
)
);
} else {
dispatch(
addToast(
makeToast({
title: t('nodes.allNodesUpdated'),
status: 'success',
})
)
);
}
},
});

View File

@ -0,0 +1,105 @@
import { logger } from 'app/logging/logger';
import { parseify } from 'common/util/serialize';
import { workflowLoadRequested } from 'features/nodes/store/actions';
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
import { $flow } from 'features/nodes/store/reactFlowInstance';
import { WorkflowVersionError } from 'features/nodes/types/error';
import { validateWorkflow } from 'features/nodes/util/validateWorkflow';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { t } from 'i18next';
import { z } from 'zod';
import { fromZodError } from 'zod-validation-error';
import { startAppListening } from '..';
export const addWorkflowLoadRequestedListener = () => {
startAppListening({
actionCreator: workflowLoadRequested,
effect: (action, { dispatch, getState }) => {
const log = logger('nodes');
const workflow = action.payload;
const nodeTemplates = getState().nodes.nodeTemplates;
try {
const { workflow: validatedWorkflow, warnings } = validateWorkflow(
workflow,
nodeTemplates
);
dispatch(workflowLoaded(validatedWorkflow));
if (!warnings.length) {
dispatch(
addToast(
makeToast({
title: t('toast.workflowLoaded'),
status: 'success',
})
)
);
} else {
dispatch(
addToast(
makeToast({
title: t('toast.loadedWithWarnings'),
status: 'warning',
})
)
);
warnings.forEach(({ message, ...rest }) => {
log.warn(rest, message);
});
}
dispatch(setActiveTab('nodes'));
requestAnimationFrame(() => {
$flow.get()?.fitView();
});
} catch (e) {
if (e instanceof WorkflowVersionError) {
// The workflow version was not recognized in the valid list of versions
log.error({ error: parseify(e) }, e.message);
dispatch(
addToast(
makeToast({
title: t('nodes.unableToValidateWorkflow'),
status: 'error',
description: e.message,
})
)
);
} else if (e instanceof z.ZodError) {
// There was a problem validating the workflow itself
const { message } = fromZodError(e, {
prefix: t('nodes.workflowValidation'),
});
log.error({ error: parseify(e) }, message);
dispatch(
addToast(
makeToast({
title: t('nodes.unableToValidateWorkflow'),
status: 'error',
description: message,
})
)
);
} else {
// Some other error occurred
console.log(e);
log.error(
{ error: parseify(e) },
t('nodes.unknownErrorValidatingWorkflow')
);
dispatch(
addToast(
makeToast({
title: t('nodes.unableToValidateWorkflow'),
status: 'error',
description: t('nodes.unknownErrorValidatingWorkflow'),
})
)
);
}
}
},
});
};

View File

@ -1,56 +0,0 @@
import { logger } from 'app/logging/logger';
import { workflowLoadRequested } from 'features/nodes/store/actions';
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
import { $flow } from 'features/nodes/store/reactFlowInstance';
import { validateWorkflow } from 'features/nodes/util/validateWorkflow';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { startAppListening } from '..';
import { t } from 'i18next';
export const addWorkflowLoadedListener = () => {
startAppListening({
actionCreator: workflowLoadRequested,
effect: (action, { dispatch, getState }) => {
const log = logger('nodes');
const workflow = action.payload;
const nodeTemplates = getState().nodes.nodeTemplates;
const { workflow: validatedWorkflow, errors } = validateWorkflow(
workflow,
nodeTemplates
);
dispatch(workflowLoaded(validatedWorkflow));
if (!errors.length) {
dispatch(
addToast(
makeToast({
title: t('toast.workflowLoaded'),
status: 'success',
})
)
);
} else {
dispatch(
addToast(
makeToast({
title: t('toast.loadedWithWarnings'),
status: 'warning',
})
)
);
errors.forEach(({ message, ...rest }) => {
log.warn(rest, message);
});
}
dispatch(setActiveTab('nodes'));
requestAnimationFrame(() => {
$flow.get()?.fitView();
});
},
});
};

View File

@ -4,7 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { selectControlAdapterAll } from 'features/controlAdapters/store/controlAdaptersSlice';
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
import { isInvocationNode } from 'features/nodes/types/types';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import i18n from 'i18next';
import { forEach } from 'lodash-es';

View File

@ -6,9 +6,9 @@ import {
isAnyOf,
} from '@reduxjs/toolkit';
import {
ControlNetModelParam,
IPAdapterModelParam,
T2IAdapterModelParam,
ParameterControlNetModel,
ParameterIPAdapterModel,
ParameterT2IAdapterModel,
} from 'features/parameters/types/parameterSchemas';
import { cloneDeep, merge, uniq } from 'lodash-es';
import { appSocketInvocationError } from 'services/events/actions';
@ -243,9 +243,9 @@ export const controlAdaptersSlice = createSlice({
action: PayloadAction<{
id: string;
model:
| ControlNetModelParam
| T2IAdapterModelParam
| IPAdapterModelParam;
| ParameterControlNetModel
| ParameterT2IAdapterModel
| ParameterIPAdapterModel;
}>
) => {
const { id, model } = action.payload;

View File

@ -1,8 +1,8 @@
import { EntityState } from '@reduxjs/toolkit';
import {
ControlNetModelParam,
IPAdapterModelParam,
T2IAdapterModelParam,
ParameterControlNetModel,
ParameterIPAdapterModel,
ParameterT2IAdapterModel,
} from 'features/parameters/types/parameterSchemas';
import { isObject } from 'lodash-es';
import { components } from 'services/api/schema';
@ -378,7 +378,7 @@ export type ControlNetConfig = {
type: 'controlnet';
id: string;
isEnabled: boolean;
model: ControlNetModelParam | null;
model: ParameterControlNetModel | null;
weight: number;
beginStepPct: number;
endStepPct: number;
@ -395,7 +395,7 @@ export type T2IAdapterConfig = {
type: 't2i_adapter';
id: string;
isEnabled: boolean;
model: T2IAdapterModelParam | null;
model: ParameterT2IAdapterModel | null;
weight: number;
beginStepPct: number;
endStepPct: number;
@ -412,7 +412,7 @@ export type IPAdapterConfig = {
id: string;
isEnabled: boolean;
controlImage: string | null;
model: IPAdapterModelParam | null;
model: ParameterIPAdapterModel | null;
weight: number;
beginStepPct: number;
endStepPct: number;

View File

@ -1,11 +1,12 @@
import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { isInvocationNode } from 'features/nodes/types/types';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { some } from 'lodash-es';
import { ImageUsage } from './types';
import { selectControlAdapterAll } from 'features/controlAdapters/store/controlAdaptersSlice';
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
import { isImageFieldInputInstance } from 'features/nodes/types/field';
export const getImageUsage = (state: RootState, image_name: string) => {
const { generation, canvas, nodes, controlAdapters } = state;
@ -19,7 +20,8 @@ export const getImageUsage = (state: RootState, image_name: string) => {
return some(
node.data.inputs,
(input) =>
input.type === 'ImageField' && input.value?.image_name === image_name
isImageFieldInputInstance(input) &&
input.value?.image_name === image_name
);
});

View File

@ -11,9 +11,9 @@ import {
useDroppable as useOriginalDroppable,
} from '@dnd-kit/core';
import {
InputFieldTemplate,
InputFieldValue,
} from 'features/nodes/types/types';
FieldInputTemplate,
FieldInputInstance,
} from 'features/nodes/types/field';
import { ImageDTO } from 'services/api/types';
type BaseDropData = {
@ -93,8 +93,8 @@ export type NodeFieldDraggableData = BaseDragData & {
payloadType: 'NODE_FIELD';
payload: {
nodeId: string;
field: InputFieldValue;
fieldTemplate: InputFieldTemplate;
field: FieldInputInstance;
fieldTemplate: FieldInputTemplate;
};
};

View File

@ -4,14 +4,14 @@ import {
LoRAMetadataItem,
IPAdapterMetadataItem,
T2IAdapterMetadataItem,
} from 'features/nodes/types/types';
} from 'features/nodes/types/metadata';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { memo, useMemo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import {
isValidControlNetModel,
isValidLoRAModel,
isValidT2IAdapterModel,
isParameterControlNetModel,
isParameterLoRAModel,
isParameterT2IAdapterModel,
} from '../../../parameters/types/parameterSchemas';
import ImageMetadataItem from './ImageMetadataItem';
@ -132,7 +132,7 @@ const ImageMetadataActions = (props: Props) => {
const validControlNets: ControlNetMetadataItem[] = useMemo(() => {
return metadata?.controlnets
? metadata.controlnets.filter((controlnet) =>
isValidControlNetModel(controlnet.control_model)
isParameterControlNetModel(controlnet.control_model)
)
: [];
}, [metadata?.controlnets]);
@ -140,7 +140,7 @@ const ImageMetadataActions = (props: Props) => {
const validIPAdapters: IPAdapterMetadataItem[] = useMemo(() => {
return metadata?.ipAdapters
? metadata.ipAdapters.filter((ipAdapter) =>
isValidControlNetModel(ipAdapter.ip_adapter_model)
isParameterControlNetModel(ipAdapter.ip_adapter_model)
)
: [];
}, [metadata?.ipAdapters]);
@ -148,7 +148,7 @@ const ImageMetadataActions = (props: Props) => {
const validT2IAdapters: T2IAdapterMetadataItem[] = useMemo(() => {
return metadata?.t2iAdapters
? metadata.t2iAdapters.filter((t2iAdapter) =>
isValidT2IAdapterModel(t2iAdapter.t2i_adapter_model)
isParameterT2IAdapterModel(t2iAdapter.t2i_adapter_model)
)
: [];
}, [metadata?.t2iAdapters]);
@ -157,8 +157,6 @@ const ImageMetadataActions = (props: Props) => {
return null;
}
console.log(metadata);
return (
<>
{metadata.created_by && (
@ -275,7 +273,7 @@ const ImageMetadataActions = (props: Props) => {
)}
{metadata.loras &&
metadata.loras.map((lora, index) => {
if (isValidLoRAModel(lora.lora)) {
if (isParameterLoRAModel(lora.lora)) {
return (
<ImageMetadataItem
key={index}

View File

@ -1,8 +1,8 @@
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import { LoRAModelParam } from 'features/parameters/types/parameterSchemas';
import { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
import { LoRAModelConfigEntity } from 'services/api/endpoints/models';
export type LoRA = LoRAModelParam & {
export type LoRA = ParameterLoRAModel & {
id: string;
weight: number;
};

View File

@ -24,7 +24,6 @@ import { useHotkeys } from 'react-hotkeys-hook';
import { HotkeyCallback } from 'react-hotkeys-hook/dist/types';
import { useTranslation } from 'react-i18next';
import 'reactflow/dist/style.css';
import { AnyInvocationType } from 'services/events/types';
import { AddNodePopoverSelectItem } from './AddNodePopoverSelectItem';
type NodeTemplate = {
@ -57,7 +56,7 @@ const AddNodePopover = () => {
const { t } = useTranslation();
const fieldFilter = useAppSelector(
(state) => state.nodes.currentConnectionFieldType
(state) => state.nodes.connectionStartFieldType
);
const handleFilter = useAppSelector(
(state) => state.nodes.connectionStartParams?.handleType
@ -111,7 +110,7 @@ const AddNodePopover = () => {
data.sort((a, b) => a.label.localeCompare(b.label));
return { data, t };
return { data };
},
defaultSelectorOptions
);
@ -121,7 +120,7 @@ const AddNodePopover = () => {
const inputRef = useRef<HTMLInputElement>(null);
const addNode = useCallback(
(nodeType: AnyInvocationType) => {
(nodeType: string) => {
const invocation = buildInvocation(nodeType);
if (!invocation) {
const errorMessage = t('nodes.unknownNode', {
@ -145,7 +144,7 @@ const AddNodePopover = () => {
return;
}
addNode(v as AnyInvocationType);
addNode(v);
},
[addNode]
);

View File

@ -2,17 +2,16 @@ import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { FIELDS } from 'features/nodes/types/constants';
import { memo } from 'react';
import { ConnectionLineComponentProps, getBezierPath } from 'reactflow';
import { getFieldColor } from '../edges/util/getEdgeColor';
const selector = createSelector(stateSelector, ({ nodes }) => {
const { shouldAnimateEdges, currentConnectionFieldType, shouldColorEdges } =
const { shouldAnimateEdges, connectionStartFieldType, shouldColorEdges } =
nodes;
const stroke =
currentConnectionFieldType && shouldColorEdges
? colorTokenToCssVar(FIELDS[currentConnectionFieldType].color)
const stroke = shouldColorEdges
? getFieldColor(connectionStartFieldType)
: colorTokenToCssVar('base.500');
let className = 'react-flow__custom_connection-path';

View File

@ -0,0 +1,12 @@
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { FIELD_COLORS } from 'features/nodes/types/constants';
import { FieldType } from 'features/nodes/types/field';
export const getFieldColor = (fieldType: FieldType | null): string => {
if (!fieldType) {
return colorTokenToCssVar('base.500');
}
const color = FIELD_COLORS[fieldType.name];
return color ? colorTokenToCssVar(color) : colorTokenToCssVar('base.500');
};

View File

@ -2,8 +2,8 @@ import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { FIELDS } from 'features/nodes/types/constants';
import { isInvocationNode } from 'features/nodes/types/types';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { getFieldColor } from './getEdgeColor';
export const makeEdgeSelector = (
source: string,
@ -29,7 +29,7 @@ export const makeEdgeSelector = (
const stroke =
sourceType && nodes.shouldColorEdges
? colorTokenToCssVar(FIELDS[sourceType].color)
? getFieldColor(sourceType)
: colorTokenToCssVar('base.500');
return {

View File

@ -1,7 +1,7 @@
import { useColorModeValue } from '@chakra-ui/react';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import { useNodeData } from 'features/nodes/hooks/useNodeData';
import { isInvocationNodeData } from 'features/nodes/types/types';
import { isInvocationNodeData } from 'features/nodes/types/invocation';
import { map } from 'lodash-es';
import { CSSProperties, memo, useMemo } from 'react';
import { Handle, Position } from 'reactflow';

View File

@ -2,8 +2,8 @@ import { Flex, Icon, Text, Tooltip } from '@chakra-ui/react';
import { compare } from 'compare-versions';
import { useNodeData } from 'features/nodes/hooks/useNodeData';
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { useNodeVersion } from 'features/nodes/hooks/useNodeVersion';
import { isInvocationNodeData } from 'features/nodes/types/types';
import { useNodeNeedsUpdate } from 'features/nodes/hooks/useNodeNeedsUpdate';
import { isInvocationNodeData } from 'features/nodes/types/invocation';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { FaInfoCircle } from 'react-icons/fa';
@ -13,7 +13,7 @@ interface Props {
}
const InvocationNodeInfoIcon = ({ nodeId }: Props) => {
const { needsUpdate } = useNodeVersion(nodeId);
const needsUpdate = useNodeNeedsUpdate(nodeId);
return (
<Tooltip

View File

@ -11,7 +11,10 @@ import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import { NodeExecutionState, NodeStatus } from 'features/nodes/types/types';
import {
NodeExecutionState,
zNodeStatus,
} from 'features/nodes/types/invocation';
import { memo, useMemo } from 'react';
import { FaCheck, FaEllipsisH, FaExclamation } from 'react-icons/fa';
import { useTranslation } from 'react-i18next';
@ -74,10 +77,10 @@ type TooltipLabelProps = {
const TooltipLabel = memo(({ nodeExecutionState }: TooltipLabelProps) => {
const { status, progress, progressImage } = nodeExecutionState;
const { t } = useTranslation();
if (status === NodeStatus.PENDING) {
if (status === zNodeStatus.enum.PENDING) {
return <Text>{t('queue.pending')}</Text>;
}
if (status === NodeStatus.IN_PROGRESS) {
if (status === zNodeStatus.enum.IN_PROGRESS) {
if (progressImage) {
return (
<Flex sx={{ pos: 'relative', pt: 1.5, pb: 0.5 }}>
@ -108,11 +111,11 @@ const TooltipLabel = memo(({ nodeExecutionState }: TooltipLabelProps) => {
return <Text>{t('nodes.executionStateInProgress')}</Text>;
}
if (status === NodeStatus.COMPLETED) {
if (status === zNodeStatus.enum.COMPLETED) {
return <Text>{t('nodes.executionStateCompleted')}</Text>;
}
if (status === NodeStatus.FAILED) {
if (status === zNodeStatus.enum.FAILED) {
return <Text>{t('nodes.executionStateError')}</Text>;
}
@ -127,7 +130,7 @@ type StatusIconProps = {
const StatusIcon = memo((props: StatusIconProps) => {
const { progress, status } = props.nodeExecutionState;
if (status === NodeStatus.PENDING) {
if (status === zNodeStatus.enum.PENDING) {
return (
<Icon
as={FaEllipsisH}
@ -139,7 +142,7 @@ const StatusIcon = memo((props: StatusIconProps) => {
/>
);
}
if (status === NodeStatus.IN_PROGRESS) {
if (status === zNodeStatus.enum.IN_PROGRESS) {
return progress === null ? (
<CircularProgress
isIndeterminate
@ -158,7 +161,7 @@ const StatusIcon = memo((props: StatusIconProps) => {
/>
);
}
if (status === NodeStatus.COMPLETED) {
if (status === zNodeStatus.enum.COMPLETED) {
return (
<Icon
as={FaCheck}
@ -170,7 +173,7 @@ const StatusIcon = memo((props: StatusIconProps) => {
/>
);
}
if (status === NodeStatus.FAILED) {
if (status === zNodeStatus.enum.FAILED) {
return (
<Icon
as={FaExclamation}

View File

@ -1,7 +1,7 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { InvocationNodeData } from 'features/nodes/types/types';
import { InvocationNodeData } from 'features/nodes/types/invocation';
import { memo, useMemo } from 'react';
import { NodeProps } from 'reactflow';
import InvocationNode from '../Invocation/InvocationNode';

View File

@ -3,7 +3,7 @@ import { useAppDispatch } from 'app/store/storeHooks';
import IAITextarea from 'common/components/IAITextarea';
import { useNodeData } from 'features/nodes/hooks/useNodeData';
import { nodeNotesChanged } from 'features/nodes/store/nodesSlice';
import { isInvocationNodeData } from 'features/nodes/types/types';
import { isInvocationNodeData } from 'features/nodes/types/invocation';
import { ChangeEvent, memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';

View File

@ -56,7 +56,7 @@ const FieldContextMenu = ({ nodeId, fieldName, kind, children }: Props) => {
);
const mayExpose = useMemo(
() => ['any', 'direct'].includes(input ?? '__UNKNOWN_INPUT__'),
() => input && ['any', 'direct'].includes(input),
[input]
);

View File

@ -1,18 +1,17 @@
import { Tooltip } from '@chakra-ui/react';
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
import {
COLLECTION_TYPES,
FIELDS,
HANDLE_TOOLTIP_OPEN_DELAY,
MODEL_TYPES,
POLYMORPHIC_TYPES,
} from 'features/nodes/types/constants';
import {
InputFieldTemplate,
OutputFieldTemplate,
} from 'features/nodes/types/types';
FieldInputTemplate,
FieldOutputTemplate,
} from 'features/nodes/types/field';
import { CSSProperties, memo, useMemo } from 'react';
import { Handle, HandleType, Position } from 'reactflow';
import { getFieldColor } from '../../../edges/util/getEdgeColor';
export const handleBaseStyles: CSSProperties = {
position: 'absolute',
@ -32,11 +31,11 @@ export const outputHandleStyles: CSSProperties = {
};
type FieldHandleProps = {
fieldTemplate: InputFieldTemplate | OutputFieldTemplate;
fieldTemplate: FieldInputTemplate | FieldOutputTemplate;
handleType: HandleType;
isConnectionInProgress: boolean;
isConnectionStartField: boolean;
connectionError: string | null;
connectionError?: string;
};
const FieldHandle = (props: FieldHandleProps) => {
@ -47,23 +46,21 @@ const FieldHandle = (props: FieldHandleProps) => {
isConnectionStartField,
connectionError,
} = props;
const { name, type } = fieldTemplate;
const { color: typeColor, title } = FIELDS[type];
const { name } = fieldTemplate;
const type = fieldTemplate.type;
const fieldTypeName = useFieldTypeName(type);
const styles: CSSProperties = useMemo(() => {
const isCollectionType = COLLECTION_TYPES.includes(type);
const isPolymorphicType = POLYMORPHIC_TYPES.includes(type);
const isModelType = MODEL_TYPES.includes(type);
const color = colorTokenToCssVar(typeColor);
const isModelType = MODEL_TYPES.some((t) => t === type.name);
const color = getFieldColor(type);
const s: CSSProperties = {
backgroundColor:
isCollectionType || isPolymorphicType
? 'var(--invokeai-colors-base-900)'
type.isCollection || type.isPolymorphic
? colorTokenToCssVar('base.900')
: color,
position: 'absolute',
width: '1rem',
height: '1rem',
borderWidth: isCollectionType || isPolymorphicType ? 4 : 0,
borderWidth: type.isCollection || type.isPolymorphic ? 4 : 0,
borderStyle: 'solid',
borderColor: color,
borderRadius: isModelType ? 4 : '100%',
@ -97,18 +94,14 @@ const FieldHandle = (props: FieldHandleProps) => {
isConnectionInProgress,
isConnectionStartField,
type,
typeColor,
]);
const tooltip = useMemo(() => {
if (isConnectionInProgress && isConnectionStartField) {
return title;
}
if (isConnectionInProgress && connectionError) {
return connectionError ?? title;
return connectionError;
}
return title;
}, [connectionError, isConnectionInProgress, isConnectionStartField, title]);
return fieldTypeName;
}, [connectionError, fieldTypeName, isConnectionInProgress]);
return (
<Tooltip

View File

@ -1,15 +1,14 @@
import { Flex, Text } from '@chakra-ui/react';
import { useFieldData } from 'features/nodes/hooks/useFieldData';
import { useFieldInstance } from 'features/nodes/hooks/useFieldData';
import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate';
import { FIELDS } from 'features/nodes/types/constants';
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
import {
isInputFieldTemplate,
isInputFieldValue,
} from 'features/nodes/types/types';
isFieldInputInstance,
isFieldInputTemplate,
} from 'features/nodes/types/field';
import { startCase } from 'lodash-es';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
interface Props {
nodeId: string;
fieldName: string;
@ -17,12 +16,13 @@ interface Props {
}
const FieldTooltipContent = ({ nodeId, fieldName, kind }: Props) => {
const field = useFieldData(nodeId, fieldName);
const field = useFieldInstance(nodeId, fieldName);
const fieldTemplate = useFieldTemplate(nodeId, fieldName, kind);
const isInputTemplate = isInputFieldTemplate(fieldTemplate);
const isInputTemplate = isFieldInputTemplate(fieldTemplate);
const fieldTypeName = useFieldTypeName(fieldTemplate?.type);
const { t } = useTranslation();
const fieldTitle = useMemo(() => {
if (isInputFieldValue(field)) {
if (isFieldInputInstance(field)) {
if (field.label && fieldTemplate?.title) {
return `${field.label} (${fieldTemplate.title})`;
}
@ -49,9 +49,9 @@ const FieldTooltipContent = ({ nodeId, fieldName, kind }: Props) => {
{fieldTemplate.description}
</Text>
)}
{fieldTemplate && (
{fieldTypeName && (
<Text>
{t('parameters.type')}: {FIELDS[fieldTemplate.type].title}
{t('parameters.type')}: {fieldTypeName}
</Text>
)}
{isInputTemplate && (

View File

@ -77,10 +77,10 @@ const InputField = ({ nodeId, fieldName }: Props) => {
sx={{
display: 'flex',
alignItems: 'center',
h: 'full',
mb: 0,
px: 1,
gap: 2,
h: 'full',
}}
>
<EditableFieldTitle

View File

@ -1,24 +1,60 @@
import { Box, Text } from '@chakra-ui/react';
import { useFieldData } from 'features/nodes/hooks/useFieldData';
import { useFieldInstance } from 'features/nodes/hooks/useFieldData';
import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate';
import {
isBoardFieldInputInstance,
isBoardFieldInputTemplate,
isBooleanFieldInputInstance,
isBooleanFieldInputTemplate,
isColorFieldInputInstance,
isColorFieldInputTemplate,
isControlNetModelFieldInputInstance,
isControlNetModelFieldInputTemplate,
isEnumFieldInputInstance,
isEnumFieldInputTemplate,
isFloatFieldInputInstance,
isFloatFieldInputTemplate,
isIPAdapterModelFieldInputInstance,
isIPAdapterModelFieldInputTemplate,
isImageFieldInputInstance,
isImageFieldInputTemplate,
isIntegerFieldInputInstance,
isIntegerFieldInputTemplate,
isLoRAModelFieldInputInstance,
isLoRAModelFieldInputTemplate,
isMainModelFieldInputInstance,
isMainModelFieldInputTemplate,
isSDXLMainModelFieldInputInstance,
isSDXLMainModelFieldInputTemplate,
isSDXLRefinerModelFieldInputInstance,
isSDXLRefinerModelFieldInputTemplate,
isSchedulerFieldInputInstance,
isSchedulerFieldInputTemplate,
isStringFieldInputInstance,
isStringFieldInputTemplate,
isT2IAdapterModelFieldInputInstance,
isT2IAdapterModelFieldInputTemplate,
isVAEModelFieldInputInstance,
isVAEModelFieldInputTemplate,
} from 'features/nodes/types/field';
import { memo } from 'react';
import BooleanInputField from './inputs/BooleanInputField';
import ColorInputField from './inputs/ColorInputField';
import ControlNetModelInputField from './inputs/ControlNetModelInputField';
import EnumInputField from './inputs/EnumInputField';
import ImageInputField from './inputs/ImageInputField';
import LoRAModelInputField from './inputs/LoRAModelInputField';
import MainModelInputField from './inputs/MainModelInputField';
import NumberInputField from './inputs/NumberInputField';
import RefinerModelInputField from './inputs/RefinerModelInputField';
import SDXLMainModelInputField from './inputs/SDXLMainModelInputField';
import SchedulerInputField from './inputs/SchedulerInputField';
import StringInputField from './inputs/StringInputField';
import VaeModelInputField from './inputs/VaeModelInputField';
import IPAdapterModelInputField from './inputs/IPAdapterModelInputField';
import T2IAdapterModelInputField from './inputs/T2IAdapterModelInputField';
import BoardInputField from './inputs/BoardInputField';
import { useTranslation } from 'react-i18next';
import BoardFieldInputComponent from './inputs/BoardFieldInputComponent';
import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent';
import ColorFieldInputComponent from './inputs/ColorFieldInputComponent';
import ControlNetModelFieldInputComponent from './inputs/ControlNetModelFieldInputComponent';
import EnumFieldInputComponent from './inputs/EnumFieldInputComponent';
import IPAdapterModelFieldInputComponent from './inputs/IPAdapterModelFieldInputComponent';
import ImageFieldInputComponent from './inputs/ImageFieldInputComponent';
import LoRAModelFieldInputComponent from './inputs/LoRAModelFieldInputComponent';
import MainModelFieldInputComponent from './inputs/MainModelFieldInputComponent';
import NumberFieldInputComponent from './inputs/NumberFieldInputComponent';
import RefinerModelFieldInputComponent from './inputs/RefinerModelFieldInputComponent';
import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent';
import SchedulerFieldInputComponent from './inputs/SchedulerFieldInputComponent';
import StringFieldInputComponent from './inputs/StringFieldInputComponent';
import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent';
import VAEModelFieldInputComponent from './inputs/VAEModelFieldInputComponent';
type InputFieldProps = {
nodeId: string;
@ -27,220 +63,227 @@ type InputFieldProps = {
const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
const { t } = useTranslation();
const field = useFieldData(nodeId, fieldName);
const fieldInstance = useFieldInstance(nodeId, fieldName);
const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input');
if (fieldTemplate?.fieldKind === 'output') {
return (
<Box p={2}>
{t('nodes.outputFieldInInput')}: {field?.type}
{t('nodes.outputFieldInInput')}: {fieldInstance?.type.name}
</Box>
);
}
if (
(field?.type === 'string' && fieldTemplate?.type === 'string') ||
(field?.type === 'StringPolymorphic' &&
fieldTemplate?.type === 'StringPolymorphic')
isStringFieldInputInstance(fieldInstance) &&
isStringFieldInputTemplate(fieldTemplate)
) {
return (
<StringInputField
<StringFieldInputComponent
nodeId={nodeId}
field={field}
field={fieldInstance}
fieldTemplate={fieldTemplate}
/>
);
}
if (
(field?.type === 'boolean' && fieldTemplate?.type === 'boolean') ||
(field?.type === 'BooleanPolymorphic' &&
fieldTemplate?.type === 'BooleanPolymorphic')
isBooleanFieldInputInstance(fieldInstance) &&
isBooleanFieldInputTemplate(fieldTemplate)
) {
return (
<BooleanInputField
<BooleanFieldInputComponent
nodeId={nodeId}
field={field}
field={fieldInstance}
fieldTemplate={fieldTemplate}
/>
);
}
if (
(field?.type === 'integer' && fieldTemplate?.type === 'integer') ||
(field?.type === 'float' && fieldTemplate?.type === 'float') ||
(field?.type === 'FloatPolymorphic' &&
fieldTemplate?.type === 'FloatPolymorphic') ||
(field?.type === 'IntegerPolymorphic' &&
fieldTemplate?.type === 'IntegerPolymorphic')
(isIntegerFieldInputInstance(fieldInstance) &&
isIntegerFieldInputTemplate(fieldTemplate)) ||
(isFloatFieldInputInstance(fieldInstance) &&
isFloatFieldInputTemplate(fieldTemplate))
) {
return (
<NumberInputField
<NumberFieldInputComponent
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (field?.type === 'enum' && fieldTemplate?.type === 'enum') {
return (
<EnumInputField
nodeId={nodeId}
field={field}
field={fieldInstance}
fieldTemplate={fieldTemplate}
/>
);
}
if (
(field?.type === 'ImageField' && fieldTemplate?.type === 'ImageField') ||
(field?.type === 'ImagePolymorphic' &&
fieldTemplate?.type === 'ImagePolymorphic')
isEnumFieldInputInstance(fieldInstance) &&
isEnumFieldInputTemplate(fieldTemplate)
) {
return (
<ImageInputField
<EnumFieldInputComponent
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (field?.type === 'BoardField' && fieldTemplate?.type === 'BoardField') {
return (
<BoardInputField
nodeId={nodeId}
field={field}
field={fieldInstance}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'MainModelField' &&
fieldTemplate?.type === 'MainModelField'
isImageFieldInputInstance(fieldInstance) &&
isImageFieldInputTemplate(fieldTemplate)
) {
return (
<MainModelInputField
<ImageFieldInputComponent
nodeId={nodeId}
field={field}
field={fieldInstance}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'SDXLRefinerModelField' &&
fieldTemplate?.type === 'SDXLRefinerModelField'
isBoardFieldInputInstance(fieldInstance) &&
isBoardFieldInputTemplate(fieldTemplate)
) {
return (
<RefinerModelInputField
<BoardFieldInputComponent
nodeId={nodeId}
field={field}
field={fieldInstance}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'VaeModelField' &&
fieldTemplate?.type === 'VaeModelField'
isMainModelFieldInputInstance(fieldInstance) &&
isMainModelFieldInputTemplate(fieldTemplate)
) {
return (
<VaeModelInputField
<MainModelFieldInputComponent
nodeId={nodeId}
field={field}
field={fieldInstance}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'LoRAModelField' &&
fieldTemplate?.type === 'LoRAModelField'
isSDXLRefinerModelFieldInputInstance(fieldInstance) &&
isSDXLRefinerModelFieldInputTemplate(fieldTemplate)
) {
return (
<LoRAModelInputField
<RefinerModelFieldInputComponent
nodeId={nodeId}
field={field}
field={fieldInstance}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'ControlNetModelField' &&
fieldTemplate?.type === 'ControlNetModelField'
isVAEModelFieldInputInstance(fieldInstance) &&
isVAEModelFieldInputTemplate(fieldTemplate)
) {
return (
<ControlNetModelInputField
<VAEModelFieldInputComponent
nodeId={nodeId}
field={field}
field={fieldInstance}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'IPAdapterModelField' &&
fieldTemplate?.type === 'IPAdapterModelField'
isLoRAModelFieldInputInstance(fieldInstance) &&
isLoRAModelFieldInputTemplate(fieldTemplate)
) {
return (
<IPAdapterModelInputField
<LoRAModelFieldInputComponent
nodeId={nodeId}
field={field}
field={fieldInstance}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'T2IAdapterModelField' &&
fieldTemplate?.type === 'T2IAdapterModelField'
isControlNetModelFieldInputInstance(fieldInstance) &&
isControlNetModelFieldInputTemplate(fieldTemplate)
) {
return (
<T2IAdapterModelInputField
<ControlNetModelFieldInputComponent
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (field?.type === 'ColorField' && fieldTemplate?.type === 'ColorField') {
return (
<ColorInputField
nodeId={nodeId}
field={field}
field={fieldInstance}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'SDXLMainModelField' &&
fieldTemplate?.type === 'SDXLMainModelField'
isIPAdapterModelFieldInputInstance(fieldInstance) &&
isIPAdapterModelFieldInputTemplate(fieldTemplate)
) {
return (
<SDXLMainModelInputField
<IPAdapterModelFieldInputComponent
nodeId={nodeId}
field={field}
field={fieldInstance}
fieldTemplate={fieldTemplate}
/>
);
}
if (field?.type === 'Scheduler' && fieldTemplate?.type === 'Scheduler') {
if (
isT2IAdapterModelFieldInputInstance(fieldInstance) &&
isT2IAdapterModelFieldInputTemplate(fieldTemplate)
) {
return (
<SchedulerInputField
<T2IAdapterModelFieldInputComponent
nodeId={nodeId}
field={field}
field={fieldInstance}
fieldTemplate={fieldTemplate}
/>
);
}
if (
isColorFieldInputInstance(fieldInstance) &&
isColorFieldInputTemplate(fieldTemplate)
) {
return (
<ColorFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
/>
);
}
if (field && fieldTemplate) {
if (
isSDXLMainModelFieldInputInstance(fieldInstance) &&
isSDXLMainModelFieldInputTemplate(fieldTemplate)
) {
return (
<SDXLMainModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
/>
);
}
if (
isSchedulerFieldInputInstance(fieldInstance) &&
isSchedulerFieldInputTemplate(fieldTemplate)
) {
return (
<SchedulerFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
/>
);
}
if (fieldInstance && fieldTemplate) {
// Fallback for when there is no component for the type
return null;
}
@ -255,7 +298,7 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
_dark: { color: 'error.300' },
}}
>
{t('nodes.unknownFieldType')}: {field?.type}
{t('nodes.unknownFieldType')}: {fieldInstance?.type.name}
</Text>
</Box>
);

View File

@ -3,15 +3,15 @@ import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { fieldBoardValueChanged } from 'features/nodes/store/nodesSlice';
import {
BoardInputFieldTemplate,
BoardInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
BoardFieldInputTemplate,
BoardFieldInputInstance,
} from 'features/nodes/types/field';
import { FieldComponentProps } from './types';
import { memo, useCallback } from 'react';
import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
const BoardInputFieldComponent = (
props: FieldComponentProps<BoardInputFieldValue, BoardInputFieldTemplate>
const BoardFieldInputComponent = (
props: FieldComponentProps<BoardFieldInputInstance, BoardFieldInputTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
@ -61,4 +61,4 @@ const BoardInputFieldComponent = (
);
};
export default memo(BoardInputFieldComponent);
export default memo(BoardFieldInputComponent);

View File

@ -2,18 +2,16 @@ import { Switch } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { fieldBooleanValueChanged } from 'features/nodes/store/nodesSlice';
import {
BooleanInputFieldTemplate,
BooleanInputFieldValue,
BooleanPolymorphicInputFieldTemplate,
BooleanPolymorphicInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
BooleanFieldInputInstance,
BooleanFieldInputTemplate,
} from 'features/nodes/types/field';
import { FieldComponentProps } from './types';
import { ChangeEvent, memo, useCallback } from 'react';
const BooleanInputFieldComponent = (
const BooleanFieldInputComponent = (
props: FieldComponentProps<
BooleanInputFieldValue | BooleanPolymorphicInputFieldValue,
BooleanInputFieldTemplate | BooleanPolymorphicInputFieldTemplate
BooleanFieldInputInstance,
BooleanFieldInputTemplate
>
) => {
const { nodeId, field } = props;
@ -42,4 +40,4 @@ const BooleanInputFieldComponent = (
);
};
export default memo(BooleanInputFieldComponent);
export default memo(BooleanFieldInputComponent);

View File

@ -1,15 +1,15 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { fieldColorValueChanged } from 'features/nodes/store/nodesSlice';
import {
ColorInputFieldTemplate,
ColorInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
ColorFieldInputTemplate,
ColorFieldInputInstance,
} from 'features/nodes/types/field';
import { FieldComponentProps } from './types';
import { memo, useCallback } from 'react';
import { RgbaColor, RgbaColorPicker } from 'react-colorful';
const ColorInputFieldComponent = (
props: FieldComponentProps<ColorInputFieldValue, ColorInputFieldTemplate>
const ColorFieldInputComponent = (
props: FieldComponentProps<ColorFieldInputInstance, ColorFieldInputTemplate>
) => {
const { nodeId, field } = props;
@ -37,4 +37,4 @@ const ColorInputFieldComponent = (
);
};
export default memo(ColorInputFieldComponent);
export default memo(ColorFieldInputComponent);

View File

@ -3,20 +3,20 @@ import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { fieldControlNetModelValueChanged } from 'features/nodes/store/nodesSlice';
import {
ControlNetModelInputFieldTemplate,
ControlNetModelInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
ControlNetModelFieldInputTemplate,
ControlNetModelFieldInputInstance,
} from 'features/nodes/types/field';
import { FieldComponentProps } from './types';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToControlNetModelParam } from 'features/parameters/util/modelIdToControlNetModelParam';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
const ControlNetModelInputFieldComponent = (
const ControlNetModelFieldInputComponent = (
props: FieldComponentProps<
ControlNetModelInputFieldValue,
ControlNetModelInputFieldTemplate
ControlNetModelFieldInputInstance,
ControlNetModelFieldInputTemplate
>
) => {
const { nodeId, field } = props;
@ -97,4 +97,4 @@ const ControlNetModelInputFieldComponent = (
);
};
export default memo(ControlNetModelInputFieldComponent);
export default memo(ControlNetModelFieldInputComponent);

View File

@ -2,14 +2,14 @@ import { Select } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { fieldEnumModelValueChanged } from 'features/nodes/store/nodesSlice';
import {
EnumInputFieldTemplate,
EnumInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
EnumFieldInputInstance,
EnumFieldInputTemplate,
} from 'features/nodes/types/field';
import { FieldComponentProps } from './types';
import { ChangeEvent, memo, useCallback } from 'react';
const EnumInputFieldComponent = (
props: FieldComponentProps<EnumInputFieldValue, EnumInputFieldTemplate>
const EnumFieldInputComponent = (
props: FieldComponentProps<EnumFieldInputInstance, EnumFieldInputTemplate>
) => {
const { nodeId, field, fieldTemplate } = props;
@ -45,4 +45,4 @@ const EnumInputFieldComponent = (
);
};
export default memo(EnumInputFieldComponent);
export default memo(EnumFieldInputComponent);

View File

@ -3,20 +3,20 @@ import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { fieldIPAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
import {
IPAdapterModelInputFieldTemplate,
IPAdapterModelInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
IPAdapterModelFieldInputTemplate,
IPAdapterModelFieldInputInstance,
} from 'features/nodes/types/field';
import { FieldComponentProps } from './types';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToIPAdapterModelParam } from 'features/parameters/util/modelIdToIPAdapterModelParams';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useGetIPAdapterModelsQuery } from 'services/api/endpoints/models';
const IPAdapterModelInputFieldComponent = (
const IPAdapterModelFieldInputComponent = (
props: FieldComponentProps<
IPAdapterModelInputFieldValue,
IPAdapterModelInputFieldTemplate
IPAdapterModelFieldInputInstance,
IPAdapterModelFieldInputTemplate
>
) => {
const { nodeId, field } = props;
@ -97,4 +97,4 @@ const IPAdapterModelInputFieldComponent = (
);
};
export default memo(IPAdapterModelInputFieldComponent);
export default memo(IPAdapterModelFieldInputComponent);

View File

@ -9,23 +9,18 @@ import {
} from 'features/dnd/types';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import {
FieldComponentProps,
ImageInputFieldTemplate,
ImageInputFieldValue,
ImagePolymorphicInputFieldTemplate,
ImagePolymorphicInputFieldValue,
} from 'features/nodes/types/types';
ImageFieldInputInstance,
ImageFieldInputTemplate,
} from 'features/nodes/types/field';
import { FieldComponentProps } from './types';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { FaUndo } from 'react-icons/fa';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { PostUploadAction } from 'services/api/types';
const ImageInputFieldComponent = (
props: FieldComponentProps<
ImageInputFieldValue | ImagePolymorphicInputFieldValue,
ImageInputFieldTemplate | ImagePolymorphicInputFieldTemplate
>
const ImageFieldInputComponent = (
props: FieldComponentProps<ImageFieldInputInstance, ImageFieldInputTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
@ -102,7 +97,7 @@ const ImageInputFieldComponent = (
);
};
export default memo(ImageInputFieldComponent);
export default memo(ImageFieldInputComponent);
const UploadElement = memo(() => {
const { t } = useTranslation();

View File

@ -5,10 +5,10 @@ import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSe
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { fieldLoRAModelValueChanged } from 'features/nodes/store/nodesSlice';
import {
LoRAModelInputFieldTemplate,
LoRAModelInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
LoRAModelFieldInputTemplate,
LoRAModelFieldInputInstance,
} from 'features/nodes/types/field';
import { FieldComponentProps } from './types';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToLoRAModelParam } from 'features/parameters/util/modelIdToLoRAModelParam';
import { forEach } from 'lodash-es';
@ -16,10 +16,10 @@ import { memo, useCallback, useMemo } from 'react';
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
import { useTranslation } from 'react-i18next';
const LoRAModelInputFieldComponent = (
const LoRAModelFieldInputComponent = (
props: FieldComponentProps<
LoRAModelInputFieldValue,
LoRAModelInputFieldTemplate
LoRAModelFieldInputInstance,
LoRAModelFieldInputTemplate
>
) => {
const { nodeId, field } = props;
@ -121,4 +121,4 @@ const LoRAModelInputFieldComponent = (
);
};
export default memo(LoRAModelInputFieldComponent);
export default memo(LoRAModelFieldInputComponent);

View File

@ -4,10 +4,10 @@ import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import {
MainModelInputFieldTemplate,
MainModelInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
MainModelFieldInputTemplate,
MainModelFieldInputInstance,
} from 'features/nodes/types/field';
import { FieldComponentProps } from './types';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
@ -21,10 +21,10 @@ import {
} from 'services/api/endpoints/models';
import { useTranslation } from 'react-i18next';
const MainModelInputFieldComponent = (
const MainModelFieldInputComponent = (
props: FieldComponentProps<
MainModelInputFieldValue,
MainModelInputFieldTemplate
MainModelFieldInputInstance,
MainModelFieldInputTemplate
>
) => {
const { nodeId, field } = props;
@ -149,4 +149,4 @@ const MainModelInputFieldComponent = (
);
};
export default memo(MainModelInputFieldComponent);
export default memo(MainModelFieldInputComponent);

View File

@ -9,28 +9,18 @@ import { useAppDispatch } from 'app/store/storeHooks';
import { numberStringRegex } from 'common/components/IAINumberInput';
import { fieldNumberValueChanged } from 'features/nodes/store/nodesSlice';
import {
FieldComponentProps,
FloatInputFieldTemplate,
FloatInputFieldValue,
FloatPolymorphicInputFieldTemplate,
FloatPolymorphicInputFieldValue,
IntegerInputFieldTemplate,
IntegerInputFieldValue,
IntegerPolymorphicInputFieldTemplate,
IntegerPolymorphicInputFieldValue,
} from 'features/nodes/types/types';
FloatFieldInputInstance,
FloatFieldInputTemplate,
IntegerFieldInputInstance,
IntegerFieldInputTemplate,
} from 'features/nodes/types/field';
import { FieldComponentProps } from './types';
import { memo, useCallback, useEffect, useMemo, useState } from 'react';
const NumberInputFieldComponent = (
const NumberFieldInputComponent = (
props: FieldComponentProps<
| IntegerInputFieldValue
| IntegerPolymorphicInputFieldValue
| FloatInputFieldValue
| FloatPolymorphicInputFieldValue,
| IntegerInputFieldTemplate
| IntegerPolymorphicInputFieldTemplate
| FloatInputFieldTemplate
| FloatPolymorphicInputFieldTemplate
IntegerFieldInputInstance | FloatFieldInputInstance,
IntegerFieldInputTemplate | FloatFieldInputTemplate
>
) => {
const { nodeId, field, fieldTemplate } = props;
@ -39,7 +29,7 @@ const NumberInputFieldComponent = (
String(field.value)
);
const isIntegerField = useMemo(
() => fieldTemplate.type === 'integer',
() => fieldTemplate.type.name === 'IntegerField',
[fieldTemplate.type]
);
@ -86,4 +76,4 @@ const NumberInputFieldComponent = (
);
};
export default memo(NumberInputFieldComponent);
export default memo(NumberFieldInputComponent);

View File

@ -4,10 +4,10 @@ import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { fieldRefinerModelValueChanged } from 'features/nodes/store/nodesSlice';
import {
FieldComponentProps,
SDXLRefinerModelInputFieldTemplate,
SDXLRefinerModelInputFieldValue,
} from 'features/nodes/types/types';
SDXLRefinerModelFieldInputTemplate,
SDXLRefinerModelFieldInputInstance,
} from 'features/nodes/types/field';
import { FieldComponentProps } from './types';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
@ -18,10 +18,10 @@ import { useTranslation } from 'react-i18next';
import { REFINER_BASE_MODELS } from 'services/api/constants';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
const RefinerModelInputFieldComponent = (
const RefinerModelFieldInputComponent = (
props: FieldComponentProps<
SDXLRefinerModelInputFieldValue,
SDXLRefinerModelInputFieldTemplate
SDXLRefinerModelFieldInputInstance,
SDXLRefinerModelFieldInputTemplate
>
) => {
const { nodeId, field } = props;
@ -120,4 +120,4 @@ const RefinerModelInputFieldComponent = (
);
};
export default memo(RefinerModelInputFieldComponent);
export default memo(RefinerModelFieldInputComponent);

View File

@ -4,10 +4,10 @@ import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import {
SDXLMainModelInputFieldTemplate,
SDXLMainModelInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
SDXLMainModelFieldInputTemplate,
SDXLMainModelFieldInputInstance,
} from 'features/nodes/types/field';
import { FieldComponentProps } from './types';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
@ -21,10 +21,10 @@ import {
useGetOnnxModelsQuery,
} from 'services/api/endpoints/models';
const ModelInputFieldComponent = (
const SDXLMainModelFieldInputComponent = (
props: FieldComponentProps<
SDXLMainModelInputFieldValue,
SDXLMainModelInputFieldTemplate
SDXLMainModelFieldInputInstance,
SDXLMainModelFieldInputTemplate
>
) => {
const { nodeId, field } = props;
@ -147,4 +147,4 @@ const ModelInputFieldComponent = (
);
};
export default memo(ModelInputFieldComponent);
export default memo(SDXLMainModelFieldInputComponent);

View File

@ -5,14 +5,12 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { fieldSchedulerValueChanged } from 'features/nodes/store/nodesSlice';
import {
SchedulerInputFieldTemplate,
SchedulerInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
import {
SCHEDULER_LABEL_MAP,
SchedulerParam,
} from 'features/parameters/types/parameterSchemas';
SchedulerFieldInputTemplate,
SchedulerFieldInputInstance,
} from 'features/nodes/types/field';
import { FieldComponentProps } from './types';
import { ParameterScheduler } from 'features/parameters/types/parameterSchemas';
import { SCHEDULER_LABEL_MAP } from 'features/parameters/types/constants';
import { map } from 'lodash-es';
import { memo, useCallback } from 'react';
@ -24,7 +22,7 @@ const selector = createSelector(
const data = map(SCHEDULER_LABEL_MAP, (label, name) => ({
value: name,
label: label,
group: enabledSchedulers.includes(name as SchedulerParam)
group: enabledSchedulers.includes(name as ParameterScheduler)
? 'Favorites'
: undefined,
})).sort((a, b) => a.label.localeCompare(b.label));
@ -36,10 +34,10 @@ const selector = createSelector(
defaultSelectorOptions
);
const SchedulerInputField = (
const SchedulerFieldInputComponent = (
props: FieldComponentProps<
SchedulerInputFieldValue,
SchedulerInputFieldTemplate
SchedulerFieldInputInstance,
SchedulerFieldInputTemplate
>
) => {
const { nodeId, field } = props;
@ -55,7 +53,7 @@ const SchedulerInputField = (
fieldSchedulerValueChanged({
nodeId,
fieldName: field.name,
value: value as SchedulerParam,
value: value as ParameterScheduler,
})
);
},
@ -72,4 +70,4 @@ const SchedulerInputField = (
);
};
export default memo(SchedulerInputField);
export default memo(SchedulerFieldInputComponent);

View File

@ -3,19 +3,14 @@ import IAIInput from 'common/components/IAIInput';
import IAITextarea from 'common/components/IAITextarea';
import { fieldStringValueChanged } from 'features/nodes/store/nodesSlice';
import {
StringInputFieldTemplate,
StringInputFieldValue,
FieldComponentProps,
StringPolymorphicInputFieldValue,
StringPolymorphicInputFieldTemplate,
} from 'features/nodes/types/types';
StringFieldInputInstance,
StringFieldInputTemplate,
} from 'features/nodes/types/field';
import { FieldComponentProps } from './types';
import { ChangeEvent, memo, useCallback } from 'react';
const StringInputFieldComponent = (
props: FieldComponentProps<
StringInputFieldValue | StringPolymorphicInputFieldValue,
StringInputFieldTemplate | StringPolymorphicInputFieldTemplate
>
const StringFieldInputComponent = (
props: FieldComponentProps<StringFieldInputInstance, StringFieldInputTemplate>
) => {
const { nodeId, field, fieldTemplate } = props;
const dispatch = useAppDispatch();
@ -48,4 +43,4 @@ const StringInputFieldComponent = (
return <IAIInput onChange={handleValueChanged} value={field.value} />;
};
export default memo(StringInputFieldComponent);
export default memo(StringFieldInputComponent);

View File

@ -3,20 +3,20 @@ import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { fieldT2IAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
import {
T2IAdapterModelInputFieldTemplate,
T2IAdapterModelInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
T2IAdapterModelFieldInputInstance,
T2IAdapterModelFieldInputTemplate,
} from 'features/nodes/types/field';
import { FieldComponentProps } from './types';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToT2IAdapterModelParam } from 'features/parameters/util/modelIdToT2IAdapterModelParam';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useGetT2IAdapterModelsQuery } from 'services/api/endpoints/models';
const T2IAdapterModelInputFieldComponent = (
const T2IAdapterModelFieldInputComponent = (
props: FieldComponentProps<
T2IAdapterModelInputFieldValue,
T2IAdapterModelInputFieldTemplate
T2IAdapterModelFieldInputInstance,
T2IAdapterModelFieldInputTemplate
>
) => {
const { nodeId, field } = props;
@ -97,4 +97,4 @@ const T2IAdapterModelInputFieldComponent = (
);
};
export default memo(T2IAdapterModelInputFieldComponent);
export default memo(T2IAdapterModelFieldInputComponent);

View File

@ -4,20 +4,20 @@ import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSe
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice';
import {
FieldComponentProps,
VaeModelInputFieldTemplate,
VaeModelInputFieldValue,
} from 'features/nodes/types/types';
VAEModelFieldInputTemplate,
VAEModelFieldInputInstance,
} from 'features/nodes/types/field';
import { FieldComponentProps } from './types';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToVAEModelParam } from 'features/parameters/util/modelIdToVAEModelParam';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
const VaeModelInputFieldComponent = (
const VAEModelFieldInputComponent = (
props: FieldComponentProps<
VaeModelInputFieldValue,
VaeModelInputFieldTemplate
VAEModelFieldInputInstance,
VAEModelFieldInputTemplate
>
) => {
const { nodeId, field } = props;
@ -105,4 +105,4 @@ const VaeModelInputFieldComponent = (
);
};
export default memo(VaeModelInputFieldComponent);
export default memo(VAEModelFieldInputComponent);

View File

@ -0,0 +1,13 @@
import {
FieldInputInstance,
FieldInputTemplate,
} from 'features/nodes/types/field';
export type FieldComponentProps<
V extends FieldInputInstance,
T extends FieldInputTemplate,
> = {
nodeId: string;
field: V;
fieldTemplate: T;
};

View File

@ -2,7 +2,7 @@ import { Box, Flex } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import IAITextarea from 'common/components/IAITextarea';
import { notesNodeValueChanged } from 'features/nodes/store/nodesSlice';
import { NotesNodeData } from 'features/nodes/types/types';
import { NotesNodeData } from 'features/nodes/types/invocation';
import { ChangeEvent, memo, useCallback } from 'react';
import { NodeProps } from 'reactflow';
import NodeWrapper from '../common/NodeWrapper';

View File

@ -14,7 +14,7 @@ import {
DRAG_HANDLE_CLASSNAME,
NODE_WIDTH,
} from 'features/nodes/types/constants';
import { NodeStatus } from 'features/nodes/types/types';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { contextMenusClosed } from 'features/ui/store/uiSlice';
import {
MouseEvent,
@ -40,7 +40,8 @@ const NodeWrapper = (props: NodeWrapperProps) => {
createSelector(
stateSelector,
({ nodes }) =>
nodes.nodeExecutionStates[nodeId]?.status === NodeStatus.IN_PROGRESS
nodes.nodeExecutionStates[nodeId]?.status ===
zNodeStatus.enum.IN_PROGRESS
),
[nodeId]
);

View File

@ -8,7 +8,7 @@ import { FaUpload } from 'react-icons/fa';
const LoadWorkflowButton = () => {
const { t } = useTranslation();
const resetRef = useRef<() => void>(null);
const loadWorkflowFromFile = useLoadWorkflowFromFile();
const loadWorkflowFromFile = useLoadWorkflowFromFile(resetRef);
return (
<FileButton
resetRef={resetRef}

View File

@ -1,31 +0,0 @@
import { Badge, Flex, Tooltip } from '@chakra-ui/react';
import { FIELDS } from 'features/nodes/types/constants';
import { map } from 'lodash-es';
import { memo } from 'react';
import 'reactflow/dist/style.css';
const FieldTypeLegend = () => {
return (
<Flex sx={{ gap: 2, flexDir: 'column' }}>
{map(FIELDS, ({ title, description, color }, key) => (
<Tooltip key={key} label={description}>
<Badge
sx={{
userSelect: 'none',
color:
parseInt(color.split('.')[1] ?? '0', 10) < 500
? 'base.800'
: 'base.50',
bg: color,
}}
textAlign="center"
>
{title}
</Badge>
</Tooltip>
))}
</Flex>
);
};
export default memo(FieldTypeLegend);

View File

@ -1,18 +1,11 @@
import { Flex } from '@chakra-ui/layout';
import { useAppSelector } from 'app/store/storeHooks';
import { memo } from 'react';
import FieldTypeLegend from './FieldTypeLegend';
import WorkflowEditorSettings from './WorkflowEditorSettings';
const TopRightPanel = () => {
const shouldShowFieldTypeLegend = useAppSelector(
(state) => state.nodes.shouldShowFieldTypeLegend
);
return (
<Flex sx={{ gap: 2, position: 'absolute', top: 2, insetInlineEnd: 2 }}>
<WorkflowEditorSettings />
{shouldShowFieldTypeLegend && <FieldTypeLegend />}
</Flex>
);
};

View File

@ -10,17 +10,15 @@ import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIIconButton from 'common/components/IAIIconButton';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { useNodeVersion } from 'features/nodes/hooks/useNodeVersion';
import { getNeedsUpdate } from 'features/nodes/store/util/nodeUpdate';
import {
InvocationNodeData,
InvocationTemplate,
isInvocationNode,
} from 'features/nodes/types/types';
import { memo } from 'react';
} from 'features/nodes/types/invocation';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { FaSync } from 'react-icons/fa';
import { Node } from 'reactflow';
import NotesTextarea from '../../flow/nodes/Invocation/NotesTextarea';
import ScrollableContent from '../ScrollableContent';
@ -63,12 +61,17 @@ const InspectorDetailsTab = () => {
export default memo(InspectorDetailsTab);
const Content = (props: {
type ContentProps = {
node: Node<InvocationNodeData>;
template: InvocationTemplate;
}) => {
};
const Content = memo(({ node, template }: ContentProps) => {
const { t } = useTranslation();
const { needsUpdate, updateNode } = useNodeVersion(props.node.id);
const needsUpdate = useMemo(
() => getNeedsUpdate(node, template),
[node, template]
);
return (
<Box
sx={{
@ -87,12 +90,12 @@ const Content = (props: {
w: 'full',
}}
>
<EditableNodeTitle nodeId={props.node.data.id} />
<EditableNodeTitle nodeId={node.data.id} />
<HStack>
<FormControl>
<FormLabel>{t('nodes.nodeType')}</FormLabel>
<Text fontSize="sm" fontWeight={600}>
{props.template.title}
{template.title}
</Text>
</FormControl>
<Flex
@ -104,22 +107,16 @@ const Content = (props: {
<FormControl isInvalid={needsUpdate}>
<FormLabel>{t('nodes.nodeVersion')}</FormLabel>
<Text fontSize="sm" fontWeight={600}>
{props.node.data.version}
{node.data.version}
</Text>
</FormControl>
{needsUpdate && (
<IAIIconButton
aria-label={t('nodes.updateNode')}
tooltip={t('nodes.updateNode')}
icon={<FaSync />}
onClick={updateNode}
/>
)}
</Flex>
</HStack>
<NotesTextarea nodeId={props.node.data.id} />
<NotesTextarea nodeId={node.data.id} />
</Flex>
</ScrollableContent>
</Box>
);
};
});
Content.displayName = 'Content';

View File

@ -5,7 +5,7 @@ import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
import { isInvocationNode } from 'features/nodes/types/types';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { memo } from 'react';
import { ImageOutput } from 'services/api/types';
import { AnyResult } from 'services/events/types';

View File

@ -2,14 +2,11 @@ import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { map } from 'lodash-es';
import { keys, map } from 'lodash-es';
import { useMemo } from 'react';
import { isInvocationNode } from '../types/types';
import {
POLYMORPHIC_TYPES,
TYPES_WITH_INPUT_COMPONENTS,
} from '../types/constants';
import { isInvocationNode } from '../types/invocation';
import { getSortedFilteredFieldNames } from '../util/getSortedFilteredFieldNames';
import { TEMPLATE_BUILDER_MAP } from '../util/buildFieldInputTemplate';
export const useAnyOrDirectInputFieldNames = (nodeId: string) => {
const selector = useMemo(
@ -28,8 +25,8 @@ export const useAnyOrDirectInputFieldNames = (nodeId: string) => {
const fields = map(nodeTemplate.inputs).filter(
(field) =>
(['any', 'direct'].includes(field.input) ||
POLYMORPHIC_TYPES.includes(field.type)) &&
TYPES_WITH_INPUT_COMPONENTS.includes(field.type)
field.type.isPolymorphic) &&
keys(TEMPLATE_BUILDER_MAP).includes(field.type.name)
);
return getSortedFilteredFieldNames(fields);
},

View File

@ -3,10 +3,13 @@ import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { useCallback } from 'react';
import { Node, useReactFlow } from 'reactflow';
import { AnyInvocationType } from 'services/events/types';
import { buildNodeData } from '../store/util/buildNodeData';
import {
buildCurrentImageNode,
buildInvocationNode,
buildNotesNode,
} from '../store/util/buildNodeData';
import { DRAG_HANDLE_CLASSNAME, NODE_WIDTH } from '../types/constants';
import { AnyNodeData, InvocationTemplate } from '../types/invocation';
const templatesSelector = createSelector(
[(state: RootState) => state.nodes],
(nodes) => nodes.nodeTemplates
@ -22,7 +25,8 @@ export const useBuildNodeData = () => {
const flow = useReactFlow();
return useCallback(
(type: AnyInvocationType | 'current_image' | 'notes') => {
// string here is "any invocation type"
(type: string | 'current_image' | 'notes'): Node<AnyNodeData> => {
let _x = window.innerWidth / 2;
let _y = window.innerHeight / 2;
@ -41,9 +45,19 @@ export const useBuildNodeData = () => {
y: _y,
});
const template = nodeTemplates[type];
if (type === 'current_image') {
return buildCurrentImageNode(position);
}
return buildNodeData(type, position, template);
if (type === 'notes') {
return buildNotesNode(position);
}
// TODO: Keep track of invocation types so we do not need to cast this
// We know it is safe because the caller of this function gets the `type` arg from the list of invocation templates.
const template = nodeTemplates[type] as InvocationTemplate;
return buildInvocationNode(position, template);
},
[nodeTemplates, flow]
);

View File

@ -2,14 +2,11 @@ import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { map } from 'lodash-es';
import { keys, map } from 'lodash-es';
import { useMemo } from 'react';
import {
POLYMORPHIC_TYPES,
TYPES_WITH_INPUT_COMPONENTS,
} from '../types/constants';
import { isInvocationNode } from '../types/types';
import { isInvocationNode } from '../types/invocation';
import { getSortedFilteredFieldNames } from '../util/getSortedFilteredFieldNames';
import { TEMPLATE_BUILDER_MAP } from '../util/buildFieldInputTemplate';
export const useConnectionInputFieldNames = (nodeId: string) => {
const selector = useMemo(
@ -29,9 +26,8 @@ export const useConnectionInputFieldNames = (nodeId: string) => {
// get the visible fields
const fields = map(nodeTemplate.inputs).filter(
(field) =>
(field.input === 'connection' &&
!POLYMORPHIC_TYPES.includes(field.type)) ||
!TYPES_WITH_INPUT_COMPONENTS.includes(field.type)
(field.input === 'connection' && !field.type.isPolymorphic) ||
!keys(TEMPLATE_BUILDER_MAP).includes(field.type.name)
);
return getSortedFilteredFieldNames(fields);

View File

@ -8,7 +8,7 @@ import { useFieldType } from './useFieldType.ts';
const selectIsConnectionInProgress = createSelector(
stateSelector,
({ nodes }) =>
nodes.currentConnectionFieldType !== null &&
nodes.connectionStartFieldType !== null &&
nodes.connectionStartParams !== null
);

View File

@ -4,7 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { compareVersions } from 'compare-versions';
import { useMemo } from 'react';
import { isInvocationNode } from '../types/types';
import { isInvocationNode } from '../types/invocation';
export const useDoNodeVersionsMatch = (nodeId: string) => {
const selector = useMemo(

View File

@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useMemo } from 'react';
import { isInvocationNode } from '../types/types';
import { isInvocationNode } from '../types/invocation';
export const useDoesInputHaveValue = (nodeId: string, fieldName: string) => {
const selector = useMemo(

View File

@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useMemo } from 'react';
import { isInvocationNode } from '../types/types';
import { isInvocationNode } from '../types/invocation';
export const useEmbedWorkflow = (nodeId: string) => {
const selector = useMemo(

View File

@ -3,9 +3,9 @@ import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useMemo } from 'react';
import { isInvocationNode } from '../types/types';
import { isInvocationNode } from '../types/invocation';
export const useFieldData = (nodeId: string, fieldName: string) => {
export const useFieldInstance = (nodeId: string, fieldName: string) => {
const selector = useMemo(
() =>
createSelector(

View File

@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useMemo } from 'react';
import { isInvocationNode } from '../types/types';
import { isInvocationNode } from '../types/invocation';
export const useFieldInputKind = (nodeId: string, fieldName: string) => {
const selector = useMemo(

View File

@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useMemo } from 'react';
import { isInvocationNode } from '../types/types';
import { isInvocationNode } from '../types/invocation';
export const useFieldLabel = (nodeId: string, fieldName: string) => {
const selector = useMemo(

View File

@ -4,7 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useMemo } from 'react';
import { KIND_MAP } from '../types/constants';
import { isInvocationNode } from '../types/types';
import { isInvocationNode } from '../types/invocation';
export const useFieldTemplate = (
nodeId: string,

View File

@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useMemo } from 'react';
import { isInvocationNode } from '../types/types';
import { isInvocationNode } from '../types/invocation';
import { KIND_MAP } from '../types/constants';
export const useFieldTemplateTitle = (

View File

@ -4,7 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useMemo } from 'react';
import { KIND_MAP } from '../types/constants';
import { isInvocationNode } from '../types/types';
import { isInvocationNode } from '../types/invocation';
export const useFieldType = (
nodeId: string,
@ -20,7 +20,8 @@ export const useFieldType = (
if (!isInvocationNode(node)) {
return;
}
return node?.data[KIND_MAP[kind]][fieldName]?.type;
const field = node.data[KIND_MAP[kind]][fieldName];
return field?.type;
},
defaultSelectorOptions
),

View File

@ -2,7 +2,8 @@ import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { getNeedsUpdate } from './useNodeVersion';
import { getNeedsUpdate } from '../store/util/nodeUpdate';
import { isInvocationNode } from '../types/invocation';
const selector = createSelector(
stateSelector,
@ -10,8 +11,11 @@ const selector = createSelector(
const nodes = state.nodes.nodes;
const templates = state.nodes.nodeTemplates;
const needsUpdate = nodes.some((node) => {
const needsUpdate = nodes.filter(isInvocationNode).some((node) => {
const template = templates[node.data.type];
if (!template) {
return false;
}
return getNeedsUpdate(node, template);
});
return needsUpdate;

View File

@ -4,8 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { some } from 'lodash-es';
import { useMemo } from 'react';
import { IMAGE_FIELDS } from '../types/constants';
import { isInvocationNode } from '../types/types';
import { isInvocationNode } from '../types/invocation';
export const useHasImageOutput = (nodeId: string) => {
const selector = useMemo(
@ -20,8 +19,8 @@ export const useHasImageOutput = (nodeId: string) => {
return some(
node.data.outputs,
(output) =>
IMAGE_FIELDS.includes(output.type) &&
// the image primitive node does not actually save the image, do not show the image-saving checkboxes
output.type.name === 'ImageField' &&
// the image primitive node (node type "image") does not actually save the image, do not show the image-saving checkboxes
node.data.type !== 'image'
);
},

View File

@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useMemo } from 'react';
import { isInvocationNode } from '../types/types';
import { isInvocationNode } from '../types/invocation';
export const useIsIntermediate = (nodeId: string) => {
const selector = useMemo(

View File

@ -4,7 +4,7 @@ import { useCallback } from 'react';
import { Connection, Node, useReactFlow } from 'reactflow';
import { validateSourceAndTargetTypes } from '../store/util/validateSourceAndTargetTypes';
import { getIsGraphAcyclic } from '../store/util/getIsGraphAcyclic';
import { InvocationNodeData } from '../types/types';
import { InvocationNodeData } from '../types/invocation';
/**
* NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts`
@ -34,10 +34,10 @@ export const useIsValidConnection = () => {
return false;
}
const sourceType = sourceNode.data.outputs[sourceHandle]?.type;
const targetType = targetNode.data.inputs[targetHandle]?.type;
const sourceField = sourceNode.data.outputs[sourceHandle];
const targetField = targetNode.data.inputs[targetHandle];
if (!sourceType || !targetType) {
if (!sourceField || !targetField) {
// something has gone terribly awry
return false;
}
@ -70,12 +70,13 @@ export const useIsValidConnection = () => {
return edge.target === target && edge.targetHandle === targetHandle;
}) &&
// except CollectionItem inputs can have multiples
targetType !== 'CollectionItem'
targetField.type.name !== 'CollectionItemField'
) {
return false;
}
if (!validateSourceAndTargetTypes(sourceType, targetType)) {
// Must use the originalType here if it exists
if (!validateSourceAndTargetTypes(sourceField.type, targetField.type)) {
return false;
}

View File

@ -1,17 +1,15 @@
import { ListItem, Text, UnorderedList } from '@chakra-ui/react';
import { useLogger } from 'app/logging/useLogger';
import { useAppDispatch } from 'app/store/storeHooks';
import { parseify } from 'common/util/serialize';
import { zWorkflow } from 'features/nodes/types/types';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { memo, useCallback } from 'react';
import { ZodError } from 'zod';
import { fromZodError, fromZodIssue } from 'zod-validation-error';
import { workflowLoadRequested } from '../store/actions';
import { RefObject, memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { ZodError } from 'zod';
import { fromZodIssue } from 'zod-validation-error';
import { workflowLoadRequested } from '../store/actions';
export const useLoadWorkflowFromFile = () => {
export const useLoadWorkflowFromFile = (resetRef: RefObject<() => void>) => {
const dispatch = useAppDispatch();
const logger = useLogger('nodes');
const { t } = useTranslation();
@ -26,33 +24,10 @@ export const useLoadWorkflowFromFile = () => {
try {
const parsedJSON = JSON.parse(String(rawJSON));
const result = zWorkflow.safeParse(parsedJSON);
if (!result.success) {
const { message } = fromZodError(result.error, {
prefix: t('nodes.workflowValidation'),
});
logger.error({ error: parseify(result.error) }, message);
dispatch(
addToast(
makeToast({
title: t('nodes.unableToValidateWorkflow'),
status: 'error',
duration: 5000,
})
)
);
reader.abort();
return;
}
dispatch(workflowLoadRequested(result.data));
reader.abort();
} catch {
// file reader error
dispatch(workflowLoadRequested(parsedJSON));
} catch (e) {
// There was a problem reading the file
logger.error(t('nodes.unableToLoadWorkflow'));
dispatch(
addToast(
makeToast({
@ -61,12 +36,15 @@ export const useLoadWorkflowFromFile = () => {
})
)
);
reader.abort();
}
};
reader.readAsText(file);
// Reset the file picker internal state so that the same file can be loaded again
resetRef.current?.();
},
[dispatch, logger, t]
[dispatch, logger, resetRef, t]
);
return loadWorkflowFromFile;

View File

@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useMemo } from 'react';
import { isInvocationNode } from '../types/types';
import { isInvocationNode } from '../types/invocation';
export const useNodeLabel = (nodeId: string) => {
const selector = useMemo(

View File

@ -0,0 +1,35 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useMemo } from 'react';
import { isInvocationNode } from '../types/invocation';
import { getNeedsUpdate } from '../store/util/nodeUpdate';
export const useNodeNeedsUpdate = (nodeId: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
const template = nodes.nodeTemplates[node?.data.type ?? ''];
return { node, template };
},
defaultSelectorOptions
),
[nodeId]
);
const { node, template } = useAppSelector(selector);
const needsUpdate = useMemo(
() =>
isInvocationNode(node) && template
? getNeedsUpdate(node, template)
: false,
[node, template]
);
return needsUpdate;
};

View File

@ -3,16 +3,14 @@ import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useMemo } from 'react';
import { AnyInvocationType } from 'services/events/types';
import { InvocationTemplate } from '../types/invocation';
export const useNodeTemplateByType = (
type: AnyInvocationType | 'current_image' | 'notes'
) => {
export const useNodeTemplateByType = (type: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
({ nodes }): InvocationTemplate | undefined => {
const nodeTemplate = nodes.nodeTemplates[type];
return nodeTemplate;
},

View File

@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useMemo } from 'react';
import { isInvocationNode } from '../types/types';
import { isInvocationNode } from '../types/invocation';
export const useNodeTemplateTitle = (nodeId: string) => {
const selector = useMemo(

View File

@ -1,119 +0,0 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { satisfies } from 'compare-versions';
import { cloneDeep, defaultsDeep } from 'lodash-es';
import { useCallback, useMemo } from 'react';
import { Node } from 'reactflow';
import { AnyInvocationType } from 'services/events/types';
import { nodeReplaced } from '../store/nodesSlice';
import { buildNodeData } from '../store/util/buildNodeData';
import {
InvocationNodeData,
InvocationTemplate,
NodeData,
isInvocationNode,
zParsedSemver,
} from '../types/types';
import { useAppToaster } from 'app/components/Toaster';
import { useTranslation } from 'react-i18next';
export const getNeedsUpdate = (
node?: Node<NodeData>,
template?: InvocationTemplate
) => {
if (!isInvocationNode(node) || !template) {
return false;
}
return node.data.version !== template.version;
};
export const getMayUpdateNode = (
node?: Node<NodeData>,
template?: InvocationTemplate
) => {
const needsUpdate = getNeedsUpdate(node, template);
if (
!needsUpdate ||
!isInvocationNode(node) ||
!template ||
!node.data.version
) {
return false;
}
const templateMajor = zParsedSemver.parse(template.version).major;
return satisfies(node.data.version, `^${templateMajor}`);
};
export const updateNode = (
node?: Node<NodeData>,
template?: InvocationTemplate
) => {
const mayUpdate = getMayUpdateNode(node, template);
if (
!mayUpdate ||
!isInvocationNode(node) ||
!template ||
!node.data.version
) {
return;
}
const defaults = buildNodeData(
node.data.type as AnyInvocationType,
node.position,
template
) as Node<InvocationNodeData>;
const clone = cloneDeep(node);
clone.data.version = template.version;
defaultsDeep(clone, defaults);
return clone;
};
export const useNodeVersion = (nodeId: string) => {
const dispatch = useAppDispatch();
const toast = useAppToaster();
const { t } = useTranslation();
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? ''];
return { node, nodeTemplate };
},
defaultSelectorOptions
),
[nodeId]
);
const { node, nodeTemplate } = useAppSelector(selector);
const needsUpdate = useMemo(
() => getNeedsUpdate(node, nodeTemplate),
[node, nodeTemplate]
);
const mayUpdate = useMemo(
() => getMayUpdateNode(node, nodeTemplate),
[node, nodeTemplate]
);
const _updateNode = useCallback(() => {
const needsUpdate = getNeedsUpdate(node, nodeTemplate);
const updatedNode = updateNode(node, nodeTemplate);
if (!updatedNode) {
if (needsUpdate) {
toast({ title: t('nodes.unableToUpdateNodes', { count: 1 }) });
}
return;
}
dispatch(nodeReplaced({ nodeId: updatedNode.id, node: updatedNode }));
}, [dispatch, node, nodeTemplate, t, toast]);
return { needsUpdate, mayUpdate, updateNode: _updateNode };
};

View File

@ -4,7 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { map } from 'lodash-es';
import { useMemo } from 'react';
import { isInvocationNode } from '../types/types';
import { isInvocationNode } from '../types/invocation';
import { getSortedFilteredFieldNames } from '../util/getSortedFilteredFieldNames';
export const useOutputFieldNames = (nodeId: string) => {

View File

@ -0,0 +1,23 @@
import { useTranslation } from 'react-i18next';
import { FieldType } from '../types/field';
import { useMemo } from 'react';
export const useFieldTypeName = (fieldType?: FieldType): string => {
const { t } = useTranslation();
const name = useMemo(() => {
if (!fieldType) {
return '';
}
const { name } = fieldType;
if (fieldType.isCollection) {
return t('nodes.collectionFieldType', { name });
}
if (fieldType.isPolymorphic) {
return t('nodes.polymorphicFieldType', { name });
}
return name;
}, [fieldType, t]);
return name;
};

View File

@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useMemo } from 'react';
import { isInvocationNode } from '../types/types';
import { isInvocationNode } from '../types/invocation';
export const useUseCache = (nodeId: string) => {
const selector = useMemo(

View File

@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useMemo } from 'react';
import { isInvocationNode } from '../types/types';
import { isInvocationNode } from '../types/invocation';
export const useWithWorkflow = (nodeId: string) => {
const selector = useMemo(

View File

@ -1,6 +1,5 @@
import { createAction, isAnyOf } from '@reduxjs/toolkit';
import { Graph } from 'services/api/types';
import { Workflow } from '../types/types';
export const textToImageGraphBuilt = createAction<Graph>(
'nodes/textToImageGraphBuilt'
@ -18,7 +17,7 @@ export const isAnyGraphBuilt = isAnyOf(
nodesGraphBuilt
);
export const workflowLoadRequested = createAction<Workflow>(
export const workflowLoadRequested = createAction<unknown>(
'nodes/workflowLoadRequested'
);

View File

@ -6,7 +6,7 @@ import { NodesState } from './types';
export const nodesPersistDenylist: (keyof NodesState)[] = [
'nodeTemplates',
'connectionStartParams',
'currentConnectionFieldType',
'connectionStartFieldType',
'selectedNodes',
'selectedEdges',
'isReady',

View File

@ -20,7 +20,6 @@ import {
XYPosition,
} from 'reactflow';
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
import { ImageField } from 'services/api/types';
import {
appSocketGeneratorProgress,
appSocketInvocationComplete,
@ -31,60 +30,58 @@ import {
import { v4 as uuidv4 } from 'uuid';
import { DRAG_HANDLE_CLASSNAME } from '../types/constants';
import {
BoardInputFieldValue,
BooleanInputFieldValue,
ColorInputFieldValue,
ControlNetModelInputFieldValue,
CurrentImageNodeData,
EnumInputFieldValue,
BoardFieldValue,
BooleanFieldValue,
ColorFieldValue,
ControlNetModelFieldValue,
EnumFieldValue,
FieldIdentifier,
FloatInputFieldValue,
ImageInputFieldValue,
InputFieldValue,
IntegerInputFieldValue,
InvocationNodeData,
FieldValue,
FloatFieldValue,
ImageFieldValue,
IntegerFieldValue,
IPAdapterModelFieldValue,
LoRAModelFieldValue,
MainModelFieldValue,
SchedulerFieldValue,
SDXLRefinerModelFieldValue,
StringFieldValue,
T2IAdapterModelFieldValue,
VAEModelFieldValue,
} from '../types/field';
import {
AnyNodeData,
InvocationTemplate,
IPAdapterModelInputFieldValue,
isInvocationNode,
isNotesNode,
LoRAModelInputFieldValue,
MainModelInputFieldValue,
NodeExecutionState,
NodeStatus,
NotesNodeData,
SchedulerInputFieldValue,
SDXLRefinerModelInputFieldValue,
StringInputFieldValue,
T2IAdapterModelInputFieldValue,
VaeModelInputFieldValue,
Workflow,
} from '../types/types';
zNodeStatus,
} from '../types/invocation';
import { WorkflowV2 } from '../types/workflow';
import { NodesState } from './types';
import { findUnoccupiedPosition } from './util/findUnoccupiedPosition';
import { findConnectionToValidHandle } from './util/findConnectionToValidHandle';
export const WORKFLOW_FORMAT_VERSION = '1.0.0';
import { findUnoccupiedPosition } from './util/findUnoccupiedPosition';
const initialNodeExecutionState: Omit<NodeExecutionState, 'nodeId'> = {
status: NodeStatus.PENDING,
status: zNodeStatus.enum.PENDING,
error: null,
progress: null,
progressImage: null,
outputs: [],
};
export const initialWorkflow = {
meta: {
version: WORKFLOW_FORMAT_VERSION,
},
const INITIAL_WORKFLOW: WorkflowV2 = {
name: '',
author: '',
description: '',
notes: '',
tags: '',
contact: '',
version: '',
contact: '',
tags: '',
notes: '',
nodes: [],
edges: [],
exposedFields: [],
meta: { version: '2.0.0' },
};
export const initialNodesState: NodesState = {
@ -93,11 +90,10 @@ export const initialNodesState: NodesState = {
nodeTemplates: {},
isReady: false,
connectionStartParams: null,
currentConnectionFieldType: null,
connectionStartFieldType: null,
connectionMade: false,
modifyingEdge: false,
addNewNodePosition: null,
shouldShowFieldTypeLegend: false,
shouldShowMinimapPanel: true,
shouldValidateGraph: true,
shouldAnimateEdges: true,
@ -107,7 +103,7 @@ export const initialNodesState: NodesState = {
nodeOpacity: 1,
selectedNodes: [],
selectedEdges: [],
workflow: initialWorkflow,
workflow: INITIAL_WORKFLOW,
nodeExecutionStates: {},
viewport: { x: 0, y: 0, zoom: 1 },
mouseOverField: null,
@ -117,13 +113,13 @@ export const initialNodesState: NodesState = {
selectionMode: SelectionMode.Partial,
};
type FieldValueAction<T extends InputFieldValue> = PayloadAction<{
type FieldValueAction<T extends FieldValue> = PayloadAction<{
nodeId: string;
fieldName: string;
value: T['value'];
value: T;
}>;
const fieldValueReducer = <T extends InputFieldValue>(
const fieldValueReducer = <T extends FieldValue>(
state: NodesState,
action: FieldValueAction<T>
) => {
@ -161,12 +157,7 @@ const nodesSlice = createSlice({
}
state.nodes[nodeIndex] = action.payload.node;
},
nodeAdded: (
state,
action: PayloadAction<
Node<InvocationNodeData | CurrentImageNodeData | NotesNodeData>
>
) => {
nodeAdded: (state, action: PayloadAction<Node<AnyNodeData>>) => {
const node = action.payload;
const position = findUnoccupiedPosition(
state.nodes,
@ -203,7 +194,7 @@ const nodesSlice = createSlice({
nodeId &&
handleId &&
handleType &&
state.currentConnectionFieldType
state.connectionStartFieldType
) {
const newConnection = findConnectionToValidHandle(
node,
@ -212,7 +203,7 @@ const nodesSlice = createSlice({
nodeId,
handleId,
handleType,
state.currentConnectionFieldType
state.connectionStartFieldType
);
if (newConnection) {
state.edges = addEdge(
@ -224,7 +215,7 @@ const nodesSlice = createSlice({
}
state.connectionStartParams = null;
state.currentConnectionFieldType = null;
state.connectionStartFieldType = null;
},
edgeChangeStarted: (state) => {
state.modifyingEdge = true;
@ -258,10 +249,10 @@ const nodesSlice = createSlice({
handleType === 'source'
? node.data.outputs[handleId]
: node.data.inputs[handleId];
state.currentConnectionFieldType = field?.type ?? null;
state.connectionStartFieldType = field?.type ?? null;
},
connectionMade: (state, action: PayloadAction<Connection>) => {
const fieldType = state.currentConnectionFieldType;
const fieldType = state.connectionStartFieldType;
if (!fieldType) {
return;
}
@ -286,7 +277,7 @@ const nodesSlice = createSlice({
nodeId &&
handleId &&
handleType &&
state.currentConnectionFieldType
state.connectionStartFieldType
) {
const newConnection = findConnectionToValidHandle(
mouseOverNode,
@ -295,7 +286,7 @@ const nodesSlice = createSlice({
nodeId,
handleId,
handleType,
state.currentConnectionFieldType
state.connectionStartFieldType
);
if (newConnection) {
state.edges = addEdge(
@ -306,14 +297,14 @@ const nodesSlice = createSlice({
}
}
state.connectionStartParams = null;
state.currentConnectionFieldType = null;
state.connectionStartFieldType = null;
} else {
state.addNewNodePosition = action.payload.cursorPosition;
state.isAddNodePopoverOpen = true;
}
} else {
state.connectionStartParams = null;
state.currentConnectionFieldType = null;
state.connectionStartFieldType = null;
}
state.modifyingEdge = false;
},
@ -529,12 +520,7 @@ const nodesSlice = createSlice({
state.edges = applyEdgeChanges(edgeChanges, state.edges);
}
},
nodesDeleted: (
state,
action: PayloadAction<
Node<InvocationNodeData | NotesNodeData | CurrentImageNodeData>[]
>
) => {
nodesDeleted: (state, action: PayloadAction<Node<AnyNodeData>[]>) => {
action.payload.forEach((node) => {
state.workflow.exposedFields = state.workflow.exposedFields.filter(
(f) => f.nodeId !== node.id
@ -588,132 +574,94 @@ const nodesSlice = createSlice({
},
fieldStringValueChanged: (
state,
action: FieldValueAction<StringInputFieldValue>
action: FieldValueAction<StringFieldValue>
) => {
fieldValueReducer(state, action);
},
fieldNumberValueChanged: (
state,
action: FieldValueAction<IntegerInputFieldValue | FloatInputFieldValue>
action: FieldValueAction<IntegerFieldValue | FloatFieldValue>
) => {
fieldValueReducer(state, action);
},
fieldBooleanValueChanged: (
state,
action: FieldValueAction<BooleanInputFieldValue>
action: FieldValueAction<BooleanFieldValue>
) => {
fieldValueReducer(state, action);
},
fieldBoardValueChanged: (
state,
action: FieldValueAction<BoardInputFieldValue>
action: FieldValueAction<BoardFieldValue>
) => {
fieldValueReducer(state, action);
},
fieldImageValueChanged: (
state,
action: FieldValueAction<ImageInputFieldValue>
action: FieldValueAction<ImageFieldValue>
) => {
fieldValueReducer(state, action);
},
fieldColorValueChanged: (
state,
action: FieldValueAction<ColorInputFieldValue>
action: FieldValueAction<ColorFieldValue>
) => {
fieldValueReducer(state, action);
},
fieldMainModelValueChanged: (
state,
action: FieldValueAction<MainModelInputFieldValue>
action: FieldValueAction<MainModelFieldValue>
) => {
fieldValueReducer(state, action);
},
fieldRefinerModelValueChanged: (
state,
action: FieldValueAction<SDXLRefinerModelInputFieldValue>
action: FieldValueAction<SDXLRefinerModelFieldValue>
) => {
fieldValueReducer(state, action);
},
fieldVaeModelValueChanged: (
state,
action: FieldValueAction<VaeModelInputFieldValue>
action: FieldValueAction<VAEModelFieldValue>
) => {
fieldValueReducer(state, action);
},
fieldLoRAModelValueChanged: (
state,
action: FieldValueAction<LoRAModelInputFieldValue>
action: FieldValueAction<LoRAModelFieldValue>
) => {
fieldValueReducer(state, action);
},
fieldControlNetModelValueChanged: (
state,
action: FieldValueAction<ControlNetModelInputFieldValue>
action: FieldValueAction<ControlNetModelFieldValue>
) => {
fieldValueReducer(state, action);
},
fieldIPAdapterModelValueChanged: (
state,
action: FieldValueAction<IPAdapterModelInputFieldValue>
action: FieldValueAction<IPAdapterModelFieldValue>
) => {
fieldValueReducer(state, action);
},
fieldT2IAdapterModelValueChanged: (
state,
action: FieldValueAction<T2IAdapterModelInputFieldValue>
action: FieldValueAction<T2IAdapterModelFieldValue>
) => {
fieldValueReducer(state, action);
},
fieldEnumModelValueChanged: (
state,
action: FieldValueAction<EnumInputFieldValue>
action: FieldValueAction<EnumFieldValue>
) => {
fieldValueReducer(state, action);
},
fieldSchedulerValueChanged: (
state,
action: FieldValueAction<SchedulerInputFieldValue>
action: FieldValueAction<SchedulerFieldValue>
) => {
fieldValueReducer(state, action);
},
imageCollectionFieldValueChanged: (
state,
action: PayloadAction<{
nodeId: string;
fieldName: string;
value: ImageField[];
}>
) => {
const { nodeId, fieldName, value } = action.payload;
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
if (nodeIndex === -1) {
return;
}
const node = state.nodes?.[nodeIndex];
if (!isInvocationNode(node)) {
return;
}
const input = node.data?.inputs[fieldName];
if (!input) {
return;
}
const currentValue = cloneDeep(input.value);
if (!currentValue) {
input.value = value;
return;
}
input.value = uniqBy(
(currentValue as ImageField[]).concat(value),
'image_name'
);
},
notesNodeValueChanged: (
state,
action: PayloadAction<{ nodeId: string; value: string }>
@ -726,12 +674,6 @@ const nodesSlice = createSlice({
}
node.data.notes = value;
},
shouldShowFieldTypeLegendChanged: (
state,
action: PayloadAction<boolean>
) => {
state.shouldShowFieldTypeLegend = action.payload;
},
shouldShowMinimapPanelChanged: (state, action: PayloadAction<boolean>) => {
state.shouldShowMinimapPanel = action.payload;
},
@ -745,7 +687,7 @@ const nodesSlice = createSlice({
nodeEditorReset: (state) => {
state.nodes = [];
state.edges = [];
state.workflow = cloneDeep(initialWorkflow);
state.workflow = cloneDeep(INITIAL_WORKFLOW);
},
shouldValidateGraphChanged: (state, action: PayloadAction<boolean>) => {
state.shouldValidateGraph = action.payload;
@ -783,7 +725,7 @@ const nodesSlice = createSlice({
workflowContactChanged: (state, action: PayloadAction<string>) => {
state.workflow.contact = action.payload;
},
workflowLoaded: (state, action: PayloadAction<Workflow>) => {
workflowLoaded: (state, action: PayloadAction<WorkflowV2>) => {
const { nodes, edges, ...workflow } = action.payload;
state.workflow = workflow;
@ -810,7 +752,7 @@ const nodesSlice = createSlice({
}, {});
},
workflowReset: (state) => {
state.workflow = cloneDeep(initialWorkflow);
state.workflow = cloneDeep(INITIAL_WORKFLOW);
},
viewportChanged: (state, action: PayloadAction<Viewport>) => {
state.viewport = action.payload;
@ -942,7 +884,7 @@ const nodesSlice = createSlice({
//Make sure these get reset if we close the popover and haven't selected a node
state.connectionStartParams = null;
state.currentConnectionFieldType = null;
state.connectionStartFieldType = null;
},
addNodePopoverToggled: (state) => {
state.isAddNodePopoverOpen = !state.isAddNodePopoverOpen;
@ -961,14 +903,14 @@ const nodesSlice = createSlice({
const { source_node_id } = action.payload.data;
const node = state.nodeExecutionStates[source_node_id];
if (node) {
node.status = NodeStatus.IN_PROGRESS;
node.status = zNodeStatus.enum.IN_PROGRESS;
}
});
builder.addCase(appSocketInvocationComplete, (state, action) => {
const { source_node_id, result } = action.payload.data;
const nes = state.nodeExecutionStates[source_node_id];
if (nes) {
nes.status = NodeStatus.COMPLETED;
nes.status = zNodeStatus.enum.COMPLETED;
if (nes.progress !== null) {
nes.progress = 1;
}
@ -979,7 +921,7 @@ const nodesSlice = createSlice({
const { source_node_id } = action.payload.data;
const node = state.nodeExecutionStates[source_node_id];
if (node) {
node.status = NodeStatus.FAILED;
node.status = zNodeStatus.enum.FAILED;
node.error = action.payload.data.error;
node.progress = null;
node.progressImage = null;
@ -990,7 +932,7 @@ const nodesSlice = createSlice({
action.payload.data;
const node = state.nodeExecutionStates[source_node_id];
if (node) {
node.status = NodeStatus.IN_PROGRESS;
node.status = zNodeStatus.enum.IN_PROGRESS;
node.progress = (step + 1) / total_steps;
node.progressImage = progress_image ?? null;
}
@ -998,7 +940,7 @@ const nodesSlice = createSlice({
builder.addCase(appSocketQueueItemStatusChanged, (state, action) => {
if (['in_progress'].includes(action.payload.data.queue_item.status)) {
forEach(state.nodeExecutionStates, (nes) => {
nes.status = NodeStatus.PENDING;
nes.status = zNodeStatus.enum.PENDING;
nes.error = null;
nes.progress = null;
nes.progressImage = null;
@ -1037,7 +979,6 @@ export const {
fieldSchedulerValueChanged,
fieldStringValueChanged,
fieldVaeModelValueChanged,
imageCollectionFieldValueChanged,
mouseOverFieldChanged,
mouseOverNodeChanged,
nodeAdded,
@ -1063,7 +1004,6 @@ export const {
selectionPasted,
shouldAnimateEdgesChanged,
shouldColorEdgesChanged,
shouldShowFieldTypeLegendChanged,
shouldShowMinimapPanelChanged,
shouldSnapToGridChanged,
shouldValidateGraphChanged,

View File

@ -6,25 +6,23 @@ import {
Viewport,
XYPosition,
} from 'reactflow';
import { FieldIdentifier, FieldType } from '../types/field';
import {
FieldIdentifier,
FieldType,
AnyNodeData,
InvocationEdgeExtra,
InvocationTemplate,
NodeData,
NodeExecutionState,
Workflow,
} from '../types/types';
} from '../types/invocation';
import { WorkflowV2 } from '../types/workflow';
export type NodesState = {
nodes: Node<NodeData>[];
nodes: Node<AnyNodeData>[];
edges: Edge<InvocationEdgeExtra>[];
nodeTemplates: Record<string, InvocationTemplate>;
connectionStartParams: OnConnectStartParams | null;
currentConnectionFieldType: FieldType | null;
connectionStartFieldType: FieldType | null;
connectionMade: boolean;
modifyingEdge: boolean;
shouldShowFieldTypeLegend: boolean;
shouldShowMinimapPanel: boolean;
shouldValidateGraph: boolean;
shouldAnimateEdges: boolean;
@ -33,13 +31,13 @@ export type NodesState = {
shouldColorEdges: boolean;
selectedNodes: string[];
selectedEdges: string[];
workflow: Omit<Workflow, 'nodes' | 'edges'>;
workflow: Omit<WorkflowV2, 'nodes' | 'edges'>;
nodeExecutionStates: Record<string, NodeExecutionState>;
viewport: Viewport;
isReady: boolean;
mouseOverField: FieldIdentifier | null;
mouseOverNode: string | null;
nodesToCopy: Node<NodeData>[];
nodesToCopy: Node<AnyNodeData>[];
edgesToCopy: Edge<InvocationEdgeExtra>[];
isAddNodePopoverOpen: boolean;
addNewNodePosition: XYPosition | null;

View File

@ -1,50 +1,25 @@
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import {
FieldInputInstance,
FieldOutputInstance,
} from 'features/nodes/types/field';
import {
CurrentImageNodeData,
InputFieldValue,
InvocationNodeData,
InvocationTemplate,
NotesNodeData,
OutputFieldValue,
} from 'features/nodes/types/types';
import { buildInputFieldValue } from 'features/nodes/util/fieldValueBuilders';
} from 'features/nodes/types/invocation';
import { buildFieldInputInstance } from 'features/nodes/util/buildFieldInputInstance';
import { reduce } from 'lodash-es';
import { Node, XYPosition } from 'reactflow';
import { AnyInvocationType } from 'services/events/types';
import { v4 as uuidv4 } from 'uuid';
export const SHARED_NODE_PROPERTIES: Partial<Node> = {
dragHandle: `.${DRAG_HANDLE_CLASSNAME}`,
};
export const buildNodeData = (
type: AnyInvocationType | 'current_image' | 'notes',
position: XYPosition,
template?: InvocationTemplate
):
| Node<CurrentImageNodeData>
| Node<NotesNodeData>
| Node<InvocationNodeData>
| undefined => {
export const buildNotesNode = (position: XYPosition): Node<NotesNodeData> => {
const nodeId = uuidv4();
if (type === 'current_image') {
const node: Node<CurrentImageNodeData> = {
...SHARED_NODE_PROPERTIES,
id: nodeId,
type: 'current_image',
position,
data: {
id: nodeId,
type: 'current_image',
isOpen: true,
label: 'Current Image',
},
};
return node;
}
if (type === 'notes') {
const node: Node<NotesNodeData> = {
...SHARED_NODE_PROPERTIES,
id: nodeId,
@ -58,21 +33,41 @@ export const buildNodeData = (
type: 'notes',
},
};
return node;
}
};
if (template === undefined) {
console.error(`Unable to find template ${type}.`);
return;
}
export const buildCurrentImageNode = (
position: XYPosition
): Node<CurrentImageNodeData> => {
const nodeId = uuidv4();
const node: Node<CurrentImageNodeData> = {
...SHARED_NODE_PROPERTIES,
id: nodeId,
type: 'current_image',
position,
data: {
id: nodeId,
type: 'current_image',
isOpen: true,
label: 'Current Image',
},
};
return node;
};
export const buildInvocationNode = (
position: XYPosition,
template: InvocationTemplate
): Node<InvocationNodeData> => {
const nodeId = uuidv4();
const { type } = template;
const inputs = reduce(
template.inputs,
(inputsAccumulator, inputTemplate, inputName) => {
const fieldId = uuidv4();
const inputFieldValue: InputFieldValue = buildInputFieldValue(
const inputFieldValue: FieldInputInstance = buildFieldInputInstance(
fieldId,
inputTemplate
);
@ -81,7 +76,7 @@ export const buildNodeData = (
return inputsAccumulator;
},
{} as Record<string, InputFieldValue>
{} as Record<string, FieldInputInstance>
);
const outputs = reduce(
@ -89,7 +84,7 @@ export const buildNodeData = (
(outputsAccumulator, outputTemplate, outputName) => {
const fieldId = uuidv4();
const outputFieldValue: OutputFieldValue = {
const outputFieldValue: FieldOutputInstance = {
id: fieldId,
name: outputName,
type: outputTemplate.type,
@ -100,10 +95,10 @@ export const buildNodeData = (
return outputsAccumulator;
},
{} as Record<string, OutputFieldValue>
{} as Record<string, FieldOutputInstance>
);
const invocation: Node<InvocationNodeData> = {
const node: Node<InvocationNodeData> = {
...SHARED_NODE_PROPERTIES,
id: nodeId,
type: 'invocation',
@ -117,11 +112,11 @@ export const buildNodeData = (
isOpen: true,
embedWorkflow: false,
isIntermediate: type === 'save_image' ? false : true,
useCache: template.useCache,
inputs,
outputs,
useCache: template.useCache,
},
};
return invocation;
return node;
};

Some files were not shown because too many files have changed in this diff Show More