feat(ui): add support for custom field types

Node authors may now create their own arbitrary/custom field types. Any pydantic model is supported.

Two notes:
1. Your field type's class name must be unique.

Suggest prefixing fields with something related to the node pack as a kind of namespace.

2. Custom field types function as connection-only fields.

For example, if your custom field has string attributes, you will not get a text input for that attribute when you give a node a field with your custom type.

This is the same behaviour as other complex fields that don't have custom UIs in the workflow editor - like, say, a string collection.

feat(ui): fix tooltips for custom types

We need to hold onto the original type of the field so they don't all just show up as "Unknown".

fix(ui): fix ts error with custom fields

feat(ui): custom field types connection validation

In the initial commit, a custom field's original type was added to the *field templates* only as `originalType`. Custom fields' `type` property was `"Custom"`*. This allowed for type safety throughout the UI logic.

*Actually, it was `"Unknown"`, but I changed it to custom for clarity.

Connection validation logic, however, uses the *field instance* of the node/field. Like the templates, *field instances* with custom types have their `type` set to `"Custom"`, but they didn't have an `originalType` property. As a result, all custom fields could be connected to all other custom fields.

To resolve this, we need to add `originalType` to the *field instances*, then switch the validation logic to use this instead of `type`.

This ended up needing a bit of fanagling:

- If we make `originalType` a required property on field instances, existing workflows will break during connection validation, because they won't have this property. We'd need a new layer of logic to migrate the workflows, adding the new `originalType` property.

While this layer is probably needed anyways, typing `originalType` as optional is much simpler. Workflow migration logic can come layer.

