mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): add originalType
to FieldType, improved connection validation
We now keep track of the original field type, derived from the python type annotation in addition to the override type provided by `ui_type`. This makes `ui_type` work more like it sound like it should work - change the UI input component only. Connection validation is extend to also check the original types. If there is any match between two fields' "final" or original types, we consider the connection valid.This change is backwards-compatible; there is no workflow migration needed.
This commit is contained in:
parent
af3fd26d4e
commit
85a5a7c47a
@ -4,9 +4,8 @@ import { useAppSelector, useAppStore } from 'app/store/storeHooks';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import { getIsGraphAcyclic } from 'features/nodes/store/util/getIsGraphAcyclic';
|
||||
import { getCollectItemType } from 'features/nodes/store/util/makeIsConnectionValidSelector';
|
||||
import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes';
|
||||
import { areTypesEqual, validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes';
|
||||
import type { InvocationNodeData } from 'features/nodes/types/invocation';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import { useCallback } from 'react';
|
||||
import type { Connection, Node } from 'reactflow';
|
||||
|
||||
@ -70,7 +69,7 @@ export const useIsValidConnection = () => {
|
||||
// Collect nodes shouldn't mix and match field types
|
||||
const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id);
|
||||
if (collectItemType) {
|
||||
return isEqual(sourceFieldTemplate.type, collectItemType);
|
||||
return areTypesEqual(sourceFieldTemplate.type, collectItemType);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,12 +1,12 @@
|
||||
import type { PendingConnection, Templates } from 'features/nodes/store/types';
|
||||
import { getCollectItemType } from 'features/nodes/store/util/makeIsConnectionValidSelector';
|
||||
import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation';
|
||||
import { differenceWith, isEqual, map } from 'lodash-es';
|
||||
import { differenceWith, map } from 'lodash-es';
|
||||
import type { Connection } from 'reactflow';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
|
||||
import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
|
||||
import { areTypesEqual, validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
|
||||
|
||||
export const getFirstValidConnection = (
|
||||
templates: Templates,
|
||||
@ -83,7 +83,7 @@ export const getFirstValidConnection = (
|
||||
// Narrow candidates to same field type as already is connected to the collect node
|
||||
const collectItemType = getCollectItemType(templates, nodes, edges, pendingConnection.node.id);
|
||||
if (collectItemType) {
|
||||
candidateFields = candidateFields.filter((field) => isEqual(field.type, collectItemType));
|
||||
candidateFields = candidateFields.filter((field) => areTypesEqual(field.type, collectItemType));
|
||||
}
|
||||
}
|
||||
const candidateField = candidateFields.find((field) => {
|
||||
|
@ -4,12 +4,11 @@ import type { PendingConnection, Templates } from 'features/nodes/store/types';
|
||||
import type { FieldType } from 'features/nodes/types/field';
|
||||
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
|
||||
import i18n from 'i18next';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import type { HandleType } from 'reactflow';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
|
||||
import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
|
||||
import { areTypesEqual, validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
|
||||
|
||||
export const getCollectItemType = (
|
||||
templates: Templates,
|
||||
@ -111,7 +110,7 @@ export const makeConnectionErrorSelector = (
|
||||
// Collect nodes shouldn't mix and match field types
|
||||
const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id);
|
||||
if (collectItemType) {
|
||||
if (!isEqual(sourceType, collectItemType)) {
|
||||
if (!areTypesEqual(sourceType, collectItemType)) {
|
||||
return i18n.t('nodes.cannotMixAndMatchCollectionItemTypes');
|
||||
}
|
||||
}
|
||||
|
@ -1,5 +1,25 @@
|
||||
import type { FieldType } from 'features/nodes/types/field';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import { type FieldType, isStatefulFieldType } from 'features/nodes/types/field';
|
||||
import { isEqual, omit } from 'lodash-es';
|
||||
|
||||
export const areTypesEqual = (sourceType: FieldType, targetType: FieldType) => {
|
||||
const _sourceType = isStatefulFieldType(sourceType) ? omit(sourceType, 'originalType') : sourceType;
|
||||
const _targetType = isStatefulFieldType(targetType) ? omit(targetType, 'originalType') : targetType;
|
||||
const _sourceTypeOriginal = isStatefulFieldType(sourceType) ? sourceType.originalType : sourceType;
|
||||
const _targetTypeOriginal = isStatefulFieldType(targetType) ? targetType.originalType : targetType;
|
||||
if (isEqual(_sourceType, _targetType)) {
|
||||
return true;
|
||||
}
|
||||
if (isEqual(_sourceType, _targetTypeOriginal)) {
|
||||
return true;
|
||||
}
|
||||
if (isEqual(_sourceTypeOriginal, _targetType)) {
|
||||
return true;
|
||||
}
|
||||
if (isEqual(_sourceTypeOriginal, _targetTypeOriginal)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
/**
|
||||
* Validates that the source and target types are compatible for a connection.
|
||||
@ -15,7 +35,7 @@ export const validateSourceAndTargetTypes = (sourceType: FieldType, targetType:
|
||||
return false;
|
||||
}
|
||||
|
||||
if (isEqual(sourceType, targetType)) {
|
||||
if (areTypesEqual(sourceType, targetType)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -66,16 +66,114 @@ export const zFieldIdentifier = z.object({
|
||||
export type FieldIdentifier = z.infer<typeof zFieldIdentifier>;
|
||||
// #endregion
|
||||
|
||||
// #region IntegerField
|
||||
// #region Field Types
|
||||
const zStatelessFieldType = zFieldTypeBase.extend({
|
||||
name: z.string().min(1), // stateless --> we accept the field's name as the type
|
||||
});
|
||||
const zIntegerFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('IntegerField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zFloatFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('FloatField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zStringFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('StringField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zBooleanFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('BooleanField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zEnumFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('EnumField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zImageFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('ImageField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zBoardFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('BoardField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zColorFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('ColorField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zMainModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('MainModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zSDXLMainModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SDXLMainModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SDXLRefinerModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zVAEModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('VAEModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zLoRAModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('LoRAModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zControlNetModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('ControlNetModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zIPAdapterModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('IPAdapterModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zT2IAdapterModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('T2IAdapterModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zSchedulerFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SchedulerField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zStatefulFieldType = z.union([
|
||||
zIntegerFieldType,
|
||||
zFloatFieldType,
|
||||
zStringFieldType,
|
||||
zBooleanFieldType,
|
||||
zEnumFieldType,
|
||||
zImageFieldType,
|
||||
zBoardFieldType,
|
||||
zMainModelFieldType,
|
||||
zSDXLMainModelFieldType,
|
||||
zSDXLRefinerModelFieldType,
|
||||
zVAEModelFieldType,
|
||||
zLoRAModelFieldType,
|
||||
zControlNetModelFieldType,
|
||||
zIPAdapterModelFieldType,
|
||||
zT2IAdapterModelFieldType,
|
||||
zColorFieldType,
|
||||
zSchedulerFieldType,
|
||||
]);
|
||||
export type StatefulFieldType = z.infer<typeof zStatefulFieldType>;
|
||||
const statefulFieldTypeNames = zStatefulFieldType.options.map((o) => o.shape.name.value);
|
||||
export const isStatefulFieldType = (fieldType: FieldType): fieldType is StatefulFieldType =>
|
||||
statefulFieldTypeNames.includes(fieldType.name as any);
|
||||
const zFieldType = z.union([zStatefulFieldType, zStatelessFieldType]);
|
||||
export type FieldType = z.infer<typeof zFieldType>;
|
||||
// #endregion
|
||||
|
||||
// #region IntegerField
|
||||
|
||||
export const zIntegerFieldValue = z.number().int();
|
||||
const zIntegerFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zIntegerFieldValue,
|
||||
});
|
||||
const zIntegerFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zIntegerFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zIntegerFieldValue,
|
||||
multipleOf: z.number().int().optional(),
|
||||
maximum: z.number().int().optional(),
|
||||
@ -85,6 +183,7 @@ const zIntegerFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
});
|
||||
const zIntegerFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zIntegerFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
});
|
||||
export type IntegerFieldValue = z.infer<typeof zIntegerFieldValue>;
|
||||
export type IntegerFieldInputInstance = z.infer<typeof zIntegerFieldInputInstance>;
|
||||
@ -96,15 +195,14 @@ export const isIntegerFieldInputTemplate = (val: unknown): val is IntegerFieldIn
|
||||
// #endregion
|
||||
|
||||
// #region FloatField
|
||||
const zFloatFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('FloatField'),
|
||||
});
|
||||
|
||||
export const zFloatFieldValue = z.number();
|
||||
const zFloatFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zFloatFieldValue,
|
||||
});
|
||||
const zFloatFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zFloatFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zFloatFieldValue,
|
||||
multipleOf: z.number().optional(),
|
||||
maximum: z.number().optional(),
|
||||
@ -114,6 +212,7 @@ const zFloatFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
});
|
||||
const zFloatFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zFloatFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
});
|
||||
export type FloatFieldValue = z.infer<typeof zFloatFieldValue>;
|
||||
export type FloatFieldInputInstance = z.infer<typeof zFloatFieldInputInstance>;
|
||||
@ -125,21 +224,21 @@ export const isFloatFieldInputTemplate = (val: unknown): val is FloatFieldInputT
|
||||
// #endregion
|
||||
|
||||
// #region StringField
|
||||
const zStringFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('StringField'),
|
||||
});
|
||||
|
||||
export const zStringFieldValue = z.string();
|
||||
const zStringFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zStringFieldValue,
|
||||
});
|
||||
const zStringFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zStringFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zStringFieldValue,
|
||||
maxLength: z.number().int().optional(),
|
||||
minLength: z.number().int().optional(),
|
||||
});
|
||||
const zStringFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zStringFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
});
|
||||
|
||||
export type StringFieldValue = z.infer<typeof zStringFieldValue>;
|
||||
@ -152,19 +251,19 @@ export const isStringFieldInputTemplate = (val: unknown): val is StringFieldInpu
|
||||
// #endregion
|
||||
|
||||
// #region BooleanField
|
||||
const zBooleanFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('BooleanField'),
|
||||
});
|
||||
|
||||
export const zBooleanFieldValue = z.boolean();
|
||||
const zBooleanFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zBooleanFieldValue,
|
||||
});
|
||||
const zBooleanFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zBooleanFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zBooleanFieldValue,
|
||||
});
|
||||
const zBooleanFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zBooleanFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
});
|
||||
export type BooleanFieldValue = z.infer<typeof zBooleanFieldValue>;
|
||||
export type BooleanFieldInputInstance = z.infer<typeof zBooleanFieldInputInstance>;
|
||||
@ -176,21 +275,21 @@ export const isBooleanFieldInputTemplate = (val: unknown): val is BooleanFieldIn
|
||||
// #endregion
|
||||
|
||||
// #region EnumField
|
||||
const zEnumFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('EnumField'),
|
||||
});
|
||||
|
||||
export const zEnumFieldValue = z.string();
|
||||
const zEnumFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zEnumFieldValue,
|
||||
});
|
||||
const zEnumFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zEnumFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zEnumFieldValue,
|
||||
options: z.array(z.string()),
|
||||
labels: z.record(z.string()).optional(),
|
||||
});
|
||||
const zEnumFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zEnumFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
});
|
||||
export type EnumFieldValue = z.infer<typeof zEnumFieldValue>;
|
||||
export type EnumFieldInputInstance = z.infer<typeof zEnumFieldInputInstance>;
|
||||
@ -202,19 +301,19 @@ export const isEnumFieldInputTemplate = (val: unknown): val is EnumFieldInputTem
|
||||
// #endregion
|
||||
|
||||
// #region ImageField
|
||||
const zImageFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('ImageField'),
|
||||
});
|
||||
|
||||
export const zImageFieldValue = zImageField.optional();
|
||||
const zImageFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zImageFieldValue,
|
||||
});
|
||||
const zImageFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zImageFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zImageFieldValue,
|
||||
});
|
||||
const zImageFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zImageFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
});
|
||||
export type ImageFieldValue = z.infer<typeof zImageFieldValue>;
|
||||
export type ImageFieldInputInstance = z.infer<typeof zImageFieldInputInstance>;
|
||||
@ -226,19 +325,19 @@ export const isImageFieldInputTemplate = (val: unknown): val is ImageFieldInputT
|
||||
// #endregion
|
||||
|
||||
// #region BoardField
|
||||
const zBoardFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('BoardField'),
|
||||
});
|
||||
|
||||
export const zBoardFieldValue = zBoardField.optional();
|
||||
const zBoardFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zBoardFieldValue,
|
||||
});
|
||||
const zBoardFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zBoardFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zBoardFieldValue,
|
||||
});
|
||||
const zBoardFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zBoardFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
});
|
||||
export type BoardFieldValue = z.infer<typeof zBoardFieldValue>;
|
||||
export type BoardFieldInputInstance = z.infer<typeof zBoardFieldInputInstance>;
|
||||
@ -250,19 +349,19 @@ export const isBoardFieldInputTemplate = (val: unknown): val is BoardFieldInputT
|
||||
// #endregion
|
||||
|
||||
// #region ColorField
|
||||
const zColorFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('ColorField'),
|
||||
});
|
||||
|
||||
export const zColorFieldValue = zColorField.optional();
|
||||
const zColorFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zColorFieldValue,
|
||||
});
|
||||
const zColorFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zColorFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zColorFieldValue,
|
||||
});
|
||||
const zColorFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zColorFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
});
|
||||
export type ColorFieldValue = z.infer<typeof zColorFieldValue>;
|
||||
export type ColorFieldInputInstance = z.infer<typeof zColorFieldInputInstance>;
|
||||
@ -274,19 +373,19 @@ export const isColorFieldInputTemplate = (val: unknown): val is ColorFieldInputT
|
||||
// #endregion
|
||||
|
||||
// #region MainModelField
|
||||
const zMainModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('MainModelField'),
|
||||
});
|
||||
|
||||
export const zMainModelFieldValue = zModelIdentifierField.optional();
|
||||
const zMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zMainModelFieldValue,
|
||||
});
|
||||
const zMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zMainModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zMainModelFieldValue,
|
||||
});
|
||||
const zMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zMainModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
});
|
||||
export type MainModelFieldValue = z.infer<typeof zMainModelFieldValue>;
|
||||
export type MainModelFieldInputInstance = z.infer<typeof zMainModelFieldInputInstance>;
|
||||
@ -298,19 +397,19 @@ export const isMainModelFieldInputTemplate = (val: unknown): val is MainModelFie
|
||||
// #endregion
|
||||
|
||||
// #region SDXLMainModelField
|
||||
const zSDXLMainModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SDXLMainModelField'),
|
||||
});
|
||||
|
||||
const zSDXLMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
|
||||
const zSDXLMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zSDXLMainModelFieldValue,
|
||||
});
|
||||
const zSDXLMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zSDXLMainModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zSDXLMainModelFieldValue,
|
||||
});
|
||||
const zSDXLMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zSDXLMainModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
});
|
||||
export type SDXLMainModelFieldInputInstance = z.infer<typeof zSDXLMainModelFieldInputInstance>;
|
||||
export type SDXLMainModelFieldInputTemplate = z.infer<typeof zSDXLMainModelFieldInputTemplate>;
|
||||
@ -321,9 +420,7 @@ export const isSDXLMainModelFieldInputTemplate = (val: unknown): val is SDXLMain
|
||||
// #endregion
|
||||
|
||||
// #region SDXLRefinerModelField
|
||||
const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SDXLRefinerModelField'),
|
||||
});
|
||||
|
||||
/** @alias */ // tells knip to ignore this duplicate export
|
||||
export const zSDXLRefinerModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL Refiner models only.
|
||||
const zSDXLRefinerModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
@ -331,10 +428,12 @@ const zSDXLRefinerModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
});
|
||||
const zSDXLRefinerModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zSDXLRefinerModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zSDXLRefinerModelFieldValue,
|
||||
});
|
||||
const zSDXLRefinerModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zSDXLRefinerModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
});
|
||||
export type SDXLRefinerModelFieldValue = z.infer<typeof zSDXLRefinerModelFieldValue>;
|
||||
export type SDXLRefinerModelFieldInputInstance = z.infer<typeof zSDXLRefinerModelFieldInputInstance>;
|
||||
@ -346,19 +445,19 @@ export const isSDXLRefinerModelFieldInputTemplate = (val: unknown): val is SDXLR
|
||||
// #endregion
|
||||
|
||||
// #region VAEModelField
|
||||
const zVAEModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('VAEModelField'),
|
||||
});
|
||||
|
||||
export const zVAEModelFieldValue = zModelIdentifierField.optional();
|
||||
const zVAEModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zVAEModelFieldValue,
|
||||
});
|
||||
const zVAEModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zVAEModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zVAEModelFieldValue,
|
||||
});
|
||||
const zVAEModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zVAEModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
});
|
||||
export type VAEModelFieldValue = z.infer<typeof zVAEModelFieldValue>;
|
||||
export type VAEModelFieldInputInstance = z.infer<typeof zVAEModelFieldInputInstance>;
|
||||
@ -370,19 +469,19 @@ export const isVAEModelFieldInputTemplate = (val: unknown): val is VAEModelField
|
||||
// #endregion
|
||||
|
||||
// #region LoRAModelField
|
||||
const zLoRAModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('LoRAModelField'),
|
||||
});
|
||||
|
||||
export const zLoRAModelFieldValue = zModelIdentifierField.optional();
|
||||
const zLoRAModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zLoRAModelFieldValue,
|
||||
});
|
||||
const zLoRAModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zLoRAModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zLoRAModelFieldValue,
|
||||
});
|
||||
const zLoRAModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zLoRAModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
});
|
||||
export type LoRAModelFieldValue = z.infer<typeof zLoRAModelFieldValue>;
|
||||
export type LoRAModelFieldInputInstance = z.infer<typeof zLoRAModelFieldInputInstance>;
|
||||
@ -394,19 +493,19 @@ export const isLoRAModelFieldInputTemplate = (val: unknown): val is LoRAModelFie
|
||||
// #endregion
|
||||
|
||||
// #region ControlNetModelField
|
||||
const zControlNetModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('ControlNetModelField'),
|
||||
});
|
||||
|
||||
export const zControlNetModelFieldValue = zModelIdentifierField.optional();
|
||||
const zControlNetModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zControlNetModelFieldValue,
|
||||
});
|
||||
const zControlNetModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zControlNetModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zControlNetModelFieldValue,
|
||||
});
|
||||
const zControlNetModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zControlNetModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
});
|
||||
export type ControlNetModelFieldValue = z.infer<typeof zControlNetModelFieldValue>;
|
||||
export type ControlNetModelFieldInputInstance = z.infer<typeof zControlNetModelFieldInputInstance>;
|
||||
@ -418,19 +517,19 @@ export const isControlNetModelFieldInputTemplate = (val: unknown): val is Contro
|
||||
// #endregion
|
||||
|
||||
// #region IPAdapterModelField
|
||||
const zIPAdapterModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('IPAdapterModelField'),
|
||||
});
|
||||
|
||||
export const zIPAdapterModelFieldValue = zModelIdentifierField.optional();
|
||||
const zIPAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zIPAdapterModelFieldValue,
|
||||
});
|
||||
const zIPAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zIPAdapterModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zIPAdapterModelFieldValue,
|
||||
});
|
||||
const zIPAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zIPAdapterModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
});
|
||||
export type IPAdapterModelFieldValue = z.infer<typeof zIPAdapterModelFieldValue>;
|
||||
export type IPAdapterModelFieldInputInstance = z.infer<typeof zIPAdapterModelFieldInputInstance>;
|
||||
@ -442,19 +541,19 @@ export const isIPAdapterModelFieldInputTemplate = (val: unknown): val is IPAdapt
|
||||
// #endregion
|
||||
|
||||
// #region T2IAdapterField
|
||||
const zT2IAdapterModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('T2IAdapterModelField'),
|
||||
});
|
||||
|
||||
export const zT2IAdapterModelFieldValue = zModelIdentifierField.optional();
|
||||
const zT2IAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zT2IAdapterModelFieldValue,
|
||||
});
|
||||
const zT2IAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zT2IAdapterModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zT2IAdapterModelFieldValue,
|
||||
});
|
||||
const zT2IAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zT2IAdapterModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
});
|
||||
export type T2IAdapterModelFieldValue = z.infer<typeof zT2IAdapterModelFieldValue>;
|
||||
export type T2IAdapterModelFieldInputInstance = z.infer<typeof zT2IAdapterModelFieldInputInstance>;
|
||||
@ -466,19 +565,19 @@ export const isT2IAdapterModelFieldInputTemplate = (val: unknown): val is T2IAda
|
||||
// #endregion
|
||||
|
||||
// #region SchedulerField
|
||||
const zSchedulerFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SchedulerField'),
|
||||
});
|
||||
|
||||
export const zSchedulerFieldValue = zSchedulerField.optional();
|
||||
const zSchedulerFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zSchedulerFieldValue,
|
||||
});
|
||||
const zSchedulerFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zSchedulerFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zSchedulerFieldValue,
|
||||
});
|
||||
const zSchedulerFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zSchedulerFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
});
|
||||
export type SchedulerFieldValue = z.infer<typeof zSchedulerFieldValue>;
|
||||
export type SchedulerFieldInputInstance = z.infer<typeof zSchedulerFieldInputInstance>;
|
||||
@ -501,20 +600,20 @@ export const isSchedulerFieldInputTemplate = (val: unknown): val is SchedulerFie
|
||||
* - Reserved fields like IsIntermediate
|
||||
* - Any other field we don't have full-on schemas for
|
||||
*/
|
||||
const zStatelessFieldType = zFieldTypeBase.extend({
|
||||
name: z.string().min(1), // stateless --> we accept the field's name as the type
|
||||
});
|
||||
|
||||
const zStatelessFieldValue = z.undefined().catch(undefined); // stateless --> no value, but making this z.never() introduces a lot of extra TS fanagling
|
||||
const zStatelessFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zStatelessFieldValue,
|
||||
});
|
||||
const zStatelessFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zStatelessFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zStatelessFieldValue,
|
||||
input: z.literal('connection'), // stateless --> only accepts connection inputs
|
||||
});
|
||||
const zStatelessFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zStatelessFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
});
|
||||
|
||||
export type StatelessFieldInputTemplate = z.infer<typeof zStatelessFieldInputTemplate>;
|
||||
@ -535,34 +634,6 @@ export type StatelessFieldInputTemplate = z.infer<typeof zStatelessFieldInputTem
|
||||
* for all other StatelessFields.
|
||||
*/
|
||||
|
||||
// #region StatefulFieldType & FieldType
|
||||
const zStatefulFieldType = z.union([
|
||||
zIntegerFieldType,
|
||||
zFloatFieldType,
|
||||
zStringFieldType,
|
||||
zBooleanFieldType,
|
||||
zEnumFieldType,
|
||||
zImageFieldType,
|
||||
zBoardFieldType,
|
||||
zMainModelFieldType,
|
||||
zSDXLMainModelFieldType,
|
||||
zSDXLRefinerModelFieldType,
|
||||
zVAEModelFieldType,
|
||||
zLoRAModelFieldType,
|
||||
zControlNetModelFieldType,
|
||||
zIPAdapterModelFieldType,
|
||||
zT2IAdapterModelFieldType,
|
||||
zColorFieldType,
|
||||
zSchedulerFieldType,
|
||||
]);
|
||||
export type StatefulFieldType = z.infer<typeof zStatefulFieldType>;
|
||||
export const isStatefulFieldType = (val: unknown): val is StatefulFieldType =>
|
||||
zStatefulFieldType.safeParse(val).success;
|
||||
|
||||
const zFieldType = z.union([zStatefulFieldType, zStatelessFieldType]);
|
||||
export type FieldType = z.infer<typeof zFieldType>;
|
||||
// #endregion
|
||||
|
||||
// #region StatefulFieldValue & FieldValue
|
||||
export const zStatefulFieldValue = z.union([
|
||||
zIntegerFieldValue,
|
||||
|
@ -30,26 +30,16 @@ import { isNumber, startCase } from 'lodash-es';
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
type FieldInputTemplateBuilder<T extends FieldInputTemplate = any> = // valid `any`!
|
||||
(arg: {
|
||||
schemaObject: InvocationFieldSchema;
|
||||
baseField: Omit<T, 'type'>;
|
||||
isCollection: boolean;
|
||||
isCollectionOrScalar: boolean;
|
||||
}) => T;
|
||||
(arg: { schemaObject: InvocationFieldSchema; baseField: Omit<T, 'type'>; fieldType: T['type'] }) => T;
|
||||
|
||||
const buildIntegerFieldInputTemplate: FieldInputTemplateBuilder<IntegerFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
isCollection,
|
||||
isCollectionOrScalar,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: IntegerFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: {
|
||||
name: 'IntegerField',
|
||||
isCollection,
|
||||
isCollectionOrScalar,
|
||||
},
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? 0,
|
||||
};
|
||||
|
||||
@ -79,16 +69,11 @@ const buildIntegerFieldInputTemplate: FieldInputTemplateBuilder<IntegerFieldInpu
|
||||
const buildFloatFieldInputTemplate: FieldInputTemplateBuilder<FloatFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
isCollection,
|
||||
isCollectionOrScalar,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: FloatFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: {
|
||||
name: 'FloatField',
|
||||
isCollection,
|
||||
isCollectionOrScalar,
|
||||
},
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? 0,
|
||||
};
|
||||
|
||||
@ -118,16 +103,11 @@ const buildFloatFieldInputTemplate: FieldInputTemplateBuilder<FloatFieldInputTem
|
||||
const buildStringFieldInputTemplate: FieldInputTemplateBuilder<StringFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
isCollection,
|
||||
isCollectionOrScalar,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: StringFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: {
|
||||
name: 'StringField',
|
||||
isCollection,
|
||||
isCollectionOrScalar,
|
||||
},
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? '',
|
||||
};
|
||||
|
||||
@ -145,16 +125,11 @@ const buildStringFieldInputTemplate: FieldInputTemplateBuilder<StringFieldInputT
|
||||
const buildBooleanFieldInputTemplate: FieldInputTemplateBuilder<BooleanFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
isCollection,
|
||||
isCollectionOrScalar,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: BooleanFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: {
|
||||
name: 'BooleanField',
|
||||
isCollection,
|
||||
isCollectionOrScalar,
|
||||
},
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? false,
|
||||
};
|
||||
|
||||
@ -164,16 +139,11 @@ const buildBooleanFieldInputTemplate: FieldInputTemplateBuilder<BooleanFieldInpu
|
||||
const buildMainModelFieldInputTemplate: FieldInputTemplateBuilder<MainModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
isCollection,
|
||||
isCollectionOrScalar,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: MainModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: {
|
||||
name: 'MainModelField',
|
||||
isCollection,
|
||||
isCollectionOrScalar,
|
||||
},
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
@ -183,16 +153,11 @@ const buildMainModelFieldInputTemplate: FieldInputTemplateBuilder<MainModelField
|
||||
const buildSDXLMainModelFieldInputTemplate: FieldInputTemplateBuilder<SDXLMainModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
isCollection,
|
||||
isCollectionOrScalar,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: SDXLMainModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: {
|
||||
name: 'SDXLMainModelField',
|
||||
isCollection,
|
||||
isCollectionOrScalar,
|
||||
},
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
@ -202,16 +167,11 @@ const buildSDXLMainModelFieldInputTemplate: FieldInputTemplateBuilder<SDXLMainMo
|
||||
const buildRefinerModelFieldInputTemplate: FieldInputTemplateBuilder<SDXLRefinerModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
isCollection,
|
||||
isCollectionOrScalar,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: SDXLRefinerModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: {
|
||||
name: 'SDXLRefinerModelField',
|
||||
isCollection,
|
||||
isCollectionOrScalar,
|
||||
},
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
@ -221,16 +181,11 @@ const buildRefinerModelFieldInputTemplate: FieldInputTemplateBuilder<SDXLRefiner
|
||||
const buildVAEModelFieldInputTemplate: FieldInputTemplateBuilder<VAEModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
isCollection,
|
||||
isCollectionOrScalar,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: VAEModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: {
|
||||
name: 'VAEModelField',
|
||||
isCollection,
|
||||
isCollectionOrScalar,
|
||||
},
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
@ -240,16 +195,11 @@ const buildVAEModelFieldInputTemplate: FieldInputTemplateBuilder<VAEModelFieldIn
|
||||
const buildLoRAModelFieldInputTemplate: FieldInputTemplateBuilder<LoRAModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
isCollection,
|
||||
isCollectionOrScalar,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: LoRAModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: {
|
||||
name: 'LoRAModelField',
|
||||
isCollection,
|
||||
isCollectionOrScalar,
|
||||
},
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
@ -259,16 +209,11 @@ const buildLoRAModelFieldInputTemplate: FieldInputTemplateBuilder<LoRAModelField
|
||||
const buildControlNetModelFieldInputTemplate: FieldInputTemplateBuilder<ControlNetModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
isCollection,
|
||||
isCollectionOrScalar,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: ControlNetModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: {
|
||||
name: 'ControlNetModelField',
|
||||
isCollection,
|
||||
isCollectionOrScalar,
|
||||
},
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
@ -278,16 +223,11 @@ const buildControlNetModelFieldInputTemplate: FieldInputTemplateBuilder<ControlN
|
||||
const buildIPAdapterModelFieldInputTemplate: FieldInputTemplateBuilder<IPAdapterModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
isCollection,
|
||||
isCollectionOrScalar,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: IPAdapterModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: {
|
||||
name: 'IPAdapterModelField',
|
||||
isCollection,
|
||||
isCollectionOrScalar,
|
||||
},
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
@ -297,16 +237,11 @@ const buildIPAdapterModelFieldInputTemplate: FieldInputTemplateBuilder<IPAdapter
|
||||
const buildT2IAdapterModelFieldInputTemplate: FieldInputTemplateBuilder<T2IAdapterModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
isCollection,
|
||||
isCollectionOrScalar,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: T2IAdapterModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: {
|
||||
name: 'T2IAdapterModelField',
|
||||
isCollection,
|
||||
isCollectionOrScalar,
|
||||
},
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
@ -316,16 +251,11 @@ const buildT2IAdapterModelFieldInputTemplate: FieldInputTemplateBuilder<T2IAdapt
|
||||
const buildBoardFieldInputTemplate: FieldInputTemplateBuilder<BoardFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
isCollection,
|
||||
isCollectionOrScalar,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: BoardFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: {
|
||||
name: 'BoardField',
|
||||
isCollection,
|
||||
isCollectionOrScalar,
|
||||
},
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
@ -335,16 +265,11 @@ const buildBoardFieldInputTemplate: FieldInputTemplateBuilder<BoardFieldInputTem
|
||||
const buildImageFieldInputTemplate: FieldInputTemplateBuilder<ImageFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
isCollection,
|
||||
isCollectionOrScalar,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: ImageFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: {
|
||||
name: 'ImageField',
|
||||
isCollection,
|
||||
isCollectionOrScalar,
|
||||
},
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
@ -354,8 +279,7 @@ const buildImageFieldInputTemplate: FieldInputTemplateBuilder<ImageFieldInputTem
|
||||
const buildEnumFieldInputTemplate: FieldInputTemplateBuilder<EnumFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
isCollection,
|
||||
isCollectionOrScalar,
|
||||
fieldType,
|
||||
}) => {
|
||||
let options: EnumFieldInputTemplate['options'] = [];
|
||||
if (schemaObject.anyOf) {
|
||||
@ -383,11 +307,7 @@ const buildEnumFieldInputTemplate: FieldInputTemplateBuilder<EnumFieldInputTempl
|
||||
}
|
||||
const template: EnumFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: {
|
||||
name: 'EnumField',
|
||||
isCollection,
|
||||
isCollectionOrScalar,
|
||||
},
|
||||
type: fieldType,
|
||||
options,
|
||||
ui_choice_labels: schemaObject.ui_choice_labels,
|
||||
default: schemaObject.default ?? options[0],
|
||||
@ -399,16 +319,11 @@ const buildEnumFieldInputTemplate: FieldInputTemplateBuilder<EnumFieldInputTempl
|
||||
const buildColorFieldInputTemplate: FieldInputTemplateBuilder<ColorFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
isCollection,
|
||||
isCollectionOrScalar,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: ColorFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: {
|
||||
name: 'ColorField',
|
||||
isCollection,
|
||||
isCollectionOrScalar,
|
||||
},
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? { r: 127, g: 127, b: 127, a: 255 },
|
||||
};
|
||||
|
||||
@ -418,16 +333,11 @@ const buildColorFieldInputTemplate: FieldInputTemplateBuilder<ColorFieldInputTem
|
||||
const buildSchedulerFieldInputTemplate: FieldInputTemplateBuilder<SchedulerFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
isCollection,
|
||||
isCollectionOrScalar,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: SchedulerFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: {
|
||||
name: 'SchedulerField',
|
||||
isCollection,
|
||||
isCollectionOrScalar,
|
||||
},
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? 'euler',
|
||||
};
|
||||
|
||||
@ -452,7 +362,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
|
||||
StringField: buildStringFieldInputTemplate,
|
||||
T2IAdapterModelField: buildT2IAdapterModelFieldInputTemplate,
|
||||
VAEModelField: buildVAEModelFieldInputTemplate,
|
||||
};
|
||||
} as const;
|
||||
|
||||
export const buildFieldInputTemplate = (
|
||||
fieldSchema: InvocationFieldSchema,
|
||||
@ -479,20 +389,22 @@ export const buildFieldInputTemplate = (
|
||||
|
||||
if (isStatefulFieldType(fieldType)) {
|
||||
const builder = TEMPLATE_BUILDER_MAP[fieldType.name];
|
||||
return builder({
|
||||
const template = builder({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
isCollection: fieldType.isCollection,
|
||||
isCollectionOrScalar: fieldType.isCollectionOrScalar,
|
||||
fieldType,
|
||||
});
|
||||
}
|
||||
|
||||
// This is a StatelessField, create it directly.
|
||||
const template: StatelessFieldInputTemplate = {
|
||||
...baseField,
|
||||
input: 'connection', // stateless --> connection only inputs
|
||||
type: fieldType,
|
||||
default: undefined, // stateless --> no default value
|
||||
};
|
||||
return template;
|
||||
return template;
|
||||
} else {
|
||||
// This is a StatelessField, create it directly.
|
||||
const template: StatelessFieldInputTemplate = {
|
||||
...baseField,
|
||||
input: 'connection', // stateless --> connection only inputs
|
||||
type: fieldType,
|
||||
default: undefined, // stateless --> no default value
|
||||
};
|
||||
|
||||
return template;
|
||||
}
|
||||
};
|
||||
|
@ -9,7 +9,7 @@ export const buildFieldOutputTemplate = (
|
||||
): FieldOutputTemplate => {
|
||||
const { title, description, ui_hidden, ui_type, ui_order } = fieldSchema;
|
||||
|
||||
const fieldOutputTemplate: FieldOutputTemplate = {
|
||||
const template: FieldOutputTemplate = {
|
||||
fieldKind: 'output',
|
||||
name: fieldName,
|
||||
title: title ?? (fieldName ? startCase(fieldName) : ''),
|
||||
@ -20,5 +20,5 @@ export const buildFieldOutputTemplate = (
|
||||
ui_order,
|
||||
};
|
||||
|
||||
return fieldOutputTemplate;
|
||||
return template;
|
||||
};
|
||||
|
@ -244,7 +244,7 @@ const specialCases: ParseFieldTypeTestCase[] = [
|
||||
enum: ['ddim', 'ddpm', 'deis'],
|
||||
ui_type: 'SchedulerField',
|
||||
},
|
||||
expected: { name: 'SchedulerField', isCollection: false, isCollectionOrScalar: false },
|
||||
expected: { name: 'EnumField', isCollection: false, isCollectionOrScalar: false },
|
||||
},
|
||||
{
|
||||
name: 'Explicit ui_type (AnyField)',
|
||||
@ -253,7 +253,7 @@ const specialCases: ParseFieldTypeTestCase[] = [
|
||||
enum: ['ddim', 'ddpm', 'deis'],
|
||||
ui_type: 'AnyField',
|
||||
},
|
||||
expected: { name: 'AnyField', isCollection: false, isCollectionOrScalar: false },
|
||||
expected: { name: 'EnumField', isCollection: false, isCollectionOrScalar: false },
|
||||
},
|
||||
{
|
||||
name: 'Explicit ui_type (CollectionField)',
|
||||
@ -262,7 +262,7 @@ const specialCases: ParseFieldTypeTestCase[] = [
|
||||
enum: ['ddim', 'ddpm', 'deis'],
|
||||
ui_type: 'CollectionField',
|
||||
},
|
||||
expected: { name: 'CollectionField', isCollection: true, isCollectionOrScalar: false },
|
||||
expected: { name: 'EnumField', isCollection: false, isCollectionOrScalar: false },
|
||||
},
|
||||
];
|
||||
|
||||
|
@ -6,14 +6,8 @@ import {
|
||||
UnsupportedUnionError,
|
||||
} from 'features/nodes/types/error';
|
||||
import type { FieldType } from 'features/nodes/types/field';
|
||||
import type { InvocationFieldSchema, OpenAPIV3_1SchemaOrRef } from 'features/nodes/types/openapi';
|
||||
import {
|
||||
isArraySchemaObject,
|
||||
isInvocationFieldSchema,
|
||||
isNonArraySchemaObject,
|
||||
isRefObject,
|
||||
isSchemaObject,
|
||||
} from 'features/nodes/types/openapi';
|
||||
import type { OpenAPIV3_1SchemaOrRef } from 'features/nodes/types/openapi';
|
||||
import { isArraySchemaObject, isNonArraySchemaObject, isRefObject, isSchemaObject } from 'features/nodes/types/openapi';
|
||||
import { t } from 'i18next';
|
||||
import { isArray } from 'lodash-es';
|
||||
import type { OpenAPIV3_1 } from 'openapi-types';
|
||||
@ -35,7 +29,7 @@ const OPENAPI_TO_FIELD_TYPE_MAP: Record<string, string> = {
|
||||
boolean: 'BooleanField',
|
||||
};
|
||||
|
||||
const isCollectionFieldType = (fieldType: string) => {
|
||||
export const isCollectionFieldType = (fieldType: string) => {
|
||||
/**
|
||||
* CollectionField is `list[Any]` in the pydantic schema, but we need to distinguish between
|
||||
* it and other `list[Any]` fields, due to its special internal handling.
|
||||
@ -48,18 +42,7 @@ const isCollectionFieldType = (fieldType: string) => {
|
||||
return false;
|
||||
};
|
||||
|
||||
export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef | InvocationFieldSchema): FieldType => {
|
||||
if (isInvocationFieldSchema(schemaObject)) {
|
||||
// Check if this field has an explicit type provided by the node schema
|
||||
const { ui_type } = schemaObject;
|
||||
if (ui_type) {
|
||||
return {
|
||||
name: ui_type,
|
||||
isCollection: isCollectionFieldType(ui_type),
|
||||
isCollectionOrScalar: false,
|
||||
};
|
||||
}
|
||||
}
|
||||
export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType => {
|
||||
if (isSchemaObject(schemaObject)) {
|
||||
if (schemaObject.const) {
|
||||
// Fields with a single const value are defined as `Literal["value"]` in the pydantic schema - it's actually an enum
|
||||
|
@ -97,6 +97,11 @@ const expected = {
|
||||
name: 'SchedulerField',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
originalType: {
|
||||
name: 'EnumField',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
},
|
||||
},
|
||||
default: 'euler',
|
||||
},
|
||||
@ -111,6 +116,11 @@ const expected = {
|
||||
name: 'SchedulerField',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
originalType: {
|
||||
name: 'EnumField',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
},
|
||||
},
|
||||
ui_hidden: false,
|
||||
ui_type: 'SchedulerField',
|
||||
@ -141,6 +151,11 @@ const expected = {
|
||||
name: 'MainModelField',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
originalType: {
|
||||
name: 'ModelIdentifierField',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@ -186,6 +201,48 @@ const expected = {
|
||||
nodePack: 'invokeai',
|
||||
classification: 'stable',
|
||||
},
|
||||
collect: {
|
||||
title: 'Collect',
|
||||
type: 'collect',
|
||||
version: '1.0.0',
|
||||
tags: [],
|
||||
description: 'Collects values into a collection',
|
||||
outputType: 'collect_output',
|
||||
inputs: {
|
||||
item: {
|
||||
name: 'item',
|
||||
title: 'Collection Item',
|
||||
required: false,
|
||||
description: 'The item to collect (all inputs must be of the same type)',
|
||||
fieldKind: 'input',
|
||||
input: 'connection',
|
||||
ui_hidden: false,
|
||||
ui_type: 'CollectionItemField',
|
||||
type: {
|
||||
name: 'CollectionItemField',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
outputs: {
|
||||
collection: {
|
||||
fieldKind: 'output',
|
||||
name: 'collection',
|
||||
title: 'Collection',
|
||||
description: 'The collection of input items',
|
||||
type: {
|
||||
name: 'CollectionField',
|
||||
isCollection: true,
|
||||
isCollectionOrScalar: false,
|
||||
},
|
||||
ui_hidden: false,
|
||||
ui_type: 'CollectionField',
|
||||
},
|
||||
},
|
||||
useCache: true,
|
||||
classification: 'stable',
|
||||
},
|
||||
};
|
||||
|
||||
const schema = {
|
||||
@ -785,6 +842,101 @@ const schema = {
|
||||
type: 'object',
|
||||
class: 'output',
|
||||
},
|
||||
CollectInvocation: {
|
||||
properties: {
|
||||
id: {
|
||||
type: 'string',
|
||||
title: 'Id',
|
||||
description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.',
|
||||
field_kind: 'node_attribute',
|
||||
},
|
||||
is_intermediate: {
|
||||
type: 'boolean',
|
||||
title: 'Is Intermediate',
|
||||
description: 'Whether or not this is an intermediate invocation.',
|
||||
default: false,
|
||||
field_kind: 'node_attribute',
|
||||
ui_type: 'IsIntermediate',
|
||||
},
|
||||
use_cache: {
|
||||
type: 'boolean',
|
||||
title: 'Use Cache',
|
||||
description: 'Whether or not to use the cache',
|
||||
default: true,
|
||||
field_kind: 'node_attribute',
|
||||
},
|
||||
item: {
|
||||
anyOf: [
|
||||
{},
|
||||
{
|
||||
type: 'null',
|
||||
},
|
||||
],
|
||||
title: 'Collection Item',
|
||||
description: 'The item to collect (all inputs must be of the same type)',
|
||||
field_kind: 'input',
|
||||
input: 'connection',
|
||||
orig_required: false,
|
||||
ui_hidden: false,
|
||||
ui_type: 'CollectionItemField',
|
||||
},
|
||||
collection: {
|
||||
items: {},
|
||||
type: 'array',
|
||||
title: 'Collection',
|
||||
description: 'The collection, will be provided on execution',
|
||||
default: [],
|
||||
field_kind: 'input',
|
||||
input: 'any',
|
||||
orig_default: [],
|
||||
orig_required: false,
|
||||
ui_hidden: true,
|
||||
},
|
||||
type: {
|
||||
type: 'string',
|
||||
enum: ['collect'],
|
||||
const: 'collect',
|
||||
title: 'type',
|
||||
default: 'collect',
|
||||
field_kind: 'node_attribute',
|
||||
},
|
||||
},
|
||||
type: 'object',
|
||||
required: ['type', 'id'],
|
||||
title: 'CollectInvocation',
|
||||
description: 'Collects values into a collection',
|
||||
classification: 'stable',
|
||||
version: '1.0.0',
|
||||
output: {
|
||||
$ref: '#/components/schemas/CollectInvocationOutput',
|
||||
},
|
||||
class: 'invocation',
|
||||
},
|
||||
CollectInvocationOutput: {
|
||||
properties: {
|
||||
collection: {
|
||||
description: 'The collection of input items',
|
||||
field_kind: 'output',
|
||||
items: {},
|
||||
title: 'Collection',
|
||||
type: 'array',
|
||||
ui_hidden: false,
|
||||
ui_type: 'CollectionField',
|
||||
},
|
||||
type: {
|
||||
const: 'collect_output',
|
||||
default: 'collect_output',
|
||||
enum: ['collect_output'],
|
||||
field_kind: 'node_attribute',
|
||||
title: 'type',
|
||||
type: 'string',
|
||||
},
|
||||
},
|
||||
required: ['collection', 'type', 'type'],
|
||||
title: 'CollectInvocationOutput',
|
||||
type: 'object',
|
||||
class: 'output',
|
||||
},
|
||||
},
|
||||
},
|
||||
} as OpenAPIV3_1.Document;
|
||||
|
@ -1,23 +1,29 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import type { Templates } from 'features/nodes/store/types';
|
||||
import { FieldParseError } from 'features/nodes/types/error';
|
||||
import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field';
|
||||
import {
|
||||
type FieldInputTemplate,
|
||||
type FieldOutputTemplate,
|
||||
type FieldType,
|
||||
isStatefulFieldType,
|
||||
} from 'features/nodes/types/field';
|
||||
import type { InvocationTemplate } from 'features/nodes/types/invocation';
|
||||
import type { InvocationSchemaObject } from 'features/nodes/types/openapi';
|
||||
import type { InvocationFieldSchema, InvocationSchemaObject } from 'features/nodes/types/openapi';
|
||||
import {
|
||||
isInvocationFieldSchema,
|
||||
isInvocationOutputSchemaObject,
|
||||
isInvocationSchemaObject,
|
||||
} from 'features/nodes/types/openapi';
|
||||
import { t } from 'i18next';
|
||||
import { reduce } from 'lodash-es';
|
||||
import { isEqual, reduce } from 'lodash-es';
|
||||
import type { OpenAPIV3_1 } from 'openapi-types';
|
||||
import { serializeError } from 'serialize-error';
|
||||
|
||||
import { buildFieldInputTemplate } from './buildFieldInputTemplate';
|
||||
import { buildFieldOutputTemplate } from './buildFieldOutputTemplate';
|
||||
import { parseFieldType } from './parseFieldType';
|
||||
import { isCollectionFieldType, parseFieldType } from './parseFieldType';
|
||||
|
||||
const RESERVED_INPUT_FIELD_NAMES = ['id', 'type', 'use_cache'];
|
||||
const RESERVED_OUTPUT_FIELD_NAMES = ['type'];
|
||||
@ -94,51 +100,43 @@ export const parseSchema = (
|
||||
return inputsAccumulator;
|
||||
}
|
||||
|
||||
try {
|
||||
const fieldType = parseFieldType(property);
|
||||
const fieldTypeOverride = property.ui_type
|
||||
? {
|
||||
name: property.ui_type,
|
||||
isCollection: isCollectionFieldType(property.ui_type),
|
||||
isCollectionOrScalar: false,
|
||||
}
|
||||
: null;
|
||||
|
||||
if (isReservedFieldType(fieldType.name)) {
|
||||
logger('nodes').trace(
|
||||
{ node: type, field: propertyName, schema: parseify(property) },
|
||||
'Skipped reserved input field'
|
||||
);
|
||||
return inputsAccumulator;
|
||||
}
|
||||
const originalFieldType = getFieldType(property, propertyName, type, 'input');
|
||||
|
||||
const fieldInputTemplate = buildFieldInputTemplate(property, propertyName, fieldType);
|
||||
|
||||
inputsAccumulator[propertyName] = fieldInputTemplate;
|
||||
} catch (e) {
|
||||
if (e instanceof FieldParseError) {
|
||||
logger('nodes').warn(
|
||||
{
|
||||
node: type,
|
||||
field: propertyName,
|
||||
schema: parseify(property),
|
||||
},
|
||||
t('nodes.inputFieldTypeParseError', {
|
||||
node: type,
|
||||
field: propertyName,
|
||||
message: e.message,
|
||||
})
|
||||
);
|
||||
} else {
|
||||
logger('nodes').warn(
|
||||
{
|
||||
node: type,
|
||||
field: propertyName,
|
||||
schema: parseify(property),
|
||||
error: serializeError(e),
|
||||
},
|
||||
t('nodes.inputFieldTypeParseError', {
|
||||
node: type,
|
||||
field: propertyName,
|
||||
message: 'unknown error',
|
||||
})
|
||||
);
|
||||
}
|
||||
const fieldType = fieldTypeOverride ?? originalFieldType;
|
||||
if (!fieldType) {
|
||||
logger('nodes').trace(
|
||||
{ node: type, field: propertyName, schema: parseify(property) },
|
||||
'Unable to parse field type'
|
||||
);
|
||||
return inputsAccumulator;
|
||||
}
|
||||
|
||||
if (isReservedFieldType(fieldType.name)) {
|
||||
logger('nodes').trace(
|
||||
{ node: type, field: propertyName, schema: parseify(property) },
|
||||
'Skipped reserved input field'
|
||||
);
|
||||
return inputsAccumulator;
|
||||
}
|
||||
|
||||
if (isStatefulFieldType(fieldType) && originalFieldType && !isEqual(originalFieldType, fieldType)) {
|
||||
console.log('STATEFUL WITH ORIGINAL');
|
||||
fieldType.originalType = deepClone(originalFieldType);
|
||||
console.log(fieldType);
|
||||
}
|
||||
|
||||
const fieldInputTemplate = buildFieldInputTemplate(property, propertyName, fieldType);
|
||||
console.log(fieldInputTemplate);
|
||||
inputsAccumulator[propertyName] = fieldInputTemplate;
|
||||
|
||||
return inputsAccumulator;
|
||||
},
|
||||
{}
|
||||
@ -183,54 +181,34 @@ export const parseSchema = (
|
||||
return outputsAccumulator;
|
||||
}
|
||||
|
||||
try {
|
||||
const fieldType = parseFieldType(property);
|
||||
const fieldTypeOverride = property.ui_type
|
||||
? {
|
||||
name: property.ui_type,
|
||||
isCollection: isCollectionFieldType(property.ui_type),
|
||||
isCollectionOrScalar: false,
|
||||
}
|
||||
: null;
|
||||
|
||||
if (!fieldType) {
|
||||
logger('nodes').warn(
|
||||
{
|
||||
node: type,
|
||||
field: propertyName,
|
||||
schema: parseify(property),
|
||||
},
|
||||
'Missing output field type'
|
||||
);
|
||||
return outputsAccumulator;
|
||||
}
|
||||
const originalFieldType = getFieldType(property, propertyName, type, 'output');
|
||||
|
||||
const fieldOutputTemplate = buildFieldOutputTemplate(property, propertyName, fieldType);
|
||||
|
||||
outputsAccumulator[propertyName] = fieldOutputTemplate;
|
||||
} catch (e) {
|
||||
if (e instanceof FieldParseError) {
|
||||
logger('nodes').warn(
|
||||
{
|
||||
node: type,
|
||||
field: propertyName,
|
||||
schema: parseify(property),
|
||||
},
|
||||
t('nodes.outputFieldTypeParseError', {
|
||||
node: type,
|
||||
field: propertyName,
|
||||
message: e.message,
|
||||
})
|
||||
);
|
||||
} else {
|
||||
logger('nodes').warn(
|
||||
{
|
||||
node: type,
|
||||
field: propertyName,
|
||||
schema: parseify(property),
|
||||
error: serializeError(e),
|
||||
},
|
||||
t('nodes.outputFieldTypeParseError', {
|
||||
node: type,
|
||||
field: propertyName,
|
||||
message: 'unknown error',
|
||||
})
|
||||
);
|
||||
}
|
||||
const fieldType = fieldTypeOverride ?? originalFieldType;
|
||||
if (!fieldType) {
|
||||
logger('nodes').trace(
|
||||
{ node: type, field: propertyName, schema: parseify(property) },
|
||||
'Unable to parse field type'
|
||||
);
|
||||
return outputsAccumulator;
|
||||
}
|
||||
|
||||
if (isStatefulFieldType(fieldType) && originalFieldType && !isEqual(originalFieldType, fieldType)) {
|
||||
console.log('STATEFUL WITH ORIGINAL');
|
||||
fieldType.originalType = deepClone(originalFieldType);
|
||||
console.log(fieldType);
|
||||
}
|
||||
|
||||
const fieldOutputTemplate = buildFieldOutputTemplate(property, propertyName, fieldType);
|
||||
|
||||
outputsAccumulator[propertyName] = fieldOutputTemplate;
|
||||
return outputsAccumulator;
|
||||
},
|
||||
{} as Record<string, FieldOutputTemplate>
|
||||
@ -259,3 +237,45 @@ export const parseSchema = (
|
||||
|
||||
return invocations;
|
||||
};
|
||||
|
||||
const getFieldType = (
|
||||
property: InvocationFieldSchema,
|
||||
propertyName: string,
|
||||
type: string,
|
||||
kind: 'input' | 'output'
|
||||
): FieldType | null => {
|
||||
try {
|
||||
return parseFieldType(property);
|
||||
} catch (e) {
|
||||
const tKey = kind === 'input' ? 'nodes.inputFieldTypeParseError' : 'nodes.outputFieldTypeParseError';
|
||||
if (e instanceof FieldParseError) {
|
||||
logger('nodes').warn(
|
||||
{
|
||||
node: type,
|
||||
field: propertyName,
|
||||
schema: parseify(property),
|
||||
},
|
||||
t(tKey, {
|
||||
node: type,
|
||||
field: propertyName,
|
||||
message: e.message,
|
||||
})
|
||||
);
|
||||
} else {
|
||||
logger('nodes').warn(
|
||||
{
|
||||
node: type,
|
||||
field: propertyName,
|
||||
schema: parseify(property),
|
||||
error: serializeError(e),
|
||||
},
|
||||
t(tKey, {
|
||||
node: type,
|
||||
field: propertyName,
|
||||
message: 'unknown error',
|
||||
})
|
||||
);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user