feat(ui): add originalType to FieldType, improved connection validation

We now keep track of the original field type, derived from the python type annotation in addition to the override type provided by `ui_type`.

This makes `ui_type` work more like it sound like it should work - change the UI input component only.

Connection validation is extend to also check the original types. If there is any match between two fields' "final" or original types, we consider the connection valid.This change is backwards-compatible; there is no workflow migration needed.
This commit is contained in:
psychedelicious 2024-05-17 20:08:32 +10:00
parent af3fd26d4e
commit 85a5a7c47a
11 changed files with 502 additions and 346 deletions

View File

@ -4,9 +4,8 @@ import { useAppSelector, useAppStore } from 'app/store/storeHooks';
import { $templates } from 'features/nodes/store/nodesSlice';
import { getIsGraphAcyclic } from 'features/nodes/store/util/getIsGraphAcyclic';
import { getCollectItemType } from 'features/nodes/store/util/makeIsConnectionValidSelector';
import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes';
import { areTypesEqual, validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes';
import type { InvocationNodeData } from 'features/nodes/types/invocation';
import { isEqual } from 'lodash-es';
import { useCallback } from 'react';
import type { Connection, Node } from 'reactflow';
@ -70,7 +69,7 @@ export const useIsValidConnection = () => {
// Collect nodes shouldn't mix and match field types
const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id);
if (collectItemType) {
return isEqual(sourceFieldTemplate.type, collectItemType);
return areTypesEqual(sourceFieldTemplate.type, collectItemType);
}
}

View File

@ -1,12 +1,12 @@
import type { PendingConnection, Templates } from 'features/nodes/store/types';
import { getCollectItemType } from 'features/nodes/store/util/makeIsConnectionValidSelector';
import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation';
import { differenceWith, isEqual, map } from 'lodash-es';
import { differenceWith, map } from 'lodash-es';
import type { Connection } from 'reactflow';
import { assert } from 'tsafe';
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
import { areTypesEqual, validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
export const getFirstValidConnection = (
templates: Templates,
@ -83,7 +83,7 @@ export const getFirstValidConnection = (
// Narrow candidates to same field type as already is connected to the collect node
const collectItemType = getCollectItemType(templates, nodes, edges, pendingConnection.node.id);
if (collectItemType) {
candidateFields = candidateFields.filter((field) => isEqual(field.type, collectItemType));
candidateFields = candidateFields.filter((field) => areTypesEqual(field.type, collectItemType));
}
}
const candidateField = candidateFields.find((field) => {

View File

@ -4,12 +4,11 @@ import type { PendingConnection, Templates } from 'features/nodes/store/types';
import type { FieldType } from 'features/nodes/types/field';
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
import i18n from 'i18next';
import { isEqual } from 'lodash-es';
import type { HandleType } from 'reactflow';
import { assert } from 'tsafe';
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
import { areTypesEqual, validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
export const getCollectItemType = (
templates: Templates,
@ -111,7 +110,7 @@ export const makeConnectionErrorSelector = (
// Collect nodes shouldn't mix and match field types
const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id);
if (collectItemType) {
if (!isEqual(sourceType, collectItemType)) {
if (!areTypesEqual(sourceType, collectItemType)) {
return i18n.t('nodes.cannotMixAndMatchCollectionItemTypes');
}
}

View File

@ -1,5 +1,25 @@
import type { FieldType } from 'features/nodes/types/field';
import { isEqual } from 'lodash-es';
import { type FieldType, isStatefulFieldType } from 'features/nodes/types/field';
import { isEqual, omit } from 'lodash-es';
export const areTypesEqual = (sourceType: FieldType, targetType: FieldType) => {
const _sourceType = isStatefulFieldType(sourceType) ? omit(sourceType, 'originalType') : sourceType;
const _targetType = isStatefulFieldType(targetType) ? omit(targetType, 'originalType') : targetType;
const _sourceTypeOriginal = isStatefulFieldType(sourceType) ? sourceType.originalType : sourceType;
const _targetTypeOriginal = isStatefulFieldType(targetType) ? targetType.originalType : targetType;
if (isEqual(_sourceType, _targetType)) {
return true;
}
if (isEqual(_sourceType, _targetTypeOriginal)) {
return true;
}
if (isEqual(_sourceTypeOriginal, _targetType)) {
return true;
}
if (isEqual(_sourceTypeOriginal, _targetTypeOriginal)) {
return true;
}
return false;
};
/**
* Validates that the source and target types are compatible for a connection.
@ -15,7 +35,7 @@ export const validateSourceAndTargetTypes = (sourceType: FieldType, targetType:
return false;
}
if (isEqual(sourceType, targetType)) {
if (areTypesEqual(sourceType, targetType)) {
return true;
}

View File

@ -66,16 +66,114 @@ export const zFieldIdentifier = z.object({
export type FieldIdentifier = z.infer<typeof zFieldIdentifier>;
// #endregion
// #region IntegerField
// #region Field Types
const zStatelessFieldType = zFieldTypeBase.extend({
name: z.string().min(1), // stateless --> we accept the field's name as the type
});
const zIntegerFieldType = zFieldTypeBase.extend({
name: z.literal('IntegerField'),
originalType: zStatelessFieldType.optional(),
});
const zFloatFieldType = zFieldTypeBase.extend({
name: z.literal('FloatField'),
originalType: zStatelessFieldType.optional(),
});
const zStringFieldType = zFieldTypeBase.extend({
name: z.literal('StringField'),
originalType: zStatelessFieldType.optional(),
});
const zBooleanFieldType = zFieldTypeBase.extend({
name: z.literal('BooleanField'),
originalType: zStatelessFieldType.optional(),
});
const zEnumFieldType = zFieldTypeBase.extend({
name: z.literal('EnumField'),
originalType: zStatelessFieldType.optional(),
});
const zImageFieldType = zFieldTypeBase.extend({
name: z.literal('ImageField'),
originalType: zStatelessFieldType.optional(),
});
const zBoardFieldType = zFieldTypeBase.extend({
name: z.literal('BoardField'),
originalType: zStatelessFieldType.optional(),
});
const zColorFieldType = zFieldTypeBase.extend({
name: z.literal('ColorField'),
originalType: zStatelessFieldType.optional(),
});
const zMainModelFieldType = zFieldTypeBase.extend({
name: z.literal('MainModelField'),
originalType: zStatelessFieldType.optional(),
});
const zSDXLMainModelFieldType = zFieldTypeBase.extend({
name: z.literal('SDXLMainModelField'),
originalType: zStatelessFieldType.optional(),
});
const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({
name: z.literal('SDXLRefinerModelField'),
originalType: zStatelessFieldType.optional(),
});
const zVAEModelFieldType = zFieldTypeBase.extend({
name: z.literal('VAEModelField'),
originalType: zStatelessFieldType.optional(),
});
const zLoRAModelFieldType = zFieldTypeBase.extend({
name: z.literal('LoRAModelField'),
originalType: zStatelessFieldType.optional(),
});
const zControlNetModelFieldType = zFieldTypeBase.extend({
name: z.literal('ControlNetModelField'),
originalType: zStatelessFieldType.optional(),
});
const zIPAdapterModelFieldType = zFieldTypeBase.extend({
name: z.literal('IPAdapterModelField'),
originalType: zStatelessFieldType.optional(),
});
const zT2IAdapterModelFieldType = zFieldTypeBase.extend({
name: z.literal('T2IAdapterModelField'),
originalType: zStatelessFieldType.optional(),
});
const zSchedulerFieldType = zFieldTypeBase.extend({
name: z.literal('SchedulerField'),
originalType: zStatelessFieldType.optional(),
});
const zStatefulFieldType = z.union([
zIntegerFieldType,
zFloatFieldType,
zStringFieldType,
zBooleanFieldType,
zEnumFieldType,
zImageFieldType,
zBoardFieldType,
zMainModelFieldType,
zSDXLMainModelFieldType,
zSDXLRefinerModelFieldType,
zVAEModelFieldType,
zLoRAModelFieldType,
zControlNetModelFieldType,
zIPAdapterModelFieldType,
zT2IAdapterModelFieldType,
zColorFieldType,
zSchedulerFieldType,
]);
export type StatefulFieldType = z.infer<typeof zStatefulFieldType>;
const statefulFieldTypeNames = zStatefulFieldType.options.map((o) => o.shape.name.value);
export const isStatefulFieldType = (fieldType: FieldType): fieldType is StatefulFieldType =>
statefulFieldTypeNames.includes(fieldType.name as any);
const zFieldType = z.union([zStatefulFieldType, zStatelessFieldType]);
export type FieldType = z.infer<typeof zFieldType>;
// #endregion
// #region IntegerField
export const zIntegerFieldValue = z.number().int();
const zIntegerFieldInputInstance = zFieldInputInstanceBase.extend({
value: zIntegerFieldValue,
});
const zIntegerFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zIntegerFieldType,
originalType: zFieldType.optional(),
default: zIntegerFieldValue,
multipleOf: z.number().int().optional(),
maximum: z.number().int().optional(),
@ -85,6 +183,7 @@ const zIntegerFieldInputTemplate = zFieldInputTemplateBase.extend({
});
const zIntegerFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zIntegerFieldType,
originalType: zFieldType.optional(),
});
export type IntegerFieldValue = z.infer<typeof zIntegerFieldValue>;
export type IntegerFieldInputInstance = z.infer<typeof zIntegerFieldInputInstance>;
@ -96,15 +195,14 @@ export const isIntegerFieldInputTemplate = (val: unknown): val is IntegerFieldIn
// #endregion
// #region FloatField
const zFloatFieldType = zFieldTypeBase.extend({
name: z.literal('FloatField'),
});
export const zFloatFieldValue = z.number();
const zFloatFieldInputInstance = zFieldInputInstanceBase.extend({
value: zFloatFieldValue,
});
const zFloatFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zFloatFieldType,
originalType: zFieldType.optional(),
default: zFloatFieldValue,
multipleOf: z.number().optional(),
maximum: z.number().optional(),
@ -114,6 +212,7 @@ const zFloatFieldInputTemplate = zFieldInputTemplateBase.extend({
});
const zFloatFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zFloatFieldType,
originalType: zFieldType.optional(),
});
export type FloatFieldValue = z.infer<typeof zFloatFieldValue>;
export type FloatFieldInputInstance = z.infer<typeof zFloatFieldInputInstance>;
@ -125,21 +224,21 @@ export const isFloatFieldInputTemplate = (val: unknown): val is FloatFieldInputT
// #endregion
// #region StringField
const zStringFieldType = zFieldTypeBase.extend({
name: z.literal('StringField'),
});
export const zStringFieldValue = z.string();
const zStringFieldInputInstance = zFieldInputInstanceBase.extend({
value: zStringFieldValue,
});
const zStringFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zStringFieldType,
originalType: zFieldType.optional(),
default: zStringFieldValue,
maxLength: z.number().int().optional(),
minLength: z.number().int().optional(),
});
const zStringFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zStringFieldType,
originalType: zFieldType.optional(),
});
export type StringFieldValue = z.infer<typeof zStringFieldValue>;
@ -152,19 +251,19 @@ export const isStringFieldInputTemplate = (val: unknown): val is StringFieldInpu
// #endregion
// #region BooleanField
const zBooleanFieldType = zFieldTypeBase.extend({
name: z.literal('BooleanField'),
});
export const zBooleanFieldValue = z.boolean();
const zBooleanFieldInputInstance = zFieldInputInstanceBase.extend({
value: zBooleanFieldValue,
});
const zBooleanFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zBooleanFieldType,
originalType: zFieldType.optional(),
default: zBooleanFieldValue,
});
const zBooleanFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zBooleanFieldType,
originalType: zFieldType.optional(),
});
export type BooleanFieldValue = z.infer<typeof zBooleanFieldValue>;
export type BooleanFieldInputInstance = z.infer<typeof zBooleanFieldInputInstance>;
@ -176,21 +275,21 @@ export const isBooleanFieldInputTemplate = (val: unknown): val is BooleanFieldIn
// #endregion
// #region EnumField
const zEnumFieldType = zFieldTypeBase.extend({
name: z.literal('EnumField'),
});
export const zEnumFieldValue = z.string();
const zEnumFieldInputInstance = zFieldInputInstanceBase.extend({
value: zEnumFieldValue,
});
const zEnumFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zEnumFieldType,
originalType: zFieldType.optional(),
default: zEnumFieldValue,
options: z.array(z.string()),
labels: z.record(z.string()).optional(),
});
const zEnumFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zEnumFieldType,
originalType: zFieldType.optional(),
});
export type EnumFieldValue = z.infer<typeof zEnumFieldValue>;
export type EnumFieldInputInstance = z.infer<typeof zEnumFieldInputInstance>;
@ -202,19 +301,19 @@ export const isEnumFieldInputTemplate = (val: unknown): val is EnumFieldInputTem
// #endregion
// #region ImageField
const zImageFieldType = zFieldTypeBase.extend({
name: z.literal('ImageField'),
});
export const zImageFieldValue = zImageField.optional();
const zImageFieldInputInstance = zFieldInputInstanceBase.extend({
value: zImageFieldValue,
});
const zImageFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zImageFieldType,
originalType: zFieldType.optional(),
default: zImageFieldValue,
});
const zImageFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zImageFieldType,
originalType: zFieldType.optional(),
});
export type ImageFieldValue = z.infer<typeof zImageFieldValue>;
export type ImageFieldInputInstance = z.infer<typeof zImageFieldInputInstance>;
@ -226,19 +325,19 @@ export const isImageFieldInputTemplate = (val: unknown): val is ImageFieldInputT
// #endregion
// #region BoardField
const zBoardFieldType = zFieldTypeBase.extend({
name: z.literal('BoardField'),
});
export const zBoardFieldValue = zBoardField.optional();
const zBoardFieldInputInstance = zFieldInputInstanceBase.extend({
value: zBoardFieldValue,
});
const zBoardFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zBoardFieldType,
originalType: zFieldType.optional(),
default: zBoardFieldValue,
});
const zBoardFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zBoardFieldType,
originalType: zFieldType.optional(),
});
export type BoardFieldValue = z.infer<typeof zBoardFieldValue>;
export type BoardFieldInputInstance = z.infer<typeof zBoardFieldInputInstance>;
@ -250,19 +349,19 @@ export const isBoardFieldInputTemplate = (val: unknown): val is BoardFieldInputT
// #endregion
// #region ColorField
const zColorFieldType = zFieldTypeBase.extend({
name: z.literal('ColorField'),
});
export const zColorFieldValue = zColorField.optional();
const zColorFieldInputInstance = zFieldInputInstanceBase.extend({
value: zColorFieldValue,
});
const zColorFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zColorFieldType,
originalType: zFieldType.optional(),
default: zColorFieldValue,
});
const zColorFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zColorFieldType,
originalType: zFieldType.optional(),
});
export type ColorFieldValue = z.infer<typeof zColorFieldValue>;
export type ColorFieldInputInstance = z.infer<typeof zColorFieldInputInstance>;
@ -274,19 +373,19 @@ export const isColorFieldInputTemplate = (val: unknown): val is ColorFieldInputT
// #endregion
// #region MainModelField
const zMainModelFieldType = zFieldTypeBase.extend({
name: z.literal('MainModelField'),
});
export const zMainModelFieldValue = zModelIdentifierField.optional();
const zMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zMainModelFieldValue,
});
const zMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zMainModelFieldType,
originalType: zFieldType.optional(),
default: zMainModelFieldValue,
});
const zMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zMainModelFieldType,
originalType: zFieldType.optional(),
});
export type MainModelFieldValue = z.infer<typeof zMainModelFieldValue>;
export type MainModelFieldInputInstance = z.infer<typeof zMainModelFieldInputInstance>;
@ -298,19 +397,19 @@ export const isMainModelFieldInputTemplate = (val: unknown): val is MainModelFie
// #endregion
// #region SDXLMainModelField
const zSDXLMainModelFieldType = zFieldTypeBase.extend({
name: z.literal('SDXLMainModelField'),
});
const zSDXLMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
const zSDXLMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zSDXLMainModelFieldValue,
});
const zSDXLMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zSDXLMainModelFieldType,
originalType: zFieldType.optional(),
default: zSDXLMainModelFieldValue,
});
const zSDXLMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zSDXLMainModelFieldType,
originalType: zFieldType.optional(),
});
export type SDXLMainModelFieldInputInstance = z.infer<typeof zSDXLMainModelFieldInputInstance>;
export type SDXLMainModelFieldInputTemplate = z.infer<typeof zSDXLMainModelFieldInputTemplate>;
@ -321,9 +420,7 @@ export const isSDXLMainModelFieldInputTemplate = (val: unknown): val is SDXLMain
// #endregion
// #region SDXLRefinerModelField
const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({
name: z.literal('SDXLRefinerModelField'),
});
/** @alias */ // tells knip to ignore this duplicate export
export const zSDXLRefinerModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL Refiner models only.
const zSDXLRefinerModelFieldInputInstance = zFieldInputInstanceBase.extend({
@ -331,10 +428,12 @@ const zSDXLRefinerModelFieldInputInstance = zFieldInputInstanceBase.extend({
});
const zSDXLRefinerModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zSDXLRefinerModelFieldType,
originalType: zFieldType.optional(),
default: zSDXLRefinerModelFieldValue,
});
const zSDXLRefinerModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zSDXLRefinerModelFieldType,
originalType: zFieldType.optional(),
});
export type SDXLRefinerModelFieldValue = z.infer<typeof zSDXLRefinerModelFieldValue>;
export type SDXLRefinerModelFieldInputInstance = z.infer<typeof zSDXLRefinerModelFieldInputInstance>;
@ -346,19 +445,19 @@ export const isSDXLRefinerModelFieldInputTemplate = (val: unknown): val is SDXLR
// #endregion
// #region VAEModelField
const zVAEModelFieldType = zFieldTypeBase.extend({
name: z.literal('VAEModelField'),
});
export const zVAEModelFieldValue = zModelIdentifierField.optional();
const zVAEModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zVAEModelFieldValue,
});
const zVAEModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zVAEModelFieldType,
originalType: zFieldType.optional(),
default: zVAEModelFieldValue,
});
const zVAEModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zVAEModelFieldType,
originalType: zFieldType.optional(),
});
export type VAEModelFieldValue = z.infer<typeof zVAEModelFieldValue>;
export type VAEModelFieldInputInstance = z.infer<typeof zVAEModelFieldInputInstance>;
@ -370,19 +469,19 @@ export const isVAEModelFieldInputTemplate = (val: unknown): val is VAEModelField
// #endregion
// #region LoRAModelField
const zLoRAModelFieldType = zFieldTypeBase.extend({
name: z.literal('LoRAModelField'),
});
export const zLoRAModelFieldValue = zModelIdentifierField.optional();
const zLoRAModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zLoRAModelFieldValue,
});
const zLoRAModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zLoRAModelFieldType,
originalType: zFieldType.optional(),
default: zLoRAModelFieldValue,
});
const zLoRAModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zLoRAModelFieldType,
originalType: zFieldType.optional(),
});
export type LoRAModelFieldValue = z.infer<typeof zLoRAModelFieldValue>;
export type LoRAModelFieldInputInstance = z.infer<typeof zLoRAModelFieldInputInstance>;
@ -394,19 +493,19 @@ export const isLoRAModelFieldInputTemplate = (val: unknown): val is LoRAModelFie
// #endregion
// #region ControlNetModelField
const zControlNetModelFieldType = zFieldTypeBase.extend({
name: z.literal('ControlNetModelField'),
});
export const zControlNetModelFieldValue = zModelIdentifierField.optional();
const zControlNetModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zControlNetModelFieldValue,
});
const zControlNetModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zControlNetModelFieldType,
originalType: zFieldType.optional(),
default: zControlNetModelFieldValue,
});
const zControlNetModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zControlNetModelFieldType,
originalType: zFieldType.optional(),
});
export type ControlNetModelFieldValue = z.infer<typeof zControlNetModelFieldValue>;
export type ControlNetModelFieldInputInstance = z.infer<typeof zControlNetModelFieldInputInstance>;
@ -418,19 +517,19 @@ export const isControlNetModelFieldInputTemplate = (val: unknown): val is Contro
// #endregion
// #region IPAdapterModelField
const zIPAdapterModelFieldType = zFieldTypeBase.extend({
name: z.literal('IPAdapterModelField'),
});
export const zIPAdapterModelFieldValue = zModelIdentifierField.optional();
const zIPAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zIPAdapterModelFieldValue,
});
const zIPAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zIPAdapterModelFieldType,
originalType: zFieldType.optional(),
default: zIPAdapterModelFieldValue,
});
const zIPAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zIPAdapterModelFieldType,
originalType: zFieldType.optional(),
});
export type IPAdapterModelFieldValue = z.infer<typeof zIPAdapterModelFieldValue>;
export type IPAdapterModelFieldInputInstance = z.infer<typeof zIPAdapterModelFieldInputInstance>;
@ -442,19 +541,19 @@ export const isIPAdapterModelFieldInputTemplate = (val: unknown): val is IPAdapt
// #endregion
// #region T2IAdapterField
const zT2IAdapterModelFieldType = zFieldTypeBase.extend({
name: z.literal('T2IAdapterModelField'),
});
export const zT2IAdapterModelFieldValue = zModelIdentifierField.optional();
const zT2IAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zT2IAdapterModelFieldValue,
});
const zT2IAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zT2IAdapterModelFieldType,
originalType: zFieldType.optional(),
default: zT2IAdapterModelFieldValue,
});
const zT2IAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zT2IAdapterModelFieldType,
originalType: zFieldType.optional(),
});
export type T2IAdapterModelFieldValue = z.infer<typeof zT2IAdapterModelFieldValue>;
export type T2IAdapterModelFieldInputInstance = z.infer<typeof zT2IAdapterModelFieldInputInstance>;
@ -466,19 +565,19 @@ export const isT2IAdapterModelFieldInputTemplate = (val: unknown): val is T2IAda
// #endregion
// #region SchedulerField
const zSchedulerFieldType = zFieldTypeBase.extend({
name: z.literal('SchedulerField'),
});
export const zSchedulerFieldValue = zSchedulerField.optional();
const zSchedulerFieldInputInstance = zFieldInputInstanceBase.extend({
value: zSchedulerFieldValue,
});
const zSchedulerFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zSchedulerFieldType,
originalType: zFieldType.optional(),
default: zSchedulerFieldValue,
});
const zSchedulerFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zSchedulerFieldType,
originalType: zFieldType.optional(),
});
export type SchedulerFieldValue = z.infer<typeof zSchedulerFieldValue>;
export type SchedulerFieldInputInstance = z.infer<typeof zSchedulerFieldInputInstance>;
@ -501,20 +600,20 @@ export const isSchedulerFieldInputTemplate = (val: unknown): val is SchedulerFie
* - Reserved fields like IsIntermediate
* - Any other field we don't have full-on schemas for
*/
const zStatelessFieldType = zFieldTypeBase.extend({
name: z.string().min(1), // stateless --> we accept the field's name as the type
});
const zStatelessFieldValue = z.undefined().catch(undefined); // stateless --> no value, but making this z.never() introduces a lot of extra TS fanagling
const zStatelessFieldInputInstance = zFieldInputInstanceBase.extend({
value: zStatelessFieldValue,
});
const zStatelessFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zStatelessFieldType,
originalType: zFieldType.optional(),
default: zStatelessFieldValue,
input: z.literal('connection'), // stateless --> only accepts connection inputs
});
const zStatelessFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zStatelessFieldType,
originalType: zFieldType.optional(),
});
export type StatelessFieldInputTemplate = z.infer<typeof zStatelessFieldInputTemplate>;
@ -535,34 +634,6 @@ export type StatelessFieldInputTemplate = z.infer<typeof zStatelessFieldInputTem
* for all other StatelessFields.
*/
// #region StatefulFieldType & FieldType
const zStatefulFieldType = z.union([
zIntegerFieldType,
zFloatFieldType,
zStringFieldType,
zBooleanFieldType,
zEnumFieldType,
zImageFieldType,
zBoardFieldType,
zMainModelFieldType,
zSDXLMainModelFieldType,
zSDXLRefinerModelFieldType,
zVAEModelFieldType,
zLoRAModelFieldType,
zControlNetModelFieldType,
zIPAdapterModelFieldType,
zT2IAdapterModelFieldType,
zColorFieldType,
zSchedulerFieldType,
]);
export type StatefulFieldType = z.infer<typeof zStatefulFieldType>;
export const isStatefulFieldType = (val: unknown): val is StatefulFieldType =>
zStatefulFieldType.safeParse(val).success;
const zFieldType = z.union([zStatefulFieldType, zStatelessFieldType]);
export type FieldType = z.infer<typeof zFieldType>;
// #endregion
// #region StatefulFieldValue & FieldValue
export const zStatefulFieldValue = z.union([
zIntegerFieldValue,

View File

@ -30,26 +30,16 @@ import { isNumber, startCase } from 'lodash-es';
// eslint-disable-next-line @typescript-eslint/no-explicit-any
type FieldInputTemplateBuilder<T extends FieldInputTemplate = any> = // valid `any`!
(arg: {
schemaObject: InvocationFieldSchema;
baseField: Omit<T, 'type'>;
isCollection: boolean;
isCollectionOrScalar: boolean;
}) => T;
(arg: { schemaObject: InvocationFieldSchema; baseField: Omit<T, 'type'>; fieldType: T['type'] }) => T;
const buildIntegerFieldInputTemplate: FieldInputTemplateBuilder<IntegerFieldInputTemplate> = ({
schemaObject,
baseField,
isCollection,
isCollectionOrScalar,
fieldType,
}) => {
const template: IntegerFieldInputTemplate = {
...baseField,
type: {
name: 'IntegerField',
isCollection,
isCollectionOrScalar,
},
type: fieldType,
default: schemaObject.default ?? 0,
};
@ -79,16 +69,11 @@ const buildIntegerFieldInputTemplate: FieldInputTemplateBuilder<IntegerFieldInpu
const buildFloatFieldInputTemplate: FieldInputTemplateBuilder<FloatFieldInputTemplate> = ({
schemaObject,
baseField,
isCollection,
isCollectionOrScalar,
fieldType,
}) => {
const template: FloatFieldInputTemplate = {
...baseField,
type: {
name: 'FloatField',
isCollection,
isCollectionOrScalar,
},
type: fieldType,
default: schemaObject.default ?? 0,
};
@ -118,16 +103,11 @@ const buildFloatFieldInputTemplate: FieldInputTemplateBuilder<FloatFieldInputTem
const buildStringFieldInputTemplate: FieldInputTemplateBuilder<StringFieldInputTemplate> = ({
schemaObject,
baseField,
isCollection,
isCollectionOrScalar,
fieldType,
}) => {
const template: StringFieldInputTemplate = {
...baseField,
type: {
name: 'StringField',
isCollection,
isCollectionOrScalar,
},
type: fieldType,
default: schemaObject.default ?? '',
};
@ -145,16 +125,11 @@ const buildStringFieldInputTemplate: FieldInputTemplateBuilder<StringFieldInputT
const buildBooleanFieldInputTemplate: FieldInputTemplateBuilder<BooleanFieldInputTemplate> = ({
schemaObject,
baseField,
isCollection,
isCollectionOrScalar,
fieldType,
}) => {
const template: BooleanFieldInputTemplate = {
...baseField,
type: {
name: 'BooleanField',
isCollection,
isCollectionOrScalar,
},
type: fieldType,
default: schemaObject.default ?? false,
};
@ -164,16 +139,11 @@ const buildBooleanFieldInputTemplate: FieldInputTemplateBuilder<BooleanFieldInpu
const buildMainModelFieldInputTemplate: FieldInputTemplateBuilder<MainModelFieldInputTemplate> = ({
schemaObject,
baseField,
isCollection,
isCollectionOrScalar,
fieldType,
}) => {
const template: MainModelFieldInputTemplate = {
...baseField,
type: {
name: 'MainModelField',
isCollection,
isCollectionOrScalar,
},
type: fieldType,
default: schemaObject.default ?? undefined,
};
@ -183,16 +153,11 @@ const buildMainModelFieldInputTemplate: FieldInputTemplateBuilder<MainModelField
const buildSDXLMainModelFieldInputTemplate: FieldInputTemplateBuilder<SDXLMainModelFieldInputTemplate> = ({
schemaObject,
baseField,
isCollection,
isCollectionOrScalar,
fieldType,
}) => {
const template: SDXLMainModelFieldInputTemplate = {
...baseField,
type: {
name: 'SDXLMainModelField',
isCollection,
isCollectionOrScalar,
},
type: fieldType,
default: schemaObject.default ?? undefined,
};
@ -202,16 +167,11 @@ const buildSDXLMainModelFieldInputTemplate: FieldInputTemplateBuilder<SDXLMainMo
const buildRefinerModelFieldInputTemplate: FieldInputTemplateBuilder<SDXLRefinerModelFieldInputTemplate> = ({
schemaObject,
baseField,
isCollection,
isCollectionOrScalar,
fieldType,
}) => {
const template: SDXLRefinerModelFieldInputTemplate = {
...baseField,
type: {
name: 'SDXLRefinerModelField',
isCollection,
isCollectionOrScalar,
},
type: fieldType,
default: schemaObject.default ?? undefined,
};
@ -221,16 +181,11 @@ const buildRefinerModelFieldInputTemplate: FieldInputTemplateBuilder<SDXLRefiner
const buildVAEModelFieldInputTemplate: FieldInputTemplateBuilder<VAEModelFieldInputTemplate> = ({
schemaObject,
baseField,
isCollection,
isCollectionOrScalar,
fieldType,
}) => {
const template: VAEModelFieldInputTemplate = {
...baseField,
type: {
name: 'VAEModelField',
isCollection,
isCollectionOrScalar,
},
type: fieldType,
default: schemaObject.default ?? undefined,
};
@ -240,16 +195,11 @@ const buildVAEModelFieldInputTemplate: FieldInputTemplateBuilder<VAEModelFieldIn
const buildLoRAModelFieldInputTemplate: FieldInputTemplateBuilder<LoRAModelFieldInputTemplate> = ({
schemaObject,
baseField,
isCollection,
isCollectionOrScalar,
fieldType,
}) => {
const template: LoRAModelFieldInputTemplate = {
...baseField,
type: {
name: 'LoRAModelField',
isCollection,
isCollectionOrScalar,
},
type: fieldType,
default: schemaObject.default ?? undefined,
};
@ -259,16 +209,11 @@ const buildLoRAModelFieldInputTemplate: FieldInputTemplateBuilder<LoRAModelField
const buildControlNetModelFieldInputTemplate: FieldInputTemplateBuilder<ControlNetModelFieldInputTemplate> = ({
schemaObject,
baseField,
isCollection,
isCollectionOrScalar,
fieldType,
}) => {
const template: ControlNetModelFieldInputTemplate = {
...baseField,
type: {
name: 'ControlNetModelField',
isCollection,
isCollectionOrScalar,
},
type: fieldType,
default: schemaObject.default ?? undefined,
};
@ -278,16 +223,11 @@ const buildControlNetModelFieldInputTemplate: FieldInputTemplateBuilder<ControlN
const buildIPAdapterModelFieldInputTemplate: FieldInputTemplateBuilder<IPAdapterModelFieldInputTemplate> = ({
schemaObject,
baseField,
isCollection,
isCollectionOrScalar,
fieldType,
}) => {
const template: IPAdapterModelFieldInputTemplate = {
...baseField,
type: {
name: 'IPAdapterModelField',
isCollection,
isCollectionOrScalar,
},
type: fieldType,
default: schemaObject.default ?? undefined,
};
@ -297,16 +237,11 @@ const buildIPAdapterModelFieldInputTemplate: FieldInputTemplateBuilder<IPAdapter
const buildT2IAdapterModelFieldInputTemplate: FieldInputTemplateBuilder<T2IAdapterModelFieldInputTemplate> = ({
schemaObject,
baseField,
isCollection,
isCollectionOrScalar,
fieldType,
}) => {
const template: T2IAdapterModelFieldInputTemplate = {
...baseField,
type: {
name: 'T2IAdapterModelField',
isCollection,
isCollectionOrScalar,
},
type: fieldType,
default: schemaObject.default ?? undefined,
};
@ -316,16 +251,11 @@ const buildT2IAdapterModelFieldInputTemplate: FieldInputTemplateBuilder<T2IAdapt
const buildBoardFieldInputTemplate: FieldInputTemplateBuilder<BoardFieldInputTemplate> = ({
schemaObject,
baseField,
isCollection,
isCollectionOrScalar,
fieldType,
}) => {
const template: BoardFieldInputTemplate = {
...baseField,
type: {
name: 'BoardField',
isCollection,
isCollectionOrScalar,
},
type: fieldType,
default: schemaObject.default ?? undefined,
};
@ -335,16 +265,11 @@ const buildBoardFieldInputTemplate: FieldInputTemplateBuilder<BoardFieldInputTem
const buildImageFieldInputTemplate: FieldInputTemplateBuilder<ImageFieldInputTemplate> = ({
schemaObject,
baseField,
isCollection,
isCollectionOrScalar,
fieldType,
}) => {
const template: ImageFieldInputTemplate = {
...baseField,
type: {
name: 'ImageField',
isCollection,
isCollectionOrScalar,
},
type: fieldType,
default: schemaObject.default ?? undefined,
};
@ -354,8 +279,7 @@ const buildImageFieldInputTemplate: FieldInputTemplateBuilder<ImageFieldInputTem
const buildEnumFieldInputTemplate: FieldInputTemplateBuilder<EnumFieldInputTemplate> = ({
schemaObject,
baseField,
isCollection,
isCollectionOrScalar,
fieldType,
}) => {
let options: EnumFieldInputTemplate['options'] = [];
if (schemaObject.anyOf) {
@ -383,11 +307,7 @@ const buildEnumFieldInputTemplate: FieldInputTemplateBuilder<EnumFieldInputTempl
}
const template: EnumFieldInputTemplate = {
...baseField,
type: {
name: 'EnumField',
isCollection,
isCollectionOrScalar,
},
type: fieldType,
options,
ui_choice_labels: schemaObject.ui_choice_labels,
default: schemaObject.default ?? options[0],
@ -399,16 +319,11 @@ const buildEnumFieldInputTemplate: FieldInputTemplateBuilder<EnumFieldInputTempl
const buildColorFieldInputTemplate: FieldInputTemplateBuilder<ColorFieldInputTemplate> = ({
schemaObject,
baseField,
isCollection,
isCollectionOrScalar,
fieldType,
}) => {
const template: ColorFieldInputTemplate = {
...baseField,
type: {
name: 'ColorField',
isCollection,
isCollectionOrScalar,
},
type: fieldType,
default: schemaObject.default ?? { r: 127, g: 127, b: 127, a: 255 },
};
@ -418,16 +333,11 @@ const buildColorFieldInputTemplate: FieldInputTemplateBuilder<ColorFieldInputTem
const buildSchedulerFieldInputTemplate: FieldInputTemplateBuilder<SchedulerFieldInputTemplate> = ({
schemaObject,
baseField,
isCollection,
isCollectionOrScalar,
fieldType,
}) => {
const template: SchedulerFieldInputTemplate = {
...baseField,
type: {
name: 'SchedulerField',
isCollection,
isCollectionOrScalar,
},
type: fieldType,
default: schemaObject.default ?? 'euler',
};
@ -452,7 +362,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
StringField: buildStringFieldInputTemplate,
T2IAdapterModelField: buildT2IAdapterModelFieldInputTemplate,
VAEModelField: buildVAEModelFieldInputTemplate,
};
} as const;
export const buildFieldInputTemplate = (
fieldSchema: InvocationFieldSchema,
@ -479,20 +389,22 @@ export const buildFieldInputTemplate = (
if (isStatefulFieldType(fieldType)) {
const builder = TEMPLATE_BUILDER_MAP[fieldType.name];
return builder({
const template = builder({
schemaObject: fieldSchema,
baseField,
isCollection: fieldType.isCollection,
isCollectionOrScalar: fieldType.isCollectionOrScalar,
fieldType,
});
}
// This is a StatelessField, create it directly.
const template: StatelessFieldInputTemplate = {
...baseField,
input: 'connection', // stateless --> connection only inputs
type: fieldType,
default: undefined, // stateless --> no default value
};
return template;
return template;
} else {
// This is a StatelessField, create it directly.
const template: StatelessFieldInputTemplate = {
...baseField,
input: 'connection', // stateless --> connection only inputs
type: fieldType,
default: undefined, // stateless --> no default value
};
return template;
}
};

View File

@ -9,7 +9,7 @@ export const buildFieldOutputTemplate = (
): FieldOutputTemplate => {
const { title, description, ui_hidden, ui_type, ui_order } = fieldSchema;
const fieldOutputTemplate: FieldOutputTemplate = {
const template: FieldOutputTemplate = {
fieldKind: 'output',
name: fieldName,
title: title ?? (fieldName ? startCase(fieldName) : ''),
@ -20,5 +20,5 @@ export const buildFieldOutputTemplate = (
ui_order,
};
return fieldOutputTemplate;
return template;
};

View File

@ -244,7 +244,7 @@ const specialCases: ParseFieldTypeTestCase[] = [
enum: ['ddim', 'ddpm', 'deis'],
ui_type: 'SchedulerField',
},
expected: { name: 'SchedulerField', isCollection: false, isCollectionOrScalar: false },
expected: { name: 'EnumField', isCollection: false, isCollectionOrScalar: false },
},
{
name: 'Explicit ui_type (AnyField)',
@ -253,7 +253,7 @@ const specialCases: ParseFieldTypeTestCase[] = [
enum: ['ddim', 'ddpm', 'deis'],
ui_type: 'AnyField',
},
expected: { name: 'AnyField', isCollection: false, isCollectionOrScalar: false },
expected: { name: 'EnumField', isCollection: false, isCollectionOrScalar: false },
},
{
name: 'Explicit ui_type (CollectionField)',
@ -262,7 +262,7 @@ const specialCases: ParseFieldTypeTestCase[] = [
enum: ['ddim', 'ddpm', 'deis'],
ui_type: 'CollectionField',
},
expected: { name: 'CollectionField', isCollection: true, isCollectionOrScalar: false },
expected: { name: 'EnumField', isCollection: false, isCollectionOrScalar: false },
},
];

View File

@ -6,14 +6,8 @@ import {
UnsupportedUnionError,
} from 'features/nodes/types/error';
import type { FieldType } from 'features/nodes/types/field';
import type { InvocationFieldSchema, OpenAPIV3_1SchemaOrRef } from 'features/nodes/types/openapi';
import {
isArraySchemaObject,
isInvocationFieldSchema,
isNonArraySchemaObject,
isRefObject,
isSchemaObject,
} from 'features/nodes/types/openapi';
import type { OpenAPIV3_1SchemaOrRef } from 'features/nodes/types/openapi';
import { isArraySchemaObject, isNonArraySchemaObject, isRefObject, isSchemaObject } from 'features/nodes/types/openapi';
import { t } from 'i18next';
import { isArray } from 'lodash-es';
import type { OpenAPIV3_1 } from 'openapi-types';
@ -35,7 +29,7 @@ const OPENAPI_TO_FIELD_TYPE_MAP: Record<string, string> = {
boolean: 'BooleanField',
};
const isCollectionFieldType = (fieldType: string) => {
export const isCollectionFieldType = (fieldType: string) => {
/**
* CollectionField is `list[Any]` in the pydantic schema, but we need to distinguish between
* it and other `list[Any]` fields, due to its special internal handling.
@ -48,18 +42,7 @@ const isCollectionFieldType = (fieldType: string) => {
return false;
};
export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef | InvocationFieldSchema): FieldType => {
if (isInvocationFieldSchema(schemaObject)) {
// Check if this field has an explicit type provided by the node schema
const { ui_type } = schemaObject;
if (ui_type) {
return {
name: ui_type,
isCollection: isCollectionFieldType(ui_type),
isCollectionOrScalar: false,
};
}
}
export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType => {
if (isSchemaObject(schemaObject)) {
if (schemaObject.const) {
// Fields with a single const value are defined as `Literal["value"]` in the pydantic schema - it's actually an enum

View File

@ -97,6 +97,11 @@ const expected = {
name: 'SchedulerField',
isCollection: false,
isCollectionOrScalar: false,
originalType: {
name: 'EnumField',
isCollection: false,
isCollectionOrScalar: false,
},
},
default: 'euler',
},
@ -111,6 +116,11 @@ const expected = {
name: 'SchedulerField',
isCollection: false,
isCollectionOrScalar: false,
originalType: {
name: 'EnumField',
isCollection: false,
isCollectionOrScalar: false,
},
},
ui_hidden: false,
ui_type: 'SchedulerField',
@ -141,6 +151,11 @@ const expected = {
name: 'MainModelField',
isCollection: false,
isCollectionOrScalar: false,
originalType: {
name: 'ModelIdentifierField',
isCollection: false,
isCollectionOrScalar: false,
},
},
},
},
@ -186,6 +201,48 @@ const expected = {
nodePack: 'invokeai',
classification: 'stable',
},
collect: {
title: 'Collect',
type: 'collect',
version: '1.0.0',
tags: [],
description: 'Collects values into a collection',
outputType: 'collect_output',
inputs: {
item: {
name: 'item',
title: 'Collection Item',
required: false,
description: 'The item to collect (all inputs must be of the same type)',
fieldKind: 'input',
input: 'connection',
ui_hidden: false,
ui_type: 'CollectionItemField',
type: {
name: 'CollectionItemField',
isCollection: false,
isCollectionOrScalar: false,
},
},
},
outputs: {
collection: {
fieldKind: 'output',
name: 'collection',
title: 'Collection',
description: 'The collection of input items',
type: {
name: 'CollectionField',
isCollection: true,
isCollectionOrScalar: false,
},
ui_hidden: false,
ui_type: 'CollectionField',
},
},
useCache: true,
classification: 'stable',
},
};
const schema = {
@ -785,6 +842,101 @@ const schema = {
type: 'object',
class: 'output',
},
CollectInvocation: {
properties: {
id: {
type: 'string',
title: 'Id',
description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.',
field_kind: 'node_attribute',
},
is_intermediate: {
type: 'boolean',
title: 'Is Intermediate',
description: 'Whether or not this is an intermediate invocation.',
default: false,
field_kind: 'node_attribute',
ui_type: 'IsIntermediate',
},
use_cache: {
type: 'boolean',
title: 'Use Cache',
description: 'Whether or not to use the cache',
default: true,
field_kind: 'node_attribute',
},
item: {
anyOf: [
{},
{
type: 'null',
},
],
title: 'Collection Item',
description: 'The item to collect (all inputs must be of the same type)',
field_kind: 'input',
input: 'connection',
orig_required: false,
ui_hidden: false,
ui_type: 'CollectionItemField',
},
collection: {
items: {},
type: 'array',
title: 'Collection',
description: 'The collection, will be provided on execution',
default: [],
field_kind: 'input',
input: 'any',
orig_default: [],
orig_required: false,
ui_hidden: true,
},
type: {
type: 'string',
enum: ['collect'],
const: 'collect',
title: 'type',
default: 'collect',
field_kind: 'node_attribute',
},
},
type: 'object',
required: ['type', 'id'],
title: 'CollectInvocation',
description: 'Collects values into a collection',
classification: 'stable',
version: '1.0.0',
output: {
$ref: '#/components/schemas/CollectInvocationOutput',
},
class: 'invocation',
},
CollectInvocationOutput: {
properties: {
collection: {
description: 'The collection of input items',
field_kind: 'output',
items: {},
title: 'Collection',
type: 'array',
ui_hidden: false,
ui_type: 'CollectionField',
},
type: {
const: 'collect_output',
default: 'collect_output',
enum: ['collect_output'],
field_kind: 'node_attribute',
title: 'type',
type: 'string',
},
},
required: ['collection', 'type', 'type'],
title: 'CollectInvocationOutput',
type: 'object',
class: 'output',
},
},
},
} as OpenAPIV3_1.Document;

View File

@ -1,23 +1,29 @@
import { logger } from 'app/logging/logger';
import { deepClone } from 'common/util/deepClone';
import { parseify } from 'common/util/serialize';
import type { Templates } from 'features/nodes/store/types';
import { FieldParseError } from 'features/nodes/types/error';
import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field';
import {
type FieldInputTemplate,
type FieldOutputTemplate,
type FieldType,
isStatefulFieldType,
} from 'features/nodes/types/field';
import type { InvocationTemplate } from 'features/nodes/types/invocation';
import type { InvocationSchemaObject } from 'features/nodes/types/openapi';
import type { InvocationFieldSchema, InvocationSchemaObject } from 'features/nodes/types/openapi';
import {
isInvocationFieldSchema,
isInvocationOutputSchemaObject,
isInvocationSchemaObject,
} from 'features/nodes/types/openapi';
import { t } from 'i18next';
import { reduce } from 'lodash-es';
import { isEqual, reduce } from 'lodash-es';
import type { OpenAPIV3_1 } from 'openapi-types';
import { serializeError } from 'serialize-error';
import { buildFieldInputTemplate } from './buildFieldInputTemplate';
import { buildFieldOutputTemplate } from './buildFieldOutputTemplate';
import { parseFieldType } from './parseFieldType';
import { isCollectionFieldType, parseFieldType } from './parseFieldType';
const RESERVED_INPUT_FIELD_NAMES = ['id', 'type', 'use_cache'];
const RESERVED_OUTPUT_FIELD_NAMES = ['type'];
@ -94,51 +100,43 @@ export const parseSchema = (
return inputsAccumulator;
}
try {
const fieldType = parseFieldType(property);
const fieldTypeOverride = property.ui_type
? {
name: property.ui_type,
isCollection: isCollectionFieldType(property.ui_type),
isCollectionOrScalar: false,
}
: null;
if (isReservedFieldType(fieldType.name)) {
logger('nodes').trace(
{ node: type, field: propertyName, schema: parseify(property) },
'Skipped reserved input field'
);
return inputsAccumulator;
}
const originalFieldType = getFieldType(property, propertyName, type, 'input');
const fieldInputTemplate = buildFieldInputTemplate(property, propertyName, fieldType);
inputsAccumulator[propertyName] = fieldInputTemplate;
} catch (e) {
if (e instanceof FieldParseError) {
logger('nodes').warn(
{
node: type,
field: propertyName,
schema: parseify(property),
},
t('nodes.inputFieldTypeParseError', {
node: type,
field: propertyName,
message: e.message,
})
);
} else {
logger('nodes').warn(
{
node: type,
field: propertyName,
schema: parseify(property),
error: serializeError(e),
},
t('nodes.inputFieldTypeParseError', {
node: type,
field: propertyName,
message: 'unknown error',
})
);
}
const fieldType = fieldTypeOverride ?? originalFieldType;
if (!fieldType) {
logger('nodes').trace(
{ node: type, field: propertyName, schema: parseify(property) },
'Unable to parse field type'
);
return inputsAccumulator;
}
if (isReservedFieldType(fieldType.name)) {
logger('nodes').trace(
{ node: type, field: propertyName, schema: parseify(property) },
'Skipped reserved input field'
);
return inputsAccumulator;
}
if (isStatefulFieldType(fieldType) && originalFieldType && !isEqual(originalFieldType, fieldType)) {
console.log('STATEFUL WITH ORIGINAL');
fieldType.originalType = deepClone(originalFieldType);
console.log(fieldType);
}
const fieldInputTemplate = buildFieldInputTemplate(property, propertyName, fieldType);
console.log(fieldInputTemplate);
inputsAccumulator[propertyName] = fieldInputTemplate;
return inputsAccumulator;
},
{}
@ -183,54 +181,34 @@ export const parseSchema = (
return outputsAccumulator;
}
try {
const fieldType = parseFieldType(property);
const fieldTypeOverride = property.ui_type
? {
name: property.ui_type,
isCollection: isCollectionFieldType(property.ui_type),
isCollectionOrScalar: false,
}
: null;
if (!fieldType) {
logger('nodes').warn(
{
node: type,
field: propertyName,
schema: parseify(property),
},
'Missing output field type'
);
return outputsAccumulator;
}
const originalFieldType = getFieldType(property, propertyName, type, 'output');
const fieldOutputTemplate = buildFieldOutputTemplate(property, propertyName, fieldType);
outputsAccumulator[propertyName] = fieldOutputTemplate;
} catch (e) {
if (e instanceof FieldParseError) {
logger('nodes').warn(
{
node: type,
field: propertyName,
schema: parseify(property),
},
t('nodes.outputFieldTypeParseError', {
node: type,
field: propertyName,
message: e.message,
})
);
} else {
logger('nodes').warn(
{
node: type,
field: propertyName,
schema: parseify(property),
error: serializeError(e),
},
t('nodes.outputFieldTypeParseError', {
node: type,
field: propertyName,
message: 'unknown error',
})
);
}
const fieldType = fieldTypeOverride ?? originalFieldType;
if (!fieldType) {
logger('nodes').trace(
{ node: type, field: propertyName, schema: parseify(property) },
'Unable to parse field type'
);
return outputsAccumulator;
}
if (isStatefulFieldType(fieldType) && originalFieldType && !isEqual(originalFieldType, fieldType)) {
console.log('STATEFUL WITH ORIGINAL');
fieldType.originalType = deepClone(originalFieldType);
console.log(fieldType);
}
const fieldOutputTemplate = buildFieldOutputTemplate(property, propertyName, fieldType);
outputsAccumulator[propertyName] = fieldOutputTemplate;
return outputsAccumulator;
},
{} as Record<string, FieldOutputTemplate>
@ -259,3 +237,45 @@ export const parseSchema = (
return invocations;
};
const getFieldType = (
property: InvocationFieldSchema,
propertyName: string,
type: string,
kind: 'input' | 'output'
): FieldType | null => {
try {
return parseFieldType(property);
} catch (e) {
const tKey = kind === 'input' ? 'nodes.inputFieldTypeParseError' : 'nodes.outputFieldTypeParseError';
if (e instanceof FieldParseError) {
logger('nodes').warn(
{
node: type,
field: propertyName,
schema: parseify(property),
},
t(tKey, {
node: type,
field: propertyName,
message: e.message,
})
);
} else {
logger('nodes').warn(
{
node: type,
field: propertyName,
schema: parseify(property),
error: serializeError(e),
},
t(tKey, {
node: type,
field: propertyName,
message: 'unknown error',
})
);
}
return null;
}
};