(Technically, we could remove all references to field types from the workflow files, and let the templates hold all this information. This feels like a significant change and I'm reluctant to do it now.)

- Because `originalType` is optional, anywhere we care about the type of a field, we need to use it over `type`. So there are a number of `field.originalType ?? field.type` expressions. This is a bit of a gotcha, we'll need to remember this in the future.

- We use `Array.prototype.includes()` often in the workflow editor, e.g. `COLLECTION_TYPES.includes(type)`. In these cases, the const array is of type `FieldType[]`, and `type` is is `FieldType`.

Because we now support custom types, the arg `type` is now widened from `FieldType` to `string`.

This causes a TS error. This behaviour is somewhat controversial (see https://github.com/microsoft/TypeScript/issues/14520). These expressions are now rewritten as `COLLECTION_TYPES.some((t) => t === type)` to satisfy TS. It's logically equivalent.

fix(ui): typo

feat(ui): add CustomCollection and CustomPolymorphic field types

feat(ui): add validation for CustomCollection & CustomPolymorphic types

- Update connection validation for custom types
- Use simple string parsing to determine if a field is a collection or polymorphic type.
- No longer need to keep a list of collection and polymorphic types.
- Added runtime checks in `baseinvocation.py` to ensure no fields are named in such a way that it could mess up the new parsing

chore(ui): remove errant console.log

fix(ui): rename 'nodes.currentConnectionFieldType' -> 'nodes.connectionStartFieldType'

This was confusingly named and kept tripping me up. Renamed to be consistent with the `reactflow` `ConnectionStartParams` type.

fix(ui): fix ts error

feat(nodes): add runtime check for custom field names

"Custom", "CustomCollection" and "CustomPolymorphic" are reserved field names.

chore(ui): add TODO for revising field type names

wip refactor fieldtype structured

wip refactor field types

wip refactor types

wip refactor types

fix node layout

refactor field types

chore: mypy

organisation

organisation

organisation

fix(nodes): fix field orig_required, field_kind and input statuses

feat(nodes): remove broken implementation of default_factory on InputField

Use of this could break connection validation due to the difference in node schemas required fields and invoke() required args.

Removed entirely for now. It wasn't ever actually used by the system, because all graphs always had values provided for fields where default_factory was used.

Also, pydantic is smart enough to not reuse the same object when specifying a default value - it clones the object first. So, the common pattern of `default_factory=list` is extraneous. It can just be `default=[]`.

fix(nodes): fix InputField name validation

workflow validation

validation

chore: ruff

feat(nodes): fix up baseinvocation comments

fix(ui): improve typing & logic of buildFieldInputTemplate

improved error handling in parseFieldType

fix: back compat for deprecated default_factory and UIType

feat(nodes): do not show node packs loaded log if none loaded

chore(ui): typegen
This commit is contained in:
psychedelicious
2023-11-17 11:32:35 +11:00
parent 0d52430481
commit 86a74e929a
186 changed files with 5713 additions and 5704 deletions

View File

@ -0,0 +1,42 @@
import { get } from 'lodash-es';
import { FieldInputInstance, FieldInputTemplate } from '../types/field';
const FIELD_VALUE_FALLBACK_MAP = {
EnumField: '',
BoardField: undefined,
BooleanField: false,
ClipField: undefined,
ColorField: { r: 0, g: 0, b: 0, a: 1 },
FloatField: 0,
ImageField: undefined,
IntegerField: 0,
IPAdapterModelField: undefined,
LoRAModelField: undefined,
MainModelField: undefined,
ONNXModelField: undefined,
SchedulerField: 'euler',
SDXLMainModelField: undefined,
SDXLRefinerModelField: undefined,
StringField: '',
T2IAdapterModelField: undefined,
T2IAdapterPolymorphic: undefined,
VAEModelField: undefined,
ControlNetModelField: undefined,
};
export const buildFieldInputInstance = (
id: string,
template: FieldInputTemplate
): FieldInputInstance => {
const fieldInstance: FieldInputInstance = {
id,
name: template.name,
type: template.type,
label: '',
fieldKind: 'input' as const,
value:
template.default ?? get(FIELD_VALUE_FALLBACK_MAP, template.type.name),
};
return fieldInstance;
};

View File

@ -0,0 +1,376 @@
import { isNumber, startCase } from 'lodash-es';
import {
BoardFieldInputTemplate,
BooleanFieldInputTemplate,
ColorFieldInputTemplate,
ControlNetModelFieldInputTemplate,
EnumFieldInputTemplate,
FieldInputTemplate,
FieldType,
FloatFieldInputTemplate,
IPAdapterModelFieldInputTemplate,
ImageFieldInputTemplate,
IntegerFieldInputTemplate,
LoRAModelFieldInputTemplate,
MainModelFieldInputTemplate,
SDXLMainModelFieldInputTemplate,
SDXLRefinerModelFieldInputTemplate,
SchedulerFieldInputTemplate,
StatefulFieldType,
StatelessFieldInputTemplate,
StringFieldInputTemplate,
T2IAdapterModelFieldInputTemplate,
VAEModelFieldInputTemplate,
isStatefulFieldType,
} from '../types/field';
import { InvocationFieldSchema } from '../types/openapi';
// eslint-disable-next-line @typescript-eslint/no-explicit-any
type FieldInputTemplateBuilder<T extends FieldInputTemplate = any> = // valid `any`!
(arg: {
schemaObject: InvocationFieldSchema;
baseField: Omit<T, 'type'>;
isCollection: boolean;
isPolymorphic: boolean;
}) => T;
const buildIntegerFieldInputTemplate: FieldInputTemplateBuilder<
IntegerFieldInputTemplate
> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => {
const template: IntegerFieldInputTemplate = {
...baseField,
type: { name: 'IntegerField', isCollection, isPolymorphic },
default: schemaObject.default ?? 0,
};
if (schemaObject.multipleOf !== undefined) {
template.multipleOf = schemaObject.multipleOf;
}
if (schemaObject.maximum !== undefined) {
template.maximum = schemaObject.maximum;
}
if (
schemaObject.exclusiveMaximum !== undefined &&
isNumber(schemaObject.exclusiveMaximum)
) {
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
}
if (schemaObject.minimum !== undefined) {
template.minimum = schemaObject.minimum;
}
if (
schemaObject.exclusiveMinimum !== undefined &&
isNumber(schemaObject.exclusiveMinimum)
) {
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
}
return template;
};
const buildFloatFieldInputTemplate: FieldInputTemplateBuilder<
FloatFieldInputTemplate
> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => {
const template: FloatFieldInputTemplate = {
...baseField,
type: { name: 'FloatField', isCollection, isPolymorphic },
default: schemaObject.default ?? 0,
};
if (schemaObject.multipleOf !== undefined) {
template.multipleOf = schemaObject.multipleOf;
}
if (schemaObject.maximum !== undefined) {
template.maximum = schemaObject.maximum;
}
if (
schemaObject.exclusiveMaximum !== undefined &&
isNumber(schemaObject.exclusiveMaximum)
) {
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
}
if (schemaObject.minimum !== undefined) {
template.minimum = schemaObject.minimum;
}
if (
schemaObject.exclusiveMinimum !== undefined &&
isNumber(schemaObject.exclusiveMinimum)
) {
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
}
return template;
};
const buildStringFieldInputTemplate: FieldInputTemplateBuilder<
StringFieldInputTemplate
> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => {
const template: StringFieldInputTemplate = {
...baseField,
type: { name: 'StringField', isCollection, isPolymorphic },
default: schemaObject.default ?? '',
};
if (schemaObject.minLength !== undefined) {
template.minLength = schemaObject.minLength;
}
if (schemaObject.maxLength !== undefined) {
template.maxLength = schemaObject.maxLength;
}
return template;
};
const buildBooleanFieldInputTemplate: FieldInputTemplateBuilder<
BooleanFieldInputTemplate
> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => {
const template: BooleanFieldInputTemplate = {
...baseField,
type: { name: 'BooleanField', isCollection, isPolymorphic },
default: schemaObject.default ?? false,
};
return template;
};
const buildMainModelFieldInputTemplate: FieldInputTemplateBuilder<
MainModelFieldInputTemplate
> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => {
const template: MainModelFieldInputTemplate = {
...baseField,
type: { name: 'MainModelField', isCollection, isPolymorphic },
default: schemaObject.default ?? undefined,
};
return template;
};
const buildSDXLMainModelFieldInputTemplate: FieldInputTemplateBuilder<
SDXLMainModelFieldInputTemplate
> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => {
const template: SDXLMainModelFieldInputTemplate = {
...baseField,
type: { name: 'SDXLMainModelField', isCollection, isPolymorphic },
default: schemaObject.default ?? undefined,
};
return template;
};
const buildRefinerModelFieldInputTemplate: FieldInputTemplateBuilder<
SDXLRefinerModelFieldInputTemplate
> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => {
const template: SDXLRefinerModelFieldInputTemplate = {
...baseField,
type: { name: 'SDXLRefinerModelField', isCollection, isPolymorphic },
default: schemaObject.default ?? undefined,
};
return template;
};
const buildVAEModelFieldInputTemplate: FieldInputTemplateBuilder<
VAEModelFieldInputTemplate
> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => {
const template: VAEModelFieldInputTemplate = {
...baseField,
type: { name: 'VAEModelField', isCollection, isPolymorphic },
default: schemaObject.default ?? undefined,
};
return template;
};
const buildLoRAModelFieldInputTemplate: FieldInputTemplateBuilder<
LoRAModelFieldInputTemplate
> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => {
const template: LoRAModelFieldInputTemplate = {
...baseField,
type: { name: 'LoRAModelField', isCollection, isPolymorphic },
default: schemaObject.default ?? undefined,
};
return template;
};
const buildControlNetModelFieldInputTemplate: FieldInputTemplateBuilder<
ControlNetModelFieldInputTemplate
> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => {
const template: ControlNetModelFieldInputTemplate = {
...baseField,
type: { name: 'ControlNetModelField', isCollection, isPolymorphic },
default: schemaObject.default ?? undefined,
};
return template;
};
const buildIPAdapterModelFieldInputTemplate: FieldInputTemplateBuilder<
IPAdapterModelFieldInputTemplate
> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => {
const template: IPAdapterModelFieldInputTemplate = {
...baseField,
type: { name: 'IPAdapterModelField', isCollection, isPolymorphic },
default: schemaObject.default ?? undefined,
};
return template;
};
const buildT2IAdapterModelFieldInputTemplate: FieldInputTemplateBuilder<
T2IAdapterModelFieldInputTemplate
> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => {
const template: T2IAdapterModelFieldInputTemplate = {
...baseField,
type: { name: 'T2IAdapterModelField', isCollection, isPolymorphic },
default: schemaObject.default ?? undefined,
};
return template;
};
const buildBoardFieldInputTemplate: FieldInputTemplateBuilder<
BoardFieldInputTemplate
> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => {
const template: BoardFieldInputTemplate = {
...baseField,
type: { name: 'BoardField', isCollection, isPolymorphic },
default: schemaObject.default ?? undefined,
};
return template;
};
const buildImageFieldInputTemplate: FieldInputTemplateBuilder<
ImageFieldInputTemplate
> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => {
const template: ImageFieldInputTemplate = {
...baseField,
type: { name: 'ImageField', isCollection, isPolymorphic },
default: schemaObject.default ?? undefined,
};
return template;
};
const buildEnumFieldInputTemplate: FieldInputTemplateBuilder<
EnumFieldInputTemplate
> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => {
const options = schemaObject.enum ?? [];
const template: EnumFieldInputTemplate = {
...baseField,
type: { name: 'EnumField', isCollection, isPolymorphic },
options,
ui_choice_labels: schemaObject.ui_choice_labels,
default: schemaObject.default ?? options[0],
};
return template;
};
const buildColorFieldInputTemplate: FieldInputTemplateBuilder<
ColorFieldInputTemplate
> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => {
const template: ColorFieldInputTemplate = {
...baseField,
type: { name: 'ColorField', isCollection, isPolymorphic },
default: schemaObject.default ?? { r: 127, g: 127, b: 127, a: 255 },
};
return template;
};
const buildSchedulerFieldInputTemplate: FieldInputTemplateBuilder<
SchedulerFieldInputTemplate
> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => {
const template: SchedulerFieldInputTemplate = {
...baseField,
type: { name: 'SchedulerField', isCollection, isPolymorphic },
default: schemaObject.default ?? 'euler',
};
return template;
};
export const TEMPLATE_BUILDER_MAP: Record<
StatefulFieldType['name'],
FieldInputTemplateBuilder
> = {
BoardField: buildBoardFieldInputTemplate,
BooleanField: buildBooleanFieldInputTemplate,
ColorField: buildColorFieldInputTemplate,
ControlNetModelField: buildControlNetModelFieldInputTemplate,
EnumField: buildEnumFieldInputTemplate,
FloatField: buildFloatFieldInputTemplate,
ImageField: buildImageFieldInputTemplate,
IntegerField: buildIntegerFieldInputTemplate,
IPAdapterModelField: buildIPAdapterModelFieldInputTemplate,
LoRAModelField: buildLoRAModelFieldInputTemplate,
MainModelField: buildMainModelFieldInputTemplate,
SchedulerField: buildSchedulerFieldInputTemplate,
SDXLMainModelField: buildSDXLMainModelFieldInputTemplate,
SDXLRefinerModelField: buildRefinerModelFieldInputTemplate,
StringField: buildStringFieldInputTemplate,
T2IAdapterModelField: buildT2IAdapterModelFieldInputTemplate,
VAEModelField: buildVAEModelFieldInputTemplate,
};
export const buildFieldInputTemplate = (
fieldSchema: InvocationFieldSchema,
name: string,
fieldType: FieldType
): FieldInputTemplate => {
const {
input,
ui_hidden,
ui_component,
ui_type,
ui_order,
ui_choice_labels,
orig_required: required,
} = fieldSchema;
// This is the base field template that is common to all fields. The builder function will add all other
// properties to this template.
const baseField: Omit<FieldInputTemplate, 'type'> = {
name,
title: fieldSchema.title ?? (name ? startCase(name) : ''),
required,
description: fieldSchema.description ?? '',
fieldKind: 'input' as const,
input,
ui_hidden,
ui_component,
ui_type,
ui_order,
ui_choice_labels,
};
if (isStatefulFieldType(fieldType)) {
const builder = TEMPLATE_BUILDER_MAP[fieldType.name];
return builder({
schemaObject: fieldSchema,
baseField,
isCollection: fieldType.isCollection,
isPolymorphic: fieldType.isPolymorphic,
});
}
// This is a StatelessField, create it directly.
const template: StatelessFieldInputTemplate = {
...baseField,
input: 'connection', // stateless --> connection only inputs
type: fieldType,
default: undefined, // stateless --> no default value
};
return template;
};

View File

@ -1,13 +1,13 @@
import { logger } from 'app/logging/logger';
import { NodesState } from '../store/types';
import { Workflow, zWorkflowEdge, zWorkflowNode } from '../types/types';
import { WorkflowV2, zWorkflowEdge, zWorkflowNode } from '../types/workflow';
import { fromZodError } from 'zod-validation-error';
import { parseify } from 'common/util/serialize';
import i18n from 'i18next';
export const buildWorkflow = (nodesState: NodesState): Workflow => {
export const buildWorkflow = (nodesState: NodesState): WorkflowV2 => {
const { workflow: workflowMeta, nodes, edges } = nodesState;
const workflow: Workflow = {
const workflow: WorkflowV2 = {
...workflowMeta,
nodes: [],
edges: [],

View File

@ -1,85 +0,0 @@
import { FieldType, InputFieldTemplate, InputFieldValue } from '../types/types';
const FIELD_VALUE_FALLBACK_MAP: {
[key in FieldType]: InputFieldValue['value'];
} = {
Any: undefined,
enum: '',
BoardField: undefined,
boolean: false,
BooleanCollection: [],
BooleanPolymorphic: false,
ClipField: undefined,
Collection: [],
CollectionItem: undefined,
ColorCollection: [],
ColorField: undefined,
ColorPolymorphic: undefined,
ConditioningCollection: [],
ConditioningField: undefined,
ConditioningPolymorphic: undefined,
ControlCollection: [],
ControlField: undefined,
ControlNetModelField: undefined,
ControlPolymorphic: undefined,
DenoiseMaskField: undefined,
float: 0,
FloatCollection: [],
FloatPolymorphic: 0,
ImageCollection: [],
ImageField: undefined,
ImagePolymorphic: undefined,
integer: 0,
IntegerCollection: [],
IntegerPolymorphic: 0,
IPAdapterCollection: [],
IPAdapterField: undefined,
IPAdapterModelField: undefined,
IPAdapterPolymorphic: undefined,
LatentsCollection: [],
LatentsField: undefined,
LatentsPolymorphic: undefined,
MetadataItemField: undefined,
MetadataItemCollection: [],
MetadataItemPolymorphic: undefined,
MetadataField: undefined,
MetadataCollection: [],
LoRAModelField: undefined,
MainModelField: undefined,
ONNXModelField: undefined,
Scheduler: 'euler',
SDXLMainModelField: undefined,
SDXLRefinerModelField: undefined,
string: '',
StringCollection: [],
StringPolymorphic: '',
T2IAdapterCollection: [],
T2IAdapterField: undefined,
T2IAdapterModelField: undefined,
T2IAdapterPolymorphic: undefined,
UNetField: undefined,
VaeField: undefined,
VaeModelField: undefined,
};
export const buildInputFieldValue = (
id: string,
template: InputFieldTemplate
): InputFieldValue => {
// TODO: this should be `fieldValue: InputFieldValue`, but that introduces a TS issue I couldn't
// resolve - for some reason, it doesn't like `template.type`, which is the discriminant for both
// `InputFieldTemplate` union. It is (type-structurally) equal to the discriminant for the
// `InputFieldValue` union, but TS doesn't seem to like it...
const fieldValue = {
id,
name: template.name,
type: template.type,
label: '',
fieldKind: 'input',
} as InputFieldValue;
fieldValue.value =
template.default ?? FIELD_VALUE_FALLBACK_MAP[template.type];
return fieldValue;
};

View File

@ -1,8 +1,8 @@
import { isNil } from 'lodash-es';
import { InputFieldTemplate, OutputFieldTemplate } from '../types/types';
import { FieldInputTemplate, FieldOutputTemplate } from '../types/field';
export const getSortedFilteredFieldNames = (
fields: InputFieldTemplate[] | OutputFieldTemplate[]
fields: FieldInputTemplate[] | FieldOutputTemplate[]
): string[] => {
const visibleFields = fields.filter((field) => !field.ui_hidden);

View File

@ -6,8 +6,8 @@ import {
ControlField,
ControlNetInvocation,
CoreMetadataInvocation,
NonNullableGraph,
} from 'services/api/types';
import { NonNullableGraph } from '../../types/types';
import {
CANVAS_COHERENCE_DENOISE_LATENTS,
CONTROL_NET_COLLECT,

View File

@ -1,13 +1,13 @@
import { logger } from 'app/logging/logger';
import { RootState } from 'app/store/store';
import { roundToMultiple } from 'common/util/roundDownToMultiple';
import { NonNullableGraph } from 'features/nodes/types/types';
import {
DenoiseLatentsInvocation,
ESRGANInvocation,
Edge,
LatentsToImageInvocation,
NoiseInvocation,
NonNullableGraph,
} from 'services/api/types';
import {
DENOISE_LATENTS,

View File

@ -6,8 +6,8 @@ import {
CoreMetadataInvocation,
IPAdapterInvocation,
IPAdapterMetadataField,
NonNullableGraph,
} from 'services/api/types';
import { NonNullableGraph } from '../../types/types';
import {
CANVAS_COHERENCE_DENOISE_LATENTS,
IP_ADAPTER_COLLECT,

View File

@ -1,7 +1,6 @@
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { LinearUIOutputInvocation } from 'services/api/types';
import { LinearUIOutputInvocation, NonNullableGraph } from 'services/api/types';
import {
CANVAS_OUTPUT,
LATENTS_TO_IMAGE,

View File

@ -1,9 +1,9 @@
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import { forEach, size } from 'lodash-es';
import {
CoreMetadataInvocation,
LoraLoaderInvocation,
NonNullableGraph,
} from 'services/api/types';
import {
CANVAS_COHERENCE_DENOISE_LATENTS,

View File

@ -1,8 +1,8 @@
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import {
ImageNSFWBlurInvocation,
LatentsToImageInvocation,
NonNullableGraph,
} from 'services/api/types';
import { LATENTS_TO_IMAGE, NSFW_CHECKER } from './constants';

View File

@ -1,11 +1,10 @@
import { RootState } from 'app/store/store';
import {
LoRAMetadataItem,
NonNullableGraph,
zLoRAMetadataItem,
} from 'features/nodes/types/types';
} from 'features/nodes/types/metadata';
import { forEach, size } from 'lodash-es';
import { SDXLLoraLoaderInvocation } from 'services/api/types';
import { NonNullableGraph, SDXLLoraLoaderInvocation } from 'services/api/types';
import {
CANVAS_COHERENCE_DENOISE_LATENTS,
LORA_LOADER,

View File

@ -2,9 +2,9 @@ import { RootState } from 'app/store/store';
import {
CreateDenoiseMaskInvocation,
ImageDTO,
NonNullableGraph,
SeamlessModeInvocation,
} from 'services/api/types';
import { NonNullableGraph } from '../../types/types';
import {
CANVAS_OUTPUT,
INPAINT_IMAGE_RESIZE_UP,

View File

@ -1,7 +1,5 @@
import { RootState } from 'app/store/store';
import { SeamlessModeInvocation } from 'services/api/types';
import { NonNullableGraph } from '../../types/types';
import { upsertMetadata } from './metadata';
import { NonNullableGraph, SeamlessModeInvocation } from 'services/api/types';
import {
CANVAS_COHERENCE_DENOISE_LATENTS,
CANVAS_INPAINT_GRAPH,
@ -16,6 +14,7 @@ import {
SDXL_TEXT_TO_IMAGE_GRAPH,
SEAMLESS,
} from './constants';
import { upsertMetadata } from './metadata';
export const addSeamlessToLinearGraph = (
state: RootState,

View File

@ -4,9 +4,10 @@ import { omit } from 'lodash-es';
import {
CollectInvocation,
CoreMetadataInvocation,
NonNullableGraph,
T2IAdapterField,
T2IAdapterInvocation,
} from 'services/api/types';
import { NonNullableGraph, T2IAdapterField } from '../../types/types';
import {
CANVAS_COHERENCE_DENOISE_LATENTS,
T2I_ADAPTER_COLLECT,

View File

@ -1,5 +1,5 @@
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import { NonNullableGraph } from 'services/api/types';
import {
CANVAS_COHERENCE_INPAINT_CREATE_MASK,
CANVAS_IMAGE_TO_IMAGE_GRAPH,

View File

@ -1,10 +1,10 @@
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import {
ImageNSFWBlurInvocation,
ImageWatermarkInvocation,
LatentsToImageInvocation,
NonNullableGraph,
} from 'services/api/types';
import { LATENTS_TO_IMAGE, NSFW_CHECKER, WATERMARKER } from './constants';

View File

@ -1,10 +1,10 @@
import { BoardId } from 'features/gallery/store/types';
import { NonNullableGraph } from 'features/nodes/types/types';
import { ESRGANModelName } from 'features/parameters/store/postprocessingSlice';
import {
ESRGANInvocation,
Graph,
LinearUIOutputInvocation,
NonNullableGraph,
} from 'services/api/types';
import { ESRGAN, LINEAR_UI_OUTPUT } from './constants';
import { addCoreMetadataNode, upsertMetadata } from './metadata';

View File

@ -1,6 +1,5 @@
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import { ImageDTO } from 'services/api/types';
import { ImageDTO, NonNullableGraph } from 'services/api/types';
import { buildCanvasImageToImageGraph } from './buildCanvasImageToImageGraph';
import { buildCanvasInpaintGraph } from './buildCanvasInpaintGraph';
import { buildCanvasOutpaintGraph } from './buildCanvasOutpaintGraph';

View File

@ -1,12 +1,15 @@
import { logger } from 'app/logging/logger';
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import { ImageDTO, ImageToLatentsInvocation } from 'services/api/types';
import {
ImageDTO,
ImageToLatentsInvocation,
NonNullableGraph,
} from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';

View File

@ -1,6 +1,5 @@
import { logger } from 'app/logging/logger';
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import {
CreateDenoiseMaskInvocation,
ImageBlurInvocation,
@ -8,12 +7,13 @@ import {
ImageToLatentsInvocation,
MaskEdgeInvocation,
NoiseInvocation,
NonNullableGraph,
} from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';

View File

@ -1,18 +1,18 @@
import { logger } from 'app/logging/logger';
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import {
ImageDTO,
ImageToLatentsInvocation,
InfillPatchMatchInvocation,
InfillTileInvocation,
NoiseInvocation,
NonNullableGraph,
} from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';

View File

@ -1,14 +1,18 @@
import { logger } from 'app/logging/logger';
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import { ImageDTO, ImageToLatentsInvocation } from 'services/api/types';
import {
ImageDTO,
ImageToLatentsInvocation,
NonNullableGraph,
} from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
@ -26,7 +30,6 @@ import {
SEAMLESS,
} from './constants';
import { buildSDXLStylePrompts } from './helpers/craftSDXLStylePrompt';
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
import { addCoreMetadataNode } from './metadata';
/**

View File

@ -1,6 +1,5 @@
import { logger } from 'app/logging/logger';
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import {
CreateDenoiseMaskInvocation,
ImageBlurInvocation,
@ -8,13 +7,14 @@ import {
ImageToLatentsInvocation,
MaskEdgeInvocation,
NoiseInvocation,
NonNullableGraph,
} from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';

View File

@ -1,19 +1,19 @@
import { logger } from 'app/logging/logger';
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import {
ImageDTO,
ImageToLatentsInvocation,
InfillPatchMatchInvocation,
InfillTileInvocation,
NoiseInvocation,
NonNullableGraph,
} from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';

View File

@ -1,16 +1,16 @@
import { logger } from 'app/logging/logger';
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import {
DenoiseLatentsInvocation,
NonNullableGraph,
ONNXTextToLatentsInvocation,
} from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';

View File

@ -1,15 +1,15 @@
import { logger } from 'app/logging/logger';
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import {
DenoiseLatentsInvocation,
NonNullableGraph,
ONNXTextToLatentsInvocation,
} from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';

View File

@ -1,10 +1,9 @@
import { NUMPY_RAND_MAX } from 'app/constants';
import { RootState } from 'app/store/store';
import { generateSeeds } from 'common/util/generateSeeds';
import { NonNullableGraph } from 'features/nodes/types/types';
import { range } from 'lodash-es';
import { components } from 'services/api/schema';
import { Batch, BatchConfig } from 'services/api/types';
import { Batch, BatchConfig, NonNullableGraph } from 'services/api/types';
import {
CANVAS_COHERENCE_NOISE,
METADATA,

View File

@ -1,15 +1,15 @@
import { logger } from 'app/logging/logger';
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import {
ImageResizeInvocation,
ImageToLatentsInvocation,
NonNullableGraph,
} from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';

View File

@ -1,16 +1,16 @@
import { logger } from 'app/logging/logger';
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import {
ImageResizeInvocation,
ImageToLatentsInvocation,
NonNullableGraph,
} from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';

View File

@ -1,17 +1,16 @@
import { logger } from 'app/logging/logger';
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import { NonNullableGraph } from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import { addCoreMetadataNode } from './metadata';
import {
LATENTS_TO_IMAGE,
NEGATIVE_CONDITIONING,
@ -24,6 +23,7 @@ import {
SEAMLESS,
} from './constants';
import { buildSDXLStylePrompts } from './helpers/craftSDXLStylePrompt';
import { addCoreMetadataNode } from './metadata';
export const buildLinearSDXLTextToImageGraph = (
state: RootState

View File

@ -1,21 +1,20 @@
import { logger } from 'app/logging/logger';
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import {
DenoiseLatentsInvocation,
NonNullableGraph,
ONNXTextToLatentsInvocation,
} from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addHrfToGraph } from './addHrfToGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import { addCoreMetadataNode } from './metadata';
import {
CLIP_SKIP,
DENOISE_LATENTS,
@ -28,6 +27,7 @@ import {
SEAMLESS,
TEXT_TO_IMAGE_GRAPH,
} from './constants';
import { addCoreMetadataNode } from './metadata';
export const buildLinearTextToImageGraph = (
state: RootState

View File

@ -1,16 +1,20 @@
import { NodesState } from 'features/nodes/store/types';
import { InputFieldValue, isInvocationNode } from 'features/nodes/types/types';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { cloneDeep, omit, reduce } from 'lodash-es';
import { Graph } from 'services/api/types';
import { AnyInvocation } from 'services/events/types';
import { v4 as uuidv4 } from 'uuid';
import { buildWorkflow } from '../buildWorkflow';
import {
FieldInputInstance,
isColorFieldInputInstance,
} from 'features/nodes/types/field';
/**
* We need to do special handling for some fields
*/
export const parseFieldValue = (field: InputFieldValue) => {
if (field.type === 'ColorField') {
export const parseFieldValue = (field: FieldInputInstance) => {
if (isColorFieldInputInstance(field)) {
if (field.value) {
const clonedValue = cloneDeep(field.value);

View File

@ -1,5 +1,4 @@
import { NonNullableGraph } from 'features/nodes/types/types';
import { CoreMetadataInvocation } from 'services/api/types';
import { CoreMetadataInvocation, NonNullableGraph } from 'services/api/types';
import { JsonObject } from 'type-fest';
import { METADATA } from './constants';

View File

@ -0,0 +1,233 @@
import { t } from 'i18next';
import { isArray } from 'lodash-es';
import { OpenAPIV3_1 } from 'openapi-types';
import { FieldTypeParseError, UnsupportedFieldTypeError } from '../types/error';
import { FieldType } from '../types/field';
import {
OpenAPIV3_1SchemaOrRef,
isArraySchemaObject,
isInvocationFieldSchema,
isNonArraySchemaObject,
isRefObject,
isSchemaObject,
} from '../types/openapi';
/**
* Transforms an invocation output ref object to field type.
* @param ref The ref string to transform
* @returns The field type.
*
* @example
* refObjectToFieldType({ "$ref": "#/components/schemas/ImageField" }) --> 'ImageField'
*/
export const refObjectToSchemaName = (refObject: OpenAPIV3_1.ReferenceObject) =>
refObject.$ref.split('/').slice(-1)[0];
const OPENAPI_TO_FIELD_TYPE_MAP: Record<string, string> = {
integer: 'IntegerField',
number: 'FloatField',
string: 'StringField',
boolean: 'BooleanField',
};
const isCollectionFieldType = (fieldType: string) => {
/**
* CollectionField is `list[Any]` in the pydantic schema, but we need to distinguish between
* it and other `list[Any]` fields, due to its special internal handling.
*
* In pydantic, it gets an explicit field type of `CollectionField`.
*/
if (fieldType === 'CollectionField') {
return true;
}
return false;
};
export const parseFieldType = (
schemaObject: OpenAPIV3_1SchemaOrRef
): FieldType => {
if (isInvocationFieldSchema(schemaObject)) {
// Check if this field has an explicit type provided by the node schema
const { ui_type } = schemaObject;
if (ui_type) {
return {
name: ui_type,
isCollection: isCollectionFieldType(ui_type),
isPolymorphic: false,
};
}
}
if (isSchemaObject(schemaObject)) {
if (!schemaObject.type) {
// if schemaObject has no type, then it should have one of allOf, anyOf, oneOf
if (schemaObject.allOf) {
const allOf = schemaObject.allOf;
if (allOf && allOf[0] && isRefObject(allOf[0])) {
// This is a single ref type
const name = refObjectToSchemaName(allOf[0]);
if (!name) {
throw new FieldTypeParseError(
t('nodes.unableToExtractSchemaNameFromRef')
);
}
return {
name,
isCollection: false,
isPolymorphic: false,
};
}
} else if (schemaObject.anyOf) {
// ignore null types
const filteredAnyOf = schemaObject.anyOf.filter((i) => {
if (isSchemaObject(i)) {
if (i.type === 'null') {
return false;
}
}
return true;
});
if (filteredAnyOf.length === 1) {
// This is a single ref type
if (isRefObject(filteredAnyOf[0])) {
const name = refObjectToSchemaName(filteredAnyOf[0]);
if (!name) {
throw new FieldTypeParseError(
t('nodes.unableToExtractSchemaNameFromRef')
);
}
return {
name,
isCollection: false,
isPolymorphic: false,
};
} else if (isSchemaObject(filteredAnyOf[0])) {
return parseFieldType(filteredAnyOf[0]);
}
}
/**
* Handle Polymorphic inputs, eg string | string[]. In OpenAPI, this is:
* - an `anyOf` with two items
* - one is an `ArraySchemaObject` with a single `SchemaObject or ReferenceObject` of type T in its `items`
* - the other is a `SchemaObject` or `ReferenceObject` of type T
*
* Any other cases we ignore.
*/
let firstType: string | undefined;
let secondType: string | undefined;
if (isArraySchemaObject(filteredAnyOf[0])) {
// first is array, second is not
const first = filteredAnyOf[0].items;
const second = filteredAnyOf[1];
if (isRefObject(first) && isRefObject(second)) {
firstType = refObjectToSchemaName(first);
secondType = refObjectToSchemaName(second);
} else if (
isNonArraySchemaObject(first) &&
isNonArraySchemaObject(second)
) {
firstType = first.type;
secondType = second.type;
}
} else if (isArraySchemaObject(filteredAnyOf[1])) {
// first is not array, second is
const first = filteredAnyOf[0];
const second = filteredAnyOf[1].items;
if (isRefObject(first) && isRefObject(second)) {
firstType = refObjectToSchemaName(first);
secondType = refObjectToSchemaName(second);
} else if (
isNonArraySchemaObject(first) &&
isNonArraySchemaObject(second)
) {
firstType = first.type;
secondType = second.type;
}
}
if (firstType && firstType === secondType) {
return {
name: OPENAPI_TO_FIELD_TYPE_MAP[firstType] ?? firstType,
isCollection: false,
isPolymorphic: true, // <-- don't forget, polymorphic!
};
}
}
} else if (schemaObject.enum) {
return { name: 'EnumField', isCollection: false, isPolymorphic: false };
} else if (schemaObject.type) {
if (schemaObject.type === 'array') {
// We need to get the type of the items
if (isSchemaObject(schemaObject.items)) {
const itemType = schemaObject.items.type;
if (!itemType || isArray(itemType)) {
throw new UnsupportedFieldTypeError(
t('nodes.unsupportedArrayItemType', {
type: itemType,
})
);
}
// This is an OpenAPI primitive - 'null', 'object', 'array', 'integer', 'number', 'string', 'boolean'
const name = OPENAPI_TO_FIELD_TYPE_MAP[itemType];
if (!name) {
// it's 'null', 'object', or 'array' - skip
throw new UnsupportedFieldTypeError(
t('nodes.unsupportedArrayItemType', {
type: itemType,
})
);
}
return {
name,
isCollection: true, // <-- don't forget, collection!
isPolymorphic: false,
};
}
// This is a ref object, extract the type name
const name = refObjectToSchemaName(schemaObject.items);
if (!name) {
throw new FieldTypeParseError(
t('nodes.unableToExtractSchemaNameFromRef')
);
}
return {
name,
isCollection: true, // <-- don't forget, collection!
isPolymorphic: false,
};
} else if (!isArray(schemaObject.type)) {
// This is an OpenAPI primitive - 'null', 'object', 'array', 'integer', 'number', 'string', 'boolean'
const name = OPENAPI_TO_FIELD_TYPE_MAP[schemaObject.type];
if (!name) {
// it's 'null', 'object', or 'array' - skip
throw new UnsupportedFieldTypeError(
t('nodes.unsupportedArrayItemType', {
type: schemaObject.type,
})
);
}
return {
name,
isCollection: false,
isPolymorphic: false,
};
}
}
} else if (isRefObject(schemaObject)) {
const name = refObjectToSchemaName(schemaObject);
if (!name) {
throw new FieldTypeParseError(
t('nodes.unableToExtractSchemaNameFromRef')
);
}
return {
name,
isCollection: false,
isPolymorphic: false,
};
}
throw new FieldTypeParseError(t('nodes.unableToParseFieldType'));
};

View File

@ -2,24 +2,24 @@ import { logger } from 'app/logging/logger';
import { parseify } from 'common/util/serialize';
import { reduce, startCase } from 'lodash-es';
import { OpenAPIV3_1 } from 'openapi-types';
import { AnyInvocationType } from 'services/events/types';
import { FieldInputTemplate, FieldOutputTemplate } from '../types/field';
import { InvocationTemplate } from '../types/invocation';
import {
InputFieldTemplate,
InvocationSchemaObject,
InvocationTemplate,
OutputFieldTemplate,
isFieldType,
isInvocationFieldSchema,
isInvocationOutputSchemaObject,
isInvocationSchemaObject,
} from '../types/types';
import { buildInputFieldTemplate, getFieldType } from './fieldTemplateBuilders';
} from '../types/openapi';
import { buildFieldInputTemplate } from './buildFieldInputTemplate';
import { parseFieldType } from './parseFieldType';
import { FieldTypeParseError, UnsupportedFieldTypeError } from '../types/error';
import { t } from 'i18next';
const RESERVED_INPUT_FIELD_NAMES = ['id', 'type', 'use_cache'];
const RESERVED_OUTPUT_FIELD_NAMES = ['type'];
const RESERVED_FIELD_TYPES = ['IsIntermediate'];
const invocationDenylist: AnyInvocationType[] = ['graph', 'linear_ui_output'];
const invocationDenylist: string[] = ['graph', 'linear_ui_output'];
const isReservedInputField = (nodeType: string, fieldName: string) => {
if (RESERVED_INPUT_FIELD_NAMES.includes(fieldName)) {
@ -83,13 +83,13 @@ export const parseSchema = (
const inputs = reduce(
schema.properties,
(
inputsAccumulator: Record<string, InputFieldTemplate>,
inputsAccumulator: Record<string, FieldInputTemplate>,
property,
propertyName
) => {
if (isReservedInputField(type, propertyName)) {
logger('nodes').trace(
{ node: type, fieldName: propertyName, field: parseify(property) },
{ node: type, field: propertyName, schema: parseify(property) },
'Skipped reserved input field'
);
return inputsAccumulator;
@ -97,79 +97,53 @@ export const parseSchema = (
if (!isInvocationFieldSchema(property)) {
logger('nodes').warn(
{ node: type, propertyName, property: parseify(property) },
{ node: type, field: propertyName, schema: parseify(property) },
'Unhandled input property'
);
return inputsAccumulator;
}
const fieldType = property.ui_type ?? getFieldType(property);
try {
const fieldType = parseFieldType(property);
if (!fieldType) {
logger('nodes').warn(
{
node: type,
fieldName: propertyName,
fieldType,
field: parseify(property),
},
'Missing input field type'
if (fieldType.name === 'WorkflowField') {
// This supports workflows, set the flag and skip to next field
withWorkflow = true;
return inputsAccumulator;
}
if (isReservedFieldType(fieldType.name)) {
// Skip processing this reserved field
return inputsAccumulator;
}
const fieldInputTemplate = buildFieldInputTemplate(
property,
propertyName,
fieldType
);
return inputsAccumulator;
inputsAccumulator[propertyName] = fieldInputTemplate;
} catch (e) {
if (
e instanceof FieldTypeParseError ||
e instanceof UnsupportedFieldTypeError
) {
logger('nodes').warn(
{
node: type,
field: propertyName,
schema: parseify(property),
},
t('nodes.inputFieldTypeParseError', {
node: type,
field: propertyName,
message: e.message,
})
);
}
}
if (fieldType === 'WorkflowField') {
withWorkflow = true;
return inputsAccumulator;
}
if (isReservedFieldType(fieldType)) {
logger('nodes').trace(
{
node: type,
fieldName: propertyName,
fieldType,
field: parseify(property),
},
`Skipping reserved input field type: ${fieldType}`
);
return inputsAccumulator;
}
if (!isFieldType(fieldType)) {
logger('nodes').warn(
{
node: type,
fieldName: propertyName,
fieldType,
field: parseify(property),
},
`Skipping unknown input field type: ${fieldType}`
);
return inputsAccumulator;
}
const field = buildInputFieldTemplate(
schema,
property,
propertyName,
fieldType
);
if (!field) {
logger('nodes').warn(
{
node: type,
fieldName: propertyName,
fieldType,
field: parseify(property),
},
'Skipping input field with no template'
);
return inputsAccumulator;
}
inputsAccumulator[propertyName] = field;
return inputsAccumulator;
},
{}
@ -206,7 +180,7 @@ export const parseSchema = (
(outputsAccumulator, property, propertyName) => {
if (!isAllowedOutputField(type, propertyName)) {
logger('nodes').trace(
{ type, propertyName, property: parseify(property) },
{ node: type, field: propertyName, schema: parseify(property) },
'Skipped reserved output field'
);
return outputsAccumulator;
@ -214,37 +188,62 @@ export const parseSchema = (
if (!isInvocationFieldSchema(property)) {
logger('nodes').warn(
{ type, propertyName, property: parseify(property) },
{ node: type, field: propertyName, schema: parseify(property) },
'Unhandled output property'
);
return outputsAccumulator;
}
const fieldType = property.ui_type ?? getFieldType(property);
try {
const fieldType = parseFieldType(property);
if (!isFieldType(fieldType)) {
logger('nodes').warn(
{ fieldName: propertyName, fieldType, field: parseify(property) },
'Skipping unknown output field type'
);
return outputsAccumulator;
if (!fieldType) {
logger('nodes').warn(
{
node: type,
field: propertyName,
schema: parseify(property),
},
'Missing output field type'
);
return outputsAccumulator;
}
const fieldOutputTemplate: FieldOutputTemplate = {
fieldKind: 'output',
name: propertyName,
title:
property.title ?? (propertyName ? startCase(propertyName) : ''),
description: property.description ?? '',
type: fieldType,
ui_hidden: property.ui_hidden ?? false,
ui_type: property.ui_type,
ui_order: property.ui_order,
};
outputsAccumulator[propertyName] = fieldOutputTemplate;
} catch (e) {
if (
e instanceof FieldTypeParseError ||
e instanceof UnsupportedFieldTypeError
) {
logger('nodes').warn(
{
node: type,
field: propertyName,
schema: parseify(property),
},
t('nodes.outputFieldTypeParseError', {
node: type,
field: propertyName,
message: e.message,
})
);
}
}
outputsAccumulator[propertyName] = {
fieldKind: 'output',
name: propertyName,
title:
property.title ?? (propertyName ? startCase(propertyName) : ''),
description: property.description ?? '',
type: fieldType,
ui_hidden: property.ui_hidden ?? false,
ui_type: property.ui_type,
ui_order: property.ui_order,
};
return outputsAccumulator;
},
{} as Record<string, OutputFieldTemplate>
{} as Record<string, FieldOutputTemplate>
);
const useCache = schema.properties.use_cache.default;

View File

@ -1,123 +1,159 @@
import { compareVersions } from 'compare-versions';
import { cloneDeep, keyBy } from 'lodash-es';
import {
InvocationTemplate,
Workflow,
WorkflowWarning,
isWorkflowInvocationNode,
} from '../types/types';
import { parseify } from 'common/util/serialize';
import i18n from 'i18next';
import { t } from 'i18next';
import { keyBy } from 'lodash-es';
import { JsonObject } from 'type-fest';
import { getNeedsUpdate } from '../store/util/nodeUpdate';
import { InvocationTemplate } from '../types/invocation';
import { parseAndMigrateWorkflow } from '../types/migration/migrations';
import { WorkflowV2, isWorkflowInvocationNode } from '../types/workflow';
type WorkflowWarning = {
message: string;
issues?: string[];
data: JsonObject;
};
type ValidateWorkflowResult = {
workflow: WorkflowV2;
warnings: WorkflowWarning[];
};
/**
* Parses and validates a workflow:
* - Parses the workflow schema, and migrates it to the latest version if necessary.
* - Validates the workflow against the node templates, warning if the template is not known.
* - Attempts to update nodes which have a mismatched version.
* - Removes edges which are invalid.
* @param workflow The raw workflow object (e.g. JSON.parse(stringifiedWorklow))
* @param invocationTemplates The node templates to validate against.
* @throws {WorkflowVersionError} If the workflow version is not recognized.
* @throws {z.ZodError} If there is a validation error.
*/
export const validateWorkflow = (
workflow: Workflow,
nodeTemplates: Record<string, InvocationTemplate>
) => {
const clone = cloneDeep(workflow);
const { nodes, edges } = clone;
const errors: WorkflowWarning[] = [];
workflow: unknown,
invocationTemplates: Record<string, InvocationTemplate>
): ValidateWorkflowResult => {
// Parse the raw workflow data & migrate it to the latest version
const _workflow = parseAndMigrateWorkflow(workflow);
// Now we can validate the graph
const { nodes, edges } = _workflow;
const warnings: WorkflowWarning[] = [];
// We don't need to validate Note nodes or CurrentImage nodes - only Invocation nodes
const invocationNodes = nodes.filter(isWorkflowInvocationNode);
const keyedNodes = keyBy(invocationNodes, 'id');
nodes.forEach((node) => {
if (!isWorkflowInvocationNode(node)) {
return;
}
const nodeTemplate = nodeTemplates[node.data.type];
if (!nodeTemplate) {
errors.push({
message: `${i18n.t('nodes.node')} "${node.data.type}" ${i18n.t(
'nodes.skipped'
)}`,
issues: [
`${i18n.t('nodes.nodeType')}"${node.data.type}" ${i18n.t(
'nodes.doesNotExist'
)}`,
],
data: node,
invocationNodes.forEach((node) => {
const template = invocationTemplates[node.data.type];
if (!template) {
// This node's type template does not exist
const message = t('nodes.missingTemplate', {
node: node.id,
type: node.data.type,
});
warnings.push({
message,
data: parseify(node),
});
return;
}
if (
nodeTemplate.version &&
node.data.version &&
compareVersions(nodeTemplate.version, node.data.version) !== 0
) {
errors.push({
message: `${i18n.t('nodes.node')} "${node.data.type}" ${i18n.t(
'nodes.mismatchedVersion'
)}`,
issues: [
`${i18n.t('nodes.node')} "${node.data.type}" v${
node.data.version
} ${i18n.t('nodes.maybeIncompatible')} v${nodeTemplate.version}`,
],
data: { node, nodeTemplate: parseify(nodeTemplate) },
if (getNeedsUpdate(node, template)) {
// This node needs to be updated, based on comparison of its version to the template version
const message = t('nodes.mismatchedVersion', {
node: node.id,
type: node.data.type,
});
warnings.push({
message,
data: parseify({ node, nodeTemplate: template }),
});
return;
}
});
edges.forEach((edge, i) => {
// Validate each edge. If the edge is invalid, we must remove it to prevent runtime errors with reactflow.
const sourceNode = keyedNodes[edge.source];
const targetNode = keyedNodes[edge.target];
const issues: string[] = [];
if (!sourceNode) {
// The edge's source/output node does not exist
issues.push(
`${i18n.t('nodes.outputNode')} ${edge.source} ${i18n.t(
'nodes.doesNotExist'
)}`
t('nodes.sourceNodeDoesNotExist', {
node: edge.source,
})
);
} else if (
edge.type === 'default' &&
!(edge.sourceHandle in sourceNode.data.outputs)
) {
// The edge's source/output node field does not exist
issues.push(
`${i18n.t('nodes.outputNode')} "${edge.source}.${
edge.sourceHandle
}" ${i18n.t('nodes.doesNotExist')}`
t('nodes.sourceNodeFieldDoesNotExist', {
node: edge.source,
field: edge.sourceHandle,
})
);
}
if (!targetNode) {
// The edge's target/input node does not exist
issues.push(
`${i18n.t('nodes.inputNode')} ${edge.target} ${i18n.t(
'nodes.doesNotExist'
)}`
t('nodes.targetNodeDoesNotExist', {
node: edge.target,
})
);
} else if (
edge.type === 'default' &&
!(edge.targetHandle in targetNode.data.inputs)
) {
// The edge's target/input node field does not exist
issues.push(
`${i18n.t('nodes.inputField')} "${edge.target}.${
edge.targetHandle
}" ${i18n.t('nodes.doesNotExist')}`
t('nodes.targetNodeFieldDoesNotExist', {
node: edge.target,
field: edge.targetHandle,
})
);
}
if (!nodeTemplates[sourceNode?.data.type ?? '__UNKNOWN_NODE_TYPE__']) {
if (!sourceNode?.data.type || !invocationTemplates[sourceNode.data.type]) {
// The edge's source/output node template does not exist
issues.push(
`${i18n.t('nodes.sourceNode')} "${edge.source}" ${i18n.t(
'nodes.missingTemplate'
)} "${sourceNode?.data.type}"`
t('nodes.missingTemplate', {
node: edge.source,
type: sourceNode?.data.type,
})
);
}
if (!nodeTemplates[targetNode?.data.type ?? '__UNKNOWN_NODE_TYPE__']) {
if (!targetNode?.data.type || !invocationTemplates[targetNode?.data.type]) {
// The edge's target/input node template does not exist
issues.push(
`${i18n.t('nodes.sourceNode')}"${edge.target}" ${i18n.t(
'nodes.missingTemplate'
)} "${targetNode?.data.type}"`
t('nodes.missingTemplate', {
node: edge.target,
type: targetNode?.data.type,
})
);
}
if (issues.length) {
// This edge has some issues. Remove it.
delete edges[i];
const src = edge.type === 'default' ? edge.sourceHandle : edge.source;
const tgt = edge.type === 'default' ? edge.targetHandle : edge.target;
errors.push({
message: `Edge "${src} -> ${tgt}" skipped`,
const source =
edge.type === 'default'
? `${edge.source}.${edge.sourceHandle}`
: edge.source;
const target =
edge.type === 'default'
? `${edge.source}.${edge.targetHandle}`
: edge.target;
warnings.push({
message: t('nodes.deletedInvalidEdge', { source, target }),
issues,
data: edge,
});
}
});
return { workflow: clone, errors };
return { workflow: _workflow, warnings };
};