Compare commits

...

14 Commits

Author SHA1 Message Date
b65acc0137 chore(ui): add TODO for revising field type names 2023-11-20 10:29:04 +11:00
0e640adc2c feat(nodes): add runtime check for custom field names
"Custom", "CustomCollection" and "CustomPolymorphic" are reserved field names.
2023-11-20 10:28:38 +11:00
f280a2ecbd fix(ui): fix ts error 2023-11-19 18:09:32 +11:00
e047d43111 fix(ui): rename 'nodes.currentConnectionFieldType' -> 'nodes.connectionStartFieldType'
This was confusingly named and kept tripping me up. Renamed to be consistent with the `reactflow` `ConnectionStartParams` type.
2023-11-19 18:05:25 +11:00
57567d4fc3 chore(ui): remove errant console.log 2023-11-19 17:45:47 +11:00
9ebffcd26b feat(ui): add validation for CustomCollection & CustomPolymorphic types
- Update connection validation for custom types
- Use simple string parsing to determine if a field is a collection or polymorphic type.
- No longer need to keep a list of collection and polymorphic types.
- Added runtime checks in `baseinvocation.py` to ensure no fields are named in such a way that it could mess up the new parsing
2023-11-19 17:44:42 +11:00
e30f22ae7e feat(ui): add CustomCollection and CustomPolymorphic field types 2023-11-19 14:30:08 +11:00
3ff13dc93c fix(ui): typo 2023-11-19 12:47:05 +11:00
5e4b0932fd Merge remote-tracking branch 'origin/main' into feat/arbitrary-field-types 2023-11-18 10:41:36 +11:00
98a0ce0f42 feat(ui): custom field types connection validation
In the initial commit, a custom field's original type was added to the *field templates* only as `originalType`. Custom fields' `type` property was `"Custom"`*. This allowed for type safety throughout the UI logic.

*Actually, it was `"Unknown"`, but I changed it to custom for clarity.

Connection validation logic, however, uses the *field instance* of the node/field. Like the templates, *field instances* with custom types have their `type` set to `"Custom"`, but they didn't have an `originalType` property. As a result, all custom fields could be connected to all other custom fields.

To resolve this, we need to add `originalType` to the *field instances*, then switch the validation logic to use this instead of `type`.

This ended up needing a bit of fanagling:

- If we make `originalType` a required property on field instances, existing workflows will break during connection validation, because they won't have this property. We'd need a new layer of logic to migrate the workflows, adding the new `originalType` property.

While this layer is probably needed anyways, typing `originalType` as optional is much simpler. Workflow migration logic can come layer.

(Technically, we could remove all references to field types from the workflow files, and let the templates hold all this information. This feels like a significant change and I'm reluctant to do it now.)

- Because `originalType` is optional, anywhere we care about the type of a field, we need to use it over `type`. So there are a number of `field.originalType ?? field.type` expressions. This is a bit of a gotcha, we'll need to remember this in the future.

- We use `Array.prototype.includes()` often in the workflow editor, e.g. `COLLECTION_TYPES.includes(type)`. In these cases, the const array is of type `FieldType[]`, and `type` is is `FieldType`.

Because we now support custom types, the arg `type` is now widened from `FieldType` to `string`.

This causes a TS error. This behaviour is somewhat controversial (see https://github.com/microsoft/TypeScript/issues/14520). These expressions are now rewritten as `COLLECTION_TYPES.some((t) => t === type)` to satisfy TS. It's logically equivalent.
2023-11-18 10:40:19 +11:00
7b93b5e928 Merge branch 'main' into feat/arbitrary-field-types 2023-11-17 15:00:24 +11:00
dc44debbab fix(ui): fix ts error with custom fields 2023-11-17 12:09:15 +11:00
5ce2dc3a58 feat(ui): fix tooltips for custom types
We need to hold onto the original type of the field so they don't all just show up as "Unknown".
2023-11-17 12:01:39 +11:00
27fd9071ba feat(ui): add support for custom field types
Node authors may now create their own arbitrary/custom field types. Any pydantic model is supported.

Two notes:
1. Your field type's class name must be unique.

Suggest prefixing fields with something related to the node pack as a kind of namespace.

2. Custom field types function as connection-only fields.

For example, if your custom field has string attributes, you will not get a text input for that attribute when you give a node a field with your custom type.

This is the same behaviour as other complex fields that don't have custom UIs in the workflow editor - like, say, a string collection.
2023-11-17 11:32:35 +11:00
26 changed files with 425 additions and 178 deletions

View File

@ -8,7 +8,7 @@ from abc import ABC, abstractmethod
from enum import Enum
from inspect import signature
from types import UnionType
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union
from typing import TYPE_CHECKING, Any, Callable, ClassVar, ForwardRef, Iterable, Literal, Optional, Type, TypeVar, Union
import semver
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, create_model
@ -648,17 +648,40 @@ class _Model(BaseModel):
# Get all pydantic model attrs, methods, etc
RESERVED_PYDANTIC_FIELD_NAMES = {m[0] for m in inspect.getmembers(_Model())}
RESERVED_INVOKEAI_FIELD_NAMES = {"Custom", "CustomCollection", "CustomPolymorphic"}
def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None:
"""
Validates the fields of an invocation or invocation output:
- must not override any pydantic reserved fields
- must not end with "Collection" or "Polymorphic" as these are reserved for internal use
- must be created via `InputField`, `OutputField`, or be an internal field defined in this file
"""
for name, field in model_fields.items():
if name in RESERVED_PYDANTIC_FIELD_NAMES:
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved by pydantic)')
if not field.annotation:
raise InvalidFieldError(f'Invalid field type "{name}" on "{model_type}" (missing annotation)')
annotation_name = (
field.annotation.__forward_arg__ if isinstance(field.annotation, ForwardRef) else field.annotation.__name__
)
if annotation_name.endswith("Polymorphic"):
raise InvalidFieldError(
f'Invalid field type "{annotation_name}" for "{name}" on "{model_type}" (must not end in "Polymorphic")'
)
if annotation_name.endswith("Collection"):
raise InvalidFieldError(
f'Invalid field type "{annotation_name}" for "{name}" on "{model_type}" (must not end in "Collection")'
)
if annotation_name in RESERVED_INVOKEAI_FIELD_NAMES:
raise InvalidFieldError(f'Invalid field type "{annotation_name}" for "{name}" on "{model_type}" (reserved)')
field_kind = (
# _field_kind is defined via InputField(), OutputField() or by one of the internal fields defined in this file
field.json_schema_extra.get("_field_kind", None) if field.json_schema_extra else None

View File

@ -157,8 +157,6 @@ const ImageMetadataActions = (props: Props) => {
return null;
}
console.log(metadata);
return (
<>
{metadata.created_by && (

View File

@ -57,7 +57,7 @@ const AddNodePopover = () => {
const { t } = useTranslation();
const fieldFilter = useAppSelector(
(state) => state.nodes.currentConnectionFieldType
(state) => state.nodes.connectionStartFieldType
);
const handleFilter = useAppSelector(
(state) => state.nodes.connectionStartParams?.handleType
@ -74,9 +74,13 @@ const AddNodePopover = () => {
return some(handles, (handle) => {
const sourceType =
handleFilter == 'source' ? fieldFilter : handle.type;
handleFilter == 'source'
? fieldFilter
: handle.originalType ?? handle.type;
const targetType =
handleFilter == 'target' ? fieldFilter : handle.type;
handleFilter == 'target'
? fieldFilter
: handle.originalType ?? handle.type;
return validateSourceAndTargetTypes(sourceType, targetType);
});
@ -111,7 +115,7 @@ const AddNodePopover = () => {
data.sort((a, b) => a.label.localeCompare(b.label));
return { data, t };
return { data };
},
defaultSelectorOptions
);

View File

@ -2,18 +2,17 @@ import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { FIELDS } from 'features/nodes/types/constants';
import { memo } from 'react';
import { ConnectionLineComponentProps, getBezierPath } from 'reactflow';
import { getFieldColor } from '../edges/util/getEdgeColor';
const selector = createSelector(stateSelector, ({ nodes }) => {
const { shouldAnimateEdges, currentConnectionFieldType, shouldColorEdges } =
const { shouldAnimateEdges, connectionStartFieldType, shouldColorEdges } =
nodes;
const stroke =
currentConnectionFieldType && shouldColorEdges
? colorTokenToCssVar(FIELDS[currentConnectionFieldType].color)
: colorTokenToCssVar('base.500');
const stroke = shouldColorEdges
? getFieldColor(connectionStartFieldType)
: colorTokenToCssVar('base.500');
let className = 'react-flow__custom_connection-path';

View File

@ -0,0 +1,12 @@
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { FIELDS } from 'features/nodes/types/constants';
import { FieldType } from 'features/nodes/types/types';
export const getFieldColor = (fieldType: FieldType | string | null): string => {
if (!fieldType) {
return colorTokenToCssVar('base.500');
}
const color = FIELDS[fieldType]?.color;
return color ? colorTokenToCssVar(color) : colorTokenToCssVar('base.500');
};

View File

@ -2,8 +2,8 @@ import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { FIELDS } from 'features/nodes/types/constants';
import { isInvocationNode } from 'features/nodes/types/types';
import { getFieldColor } from './getEdgeColor';
export const makeEdgeSelector = (
source: string,
@ -29,7 +29,7 @@ export const makeEdgeSelector = (
const stroke =
sourceType && nodes.shouldColorEdges
? colorTokenToCssVar(FIELDS[sourceType].color)
? getFieldColor(sourceType)
: colorTokenToCssVar('base.500');
return {

View File

@ -1,11 +1,12 @@
import { Tooltip } from '@chakra-ui/react';
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import {
COLLECTION_TYPES,
FIELDS,
getIsCollection,
getIsPolymorphic,
} from 'features/nodes/store/util/parseFieldType';
import {
HANDLE_TOOLTIP_OPEN_DELAY,
MODEL_TYPES,
POLYMORPHIC_TYPES,
} from 'features/nodes/types/constants';
import {
InputFieldTemplate,
@ -13,6 +14,7 @@ import {
} from 'features/nodes/types/types';
import { CSSProperties, memo, useMemo } from 'react';
import { Handle, HandleType, Position } from 'reactflow';
import { getFieldColor } from '../../../edges/util/getEdgeColor';
export const handleBaseStyles: CSSProperties = {
position: 'absolute',
@ -47,23 +49,21 @@ const FieldHandle = (props: FieldHandleProps) => {
isConnectionStartField,
connectionError,
} = props;
const { name, type } = fieldTemplate;
const { color: typeColor, title } = FIELDS[type];
const { name } = fieldTemplate;
const type = fieldTemplate.originalType ?? fieldTemplate.type;
const styles: CSSProperties = useMemo(() => {
const isCollectionType = COLLECTION_TYPES.includes(type);
const isPolymorphicType = POLYMORPHIC_TYPES.includes(type);
const isModelType = MODEL_TYPES.includes(type);
const color = colorTokenToCssVar(typeColor);
const isCollection = getIsCollection(fieldTemplate.type);
const isPolymorphic = getIsPolymorphic(fieldTemplate.type);
const isModelType = MODEL_TYPES.some((t) => t === type);
const color = getFieldColor(type);
const s: CSSProperties = {
backgroundColor:
isCollectionType || isPolymorphicType
? 'var(--invokeai-colors-base-900)'
: color,
isCollection || isPolymorphic ? colorTokenToCssVar('base.900') : color,
position: 'absolute',
width: '1rem',
height: '1rem',
borderWidth: isCollectionType || isPolymorphicType ? 4 : 0,
borderWidth: isCollection || isPolymorphic ? 4 : 0,
borderStyle: 'solid',
borderColor: color,
borderRadius: isModelType ? 4 : '100%',
@ -93,22 +93,19 @@ const FieldHandle = (props: FieldHandleProps) => {
return s;
}, [
connectionError,
fieldTemplate.type,
handleType,
isConnectionInProgress,
isConnectionStartField,
type,
typeColor,
]);
const tooltip = useMemo(() => {
if (isConnectionInProgress && isConnectionStartField) {
return title;
}
if (isConnectionInProgress && connectionError) {
return connectionError ?? title;
return connectionError;
}
return title;
}, [connectionError, isConnectionInProgress, isConnectionStartField, title]);
return type;
}, [connectionError, isConnectionInProgress, type]);
return (
<Tooltip

View File

@ -1,7 +1,6 @@
import { Flex, Text } from '@chakra-ui/react';
import { useFieldData } from 'features/nodes/hooks/useFieldData';
import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate';
import { FIELDS } from 'features/nodes/types/constants';
import {
isInputFieldTemplate,
isInputFieldValue,
@ -9,7 +8,6 @@ import {
import { startCase } from 'lodash-es';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
interface Props {
nodeId: string;
fieldName: string;
@ -49,7 +47,7 @@ const FieldTooltipContent = ({ nodeId, fieldName, kind }: Props) => {
{fieldTemplate.description}
</Text>
)}
{fieldTemplate && <Text>Type: {FIELDS[fieldTemplate.type].title}</Text>}
{fieldTemplate && <Text>Type: {fieldTemplate.originalType}</Text>}
{isInputTemplate && <Text>Input: {startCase(fieldTemplate.input)}</Text>}
</Flex>
);

View File

@ -4,11 +4,9 @@ import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { map } from 'lodash-es';
import { useMemo } from 'react';
import { getIsPolymorphic } from '../store/util/parseFieldType';
import { TYPES_WITH_INPUT_COMPONENTS } from '../types/constants';
import { isInvocationNode } from '../types/types';
import {
POLYMORPHIC_TYPES,
TYPES_WITH_INPUT_COMPONENTS,
} from '../types/constants';
import { getSortedFilteredFieldNames } from '../util/getSortedFilteredFieldNames';
export const useAnyOrDirectInputFieldNames = (nodeId: string) => {
@ -28,7 +26,7 @@ export const useAnyOrDirectInputFieldNames = (nodeId: string) => {
const fields = map(nodeTemplate.inputs).filter(
(field) =>
(['any', 'direct'].includes(field.input) ||
POLYMORPHIC_TYPES.includes(field.type)) &&
getIsPolymorphic(field.type)) &&
TYPES_WITH_INPUT_COMPONENTS.includes(field.type)
);
return getSortedFilteredFieldNames(fields);

View File

@ -4,10 +4,8 @@ import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { map } from 'lodash-es';
import { useMemo } from 'react';
import {
POLYMORPHIC_TYPES,
TYPES_WITH_INPUT_COMPONENTS,
} from '../types/constants';
import { getIsPolymorphic } from '../store/util/parseFieldType';
import { TYPES_WITH_INPUT_COMPONENTS } from '../types/constants';
import { isInvocationNode } from '../types/types';
import { getSortedFilteredFieldNames } from '../util/getSortedFilteredFieldNames';
@ -29,8 +27,7 @@ export const useConnectionInputFieldNames = (nodeId: string) => {
// get the visible fields
const fields = map(nodeTemplate.inputs).filter(
(field) =>
(field.input === 'connection' &&
!POLYMORPHIC_TYPES.includes(field.type)) ||
(field.input === 'connection' && !getIsPolymorphic(field.type)) ||
!TYPES_WITH_INPUT_COMPONENTS.includes(field.type)
);

View File

@ -8,7 +8,7 @@ import { useFieldType } from './useFieldType.ts';
const selectIsConnectionInProgress = createSelector(
stateSelector,
({ nodes }) =>
nodes.currentConnectionFieldType !== null &&
nodes.connectionStartFieldType !== null &&
nodes.connectionStartParams !== null
);

View File

@ -20,7 +20,8 @@ export const useFieldType = (
if (!isInvocationNode(node)) {
return;
}
return node?.data[KIND_MAP[kind]][fieldName]?.type;
const field = node.data[KIND_MAP[kind]][fieldName];
return field?.originalType ?? field?.type;
},
defaultSelectorOptions
),

View File

@ -34,10 +34,10 @@ export const useIsValidConnection = () => {
return false;
}
const sourceType = sourceNode.data.outputs[sourceHandle]?.type;
const targetType = targetNode.data.inputs[targetHandle]?.type;
const sourceField = sourceNode.data.outputs[sourceHandle];
const targetField = targetNode.data.inputs[targetHandle];
if (!sourceType || !targetType) {
if (!sourceField || !targetField) {
// something has gone terribly awry
return false;
}
@ -70,12 +70,18 @@ export const useIsValidConnection = () => {
return edge.target === target && edge.targetHandle === targetHandle;
}) &&
// except CollectionItem inputs can have multiples
targetType !== 'CollectionItem'
targetField.type !== 'CollectionItem'
) {
return false;
}
if (!validateSourceAndTargetTypes(sourceType, targetType)) {
// Must use the originalType here if it exists
if (
!validateSourceAndTargetTypes(
sourceField?.originalType ?? sourceField.type,
targetField?.originalType ?? targetField.type
)
) {
return false;
}

View File

@ -6,7 +6,7 @@ import { NodesState } from './types';
export const nodesPersistDenylist: (keyof NodesState)[] = [
'nodeTemplates',
'connectionStartParams',
'currentConnectionFieldType',
'connectionStartFieldType',
'selectedNodes',
'selectedEdges',
'isReady',

View File

@ -93,7 +93,7 @@ export const initialNodesState: NodesState = {
nodeTemplates: {},
isReady: false,
connectionStartParams: null,
currentConnectionFieldType: null,
connectionStartFieldType: null,
connectionMade: false,
modifyingEdge: false,
addNewNodePosition: null,
@ -203,7 +203,7 @@ const nodesSlice = createSlice({
nodeId &&
handleId &&
handleType &&
state.currentConnectionFieldType
state.connectionStartFieldType
) {
const newConnection = findConnectionToValidHandle(
node,
@ -212,7 +212,7 @@ const nodesSlice = createSlice({
nodeId,
handleId,
handleType,
state.currentConnectionFieldType
state.connectionStartFieldType
);
if (newConnection) {
state.edges = addEdge(
@ -224,7 +224,7 @@ const nodesSlice = createSlice({
}
state.connectionStartParams = null;
state.currentConnectionFieldType = null;
state.connectionStartFieldType = null;
},
edgeChangeStarted: (state) => {
state.modifyingEdge = true;
@ -258,10 +258,11 @@ const nodesSlice = createSlice({
handleType === 'source'
? node.data.outputs[handleId]
: node.data.inputs[handleId];
state.currentConnectionFieldType = field?.type ?? null;
state.connectionStartFieldType =
field?.originalType ?? field?.type ?? null;
},
connectionMade: (state, action: PayloadAction<Connection>) => {
const fieldType = state.currentConnectionFieldType;
const fieldType = state.connectionStartFieldType;
if (!fieldType) {
return;
}
@ -286,7 +287,7 @@ const nodesSlice = createSlice({
nodeId &&
handleId &&
handleType &&
state.currentConnectionFieldType
state.connectionStartFieldType
) {
const newConnection = findConnectionToValidHandle(
mouseOverNode,
@ -295,7 +296,7 @@ const nodesSlice = createSlice({
nodeId,
handleId,
handleType,
state.currentConnectionFieldType
state.connectionStartFieldType
);
if (newConnection) {
state.edges = addEdge(
@ -306,14 +307,14 @@ const nodesSlice = createSlice({
}
}
state.connectionStartParams = null;
state.currentConnectionFieldType = null;
state.connectionStartFieldType = null;
} else {
state.addNewNodePosition = action.payload.cursorPosition;
state.isAddNodePopoverOpen = true;
}
} else {
state.connectionStartParams = null;
state.currentConnectionFieldType = null;
state.connectionStartFieldType = null;
}
state.modifyingEdge = false;
},
@ -942,7 +943,7 @@ const nodesSlice = createSlice({
//Make sure these get reset if we close the popover and haven't selected a node
state.connectionStartParams = null;
state.currentConnectionFieldType = null;
state.connectionStartFieldType = null;
},
addNodePopoverToggled: (state) => {
state.isAddNodePopoverOpen = !state.isAddNodePopoverOpen;

View File

@ -21,7 +21,7 @@ export type NodesState = {
edges: Edge<InvocationEdgeExtra>[];
nodeTemplates: Record<string, InvocationTemplate>;
connectionStartParams: OnConnectStartParams | null;
currentConnectionFieldType: FieldType | null;
connectionStartFieldType: FieldType | string | null;
connectionMade: boolean;
modifyingEdge: boolean;
shouldShowFieldTypeLegend: boolean;

View File

@ -94,6 +94,7 @@ export const buildNodeData = (
name: outputName,
type: outputTemplate.type,
fieldKind: 'output',
originalType: outputTemplate.originalType,
};
outputsAccumulator[outputName] = outputFieldValue;

View File

@ -12,7 +12,7 @@ import { getIsGraphAcyclic } from './getIsGraphAcyclic';
const isValidConnection = (
edges: Edge[],
handleCurrentType: HandleType,
handleCurrentFieldType: FieldType,
handleCurrentFieldType: FieldType | string,
node: Node,
handle: InputFieldValue | OutputFieldValue
) => {
@ -35,7 +35,12 @@ const isValidConnection = (
}
}
if (!validateSourceAndTargetTypes(handleCurrentFieldType, handle.type)) {
if (
!validateSourceAndTargetTypes(
handleCurrentFieldType,
handle.originalType ?? handle.type
)
) {
isValidConnection = false;
}
@ -49,7 +54,7 @@ export const findConnectionToValidHandle = (
handleCurrentNodeId: string,
handleCurrentName: string,
handleCurrentType: HandleType,
handleCurrentFieldType: FieldType
handleCurrentFieldType: FieldType | string
): Connection | null => {
if (node.id === handleCurrentNodeId) {
return null;

View File

@ -1,9 +1,9 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
import { FieldType } from 'features/nodes/types/types';
import i18n from 'i18next';
import { HandleType } from 'reactflow';
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
/**
@ -15,17 +15,17 @@ export const makeConnectionErrorSelector = (
nodeId: string,
fieldName: string,
handleType: HandleType,
fieldType?: FieldType
fieldType?: FieldType | string
) => {
return createSelector(stateSelector, (state) => {
if (!fieldType) {
return i18n.t('nodes.noFieldType');
}
const { currentConnectionFieldType, connectionStartParams, nodes, edges } =
const { connectionStartFieldType, connectionStartParams, nodes, edges } =
state.nodes;
if (!connectionStartParams || !currentConnectionFieldType) {
if (!connectionStartParams || !connectionStartFieldType) {
return i18n.t('nodes.noConnectionInProgress');
}
@ -40,9 +40,9 @@ export const makeConnectionErrorSelector = (
}
const targetType =
handleType === 'target' ? fieldType : currentConnectionFieldType;
handleType === 'target' ? fieldType : connectionStartFieldType;
const sourceType =
handleType === 'source' ? fieldType : currentConnectionFieldType;
handleType === 'source' ? fieldType : connectionStartFieldType;
if (nodeId === connectionNodeId) {
return i18n.t('nodes.cannotConnectToSelf');

View File

@ -0,0 +1,14 @@
import { FieldType } from 'features/nodes/types/types';
export const getIsPolymorphic = (type: FieldType | string): boolean =>
type.endsWith('Polymorphic');
export const getIsCollection = (type: FieldType | string): boolean =>
type.endsWith('Collection');
export const getBaseType = (type: FieldType | string): FieldType | string =>
getIsPolymorphic(type)
? type.replace(/Polymorphic$/, '')
: getIsCollection(type)
? type.replace(/Collection$/, '')
: type;

View File

@ -1,18 +1,32 @@
import {
COLLECTION_MAP,
COLLECTION_TYPES,
POLYMORPHIC_TO_SINGLE_MAP,
POLYMORPHIC_TYPES,
} from 'features/nodes/types/constants';
import { FieldType } from 'features/nodes/types/types';
import {
getBaseType,
getIsCollection,
getIsPolymorphic,
} from './parseFieldType';
/**
* Validates that the source and target types are compatible for a connection.
* @param sourceType The type of the source field. Must be the originalType if it exists.
* @param targetType The type of the target field. Must be the originalType if it exists.
* @returns True if the connection is valid, false otherwise.
*/
export const validateSourceAndTargetTypes = (
sourceType: FieldType,
targetType: FieldType
sourceType: FieldType | string,
targetType: FieldType | string
) => {
const isSourcePolymorphic = getIsPolymorphic(sourceType);
const isSourceCollection = getIsCollection(sourceType);
const sourceBaseType = getBaseType(sourceType);
const isTargetPolymorphic = getIsPolymorphic(targetType);
const isTargetCollection = getIsCollection(targetType);
const targetBaseType = getBaseType(targetType);
// TODO: There's a bug with Collect -> Iterate nodes:
// https://github.com/invoke-ai/InvokeAI/issues/3956
// Once this is resolved, we can remove this check.
// Note that 'Collection' here is a field type, not node type.
if (sourceType === 'Collection' && targetType === 'Collection') {
return false;
}
@ -31,37 +45,21 @@ export const validateSourceAndTargetTypes = (
*/
const isCollectionItemToNonCollection =
sourceType === 'CollectionItem' && !COLLECTION_TYPES.includes(targetType);
sourceType === 'CollectionItem' && !isTargetCollection;
const isNonCollectionToCollectionItem =
targetType === 'CollectionItem' &&
!COLLECTION_TYPES.includes(sourceType) &&
!POLYMORPHIC_TYPES.includes(sourceType);
!isSourceCollection &&
!isSourcePolymorphic;
const isAnythingToPolymorphicOfSameBaseType =
POLYMORPHIC_TYPES.includes(targetType) &&
(() => {
if (!POLYMORPHIC_TYPES.includes(targetType)) {
return false;
}
const baseType =
POLYMORPHIC_TO_SINGLE_MAP[
targetType as keyof typeof POLYMORPHIC_TO_SINGLE_MAP
];
const collectionType =
COLLECTION_MAP[baseType as keyof typeof COLLECTION_MAP];
return sourceType === baseType || sourceType === collectionType;
})();
isTargetPolymorphic && sourceBaseType === targetBaseType;
const isGenericCollectionToAnyCollectionOrPolymorphic =
sourceType === 'Collection' &&
(COLLECTION_TYPES.includes(targetType) ||
POLYMORPHIC_TYPES.includes(targetType));
sourceType === 'Collection' && (isTargetCollection || isTargetPolymorphic);
const isCollectionToGenericCollection =
targetType === 'Collection' && COLLECTION_TYPES.includes(sourceType);
targetType === 'Collection' && isSourceCollection;
const isIntToFloat = sourceType === 'integer' && targetType === 'float';
@ -71,6 +69,7 @@ export const validateSourceAndTargetTypes = (
const isTargetAnyType = targetType === 'Any';
// One of these must be true for the connection to be valid
return (
isCollectionItemToNonCollection ||
isNonCollectionToCollectionItem ||

View File

@ -20,38 +20,6 @@ export const KIND_MAP = {
output: 'outputs' as const,
};
export const COLLECTION_TYPES: FieldType[] = [
'Collection',
'IntegerCollection',
'BooleanCollection',
'FloatCollection',
'StringCollection',
'ImageCollection',
'LatentsCollection',
'ConditioningCollection',
'ControlCollection',
'ColorCollection',
'T2IAdapterCollection',
'IPAdapterCollection',
'MetadataItemCollection',
'MetadataCollection',
];
export const POLYMORPHIC_TYPES: FieldType[] = [
'IntegerPolymorphic',
'BooleanPolymorphic',
'FloatPolymorphic',
'StringPolymorphic',
'ImagePolymorphic',
'LatentsPolymorphic',
'ConditioningPolymorphic',
'ControlPolymorphic',
'ColorPolymorphic',
'T2IAdapterPolymorphic',
'IPAdapterPolymorphic',
'MetadataItemPolymorphic',
];
export const MODEL_TYPES: FieldType[] = [
'IPAdapterModelField',
'ControlNetModelField',
@ -68,6 +36,26 @@ export const MODEL_TYPES: FieldType[] = [
'IPAdapterModelField',
];
/**
* TODO: Revise the field type naming scheme
*
* Unfortunately, due to inconsistent naming of types, we need to keep the below map objects/callbacks.
*
* Problems:
* - some types do not use the word "Field" in their name, e.g. "Scheduler"
* - primitive types use all-lowercase names, e.g. "integer"
* - collection and polymorphic types do not use the word "Field"
*
* If these inconsistencies were resolved, we could remove these mappings and use simple string
* parsing/manipulation to handle field types.
*
* It would make some of the parsing logic simpler and reduce the maintenance overhead of adding new
* "official" field types.
*
* This will require migration logic for workflows to update their field types. Workflows *do* have a
* version attached to them, so this shouldn't be too difficult.
*/
export const COLLECTION_MAP: FieldTypeMapWithNumber = {
integer: 'IntegerCollection',
boolean: 'BooleanCollection',
@ -83,6 +71,7 @@ export const COLLECTION_MAP: FieldTypeMapWithNumber = {
IPAdapterField: 'IPAdapterCollection',
MetadataItemField: 'MetadataItemCollection',
MetadataField: 'MetadataCollection',
Custom: 'CustomCollection',
};
export const isCollectionItemType = (
itemType: string | undefined
@ -103,6 +92,7 @@ export const SINGLE_TO_POLYMORPHIC_MAP: FieldTypeMapWithNumber = {
T2IAdapterField: 'T2IAdapterPolymorphic',
IPAdapterField: 'IPAdapterPolymorphic',
MetadataItemField: 'MetadataItemPolymorphic',
Custom: 'CustomPolymorphic',
};
export const POLYMORPHIC_TO_SINGLE_MAP: FieldTypeMap = {
@ -118,6 +108,7 @@ export const POLYMORPHIC_TO_SINGLE_MAP: FieldTypeMap = {
T2IAdapterPolymorphic: 'T2IAdapterField',
IPAdapterPolymorphic: 'IPAdapterField',
MetadataItemPolymorphic: 'MetadataItemField',
CustomPolymorphic: 'Custom',
};
export const TYPES_WITH_INPUT_COMPONENTS: FieldType[] = [
@ -150,12 +141,27 @@ export const isPolymorphicItemType = (
): itemType is keyof typeof SINGLE_TO_POLYMORPHIC_MAP =>
Boolean(itemType && itemType in SINGLE_TO_POLYMORPHIC_MAP);
export const FIELDS: Record<FieldType, FieldUIConfig> = {
export const FIELDS: Record<FieldType | string, FieldUIConfig> = {
Any: {
color: 'gray.500',
description: 'Any field type is accepted.',
title: 'Any',
},
Custom: {
color: 'gray.500',
description: 'A custom field, provided by an external node.',
title: 'Custom',
},
CustomCollection: {
color: 'gray.500',
description: 'A custom field collection, provided by an external node.',
title: 'Custom Collection',
},
CustomPolymorphic: {
color: 'gray.500',
description: 'A custom field polymorphic, provided by an external node.',
title: 'Custom Polymorphic',
},
MetadataField: {
color: 'gray.500',
description: 'A metadata dict.',

View File

@ -133,6 +133,9 @@ export const zFieldType = z.enum([
'UNetField',
'VaeField',
'VaeModelField',
'Custom',
'CustomCollection',
'CustomPolymorphic',
]);
export type FieldType = z.infer<typeof zFieldType>;
@ -143,7 +146,7 @@ export type FieldTypeMapWithNumber = {
export const zReservedFieldType = z.enum([
'WorkflowField',
'IsIntermediate',
'IsIntermediate', // this is technically a reserved field type!
'MetadataField',
]);
@ -163,6 +166,7 @@ export const zFieldValueBase = z.object({
id: z.string().trim().min(1),
name: z.string().trim().min(1),
type: zFieldType,
originalType: z.string().optional(),
});
export type FieldValueBase = z.infer<typeof zFieldValueBase>;
@ -190,6 +194,7 @@ export type OutputFieldTemplate = {
type: FieldType;
title: string;
description: string;
originalType?: string; // used for custom types
} & _OutputField;
export const zInputFieldValueBase = zFieldValueBase.extend({
@ -789,6 +794,21 @@ export const zAnyInputFieldValue = zInputFieldValueBase.extend({
value: z.any().optional(),
});
export const zCustomInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('Custom'),
value: z.any().optional(),
});
export const zCustomCollectionInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('CustomCollection'),
value: z.array(z.any()).optional(),
});
export const zCustomPolymorphicInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('CustomPolymorphic'),
value: z.union([z.any(), z.array(z.any())]).optional(),
});
export const zInputFieldValue = z.discriminatedUnion('type', [
zAnyInputFieldValue,
zBoardInputFieldValue,
@ -846,6 +866,9 @@ export const zInputFieldValue = z.discriminatedUnion('type', [
zMetadataItemPolymorphicInputFieldValue,
zMetadataInputFieldValue,
zMetadataCollectionInputFieldValue,
zCustomInputFieldValue,
zCustomCollectionInputFieldValue,
zCustomPolymorphicInputFieldValue,
]);
export type InputFieldValue = z.infer<typeof zInputFieldValue>;
@ -856,6 +879,7 @@ export type InputFieldTemplateBase = {
description: string;
required: boolean;
fieldKind: 'input';
originalType?: string; // used for custom types
} & _InputField;
export type AnyInputFieldTemplate = InputFieldTemplateBase & {
@ -863,6 +887,21 @@ export type AnyInputFieldTemplate = InputFieldTemplateBase & {
default: undefined;
};
export type CustomInputFieldTemplate = InputFieldTemplateBase & {
type: 'Custom';
default: undefined;
};
export type CustomCollectionInputFieldTemplate = InputFieldTemplateBase & {
type: 'CustomCollection';
default: [];
};
export type CustomPolymorphicInputFieldTemplate = InputFieldTemplateBase & {
type: 'CustomPolymorphic';
default: undefined;
};
export type IntegerInputFieldTemplate = InputFieldTemplateBase & {
type: 'integer';
default: number;
@ -1259,7 +1298,10 @@ export type InputFieldTemplate =
| MetadataItemCollectionInputFieldTemplate
| MetadataInputFieldTemplate
| MetadataItemPolymorphicInputFieldTemplate
| MetadataCollectionInputFieldTemplate;
| MetadataCollectionInputFieldTemplate
| CustomInputFieldTemplate
| CustomCollectionInputFieldTemplate
| CustomPolymorphicInputFieldTemplate;
export const isInputFieldValue = (
field?: InputFieldValue | OutputFieldValue

View File

@ -8,9 +8,9 @@ import {
} from 'lodash-es';
import { OpenAPIV3_1 } from 'openapi-types';
import { ControlField } from 'services/api/types';
import { getIsPolymorphic } from '../store/util/parseFieldType';
import {
COLLECTION_MAP,
POLYMORPHIC_TYPES,
SINGLE_TO_POLYMORPHIC_MAP,
isCollectionItemType,
isPolymorphicItemType,
@ -35,6 +35,9 @@ import {
ControlInputFieldTemplate,
ControlNetModelInputFieldTemplate,
ControlPolymorphicInputFieldTemplate,
CustomCollectionInputFieldTemplate,
CustomInputFieldTemplate,
CustomPolymorphicInputFieldTemplate,
DenoiseMaskInputFieldTemplate,
EnumInputFieldTemplate,
FieldType,
@ -84,6 +87,7 @@ import {
VaeInputFieldTemplate,
VaeModelInputFieldTemplate,
isArraySchemaObject,
isFieldType,
isNonArraySchemaObject,
isRefObject,
isSchemaObject,
@ -981,9 +985,45 @@ const buildSchedulerInputFieldTemplate = ({
return template;
};
const buildCustomCollectionInputFieldTemplate = ({
baseField,
}: BuildInputFieldArg): CustomCollectionInputFieldTemplate => {
const template: CustomCollectionInputFieldTemplate = {
...baseField,
type: 'CustomCollection',
default: [],
};
return template;
};
const buildCustomPolymorphicInputFieldTemplate = ({
baseField,
}: BuildInputFieldArg): CustomPolymorphicInputFieldTemplate => {
const template: CustomPolymorphicInputFieldTemplate = {
...baseField,
type: 'CustomPolymorphic',
default: undefined,
};
return template;
};
const buildCustomInputFieldTemplate = ({
baseField,
}: BuildInputFieldArg): CustomInputFieldTemplate => {
const template: CustomInputFieldTemplate = {
...baseField,
type: 'Custom',
default: undefined,
};
return template;
};
export const getFieldType = (
schemaObject: OpenAPIV3_1SchemaOrRef
): string | undefined => {
): { type: string; originalType: string } | undefined => {
if (isSchemaObject(schemaObject)) {
if (!schemaObject.type) {
// if schemaObject has no type, then it should have one of allOf, anyOf, oneOf
@ -991,7 +1031,17 @@ export const getFieldType = (
if (schemaObject.allOf) {
const allOf = schemaObject.allOf;
if (allOf && allOf[0] && isRefObject(allOf[0])) {
return refObjectToSchemaName(allOf[0]);
// This is a single ref type
const originalType = refObjectToSchemaName(allOf[0]);
if (!originalType) {
// something has gone terribly awry
return;
}
return {
type: isFieldType(originalType) ? originalType : 'Custom',
originalType,
};
}
} else if (schemaObject.anyOf) {
// ignore null types
@ -1004,8 +1054,17 @@ export const getFieldType = (
return true;
});
if (anyOf.length === 1) {
// This is a single ref type
if (isRefObject(anyOf[0])) {
return refObjectToSchemaName(anyOf[0]);
const originalType = refObjectToSchemaName(anyOf[0]);
if (!originalType) {
return;
}
return {
type: isFieldType(originalType) ? originalType : 'Custom',
originalType,
};
} else if (isSchemaObject(anyOf[0])) {
return getFieldType(anyOf[0]);
}
@ -1051,16 +1110,29 @@ export const getFieldType = (
secondType = second.type;
}
}
if (firstType === secondType && isPolymorphicItemType(firstType)) {
return SINGLE_TO_POLYMORPHIC_MAP[firstType];
if (firstType === secondType) {
if (isPolymorphicItemType(firstType)) {
// Known polymorphic field type
const originalType = SINGLE_TO_POLYMORPHIC_MAP[firstType];
if (!originalType) {
return;
}
return { type: originalType, originalType };
}
// else custom polymorphic
return {
type: 'CustomPolymorphic',
originalType: `${firstType}Polymorphic`,
};
}
}
} else if (schemaObject.enum) {
return 'enum';
return { type: 'enum', originalType: 'enum' };
} else if (schemaObject.type) {
if (schemaObject.type === 'number') {
// floats are "number" in OpenAPI, while ints are "integer" - we need to distinguish them
return 'float';
return { type: 'float', originalType: 'float' };
} else if (schemaObject.type === 'array') {
const itemType = isSchemaObject(schemaObject.items)
? schemaObject.items.type
@ -1072,16 +1144,39 @@ export const getFieldType = (
}
if (isCollectionItemType(itemType)) {
return COLLECTION_MAP[itemType];
// known collection field type
const originalType = COLLECTION_MAP[itemType];
if (!originalType) {
return;
}
return { type: originalType, originalType };
}
return;
} else if (!isArray(schemaObject.type)) {
return schemaObject.type;
return {
type: 'CustomCollection',
originalType: `${itemType}Collection`,
};
} else if (
!isArray(schemaObject.type) &&
schemaObject.type !== 'null' && // 'null' is not valid
schemaObject.type !== 'object' // 'object' is not valid
) {
const originalType = schemaObject.type;
return { type: originalType, originalType };
}
// else ignore
return;
}
} else if (isRefObject(schemaObject)) {
return refObjectToSchemaName(schemaObject);
const originalType = refObjectToSchemaName(schemaObject);
if (!originalType) {
return;
}
return {
type: isFieldType(originalType) ? originalType : 'Custom',
originalType,
};
}
return;
};
@ -1145,13 +1240,11 @@ const TEMPLATE_BUILDER_MAP: {
UNetField: buildUNetInputFieldTemplate,
VaeField: buildVaeInputFieldTemplate,
VaeModelField: buildVaeModelInputFieldTemplate,
Custom: buildCustomInputFieldTemplate,
CustomCollection: buildCustomCollectionInputFieldTemplate,
CustomPolymorphic: buildCustomPolymorphicInputFieldTemplate,
};
const isTemplatedFieldType = (
fieldType: string | undefined
): fieldType is keyof typeof TEMPLATE_BUILDER_MAP =>
Boolean(fieldType && fieldType in TEMPLATE_BUILDER_MAP);
/**
* Builds an input field from an invocation schema property.
* @param fieldSchema The schema object
@ -1161,7 +1254,8 @@ export const buildInputFieldTemplate = (
nodeSchema: InvocationSchemaObject,
fieldSchema: InvocationFieldSchema,
name: string,
fieldType: FieldType
fieldType: FieldType,
originalType: string
) => {
const {
input,
@ -1175,7 +1269,7 @@ export const buildInputFieldTemplate = (
const extra = {
// TODO: Can we support polymorphic inputs in the UI?
input: POLYMORPHIC_TYPES.includes(fieldType) ? 'connection' : input,
input: getIsPolymorphic(fieldType) ? 'connection' : input,
ui_hidden,
ui_component,
ui_type,
@ -1183,6 +1277,7 @@ export const buildInputFieldTemplate = (
ui_order,
ui_choice_labels,
item_default,
originalType,
};
const baseField = {
@ -1193,10 +1288,6 @@ export const buildInputFieldTemplate = (
...extra,
};
if (!isTemplatedFieldType(fieldType)) {
return;
}
const builder = TEMPLATE_BUILDER_MAP[fieldType];
if (!builder) {

View File

@ -60,6 +60,9 @@ const FIELD_VALUE_FALLBACK_MAP: {
UNetField: undefined,
VaeField: undefined,
VaeModelField: undefined,
Custom: undefined,
CustomCollection: [],
CustomPolymorphic: undefined,
};
export const buildInputFieldValue = (
@ -76,10 +79,9 @@ export const buildInputFieldValue = (
type: template.type,
label: '',
fieldKind: 'input',
originalType: template.originalType,
value: template.default ?? FIELD_VALUE_FALLBACK_MAP[template.type],
} as InputFieldValue;
fieldValue.value =
template.default ?? FIELD_VALUE_FALLBACK_MAP[template.type];
return fieldValue;
};

View File

@ -103,14 +103,15 @@ export const parseSchema = (
return inputsAccumulator;
}
const fieldType = property.ui_type ?? getFieldType(property);
const fieldTypeResult = property.ui_type
? { type: property.ui_type, originalType: property.ui_type }
: getFieldType(property);
if (!fieldType) {
if (!fieldTypeResult) {
logger('nodes').warn(
{
node: type,
fieldName: propertyName,
fieldType,
field: parseify(property),
},
'Missing input field type'
@ -118,6 +119,9 @@ export const parseSchema = (
return inputsAccumulator;
}
// stash this for custom types
const { type: fieldType, originalType } = fieldTypeResult;
if (fieldType === 'WorkflowField') {
withWorkflow = true;
return inputsAccumulator;
@ -136,6 +140,18 @@ export const parseSchema = (
return inputsAccumulator;
}
if (!isFieldType(originalType)) {
logger('nodes').debug(
{
node: type,
fieldName: propertyName,
fieldType,
field: parseify(property),
},
`Fallback handling for unknown input field type: ${fieldType}`
);
}
if (!isFieldType(fieldType)) {
logger('nodes').warn(
{
@ -144,7 +160,7 @@ export const parseSchema = (
fieldType,
field: parseify(property),
},
`Skipping unknown input field type: ${fieldType}`
`Unable to parse field type: ${fieldType}`
);
return inputsAccumulator;
}
@ -153,7 +169,8 @@ export const parseSchema = (
schema,
property,
propertyName,
fieldType
fieldType,
originalType
);
if (!field) {
@ -162,6 +179,7 @@ export const parseSchema = (
node: type,
fieldName: propertyName,
fieldType,
originalType,
field: parseify(property),
},
'Skipping input field with no template'
@ -220,12 +238,46 @@ export const parseSchema = (
return outputsAccumulator;
}
const fieldType = property.ui_type ?? getFieldType(property);
const fieldTypeResult = property.ui_type
? { type: property.ui_type, originalType: property.ui_type }
: getFieldType(property);
if (!fieldTypeResult) {
logger('nodes').warn(
{
node: type,
fieldName: propertyName,
field: parseify(property),
},
'Missing output field type'
);
return outputsAccumulator;
}
const { type: fieldType, originalType } = fieldTypeResult;
if (!isFieldType(fieldType)) {
logger('nodes').debug(
{
node: type,
fieldName: propertyName,
fieldType,
originalType,
field: parseify(property),
},
`Fallback handling for unknown output field type: ${fieldType}`
);
}
if (!isFieldType(fieldType)) {
logger('nodes').warn(
{ fieldName: propertyName, fieldType, field: parseify(property) },
'Skipping unknown output field type'
{
node: type,
fieldName: propertyName,
fieldType,
field: parseify(property),
},
`Unable to parse field type: ${fieldType}`
);
return outputsAccumulator;
}
@ -240,6 +292,7 @@ export const parseSchema = (
ui_hidden: property.ui_hidden ?? false,
ui_type: property.ui_type,
ui_order: property.ui_order,
originalType,
};
return outputsAccumulator;