mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
14 Commits
next-fix-t
...
feat/arbit
Author | SHA1 | Date | |
---|---|---|---|
b65acc0137 | |||
0e640adc2c | |||
f280a2ecbd | |||
e047d43111 | |||
57567d4fc3 | |||
9ebffcd26b | |||
e30f22ae7e | |||
3ff13dc93c | |||
5e4b0932fd | |||
98a0ce0f42 | |||
7b93b5e928 | |||
dc44debbab | |||
5ce2dc3a58 | |||
27fd9071ba |
@ -8,7 +8,7 @@ from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from inspect import signature
|
||||
from types import UnionType
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, ForwardRef, Iterable, Literal, Optional, Type, TypeVar, Union
|
||||
|
||||
import semver
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, create_model
|
||||
@ -648,17 +648,40 @@ class _Model(BaseModel):
|
||||
# Get all pydantic model attrs, methods, etc
|
||||
RESERVED_PYDANTIC_FIELD_NAMES = {m[0] for m in inspect.getmembers(_Model())}
|
||||
|
||||
RESERVED_INVOKEAI_FIELD_NAMES = {"Custom", "CustomCollection", "CustomPolymorphic"}
|
||||
|
||||
|
||||
def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None:
|
||||
"""
|
||||
Validates the fields of an invocation or invocation output:
|
||||
- must not override any pydantic reserved fields
|
||||
- must not end with "Collection" or "Polymorphic" as these are reserved for internal use
|
||||
- must be created via `InputField`, `OutputField`, or be an internal field defined in this file
|
||||
"""
|
||||
for name, field in model_fields.items():
|
||||
if name in RESERVED_PYDANTIC_FIELD_NAMES:
|
||||
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved by pydantic)')
|
||||
|
||||
if not field.annotation:
|
||||
raise InvalidFieldError(f'Invalid field type "{name}" on "{model_type}" (missing annotation)')
|
||||
|
||||
annotation_name = (
|
||||
field.annotation.__forward_arg__ if isinstance(field.annotation, ForwardRef) else field.annotation.__name__
|
||||
)
|
||||
|
||||
if annotation_name.endswith("Polymorphic"):
|
||||
raise InvalidFieldError(
|
||||
f'Invalid field type "{annotation_name}" for "{name}" on "{model_type}" (must not end in "Polymorphic")'
|
||||
)
|
||||
|
||||
if annotation_name.endswith("Collection"):
|
||||
raise InvalidFieldError(
|
||||
f'Invalid field type "{annotation_name}" for "{name}" on "{model_type}" (must not end in "Collection")'
|
||||
)
|
||||
|
||||
if annotation_name in RESERVED_INVOKEAI_FIELD_NAMES:
|
||||
raise InvalidFieldError(f'Invalid field type "{annotation_name}" for "{name}" on "{model_type}" (reserved)')
|
||||
|
||||
field_kind = (
|
||||
# _field_kind is defined via InputField(), OutputField() or by one of the internal fields defined in this file
|
||||
field.json_schema_extra.get("_field_kind", None) if field.json_schema_extra else None
|
||||
|
@ -157,8 +157,6 @@ const ImageMetadataActions = (props: Props) => {
|
||||
return null;
|
||||
}
|
||||
|
||||
console.log(metadata);
|
||||
|
||||
return (
|
||||
<>
|
||||
{metadata.created_by && (
|
||||
|
@ -57,7 +57,7 @@ const AddNodePopover = () => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const fieldFilter = useAppSelector(
|
||||
(state) => state.nodes.currentConnectionFieldType
|
||||
(state) => state.nodes.connectionStartFieldType
|
||||
);
|
||||
const handleFilter = useAppSelector(
|
||||
(state) => state.nodes.connectionStartParams?.handleType
|
||||
@ -74,9 +74,13 @@ const AddNodePopover = () => {
|
||||
|
||||
return some(handles, (handle) => {
|
||||
const sourceType =
|
||||
handleFilter == 'source' ? fieldFilter : handle.type;
|
||||
handleFilter == 'source'
|
||||
? fieldFilter
|
||||
: handle.originalType ?? handle.type;
|
||||
const targetType =
|
||||
handleFilter == 'target' ? fieldFilter : handle.type;
|
||||
handleFilter == 'target'
|
||||
? fieldFilter
|
||||
: handle.originalType ?? handle.type;
|
||||
|
||||
return validateSourceAndTargetTypes(sourceType, targetType);
|
||||
});
|
||||
@ -111,7 +115,7 @@ const AddNodePopover = () => {
|
||||
|
||||
data.sort((a, b) => a.label.localeCompare(b.label));
|
||||
|
||||
return { data, t };
|
||||
return { data };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
@ -2,18 +2,17 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
|
||||
import { FIELDS } from 'features/nodes/types/constants';
|
||||
import { memo } from 'react';
|
||||
import { ConnectionLineComponentProps, getBezierPath } from 'reactflow';
|
||||
import { getFieldColor } from '../edges/util/getEdgeColor';
|
||||
|
||||
const selector = createSelector(stateSelector, ({ nodes }) => {
|
||||
const { shouldAnimateEdges, currentConnectionFieldType, shouldColorEdges } =
|
||||
const { shouldAnimateEdges, connectionStartFieldType, shouldColorEdges } =
|
||||
nodes;
|
||||
|
||||
const stroke =
|
||||
currentConnectionFieldType && shouldColorEdges
|
||||
? colorTokenToCssVar(FIELDS[currentConnectionFieldType].color)
|
||||
: colorTokenToCssVar('base.500');
|
||||
const stroke = shouldColorEdges
|
||||
? getFieldColor(connectionStartFieldType)
|
||||
: colorTokenToCssVar('base.500');
|
||||
|
||||
let className = 'react-flow__custom_connection-path';
|
||||
|
||||
|
@ -0,0 +1,12 @@
|
||||
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
|
||||
import { FIELDS } from 'features/nodes/types/constants';
|
||||
import { FieldType } from 'features/nodes/types/types';
|
||||
|
||||
export const getFieldColor = (fieldType: FieldType | string | null): string => {
|
||||
if (!fieldType) {
|
||||
return colorTokenToCssVar('base.500');
|
||||
}
|
||||
const color = FIELDS[fieldType]?.color;
|
||||
|
||||
return color ? colorTokenToCssVar(color) : colorTokenToCssVar('base.500');
|
||||
};
|
@ -2,8 +2,8 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
|
||||
import { FIELDS } from 'features/nodes/types/constants';
|
||||
import { isInvocationNode } from 'features/nodes/types/types';
|
||||
import { getFieldColor } from './getEdgeColor';
|
||||
|
||||
export const makeEdgeSelector = (
|
||||
source: string,
|
||||
@ -29,7 +29,7 @@ export const makeEdgeSelector = (
|
||||
|
||||
const stroke =
|
||||
sourceType && nodes.shouldColorEdges
|
||||
? colorTokenToCssVar(FIELDS[sourceType].color)
|
||||
? getFieldColor(sourceType)
|
||||
: colorTokenToCssVar('base.500');
|
||||
|
||||
return {
|
||||
|
@ -1,11 +1,12 @@
|
||||
import { Tooltip } from '@chakra-ui/react';
|
||||
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
|
||||
import {
|
||||
COLLECTION_TYPES,
|
||||
FIELDS,
|
||||
getIsCollection,
|
||||
getIsPolymorphic,
|
||||
} from 'features/nodes/store/util/parseFieldType';
|
||||
import {
|
||||
HANDLE_TOOLTIP_OPEN_DELAY,
|
||||
MODEL_TYPES,
|
||||
POLYMORPHIC_TYPES,
|
||||
} from 'features/nodes/types/constants';
|
||||
import {
|
||||
InputFieldTemplate,
|
||||
@ -13,6 +14,7 @@ import {
|
||||
} from 'features/nodes/types/types';
|
||||
import { CSSProperties, memo, useMemo } from 'react';
|
||||
import { Handle, HandleType, Position } from 'reactflow';
|
||||
import { getFieldColor } from '../../../edges/util/getEdgeColor';
|
||||
|
||||
export const handleBaseStyles: CSSProperties = {
|
||||
position: 'absolute',
|
||||
@ -47,23 +49,21 @@ const FieldHandle = (props: FieldHandleProps) => {
|
||||
isConnectionStartField,
|
||||
connectionError,
|
||||
} = props;
|
||||
const { name, type } = fieldTemplate;
|
||||
const { color: typeColor, title } = FIELDS[type];
|
||||
const { name } = fieldTemplate;
|
||||
const type = fieldTemplate.originalType ?? fieldTemplate.type;
|
||||
|
||||
const styles: CSSProperties = useMemo(() => {
|
||||
const isCollectionType = COLLECTION_TYPES.includes(type);
|
||||
const isPolymorphicType = POLYMORPHIC_TYPES.includes(type);
|
||||
const isModelType = MODEL_TYPES.includes(type);
|
||||
const color = colorTokenToCssVar(typeColor);
|
||||
const isCollection = getIsCollection(fieldTemplate.type);
|
||||
const isPolymorphic = getIsPolymorphic(fieldTemplate.type);
|
||||
const isModelType = MODEL_TYPES.some((t) => t === type);
|
||||
const color = getFieldColor(type);
|
||||
const s: CSSProperties = {
|
||||
backgroundColor:
|
||||
isCollectionType || isPolymorphicType
|
||||
? 'var(--invokeai-colors-base-900)'
|
||||
: color,
|
||||
isCollection || isPolymorphic ? colorTokenToCssVar('base.900') : color,
|
||||
position: 'absolute',
|
||||
width: '1rem',
|
||||
height: '1rem',
|
||||
borderWidth: isCollectionType || isPolymorphicType ? 4 : 0,
|
||||
borderWidth: isCollection || isPolymorphic ? 4 : 0,
|
||||
borderStyle: 'solid',
|
||||
borderColor: color,
|
||||
borderRadius: isModelType ? 4 : '100%',
|
||||
@ -93,22 +93,19 @@ const FieldHandle = (props: FieldHandleProps) => {
|
||||
return s;
|
||||
}, [
|
||||
connectionError,
|
||||
fieldTemplate.type,
|
||||
handleType,
|
||||
isConnectionInProgress,
|
||||
isConnectionStartField,
|
||||
type,
|
||||
typeColor,
|
||||
]);
|
||||
|
||||
const tooltip = useMemo(() => {
|
||||
if (isConnectionInProgress && isConnectionStartField) {
|
||||
return title;
|
||||
}
|
||||
if (isConnectionInProgress && connectionError) {
|
||||
return connectionError ?? title;
|
||||
return connectionError;
|
||||
}
|
||||
return title;
|
||||
}, [connectionError, isConnectionInProgress, isConnectionStartField, title]);
|
||||
return type;
|
||||
}, [connectionError, isConnectionInProgress, type]);
|
||||
|
||||
return (
|
||||
<Tooltip
|
||||
|
@ -1,7 +1,6 @@
|
||||
import { Flex, Text } from '@chakra-ui/react';
|
||||
import { useFieldData } from 'features/nodes/hooks/useFieldData';
|
||||
import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate';
|
||||
import { FIELDS } from 'features/nodes/types/constants';
|
||||
import {
|
||||
isInputFieldTemplate,
|
||||
isInputFieldValue,
|
||||
@ -9,7 +8,6 @@ import {
|
||||
import { startCase } from 'lodash-es';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
interface Props {
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
@ -49,7 +47,7 @@ const FieldTooltipContent = ({ nodeId, fieldName, kind }: Props) => {
|
||||
{fieldTemplate.description}
|
||||
</Text>
|
||||
)}
|
||||
{fieldTemplate && <Text>Type: {FIELDS[fieldTemplate.type].title}</Text>}
|
||||
{fieldTemplate && <Text>Type: {fieldTemplate.originalType}</Text>}
|
||||
{isInputTemplate && <Text>Input: {startCase(fieldTemplate.input)}</Text>}
|
||||
</Flex>
|
||||
);
|
||||
|
@ -4,11 +4,9 @@ import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { map } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
import { getIsPolymorphic } from '../store/util/parseFieldType';
|
||||
import { TYPES_WITH_INPUT_COMPONENTS } from '../types/constants';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
import {
|
||||
POLYMORPHIC_TYPES,
|
||||
TYPES_WITH_INPUT_COMPONENTS,
|
||||
} from '../types/constants';
|
||||
import { getSortedFilteredFieldNames } from '../util/getSortedFilteredFieldNames';
|
||||
|
||||
export const useAnyOrDirectInputFieldNames = (nodeId: string) => {
|
||||
@ -28,7 +26,7 @@ export const useAnyOrDirectInputFieldNames = (nodeId: string) => {
|
||||
const fields = map(nodeTemplate.inputs).filter(
|
||||
(field) =>
|
||||
(['any', 'direct'].includes(field.input) ||
|
||||
POLYMORPHIC_TYPES.includes(field.type)) &&
|
||||
getIsPolymorphic(field.type)) &&
|
||||
TYPES_WITH_INPUT_COMPONENTS.includes(field.type)
|
||||
);
|
||||
return getSortedFilteredFieldNames(fields);
|
||||
|
@ -4,10 +4,8 @@ import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { map } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
import {
|
||||
POLYMORPHIC_TYPES,
|
||||
TYPES_WITH_INPUT_COMPONENTS,
|
||||
} from '../types/constants';
|
||||
import { getIsPolymorphic } from '../store/util/parseFieldType';
|
||||
import { TYPES_WITH_INPUT_COMPONENTS } from '../types/constants';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
import { getSortedFilteredFieldNames } from '../util/getSortedFilteredFieldNames';
|
||||
|
||||
@ -29,8 +27,7 @@ export const useConnectionInputFieldNames = (nodeId: string) => {
|
||||
// get the visible fields
|
||||
const fields = map(nodeTemplate.inputs).filter(
|
||||
(field) =>
|
||||
(field.input === 'connection' &&
|
||||
!POLYMORPHIC_TYPES.includes(field.type)) ||
|
||||
(field.input === 'connection' && !getIsPolymorphic(field.type)) ||
|
||||
!TYPES_WITH_INPUT_COMPONENTS.includes(field.type)
|
||||
);
|
||||
|
||||
|
@ -8,7 +8,7 @@ import { useFieldType } from './useFieldType.ts';
|
||||
const selectIsConnectionInProgress = createSelector(
|
||||
stateSelector,
|
||||
({ nodes }) =>
|
||||
nodes.currentConnectionFieldType !== null &&
|
||||
nodes.connectionStartFieldType !== null &&
|
||||
nodes.connectionStartParams !== null
|
||||
);
|
||||
|
||||
|
@ -20,7 +20,8 @@ export const useFieldType = (
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
return node?.data[KIND_MAP[kind]][fieldName]?.type;
|
||||
const field = node.data[KIND_MAP[kind]][fieldName];
|
||||
return field?.originalType ?? field?.type;
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
|
@ -34,10 +34,10 @@ export const useIsValidConnection = () => {
|
||||
return false;
|
||||
}
|
||||
|
||||
const sourceType = sourceNode.data.outputs[sourceHandle]?.type;
|
||||
const targetType = targetNode.data.inputs[targetHandle]?.type;
|
||||
const sourceField = sourceNode.data.outputs[sourceHandle];
|
||||
const targetField = targetNode.data.inputs[targetHandle];
|
||||
|
||||
if (!sourceType || !targetType) {
|
||||
if (!sourceField || !targetField) {
|
||||
// something has gone terribly awry
|
||||
return false;
|
||||
}
|
||||
@ -70,12 +70,18 @@ export const useIsValidConnection = () => {
|
||||
return edge.target === target && edge.targetHandle === targetHandle;
|
||||
}) &&
|
||||
// except CollectionItem inputs can have multiples
|
||||
targetType !== 'CollectionItem'
|
||||
targetField.type !== 'CollectionItem'
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!validateSourceAndTargetTypes(sourceType, targetType)) {
|
||||
// Must use the originalType here if it exists
|
||||
if (
|
||||
!validateSourceAndTargetTypes(
|
||||
sourceField?.originalType ?? sourceField.type,
|
||||
targetField?.originalType ?? targetField.type
|
||||
)
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -6,7 +6,7 @@ import { NodesState } from './types';
|
||||
export const nodesPersistDenylist: (keyof NodesState)[] = [
|
||||
'nodeTemplates',
|
||||
'connectionStartParams',
|
||||
'currentConnectionFieldType',
|
||||
'connectionStartFieldType',
|
||||
'selectedNodes',
|
||||
'selectedEdges',
|
||||
'isReady',
|
||||
|
@ -93,7 +93,7 @@ export const initialNodesState: NodesState = {
|
||||
nodeTemplates: {},
|
||||
isReady: false,
|
||||
connectionStartParams: null,
|
||||
currentConnectionFieldType: null,
|
||||
connectionStartFieldType: null,
|
||||
connectionMade: false,
|
||||
modifyingEdge: false,
|
||||
addNewNodePosition: null,
|
||||
@ -203,7 +203,7 @@ const nodesSlice = createSlice({
|
||||
nodeId &&
|
||||
handleId &&
|
||||
handleType &&
|
||||
state.currentConnectionFieldType
|
||||
state.connectionStartFieldType
|
||||
) {
|
||||
const newConnection = findConnectionToValidHandle(
|
||||
node,
|
||||
@ -212,7 +212,7 @@ const nodesSlice = createSlice({
|
||||
nodeId,
|
||||
handleId,
|
||||
handleType,
|
||||
state.currentConnectionFieldType
|
||||
state.connectionStartFieldType
|
||||
);
|
||||
if (newConnection) {
|
||||
state.edges = addEdge(
|
||||
@ -224,7 +224,7 @@ const nodesSlice = createSlice({
|
||||
}
|
||||
|
||||
state.connectionStartParams = null;
|
||||
state.currentConnectionFieldType = null;
|
||||
state.connectionStartFieldType = null;
|
||||
},
|
||||
edgeChangeStarted: (state) => {
|
||||
state.modifyingEdge = true;
|
||||
@ -258,10 +258,11 @@ const nodesSlice = createSlice({
|
||||
handleType === 'source'
|
||||
? node.data.outputs[handleId]
|
||||
: node.data.inputs[handleId];
|
||||
state.currentConnectionFieldType = field?.type ?? null;
|
||||
state.connectionStartFieldType =
|
||||
field?.originalType ?? field?.type ?? null;
|
||||
},
|
||||
connectionMade: (state, action: PayloadAction<Connection>) => {
|
||||
const fieldType = state.currentConnectionFieldType;
|
||||
const fieldType = state.connectionStartFieldType;
|
||||
if (!fieldType) {
|
||||
return;
|
||||
}
|
||||
@ -286,7 +287,7 @@ const nodesSlice = createSlice({
|
||||
nodeId &&
|
||||
handleId &&
|
||||
handleType &&
|
||||
state.currentConnectionFieldType
|
||||
state.connectionStartFieldType
|
||||
) {
|
||||
const newConnection = findConnectionToValidHandle(
|
||||
mouseOverNode,
|
||||
@ -295,7 +296,7 @@ const nodesSlice = createSlice({
|
||||
nodeId,
|
||||
handleId,
|
||||
handleType,
|
||||
state.currentConnectionFieldType
|
||||
state.connectionStartFieldType
|
||||
);
|
||||
if (newConnection) {
|
||||
state.edges = addEdge(
|
||||
@ -306,14 +307,14 @@ const nodesSlice = createSlice({
|
||||
}
|
||||
}
|
||||
state.connectionStartParams = null;
|
||||
state.currentConnectionFieldType = null;
|
||||
state.connectionStartFieldType = null;
|
||||
} else {
|
||||
state.addNewNodePosition = action.payload.cursorPosition;
|
||||
state.isAddNodePopoverOpen = true;
|
||||
}
|
||||
} else {
|
||||
state.connectionStartParams = null;
|
||||
state.currentConnectionFieldType = null;
|
||||
state.connectionStartFieldType = null;
|
||||
}
|
||||
state.modifyingEdge = false;
|
||||
},
|
||||
@ -942,7 +943,7 @@ const nodesSlice = createSlice({
|
||||
|
||||
//Make sure these get reset if we close the popover and haven't selected a node
|
||||
state.connectionStartParams = null;
|
||||
state.currentConnectionFieldType = null;
|
||||
state.connectionStartFieldType = null;
|
||||
},
|
||||
addNodePopoverToggled: (state) => {
|
||||
state.isAddNodePopoverOpen = !state.isAddNodePopoverOpen;
|
||||
|
@ -21,7 +21,7 @@ export type NodesState = {
|
||||
edges: Edge<InvocationEdgeExtra>[];
|
||||
nodeTemplates: Record<string, InvocationTemplate>;
|
||||
connectionStartParams: OnConnectStartParams | null;
|
||||
currentConnectionFieldType: FieldType | null;
|
||||
connectionStartFieldType: FieldType | string | null;
|
||||
connectionMade: boolean;
|
||||
modifyingEdge: boolean;
|
||||
shouldShowFieldTypeLegend: boolean;
|
||||
|
@ -94,6 +94,7 @@ export const buildNodeData = (
|
||||
name: outputName,
|
||||
type: outputTemplate.type,
|
||||
fieldKind: 'output',
|
||||
originalType: outputTemplate.originalType,
|
||||
};
|
||||
|
||||
outputsAccumulator[outputName] = outputFieldValue;
|
||||
|
@ -12,7 +12,7 @@ import { getIsGraphAcyclic } from './getIsGraphAcyclic';
|
||||
const isValidConnection = (
|
||||
edges: Edge[],
|
||||
handleCurrentType: HandleType,
|
||||
handleCurrentFieldType: FieldType,
|
||||
handleCurrentFieldType: FieldType | string,
|
||||
node: Node,
|
||||
handle: InputFieldValue | OutputFieldValue
|
||||
) => {
|
||||
@ -35,7 +35,12 @@ const isValidConnection = (
|
||||
}
|
||||
}
|
||||
|
||||
if (!validateSourceAndTargetTypes(handleCurrentFieldType, handle.type)) {
|
||||
if (
|
||||
!validateSourceAndTargetTypes(
|
||||
handleCurrentFieldType,
|
||||
handle.originalType ?? handle.type
|
||||
)
|
||||
) {
|
||||
isValidConnection = false;
|
||||
}
|
||||
|
||||
@ -49,7 +54,7 @@ export const findConnectionToValidHandle = (
|
||||
handleCurrentNodeId: string,
|
||||
handleCurrentName: string,
|
||||
handleCurrentType: HandleType,
|
||||
handleCurrentFieldType: FieldType
|
||||
handleCurrentFieldType: FieldType | string
|
||||
): Connection | null => {
|
||||
if (node.id === handleCurrentNodeId) {
|
||||
return null;
|
||||
|
@ -1,9 +1,9 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
|
||||
import { FieldType } from 'features/nodes/types/types';
|
||||
import i18n from 'i18next';
|
||||
import { HandleType } from 'reactflow';
|
||||
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
|
||||
import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
|
||||
|
||||
/**
|
||||
@ -15,17 +15,17 @@ export const makeConnectionErrorSelector = (
|
||||
nodeId: string,
|
||||
fieldName: string,
|
||||
handleType: HandleType,
|
||||
fieldType?: FieldType
|
||||
fieldType?: FieldType | string
|
||||
) => {
|
||||
return createSelector(stateSelector, (state) => {
|
||||
if (!fieldType) {
|
||||
return i18n.t('nodes.noFieldType');
|
||||
}
|
||||
|
||||
const { currentConnectionFieldType, connectionStartParams, nodes, edges } =
|
||||
const { connectionStartFieldType, connectionStartParams, nodes, edges } =
|
||||
state.nodes;
|
||||
|
||||
if (!connectionStartParams || !currentConnectionFieldType) {
|
||||
if (!connectionStartParams || !connectionStartFieldType) {
|
||||
return i18n.t('nodes.noConnectionInProgress');
|
||||
}
|
||||
|
||||
@ -40,9 +40,9 @@ export const makeConnectionErrorSelector = (
|
||||
}
|
||||
|
||||
const targetType =
|
||||
handleType === 'target' ? fieldType : currentConnectionFieldType;
|
||||
handleType === 'target' ? fieldType : connectionStartFieldType;
|
||||
const sourceType =
|
||||
handleType === 'source' ? fieldType : currentConnectionFieldType;
|
||||
handleType === 'source' ? fieldType : connectionStartFieldType;
|
||||
|
||||
if (nodeId === connectionNodeId) {
|
||||
return i18n.t('nodes.cannotConnectToSelf');
|
||||
|
@ -0,0 +1,14 @@
|
||||
import { FieldType } from 'features/nodes/types/types';
|
||||
|
||||
export const getIsPolymorphic = (type: FieldType | string): boolean =>
|
||||
type.endsWith('Polymorphic');
|
||||
|
||||
export const getIsCollection = (type: FieldType | string): boolean =>
|
||||
type.endsWith('Collection');
|
||||
|
||||
export const getBaseType = (type: FieldType | string): FieldType | string =>
|
||||
getIsPolymorphic(type)
|
||||
? type.replace(/Polymorphic$/, '')
|
||||
: getIsCollection(type)
|
||||
? type.replace(/Collection$/, '')
|
||||
: type;
|
@ -1,18 +1,32 @@
|
||||
import {
|
||||
COLLECTION_MAP,
|
||||
COLLECTION_TYPES,
|
||||
POLYMORPHIC_TO_SINGLE_MAP,
|
||||
POLYMORPHIC_TYPES,
|
||||
} from 'features/nodes/types/constants';
|
||||
import { FieldType } from 'features/nodes/types/types';
|
||||
import {
|
||||
getBaseType,
|
||||
getIsCollection,
|
||||
getIsPolymorphic,
|
||||
} from './parseFieldType';
|
||||
|
||||
/**
|
||||
* Validates that the source and target types are compatible for a connection.
|
||||
* @param sourceType The type of the source field. Must be the originalType if it exists.
|
||||
* @param targetType The type of the target field. Must be the originalType if it exists.
|
||||
* @returns True if the connection is valid, false otherwise.
|
||||
*/
|
||||
export const validateSourceAndTargetTypes = (
|
||||
sourceType: FieldType,
|
||||
targetType: FieldType
|
||||
sourceType: FieldType | string,
|
||||
targetType: FieldType | string
|
||||
) => {
|
||||
const isSourcePolymorphic = getIsPolymorphic(sourceType);
|
||||
const isSourceCollection = getIsCollection(sourceType);
|
||||
const sourceBaseType = getBaseType(sourceType);
|
||||
|
||||
const isTargetPolymorphic = getIsPolymorphic(targetType);
|
||||
const isTargetCollection = getIsCollection(targetType);
|
||||
const targetBaseType = getBaseType(targetType);
|
||||
|
||||
// TODO: There's a bug with Collect -> Iterate nodes:
|
||||
// https://github.com/invoke-ai/InvokeAI/issues/3956
|
||||
// Once this is resolved, we can remove this check.
|
||||
// Note that 'Collection' here is a field type, not node type.
|
||||
if (sourceType === 'Collection' && targetType === 'Collection') {
|
||||
return false;
|
||||
}
|
||||
@ -31,37 +45,21 @@ export const validateSourceAndTargetTypes = (
|
||||
*/
|
||||
|
||||
const isCollectionItemToNonCollection =
|
||||
sourceType === 'CollectionItem' && !COLLECTION_TYPES.includes(targetType);
|
||||
sourceType === 'CollectionItem' && !isTargetCollection;
|
||||
|
||||
const isNonCollectionToCollectionItem =
|
||||
targetType === 'CollectionItem' &&
|
||||
!COLLECTION_TYPES.includes(sourceType) &&
|
||||
!POLYMORPHIC_TYPES.includes(sourceType);
|
||||
!isSourceCollection &&
|
||||
!isSourcePolymorphic;
|
||||
|
||||
const isAnythingToPolymorphicOfSameBaseType =
|
||||
POLYMORPHIC_TYPES.includes(targetType) &&
|
||||
(() => {
|
||||
if (!POLYMORPHIC_TYPES.includes(targetType)) {
|
||||
return false;
|
||||
}
|
||||
const baseType =
|
||||
POLYMORPHIC_TO_SINGLE_MAP[
|
||||
targetType as keyof typeof POLYMORPHIC_TO_SINGLE_MAP
|
||||
];
|
||||
|
||||
const collectionType =
|
||||
COLLECTION_MAP[baseType as keyof typeof COLLECTION_MAP];
|
||||
|
||||
return sourceType === baseType || sourceType === collectionType;
|
||||
})();
|
||||
isTargetPolymorphic && sourceBaseType === targetBaseType;
|
||||
|
||||
const isGenericCollectionToAnyCollectionOrPolymorphic =
|
||||
sourceType === 'Collection' &&
|
||||
(COLLECTION_TYPES.includes(targetType) ||
|
||||
POLYMORPHIC_TYPES.includes(targetType));
|
||||
sourceType === 'Collection' && (isTargetCollection || isTargetPolymorphic);
|
||||
|
||||
const isCollectionToGenericCollection =
|
||||
targetType === 'Collection' && COLLECTION_TYPES.includes(sourceType);
|
||||
targetType === 'Collection' && isSourceCollection;
|
||||
|
||||
const isIntToFloat = sourceType === 'integer' && targetType === 'float';
|
||||
|
||||
@ -71,6 +69,7 @@ export const validateSourceAndTargetTypes = (
|
||||
|
||||
const isTargetAnyType = targetType === 'Any';
|
||||
|
||||
// One of these must be true for the connection to be valid
|
||||
return (
|
||||
isCollectionItemToNonCollection ||
|
||||
isNonCollectionToCollectionItem ||
|
||||
|
@ -20,38 +20,6 @@ export const KIND_MAP = {
|
||||
output: 'outputs' as const,
|
||||
};
|
||||
|
||||
export const COLLECTION_TYPES: FieldType[] = [
|
||||
'Collection',
|
||||
'IntegerCollection',
|
||||
'BooleanCollection',
|
||||
'FloatCollection',
|
||||
'StringCollection',
|
||||
'ImageCollection',
|
||||
'LatentsCollection',
|
||||
'ConditioningCollection',
|
||||
'ControlCollection',
|
||||
'ColorCollection',
|
||||
'T2IAdapterCollection',
|
||||
'IPAdapterCollection',
|
||||
'MetadataItemCollection',
|
||||
'MetadataCollection',
|
||||
];
|
||||
|
||||
export const POLYMORPHIC_TYPES: FieldType[] = [
|
||||
'IntegerPolymorphic',
|
||||
'BooleanPolymorphic',
|
||||
'FloatPolymorphic',
|
||||
'StringPolymorphic',
|
||||
'ImagePolymorphic',
|
||||
'LatentsPolymorphic',
|
||||
'ConditioningPolymorphic',
|
||||
'ControlPolymorphic',
|
||||
'ColorPolymorphic',
|
||||
'T2IAdapterPolymorphic',
|
||||
'IPAdapterPolymorphic',
|
||||
'MetadataItemPolymorphic',
|
||||
];
|
||||
|
||||
export const MODEL_TYPES: FieldType[] = [
|
||||
'IPAdapterModelField',
|
||||
'ControlNetModelField',
|
||||
@ -68,6 +36,26 @@ export const MODEL_TYPES: FieldType[] = [
|
||||
'IPAdapterModelField',
|
||||
];
|
||||
|
||||
/**
|
||||
* TODO: Revise the field type naming scheme
|
||||
*
|
||||
* Unfortunately, due to inconsistent naming of types, we need to keep the below map objects/callbacks.
|
||||
*
|
||||
* Problems:
|
||||
* - some types do not use the word "Field" in their name, e.g. "Scheduler"
|
||||
* - primitive types use all-lowercase names, e.g. "integer"
|
||||
* - collection and polymorphic types do not use the word "Field"
|
||||
*
|
||||
* If these inconsistencies were resolved, we could remove these mappings and use simple string
|
||||
* parsing/manipulation to handle field types.
|
||||
*
|
||||
* It would make some of the parsing logic simpler and reduce the maintenance overhead of adding new
|
||||
* "official" field types.
|
||||
*
|
||||
* This will require migration logic for workflows to update their field types. Workflows *do* have a
|
||||
* version attached to them, so this shouldn't be too difficult.
|
||||
*/
|
||||
|
||||
export const COLLECTION_MAP: FieldTypeMapWithNumber = {
|
||||
integer: 'IntegerCollection',
|
||||
boolean: 'BooleanCollection',
|
||||
@ -83,6 +71,7 @@ export const COLLECTION_MAP: FieldTypeMapWithNumber = {
|
||||
IPAdapterField: 'IPAdapterCollection',
|
||||
MetadataItemField: 'MetadataItemCollection',
|
||||
MetadataField: 'MetadataCollection',
|
||||
Custom: 'CustomCollection',
|
||||
};
|
||||
export const isCollectionItemType = (
|
||||
itemType: string | undefined
|
||||
@ -103,6 +92,7 @@ export const SINGLE_TO_POLYMORPHIC_MAP: FieldTypeMapWithNumber = {
|
||||
T2IAdapterField: 'T2IAdapterPolymorphic',
|
||||
IPAdapterField: 'IPAdapterPolymorphic',
|
||||
MetadataItemField: 'MetadataItemPolymorphic',
|
||||
Custom: 'CustomPolymorphic',
|
||||
};
|
||||
|
||||
export const POLYMORPHIC_TO_SINGLE_MAP: FieldTypeMap = {
|
||||
@ -118,6 +108,7 @@ export const POLYMORPHIC_TO_SINGLE_MAP: FieldTypeMap = {
|
||||
T2IAdapterPolymorphic: 'T2IAdapterField',
|
||||
IPAdapterPolymorphic: 'IPAdapterField',
|
||||
MetadataItemPolymorphic: 'MetadataItemField',
|
||||
CustomPolymorphic: 'Custom',
|
||||
};
|
||||
|
||||
export const TYPES_WITH_INPUT_COMPONENTS: FieldType[] = [
|
||||
@ -150,12 +141,27 @@ export const isPolymorphicItemType = (
|
||||
): itemType is keyof typeof SINGLE_TO_POLYMORPHIC_MAP =>
|
||||
Boolean(itemType && itemType in SINGLE_TO_POLYMORPHIC_MAP);
|
||||
|
||||
export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
||||
export const FIELDS: Record<FieldType | string, FieldUIConfig> = {
|
||||
Any: {
|
||||
color: 'gray.500',
|
||||
description: 'Any field type is accepted.',
|
||||
title: 'Any',
|
||||
},
|
||||
Custom: {
|
||||
color: 'gray.500',
|
||||
description: 'A custom field, provided by an external node.',
|
||||
title: 'Custom',
|
||||
},
|
||||
CustomCollection: {
|
||||
color: 'gray.500',
|
||||
description: 'A custom field collection, provided by an external node.',
|
||||
title: 'Custom Collection',
|
||||
},
|
||||
CustomPolymorphic: {
|
||||
color: 'gray.500',
|
||||
description: 'A custom field polymorphic, provided by an external node.',
|
||||
title: 'Custom Polymorphic',
|
||||
},
|
||||
MetadataField: {
|
||||
color: 'gray.500',
|
||||
description: 'A metadata dict.',
|
||||
|
@ -133,6 +133,9 @@ export const zFieldType = z.enum([
|
||||
'UNetField',
|
||||
'VaeField',
|
||||
'VaeModelField',
|
||||
'Custom',
|
||||
'CustomCollection',
|
||||
'CustomPolymorphic',
|
||||
]);
|
||||
|
||||
export type FieldType = z.infer<typeof zFieldType>;
|
||||
@ -143,7 +146,7 @@ export type FieldTypeMapWithNumber = {
|
||||
|
||||
export const zReservedFieldType = z.enum([
|
||||
'WorkflowField',
|
||||
'IsIntermediate',
|
||||
'IsIntermediate', // this is technically a reserved field type!
|
||||
'MetadataField',
|
||||
]);
|
||||
|
||||
@ -163,6 +166,7 @@ export const zFieldValueBase = z.object({
|
||||
id: z.string().trim().min(1),
|
||||
name: z.string().trim().min(1),
|
||||
type: zFieldType,
|
||||
originalType: z.string().optional(),
|
||||
});
|
||||
export type FieldValueBase = z.infer<typeof zFieldValueBase>;
|
||||
|
||||
@ -190,6 +194,7 @@ export type OutputFieldTemplate = {
|
||||
type: FieldType;
|
||||
title: string;
|
||||
description: string;
|
||||
originalType?: string; // used for custom types
|
||||
} & _OutputField;
|
||||
|
||||
export const zInputFieldValueBase = zFieldValueBase.extend({
|
||||
@ -789,6 +794,21 @@ export const zAnyInputFieldValue = zInputFieldValueBase.extend({
|
||||
value: z.any().optional(),
|
||||
});
|
||||
|
||||
export const zCustomInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('Custom'),
|
||||
value: z.any().optional(),
|
||||
});
|
||||
|
||||
export const zCustomCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('CustomCollection'),
|
||||
value: z.array(z.any()).optional(),
|
||||
});
|
||||
|
||||
export const zCustomPolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('CustomPolymorphic'),
|
||||
value: z.union([z.any(), z.array(z.any())]).optional(),
|
||||
});
|
||||
|
||||
export const zInputFieldValue = z.discriminatedUnion('type', [
|
||||
zAnyInputFieldValue,
|
||||
zBoardInputFieldValue,
|
||||
@ -846,6 +866,9 @@ export const zInputFieldValue = z.discriminatedUnion('type', [
|
||||
zMetadataItemPolymorphicInputFieldValue,
|
||||
zMetadataInputFieldValue,
|
||||
zMetadataCollectionInputFieldValue,
|
||||
zCustomInputFieldValue,
|
||||
zCustomCollectionInputFieldValue,
|
||||
zCustomPolymorphicInputFieldValue,
|
||||
]);
|
||||
|
||||
export type InputFieldValue = z.infer<typeof zInputFieldValue>;
|
||||
@ -856,6 +879,7 @@ export type InputFieldTemplateBase = {
|
||||
description: string;
|
||||
required: boolean;
|
||||
fieldKind: 'input';
|
||||
originalType?: string; // used for custom types
|
||||
} & _InputField;
|
||||
|
||||
export type AnyInputFieldTemplate = InputFieldTemplateBase & {
|
||||
@ -863,6 +887,21 @@ export type AnyInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: undefined;
|
||||
};
|
||||
|
||||
export type CustomInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'Custom';
|
||||
default: undefined;
|
||||
};
|
||||
|
||||
export type CustomCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'CustomCollection';
|
||||
default: [];
|
||||
};
|
||||
|
||||
export type CustomPolymorphicInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'CustomPolymorphic';
|
||||
default: undefined;
|
||||
};
|
||||
|
||||
export type IntegerInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'integer';
|
||||
default: number;
|
||||
@ -1259,7 +1298,10 @@ export type InputFieldTemplate =
|
||||
| MetadataItemCollectionInputFieldTemplate
|
||||
| MetadataInputFieldTemplate
|
||||
| MetadataItemPolymorphicInputFieldTemplate
|
||||
| MetadataCollectionInputFieldTemplate;
|
||||
| MetadataCollectionInputFieldTemplate
|
||||
| CustomInputFieldTemplate
|
||||
| CustomCollectionInputFieldTemplate
|
||||
| CustomPolymorphicInputFieldTemplate;
|
||||
|
||||
export const isInputFieldValue = (
|
||||
field?: InputFieldValue | OutputFieldValue
|
||||
|
@ -8,9 +8,9 @@ import {
|
||||
} from 'lodash-es';
|
||||
import { OpenAPIV3_1 } from 'openapi-types';
|
||||
import { ControlField } from 'services/api/types';
|
||||
import { getIsPolymorphic } from '../store/util/parseFieldType';
|
||||
import {
|
||||
COLLECTION_MAP,
|
||||
POLYMORPHIC_TYPES,
|
||||
SINGLE_TO_POLYMORPHIC_MAP,
|
||||
isCollectionItemType,
|
||||
isPolymorphicItemType,
|
||||
@ -35,6 +35,9 @@ import {
|
||||
ControlInputFieldTemplate,
|
||||
ControlNetModelInputFieldTemplate,
|
||||
ControlPolymorphicInputFieldTemplate,
|
||||
CustomCollectionInputFieldTemplate,
|
||||
CustomInputFieldTemplate,
|
||||
CustomPolymorphicInputFieldTemplate,
|
||||
DenoiseMaskInputFieldTemplate,
|
||||
EnumInputFieldTemplate,
|
||||
FieldType,
|
||||
@ -84,6 +87,7 @@ import {
|
||||
VaeInputFieldTemplate,
|
||||
VaeModelInputFieldTemplate,
|
||||
isArraySchemaObject,
|
||||
isFieldType,
|
||||
isNonArraySchemaObject,
|
||||
isRefObject,
|
||||
isSchemaObject,
|
||||
@ -981,9 +985,45 @@ const buildSchedulerInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildCustomCollectionInputFieldTemplate = ({
|
||||
baseField,
|
||||
}: BuildInputFieldArg): CustomCollectionInputFieldTemplate => {
|
||||
const template: CustomCollectionInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'CustomCollection',
|
||||
default: [],
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildCustomPolymorphicInputFieldTemplate = ({
|
||||
baseField,
|
||||
}: BuildInputFieldArg): CustomPolymorphicInputFieldTemplate => {
|
||||
const template: CustomPolymorphicInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'CustomPolymorphic',
|
||||
default: undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildCustomInputFieldTemplate = ({
|
||||
baseField,
|
||||
}: BuildInputFieldArg): CustomInputFieldTemplate => {
|
||||
const template: CustomInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'Custom',
|
||||
default: undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
export const getFieldType = (
|
||||
schemaObject: OpenAPIV3_1SchemaOrRef
|
||||
): string | undefined => {
|
||||
): { type: string; originalType: string } | undefined => {
|
||||
if (isSchemaObject(schemaObject)) {
|
||||
if (!schemaObject.type) {
|
||||
// if schemaObject has no type, then it should have one of allOf, anyOf, oneOf
|
||||
@ -991,7 +1031,17 @@ export const getFieldType = (
|
||||
if (schemaObject.allOf) {
|
||||
const allOf = schemaObject.allOf;
|
||||
if (allOf && allOf[0] && isRefObject(allOf[0])) {
|
||||
return refObjectToSchemaName(allOf[0]);
|
||||
// This is a single ref type
|
||||
const originalType = refObjectToSchemaName(allOf[0]);
|
||||
|
||||
if (!originalType) {
|
||||
// something has gone terribly awry
|
||||
return;
|
||||
}
|
||||
return {
|
||||
type: isFieldType(originalType) ? originalType : 'Custom',
|
||||
originalType,
|
||||
};
|
||||
}
|
||||
} else if (schemaObject.anyOf) {
|
||||
// ignore null types
|
||||
@ -1004,8 +1054,17 @@ export const getFieldType = (
|
||||
return true;
|
||||
});
|
||||
if (anyOf.length === 1) {
|
||||
// This is a single ref type
|
||||
if (isRefObject(anyOf[0])) {
|
||||
return refObjectToSchemaName(anyOf[0]);
|
||||
const originalType = refObjectToSchemaName(anyOf[0]);
|
||||
if (!originalType) {
|
||||
return;
|
||||
}
|
||||
|
||||
return {
|
||||
type: isFieldType(originalType) ? originalType : 'Custom',
|
||||
originalType,
|
||||
};
|
||||
} else if (isSchemaObject(anyOf[0])) {
|
||||
return getFieldType(anyOf[0]);
|
||||
}
|
||||
@ -1051,16 +1110,29 @@ export const getFieldType = (
|
||||
secondType = second.type;
|
||||
}
|
||||
}
|
||||
if (firstType === secondType && isPolymorphicItemType(firstType)) {
|
||||
return SINGLE_TO_POLYMORPHIC_MAP[firstType];
|
||||
if (firstType === secondType) {
|
||||
if (isPolymorphicItemType(firstType)) {
|
||||
// Known polymorphic field type
|
||||
const originalType = SINGLE_TO_POLYMORPHIC_MAP[firstType];
|
||||
if (!originalType) {
|
||||
return;
|
||||
}
|
||||
return { type: originalType, originalType };
|
||||
}
|
||||
|
||||
// else custom polymorphic
|
||||
return {
|
||||
type: 'CustomPolymorphic',
|
||||
originalType: `${firstType}Polymorphic`,
|
||||
};
|
||||
}
|
||||
}
|
||||
} else if (schemaObject.enum) {
|
||||
return 'enum';
|
||||
return { type: 'enum', originalType: 'enum' };
|
||||
} else if (schemaObject.type) {
|
||||
if (schemaObject.type === 'number') {
|
||||
// floats are "number" in OpenAPI, while ints are "integer" - we need to distinguish them
|
||||
return 'float';
|
||||
return { type: 'float', originalType: 'float' };
|
||||
} else if (schemaObject.type === 'array') {
|
||||
const itemType = isSchemaObject(schemaObject.items)
|
||||
? schemaObject.items.type
|
||||
@ -1072,16 +1144,39 @@ export const getFieldType = (
|
||||
}
|
||||
|
||||
if (isCollectionItemType(itemType)) {
|
||||
return COLLECTION_MAP[itemType];
|
||||
// known collection field type
|
||||
const originalType = COLLECTION_MAP[itemType];
|
||||
if (!originalType) {
|
||||
return;
|
||||
}
|
||||
return { type: originalType, originalType };
|
||||
}
|
||||
|
||||
return;
|
||||
} else if (!isArray(schemaObject.type)) {
|
||||
return schemaObject.type;
|
||||
return {
|
||||
type: 'CustomCollection',
|
||||
originalType: `${itemType}Collection`,
|
||||
};
|
||||
} else if (
|
||||
!isArray(schemaObject.type) &&
|
||||
schemaObject.type !== 'null' && // 'null' is not valid
|
||||
schemaObject.type !== 'object' // 'object' is not valid
|
||||
) {
|
||||
const originalType = schemaObject.type;
|
||||
return { type: originalType, originalType };
|
||||
}
|
||||
// else ignore
|
||||
return;
|
||||
}
|
||||
} else if (isRefObject(schemaObject)) {
|
||||
return refObjectToSchemaName(schemaObject);
|
||||
const originalType = refObjectToSchemaName(schemaObject);
|
||||
if (!originalType) {
|
||||
return;
|
||||
}
|
||||
|
||||
return {
|
||||
type: isFieldType(originalType) ? originalType : 'Custom',
|
||||
originalType,
|
||||
};
|
||||
}
|
||||
return;
|
||||
};
|
||||
@ -1145,13 +1240,11 @@ const TEMPLATE_BUILDER_MAP: {
|
||||
UNetField: buildUNetInputFieldTemplate,
|
||||
VaeField: buildVaeInputFieldTemplate,
|
||||
VaeModelField: buildVaeModelInputFieldTemplate,
|
||||
Custom: buildCustomInputFieldTemplate,
|
||||
CustomCollection: buildCustomCollectionInputFieldTemplate,
|
||||
CustomPolymorphic: buildCustomPolymorphicInputFieldTemplate,
|
||||
};
|
||||
|
||||
const isTemplatedFieldType = (
|
||||
fieldType: string | undefined
|
||||
): fieldType is keyof typeof TEMPLATE_BUILDER_MAP =>
|
||||
Boolean(fieldType && fieldType in TEMPLATE_BUILDER_MAP);
|
||||
|
||||
/**
|
||||
* Builds an input field from an invocation schema property.
|
||||
* @param fieldSchema The schema object
|
||||
@ -1161,7 +1254,8 @@ export const buildInputFieldTemplate = (
|
||||
nodeSchema: InvocationSchemaObject,
|
||||
fieldSchema: InvocationFieldSchema,
|
||||
name: string,
|
||||
fieldType: FieldType
|
||||
fieldType: FieldType,
|
||||
originalType: string
|
||||
) => {
|
||||
const {
|
||||
input,
|
||||
@ -1175,7 +1269,7 @@ export const buildInputFieldTemplate = (
|
||||
|
||||
const extra = {
|
||||
// TODO: Can we support polymorphic inputs in the UI?
|
||||
input: POLYMORPHIC_TYPES.includes(fieldType) ? 'connection' : input,
|
||||
input: getIsPolymorphic(fieldType) ? 'connection' : input,
|
||||
ui_hidden,
|
||||
ui_component,
|
||||
ui_type,
|
||||
@ -1183,6 +1277,7 @@ export const buildInputFieldTemplate = (
|
||||
ui_order,
|
||||
ui_choice_labels,
|
||||
item_default,
|
||||
originalType,
|
||||
};
|
||||
|
||||
const baseField = {
|
||||
@ -1193,10 +1288,6 @@ export const buildInputFieldTemplate = (
|
||||
...extra,
|
||||
};
|
||||
|
||||
if (!isTemplatedFieldType(fieldType)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const builder = TEMPLATE_BUILDER_MAP[fieldType];
|
||||
|
||||
if (!builder) {
|
||||
|
@ -60,6 +60,9 @@ const FIELD_VALUE_FALLBACK_MAP: {
|
||||
UNetField: undefined,
|
||||
VaeField: undefined,
|
||||
VaeModelField: undefined,
|
||||
Custom: undefined,
|
||||
CustomCollection: [],
|
||||
CustomPolymorphic: undefined,
|
||||
};
|
||||
|
||||
export const buildInputFieldValue = (
|
||||
@ -76,10 +79,9 @@ export const buildInputFieldValue = (
|
||||
type: template.type,
|
||||
label: '',
|
||||
fieldKind: 'input',
|
||||
originalType: template.originalType,
|
||||
value: template.default ?? FIELD_VALUE_FALLBACK_MAP[template.type],
|
||||
} as InputFieldValue;
|
||||
|
||||
fieldValue.value =
|
||||
template.default ?? FIELD_VALUE_FALLBACK_MAP[template.type];
|
||||
|
||||
return fieldValue;
|
||||
};
|
||||
|
@ -103,14 +103,15 @@ export const parseSchema = (
|
||||
return inputsAccumulator;
|
||||
}
|
||||
|
||||
const fieldType = property.ui_type ?? getFieldType(property);
|
||||
const fieldTypeResult = property.ui_type
|
||||
? { type: property.ui_type, originalType: property.ui_type }
|
||||
: getFieldType(property);
|
||||
|
||||
if (!fieldType) {
|
||||
if (!fieldTypeResult) {
|
||||
logger('nodes').warn(
|
||||
{
|
||||
node: type,
|
||||
fieldName: propertyName,
|
||||
fieldType,
|
||||
field: parseify(property),
|
||||
},
|
||||
'Missing input field type'
|
||||
@ -118,6 +119,9 @@ export const parseSchema = (
|
||||
return inputsAccumulator;
|
||||
}
|
||||
|
||||
// stash this for custom types
|
||||
const { type: fieldType, originalType } = fieldTypeResult;
|
||||
|
||||
if (fieldType === 'WorkflowField') {
|
||||
withWorkflow = true;
|
||||
return inputsAccumulator;
|
||||
@ -136,6 +140,18 @@ export const parseSchema = (
|
||||
return inputsAccumulator;
|
||||
}
|
||||
|
||||
if (!isFieldType(originalType)) {
|
||||
logger('nodes').debug(
|
||||
{
|
||||
node: type,
|
||||
fieldName: propertyName,
|
||||
fieldType,
|
||||
field: parseify(property),
|
||||
},
|
||||
`Fallback handling for unknown input field type: ${fieldType}`
|
||||
);
|
||||
}
|
||||
|
||||
if (!isFieldType(fieldType)) {
|
||||
logger('nodes').warn(
|
||||
{
|
||||
@ -144,7 +160,7 @@ export const parseSchema = (
|
||||
fieldType,
|
||||
field: parseify(property),
|
||||
},
|
||||
`Skipping unknown input field type: ${fieldType}`
|
||||
`Unable to parse field type: ${fieldType}`
|
||||
);
|
||||
return inputsAccumulator;
|
||||
}
|
||||
@ -153,7 +169,8 @@ export const parseSchema = (
|
||||
schema,
|
||||
property,
|
||||
propertyName,
|
||||
fieldType
|
||||
fieldType,
|
||||
originalType
|
||||
);
|
||||
|
||||
if (!field) {
|
||||
@ -162,6 +179,7 @@ export const parseSchema = (
|
||||
node: type,
|
||||
fieldName: propertyName,
|
||||
fieldType,
|
||||
originalType,
|
||||
field: parseify(property),
|
||||
},
|
||||
'Skipping input field with no template'
|
||||
@ -220,12 +238,46 @@ export const parseSchema = (
|
||||
return outputsAccumulator;
|
||||
}
|
||||
|
||||
const fieldType = property.ui_type ?? getFieldType(property);
|
||||
const fieldTypeResult = property.ui_type
|
||||
? { type: property.ui_type, originalType: property.ui_type }
|
||||
: getFieldType(property);
|
||||
|
||||
if (!fieldTypeResult) {
|
||||
logger('nodes').warn(
|
||||
{
|
||||
node: type,
|
||||
fieldName: propertyName,
|
||||
field: parseify(property),
|
||||
},
|
||||
'Missing output field type'
|
||||
);
|
||||
return outputsAccumulator;
|
||||
}
|
||||
|
||||
const { type: fieldType, originalType } = fieldTypeResult;
|
||||
|
||||
if (!isFieldType(fieldType)) {
|
||||
logger('nodes').debug(
|
||||
{
|
||||
node: type,
|
||||
fieldName: propertyName,
|
||||
fieldType,
|
||||
originalType,
|
||||
field: parseify(property),
|
||||
},
|
||||
`Fallback handling for unknown output field type: ${fieldType}`
|
||||
);
|
||||
}
|
||||
|
||||
if (!isFieldType(fieldType)) {
|
||||
logger('nodes').warn(
|
||||
{ fieldName: propertyName, fieldType, field: parseify(property) },
|
||||
'Skipping unknown output field type'
|
||||
{
|
||||
node: type,
|
||||
fieldName: propertyName,
|
||||
fieldType,
|
||||
field: parseify(property),
|
||||
},
|
||||
`Unable to parse field type: ${fieldType}`
|
||||
);
|
||||
return outputsAccumulator;
|
||||
}
|
||||
@ -240,6 +292,7 @@ export const parseSchema = (
|
||||
ui_hidden: property.ui_hidden ?? false,
|
||||
ui_type: property.ui_type,
|
||||
ui_order: property.ui_order,
|
||||
originalType,
|
||||
};
|
||||
|
||||
return outputsAccumulator;
|
||||
|
Reference in New Issue
Block a user