From 1062fc4796bbcd661ffbe0e96b480e3deddedbb4 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 1 Sep 2023 19:40:27 +1000 Subject: [PATCH 01/13] feat: polymorphic fields Initial support for polymorphic field types. Polymorphic types are a single of or list of a specific type. For example, `Union[str, list[str]]`. Polymorphics do not yet have support for direct input in the UI (will come in the future). They will be forcibly set as Connection-only fields, in which case users will not be able to provide direct input to the field. If a polymorphic should present as a singleton type - which would allow direct input - the node must provide an explicit type hint. For example, `DenoiseLatents`' `CFG Scale` is polymorphic, but in the node editor, we want to present this as a number input. In the node definition, the field is given `ui_type=UIType.Float`, which tells the UI to treat this as a `float` field. The connection validation logic will prevent connecting a collection to `CFG Scale` in this situation, because it is typed as `float`. The workaround is to disable validation from the settings to make this specific connection. A future improvement will resolve this. This also introduces better support for collection field types. Like polymorphics, collection types are parsed automatically by the client and do not need any specific type hints. Also like polymorphics, there is no support yet for direct input of collection types in the UI. - Disabling validation in workflow editor now displays the visual hints for valid connections, but lets you connect to anything. - Added `ui_order: int` to `InputField` and `OutputField`. The UI will use this, if present, to order fields in a node UI. See usage in `DenoiseLatents` for an example. - Updated the field colors - duplicate colors have just been lightened a bit. It's not perfect but it was a quick fix. - Field handles for collections are the same color as their single counterparts, but have a dark dot in the center of them. - Field handles for polymorphics are a rounded square with dot in the middle. - Removed all fields that just render `null` from `InputFieldRenderer`, replaced with a single fallback - Removed logic in `zValidatedWorkflow`, which checked for existence of node templates for each node in a workflow. This logic introduced a circular dependency, due to importing the global redux `store` in order to get the node templates within a zod schema. It's actually fine to just leave this out entirely; The case of a missing node template is handled by the UI. Fixing it otherwise would introduce a substantial headache. - Fixed the `ControlNetInvocation.control_model` field default, which was a string when it shouldn't have one. --- invokeai/app/invocations/baseinvocation.py | 53 +- .../controlnet_image_processors.py | 4 +- invokeai/app/invocations/latent.py | 7 +- invokeai/app/invocations/primitives.py | 50 +- .../web/src/common/util/colorTokenToCssVar.ts | 2 +- .../nodes/Invocation/fields/FieldHandle.tsx | 21 +- .../nodes/Invocation/fields/InputField.tsx | 1 + .../Invocation/fields/InputFieldRenderer.tsx | 134 +--- .../fields/inputs/ControlInputField.tsx | 7 +- .../fields/inputs/ImageInputField.tsx | 2 +- .../fields/inputs/LatentsInputField.tsx | 7 +- .../fields/inputs/NumberInputField.tsx | 2 +- .../nodes/hooks/useIsValidConnection.ts | 91 ++- .../util/makeIsConnectionValidSelector.ts | 99 ++- .../web/src/features/nodes/types/constants.ts | 366 ++++++---- .../web/src/features/nodes/types/types.ts | 479 ++++++++++--- .../nodes/util/fieldTemplateBuilders.ts | 635 +++++++++++++----- .../features/nodes/util/fieldValueBuilders.ts | 143 ++-- .../frontend/web/src/services/api/schema.d.ts | 25 +- 19 files changed, 1408 insertions(+), 720 deletions(-) diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 87c1d65113..ccc2b4d05f 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -105,24 +105,39 @@ class UIType(str, Enum): """ # region Primitives - Integer = "integer" - Float = "float" Boolean = "boolean" - String = "string" - Array = "array" - Image = "ImageField" - Latents = "LatentsField" + Color = "ColorField" Conditioning = "ConditioningField" Control = "ControlField" - Color = "ColorField" - ImageCollection = "ImageCollection" - ConditioningCollection = "ConditioningCollection" - ColorCollection = "ColorCollection" - LatentsCollection = "LatentsCollection" - IntegerCollection = "IntegerCollection" - FloatCollection = "FloatCollection" - StringCollection = "StringCollection" + Float = "float" + Image = "ImageField" + Integer = "integer" + Latents = "LatentsField" + String = "string" + # endregion + + # region Collection Primitives BooleanCollection = "BooleanCollection" + ColorCollection = "ColorCollection" + ConditioningCollection = "ConditioningCollection" + ControlCollection = "ControlCollection" + FloatCollection = "FloatCollection" + ImageCollection = "ImageCollection" + IntegerCollection = "IntegerCollection" + LatentsCollection = "LatentsCollection" + StringCollection = "StringCollection" + # endregion + + # region Polymorphic Primitives + BooleanPolymorphic = "BooleanPolymorphic" + ColorPolymorphic = "ColorPolymorphic" + ConditioningPolymorphic = "ConditioningPolymorphic" + ControlPolymorphic = "ControlPolymorphic" + FloatPolymorphic = "FloatPolymorphic" + ImagePolymorphic = "ImagePolymorphic" + IntegerPolymorphic = "IntegerPolymorphic" + LatentsPolymorphic = "LatentsPolymorphic" + StringPolymorphic = "StringPolymorphic" # endregion # region Models @@ -176,6 +191,7 @@ class _InputField(BaseModel): ui_type: Optional[UIType] ui_component: Optional[UIComponent] ui_order: Optional[int] + item_default: Optional[Any] class _OutputField(BaseModel): @@ -223,6 +239,7 @@ def InputField( ui_component: Optional[UIComponent] = None, ui_hidden: bool = False, ui_order: Optional[int] = None, + item_default: Optional[Any] = None, **kwargs: Any, ) -> Any: """ @@ -249,6 +266,11 @@ def InputField( For this case, you could provide `UIComponent.Textarea`. : param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. + + : param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \ + + : param bool item_default: [None] Specifies the default item value, if this is a collection input. \ + Ignored for non-collection fields.. """ return Field( *args, @@ -282,6 +304,7 @@ def InputField( ui_component=ui_component, ui_hidden=ui_hidden, ui_order=ui_order, + item_default=item_default, **kwargs, ) @@ -332,6 +355,8 @@ def OutputField( `UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field. : param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \ + + : param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \ """ return Field( *args, diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index a666c5d6f4..272afb3a4c 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -100,9 +100,7 @@ class ControlNetInvocation(BaseInvocation): """Collects ControlNet info to pass to other nodes""" image: ImageField = InputField(description="The control image") - control_model: ControlNetModelField = InputField( - default="lllyasviel/sd-controlnet-canny", description=FieldDescriptions.controlnet_model, input=Input.Direct - ) + control_model: ControlNetModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct) control_weight: Union[float, List[float]] = InputField( default=1.0, description="The weight given to the ControlNet", ui_type=UIType.Float ) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 6357c1ac7b..96ad49165a 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -208,7 +208,10 @@ class DenoiseLatentsInvocation(BaseInvocation): ) unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection, title="UNet", ui_order=2) control: Union[ControlField, list[ControlField]] = InputField( - default=None, description=FieldDescriptions.control, input=Input.Connection, ui_order=5 + default=None, + description=FieldDescriptions.control, + input=Input.Connection, + ui_order=5, ) latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection) denoise_mask: Optional[DenoiseMaskField] = InputField( @@ -317,7 +320,7 @@ class DenoiseLatentsInvocation(BaseInvocation): context: InvocationContext, # really only need model for dtype and device model: StableDiffusionGeneratorPipeline, - control_input: List[ControlField], + control_input: Union[ControlField, List[ControlField]], latents_shape: List[int], exit_stack: ExitStack, do_classifier_free_guidance: bool = True, diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index d002ba8ddc..81914ab2af 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -14,7 +14,6 @@ from .baseinvocation import ( InvocationContext, OutputField, UIComponent, - UIType, invocation, invocation_output, ) @@ -40,7 +39,9 @@ class BooleanOutput(BaseInvocationOutput): class BooleanCollectionOutput(BaseInvocationOutput): """Base class for nodes that output a collection of booleans""" - collection: list[bool] = OutputField(description="The output boolean collection", ui_type=UIType.BooleanCollection) + collection: list[bool] = OutputField( + description="The output boolean collection", + ) @invocation("boolean", title="Boolean Primitive", tags=["primitives", "boolean"], category="primitives") @@ -62,9 +63,7 @@ class BooleanInvocation(BaseInvocation): class BooleanCollectionInvocation(BaseInvocation): """A collection of boolean primitive values""" - collection: list[bool] = InputField( - default_factory=list, description="The collection of boolean values", ui_type=UIType.BooleanCollection - ) + collection: list[bool] = InputField(default_factory=list, description="The collection of boolean values") def invoke(self, context: InvocationContext) -> BooleanCollectionOutput: return BooleanCollectionOutput(collection=self.collection) @@ -86,7 +85,9 @@ class IntegerOutput(BaseInvocationOutput): class IntegerCollectionOutput(BaseInvocationOutput): """Base class for nodes that output a collection of integers""" - collection: list[int] = OutputField(description="The int collection", ui_type=UIType.IntegerCollection) + collection: list[int] = OutputField( + description="The int collection", + ) @invocation("integer", title="Integer Primitive", tags=["primitives", "integer"], category="primitives") @@ -108,9 +109,7 @@ class IntegerInvocation(BaseInvocation): class IntegerCollectionInvocation(BaseInvocation): """A collection of integer primitive values""" - collection: list[int] = InputField( - default_factory=list, description="The collection of integer values", ui_type=UIType.IntegerCollection - ) + collection: list[int] = InputField(default_factory=list, description="The collection of integer values") def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: return IntegerCollectionOutput(collection=self.collection) @@ -132,7 +131,9 @@ class FloatOutput(BaseInvocationOutput): class FloatCollectionOutput(BaseInvocationOutput): """Base class for nodes that output a collection of floats""" - collection: list[float] = OutputField(description="The float collection", ui_type=UIType.FloatCollection) + collection: list[float] = OutputField( + description="The float collection", + ) @invocation("float", title="Float Primitive", tags=["primitives", "float"], category="primitives") @@ -154,9 +155,7 @@ class FloatInvocation(BaseInvocation): class FloatCollectionInvocation(BaseInvocation): """A collection of float primitive values""" - collection: list[float] = InputField( - default_factory=list, description="The collection of float values", ui_type=UIType.FloatCollection - ) + collection: list[float] = InputField(default_factory=list, description="The collection of float values") def invoke(self, context: InvocationContext) -> FloatCollectionOutput: return FloatCollectionOutput(collection=self.collection) @@ -178,7 +177,9 @@ class StringOutput(BaseInvocationOutput): class StringCollectionOutput(BaseInvocationOutput): """Base class for nodes that output a collection of strings""" - collection: list[str] = OutputField(description="The output strings", ui_type=UIType.StringCollection) + collection: list[str] = OutputField( + description="The output strings", + ) @invocation("string", title="String Primitive", tags=["primitives", "string"], category="primitives") @@ -200,9 +201,7 @@ class StringInvocation(BaseInvocation): class StringCollectionInvocation(BaseInvocation): """A collection of string primitive values""" - collection: list[str] = InputField( - default_factory=list, description="The collection of string values", ui_type=UIType.StringCollection - ) + collection: list[str] = InputField(default_factory=list, description="The collection of string values") def invoke(self, context: InvocationContext) -> StringCollectionOutput: return StringCollectionOutput(collection=self.collection) @@ -232,7 +231,9 @@ class ImageOutput(BaseInvocationOutput): class ImageCollectionOutput(BaseInvocationOutput): """Base class for nodes that output a collection of images""" - collection: list[ImageField] = OutputField(description="The output images", ui_type=UIType.ImageCollection) + collection: list[ImageField] = OutputField( + description="The output images", + ) @invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives") @@ -260,9 +261,7 @@ class ImageInvocation(BaseInvocation): class ImageCollectionInvocation(BaseInvocation): """A collection of image primitive values""" - collection: list[ImageField] = InputField( - default_factory=list, description="The collection of image values", ui_type=UIType.ImageCollection - ) + collection: list[ImageField] = InputField(default_factory=list, description="The collection of image values") def invoke(self, context: InvocationContext) -> ImageCollectionOutput: return ImageCollectionOutput(collection=self.collection) @@ -316,7 +315,6 @@ class LatentsCollectionOutput(BaseInvocationOutput): collection: list[LatentsField] = OutputField( description=FieldDescriptions.latents, - ui_type=UIType.LatentsCollection, ) @@ -342,7 +340,7 @@ class LatentsCollectionInvocation(BaseInvocation): """A collection of latents tensor primitive values""" collection: list[LatentsField] = InputField( - description="The collection of latents tensors", ui_type=UIType.LatentsCollection + description="The collection of latents tensors", ) def invoke(self, context: InvocationContext) -> LatentsCollectionOutput: @@ -385,7 +383,9 @@ class ColorOutput(BaseInvocationOutput): class ColorCollectionOutput(BaseInvocationOutput): """Base class for nodes that output a collection of colors""" - collection: list[ColorField] = OutputField(description="The output colors", ui_type=UIType.ColorCollection) + collection: list[ColorField] = OutputField( + description="The output colors", + ) @invocation("color", title="Color Primitive", tags=["primitives", "color"], category="primitives") @@ -422,7 +422,6 @@ class ConditioningCollectionOutput(BaseInvocationOutput): collection: list[ConditioningField] = OutputField( description="The output conditioning tensors", - ui_type=UIType.ConditioningCollection, ) @@ -453,7 +452,6 @@ class ConditioningCollectionInvocation(BaseInvocation): collection: list[ConditioningField] = InputField( default_factory=list, description="The collection of conditioning tensors", - ui_type=UIType.ConditioningCollection, ) def invoke(self, context: InvocationContext) -> ConditioningCollectionOutput: diff --git a/invokeai/frontend/web/src/common/util/colorTokenToCssVar.ts b/invokeai/frontend/web/src/common/util/colorTokenToCssVar.ts index e29005186f..87724d9c9b 100644 --- a/invokeai/frontend/web/src/common/util/colorTokenToCssVar.ts +++ b/invokeai/frontend/web/src/common/util/colorTokenToCssVar.ts @@ -1,2 +1,2 @@ export const colorTokenToCssVar = (colorToken: string) => - `var(--invokeai-colors-${colorToken.split('.').join('-')}`; + `var(--invokeai-colors-${colorToken.split('.').join('-')})`; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx index 14924a16fe..02b18e7178 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx @@ -1,8 +1,10 @@ import { Tooltip } from '@chakra-ui/react'; import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; import { + COLLECTION_TYPES, FIELDS, HANDLE_TOOLTIP_OPEN_DELAY, + POLYMORPHIC_TYPES, } from 'features/nodes/types/constants'; import { InputFieldTemplate, @@ -18,6 +20,7 @@ export const handleBaseStyles: CSSProperties = { borderWidth: 0, zIndex: 1, }; +``; export const inputHandleStyles: CSSProperties = { left: '-1rem', @@ -44,15 +47,24 @@ const FieldHandle = (props: FieldHandleProps) => { connectionError, } = props; const { name, type } = fieldTemplate; - const { color, title } = FIELDS[type]; + const { color: typeColor, title } = FIELDS[type]; const styles: CSSProperties = useMemo(() => { + const isCollectionType = COLLECTION_TYPES.includes(type); + const isPolymorphicType = POLYMORPHIC_TYPES.includes(type); + const color = colorTokenToCssVar(typeColor); const s: CSSProperties = { - backgroundColor: colorTokenToCssVar(color), + backgroundColor: + isCollectionType || isPolymorphicType + ? 'var(--invokeai-colors-base-900)' + : color, position: 'absolute', width: '1rem', height: '1rem', - borderWidth: 0, + borderWidth: isCollectionType || isPolymorphicType ? 4 : 0, + borderStyle: 'solid', + borderColor: color, + borderRadius: isPolymorphicType ? 4 : '100%', zIndex: 1, }; @@ -78,11 +90,12 @@ const FieldHandle = (props: FieldHandleProps) => { return s; }, [ - color, connectionError, handleType, isConnectionInProgress, isConnectionStartField, + type, + typeColor, ]); const tooltip = useMemo(() => { diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx index 6ad8e14bc2..bee5264e00 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx @@ -75,6 +75,7 @@ const InputField = ({ nodeId, fieldName }: Props) => { sx={{ display: 'flex', alignItems: 'center', + h: 'full', mb: 0, px: 1, gap: 2, diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx index bb9637cd73..fa5c4533c2 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx @@ -3,18 +3,10 @@ import { useFieldData } from 'features/nodes/hooks/useFieldData'; import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate'; import { memo } from 'react'; import BooleanInputField from './inputs/BooleanInputField'; -import ClipInputField from './inputs/ClipInputField'; -import CollectionInputField from './inputs/CollectionInputField'; -import CollectionItemInputField from './inputs/CollectionItemInputField'; import ColorInputField from './inputs/ColorInputField'; -import ConditioningInputField from './inputs/ConditioningInputField'; -import ControlInputField from './inputs/ControlInputField'; import ControlNetModelInputField from './inputs/ControlNetModelInputField'; -import DenoiseMaskInputField from './inputs/DenoiseMaskInputField'; import EnumInputField from './inputs/EnumInputField'; -import ImageCollectionInputField from './inputs/ImageCollectionInputField'; import ImageInputField from './inputs/ImageInputField'; -import LatentsInputField from './inputs/LatentsInputField'; import LoRAModelInputField from './inputs/LoRAModelInputField'; import MainModelInputField from './inputs/MainModelInputField'; import NumberInputField from './inputs/NumberInputField'; @@ -22,8 +14,6 @@ import RefinerModelInputField from './inputs/RefinerModelInputField'; import SDXLMainModelInputField from './inputs/SDXLMainModelInputField'; import SchedulerInputField from './inputs/SchedulerInputField'; import StringInputField from './inputs/StringInputField'; -import UnetInputField from './inputs/UnetInputField'; -import VaeInputField from './inputs/VaeInputField'; import VaeModelInputField from './inputs/VaeModelInputField'; type InputFieldProps = { @@ -31,7 +21,6 @@ type InputFieldProps = { fieldName: string; }; -// build an individual input element based on the schema const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => { const field = useFieldData(nodeId, fieldName); const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input'); @@ -93,88 +82,6 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => { ); } - if ( - field?.type === 'LatentsField' && - fieldTemplate?.type === 'LatentsField' - ) { - return ( - - ); - } - - if ( - field?.type === 'DenoiseMaskField' && - fieldTemplate?.type === 'DenoiseMaskField' - ) { - return ( - - ); - } - - if ( - field?.type === 'ConditioningField' && - fieldTemplate?.type === 'ConditioningField' - ) { - return ( - - ); - } - - if (field?.type === 'UNetField' && fieldTemplate?.type === 'UNetField') { - return ( - - ); - } - - if (field?.type === 'ClipField' && fieldTemplate?.type === 'ClipField') { - return ( - - ); - } - - if (field?.type === 'VaeField' && fieldTemplate?.type === 'VaeField') { - return ( - - ); - } - - if ( - field?.type === 'ControlField' && - fieldTemplate?.type === 'ControlField' - ) { - return ( - - ); - } - if ( field?.type === 'MainModelField' && fieldTemplate?.type === 'MainModelField' @@ -240,29 +147,6 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => { ); } - if (field?.type === 'Collection' && fieldTemplate?.type === 'Collection') { - return ( - - ); - } - - if ( - field?.type === 'CollectionItem' && - fieldTemplate?.type === 'CollectionItem' - ) { - return ( - - ); - } - if (field?.type === 'ColorField' && fieldTemplate?.type === 'ColorField') { return ( { ); } - if ( - field?.type === 'ImageCollection' && - fieldTemplate?.type === 'ImageCollection' - ) { - return ( - - ); - } - if ( field?.type === 'SDXLMainModelField' && fieldTemplate?.type === 'SDXLMainModelField' @@ -309,6 +180,11 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => { ); } + if (field && fieldTemplate) { + // Fallback for when there is no component for the type + return null; + } + return ( + _props: FieldComponentProps< + ControlInputFieldValue | ControlPolymorphicInputFieldValue, + ControlInputFieldTemplate | ControlPolymorphicInputFieldTemplate + > ) => { return null; }; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageInputField.tsx index e04d0d1edc..7f96675792 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageInputField.tsx @@ -9,9 +9,9 @@ import { } from 'features/dnd/types'; import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice'; import { + FieldComponentProps, ImageInputFieldTemplate, ImageInputFieldValue, - FieldComponentProps, } from 'features/nodes/types/types'; import { memo, useCallback, useMemo } from 'react'; import { FaUndo } from 'react-icons/fa'; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/LatentsInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/LatentsInputField.tsx index 099314654f..a5065be0ee 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/LatentsInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/LatentsInputField.tsx @@ -2,11 +2,16 @@ import { LatentsInputFieldTemplate, LatentsInputFieldValue, FieldComponentProps, + LatentsPolymorphicInputFieldValue, + LatentsPolymorphicInputFieldTemplate, } from 'features/nodes/types/types'; import { memo } from 'react'; const LatentsInputFieldComponent = ( - _props: FieldComponentProps + _props: FieldComponentProps< + LatentsInputFieldValue | LatentsPolymorphicInputFieldValue, + LatentsInputFieldTemplate | LatentsPolymorphicInputFieldTemplate + > ) => { return null; }; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/NumberInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/NumberInputField.tsx index 1e569d5005..61387d751b 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/NumberInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/NumberInputField.tsx @@ -9,11 +9,11 @@ import { useAppDispatch } from 'app/store/storeHooks'; import { numberStringRegex } from 'common/components/IAINumberInput'; import { fieldNumberValueChanged } from 'features/nodes/store/nodesSlice'; import { + FieldComponentProps, FloatInputFieldTemplate, FloatInputFieldValue, IntegerInputFieldTemplate, IntegerInputFieldValue, - FieldComponentProps, } from 'features/nodes/types/types'; import { memo, useEffect, useMemo, useState } from 'react'; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts index 3a63d75bb0..1c372e551d 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts @@ -3,9 +3,19 @@ import graphlib from '@dagrejs/graphlib'; import { useAppSelector } from 'app/store/storeHooks'; import { useCallback } from 'react'; import { Connection, Edge, Node, useReactFlow } from 'reactflow'; -import { COLLECTION_TYPES } from '../types/constants'; +import { + COLLECTION_MAP, + COLLECTION_TYPES, + POLYMORPHIC_TO_SINGLE_MAP, + POLYMORPHIC_TYPES, +} from '../types/constants'; import { InvocationNodeData } from '../types/types'; +/** + * NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts` + * TODO: Figure out how to do this without duplicating all the logic + */ + export const useIsValidConnection = () => { const flow = useReactFlow(); const shouldValidateGraph = useAppSelector( @@ -42,6 +52,19 @@ export const useIsValidConnection = () => { return false; } + if ( + edges + .filter((edge) => { + return edge.target === target && edge.targetHandle === targetHandle; + }) + .find((edge) => { + edge.source === source && edge.sourceHandle === sourceHandle; + }) + ) { + // We already have a connection from this source to this target + return false; + } + // Connection is invalid if target already has a connection if ( edges.find((edge) => { @@ -53,21 +76,59 @@ export const useIsValidConnection = () => { return false; } - // Connection types must be the same for a connection - if ( - sourceType !== targetType && - sourceType !== 'CollectionItem' && - targetType !== 'CollectionItem' - ) { - if ( - !( - COLLECTION_TYPES.includes(targetType) && - COLLECTION_TYPES.includes(sourceType) - ) - ) { - return false; - } + /** + * Connection types must be the same for a connection, with exceptions: + * - CollectionItem can connect to any non-Collection + * - Non-Collections can connect to CollectionItem + * - Anything (non-Collections, Collections, Polymorphics) can connect to Polymorphics of the same base type + * - Generic Collection can connect to any other Collection or Polymorphic + * - Any Collection can connect to a Generic Collection + */ + + if (sourceType !== targetType) { + const isCollectionItemToNonCollection = + sourceType === 'CollectionItem' && + !COLLECTION_TYPES.includes(targetType); + + const isNonCollectionToCollectionItem = + targetType === 'CollectionItem' && + !COLLECTION_TYPES.includes(sourceType) && + !POLYMORPHIC_TYPES.includes(sourceType); + + const isAnythingToPolymorphicOfSameBaseType = + POLYMORPHIC_TYPES.includes(targetType) && + (() => { + if (!POLYMORPHIC_TYPES.includes(targetType)) { + return false; + } + const baseType = + POLYMORPHIC_TO_SINGLE_MAP[ + targetType as keyof typeof POLYMORPHIC_TO_SINGLE_MAP + ]; + + const collectionType = + COLLECTION_MAP[baseType as keyof typeof COLLECTION_MAP]; + + return sourceType === baseType || sourceType === collectionType; + })(); + + const isGenericCollectionToAnyCollectionOrPolymorphic = + sourceType === 'Collection' && + (COLLECTION_TYPES.includes(targetType) || + POLYMORPHIC_TYPES.includes(targetType)); + + const isCollectionToGenericCollection = + targetType === 'Collection' && COLLECTION_TYPES.includes(sourceType); + + return ( + isCollectionItemToNonCollection || + isNonCollectionToCollectionItem || + isAnythingToPolymorphicOfSameBaseType || + isGenericCollectionToAnyCollectionOrPolymorphic || + isCollectionToGenericCollection + ); } + // Graphs much be acyclic (no loops!) return getIsGraphAcyclic(source, target, nodes, edges); }, diff --git a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts index 29603036ab..0c5fee509c 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts @@ -1,10 +1,20 @@ import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; import { getIsGraphAcyclic } from 'features/nodes/hooks/useIsValidConnection'; -import { COLLECTION_TYPES } from 'features/nodes/types/constants'; +import { + COLLECTION_MAP, + COLLECTION_TYPES, + POLYMORPHIC_TO_SINGLE_MAP, + POLYMORPHIC_TYPES, +} from 'features/nodes/types/constants'; import { FieldType } from 'features/nodes/types/types'; import { HandleType } from 'reactflow'; +/** + * NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts` + * TODO: Figure out how to do this without duplicating all the logic + */ + export const makeConnectionErrorSelector = ( nodeId: string, fieldName: string, @@ -19,11 +29,6 @@ export const makeConnectionErrorSelector = ( const { currentConnectionFieldType, connectionStartParams, nodes, edges } = state.nodes; - if (!state.nodes.shouldValidateGraph) { - // manual override! - return null; - } - if (!connectionStartParams || !currentConnectionFieldType) { return 'No connection in progress'; } @@ -38,9 +43,9 @@ export const makeConnectionErrorSelector = ( return 'No connection data'; } - const targetFieldType = + const targetType = handleType === 'target' ? fieldType : currentConnectionFieldType; - const sourceFieldType = + const sourceType = handleType === 'source' ? fieldType : currentConnectionFieldType; if (nodeId === connectionNodeId) { @@ -55,30 +60,70 @@ export const makeConnectionErrorSelector = ( } if ( - fieldType !== currentConnectionFieldType && - fieldType !== 'CollectionItem' && - currentConnectionFieldType !== 'CollectionItem' - ) { - if ( - !( - COLLECTION_TYPES.includes(targetFieldType) && - COLLECTION_TYPES.includes(sourceFieldType) - ) - ) { - // except for collection items, field types must match - return 'Field types must match'; - } - } - - if ( - handleType === 'target' && edges.find((edge) => { return edge.target === nodeId && edge.targetHandle === fieldName; }) && // except CollectionItem inputs can have multiples - targetFieldType !== 'CollectionItem' + targetType !== 'CollectionItem' ) { - return 'Inputs may only have one connection'; + return 'Input may only have one connection'; + } + + /** + * Connection types must be the same for a connection, with exceptions: + * - CollectionItem can connect to any non-Collection + * - Non-Collections can connect to CollectionItem + * - Anything (non-Collections, Collections, Polymorphics) can connect to Polymorphics of the same base type + * - Generic Collection can connect to any other Collection or Polymorphic + * - Any Collection can connect to a Generic Collection + */ + + if (sourceType !== targetType) { + const isCollectionItemToNonCollection = + sourceType === 'CollectionItem' && + !COLLECTION_TYPES.includes(targetType); + + const isNonCollectionToCollectionItem = + targetType === 'CollectionItem' && + !COLLECTION_TYPES.includes(sourceType) && + !POLYMORPHIC_TYPES.includes(sourceType); + + const isAnythingToPolymorphicOfSameBaseType = + POLYMORPHIC_TYPES.includes(targetType) && + (() => { + if (!POLYMORPHIC_TYPES.includes(targetType)) { + return false; + } + const baseType = + POLYMORPHIC_TO_SINGLE_MAP[ + targetType as keyof typeof POLYMORPHIC_TO_SINGLE_MAP + ]; + + const collectionType = + COLLECTION_MAP[baseType as keyof typeof COLLECTION_MAP]; + + return sourceType === baseType || sourceType === collectionType; + })(); + + const isGenericCollectionToAnyCollectionOrPolymorphic = + sourceType === 'Collection' && + (COLLECTION_TYPES.includes(targetType) || + POLYMORPHIC_TYPES.includes(targetType)); + + const isCollectionToGenericCollection = + targetType === 'Collection' && COLLECTION_TYPES.includes(sourceType); + + if ( + !( + isCollectionItemToNonCollection || + isNonCollectionToCollectionItem || + isAnythingToPolymorphicOfSameBaseType || + isGenericCollectionToAnyCollectionOrPolymorphic || + isCollectionToGenericCollection + ) + ) { + return 'Field types must match'; + } } const isGraphAcyclic = getIsGraphAcyclic( diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts index c611ca9976..dcd579d912 100644 --- a/invokeai/frontend/web/src/features/nodes/types/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts @@ -17,176 +17,284 @@ export const KIND_MAP = { export const COLLECTION_TYPES: FieldType[] = [ 'Collection', 'IntegerCollection', + 'BooleanCollection', 'FloatCollection', 'StringCollection', - 'BooleanCollection', 'ImageCollection', + 'LatentsCollection', + 'ConditioningCollection', + 'ControlCollection', + 'ColorCollection', ]; +export const POLYMORPHIC_TYPES = [ + 'IntegerPolymorphic', + 'BooleanPolymorphic', + 'FloatPolymorphic', + 'StringPolymorphic', + 'ImagePolymorphic', + 'LatentsPolymorphic', + 'ConditioningPolymorphic', + 'ControlPolymorphic', + 'ColorPolymorphic', +]; + +export const COLLECTION_MAP = { + integer: 'IntegerCollection', + boolean: 'BooleanCollection', + number: 'FloatCollection', + float: 'FloatCollection', + string: 'StringCollection', + ImageField: 'ImageCollection', + LatentsField: 'LatentsCollection', + ConditioningField: 'ConditioningCollection', + ControlField: 'ControlCollection', + ColorField: 'ColorCollection', +}; +export const isCollectionItemType = ( + itemType: string | undefined +): itemType is keyof typeof COLLECTION_MAP => + Boolean(itemType && itemType in COLLECTION_MAP); + +export const SINGLE_TO_POLYMORPHIC_MAP = { + integer: 'IntegerPolymorphic', + boolean: 'BooleanPolymorphic', + number: 'FloatPolymorphic', + float: 'FloatPolymorphic', + string: 'StringPolymorphic', + ImageField: 'ImagePolymorphic', + LatentsField: 'LatentsPolymorphic', + ConditioningField: 'ConditioningPolymorphic', + ControlField: 'ControlPolymorphic', + ColorField: 'ColorPolymorphic', +}; + +export const POLYMORPHIC_TO_SINGLE_MAP = { + IntegerPolymorphic: 'integer', + BooleanPolymorphic: 'boolean', + FloatPolymorphic: 'float', + StringPolymorphic: 'string', + ImagePolymorphic: 'ImageField', + LatentsPolymorphic: 'LatentsField', + ConditioningPolymorphic: 'ConditioningField', + ControlPolymorphic: 'ControlField', + ColorPolymorphic: 'ColorField', +}; + +export const isPolymorphicItemType = ( + itemType: string | undefined +): itemType is keyof typeof SINGLE_TO_POLYMORPHIC_MAP => + Boolean(itemType && itemType in SINGLE_TO_POLYMORPHIC_MAP); + export const FIELDS: Record = { - integer: { - title: 'Integer', - description: 'Integers are whole numbers, without a decimal point.', - color: 'red.500', - }, - float: { - title: 'Float', - description: 'Floats are numbers with a decimal point.', - color: 'orange.500', - }, - string: { - title: 'String', - description: 'Strings are text.', - color: 'yellow.500', - }, boolean: { - title: 'Boolean', color: 'green.500', description: 'Booleans are true or false.', + title: 'Boolean', }, - enum: { - title: 'Enum', - description: 'Enums are values that may be one of a number of options.', - color: 'blue.500', + BooleanCollection: { + color: 'green.500', + description: 'A collection of booleans.', + title: 'Boolean Collection', }, - array: { - title: 'Array', - description: 'Enums are values that may be one of a number of options.', - color: 'base.500', - }, - ImageField: { - title: 'Image', - description: 'Images may be passed between nodes.', - color: 'purple.500', - }, - DenoiseMaskField: { - title: 'Denoise Mask', - description: 'Denoise Mask may be passed between nodes', - color: 'base.500', - }, - LatentsField: { - title: 'Latents', - description: 'Latents may be passed between nodes.', - color: 'pink.500', - }, - LatentsCollection: { - title: 'Latents Collection', - description: 'Latents may be passed between nodes.', - color: 'pink.500', - }, - ConditioningField: { - color: 'cyan.500', - title: 'Conditioning', - description: 'Conditioning may be passed between nodes.', - }, - ConditioningCollection: { - color: 'cyan.500', - title: 'Conditioning Collection', - description: 'Conditioning may be passed between nodes.', - }, - ImageCollection: { - title: 'Image Collection', - description: 'A collection of images.', - color: 'base.300', - }, - UNetField: { - color: 'red.500', - title: 'UNet', - description: 'UNet submodel.', + BooleanPolymorphic: { + color: 'green.500', + description: 'A collection of booleans.', + title: 'Boolean Polymorphic', }, ClipField: { - color: 'green.500', - title: 'Clip', + color: 'green.300', description: 'Tokenizer and text_encoder submodels.', - }, - VaeField: { - color: 'blue.500', - title: 'Vae', - description: 'Vae submodel.', - }, - ControlField: { - color: 'cyan.500', - title: 'Control', - description: 'Control info passed between nodes.', - }, - MainModelField: { - color: 'teal.500', - title: 'Model', - description: 'TODO', - }, - SDXLRefinerModelField: { - color: 'teal.500', - title: 'Refiner Model', - description: 'TODO', - }, - VaeModelField: { - color: 'teal.500', - title: 'VAE', - description: 'TODO', - }, - LoRAModelField: { - color: 'teal.500', - title: 'LoRA', - description: 'TODO', - }, - ControlNetModelField: { - color: 'teal.500', - title: 'ControlNet', - description: 'TODO', - }, - Scheduler: { - color: 'base.500', - title: 'Scheduler', - description: 'TODO', + title: 'Clip', }, Collection: { color: 'base.500', - title: 'Collection', description: 'TODO', + title: 'Collection', }, CollectionItem: { color: 'base.500', - title: 'Collection Item', description: 'TODO', + title: 'Collection Item', + }, + ColorCollection: { + color: 'pink.300', + description: 'A collection of colors.', + title: 'Color Collection', }, ColorField: { - title: 'Color', + color: 'pink.300', description: 'A RGBA color.', - color: 'base.500', + title: 'Color', }, - BooleanCollection: { - title: 'Boolean Collection', - description: 'A collection of booleans.', - color: 'green.500', + ColorPolymorphic: { + color: 'pink.300', + description: 'A collection of colors.', + title: 'Color Polymorphic', }, - IntegerCollection: { - title: 'Integer Collection', - description: 'A collection of integers.', - color: 'red.500', + ConditioningCollection: { + color: 'cyan.500', + description: 'Conditioning may be passed between nodes.', + title: 'Conditioning Collection', + }, + ConditioningField: { + color: 'cyan.500', + description: 'Conditioning may be passed between nodes.', + title: 'Conditioning', + }, + ConditioningPolymorphic: { + color: 'cyan.500', + description: 'Conditioning may be passed between nodes.', + title: 'Conditioning Polymorphic', + }, + ControlCollection: { + color: 'teal.500', + description: 'Control info passed between nodes.', + title: 'Control Collection', + }, + ControlField: { + color: 'teal.500', + description: 'Control info passed between nodes.', + title: 'Control', + }, + ControlNetModelField: { + color: 'teal.500', + description: 'TODO', + title: 'ControlNet', + }, + ControlPolymorphic: { + color: 'teal.500', + description: 'Control info passed between nodes.', + title: 'Control Polymorphic', + }, + DenoiseMaskField: { + color: 'blue.300', + description: 'Denoise Mask may be passed between nodes', + title: 'Denoise Mask', + }, + enum: { + color: 'blue.500', + description: 'Enums are values that may be one of a number of options.', + title: 'Enum', + }, + float: { + color: 'orange.500', + description: 'Floats are numbers with a decimal point.', + title: 'Float', }, FloatCollection: { color: 'orange.500', - title: 'Float Collection', description: 'A collection of floats.', + title: 'Float Collection', }, - ColorCollection: { - color: 'base.500', - title: 'Color Collection', - description: 'A collection of colors.', + FloatPolymorphic: { + color: 'orange.500', + description: 'A collection of floats.', + title: 'Float Polymorphic', + }, + ImageCollection: { + color: 'purple.500', + description: 'A collection of images.', + title: 'Image Collection', + }, + ImageField: { + color: 'purple.500', + description: 'Images may be passed between nodes.', + title: 'Image', + }, + ImagePolymorphic: { + color: 'purple.500', + description: 'A collection of images.', + title: 'Image Polymorphic', + }, + integer: { + color: 'red.500', + description: 'Integers are whole numbers, without a decimal point.', + title: 'Integer', + }, + IntegerCollection: { + color: 'red.500', + description: 'A collection of integers.', + title: 'Integer Collection', + }, + IntegerPolymorphic: { + color: 'red.500', + description: 'A collection of integers.', + title: 'Integer Polymorphic', + }, + LatentsCollection: { + color: 'pink.500', + description: 'Latents may be passed between nodes.', + title: 'Latents Collection', + }, + LatentsField: { + color: 'pink.500', + description: 'Latents may be passed between nodes.', + title: 'Latents', + }, + LatentsPolymorphic: { + color: 'pink.500', + description: 'Latents may be passed between nodes.', + title: 'Latents Polymorphic', + }, + LoRAModelField: { + color: 'teal.300', + description: 'TODO', + title: 'LoRA', + }, + MainModelField: { + color: 'teal.300', + description: 'TODO', + title: 'Model', }, ONNXModelField: { - color: 'base.500', - title: 'ONNX Model', + color: 'teal.300', description: 'ONNX model field.', + title: 'ONNX Model', + }, + Scheduler: { + color: 'base.500', + description: 'TODO', + title: 'Scheduler', }, SDXLMainModelField: { - color: 'base.500', - title: 'SDXL Model', + color: 'teal.300', description: 'SDXL model field.', + title: 'SDXL Model', + }, + SDXLRefinerModelField: { + color: 'teal.300', + description: 'TODO', + title: 'Refiner Model', + }, + string: { + color: 'yellow.500', + description: 'Strings are text.', + title: 'String', }, StringCollection: { color: 'yellow.500', - title: 'String Collection', description: 'A collection of strings.', + title: 'String Collection', + }, + StringPolymorphic: { + color: 'yellow.500', + description: 'A collection of strings.', + title: 'String Polymorphic', + }, + UNetField: { + color: 'red.300', + description: 'UNet submodel.', + title: 'UNet', + }, + VaeField: { + color: 'blue.300', + description: 'Vae submodel.', + title: 'Vae', + }, + VaeModelField: { + color: 'teal.300', + description: 'TODO', + title: 'VAE', }, }; diff --git a/invokeai/frontend/web/src/features/nodes/types/types.ts b/invokeai/frontend/web/src/features/nodes/types/types.ts index aee5c69705..f7986a5028 100644 --- a/invokeai/frontend/web/src/features/nodes/types/types.ts +++ b/invokeai/frontend/web/src/features/nodes/types/types.ts @@ -11,7 +11,7 @@ import { keyBy } from 'lodash-es'; import { OpenAPIV3 } from 'openapi-types'; import { RgbaColor } from 'react-colorful'; import { Node } from 'reactflow'; -import { Graph, ImageDTO, _InputField, _OutputField } from 'services/api/types'; +import { Graph, _InputField, _OutputField } from 'services/api/types'; import { AnyInvocationType, AnyResult, @@ -62,50 +62,48 @@ export type FieldUIConfig = { // TODO: Get this from the OpenAPI schema? may be tricky... export const zFieldType = z.enum([ - // region Primitives - 'integer', - 'float', 'boolean', - 'string', - 'array', - 'ImageField', - 'DenoiseMaskField', - 'LatentsField', - 'ConditioningField', - 'ControlField', - 'ColorField', - 'ImageCollection', - 'ConditioningCollection', - 'ColorCollection', - 'LatentsCollection', - 'IntegerCollection', - 'FloatCollection', - 'StringCollection', 'BooleanCollection', - // endregion - - // region Models - 'MainModelField', - 'SDXLMainModelField', - 'SDXLRefinerModelField', - 'ONNXModelField', - 'VaeModelField', - 'LoRAModelField', - 'ControlNetModelField', - 'UNetField', - 'VaeField', + 'BooleanPolymorphic', 'ClipField', - // endregion - - // region Iterate/Collect 'Collection', 'CollectionItem', - // endregion - - // region Misc + 'ColorCollection', + 'ColorField', + 'ColorPolymorphic', + 'ConditioningCollection', + 'ConditioningField', + 'ConditioningPolymorphic', + 'ControlCollection', + 'ControlField', + 'ControlNetModelField', + 'ControlPolymorphic', + 'DenoiseMaskField', 'enum', + 'float', + 'FloatCollection', + 'FloatPolymorphic', + 'ImageCollection', + 'ImageField', + 'ImagePolymorphic', + 'integer', + 'IntegerCollection', + 'IntegerPolymorphic', + 'LatentsCollection', + 'LatentsField', + 'LatentsPolymorphic', + 'LoRAModelField', + 'MainModelField', + 'ONNXModelField', 'Scheduler', - // endregion + 'SDXLMainModelField', + 'SDXLRefinerModelField', + 'string', + 'StringCollection', + 'StringPolymorphic', + 'UNetField', + 'VaeField', + 'VaeModelField', ]); export type FieldType = z.infer; @@ -122,38 +120,6 @@ export const isFieldType = (value: unknown): value is FieldType => zFieldType.safeParse(value).success || zReservedFieldType.safeParse(value).success; -/** - * An input field template is generated on each page load from the OpenAPI schema. - * - * The template provides the field type and other field metadata (e.g. title, description, - * maximum length, pattern to match, etc). - */ -export type InputFieldTemplate = - | IntegerInputFieldTemplate - | FloatInputFieldTemplate - | StringInputFieldTemplate - | BooleanInputFieldTemplate - | ImageInputFieldTemplate - | DenoiseMaskInputFieldTemplate - | LatentsInputFieldTemplate - | ConditioningInputFieldTemplate - | UNetInputFieldTemplate - | ClipInputFieldTemplate - | VaeInputFieldTemplate - | ControlInputFieldTemplate - | EnumInputFieldTemplate - | MainModelInputFieldTemplate - | SDXLMainModelInputFieldTemplate - | SDXLRefinerModelInputFieldTemplate - | VaeModelInputFieldTemplate - | LoRAModelInputFieldTemplate - | ControlNetModelInputFieldTemplate - | CollectionInputFieldTemplate - | CollectionItemInputFieldTemplate - | ColorInputFieldTemplate - | ImageCollectionInputFieldTemplate - | SchedulerInputFieldTemplate; - /** * Indicates the kind of input(s) this field may have. */ @@ -232,24 +198,88 @@ export const zIntegerInputFieldValue = zInputFieldValueBase.extend({ }); export type IntegerInputFieldValue = z.infer; +export const zIntegerCollectionInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('IntegerCollection'), + value: z.array(z.number().int()).optional(), +}); +export type IntegerCollectionInputFieldValue = z.infer< + typeof zIntegerCollectionInputFieldValue +>; + +export const zIntegerPolymorphicInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('IntegerPolymorphic'), + value: z.union([z.number().int(), z.array(z.number().int())]).optional(), +}); +export type IntegerPolymorphicInputFieldValue = z.infer< + typeof zIntegerPolymorphicInputFieldValue +>; + export const zFloatInputFieldValue = zInputFieldValueBase.extend({ type: z.literal('float'), value: z.number().optional(), }); export type FloatInputFieldValue = z.infer; +export const zFloatCollectionInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('FloatCollection'), + value: z.array(z.number()).optional(), +}); +export type FloatCollectionInputFieldValue = z.infer< + typeof zFloatCollectionInputFieldValue +>; + +export const zFloatPolymorphicInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('FloatPolymorphic'), + value: z.union([z.number(), z.array(z.number())]).optional(), +}); +export type FloatPolymorphicInputFieldValue = z.infer< + typeof zFloatPolymorphicInputFieldValue +>; + export const zStringInputFieldValue = zInputFieldValueBase.extend({ type: z.literal('string'), value: z.string().optional(), }); export type StringInputFieldValue = z.infer; +export const zStringCollectionInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('StringCollection'), + value: z.array(z.string()).optional(), +}); +export type StringCollectionInputFieldValue = z.infer< + typeof zStringCollectionInputFieldValue +>; + +export const zStringPolymorphicInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('StringPolymorphic'), + value: z.union([z.string(), z.array(z.string())]).optional(), +}); +export type StringPolymorphicInputFieldValue = z.infer< + typeof zStringPolymorphicInputFieldValue +>; + export const zBooleanInputFieldValue = zInputFieldValueBase.extend({ type: z.literal('boolean'), value: z.boolean().optional(), }); export type BooleanInputFieldValue = z.infer; +export const zBooleanCollectionInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('BooleanCollection'), + value: z.array(z.boolean()).optional(), +}); +export type BooleanCollectionInputFieldValue = z.infer< + typeof zBooleanCollectionInputFieldValue +>; + +export const zBooleanPolymorphicInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('BooleanPolymorphic'), + value: z.union([z.boolean(), z.array(z.boolean())]).optional(), +}); +export type BooleanPolymorphicInputFieldValue = z.infer< + typeof zBooleanPolymorphicInputFieldValue +>; + export const zEnumInputFieldValue = zInputFieldValueBase.extend({ type: z.literal('enum'), value: z.union([z.string(), z.number()]).optional(), @@ -262,6 +292,22 @@ export const zLatentsInputFieldValue = zInputFieldValueBase.extend({ }); export type LatentsInputFieldValue = z.infer; +export const zLatentsCollectionInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('LatentsCollection'), + value: z.array(zLatentsField).optional(), +}); +export type LatentsCollectionInputFieldValue = z.infer< + typeof zLatentsCollectionInputFieldValue +>; + +export const zLatentsPolymorphicInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('LatentsPolymorphic'), + value: z.union([zLatentsField, z.array(zLatentsField)]).optional(), +}); +export type LatentsPolymorphicInputFieldValue = z.infer< + typeof zLatentsPolymorphicInputFieldValue +>; + export const zDenoiseMaskInputFieldValue = zInputFieldValueBase.extend({ type: z.literal('DenoiseMaskField'), value: zDenoiseMaskField.optional(), @@ -278,6 +324,26 @@ export type ConditioningInputFieldValue = z.infer< typeof zConditioningInputFieldValue >; +export const zConditioningCollectionInputFieldValue = + zInputFieldValueBase.extend({ + type: z.literal('ConditioningCollection'), + value: z.array(zConditioningField).optional(), + }); +export type ConditioningCollectionInputFieldValue = z.infer< + typeof zConditioningCollectionInputFieldValue +>; + +export const zConditioningPolymorphicInputFieldValue = + zInputFieldValueBase.extend({ + type: z.literal('ConditioningPolymorphic'), + value: z + .union([zConditioningField, z.array(zConditioningField)]) + .optional(), + }); +export type ConditioningPolymorphicInputFieldValue = z.infer< + typeof zConditioningPolymorphicInputFieldValue +>; + export const zControlNetModel = zModelIdentifier; export type ControlNetModel = z.infer; @@ -302,6 +368,22 @@ export const zControlInputFieldValue = zInputFieldValueBase.extend({ }); export type ControlInputFieldValue = z.infer; +export const zControlPolymorphicInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('ControlPolymorphic'), + value: z.union([zControlField, z.array(zControlField)]).optional(), +}); +export type ControlPolymorphicInputFieldValue = z.infer< + typeof zControlPolymorphicInputFieldValue +>; + +export const zControlCollectionInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('ControlCollection'), + value: z.array(zControlField).optional(), +}); +export type ControlCollectionInputFieldValue = z.infer< + typeof zControlCollectionInputFieldValue +>; + export const zModelType = z.enum([ 'onnx', 'main', @@ -381,6 +463,14 @@ export const zImageInputFieldValue = zInputFieldValueBase.extend({ }); export type ImageInputFieldValue = z.infer; +export const zImagePolymorphicInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('ImagePolymorphic'), + value: z.union([zImageField, z.array(zImageField)]).optional(), +}); +export type ImagePolymorphicInputFieldValue = z.infer< + typeof zImagePolymorphicInputFieldValue +>; + export const zImageCollectionInputFieldValue = zInputFieldValueBase.extend({ type: z.literal('ImageCollection'), value: z.array(zImageField).optional(), @@ -473,6 +563,22 @@ export const zColorInputFieldValue = zInputFieldValueBase.extend({ }); export type ColorInputFieldValue = z.infer; +export const zColorCollectionInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('ColorCollection'), + value: z.array(zColorField).optional(), +}); +export type ColorCollectionInputFieldValue = z.infer< + typeof zColorCollectionInputFieldValue +>; + +export const zColorPolymorphicInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('ColorPolymorphic'), + value: z.union([zColorField, z.array(zColorField)]).optional(), +}); +export type ColorPolymorphicInputFieldValue = z.infer< + typeof zColorPolymorphicInputFieldValue +>; + export const zSchedulerInputFieldValue = zInputFieldValueBase.extend({ type: z.literal('Scheduler'), value: zScheduler.optional(), @@ -482,30 +588,47 @@ export type SchedulerInputFieldValue = z.infer< >; export const zInputFieldValue = z.discriminatedUnion('type', [ - zIntegerInputFieldValue, - zFloatInputFieldValue, - zStringInputFieldValue, + zBooleanCollectionInputFieldValue, zBooleanInputFieldValue, - zImageInputFieldValue, - zLatentsInputFieldValue, - zDenoiseMaskInputFieldValue, - zConditioningInputFieldValue, - zUNetInputFieldValue, + zBooleanPolymorphicInputFieldValue, zClipInputFieldValue, - zVaeInputFieldValue, - zControlInputFieldValue, - zEnumInputFieldValue, - zMainModelInputFieldValue, - zSDXLMainModelInputFieldValue, - zSDXLRefinerModelInputFieldValue, - zVaeModelInputFieldValue, - zLoRAModelInputFieldValue, - zControlNetModelInputFieldValue, zCollectionInputFieldValue, zCollectionItemInputFieldValue, zColorInputFieldValue, + zColorCollectionInputFieldValue, + zColorPolymorphicInputFieldValue, + zConditioningInputFieldValue, + zConditioningCollectionInputFieldValue, + zConditioningPolymorphicInputFieldValue, + zControlInputFieldValue, + zControlNetModelInputFieldValue, + zControlCollectionInputFieldValue, + zControlPolymorphicInputFieldValue, + zDenoiseMaskInputFieldValue, + zEnumInputFieldValue, + zFloatCollectionInputFieldValue, + zFloatInputFieldValue, + zFloatPolymorphicInputFieldValue, zImageCollectionInputFieldValue, + zImagePolymorphicInputFieldValue, + zImageInputFieldValue, + zIntegerCollectionInputFieldValue, + zIntegerPolymorphicInputFieldValue, + zIntegerInputFieldValue, + zLatentsInputFieldValue, + zLatentsCollectionInputFieldValue, + zLatentsPolymorphicInputFieldValue, + zLoRAModelInputFieldValue, + zMainModelInputFieldValue, zSchedulerInputFieldValue, + zSDXLMainModelInputFieldValue, + zSDXLRefinerModelInputFieldValue, + zStringCollectionInputFieldValue, + zStringPolymorphicInputFieldValue, + zStringInputFieldValue, + zUNetInputFieldValue, + zVaeInputFieldValue, + zVaeModelInputFieldValue, ]); export type InputFieldValue = z.infer; @@ -514,7 +637,6 @@ export type InputFieldTemplateBase = { name: string; title: string; description: string; - type: FieldType; required: boolean; fieldKind: 'input'; } & _InputField; @@ -529,6 +651,19 @@ export type IntegerInputFieldTemplate = InputFieldTemplateBase & { exclusiveMinimum?: boolean; }; +export type IntegerCollectionInputFieldTemplate = InputFieldTemplateBase & { + type: 'IntegerCollection'; + default: number[]; + item_default?: number; +}; + +export type IntegerPolymorphicInputFieldTemplate = Omit< + IntegerInputFieldTemplate, + 'type' +> & { + type: 'IntegerPolymorphic'; +}; + export type FloatInputFieldTemplate = InputFieldTemplateBase & { type: 'float'; default: number; @@ -539,6 +674,19 @@ export type FloatInputFieldTemplate = InputFieldTemplateBase & { exclusiveMinimum?: boolean; }; +export type FloatCollectionInputFieldTemplate = InputFieldTemplateBase & { + type: 'FloatCollection'; + default: number[]; + item_default?: number; +}; + +export type FloatPolymorphicInputFieldTemplate = Omit< + FloatInputFieldTemplate, + 'type' +> & { + type: 'FloatPolymorphic'; +}; + export type StringInputFieldTemplate = InputFieldTemplateBase & { type: 'string'; default: string; @@ -547,19 +695,53 @@ export type StringInputFieldTemplate = InputFieldTemplateBase & { pattern?: string; }; +export type StringCollectionInputFieldTemplate = InputFieldTemplateBase & { + type: 'StringCollection'; + default: string[]; + item_default?: string; +}; + +export type StringPolymorphicInputFieldTemplate = Omit< + StringInputFieldTemplate, + 'type' +> & { + type: 'StringPolymorphic'; +}; + export type BooleanInputFieldTemplate = InputFieldTemplateBase & { default: boolean; type: 'boolean'; }; +export type BooleanCollectionInputFieldTemplate = InputFieldTemplateBase & { + type: 'BooleanCollection'; + default: boolean[]; + item_default?: boolean; +}; + +export type BooleanPolymorphicInputFieldTemplate = Omit< + BooleanInputFieldTemplate, + 'type' +> & { + type: 'BooleanPolymorphic'; +}; + export type ImageInputFieldTemplate = InputFieldTemplateBase & { - default: ImageDTO; + default: ImageField; type: 'ImageField'; }; export type ImageCollectionInputFieldTemplate = InputFieldTemplateBase & { default: ImageField[]; type: 'ImageCollection'; + item_default?: ImageField; +}; + +export type ImagePolymorphicInputFieldTemplate = Omit< + ImageInputFieldTemplate, + 'type' +> & { + type: 'ImagePolymorphic'; }; export type DenoiseMaskInputFieldTemplate = InputFieldTemplateBase & { @@ -568,15 +750,40 @@ export type DenoiseMaskInputFieldTemplate = InputFieldTemplateBase & { }; export type LatentsInputFieldTemplate = InputFieldTemplateBase & { - default: string; + default: LatentsField; type: 'LatentsField'; }; +export type LatentsCollectionInputFieldTemplate = InputFieldTemplateBase & { + default: LatentsField[]; + type: 'LatentsCollection'; + item_default?: LatentsField; +}; + +export type LatentsPolymorphicInputFieldTemplate = InputFieldTemplateBase & { + default: LatentsField; + type: 'LatentsPolymorphic'; +}; + export type ConditioningInputFieldTemplate = InputFieldTemplateBase & { default: undefined; type: 'ConditioningField'; }; +export type ConditioningCollectionInputFieldTemplate = + InputFieldTemplateBase & { + default: ConditioningField[]; + type: 'ConditioningCollection'; + item_default?: ConditioningField; + }; + +export type ConditioningPolymorphicInputFieldTemplate = Omit< + ConditioningInputFieldTemplate, + 'type' +> & { + type: 'ConditioningPolymorphic'; +}; + export type UNetInputFieldTemplate = InputFieldTemplateBase & { default: undefined; type: 'UNetField'; @@ -597,6 +804,19 @@ export type ControlInputFieldTemplate = InputFieldTemplateBase & { type: 'ControlField'; }; +export type ControlCollectionInputFieldTemplate = InputFieldTemplateBase & { + default: undefined; + type: 'ControlCollection'; + item_default?: ControlField; +}; + +export type ControlPolymorphicInputFieldTemplate = Omit< + ControlInputFieldTemplate, + 'type' +> & { + type: 'ControlPolymorphic'; +}; + export type EnumInputFieldTemplate = InputFieldTemplateBase & { default: string | number; type: 'enum'; @@ -649,6 +869,18 @@ export type ColorInputFieldTemplate = InputFieldTemplateBase & { type: 'ColorField'; }; +export type ColorPolymorphicInputFieldTemplate = Omit< + ColorInputFieldTemplate, + 'type' +> & { + type: 'ColorPolymorphic'; +}; + +export type ColorCollectionInputFieldTemplate = InputFieldTemplateBase & { + default: []; + type: 'ColorCollection'; +}; + export type SchedulerInputFieldTemplate = InputFieldTemplateBase & { default: SchedulerParam; type: 'Scheduler'; @@ -659,6 +891,55 @@ export type WorkflowInputFieldTemplate = InputFieldTemplateBase & { type: 'WorkflowField'; }; +/** + * An input field template is generated on each page load from the OpenAPI schema. + * + * The template provides the field type and other field metadata (e.g. title, description, + * maximum length, pattern to match, etc). + */ +export type InputFieldTemplate = + | BooleanCollectionInputFieldTemplate + | BooleanPolymorphicInputFieldTemplate + | BooleanInputFieldTemplate + | ClipInputFieldTemplate + | CollectionInputFieldTemplate + | CollectionItemInputFieldTemplate + | ColorInputFieldTemplate + | ColorCollectionInputFieldTemplate + | ColorPolymorphicInputFieldTemplate + | ConditioningInputFieldTemplate + | ConditioningCollectionInputFieldTemplate + | ConditioningPolymorphicInputFieldTemplate + | ControlInputFieldTemplate + | ControlCollectionInputFieldTemplate + | ControlNetModelInputFieldTemplate + | ControlPolymorphicInputFieldTemplate + | DenoiseMaskInputFieldTemplate + | EnumInputFieldTemplate + | FloatCollectionInputFieldTemplate + | FloatInputFieldTemplate + | FloatPolymorphicInputFieldTemplate + | ImageCollectionInputFieldTemplate + | ImagePolymorphicInputFieldTemplate + | ImageInputFieldTemplate + | IntegerCollectionInputFieldTemplate + | IntegerPolymorphicInputFieldTemplate + | IntegerInputFieldTemplate + | LatentsInputFieldTemplate + | LatentsCollectionInputFieldTemplate + | LatentsPolymorphicInputFieldTemplate + | LoRAModelInputFieldTemplate + | MainModelInputFieldTemplate + | SchedulerInputFieldTemplate + | SDXLMainModelInputFieldTemplate + | SDXLRefinerModelInputFieldTemplate + | StringCollectionInputFieldTemplate + | StringPolymorphicInputFieldTemplate + | StringInputFieldTemplate + | UNetInputFieldTemplate + | VaeInputFieldTemplate + | VaeModelInputFieldTemplate; + export const isInputFieldValue = ( field?: InputFieldValue | OutputFieldValue ): field is InputFieldValue => Boolean(field && field.fieldKind === 'input'); @@ -731,8 +1012,22 @@ export type InvocationSchemaObject = ( ) & { class: 'invocation' }; export const isSchemaObject = ( - obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject -): obj is OpenAPIV3.SchemaObject => !('$ref' in obj); + obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject | undefined +): obj is OpenAPIV3.SchemaObject => Boolean(obj && !('$ref' in obj)); + +export const isArraySchemaObject = ( + obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject | undefined +): obj is OpenAPIV3.ArraySchemaObject => + Boolean(obj && !('$ref' in obj) && obj.type === 'array'); + +export const isNonArraySchemaObject = ( + obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject | undefined +): obj is OpenAPIV3.NonArraySchemaObject => + Boolean(obj && !('$ref' in obj) && obj.type !== 'array'); + +export const isRefObject = ( + obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject | undefined +): obj is OpenAPIV3.ReferenceObject => Boolean(obj && '$ref' in obj); export const isInvocationSchemaObject = ( obj: diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts index 3c4ec7e089..20463f37f6 100644 --- a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts +++ b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts @@ -1,5 +1,14 @@ +import { isBoolean, isInteger, isNumber, isString } from 'lodash-es'; import { OpenAPIV3 } from 'openapi-types'; import { + COLLECTION_MAP, + POLYMORPHIC_TYPES, + SINGLE_TO_POLYMORPHIC_MAP, + isCollectionItemType, + isPolymorphicItemType, +} from '../types/constants'; +import { + BooleanCollectionInputFieldTemplate, BooleanInputFieldTemplate, ClipInputFieldTemplate, CollectionInputFieldTemplate, @@ -11,10 +20,13 @@ import { DenoiseMaskInputFieldTemplate, EnumInputFieldTemplate, FieldType, + FloatCollectionInputFieldTemplate, + FloatPolymorphicInputFieldTemplate, FloatInputFieldTemplate, ImageCollectionInputFieldTemplate, ImageInputFieldTemplate, InputFieldTemplateBase, + IntegerCollectionInputFieldTemplate, IntegerInputFieldTemplate, InvocationFieldSchema, InvocationSchemaObject, @@ -24,11 +36,32 @@ import { SDXLMainModelInputFieldTemplate, SDXLRefinerModelInputFieldTemplate, SchedulerInputFieldTemplate, + StringCollectionInputFieldTemplate, StringInputFieldTemplate, UNetInputFieldTemplate, VaeInputFieldTemplate, VaeModelInputFieldTemplate, + isArraySchemaObject, + isNonArraySchemaObject, + isRefObject, + isSchemaObject, + ControlPolymorphicInputFieldTemplate, + ColorPolymorphicInputFieldTemplate, + ColorCollectionInputFieldTemplate, + IntegerPolymorphicInputFieldTemplate, + StringPolymorphicInputFieldTemplate, + BooleanPolymorphicInputFieldTemplate, + ImagePolymorphicInputFieldTemplate, + LatentsPolymorphicInputFieldTemplate, + LatentsCollectionInputFieldTemplate, + ConditioningPolymorphicInputFieldTemplate, + ConditioningCollectionInputFieldTemplate, + ControlCollectionInputFieldTemplate, + ImageField, + LatentsField, + ConditioningField, } from '../types/types'; +import { ControlField } from 'services/api/types'; export type BaseFieldProperties = 'name' | 'title' | 'description'; @@ -45,15 +78,8 @@ export type BuildInputFieldArg = { * @example * refObjectToFieldType({ "$ref": "#/components/schemas/ImageField" }) --> 'ImageField' */ -export const refObjectToFieldType = ( - refObject: OpenAPIV3.ReferenceObject -): FieldType => { - const name = refObject.$ref.split('/').slice(-1)[0]; - if (!name) { - throw `Unknown field type: ${name}`; - } - return name as FieldType; -}; +export const refObjectToSchemaName = (refObject: OpenAPIV3.ReferenceObject) => + refObject.$ref.split('/').slice(-1)[0]; const buildIntegerInputFieldTemplate = ({ schemaObject, @@ -88,6 +114,57 @@ const buildIntegerInputFieldTemplate = ({ return template; }; +const buildIntegerPolymorphicInputFieldTemplate = ({ + schemaObject, + baseField, +}: BuildInputFieldArg): IntegerPolymorphicInputFieldTemplate => { + const template: IntegerPolymorphicInputFieldTemplate = { + ...baseField, + type: 'IntegerPolymorphic', + default: schemaObject.default ?? 0, + }; + + if (schemaObject.multipleOf !== undefined) { + template.multipleOf = schemaObject.multipleOf; + } + + if (schemaObject.maximum !== undefined) { + template.maximum = schemaObject.maximum; + } + + if (schemaObject.exclusiveMaximum !== undefined) { + template.exclusiveMaximum = schemaObject.exclusiveMaximum; + } + + if (schemaObject.minimum !== undefined) { + template.minimum = schemaObject.minimum; + } + + if (schemaObject.exclusiveMinimum !== undefined) { + template.exclusiveMinimum = schemaObject.exclusiveMinimum; + } + + return template; +}; + +const buildIntegerCollectionInputFieldTemplate = ({ + schemaObject, + baseField, +}: BuildInputFieldArg): IntegerCollectionInputFieldTemplate => { + const item_default = + isNumber(schemaObject.item_default) && isInteger(schemaObject.item_default) + ? schemaObject.item_default + : 0; + const template: IntegerCollectionInputFieldTemplate = { + ...baseField, + type: 'IntegerCollection', + default: schemaObject.default ?? [], + item_default, + }; + + return template; +}; + const buildFloatInputFieldTemplate = ({ schemaObject, baseField, @@ -121,6 +198,54 @@ const buildFloatInputFieldTemplate = ({ return template; }; +const buildFloatPolymorphicInputFieldTemplate = ({ + schemaObject, + baseField, +}: BuildInputFieldArg): FloatPolymorphicInputFieldTemplate => { + const template: FloatPolymorphicInputFieldTemplate = { + ...baseField, + type: 'FloatPolymorphic', + default: schemaObject.default ?? 0, + }; + if (schemaObject.multipleOf !== undefined) { + template.multipleOf = schemaObject.multipleOf; + } + + if (schemaObject.maximum !== undefined) { + template.maximum = schemaObject.maximum; + } + + if (schemaObject.exclusiveMaximum !== undefined) { + template.exclusiveMaximum = schemaObject.exclusiveMaximum; + } + + if (schemaObject.minimum !== undefined) { + template.minimum = schemaObject.minimum; + } + + if (schemaObject.exclusiveMinimum !== undefined) { + template.exclusiveMinimum = schemaObject.exclusiveMinimum; + } + return template; +}; + +const buildFloatCollectionInputFieldTemplate = ({ + schemaObject, + baseField, +}: BuildInputFieldArg): FloatCollectionInputFieldTemplate => { + const item_default = isNumber(schemaObject.item_default) + ? schemaObject.item_default + : 0; + const template: FloatCollectionInputFieldTemplate = { + ...baseField, + type: 'FloatCollection', + default: schemaObject.default ?? [], + item_default, + }; + + return template; +}; + const buildStringInputFieldTemplate = ({ schemaObject, baseField, @@ -146,6 +271,48 @@ const buildStringInputFieldTemplate = ({ return template; }; +const buildStringPolymorphicInputFieldTemplate = ({ + schemaObject, + baseField, +}: BuildInputFieldArg): StringPolymorphicInputFieldTemplate => { + const template: StringPolymorphicInputFieldTemplate = { + ...baseField, + type: 'StringPolymorphic', + default: schemaObject.default ?? '', + }; + + if (schemaObject.minLength !== undefined) { + template.minLength = schemaObject.minLength; + } + + if (schemaObject.maxLength !== undefined) { + template.maxLength = schemaObject.maxLength; + } + + if (schemaObject.pattern !== undefined) { + template.pattern = schemaObject.pattern; + } + + return template; +}; + +const buildStringCollectionInputFieldTemplate = ({ + schemaObject, + baseField, +}: BuildInputFieldArg): StringCollectionInputFieldTemplate => { + const item_default = isString(schemaObject.item_default) + ? schemaObject.item_default + : ''; + const template: StringCollectionInputFieldTemplate = { + ...baseField, + type: 'StringCollection', + default: schemaObject.default ?? [], + item_default, + }; + + return template; +}; + const buildBooleanInputFieldTemplate = ({ schemaObject, baseField, @@ -159,6 +326,37 @@ const buildBooleanInputFieldTemplate = ({ return template; }; +const buildBooleanPolymorphicInputFieldTemplate = ({ + schemaObject, + baseField, +}: BuildInputFieldArg): BooleanPolymorphicInputFieldTemplate => { + const template: BooleanPolymorphicInputFieldTemplate = { + ...baseField, + type: 'BooleanPolymorphic', + default: schemaObject.default ?? false, + }; + + return template; +}; + +const buildBooleanCollectionInputFieldTemplate = ({ + schemaObject, + baseField, +}: BuildInputFieldArg): BooleanCollectionInputFieldTemplate => { + const item_default = + schemaObject.item_default && isBoolean(schemaObject.item_default) + ? schemaObject.item_default + : false; + const template: BooleanCollectionInputFieldTemplate = { + ...baseField, + type: 'BooleanCollection', + default: schemaObject.default ?? [], + item_default, + }; + + return template; +}; + const buildMainModelInputFieldTemplate = ({ schemaObject, baseField, @@ -250,6 +448,19 @@ const buildImageInputFieldTemplate = ({ return template; }; +const buildImagePolymorphicInputFieldTemplate = ({ + schemaObject, + baseField, +}: BuildInputFieldArg): ImagePolymorphicInputFieldTemplate => { + const template: ImagePolymorphicInputFieldTemplate = { + ...baseField, + type: 'ImagePolymorphic', + default: schemaObject.default ?? undefined, + }; + + return template; +}; + const buildImageCollectionInputFieldTemplate = ({ schemaObject, baseField, @@ -257,7 +468,8 @@ const buildImageCollectionInputFieldTemplate = ({ const template: ImageCollectionInputFieldTemplate = { ...baseField, type: 'ImageCollection', - default: schemaObject.default ?? undefined, + default: schemaObject.default ?? [], + item_default: (schemaObject.item_default as ImageField) ?? undefined, }; return template; @@ -289,6 +501,33 @@ const buildLatentsInputFieldTemplate = ({ return template; }; +const buildLatentsPolymorphicInputFieldTemplate = ({ + schemaObject, + baseField, +}: BuildInputFieldArg): LatentsPolymorphicInputFieldTemplate => { + const template: LatentsPolymorphicInputFieldTemplate = { + ...baseField, + type: 'LatentsPolymorphic', + default: schemaObject.default ?? undefined, + }; + + return template; +}; + +const buildLatentsCollectionInputFieldTemplate = ({ + schemaObject, + baseField, +}: BuildInputFieldArg): LatentsCollectionInputFieldTemplate => { + const template: LatentsCollectionInputFieldTemplate = { + ...baseField, + type: 'LatentsCollection', + default: schemaObject.default ?? [], + item_default: (schemaObject.item_default as LatentsField) ?? undefined, + }; + + return template; +}; + const buildConditioningInputFieldTemplate = ({ schemaObject, baseField, @@ -302,6 +541,33 @@ const buildConditioningInputFieldTemplate = ({ return template; }; +const buildConditioningPolymorphicInputFieldTemplate = ({ + schemaObject, + baseField, +}: BuildInputFieldArg): ConditioningPolymorphicInputFieldTemplate => { + const template: ConditioningPolymorphicInputFieldTemplate = { + ...baseField, + type: 'ConditioningPolymorphic', + default: schemaObject.default ?? undefined, + }; + + return template; +}; + +const buildConditioningCollectionInputFieldTemplate = ({ + schemaObject, + baseField, +}: BuildInputFieldArg): ConditioningCollectionInputFieldTemplate => { + const template: ConditioningCollectionInputFieldTemplate = { + ...baseField, + type: 'ConditioningCollection', + default: schemaObject.default ?? [], + item_default: (schemaObject.item_default as ConditioningField) ?? undefined, + }; + + return template; +}; + const buildUNetInputFieldTemplate = ({ schemaObject, baseField, @@ -355,6 +621,33 @@ const buildControlInputFieldTemplate = ({ return template; }; +const buildControlPolymorphicInputFieldTemplate = ({ + schemaObject, + baseField, +}: BuildInputFieldArg): ControlPolymorphicInputFieldTemplate => { + const template: ControlPolymorphicInputFieldTemplate = { + ...baseField, + type: 'ControlPolymorphic', + default: schemaObject.default ?? undefined, + }; + + return template; +}; + +const buildControlCollectionInputFieldTemplate = ({ + schemaObject, + baseField, +}: BuildInputFieldArg): ControlCollectionInputFieldTemplate => { + const template: ControlCollectionInputFieldTemplate = { + ...baseField, + type: 'ControlCollection', + default: schemaObject.default ?? [], + item_default: (schemaObject.item_default as ControlField) ?? undefined, + }; + + return template; +}; + const buildEnumInputFieldTemplate = ({ schemaObject, baseField, @@ -408,6 +701,32 @@ const buildColorInputFieldTemplate = ({ return template; }; +const buildColorPolymorphicInputFieldTemplate = ({ + schemaObject, + baseField, +}: BuildInputFieldArg): ColorPolymorphicInputFieldTemplate => { + const template: ColorPolymorphicInputFieldTemplate = { + ...baseField, + type: 'ColorPolymorphic', + default: schemaObject.default ?? { r: 127, g: 127, b: 127, a: 255 }, + }; + + return template; +}; + +const buildColorCollectionInputFieldTemplate = ({ + schemaObject, + baseField, +}: BuildInputFieldArg): ColorCollectionInputFieldTemplate => { + const template: ColorCollectionInputFieldTemplate = { + ...baseField, + type: 'ColorCollection', + default: schemaObject.default ?? [], + }; + + return template; +}; + const buildSchedulerInputFieldTemplate = ({ schemaObject, baseField, @@ -421,45 +740,138 @@ const buildSchedulerInputFieldTemplate = ({ return template; }; -export const getFieldType = (schemaObject: InvocationFieldSchema): string => { - let fieldType = ''; - - const { ui_type } = schemaObject; - if (ui_type) { - fieldType = ui_type; +export const getFieldType = ( + schemaObject: InvocationFieldSchema +): string | undefined => { + if (schemaObject?.ui_type) { + return schemaObject.ui_type; } else if (!schemaObject.type) { - // console.log('refObject', schemaObject); // if schemaObject has no type, then it should have one of allOf, anyOf, oneOf + if (schemaObject.allOf) { - fieldType = refObjectToFieldType( - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - schemaObject.allOf![0] as OpenAPIV3.ReferenceObject - ); + const allOf = schemaObject.allOf; + if (allOf && allOf[0] && isRefObject(allOf[0])) { + return refObjectToSchemaName(allOf[0]); + } } else if (schemaObject.anyOf) { - fieldType = refObjectToFieldType( - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - schemaObject.anyOf![0] as OpenAPIV3.ReferenceObject - ); - } else if (schemaObject.oneOf) { - fieldType = refObjectToFieldType( - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - schemaObject.oneOf![0] as OpenAPIV3.ReferenceObject - ); + const anyOf = schemaObject.anyOf; + /** + * Handle Polymorphic inputs, eg string | string[]. In OpenAPI, this is: + * - an `anyOf` with two items + * - one is an `ArraySchemaObject` with a single `SchemaObject or ReferenceObject` of type T in its `items` + * - the other is a `SchemaObject` or `ReferenceObject` of type T + * + * Any other cases we ignore. + */ + + let firstType: string | undefined; + let secondType: string | undefined; + + if (isArraySchemaObject(anyOf[0])) { + // first is array, second is not + const first = anyOf[0].items; + const second = anyOf[1]; + if (isRefObject(first) && isRefObject(second)) { + firstType = refObjectToSchemaName(first); + secondType = refObjectToSchemaName(second); + } else if ( + isNonArraySchemaObject(first) && + isNonArraySchemaObject(second) + ) { + firstType = first.type; + secondType = second.type; + } + } else if (isArraySchemaObject(anyOf[1])) { + // first is not array, second is + const first = anyOf[0]; + const second = anyOf[1].items; + if (isRefObject(first) && isRefObject(second)) { + firstType = refObjectToSchemaName(first); + secondType = refObjectToSchemaName(second); + } else if ( + isNonArraySchemaObject(first) && + isNonArraySchemaObject(second) + ) { + firstType = first.type; + secondType = second.type; + } + } + if (firstType === secondType && isPolymorphicItemType(firstType)) { + return SINGLE_TO_POLYMORPHIC_MAP[firstType]; + } } } else if (schemaObject.enum) { - fieldType = 'enum'; + return 'enum'; } else if (schemaObject.type) { if (schemaObject.type === 'number') { - // floats are "number" in OpenAPI, while ints are "integer" - fieldType = 'float'; + // floats are "number" in OpenAPI, while ints are "integer" - we need to distinguish them + return 'float'; + } else if (schemaObject.type === 'array') { + const itemType = isSchemaObject(schemaObject.items) + ? schemaObject.items.type + : refObjectToSchemaName(schemaObject.items); + + if (isCollectionItemType(itemType)) { + return COLLECTION_MAP[itemType]; + } + + return; } else { - fieldType = schemaObject.type; + return schemaObject.type; } } - - return fieldType; + return; }; +const TEMPLATE_BUILDER_MAP = { + boolean: buildBooleanInputFieldTemplate, + BooleanCollection: buildBooleanCollectionInputFieldTemplate, + BooleanPolymorphic: buildBooleanPolymorphicInputFieldTemplate, + ClipField: buildClipInputFieldTemplate, + Collection: buildCollectionInputFieldTemplate, + CollectionItem: buildCollectionItemInputFieldTemplate, + ColorCollection: buildColorCollectionInputFieldTemplate, + ColorField: buildColorInputFieldTemplate, + ColorPolymorphic: buildColorPolymorphicInputFieldTemplate, + ConditioningCollection: buildConditioningCollectionInputFieldTemplate, + ConditioningField: buildConditioningInputFieldTemplate, + ConditioningPolymorphic: buildConditioningPolymorphicInputFieldTemplate, + ControlCollection: buildControlCollectionInputFieldTemplate, + ControlField: buildControlInputFieldTemplate, + ControlNetModelField: buildControlNetModelInputFieldTemplate, + ControlPolymorphic: buildControlPolymorphicInputFieldTemplate, + DenoiseMaskField: buildDenoiseMaskInputFieldTemplate, + enum: buildEnumInputFieldTemplate, + float: buildFloatInputFieldTemplate, + FloatCollection: buildFloatCollectionInputFieldTemplate, + FloatPolymorphic: buildFloatPolymorphicInputFieldTemplate, + ImageCollection: buildImageCollectionInputFieldTemplate, + ImageField: buildImageInputFieldTemplate, + ImagePolymorphic: buildImagePolymorphicInputFieldTemplate, + integer: buildIntegerInputFieldTemplate, + IntegerCollection: buildIntegerCollectionInputFieldTemplate, + IntegerPolymorphic: buildIntegerPolymorphicInputFieldTemplate, + LatentsCollection: buildLatentsCollectionInputFieldTemplate, + LatentsField: buildLatentsInputFieldTemplate, + LatentsPolymorphic: buildLatentsPolymorphicInputFieldTemplate, + LoRAModelField: buildLoRAModelInputFieldTemplate, + MainModelField: buildMainModelInputFieldTemplate, + Scheduler: buildSchedulerInputFieldTemplate, + SDXLMainModelField: buildSDXLMainModelInputFieldTemplate, + SDXLRefinerModelField: buildRefinerModelInputFieldTemplate, + string: buildStringInputFieldTemplate, + StringCollection: buildStringCollectionInputFieldTemplate, + StringPolymorphic: buildStringPolymorphicInputFieldTemplate, + UNetField: buildUNetInputFieldTemplate, + VaeField: buildVaeInputFieldTemplate, + VaeModelField: buildVaeModelInputFieldTemplate, +}; + +const isTemplatedFieldType = ( + fieldType: string | undefined +): fieldType is keyof typeof TEMPLATE_BUILDER_MAP => + Boolean(fieldType && fieldType in TEMPLATE_BUILDER_MAP); + /** * Builds an input field from an invocation schema property. * @param fieldSchema The schema object @@ -474,7 +886,8 @@ export const buildInputFieldTemplate = ( const { input, ui_hidden, ui_component, ui_type, ui_order } = fieldSchema; const extra = { - input, + // TODO: Can we support polymorphic inputs in the UI? + input: POLYMORPHIC_TYPES.includes(fieldType) ? 'connection' : input, ui_hidden, ui_component, ui_type, @@ -490,146 +903,12 @@ export const buildInputFieldTemplate = ( ...extra, }; - if (fieldType === 'ImageField') { - return buildImageInputFieldTemplate({ - schemaObject: fieldSchema, - baseField, - }); + if (!isTemplatedFieldType(fieldType)) { + return; } - if (fieldType === 'ImageCollection') { - return buildImageCollectionInputFieldTemplate({ - schemaObject: fieldSchema, - baseField, - }); - } - if (fieldType === 'DenoiseMaskField') { - return buildDenoiseMaskInputFieldTemplate({ - schemaObject: fieldSchema, - baseField, - }); - } - if (fieldType === 'LatentsField') { - return buildLatentsInputFieldTemplate({ - schemaObject: fieldSchema, - baseField, - }); - } - if (fieldType === 'ConditioningField') { - return buildConditioningInputFieldTemplate({ - schemaObject: fieldSchema, - baseField, - }); - } - if (fieldType === 'UNetField') { - return buildUNetInputFieldTemplate({ - schemaObject: fieldSchema, - baseField, - }); - } - if (fieldType === 'ClipField') { - return buildClipInputFieldTemplate({ - schemaObject: fieldSchema, - baseField, - }); - } - if (fieldType === 'VaeField') { - return buildVaeInputFieldTemplate({ schemaObject: fieldSchema, baseField }); - } - if (fieldType === 'ControlField') { - return buildControlInputFieldTemplate({ - schemaObject: fieldSchema, - baseField, - }); - } - if (fieldType === 'MainModelField') { - return buildMainModelInputFieldTemplate({ - schemaObject: fieldSchema, - baseField, - }); - } - if (fieldType === 'SDXLRefinerModelField') { - return buildRefinerModelInputFieldTemplate({ - schemaObject: fieldSchema, - baseField, - }); - } - if (fieldType === 'SDXLMainModelField') { - return buildSDXLMainModelInputFieldTemplate({ - schemaObject: fieldSchema, - baseField, - }); - } - if (fieldType === 'VaeModelField') { - return buildVaeModelInputFieldTemplate({ - schemaObject: fieldSchema, - baseField, - }); - } - if (fieldType === 'LoRAModelField') { - return buildLoRAModelInputFieldTemplate({ - schemaObject: fieldSchema, - baseField, - }); - } - if (fieldType === 'ControlNetModelField') { - return buildControlNetModelInputFieldTemplate({ - schemaObject: fieldSchema, - baseField, - }); - } - if (fieldType === 'enum') { - return buildEnumInputFieldTemplate({ - schemaObject: fieldSchema, - baseField, - }); - } - if (fieldType === 'integer') { - return buildIntegerInputFieldTemplate({ - schemaObject: fieldSchema, - baseField, - }); - } - if (fieldType === 'float') { - return buildFloatInputFieldTemplate({ - schemaObject: fieldSchema, - baseField, - }); - } - if (fieldType === 'string') { - return buildStringInputFieldTemplate({ - schemaObject: fieldSchema, - baseField, - }); - } - if (fieldType === 'boolean') { - return buildBooleanInputFieldTemplate({ - schemaObject: fieldSchema, - baseField, - }); - } - if (fieldType === 'Collection') { - return buildCollectionInputFieldTemplate({ - schemaObject: fieldSchema, - baseField, - }); - } - if (fieldType === 'CollectionItem') { - return buildCollectionItemInputFieldTemplate({ - schemaObject: fieldSchema, - baseField, - }); - } - if (fieldType === 'ColorField') { - return buildColorInputFieldTemplate({ - schemaObject: fieldSchema, - baseField, - }); - } - if (fieldType === 'Scheduler') { - return buildSchedulerInputFieldTemplate({ - schemaObject: fieldSchema, - baseField, - }); - } - return; + + return TEMPLATE_BUILDER_MAP[fieldType]({ + schemaObject: fieldSchema, + baseField, + }); }; diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts index 1d06d644d1..a3046feee7 100644 --- a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts +++ b/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts @@ -1,104 +1,79 @@ import { InputFieldTemplate, InputFieldValue } from '../types/types'; +const FIELD_VALUE_FALLBACK_MAP = { + 'enum.number': 0, + 'enum.string': '', + boolean: false, + BooleanCollection: [], + BooleanPolymorphic: false, + ClipField: undefined, + Collection: [], + CollectionItem: undefined, + ColorCollection: [], + ColorField: undefined, + ColorPolymorphic: undefined, + ConditioningCollection: [], + ConditioningField: undefined, + ConditioningPolymorphic: undefined, + ControlCollection: [], + ControlField: undefined, + ControlNetModelField: undefined, + ControlPolymorphic: undefined, + DenoiseMaskField: undefined, + float: 0, + FloatCollection: [], + FloatPolymorphic: 0, + ImageCollection: [], + ImageField: undefined, + ImagePolymorphic: undefined, + integer: 0, + IntegerCollection: [], + IntegerPolymorphic: 0, + LatentsCollection: [], + LatentsField: undefined, + LatentsPolymorphic: undefined, + LoRAModelField: undefined, + MainModelField: undefined, + ONNXModelField: undefined, + Scheduler: 'euler', + SDXLMainModelField: undefined, + SDXLRefinerModelField: undefined, + string: '', + StringCollection: [], + StringPolymorphic: '', + UNetField: undefined, + VaeField: undefined, + VaeModelField: undefined, +}; + export const buildInputFieldValue = ( id: string, template: InputFieldTemplate ): InputFieldValue => { - const fieldValue: InputFieldValue = { + // TODO: this should be `fieldValue: InputFieldValue`, but that introduces a TS issue I couldn't + // resolve - for some reason, it doesn't like `template.type`, which is the discriminant for both + // `InputFieldTemplate` union. It is (type-structurally) equal to the discriminant for the + // `InputFieldValue` union, but TS doesn't seem to like it... + const fieldValue = { id, name: template.name, type: template.type, label: '', fieldKind: 'input', - }; - - if (template.type === 'string') { - fieldValue.value = template.default ?? ''; - } - - if (template.type === 'integer') { - fieldValue.value = template.default ?? 0; - } - - if (template.type === 'float') { - fieldValue.value = template.default ?? 0; - } - - if (template.type === 'boolean') { - fieldValue.value = template.default ?? false; - } + } as InputFieldValue; if (template.type === 'enum') { if (template.enumType === 'number') { - fieldValue.value = template.default ?? 0; + fieldValue.value = + template.default ?? FIELD_VALUE_FALLBACK_MAP['enum.number']; } if (template.enumType === 'string') { - fieldValue.value = template.default ?? ''; + fieldValue.value = + template.default ?? FIELD_VALUE_FALLBACK_MAP['enum.string']; } - } - - if (template.type === 'Collection') { - fieldValue.value = template.default ?? 1; - } - - if (template.type === 'ImageField') { - fieldValue.value = undefined; - } - - if (template.type === 'ImageCollection') { - fieldValue.value = []; - } - - if (template.type === 'DenoiseMaskField') { - fieldValue.value = undefined; - } - - if (template.type === 'LatentsField') { - fieldValue.value = undefined; - } - - if (template.type === 'ConditioningField') { - fieldValue.value = undefined; - } - - if (template.type === 'UNetField') { - fieldValue.value = undefined; - } - - if (template.type === 'ClipField') { - fieldValue.value = undefined; - } - - if (template.type === 'VaeField') { - fieldValue.value = undefined; - } - - if (template.type === 'ControlField') { - fieldValue.value = undefined; - } - - if (template.type === 'MainModelField') { - fieldValue.value = undefined; - } - - if (template.type === 'SDXLRefinerModelField') { - fieldValue.value = undefined; - } - - if (template.type === 'VaeModelField') { - fieldValue.value = undefined; - } - - if (template.type === 'LoRAModelField') { - fieldValue.value = undefined; - } - - if (template.type === 'ControlNetModelField') { - fieldValue.value = undefined; - } - - if (template.type === 'Scheduler') { - fieldValue.value = 'euler'; + } else { + fieldValue.value = + template.default ?? FIELD_VALUE_FALLBACK_MAP[template.type]; } return fieldValue; diff --git a/invokeai/frontend/web/src/services/api/schema.d.ts b/invokeai/frontend/web/src/services/api/schema.d.ts index c38aa95fe5..8f17985bbd 100644 --- a/invokeai/frontend/web/src/services/api/schema.d.ts +++ b/invokeai/frontend/web/src/services/api/schema.d.ts @@ -1397,9 +1397,8 @@ export type components = { /** * Control Model * @description ControlNet model to load - * @default lllyasviel/sd-controlnet-canny */ - control_model?: components["schemas"]["ControlNetModelField"]; + control_model: components["schemas"]["ControlNetModelField"]; /** * Control Weight * @description The weight given to the ControlNet @@ -5806,12 +5805,12 @@ export type components = { */ target_height?: number; /** - * Clip + * CLIP 1 * @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count */ clip?: components["schemas"]["ClipField"]; /** - * Clip2 + * CLIP 2 * @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count */ clip2?: components["schemas"]["ClipField"]; @@ -5855,7 +5854,7 @@ export type components = { */ weight?: number; /** - * UNET + * UNet * @description UNet (scheduler, LoRAs) */ unet?: components["schemas"]["UNetField"]; @@ -6998,7 +6997,7 @@ export type components = { * If a field should be provided a data type that does not exactly match the python type of the field, use this to provide the type that should be used instead. See the node development docs for detail on adding a new field type, which involves client-side changes. * @enum {string} */ - UIType: "integer" | "float" | "boolean" | "string" | "array" | "ImageField" | "LatentsField" | "ConditioningField" | "ControlField" | "ColorField" | "ImageCollection" | "ConditioningCollection" | "ColorCollection" | "LatentsCollection" | "IntegerCollection" | "FloatCollection" | "StringCollection" | "BooleanCollection" | "MainModelField" | "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VaeModelField" | "LoRAModelField" | "ControlNetModelField" | "UNetField" | "VaeField" | "ClipField" | "Collection" | "CollectionItem" | "enum" | "Scheduler" | "WorkflowField" | "IsIntermediate" | "MetadataField"; + UIType: "boolean" | "ColorField" | "ConditioningField" | "ControlField" | "float" | "ImageField" | "integer" | "LatentsField" | "string" | "BooleanCollection" | "ColorCollection" | "ConditioningCollection" | "ControlCollection" | "FloatCollection" | "ImageCollection" | "IntegerCollection" | "LatentsCollection" | "StringCollection" | "BooleanPolymorphic" | "ColorPolymorphic" | "ConditioningPolymorphic" | "ControlPolymorphic" | "FloatPolymorphic" | "ImagePolymorphic" | "IntegerPolymorphic" | "LatentsPolymorphic" | "StringPolymorphic" | "MainModelField" | "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VaeModelField" | "LoRAModelField" | "ControlNetModelField" | "UNetField" | "VaeField" | "ClipField" | "Collection" | "CollectionItem" | "enum" | "Scheduler" | "WorkflowField" | "IsIntermediate" | "MetadataField"; /** * UIComponent * @description The type of UI component to use for a field, used to override the default components, which are inferred from the field type. @@ -7020,6 +7019,8 @@ export type components = { ui_component?: components["schemas"]["UIComponent"]; /** Ui Order */ ui_order?: number; + /** Item Default */ + item_default?: unknown; }; /** * _OutputField @@ -7041,6 +7042,12 @@ export type components = { * @enum {string} */ StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; + /** + * StableDiffusionOnnxModelFormat + * @description An enumeration. + * @enum {string} + */ + StableDiffusionOnnxModelFormat: "olive" | "onnx"; /** * ControlNetModelFormat * @description An enumeration. @@ -7059,12 +7066,6 @@ export type components = { * @enum {string} */ StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; - /** - * StableDiffusionOnnxModelFormat - * @description An enumeration. - * @enum {string} - */ - StableDiffusionOnnxModelFormat: "olive" | "onnx"; }; responses: never; parameters: never; From 09803b075d62cb091bc406a96afc5d56cea96718 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 4 Sep 2023 10:22:36 +1000 Subject: [PATCH 02/13] fix(ui): fix node value checks to compare to undefined existing checks would fail if falsy values --- .../frontend/web/src/common/hooks/useIsReadyToInvoke.ts | 6 +++++- .../web/src/features/nodes/hooks/useDoesInputHaveValue.ts | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/invokeai/frontend/web/src/common/hooks/useIsReadyToInvoke.ts b/invokeai/frontend/web/src/common/hooks/useIsReadyToInvoke.ts index e06a1106c1..3820b07daf 100644 --- a/invokeai/frontend/web/src/common/hooks/useIsReadyToInvoke.ts +++ b/invokeai/frontend/web/src/common/hooks/useIsReadyToInvoke.ts @@ -63,7 +63,11 @@ const selector = createSelector( return; } - if (fieldTemplate.required && !field.value && !hasConnection) { + if ( + fieldTemplate.required && + field.value === undefined && + !hasConnection + ) { reasons.push( `${node.data.label || nodeTemplate.title} -> ${ field.label || fieldTemplate.title diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useDoesInputHaveValue.ts b/invokeai/frontend/web/src/features/nodes/hooks/useDoesInputHaveValue.ts index f56099ed2b..83bf6b8af0 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useDoesInputHaveValue.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useDoesInputHaveValue.ts @@ -15,7 +15,7 @@ export const useDoesInputHaveValue = (nodeId: string, fieldName: string) => { if (!isInvocationNode(node)) { return; } - return Boolean(node?.data.inputs[fieldName]?.value); + return node?.data.inputs[fieldName]?.value !== undefined; }, defaultSelectorOptions ), From a765f01c08cb27f8559808ab9b00f98d8f5c755a Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 4 Sep 2023 10:23:48 +1000 Subject: [PATCH 03/13] chore(ui): typegen --- .../frontend/web/src/services/api/schema.d.ts | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/invokeai/frontend/web/src/services/api/schema.d.ts b/invokeai/frontend/web/src/services/api/schema.d.ts index 8f17985bbd..a895e6a230 100644 --- a/invokeai/frontend/web/src/services/api/schema.d.ts +++ b/invokeai/frontend/web/src/services/api/schema.d.ts @@ -6804,7 +6804,7 @@ export type components = { * Seamless Axes * @description Axes("x" and "y") to which apply seamless */ - seamless_axes?: string[]; + seamless_axes: string[]; }; /** Upscaler */ Upscaler: { @@ -6843,7 +6843,7 @@ export type components = { * Seamless Axes * @description Axes("x" and "y") to which apply seamless */ - seamless_axes?: string[]; + seamless_axes: string[]; }; /** * VAE @@ -7042,12 +7042,6 @@ export type components = { * @enum {string} */ StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; - /** - * StableDiffusionOnnxModelFormat - * @description An enumeration. - * @enum {string} - */ - StableDiffusionOnnxModelFormat: "olive" | "onnx"; /** * ControlNetModelFormat * @description An enumeration. @@ -7066,6 +7060,12 @@ export type components = { * @enum {string} */ StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; + /** + * StableDiffusionOnnxModelFormat + * @description An enumeration. + * @enum {string} + */ + StableDiffusionOnnxModelFormat: "olive" | "onnx"; }; responses: never; parameters: never; From 92975130bd686e8f88a3e19bd3e3efa64b2f9bdc Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 4 Sep 2023 11:22:31 +1000 Subject: [PATCH 04/13] feat: allow float inputs to accept integers Pydantic automatically casts ints to floats. --- invokeai/app/services/graph.py | 4 ++++ .../web/src/features/nodes/hooks/useIsValidConnection.ts | 5 ++++- .../nodes/store/util/makeIsConnectionValidSelector.ts | 5 ++++- 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/invokeai/app/services/graph.py b/invokeai/app/services/graph.py index 18c99fafc1..0e40636280 100644 --- a/invokeai/app/services/graph.py +++ b/invokeai/app/services/graph.py @@ -112,6 +112,10 @@ def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool: if to_type in get_args(from_type): return True + # allow int -> float, pydantic will cast for us + if from_type is int and to_type is float: + return True + # if not issubclass(from_type, to_type): if not is_union_subtype(from_type, to_type): return False diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts index 1c372e551d..d1d10bb7e7 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts @@ -120,12 +120,15 @@ export const useIsValidConnection = () => { const isCollectionToGenericCollection = targetType === 'Collection' && COLLECTION_TYPES.includes(sourceType); + const isIntToFloat = sourceType === 'integer' && targetType === 'float'; + return ( isCollectionItemToNonCollection || isNonCollectionToCollectionItem || isAnythingToPolymorphicOfSameBaseType || isGenericCollectionToAnyCollectionOrPolymorphic || - isCollectionToGenericCollection + isCollectionToGenericCollection || + isIntToFloat ); } diff --git a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts index 0c5fee509c..5cb6d557e8 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts @@ -113,13 +113,16 @@ export const makeConnectionErrorSelector = ( const isCollectionToGenericCollection = targetType === 'Collection' && COLLECTION_TYPES.includes(sourceType); + const isIntToFloat = sourceType === 'integer' && targetType === 'float'; + if ( !( isCollectionItemToNonCollection || isNonCollectionToCollectionItem || isAnythingToPolymorphicOfSameBaseType || isGenericCollectionToAnyCollectionOrPolymorphic || - isCollectionToGenericCollection + isCollectionToGenericCollection || + isIntToFloat ) ) { return 'Field types must match'; From 446dc6bea1928255cb3c96f3a607f16d9dee999e Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 4 Sep 2023 15:23:10 +1000 Subject: [PATCH 05/13] fix(nodes): denoise_mask is connection-only, ui_order=6 --- invokeai/app/invocations/latent.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 96ad49165a..c0e53e4e12 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -215,8 +215,7 @@ class DenoiseLatentsInvocation(BaseInvocation): ) latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection) denoise_mask: Optional[DenoiseMaskField] = InputField( - default=None, - description=FieldDescriptions.mask, + default=None, description=FieldDescriptions.mask, input=Input.Connection, ui_order=6 ) @validator("cfg_scale") From d65553841e9181461bb028009f745a865348f7f4 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 4 Sep 2023 15:24:44 +1000 Subject: [PATCH 06/13] fix: remove default_factory for ImageCollectionInvocation --- invokeai/app/invocations/primitives.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index 81914ab2af..fdadc4b31b 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -261,7 +261,7 @@ class ImageInvocation(BaseInvocation): class ImageCollectionInvocation(BaseInvocation): """A collection of image primitive values""" - collection: list[ImageField] = InputField(default_factory=list, description="The collection of image values") + collection: list[ImageField] = InputField(description="The collection of image values") def invoke(self, context: InvocationContext) -> ImageCollectionOutput: return ImageCollectionOutput(collection=self.collection) From 34e3c2e000d5f5d756fcfb75e7a024d13b78c709 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 4 Sep 2023 15:24:53 +1000 Subject: [PATCH 07/13] feat(ui): style handles --- .../nodes/Invocation/fields/FieldHandle.tsx | 4 ++- .../web/src/features/nodes/types/constants.ts | 31 +++++++++++++------ 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx index 02b18e7178..3166590254 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx @@ -4,6 +4,7 @@ import { COLLECTION_TYPES, FIELDS, HANDLE_TOOLTIP_OPEN_DELAY, + MODEL_TYPES, POLYMORPHIC_TYPES, } from 'features/nodes/types/constants'; import { @@ -52,6 +53,7 @@ const FieldHandle = (props: FieldHandleProps) => { const styles: CSSProperties = useMemo(() => { const isCollectionType = COLLECTION_TYPES.includes(type); const isPolymorphicType = POLYMORPHIC_TYPES.includes(type); + const isModelType = MODEL_TYPES.includes(type); const color = colorTokenToCssVar(typeColor); const s: CSSProperties = { backgroundColor: @@ -64,7 +66,7 @@ const FieldHandle = (props: FieldHandleProps) => { borderWidth: isCollectionType || isPolymorphicType ? 4 : 0, borderStyle: 'solid', borderColor: color, - borderRadius: isPolymorphicType ? 4 : '100%', + borderRadius: isModelType ? 4 : '100%', zIndex: 1, }; diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts index dcd579d912..a12c1fbddc 100644 --- a/invokeai/frontend/web/src/features/nodes/types/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts @@ -39,6 +39,19 @@ export const POLYMORPHIC_TYPES = [ 'ColorPolymorphic', ]; +export const MODEL_TYPES = [ + 'ControlNetModelField', + 'LoRAModelField', + 'MainModelField', + 'ONNXModelField', + 'SDXLMainModelField', + 'SDXLRefinerModelField', + 'VaeModelField', + 'UNetField', + 'VaeField', + 'ClipField', +]; + export const COLLECTION_MAP = { integer: 'IntegerCollection', boolean: 'BooleanCollection', @@ -103,7 +116,7 @@ export const FIELDS: Record = { title: 'Boolean Polymorphic', }, ClipField: { - color: 'green.300', + color: 'green.500', description: 'Tokenizer and text_encoder submodels.', title: 'Clip', }, @@ -238,17 +251,17 @@ export const FIELDS: Record = { title: 'Latents Polymorphic', }, LoRAModelField: { - color: 'teal.300', + color: 'teal.500', description: 'TODO', title: 'LoRA', }, MainModelField: { - color: 'teal.300', + color: 'teal.500', description: 'TODO', title: 'Model', }, ONNXModelField: { - color: 'teal.300', + color: 'teal.500', description: 'ONNX model field.', title: 'ONNX Model', }, @@ -258,12 +271,12 @@ export const FIELDS: Record = { title: 'Scheduler', }, SDXLMainModelField: { - color: 'teal.300', + color: 'teal.500', description: 'SDXL model field.', title: 'SDXL Model', }, SDXLRefinerModelField: { - color: 'teal.300', + color: 'teal.500', description: 'TODO', title: 'Refiner Model', }, @@ -283,17 +296,17 @@ export const FIELDS: Record = { title: 'String Polymorphic', }, UNetField: { - color: 'red.300', + color: 'red.500', description: 'UNet submodel.', title: 'UNet', }, VaeField: { - color: 'blue.300', + color: 'blue.500', description: 'Vae submodel.', title: 'Vae', }, VaeModelField: { - color: 'teal.300', + color: 'teal.500', description: 'TODO', title: 'VAE', }, From 920fc0e75190d3a90cef1e01fae96ac066183140 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 4 Sep 2023 15:25:58 +1000 Subject: [PATCH 08/13] chore(ui): typegen --- .../frontend/web/src/services/api/schema.d.ts | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/invokeai/frontend/web/src/services/api/schema.d.ts b/invokeai/frontend/web/src/services/api/schema.d.ts index a895e6a230..f48892113d 100644 --- a/invokeai/frontend/web/src/services/api/schema.d.ts +++ b/invokeai/frontend/web/src/services/api/schema.d.ts @@ -6804,7 +6804,7 @@ export type components = { * Seamless Axes * @description Axes("x" and "y") to which apply seamless */ - seamless_axes: string[]; + seamless_axes?: string[]; }; /** Upscaler */ Upscaler: { @@ -6843,7 +6843,7 @@ export type components = { * Seamless Axes * @description Axes("x" and "y") to which apply seamless */ - seamless_axes: string[]; + seamless_axes?: string[]; }; /** * VAE @@ -7036,6 +7036,12 @@ export type components = { /** Ui Order */ ui_order?: number; }; + /** + * StableDiffusionOnnxModelFormat + * @description An enumeration. + * @enum {string} + */ + StableDiffusionOnnxModelFormat: "olive" | "onnx"; /** * StableDiffusion1ModelFormat * @description An enumeration. @@ -7060,12 +7066,6 @@ export type components = { * @enum {string} */ StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; - /** - * StableDiffusionOnnxModelFormat - * @description An enumeration. - * @enum {string} - */ - StableDiffusionOnnxModelFormat: "olive" | "onnx"; }; responses: never; parameters: never; From 59cb6305b92f843138a7b216bddf393a272ef52e Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 4 Sep 2023 19:07:41 +1000 Subject: [PATCH 09/13] feat(tests): add tests for decorator and int -> float --- tests/nodes/test_node_graph.py | 56 +++++++++++++++++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/tests/nodes/test_node_graph.py b/tests/nodes/test_node_graph.py index fe6709827f..56bf823d14 100644 --- a/tests/nodes/test_node_graph.py +++ b/tests/nodes/test_node_graph.py @@ -1,3 +1,4 @@ +from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output from .test_nodes import ( ImageToImageTestInvocation, TextToImageTestInvocation, @@ -20,7 +21,7 @@ from invokeai.app.invocations.upscale import ESRGANInvocation from invokeai.app.invocations.image import ShowImageInvocation from invokeai.app.invocations.math import AddInvocation, SubtractInvocation -from invokeai.app.invocations.primitives import IntegerInvocation +from invokeai.app.invocations.primitives import FloatInvocation, IntegerInvocation from invokeai.app.services.default_graphs import create_text_to_image import pytest @@ -610,6 +611,59 @@ def test_graph_can_deserialize(): assert g2.edges[0].destination.field == "image" +def test_invocation_decorator(): + invocation_type = "test_invocation" + title = "Test Invocation" + tags = ["first", "second", "third"] + category = "category" + + @invocation(invocation_type, title=title, tags=tags, category=category) + class Test(BaseInvocation): + def invoke(self): + pass + + schema = Test.schema() + + assert schema.get("title") == title + assert schema.get("tags") == tags + assert schema.get("category") == category + assert Test(id="1").type == invocation_type # type: ignore (type is dynamically added) + + +def test_invocation_output_decorator(): + output_type = "test_output" + + @invocation_output(output_type) + class TestOutput(BaseInvocationOutput): + pass + + assert TestOutput().type == output_type # type: ignore (type is dynamically added) + + +def test_floats_accept_ints(): + g = Graph() + n1 = IntegerInvocation(id="1", value=1) + n2 = FloatInvocation(id="2") + g.add_node(n1) + g.add_node(n2) + e = create_edge(n1.id, "value", n2.id, "value") + + # Not throwing on this line is sufficient + g.add_edge(e) + + +def test_ints_do_not_accept_floats(): + g = Graph() + n1 = FloatInvocation(id="1", value=1.0) + n2 = IntegerInvocation(id="2") + g.add_node(n1) + g.add_node(n2) + e = create_edge(n1.id, "value", n2.id, "value") + + with pytest.raises(InvalidEdgeError): + g.add_edge(e) + + def test_graph_can_generate_schema(): # Not throwing on this line is sufficient # NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation From d9148fb619864231c2948034647710dbfe65a0ad Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 4 Sep 2023 18:11:56 +1000 Subject: [PATCH 10/13] feat(nodes): add version to node schemas The `@invocation` decorator is extended with an optional `version` arg. On execution of the decorator, the version string is parsed using the `semver` package (this was an indirect dependency and has been added to `pyproject.toml`). All built-in nodes are set with `version="1.0.0"`. The version is added to the OpenAPI Schema for consumption by the client. --- invokeai/app/invocations/baseinvocation.py | 15 ++++++- invokeai/app/invocations/collections.py | 6 ++- invokeai/app/invocations/compel.py | 6 ++- .../controlnet_image_processors.py | 27 ++++++++++-- invokeai/app/invocations/cv.py | 7 +-- invokeai/app/invocations/image.py | 44 +++++++++++-------- invokeai/app/invocations/infill.py | 10 +++-- invokeai/app/invocations/latent.py | 21 ++++++--- invokeai/app/invocations/math.py | 10 ++--- invokeai/app/invocations/metadata.py | 4 +- invokeai/app/invocations/model.py | 10 ++--- invokeai/app/invocations/noise.py | 2 +- invokeai/app/invocations/onnx.py | 6 ++- invokeai/app/invocations/param_easing.py | 4 +- invokeai/app/invocations/primitives.py | 28 +++++++++--- invokeai/app/invocations/prompt.py | 4 +- invokeai/app/invocations/sdxl.py | 3 +- invokeai/app/invocations/upscale.py | 2 +- pyproject.toml | 1 + 19 files changed, 139 insertions(+), 71 deletions(-) diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index ccc2b4d05f..540571762f 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -26,6 +26,7 @@ from typing import ( from pydantic import BaseModel, Field, validator from pydantic.fields import Undefined, ModelField from pydantic.typing import NoArgAnyCallable +import semver if TYPE_CHECKING: from ..services.invocation_services import InvocationServices @@ -401,6 +402,9 @@ class UIConfigBase(BaseModel): tags: Optional[list[str]] = Field(default_factory=None, description="The node's tags") title: Optional[str] = Field(default=None, description="The node's display name") category: Optional[str] = Field(default=None, description="The node's category") + version: Optional[str] = Field( + default=None, description='The node\'s version. Should be a valid semver string e.g. "1.0.0" or "3.8.13".' + ) class InvocationContext: @@ -499,6 +503,8 @@ class BaseInvocation(ABC, BaseModel): schema["tags"] = uiconfig.tags if uiconfig and hasattr(uiconfig, "category"): schema["category"] = uiconfig.category + if uiconfig and hasattr(uiconfig, "version"): + schema["version"] = uiconfig.version if "required" not in schema or not isinstance(schema["required"], list): schema["required"] = list() schema["required"].extend(["type", "id"]) @@ -567,7 +573,11 @@ GenericBaseInvocation = TypeVar("GenericBaseInvocation", bound=BaseInvocation) def invocation( - invocation_type: str, title: Optional[str] = None, tags: Optional[list[str]] = None, category: Optional[str] = None + invocation_type: str, + title: Optional[str] = None, + tags: Optional[list[str]] = None, + category: Optional[str] = None, + version: Optional[str] = None, ) -> Callable[[Type[GenericBaseInvocation]], Type[GenericBaseInvocation]]: """ Adds metadata to an invocation. @@ -594,6 +604,9 @@ def invocation( cls.UIConfig.tags = tags if category is not None: cls.UIConfig.category = category + if version is not None: + semver.Version.parse(version) # raises ValueError if invalid semver + cls.UIConfig.version = version # Add the invocation type to the pydantic model of the invocation invocation_type_annotation = Literal[invocation_type] # type: ignore diff --git a/invokeai/app/invocations/collections.py b/invokeai/app/invocations/collections.py index 979f2e43b7..2814a9c3ca 100644 --- a/invokeai/app/invocations/collections.py +++ b/invokeai/app/invocations/collections.py @@ -10,7 +10,9 @@ from invokeai.app.util.misc import SEED_MAX, get_random_seed from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation -@invocation("range", title="Integer Range", tags=["collection", "integer", "range"], category="collections") +@invocation( + "range", title="Integer Range", tags=["collection", "integer", "range"], category="collections", version="1.0.0" +) class RangeInvocation(BaseInvocation): """Creates a range of numbers from start to stop with step""" @@ -33,6 +35,7 @@ class RangeInvocation(BaseInvocation): title="Integer Range of Size", tags=["collection", "integer", "size", "range"], category="collections", + version="1.0.0", ) class RangeOfSizeInvocation(BaseInvocation): """Creates a range from start to start + size with step""" @@ -50,6 +53,7 @@ class RangeOfSizeInvocation(BaseInvocation): title="Random Range", tags=["range", "integer", "random", "collection"], category="collections", + version="1.0.0", ) class RandomRangeInvocation(BaseInvocation): """Creates a collection of random numbers""" diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 563d8d97fd..4557c57820 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -44,7 +44,7 @@ class ConditioningFieldData: # PerpNeg = "perp_neg" -@invocation("compel", title="Prompt", tags=["prompt", "compel"], category="conditioning") +@invocation("compel", title="Prompt", tags=["prompt", "compel"], category="conditioning", version="1.0.0") class CompelInvocation(BaseInvocation): """Parse prompt using compel package to conditioning.""" @@ -267,6 +267,7 @@ class SDXLPromptInvocationBase: title="SDXL Prompt", tags=["sdxl", "compel", "prompt"], category="conditioning", + version="1.0.0", ) class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): """Parse prompt using compel package to conditioning.""" @@ -351,6 +352,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): title="SDXL Refiner Prompt", tags=["sdxl", "compel", "prompt"], category="conditioning", + version="1.0.0", ) class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): """Parse prompt using compel package to conditioning.""" @@ -403,7 +405,7 @@ class ClipSkipInvocationOutput(BaseInvocationOutput): clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP") -@invocation("clip_skip", title="CLIP Skip", tags=["clipskip", "clip", "skip"], category="conditioning") +@invocation("clip_skip", title="CLIP Skip", tags=["clipskip", "clip", "skip"], category="conditioning", version="1.0.0") class ClipSkipInvocation(BaseInvocation): """Skip layers in clip text_encoder model.""" diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index 272afb3a4c..2c2eab9155 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -95,7 +95,7 @@ class ControlOutput(BaseInvocationOutput): control: ControlField = OutputField(description=FieldDescriptions.control) -@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet") +@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.0.0") class ControlNetInvocation(BaseInvocation): """Collects ControlNet info to pass to other nodes""" @@ -127,7 +127,9 @@ class ControlNetInvocation(BaseInvocation): ) -@invocation("image_processor", title="Base Image Processor", tags=["controlnet"], category="controlnet") +@invocation( + "image_processor", title="Base Image Processor", tags=["controlnet"], category="controlnet", version="1.0.0" +) class ImageProcessorInvocation(BaseInvocation): """Base class for invocations that preprocess images for ControlNet""" @@ -171,6 +173,7 @@ class ImageProcessorInvocation(BaseInvocation): title="Canny Processor", tags=["controlnet", "canny"], category="controlnet", + version="1.0.0", ) class CannyImageProcessorInvocation(ImageProcessorInvocation): """Canny edge detection for ControlNet""" @@ -193,6 +196,7 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation): title="HED (softedge) Processor", tags=["controlnet", "hed", "softedge"], category="controlnet", + version="1.0.0", ) class HedImageProcessorInvocation(ImageProcessorInvocation): """Applies HED edge detection to image""" @@ -221,6 +225,7 @@ class HedImageProcessorInvocation(ImageProcessorInvocation): title="Lineart Processor", tags=["controlnet", "lineart"], category="controlnet", + version="1.0.0", ) class LineartImageProcessorInvocation(ImageProcessorInvocation): """Applies line art processing to image""" @@ -242,6 +247,7 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation): title="Lineart Anime Processor", tags=["controlnet", "lineart", "anime"], category="controlnet", + version="1.0.0", ) class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation): """Applies line art anime processing to image""" @@ -264,6 +270,7 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation): title="Openpose Processor", tags=["controlnet", "openpose", "pose"], category="controlnet", + version="1.0.0", ) class OpenposeImageProcessorInvocation(ImageProcessorInvocation): """Applies Openpose processing to image""" @@ -288,6 +295,7 @@ class OpenposeImageProcessorInvocation(ImageProcessorInvocation): title="Midas Depth Processor", tags=["controlnet", "midas"], category="controlnet", + version="1.0.0", ) class MidasDepthImageProcessorInvocation(ImageProcessorInvocation): """Applies Midas depth processing to image""" @@ -314,6 +322,7 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation): title="Normal BAE Processor", tags=["controlnet"], category="controlnet", + version="1.0.0", ) class NormalbaeImageProcessorInvocation(ImageProcessorInvocation): """Applies NormalBae processing to image""" @@ -329,7 +338,9 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation): return processed_image -@invocation("mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet") +@invocation( + "mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.0.0" +) class MlsdImageProcessorInvocation(ImageProcessorInvocation): """Applies MLSD processing to image""" @@ -350,7 +361,9 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation): return processed_image -@invocation("pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet") +@invocation( + "pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.0.0" +) class PidiImageProcessorInvocation(ImageProcessorInvocation): """Applies PIDI processing to image""" @@ -376,6 +389,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation): title="Content Shuffle Processor", tags=["controlnet", "contentshuffle"], category="controlnet", + version="1.0.0", ) class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation): """Applies content shuffle processing to image""" @@ -405,6 +419,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation): title="Zoe (Depth) Processor", tags=["controlnet", "zoe", "depth"], category="controlnet", + version="1.0.0", ) class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation): """Applies Zoe depth processing to image""" @@ -420,6 +435,7 @@ class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation): title="Mediapipe Face Processor", tags=["controlnet", "mediapipe", "face"], category="controlnet", + version="1.0.0", ) class MediapipeFaceProcessorInvocation(ImageProcessorInvocation): """Applies mediapipe face processing to image""" @@ -442,6 +458,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation): title="Leres (Depth) Processor", tags=["controlnet", "leres", "depth"], category="controlnet", + version="1.0.0", ) class LeresImageProcessorInvocation(ImageProcessorInvocation): """Applies leres processing to image""" @@ -470,6 +487,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation): title="Tile Resample Processor", tags=["controlnet", "tile"], category="controlnet", + version="1.0.0", ) class TileResamplerProcessorInvocation(ImageProcessorInvocation): """Tile resampler processor""" @@ -509,6 +527,7 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation): title="Segment Anything Processor", tags=["controlnet", "segmentanything"], category="controlnet", + version="1.0.0", ) class SegmentAnythingProcessorInvocation(ImageProcessorInvocation): """Applies segment anything processing to image""" diff --git a/invokeai/app/invocations/cv.py b/invokeai/app/invocations/cv.py index 40d8867aa1..411aff8234 100644 --- a/invokeai/app/invocations/cv.py +++ b/invokeai/app/invocations/cv.py @@ -10,12 +10,7 @@ from invokeai.app.models.image import ImageCategory, ResourceOrigin from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation -@invocation( - "cv_inpaint", - title="OpenCV Inpaint", - tags=["opencv", "inpaint"], - category="inpaint", -) +@invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.0.0") class CvInpaintInvocation(BaseInvocation): """Simple inpaint using opencv.""" diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index 5eeead7db2..b6f7cc405b 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -16,7 +16,7 @@ from ..models.image import ImageCategory, ResourceOrigin from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, invocation -@invocation("show_image", title="Show Image", tags=["image"], category="image") +@invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.0") class ShowImageInvocation(BaseInvocation): """Displays a provided image using the OS image viewer, and passes it forward in the pipeline.""" @@ -36,7 +36,7 @@ class ShowImageInvocation(BaseInvocation): ) -@invocation("blank_image", title="Blank Image", tags=["image"], category="image") +@invocation("blank_image", title="Blank Image", tags=["image"], category="image", version="1.0.0") class BlankImageInvocation(BaseInvocation): """Creates a blank image and forwards it to the pipeline""" @@ -65,7 +65,7 @@ class BlankImageInvocation(BaseInvocation): ) -@invocation("img_crop", title="Crop Image", tags=["image", "crop"], category="image") +@invocation("img_crop", title="Crop Image", tags=["image", "crop"], category="image", version="1.0.0") class ImageCropInvocation(BaseInvocation): """Crops an image to a specified box. The box can be outside of the image.""" @@ -98,7 +98,7 @@ class ImageCropInvocation(BaseInvocation): ) -@invocation("img_paste", title="Paste Image", tags=["image", "paste"], category="image") +@invocation("img_paste", title="Paste Image", tags=["image", "paste"], category="image", version="1.0.0") class ImagePasteInvocation(BaseInvocation): """Pastes an image into another image.""" @@ -146,7 +146,7 @@ class ImagePasteInvocation(BaseInvocation): ) -@invocation("tomask", title="Mask from Alpha", tags=["image", "mask"], category="image") +@invocation("tomask", title="Mask from Alpha", tags=["image", "mask"], category="image", version="1.0.0") class MaskFromAlphaInvocation(BaseInvocation): """Extracts the alpha channel of an image as a mask.""" @@ -177,7 +177,7 @@ class MaskFromAlphaInvocation(BaseInvocation): ) -@invocation("img_mul", title="Multiply Images", tags=["image", "multiply"], category="image") +@invocation("img_mul", title="Multiply Images", tags=["image", "multiply"], category="image", version="1.0.0") class ImageMultiplyInvocation(BaseInvocation): """Multiplies two images together using `PIL.ImageChops.multiply()`.""" @@ -210,7 +210,7 @@ class ImageMultiplyInvocation(BaseInvocation): IMAGE_CHANNELS = Literal["A", "R", "G", "B"] -@invocation("img_chan", title="Extract Image Channel", tags=["image", "channel"], category="image") +@invocation("img_chan", title="Extract Image Channel", tags=["image", "channel"], category="image", version="1.0.0") class ImageChannelInvocation(BaseInvocation): """Gets a channel from an image.""" @@ -242,7 +242,7 @@ class ImageChannelInvocation(BaseInvocation): IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"] -@invocation("img_conv", title="Convert Image Mode", tags=["image", "convert"], category="image") +@invocation("img_conv", title="Convert Image Mode", tags=["image", "convert"], category="image", version="1.0.0") class ImageConvertInvocation(BaseInvocation): """Converts an image to a different mode.""" @@ -271,7 +271,7 @@ class ImageConvertInvocation(BaseInvocation): ) -@invocation("img_blur", title="Blur Image", tags=["image", "blur"], category="image") +@invocation("img_blur", title="Blur Image", tags=["image", "blur"], category="image", version="1.0.0") class ImageBlurInvocation(BaseInvocation): """Blurs an image""" @@ -325,7 +325,7 @@ PIL_RESAMPLING_MAP = { } -@invocation("img_resize", title="Resize Image", tags=["image", "resize"], category="image") +@invocation("img_resize", title="Resize Image", tags=["image", "resize"], category="image", version="1.0.0") class ImageResizeInvocation(BaseInvocation): """Resizes an image to specific dimensions""" @@ -365,7 +365,7 @@ class ImageResizeInvocation(BaseInvocation): ) -@invocation("img_scale", title="Scale Image", tags=["image", "scale"], category="image") +@invocation("img_scale", title="Scale Image", tags=["image", "scale"], category="image", version="1.0.0") class ImageScaleInvocation(BaseInvocation): """Scales an image by a factor""" @@ -406,7 +406,7 @@ class ImageScaleInvocation(BaseInvocation): ) -@invocation("img_lerp", title="Lerp Image", tags=["image", "lerp"], category="image") +@invocation("img_lerp", title="Lerp Image", tags=["image", "lerp"], category="image", version="1.0.0") class ImageLerpInvocation(BaseInvocation): """Linear interpolation of all pixels of an image""" @@ -439,7 +439,7 @@ class ImageLerpInvocation(BaseInvocation): ) -@invocation("img_ilerp", title="Inverse Lerp Image", tags=["image", "ilerp"], category="image") +@invocation("img_ilerp", title="Inverse Lerp Image", tags=["image", "ilerp"], category="image", version="1.0.0") class ImageInverseLerpInvocation(BaseInvocation): """Inverse linear interpolation of all pixels of an image""" @@ -472,7 +472,7 @@ class ImageInverseLerpInvocation(BaseInvocation): ) -@invocation("img_nsfw", title="Blur NSFW Image", tags=["image", "nsfw"], category="image") +@invocation("img_nsfw", title="Blur NSFW Image", tags=["image", "nsfw"], category="image", version="1.0.0") class ImageNSFWBlurInvocation(BaseInvocation): """Add blur to NSFW-flagged images""" @@ -517,7 +517,9 @@ class ImageNSFWBlurInvocation(BaseInvocation): return caution.resize((caution.width // 2, caution.height // 2)) -@invocation("img_watermark", title="Add Invisible Watermark", tags=["image", "watermark"], category="image") +@invocation( + "img_watermark", title="Add Invisible Watermark", tags=["image", "watermark"], category="image", version="1.0.0" +) class ImageWatermarkInvocation(BaseInvocation): """Add an invisible watermark to an image""" @@ -548,7 +550,7 @@ class ImageWatermarkInvocation(BaseInvocation): ) -@invocation("mask_edge", title="Mask Edge", tags=["image", "mask", "inpaint"], category="image") +@invocation("mask_edge", title="Mask Edge", tags=["image", "mask", "inpaint"], category="image", version="1.0.0") class MaskEdgeInvocation(BaseInvocation): """Applies an edge mask to an image""" @@ -593,7 +595,9 @@ class MaskEdgeInvocation(BaseInvocation): ) -@invocation("mask_combine", title="Combine Masks", tags=["image", "mask", "multiply"], category="image") +@invocation( + "mask_combine", title="Combine Masks", tags=["image", "mask", "multiply"], category="image", version="1.0.0" +) class MaskCombineInvocation(BaseInvocation): """Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`.""" @@ -623,7 +627,7 @@ class MaskCombineInvocation(BaseInvocation): ) -@invocation("color_correct", title="Color Correct", tags=["image", "color"], category="image") +@invocation("color_correct", title="Color Correct", tags=["image", "color"], category="image", version="1.0.0") class ColorCorrectInvocation(BaseInvocation): """ Shifts the colors of a target image to match the reference image, optionally @@ -728,7 +732,7 @@ class ColorCorrectInvocation(BaseInvocation): ) -@invocation("img_hue_adjust", title="Adjust Image Hue", tags=["image", "hue"], category="image") +@invocation("img_hue_adjust", title="Adjust Image Hue", tags=["image", "hue"], category="image", version="1.0.0") class ImageHueAdjustmentInvocation(BaseInvocation): """Adjusts the Hue of an image.""" @@ -774,6 +778,7 @@ class ImageHueAdjustmentInvocation(BaseInvocation): title="Adjust Image Luminosity", tags=["image", "luminosity", "hsl"], category="image", + version="1.0.0", ) class ImageLuminosityAdjustmentInvocation(BaseInvocation): """Adjusts the Luminosity (Value) of an image.""" @@ -826,6 +831,7 @@ class ImageLuminosityAdjustmentInvocation(BaseInvocation): title="Adjust Image Saturation", tags=["image", "saturation", "hsl"], category="image", + version="1.0.0", ) class ImageSaturationAdjustmentInvocation(BaseInvocation): """Adjusts the Saturation of an image.""" diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index 438c56e312..fa322e7864 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -116,7 +116,7 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int] return si -@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint") +@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0") class InfillColorInvocation(BaseInvocation): """Infills transparent areas of an image with a solid color""" @@ -151,7 +151,7 @@ class InfillColorInvocation(BaseInvocation): ) -@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint") +@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0") class InfillTileInvocation(BaseInvocation): """Infills transparent areas of an image with tiles of the image""" @@ -187,7 +187,9 @@ class InfillTileInvocation(BaseInvocation): ) -@invocation("infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint") +@invocation( + "infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0" +) class InfillPatchMatchInvocation(BaseInvocation): """Infills transparent areas of an image using the PatchMatch algorithm""" @@ -218,7 +220,7 @@ class InfillPatchMatchInvocation(BaseInvocation): ) -@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint") +@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0") class LaMaInfillInvocation(BaseInvocation): """Infills transparent areas of an image using the LaMa model""" diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index c0e53e4e12..8fde088b36 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -74,7 +74,7 @@ class SchedulerOutput(BaseInvocationOutput): scheduler: SAMPLER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler) -@invocation("scheduler", title="Scheduler", tags=["scheduler"], category="latents") +@invocation("scheduler", title="Scheduler", tags=["scheduler"], category="latents", version="1.0.0") class SchedulerInvocation(BaseInvocation): """Selects a scheduler.""" @@ -86,7 +86,9 @@ class SchedulerInvocation(BaseInvocation): return SchedulerOutput(scheduler=self.scheduler) -@invocation("create_denoise_mask", title="Create Denoise Mask", tags=["mask", "denoise"], category="latents") +@invocation( + "create_denoise_mask", title="Create Denoise Mask", tags=["mask", "denoise"], category="latents", version="1.0.0" +) class CreateDenoiseMaskInvocation(BaseInvocation): """Creates mask for denoising model run.""" @@ -186,6 +188,7 @@ def get_scheduler( title="Denoise Latents", tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"], category="latents", + version="1.0.0", ) class DenoiseLatentsInvocation(BaseInvocation): """Denoises noisy latents to decodable images""" @@ -544,7 +547,9 @@ class DenoiseLatentsInvocation(BaseInvocation): return build_latents_output(latents_name=name, latents=result_latents, seed=seed) -@invocation("l2i", title="Latents to Image", tags=["latents", "image", "vae", "l2i"], category="latents") +@invocation( + "l2i", title="Latents to Image", tags=["latents", "image", "vae", "l2i"], category="latents", version="1.0.0" +) class LatentsToImageInvocation(BaseInvocation): """Generates an image from latents.""" @@ -641,7 +646,7 @@ class LatentsToImageInvocation(BaseInvocation): LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"] -@invocation("lresize", title="Resize Latents", tags=["latents", "resize"], category="latents") +@invocation("lresize", title="Resize Latents", tags=["latents", "resize"], category="latents", version="1.0.0") class ResizeLatentsInvocation(BaseInvocation): """Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8.""" @@ -685,7 +690,7 @@ class ResizeLatentsInvocation(BaseInvocation): return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed) -@invocation("lscale", title="Scale Latents", tags=["latents", "resize"], category="latents") +@invocation("lscale", title="Scale Latents", tags=["latents", "resize"], category="latents", version="1.0.0") class ScaleLatentsInvocation(BaseInvocation): """Scales latents by a given factor.""" @@ -721,7 +726,9 @@ class ScaleLatentsInvocation(BaseInvocation): return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed) -@invocation("i2l", title="Image to Latents", tags=["latents", "image", "vae", "i2l"], category="latents") +@invocation( + "i2l", title="Image to Latents", tags=["latents", "image", "vae", "i2l"], category="latents", version="1.0.0" +) class ImageToLatentsInvocation(BaseInvocation): """Encodes an image into latents.""" @@ -801,7 +808,7 @@ class ImageToLatentsInvocation(BaseInvocation): return build_latents_output(latents_name=name, latents=latents, seed=None) -@invocation("lblend", title="Blend Latents", tags=["latents", "blend"], category="latents") +@invocation("lblend", title="Blend Latents", tags=["latents", "blend"], category="latents", version="1.0.0") class BlendLatentsInvocation(BaseInvocation): """Blend two latents using a given alpha. Latents must have same size.""" diff --git a/invokeai/app/invocations/math.py b/invokeai/app/invocations/math.py index 2a8dc12b28..0bc8b7b950 100644 --- a/invokeai/app/invocations/math.py +++ b/invokeai/app/invocations/math.py @@ -7,7 +7,7 @@ from invokeai.app.invocations.primitives import IntegerOutput from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, invocation -@invocation("add", title="Add Integers", tags=["math", "add"], category="math") +@invocation("add", title="Add Integers", tags=["math", "add"], category="math", version="1.0.0") class AddInvocation(BaseInvocation): """Adds two numbers""" @@ -18,7 +18,7 @@ class AddInvocation(BaseInvocation): return IntegerOutput(value=self.a + self.b) -@invocation("sub", title="Subtract Integers", tags=["math", "subtract"], category="math") +@invocation("sub", title="Subtract Integers", tags=["math", "subtract"], category="math", version="1.0.0") class SubtractInvocation(BaseInvocation): """Subtracts two numbers""" @@ -29,7 +29,7 @@ class SubtractInvocation(BaseInvocation): return IntegerOutput(value=self.a - self.b) -@invocation("mul", title="Multiply Integers", tags=["math", "multiply"], category="math") +@invocation("mul", title="Multiply Integers", tags=["math", "multiply"], category="math", version="1.0.0") class MultiplyInvocation(BaseInvocation): """Multiplies two numbers""" @@ -40,7 +40,7 @@ class MultiplyInvocation(BaseInvocation): return IntegerOutput(value=self.a * self.b) -@invocation("div", title="Divide Integers", tags=["math", "divide"], category="math") +@invocation("div", title="Divide Integers", tags=["math", "divide"], category="math", version="1.0.0") class DivideInvocation(BaseInvocation): """Divides two numbers""" @@ -51,7 +51,7 @@ class DivideInvocation(BaseInvocation): return IntegerOutput(value=int(self.a / self.b)) -@invocation("rand_int", title="Random Integer", tags=["math", "random"], category="math") +@invocation("rand_int", title="Random Integer", tags=["math", "random"], category="math", version="1.0.0") class RandomIntInvocation(BaseInvocation): """Outputs a single random integer.""" diff --git a/invokeai/app/invocations/metadata.py b/invokeai/app/invocations/metadata.py index 9c028a2dc1..39fa3beba0 100644 --- a/invokeai/app/invocations/metadata.py +++ b/invokeai/app/invocations/metadata.py @@ -98,7 +98,9 @@ class MetadataAccumulatorOutput(BaseInvocationOutput): metadata: CoreMetadata = OutputField(description="The core metadata for the image") -@invocation("metadata_accumulator", title="Metadata Accumulator", tags=["metadata"], category="metadata") +@invocation( + "metadata_accumulator", title="Metadata Accumulator", tags=["metadata"], category="metadata", version="1.0.0" +) class MetadataAccumulatorInvocation(BaseInvocation): """Outputs a Core Metadata Object""" diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 5a1073df0a..571cb2e730 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -73,7 +73,7 @@ class LoRAModelField(BaseModel): base_model: BaseModelType = Field(description="Base model") -@invocation("main_model_loader", title="Main Model", tags=["model"], category="model") +@invocation("main_model_loader", title="Main Model", tags=["model"], category="model", version="1.0.0") class MainModelLoaderInvocation(BaseInvocation): """Loads a main model, outputting its submodels.""" @@ -173,7 +173,7 @@ class LoraLoaderOutput(BaseInvocationOutput): clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP") -@invocation("lora_loader", title="LoRA", tags=["model"], category="model") +@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.0") class LoraLoaderInvocation(BaseInvocation): """Apply selected lora to unet and text_encoder.""" @@ -244,7 +244,7 @@ class SDXLLoraLoaderOutput(BaseInvocationOutput): clip2: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2") -@invocation("sdxl_lora_loader", title="SDXL LoRA", tags=["lora", "model"], category="model") +@invocation("sdxl_lora_loader", title="SDXL LoRA", tags=["lora", "model"], category="model", version="1.0.0") class SDXLLoraLoaderInvocation(BaseInvocation): """Apply selected lora to unet and text_encoder.""" @@ -338,7 +338,7 @@ class VaeLoaderOutput(BaseInvocationOutput): vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE") -@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model") +@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.0") class VaeLoaderInvocation(BaseInvocation): """Loads a VAE model, outputting a VaeLoaderOutput""" @@ -376,7 +376,7 @@ class SeamlessModeOutput(BaseInvocationOutput): vae: Optional[VaeField] = OutputField(description=FieldDescriptions.vae, title="VAE") -@invocation("seamless", title="Seamless", tags=["seamless", "model"], category="model") +@invocation("seamless", title="Seamless", tags=["seamless", "model"], category="model", version="1.0.0") class SeamlessModeInvocation(BaseInvocation): """Applies the seamless transformation to the Model UNet and VAE.""" diff --git a/invokeai/app/invocations/noise.py b/invokeai/app/invocations/noise.py index 1f1d9fe3ce..c46747aa89 100644 --- a/invokeai/app/invocations/noise.py +++ b/invokeai/app/invocations/noise.py @@ -78,7 +78,7 @@ def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int): ) -@invocation("noise", title="Noise", tags=["latents", "noise"], category="latents") +@invocation("noise", title="Noise", tags=["latents", "noise"], category="latents", version="1.0.0") class NoiseInvocation(BaseInvocation): """Generates latent noise.""" diff --git a/invokeai/app/invocations/onnx.py b/invokeai/app/invocations/onnx.py index b61ea2da99..d346a5f58f 100644 --- a/invokeai/app/invocations/onnx.py +++ b/invokeai/app/invocations/onnx.py @@ -56,7 +56,7 @@ ORT_TO_NP_TYPE = { PRECISION_VALUES = Literal[tuple(list(ORT_TO_NP_TYPE.keys()))] -@invocation("prompt_onnx", title="ONNX Prompt (Raw)", tags=["prompt", "onnx"], category="conditioning") +@invocation("prompt_onnx", title="ONNX Prompt (Raw)", tags=["prompt", "onnx"], category="conditioning", version="1.0.0") class ONNXPromptInvocation(BaseInvocation): prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea) clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection) @@ -143,6 +143,7 @@ class ONNXPromptInvocation(BaseInvocation): title="ONNX Text to Latents", tags=["latents", "inference", "txt2img", "onnx"], category="latents", + version="1.0.0", ) class ONNXTextToLatentsInvocation(BaseInvocation): """Generates latents from conditionings.""" @@ -319,6 +320,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation): title="ONNX Latents to Image", tags=["latents", "image", "vae", "onnx"], category="image", + version="1.0.0", ) class ONNXLatentsToImageInvocation(BaseInvocation): """Generates an image from latents.""" @@ -403,7 +405,7 @@ class OnnxModelField(BaseModel): model_type: ModelType = Field(description="Model Type") -@invocation("onnx_model_loader", title="ONNX Main Model", tags=["onnx", "model"], category="model") +@invocation("onnx_model_loader", title="ONNX Main Model", tags=["onnx", "model"], category="model", version="1.0.0") class OnnxModelLoaderInvocation(BaseInvocation): """Loads a main model, outputting its submodels.""" diff --git a/invokeai/app/invocations/param_easing.py b/invokeai/app/invocations/param_easing.py index 1b3c0dc09e..9cfe447372 100644 --- a/invokeai/app/invocations/param_easing.py +++ b/invokeai/app/invocations/param_easing.py @@ -45,7 +45,7 @@ from invokeai.app.invocations.primitives import FloatCollectionOutput from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation -@invocation("float_range", title="Float Range", tags=["math", "range"], category="math") +@invocation("float_range", title="Float Range", tags=["math", "range"], category="math", version="1.0.0") class FloatLinearRangeInvocation(BaseInvocation): """Creates a range""" @@ -96,7 +96,7 @@ EASING_FUNCTION_KEYS = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))] # actually I think for now could just use CollectionOutput (which is list[Any] -@invocation("step_param_easing", title="Step Param Easing", tags=["step", "easing"], category="step") +@invocation("step_param_easing", title="Step Param Easing", tags=["step", "easing"], category="step", version="1.0.0") class StepParamEasingInvocation(BaseInvocation): """Experimental per-step parameter easing for denoising steps""" diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index fdadc4b31b..93cf29f7d6 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -44,7 +44,9 @@ class BooleanCollectionOutput(BaseInvocationOutput): ) -@invocation("boolean", title="Boolean Primitive", tags=["primitives", "boolean"], category="primitives") +@invocation( + "boolean", title="Boolean Primitive", tags=["primitives", "boolean"], category="primitives", version="1.0.0" +) class BooleanInvocation(BaseInvocation): """A boolean primitive value""" @@ -59,6 +61,7 @@ class BooleanInvocation(BaseInvocation): title="Boolean Collection Primitive", tags=["primitives", "boolean", "collection"], category="primitives", + version="1.0.0", ) class BooleanCollectionInvocation(BaseInvocation): """A collection of boolean primitive values""" @@ -90,7 +93,9 @@ class IntegerCollectionOutput(BaseInvocationOutput): ) -@invocation("integer", title="Integer Primitive", tags=["primitives", "integer"], category="primitives") +@invocation( + "integer", title="Integer Primitive", tags=["primitives", "integer"], category="primitives", version="1.0.0" +) class IntegerInvocation(BaseInvocation): """An integer primitive value""" @@ -105,6 +110,7 @@ class IntegerInvocation(BaseInvocation): title="Integer Collection Primitive", tags=["primitives", "integer", "collection"], category="primitives", + version="1.0.0", ) class IntegerCollectionInvocation(BaseInvocation): """A collection of integer primitive values""" @@ -136,7 +142,7 @@ class FloatCollectionOutput(BaseInvocationOutput): ) -@invocation("float", title="Float Primitive", tags=["primitives", "float"], category="primitives") +@invocation("float", title="Float Primitive", tags=["primitives", "float"], category="primitives", version="1.0.0") class FloatInvocation(BaseInvocation): """A float primitive value""" @@ -151,6 +157,7 @@ class FloatInvocation(BaseInvocation): title="Float Collection Primitive", tags=["primitives", "float", "collection"], category="primitives", + version="1.0.0", ) class FloatCollectionInvocation(BaseInvocation): """A collection of float primitive values""" @@ -182,7 +189,7 @@ class StringCollectionOutput(BaseInvocationOutput): ) -@invocation("string", title="String Primitive", tags=["primitives", "string"], category="primitives") +@invocation("string", title="String Primitive", tags=["primitives", "string"], category="primitives", version="1.0.0") class StringInvocation(BaseInvocation): """A string primitive value""" @@ -197,6 +204,7 @@ class StringInvocation(BaseInvocation): title="String Collection Primitive", tags=["primitives", "string", "collection"], category="primitives", + version="1.0.0", ) class StringCollectionInvocation(BaseInvocation): """A collection of string primitive values""" @@ -236,7 +244,7 @@ class ImageCollectionOutput(BaseInvocationOutput): ) -@invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives") +@invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives", version="1.0.0") class ImageInvocation(BaseInvocation): """An image primitive value""" @@ -257,6 +265,7 @@ class ImageInvocation(BaseInvocation): title="Image Collection Primitive", tags=["primitives", "image", "collection"], category="primitives", + version="1.0.0", ) class ImageCollectionInvocation(BaseInvocation): """A collection of image primitive values""" @@ -318,7 +327,9 @@ class LatentsCollectionOutput(BaseInvocationOutput): ) -@invocation("latents", title="Latents Primitive", tags=["primitives", "latents"], category="primitives") +@invocation( + "latents", title="Latents Primitive", tags=["primitives", "latents"], category="primitives", version="1.0.0" +) class LatentsInvocation(BaseInvocation): """A latents tensor primitive value""" @@ -335,6 +346,7 @@ class LatentsInvocation(BaseInvocation): title="Latents Collection Primitive", tags=["primitives", "latents", "collection"], category="primitives", + version="1.0.0", ) class LatentsCollectionInvocation(BaseInvocation): """A collection of latents tensor primitive values""" @@ -388,7 +400,7 @@ class ColorCollectionOutput(BaseInvocationOutput): ) -@invocation("color", title="Color Primitive", tags=["primitives", "color"], category="primitives") +@invocation("color", title="Color Primitive", tags=["primitives", "color"], category="primitives", version="1.0.0") class ColorInvocation(BaseInvocation): """A color primitive value""" @@ -430,6 +442,7 @@ class ConditioningCollectionOutput(BaseInvocationOutput): title="Conditioning Primitive", tags=["primitives", "conditioning"], category="primitives", + version="1.0.0", ) class ConditioningInvocation(BaseInvocation): """A conditioning tensor primitive value""" @@ -445,6 +458,7 @@ class ConditioningInvocation(BaseInvocation): title="Conditioning Collection Primitive", tags=["primitives", "conditioning", "collection"], category="primitives", + version="1.0.0", ) class ConditioningCollectionInvocation(BaseInvocation): """A collection of conditioning tensor primitive values""" diff --git a/invokeai/app/invocations/prompt.py b/invokeai/app/invocations/prompt.py index c42deeaa2c..69ce1dba49 100644 --- a/invokeai/app/invocations/prompt.py +++ b/invokeai/app/invocations/prompt.py @@ -10,7 +10,7 @@ from invokeai.app.invocations.primitives import StringCollectionOutput from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIComponent, invocation -@invocation("dynamic_prompt", title="Dynamic Prompt", tags=["prompt", "collection"], category="prompt") +@invocation("dynamic_prompt", title="Dynamic Prompt", tags=["prompt", "collection"], category="prompt", version="1.0.0") class DynamicPromptInvocation(BaseInvocation): """Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator""" @@ -29,7 +29,7 @@ class DynamicPromptInvocation(BaseInvocation): return StringCollectionOutput(collection=prompts) -@invocation("prompt_from_file", title="Prompts from File", tags=["prompt", "file"], category="prompt") +@invocation("prompt_from_file", title="Prompts from File", tags=["prompt", "file"], category="prompt", version="1.0.0") class PromptsFromFileInvocation(BaseInvocation): """Loads prompts from a text file""" diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index 288858a173..de4ea604b4 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -33,7 +33,7 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput): vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE") -@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model") +@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.0") class SDXLModelLoaderInvocation(BaseInvocation): """Loads an sdxl base model, outputting its submodels.""" @@ -119,6 +119,7 @@ class SDXLModelLoaderInvocation(BaseInvocation): title="SDXL Refiner Model", tags=["model", "sdxl", "refiner"], category="model", + version="1.0.0", ) class SDXLRefinerModelLoaderInvocation(BaseInvocation): """Loads an sdxl refiner model, outputting its submodels.""" diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py index e9fb3f9963..7dca6d9f21 100644 --- a/invokeai/app/invocations/upscale.py +++ b/invokeai/app/invocations/upscale.py @@ -23,7 +23,7 @@ ESRGAN_MODELS = Literal[ ] -@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan") +@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.0.0") class ESRGANInvocation(BaseInvocation): """Upscales an image using RealESRGAN.""" diff --git a/pyproject.toml b/pyproject.toml index 129538264d..4b06944b33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,6 +74,7 @@ dependencies = [ "rich~=13.3", "safetensors==0.3.1", "scikit-image~=0.21.0", + "semver~=3.0.1", "send2trash", "test-tube~=0.7.5", "torch~=2.0.1", From 4aca264308c382df515100fba2edcfc4d2dae79b Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 4 Sep 2023 18:42:10 +1000 Subject: [PATCH 11/13] feat(ui): handle node versions - Node versions are now added to node templates - Node data (including in workflows) include the version of the node - On loading a workflow, we check to see if the node and template versions match exactly. If not, a warning is logged to console. - The node info icon (top-right corner of node, which you may click to open the notes editor) now shows the version and mentions any issues. - Some workflow validation logic has been shifted around and is now executed in a redux listener. --- invokeai/frontend/web/package.json | 1 + .../middleware/listenerMiddleware/index.ts | 4 + .../listeners/workflowLoaded.ts | 55 +++++++++++ .../CurrentImage/CurrentImageButtons.tsx | 16 +--- .../SingleSelectionMenuItems.tsx | 15 +-- .../features/nodes/components/flow/Flow.tsx | 8 ++ .../nodes/Invocation/InvocationNodeNotes.tsx | 54 ++++++++++- .../features/nodes/hooks/useBuildNodeData.ts | 7 +- .../nodes/hooks/useDoNodeVersionsMatch.ts | 33 +++++++ .../nodes/hooks/useLoadWorkflowFromFile.tsx | 32 +------ .../web/src/features/nodes/store/actions.ts | 5 + .../features/nodes/store/reactFlowInstance.ts | 4 + .../web/src/features/nodes/types/types.ts | 43 ++++++--- .../src/features/nodes/util/parseSchema.ts | 4 +- .../features/nodes/util/validateWorkflow.ts | 96 +++++++++++++++++++ .../frontend/web/src/services/api/schema.d.ts | 29 +++--- invokeai/frontend/web/yarn.lock | 5 + 17 files changed, 324 insertions(+), 87 deletions(-) create mode 100644 invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoaded.ts create mode 100644 invokeai/frontend/web/src/features/nodes/hooks/useDoNodeVersionsMatch.ts create mode 100644 invokeai/frontend/web/src/features/nodes/store/reactFlowInstance.ts create mode 100644 invokeai/frontend/web/src/features/nodes/util/validateWorkflow.ts diff --git a/invokeai/frontend/web/package.json b/invokeai/frontend/web/package.json index cc1e17cf51..9a45dd89a5 100644 --- a/invokeai/frontend/web/package.json +++ b/invokeai/frontend/web/package.json @@ -75,6 +75,7 @@ "@reduxjs/toolkit": "^1.9.5", "@roarr/browser-log-writer": "^1.1.5", "@stevebel/png": "^1.5.1", + "compare-versions": "^6.1.0", "dateformat": "^5.0.3", "formik": "^2.4.3", "framer-motion": "^10.16.1", diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts index 4afe023fbb..261edba0af 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts @@ -84,6 +84,7 @@ import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas'; import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage'; import { addUserInvokedNodesListener } from './listeners/userInvokedNodes'; import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage'; +import { addWorkflowLoadedListener } from './listeners/workflowLoaded'; export const listenerMiddleware = createListenerMiddleware(); @@ -202,6 +203,9 @@ addBoardIdSelectedListener(); // Node schemas addReceivedOpenAPISchemaListener(); +// Workflows +addWorkflowLoadedListener(); + // DND addImageDroppedListener(); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoaded.ts new file mode 100644 index 0000000000..c447720941 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoaded.ts @@ -0,0 +1,55 @@ +import { logger } from 'app/logging/logger'; +import { workflowLoadRequested } from 'features/nodes/store/actions'; +import { workflowLoaded } from 'features/nodes/store/nodesSlice'; +import { $flow } from 'features/nodes/store/reactFlowInstance'; +import { validateWorkflow } from 'features/nodes/util/validateWorkflow'; +import { addToast } from 'features/system/store/systemSlice'; +import { makeToast } from 'features/system/util/makeToast'; +import { setActiveTab } from 'features/ui/store/uiSlice'; +import { startAppListening } from '..'; + +export const addWorkflowLoadedListener = () => { + startAppListening({ + actionCreator: workflowLoadRequested, + effect: (action, { dispatch, getState }) => { + const log = logger('nodes'); + const workflow = action.payload; + const nodeTemplates = getState().nodes.nodeTemplates; + + const { workflow: validatedWorkflow, errors } = validateWorkflow( + workflow, + nodeTemplates + ); + + dispatch(workflowLoaded(validatedWorkflow)); + + if (!errors.length) { + dispatch( + addToast( + makeToast({ + title: 'Workflow Loaded', + status: 'success', + }) + ) + ); + } else { + dispatch( + addToast( + makeToast({ + title: 'Workflow Loaded with Warnings', + status: 'warning', + }) + ) + ); + errors.forEach(({ message, ...rest }) => { + log.warn(rest, message); + }); + } + + dispatch(setActiveTab('nodes')); + requestAnimationFrame(() => { + $flow.get()?.fitView(); + }); + }, + }); +}; diff --git a/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImageButtons.tsx b/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImageButtons.tsx index 3559679fc4..846cf5a6f0 100644 --- a/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImageButtons.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImageButtons.tsx @@ -17,16 +17,13 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIIconButton from 'common/components/IAIIconButton'; import { DeleteImageButton } from 'features/deleteImageModal/components/DeleteImageButton'; import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice'; -import { workflowLoaded } from 'features/nodes/store/nodesSlice'; +import { workflowLoadRequested } from 'features/nodes/store/actions'; import ParamUpscalePopover from 'features/parameters/components/Parameters/Upscale/ParamUpscaleSettings'; import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters'; import { initialImageSelected } from 'features/parameters/store/actions'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; -import { addToast } from 'features/system/store/systemSlice'; -import { makeToast } from 'features/system/util/makeToast'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { - setActiveTab, setShouldShowImageDetails, setShouldShowProgressInViewer, } from 'features/ui/store/uiSlice'; @@ -124,16 +121,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => { if (!workflow) { return; } - dispatch(workflowLoaded(workflow)); - dispatch(setActiveTab('nodes')); - dispatch( - addToast( - makeToast({ - title: 'Workflow Loaded', - status: 'success', - }) - ) - ); + dispatch(workflowLoadRequested(workflow)); }, [dispatch, workflow]); const handleClickUseAllParameters = useCallback(() => { diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/SingleSelectionMenuItems.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/SingleSelectionMenuItems.tsx index e75a7745bb..90272a3a86 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/SingleSelectionMenuItems.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/SingleSelectionMenuItems.tsx @@ -7,12 +7,9 @@ import { isModalOpenChanged, } from 'features/changeBoardModal/store/slice'; import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice'; -import { workflowLoaded } from 'features/nodes/store/nodesSlice'; import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters'; import { initialImageSelected } from 'features/parameters/store/actions'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; -import { addToast } from 'features/system/store/systemSlice'; -import { makeToast } from 'features/system/util/makeToast'; import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard'; import { setActiveTab } from 'features/ui/store/uiSlice'; import { memo, useCallback } from 'react'; @@ -36,6 +33,7 @@ import { } from 'services/api/endpoints/images'; import { ImageDTO } from 'services/api/types'; import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions'; +import { workflowLoadRequested } from 'features/nodes/store/actions'; type SingleSelectionMenuItemsProps = { imageDTO: ImageDTO; @@ -102,16 +100,7 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => { if (!workflow) { return; } - dispatch(workflowLoaded(workflow)); - dispatch(setActiveTab('nodes')); - dispatch( - addToast( - makeToast({ - title: 'Workflow Loaded', - status: 'success', - }) - ) - ); + dispatch(workflowLoadRequested(workflow)); }, [dispatch, workflow]); const handleSendToImageToImage = useCallback(() => { diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx index e8fb66d074..16af1fe12c 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx @@ -3,6 +3,7 @@ import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import { $flow } from 'features/nodes/store/reactFlowInstance'; import { contextMenusClosed } from 'features/ui/store/uiSlice'; import { useCallback } from 'react'; import { useHotkeys } from 'react-hotkeys-hook'; @@ -13,6 +14,7 @@ import { OnConnectStart, OnEdgesChange, OnEdgesDelete, + OnInit, OnMoveEnd, OnNodesChange, OnNodesDelete, @@ -147,6 +149,11 @@ export const Flow = () => { dispatch(contextMenusClosed()); }, [dispatch]); + const onInit: OnInit = useCallback((flow) => { + $flow.set(flow); + flow.fitView(); + }, []); + useHotkeys(['Ctrl+c', 'Meta+c'], (e) => { e.preventDefault(); dispatch(selectionCopied()); @@ -170,6 +177,7 @@ export const Flow = () => { edgeTypes={edgeTypes} nodes={nodes} edges={edges} + onInit={onInit} onNodesChange={onNodesChange} onEdgesChange={onEdgesChange} onEdgesDelete={onEdgesDelete} diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeNotes.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeNotes.tsx index eae05688b5..143785ecfe 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeNotes.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeNotes.tsx @@ -12,6 +12,7 @@ import { Tooltip, useDisclosure, } from '@chakra-ui/react'; +import { compare } from 'compare-versions'; import { useNodeData } from 'features/nodes/hooks/useNodeData'; import { useNodeLabel } from 'features/nodes/hooks/useNodeLabel'; import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate'; @@ -20,6 +21,7 @@ import { isInvocationNodeData } from 'features/nodes/types/types'; import { memo, useMemo } from 'react'; import { FaInfoCircle } from 'react-icons/fa'; import NotesTextarea from './NotesTextarea'; +import { useDoNodeVersionsMatch } from 'features/nodes/hooks/useDoNodeVersionsMatch'; interface Props { nodeId: string; @@ -29,6 +31,7 @@ const InvocationNodeNotes = ({ nodeId }: Props) => { const { isOpen, onOpen, onClose } = useDisclosure(); const label = useNodeLabel(nodeId); const title = useNodeTemplateTitle(nodeId); + const doVersionsMatch = useDoNodeVersionsMatch(nodeId); return ( <> @@ -50,7 +53,11 @@ const InvocationNodeNotes = ({ nodeId }: Props) => { > @@ -92,16 +99,59 @@ const TooltipContent = memo(({ nodeId }: { nodeId: string }) => { return 'Unknown Node'; }, [data, nodeTemplate]); + const versionComponent = useMemo(() => { + if (!isInvocationNodeData(data) || !nodeTemplate) { + return null; + } + + if (!data.version) { + return ( + + Version unknown + + ); + } + + if (!nodeTemplate.version) { + return ( + + Version {data.version} (unknown template) + + ); + } + + if (compare(data.version, nodeTemplate.version, '<')) { + return ( + + Version {data.version} (update node) + + ); + } + + if (compare(data.version, nodeTemplate.version, '>')) { + return ( + + Version {data.version} (update app) + + ); + } + + return Version {data.version}; + }, [data, nodeTemplate]); + if (!isInvocationNodeData(data)) { return Unknown Node; } return ( - {title} + + {title} + {nodeTemplate?.description} + {versionComponent} {data?.notes && {data.notes}} ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useBuildNodeData.ts b/invokeai/frontend/web/src/features/nodes/hooks/useBuildNodeData.ts index a88a82e1fc..24982f591e 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useBuildNodeData.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useBuildNodeData.ts @@ -138,13 +138,14 @@ export const useBuildNodeData = () => { data: { id: nodeId, type, - inputs, - outputs, - isOpen: true, + version: template.version, label: '', notes: '', + isOpen: true, embedWorkflow: false, isIntermediate: true, + inputs, + outputs, }, }; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useDoNodeVersionsMatch.ts b/invokeai/frontend/web/src/features/nodes/hooks/useDoNodeVersionsMatch.ts new file mode 100644 index 0000000000..926c56ac1e --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/hooks/useDoNodeVersionsMatch.ts @@ -0,0 +1,33 @@ +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import { compareVersions } from 'compare-versions'; +import { useMemo } from 'react'; +import { isInvocationNode } from '../types/types'; + +export const useDoNodeVersionsMatch = (nodeId: string) => { + const selector = useMemo( + () => + createSelector( + stateSelector, + ({ nodes }) => { + const node = nodes.nodes.find((node) => node.id === nodeId); + if (!isInvocationNode(node)) { + return false; + } + const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? '']; + if (!nodeTemplate?.version || !node.data?.version) { + return false; + } + return compareVersions(nodeTemplate.version, node.data.version) === 0; + }, + defaultSelectorOptions + ), + [nodeId] + ); + + const nodeTemplate = useAppSelector(selector); + + return nodeTemplate; +}; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useLoadWorkflowFromFile.tsx b/invokeai/frontend/web/src/features/nodes/hooks/useLoadWorkflowFromFile.tsx index 97f2cea77b..7f015ac5eb 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useLoadWorkflowFromFile.tsx +++ b/invokeai/frontend/web/src/features/nodes/hooks/useLoadWorkflowFromFile.tsx @@ -2,13 +2,13 @@ import { ListItem, Text, UnorderedList } from '@chakra-ui/react'; import { useLogger } from 'app/logging/useLogger'; import { useAppDispatch } from 'app/store/storeHooks'; import { parseify } from 'common/util/serialize'; -import { workflowLoaded } from 'features/nodes/store/nodesSlice'; -import { zValidatedWorkflow } from 'features/nodes/types/types'; +import { zWorkflow } from 'features/nodes/types/types'; import { addToast } from 'features/system/store/systemSlice'; import { makeToast } from 'features/system/util/makeToast'; import { memo, useCallback } from 'react'; import { ZodError } from 'zod'; import { fromZodError, fromZodIssue } from 'zod-validation-error'; +import { workflowLoadRequested } from '../store/actions'; export const useLoadWorkflowFromFile = () => { const dispatch = useAppDispatch(); @@ -24,7 +24,7 @@ export const useLoadWorkflowFromFile = () => { try { const parsedJSON = JSON.parse(String(rawJSON)); - const result = zValidatedWorkflow.safeParse(parsedJSON); + const result = zWorkflow.safeParse(parsedJSON); if (!result.success) { const { message } = fromZodError(result.error, { @@ -45,32 +45,8 @@ export const useLoadWorkflowFromFile = () => { reader.abort(); return; } - dispatch(workflowLoaded(result.data.workflow)); - if (!result.data.warnings.length) { - dispatch( - addToast( - makeToast({ - title: 'Workflow Loaded', - status: 'success', - }) - ) - ); - reader.abort(); - return; - } - - dispatch( - addToast( - makeToast({ - title: 'Workflow Loaded with Warnings', - status: 'warning', - }) - ) - ); - result.data.warnings.forEach(({ message, ...rest }) => { - logger.warn(rest, message); - }); + dispatch(workflowLoadRequested(result.data)); reader.abort(); } catch { diff --git a/invokeai/frontend/web/src/features/nodes/store/actions.ts b/invokeai/frontend/web/src/features/nodes/store/actions.ts index 2463a1e945..cf7ccf8238 100644 --- a/invokeai/frontend/web/src/features/nodes/store/actions.ts +++ b/invokeai/frontend/web/src/features/nodes/store/actions.ts @@ -1,5 +1,6 @@ import { createAction, isAnyOf } from '@reduxjs/toolkit'; import { Graph } from 'services/api/types'; +import { Workflow } from '../types/types'; export const textToImageGraphBuilt = createAction( 'nodes/textToImageGraphBuilt' @@ -16,3 +17,7 @@ export const isAnyGraphBuilt = isAnyOf( canvasGraphBuilt, nodesGraphBuilt ); + +export const workflowLoadRequested = createAction( + 'nodes/workflowLoadRequested' +); diff --git a/invokeai/frontend/web/src/features/nodes/store/reactFlowInstance.ts b/invokeai/frontend/web/src/features/nodes/store/reactFlowInstance.ts new file mode 100644 index 0000000000..e9094a9310 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/reactFlowInstance.ts @@ -0,0 +1,4 @@ +import { atom } from 'nanostores'; +import { ReactFlowInstance } from 'reactflow'; + +export const $flow = atom(null); diff --git a/invokeai/frontend/web/src/features/nodes/types/types.ts b/invokeai/frontend/web/src/features/nodes/types/types.ts index f7986a5028..402ef4ac7a 100644 --- a/invokeai/frontend/web/src/features/nodes/types/types.ts +++ b/invokeai/frontend/web/src/features/nodes/types/types.ts @@ -52,6 +52,10 @@ export type InvocationTemplate = { * The type of this node's output */ outputType: string; // TODO: generate a union of output types + /** + * The invocation's version. + */ + version?: string; }; export type FieldUIConfig = { @@ -962,6 +966,7 @@ export type InvocationSchemaExtra = { title: string; category?: string; tags?: string[]; + version?: string; properties: Omit< NonNullable & (_InputField | _OutputField), @@ -1095,6 +1100,29 @@ export const zCoreMetadata = z export type CoreMetadata = z.infer; +export const zSemVer = z.string().refine((val) => { + const [major, minor, patch] = val.split('.'); + return ( + major !== undefined && + Number.isInteger(Number(major)) && + minor !== undefined && + Number.isInteger(Number(minor)) && + patch !== undefined && + Number.isInteger(Number(patch)) + ); +}); + +export const zParsedSemver = zSemVer.transform((val) => { + const [major, minor, patch] = val.split('.'); + return { + major: Number(major), + minor: Number(minor), + patch: Number(patch), + }; +}); + +export type SemVer = z.infer; + export const zInvocationNodeData = z.object({ id: z.string().trim().min(1), // no easy way to build this dynamically, and we don't want to anyways, because this will be used @@ -1107,6 +1135,7 @@ export const zInvocationNodeData = z.object({ notes: z.string(), embedWorkflow: z.boolean(), isIntermediate: z.boolean(), + version: zSemVer.optional(), }); // Massage this to get better type safety while developing @@ -1195,20 +1224,6 @@ export const zFieldIdentifier = z.object({ export type FieldIdentifier = z.infer; -export const zSemVer = z.string().refine((val) => { - const [major, minor, patch] = val.split('.'); - return ( - major !== undefined && - minor !== undefined && - patch !== undefined && - Number.isInteger(Number(major)) && - Number.isInteger(Number(minor)) && - Number.isInteger(Number(patch)) - ); -}); - -export type SemVer = z.infer; - export type WorkflowWarning = { message: string; issues: string[]; diff --git a/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts b/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts index 78e7495481..d8bb189abc 100644 --- a/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts +++ b/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts @@ -73,6 +73,7 @@ export const parseSchema = ( const title = schema.title.replace('Invocation', ''); const tags = schema.tags ?? []; const description = schema.description ?? ''; + const version = schema.version ?? ''; const inputs = reduce( schema.properties, @@ -225,11 +226,12 @@ export const parseSchema = ( const invocation: InvocationTemplate = { title, type, + version, tags, description, + outputType, inputs, outputs, - outputType, }; Object.assign(invocationsAccumulator, { [type]: invocation }); diff --git a/invokeai/frontend/web/src/features/nodes/util/validateWorkflow.ts b/invokeai/frontend/web/src/features/nodes/util/validateWorkflow.ts new file mode 100644 index 0000000000..a3085d516b --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/validateWorkflow.ts @@ -0,0 +1,96 @@ +import { compareVersions } from 'compare-versions'; +import { cloneDeep, keyBy } from 'lodash-es'; +import { + InvocationTemplate, + Workflow, + WorkflowWarning, + isWorkflowInvocationNode, +} from '../types/types'; +import { parseify } from 'common/util/serialize'; + +export const validateWorkflow = ( + workflow: Workflow, + nodeTemplates: Record +) => { + const clone = cloneDeep(workflow); + const { nodes, edges } = clone; + const errors: WorkflowWarning[] = []; + const invocationNodes = nodes.filter(isWorkflowInvocationNode); + const keyedNodes = keyBy(invocationNodes, 'id'); + nodes.forEach((node) => { + if (!isWorkflowInvocationNode(node)) { + return; + } + + const nodeTemplate = nodeTemplates[node.data.type]; + if (!nodeTemplate) { + errors.push({ + message: `Node "${node.data.type}" skipped`, + issues: [`Node type "${node.data.type}" does not exist`], + data: node, + }); + return; + } + + if ( + nodeTemplate.version && + node.data.version && + compareVersions(nodeTemplate.version, node.data.version) !== 0 + ) { + errors.push({ + message: `Node "${node.data.type}" has mismatched version`, + issues: [ + `Node "${node.data.type}" v${node.data.version} may be incompatible with installed v${nodeTemplate.version}`, + ], + data: { node, nodeTemplate: parseify(nodeTemplate) }, + }); + return; + } + }); + edges.forEach((edge, i) => { + const sourceNode = keyedNodes[edge.source]; + const targetNode = keyedNodes[edge.target]; + const issues: string[] = []; + if (!sourceNode) { + issues.push(`Output node ${edge.source} does not exist`); + } else if ( + edge.type === 'default' && + !(edge.sourceHandle in sourceNode.data.outputs) + ) { + issues.push( + `Output field "${edge.source}.${edge.sourceHandle}" does not exist` + ); + } + if (!targetNode) { + issues.push(`Input node ${edge.target} does not exist`); + } else if ( + edge.type === 'default' && + !(edge.targetHandle in targetNode.data.inputs) + ) { + issues.push( + `Input field "${edge.target}.${edge.targetHandle}" does not exist` + ); + } + if (!nodeTemplates[sourceNode?.data.type ?? '__UNKNOWN_NODE_TYPE__']) { + issues.push( + `Source node "${edge.source}" missing template "${sourceNode?.data.type}"` + ); + } + if (!nodeTemplates[targetNode?.data.type ?? '__UNKNOWN_NODE_TYPE__']) { + issues.push( + `Source node "${edge.target}" missing template "${targetNode?.data.type}"` + ); + } + if (issues.length) { + delete edges[i]; + const src = edge.type === 'default' ? edge.sourceHandle : edge.source; + const tgt = edge.type === 'default' ? edge.targetHandle : edge.target; + errors.push({ + message: `Edge "${src} -> ${tgt}" skipped`, + issues, + data: edge, + }); + } + }); + return { workflow: clone, errors }; +}; diff --git a/invokeai/frontend/web/src/services/api/schema.d.ts b/invokeai/frontend/web/src/services/api/schema.d.ts index f48892113d..6e00d1b38b 100644 --- a/invokeai/frontend/web/src/services/api/schema.d.ts +++ b/invokeai/frontend/web/src/services/api/schema.d.ts @@ -6981,6 +6981,11 @@ export type components = { * @description The node's category */ category?: string; + /** + * Version + * @description The node's version. Should be a valid semver string e.g. "1.0.0" or "3.8.13". + */ + version?: string; }; /** * Input @@ -7036,24 +7041,12 @@ export type components = { /** Ui Order */ ui_order?: number; }; - /** - * StableDiffusionOnnxModelFormat - * @description An enumeration. - * @enum {string} - */ - StableDiffusionOnnxModelFormat: "olive" | "onnx"; /** * StableDiffusion1ModelFormat * @description An enumeration. * @enum {string} */ StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; - /** - * ControlNetModelFormat - * @description An enumeration. - * @enum {string} - */ - ControlNetModelFormat: "checkpoint" | "diffusers"; /** * StableDiffusionXLModelFormat * @description An enumeration. @@ -7066,6 +7059,18 @@ export type components = { * @enum {string} */ StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; + /** + * ControlNetModelFormat + * @description An enumeration. + * @enum {string} + */ + ControlNetModelFormat: "checkpoint" | "diffusers"; + /** + * StableDiffusionOnnxModelFormat + * @description An enumeration. + * @enum {string} + */ + StableDiffusionOnnxModelFormat: "olive" | "onnx"; }; responses: never; parameters: never; diff --git a/invokeai/frontend/web/yarn.lock b/invokeai/frontend/web/yarn.lock index 2c3c9ae88f..787c81a756 100644 --- a/invokeai/frontend/web/yarn.lock +++ b/invokeai/frontend/web/yarn.lock @@ -2970,6 +2970,11 @@ commondir@^1.0.1: resolved "https://registry.yarnpkg.com/commondir/-/commondir-1.0.1.tgz#ddd800da0c66127393cca5950ea968a3aaf1253b" integrity sha512-W9pAhw0ja1Edb5GVdIF1mjZw/ASI0AlShXM83UUGe2DVr5TdAPEA1OA8m/g8zWp9x6On7gqufY+FatDbC3MDQg== +compare-versions@^6.1.0: + version "6.1.0" + resolved "https://registry.yarnpkg.com/compare-versions/-/compare-versions-6.1.0.tgz#3f2131e3ae93577df111dba133e6db876ffe127a" + integrity sha512-LNZQXhqUvqUTotpZ00qLSaify3b4VFD588aRr8MKFw4CMUr98ytzCW5wDH5qx/DEY5kCDXcbcRuCqL0szEf2tg== + compute-scroll-into-view@1.0.20: version "1.0.20" resolved "https://registry.yarnpkg.com/compute-scroll-into-view/-/compute-scroll-into-view-1.0.20.tgz#1768b5522d1172754f5d0c9b02de3af6be506a43" From d6317bc53f1487b6a2a6838177f0715871cd6227 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 4 Sep 2023 18:45:58 +1000 Subject: [PATCH 12/13] docs: update INVOCATIONS.md with version info --- docs/contributing/INVOCATIONS.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/contributing/INVOCATIONS.md b/docs/contributing/INVOCATIONS.md index b34a2f25ac..be5d9d0805 100644 --- a/docs/contributing/INVOCATIONS.md +++ b/docs/contributing/INVOCATIONS.md @@ -244,8 +244,12 @@ copy-paste the template above. We can use the `@invocation` decorator to provide some additional info to the UI, like a custom title, tags and category. +We also encourage providing a version. This must be a +[semver](https://semver.org/) version string ("$MAJOR.$MINOR.$PATCH"). The UI +will let users know if their workflow is using a mismatched version of the node. + ```python -@invocation("resize", title="My Resizer", tags=["resize", "image"], category="My Invocations") +@invocation("resize", title="My Resizer", tags=["resize", "image"], category="My Invocations", version="1.0.0") class ResizeInvocation(BaseInvocation): """Resizes an image""" @@ -279,8 +283,6 @@ take a look a at our [contributing nodes overview](contributingNodes). ## Advanced ---> - ### Custom Output Types Like with custom inputs, sometimes you might find yourself needing custom From 3dbb0e1bfb6c88388f962c7b58587ee04d59c74c Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 4 Sep 2023 19:16:44 +1000 Subject: [PATCH 13/13] feat(tests): add tests for node versions --- invokeai/app/invocations/baseinvocation.py | 9 +++++- tests/nodes/test_node_graph.py | 36 +++++++++++++++++++--- 2 files changed, 39 insertions(+), 6 deletions(-) diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 540571762f..65a8734690 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -32,6 +32,10 @@ if TYPE_CHECKING: from ..services.invocation_services import InvocationServices +class InvalidVersionError(ValueError): + pass + + class FieldDescriptions: denoising_start = "When to start denoising, expressed a percentage of total steps" denoising_end = "When to stop denoising, expressed a percentage of total steps" @@ -605,7 +609,10 @@ def invocation( if category is not None: cls.UIConfig.category = category if version is not None: - semver.Version.parse(version) # raises ValueError if invalid semver + try: + semver.Version.parse(version) + except ValueError as e: + raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e cls.UIConfig.version = version # Add the invocation type to the pydantic model of the invocation diff --git a/tests/nodes/test_node_graph.py b/tests/nodes/test_node_graph.py index 56bf823d14..0e1be8f343 100644 --- a/tests/nodes/test_node_graph.py +++ b/tests/nodes/test_node_graph.py @@ -1,4 +1,10 @@ -from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output +from invokeai.app.invocations.baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, + InvalidVersionError, + invocation, + invocation_output, +) from .test_nodes import ( ImageToImageTestInvocation, TextToImageTestInvocation, @@ -616,18 +622,38 @@ def test_invocation_decorator(): title = "Test Invocation" tags = ["first", "second", "third"] category = "category" + version = "1.2.3" - @invocation(invocation_type, title=title, tags=tags, category=category) - class Test(BaseInvocation): + @invocation(invocation_type, title=title, tags=tags, category=category, version=version) + class TestInvocation(BaseInvocation): def invoke(self): pass - schema = Test.schema() + schema = TestInvocation.schema() assert schema.get("title") == title assert schema.get("tags") == tags assert schema.get("category") == category - assert Test(id="1").type == invocation_type # type: ignore (type is dynamically added) + assert schema.get("version") == version + assert TestInvocation(id="1").type == invocation_type # type: ignore (type is dynamically added) + + +def test_invocation_version_must_be_semver(): + invocation_type = "test_invocation" + valid_version = "1.0.0" + invalid_version = "not_semver" + + @invocation(invocation_type, version=valid_version) + class ValidVersionInvocation(BaseInvocation): + def invoke(self): + pass + + with pytest.raises(InvalidVersionError): + + @invocation(invocation_type, version=invalid_version) + class InvalidVersionInvocation(BaseInvocation): + def invoke(self): + pass def test_invocation_output_decorator():