feat(nodes): polymorphic fields (#4423)

## What type of PR is this? (check all applicable)

- [ ] Refactor
- [x] Feature
- [x] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission

## Description

### Polymorphic Fields

Initial support for polymorphic field types. Polymorphic types are a
single of or list of a specific type. For example, `Union[str,
list[str]]`.

Polymorphics do not yet have support for direct input in the UI (will
come in the future). They will be forcibly set as Connection-only
fields, in which case users will not be able to provide direct input to
the field.

If a polymorphic should present as a singleton type - which would allow
direct input - the node must provide an explicit type hint.

For example, `DenoiseLatents`' `CFG Scale` is polymorphic, but in the
node editor, we want to present this as a number input. In the node
definition, the field is given `ui_type=UIType.Float`, which tells the
UI to treat this as a `float` field.

The connection validation logic will prevent connecting a collection to
`CFG Scale` in this situation, because it is typed as `float`. The
workaround is to disable validation from the settings to make this
specific connection. A future improvement will resolve this.

### Collection Fields

This also introduces better support for collection field types. Like
polymorphics, collection types are parsed automatically by the client
and do not need any specific type hints.

Also like polymorphics, there is no support yet for direct input of
collection types in the UI.

### Other Changes

- Disabling validation in workflow editor now displays the visual hints
for valid connections, but lets you connect to anything.
- Added `ui_order: int` to `InputField` and `OutputField`. The UI will
use this, if present, to order fields in a node UI. See usage in
`DenoiseLatents` for an example.
- Updated the field colors - duplicate colors have just been lightened a
bit. It's not perfect but it was a quick fix.
- Field handles for collections are the same color as their single
counterparts, but have a dark dot in the center of them.
- Field handles for polymorphics are a rounded square with dot in the
middle.
- Removed all fields that just render `null` from `InputFieldRenderer`,
replaced with a single fallback
- Removed logic in `zValidatedWorkflow`, which checked for existence of
node templates for each node in a workflow. This logic introduced a
circular dependency, due to importing the global redux `store` in order
to get the node templates within a zod schema. It's actually fine to
just leave this out entirely; The case of a missing node template is
handled by the UI. Fixing it otherwise would introduce a substantial
headache.
- Fixed the `ControlNetInvocation.control_model` field default, which
was a string when it shouldn't have one.

## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->

- Closes #4266 

## QA Instructions, Screenshots, Recordings

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

Add this polymorphic float node to the end of your
`invokeai/app/invocations/primitives.py`:
```py
@invocation("float_poly", title="Float Poly Test", tags=["primitives", "float"], category="primitives")
class FloatPolyInvocation(BaseInvocation):
    """A float polymorphic primitive value"""

    value: Union[float, list[float]] = InputField(default_factory=list, description="The float value")

    def invoke(self, context: InvocationContext) -> FloatOutput:
        return FloatOutput(value=self.value[0] if isinstance(self.value, list) else self.value)
``

Head over to nodes and try to connecting up some collection and polymorphic inputs.
This commit is contained in:
blessedcoolant 2023-09-05 09:39:04 +12:00 committed by GitHub
commit 1f6c868212
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 1494 additions and 724 deletions

View File

@ -105,24 +105,39 @@ class UIType(str, Enum):
""" """
# region Primitives # region Primitives
Integer = "integer"
Float = "float"
Boolean = "boolean" Boolean = "boolean"
String = "string" Color = "ColorField"
Array = "array"
Image = "ImageField"
Latents = "LatentsField"
Conditioning = "ConditioningField" Conditioning = "ConditioningField"
Control = "ControlField" Control = "ControlField"
Color = "ColorField" Float = "float"
ImageCollection = "ImageCollection" Image = "ImageField"
ConditioningCollection = "ConditioningCollection" Integer = "integer"
ColorCollection = "ColorCollection" Latents = "LatentsField"
LatentsCollection = "LatentsCollection" String = "string"
IntegerCollection = "IntegerCollection" # endregion
FloatCollection = "FloatCollection"
StringCollection = "StringCollection" # region Collection Primitives
BooleanCollection = "BooleanCollection" BooleanCollection = "BooleanCollection"
ColorCollection = "ColorCollection"
ConditioningCollection = "ConditioningCollection"
ControlCollection = "ControlCollection"
FloatCollection = "FloatCollection"
ImageCollection = "ImageCollection"
IntegerCollection = "IntegerCollection"
LatentsCollection = "LatentsCollection"
StringCollection = "StringCollection"
# endregion
# region Polymorphic Primitives
BooleanPolymorphic = "BooleanPolymorphic"
ColorPolymorphic = "ColorPolymorphic"
ConditioningPolymorphic = "ConditioningPolymorphic"
ControlPolymorphic = "ControlPolymorphic"
FloatPolymorphic = "FloatPolymorphic"
ImagePolymorphic = "ImagePolymorphic"
IntegerPolymorphic = "IntegerPolymorphic"
LatentsPolymorphic = "LatentsPolymorphic"
StringPolymorphic = "StringPolymorphic"
# endregion # endregion
# region Models # region Models
@ -176,6 +191,7 @@ class _InputField(BaseModel):
ui_type: Optional[UIType] ui_type: Optional[UIType]
ui_component: Optional[UIComponent] ui_component: Optional[UIComponent]
ui_order: Optional[int] ui_order: Optional[int]
item_default: Optional[Any]
class _OutputField(BaseModel): class _OutputField(BaseModel):
@ -223,6 +239,7 @@ def InputField(
ui_component: Optional[UIComponent] = None, ui_component: Optional[UIComponent] = None,
ui_hidden: bool = False, ui_hidden: bool = False,
ui_order: Optional[int] = None, ui_order: Optional[int] = None,
item_default: Optional[Any] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
""" """
@ -249,6 +266,11 @@ def InputField(
For this case, you could provide `UIComponent.Textarea`. For this case, you could provide `UIComponent.Textarea`.
: param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. : param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI.
: param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \
: param bool item_default: [None] Specifies the default item value, if this is a collection input. \
Ignored for non-collection fields..
""" """
return Field( return Field(
*args, *args,
@ -282,6 +304,7 @@ def InputField(
ui_component=ui_component, ui_component=ui_component,
ui_hidden=ui_hidden, ui_hidden=ui_hidden,
ui_order=ui_order, ui_order=ui_order,
item_default=item_default,
**kwargs, **kwargs,
) )
@ -332,6 +355,8 @@ def OutputField(
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field. `UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
: param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \ : param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \
: param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \
""" """
return Field( return Field(
*args, *args,

View File

@ -100,9 +100,7 @@ class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes""" """Collects ControlNet info to pass to other nodes"""
image: ImageField = InputField(description="The control image") image: ImageField = InputField(description="The control image")
control_model: ControlNetModelField = InputField( control_model: ControlNetModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct)
default="lllyasviel/sd-controlnet-canny", description=FieldDescriptions.controlnet_model, input=Input.Direct
)
control_weight: Union[float, List[float]] = InputField( control_weight: Union[float, List[float]] = InputField(
default=1.0, description="The weight given to the ControlNet", ui_type=UIType.Float default=1.0, description="The weight given to the ControlNet", ui_type=UIType.Float
) )

View File

@ -208,12 +208,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
) )
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection, title="UNet", ui_order=2) unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection, title="UNet", ui_order=2)
control: Union[ControlField, list[ControlField]] = InputField( control: Union[ControlField, list[ControlField]] = InputField(
default=None, description=FieldDescriptions.control, input=Input.Connection, ui_order=5 default=None,
description=FieldDescriptions.control,
input=Input.Connection,
ui_order=5,
) )
latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection) latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
denoise_mask: Optional[DenoiseMaskField] = InputField( denoise_mask: Optional[DenoiseMaskField] = InputField(
default=None, default=None, description=FieldDescriptions.mask, input=Input.Connection, ui_order=6
description=FieldDescriptions.mask,
) )
@validator("cfg_scale") @validator("cfg_scale")
@ -317,7 +319,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
context: InvocationContext, context: InvocationContext,
# really only need model for dtype and device # really only need model for dtype and device
model: StableDiffusionGeneratorPipeline, model: StableDiffusionGeneratorPipeline,
control_input: List[ControlField], control_input: Union[ControlField, List[ControlField]],
latents_shape: List[int], latents_shape: List[int],
exit_stack: ExitStack, exit_stack: ExitStack,
do_classifier_free_guidance: bool = True, do_classifier_free_guidance: bool = True,

View File

@ -14,7 +14,6 @@ from .baseinvocation import (
InvocationContext, InvocationContext,
OutputField, OutputField,
UIComponent, UIComponent,
UIType,
invocation, invocation,
invocation_output, invocation_output,
) )
@ -40,7 +39,9 @@ class BooleanOutput(BaseInvocationOutput):
class BooleanCollectionOutput(BaseInvocationOutput): class BooleanCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of booleans""" """Base class for nodes that output a collection of booleans"""
collection: list[bool] = OutputField(description="The output boolean collection", ui_type=UIType.BooleanCollection) collection: list[bool] = OutputField(
description="The output boolean collection",
)
@invocation("boolean", title="Boolean Primitive", tags=["primitives", "boolean"], category="primitives") @invocation("boolean", title="Boolean Primitive", tags=["primitives", "boolean"], category="primitives")
@ -62,9 +63,7 @@ class BooleanInvocation(BaseInvocation):
class BooleanCollectionInvocation(BaseInvocation): class BooleanCollectionInvocation(BaseInvocation):
"""A collection of boolean primitive values""" """A collection of boolean primitive values"""
collection: list[bool] = InputField( collection: list[bool] = InputField(default_factory=list, description="The collection of boolean values")
default_factory=list, description="The collection of boolean values", ui_type=UIType.BooleanCollection
)
def invoke(self, context: InvocationContext) -> BooleanCollectionOutput: def invoke(self, context: InvocationContext) -> BooleanCollectionOutput:
return BooleanCollectionOutput(collection=self.collection) return BooleanCollectionOutput(collection=self.collection)
@ -86,7 +85,9 @@ class IntegerOutput(BaseInvocationOutput):
class IntegerCollectionOutput(BaseInvocationOutput): class IntegerCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of integers""" """Base class for nodes that output a collection of integers"""
collection: list[int] = OutputField(description="The int collection", ui_type=UIType.IntegerCollection) collection: list[int] = OutputField(
description="The int collection",
)
@invocation("integer", title="Integer Primitive", tags=["primitives", "integer"], category="primitives") @invocation("integer", title="Integer Primitive", tags=["primitives", "integer"], category="primitives")
@ -108,9 +109,7 @@ class IntegerInvocation(BaseInvocation):
class IntegerCollectionInvocation(BaseInvocation): class IntegerCollectionInvocation(BaseInvocation):
"""A collection of integer primitive values""" """A collection of integer primitive values"""
collection: list[int] = InputField( collection: list[int] = InputField(default_factory=list, description="The collection of integer values")
default_factory=list, description="The collection of integer values", ui_type=UIType.IntegerCollection
)
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
return IntegerCollectionOutput(collection=self.collection) return IntegerCollectionOutput(collection=self.collection)
@ -132,7 +131,9 @@ class FloatOutput(BaseInvocationOutput):
class FloatCollectionOutput(BaseInvocationOutput): class FloatCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of floats""" """Base class for nodes that output a collection of floats"""
collection: list[float] = OutputField(description="The float collection", ui_type=UIType.FloatCollection) collection: list[float] = OutputField(
description="The float collection",
)
@invocation("float", title="Float Primitive", tags=["primitives", "float"], category="primitives") @invocation("float", title="Float Primitive", tags=["primitives", "float"], category="primitives")
@ -154,9 +155,7 @@ class FloatInvocation(BaseInvocation):
class FloatCollectionInvocation(BaseInvocation): class FloatCollectionInvocation(BaseInvocation):
"""A collection of float primitive values""" """A collection of float primitive values"""
collection: list[float] = InputField( collection: list[float] = InputField(default_factory=list, description="The collection of float values")
default_factory=list, description="The collection of float values", ui_type=UIType.FloatCollection
)
def invoke(self, context: InvocationContext) -> FloatCollectionOutput: def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
return FloatCollectionOutput(collection=self.collection) return FloatCollectionOutput(collection=self.collection)
@ -178,7 +177,9 @@ class StringOutput(BaseInvocationOutput):
class StringCollectionOutput(BaseInvocationOutput): class StringCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of strings""" """Base class for nodes that output a collection of strings"""
collection: list[str] = OutputField(description="The output strings", ui_type=UIType.StringCollection) collection: list[str] = OutputField(
description="The output strings",
)
@invocation("string", title="String Primitive", tags=["primitives", "string"], category="primitives") @invocation("string", title="String Primitive", tags=["primitives", "string"], category="primitives")
@ -200,9 +201,7 @@ class StringInvocation(BaseInvocation):
class StringCollectionInvocation(BaseInvocation): class StringCollectionInvocation(BaseInvocation):
"""A collection of string primitive values""" """A collection of string primitive values"""
collection: list[str] = InputField( collection: list[str] = InputField(default_factory=list, description="The collection of string values")
default_factory=list, description="The collection of string values", ui_type=UIType.StringCollection
)
def invoke(self, context: InvocationContext) -> StringCollectionOutput: def invoke(self, context: InvocationContext) -> StringCollectionOutput:
return StringCollectionOutput(collection=self.collection) return StringCollectionOutput(collection=self.collection)
@ -232,7 +231,9 @@ class ImageOutput(BaseInvocationOutput):
class ImageCollectionOutput(BaseInvocationOutput): class ImageCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of images""" """Base class for nodes that output a collection of images"""
collection: list[ImageField] = OutputField(description="The output images", ui_type=UIType.ImageCollection) collection: list[ImageField] = OutputField(
description="The output images",
)
@invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives") @invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives")
@ -260,9 +261,7 @@ class ImageInvocation(BaseInvocation):
class ImageCollectionInvocation(BaseInvocation): class ImageCollectionInvocation(BaseInvocation):
"""A collection of image primitive values""" """A collection of image primitive values"""
collection: list[ImageField] = InputField( collection: list[ImageField] = InputField(description="The collection of image values")
default_factory=list, description="The collection of image values", ui_type=UIType.ImageCollection
)
def invoke(self, context: InvocationContext) -> ImageCollectionOutput: def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
return ImageCollectionOutput(collection=self.collection) return ImageCollectionOutput(collection=self.collection)
@ -316,7 +315,6 @@ class LatentsCollectionOutput(BaseInvocationOutput):
collection: list[LatentsField] = OutputField( collection: list[LatentsField] = OutputField(
description=FieldDescriptions.latents, description=FieldDescriptions.latents,
ui_type=UIType.LatentsCollection,
) )
@ -342,7 +340,7 @@ class LatentsCollectionInvocation(BaseInvocation):
"""A collection of latents tensor primitive values""" """A collection of latents tensor primitive values"""
collection: list[LatentsField] = InputField( collection: list[LatentsField] = InputField(
description="The collection of latents tensors", ui_type=UIType.LatentsCollection description="The collection of latents tensors",
) )
def invoke(self, context: InvocationContext) -> LatentsCollectionOutput: def invoke(self, context: InvocationContext) -> LatentsCollectionOutput:
@ -385,7 +383,9 @@ class ColorOutput(BaseInvocationOutput):
class ColorCollectionOutput(BaseInvocationOutput): class ColorCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of colors""" """Base class for nodes that output a collection of colors"""
collection: list[ColorField] = OutputField(description="The output colors", ui_type=UIType.ColorCollection) collection: list[ColorField] = OutputField(
description="The output colors",
)
@invocation("color", title="Color Primitive", tags=["primitives", "color"], category="primitives") @invocation("color", title="Color Primitive", tags=["primitives", "color"], category="primitives")
@ -422,7 +422,6 @@ class ConditioningCollectionOutput(BaseInvocationOutput):
collection: list[ConditioningField] = OutputField( collection: list[ConditioningField] = OutputField(
description="The output conditioning tensors", description="The output conditioning tensors",
ui_type=UIType.ConditioningCollection,
) )
@ -453,7 +452,6 @@ class ConditioningCollectionInvocation(BaseInvocation):
collection: list[ConditioningField] = InputField( collection: list[ConditioningField] = InputField(
default_factory=list, default_factory=list,
description="The collection of conditioning tensors", description="The collection of conditioning tensors",
ui_type=UIType.ConditioningCollection,
) )
def invoke(self, context: InvocationContext) -> ConditioningCollectionOutput: def invoke(self, context: InvocationContext) -> ConditioningCollectionOutput:

View File

@ -112,6 +112,10 @@ def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
if to_type in get_args(from_type): if to_type in get_args(from_type):
return True return True
# allow int -> float, pydantic will cast for us
if from_type is int and to_type is float:
return True
# if not issubclass(from_type, to_type): # if not issubclass(from_type, to_type):
if not is_union_subtype(from_type, to_type): if not is_union_subtype(from_type, to_type):
return False return False

View File

@ -63,7 +63,11 @@ const selector = createSelector(
return; return;
} }
if (fieldTemplate.required && !field.value && !hasConnection) { if (
fieldTemplate.required &&
field.value === undefined &&
!hasConnection
) {
reasons.push( reasons.push(
`${node.data.label || nodeTemplate.title} -> ${ `${node.data.label || nodeTemplate.title} -> ${
field.label || fieldTemplate.title field.label || fieldTemplate.title

View File

@ -1,2 +1,2 @@
export const colorTokenToCssVar = (colorToken: string) => export const colorTokenToCssVar = (colorToken: string) =>
`var(--invokeai-colors-${colorToken.split('.').join('-')}`; `var(--invokeai-colors-${colorToken.split('.').join('-')})`;

View File

@ -1,8 +1,11 @@
import { Tooltip } from '@chakra-ui/react'; import { Tooltip } from '@chakra-ui/react';
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { import {
COLLECTION_TYPES,
FIELDS, FIELDS,
HANDLE_TOOLTIP_OPEN_DELAY, HANDLE_TOOLTIP_OPEN_DELAY,
MODEL_TYPES,
POLYMORPHIC_TYPES,
} from 'features/nodes/types/constants'; } from 'features/nodes/types/constants';
import { import {
InputFieldTemplate, InputFieldTemplate,
@ -18,6 +21,7 @@ export const handleBaseStyles: CSSProperties = {
borderWidth: 0, borderWidth: 0,
zIndex: 1, zIndex: 1,
}; };
``;
export const inputHandleStyles: CSSProperties = { export const inputHandleStyles: CSSProperties = {
left: '-1rem', left: '-1rem',
@ -44,15 +48,25 @@ const FieldHandle = (props: FieldHandleProps) => {
connectionError, connectionError,
} = props; } = props;
const { name, type } = fieldTemplate; const { name, type } = fieldTemplate;
const { color, title } = FIELDS[type]; const { color: typeColor, title } = FIELDS[type];
const styles: CSSProperties = useMemo(() => { const styles: CSSProperties = useMemo(() => {
const isCollectionType = COLLECTION_TYPES.includes(type);
const isPolymorphicType = POLYMORPHIC_TYPES.includes(type);
const isModelType = MODEL_TYPES.includes(type);
const color = colorTokenToCssVar(typeColor);
const s: CSSProperties = { const s: CSSProperties = {
backgroundColor: colorTokenToCssVar(color), backgroundColor:
isCollectionType || isPolymorphicType
? 'var(--invokeai-colors-base-900)'
: color,
position: 'absolute', position: 'absolute',
width: '1rem', width: '1rem',
height: '1rem', height: '1rem',
borderWidth: 0, borderWidth: isCollectionType || isPolymorphicType ? 4 : 0,
borderStyle: 'solid',
borderColor: color,
borderRadius: isModelType ? 4 : '100%',
zIndex: 1, zIndex: 1,
}; };
@ -78,11 +92,12 @@ const FieldHandle = (props: FieldHandleProps) => {
return s; return s;
}, [ }, [
color,
connectionError, connectionError,
handleType, handleType,
isConnectionInProgress, isConnectionInProgress,
isConnectionStartField, isConnectionStartField,
type,
typeColor,
]); ]);
const tooltip = useMemo(() => { const tooltip = useMemo(() => {

View File

@ -75,6 +75,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
sx={{ sx={{
display: 'flex', display: 'flex',
alignItems: 'center', alignItems: 'center',
h: 'full',
mb: 0, mb: 0,
px: 1, px: 1,
gap: 2, gap: 2,

View File

@ -3,18 +3,10 @@ import { useFieldData } from 'features/nodes/hooks/useFieldData';
import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate'; import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate';
import { memo } from 'react'; import { memo } from 'react';
import BooleanInputField from './inputs/BooleanInputField'; import BooleanInputField from './inputs/BooleanInputField';
import ClipInputField from './inputs/ClipInputField';
import CollectionInputField from './inputs/CollectionInputField';
import CollectionItemInputField from './inputs/CollectionItemInputField';
import ColorInputField from './inputs/ColorInputField'; import ColorInputField from './inputs/ColorInputField';
import ConditioningInputField from './inputs/ConditioningInputField';
import ControlInputField from './inputs/ControlInputField';
import ControlNetModelInputField from './inputs/ControlNetModelInputField'; import ControlNetModelInputField from './inputs/ControlNetModelInputField';
import DenoiseMaskInputField from './inputs/DenoiseMaskInputField';
import EnumInputField from './inputs/EnumInputField'; import EnumInputField from './inputs/EnumInputField';
import ImageCollectionInputField from './inputs/ImageCollectionInputField';
import ImageInputField from './inputs/ImageInputField'; import ImageInputField from './inputs/ImageInputField';
import LatentsInputField from './inputs/LatentsInputField';
import LoRAModelInputField from './inputs/LoRAModelInputField'; import LoRAModelInputField from './inputs/LoRAModelInputField';
import MainModelInputField from './inputs/MainModelInputField'; import MainModelInputField from './inputs/MainModelInputField';
import NumberInputField from './inputs/NumberInputField'; import NumberInputField from './inputs/NumberInputField';
@ -22,8 +14,6 @@ import RefinerModelInputField from './inputs/RefinerModelInputField';
import SDXLMainModelInputField from './inputs/SDXLMainModelInputField'; import SDXLMainModelInputField from './inputs/SDXLMainModelInputField';
import SchedulerInputField from './inputs/SchedulerInputField'; import SchedulerInputField from './inputs/SchedulerInputField';
import StringInputField from './inputs/StringInputField'; import StringInputField from './inputs/StringInputField';
import UnetInputField from './inputs/UnetInputField';
import VaeInputField from './inputs/VaeInputField';
import VaeModelInputField from './inputs/VaeModelInputField'; import VaeModelInputField from './inputs/VaeModelInputField';
type InputFieldProps = { type InputFieldProps = {
@ -31,7 +21,6 @@ type InputFieldProps = {
fieldName: string; fieldName: string;
}; };
// build an individual input element based on the schema
const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => { const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
const field = useFieldData(nodeId, fieldName); const field = useFieldData(nodeId, fieldName);
const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input'); const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input');
@ -93,88 +82,6 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
); );
} }
if (
field?.type === 'LatentsField' &&
fieldTemplate?.type === 'LatentsField'
) {
return (
<LatentsInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'DenoiseMaskField' &&
fieldTemplate?.type === 'DenoiseMaskField'
) {
return (
<DenoiseMaskInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'ConditioningField' &&
fieldTemplate?.type === 'ConditioningField'
) {
return (
<ConditioningInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (field?.type === 'UNetField' && fieldTemplate?.type === 'UNetField') {
return (
<UnetInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (field?.type === 'ClipField' && fieldTemplate?.type === 'ClipField') {
return (
<ClipInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (field?.type === 'VaeField' && fieldTemplate?.type === 'VaeField') {
return (
<VaeInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'ControlField' &&
fieldTemplate?.type === 'ControlField'
) {
return (
<ControlInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if ( if (
field?.type === 'MainModelField' && field?.type === 'MainModelField' &&
fieldTemplate?.type === 'MainModelField' fieldTemplate?.type === 'MainModelField'
@ -240,29 +147,6 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
); );
} }
if (field?.type === 'Collection' && fieldTemplate?.type === 'Collection') {
return (
<CollectionInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'CollectionItem' &&
fieldTemplate?.type === 'CollectionItem'
) {
return (
<CollectionItemInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (field?.type === 'ColorField' && fieldTemplate?.type === 'ColorField') { if (field?.type === 'ColorField' && fieldTemplate?.type === 'ColorField') {
return ( return (
<ColorInputField <ColorInputField
@ -273,19 +157,6 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
); );
} }
if (
field?.type === 'ImageCollection' &&
fieldTemplate?.type === 'ImageCollection'
) {
return (
<ImageCollectionInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if ( if (
field?.type === 'SDXLMainModelField' && field?.type === 'SDXLMainModelField' &&
fieldTemplate?.type === 'SDXLMainModelField' fieldTemplate?.type === 'SDXLMainModelField'
@ -309,6 +180,11 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
); );
} }
if (field && fieldTemplate) {
// Fallback for when there is no component for the type
return null;
}
return ( return (
<Box p={1}> <Box p={1}>
<Text <Text

View File

@ -1,12 +1,17 @@
import { import {
ControlInputFieldTemplate, ControlInputFieldTemplate,
ControlInputFieldValue, ControlInputFieldValue,
ControlPolymorphicInputFieldTemplate,
ControlPolymorphicInputFieldValue,
FieldComponentProps, FieldComponentProps,
} from 'features/nodes/types/types'; } from 'features/nodes/types/types';
import { memo } from 'react'; import { memo } from 'react';
const ControlInputFieldComponent = ( const ControlInputFieldComponent = (
_props: FieldComponentProps<ControlInputFieldValue, ControlInputFieldTemplate> _props: FieldComponentProps<
ControlInputFieldValue | ControlPolymorphicInputFieldValue,
ControlInputFieldTemplate | ControlPolymorphicInputFieldTemplate
>
) => { ) => {
return null; return null;
}; };

View File

@ -9,9 +9,9 @@ import {
} from 'features/dnd/types'; } from 'features/dnd/types';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice'; import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { import {
FieldComponentProps,
ImageInputFieldTemplate, ImageInputFieldTemplate,
ImageInputFieldValue, ImageInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types'; } from 'features/nodes/types/types';
import { memo, useCallback, useMemo } from 'react'; import { memo, useCallback, useMemo } from 'react';
import { FaUndo } from 'react-icons/fa'; import { FaUndo } from 'react-icons/fa';

View File

@ -2,11 +2,16 @@ import {
LatentsInputFieldTemplate, LatentsInputFieldTemplate,
LatentsInputFieldValue, LatentsInputFieldValue,
FieldComponentProps, FieldComponentProps,
LatentsPolymorphicInputFieldValue,
LatentsPolymorphicInputFieldTemplate,
} from 'features/nodes/types/types'; } from 'features/nodes/types/types';
import { memo } from 'react'; import { memo } from 'react';
const LatentsInputFieldComponent = ( const LatentsInputFieldComponent = (
_props: FieldComponentProps<LatentsInputFieldValue, LatentsInputFieldTemplate> _props: FieldComponentProps<
LatentsInputFieldValue | LatentsPolymorphicInputFieldValue,
LatentsInputFieldTemplate | LatentsPolymorphicInputFieldTemplate
>
) => { ) => {
return null; return null;
}; };

View File

@ -9,11 +9,11 @@ import { useAppDispatch } from 'app/store/storeHooks';
import { numberStringRegex } from 'common/components/IAINumberInput'; import { numberStringRegex } from 'common/components/IAINumberInput';
import { fieldNumberValueChanged } from 'features/nodes/store/nodesSlice'; import { fieldNumberValueChanged } from 'features/nodes/store/nodesSlice';
import { import {
FieldComponentProps,
FloatInputFieldTemplate, FloatInputFieldTemplate,
FloatInputFieldValue, FloatInputFieldValue,
IntegerInputFieldTemplate, IntegerInputFieldTemplate,
IntegerInputFieldValue, IntegerInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types'; } from 'features/nodes/types/types';
import { memo, useEffect, useMemo, useState } from 'react'; import { memo, useEffect, useMemo, useState } from 'react';

View File

@ -15,7 +15,7 @@ export const useDoesInputHaveValue = (nodeId: string, fieldName: string) => {
if (!isInvocationNode(node)) { if (!isInvocationNode(node)) {
return; return;
} }
return Boolean(node?.data.inputs[fieldName]?.value); return node?.data.inputs[fieldName]?.value !== undefined;
}, },
defaultSelectorOptions defaultSelectorOptions
), ),

View File

@ -3,9 +3,19 @@ import graphlib from '@dagrejs/graphlib';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { useCallback } from 'react'; import { useCallback } from 'react';
import { Connection, Edge, Node, useReactFlow } from 'reactflow'; import { Connection, Edge, Node, useReactFlow } from 'reactflow';
import { COLLECTION_TYPES } from '../types/constants'; import {
COLLECTION_MAP,
COLLECTION_TYPES,
POLYMORPHIC_TO_SINGLE_MAP,
POLYMORPHIC_TYPES,
} from '../types/constants';
import { InvocationNodeData } from '../types/types'; import { InvocationNodeData } from '../types/types';
/**
* NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts`
* TODO: Figure out how to do this without duplicating all the logic
*/
export const useIsValidConnection = () => { export const useIsValidConnection = () => {
const flow = useReactFlow(); const flow = useReactFlow();
const shouldValidateGraph = useAppSelector( const shouldValidateGraph = useAppSelector(
@ -42,6 +52,19 @@ export const useIsValidConnection = () => {
return false; return false;
} }
if (
edges
.filter((edge) => {
return edge.target === target && edge.targetHandle === targetHandle;
})
.find((edge) => {
edge.source === source && edge.sourceHandle === sourceHandle;
})
) {
// We already have a connection from this source to this target
return false;
}
// Connection is invalid if target already has a connection // Connection is invalid if target already has a connection
if ( if (
edges.find((edge) => { edges.find((edge) => {
@ -53,21 +76,62 @@ export const useIsValidConnection = () => {
return false; return false;
} }
// Connection types must be the same for a connection /**
if ( * Connection types must be the same for a connection, with exceptions:
sourceType !== targetType && * - CollectionItem can connect to any non-Collection
sourceType !== 'CollectionItem' && * - Non-Collections can connect to CollectionItem
targetType !== 'CollectionItem' * - Anything (non-Collections, Collections, Polymorphics) can connect to Polymorphics of the same base type
) { * - Generic Collection can connect to any other Collection or Polymorphic
if ( * - Any Collection can connect to a Generic Collection
!( */
COLLECTION_TYPES.includes(targetType) &&
COLLECTION_TYPES.includes(sourceType) if (sourceType !== targetType) {
) const isCollectionItemToNonCollection =
) { sourceType === 'CollectionItem' &&
return false; !COLLECTION_TYPES.includes(targetType);
}
const isNonCollectionToCollectionItem =
targetType === 'CollectionItem' &&
!COLLECTION_TYPES.includes(sourceType) &&
!POLYMORPHIC_TYPES.includes(sourceType);
const isAnythingToPolymorphicOfSameBaseType =
POLYMORPHIC_TYPES.includes(targetType) &&
(() => {
if (!POLYMORPHIC_TYPES.includes(targetType)) {
return false;
}
const baseType =
POLYMORPHIC_TO_SINGLE_MAP[
targetType as keyof typeof POLYMORPHIC_TO_SINGLE_MAP
];
const collectionType =
COLLECTION_MAP[baseType as keyof typeof COLLECTION_MAP];
return sourceType === baseType || sourceType === collectionType;
})();
const isGenericCollectionToAnyCollectionOrPolymorphic =
sourceType === 'Collection' &&
(COLLECTION_TYPES.includes(targetType) ||
POLYMORPHIC_TYPES.includes(targetType));
const isCollectionToGenericCollection =
targetType === 'Collection' && COLLECTION_TYPES.includes(sourceType);
const isIntToFloat = sourceType === 'integer' && targetType === 'float';
return (
isCollectionItemToNonCollection ||
isNonCollectionToCollectionItem ||
isAnythingToPolymorphicOfSameBaseType ||
isGenericCollectionToAnyCollectionOrPolymorphic ||
isCollectionToGenericCollection ||
isIntToFloat
);
} }
// Graphs much be acyclic (no loops!) // Graphs much be acyclic (no loops!)
return getIsGraphAcyclic(source, target, nodes, edges); return getIsGraphAcyclic(source, target, nodes, edges);
}, },

View File

@ -1,10 +1,20 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
import { getIsGraphAcyclic } from 'features/nodes/hooks/useIsValidConnection'; import { getIsGraphAcyclic } from 'features/nodes/hooks/useIsValidConnection';
import { COLLECTION_TYPES } from 'features/nodes/types/constants'; import {
COLLECTION_MAP,
COLLECTION_TYPES,
POLYMORPHIC_TO_SINGLE_MAP,
POLYMORPHIC_TYPES,
} from 'features/nodes/types/constants';
import { FieldType } from 'features/nodes/types/types'; import { FieldType } from 'features/nodes/types/types';
import { HandleType } from 'reactflow'; import { HandleType } from 'reactflow';
/**
* NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts`
* TODO: Figure out how to do this without duplicating all the logic
*/
export const makeConnectionErrorSelector = ( export const makeConnectionErrorSelector = (
nodeId: string, nodeId: string,
fieldName: string, fieldName: string,
@ -19,11 +29,6 @@ export const makeConnectionErrorSelector = (
const { currentConnectionFieldType, connectionStartParams, nodes, edges } = const { currentConnectionFieldType, connectionStartParams, nodes, edges } =
state.nodes; state.nodes;
if (!state.nodes.shouldValidateGraph) {
// manual override!
return null;
}
if (!connectionStartParams || !currentConnectionFieldType) { if (!connectionStartParams || !currentConnectionFieldType) {
return 'No connection in progress'; return 'No connection in progress';
} }
@ -38,9 +43,9 @@ export const makeConnectionErrorSelector = (
return 'No connection data'; return 'No connection data';
} }
const targetFieldType = const targetType =
handleType === 'target' ? fieldType : currentConnectionFieldType; handleType === 'target' ? fieldType : currentConnectionFieldType;
const sourceFieldType = const sourceType =
handleType === 'source' ? fieldType : currentConnectionFieldType; handleType === 'source' ? fieldType : currentConnectionFieldType;
if (nodeId === connectionNodeId) { if (nodeId === connectionNodeId) {
@ -55,30 +60,73 @@ export const makeConnectionErrorSelector = (
} }
if ( if (
fieldType !== currentConnectionFieldType &&
fieldType !== 'CollectionItem' &&
currentConnectionFieldType !== 'CollectionItem'
) {
if (
!(
COLLECTION_TYPES.includes(targetFieldType) &&
COLLECTION_TYPES.includes(sourceFieldType)
)
) {
// except for collection items, field types must match
return 'Field types must match';
}
}
if (
handleType === 'target' &&
edges.find((edge) => { edges.find((edge) => {
return edge.target === nodeId && edge.targetHandle === fieldName; return edge.target === nodeId && edge.targetHandle === fieldName;
}) && }) &&
// except CollectionItem inputs can have multiples // except CollectionItem inputs can have multiples
targetFieldType !== 'CollectionItem' targetType !== 'CollectionItem'
) { ) {
return 'Inputs may only have one connection'; return 'Input may only have one connection';
}
/**
* Connection types must be the same for a connection, with exceptions:
* - CollectionItem can connect to any non-Collection
* - Non-Collections can connect to CollectionItem
* - Anything (non-Collections, Collections, Polymorphics) can connect to Polymorphics of the same base type
* - Generic Collection can connect to any other Collection or Polymorphic
* - Any Collection can connect to a Generic Collection
*/
if (sourceType !== targetType) {
const isCollectionItemToNonCollection =
sourceType === 'CollectionItem' &&
!COLLECTION_TYPES.includes(targetType);
const isNonCollectionToCollectionItem =
targetType === 'CollectionItem' &&
!COLLECTION_TYPES.includes(sourceType) &&
!POLYMORPHIC_TYPES.includes(sourceType);
const isAnythingToPolymorphicOfSameBaseType =
POLYMORPHIC_TYPES.includes(targetType) &&
(() => {
if (!POLYMORPHIC_TYPES.includes(targetType)) {
return false;
}
const baseType =
POLYMORPHIC_TO_SINGLE_MAP[
targetType as keyof typeof POLYMORPHIC_TO_SINGLE_MAP
];
const collectionType =
COLLECTION_MAP[baseType as keyof typeof COLLECTION_MAP];
return sourceType === baseType || sourceType === collectionType;
})();
const isGenericCollectionToAnyCollectionOrPolymorphic =
sourceType === 'Collection' &&
(COLLECTION_TYPES.includes(targetType) ||
POLYMORPHIC_TYPES.includes(targetType));
const isCollectionToGenericCollection =
targetType === 'Collection' && COLLECTION_TYPES.includes(sourceType);
const isIntToFloat = sourceType === 'integer' && targetType === 'float';
if (
!(
isCollectionItemToNonCollection ||
isNonCollectionToCollectionItem ||
isAnythingToPolymorphicOfSameBaseType ||
isGenericCollectionToAnyCollectionOrPolymorphic ||
isCollectionToGenericCollection ||
isIntToFloat
)
) {
return 'Field types must match';
}
} }
const isGraphAcyclic = getIsGraphAcyclic( const isGraphAcyclic = getIsGraphAcyclic(

View File

@ -17,176 +17,297 @@ export const KIND_MAP = {
export const COLLECTION_TYPES: FieldType[] = [ export const COLLECTION_TYPES: FieldType[] = [
'Collection', 'Collection',
'IntegerCollection', 'IntegerCollection',
'BooleanCollection',
'FloatCollection', 'FloatCollection',
'StringCollection', 'StringCollection',
'BooleanCollection',
'ImageCollection', 'ImageCollection',
'LatentsCollection',
'ConditioningCollection',
'ControlCollection',
'ColorCollection',
]; ];
export const POLYMORPHIC_TYPES = [
'IntegerPolymorphic',
'BooleanPolymorphic',
'FloatPolymorphic',
'StringPolymorphic',
'ImagePolymorphic',
'LatentsPolymorphic',
'ConditioningPolymorphic',
'ControlPolymorphic',
'ColorPolymorphic',
];
export const MODEL_TYPES = [
'ControlNetModelField',
'LoRAModelField',
'MainModelField',
'ONNXModelField',
'SDXLMainModelField',
'SDXLRefinerModelField',
'VaeModelField',
'UNetField',
'VaeField',
'ClipField',
];
export const COLLECTION_MAP = {
integer: 'IntegerCollection',
boolean: 'BooleanCollection',
number: 'FloatCollection',
float: 'FloatCollection',
string: 'StringCollection',
ImageField: 'ImageCollection',
LatentsField: 'LatentsCollection',
ConditioningField: 'ConditioningCollection',
ControlField: 'ControlCollection',
ColorField: 'ColorCollection',
};
export const isCollectionItemType = (
itemType: string | undefined
): itemType is keyof typeof COLLECTION_MAP =>
Boolean(itemType && itemType in COLLECTION_MAP);
export const SINGLE_TO_POLYMORPHIC_MAP = {
integer: 'IntegerPolymorphic',
boolean: 'BooleanPolymorphic',
number: 'FloatPolymorphic',
float: 'FloatPolymorphic',
string: 'StringPolymorphic',
ImageField: 'ImagePolymorphic',
LatentsField: 'LatentsPolymorphic',
ConditioningField: 'ConditioningPolymorphic',
ControlField: 'ControlPolymorphic',
ColorField: 'ColorPolymorphic',
};
export const POLYMORPHIC_TO_SINGLE_MAP = {
IntegerPolymorphic: 'integer',
BooleanPolymorphic: 'boolean',
FloatPolymorphic: 'float',
StringPolymorphic: 'string',
ImagePolymorphic: 'ImageField',
LatentsPolymorphic: 'LatentsField',
ConditioningPolymorphic: 'ConditioningField',
ControlPolymorphic: 'ControlField',
ColorPolymorphic: 'ColorField',
};
export const isPolymorphicItemType = (
itemType: string | undefined
): itemType is keyof typeof SINGLE_TO_POLYMORPHIC_MAP =>
Boolean(itemType && itemType in SINGLE_TO_POLYMORPHIC_MAP);
export const FIELDS: Record<FieldType, FieldUIConfig> = { export const FIELDS: Record<FieldType, FieldUIConfig> = {
integer: {
title: 'Integer',
description: 'Integers are whole numbers, without a decimal point.',
color: 'red.500',
},
float: {
title: 'Float',
description: 'Floats are numbers with a decimal point.',
color: 'orange.500',
},
string: {
title: 'String',
description: 'Strings are text.',
color: 'yellow.500',
},
boolean: { boolean: {
title: 'Boolean',
color: 'green.500', color: 'green.500',
description: 'Booleans are true or false.', description: 'Booleans are true or false.',
title: 'Boolean',
}, },
enum: { BooleanCollection: {
title: 'Enum', color: 'green.500',
description: 'Enums are values that may be one of a number of options.', description: 'A collection of booleans.',
color: 'blue.500', title: 'Boolean Collection',
}, },
array: { BooleanPolymorphic: {
title: 'Array', color: 'green.500',
description: 'Enums are values that may be one of a number of options.', description: 'A collection of booleans.',
color: 'base.500', title: 'Boolean Polymorphic',
},
ImageField: {
title: 'Image',
description: 'Images may be passed between nodes.',
color: 'purple.500',
},
DenoiseMaskField: {
title: 'Denoise Mask',
description: 'Denoise Mask may be passed between nodes',
color: 'base.500',
},
LatentsField: {
title: 'Latents',
description: 'Latents may be passed between nodes.',
color: 'pink.500',
},
LatentsCollection: {
title: 'Latents Collection',
description: 'Latents may be passed between nodes.',
color: 'pink.500',
},
ConditioningField: {
color: 'cyan.500',
title: 'Conditioning',
description: 'Conditioning may be passed between nodes.',
},
ConditioningCollection: {
color: 'cyan.500',
title: 'Conditioning Collection',
description: 'Conditioning may be passed between nodes.',
},
ImageCollection: {
title: 'Image Collection',
description: 'A collection of images.',
color: 'base.300',
},
UNetField: {
color: 'red.500',
title: 'UNet',
description: 'UNet submodel.',
}, },
ClipField: { ClipField: {
color: 'green.500', color: 'green.500',
title: 'Clip',
description: 'Tokenizer and text_encoder submodels.', description: 'Tokenizer and text_encoder submodels.',
}, title: 'Clip',
VaeField: {
color: 'blue.500',
title: 'Vae',
description: 'Vae submodel.',
},
ControlField: {
color: 'cyan.500',
title: 'Control',
description: 'Control info passed between nodes.',
},
MainModelField: {
color: 'teal.500',
title: 'Model',
description: 'TODO',
},
SDXLRefinerModelField: {
color: 'teal.500',
title: 'Refiner Model',
description: 'TODO',
},
VaeModelField: {
color: 'teal.500',
title: 'VAE',
description: 'TODO',
},
LoRAModelField: {
color: 'teal.500',
title: 'LoRA',
description: 'TODO',
},
ControlNetModelField: {
color: 'teal.500',
title: 'ControlNet',
description: 'TODO',
},
Scheduler: {
color: 'base.500',
title: 'Scheduler',
description: 'TODO',
}, },
Collection: { Collection: {
color: 'base.500', color: 'base.500',
title: 'Collection',
description: 'TODO', description: 'TODO',
title: 'Collection',
}, },
CollectionItem: { CollectionItem: {
color: 'base.500', color: 'base.500',
title: 'Collection Item',
description: 'TODO', description: 'TODO',
title: 'Collection Item',
},
ColorCollection: {
color: 'pink.300',
description: 'A collection of colors.',
title: 'Color Collection',
}, },
ColorField: { ColorField: {
title: 'Color', color: 'pink.300',
description: 'A RGBA color.', description: 'A RGBA color.',
color: 'base.500', title: 'Color',
}, },
BooleanCollection: { ColorPolymorphic: {
title: 'Boolean Collection', color: 'pink.300',
description: 'A collection of booleans.', description: 'A collection of colors.',
color: 'green.500', title: 'Color Polymorphic',
}, },
IntegerCollection: { ConditioningCollection: {
title: 'Integer Collection', color: 'cyan.500',
description: 'A collection of integers.', description: 'Conditioning may be passed between nodes.',
color: 'red.500', title: 'Conditioning Collection',
},
ConditioningField: {
color: 'cyan.500',
description: 'Conditioning may be passed between nodes.',
title: 'Conditioning',
},
ConditioningPolymorphic: {
color: 'cyan.500',
description: 'Conditioning may be passed between nodes.',
title: 'Conditioning Polymorphic',
},
ControlCollection: {
color: 'teal.500',
description: 'Control info passed between nodes.',
title: 'Control Collection',
},
ControlField: {
color: 'teal.500',
description: 'Control info passed between nodes.',
title: 'Control',
},
ControlNetModelField: {
color: 'teal.500',
description: 'TODO',
title: 'ControlNet',
},
ControlPolymorphic: {
color: 'teal.500',
description: 'Control info passed between nodes.',
title: 'Control Polymorphic',
},
DenoiseMaskField: {
color: 'blue.300',
description: 'Denoise Mask may be passed between nodes',
title: 'Denoise Mask',
},
enum: {
color: 'blue.500',
description: 'Enums are values that may be one of a number of options.',
title: 'Enum',
},
float: {
color: 'orange.500',
description: 'Floats are numbers with a decimal point.',
title: 'Float',
}, },
FloatCollection: { FloatCollection: {
color: 'orange.500', color: 'orange.500',
title: 'Float Collection',
description: 'A collection of floats.', description: 'A collection of floats.',
title: 'Float Collection',
}, },
ColorCollection: { FloatPolymorphic: {
color: 'base.500', color: 'orange.500',
title: 'Color Collection', description: 'A collection of floats.',
description: 'A collection of colors.', title: 'Float Polymorphic',
},
ImageCollection: {
color: 'purple.500',
description: 'A collection of images.',
title: 'Image Collection',
},
ImageField: {
color: 'purple.500',
description: 'Images may be passed between nodes.',
title: 'Image',
},
ImagePolymorphic: {
color: 'purple.500',
description: 'A collection of images.',
title: 'Image Polymorphic',
},
integer: {
color: 'red.500',
description: 'Integers are whole numbers, without a decimal point.',
title: 'Integer',
},
IntegerCollection: {
color: 'red.500',
description: 'A collection of integers.',
title: 'Integer Collection',
},
IntegerPolymorphic: {
color: 'red.500',
description: 'A collection of integers.',
title: 'Integer Polymorphic',
},
LatentsCollection: {
color: 'pink.500',
description: 'Latents may be passed between nodes.',
title: 'Latents Collection',
},
LatentsField: {
color: 'pink.500',
description: 'Latents may be passed between nodes.',
title: 'Latents',
},
LatentsPolymorphic: {
color: 'pink.500',
description: 'Latents may be passed between nodes.',
title: 'Latents Polymorphic',
},
LoRAModelField: {
color: 'teal.500',
description: 'TODO',
title: 'LoRA',
},
MainModelField: {
color: 'teal.500',
description: 'TODO',
title: 'Model',
}, },
ONNXModelField: { ONNXModelField: {
color: 'base.500', color: 'teal.500',
title: 'ONNX Model',
description: 'ONNX model field.', description: 'ONNX model field.',
title: 'ONNX Model',
},
Scheduler: {
color: 'base.500',
description: 'TODO',
title: 'Scheduler',
}, },
SDXLMainModelField: { SDXLMainModelField: {
color: 'base.500', color: 'teal.500',
title: 'SDXL Model',
description: 'SDXL model field.', description: 'SDXL model field.',
title: 'SDXL Model',
},
SDXLRefinerModelField: {
color: 'teal.500',
description: 'TODO',
title: 'Refiner Model',
},
string: {
color: 'yellow.500',
description: 'Strings are text.',
title: 'String',
}, },
StringCollection: { StringCollection: {
color: 'yellow.500', color: 'yellow.500',
title: 'String Collection',
description: 'A collection of strings.', description: 'A collection of strings.',
title: 'String Collection',
},
StringPolymorphic: {
color: 'yellow.500',
description: 'A collection of strings.',
title: 'String Polymorphic',
},
UNetField: {
color: 'red.500',
description: 'UNet submodel.',
title: 'UNet',
},
VaeField: {
color: 'blue.500',
description: 'Vae submodel.',
title: 'Vae',
},
VaeModelField: {
color: 'teal.500',
description: 'TODO',
title: 'VAE',
}, },
}; };

View File

@ -11,7 +11,7 @@ import { keyBy } from 'lodash-es';
import { OpenAPIV3 } from 'openapi-types'; import { OpenAPIV3 } from 'openapi-types';
import { RgbaColor } from 'react-colorful'; import { RgbaColor } from 'react-colorful';
import { Node } from 'reactflow'; import { Node } from 'reactflow';
import { Graph, ImageDTO, _InputField, _OutputField } from 'services/api/types'; import { Graph, _InputField, _OutputField } from 'services/api/types';
import { import {
AnyInvocationType, AnyInvocationType,
AnyResult, AnyResult,
@ -62,50 +62,48 @@ export type FieldUIConfig = {
// TODO: Get this from the OpenAPI schema? may be tricky... // TODO: Get this from the OpenAPI schema? may be tricky...
export const zFieldType = z.enum([ export const zFieldType = z.enum([
// region Primitives
'integer',
'float',
'boolean', 'boolean',
'string',
'array',
'ImageField',
'DenoiseMaskField',
'LatentsField',
'ConditioningField',
'ControlField',
'ColorField',
'ImageCollection',
'ConditioningCollection',
'ColorCollection',
'LatentsCollection',
'IntegerCollection',
'FloatCollection',
'StringCollection',
'BooleanCollection', 'BooleanCollection',
// endregion 'BooleanPolymorphic',
// region Models
'MainModelField',
'SDXLMainModelField',
'SDXLRefinerModelField',
'ONNXModelField',
'VaeModelField',
'LoRAModelField',
'ControlNetModelField',
'UNetField',
'VaeField',
'ClipField', 'ClipField',
// endregion
// region Iterate/Collect
'Collection', 'Collection',
'CollectionItem', 'CollectionItem',
// endregion 'ColorCollection',
'ColorField',
// region Misc 'ColorPolymorphic',
'ConditioningCollection',
'ConditioningField',
'ConditioningPolymorphic',
'ControlCollection',
'ControlField',
'ControlNetModelField',
'ControlPolymorphic',
'DenoiseMaskField',
'enum', 'enum',
'float',
'FloatCollection',
'FloatPolymorphic',
'ImageCollection',
'ImageField',
'ImagePolymorphic',
'integer',
'IntegerCollection',
'IntegerPolymorphic',
'LatentsCollection',
'LatentsField',
'LatentsPolymorphic',
'LoRAModelField',
'MainModelField',
'ONNXModelField',
'Scheduler', 'Scheduler',
// endregion 'SDXLMainModelField',
'SDXLRefinerModelField',
'string',
'StringCollection',
'StringPolymorphic',
'UNetField',
'VaeField',
'VaeModelField',
]); ]);
export type FieldType = z.infer<typeof zFieldType>; export type FieldType = z.infer<typeof zFieldType>;
@ -122,38 +120,6 @@ export const isFieldType = (value: unknown): value is FieldType =>
zFieldType.safeParse(value).success || zFieldType.safeParse(value).success ||
zReservedFieldType.safeParse(value).success; zReservedFieldType.safeParse(value).success;
/**
* An input field template is generated on each page load from the OpenAPI schema.
*
* The template provides the field type and other field metadata (e.g. title, description,
* maximum length, pattern to match, etc).
*/
export type InputFieldTemplate =
| IntegerInputFieldTemplate
| FloatInputFieldTemplate
| StringInputFieldTemplate
| BooleanInputFieldTemplate
| ImageInputFieldTemplate
| DenoiseMaskInputFieldTemplate
| LatentsInputFieldTemplate
| ConditioningInputFieldTemplate
| UNetInputFieldTemplate
| ClipInputFieldTemplate
| VaeInputFieldTemplate
| ControlInputFieldTemplate
| EnumInputFieldTemplate
| MainModelInputFieldTemplate
| SDXLMainModelInputFieldTemplate
| SDXLRefinerModelInputFieldTemplate
| VaeModelInputFieldTemplate
| LoRAModelInputFieldTemplate
| ControlNetModelInputFieldTemplate
| CollectionInputFieldTemplate
| CollectionItemInputFieldTemplate
| ColorInputFieldTemplate
| ImageCollectionInputFieldTemplate
| SchedulerInputFieldTemplate;
/** /**
* Indicates the kind of input(s) this field may have. * Indicates the kind of input(s) this field may have.
*/ */
@ -232,24 +198,88 @@ export const zIntegerInputFieldValue = zInputFieldValueBase.extend({
}); });
export type IntegerInputFieldValue = z.infer<typeof zIntegerInputFieldValue>; export type IntegerInputFieldValue = z.infer<typeof zIntegerInputFieldValue>;
export const zIntegerCollectionInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('IntegerCollection'),
value: z.array(z.number().int()).optional(),
});
export type IntegerCollectionInputFieldValue = z.infer<
typeof zIntegerCollectionInputFieldValue
>;
export const zIntegerPolymorphicInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('IntegerPolymorphic'),
value: z.union([z.number().int(), z.array(z.number().int())]).optional(),
});
export type IntegerPolymorphicInputFieldValue = z.infer<
typeof zIntegerPolymorphicInputFieldValue
>;
export const zFloatInputFieldValue = zInputFieldValueBase.extend({ export const zFloatInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('float'), type: z.literal('float'),
value: z.number().optional(), value: z.number().optional(),
}); });
export type FloatInputFieldValue = z.infer<typeof zFloatInputFieldValue>; export type FloatInputFieldValue = z.infer<typeof zFloatInputFieldValue>;
export const zFloatCollectionInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('FloatCollection'),
value: z.array(z.number()).optional(),
});
export type FloatCollectionInputFieldValue = z.infer<
typeof zFloatCollectionInputFieldValue
>;
export const zFloatPolymorphicInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('FloatPolymorphic'),
value: z.union([z.number(), z.array(z.number())]).optional(),
});
export type FloatPolymorphicInputFieldValue = z.infer<
typeof zFloatPolymorphicInputFieldValue
>;
export const zStringInputFieldValue = zInputFieldValueBase.extend({ export const zStringInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('string'), type: z.literal('string'),
value: z.string().optional(), value: z.string().optional(),
}); });
export type StringInputFieldValue = z.infer<typeof zStringInputFieldValue>; export type StringInputFieldValue = z.infer<typeof zStringInputFieldValue>;
export const zStringCollectionInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('StringCollection'),
value: z.array(z.string()).optional(),
});
export type StringCollectionInputFieldValue = z.infer<
typeof zStringCollectionInputFieldValue
>;
export const zStringPolymorphicInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('StringPolymorphic'),
value: z.union([z.string(), z.array(z.string())]).optional(),
});
export type StringPolymorphicInputFieldValue = z.infer<
typeof zStringPolymorphicInputFieldValue
>;
export const zBooleanInputFieldValue = zInputFieldValueBase.extend({ export const zBooleanInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('boolean'), type: z.literal('boolean'),
value: z.boolean().optional(), value: z.boolean().optional(),
}); });
export type BooleanInputFieldValue = z.infer<typeof zBooleanInputFieldValue>; export type BooleanInputFieldValue = z.infer<typeof zBooleanInputFieldValue>;
export const zBooleanCollectionInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('BooleanCollection'),
value: z.array(z.boolean()).optional(),
});
export type BooleanCollectionInputFieldValue = z.infer<
typeof zBooleanCollectionInputFieldValue
>;
export const zBooleanPolymorphicInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('BooleanPolymorphic'),
value: z.union([z.boolean(), z.array(z.boolean())]).optional(),
});
export type BooleanPolymorphicInputFieldValue = z.infer<
typeof zBooleanPolymorphicInputFieldValue
>;
export const zEnumInputFieldValue = zInputFieldValueBase.extend({ export const zEnumInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('enum'), type: z.literal('enum'),
value: z.union([z.string(), z.number()]).optional(), value: z.union([z.string(), z.number()]).optional(),
@ -262,6 +292,22 @@ export const zLatentsInputFieldValue = zInputFieldValueBase.extend({
}); });
export type LatentsInputFieldValue = z.infer<typeof zLatentsInputFieldValue>; export type LatentsInputFieldValue = z.infer<typeof zLatentsInputFieldValue>;
export const zLatentsCollectionInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('LatentsCollection'),
value: z.array(zLatentsField).optional(),
});
export type LatentsCollectionInputFieldValue = z.infer<
typeof zLatentsCollectionInputFieldValue
>;
export const zLatentsPolymorphicInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('LatentsPolymorphic'),
value: z.union([zLatentsField, z.array(zLatentsField)]).optional(),
});
export type LatentsPolymorphicInputFieldValue = z.infer<
typeof zLatentsPolymorphicInputFieldValue
>;
export const zDenoiseMaskInputFieldValue = zInputFieldValueBase.extend({ export const zDenoiseMaskInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('DenoiseMaskField'), type: z.literal('DenoiseMaskField'),
value: zDenoiseMaskField.optional(), value: zDenoiseMaskField.optional(),
@ -278,6 +324,26 @@ export type ConditioningInputFieldValue = z.infer<
typeof zConditioningInputFieldValue typeof zConditioningInputFieldValue
>; >;
export const zConditioningCollectionInputFieldValue =
zInputFieldValueBase.extend({
type: z.literal('ConditioningCollection'),
value: z.array(zConditioningField).optional(),
});
export type ConditioningCollectionInputFieldValue = z.infer<
typeof zConditioningCollectionInputFieldValue
>;
export const zConditioningPolymorphicInputFieldValue =
zInputFieldValueBase.extend({
type: z.literal('ConditioningPolymorphic'),
value: z
.union([zConditioningField, z.array(zConditioningField)])
.optional(),
});
export type ConditioningPolymorphicInputFieldValue = z.infer<
typeof zConditioningPolymorphicInputFieldValue
>;
export const zControlNetModel = zModelIdentifier; export const zControlNetModel = zModelIdentifier;
export type ControlNetModel = z.infer<typeof zControlNetModel>; export type ControlNetModel = z.infer<typeof zControlNetModel>;
@ -302,6 +368,22 @@ export const zControlInputFieldValue = zInputFieldValueBase.extend({
}); });
export type ControlInputFieldValue = z.infer<typeof zControlInputFieldValue>; export type ControlInputFieldValue = z.infer<typeof zControlInputFieldValue>;
export const zControlPolymorphicInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('ControlPolymorphic'),
value: z.union([zControlField, z.array(zControlField)]).optional(),
});
export type ControlPolymorphicInputFieldValue = z.infer<
typeof zControlPolymorphicInputFieldValue
>;
export const zControlCollectionInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('ControlCollection'),
value: z.array(zControlField).optional(),
});
export type ControlCollectionInputFieldValue = z.infer<
typeof zControlCollectionInputFieldValue
>;
export const zModelType = z.enum([ export const zModelType = z.enum([
'onnx', 'onnx',
'main', 'main',
@ -381,6 +463,14 @@ export const zImageInputFieldValue = zInputFieldValueBase.extend({
}); });
export type ImageInputFieldValue = z.infer<typeof zImageInputFieldValue>; export type ImageInputFieldValue = z.infer<typeof zImageInputFieldValue>;
export const zImagePolymorphicInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('ImagePolymorphic'),
value: z.union([zImageField, z.array(zImageField)]).optional(),
});
export type ImagePolymorphicInputFieldValue = z.infer<
typeof zImagePolymorphicInputFieldValue
>;
export const zImageCollectionInputFieldValue = zInputFieldValueBase.extend({ export const zImageCollectionInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('ImageCollection'), type: z.literal('ImageCollection'),
value: z.array(zImageField).optional(), value: z.array(zImageField).optional(),
@ -473,6 +563,22 @@ export const zColorInputFieldValue = zInputFieldValueBase.extend({
}); });
export type ColorInputFieldValue = z.infer<typeof zColorInputFieldValue>; export type ColorInputFieldValue = z.infer<typeof zColorInputFieldValue>;
export const zColorCollectionInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('ColorCollection'),
value: z.array(zColorField).optional(),
});
export type ColorCollectionInputFieldValue = z.infer<
typeof zColorCollectionInputFieldValue
>;
export const zColorPolymorphicInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('ColorPolymorphic'),
value: z.union([zColorField, z.array(zColorField)]).optional(),
});
export type ColorPolymorphicInputFieldValue = z.infer<
typeof zColorPolymorphicInputFieldValue
>;
export const zSchedulerInputFieldValue = zInputFieldValueBase.extend({ export const zSchedulerInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('Scheduler'), type: z.literal('Scheduler'),
value: zScheduler.optional(), value: zScheduler.optional(),
@ -482,30 +588,47 @@ export type SchedulerInputFieldValue = z.infer<
>; >;
export const zInputFieldValue = z.discriminatedUnion('type', [ export const zInputFieldValue = z.discriminatedUnion('type', [
zIntegerInputFieldValue, zBooleanCollectionInputFieldValue,
zFloatInputFieldValue,
zStringInputFieldValue,
zBooleanInputFieldValue, zBooleanInputFieldValue,
zImageInputFieldValue, zBooleanPolymorphicInputFieldValue,
zLatentsInputFieldValue,
zDenoiseMaskInputFieldValue,
zConditioningInputFieldValue,
zUNetInputFieldValue,
zClipInputFieldValue, zClipInputFieldValue,
zVaeInputFieldValue,
zControlInputFieldValue,
zEnumInputFieldValue,
zMainModelInputFieldValue,
zSDXLMainModelInputFieldValue,
zSDXLRefinerModelInputFieldValue,
zVaeModelInputFieldValue,
zLoRAModelInputFieldValue,
zControlNetModelInputFieldValue,
zCollectionInputFieldValue, zCollectionInputFieldValue,
zCollectionItemInputFieldValue, zCollectionItemInputFieldValue,
zColorInputFieldValue, zColorInputFieldValue,
zColorCollectionInputFieldValue,
zColorPolymorphicInputFieldValue,
zConditioningInputFieldValue,
zConditioningCollectionInputFieldValue,
zConditioningPolymorphicInputFieldValue,
zControlInputFieldValue,
zControlNetModelInputFieldValue,
zControlCollectionInputFieldValue,
zControlPolymorphicInputFieldValue,
zDenoiseMaskInputFieldValue,
zEnumInputFieldValue,
zFloatCollectionInputFieldValue,
zFloatInputFieldValue,
zFloatPolymorphicInputFieldValue,
zImageCollectionInputFieldValue, zImageCollectionInputFieldValue,
zImagePolymorphicInputFieldValue,
zImageInputFieldValue,
zIntegerCollectionInputFieldValue,
zIntegerPolymorphicInputFieldValue,
zIntegerInputFieldValue,
zLatentsInputFieldValue,
zLatentsCollectionInputFieldValue,
zLatentsPolymorphicInputFieldValue,
zLoRAModelInputFieldValue,
zMainModelInputFieldValue,
zSchedulerInputFieldValue, zSchedulerInputFieldValue,
zSDXLMainModelInputFieldValue,
zSDXLRefinerModelInputFieldValue,
zStringCollectionInputFieldValue,
zStringPolymorphicInputFieldValue,
zStringInputFieldValue,
zUNetInputFieldValue,
zVaeInputFieldValue,
zVaeModelInputFieldValue,
]); ]);
export type InputFieldValue = z.infer<typeof zInputFieldValue>; export type InputFieldValue = z.infer<typeof zInputFieldValue>;
@ -514,7 +637,6 @@ export type InputFieldTemplateBase = {
name: string; name: string;
title: string; title: string;
description: string; description: string;
type: FieldType;
required: boolean; required: boolean;
fieldKind: 'input'; fieldKind: 'input';
} & _InputField; } & _InputField;
@ -529,6 +651,19 @@ export type IntegerInputFieldTemplate = InputFieldTemplateBase & {
exclusiveMinimum?: boolean; exclusiveMinimum?: boolean;
}; };
export type IntegerCollectionInputFieldTemplate = InputFieldTemplateBase & {
type: 'IntegerCollection';
default: number[];
item_default?: number;
};
export type IntegerPolymorphicInputFieldTemplate = Omit<
IntegerInputFieldTemplate,
'type'
> & {
type: 'IntegerPolymorphic';
};
export type FloatInputFieldTemplate = InputFieldTemplateBase & { export type FloatInputFieldTemplate = InputFieldTemplateBase & {
type: 'float'; type: 'float';
default: number; default: number;
@ -539,6 +674,19 @@ export type FloatInputFieldTemplate = InputFieldTemplateBase & {
exclusiveMinimum?: boolean; exclusiveMinimum?: boolean;
}; };
export type FloatCollectionInputFieldTemplate = InputFieldTemplateBase & {
type: 'FloatCollection';
default: number[];
item_default?: number;
};
export type FloatPolymorphicInputFieldTemplate = Omit<
FloatInputFieldTemplate,
'type'
> & {
type: 'FloatPolymorphic';
};
export type StringInputFieldTemplate = InputFieldTemplateBase & { export type StringInputFieldTemplate = InputFieldTemplateBase & {
type: 'string'; type: 'string';
default: string; default: string;
@ -547,19 +695,53 @@ export type StringInputFieldTemplate = InputFieldTemplateBase & {
pattern?: string; pattern?: string;
}; };
export type StringCollectionInputFieldTemplate = InputFieldTemplateBase & {
type: 'StringCollection';
default: string[];
item_default?: string;
};
export type StringPolymorphicInputFieldTemplate = Omit<
StringInputFieldTemplate,
'type'
> & {
type: 'StringPolymorphic';
};
export type BooleanInputFieldTemplate = InputFieldTemplateBase & { export type BooleanInputFieldTemplate = InputFieldTemplateBase & {
default: boolean; default: boolean;
type: 'boolean'; type: 'boolean';
}; };
export type BooleanCollectionInputFieldTemplate = InputFieldTemplateBase & {
type: 'BooleanCollection';
default: boolean[];
item_default?: boolean;
};
export type BooleanPolymorphicInputFieldTemplate = Omit<
BooleanInputFieldTemplate,
'type'
> & {
type: 'BooleanPolymorphic';
};
export type ImageInputFieldTemplate = InputFieldTemplateBase & { export type ImageInputFieldTemplate = InputFieldTemplateBase & {
default: ImageDTO; default: ImageField;
type: 'ImageField'; type: 'ImageField';
}; };
export type ImageCollectionInputFieldTemplate = InputFieldTemplateBase & { export type ImageCollectionInputFieldTemplate = InputFieldTemplateBase & {
default: ImageField[]; default: ImageField[];
type: 'ImageCollection'; type: 'ImageCollection';
item_default?: ImageField;
};
export type ImagePolymorphicInputFieldTemplate = Omit<
ImageInputFieldTemplate,
'type'
> & {
type: 'ImagePolymorphic';
}; };
export type DenoiseMaskInputFieldTemplate = InputFieldTemplateBase & { export type DenoiseMaskInputFieldTemplate = InputFieldTemplateBase & {
@ -568,15 +750,40 @@ export type DenoiseMaskInputFieldTemplate = InputFieldTemplateBase & {
}; };
export type LatentsInputFieldTemplate = InputFieldTemplateBase & { export type LatentsInputFieldTemplate = InputFieldTemplateBase & {
default: string; default: LatentsField;
type: 'LatentsField'; type: 'LatentsField';
}; };
export type LatentsCollectionInputFieldTemplate = InputFieldTemplateBase & {
default: LatentsField[];
type: 'LatentsCollection';
item_default?: LatentsField;
};
export type LatentsPolymorphicInputFieldTemplate = InputFieldTemplateBase & {
default: LatentsField;
type: 'LatentsPolymorphic';
};
export type ConditioningInputFieldTemplate = InputFieldTemplateBase & { export type ConditioningInputFieldTemplate = InputFieldTemplateBase & {
default: undefined; default: undefined;
type: 'ConditioningField'; type: 'ConditioningField';
}; };
export type ConditioningCollectionInputFieldTemplate =
InputFieldTemplateBase & {
default: ConditioningField[];
type: 'ConditioningCollection';
item_default?: ConditioningField;
};
export type ConditioningPolymorphicInputFieldTemplate = Omit<
ConditioningInputFieldTemplate,
'type'
> & {
type: 'ConditioningPolymorphic';
};
export type UNetInputFieldTemplate = InputFieldTemplateBase & { export type UNetInputFieldTemplate = InputFieldTemplateBase & {
default: undefined; default: undefined;
type: 'UNetField'; type: 'UNetField';
@ -597,6 +804,19 @@ export type ControlInputFieldTemplate = InputFieldTemplateBase & {
type: 'ControlField'; type: 'ControlField';
}; };
export type ControlCollectionInputFieldTemplate = InputFieldTemplateBase & {
default: undefined;
type: 'ControlCollection';
item_default?: ControlField;
};
export type ControlPolymorphicInputFieldTemplate = Omit<
ControlInputFieldTemplate,
'type'
> & {
type: 'ControlPolymorphic';
};
export type EnumInputFieldTemplate = InputFieldTemplateBase & { export type EnumInputFieldTemplate = InputFieldTemplateBase & {
default: string | number; default: string | number;
type: 'enum'; type: 'enum';
@ -649,6 +869,18 @@ export type ColorInputFieldTemplate = InputFieldTemplateBase & {
type: 'ColorField'; type: 'ColorField';
}; };
export type ColorPolymorphicInputFieldTemplate = Omit<
ColorInputFieldTemplate,
'type'
> & {
type: 'ColorPolymorphic';
};
export type ColorCollectionInputFieldTemplate = InputFieldTemplateBase & {
default: [];
type: 'ColorCollection';
};
export type SchedulerInputFieldTemplate = InputFieldTemplateBase & { export type SchedulerInputFieldTemplate = InputFieldTemplateBase & {
default: SchedulerParam; default: SchedulerParam;
type: 'Scheduler'; type: 'Scheduler';
@ -659,6 +891,55 @@ export type WorkflowInputFieldTemplate = InputFieldTemplateBase & {
type: 'WorkflowField'; type: 'WorkflowField';
}; };
/**
* An input field template is generated on each page load from the OpenAPI schema.
*
* The template provides the field type and other field metadata (e.g. title, description,
* maximum length, pattern to match, etc).
*/
export type InputFieldTemplate =
| BooleanCollectionInputFieldTemplate
| BooleanPolymorphicInputFieldTemplate
| BooleanInputFieldTemplate
| ClipInputFieldTemplate
| CollectionInputFieldTemplate
| CollectionItemInputFieldTemplate
| ColorInputFieldTemplate
| ColorCollectionInputFieldTemplate
| ColorPolymorphicInputFieldTemplate
| ConditioningInputFieldTemplate
| ConditioningCollectionInputFieldTemplate
| ConditioningPolymorphicInputFieldTemplate
| ControlInputFieldTemplate
| ControlCollectionInputFieldTemplate
| ControlNetModelInputFieldTemplate
| ControlPolymorphicInputFieldTemplate
| DenoiseMaskInputFieldTemplate
| EnumInputFieldTemplate
| FloatCollectionInputFieldTemplate
| FloatInputFieldTemplate
| FloatPolymorphicInputFieldTemplate
| ImageCollectionInputFieldTemplate
| ImagePolymorphicInputFieldTemplate
| ImageInputFieldTemplate
| IntegerCollectionInputFieldTemplate
| IntegerPolymorphicInputFieldTemplate
| IntegerInputFieldTemplate
| LatentsInputFieldTemplate
| LatentsCollectionInputFieldTemplate
| LatentsPolymorphicInputFieldTemplate
| LoRAModelInputFieldTemplate
| MainModelInputFieldTemplate
| SchedulerInputFieldTemplate
| SDXLMainModelInputFieldTemplate
| SDXLRefinerModelInputFieldTemplate
| StringCollectionInputFieldTemplate
| StringPolymorphicInputFieldTemplate
| StringInputFieldTemplate
| UNetInputFieldTemplate
| VaeInputFieldTemplate
| VaeModelInputFieldTemplate;
export const isInputFieldValue = ( export const isInputFieldValue = (
field?: InputFieldValue | OutputFieldValue field?: InputFieldValue | OutputFieldValue
): field is InputFieldValue => Boolean(field && field.fieldKind === 'input'); ): field is InputFieldValue => Boolean(field && field.fieldKind === 'input');
@ -731,8 +1012,22 @@ export type InvocationSchemaObject = (
) & { class: 'invocation' }; ) & { class: 'invocation' };
export const isSchemaObject = ( export const isSchemaObject = (
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject | undefined
): obj is OpenAPIV3.SchemaObject => !('$ref' in obj); ): obj is OpenAPIV3.SchemaObject => Boolean(obj && !('$ref' in obj));
export const isArraySchemaObject = (
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject | undefined
): obj is OpenAPIV3.ArraySchemaObject =>
Boolean(obj && !('$ref' in obj) && obj.type === 'array');
export const isNonArraySchemaObject = (
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject | undefined
): obj is OpenAPIV3.NonArraySchemaObject =>
Boolean(obj && !('$ref' in obj) && obj.type !== 'array');
export const isRefObject = (
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject | undefined
): obj is OpenAPIV3.ReferenceObject => Boolean(obj && '$ref' in obj);
export const isInvocationSchemaObject = ( export const isInvocationSchemaObject = (
obj: obj:

View File

@ -1,5 +1,14 @@
import { isBoolean, isInteger, isNumber, isString } from 'lodash-es';
import { OpenAPIV3 } from 'openapi-types'; import { OpenAPIV3 } from 'openapi-types';
import { import {
COLLECTION_MAP,
POLYMORPHIC_TYPES,
SINGLE_TO_POLYMORPHIC_MAP,
isCollectionItemType,
isPolymorphicItemType,
} from '../types/constants';
import {
BooleanCollectionInputFieldTemplate,
BooleanInputFieldTemplate, BooleanInputFieldTemplate,
ClipInputFieldTemplate, ClipInputFieldTemplate,
CollectionInputFieldTemplate, CollectionInputFieldTemplate,
@ -11,10 +20,13 @@ import {
DenoiseMaskInputFieldTemplate, DenoiseMaskInputFieldTemplate,
EnumInputFieldTemplate, EnumInputFieldTemplate,
FieldType, FieldType,
FloatCollectionInputFieldTemplate,
FloatPolymorphicInputFieldTemplate,
FloatInputFieldTemplate, FloatInputFieldTemplate,
ImageCollectionInputFieldTemplate, ImageCollectionInputFieldTemplate,
ImageInputFieldTemplate, ImageInputFieldTemplate,
InputFieldTemplateBase, InputFieldTemplateBase,
IntegerCollectionInputFieldTemplate,
IntegerInputFieldTemplate, IntegerInputFieldTemplate,
InvocationFieldSchema, InvocationFieldSchema,
InvocationSchemaObject, InvocationSchemaObject,
@ -24,11 +36,32 @@ import {
SDXLMainModelInputFieldTemplate, SDXLMainModelInputFieldTemplate,
SDXLRefinerModelInputFieldTemplate, SDXLRefinerModelInputFieldTemplate,
SchedulerInputFieldTemplate, SchedulerInputFieldTemplate,
StringCollectionInputFieldTemplate,
StringInputFieldTemplate, StringInputFieldTemplate,
UNetInputFieldTemplate, UNetInputFieldTemplate,
VaeInputFieldTemplate, VaeInputFieldTemplate,
VaeModelInputFieldTemplate, VaeModelInputFieldTemplate,
isArraySchemaObject,
isNonArraySchemaObject,
isRefObject,
isSchemaObject,
ControlPolymorphicInputFieldTemplate,
ColorPolymorphicInputFieldTemplate,
ColorCollectionInputFieldTemplate,
IntegerPolymorphicInputFieldTemplate,
StringPolymorphicInputFieldTemplate,
BooleanPolymorphicInputFieldTemplate,
ImagePolymorphicInputFieldTemplate,
LatentsPolymorphicInputFieldTemplate,
LatentsCollectionInputFieldTemplate,
ConditioningPolymorphicInputFieldTemplate,
ConditioningCollectionInputFieldTemplate,
ControlCollectionInputFieldTemplate,
ImageField,
LatentsField,
ConditioningField,
} from '../types/types'; } from '../types/types';
import { ControlField } from 'services/api/types';
export type BaseFieldProperties = 'name' | 'title' | 'description'; export type BaseFieldProperties = 'name' | 'title' | 'description';
@ -45,15 +78,8 @@ export type BuildInputFieldArg = {
* @example * @example
* refObjectToFieldType({ "$ref": "#/components/schemas/ImageField" }) --> 'ImageField' * refObjectToFieldType({ "$ref": "#/components/schemas/ImageField" }) --> 'ImageField'
*/ */
export const refObjectToFieldType = ( export const refObjectToSchemaName = (refObject: OpenAPIV3.ReferenceObject) =>
refObject: OpenAPIV3.ReferenceObject refObject.$ref.split('/').slice(-1)[0];
): FieldType => {
const name = refObject.$ref.split('/').slice(-1)[0];
if (!name) {
throw `Unknown field type: ${name}`;
}
return name as FieldType;
};
const buildIntegerInputFieldTemplate = ({ const buildIntegerInputFieldTemplate = ({
schemaObject, schemaObject,
@ -88,6 +114,57 @@ const buildIntegerInputFieldTemplate = ({
return template; return template;
}; };
const buildIntegerPolymorphicInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): IntegerPolymorphicInputFieldTemplate => {
const template: IntegerPolymorphicInputFieldTemplate = {
...baseField,
type: 'IntegerPolymorphic',
default: schemaObject.default ?? 0,
};
if (schemaObject.multipleOf !== undefined) {
template.multipleOf = schemaObject.multipleOf;
}
if (schemaObject.maximum !== undefined) {
template.maximum = schemaObject.maximum;
}
if (schemaObject.exclusiveMaximum !== undefined) {
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
}
if (schemaObject.minimum !== undefined) {
template.minimum = schemaObject.minimum;
}
if (schemaObject.exclusiveMinimum !== undefined) {
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
}
return template;
};
const buildIntegerCollectionInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): IntegerCollectionInputFieldTemplate => {
const item_default =
isNumber(schemaObject.item_default) && isInteger(schemaObject.item_default)
? schemaObject.item_default
: 0;
const template: IntegerCollectionInputFieldTemplate = {
...baseField,
type: 'IntegerCollection',
default: schemaObject.default ?? [],
item_default,
};
return template;
};
const buildFloatInputFieldTemplate = ({ const buildFloatInputFieldTemplate = ({
schemaObject, schemaObject,
baseField, baseField,
@ -121,6 +198,54 @@ const buildFloatInputFieldTemplate = ({
return template; return template;
}; };
const buildFloatPolymorphicInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): FloatPolymorphicInputFieldTemplate => {
const template: FloatPolymorphicInputFieldTemplate = {
...baseField,
type: 'FloatPolymorphic',
default: schemaObject.default ?? 0,
};
if (schemaObject.multipleOf !== undefined) {
template.multipleOf = schemaObject.multipleOf;
}
if (schemaObject.maximum !== undefined) {
template.maximum = schemaObject.maximum;
}
if (schemaObject.exclusiveMaximum !== undefined) {
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
}
if (schemaObject.minimum !== undefined) {
template.minimum = schemaObject.minimum;
}
if (schemaObject.exclusiveMinimum !== undefined) {
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
}
return template;
};
const buildFloatCollectionInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): FloatCollectionInputFieldTemplate => {
const item_default = isNumber(schemaObject.item_default)
? schemaObject.item_default
: 0;
const template: FloatCollectionInputFieldTemplate = {
...baseField,
type: 'FloatCollection',
default: schemaObject.default ?? [],
item_default,
};
return template;
};
const buildStringInputFieldTemplate = ({ const buildStringInputFieldTemplate = ({
schemaObject, schemaObject,
baseField, baseField,
@ -146,6 +271,48 @@ const buildStringInputFieldTemplate = ({
return template; return template;
}; };
const buildStringPolymorphicInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): StringPolymorphicInputFieldTemplate => {
const template: StringPolymorphicInputFieldTemplate = {
...baseField,
type: 'StringPolymorphic',
default: schemaObject.default ?? '',
};
if (schemaObject.minLength !== undefined) {
template.minLength = schemaObject.minLength;
}
if (schemaObject.maxLength !== undefined) {
template.maxLength = schemaObject.maxLength;
}
if (schemaObject.pattern !== undefined) {
template.pattern = schemaObject.pattern;
}
return template;
};
const buildStringCollectionInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): StringCollectionInputFieldTemplate => {
const item_default = isString(schemaObject.item_default)
? schemaObject.item_default
: '';
const template: StringCollectionInputFieldTemplate = {
...baseField,
type: 'StringCollection',
default: schemaObject.default ?? [],
item_default,
};
return template;
};
const buildBooleanInputFieldTemplate = ({ const buildBooleanInputFieldTemplate = ({
schemaObject, schemaObject,
baseField, baseField,
@ -159,6 +326,37 @@ const buildBooleanInputFieldTemplate = ({
return template; return template;
}; };
const buildBooleanPolymorphicInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): BooleanPolymorphicInputFieldTemplate => {
const template: BooleanPolymorphicInputFieldTemplate = {
...baseField,
type: 'BooleanPolymorphic',
default: schemaObject.default ?? false,
};
return template;
};
const buildBooleanCollectionInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): BooleanCollectionInputFieldTemplate => {
const item_default =
schemaObject.item_default && isBoolean(schemaObject.item_default)
? schemaObject.item_default
: false;
const template: BooleanCollectionInputFieldTemplate = {
...baseField,
type: 'BooleanCollection',
default: schemaObject.default ?? [],
item_default,
};
return template;
};
const buildMainModelInputFieldTemplate = ({ const buildMainModelInputFieldTemplate = ({
schemaObject, schemaObject,
baseField, baseField,
@ -250,6 +448,19 @@ const buildImageInputFieldTemplate = ({
return template; return template;
}; };
const buildImagePolymorphicInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): ImagePolymorphicInputFieldTemplate => {
const template: ImagePolymorphicInputFieldTemplate = {
...baseField,
type: 'ImagePolymorphic',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildImageCollectionInputFieldTemplate = ({ const buildImageCollectionInputFieldTemplate = ({
schemaObject, schemaObject,
baseField, baseField,
@ -257,7 +468,8 @@ const buildImageCollectionInputFieldTemplate = ({
const template: ImageCollectionInputFieldTemplate = { const template: ImageCollectionInputFieldTemplate = {
...baseField, ...baseField,
type: 'ImageCollection', type: 'ImageCollection',
default: schemaObject.default ?? undefined, default: schemaObject.default ?? [],
item_default: (schemaObject.item_default as ImageField) ?? undefined,
}; };
return template; return template;
@ -289,6 +501,33 @@ const buildLatentsInputFieldTemplate = ({
return template; return template;
}; };
const buildLatentsPolymorphicInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): LatentsPolymorphicInputFieldTemplate => {
const template: LatentsPolymorphicInputFieldTemplate = {
...baseField,
type: 'LatentsPolymorphic',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildLatentsCollectionInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): LatentsCollectionInputFieldTemplate => {
const template: LatentsCollectionInputFieldTemplate = {
...baseField,
type: 'LatentsCollection',
default: schemaObject.default ?? [],
item_default: (schemaObject.item_default as LatentsField) ?? undefined,
};
return template;
};
const buildConditioningInputFieldTemplate = ({ const buildConditioningInputFieldTemplate = ({
schemaObject, schemaObject,
baseField, baseField,
@ -302,6 +541,33 @@ const buildConditioningInputFieldTemplate = ({
return template; return template;
}; };
const buildConditioningPolymorphicInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): ConditioningPolymorphicInputFieldTemplate => {
const template: ConditioningPolymorphicInputFieldTemplate = {
...baseField,
type: 'ConditioningPolymorphic',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildConditioningCollectionInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): ConditioningCollectionInputFieldTemplate => {
const template: ConditioningCollectionInputFieldTemplate = {
...baseField,
type: 'ConditioningCollection',
default: schemaObject.default ?? [],
item_default: (schemaObject.item_default as ConditioningField) ?? undefined,
};
return template;
};
const buildUNetInputFieldTemplate = ({ const buildUNetInputFieldTemplate = ({
schemaObject, schemaObject,
baseField, baseField,
@ -355,6 +621,33 @@ const buildControlInputFieldTemplate = ({
return template; return template;
}; };
const buildControlPolymorphicInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): ControlPolymorphicInputFieldTemplate => {
const template: ControlPolymorphicInputFieldTemplate = {
...baseField,
type: 'ControlPolymorphic',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildControlCollectionInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): ControlCollectionInputFieldTemplate => {
const template: ControlCollectionInputFieldTemplate = {
...baseField,
type: 'ControlCollection',
default: schemaObject.default ?? [],
item_default: (schemaObject.item_default as ControlField) ?? undefined,
};
return template;
};
const buildEnumInputFieldTemplate = ({ const buildEnumInputFieldTemplate = ({
schemaObject, schemaObject,
baseField, baseField,
@ -408,6 +701,32 @@ const buildColorInputFieldTemplate = ({
return template; return template;
}; };
const buildColorPolymorphicInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): ColorPolymorphicInputFieldTemplate => {
const template: ColorPolymorphicInputFieldTemplate = {
...baseField,
type: 'ColorPolymorphic',
default: schemaObject.default ?? { r: 127, g: 127, b: 127, a: 255 },
};
return template;
};
const buildColorCollectionInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): ColorCollectionInputFieldTemplate => {
const template: ColorCollectionInputFieldTemplate = {
...baseField,
type: 'ColorCollection',
default: schemaObject.default ?? [],
};
return template;
};
const buildSchedulerInputFieldTemplate = ({ const buildSchedulerInputFieldTemplate = ({
schemaObject, schemaObject,
baseField, baseField,
@ -421,45 +740,138 @@ const buildSchedulerInputFieldTemplate = ({
return template; return template;
}; };
export const getFieldType = (schemaObject: InvocationFieldSchema): string => { export const getFieldType = (
let fieldType = ''; schemaObject: InvocationFieldSchema
): string | undefined => {
const { ui_type } = schemaObject; if (schemaObject?.ui_type) {
if (ui_type) { return schemaObject.ui_type;
fieldType = ui_type;
} else if (!schemaObject.type) { } else if (!schemaObject.type) {
// console.log('refObject', schemaObject);
// if schemaObject has no type, then it should have one of allOf, anyOf, oneOf // if schemaObject has no type, then it should have one of allOf, anyOf, oneOf
if (schemaObject.allOf) { if (schemaObject.allOf) {
fieldType = refObjectToFieldType( const allOf = schemaObject.allOf;
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion if (allOf && allOf[0] && isRefObject(allOf[0])) {
schemaObject.allOf![0] as OpenAPIV3.ReferenceObject return refObjectToSchemaName(allOf[0]);
); }
} else if (schemaObject.anyOf) { } else if (schemaObject.anyOf) {
fieldType = refObjectToFieldType( const anyOf = schemaObject.anyOf;
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion /**
schemaObject.anyOf![0] as OpenAPIV3.ReferenceObject * Handle Polymorphic inputs, eg string | string[]. In OpenAPI, this is:
); * - an `anyOf` with two items
} else if (schemaObject.oneOf) { * - one is an `ArraySchemaObject` with a single `SchemaObject or ReferenceObject` of type T in its `items`
fieldType = refObjectToFieldType( * - the other is a `SchemaObject` or `ReferenceObject` of type T
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion *
schemaObject.oneOf![0] as OpenAPIV3.ReferenceObject * Any other cases we ignore.
); */
let firstType: string | undefined;
let secondType: string | undefined;
if (isArraySchemaObject(anyOf[0])) {
// first is array, second is not
const first = anyOf[0].items;
const second = anyOf[1];
if (isRefObject(first) && isRefObject(second)) {
firstType = refObjectToSchemaName(first);
secondType = refObjectToSchemaName(second);
} else if (
isNonArraySchemaObject(first) &&
isNonArraySchemaObject(second)
) {
firstType = first.type;
secondType = second.type;
}
} else if (isArraySchemaObject(anyOf[1])) {
// first is not array, second is
const first = anyOf[0];
const second = anyOf[1].items;
if (isRefObject(first) && isRefObject(second)) {
firstType = refObjectToSchemaName(first);
secondType = refObjectToSchemaName(second);
} else if (
isNonArraySchemaObject(first) &&
isNonArraySchemaObject(second)
) {
firstType = first.type;
secondType = second.type;
}
}
if (firstType === secondType && isPolymorphicItemType(firstType)) {
return SINGLE_TO_POLYMORPHIC_MAP[firstType];
}
} }
} else if (schemaObject.enum) { } else if (schemaObject.enum) {
fieldType = 'enum'; return 'enum';
} else if (schemaObject.type) { } else if (schemaObject.type) {
if (schemaObject.type === 'number') { if (schemaObject.type === 'number') {
// floats are "number" in OpenAPI, while ints are "integer" // floats are "number" in OpenAPI, while ints are "integer" - we need to distinguish them
fieldType = 'float'; return 'float';
} else if (schemaObject.type === 'array') {
const itemType = isSchemaObject(schemaObject.items)
? schemaObject.items.type
: refObjectToSchemaName(schemaObject.items);
if (isCollectionItemType(itemType)) {
return COLLECTION_MAP[itemType];
}
return;
} else { } else {
fieldType = schemaObject.type; return schemaObject.type;
} }
} }
return;
return fieldType;
}; };
const TEMPLATE_BUILDER_MAP = {
boolean: buildBooleanInputFieldTemplate,
BooleanCollection: buildBooleanCollectionInputFieldTemplate,
BooleanPolymorphic: buildBooleanPolymorphicInputFieldTemplate,
ClipField: buildClipInputFieldTemplate,
Collection: buildCollectionInputFieldTemplate,
CollectionItem: buildCollectionItemInputFieldTemplate,
ColorCollection: buildColorCollectionInputFieldTemplate,
ColorField: buildColorInputFieldTemplate,
ColorPolymorphic: buildColorPolymorphicInputFieldTemplate,
ConditioningCollection: buildConditioningCollectionInputFieldTemplate,
ConditioningField: buildConditioningInputFieldTemplate,
ConditioningPolymorphic: buildConditioningPolymorphicInputFieldTemplate,
ControlCollection: buildControlCollectionInputFieldTemplate,
ControlField: buildControlInputFieldTemplate,
ControlNetModelField: buildControlNetModelInputFieldTemplate,
ControlPolymorphic: buildControlPolymorphicInputFieldTemplate,
DenoiseMaskField: buildDenoiseMaskInputFieldTemplate,
enum: buildEnumInputFieldTemplate,
float: buildFloatInputFieldTemplate,
FloatCollection: buildFloatCollectionInputFieldTemplate,
FloatPolymorphic: buildFloatPolymorphicInputFieldTemplate,
ImageCollection: buildImageCollectionInputFieldTemplate,
ImageField: buildImageInputFieldTemplate,
ImagePolymorphic: buildImagePolymorphicInputFieldTemplate,
integer: buildIntegerInputFieldTemplate,
IntegerCollection: buildIntegerCollectionInputFieldTemplate,
IntegerPolymorphic: buildIntegerPolymorphicInputFieldTemplate,
LatentsCollection: buildLatentsCollectionInputFieldTemplate,
LatentsField: buildLatentsInputFieldTemplate,
LatentsPolymorphic: buildLatentsPolymorphicInputFieldTemplate,
LoRAModelField: buildLoRAModelInputFieldTemplate,
MainModelField: buildMainModelInputFieldTemplate,
Scheduler: buildSchedulerInputFieldTemplate,
SDXLMainModelField: buildSDXLMainModelInputFieldTemplate,
SDXLRefinerModelField: buildRefinerModelInputFieldTemplate,
string: buildStringInputFieldTemplate,
StringCollection: buildStringCollectionInputFieldTemplate,
StringPolymorphic: buildStringPolymorphicInputFieldTemplate,
UNetField: buildUNetInputFieldTemplate,
VaeField: buildVaeInputFieldTemplate,
VaeModelField: buildVaeModelInputFieldTemplate,
};
const isTemplatedFieldType = (
fieldType: string | undefined
): fieldType is keyof typeof TEMPLATE_BUILDER_MAP =>
Boolean(fieldType && fieldType in TEMPLATE_BUILDER_MAP);
/** /**
* Builds an input field from an invocation schema property. * Builds an input field from an invocation schema property.
* @param fieldSchema The schema object * @param fieldSchema The schema object
@ -474,7 +886,8 @@ export const buildInputFieldTemplate = (
const { input, ui_hidden, ui_component, ui_type, ui_order } = fieldSchema; const { input, ui_hidden, ui_component, ui_type, ui_order } = fieldSchema;
const extra = { const extra = {
input, // TODO: Can we support polymorphic inputs in the UI?
input: POLYMORPHIC_TYPES.includes(fieldType) ? 'connection' : input,
ui_hidden, ui_hidden,
ui_component, ui_component,
ui_type, ui_type,
@ -490,146 +903,12 @@ export const buildInputFieldTemplate = (
...extra, ...extra,
}; };
if (fieldType === 'ImageField') { if (!isTemplatedFieldType(fieldType)) {
return buildImageInputFieldTemplate({ return;
schemaObject: fieldSchema,
baseField,
});
} }
if (fieldType === 'ImageCollection') {
return buildImageCollectionInputFieldTemplate({ return TEMPLATE_BUILDER_MAP[fieldType]({
schemaObject: fieldSchema, schemaObject: fieldSchema,
baseField, baseField,
}); });
}
if (fieldType === 'DenoiseMaskField') {
return buildDenoiseMaskInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'LatentsField') {
return buildLatentsInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'ConditioningField') {
return buildConditioningInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'UNetField') {
return buildUNetInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'ClipField') {
return buildClipInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'VaeField') {
return buildVaeInputFieldTemplate({ schemaObject: fieldSchema, baseField });
}
if (fieldType === 'ControlField') {
return buildControlInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'MainModelField') {
return buildMainModelInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'SDXLRefinerModelField') {
return buildRefinerModelInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'SDXLMainModelField') {
return buildSDXLMainModelInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'VaeModelField') {
return buildVaeModelInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'LoRAModelField') {
return buildLoRAModelInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'ControlNetModelField') {
return buildControlNetModelInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'enum') {
return buildEnumInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'integer') {
return buildIntegerInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'float') {
return buildFloatInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'string') {
return buildStringInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'boolean') {
return buildBooleanInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'Collection') {
return buildCollectionInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'CollectionItem') {
return buildCollectionItemInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'ColorField') {
return buildColorInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'Scheduler') {
return buildSchedulerInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
return;
}; };

View File

@ -1,104 +1,79 @@
import { InputFieldTemplate, InputFieldValue } from '../types/types'; import { InputFieldTemplate, InputFieldValue } from '../types/types';
const FIELD_VALUE_FALLBACK_MAP = {
'enum.number': 0,
'enum.string': '',
boolean: false,
BooleanCollection: [],
BooleanPolymorphic: false,
ClipField: undefined,
Collection: [],
CollectionItem: undefined,
ColorCollection: [],
ColorField: undefined,
ColorPolymorphic: undefined,
ConditioningCollection: [],
ConditioningField: undefined,
ConditioningPolymorphic: undefined,
ControlCollection: [],
ControlField: undefined,
ControlNetModelField: undefined,
ControlPolymorphic: undefined,
DenoiseMaskField: undefined,
float: 0,
FloatCollection: [],
FloatPolymorphic: 0,
ImageCollection: [],
ImageField: undefined,
ImagePolymorphic: undefined,
integer: 0,
IntegerCollection: [],
IntegerPolymorphic: 0,
LatentsCollection: [],
LatentsField: undefined,
LatentsPolymorphic: undefined,
LoRAModelField: undefined,
MainModelField: undefined,
ONNXModelField: undefined,
Scheduler: 'euler',
SDXLMainModelField: undefined,
SDXLRefinerModelField: undefined,
string: '',
StringCollection: [],
StringPolymorphic: '',
UNetField: undefined,
VaeField: undefined,
VaeModelField: undefined,
};
export const buildInputFieldValue = ( export const buildInputFieldValue = (
id: string, id: string,
template: InputFieldTemplate template: InputFieldTemplate
): InputFieldValue => { ): InputFieldValue => {
const fieldValue: InputFieldValue = { // TODO: this should be `fieldValue: InputFieldValue`, but that introduces a TS issue I couldn't
// resolve - for some reason, it doesn't like `template.type`, which is the discriminant for both
// `InputFieldTemplate` union. It is (type-structurally) equal to the discriminant for the
// `InputFieldValue` union, but TS doesn't seem to like it...
const fieldValue = {
id, id,
name: template.name, name: template.name,
type: template.type, type: template.type,
label: '', label: '',
fieldKind: 'input', fieldKind: 'input',
}; } as InputFieldValue;
if (template.type === 'string') {
fieldValue.value = template.default ?? '';
}
if (template.type === 'integer') {
fieldValue.value = template.default ?? 0;
}
if (template.type === 'float') {
fieldValue.value = template.default ?? 0;
}
if (template.type === 'boolean') {
fieldValue.value = template.default ?? false;
}
if (template.type === 'enum') { if (template.type === 'enum') {
if (template.enumType === 'number') { if (template.enumType === 'number') {
fieldValue.value = template.default ?? 0; fieldValue.value =
template.default ?? FIELD_VALUE_FALLBACK_MAP['enum.number'];
} }
if (template.enumType === 'string') { if (template.enumType === 'string') {
fieldValue.value = template.default ?? ''; fieldValue.value =
template.default ?? FIELD_VALUE_FALLBACK_MAP['enum.string'];
} }
} } else {
fieldValue.value =
if (template.type === 'Collection') { template.default ?? FIELD_VALUE_FALLBACK_MAP[template.type];
fieldValue.value = template.default ?? 1;
}
if (template.type === 'ImageField') {
fieldValue.value = undefined;
}
if (template.type === 'ImageCollection') {
fieldValue.value = [];
}
if (template.type === 'DenoiseMaskField') {
fieldValue.value = undefined;
}
if (template.type === 'LatentsField') {
fieldValue.value = undefined;
}
if (template.type === 'ConditioningField') {
fieldValue.value = undefined;
}
if (template.type === 'UNetField') {
fieldValue.value = undefined;
}
if (template.type === 'ClipField') {
fieldValue.value = undefined;
}
if (template.type === 'VaeField') {
fieldValue.value = undefined;
}
if (template.type === 'ControlField') {
fieldValue.value = undefined;
}
if (template.type === 'MainModelField') {
fieldValue.value = undefined;
}
if (template.type === 'SDXLRefinerModelField') {
fieldValue.value = undefined;
}
if (template.type === 'VaeModelField') {
fieldValue.value = undefined;
}
if (template.type === 'LoRAModelField') {
fieldValue.value = undefined;
}
if (template.type === 'ControlNetModelField') {
fieldValue.value = undefined;
}
if (template.type === 'Scheduler') {
fieldValue.value = 'euler';
} }
return fieldValue; return fieldValue;

View File

@ -1397,9 +1397,8 @@ export type components = {
/** /**
* Control Model * Control Model
* @description ControlNet model to load * @description ControlNet model to load
* @default lllyasviel/sd-controlnet-canny
*/ */
control_model?: components["schemas"]["ControlNetModelField"]; control_model: components["schemas"]["ControlNetModelField"];
/** /**
* Control Weight * Control Weight
* @description The weight given to the ControlNet * @description The weight given to the ControlNet
@ -5806,12 +5805,12 @@ export type components = {
*/ */
target_height?: number; target_height?: number;
/** /**
* Clip * CLIP 1
* @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count * @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
*/ */
clip?: components["schemas"]["ClipField"]; clip?: components["schemas"]["ClipField"];
/** /**
* Clip2 * CLIP 2
* @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count * @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
*/ */
clip2?: components["schemas"]["ClipField"]; clip2?: components["schemas"]["ClipField"];
@ -5855,7 +5854,7 @@ export type components = {
*/ */
weight?: number; weight?: number;
/** /**
* UNET * UNet
* @description UNet (scheduler, LoRAs) * @description UNet (scheduler, LoRAs)
*/ */
unet?: components["schemas"]["UNetField"]; unet?: components["schemas"]["UNetField"];
@ -6998,7 +6997,7 @@ export type components = {
* If a field should be provided a data type that does not exactly match the python type of the field, use this to provide the type that should be used instead. See the node development docs for detail on adding a new field type, which involves client-side changes. * If a field should be provided a data type that does not exactly match the python type of the field, use this to provide the type that should be used instead. See the node development docs for detail on adding a new field type, which involves client-side changes.
* @enum {string} * @enum {string}
*/ */
UIType: "integer" | "float" | "boolean" | "string" | "array" | "ImageField" | "LatentsField" | "ConditioningField" | "ControlField" | "ColorField" | "ImageCollection" | "ConditioningCollection" | "ColorCollection" | "LatentsCollection" | "IntegerCollection" | "FloatCollection" | "StringCollection" | "BooleanCollection" | "MainModelField" | "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VaeModelField" | "LoRAModelField" | "ControlNetModelField" | "UNetField" | "VaeField" | "ClipField" | "Collection" | "CollectionItem" | "enum" | "Scheduler" | "WorkflowField" | "IsIntermediate" | "MetadataField"; UIType: "boolean" | "ColorField" | "ConditioningField" | "ControlField" | "float" | "ImageField" | "integer" | "LatentsField" | "string" | "BooleanCollection" | "ColorCollection" | "ConditioningCollection" | "ControlCollection" | "FloatCollection" | "ImageCollection" | "IntegerCollection" | "LatentsCollection" | "StringCollection" | "BooleanPolymorphic" | "ColorPolymorphic" | "ConditioningPolymorphic" | "ControlPolymorphic" | "FloatPolymorphic" | "ImagePolymorphic" | "IntegerPolymorphic" | "LatentsPolymorphic" | "StringPolymorphic" | "MainModelField" | "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VaeModelField" | "LoRAModelField" | "ControlNetModelField" | "UNetField" | "VaeField" | "ClipField" | "Collection" | "CollectionItem" | "enum" | "Scheduler" | "WorkflowField" | "IsIntermediate" | "MetadataField";
/** /**
* UIComponent * UIComponent
* @description The type of UI component to use for a field, used to override the default components, which are inferred from the field type. * @description The type of UI component to use for a field, used to override the default components, which are inferred from the field type.
@ -7020,6 +7019,8 @@ export type components = {
ui_component?: components["schemas"]["UIComponent"]; ui_component?: components["schemas"]["UIComponent"];
/** Ui Order */ /** Ui Order */
ui_order?: number; ui_order?: number;
/** Item Default */
item_default?: unknown;
}; };
/** /**
* _OutputField * _OutputField
@ -7035,6 +7036,12 @@ export type components = {
/** Ui Order */ /** Ui Order */
ui_order?: number; ui_order?: number;
}; };
/**
* StableDiffusionOnnxModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusionOnnxModelFormat: "olive" | "onnx";
/** /**
* StableDiffusion1ModelFormat * StableDiffusion1ModelFormat
* @description An enumeration. * @description An enumeration.
@ -7059,12 +7066,6 @@ export type components = {
* @enum {string} * @enum {string}
*/ */
StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusionOnnxModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusionOnnxModelFormat: "olive" | "onnx";
}; };
responses: never; responses: never;
parameters: never; parameters: never;

View File

@ -1,3 +1,4 @@
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from .test_nodes import ( from .test_nodes import (
ImageToImageTestInvocation, ImageToImageTestInvocation,
TextToImageTestInvocation, TextToImageTestInvocation,
@ -20,7 +21,7 @@ from invokeai.app.invocations.upscale import ESRGANInvocation
from invokeai.app.invocations.image import ShowImageInvocation from invokeai.app.invocations.image import ShowImageInvocation
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
from invokeai.app.invocations.primitives import IntegerInvocation from invokeai.app.invocations.primitives import FloatInvocation, IntegerInvocation
from invokeai.app.services.default_graphs import create_text_to_image from invokeai.app.services.default_graphs import create_text_to_image
import pytest import pytest
@ -610,6 +611,59 @@ def test_graph_can_deserialize():
assert g2.edges[0].destination.field == "image" assert g2.edges[0].destination.field == "image"
def test_invocation_decorator():
invocation_type = "test_invocation"
title = "Test Invocation"
tags = ["first", "second", "third"]
category = "category"
@invocation(invocation_type, title=title, tags=tags, category=category)
class Test(BaseInvocation):
def invoke(self):
pass
schema = Test.schema()
assert schema.get("title") == title
assert schema.get("tags") == tags
assert schema.get("category") == category
assert Test(id="1").type == invocation_type # type: ignore (type is dynamically added)
def test_invocation_output_decorator():
output_type = "test_output"
@invocation_output(output_type)
class TestOutput(BaseInvocationOutput):
pass
assert TestOutput().type == output_type # type: ignore (type is dynamically added)
def test_floats_accept_ints():
g = Graph()
n1 = IntegerInvocation(id="1", value=1)
n2 = FloatInvocation(id="2")
g.add_node(n1)
g.add_node(n2)
e = create_edge(n1.id, "value", n2.id, "value")
# Not throwing on this line is sufficient
g.add_edge(e)
def test_ints_do_not_accept_floats():
g = Graph()
n1 = FloatInvocation(id="1", value=1.0)
n2 = IntegerInvocation(id="2")
g.add_node(n1)
g.add_node(n2)
e = create_edge(n1.id, "value", n2.id, "value")
with pytest.raises(InvalidEdgeError):
g.add_edge(e)
def test_graph_can_generate_schema(): def test_graph_can_generate_schema():
# Not throwing on this line is sufficient # Not throwing on this line is sufficient
# NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation # NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation