fix(ui): use new field type cardinality throughout app

Update business logic and tests.
This commit is contained in:
psychedelicious 2024-05-19 22:57:12 +10:00
parent dba8c43ecb
commit 8062a47d16
12 changed files with 239 additions and 290 deletions

@ -4,7 +4,7 @@ import { getFieldColor } from 'features/nodes/components/flow/edges/util/getEdge
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
import type { ValidationResult } from 'features/nodes/store/util/validateConnection';
import { HANDLE_TOOLTIP_OPEN_DELAY, MODEL_TYPES } from 'features/nodes/types/constants';
import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field';
import { type FieldInputTemplate, type FieldOutputTemplate, isSingle } from 'features/nodes/types/field';
import type { CSSProperties } from 'react';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@ -29,11 +29,11 @@ const FieldHandle = (props: FieldHandleProps) => {
const isModelType = MODEL_TYPES.some((t) => t === type.name);
const color = getFieldColor(type);
const s: CSSProperties = {
backgroundColor: type.isCollection || type.isCollectionOrScalar ? colorTokenToCssVar('base.900') : color,
backgroundColor: !isSingle(type) ? colorTokenToCssVar('base.900') : color,
position: 'absolute',
width: '1rem',
height: '1rem',
borderWidth: type.isCollection || type.isCollectionOrScalar ? 4 : 0,
borderWidth: !isSingle(type) ? 4 : 0,
borderStyle: 'solid',
borderColor: color,
borderRadius: isModelType ? 4 : '100%',

@ -1,5 +1,6 @@
import { EMPTY_ARRAY } from 'app/store/constants';
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { isSingleOrCollection } from 'features/nodes/types/field';
import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames';
import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate';
import { keys, map } from 'lodash-es';
@ -11,7 +12,7 @@ export const useAnyOrDirectInputFieldNames = (nodeId: string): string[] => {
const fieldNames = useMemo(() => {
const fields = map(template.inputs).filter((field) => {
return (
(['any', 'direct'].includes(field.input) || field.type.isCollectionOrScalar) &&
(['any', 'direct'].includes(field.input) || isSingleOrCollection(field.type)) &&
keys(TEMPLATE_BUILDER_MAP).includes(field.type.name)
);
});

@ -1,5 +1,6 @@
import { EMPTY_ARRAY } from 'app/store/constants';
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { isSingleOrCollection } from 'features/nodes/types/field';
import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames';
import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate';
import { keys, map } from 'lodash-es';
@ -11,7 +12,7 @@ export const useConnectionInputFieldNames = (nodeId: string): string[] => {
// get the visible fields
const fields = map(template.inputs).filter(
(field) =>
(field.input === 'connection' && !field.type.isCollectionOrScalar) ||
(field.input === 'connection' && !isSingleOrCollection(field.type)) ||
!keys(TEMPLATE_BUILDER_MAP).includes(field.type.name)
);

@ -1,4 +1,4 @@
import type { FieldType } from 'features/nodes/types/field';
import { type FieldType, isCollection, isSingleOrCollection } from 'features/nodes/types/field';
import { useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@ -10,10 +10,10 @@ export const useFieldTypeName = (fieldType?: FieldType): string => {
return '';
}
const { name } = fieldType;
if (fieldType.isCollection) {
if (isCollection(fieldType)) {
return t('nodes.collectionFieldType', { name });
}
if (fieldType.isCollectionOrScalar) {
if (isSingleOrCollection(fieldType)) {
return t('nodes.collectionOrScalarFieldType', { name });
}
return name;

@ -1,99 +1,84 @@
import type { FieldType } from 'features/nodes/types/field';
import { describe, expect, it } from 'vitest';
import { areTypesEqual } from './areTypesEqual';
describe(areTypesEqual.name, () => {
it('should handle equal source and target type', () => {
const sourceType = {
const sourceType: FieldType = {
name: 'IntegerField',
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
originalType: {
name: 'Foo',
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
},
};
const targetType = {
const targetType: FieldType = {
name: 'IntegerField',
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
originalType: {
name: 'Bar',
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
},
};
expect(areTypesEqual(sourceType, targetType)).toBe(true);
});
it('should handle equal source type and original target type', () => {
const sourceType = {
const sourceType: FieldType = {
name: 'IntegerField',
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
originalType: {
name: 'Foo',
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
},
};
const targetType = {
name: 'Bar',
isCollection: false,
isCollectionOrScalar: false,
const targetType: FieldType = {
name: 'MainModelField',
cardinality: 'SINGLE',
originalType: {
name: 'IntegerField',
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
},
};
expect(areTypesEqual(sourceType, targetType)).toBe(true);
});
it('should handle equal original source type and target type', () => {
const sourceType = {
name: 'Foo',
isCollection: false,
isCollectionOrScalar: false,
const sourceType: FieldType = {
name: 'MainModelField',
cardinality: 'SINGLE',
originalType: {
name: 'IntegerField',
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
},
};
const targetType = {
const targetType: FieldType = {
name: 'IntegerField',
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
originalType: {
name: 'Bar',
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
},
};
expect(areTypesEqual(sourceType, targetType)).toBe(true);
});
it('should handle equal original source type and original target type', () => {
const sourceType = {
name: 'Foo',
isCollection: false,
isCollectionOrScalar: false,
const sourceType: FieldType = {
name: 'MainModelField',
cardinality: 'SINGLE',
originalType: {
name: 'IntegerField',
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
},
};
const targetType = {
name: 'Bar',
isCollection: false,
isCollectionOrScalar: false,
const targetType: FieldType = {
name: 'LoRAModelField',
cardinality: 'SINGLE',
originalType: {
name: 'IntegerField',
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
},
};
expect(areTypesEqual(sourceType, targetType)).toBe(true);

@ -11,7 +11,7 @@ describe(getCollectItemType.name, () => {
const n2 = buildNode(collect);
const e1 = buildEdge(n1.id, 'value', n2.id, 'item');
const result = getCollectItemType(templates, [n1, n2], [e1], n2.id);
expect(result).toEqual<FieldType>({ name: 'IntegerField', isCollection: false, isCollectionOrScalar: false });
expect(result).toEqual<FieldType>({ name: 'IntegerField', cardinality: 'SINGLE' });
});
it('should return null if the collect node does not have any connections', () => {
const n1 = buildNode(collect);

@ -33,8 +33,7 @@ export const add: InvocationTemplate = {
ui_hidden: false,
type: {
name: 'IntegerField',
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
},
default: 0,
},
@ -48,8 +47,7 @@ export const add: InvocationTemplate = {
ui_hidden: false,
type: {
name: 'IntegerField',
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
},
default: 0,
},
@ -62,8 +60,7 @@ export const add: InvocationTemplate = {
description: 'The output integer',
type: {
name: 'IntegerField',
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
},
ui_hidden: false,
},
@ -91,8 +88,7 @@ export const sub: InvocationTemplate = {
ui_hidden: false,
type: {
name: 'IntegerField',
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
},
default: 0,
},
@ -106,8 +102,7 @@ export const sub: InvocationTemplate = {
ui_hidden: false,
type: {
name: 'IntegerField',
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
},
default: 0,
},
@ -120,8 +115,7 @@ export const sub: InvocationTemplate = {
description: 'The output integer',
type: {
name: 'IntegerField',
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
},
ui_hidden: false,
},
@ -150,8 +144,7 @@ export const collect: InvocationTemplate = {
ui_type: 'CollectionItemField',
type: {
name: 'CollectionItemField',
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
},
},
},
@ -163,8 +156,7 @@ export const collect: InvocationTemplate = {
description: 'The collection of input items',
type: {
name: 'CollectionField',
isCollection: true,
isCollectionOrScalar: false,
cardinality: 'COLLECTION',
},
ui_hidden: false,
ui_type: 'CollectionField',
@ -193,12 +185,11 @@ const scheduler: InvocationTemplate = {
ui_type: 'SchedulerField',
type: {
name: 'SchedulerField',
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
originalType: {
name: 'EnumField',
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
},
},
default: 'euler',
@ -212,12 +203,11 @@ const scheduler: InvocationTemplate = {
description: 'Scheduler to use during inference',
type: {
name: 'SchedulerField',
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
originalType: {
name: 'EnumField',
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
},
},
ui_hidden: false,
@ -248,12 +238,11 @@ export const main_model_loader: InvocationTemplate = {
ui_type: 'MainModelField',
type: {
name: 'MainModelField',
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
originalType: {
name: 'ModelIdentifierField',
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
},
},
},
@ -266,8 +255,7 @@ export const main_model_loader: InvocationTemplate = {
description: 'VAE',
type: {
name: 'VAEField',
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
},
ui_hidden: false,
},
@ -278,8 +266,7 @@ export const main_model_loader: InvocationTemplate = {
description: 'CLIP (tokenizer, text encoder, LoRAs) and skipped layer count',
type: {
name: 'CLIPField',
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
},
ui_hidden: false,
},
@ -290,8 +277,7 @@ export const main_model_loader: InvocationTemplate = {
description: 'UNet (scheduler, LoRAs)',
type: {
name: 'UNetField',
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
},
ui_hidden: false,
},
@ -319,8 +305,7 @@ export const img_resize: InvocationTemplate = {
ui_hidden: false,
type: {
name: 'BoardField',
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
},
},
metadata: {
@ -333,8 +318,7 @@ export const img_resize: InvocationTemplate = {
ui_hidden: false,
type: {
name: 'MetadataField',
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
},
},
image: {
@ -347,8 +331,7 @@ export const img_resize: InvocationTemplate = {
ui_hidden: false,
type: {
name: 'ImageField',
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
},
},
width: {
@ -361,8 +344,7 @@ export const img_resize: InvocationTemplate = {
ui_hidden: false,
type: {
name: 'IntegerField',
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
},
default: 512,
exclusiveMinimum: 0,
@ -377,8 +359,7 @@ export const img_resize: InvocationTemplate = {
ui_hidden: false,
type: {
name: 'IntegerField',
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
},
default: 512,
exclusiveMinimum: 0,
@ -393,8 +374,7 @@ export const img_resize: InvocationTemplate = {
ui_hidden: false,
type: {
name: 'EnumField',
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
},
options: ['nearest', 'box', 'bilinear', 'hamming', 'bicubic', 'lanczos'],
default: 'bicubic',
@ -408,8 +388,7 @@ export const img_resize: InvocationTemplate = {
description: 'The output image',
type: {
name: 'ImageField',
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
},
ui_hidden: false,
},
@ -420,8 +399,7 @@ export const img_resize: InvocationTemplate = {
description: 'The width of the image in pixels',
type: {
name: 'IntegerField',
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
},
ui_hidden: false,
},
@ -432,8 +410,7 @@ export const img_resize: InvocationTemplate = {
description: 'The height of the image in pixels',
type: {
name: 'IntegerField',
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
},
ui_hidden: false,
},
@ -462,8 +439,7 @@ const iterate: InvocationTemplate = {
ui_type: 'CollectionField',
type: {
name: 'CollectionField',
isCollection: true,
isCollectionOrScalar: false,
cardinality: 'COLLECTION',
},
},
},
@ -475,8 +451,7 @@ const iterate: InvocationTemplate = {
description: 'The item being iterated over',
type: {
name: 'CollectionItemField',
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
},
ui_hidden: false,
ui_type: 'CollectionItemField',
@ -488,8 +463,7 @@ const iterate: InvocationTemplate = {
description: 'The index of the item',
type: {
name: 'IntegerField',
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
},
ui_hidden: false,
},
@ -500,8 +474,7 @@ const iterate: InvocationTemplate = {
description: 'The total number of items',
type: {
name: 'IntegerField',
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
},
ui_hidden: false,
},

@ -4,148 +4,148 @@ import { validateConnectionTypes } from './validateConnectionTypes';
describe(validateConnectionTypes.name, () => {
describe('generic cases', () => {
it('should accept Scalar to Scalar of same type', () => {
it('should accept SINGLE to SINGLE of same type', () => {
const r = validateConnectionTypes(
{ name: 'FooField', isCollection: false, isCollectionOrScalar: false },
{ name: 'FooField', isCollection: false, isCollectionOrScalar: false }
{ name: 'FooField', cardinality: 'SINGLE' },
{ name: 'FooField', cardinality: 'SINGLE' }
);
expect(r).toBe(true);
});
it('should accept Collection to Collection of same type', () => {
it('should accept COLLECTION to COLLECTION of same type', () => {
const r = validateConnectionTypes(
{ name: 'FooField', isCollection: true, isCollectionOrScalar: false },
{ name: 'FooField', isCollection: true, isCollectionOrScalar: false }
{ name: 'FooField', cardinality: 'COLLECTION' },
{ name: 'FooField', cardinality: 'COLLECTION' }
);
expect(r).toBe(true);
});
it('should accept Scalar to CollectionOrScalar of same type', () => {
it('should accept SINGLE to SINGLE_OR_COLLECTION of same type', () => {
const r = validateConnectionTypes(
{ name: 'FooField', isCollection: false, isCollectionOrScalar: false },
{ name: 'FooField', isCollection: false, isCollectionOrScalar: true }
{ name: 'FooField', cardinality: 'SINGLE' },
{ name: 'FooField', cardinality: 'SINGLE_OR_COLLECTION' }
);
expect(r).toBe(true);
});
it('should accept Collection to CollectionOrScalar of same type', () => {
it('should accept COLLECTION to SINGLE_OR_COLLECTION of same type', () => {
const r = validateConnectionTypes(
{ name: 'FooField', isCollection: true, isCollectionOrScalar: false },
{ name: 'FooField', isCollection: false, isCollectionOrScalar: true }
{ name: 'FooField', cardinality: 'COLLECTION' },
{ name: 'FooField', cardinality: 'SINGLE_OR_COLLECTION' }
);
expect(r).toBe(true);
});
it('should reject Collection to Scalar of same type', () => {
it('should reject COLLECTION to SINGLE of same type', () => {
const r = validateConnectionTypes(
{ name: 'FooField', isCollection: true, isCollectionOrScalar: false },
{ name: 'FooField', isCollection: false, isCollectionOrScalar: false }
{ name: 'FooField', cardinality: 'COLLECTION' },
{ name: 'FooField', cardinality: 'SINGLE' }
);
expect(r).toBe(false);
});
it('should reject CollectionOrScalar to Scalar of same type', () => {
it('should reject SINGLE_OR_COLLECTION to SINGLE of same type', () => {
const r = validateConnectionTypes(
{ name: 'FooField', isCollection: false, isCollectionOrScalar: true },
{ name: 'FooField', isCollection: false, isCollectionOrScalar: false }
{ name: 'FooField', cardinality: 'SINGLE_OR_COLLECTION' },
{ name: 'FooField', cardinality: 'SINGLE' }
);
expect(r).toBe(false);
});
it('should reject mismatched types', () => {
const r = validateConnectionTypes(
{ name: 'FooField', isCollection: false, isCollectionOrScalar: false },
{ name: 'BarField', isCollection: false, isCollectionOrScalar: false }
{ name: 'FooField', cardinality: 'SINGLE' },
{ name: 'BarField', cardinality: 'SINGLE' }
);
expect(r).toBe(false);
});
});
describe('special cases', () => {
it('should reject a collection input to a collection input', () => {
it('should reject a COLLECTION input to a COLLECTION input', () => {
const r = validateConnectionTypes(
{ name: 'CollectionField', isCollection: true, isCollectionOrScalar: false },
{ name: 'CollectionField', isCollection: true, isCollectionOrScalar: false }
{ name: 'CollectionField', cardinality: 'COLLECTION' },
{ name: 'CollectionField', cardinality: 'COLLECTION' }
);
expect(r).toBe(false);
});
it('should accept equal types', () => {
const r = validateConnectionTypes(
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: false },
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: false }
{ name: 'IntegerField', cardinality: 'SINGLE' },
{ name: 'IntegerField', cardinality: 'SINGLE' }
);
expect(r).toBe(true);
});
describe('CollectionItemField', () => {
it('should accept CollectionItemField to any Scalar target', () => {
it('should accept CollectionItemField to any SINGLE target', () => {
const r = validateConnectionTypes(
{ name: 'CollectionItemField', isCollection: false, isCollectionOrScalar: false },
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: false }
{ name: 'CollectionItemField', cardinality: 'SINGLE' },
{ name: 'IntegerField', cardinality: 'SINGLE' }
);
expect(r).toBe(true);
});
it('should accept CollectionItemField to any CollectionOrScalar target', () => {
it('should accept CollectionItemField to any SINGLE_OR_COLLECTION target', () => {
const r = validateConnectionTypes(
{ name: 'CollectionItemField', isCollection: false, isCollectionOrScalar: false },
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: true }
{ name: 'CollectionItemField', cardinality: 'SINGLE' },
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' }
);
expect(r).toBe(true);
});
it('should accept any non-Collection to CollectionItemField', () => {
it('should accept any SINGLE to CollectionItemField', () => {
const r = validateConnectionTypes(
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: false },
{ name: 'CollectionItemField', isCollection: false, isCollectionOrScalar: false }
{ name: 'IntegerField', cardinality: 'SINGLE' },
{ name: 'CollectionItemField', cardinality: 'SINGLE' }
);
expect(r).toBe(true);
});
it('should reject any Collection to CollectionItemField', () => {
it('should reject any COLLECTION to CollectionItemField', () => {
const r = validateConnectionTypes(
{ name: 'IntegerField', isCollection: true, isCollectionOrScalar: false },
{ name: 'CollectionItemField', isCollection: false, isCollectionOrScalar: false }
{ name: 'IntegerField', cardinality: 'COLLECTION' },
{ name: 'CollectionItemField', cardinality: 'SINGLE' }
);
expect(r).toBe(false);
});
it('should reject any CollectionOrScalar to CollectionItemField', () => {
it('should reject any SINGLE_OR_COLLECTION to CollectionItemField', () => {
const r = validateConnectionTypes(
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: true },
{ name: 'CollectionItemField', isCollection: false, isCollectionOrScalar: false }
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' },
{ name: 'CollectionItemField', cardinality: 'SINGLE' }
);
expect(r).toBe(false);
});
});
describe('CollectionOrScalar', () => {
it('should accept any Scalar of same type to CollectionOrScalar', () => {
describe('SINGLE_OR_COLLECTION', () => {
it('should accept any SINGLE of same type to SINGLE_OR_COLLECTION', () => {
const r = validateConnectionTypes(
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: false },
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: true }
{ name: 'IntegerField', cardinality: 'SINGLE' },
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' }
);
expect(r).toBe(true);
});
it('should accept any Collection of same type to CollectionOrScalar', () => {
it('should accept any COLLECTION of same type to SINGLE_OR_COLLECTION', () => {
const r = validateConnectionTypes(
{ name: 'IntegerField', isCollection: true, isCollectionOrScalar: false },
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: true }
{ name: 'IntegerField', cardinality: 'COLLECTION' },
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' }
);
expect(r).toBe(true);
});
it('should accept any CollectionOrScalar of same type to CollectionOrScalar', () => {
it('should accept any SINGLE_OR_COLLECTION of same type to SINGLE_OR_COLLECTION', () => {
const r = validateConnectionTypes(
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: true },
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: true }
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' },
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' }
);
expect(r).toBe(true);
});
});
describe('CollectionField', () => {
it('should accept any CollectionField to any Collection type', () => {
it('should accept any CollectionField to any COLLECTION type', () => {
const r = validateConnectionTypes(
{ name: 'CollectionField', isCollection: false, isCollectionOrScalar: false },
{ name: 'IntegerField', isCollection: true, isCollectionOrScalar: false }
{ name: 'CollectionField', cardinality: 'SINGLE' },
{ name: 'IntegerField', cardinality: 'COLLECTION' }
);
expect(r).toBe(true);
});
it('should accept any CollectionField to any CollectionOrScalar type', () => {
it('should accept any CollectionField to any SINGLE_OR_COLLECTION type', () => {
const r = validateConnectionTypes(
{ name: 'CollectionField', isCollection: false, isCollectionOrScalar: false },
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: true }
{ name: 'CollectionField', cardinality: 'SINGLE' },
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' }
);
expect(r).toBe(true);
});
@ -158,62 +158,62 @@ describe(validateConnectionTypes.name, () => {
{ t1: 'IntegerField', t2: 'StringField' },
{ t1: 'FloatField', t2: 'StringField' },
];
it.each(typePairs)('should accept Scalar $t1 to Scalar $t2', ({ t1, t2 }: TypePair) => {
it.each(typePairs)('should accept SINGLE $t1 to SINGLE $t2', ({ t1, t2 }: TypePair) => {
const r = validateConnectionTypes({ name: t1, cardinality: 'SINGLE' }, { name: t2, cardinality: 'SINGLE' });
expect(r).toBe(true);
});
it.each(typePairs)('should accept SINGLE $t1 to SINGLE_OR_COLLECTION $t2', ({ t1, t2 }: TypePair) => {
const r = validateConnectionTypes(
{ name: t1, isCollection: false, isCollectionOrScalar: false },
{ name: t2, isCollection: false, isCollectionOrScalar: false }
{ name: t1, cardinality: 'SINGLE' },
{ name: t2, cardinality: 'SINGLE_OR_COLLECTION' }
);
expect(r).toBe(true);
});
it.each(typePairs)('should accept Scalar $t1 to CollectionOrScalar $t2', ({ t1, t2 }: TypePair) => {
it.each(typePairs)('should accept COLLECTION $t1 to COLLECTION $t2', ({ t1, t2 }: TypePair) => {
const r = validateConnectionTypes(
{ name: t1, isCollection: false, isCollectionOrScalar: false },
{ name: t2, isCollection: false, isCollectionOrScalar: true }
{ name: t1, cardinality: 'COLLECTION' },
{ name: t2, cardinality: 'COLLECTION' }
);
expect(r).toBe(true);
});
it.each(typePairs)('should accept Collection $t1 to Collection $t2', ({ t1, t2 }: TypePair) => {
it.each(typePairs)('should accept COLLECTION $t1 to SINGLE_OR_COLLECTION $t2', ({ t1, t2 }: TypePair) => {
const r = validateConnectionTypes(
{ name: t1, isCollection: true, isCollectionOrScalar: false },
{ name: t2, isCollection: true, isCollectionOrScalar: false }
);
expect(r).toBe(true);
});
it.each(typePairs)('should accept Collection $t1 to CollectionOrScalar $t2', ({ t1, t2 }: TypePair) => {
const r = validateConnectionTypes(
{ name: t1, isCollection: true, isCollectionOrScalar: false },
{ name: t2, isCollection: false, isCollectionOrScalar: true }
);
expect(r).toBe(true);
});
it.each(typePairs)('should accept CollectionOrScalar $t1 to CollectionOrScalar $t2', ({ t1, t2 }: TypePair) => {
const r = validateConnectionTypes(
{ name: t1, isCollection: false, isCollectionOrScalar: true },
{ name: t2, isCollection: false, isCollectionOrScalar: true }
{ name: t1, cardinality: 'COLLECTION' },
{ name: t2, cardinality: 'SINGLE_OR_COLLECTION' }
);
expect(r).toBe(true);
});
it.each(typePairs)(
'should accept SINGLE_OR_COLLECTION $t1 to SINGLE_OR_COLLECTION $t2',
({ t1, t2 }: TypePair) => {
const r = validateConnectionTypes(
{ name: t1, cardinality: 'SINGLE_OR_COLLECTION' },
{ name: t2, cardinality: 'SINGLE_OR_COLLECTION' }
);
expect(r).toBe(true);
}
);
});
describe('AnyField', () => {
it('should accept any Scalar type to AnyField', () => {
it('should accept any SINGLE type to AnyField', () => {
const r = validateConnectionTypes(
{ name: 'FooField', isCollection: false, isCollectionOrScalar: false },
{ name: 'AnyField', isCollection: false, isCollectionOrScalar: false }
{ name: 'FooField', cardinality: 'SINGLE' },
{ name: 'AnyField', cardinality: 'SINGLE' }
);
expect(r).toBe(true);
});
it('should accept any Collection type to AnyField', () => {
it('should accept any COLLECTION type to AnyField', () => {
const r = validateConnectionTypes(
{ name: 'FooField', isCollection: false, isCollectionOrScalar: false },
{ name: 'AnyField', isCollection: true, isCollectionOrScalar: false }
{ name: 'FooField', cardinality: 'SINGLE' },
{ name: 'AnyField', cardinality: 'COLLECTION' }
);
expect(r).toBe(true);
});
it('should accept any CollectionOrScalar type to AnyField', () => {
it('should accept any SINGLE_OR_COLLECTION type to AnyField', () => {
const r = validateConnectionTypes(
{ name: 'FooField', isCollection: false, isCollectionOrScalar: false },
{ name: 'AnyField', isCollection: false, isCollectionOrScalar: true }
{ name: 'FooField', cardinality: 'SINGLE' },
{ name: 'AnyField', cardinality: 'SINGLE_OR_COLLECTION' }
);
expect(r).toBe(true);
});

@ -1,5 +1,5 @@
import { areTypesEqual } from 'features/nodes/store/util/areTypesEqual';
import type { FieldType } from 'features/nodes/types/field';
import { type FieldType, isCollection, isSingle, isSingleOrCollection } from 'features/nodes/types/field';
/**
* Validates that the source and target types are compatible for a connection.
@ -27,38 +27,37 @@ export const validateConnectionTypes = (sourceType: FieldType, targetType: Field
* - Generic Collection can connect to any other Collection or CollectionOrScalar
* - Any Collection can connect to a Generic Collection
*/
const isCollectionItemToNonCollection = sourceType.name === 'CollectionItemField' && !targetType.isCollection;
const isCollectionItemToNonCollection = sourceType.name === 'CollectionItemField' && !isCollection(targetType);
const isNonCollectionToCollectionItem =
targetType.name === 'CollectionItemField' && !sourceType.isCollection && !sourceType.isCollectionOrScalar;
const isNonCollectionToCollectionItem = isSingle(sourceType) && targetType.name === 'CollectionItemField';
const isAnythingToCollectionOrScalarOfSameBaseType =
targetType.isCollectionOrScalar && sourceType.name === targetType.name;
const isAnythingToSingleOrCollectionOfSameBaseType =
isSingleOrCollection(targetType) && sourceType.name === targetType.name;
const isGenericCollectionToAnyCollectionOrCollectionOrScalar =
sourceType.name === 'CollectionField' && (targetType.isCollection || targetType.isCollectionOrScalar);
const isGenericCollectionToAnyCollectionOrSingleOrCollection =
sourceType.name === 'CollectionField' && !isSingle(targetType);
const isCollectionToGenericCollection = targetType.name === 'CollectionField' && sourceType.isCollection;
const isCollectionToGenericCollection = targetType.name === 'CollectionField' && isCollection(sourceType);
const isSourceScalar = !sourceType.isCollection && !sourceType.isCollectionOrScalar;
const isTargetScalar = !targetType.isCollection && !targetType.isCollectionOrScalar;
const isScalarToScalar = isSourceScalar && isTargetScalar;
const isScalarToCollectionOrScalar = isSourceScalar && targetType.isCollectionOrScalar;
const isCollectionToCollection = sourceType.isCollection && targetType.isCollection;
const isCollectionToCollectionOrScalar = sourceType.isCollection && targetType.isCollectionOrScalar;
const isCollectionOrScalarToCollectionOrScalar = sourceType.isCollectionOrScalar && targetType.isCollectionOrScalar;
const isPluralityMatch =
isScalarToScalar ||
const isSourceSingle = isSingle(sourceType);
const isTargetSingle = isSingle(targetType);
const isSingleToSingle = isSourceSingle && isTargetSingle;
const isSingleToSingleOrCollection = isSourceSingle && isSingleOrCollection(targetType);
const isCollectionToCollection = isCollection(sourceType) && isCollection(targetType);
const isCollectionToSingleOrCollection = isCollection(sourceType) && isSingleOrCollection(targetType);
const isSingleOrCollectionToSingleOrCollection = isSingleOrCollection(sourceType) && isSingleOrCollection(targetType);
const doesCardinalityMatch =
isSingleToSingle ||
isCollectionToCollection ||
isCollectionToCollectionOrScalar ||
isCollectionOrScalarToCollectionOrScalar ||
isScalarToCollectionOrScalar;
isCollectionToSingleOrCollection ||
isSingleOrCollectionToSingleOrCollection ||
isSingleToSingleOrCollection;
const isIntToFloat = sourceType.name === 'IntegerField' && targetType.name === 'FloatField';
const isIntToString = sourceType.name === 'IntegerField' && targetType.name === 'StringField';
const isFloatToString = sourceType.name === 'FloatField' && targetType.name === 'StringField';
const isSubTypeMatch = isPluralityMatch && (isIntToFloat || isIntToString || isFloatToString);
const isSubTypeMatch = doesCardinalityMatch && (isIntToFloat || isIntToString || isFloatToString);
const isTargetAnyType = targetType.name === 'AnyField';
@ -66,8 +65,8 @@ export const validateConnectionTypes = (sourceType: FieldType, targetType: Field
return (
isCollectionItemToNonCollection ||
isNonCollectionToCollectionItem ||
isAnythingToCollectionOrScalarOfSameBaseType ||
isGenericCollectionToAnyCollectionOrCollectionOrScalar ||
isAnythingToSingleOrCollectionOfSameBaseType ||
isGenericCollectionToAnyCollectionOrSingleOrCollection ||
isCollectionToGenericCollection ||
isSubTypeMatch ||
isTargetAnyType

@ -4,6 +4,7 @@ import {
UnsupportedPrimitiveTypeError,
UnsupportedUnionError,
} from 'features/nodes/types/error';
import type { FieldType } from 'features/nodes/types/field';
import type { InvocationFieldSchema, OpenAPIV3_1SchemaOrRef } from 'features/nodes/types/openapi';
import { parseFieldType, refObjectToSchemaName } from 'features/nodes/util/schema/parseFieldType';
import { describe, expect, it } from 'vitest';
@ -11,52 +12,52 @@ import { describe, expect, it } from 'vitest';
type ParseFieldTypeTestCase = {
name: string;
schema: OpenAPIV3_1SchemaOrRef | InvocationFieldSchema;
expected: { name: string; isCollection: boolean; isCollectionOrScalar: boolean };
expected: FieldType;
};
const primitiveTypes: ParseFieldTypeTestCase[] = [
{
name: 'Scalar IntegerField',
name: 'SINGLE IntegerField',
schema: { type: 'integer' },
expected: { name: 'IntegerField', isCollection: false, isCollectionOrScalar: false },
expected: { name: 'IntegerField', cardinality: 'SINGLE' },
},
{
name: 'Scalar FloatField',
name: 'SINGLE FloatField',
schema: { type: 'number' },
expected: { name: 'FloatField', isCollection: false, isCollectionOrScalar: false },
expected: { name: 'FloatField', cardinality: 'SINGLE' },
},
{
name: 'Scalar StringField',
name: 'SINGLE StringField',
schema: { type: 'string' },
expected: { name: 'StringField', isCollection: false, isCollectionOrScalar: false },
expected: { name: 'StringField', cardinality: 'SINGLE' },
},
{
name: 'Scalar BooleanField',
name: 'SINGLE BooleanField',
schema: { type: 'boolean' },
expected: { name: 'BooleanField', isCollection: false, isCollectionOrScalar: false },
expected: { name: 'BooleanField', cardinality: 'SINGLE' },
},
{
name: 'Collection IntegerField',
name: 'COLLECTION IntegerField',
schema: { items: { type: 'integer' }, type: 'array' },
expected: { name: 'IntegerField', isCollection: true, isCollectionOrScalar: false },
expected: { name: 'IntegerField', cardinality: 'COLLECTION' },
},
{
name: 'Collection FloatField',
name: 'COLLECTION FloatField',
schema: { items: { type: 'number' }, type: 'array' },
expected: { name: 'FloatField', isCollection: true, isCollectionOrScalar: false },
expected: { name: 'FloatField', cardinality: 'COLLECTION' },
},
{
name: 'Collection StringField',
name: 'COLLECTION StringField',
schema: { items: { type: 'string' }, type: 'array' },
expected: { name: 'StringField', isCollection: true, isCollectionOrScalar: false },
expected: { name: 'StringField', cardinality: 'COLLECTION' },
},
{
name: 'Collection BooleanField',
name: 'COLLECTION BooleanField',
schema: { items: { type: 'boolean' }, type: 'array' },
expected: { name: 'BooleanField', isCollection: true, isCollectionOrScalar: false },
expected: { name: 'BooleanField', cardinality: 'COLLECTION' },
},
{
name: 'CollectionOrScalar IntegerField',
name: 'SINGLE_OR_COLLECTION IntegerField',
schema: {
anyOf: [
{
@ -70,10 +71,10 @@ const primitiveTypes: ParseFieldTypeTestCase[] = [
},
],
},
expected: { name: 'IntegerField', isCollection: false, isCollectionOrScalar: true },
expected: { name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' },
},
{
name: 'CollectionOrScalar FloatField',
name: 'SINGLE_OR_COLLECTION FloatField',
schema: {
anyOf: [
{
@ -87,10 +88,10 @@ const primitiveTypes: ParseFieldTypeTestCase[] = [
},
],
},
expected: { name: 'FloatField', isCollection: false, isCollectionOrScalar: true },
expected: { name: 'FloatField', cardinality: 'SINGLE_OR_COLLECTION' },
},
{
name: 'CollectionOrScalar StringField',
name: 'SINGLE_OR_COLLECTION StringField',
schema: {
anyOf: [
{
@ -104,10 +105,10 @@ const primitiveTypes: ParseFieldTypeTestCase[] = [
},
],
},
expected: { name: 'StringField', isCollection: false, isCollectionOrScalar: true },
expected: { name: 'StringField', cardinality: 'SINGLE_OR_COLLECTION' },
},
{
name: 'CollectionOrScalar BooleanField',
name: 'SINGLE_OR_COLLECTION BooleanField',
schema: {
anyOf: [
{
@ -121,13 +122,13 @@ const primitiveTypes: ParseFieldTypeTestCase[] = [
},
],
},
expected: { name: 'BooleanField', isCollection: false, isCollectionOrScalar: true },
expected: { name: 'BooleanField', cardinality: 'SINGLE_OR_COLLECTION' },
},
];
const complexTypes: ParseFieldTypeTestCase[] = [
{
name: 'Scalar ConditioningField',
name: 'SINGLE ConditioningField',
schema: {
allOf: [
{
@ -135,10 +136,10 @@ const complexTypes: ParseFieldTypeTestCase[] = [
},
],
},
expected: { name: 'ConditioningField', isCollection: false, isCollectionOrScalar: false },
expected: { name: 'ConditioningField', cardinality: 'SINGLE' },
},
{
name: 'Nullable Scalar ConditioningField',
name: 'Nullable SINGLE ConditioningField',
schema: {
anyOf: [
{
@ -149,10 +150,10 @@ const complexTypes: ParseFieldTypeTestCase[] = [
},
],
},
expected: { name: 'ConditioningField', isCollection: false, isCollectionOrScalar: false },
expected: { name: 'ConditioningField', cardinality: 'SINGLE' },
},
{
name: 'Collection ConditioningField',
name: 'COLLECTION ConditioningField',
schema: {
anyOf: [
{
@ -163,7 +164,7 @@ const complexTypes: ParseFieldTypeTestCase[] = [
},
],
},
expected: { name: 'ConditioningField', isCollection: true, isCollectionOrScalar: false },
expected: { name: 'ConditioningField', cardinality: 'COLLECTION' },
},
{
name: 'Nullable Collection ConditioningField',
@ -180,10 +181,10 @@ const complexTypes: ParseFieldTypeTestCase[] = [
},
],
},
expected: { name: 'ConditioningField', isCollection: true, isCollectionOrScalar: false },
expected: { name: 'ConditioningField', cardinality: 'COLLECTION' },
},
{
name: 'CollectionOrScalar ConditioningField',
name: 'SINGLE_OR_COLLECTION ConditioningField',
schema: {
anyOf: [
{
@ -197,10 +198,10 @@ const complexTypes: ParseFieldTypeTestCase[] = [
},
],
},
expected: { name: 'ConditioningField', isCollection: false, isCollectionOrScalar: true },
expected: { name: 'ConditioningField', cardinality: 'SINGLE_OR_COLLECTION' },
},
{
name: 'Nullable CollectionOrScalar ConditioningField',
name: 'Nullable SINGLE_OR_COLLECTION ConditioningField',
schema: {
anyOf: [
{
@ -217,7 +218,7 @@ const complexTypes: ParseFieldTypeTestCase[] = [
},
],
},
expected: { name: 'ConditioningField', isCollection: false, isCollectionOrScalar: true },
expected: { name: 'ConditioningField', cardinality: 'SINGLE_OR_COLLECTION' },
},
];
@ -228,14 +229,14 @@ const specialCases: ParseFieldTypeTestCase[] = [
type: 'string',
enum: ['large', 'base', 'small'],
},
expected: { name: 'EnumField', isCollection: false, isCollectionOrScalar: false },
expected: { name: 'EnumField', cardinality: 'SINGLE' },
},
{
name: 'String EnumField with one value',
schema: {
const: 'Some Value',
},
expected: { name: 'EnumField', isCollection: false, isCollectionOrScalar: false },
expected: { name: 'EnumField', cardinality: 'SINGLE' },
},
{
name: 'Explicit ui_type (SchedulerField)',
@ -244,7 +245,7 @@ const specialCases: ParseFieldTypeTestCase[] = [
enum: ['ddim', 'ddpm', 'deis'],
ui_type: 'SchedulerField',
},
expected: { name: 'EnumField', isCollection: false, isCollectionOrScalar: false },
expected: { name: 'EnumField', cardinality: 'SINGLE' },
},
{
name: 'Explicit ui_type (AnyField)',
@ -253,7 +254,7 @@ const specialCases: ParseFieldTypeTestCase[] = [
enum: ['ddim', 'ddpm', 'deis'],
ui_type: 'AnyField',
},
expected: { name: 'EnumField', isCollection: false, isCollectionOrScalar: false },
expected: { name: 'EnumField', cardinality: 'SINGLE' },
},
{
name: 'Explicit ui_type (CollectionField)',
@ -262,7 +263,7 @@ const specialCases: ParseFieldTypeTestCase[] = [
enum: ['ddim', 'ddpm', 'deis'],
ui_type: 'CollectionField',
},
expected: { name: 'EnumField', isCollection: false, isCollectionOrScalar: false },
expected: { name: 'EnumField', cardinality: 'SINGLE' },
},
];

@ -48,8 +48,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType
// Fields with a single const value are defined as `Literal["value"]` in the pydantic schema - it's actually an enum
return {
name: 'EnumField',
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
};
}
if (!schemaObject.type) {
@ -65,8 +64,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType
}
return {
name,
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
};
}
} else if (schemaObject.anyOf) {
@ -89,8 +87,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType
return {
name,
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
};
} else if (isSchemaObject(filteredAnyOf[0])) {
return parseFieldType(filteredAnyOf[0]);
@ -143,8 +140,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType
if (firstType && firstType === secondType) {
return {
name: OPENAPI_TO_FIELD_TYPE_MAP[firstType] ?? firstType,
isCollection: false,
isCollectionOrScalar: true, // <-- don't forget, CollectionOrScalar type!
cardinality: 'SINGLE_OR_COLLECTION',
};
}
@ -158,8 +154,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType
} else if (schemaObject.enum) {
return {
name: 'EnumField',
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
};
} else if (schemaObject.type) {
if (schemaObject.type === 'array') {
@ -185,8 +180,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType
}
return {
name,
isCollection: true, // <-- don't forget, collection!
isCollectionOrScalar: false,
cardinality: 'COLLECTION',
};
}
@ -197,8 +191,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType
}
return {
name,
isCollection: true, // <-- don't forget, collection!
isCollectionOrScalar: false,
cardinality: 'COLLECTION',
};
} else if (!isArray(schemaObject.type)) {
// This is an OpenAPI primitive - 'null', 'object', 'array', 'integer', 'number', 'string', 'boolean'
@ -213,8 +206,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType
}
return {
name,
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
};
}
}
@ -225,8 +217,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType
}
return {
name,
isCollection: false,
isCollectionOrScalar: false,
cardinality: 'SINGLE',
};
}
throw new FieldParseError(t('nodes.unableToParseFieldType'));

@ -100,11 +100,10 @@ export const parseSchema = (
return inputsAccumulator;
}
const fieldTypeOverride = property.ui_type
const fieldTypeOverride: FieldType | null = property.ui_type
? {
name: property.ui_type,
isCollection: isCollectionFieldType(property.ui_type),
isCollectionOrScalar: false,
cardinality: isCollectionFieldType(property.ui_type) ? 'COLLECTION' : 'SINGLE',
}
: null;
@ -178,11 +177,10 @@ export const parseSchema = (
return outputsAccumulator;
}
const fieldTypeOverride = property.ui_type
const fieldTypeOverride: FieldType | null = property.ui_type
? {
name: property.ui_type,
isCollection: isCollectionFieldType(property.ui_type),
isCollectionOrScalar: false,
cardinality: isCollectionFieldType(property.ui_type) ? 'COLLECTION' : 'SINGLE',
}
: null;