mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): add support for custom field types
Node authors may now create their own arbitrary/custom field types. Any pydantic model is supported. Two notes: 1. Your field type's class name must be unique. Suggest prefixing fields with something related to the node pack as a kind of namespace. 2. Custom field types function as connection-only fields. For example, if your custom field has string attributes, you will not get a text input for that attribute when you give a node a field with your custom type. This is the same behaviour as other complex fields that don't have custom UIs in the workflow editor - like, say, a string collection. feat(ui): fix tooltips for custom types We need to hold onto the original type of the field so they don't all just show up as "Unknown". fix(ui): fix ts error with custom fields feat(ui): custom field types connection validation In the initial commit, a custom field's original type was added to the *field templates* only as `originalType`. Custom fields' `type` property was `"Custom"`*. This allowed for type safety throughout the UI logic. *Actually, it was `"Unknown"`, but I changed it to custom for clarity. Connection validation logic, however, uses the *field instance* of the node/field. Like the templates, *field instances* with custom types have their `type` set to `"Custom"`, but they didn't have an `originalType` property. As a result, all custom fields could be connected to all other custom fields. To resolve this, we need to add `originalType` to the *field instances*, then switch the validation logic to use this instead of `type`. This ended up needing a bit of fanagling: - If we make `originalType` a required property on field instances, existing workflows will break during connection validation, because they won't have this property. We'd need a new layer of logic to migrate the workflows, adding the new `originalType` property. While this layer is probably needed anyways, typing `originalType` as optional is much simpler. Workflow migration logic can come layer. (Technically, we could remove all references to field types from the workflow files, and let the templates hold all this information. This feels like a significant change and I'm reluctant to do it now.) - Because `originalType` is optional, anywhere we care about the type of a field, we need to use it over `type`. So there are a number of `field.originalType ?? field.type` expressions. This is a bit of a gotcha, we'll need to remember this in the future. - We use `Array.prototype.includes()` often in the workflow editor, e.g. `COLLECTION_TYPES.includes(type)`. In these cases, the const array is of type `FieldType[]`, and `type` is is `FieldType`. Because we now support custom types, the arg `type` is now widened from `FieldType` to `string`. This causes a TS error. This behaviour is somewhat controversial (see https://github.com/microsoft/TypeScript/issues/14520). These expressions are now rewritten as `COLLECTION_TYPES.some((t) => t === type)` to satisfy TS. It's logically equivalent. fix(ui): typo feat(ui): add CustomCollection and CustomPolymorphic field types feat(ui): add validation for CustomCollection & CustomPolymorphic types - Update connection validation for custom types - Use simple string parsing to determine if a field is a collection or polymorphic type. - No longer need to keep a list of collection and polymorphic types. - Added runtime checks in `baseinvocation.py` to ensure no fields are named in such a way that it could mess up the new parsing chore(ui): remove errant console.log fix(ui): rename 'nodes.currentConnectionFieldType' -> 'nodes.connectionStartFieldType' This was confusingly named and kept tripping me up. Renamed to be consistent with the `reactflow` `ConnectionStartParams` type. fix(ui): fix ts error feat(nodes): add runtime check for custom field names "Custom", "CustomCollection" and "CustomPolymorphic" are reserved field names. chore(ui): add TODO for revising field type names wip refactor fieldtype structured wip refactor field types wip refactor types wip refactor types fix node layout refactor field types chore: mypy organisation organisation organisation fix(nodes): fix field orig_required, field_kind and input statuses feat(nodes): remove broken implementation of default_factory on InputField Use of this could break connection validation due to the difference in node schemas required fields and invoke() required args. Removed entirely for now. It wasn't ever actually used by the system, because all graphs always had values provided for fields where default_factory was used. Also, pydantic is smart enough to not reuse the same object when specifying a default value - it clones the object first. So, the common pattern of `default_factory=list` is extraneous. It can just be `default=[]`. fix(nodes): fix InputField name validation workflow validation validation chore: ruff feat(nodes): fix up baseinvocation comments fix(ui): improve typing & logic of buildFieldInputTemplate improved error handling in parseFieldType fix: back compat for deprecated default_factory and UIType feat(nodes): do not show node packs loaded log if none loaded chore(ui): typegen
This commit is contained in:
parent
0d52430481
commit
86a74e929a
@ -1,11 +1,8 @@
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
# parse_args() must be called before any other imports. if it is not called first, consumers of the config
|
||||
# which are imported/used before parse_args() is called will get the default config values instead of the
|
||||
# values from the command line or config file.
|
||||
import sys
|
||||
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
|
||||
from .services.config import InvokeAIAppConfig
|
||||
@ -22,6 +19,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
import socket
|
||||
from inspect import signature
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
@ -29,7 +27,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
from fastapi.middleware.gzip import GZipMiddleware
|
||||
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.responses import FileResponse, HTMLResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||
@ -58,9 +56,9 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
InputFieldJSONSchemaExtra,
|
||||
OutputFieldJSONSchemaExtra,
|
||||
UIConfigBase,
|
||||
_InputField,
|
||||
_OutputField,
|
||||
)
|
||||
|
||||
if is_mps_available():
|
||||
@ -157,7 +155,11 @@ def custom_openapi() -> dict[str, Any]:
|
||||
|
||||
# Add Node Editor UI helper schemas
|
||||
ui_config_schemas = models_json_schema(
|
||||
[(UIConfigBase, "serialization"), (_InputField, "serialization"), (_OutputField, "serialization")],
|
||||
[
|
||||
(UIConfigBase, "serialization"),
|
||||
(InputFieldJSONSchemaExtra, "serialization"),
|
||||
(OutputFieldJSONSchemaExtra, "serialization"),
|
||||
],
|
||||
ref_template="#/components/schemas/{model}",
|
||||
)
|
||||
for schema_key, ui_config_schema in ui_config_schemas[1]["$defs"].items():
|
||||
@ -165,7 +167,7 @@ def custom_openapi() -> dict[str, Any]:
|
||||
|
||||
# Add a reference to the output type to additionalProperties of the invoker schema
|
||||
for invoker in all_invocations:
|
||||
invoker_name = invoker.__name__
|
||||
invoker_name = invoker.__name__ # type: ignore [attr-defined] # this is a valid attribute
|
||||
output_type = signature(obj=invoker.invoke).return_annotation
|
||||
output_type_title = output_type_titles[output_type.__name__]
|
||||
invoker_schema = openapi_schema["components"]["schemas"][f"{invoker_name}"]
|
||||
|
@ -17,11 +17,15 @@ from pydantic_core import PydanticUndefined
|
||||
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.app.util.metaenum import MetaEnum
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..services.invocation_services import InvocationServices
|
||||
|
||||
logger = InvokeAILogger.get_logger()
|
||||
|
||||
|
||||
class InvalidVersionError(ValueError):
|
||||
pass
|
||||
@ -31,7 +35,7 @@ class InvalidFieldError(TypeError):
|
||||
pass
|
||||
|
||||
|
||||
class Input(str, Enum):
|
||||
class Input(str, Enum, metaclass=MetaEnum):
|
||||
"""
|
||||
The type of input a field accepts.
|
||||
- `Input.Direct`: The field must have its value provided directly, when the invocation and field \
|
||||
@ -45,86 +49,120 @@ class Input(str, Enum):
|
||||
Any = "any"
|
||||
|
||||
|
||||
class UIType(str, Enum):
|
||||
class FieldKind(str, Enum, metaclass=MetaEnum):
|
||||
"""
|
||||
Type hints for the UI.
|
||||
If a field should be provided a data type that does not exactly match the python type of the field, \
|
||||
use this to provide the type that should be used instead. See the node development docs for detail \
|
||||
on adding a new field type, which involves client-side changes.
|
||||
The kind of field.
|
||||
- `Input`: An input field on a node.
|
||||
- `Output`: An output field on a node.
|
||||
- `Internal`: A field which is treated as an input, but cannot be used in node definitions. Metadata is
|
||||
one example. It is provided to nodes via the WithMetadata class, and we want to reserve the field name
|
||||
"metadata" for this on all nodes. `FieldKind` is used to short-circuit the field name validation logic,
|
||||
allowing "metadata" for that field.
|
||||
- `NodeAttribute`: The field is a node attribute. These are fields which are not inputs or outputs,
|
||||
but which are used to store information about the node. For example, the `id` and `type` fields are node
|
||||
attributes.
|
||||
|
||||
The presence of this in `json_schema_extra["field_kind"]` is used when initializing node schemas on app
|
||||
startup, and when generating the OpenAPI schema for the workflow editor.
|
||||
"""
|
||||
|
||||
# region Primitives
|
||||
Boolean = "boolean"
|
||||
Color = "ColorField"
|
||||
Conditioning = "ConditioningField"
|
||||
Control = "ControlField"
|
||||
Float = "float"
|
||||
Image = "ImageField"
|
||||
Integer = "integer"
|
||||
Latents = "LatentsField"
|
||||
String = "string"
|
||||
# endregion
|
||||
Input = "input"
|
||||
Output = "output"
|
||||
Internal = "internal"
|
||||
NodeAttribute = "node_attribute"
|
||||
|
||||
# region Collection Primitives
|
||||
BooleanCollection = "BooleanCollection"
|
||||
ColorCollection = "ColorCollection"
|
||||
ConditioningCollection = "ConditioningCollection"
|
||||
ControlCollection = "ControlCollection"
|
||||
FloatCollection = "FloatCollection"
|
||||
ImageCollection = "ImageCollection"
|
||||
IntegerCollection = "IntegerCollection"
|
||||
LatentsCollection = "LatentsCollection"
|
||||
StringCollection = "StringCollection"
|
||||
# endregion
|
||||
|
||||
# region Polymorphic Primitives
|
||||
BooleanPolymorphic = "BooleanPolymorphic"
|
||||
ColorPolymorphic = "ColorPolymorphic"
|
||||
ConditioningPolymorphic = "ConditioningPolymorphic"
|
||||
ControlPolymorphic = "ControlPolymorphic"
|
||||
FloatPolymorphic = "FloatPolymorphic"
|
||||
ImagePolymorphic = "ImagePolymorphic"
|
||||
IntegerPolymorphic = "IntegerPolymorphic"
|
||||
LatentsPolymorphic = "LatentsPolymorphic"
|
||||
StringPolymorphic = "StringPolymorphic"
|
||||
# endregion
|
||||
class UIType(str, Enum, metaclass=MetaEnum):
|
||||
"""
|
||||
Type hints for the UI for situations in which the field type is not enough to infer the correct UI type.
|
||||
|
||||
# region Models
|
||||
MainModel = "MainModelField"
|
||||
- Model Fields
|
||||
The most common node-author-facing use will be for model fields. Internally, there is no difference
|
||||
between SD-1, SD-2 and SDXL model fields - they all use the class `MainModelField`. To ensure the
|
||||
base-model-specific UI is rendered, use e.g. `ui_type=UIType.SDXLMainModelField` to indicate that
|
||||
the field is an SDXL main model field.
|
||||
|
||||
- Any Field
|
||||
We cannot infer the usage of `typing.Any` via schema parsing, so you *must* use `ui_type=UIType.Any` to
|
||||
indicate that the field accepts any type. Use with caution. This cannot be used on outputs.
|
||||
|
||||
- Scheduler Field
|
||||
Special handling in the UI is needed for this field, which otherwise would be parsed as a plain enum field.
|
||||
|
||||
- Internal Fields
|
||||
Similar to the Any Field, the `collect` and `iterate` nodes make use of `typing.Any`. To facilitate
|
||||
handling these types in the client, we use `UIType._Collection` and `UIType._CollectionItem`. These
|
||||
should not be used by node authors.
|
||||
"""
|
||||
|
||||
# region Model Field Types
|
||||
SDXLMainModel = "SDXLMainModelField"
|
||||
SDXLRefinerModel = "SDXLRefinerModelField"
|
||||
ONNXModel = "ONNXModelField"
|
||||
VaeModel = "VaeModelField"
|
||||
VaeModel = "VAEModelField"
|
||||
LoRAModel = "LoRAModelField"
|
||||
ControlNetModel = "ControlNetModelField"
|
||||
IPAdapterModel = "IPAdapterModelField"
|
||||
UNet = "UNetField"
|
||||
Vae = "VaeField"
|
||||
CLIP = "ClipField"
|
||||
# endregion
|
||||
|
||||
# region Iterate/Collect
|
||||
Collection = "Collection"
|
||||
CollectionItem = "CollectionItem"
|
||||
# region Misc Field Types
|
||||
Scheduler = "SchedulerField"
|
||||
Any = "AnyField"
|
||||
# endregion
|
||||
|
||||
# region Misc
|
||||
Enum = "enum"
|
||||
Scheduler = "Scheduler"
|
||||
WorkflowField = "WorkflowField"
|
||||
IsIntermediate = "IsIntermediate"
|
||||
BoardField = "BoardField"
|
||||
Any = "Any"
|
||||
MetadataItem = "MetadataItem"
|
||||
MetadataItemCollection = "MetadataItemCollection"
|
||||
MetadataItemPolymorphic = "MetadataItemPolymorphic"
|
||||
MetadataDict = "MetadataDict"
|
||||
# region Internal Field Types
|
||||
_Collection = "CollectionField"
|
||||
_CollectionItem = "CollectionItemField"
|
||||
# endregion
|
||||
|
||||
# region DEPRECATED
|
||||
Boolean = "DEPRECATED_Boolean"
|
||||
Color = "DEPRECATED_Color"
|
||||
Conditioning = "DEPRECATED_Conditioning"
|
||||
Control = "DEPRECATED_Control"
|
||||
Float = "DEPRECATED_Float"
|
||||
Image = "DEPRECATED_Image"
|
||||
Integer = "DEPRECATED_Integer"
|
||||
Latents = "DEPRECATED_Latents"
|
||||
String = "DEPRECATED_String"
|
||||
BooleanCollection = "DEPRECATED_BooleanCollection"
|
||||
ColorCollection = "DEPRECATED_ColorCollection"
|
||||
ConditioningCollection = "DEPRECATED_ConditioningCollection"
|
||||
ControlCollection = "DEPRECATED_ControlCollection"
|
||||
FloatCollection = "DEPRECATED_FloatCollection"
|
||||
ImageCollection = "DEPRECATED_ImageCollection"
|
||||
IntegerCollection = "DEPRECATED_IntegerCollection"
|
||||
LatentsCollection = "DEPRECATED_LatentsCollection"
|
||||
StringCollection = "DEPRECATED_StringCollection"
|
||||
BooleanPolymorphic = "DEPRECATED_BooleanPolymorphic"
|
||||
ColorPolymorphic = "DEPRECATED_ColorPolymorphic"
|
||||
ConditioningPolymorphic = "DEPRECATED_ConditioningPolymorphic"
|
||||
ControlPolymorphic = "DEPRECATED_ControlPolymorphic"
|
||||
FloatPolymorphic = "DEPRECATED_FloatPolymorphic"
|
||||
ImagePolymorphic = "DEPRECATED_ImagePolymorphic"
|
||||
IntegerPolymorphic = "DEPRECATED_IntegerPolymorphic"
|
||||
LatentsPolymorphic = "DEPRECATED_LatentsPolymorphic"
|
||||
StringPolymorphic = "DEPRECATED_StringPolymorphic"
|
||||
MainModel = "DEPRECATED_MainModel"
|
||||
UNet = "DEPRECATED_UNet"
|
||||
Vae = "DEPRECATED_Vae"
|
||||
CLIP = "DEPRECATED_CLIP"
|
||||
Collection = "DEPRECATED_Collection"
|
||||
CollectionItem = "DEPRECATED_CollectionItem"
|
||||
Enum = "DEPRECATED_Enum"
|
||||
WorkflowField = "DEPRECATED_WorkflowField"
|
||||
IsIntermediate = "DEPRECATED_IsIntermediate"
|
||||
BoardField = "DEPRECATED_BoardField"
|
||||
MetadataItem = "DEPRECATED_MetadataItem"
|
||||
MetadataItemCollection = "DEPRECATED_MetadataItemCollection"
|
||||
MetadataItemPolymorphic = "DEPRECATED_MetadataItemPolymorphic"
|
||||
MetadataDict = "DEPRECATED_MetadataDict"
|
||||
# endregion
|
||||
|
||||
|
||||
class UIComponent(str, Enum):
|
||||
class UIComponent(str, Enum, metaclass=MetaEnum):
|
||||
"""
|
||||
The type of UI component to use for a field, used to override the default components, which are \
|
||||
The type of UI component to use for a field, used to override the default components, which are
|
||||
inferred from the field type.
|
||||
"""
|
||||
|
||||
@ -133,7 +171,7 @@ class UIComponent(str, Enum):
|
||||
Slider = "slider"
|
||||
|
||||
|
||||
class _InputField(BaseModel):
|
||||
class InputFieldJSONSchemaExtra(BaseModel):
|
||||
"""
|
||||
*DO NOT USE*
|
||||
This helper class is used to tell the client about our custom field attributes via OpenAPI
|
||||
@ -142,12 +180,15 @@ class _InputField(BaseModel):
|
||||
"""
|
||||
|
||||
input: Input
|
||||
ui_hidden: bool
|
||||
ui_type: Optional[UIType]
|
||||
ui_component: Optional[UIComponent]
|
||||
ui_order: Optional[int]
|
||||
ui_choice_labels: Optional[dict[str, str]]
|
||||
item_default: Optional[Any]
|
||||
orig_required: bool
|
||||
field_kind: FieldKind
|
||||
default: Optional[Any] = None
|
||||
orig_default: Optional[Any] = None
|
||||
ui_hidden: bool = False
|
||||
ui_type: Optional[UIType] = None
|
||||
ui_component: Optional[UIComponent] = None
|
||||
ui_order: Optional[int] = None
|
||||
ui_choice_labels: Optional[dict[str, str]] = None
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
@ -155,7 +196,7 @@ class _InputField(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class _OutputField(BaseModel):
|
||||
class OutputFieldJSONSchemaExtra(BaseModel):
|
||||
"""
|
||||
*DO NOT USE*
|
||||
This helper class is used to tell the client about our custom field attributes via OpenAPI
|
||||
@ -163,6 +204,7 @@ class _OutputField(BaseModel):
|
||||
purpose in the backend.
|
||||
"""
|
||||
|
||||
field_kind: FieldKind
|
||||
ui_hidden: bool
|
||||
ui_type: Optional[UIType]
|
||||
ui_order: Optional[int]
|
||||
@ -180,6 +222,7 @@ def get_type(klass: BaseModel) -> str:
|
||||
|
||||
def InputField(
|
||||
# copied from pydantic's Field
|
||||
# TODO: Can we support default_factory?
|
||||
default: Any = _Unset,
|
||||
default_factory: Callable[[], Any] | None = _Unset,
|
||||
title: str | None = _Unset,
|
||||
@ -203,12 +246,11 @@ def InputField(
|
||||
ui_hidden: bool = False,
|
||||
ui_order: Optional[int] = None,
|
||||
ui_choice_labels: Optional[dict[str, str]] = None,
|
||||
item_default: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Creates an input field for an invocation.
|
||||
|
||||
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \
|
||||
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/latest/api/fields/#pydantic.fields.Field) \
|
||||
that adds a few extra parameters to support graph execution and the node editor UI.
|
||||
|
||||
:param Input input: [Input.Any] The kind of input this field requires. \
|
||||
@ -230,26 +272,57 @@ def InputField(
|
||||
|
||||
:param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI.
|
||||
|
||||
: param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \
|
||||
:param int ui_order: [None] Specifies the order in which this field should be rendered in the UI.
|
||||
|
||||
: param bool item_default: [None] Specifies the default item value, if this is a collection input. \
|
||||
Ignored for non-collection fields.
|
||||
:param dict[str, str] ui_choice_labels: [None] Specifies the labels to use for the choices in an enum field.
|
||||
"""
|
||||
|
||||
json_schema_extra_: dict[str, Any] = {
|
||||
"input": input,
|
||||
"ui_type": ui_type,
|
||||
"ui_component": ui_component,
|
||||
"ui_hidden": ui_hidden,
|
||||
"ui_order": ui_order,
|
||||
"item_default": item_default,
|
||||
"ui_choice_labels": ui_choice_labels,
|
||||
"_field_kind": "input",
|
||||
}
|
||||
json_schema_extra_ = InputFieldJSONSchemaExtra(
|
||||
input=input,
|
||||
ui_type=ui_type,
|
||||
ui_component=ui_component,
|
||||
ui_hidden=ui_hidden,
|
||||
ui_order=ui_order,
|
||||
ui_choice_labels=ui_choice_labels,
|
||||
field_kind=FieldKind.Input,
|
||||
orig_required=True,
|
||||
)
|
||||
|
||||
"""
|
||||
There is a conflict between the typing of invocation definitions and the typing of an invocation's
|
||||
`invoke()` function.
|
||||
|
||||
On instantiation of a node, the invocation definition is used to create the python class. At this time,
|
||||
any number of fields may be optional, because they may be provided by connections.
|
||||
|
||||
On calling of `invoke()`, however, those fields may be required.
|
||||
|
||||
For example, consider an ResizeImageInvocation with an `image: ImageField` field.
|
||||
|
||||
`image` is required during the call to `invoke()`, but when the python class is instantiated,
|
||||
the field may not be present. This is fine, because that image field will be provided by a
|
||||
connection from an ancestor node, which outputs an image.
|
||||
|
||||
This means we want to type the `image` field as optional for the node class definition, but required
|
||||
for the `invoke()` function.
|
||||
|
||||
If we use `typing.Optional` in the node class definition, the field will be typed as optional in the
|
||||
`invoke()` method, and we'll have to do a lot of runtime checks to ensure the field is present - or
|
||||
any static type analysis tools will complain.
|
||||
|
||||
To get around this, in node class definitions, we type all fields correctly for the `invoke()` function,
|
||||
but secretly make them optional in `InputField()`. We also store the original required bool and/or default
|
||||
value. When we call `invoke()`, we use this stored information to do an additional check on the class.
|
||||
"""
|
||||
|
||||
if default_factory is not _Unset and default_factory is not None:
|
||||
default = default_factory()
|
||||
del default_factory
|
||||
logger.warn('"default_factory" is not supported, calling it now to set "default"')
|
||||
|
||||
# These are the args we may wish pass to the pydantic `Field()` function
|
||||
field_args = {
|
||||
"default": default,
|
||||
"default_factory": default_factory,
|
||||
"title": title,
|
||||
"description": description,
|
||||
"pattern": pattern,
|
||||
@ -266,70 +339,34 @@ def InputField(
|
||||
"max_length": max_length,
|
||||
}
|
||||
|
||||
"""
|
||||
Invocation definitions have their fields typed correctly for their `invoke()` functions.
|
||||
This typing is often more specific than the actual invocation definition requires, because
|
||||
fields may have values provided only by connections.
|
||||
|
||||
For example, consider an ResizeImageInvocation with an `image: ImageField` field.
|
||||
|
||||
`image` is required during the call to `invoke()`, but when the python class is instantiated,
|
||||
the field may not be present. This is fine, because that image field will be provided by a
|
||||
an ancestor node that outputs the image.
|
||||
|
||||
So we'd like to type that `image` field as `Optional[ImageField]`. If we do that, however, then
|
||||
we need to handle a lot of extra logic in the `invoke()` function to check if the field has a
|
||||
value or not. This is very tedious.
|
||||
|
||||
Ideally, the invocation definition would be able to specify that the field is required during
|
||||
invocation, but optional during instantiation. So the field would be typed as `image: ImageField`,
|
||||
but when calling the `invoke()` function, we raise an error if the field is not present.
|
||||
|
||||
To do this, we need to do a bit of fanagling to make the pydantic field optional, and then do
|
||||
extra validation when calling `invoke()`.
|
||||
|
||||
There is some additional logic here to cleaning create the pydantic field via the wrapper.
|
||||
"""
|
||||
|
||||
# Filter out field args not provided
|
||||
# We only want to pass the args that were provided, otherwise the `Field()`` function won't work as expected
|
||||
provided_args = {k: v for (k, v) in field_args.items() if v is not PydanticUndefined}
|
||||
|
||||
if (default is not PydanticUndefined) and (default_factory is not PydanticUndefined):
|
||||
raise ValueError("Cannot specify both default and default_factory")
|
||||
# Because we are manually making fields optional, we need to store the original required bool for reference later
|
||||
json_schema_extra_.orig_required = default is PydanticUndefined
|
||||
|
||||
# because we are manually making fields optional, we need to store the original required bool for reference later
|
||||
if default is PydanticUndefined and default_factory is PydanticUndefined:
|
||||
json_schema_extra_.update({"orig_required": True})
|
||||
else:
|
||||
json_schema_extra_.update({"orig_required": False})
|
||||
|
||||
# make Input.Any and Input.Connection fields optional, providing None as a default if the field doesn't already have one
|
||||
if (input is Input.Any or input is Input.Connection) and default_factory is PydanticUndefined:
|
||||
# Make Input.Any and Input.Connection fields optional, providing None as a default if the field doesn't already have one
|
||||
if input is Input.Any or input is Input.Connection:
|
||||
default_ = None if default is PydanticUndefined else default
|
||||
provided_args.update({"default": default_})
|
||||
if default is not PydanticUndefined:
|
||||
# before invoking, we'll grab the original default value and set it on the field if the field wasn't provided a value
|
||||
json_schema_extra_.update({"default": default})
|
||||
json_schema_extra_.update({"orig_default": default})
|
||||
elif default is not PydanticUndefined and default_factory is PydanticUndefined:
|
||||
# Before invoking, we'll check for the original default value and set it on the field if the field has no value
|
||||
json_schema_extra_.default = default
|
||||
json_schema_extra_.orig_default = default
|
||||
elif default is not PydanticUndefined:
|
||||
default_ = default
|
||||
provided_args.update({"default": default_})
|
||||
json_schema_extra_.update({"orig_default": default_})
|
||||
elif default_factory is not PydanticUndefined:
|
||||
provided_args.update({"default_factory": default_factory})
|
||||
# TODO: cannot serialize default_factory...
|
||||
# json_schema_extra_.update(dict(orig_default_factory=default_factory))
|
||||
json_schema_extra_.orig_default = default_
|
||||
|
||||
return Field(
|
||||
**provided_args,
|
||||
json_schema_extra=json_schema_extra_,
|
||||
json_schema_extra=json_schema_extra_.model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
|
||||
def OutputField(
|
||||
# copied from pydantic's Field
|
||||
default: Any = _Unset,
|
||||
default_factory: Callable[[], Any] | None = _Unset,
|
||||
title: str | None = _Unset,
|
||||
description: str | None = _Unset,
|
||||
pattern: str | None = _Unset,
|
||||
@ -368,7 +405,6 @@ def OutputField(
|
||||
"""
|
||||
return Field(
|
||||
default=default,
|
||||
default_factory=default_factory,
|
||||
title=title,
|
||||
description=description,
|
||||
pattern=pattern,
|
||||
@ -383,12 +419,12 @@ def OutputField(
|
||||
decimal_places=decimal_places,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
json_schema_extra={
|
||||
"ui_type": ui_type,
|
||||
"ui_hidden": ui_hidden,
|
||||
"ui_order": ui_order,
|
||||
"_field_kind": "output",
|
||||
},
|
||||
json_schema_extra=OutputFieldJSONSchemaExtra(
|
||||
ui_type=ui_type,
|
||||
ui_hidden=ui_hidden,
|
||||
ui_order=ui_order,
|
||||
field_kind=FieldKind.Output,
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
|
||||
@ -538,7 +574,7 @@ class BaseInvocation(ABC, BaseModel):
|
||||
return signature(cls.invoke).return_annotation
|
||||
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel], *args, **kwargs) -> None:
|
||||
# Add the various UI-facing attributes to the schema. These are used to build the invocation templates.
|
||||
uiconfig = getattr(model_class, "UIConfig", None)
|
||||
if uiconfig and hasattr(uiconfig, "title"):
|
||||
@ -604,15 +640,17 @@ class BaseInvocation(ABC, BaseModel):
|
||||
id: str = Field(
|
||||
default_factory=uuid_string,
|
||||
description="The id of this instance of an invocation. Must be unique among all instances of invocations.",
|
||||
json_schema_extra={"_field_kind": "internal"},
|
||||
json_schema_extra={"field_kind": FieldKind.NodeAttribute},
|
||||
)
|
||||
is_intermediate: bool = Field(
|
||||
default=False,
|
||||
description="Whether or not this is an intermediate invocation.",
|
||||
json_schema_extra={"ui_type": UIType.IsIntermediate, "_field_kind": "internal"},
|
||||
json_schema_extra={"ui_type": "IsIntermediate", "field_kind": FieldKind.NodeAttribute},
|
||||
)
|
||||
use_cache: bool = Field(
|
||||
default=True, description="Whether or not to use the cache", json_schema_extra={"_field_kind": "internal"}
|
||||
default=True,
|
||||
description="Whether or not to use the cache",
|
||||
json_schema_extra={"field_kind": FieldKind.NodeAttribute},
|
||||
)
|
||||
|
||||
UIConfig: ClassVar[Type[UIConfigBase]]
|
||||
@ -629,12 +667,15 @@ class BaseInvocation(ABC, BaseModel):
|
||||
TBaseInvocation = TypeVar("TBaseInvocation", bound=BaseInvocation)
|
||||
|
||||
|
||||
RESERVED_INPUT_FIELD_NAMES = {
|
||||
RESERVED_NODE_ATTRIBUTE_FIELD_NAMES = {
|
||||
"id",
|
||||
"is_intermediate",
|
||||
"use_cache",
|
||||
"type",
|
||||
"workflow",
|
||||
}
|
||||
|
||||
RESERVED_INPUT_FIELD_NAMES = {
|
||||
"metadata",
|
||||
}
|
||||
|
||||
@ -653,39 +694,56 @@ def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None
|
||||
"""
|
||||
Validates the fields of an invocation or invocation output:
|
||||
- must not override any pydantic reserved fields
|
||||
- must not end with "Collection" or "Polymorphic" as these are reserved for internal use
|
||||
- must be created via `InputField`, `OutputField`, or be an internal field defined in this file
|
||||
"""
|
||||
for name, field in model_fields.items():
|
||||
if name in RESERVED_PYDANTIC_FIELD_NAMES:
|
||||
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved by pydantic)')
|
||||
|
||||
field_kind = (
|
||||
# _field_kind is defined via InputField(), OutputField() or by one of the internal fields defined in this file
|
||||
field.json_schema_extra.get("_field_kind", None) if field.json_schema_extra else None
|
||||
if not field.annotation:
|
||||
raise InvalidFieldError(f'Invalid field type "{name}" on "{model_type}" (missing annotation)')
|
||||
|
||||
if not isinstance(field.json_schema_extra, dict):
|
||||
raise InvalidFieldError(
|
||||
f'Invalid field definition for "{name}" on "{model_type}" (missing json_schema_extra dict)'
|
||||
)
|
||||
|
||||
field_kind = field.json_schema_extra.get("field_kind", None)
|
||||
|
||||
# must have a field_kind
|
||||
if field_kind is None or field_kind not in {"input", "output", "internal"}:
|
||||
if not isinstance(field_kind, FieldKind):
|
||||
raise InvalidFieldError(
|
||||
f'Invalid field definition for "{name}" on "{model_type}" (maybe it\'s not an InputField or OutputField?)'
|
||||
)
|
||||
|
||||
if field_kind == "input" and name in RESERVED_INPUT_FIELD_NAMES:
|
||||
if field_kind is FieldKind.Input and (
|
||||
name in RESERVED_NODE_ATTRIBUTE_FIELD_NAMES or name in RESERVED_INPUT_FIELD_NAMES
|
||||
):
|
||||
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved input field name)')
|
||||
|
||||
if field_kind == "output" and name in RESERVED_OUTPUT_FIELD_NAMES:
|
||||
if field_kind is FieldKind.Output and name in RESERVED_OUTPUT_FIELD_NAMES:
|
||||
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved output field name)')
|
||||
|
||||
# internal fields *must* be in the reserved list
|
||||
if (
|
||||
field_kind == "internal"
|
||||
and name not in RESERVED_INPUT_FIELD_NAMES
|
||||
and name not in RESERVED_OUTPUT_FIELD_NAMES
|
||||
):
|
||||
if (field_kind is FieldKind.Internal) and name not in RESERVED_INPUT_FIELD_NAMES:
|
||||
raise InvalidFieldError(
|
||||
f'Invalid field name "{name}" on "{model_type}" (internal field without reserved name)'
|
||||
)
|
||||
|
||||
# node attribute fields *must* be in the reserved list
|
||||
if (
|
||||
field_kind is FieldKind.NodeAttribute
|
||||
and name not in RESERVED_NODE_ATTRIBUTE_FIELD_NAMES
|
||||
and name not in RESERVED_OUTPUT_FIELD_NAMES
|
||||
):
|
||||
raise InvalidFieldError(
|
||||
f'Invalid field name "{name}" on "{model_type}" (node attribute field without reserved name)'
|
||||
)
|
||||
|
||||
ui_type = field.json_schema_extra.get("ui_type", None)
|
||||
if isinstance(ui_type, str) and ui_type.startswith("DEPRECATED_"):
|
||||
logger.warn(f"\"UIType.{ui_type.split('_')[-1]}\" is deprecated, ignoring")
|
||||
field.json_schema_extra.pop("ui_type")
|
||||
return None
|
||||
|
||||
|
||||
@ -749,7 +807,7 @@ def invocation(
|
||||
|
||||
invocation_type_annotation = Literal[invocation_type] # type: ignore
|
||||
invocation_type_field = Field(
|
||||
title="type", default=invocation_type, json_schema_extra={"_field_kind": "internal"}
|
||||
title="type", default=invocation_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute}
|
||||
)
|
||||
|
||||
docstring = cls.__doc__
|
||||
@ -795,7 +853,9 @@ def invocation_output(
|
||||
# Add the output type to the model.
|
||||
|
||||
output_type_annotation = Literal[output_type] # type: ignore
|
||||
output_type_field = Field(title="type", default=output_type, json_schema_extra={"_field_kind": "internal"})
|
||||
output_type_field = Field(
|
||||
title="type", default=output_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute}
|
||||
)
|
||||
|
||||
docstring = cls.__doc__
|
||||
cls = create_model(
|
||||
@ -827,7 +887,7 @@ WorkflowFieldValidator = TypeAdapter(WorkflowField)
|
||||
|
||||
class WithWorkflow(BaseModel):
|
||||
workflow: Optional[WorkflowField] = Field(
|
||||
default=None, description=FieldDescriptions.workflow, json_schema_extra={"_field_kind": "internal"}
|
||||
default=None, description=FieldDescriptions.workflow, json_schema_extra={"field_kind": FieldKind.NodeAttribute}
|
||||
)
|
||||
|
||||
|
||||
@ -845,5 +905,11 @@ MetadataFieldValidator = TypeAdapter(MetadataField)
|
||||
|
||||
class WithMetadata(BaseModel):
|
||||
metadata: Optional[MetadataField] = Field(
|
||||
default=None, description=FieldDescriptions.metadata, json_schema_extra={"_field_kind": "internal"}
|
||||
default=None,
|
||||
description=FieldDescriptions.metadata,
|
||||
json_schema_extra=InputFieldJSONSchemaExtra(
|
||||
field_kind=FieldKind.Internal,
|
||||
input=Input.Connection,
|
||||
orig_required=False,
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
|
@ -5,7 +5,7 @@ import numpy as np
|
||||
from pydantic import ValidationInfo, field_validator
|
||||
|
||||
from invokeai.app.invocations.primitives import IntegerCollectionOutput
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from invokeai.app.util.misc import SEED_MAX
|
||||
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||
|
||||
@ -55,7 +55,7 @@ class RangeOfSizeInvocation(BaseInvocation):
|
||||
title="Random Range",
|
||||
tags=["range", "integer", "random", "collection"],
|
||||
category="collections",
|
||||
version="1.0.0",
|
||||
version="1.0.1",
|
||||
use_cache=False,
|
||||
)
|
||||
class RandomRangeInvocation(BaseInvocation):
|
||||
@ -65,10 +65,10 @@ class RandomRangeInvocation(BaseInvocation):
|
||||
high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value")
|
||||
size: int = InputField(default=1, description="The number of values to generate")
|
||||
seed: int = InputField(
|
||||
default=0,
|
||||
ge=0,
|
||||
le=SEED_MAX,
|
||||
description="The seed for the RNG (omit for random)",
|
||||
default_factory=get_random_seed,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
|
||||
|
@ -39,6 +39,8 @@ for d in Path(__file__).parent.iterdir():
|
||||
logger.warn(f"Could not load {init}")
|
||||
continue
|
||||
|
||||
logger.info(f"Loading node pack {spec.name}")
|
||||
|
||||
module = module_from_spec(spec)
|
||||
sys.modules[spec.name] = module
|
||||
spec.loader.exec_module(module)
|
||||
@ -47,5 +49,5 @@ for d in Path(__file__).parent.iterdir():
|
||||
|
||||
del init, module_name
|
||||
|
||||
|
||||
logger.info(f"Loaded {loaded_count} modules from {Path(__file__).parent}")
|
||||
if loaded_count > 0:
|
||||
logger.info(f"Loaded {loaded_count} node packs from {Path(__file__).parent}")
|
||||
|
@ -8,7 +8,7 @@ from PIL import Image, ImageOps
|
||||
|
||||
from invokeai.app.invocations.primitives import ColorField, ImageField, ImageOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from invokeai.app.util.misc import SEED_MAX
|
||||
from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint
|
||||
from invokeai.backend.image_util.lama import LaMA
|
||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||
@ -154,17 +154,17 @@ class InfillColorInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
)
|
||||
|
||||
|
||||
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.1.0")
|
||||
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.1.1")
|
||||
class InfillTileInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
"""Infills transparent areas of an image with tiles of the image"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
tile_size: int = InputField(default=32, ge=1, description="The tile size (px)")
|
||||
seed: int = InputField(
|
||||
default=0,
|
||||
ge=0,
|
||||
le=SEED_MAX,
|
||||
description="The seed to use for tile generation (omit for random)",
|
||||
default_factory=get_random_seed,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
|
@ -11,7 +11,6 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIType,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@ -67,7 +66,7 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
|
||||
# weight: float = InputField(default=1.0, description="The weight of the IP-Adapter.", ui_type=UIType.Float)
|
||||
weight: Union[float, List[float]] = InputField(
|
||||
default=1, ge=-1, description="The weight given to the IP-Adapter", ui_type=UIType.Float, title="Weight"
|
||||
default=1, ge=-1, description="The weight given to the IP-Adapter", title="Weight"
|
||||
)
|
||||
|
||||
begin_step_percent: float = InputField(
|
||||
|
@ -274,7 +274,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
ui_order=7,
|
||||
)
|
||||
latents: Optional[LatentsField] = InputField(
|
||||
default=None, description=FieldDescriptions.latents, input=Input.Connection
|
||||
default=None,
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
ui_order=4,
|
||||
)
|
||||
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
||||
default=None,
|
||||
|
@ -14,7 +14,6 @@ from .baseinvocation import (
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIType,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@ -395,7 +394,6 @@ class VaeLoaderInvocation(BaseInvocation):
|
||||
vae_model: VAEModelField = InputField(
|
||||
description=FieldDescriptions.vae_model,
|
||||
input=Input.Direct,
|
||||
ui_type=UIType.VaeModel,
|
||||
title="VAE",
|
||||
)
|
||||
|
||||
|
@ -6,7 +6,7 @@ from pydantic import field_validator
|
||||
|
||||
from invokeai.app.invocations.latent import LatentsField
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from invokeai.app.util.misc import SEED_MAX
|
||||
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
from .baseinvocation import (
|
||||
@ -83,16 +83,16 @@ def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int):
|
||||
title="Noise",
|
||||
tags=["latents", "noise"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
version="1.0.1",
|
||||
)
|
||||
class NoiseInvocation(BaseInvocation):
|
||||
"""Generates latent noise."""
|
||||
|
||||
seed: int = InputField(
|
||||
default=0,
|
||||
ge=0,
|
||||
le=SEED_MAX,
|
||||
description=FieldDescriptions.seed,
|
||||
default_factory=get_random_seed,
|
||||
)
|
||||
width: int = InputField(
|
||||
default=512,
|
||||
|
@ -62,12 +62,12 @@ class BooleanInvocation(BaseInvocation):
|
||||
title="Boolean Collection Primitive",
|
||||
tags=["primitives", "boolean", "collection"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
version="1.0.1",
|
||||
)
|
||||
class BooleanCollectionInvocation(BaseInvocation):
|
||||
"""A collection of boolean primitive values"""
|
||||
|
||||
collection: list[bool] = InputField(default_factory=list, description="The collection of boolean values")
|
||||
collection: list[bool] = InputField(default=[], description="The collection of boolean values")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> BooleanCollectionOutput:
|
||||
return BooleanCollectionOutput(collection=self.collection)
|
||||
@ -111,12 +111,12 @@ class IntegerInvocation(BaseInvocation):
|
||||
title="Integer Collection Primitive",
|
||||
tags=["primitives", "integer", "collection"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
version="1.0.1",
|
||||
)
|
||||
class IntegerCollectionInvocation(BaseInvocation):
|
||||
"""A collection of integer primitive values"""
|
||||
|
||||
collection: list[int] = InputField(default_factory=list, description="The collection of integer values")
|
||||
collection: list[int] = InputField(default=[], description="The collection of integer values")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
|
||||
return IntegerCollectionOutput(collection=self.collection)
|
||||
@ -158,12 +158,12 @@ class FloatInvocation(BaseInvocation):
|
||||
title="Float Collection Primitive",
|
||||
tags=["primitives", "float", "collection"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
version="1.0.1",
|
||||
)
|
||||
class FloatCollectionInvocation(BaseInvocation):
|
||||
"""A collection of float primitive values"""
|
||||
|
||||
collection: list[float] = InputField(default_factory=list, description="The collection of float values")
|
||||
collection: list[float] = InputField(default=[], description="The collection of float values")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||
return FloatCollectionOutput(collection=self.collection)
|
||||
@ -205,12 +205,12 @@ class StringInvocation(BaseInvocation):
|
||||
title="String Collection Primitive",
|
||||
tags=["primitives", "string", "collection"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
version="1.0.1",
|
||||
)
|
||||
class StringCollectionInvocation(BaseInvocation):
|
||||
"""A collection of string primitive values"""
|
||||
|
||||
collection: list[str] = InputField(default_factory=list, description="The collection of string values")
|
||||
collection: list[str] = InputField(default=[], description="The collection of string values")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
|
||||
return StringCollectionOutput(collection=self.collection)
|
||||
@ -467,13 +467,13 @@ class ConditioningInvocation(BaseInvocation):
|
||||
title="Conditioning Collection Primitive",
|
||||
tags=["primitives", "conditioning", "collection"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
version="1.0.1",
|
||||
)
|
||||
class ConditioningCollectionInvocation(BaseInvocation):
|
||||
"""A collection of conditioning tensor primitive values"""
|
||||
|
||||
collection: list[ConditioningField] = InputField(
|
||||
default_factory=list,
|
||||
default=[],
|
||||
description="The collection of conditioning tensors",
|
||||
)
|
||||
|
||||
|
@ -9,7 +9,6 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIType,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@ -59,7 +58,7 @@ class T2IAdapterInvocation(BaseInvocation):
|
||||
ui_order=-1,
|
||||
)
|
||||
weight: Union[float, list[float]] = InputField(
|
||||
default=1, ge=0, description="The weight given to the T2I-Adapter", ui_type=UIType.Float, title="Weight"
|
||||
default=1, ge=0, description="The weight given to the T2I-Adapter", title="Weight"
|
||||
)
|
||||
begin_step_percent: float = InputField(
|
||||
default=0, ge=-1, le=2, description="When the T2I-Adapter is first applied (% of total steps)"
|
||||
|
@ -205,7 +205,7 @@ class IterateInvocationOutput(BaseInvocationOutput):
|
||||
"""Used to connect iteration outputs. Will be expanded to a specific output."""
|
||||
|
||||
item: Any = OutputField(
|
||||
description="The item being iterated over", title="Collection Item", ui_type=UIType.CollectionItem
|
||||
description="The item being iterated over", title="Collection Item", ui_type=UIType._CollectionItem
|
||||
)
|
||||
|
||||
|
||||
@ -215,7 +215,7 @@ class IterateInvocation(BaseInvocation):
|
||||
"""Iterates over a list of items"""
|
||||
|
||||
collection: list[Any] = InputField(
|
||||
description="The list of items to iterate over", default_factory=list, ui_type=UIType.Collection
|
||||
description="The list of items to iterate over", default=[], ui_type=UIType._Collection
|
||||
)
|
||||
index: int = InputField(description="The index, will be provided on executed iterators", default=0, ui_hidden=True)
|
||||
|
||||
@ -227,7 +227,7 @@ class IterateInvocation(BaseInvocation):
|
||||
@invocation_output("collect_output")
|
||||
class CollectInvocationOutput(BaseInvocationOutput):
|
||||
collection: list[Any] = OutputField(
|
||||
description="The collection of input items", title="Collection", ui_type=UIType.Collection
|
||||
description="The collection of input items", title="Collection", ui_type=UIType._Collection
|
||||
)
|
||||
|
||||
|
||||
@ -238,12 +238,12 @@ class CollectInvocation(BaseInvocation):
|
||||
item: Optional[Any] = InputField(
|
||||
default=None,
|
||||
description="The item to collect (all inputs must be of the same type)",
|
||||
ui_type=UIType.CollectionItem,
|
||||
ui_type=UIType._CollectionItem,
|
||||
title="Collection Item",
|
||||
input=Input.Connection,
|
||||
)
|
||||
collection: list[Any] = InputField(
|
||||
description="The collection, will be provided on execution", default_factory=list, ui_hidden=True
|
||||
description="The collection, will be provided on execution", default=[], ui_hidden=True
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> CollectInvocationOutput:
|
||||
|
@ -805,6 +805,8 @@
|
||||
"clipField": "Clip",
|
||||
"clipFieldDescription": "Tokenizer and text_encoder submodels.",
|
||||
"collection": "Collection",
|
||||
"collectionFieldType": "{{name}} Collection",
|
||||
"polymorphicFieldType": "{{name}} Polymorphic",
|
||||
"collectionDescription": "TODO",
|
||||
"collectionItem": "Collection Item",
|
||||
"collectionItemDescription": "TODO",
|
||||
@ -891,10 +893,15 @@
|
||||
"mainModelField": "Model",
|
||||
"mainModelFieldDescription": "TODO",
|
||||
"maybeIncompatible": "May be Incompatible With Installed",
|
||||
"mismatchedVersion": "Has Mismatched Version",
|
||||
"mismatchedVersion": "Invalid node: node {{node}} of type {{type}} has mismatched version (try updating?)",
|
||||
"missingCanvaInitImage": "Missing canvas init image",
|
||||
"missingCanvaInitMaskImages": "Missing canvas init and mask images",
|
||||
"missingTemplate": "Missing Template",
|
||||
"missingTemplate": "Invalid node: node {{node}} of type {{type}} missing template (not installed?)",
|
||||
"sourceNodeDoesNotExist": "Invalid edge: source/output node {{node}} does not exist",
|
||||
"targetNodeDoesNotExist": "Invalid edge: target/input node {{node}} does not exist",
|
||||
"sourceNodeFieldDoesNotExist": "Invalid edge: source/output field {{node}}.{{field}} does not exist",
|
||||
"targetNodeFieldDoesNotExist": "Invalid edge: target/input field {{node}}.{{field}} does not exist",
|
||||
"deletedInvalidEdge": "Deleted invalid edge {{source}} -> {{target}}",
|
||||
"noConnectionData": "No connection data",
|
||||
"noConnectionInProgress": "No connection in progress",
|
||||
"node": "Node",
|
||||
@ -954,10 +961,17 @@
|
||||
"stringDescription": "Strings are text.",
|
||||
"stringPolymorphic": "String Polymorphic",
|
||||
"stringPolymorphicDescription": "A collection of strings.",
|
||||
"unableToLoadWorkflow": "Unable to Validate Workflow",
|
||||
"unableToLoadWorkflow": "Unable to Load Workflow",
|
||||
"unableToParseEdge": "Unable to parse edge",
|
||||
"unableToParseNode": "Unable to parse node",
|
||||
"unableToUpdateNode": "Unable to update node",
|
||||
"unableToValidateWorkflow": "Unable to Validate Workflow",
|
||||
"unknownErrorValidatingWorkflow": "Unknown error validating workflow",
|
||||
"inputFieldTypeParseError": "Unable to parse type of input field {{node}}.{{field}} ({{message}})",
|
||||
"outputFieldTypeParseError": "Unable to parse type of output field {{node}}.{{field}} ({{message}})",
|
||||
"unableToExtractSchemaNameFromRef": "unable to extract schema name from ref",
|
||||
"unsupportedArrayItemType": "unsupported array item type \"{{type}}\"",
|
||||
"unableToParseFieldType": "unable to parse field type",
|
||||
"uNetField": "UNet",
|
||||
"uNetFieldDescription": "UNet submodel.",
|
||||
"unhandledInputProperty": "Unhandled input property",
|
||||
@ -971,8 +985,9 @@
|
||||
"unkownInvocation": "Unknown Invocation type",
|
||||
"unknownOutput": "Unknown output",
|
||||
"updateNode": "Update Node",
|
||||
"updateAllNodes": "Update All Nodes",
|
||||
"updateApp": "Update App",
|
||||
"updateAllNodes": "Update All Nodes",
|
||||
"allNodesUpdated": "All Nodes Updated",
|
||||
"unableToUpdateNodes_one": "Unable to update {{count}} node",
|
||||
"unableToUpdateNodes_other": "Unable to update {{count}} nodes",
|
||||
"vaeField": "Vae",
|
||||
@ -981,6 +996,8 @@
|
||||
"vaeModelFieldDescription": "TODO",
|
||||
"validateConnections": "Validate Connections and Graph",
|
||||
"validateConnectionsHelp": "Prevent invalid connections from being made, and invalid graphs from being invoked",
|
||||
"unableToGetWorkflowVersion": "Unable to get workflow schema version",
|
||||
"unrecognizedWorkflowVersion": "Unrecognized workflow schema version {{version}}",
|
||||
"version": "Version",
|
||||
"versionUnknown": " Version Unknown",
|
||||
"workflow": "Workflow",
|
||||
|
@ -71,7 +71,7 @@ import { addSocketUnsubscribedEventListener as addSocketUnsubscribedListener } f
|
||||
import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSaved';
|
||||
import { addTabChangedListener } from './listeners/tabChanged';
|
||||
import { addUpscaleRequestedListener } from './listeners/upscaleRequested';
|
||||
import { addWorkflowLoadedListener } from './listeners/workflowLoaded';
|
||||
import { addWorkflowLoadRequestedListener } from './listeners/workflowLoadRequested';
|
||||
import { addUpdateAllNodesRequestedListener } from './listeners/updateAllNodesRequested';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
@ -178,7 +178,7 @@ addBoardIdSelectedListener();
|
||||
addReceivedOpenAPISchemaListener();
|
||||
|
||||
// Workflows
|
||||
addWorkflowLoadedListener();
|
||||
addWorkflowLoadRequestedListener();
|
||||
addUpdateAllNodesRequestedListener();
|
||||
|
||||
// DND
|
||||
|
@ -12,10 +12,10 @@ import { addToast } from 'features/system/store/systemSlice';
|
||||
import { t } from 'i18next';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import { isImageOutput } from 'services/api/guards';
|
||||
import { BatchConfig, ImageDTO } from 'services/api/types';
|
||||
import { socketInvocationComplete } from 'services/events/actions';
|
||||
import { startAppListening } from '..';
|
||||
import { isImageOutput } from 'features/nodes/types/common';
|
||||
|
||||
export const addControlNetImageProcessedListener = () => {
|
||||
startAppListening({
|
||||
|
@ -5,19 +5,20 @@ import {
|
||||
controlAdapterProcessedImageChanged,
|
||||
selectControlAdapterAll,
|
||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
|
||||
import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions';
|
||||
import { isModalOpenChanged } from 'features/deleteImageModal/store/slice';
|
||||
import { selectListImagesBaseQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/types';
|
||||
import { isImageFieldInputInstance } from 'features/nodes/types/field';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
||||
import { clamp, forEach } from 'lodash-es';
|
||||
import { api } from 'services/api';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { imagesAdapter } from 'services/api/util';
|
||||
import { startAppListening } from '..';
|
||||
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
|
||||
|
||||
export const addRequestedSingleImageDeletionListener = () => {
|
||||
startAppListening({
|
||||
@ -121,7 +122,7 @@ export const addRequestedSingleImageDeletionListener = () => {
|
||||
|
||||
forEach(node.data.inputs, (input) => {
|
||||
if (
|
||||
input.type === 'ImageField' &&
|
||||
isImageFieldInputInstance(input) &&
|
||||
input.value?.image_name === imageDTO.image_name
|
||||
) {
|
||||
dispatch(
|
||||
@ -241,7 +242,7 @@ export const addRequestedMultipleImageDeletionListener = () => {
|
||||
|
||||
forEach(node.data.inputs, (input) => {
|
||||
if (
|
||||
input.type === 'ImageField' &&
|
||||
isImageFieldInputInstance(input) &&
|
||||
input.value?.image_name === imageDTO.image_name
|
||||
) {
|
||||
dispatch(
|
||||
|
@ -12,12 +12,12 @@ import {
|
||||
setWidth,
|
||||
vaeSelected,
|
||||
} from 'features/parameters/store/generationSlice';
|
||||
import { zMainOrOnnxModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { t } from 'i18next';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { startAppListening } from '..';
|
||||
import { zParameterModel } from 'features/parameters/types/parameterSchemas';
|
||||
|
||||
export const addModelSelectedListener = () => {
|
||||
startAppListening({
|
||||
@ -26,7 +26,7 @@ export const addModelSelectedListener = () => {
|
||||
const log = logger('models');
|
||||
|
||||
const state = getState();
|
||||
const result = zMainOrOnnxModel.safeParse(action.payload);
|
||||
const result = zParameterModel.safeParse(action.payload);
|
||||
|
||||
if (!result.success) {
|
||||
log.error(
|
||||
|
@ -11,9 +11,9 @@ import {
|
||||
vaeSelected,
|
||||
} from 'features/parameters/store/generationSlice';
|
||||
import {
|
||||
zMainOrOnnxModel,
|
||||
zSDXLRefinerModel,
|
||||
zVaeModel,
|
||||
zParameterModel,
|
||||
zParameterSDXLRefinerModel,
|
||||
zParameterVAEModel,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import {
|
||||
refinerModelChanged,
|
||||
@ -67,7 +67,7 @@ export const addModelsLoadedListener = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
const result = zMainOrOnnxModel.safeParse(models[0]);
|
||||
const result = zParameterModel.safeParse(models[0]);
|
||||
|
||||
if (!result.success) {
|
||||
log.error(
|
||||
@ -119,7 +119,7 @@ export const addModelsLoadedListener = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
const result = zSDXLRefinerModel.safeParse(models[0]);
|
||||
const result = zParameterSDXLRefinerModel.safeParse(models[0]);
|
||||
|
||||
if (!result.success) {
|
||||
log.error(
|
||||
@ -170,7 +170,7 @@ export const addModelsLoadedListener = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
const result = zVaeModel.safeParse(firstModel);
|
||||
const result = zParameterVAEModel.safeParse(firstModel);
|
||||
|
||||
if (!result.success) {
|
||||
log.error(
|
||||
|
@ -15,6 +15,7 @@ export const addReceivedOpenAPISchemaListener = () => {
|
||||
|
||||
log.debug({ schemaJSON }, 'Received OpenAPI schema');
|
||||
const { nodesAllowlist, nodesDenylist } = getState().config;
|
||||
|
||||
const nodeTemplates = parseSchema(
|
||||
schemaJSON,
|
||||
nodesAllowlist,
|
||||
|
@ -13,13 +13,13 @@ import {
|
||||
} from 'features/nodes/util/graphBuilders/constants';
|
||||
import { boardsApi } from 'services/api/endpoints/boards';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { isImageOutput } from 'services/api/guards';
|
||||
import { imagesAdapter } from 'services/api/util';
|
||||
import {
|
||||
appSocketInvocationComplete,
|
||||
socketInvocationComplete,
|
||||
} from 'services/events/actions';
|
||||
import { startAppListening } from '../..';
|
||||
import { isImageOutput } from 'features/nodes/types/common';
|
||||
|
||||
// These nodes output an image, but do not actually *save* an image, so we don't want to handle the gallery logic on them
|
||||
const nodeTypeDenylist = ['load_image', 'image'];
|
||||
|
@ -1,14 +1,16 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { updateAllNodesRequested } from 'features/nodes/store/actions';
|
||||
import { nodeReplaced } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
getNeedsUpdate,
|
||||
updateNode,
|
||||
} from 'features/nodes/hooks/useNodeVersion';
|
||||
import { updateAllNodesRequested } from 'features/nodes/store/actions';
|
||||
import { nodeReplaced } from 'features/nodes/store/nodesSlice';
|
||||
import { startAppListening } from '..';
|
||||
import { logger } from 'app/logging/logger';
|
||||
} from 'features/nodes/store/util/nodeUpdate';
|
||||
import { NodeUpdateError } from 'features/nodes/types/error';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { t } from 'i18next';
|
||||
import { startAppListening } from '..';
|
||||
|
||||
export const addUpdateAllNodesRequestedListener = () => {
|
||||
startAppListening({
|
||||
@ -20,22 +22,31 @@ export const addUpdateAllNodesRequestedListener = () => {
|
||||
|
||||
let unableToUpdateCount = 0;
|
||||
|
||||
nodes.forEach((node) => {
|
||||
nodes.filter(isInvocationNode).forEach((node) => {
|
||||
const template = templates[node.data.type];
|
||||
const needsUpdate = getNeedsUpdate(node, template);
|
||||
const updatedNode = updateNode(node, template);
|
||||
if (!updatedNode) {
|
||||
if (needsUpdate) {
|
||||
if (!template) {
|
||||
unableToUpdateCount++;
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (!getNeedsUpdate(node, template)) {
|
||||
// No need to increment the count here, since we're not actually updating
|
||||
return;
|
||||
}
|
||||
try {
|
||||
const updatedNode = updateNode(node, template);
|
||||
dispatch(nodeReplaced({ nodeId: updatedNode.id, node: updatedNode }));
|
||||
} catch (e) {
|
||||
if (e instanceof NodeUpdateError) {
|
||||
unableToUpdateCount++;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if (unableToUpdateCount) {
|
||||
log.warn(
|
||||
`Unable to update ${unableToUpdateCount} nodes. Please report this issue.`
|
||||
t('nodes.unableToUpdateNodes', {
|
||||
count: unableToUpdateCount,
|
||||
})
|
||||
);
|
||||
dispatch(
|
||||
addToast(
|
||||
@ -46,6 +57,15 @@ export const addUpdateAllNodesRequestedListener = () => {
|
||||
})
|
||||
)
|
||||
);
|
||||
} else {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('nodes.allNodesUpdated'),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
@ -0,0 +1,105 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { workflowLoadRequested } from 'features/nodes/store/actions';
|
||||
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
|
||||
import { $flow } from 'features/nodes/store/reactFlowInstance';
|
||||
import { WorkflowVersionError } from 'features/nodes/types/error';
|
||||
import { validateWorkflow } from 'features/nodes/util/validateWorkflow';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||
import { t } from 'i18next';
|
||||
import { z } from 'zod';
|
||||
import { fromZodError } from 'zod-validation-error';
|
||||
import { startAppListening } from '..';
|
||||
|
||||
export const addWorkflowLoadRequestedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: workflowLoadRequested,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const log = logger('nodes');
|
||||
const workflow = action.payload;
|
||||
const nodeTemplates = getState().nodes.nodeTemplates;
|
||||
|
||||
try {
|
||||
const { workflow: validatedWorkflow, warnings } = validateWorkflow(
|
||||
workflow,
|
||||
nodeTemplates
|
||||
);
|
||||
dispatch(workflowLoaded(validatedWorkflow));
|
||||
if (!warnings.length) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('toast.workflowLoaded'),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
} else {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('toast.loadedWithWarnings'),
|
||||
status: 'warning',
|
||||
})
|
||||
)
|
||||
);
|
||||
warnings.forEach(({ message, ...rest }) => {
|
||||
log.warn(rest, message);
|
||||
});
|
||||
}
|
||||
|
||||
dispatch(setActiveTab('nodes'));
|
||||
requestAnimationFrame(() => {
|
||||
$flow.get()?.fitView();
|
||||
});
|
||||
} catch (e) {
|
||||
if (e instanceof WorkflowVersionError) {
|
||||
// The workflow version was not recognized in the valid list of versions
|
||||
log.error({ error: parseify(e) }, e.message);
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('nodes.unableToValidateWorkflow'),
|
||||
status: 'error',
|
||||
description: e.message,
|
||||
})
|
||||
)
|
||||
);
|
||||
} else if (e instanceof z.ZodError) {
|
||||
// There was a problem validating the workflow itself
|
||||
const { message } = fromZodError(e, {
|
||||
prefix: t('nodes.workflowValidation'),
|
||||
});
|
||||
log.error({ error: parseify(e) }, message);
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('nodes.unableToValidateWorkflow'),
|
||||
status: 'error',
|
||||
description: message,
|
||||
})
|
||||
)
|
||||
);
|
||||
} else {
|
||||
// Some other error occurred
|
||||
console.log(e);
|
||||
log.error(
|
||||
{ error: parseify(e) },
|
||||
t('nodes.unknownErrorValidatingWorkflow')
|
||||
);
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('nodes.unableToValidateWorkflow'),
|
||||
status: 'error',
|
||||
description: t('nodes.unknownErrorValidatingWorkflow'),
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
@ -1,56 +0,0 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { workflowLoadRequested } from 'features/nodes/store/actions';
|
||||
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
|
||||
import { $flow } from 'features/nodes/store/reactFlowInstance';
|
||||
import { validateWorkflow } from 'features/nodes/util/validateWorkflow';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||
import { startAppListening } from '..';
|
||||
import { t } from 'i18next';
|
||||
|
||||
export const addWorkflowLoadedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: workflowLoadRequested,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const log = logger('nodes');
|
||||
const workflow = action.payload;
|
||||
const nodeTemplates = getState().nodes.nodeTemplates;
|
||||
|
||||
const { workflow: validatedWorkflow, errors } = validateWorkflow(
|
||||
workflow,
|
||||
nodeTemplates
|
||||
);
|
||||
|
||||
dispatch(workflowLoaded(validatedWorkflow));
|
||||
|
||||
if (!errors.length) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('toast.workflowLoaded'),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
} else {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('toast.loadedWithWarnings'),
|
||||
status: 'warning',
|
||||
})
|
||||
)
|
||||
);
|
||||
errors.forEach(({ message, ...rest }) => {
|
||||
log.warn(rest, message);
|
||||
});
|
||||
}
|
||||
|
||||
dispatch(setActiveTab('nodes'));
|
||||
requestAnimationFrame(() => {
|
||||
$flow.get()?.fitView();
|
||||
});
|
||||
},
|
||||
});
|
||||
};
|
@ -4,7 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { selectControlAdapterAll } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
|
||||
import { isInvocationNode } from 'features/nodes/types/types';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import i18n from 'i18next';
|
||||
import { forEach } from 'lodash-es';
|
||||
|
@ -6,9 +6,9 @@ import {
|
||||
isAnyOf,
|
||||
} from '@reduxjs/toolkit';
|
||||
import {
|
||||
ControlNetModelParam,
|
||||
IPAdapterModelParam,
|
||||
T2IAdapterModelParam,
|
||||
ParameterControlNetModel,
|
||||
ParameterIPAdapterModel,
|
||||
ParameterT2IAdapterModel,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import { cloneDeep, merge, uniq } from 'lodash-es';
|
||||
import { appSocketInvocationError } from 'services/events/actions';
|
||||
@ -243,9 +243,9 @@ export const controlAdaptersSlice = createSlice({
|
||||
action: PayloadAction<{
|
||||
id: string;
|
||||
model:
|
||||
| ControlNetModelParam
|
||||
| T2IAdapterModelParam
|
||||
| IPAdapterModelParam;
|
||||
| ParameterControlNetModel
|
||||
| ParameterT2IAdapterModel
|
||||
| ParameterIPAdapterModel;
|
||||
}>
|
||||
) => {
|
||||
const { id, model } = action.payload;
|
||||
|
@ -1,8 +1,8 @@
|
||||
import { EntityState } from '@reduxjs/toolkit';
|
||||
import {
|
||||
ControlNetModelParam,
|
||||
IPAdapterModelParam,
|
||||
T2IAdapterModelParam,
|
||||
ParameterControlNetModel,
|
||||
ParameterIPAdapterModel,
|
||||
ParameterT2IAdapterModel,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import { isObject } from 'lodash-es';
|
||||
import { components } from 'services/api/schema';
|
||||
@ -378,7 +378,7 @@ export type ControlNetConfig = {
|
||||
type: 'controlnet';
|
||||
id: string;
|
||||
isEnabled: boolean;
|
||||
model: ControlNetModelParam | null;
|
||||
model: ParameterControlNetModel | null;
|
||||
weight: number;
|
||||
beginStepPct: number;
|
||||
endStepPct: number;
|
||||
@ -395,7 +395,7 @@ export type T2IAdapterConfig = {
|
||||
type: 't2i_adapter';
|
||||
id: string;
|
||||
isEnabled: boolean;
|
||||
model: T2IAdapterModelParam | null;
|
||||
model: ParameterT2IAdapterModel | null;
|
||||
weight: number;
|
||||
beginStepPct: number;
|
||||
endStepPct: number;
|
||||
@ -412,7 +412,7 @@ export type IPAdapterConfig = {
|
||||
id: string;
|
||||
isEnabled: boolean;
|
||||
controlImage: string | null;
|
||||
model: IPAdapterModelParam | null;
|
||||
model: ParameterIPAdapterModel | null;
|
||||
weight: number;
|
||||
beginStepPct: number;
|
||||
endStepPct: number;
|
||||
|
@ -1,11 +1,12 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { isInvocationNode } from 'features/nodes/types/types';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { some } from 'lodash-es';
|
||||
import { ImageUsage } from './types';
|
||||
import { selectControlAdapterAll } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
|
||||
import { isImageFieldInputInstance } from 'features/nodes/types/field';
|
||||
|
||||
export const getImageUsage = (state: RootState, image_name: string) => {
|
||||
const { generation, canvas, nodes, controlAdapters } = state;
|
||||
@ -19,7 +20,8 @@ export const getImageUsage = (state: RootState, image_name: string) => {
|
||||
return some(
|
||||
node.data.inputs,
|
||||
(input) =>
|
||||
input.type === 'ImageField' && input.value?.image_name === image_name
|
||||
isImageFieldInputInstance(input) &&
|
||||
input.value?.image_name === image_name
|
||||
);
|
||||
});
|
||||
|
||||
|
@ -11,9 +11,9 @@ import {
|
||||
useDroppable as useOriginalDroppable,
|
||||
} from '@dnd-kit/core';
|
||||
import {
|
||||
InputFieldTemplate,
|
||||
InputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
FieldInputTemplate,
|
||||
FieldInputInstance,
|
||||
} from 'features/nodes/types/field';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
|
||||
type BaseDropData = {
|
||||
@ -93,8 +93,8 @@ export type NodeFieldDraggableData = BaseDragData & {
|
||||
payloadType: 'NODE_FIELD';
|
||||
payload: {
|
||||
nodeId: string;
|
||||
field: InputFieldValue;
|
||||
fieldTemplate: InputFieldTemplate;
|
||||
field: FieldInputInstance;
|
||||
fieldTemplate: FieldInputTemplate;
|
||||
};
|
||||
};
|
||||
|
||||
|
@ -4,14 +4,14 @@ import {
|
||||
LoRAMetadataItem,
|
||||
IPAdapterMetadataItem,
|
||||
T2IAdapterMetadataItem,
|
||||
} from 'features/nodes/types/types';
|
||||
} from 'features/nodes/types/metadata';
|
||||
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
||||
import { memo, useMemo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
isValidControlNetModel,
|
||||
isValidLoRAModel,
|
||||
isValidT2IAdapterModel,
|
||||
isParameterControlNetModel,
|
||||
isParameterLoRAModel,
|
||||
isParameterT2IAdapterModel,
|
||||
} from '../../../parameters/types/parameterSchemas';
|
||||
import ImageMetadataItem from './ImageMetadataItem';
|
||||
|
||||
@ -132,7 +132,7 @@ const ImageMetadataActions = (props: Props) => {
|
||||
const validControlNets: ControlNetMetadataItem[] = useMemo(() => {
|
||||
return metadata?.controlnets
|
||||
? metadata.controlnets.filter((controlnet) =>
|
||||
isValidControlNetModel(controlnet.control_model)
|
||||
isParameterControlNetModel(controlnet.control_model)
|
||||
)
|
||||
: [];
|
||||
}, [metadata?.controlnets]);
|
||||
@ -140,7 +140,7 @@ const ImageMetadataActions = (props: Props) => {
|
||||
const validIPAdapters: IPAdapterMetadataItem[] = useMemo(() => {
|
||||
return metadata?.ipAdapters
|
||||
? metadata.ipAdapters.filter((ipAdapter) =>
|
||||
isValidControlNetModel(ipAdapter.ip_adapter_model)
|
||||
isParameterControlNetModel(ipAdapter.ip_adapter_model)
|
||||
)
|
||||
: [];
|
||||
}, [metadata?.ipAdapters]);
|
||||
@ -148,7 +148,7 @@ const ImageMetadataActions = (props: Props) => {
|
||||
const validT2IAdapters: T2IAdapterMetadataItem[] = useMemo(() => {
|
||||
return metadata?.t2iAdapters
|
||||
? metadata.t2iAdapters.filter((t2iAdapter) =>
|
||||
isValidT2IAdapterModel(t2iAdapter.t2i_adapter_model)
|
||||
isParameterT2IAdapterModel(t2iAdapter.t2i_adapter_model)
|
||||
)
|
||||
: [];
|
||||
}, [metadata?.t2iAdapters]);
|
||||
@ -157,8 +157,6 @@ const ImageMetadataActions = (props: Props) => {
|
||||
return null;
|
||||
}
|
||||
|
||||
console.log(metadata);
|
||||
|
||||
return (
|
||||
<>
|
||||
{metadata.created_by && (
|
||||
@ -275,7 +273,7 @@ const ImageMetadataActions = (props: Props) => {
|
||||
)}
|
||||
{metadata.loras &&
|
||||
metadata.loras.map((lora, index) => {
|
||||
if (isValidLoRAModel(lora.lora)) {
|
||||
if (isParameterLoRAModel(lora.lora)) {
|
||||
return (
|
||||
<ImageMetadataItem
|
||||
key={index}
|
||||
|
@ -1,8 +1,8 @@
|
||||
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
|
||||
import { LoRAModelParam } from 'features/parameters/types/parameterSchemas';
|
||||
import { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { LoRAModelConfigEntity } from 'services/api/endpoints/models';
|
||||
|
||||
export type LoRA = LoRAModelParam & {
|
||||
export type LoRA = ParameterLoRAModel & {
|
||||
id: string;
|
||||
weight: number;
|
||||
};
|
||||
|
@ -24,7 +24,6 @@ import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { HotkeyCallback } from 'react-hotkeys-hook/dist/types';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import 'reactflow/dist/style.css';
|
||||
import { AnyInvocationType } from 'services/events/types';
|
||||
import { AddNodePopoverSelectItem } from './AddNodePopoverSelectItem';
|
||||
|
||||
type NodeTemplate = {
|
||||
@ -57,7 +56,7 @@ const AddNodePopover = () => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const fieldFilter = useAppSelector(
|
||||
(state) => state.nodes.currentConnectionFieldType
|
||||
(state) => state.nodes.connectionStartFieldType
|
||||
);
|
||||
const handleFilter = useAppSelector(
|
||||
(state) => state.nodes.connectionStartParams?.handleType
|
||||
@ -111,7 +110,7 @@ const AddNodePopover = () => {
|
||||
|
||||
data.sort((a, b) => a.label.localeCompare(b.label));
|
||||
|
||||
return { data, t };
|
||||
return { data };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
@ -121,7 +120,7 @@ const AddNodePopover = () => {
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
const addNode = useCallback(
|
||||
(nodeType: AnyInvocationType) => {
|
||||
(nodeType: string) => {
|
||||
const invocation = buildInvocation(nodeType);
|
||||
if (!invocation) {
|
||||
const errorMessage = t('nodes.unknownNode', {
|
||||
@ -145,7 +144,7 @@ const AddNodePopover = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
addNode(v as AnyInvocationType);
|
||||
addNode(v);
|
||||
},
|
||||
[addNode]
|
||||
);
|
||||
|
@ -2,17 +2,16 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
|
||||
import { FIELDS } from 'features/nodes/types/constants';
|
||||
import { memo } from 'react';
|
||||
import { ConnectionLineComponentProps, getBezierPath } from 'reactflow';
|
||||
import { getFieldColor } from '../edges/util/getEdgeColor';
|
||||
|
||||
const selector = createSelector(stateSelector, ({ nodes }) => {
|
||||
const { shouldAnimateEdges, currentConnectionFieldType, shouldColorEdges } =
|
||||
const { shouldAnimateEdges, connectionStartFieldType, shouldColorEdges } =
|
||||
nodes;
|
||||
|
||||
const stroke =
|
||||
currentConnectionFieldType && shouldColorEdges
|
||||
? colorTokenToCssVar(FIELDS[currentConnectionFieldType].color)
|
||||
const stroke = shouldColorEdges
|
||||
? getFieldColor(connectionStartFieldType)
|
||||
: colorTokenToCssVar('base.500');
|
||||
|
||||
let className = 'react-flow__custom_connection-path';
|
||||
|
@ -0,0 +1,12 @@
|
||||
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
|
||||
import { FIELD_COLORS } from 'features/nodes/types/constants';
|
||||
import { FieldType } from 'features/nodes/types/field';
|
||||
|
||||
export const getFieldColor = (fieldType: FieldType | null): string => {
|
||||
if (!fieldType) {
|
||||
return colorTokenToCssVar('base.500');
|
||||
}
|
||||
const color = FIELD_COLORS[fieldType.name];
|
||||
|
||||
return color ? colorTokenToCssVar(color) : colorTokenToCssVar('base.500');
|
||||
};
|
@ -2,8 +2,8 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
|
||||
import { FIELDS } from 'features/nodes/types/constants';
|
||||
import { isInvocationNode } from 'features/nodes/types/types';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { getFieldColor } from './getEdgeColor';
|
||||
|
||||
export const makeEdgeSelector = (
|
||||
source: string,
|
||||
@ -29,7 +29,7 @@ export const makeEdgeSelector = (
|
||||
|
||||
const stroke =
|
||||
sourceType && nodes.shouldColorEdges
|
||||
? colorTokenToCssVar(FIELDS[sourceType].color)
|
||||
? getFieldColor(sourceType)
|
||||
: colorTokenToCssVar('base.500');
|
||||
|
||||
return {
|
||||
|
@ -1,7 +1,7 @@
|
||||
import { useColorModeValue } from '@chakra-ui/react';
|
||||
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
|
||||
import { useNodeData } from 'features/nodes/hooks/useNodeData';
|
||||
import { isInvocationNodeData } from 'features/nodes/types/types';
|
||||
import { isInvocationNodeData } from 'features/nodes/types/invocation';
|
||||
import { map } from 'lodash-es';
|
||||
import { CSSProperties, memo, useMemo } from 'react';
|
||||
import { Handle, Position } from 'reactflow';
|
||||
|
@ -2,8 +2,8 @@ import { Flex, Icon, Text, Tooltip } from '@chakra-ui/react';
|
||||
import { compare } from 'compare-versions';
|
||||
import { useNodeData } from 'features/nodes/hooks/useNodeData';
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import { useNodeVersion } from 'features/nodes/hooks/useNodeVersion';
|
||||
import { isInvocationNodeData } from 'features/nodes/types/types';
|
||||
import { useNodeNeedsUpdate } from 'features/nodes/hooks/useNodeNeedsUpdate';
|
||||
import { isInvocationNodeData } from 'features/nodes/types/invocation';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaInfoCircle } from 'react-icons/fa';
|
||||
@ -13,7 +13,7 @@ interface Props {
|
||||
}
|
||||
|
||||
const InvocationNodeInfoIcon = ({ nodeId }: Props) => {
|
||||
const { needsUpdate } = useNodeVersion(nodeId);
|
||||
const needsUpdate = useNodeNeedsUpdate(nodeId);
|
||||
|
||||
return (
|
||||
<Tooltip
|
||||
|
@ -11,7 +11,10 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
|
||||
import { NodeExecutionState, NodeStatus } from 'features/nodes/types/types';
|
||||
import {
|
||||
NodeExecutionState,
|
||||
zNodeStatus,
|
||||
} from 'features/nodes/types/invocation';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { FaCheck, FaEllipsisH, FaExclamation } from 'react-icons/fa';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@ -74,10 +77,10 @@ type TooltipLabelProps = {
|
||||
const TooltipLabel = memo(({ nodeExecutionState }: TooltipLabelProps) => {
|
||||
const { status, progress, progressImage } = nodeExecutionState;
|
||||
const { t } = useTranslation();
|
||||
if (status === NodeStatus.PENDING) {
|
||||
if (status === zNodeStatus.enum.PENDING) {
|
||||
return <Text>{t('queue.pending')}</Text>;
|
||||
}
|
||||
if (status === NodeStatus.IN_PROGRESS) {
|
||||
if (status === zNodeStatus.enum.IN_PROGRESS) {
|
||||
if (progressImage) {
|
||||
return (
|
||||
<Flex sx={{ pos: 'relative', pt: 1.5, pb: 0.5 }}>
|
||||
@ -108,11 +111,11 @@ const TooltipLabel = memo(({ nodeExecutionState }: TooltipLabelProps) => {
|
||||
return <Text>{t('nodes.executionStateInProgress')}</Text>;
|
||||
}
|
||||
|
||||
if (status === NodeStatus.COMPLETED) {
|
||||
if (status === zNodeStatus.enum.COMPLETED) {
|
||||
return <Text>{t('nodes.executionStateCompleted')}</Text>;
|
||||
}
|
||||
|
||||
if (status === NodeStatus.FAILED) {
|
||||
if (status === zNodeStatus.enum.FAILED) {
|
||||
return <Text>{t('nodes.executionStateError')}</Text>;
|
||||
}
|
||||
|
||||
@ -127,7 +130,7 @@ type StatusIconProps = {
|
||||
|
||||
const StatusIcon = memo((props: StatusIconProps) => {
|
||||
const { progress, status } = props.nodeExecutionState;
|
||||
if (status === NodeStatus.PENDING) {
|
||||
if (status === zNodeStatus.enum.PENDING) {
|
||||
return (
|
||||
<Icon
|
||||
as={FaEllipsisH}
|
||||
@ -139,7 +142,7 @@ const StatusIcon = memo((props: StatusIconProps) => {
|
||||
/>
|
||||
);
|
||||
}
|
||||
if (status === NodeStatus.IN_PROGRESS) {
|
||||
if (status === zNodeStatus.enum.IN_PROGRESS) {
|
||||
return progress === null ? (
|
||||
<CircularProgress
|
||||
isIndeterminate
|
||||
@ -158,7 +161,7 @@ const StatusIcon = memo((props: StatusIconProps) => {
|
||||
/>
|
||||
);
|
||||
}
|
||||
if (status === NodeStatus.COMPLETED) {
|
||||
if (status === zNodeStatus.enum.COMPLETED) {
|
||||
return (
|
||||
<Icon
|
||||
as={FaCheck}
|
||||
@ -170,7 +173,7 @@ const StatusIcon = memo((props: StatusIconProps) => {
|
||||
/>
|
||||
);
|
||||
}
|
||||
if (status === NodeStatus.FAILED) {
|
||||
if (status === zNodeStatus.enum.FAILED) {
|
||||
return (
|
||||
<Icon
|
||||
as={FaExclamation}
|
||||
|
@ -1,7 +1,7 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { InvocationNodeData } from 'features/nodes/types/types';
|
||||
import { InvocationNodeData } from 'features/nodes/types/invocation';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { NodeProps } from 'reactflow';
|
||||
import InvocationNode from '../Invocation/InvocationNode';
|
||||
|
@ -3,7 +3,7 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAITextarea from 'common/components/IAITextarea';
|
||||
import { useNodeData } from 'features/nodes/hooks/useNodeData';
|
||||
import { nodeNotesChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { isInvocationNodeData } from 'features/nodes/types/types';
|
||||
import { isInvocationNodeData } from 'features/nodes/types/invocation';
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
|
@ -56,7 +56,7 @@ const FieldContextMenu = ({ nodeId, fieldName, kind, children }: Props) => {
|
||||
);
|
||||
|
||||
const mayExpose = useMemo(
|
||||
() => ['any', 'direct'].includes(input ?? '__UNKNOWN_INPUT__'),
|
||||
() => input && ['any', 'direct'].includes(input),
|
||||
[input]
|
||||
);
|
||||
|
||||
|
@ -1,18 +1,17 @@
|
||||
import { Tooltip } from '@chakra-ui/react';
|
||||
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
|
||||
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
|
||||
import {
|
||||
COLLECTION_TYPES,
|
||||
FIELDS,
|
||||
HANDLE_TOOLTIP_OPEN_DELAY,
|
||||
MODEL_TYPES,
|
||||
POLYMORPHIC_TYPES,
|
||||
} from 'features/nodes/types/constants';
|
||||
import {
|
||||
InputFieldTemplate,
|
||||
OutputFieldTemplate,
|
||||
} from 'features/nodes/types/types';
|
||||
FieldInputTemplate,
|
||||
FieldOutputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { CSSProperties, memo, useMemo } from 'react';
|
||||
import { Handle, HandleType, Position } from 'reactflow';
|
||||
import { getFieldColor } from '../../../edges/util/getEdgeColor';
|
||||
|
||||
export const handleBaseStyles: CSSProperties = {
|
||||
position: 'absolute',
|
||||
@ -32,11 +31,11 @@ export const outputHandleStyles: CSSProperties = {
|
||||
};
|
||||
|
||||
type FieldHandleProps = {
|
||||
fieldTemplate: InputFieldTemplate | OutputFieldTemplate;
|
||||
fieldTemplate: FieldInputTemplate | FieldOutputTemplate;
|
||||
handleType: HandleType;
|
||||
isConnectionInProgress: boolean;
|
||||
isConnectionStartField: boolean;
|
||||
connectionError: string | null;
|
||||
connectionError?: string;
|
||||
};
|
||||
|
||||
const FieldHandle = (props: FieldHandleProps) => {
|
||||
@ -47,23 +46,21 @@ const FieldHandle = (props: FieldHandleProps) => {
|
||||
isConnectionStartField,
|
||||
connectionError,
|
||||
} = props;
|
||||
const { name, type } = fieldTemplate;
|
||||
const { color: typeColor, title } = FIELDS[type];
|
||||
|
||||
const { name } = fieldTemplate;
|
||||
const type = fieldTemplate.type;
|
||||
const fieldTypeName = useFieldTypeName(type);
|
||||
const styles: CSSProperties = useMemo(() => {
|
||||
const isCollectionType = COLLECTION_TYPES.includes(type);
|
||||
const isPolymorphicType = POLYMORPHIC_TYPES.includes(type);
|
||||
const isModelType = MODEL_TYPES.includes(type);
|
||||
const color = colorTokenToCssVar(typeColor);
|
||||
const isModelType = MODEL_TYPES.some((t) => t === type.name);
|
||||
const color = getFieldColor(type);
|
||||
const s: CSSProperties = {
|
||||
backgroundColor:
|
||||
isCollectionType || isPolymorphicType
|
||||
? 'var(--invokeai-colors-base-900)'
|
||||
type.isCollection || type.isPolymorphic
|
||||
? colorTokenToCssVar('base.900')
|
||||
: color,
|
||||
position: 'absolute',
|
||||
width: '1rem',
|
||||
height: '1rem',
|
||||
borderWidth: isCollectionType || isPolymorphicType ? 4 : 0,
|
||||
borderWidth: type.isCollection || type.isPolymorphic ? 4 : 0,
|
||||
borderStyle: 'solid',
|
||||
borderColor: color,
|
||||
borderRadius: isModelType ? 4 : '100%',
|
||||
@ -97,18 +94,14 @@ const FieldHandle = (props: FieldHandleProps) => {
|
||||
isConnectionInProgress,
|
||||
isConnectionStartField,
|
||||
type,
|
||||
typeColor,
|
||||
]);
|
||||
|
||||
const tooltip = useMemo(() => {
|
||||
if (isConnectionInProgress && isConnectionStartField) {
|
||||
return title;
|
||||
}
|
||||
if (isConnectionInProgress && connectionError) {
|
||||
return connectionError ?? title;
|
||||
return connectionError;
|
||||
}
|
||||
return title;
|
||||
}, [connectionError, isConnectionInProgress, isConnectionStartField, title]);
|
||||
return fieldTypeName;
|
||||
}, [connectionError, fieldTypeName, isConnectionInProgress]);
|
||||
|
||||
return (
|
||||
<Tooltip
|
||||
|
@ -1,15 +1,14 @@
|
||||
import { Flex, Text } from '@chakra-ui/react';
|
||||
import { useFieldData } from 'features/nodes/hooks/useFieldData';
|
||||
import { useFieldInstance } from 'features/nodes/hooks/useFieldData';
|
||||
import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate';
|
||||
import { FIELDS } from 'features/nodes/types/constants';
|
||||
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
|
||||
import {
|
||||
isInputFieldTemplate,
|
||||
isInputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
isFieldInputInstance,
|
||||
isFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { startCase } from 'lodash-es';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
interface Props {
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
@ -17,12 +16,13 @@ interface Props {
|
||||
}
|
||||
|
||||
const FieldTooltipContent = ({ nodeId, fieldName, kind }: Props) => {
|
||||
const field = useFieldData(nodeId, fieldName);
|
||||
const field = useFieldInstance(nodeId, fieldName);
|
||||
const fieldTemplate = useFieldTemplate(nodeId, fieldName, kind);
|
||||
const isInputTemplate = isInputFieldTemplate(fieldTemplate);
|
||||
const isInputTemplate = isFieldInputTemplate(fieldTemplate);
|
||||
const fieldTypeName = useFieldTypeName(fieldTemplate?.type);
|
||||
const { t } = useTranslation();
|
||||
const fieldTitle = useMemo(() => {
|
||||
if (isInputFieldValue(field)) {
|
||||
if (isFieldInputInstance(field)) {
|
||||
if (field.label && fieldTemplate?.title) {
|
||||
return `${field.label} (${fieldTemplate.title})`;
|
||||
}
|
||||
@ -49,9 +49,9 @@ const FieldTooltipContent = ({ nodeId, fieldName, kind }: Props) => {
|
||||
{fieldTemplate.description}
|
||||
</Text>
|
||||
)}
|
||||
{fieldTemplate && (
|
||||
{fieldTypeName && (
|
||||
<Text>
|
||||
{t('parameters.type')}: {FIELDS[fieldTemplate.type].title}
|
||||
{t('parameters.type')}: {fieldTypeName}
|
||||
</Text>
|
||||
)}
|
||||
{isInputTemplate && (
|
||||
|
@ -77,10 +77,10 @@ const InputField = ({ nodeId, fieldName }: Props) => {
|
||||
sx={{
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
h: 'full',
|
||||
mb: 0,
|
||||
px: 1,
|
||||
gap: 2,
|
||||
h: 'full',
|
||||
}}
|
||||
>
|
||||
<EditableFieldTitle
|
||||
|
@ -1,24 +1,60 @@
|
||||
import { Box, Text } from '@chakra-ui/react';
|
||||
import { useFieldData } from 'features/nodes/hooks/useFieldData';
|
||||
import { useFieldInstance } from 'features/nodes/hooks/useFieldData';
|
||||
import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate';
|
||||
import {
|
||||
isBoardFieldInputInstance,
|
||||
isBoardFieldInputTemplate,
|
||||
isBooleanFieldInputInstance,
|
||||
isBooleanFieldInputTemplate,
|
||||
isColorFieldInputInstance,
|
||||
isColorFieldInputTemplate,
|
||||
isControlNetModelFieldInputInstance,
|
||||
isControlNetModelFieldInputTemplate,
|
||||
isEnumFieldInputInstance,
|
||||
isEnumFieldInputTemplate,
|
||||
isFloatFieldInputInstance,
|
||||
isFloatFieldInputTemplate,
|
||||
isIPAdapterModelFieldInputInstance,
|
||||
isIPAdapterModelFieldInputTemplate,
|
||||
isImageFieldInputInstance,
|
||||
isImageFieldInputTemplate,
|
||||
isIntegerFieldInputInstance,
|
||||
isIntegerFieldInputTemplate,
|
||||
isLoRAModelFieldInputInstance,
|
||||
isLoRAModelFieldInputTemplate,
|
||||
isMainModelFieldInputInstance,
|
||||
isMainModelFieldInputTemplate,
|
||||
isSDXLMainModelFieldInputInstance,
|
||||
isSDXLMainModelFieldInputTemplate,
|
||||
isSDXLRefinerModelFieldInputInstance,
|
||||
isSDXLRefinerModelFieldInputTemplate,
|
||||
isSchedulerFieldInputInstance,
|
||||
isSchedulerFieldInputTemplate,
|
||||
isStringFieldInputInstance,
|
||||
isStringFieldInputTemplate,
|
||||
isT2IAdapterModelFieldInputInstance,
|
||||
isT2IAdapterModelFieldInputTemplate,
|
||||
isVAEModelFieldInputInstance,
|
||||
isVAEModelFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { memo } from 'react';
|
||||
import BooleanInputField from './inputs/BooleanInputField';
|
||||
import ColorInputField from './inputs/ColorInputField';
|
||||
import ControlNetModelInputField from './inputs/ControlNetModelInputField';
|
||||
import EnumInputField from './inputs/EnumInputField';
|
||||
import ImageInputField from './inputs/ImageInputField';
|
||||
import LoRAModelInputField from './inputs/LoRAModelInputField';
|
||||
import MainModelInputField from './inputs/MainModelInputField';
|
||||
import NumberInputField from './inputs/NumberInputField';
|
||||
import RefinerModelInputField from './inputs/RefinerModelInputField';
|
||||
import SDXLMainModelInputField from './inputs/SDXLMainModelInputField';
|
||||
import SchedulerInputField from './inputs/SchedulerInputField';
|
||||
import StringInputField from './inputs/StringInputField';
|
||||
import VaeModelInputField from './inputs/VaeModelInputField';
|
||||
import IPAdapterModelInputField from './inputs/IPAdapterModelInputField';
|
||||
import T2IAdapterModelInputField from './inputs/T2IAdapterModelInputField';
|
||||
import BoardInputField from './inputs/BoardInputField';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import BoardFieldInputComponent from './inputs/BoardFieldInputComponent';
|
||||
import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent';
|
||||
import ColorFieldInputComponent from './inputs/ColorFieldInputComponent';
|
||||
import ControlNetModelFieldInputComponent from './inputs/ControlNetModelFieldInputComponent';
|
||||
import EnumFieldInputComponent from './inputs/EnumFieldInputComponent';
|
||||
import IPAdapterModelFieldInputComponent from './inputs/IPAdapterModelFieldInputComponent';
|
||||
import ImageFieldInputComponent from './inputs/ImageFieldInputComponent';
|
||||
import LoRAModelFieldInputComponent from './inputs/LoRAModelFieldInputComponent';
|
||||
import MainModelFieldInputComponent from './inputs/MainModelFieldInputComponent';
|
||||
import NumberFieldInputComponent from './inputs/NumberFieldInputComponent';
|
||||
import RefinerModelFieldInputComponent from './inputs/RefinerModelFieldInputComponent';
|
||||
import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent';
|
||||
import SchedulerFieldInputComponent from './inputs/SchedulerFieldInputComponent';
|
||||
import StringFieldInputComponent from './inputs/StringFieldInputComponent';
|
||||
import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent';
|
||||
import VAEModelFieldInputComponent from './inputs/VAEModelFieldInputComponent';
|
||||
|
||||
type InputFieldProps = {
|
||||
nodeId: string;
|
||||
@ -27,220 +63,227 @@ type InputFieldProps = {
|
||||
|
||||
const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
const { t } = useTranslation();
|
||||
const field = useFieldData(nodeId, fieldName);
|
||||
const fieldInstance = useFieldInstance(nodeId, fieldName);
|
||||
const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input');
|
||||
|
||||
if (fieldTemplate?.fieldKind === 'output') {
|
||||
return (
|
||||
<Box p={2}>
|
||||
{t('nodes.outputFieldInInput')}: {field?.type}
|
||||
{t('nodes.outputFieldInInput')}: {fieldInstance?.type.name}
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
(field?.type === 'string' && fieldTemplate?.type === 'string') ||
|
||||
(field?.type === 'StringPolymorphic' &&
|
||||
fieldTemplate?.type === 'StringPolymorphic')
|
||||
isStringFieldInputInstance(fieldInstance) &&
|
||||
isStringFieldInputTemplate(fieldTemplate)
|
||||
) {
|
||||
return (
|
||||
<StringInputField
|
||||
<StringFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
(field?.type === 'boolean' && fieldTemplate?.type === 'boolean') ||
|
||||
(field?.type === 'BooleanPolymorphic' &&
|
||||
fieldTemplate?.type === 'BooleanPolymorphic')
|
||||
isBooleanFieldInputInstance(fieldInstance) &&
|
||||
isBooleanFieldInputTemplate(fieldTemplate)
|
||||
) {
|
||||
return (
|
||||
<BooleanInputField
|
||||
<BooleanFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
(field?.type === 'integer' && fieldTemplate?.type === 'integer') ||
|
||||
(field?.type === 'float' && fieldTemplate?.type === 'float') ||
|
||||
(field?.type === 'FloatPolymorphic' &&
|
||||
fieldTemplate?.type === 'FloatPolymorphic') ||
|
||||
(field?.type === 'IntegerPolymorphic' &&
|
||||
fieldTemplate?.type === 'IntegerPolymorphic')
|
||||
(isIntegerFieldInputInstance(fieldInstance) &&
|
||||
isIntegerFieldInputTemplate(fieldTemplate)) ||
|
||||
(isFloatFieldInputInstance(fieldInstance) &&
|
||||
isFloatFieldInputTemplate(fieldTemplate))
|
||||
) {
|
||||
return (
|
||||
<NumberInputField
|
||||
<NumberFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (field?.type === 'enum' && fieldTemplate?.type === 'enum') {
|
||||
return (
|
||||
<EnumInputField
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
(field?.type === 'ImageField' && fieldTemplate?.type === 'ImageField') ||
|
||||
(field?.type === 'ImagePolymorphic' &&
|
||||
fieldTemplate?.type === 'ImagePolymorphic')
|
||||
isEnumFieldInputInstance(fieldInstance) &&
|
||||
isEnumFieldInputTemplate(fieldTemplate)
|
||||
) {
|
||||
return (
|
||||
<ImageInputField
|
||||
<EnumFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (field?.type === 'BoardField' && fieldTemplate?.type === 'BoardField') {
|
||||
return (
|
||||
<BoardInputField
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'MainModelField' &&
|
||||
fieldTemplate?.type === 'MainModelField'
|
||||
isImageFieldInputInstance(fieldInstance) &&
|
||||
isImageFieldInputTemplate(fieldTemplate)
|
||||
) {
|
||||
return (
|
||||
<MainModelInputField
|
||||
<ImageFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'SDXLRefinerModelField' &&
|
||||
fieldTemplate?.type === 'SDXLRefinerModelField'
|
||||
isBoardFieldInputInstance(fieldInstance) &&
|
||||
isBoardFieldInputTemplate(fieldTemplate)
|
||||
) {
|
||||
return (
|
||||
<RefinerModelInputField
|
||||
<BoardFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'VaeModelField' &&
|
||||
fieldTemplate?.type === 'VaeModelField'
|
||||
isMainModelFieldInputInstance(fieldInstance) &&
|
||||
isMainModelFieldInputTemplate(fieldTemplate)
|
||||
) {
|
||||
return (
|
||||
<VaeModelInputField
|
||||
<MainModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'LoRAModelField' &&
|
||||
fieldTemplate?.type === 'LoRAModelField'
|
||||
isSDXLRefinerModelFieldInputInstance(fieldInstance) &&
|
||||
isSDXLRefinerModelFieldInputTemplate(fieldTemplate)
|
||||
) {
|
||||
return (
|
||||
<LoRAModelInputField
|
||||
<RefinerModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'ControlNetModelField' &&
|
||||
fieldTemplate?.type === 'ControlNetModelField'
|
||||
isVAEModelFieldInputInstance(fieldInstance) &&
|
||||
isVAEModelFieldInputTemplate(fieldTemplate)
|
||||
) {
|
||||
return (
|
||||
<ControlNetModelInputField
|
||||
<VAEModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'IPAdapterModelField' &&
|
||||
fieldTemplate?.type === 'IPAdapterModelField'
|
||||
isLoRAModelFieldInputInstance(fieldInstance) &&
|
||||
isLoRAModelFieldInputTemplate(fieldTemplate)
|
||||
) {
|
||||
return (
|
||||
<IPAdapterModelInputField
|
||||
<LoRAModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'T2IAdapterModelField' &&
|
||||
fieldTemplate?.type === 'T2IAdapterModelField'
|
||||
isControlNetModelFieldInputInstance(fieldInstance) &&
|
||||
isControlNetModelFieldInputTemplate(fieldTemplate)
|
||||
) {
|
||||
return (
|
||||
<T2IAdapterModelInputField
|
||||
<ControlNetModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
if (field?.type === 'ColorField' && fieldTemplate?.type === 'ColorField') {
|
||||
return (
|
||||
<ColorInputField
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'SDXLMainModelField' &&
|
||||
fieldTemplate?.type === 'SDXLMainModelField'
|
||||
isIPAdapterModelFieldInputInstance(fieldInstance) &&
|
||||
isIPAdapterModelFieldInputTemplate(fieldTemplate)
|
||||
) {
|
||||
return (
|
||||
<SDXLMainModelInputField
|
||||
<IPAdapterModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (field?.type === 'Scheduler' && fieldTemplate?.type === 'Scheduler') {
|
||||
if (
|
||||
isT2IAdapterModelFieldInputInstance(fieldInstance) &&
|
||||
isT2IAdapterModelFieldInputTemplate(fieldTemplate)
|
||||
) {
|
||||
return (
|
||||
<SchedulerInputField
|
||||
<T2IAdapterModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
if (
|
||||
isColorFieldInputInstance(fieldInstance) &&
|
||||
isColorFieldInputTemplate(fieldTemplate)
|
||||
) {
|
||||
return (
|
||||
<ColorFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (field && fieldTemplate) {
|
||||
if (
|
||||
isSDXLMainModelFieldInputInstance(fieldInstance) &&
|
||||
isSDXLMainModelFieldInputTemplate(fieldTemplate)
|
||||
) {
|
||||
return (
|
||||
<SDXLMainModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
isSchedulerFieldInputInstance(fieldInstance) &&
|
||||
isSchedulerFieldInputTemplate(fieldTemplate)
|
||||
) {
|
||||
return (
|
||||
<SchedulerFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (fieldInstance && fieldTemplate) {
|
||||
// Fallback for when there is no component for the type
|
||||
return null;
|
||||
}
|
||||
@ -255,7 +298,7 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
_dark: { color: 'error.300' },
|
||||
}}
|
||||
>
|
||||
{t('nodes.unknownFieldType')}: {field?.type}
|
||||
{t('nodes.unknownFieldType')}: {fieldInstance?.type.name}
|
||||
</Text>
|
||||
</Box>
|
||||
);
|
||||
|
@ -3,15 +3,15 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||
import { fieldBoardValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
BoardInputFieldTemplate,
|
||||
BoardInputFieldValue,
|
||||
FieldComponentProps,
|
||||
} from 'features/nodes/types/types';
|
||||
BoardFieldInputTemplate,
|
||||
BoardFieldInputInstance,
|
||||
} from 'features/nodes/types/field';
|
||||
import { FieldComponentProps } from './types';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
|
||||
|
||||
const BoardInputFieldComponent = (
|
||||
props: FieldComponentProps<BoardInputFieldValue, BoardInputFieldTemplate>
|
||||
const BoardFieldInputComponent = (
|
||||
props: FieldComponentProps<BoardFieldInputInstance, BoardFieldInputTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
@ -61,4 +61,4 @@ const BoardInputFieldComponent = (
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(BoardInputFieldComponent);
|
||||
export default memo(BoardFieldInputComponent);
|
@ -2,18 +2,16 @@ import { Switch } from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { fieldBooleanValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
BooleanInputFieldTemplate,
|
||||
BooleanInputFieldValue,
|
||||
BooleanPolymorphicInputFieldTemplate,
|
||||
BooleanPolymorphicInputFieldValue,
|
||||
FieldComponentProps,
|
||||
} from 'features/nodes/types/types';
|
||||
BooleanFieldInputInstance,
|
||||
BooleanFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { FieldComponentProps } from './types';
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
|
||||
const BooleanInputFieldComponent = (
|
||||
const BooleanFieldInputComponent = (
|
||||
props: FieldComponentProps<
|
||||
BooleanInputFieldValue | BooleanPolymorphicInputFieldValue,
|
||||
BooleanInputFieldTemplate | BooleanPolymorphicInputFieldTemplate
|
||||
BooleanFieldInputInstance,
|
||||
BooleanFieldInputTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
@ -42,4 +40,4 @@ const BooleanInputFieldComponent = (
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(BooleanInputFieldComponent);
|
||||
export default memo(BooleanFieldInputComponent);
|
@ -1,15 +1,15 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { fieldColorValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
ColorInputFieldTemplate,
|
||||
ColorInputFieldValue,
|
||||
FieldComponentProps,
|
||||
} from 'features/nodes/types/types';
|
||||
ColorFieldInputTemplate,
|
||||
ColorFieldInputInstance,
|
||||
} from 'features/nodes/types/field';
|
||||
import { FieldComponentProps } from './types';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { RgbaColor, RgbaColorPicker } from 'react-colorful';
|
||||
|
||||
const ColorInputFieldComponent = (
|
||||
props: FieldComponentProps<ColorInputFieldValue, ColorInputFieldTemplate>
|
||||
const ColorFieldInputComponent = (
|
||||
props: FieldComponentProps<ColorFieldInputInstance, ColorFieldInputTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
|
||||
@ -37,4 +37,4 @@ const ColorInputFieldComponent = (
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ColorInputFieldComponent);
|
||||
export default memo(ColorFieldInputComponent);
|
@ -3,20 +3,20 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { fieldControlNetModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
ControlNetModelInputFieldTemplate,
|
||||
ControlNetModelInputFieldValue,
|
||||
FieldComponentProps,
|
||||
} from 'features/nodes/types/types';
|
||||
ControlNetModelFieldInputTemplate,
|
||||
ControlNetModelFieldInputInstance,
|
||||
} from 'features/nodes/types/field';
|
||||
import { FieldComponentProps } from './types';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { modelIdToControlNetModelParam } from 'features/parameters/util/modelIdToControlNetModelParam';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
const ControlNetModelInputFieldComponent = (
|
||||
const ControlNetModelFieldInputComponent = (
|
||||
props: FieldComponentProps<
|
||||
ControlNetModelInputFieldValue,
|
||||
ControlNetModelInputFieldTemplate
|
||||
ControlNetModelFieldInputInstance,
|
||||
ControlNetModelFieldInputTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
@ -97,4 +97,4 @@ const ControlNetModelInputFieldComponent = (
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ControlNetModelInputFieldComponent);
|
||||
export default memo(ControlNetModelFieldInputComponent);
|
@ -2,14 +2,14 @@ import { Select } from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { fieldEnumModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
EnumInputFieldTemplate,
|
||||
EnumInputFieldValue,
|
||||
FieldComponentProps,
|
||||
} from 'features/nodes/types/types';
|
||||
EnumFieldInputInstance,
|
||||
EnumFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { FieldComponentProps } from './types';
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
|
||||
const EnumInputFieldComponent = (
|
||||
props: FieldComponentProps<EnumInputFieldValue, EnumInputFieldTemplate>
|
||||
const EnumFieldInputComponent = (
|
||||
props: FieldComponentProps<EnumFieldInputInstance, EnumFieldInputTemplate>
|
||||
) => {
|
||||
const { nodeId, field, fieldTemplate } = props;
|
||||
|
||||
@ -45,4 +45,4 @@ const EnumInputFieldComponent = (
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(EnumInputFieldComponent);
|
||||
export default memo(EnumFieldInputComponent);
|
@ -3,20 +3,20 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { fieldIPAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
IPAdapterModelInputFieldTemplate,
|
||||
IPAdapterModelInputFieldValue,
|
||||
FieldComponentProps,
|
||||
} from 'features/nodes/types/types';
|
||||
IPAdapterModelFieldInputTemplate,
|
||||
IPAdapterModelFieldInputInstance,
|
||||
} from 'features/nodes/types/field';
|
||||
import { FieldComponentProps } from './types';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { modelIdToIPAdapterModelParam } from 'features/parameters/util/modelIdToIPAdapterModelParams';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useGetIPAdapterModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
const IPAdapterModelInputFieldComponent = (
|
||||
const IPAdapterModelFieldInputComponent = (
|
||||
props: FieldComponentProps<
|
||||
IPAdapterModelInputFieldValue,
|
||||
IPAdapterModelInputFieldTemplate
|
||||
IPAdapterModelFieldInputInstance,
|
||||
IPAdapterModelFieldInputTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
@ -97,4 +97,4 @@ const IPAdapterModelInputFieldComponent = (
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(IPAdapterModelInputFieldComponent);
|
||||
export default memo(IPAdapterModelFieldInputComponent);
|
@ -9,23 +9,18 @@ import {
|
||||
} from 'features/dnd/types';
|
||||
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
FieldComponentProps,
|
||||
ImageInputFieldTemplate,
|
||||
ImageInputFieldValue,
|
||||
ImagePolymorphicInputFieldTemplate,
|
||||
ImagePolymorphicInputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
ImageFieldInputInstance,
|
||||
ImageFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { FieldComponentProps } from './types';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaUndo } from 'react-icons/fa';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import { PostUploadAction } from 'services/api/types';
|
||||
|
||||
const ImageInputFieldComponent = (
|
||||
props: FieldComponentProps<
|
||||
ImageInputFieldValue | ImagePolymorphicInputFieldValue,
|
||||
ImageInputFieldTemplate | ImagePolymorphicInputFieldTemplate
|
||||
>
|
||||
const ImageFieldInputComponent = (
|
||||
props: FieldComponentProps<ImageFieldInputInstance, ImageFieldInputTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
@ -102,7 +97,7 @@ const ImageInputFieldComponent = (
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ImageInputFieldComponent);
|
||||
export default memo(ImageFieldInputComponent);
|
||||
|
||||
const UploadElement = memo(() => {
|
||||
const { t } = useTranslation();
|
@ -5,10 +5,10 @@ import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSe
|
||||
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
|
||||
import { fieldLoRAModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
LoRAModelInputFieldTemplate,
|
||||
LoRAModelInputFieldValue,
|
||||
FieldComponentProps,
|
||||
} from 'features/nodes/types/types';
|
||||
LoRAModelFieldInputTemplate,
|
||||
LoRAModelFieldInputInstance,
|
||||
} from 'features/nodes/types/field';
|
||||
import { FieldComponentProps } from './types';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { modelIdToLoRAModelParam } from 'features/parameters/util/modelIdToLoRAModelParam';
|
||||
import { forEach } from 'lodash-es';
|
||||
@ -16,10 +16,10 @@ import { memo, useCallback, useMemo } from 'react';
|
||||
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const LoRAModelInputFieldComponent = (
|
||||
const LoRAModelFieldInputComponent = (
|
||||
props: FieldComponentProps<
|
||||
LoRAModelInputFieldValue,
|
||||
LoRAModelInputFieldTemplate
|
||||
LoRAModelFieldInputInstance,
|
||||
LoRAModelFieldInputTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
@ -121,4 +121,4 @@ const LoRAModelInputFieldComponent = (
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(LoRAModelInputFieldComponent);
|
||||
export default memo(LoRAModelFieldInputComponent);
|
@ -4,10 +4,10 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
MainModelInputFieldTemplate,
|
||||
MainModelInputFieldValue,
|
||||
FieldComponentProps,
|
||||
} from 'features/nodes/types/types';
|
||||
MainModelFieldInputTemplate,
|
||||
MainModelFieldInputInstance,
|
||||
} from 'features/nodes/types/field';
|
||||
import { FieldComponentProps } from './types';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
@ -21,10 +21,10 @@ import {
|
||||
} from 'services/api/endpoints/models';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const MainModelInputFieldComponent = (
|
||||
const MainModelFieldInputComponent = (
|
||||
props: FieldComponentProps<
|
||||
MainModelInputFieldValue,
|
||||
MainModelInputFieldTemplate
|
||||
MainModelFieldInputInstance,
|
||||
MainModelFieldInputTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
@ -149,4 +149,4 @@ const MainModelInputFieldComponent = (
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(MainModelInputFieldComponent);
|
||||
export default memo(MainModelFieldInputComponent);
|
@ -9,28 +9,18 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { numberStringRegex } from 'common/components/IAINumberInput';
|
||||
import { fieldNumberValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
FieldComponentProps,
|
||||
FloatInputFieldTemplate,
|
||||
FloatInputFieldValue,
|
||||
FloatPolymorphicInputFieldTemplate,
|
||||
FloatPolymorphicInputFieldValue,
|
||||
IntegerInputFieldTemplate,
|
||||
IntegerInputFieldValue,
|
||||
IntegerPolymorphicInputFieldTemplate,
|
||||
IntegerPolymorphicInputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
FloatFieldInputInstance,
|
||||
FloatFieldInputTemplate,
|
||||
IntegerFieldInputInstance,
|
||||
IntegerFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { FieldComponentProps } from './types';
|
||||
import { memo, useCallback, useEffect, useMemo, useState } from 'react';
|
||||
|
||||
const NumberInputFieldComponent = (
|
||||
const NumberFieldInputComponent = (
|
||||
props: FieldComponentProps<
|
||||
| IntegerInputFieldValue
|
||||
| IntegerPolymorphicInputFieldValue
|
||||
| FloatInputFieldValue
|
||||
| FloatPolymorphicInputFieldValue,
|
||||
| IntegerInputFieldTemplate
|
||||
| IntegerPolymorphicInputFieldTemplate
|
||||
| FloatInputFieldTemplate
|
||||
| FloatPolymorphicInputFieldTemplate
|
||||
IntegerFieldInputInstance | FloatFieldInputInstance,
|
||||
IntegerFieldInputTemplate | FloatFieldInputTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field, fieldTemplate } = props;
|
||||
@ -39,7 +29,7 @@ const NumberInputFieldComponent = (
|
||||
String(field.value)
|
||||
);
|
||||
const isIntegerField = useMemo(
|
||||
() => fieldTemplate.type === 'integer',
|
||||
() => fieldTemplate.type.name === 'IntegerField',
|
||||
[fieldTemplate.type]
|
||||
);
|
||||
|
||||
@ -86,4 +76,4 @@ const NumberInputFieldComponent = (
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(NumberInputFieldComponent);
|
||||
export default memo(NumberFieldInputComponent);
|
@ -4,10 +4,10 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||
import { fieldRefinerModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
FieldComponentProps,
|
||||
SDXLRefinerModelInputFieldTemplate,
|
||||
SDXLRefinerModelInputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
SDXLRefinerModelFieldInputTemplate,
|
||||
SDXLRefinerModelFieldInputInstance,
|
||||
} from 'features/nodes/types/field';
|
||||
import { FieldComponentProps } from './types';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
@ -18,10 +18,10 @@ import { useTranslation } from 'react-i18next';
|
||||
import { REFINER_BASE_MODELS } from 'services/api/constants';
|
||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
const RefinerModelInputFieldComponent = (
|
||||
const RefinerModelFieldInputComponent = (
|
||||
props: FieldComponentProps<
|
||||
SDXLRefinerModelInputFieldValue,
|
||||
SDXLRefinerModelInputFieldTemplate
|
||||
SDXLRefinerModelFieldInputInstance,
|
||||
SDXLRefinerModelFieldInputTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
@ -120,4 +120,4 @@ const RefinerModelInputFieldComponent = (
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(RefinerModelInputFieldComponent);
|
||||
export default memo(RefinerModelFieldInputComponent);
|
@ -4,10 +4,10 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
SDXLMainModelInputFieldTemplate,
|
||||
SDXLMainModelInputFieldValue,
|
||||
FieldComponentProps,
|
||||
} from 'features/nodes/types/types';
|
||||
SDXLMainModelFieldInputTemplate,
|
||||
SDXLMainModelFieldInputInstance,
|
||||
} from 'features/nodes/types/field';
|
||||
import { FieldComponentProps } from './types';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
@ -21,10 +21,10 @@ import {
|
||||
useGetOnnxModelsQuery,
|
||||
} from 'services/api/endpoints/models';
|
||||
|
||||
const ModelInputFieldComponent = (
|
||||
const SDXLMainModelFieldInputComponent = (
|
||||
props: FieldComponentProps<
|
||||
SDXLMainModelInputFieldValue,
|
||||
SDXLMainModelInputFieldTemplate
|
||||
SDXLMainModelFieldInputInstance,
|
||||
SDXLMainModelFieldInputTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
@ -147,4 +147,4 @@ const ModelInputFieldComponent = (
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ModelInputFieldComponent);
|
||||
export default memo(SDXLMainModelFieldInputComponent);
|
@ -5,14 +5,12 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||
import { fieldSchedulerValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
SchedulerInputFieldTemplate,
|
||||
SchedulerInputFieldValue,
|
||||
FieldComponentProps,
|
||||
} from 'features/nodes/types/types';
|
||||
import {
|
||||
SCHEDULER_LABEL_MAP,
|
||||
SchedulerParam,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
SchedulerFieldInputTemplate,
|
||||
SchedulerFieldInputInstance,
|
||||
} from 'features/nodes/types/field';
|
||||
import { FieldComponentProps } from './types';
|
||||
import { ParameterScheduler } from 'features/parameters/types/parameterSchemas';
|
||||
import { SCHEDULER_LABEL_MAP } from 'features/parameters/types/constants';
|
||||
import { map } from 'lodash-es';
|
||||
import { memo, useCallback } from 'react';
|
||||
|
||||
@ -24,7 +22,7 @@ const selector = createSelector(
|
||||
const data = map(SCHEDULER_LABEL_MAP, (label, name) => ({
|
||||
value: name,
|
||||
label: label,
|
||||
group: enabledSchedulers.includes(name as SchedulerParam)
|
||||
group: enabledSchedulers.includes(name as ParameterScheduler)
|
||||
? 'Favorites'
|
||||
: undefined,
|
||||
})).sort((a, b) => a.label.localeCompare(b.label));
|
||||
@ -36,10 +34,10 @@ const selector = createSelector(
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const SchedulerInputField = (
|
||||
const SchedulerFieldInputComponent = (
|
||||
props: FieldComponentProps<
|
||||
SchedulerInputFieldValue,
|
||||
SchedulerInputFieldTemplate
|
||||
SchedulerFieldInputInstance,
|
||||
SchedulerFieldInputTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
@ -55,7 +53,7 @@ const SchedulerInputField = (
|
||||
fieldSchedulerValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value: value as SchedulerParam,
|
||||
value: value as ParameterScheduler,
|
||||
})
|
||||
);
|
||||
},
|
||||
@ -72,4 +70,4 @@ const SchedulerInputField = (
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(SchedulerInputField);
|
||||
export default memo(SchedulerFieldInputComponent);
|
@ -3,19 +3,14 @@ import IAIInput from 'common/components/IAIInput';
|
||||
import IAITextarea from 'common/components/IAITextarea';
|
||||
import { fieldStringValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
StringInputFieldTemplate,
|
||||
StringInputFieldValue,
|
||||
FieldComponentProps,
|
||||
StringPolymorphicInputFieldValue,
|
||||
StringPolymorphicInputFieldTemplate,
|
||||
} from 'features/nodes/types/types';
|
||||
StringFieldInputInstance,
|
||||
StringFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { FieldComponentProps } from './types';
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
|
||||
const StringInputFieldComponent = (
|
||||
props: FieldComponentProps<
|
||||
StringInputFieldValue | StringPolymorphicInputFieldValue,
|
||||
StringInputFieldTemplate | StringPolymorphicInputFieldTemplate
|
||||
>
|
||||
const StringFieldInputComponent = (
|
||||
props: FieldComponentProps<StringFieldInputInstance, StringFieldInputTemplate>
|
||||
) => {
|
||||
const { nodeId, field, fieldTemplate } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
@ -48,4 +43,4 @@ const StringInputFieldComponent = (
|
||||
return <IAIInput onChange={handleValueChanged} value={field.value} />;
|
||||
};
|
||||
|
||||
export default memo(StringInputFieldComponent);
|
||||
export default memo(StringFieldInputComponent);
|
@ -3,20 +3,20 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { fieldT2IAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
T2IAdapterModelInputFieldTemplate,
|
||||
T2IAdapterModelInputFieldValue,
|
||||
FieldComponentProps,
|
||||
} from 'features/nodes/types/types';
|
||||
T2IAdapterModelFieldInputInstance,
|
||||
T2IAdapterModelFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { FieldComponentProps } from './types';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { modelIdToT2IAdapterModelParam } from 'features/parameters/util/modelIdToT2IAdapterModelParam';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useGetT2IAdapterModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
const T2IAdapterModelInputFieldComponent = (
|
||||
const T2IAdapterModelFieldInputComponent = (
|
||||
props: FieldComponentProps<
|
||||
T2IAdapterModelInputFieldValue,
|
||||
T2IAdapterModelInputFieldTemplate
|
||||
T2IAdapterModelFieldInputInstance,
|
||||
T2IAdapterModelFieldInputTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
@ -97,4 +97,4 @@ const T2IAdapterModelInputFieldComponent = (
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(T2IAdapterModelInputFieldComponent);
|
||||
export default memo(T2IAdapterModelFieldInputComponent);
|
@ -4,20 +4,20 @@ import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSe
|
||||
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
|
||||
import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
FieldComponentProps,
|
||||
VaeModelInputFieldTemplate,
|
||||
VaeModelInputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
VAEModelFieldInputTemplate,
|
||||
VAEModelFieldInputInstance,
|
||||
} from 'features/nodes/types/field';
|
||||
import { FieldComponentProps } from './types';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { modelIdToVAEModelParam } from 'features/parameters/util/modelIdToVAEModelParam';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
const VaeModelInputFieldComponent = (
|
||||
const VAEModelFieldInputComponent = (
|
||||
props: FieldComponentProps<
|
||||
VaeModelInputFieldValue,
|
||||
VaeModelInputFieldTemplate
|
||||
VAEModelFieldInputInstance,
|
||||
VAEModelFieldInputTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
@ -105,4 +105,4 @@ const VaeModelInputFieldComponent = (
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(VaeModelInputFieldComponent);
|
||||
export default memo(VAEModelFieldInputComponent);
|
@ -0,0 +1,13 @@
|
||||
import {
|
||||
FieldInputInstance,
|
||||
FieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
|
||||
export type FieldComponentProps<
|
||||
V extends FieldInputInstance,
|
||||
T extends FieldInputTemplate,
|
||||
> = {
|
||||
nodeId: string;
|
||||
field: V;
|
||||
fieldTemplate: T;
|
||||
};
|
@ -2,7 +2,7 @@ import { Box, Flex } from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAITextarea from 'common/components/IAITextarea';
|
||||
import { notesNodeValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { NotesNodeData } from 'features/nodes/types/types';
|
||||
import { NotesNodeData } from 'features/nodes/types/invocation';
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
import { NodeProps } from 'reactflow';
|
||||
import NodeWrapper from '../common/NodeWrapper';
|
||||
|
@ -14,7 +14,7 @@ import {
|
||||
DRAG_HANDLE_CLASSNAME,
|
||||
NODE_WIDTH,
|
||||
} from 'features/nodes/types/constants';
|
||||
import { NodeStatus } from 'features/nodes/types/types';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { contextMenusClosed } from 'features/ui/store/uiSlice';
|
||||
import {
|
||||
MouseEvent,
|
||||
@ -40,7 +40,8 @@ const NodeWrapper = (props: NodeWrapperProps) => {
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ nodes }) =>
|
||||
nodes.nodeExecutionStates[nodeId]?.status === NodeStatus.IN_PROGRESS
|
||||
nodes.nodeExecutionStates[nodeId]?.status ===
|
||||
zNodeStatus.enum.IN_PROGRESS
|
||||
),
|
||||
[nodeId]
|
||||
);
|
||||
|
@ -8,7 +8,7 @@ import { FaUpload } from 'react-icons/fa';
|
||||
const LoadWorkflowButton = () => {
|
||||
const { t } = useTranslation();
|
||||
const resetRef = useRef<() => void>(null);
|
||||
const loadWorkflowFromFile = useLoadWorkflowFromFile();
|
||||
const loadWorkflowFromFile = useLoadWorkflowFromFile(resetRef);
|
||||
return (
|
||||
<FileButton
|
||||
resetRef={resetRef}
|
||||
|
@ -1,31 +0,0 @@
|
||||
import { Badge, Flex, Tooltip } from '@chakra-ui/react';
|
||||
import { FIELDS } from 'features/nodes/types/constants';
|
||||
import { map } from 'lodash-es';
|
||||
import { memo } from 'react';
|
||||
import 'reactflow/dist/style.css';
|
||||
|
||||
const FieldTypeLegend = () => {
|
||||
return (
|
||||
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
||||
{map(FIELDS, ({ title, description, color }, key) => (
|
||||
<Tooltip key={key} label={description}>
|
||||
<Badge
|
||||
sx={{
|
||||
userSelect: 'none',
|
||||
color:
|
||||
parseInt(color.split('.')[1] ?? '0', 10) < 500
|
||||
? 'base.800'
|
||||
: 'base.50',
|
||||
bg: color,
|
||||
}}
|
||||
textAlign="center"
|
||||
>
|
||||
{title}
|
||||
</Badge>
|
||||
</Tooltip>
|
||||
))}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(FieldTypeLegend);
|
@ -1,18 +1,11 @@
|
||||
import { Flex } from '@chakra-ui/layout';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { memo } from 'react';
|
||||
import FieldTypeLegend from './FieldTypeLegend';
|
||||
import WorkflowEditorSettings from './WorkflowEditorSettings';
|
||||
|
||||
const TopRightPanel = () => {
|
||||
const shouldShowFieldTypeLegend = useAppSelector(
|
||||
(state) => state.nodes.shouldShowFieldTypeLegend
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex sx={{ gap: 2, position: 'absolute', top: 2, insetInlineEnd: 2 }}>
|
||||
<WorkflowEditorSettings />
|
||||
{shouldShowFieldTypeLegend && <FieldTypeLegend />}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
@ -10,17 +10,15 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import { useNodeVersion } from 'features/nodes/hooks/useNodeVersion';
|
||||
import { getNeedsUpdate } from 'features/nodes/store/util/nodeUpdate';
|
||||
import {
|
||||
InvocationNodeData,
|
||||
InvocationTemplate,
|
||||
isInvocationNode,
|
||||
} from 'features/nodes/types/types';
|
||||
import { memo } from 'react';
|
||||
} from 'features/nodes/types/invocation';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaSync } from 'react-icons/fa';
|
||||
import { Node } from 'reactflow';
|
||||
import NotesTextarea from '../../flow/nodes/Invocation/NotesTextarea';
|
||||
import ScrollableContent from '../ScrollableContent';
|
||||
@ -63,12 +61,17 @@ const InspectorDetailsTab = () => {
|
||||
|
||||
export default memo(InspectorDetailsTab);
|
||||
|
||||
const Content = (props: {
|
||||
type ContentProps = {
|
||||
node: Node<InvocationNodeData>;
|
||||
template: InvocationTemplate;
|
||||
}) => {
|
||||
};
|
||||
|
||||
const Content = memo(({ node, template }: ContentProps) => {
|
||||
const { t } = useTranslation();
|
||||
const { needsUpdate, updateNode } = useNodeVersion(props.node.id);
|
||||
const needsUpdate = useMemo(
|
||||
() => getNeedsUpdate(node, template),
|
||||
[node, template]
|
||||
);
|
||||
return (
|
||||
<Box
|
||||
sx={{
|
||||
@ -87,12 +90,12 @@ const Content = (props: {
|
||||
w: 'full',
|
||||
}}
|
||||
>
|
||||
<EditableNodeTitle nodeId={props.node.data.id} />
|
||||
<EditableNodeTitle nodeId={node.data.id} />
|
||||
<HStack>
|
||||
<FormControl>
|
||||
<FormLabel>{t('nodes.nodeType')}</FormLabel>
|
||||
<Text fontSize="sm" fontWeight={600}>
|
||||
{props.template.title}
|
||||
{template.title}
|
||||
</Text>
|
||||
</FormControl>
|
||||
<Flex
|
||||
@ -104,22 +107,16 @@ const Content = (props: {
|
||||
<FormControl isInvalid={needsUpdate}>
|
||||
<FormLabel>{t('nodes.nodeVersion')}</FormLabel>
|
||||
<Text fontSize="sm" fontWeight={600}>
|
||||
{props.node.data.version}
|
||||
{node.data.version}
|
||||
</Text>
|
||||
</FormControl>
|
||||
{needsUpdate && (
|
||||
<IAIIconButton
|
||||
aria-label={t('nodes.updateNode')}
|
||||
tooltip={t('nodes.updateNode')}
|
||||
icon={<FaSync />}
|
||||
onClick={updateNode}
|
||||
/>
|
||||
)}
|
||||
</Flex>
|
||||
</HStack>
|
||||
<NotesTextarea nodeId={props.node.data.id} />
|
||||
<NotesTextarea nodeId={node.data.id} />
|
||||
</Flex>
|
||||
</ScrollableContent>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
});
|
||||
|
||||
Content.displayName = 'Content';
|
||||
|
@ -5,7 +5,7 @@ import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
|
||||
import { isInvocationNode } from 'features/nodes/types/types';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { memo } from 'react';
|
||||
import { ImageOutput } from 'services/api/types';
|
||||
import { AnyResult } from 'services/events/types';
|
||||
|
@ -2,14 +2,11 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { map } from 'lodash-es';
|
||||
import { keys, map } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
import {
|
||||
POLYMORPHIC_TYPES,
|
||||
TYPES_WITH_INPUT_COMPONENTS,
|
||||
} from '../types/constants';
|
||||
import { isInvocationNode } from '../types/invocation';
|
||||
import { getSortedFilteredFieldNames } from '../util/getSortedFilteredFieldNames';
|
||||
import { TEMPLATE_BUILDER_MAP } from '../util/buildFieldInputTemplate';
|
||||
|
||||
export const useAnyOrDirectInputFieldNames = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
@ -28,8 +25,8 @@ export const useAnyOrDirectInputFieldNames = (nodeId: string) => {
|
||||
const fields = map(nodeTemplate.inputs).filter(
|
||||
(field) =>
|
||||
(['any', 'direct'].includes(field.input) ||
|
||||
POLYMORPHIC_TYPES.includes(field.type)) &&
|
||||
TYPES_WITH_INPUT_COMPONENTS.includes(field.type)
|
||||
field.type.isPolymorphic) &&
|
||||
keys(TEMPLATE_BUILDER_MAP).includes(field.type.name)
|
||||
);
|
||||
return getSortedFilteredFieldNames(fields);
|
||||
},
|
||||
|
@ -3,10 +3,13 @@ import { RootState } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useCallback } from 'react';
|
||||
import { Node, useReactFlow } from 'reactflow';
|
||||
import { AnyInvocationType } from 'services/events/types';
|
||||
import { buildNodeData } from '../store/util/buildNodeData';
|
||||
import {
|
||||
buildCurrentImageNode,
|
||||
buildInvocationNode,
|
||||
buildNotesNode,
|
||||
} from '../store/util/buildNodeData';
|
||||
import { DRAG_HANDLE_CLASSNAME, NODE_WIDTH } from '../types/constants';
|
||||
|
||||
import { AnyNodeData, InvocationTemplate } from '../types/invocation';
|
||||
const templatesSelector = createSelector(
|
||||
[(state: RootState) => state.nodes],
|
||||
(nodes) => nodes.nodeTemplates
|
||||
@ -22,7 +25,8 @@ export const useBuildNodeData = () => {
|
||||
const flow = useReactFlow();
|
||||
|
||||
return useCallback(
|
||||
(type: AnyInvocationType | 'current_image' | 'notes') => {
|
||||
// string here is "any invocation type"
|
||||
(type: string | 'current_image' | 'notes'): Node<AnyNodeData> => {
|
||||
let _x = window.innerWidth / 2;
|
||||
let _y = window.innerHeight / 2;
|
||||
|
||||
@ -41,9 +45,19 @@ export const useBuildNodeData = () => {
|
||||
y: _y,
|
||||
});
|
||||
|
||||
const template = nodeTemplates[type];
|
||||
if (type === 'current_image') {
|
||||
return buildCurrentImageNode(position);
|
||||
}
|
||||
|
||||
return buildNodeData(type, position, template);
|
||||
if (type === 'notes') {
|
||||
return buildNotesNode(position);
|
||||
}
|
||||
|
||||
// TODO: Keep track of invocation types so we do not need to cast this
|
||||
// We know it is safe because the caller of this function gets the `type` arg from the list of invocation templates.
|
||||
const template = nodeTemplates[type] as InvocationTemplate;
|
||||
|
||||
return buildInvocationNode(position, template);
|
||||
},
|
||||
[nodeTemplates, flow]
|
||||
);
|
||||
|
@ -2,14 +2,11 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { map } from 'lodash-es';
|
||||
import { keys, map } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
import {
|
||||
POLYMORPHIC_TYPES,
|
||||
TYPES_WITH_INPUT_COMPONENTS,
|
||||
} from '../types/constants';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
import { isInvocationNode } from '../types/invocation';
|
||||
import { getSortedFilteredFieldNames } from '../util/getSortedFilteredFieldNames';
|
||||
import { TEMPLATE_BUILDER_MAP } from '../util/buildFieldInputTemplate';
|
||||
|
||||
export const useConnectionInputFieldNames = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
@ -29,9 +26,8 @@ export const useConnectionInputFieldNames = (nodeId: string) => {
|
||||
// get the visible fields
|
||||
const fields = map(nodeTemplate.inputs).filter(
|
||||
(field) =>
|
||||
(field.input === 'connection' &&
|
||||
!POLYMORPHIC_TYPES.includes(field.type)) ||
|
||||
!TYPES_WITH_INPUT_COMPONENTS.includes(field.type)
|
||||
(field.input === 'connection' && !field.type.isPolymorphic) ||
|
||||
!keys(TEMPLATE_BUILDER_MAP).includes(field.type.name)
|
||||
);
|
||||
|
||||
return getSortedFilteredFieldNames(fields);
|
||||
|
@ -8,7 +8,7 @@ import { useFieldType } from './useFieldType.ts';
|
||||
const selectIsConnectionInProgress = createSelector(
|
||||
stateSelector,
|
||||
({ nodes }) =>
|
||||
nodes.currentConnectionFieldType !== null &&
|
||||
nodes.connectionStartFieldType !== null &&
|
||||
nodes.connectionStartParams !== null
|
||||
);
|
||||
|
||||
|
@ -4,7 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { compareVersions } from 'compare-versions';
|
||||
import { useMemo } from 'react';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
import { isInvocationNode } from '../types/invocation';
|
||||
|
||||
export const useDoNodeVersionsMatch = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
|
@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { useMemo } from 'react';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
import { isInvocationNode } from '../types/invocation';
|
||||
|
||||
export const useDoesInputHaveValue = (nodeId: string, fieldName: string) => {
|
||||
const selector = useMemo(
|
||||
|
@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { useMemo } from 'react';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
import { isInvocationNode } from '../types/invocation';
|
||||
|
||||
export const useEmbedWorkflow = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
|
@ -3,9 +3,9 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { useMemo } from 'react';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
import { isInvocationNode } from '../types/invocation';
|
||||
|
||||
export const useFieldData = (nodeId: string, fieldName: string) => {
|
||||
export const useFieldInstance = (nodeId: string, fieldName: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
|
@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { useMemo } from 'react';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
import { isInvocationNode } from '../types/invocation';
|
||||
|
||||
export const useFieldInputKind = (nodeId: string, fieldName: string) => {
|
||||
const selector = useMemo(
|
||||
|
@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { useMemo } from 'react';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
import { isInvocationNode } from '../types/invocation';
|
||||
|
||||
export const useFieldLabel = (nodeId: string, fieldName: string) => {
|
||||
const selector = useMemo(
|
||||
|
@ -4,7 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { useMemo } from 'react';
|
||||
import { KIND_MAP } from '../types/constants';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
import { isInvocationNode } from '../types/invocation';
|
||||
|
||||
export const useFieldTemplate = (
|
||||
nodeId: string,
|
||||
|
@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { useMemo } from 'react';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
import { isInvocationNode } from '../types/invocation';
|
||||
import { KIND_MAP } from '../types/constants';
|
||||
|
||||
export const useFieldTemplateTitle = (
|
||||
|
@ -4,7 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { useMemo } from 'react';
|
||||
import { KIND_MAP } from '../types/constants';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
import { isInvocationNode } from '../types/invocation';
|
||||
|
||||
export const useFieldType = (
|
||||
nodeId: string,
|
||||
@ -20,7 +20,8 @@ export const useFieldType = (
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
return node?.data[KIND_MAP[kind]][fieldName]?.type;
|
||||
const field = node.data[KIND_MAP[kind]][fieldName];
|
||||
return field?.type;
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
|
@ -2,7 +2,8 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { getNeedsUpdate } from './useNodeVersion';
|
||||
import { getNeedsUpdate } from '../store/util/nodeUpdate';
|
||||
import { isInvocationNode } from '../types/invocation';
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
@ -10,8 +11,11 @@ const selector = createSelector(
|
||||
const nodes = state.nodes.nodes;
|
||||
const templates = state.nodes.nodeTemplates;
|
||||
|
||||
const needsUpdate = nodes.some((node) => {
|
||||
const needsUpdate = nodes.filter(isInvocationNode).some((node) => {
|
||||
const template = templates[node.data.type];
|
||||
if (!template) {
|
||||
return false;
|
||||
}
|
||||
return getNeedsUpdate(node, template);
|
||||
});
|
||||
return needsUpdate;
|
||||
|
@ -4,8 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { some } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
import { IMAGE_FIELDS } from '../types/constants';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
import { isInvocationNode } from '../types/invocation';
|
||||
|
||||
export const useHasImageOutput = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
@ -20,8 +19,8 @@ export const useHasImageOutput = (nodeId: string) => {
|
||||
return some(
|
||||
node.data.outputs,
|
||||
(output) =>
|
||||
IMAGE_FIELDS.includes(output.type) &&
|
||||
// the image primitive node does not actually save the image, do not show the image-saving checkboxes
|
||||
output.type.name === 'ImageField' &&
|
||||
// the image primitive node (node type "image") does not actually save the image, do not show the image-saving checkboxes
|
||||
node.data.type !== 'image'
|
||||
);
|
||||
},
|
||||
|
@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { useMemo } from 'react';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
import { isInvocationNode } from '../types/invocation';
|
||||
|
||||
export const useIsIntermediate = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
|
@ -4,7 +4,7 @@ import { useCallback } from 'react';
|
||||
import { Connection, Node, useReactFlow } from 'reactflow';
|
||||
import { validateSourceAndTargetTypes } from '../store/util/validateSourceAndTargetTypes';
|
||||
import { getIsGraphAcyclic } from '../store/util/getIsGraphAcyclic';
|
||||
import { InvocationNodeData } from '../types/types';
|
||||
import { InvocationNodeData } from '../types/invocation';
|
||||
|
||||
/**
|
||||
* NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts`
|
||||
@ -34,10 +34,10 @@ export const useIsValidConnection = () => {
|
||||
return false;
|
||||
}
|
||||
|
||||
const sourceType = sourceNode.data.outputs[sourceHandle]?.type;
|
||||
const targetType = targetNode.data.inputs[targetHandle]?.type;
|
||||
const sourceField = sourceNode.data.outputs[sourceHandle];
|
||||
const targetField = targetNode.data.inputs[targetHandle];
|
||||
|
||||
if (!sourceType || !targetType) {
|
||||
if (!sourceField || !targetField) {
|
||||
// something has gone terribly awry
|
||||
return false;
|
||||
}
|
||||
@ -70,12 +70,13 @@ export const useIsValidConnection = () => {
|
||||
return edge.target === target && edge.targetHandle === targetHandle;
|
||||
}) &&
|
||||
// except CollectionItem inputs can have multiples
|
||||
targetType !== 'CollectionItem'
|
||||
targetField.type.name !== 'CollectionItemField'
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!validateSourceAndTargetTypes(sourceType, targetType)) {
|
||||
// Must use the originalType here if it exists
|
||||
if (!validateSourceAndTargetTypes(sourceField.type, targetField.type)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -1,17 +1,15 @@
|
||||
import { ListItem, Text, UnorderedList } from '@chakra-ui/react';
|
||||
import { useLogger } from 'app/logging/useLogger';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { zWorkflow } from 'features/nodes/types/types';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { ZodError } from 'zod';
|
||||
import { fromZodError, fromZodIssue } from 'zod-validation-error';
|
||||
import { workflowLoadRequested } from '../store/actions';
|
||||
import { RefObject, memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { ZodError } from 'zod';
|
||||
import { fromZodIssue } from 'zod-validation-error';
|
||||
import { workflowLoadRequested } from '../store/actions';
|
||||
|
||||
export const useLoadWorkflowFromFile = () => {
|
||||
export const useLoadWorkflowFromFile = (resetRef: RefObject<() => void>) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const logger = useLogger('nodes');
|
||||
const { t } = useTranslation();
|
||||
@ -26,33 +24,10 @@ export const useLoadWorkflowFromFile = () => {
|
||||
|
||||
try {
|
||||
const parsedJSON = JSON.parse(String(rawJSON));
|
||||
const result = zWorkflow.safeParse(parsedJSON);
|
||||
|
||||
if (!result.success) {
|
||||
const { message } = fromZodError(result.error, {
|
||||
prefix: t('nodes.workflowValidation'),
|
||||
});
|
||||
|
||||
logger.error({ error: parseify(result.error) }, message);
|
||||
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('nodes.unableToValidateWorkflow'),
|
||||
status: 'error',
|
||||
duration: 5000,
|
||||
})
|
||||
)
|
||||
);
|
||||
reader.abort();
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(workflowLoadRequested(result.data));
|
||||
|
||||
reader.abort();
|
||||
} catch {
|
||||
// file reader error
|
||||
dispatch(workflowLoadRequested(parsedJSON));
|
||||
} catch (e) {
|
||||
// There was a problem reading the file
|
||||
logger.error(t('nodes.unableToLoadWorkflow'));
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
@ -61,12 +36,15 @@ export const useLoadWorkflowFromFile = () => {
|
||||
})
|
||||
)
|
||||
);
|
||||
reader.abort();
|
||||
}
|
||||
};
|
||||
|
||||
reader.readAsText(file);
|
||||
// Reset the file picker internal state so that the same file can be loaded again
|
||||
resetRef.current?.();
|
||||
},
|
||||
[dispatch, logger, t]
|
||||
[dispatch, logger, resetRef, t]
|
||||
);
|
||||
|
||||
return loadWorkflowFromFile;
|
||||
|
@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { useMemo } from 'react';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
import { isInvocationNode } from '../types/invocation';
|
||||
|
||||
export const useNodeLabel = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
|
@ -0,0 +1,35 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { useMemo } from 'react';
|
||||
import { isInvocationNode } from '../types/invocation';
|
||||
import { getNeedsUpdate } from '../store/util/nodeUpdate';
|
||||
|
||||
export const useNodeNeedsUpdate = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ nodes }) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
const template = nodes.nodeTemplates[node?.data.type ?? ''];
|
||||
return { node, template };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
const { node, template } = useAppSelector(selector);
|
||||
|
||||
const needsUpdate = useMemo(
|
||||
() =>
|
||||
isInvocationNode(node) && template
|
||||
? getNeedsUpdate(node, template)
|
||||
: false,
|
||||
[node, template]
|
||||
);
|
||||
|
||||
return needsUpdate;
|
||||
};
|
@ -3,16 +3,14 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { useMemo } from 'react';
|
||||
import { AnyInvocationType } from 'services/events/types';
|
||||
import { InvocationTemplate } from '../types/invocation';
|
||||
|
||||
export const useNodeTemplateByType = (
|
||||
type: AnyInvocationType | 'current_image' | 'notes'
|
||||
) => {
|
||||
export const useNodeTemplateByType = (type: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ nodes }) => {
|
||||
({ nodes }): InvocationTemplate | undefined => {
|
||||
const nodeTemplate = nodes.nodeTemplates[type];
|
||||
return nodeTemplate;
|
||||
},
|
||||
|
@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { useMemo } from 'react';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
import { isInvocationNode } from '../types/invocation';
|
||||
|
||||
export const useNodeTemplateTitle = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
|
@ -1,119 +0,0 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { satisfies } from 'compare-versions';
|
||||
import { cloneDeep, defaultsDeep } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { Node } from 'reactflow';
|
||||
import { AnyInvocationType } from 'services/events/types';
|
||||
import { nodeReplaced } from '../store/nodesSlice';
|
||||
import { buildNodeData } from '../store/util/buildNodeData';
|
||||
import {
|
||||
InvocationNodeData,
|
||||
InvocationTemplate,
|
||||
NodeData,
|
||||
isInvocationNode,
|
||||
zParsedSemver,
|
||||
} from '../types/types';
|
||||
import { useAppToaster } from 'app/components/Toaster';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export const getNeedsUpdate = (
|
||||
node?: Node<NodeData>,
|
||||
template?: InvocationTemplate
|
||||
) => {
|
||||
if (!isInvocationNode(node) || !template) {
|
||||
return false;
|
||||
}
|
||||
return node.data.version !== template.version;
|
||||
};
|
||||
|
||||
export const getMayUpdateNode = (
|
||||
node?: Node<NodeData>,
|
||||
template?: InvocationTemplate
|
||||
) => {
|
||||
const needsUpdate = getNeedsUpdate(node, template);
|
||||
if (
|
||||
!needsUpdate ||
|
||||
!isInvocationNode(node) ||
|
||||
!template ||
|
||||
!node.data.version
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
const templateMajor = zParsedSemver.parse(template.version).major;
|
||||
|
||||
return satisfies(node.data.version, `^${templateMajor}`);
|
||||
};
|
||||
|
||||
export const updateNode = (
|
||||
node?: Node<NodeData>,
|
||||
template?: InvocationTemplate
|
||||
) => {
|
||||
const mayUpdate = getMayUpdateNode(node, template);
|
||||
if (
|
||||
!mayUpdate ||
|
||||
!isInvocationNode(node) ||
|
||||
!template ||
|
||||
!node.data.version
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
const defaults = buildNodeData(
|
||||
node.data.type as AnyInvocationType,
|
||||
node.position,
|
||||
template
|
||||
) as Node<InvocationNodeData>;
|
||||
|
||||
const clone = cloneDeep(node);
|
||||
clone.data.version = template.version;
|
||||
defaultsDeep(clone, defaults);
|
||||
return clone;
|
||||
};
|
||||
|
||||
export const useNodeVersion = (nodeId: string) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const toast = useAppToaster();
|
||||
const { t } = useTranslation();
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ nodes }) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? ''];
|
||||
return { node, nodeTemplate };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
const { node, nodeTemplate } = useAppSelector(selector);
|
||||
|
||||
const needsUpdate = useMemo(
|
||||
() => getNeedsUpdate(node, nodeTemplate),
|
||||
[node, nodeTemplate]
|
||||
);
|
||||
|
||||
const mayUpdate = useMemo(
|
||||
() => getMayUpdateNode(node, nodeTemplate),
|
||||
[node, nodeTemplate]
|
||||
);
|
||||
|
||||
const _updateNode = useCallback(() => {
|
||||
const needsUpdate = getNeedsUpdate(node, nodeTemplate);
|
||||
const updatedNode = updateNode(node, nodeTemplate);
|
||||
if (!updatedNode) {
|
||||
if (needsUpdate) {
|
||||
toast({ title: t('nodes.unableToUpdateNodes', { count: 1 }) });
|
||||
}
|
||||
return;
|
||||
}
|
||||
dispatch(nodeReplaced({ nodeId: updatedNode.id, node: updatedNode }));
|
||||
}, [dispatch, node, nodeTemplate, t, toast]);
|
||||
|
||||
return { needsUpdate, mayUpdate, updateNode: _updateNode };
|
||||
};
|
@ -4,7 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { map } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
import { isInvocationNode } from '../types/invocation';
|
||||
import { getSortedFilteredFieldNames } from '../util/getSortedFilteredFieldNames';
|
||||
|
||||
export const useOutputFieldNames = (nodeId: string) => {
|
||||
|
@ -0,0 +1,23 @@
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FieldType } from '../types/field';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useFieldTypeName = (fieldType?: FieldType): string => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const name = useMemo(() => {
|
||||
if (!fieldType) {
|
||||
return '';
|
||||
}
|
||||
const { name } = fieldType;
|
||||
if (fieldType.isCollection) {
|
||||
return t('nodes.collectionFieldType', { name });
|
||||
}
|
||||
if (fieldType.isPolymorphic) {
|
||||
return t('nodes.polymorphicFieldType', { name });
|
||||
}
|
||||
return name;
|
||||
}, [fieldType, t]);
|
||||
|
||||
return name;
|
||||
};
|
@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { useMemo } from 'react';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
import { isInvocationNode } from '../types/invocation';
|
||||
|
||||
export const useUseCache = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
|
@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { useMemo } from 'react';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
import { isInvocationNode } from '../types/invocation';
|
||||
|
||||
export const useWithWorkflow = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
|
@ -1,6 +1,5 @@
|
||||
import { createAction, isAnyOf } from '@reduxjs/toolkit';
|
||||
import { Graph } from 'services/api/types';
|
||||
import { Workflow } from '../types/types';
|
||||
|
||||
export const textToImageGraphBuilt = createAction<Graph>(
|
||||
'nodes/textToImageGraphBuilt'
|
||||
@ -18,7 +17,7 @@ export const isAnyGraphBuilt = isAnyOf(
|
||||
nodesGraphBuilt
|
||||
);
|
||||
|
||||
export const workflowLoadRequested = createAction<Workflow>(
|
||||
export const workflowLoadRequested = createAction<unknown>(
|
||||
'nodes/workflowLoadRequested'
|
||||
);
|
||||
|
||||
|
@ -6,7 +6,7 @@ import { NodesState } from './types';
|
||||
export const nodesPersistDenylist: (keyof NodesState)[] = [
|
||||
'nodeTemplates',
|
||||
'connectionStartParams',
|
||||
'currentConnectionFieldType',
|
||||
'connectionStartFieldType',
|
||||
'selectedNodes',
|
||||
'selectedEdges',
|
||||
'isReady',
|
||||
|
@ -20,7 +20,6 @@ import {
|
||||
XYPosition,
|
||||
} from 'reactflow';
|
||||
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
|
||||
import { ImageField } from 'services/api/types';
|
||||
import {
|
||||
appSocketGeneratorProgress,
|
||||
appSocketInvocationComplete,
|
||||
@ -31,60 +30,58 @@ import {
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { DRAG_HANDLE_CLASSNAME } from '../types/constants';
|
||||
import {
|
||||
BoardInputFieldValue,
|
||||
BooleanInputFieldValue,
|
||||
ColorInputFieldValue,
|
||||
ControlNetModelInputFieldValue,
|
||||
CurrentImageNodeData,
|
||||
EnumInputFieldValue,
|
||||
BoardFieldValue,
|
||||
BooleanFieldValue,
|
||||
ColorFieldValue,
|
||||
ControlNetModelFieldValue,
|
||||
EnumFieldValue,
|
||||
FieldIdentifier,
|
||||
FloatInputFieldValue,
|
||||
ImageInputFieldValue,
|
||||
InputFieldValue,
|
||||
IntegerInputFieldValue,
|
||||
InvocationNodeData,
|
||||
FieldValue,
|
||||
FloatFieldValue,
|
||||
ImageFieldValue,
|
||||
IntegerFieldValue,
|
||||
IPAdapterModelFieldValue,
|
||||
LoRAModelFieldValue,
|
||||
MainModelFieldValue,
|
||||
SchedulerFieldValue,
|
||||
SDXLRefinerModelFieldValue,
|
||||
StringFieldValue,
|
||||
T2IAdapterModelFieldValue,
|
||||
VAEModelFieldValue,
|
||||
} from '../types/field';
|
||||
import {
|
||||
AnyNodeData,
|
||||
InvocationTemplate,
|
||||
IPAdapterModelInputFieldValue,
|
||||
isInvocationNode,
|
||||
isNotesNode,
|
||||
LoRAModelInputFieldValue,
|
||||
MainModelInputFieldValue,
|
||||
NodeExecutionState,
|
||||
NodeStatus,
|
||||
NotesNodeData,
|
||||
SchedulerInputFieldValue,
|
||||
SDXLRefinerModelInputFieldValue,
|
||||
StringInputFieldValue,
|
||||
T2IAdapterModelInputFieldValue,
|
||||
VaeModelInputFieldValue,
|
||||
Workflow,
|
||||
} from '../types/types';
|
||||
zNodeStatus,
|
||||
} from '../types/invocation';
|
||||
import { WorkflowV2 } from '../types/workflow';
|
||||
import { NodesState } from './types';
|
||||
import { findUnoccupiedPosition } from './util/findUnoccupiedPosition';
|
||||
import { findConnectionToValidHandle } from './util/findConnectionToValidHandle';
|
||||
|
||||
export const WORKFLOW_FORMAT_VERSION = '1.0.0';
|
||||
import { findUnoccupiedPosition } from './util/findUnoccupiedPosition';
|
||||
|
||||
const initialNodeExecutionState: Omit<NodeExecutionState, 'nodeId'> = {
|
||||
status: NodeStatus.PENDING,
|
||||
status: zNodeStatus.enum.PENDING,
|
||||
error: null,
|
||||
progress: null,
|
||||
progressImage: null,
|
||||
outputs: [],
|
||||
};
|
||||
|
||||
export const initialWorkflow = {
|
||||
meta: {
|
||||
version: WORKFLOW_FORMAT_VERSION,
|
||||
},
|
||||
const INITIAL_WORKFLOW: WorkflowV2 = {
|
||||
name: '',
|
||||
author: '',
|
||||
description: '',
|
||||
notes: '',
|
||||
tags: '',
|
||||
contact: '',
|
||||
version: '',
|
||||
contact: '',
|
||||
tags: '',
|
||||
notes: '',
|
||||
nodes: [],
|
||||
edges: [],
|
||||
exposedFields: [],
|
||||
meta: { version: '2.0.0' },
|
||||
};
|
||||
|
||||
export const initialNodesState: NodesState = {
|
||||
@ -93,11 +90,10 @@ export const initialNodesState: NodesState = {
|
||||
nodeTemplates: {},
|
||||
isReady: false,
|
||||
connectionStartParams: null,
|
||||
currentConnectionFieldType: null,
|
||||
connectionStartFieldType: null,
|
||||
connectionMade: false,
|
||||
modifyingEdge: false,
|
||||
addNewNodePosition: null,
|
||||
shouldShowFieldTypeLegend: false,
|
||||
shouldShowMinimapPanel: true,
|
||||
shouldValidateGraph: true,
|
||||
shouldAnimateEdges: true,
|
||||
@ -107,7 +103,7 @@ export const initialNodesState: NodesState = {
|
||||
nodeOpacity: 1,
|
||||
selectedNodes: [],
|
||||
selectedEdges: [],
|
||||
workflow: initialWorkflow,
|
||||
workflow: INITIAL_WORKFLOW,
|
||||
nodeExecutionStates: {},
|
||||
viewport: { x: 0, y: 0, zoom: 1 },
|
||||
mouseOverField: null,
|
||||
@ -117,13 +113,13 @@ export const initialNodesState: NodesState = {
|
||||
selectionMode: SelectionMode.Partial,
|
||||
};
|
||||
|
||||
type FieldValueAction<T extends InputFieldValue> = PayloadAction<{
|
||||
type FieldValueAction<T extends FieldValue> = PayloadAction<{
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
value: T['value'];
|
||||
value: T;
|
||||
}>;
|
||||
|
||||
const fieldValueReducer = <T extends InputFieldValue>(
|
||||
const fieldValueReducer = <T extends FieldValue>(
|
||||
state: NodesState,
|
||||
action: FieldValueAction<T>
|
||||
) => {
|
||||
@ -161,12 +157,7 @@ const nodesSlice = createSlice({
|
||||
}
|
||||
state.nodes[nodeIndex] = action.payload.node;
|
||||
},
|
||||
nodeAdded: (
|
||||
state,
|
||||
action: PayloadAction<
|
||||
Node<InvocationNodeData | CurrentImageNodeData | NotesNodeData>
|
||||
>
|
||||
) => {
|
||||
nodeAdded: (state, action: PayloadAction<Node<AnyNodeData>>) => {
|
||||
const node = action.payload;
|
||||
const position = findUnoccupiedPosition(
|
||||
state.nodes,
|
||||
@ -203,7 +194,7 @@ const nodesSlice = createSlice({
|
||||
nodeId &&
|
||||
handleId &&
|
||||
handleType &&
|
||||
state.currentConnectionFieldType
|
||||
state.connectionStartFieldType
|
||||
) {
|
||||
const newConnection = findConnectionToValidHandle(
|
||||
node,
|
||||
@ -212,7 +203,7 @@ const nodesSlice = createSlice({
|
||||
nodeId,
|
||||
handleId,
|
||||
handleType,
|
||||
state.currentConnectionFieldType
|
||||
state.connectionStartFieldType
|
||||
);
|
||||
if (newConnection) {
|
||||
state.edges = addEdge(
|
||||
@ -224,7 +215,7 @@ const nodesSlice = createSlice({
|
||||
}
|
||||
|
||||
state.connectionStartParams = null;
|
||||
state.currentConnectionFieldType = null;
|
||||
state.connectionStartFieldType = null;
|
||||
},
|
||||
edgeChangeStarted: (state) => {
|
||||
state.modifyingEdge = true;
|
||||
@ -258,10 +249,10 @@ const nodesSlice = createSlice({
|
||||
handleType === 'source'
|
||||
? node.data.outputs[handleId]
|
||||
: node.data.inputs[handleId];
|
||||
state.currentConnectionFieldType = field?.type ?? null;
|
||||
state.connectionStartFieldType = field?.type ?? null;
|
||||
},
|
||||
connectionMade: (state, action: PayloadAction<Connection>) => {
|
||||
const fieldType = state.currentConnectionFieldType;
|
||||
const fieldType = state.connectionStartFieldType;
|
||||
if (!fieldType) {
|
||||
return;
|
||||
}
|
||||
@ -286,7 +277,7 @@ const nodesSlice = createSlice({
|
||||
nodeId &&
|
||||
handleId &&
|
||||
handleType &&
|
||||
state.currentConnectionFieldType
|
||||
state.connectionStartFieldType
|
||||
) {
|
||||
const newConnection = findConnectionToValidHandle(
|
||||
mouseOverNode,
|
||||
@ -295,7 +286,7 @@ const nodesSlice = createSlice({
|
||||
nodeId,
|
||||
handleId,
|
||||
handleType,
|
||||
state.currentConnectionFieldType
|
||||
state.connectionStartFieldType
|
||||
);
|
||||
if (newConnection) {
|
||||
state.edges = addEdge(
|
||||
@ -306,14 +297,14 @@ const nodesSlice = createSlice({
|
||||
}
|
||||
}
|
||||
state.connectionStartParams = null;
|
||||
state.currentConnectionFieldType = null;
|
||||
state.connectionStartFieldType = null;
|
||||
} else {
|
||||
state.addNewNodePosition = action.payload.cursorPosition;
|
||||
state.isAddNodePopoverOpen = true;
|
||||
}
|
||||
} else {
|
||||
state.connectionStartParams = null;
|
||||
state.currentConnectionFieldType = null;
|
||||
state.connectionStartFieldType = null;
|
||||
}
|
||||
state.modifyingEdge = false;
|
||||
},
|
||||
@ -529,12 +520,7 @@ const nodesSlice = createSlice({
|
||||
state.edges = applyEdgeChanges(edgeChanges, state.edges);
|
||||
}
|
||||
},
|
||||
nodesDeleted: (
|
||||
state,
|
||||
action: PayloadAction<
|
||||
Node<InvocationNodeData | NotesNodeData | CurrentImageNodeData>[]
|
||||
>
|
||||
) => {
|
||||
nodesDeleted: (state, action: PayloadAction<Node<AnyNodeData>[]>) => {
|
||||
action.payload.forEach((node) => {
|
||||
state.workflow.exposedFields = state.workflow.exposedFields.filter(
|
||||
(f) => f.nodeId !== node.id
|
||||
@ -588,132 +574,94 @@ const nodesSlice = createSlice({
|
||||
},
|
||||
fieldStringValueChanged: (
|
||||
state,
|
||||
action: FieldValueAction<StringInputFieldValue>
|
||||
action: FieldValueAction<StringFieldValue>
|
||||
) => {
|
||||
fieldValueReducer(state, action);
|
||||
},
|
||||
fieldNumberValueChanged: (
|
||||
state,
|
||||
action: FieldValueAction<IntegerInputFieldValue | FloatInputFieldValue>
|
||||
action: FieldValueAction<IntegerFieldValue | FloatFieldValue>
|
||||
) => {
|
||||
fieldValueReducer(state, action);
|
||||
},
|
||||
fieldBooleanValueChanged: (
|
||||
state,
|
||||
action: FieldValueAction<BooleanInputFieldValue>
|
||||
action: FieldValueAction<BooleanFieldValue>
|
||||
) => {
|
||||
fieldValueReducer(state, action);
|
||||
},
|
||||
fieldBoardValueChanged: (
|
||||
state,
|
||||
action: FieldValueAction<BoardInputFieldValue>
|
||||
action: FieldValueAction<BoardFieldValue>
|
||||
) => {
|
||||
fieldValueReducer(state, action);
|
||||
},
|
||||
fieldImageValueChanged: (
|
||||
state,
|
||||
action: FieldValueAction<ImageInputFieldValue>
|
||||
action: FieldValueAction<ImageFieldValue>
|
||||
) => {
|
||||
fieldValueReducer(state, action);
|
||||
},
|
||||
fieldColorValueChanged: (
|
||||
state,
|
||||
action: FieldValueAction<ColorInputFieldValue>
|
||||
action: FieldValueAction<ColorFieldValue>
|
||||
) => {
|
||||
fieldValueReducer(state, action);
|
||||
},
|
||||
fieldMainModelValueChanged: (
|
||||
state,
|
||||
action: FieldValueAction<MainModelInputFieldValue>
|
||||
action: FieldValueAction<MainModelFieldValue>
|
||||
) => {
|
||||
fieldValueReducer(state, action);
|
||||
},
|
||||
fieldRefinerModelValueChanged: (
|
||||
state,
|
||||
action: FieldValueAction<SDXLRefinerModelInputFieldValue>
|
||||
action: FieldValueAction<SDXLRefinerModelFieldValue>
|
||||
) => {
|
||||
fieldValueReducer(state, action);
|
||||
},
|
||||
fieldVaeModelValueChanged: (
|
||||
state,
|
||||
action: FieldValueAction<VaeModelInputFieldValue>
|
||||
action: FieldValueAction<VAEModelFieldValue>
|
||||
) => {
|
||||
fieldValueReducer(state, action);
|
||||
},
|
||||
fieldLoRAModelValueChanged: (
|
||||
state,
|
||||
action: FieldValueAction<LoRAModelInputFieldValue>
|
||||
action: FieldValueAction<LoRAModelFieldValue>
|
||||
) => {
|
||||
fieldValueReducer(state, action);
|
||||
},
|
||||
fieldControlNetModelValueChanged: (
|
||||
state,
|
||||
action: FieldValueAction<ControlNetModelInputFieldValue>
|
||||
action: FieldValueAction<ControlNetModelFieldValue>
|
||||
) => {
|
||||
fieldValueReducer(state, action);
|
||||
},
|
||||
fieldIPAdapterModelValueChanged: (
|
||||
state,
|
||||
action: FieldValueAction<IPAdapterModelInputFieldValue>
|
||||
action: FieldValueAction<IPAdapterModelFieldValue>
|
||||
) => {
|
||||
fieldValueReducer(state, action);
|
||||
},
|
||||
fieldT2IAdapterModelValueChanged: (
|
||||
state,
|
||||
action: FieldValueAction<T2IAdapterModelInputFieldValue>
|
||||
action: FieldValueAction<T2IAdapterModelFieldValue>
|
||||
) => {
|
||||
fieldValueReducer(state, action);
|
||||
},
|
||||
fieldEnumModelValueChanged: (
|
||||
state,
|
||||
action: FieldValueAction<EnumInputFieldValue>
|
||||
action: FieldValueAction<EnumFieldValue>
|
||||
) => {
|
||||
fieldValueReducer(state, action);
|
||||
},
|
||||
fieldSchedulerValueChanged: (
|
||||
state,
|
||||
action: FieldValueAction<SchedulerInputFieldValue>
|
||||
action: FieldValueAction<SchedulerFieldValue>
|
||||
) => {
|
||||
fieldValueReducer(state, action);
|
||||
},
|
||||
imageCollectionFieldValueChanged: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
value: ImageField[];
|
||||
}>
|
||||
) => {
|
||||
const { nodeId, fieldName, value } = action.payload;
|
||||
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
|
||||
|
||||
if (nodeIndex === -1) {
|
||||
return;
|
||||
}
|
||||
|
||||
const node = state.nodes?.[nodeIndex];
|
||||
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const input = node.data?.inputs[fieldName];
|
||||
if (!input) {
|
||||
return;
|
||||
}
|
||||
|
||||
const currentValue = cloneDeep(input.value);
|
||||
|
||||
if (!currentValue) {
|
||||
input.value = value;
|
||||
return;
|
||||
}
|
||||
|
||||
input.value = uniqBy(
|
||||
(currentValue as ImageField[]).concat(value),
|
||||
'image_name'
|
||||
);
|
||||
},
|
||||
notesNodeValueChanged: (
|
||||
state,
|
||||
action: PayloadAction<{ nodeId: string; value: string }>
|
||||
@ -726,12 +674,6 @@ const nodesSlice = createSlice({
|
||||
}
|
||||
node.data.notes = value;
|
||||
},
|
||||
shouldShowFieldTypeLegendChanged: (
|
||||
state,
|
||||
action: PayloadAction<boolean>
|
||||
) => {
|
||||
state.shouldShowFieldTypeLegend = action.payload;
|
||||
},
|
||||
shouldShowMinimapPanelChanged: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldShowMinimapPanel = action.payload;
|
||||
},
|
||||
@ -745,7 +687,7 @@ const nodesSlice = createSlice({
|
||||
nodeEditorReset: (state) => {
|
||||
state.nodes = [];
|
||||
state.edges = [];
|
||||
state.workflow = cloneDeep(initialWorkflow);
|
||||
state.workflow = cloneDeep(INITIAL_WORKFLOW);
|
||||
},
|
||||
shouldValidateGraphChanged: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldValidateGraph = action.payload;
|
||||
@ -783,7 +725,7 @@ const nodesSlice = createSlice({
|
||||
workflowContactChanged: (state, action: PayloadAction<string>) => {
|
||||
state.workflow.contact = action.payload;
|
||||
},
|
||||
workflowLoaded: (state, action: PayloadAction<Workflow>) => {
|
||||
workflowLoaded: (state, action: PayloadAction<WorkflowV2>) => {
|
||||
const { nodes, edges, ...workflow } = action.payload;
|
||||
state.workflow = workflow;
|
||||
|
||||
@ -810,7 +752,7 @@ const nodesSlice = createSlice({
|
||||
}, {});
|
||||
},
|
||||
workflowReset: (state) => {
|
||||
state.workflow = cloneDeep(initialWorkflow);
|
||||
state.workflow = cloneDeep(INITIAL_WORKFLOW);
|
||||
},
|
||||
viewportChanged: (state, action: PayloadAction<Viewport>) => {
|
||||
state.viewport = action.payload;
|
||||
@ -942,7 +884,7 @@ const nodesSlice = createSlice({
|
||||
|
||||
//Make sure these get reset if we close the popover and haven't selected a node
|
||||
state.connectionStartParams = null;
|
||||
state.currentConnectionFieldType = null;
|
||||
state.connectionStartFieldType = null;
|
||||
},
|
||||
addNodePopoverToggled: (state) => {
|
||||
state.isAddNodePopoverOpen = !state.isAddNodePopoverOpen;
|
||||
@ -961,14 +903,14 @@ const nodesSlice = createSlice({
|
||||
const { source_node_id } = action.payload.data;
|
||||
const node = state.nodeExecutionStates[source_node_id];
|
||||
if (node) {
|
||||
node.status = NodeStatus.IN_PROGRESS;
|
||||
node.status = zNodeStatus.enum.IN_PROGRESS;
|
||||
}
|
||||
});
|
||||
builder.addCase(appSocketInvocationComplete, (state, action) => {
|
||||
const { source_node_id, result } = action.payload.data;
|
||||
const nes = state.nodeExecutionStates[source_node_id];
|
||||
if (nes) {
|
||||
nes.status = NodeStatus.COMPLETED;
|
||||
nes.status = zNodeStatus.enum.COMPLETED;
|
||||
if (nes.progress !== null) {
|
||||
nes.progress = 1;
|
||||
}
|
||||
@ -979,7 +921,7 @@ const nodesSlice = createSlice({
|
||||
const { source_node_id } = action.payload.data;
|
||||
const node = state.nodeExecutionStates[source_node_id];
|
||||
if (node) {
|
||||
node.status = NodeStatus.FAILED;
|
||||
node.status = zNodeStatus.enum.FAILED;
|
||||
node.error = action.payload.data.error;
|
||||
node.progress = null;
|
||||
node.progressImage = null;
|
||||
@ -990,7 +932,7 @@ const nodesSlice = createSlice({
|
||||
action.payload.data;
|
||||
const node = state.nodeExecutionStates[source_node_id];
|
||||
if (node) {
|
||||
node.status = NodeStatus.IN_PROGRESS;
|
||||
node.status = zNodeStatus.enum.IN_PROGRESS;
|
||||
node.progress = (step + 1) / total_steps;
|
||||
node.progressImage = progress_image ?? null;
|
||||
}
|
||||
@ -998,7 +940,7 @@ const nodesSlice = createSlice({
|
||||
builder.addCase(appSocketQueueItemStatusChanged, (state, action) => {
|
||||
if (['in_progress'].includes(action.payload.data.queue_item.status)) {
|
||||
forEach(state.nodeExecutionStates, (nes) => {
|
||||
nes.status = NodeStatus.PENDING;
|
||||
nes.status = zNodeStatus.enum.PENDING;
|
||||
nes.error = null;
|
||||
nes.progress = null;
|
||||
nes.progressImage = null;
|
||||
@ -1037,7 +979,6 @@ export const {
|
||||
fieldSchedulerValueChanged,
|
||||
fieldStringValueChanged,
|
||||
fieldVaeModelValueChanged,
|
||||
imageCollectionFieldValueChanged,
|
||||
mouseOverFieldChanged,
|
||||
mouseOverNodeChanged,
|
||||
nodeAdded,
|
||||
@ -1063,7 +1004,6 @@ export const {
|
||||
selectionPasted,
|
||||
shouldAnimateEdgesChanged,
|
||||
shouldColorEdgesChanged,
|
||||
shouldShowFieldTypeLegendChanged,
|
||||
shouldShowMinimapPanelChanged,
|
||||
shouldSnapToGridChanged,
|
||||
shouldValidateGraphChanged,
|
||||
|
@ -6,25 +6,23 @@ import {
|
||||
Viewport,
|
||||
XYPosition,
|
||||
} from 'reactflow';
|
||||
import { FieldIdentifier, FieldType } from '../types/field';
|
||||
import {
|
||||
FieldIdentifier,
|
||||
FieldType,
|
||||
AnyNodeData,
|
||||
InvocationEdgeExtra,
|
||||
InvocationTemplate,
|
||||
NodeData,
|
||||
NodeExecutionState,
|
||||
Workflow,
|
||||
} from '../types/types';
|
||||
} from '../types/invocation';
|
||||
import { WorkflowV2 } from '../types/workflow';
|
||||
|
||||
export type NodesState = {
|
||||
nodes: Node<NodeData>[];
|
||||
nodes: Node<AnyNodeData>[];
|
||||
edges: Edge<InvocationEdgeExtra>[];
|
||||
nodeTemplates: Record<string, InvocationTemplate>;
|
||||
connectionStartParams: OnConnectStartParams | null;
|
||||
currentConnectionFieldType: FieldType | null;
|
||||
connectionStartFieldType: FieldType | null;
|
||||
connectionMade: boolean;
|
||||
modifyingEdge: boolean;
|
||||
shouldShowFieldTypeLegend: boolean;
|
||||
shouldShowMinimapPanel: boolean;
|
||||
shouldValidateGraph: boolean;
|
||||
shouldAnimateEdges: boolean;
|
||||
@ -33,13 +31,13 @@ export type NodesState = {
|
||||
shouldColorEdges: boolean;
|
||||
selectedNodes: string[];
|
||||
selectedEdges: string[];
|
||||
workflow: Omit<Workflow, 'nodes' | 'edges'>;
|
||||
workflow: Omit<WorkflowV2, 'nodes' | 'edges'>;
|
||||
nodeExecutionStates: Record<string, NodeExecutionState>;
|
||||
viewport: Viewport;
|
||||
isReady: boolean;
|
||||
mouseOverField: FieldIdentifier | null;
|
||||
mouseOverNode: string | null;
|
||||
nodesToCopy: Node<NodeData>[];
|
||||
nodesToCopy: Node<AnyNodeData>[];
|
||||
edgesToCopy: Edge<InvocationEdgeExtra>[];
|
||||
isAddNodePopoverOpen: boolean;
|
||||
addNewNodePosition: XYPosition | null;
|
||||
|
@ -1,50 +1,25 @@
|
||||
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
|
||||
import {
|
||||
FieldInputInstance,
|
||||
FieldOutputInstance,
|
||||
} from 'features/nodes/types/field';
|
||||
import {
|
||||
CurrentImageNodeData,
|
||||
InputFieldValue,
|
||||
InvocationNodeData,
|
||||
InvocationTemplate,
|
||||
NotesNodeData,
|
||||
OutputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
import { buildInputFieldValue } from 'features/nodes/util/fieldValueBuilders';
|
||||
} from 'features/nodes/types/invocation';
|
||||
import { buildFieldInputInstance } from 'features/nodes/util/buildFieldInputInstance';
|
||||
import { reduce } from 'lodash-es';
|
||||
import { Node, XYPosition } from 'reactflow';
|
||||
import { AnyInvocationType } from 'services/events/types';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
export const SHARED_NODE_PROPERTIES: Partial<Node> = {
|
||||
dragHandle: `.${DRAG_HANDLE_CLASSNAME}`,
|
||||
};
|
||||
export const buildNodeData = (
|
||||
type: AnyInvocationType | 'current_image' | 'notes',
|
||||
position: XYPosition,
|
||||
template?: InvocationTemplate
|
||||
):
|
||||
| Node<CurrentImageNodeData>
|
||||
| Node<NotesNodeData>
|
||||
| Node<InvocationNodeData>
|
||||
| undefined => {
|
||||
|
||||
export const buildNotesNode = (position: XYPosition): Node<NotesNodeData> => {
|
||||
const nodeId = uuidv4();
|
||||
|
||||
if (type === 'current_image') {
|
||||
const node: Node<CurrentImageNodeData> = {
|
||||
...SHARED_NODE_PROPERTIES,
|
||||
id: nodeId,
|
||||
type: 'current_image',
|
||||
position,
|
||||
data: {
|
||||
id: nodeId,
|
||||
type: 'current_image',
|
||||
isOpen: true,
|
||||
label: 'Current Image',
|
||||
},
|
||||
};
|
||||
|
||||
return node;
|
||||
}
|
||||
|
||||
if (type === 'notes') {
|
||||
const node: Node<NotesNodeData> = {
|
||||
...SHARED_NODE_PROPERTIES,
|
||||
id: nodeId,
|
||||
@ -58,21 +33,41 @@ export const buildNodeData = (
|
||||
type: 'notes',
|
||||
},
|
||||
};
|
||||
|
||||
return node;
|
||||
}
|
||||
};
|
||||
|
||||
if (template === undefined) {
|
||||
console.error(`Unable to find template ${type}.`);
|
||||
return;
|
||||
}
|
||||
export const buildCurrentImageNode = (
|
||||
position: XYPosition
|
||||
): Node<CurrentImageNodeData> => {
|
||||
const nodeId = uuidv4();
|
||||
const node: Node<CurrentImageNodeData> = {
|
||||
...SHARED_NODE_PROPERTIES,
|
||||
id: nodeId,
|
||||
type: 'current_image',
|
||||
position,
|
||||
data: {
|
||||
id: nodeId,
|
||||
type: 'current_image',
|
||||
isOpen: true,
|
||||
label: 'Current Image',
|
||||
},
|
||||
};
|
||||
return node;
|
||||
};
|
||||
|
||||
export const buildInvocationNode = (
|
||||
position: XYPosition,
|
||||
template: InvocationTemplate
|
||||
): Node<InvocationNodeData> => {
|
||||
const nodeId = uuidv4();
|
||||
const { type } = template;
|
||||
|
||||
const inputs = reduce(
|
||||
template.inputs,
|
||||
(inputsAccumulator, inputTemplate, inputName) => {
|
||||
const fieldId = uuidv4();
|
||||
|
||||
const inputFieldValue: InputFieldValue = buildInputFieldValue(
|
||||
const inputFieldValue: FieldInputInstance = buildFieldInputInstance(
|
||||
fieldId,
|
||||
inputTemplate
|
||||
);
|
||||
@ -81,7 +76,7 @@ export const buildNodeData = (
|
||||
|
||||
return inputsAccumulator;
|
||||
},
|
||||
{} as Record<string, InputFieldValue>
|
||||
{} as Record<string, FieldInputInstance>
|
||||
);
|
||||
|
||||
const outputs = reduce(
|
||||
@ -89,7 +84,7 @@ export const buildNodeData = (
|
||||
(outputsAccumulator, outputTemplate, outputName) => {
|
||||
const fieldId = uuidv4();
|
||||
|
||||
const outputFieldValue: OutputFieldValue = {
|
||||
const outputFieldValue: FieldOutputInstance = {
|
||||
id: fieldId,
|
||||
name: outputName,
|
||||
type: outputTemplate.type,
|
||||
@ -100,10 +95,10 @@ export const buildNodeData = (
|
||||
|
||||
return outputsAccumulator;
|
||||
},
|
||||
{} as Record<string, OutputFieldValue>
|
||||
{} as Record<string, FieldOutputInstance>
|
||||
);
|
||||
|
||||
const invocation: Node<InvocationNodeData> = {
|
||||
const node: Node<InvocationNodeData> = {
|
||||
...SHARED_NODE_PROPERTIES,
|
||||
id: nodeId,
|
||||
type: 'invocation',
|
||||
@ -117,11 +112,11 @@ export const buildNodeData = (
|
||||
isOpen: true,
|
||||
embedWorkflow: false,
|
||||
isIntermediate: type === 'save_image' ? false : true,
|
||||
useCache: template.useCache,
|
||||
inputs,
|
||||
outputs,
|
||||
useCache: template.useCache,
|
||||
},
|
||||
};
|
||||
|
||||
return invocation;
|
||||
return node;
|
||||
};
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user