mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): add support for custom field types
Node authors may now create their own arbitrary/custom field types. Any pydantic model is supported. Two notes: 1. Your field type's class name must be unique. Suggest prefixing fields with something related to the node pack as a kind of namespace. 2. Custom field types function as connection-only fields. For example, if your custom field has string attributes, you will not get a text input for that attribute when you give a node a field with your custom type. This is the same behaviour as other complex fields that don't have custom UIs in the workflow editor - like, say, a string collection. feat(ui): fix tooltips for custom types We need to hold onto the original type of the field so they don't all just show up as "Unknown". fix(ui): fix ts error with custom fields feat(ui): custom field types connection validation In the initial commit, a custom field's original type was added to the *field templates* only as `originalType`. Custom fields' `type` property was `"Custom"`*. This allowed for type safety throughout the UI logic. *Actually, it was `"Unknown"`, but I changed it to custom for clarity. Connection validation logic, however, uses the *field instance* of the node/field. Like the templates, *field instances* with custom types have their `type` set to `"Custom"`, but they didn't have an `originalType` property. As a result, all custom fields could be connected to all other custom fields. To resolve this, we need to add `originalType` to the *field instances*, then switch the validation logic to use this instead of `type`. This ended up needing a bit of fanagling: - If we make `originalType` a required property on field instances, existing workflows will break during connection validation, because they won't have this property. We'd need a new layer of logic to migrate the workflows, adding the new `originalType` property. While this layer is probably needed anyways, typing `originalType` as optional is much simpler. Workflow migration logic can come layer. (Technically, we could remove all references to field types from the workflow files, and let the templates hold all this information. This feels like a significant change and I'm reluctant to do it now.) - Because `originalType` is optional, anywhere we care about the type of a field, we need to use it over `type`. So there are a number of `field.originalType ?? field.type` expressions. This is a bit of a gotcha, we'll need to remember this in the future. - We use `Array.prototype.includes()` often in the workflow editor, e.g. `COLLECTION_TYPES.includes(type)`. In these cases, the const array is of type `FieldType[]`, and `type` is is `FieldType`. Because we now support custom types, the arg `type` is now widened from `FieldType` to `string`. This causes a TS error. This behaviour is somewhat controversial (see https://github.com/microsoft/TypeScript/issues/14520). These expressions are now rewritten as `COLLECTION_TYPES.some((t) => t === type)` to satisfy TS. It's logically equivalent. fix(ui): typo feat(ui): add CustomCollection and CustomPolymorphic field types feat(ui): add validation for CustomCollection & CustomPolymorphic types - Update connection validation for custom types - Use simple string parsing to determine if a field is a collection or polymorphic type. - No longer need to keep a list of collection and polymorphic types. - Added runtime checks in `baseinvocation.py` to ensure no fields are named in such a way that it could mess up the new parsing chore(ui): remove errant console.log fix(ui): rename 'nodes.currentConnectionFieldType' -> 'nodes.connectionStartFieldType' This was confusingly named and kept tripping me up. Renamed to be consistent with the `reactflow` `ConnectionStartParams` type. fix(ui): fix ts error feat(nodes): add runtime check for custom field names "Custom", "CustomCollection" and "CustomPolymorphic" are reserved field names. chore(ui): add TODO for revising field type names wip refactor fieldtype structured wip refactor field types wip refactor types wip refactor types fix node layout refactor field types chore: mypy organisation organisation organisation fix(nodes): fix field orig_required, field_kind and input statuses feat(nodes): remove broken implementation of default_factory on InputField Use of this could break connection validation due to the difference in node schemas required fields and invoke() required args. Removed entirely for now. It wasn't ever actually used by the system, because all graphs always had values provided for fields where default_factory was used. Also, pydantic is smart enough to not reuse the same object when specifying a default value - it clones the object first. So, the common pattern of `default_factory=list` is extraneous. It can just be `default=[]`. fix(nodes): fix InputField name validation workflow validation validation chore: ruff feat(nodes): fix up baseinvocation comments fix(ui): improve typing & logic of buildFieldInputTemplate improved error handling in parseFieldType fix: back compat for deprecated default_factory and UIType feat(nodes): do not show node packs loaded log if none loaded chore(ui): typegen
This commit is contained in:
@ -0,0 +1,42 @@
|
||||
import { get } from 'lodash-es';
|
||||
import { FieldInputInstance, FieldInputTemplate } from '../types/field';
|
||||
|
||||
const FIELD_VALUE_FALLBACK_MAP = {
|
||||
EnumField: '',
|
||||
BoardField: undefined,
|
||||
BooleanField: false,
|
||||
ClipField: undefined,
|
||||
ColorField: { r: 0, g: 0, b: 0, a: 1 },
|
||||
FloatField: 0,
|
||||
ImageField: undefined,
|
||||
IntegerField: 0,
|
||||
IPAdapterModelField: undefined,
|
||||
LoRAModelField: undefined,
|
||||
MainModelField: undefined,
|
||||
ONNXModelField: undefined,
|
||||
SchedulerField: 'euler',
|
||||
SDXLMainModelField: undefined,
|
||||
SDXLRefinerModelField: undefined,
|
||||
StringField: '',
|
||||
T2IAdapterModelField: undefined,
|
||||
T2IAdapterPolymorphic: undefined,
|
||||
VAEModelField: undefined,
|
||||
ControlNetModelField: undefined,
|
||||
};
|
||||
|
||||
export const buildFieldInputInstance = (
|
||||
id: string,
|
||||
template: FieldInputTemplate
|
||||
): FieldInputInstance => {
|
||||
const fieldInstance: FieldInputInstance = {
|
||||
id,
|
||||
name: template.name,
|
||||
type: template.type,
|
||||
label: '',
|
||||
fieldKind: 'input' as const,
|
||||
value:
|
||||
template.default ?? get(FIELD_VALUE_FALLBACK_MAP, template.type.name),
|
||||
};
|
||||
|
||||
return fieldInstance;
|
||||
};
|
@ -0,0 +1,376 @@
|
||||
import { isNumber, startCase } from 'lodash-es';
|
||||
import {
|
||||
BoardFieldInputTemplate,
|
||||
BooleanFieldInputTemplate,
|
||||
ColorFieldInputTemplate,
|
||||
ControlNetModelFieldInputTemplate,
|
||||
EnumFieldInputTemplate,
|
||||
FieldInputTemplate,
|
||||
FieldType,
|
||||
FloatFieldInputTemplate,
|
||||
IPAdapterModelFieldInputTemplate,
|
||||
ImageFieldInputTemplate,
|
||||
IntegerFieldInputTemplate,
|
||||
LoRAModelFieldInputTemplate,
|
||||
MainModelFieldInputTemplate,
|
||||
SDXLMainModelFieldInputTemplate,
|
||||
SDXLRefinerModelFieldInputTemplate,
|
||||
SchedulerFieldInputTemplate,
|
||||
StatefulFieldType,
|
||||
StatelessFieldInputTemplate,
|
||||
StringFieldInputTemplate,
|
||||
T2IAdapterModelFieldInputTemplate,
|
||||
VAEModelFieldInputTemplate,
|
||||
isStatefulFieldType,
|
||||
} from '../types/field';
|
||||
import { InvocationFieldSchema } from '../types/openapi';
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
type FieldInputTemplateBuilder<T extends FieldInputTemplate = any> = // valid `any`!
|
||||
(arg: {
|
||||
schemaObject: InvocationFieldSchema;
|
||||
baseField: Omit<T, 'type'>;
|
||||
isCollection: boolean;
|
||||
isPolymorphic: boolean;
|
||||
}) => T;
|
||||
|
||||
const buildIntegerFieldInputTemplate: FieldInputTemplateBuilder<
|
||||
IntegerFieldInputTemplate
|
||||
> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => {
|
||||
const template: IntegerFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: { name: 'IntegerField', isCollection, isPolymorphic },
|
||||
default: schemaObject.default ?? 0,
|
||||
};
|
||||
|
||||
if (schemaObject.multipleOf !== undefined) {
|
||||
template.multipleOf = schemaObject.multipleOf;
|
||||
}
|
||||
|
||||
if (schemaObject.maximum !== undefined) {
|
||||
template.maximum = schemaObject.maximum;
|
||||
}
|
||||
|
||||
if (
|
||||
schemaObject.exclusiveMaximum !== undefined &&
|
||||
isNumber(schemaObject.exclusiveMaximum)
|
||||
) {
|
||||
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
|
||||
}
|
||||
|
||||
if (schemaObject.minimum !== undefined) {
|
||||
template.minimum = schemaObject.minimum;
|
||||
}
|
||||
|
||||
if (
|
||||
schemaObject.exclusiveMinimum !== undefined &&
|
||||
isNumber(schemaObject.exclusiveMinimum)
|
||||
) {
|
||||
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
|
||||
}
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildFloatFieldInputTemplate: FieldInputTemplateBuilder<
|
||||
FloatFieldInputTemplate
|
||||
> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => {
|
||||
const template: FloatFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: { name: 'FloatField', isCollection, isPolymorphic },
|
||||
default: schemaObject.default ?? 0,
|
||||
};
|
||||
|
||||
if (schemaObject.multipleOf !== undefined) {
|
||||
template.multipleOf = schemaObject.multipleOf;
|
||||
}
|
||||
|
||||
if (schemaObject.maximum !== undefined) {
|
||||
template.maximum = schemaObject.maximum;
|
||||
}
|
||||
|
||||
if (
|
||||
schemaObject.exclusiveMaximum !== undefined &&
|
||||
isNumber(schemaObject.exclusiveMaximum)
|
||||
) {
|
||||
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
|
||||
}
|
||||
|
||||
if (schemaObject.minimum !== undefined) {
|
||||
template.minimum = schemaObject.minimum;
|
||||
}
|
||||
|
||||
if (
|
||||
schemaObject.exclusiveMinimum !== undefined &&
|
||||
isNumber(schemaObject.exclusiveMinimum)
|
||||
) {
|
||||
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
|
||||
}
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildStringFieldInputTemplate: FieldInputTemplateBuilder<
|
||||
StringFieldInputTemplate
|
||||
> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => {
|
||||
const template: StringFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: { name: 'StringField', isCollection, isPolymorphic },
|
||||
default: schemaObject.default ?? '',
|
||||
};
|
||||
|
||||
if (schemaObject.minLength !== undefined) {
|
||||
template.minLength = schemaObject.minLength;
|
||||
}
|
||||
|
||||
if (schemaObject.maxLength !== undefined) {
|
||||
template.maxLength = schemaObject.maxLength;
|
||||
}
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildBooleanFieldInputTemplate: FieldInputTemplateBuilder<
|
||||
BooleanFieldInputTemplate
|
||||
> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => {
|
||||
const template: BooleanFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: { name: 'BooleanField', isCollection, isPolymorphic },
|
||||
default: schemaObject.default ?? false,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildMainModelFieldInputTemplate: FieldInputTemplateBuilder<
|
||||
MainModelFieldInputTemplate
|
||||
> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => {
|
||||
const template: MainModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: { name: 'MainModelField', isCollection, isPolymorphic },
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildSDXLMainModelFieldInputTemplate: FieldInputTemplateBuilder<
|
||||
SDXLMainModelFieldInputTemplate
|
||||
> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => {
|
||||
const template: SDXLMainModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: { name: 'SDXLMainModelField', isCollection, isPolymorphic },
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildRefinerModelFieldInputTemplate: FieldInputTemplateBuilder<
|
||||
SDXLRefinerModelFieldInputTemplate
|
||||
> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => {
|
||||
const template: SDXLRefinerModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: { name: 'SDXLRefinerModelField', isCollection, isPolymorphic },
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildVAEModelFieldInputTemplate: FieldInputTemplateBuilder<
|
||||
VAEModelFieldInputTemplate
|
||||
> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => {
|
||||
const template: VAEModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: { name: 'VAEModelField', isCollection, isPolymorphic },
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildLoRAModelFieldInputTemplate: FieldInputTemplateBuilder<
|
||||
LoRAModelFieldInputTemplate
|
||||
> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => {
|
||||
const template: LoRAModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: { name: 'LoRAModelField', isCollection, isPolymorphic },
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildControlNetModelFieldInputTemplate: FieldInputTemplateBuilder<
|
||||
ControlNetModelFieldInputTemplate
|
||||
> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => {
|
||||
const template: ControlNetModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: { name: 'ControlNetModelField', isCollection, isPolymorphic },
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildIPAdapterModelFieldInputTemplate: FieldInputTemplateBuilder<
|
||||
IPAdapterModelFieldInputTemplate
|
||||
> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => {
|
||||
const template: IPAdapterModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: { name: 'IPAdapterModelField', isCollection, isPolymorphic },
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildT2IAdapterModelFieldInputTemplate: FieldInputTemplateBuilder<
|
||||
T2IAdapterModelFieldInputTemplate
|
||||
> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => {
|
||||
const template: T2IAdapterModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: { name: 'T2IAdapterModelField', isCollection, isPolymorphic },
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildBoardFieldInputTemplate: FieldInputTemplateBuilder<
|
||||
BoardFieldInputTemplate
|
||||
> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => {
|
||||
const template: BoardFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: { name: 'BoardField', isCollection, isPolymorphic },
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildImageFieldInputTemplate: FieldInputTemplateBuilder<
|
||||
ImageFieldInputTemplate
|
||||
> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => {
|
||||
const template: ImageFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: { name: 'ImageField', isCollection, isPolymorphic },
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildEnumFieldInputTemplate: FieldInputTemplateBuilder<
|
||||
EnumFieldInputTemplate
|
||||
> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => {
|
||||
const options = schemaObject.enum ?? [];
|
||||
const template: EnumFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: { name: 'EnumField', isCollection, isPolymorphic },
|
||||
options,
|
||||
ui_choice_labels: schemaObject.ui_choice_labels,
|
||||
default: schemaObject.default ?? options[0],
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildColorFieldInputTemplate: FieldInputTemplateBuilder<
|
||||
ColorFieldInputTemplate
|
||||
> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => {
|
||||
const template: ColorFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: { name: 'ColorField', isCollection, isPolymorphic },
|
||||
default: schemaObject.default ?? { r: 127, g: 127, b: 127, a: 255 },
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildSchedulerFieldInputTemplate: FieldInputTemplateBuilder<
|
||||
SchedulerFieldInputTemplate
|
||||
> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => {
|
||||
const template: SchedulerFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: { name: 'SchedulerField', isCollection, isPolymorphic },
|
||||
default: schemaObject.default ?? 'euler',
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
export const TEMPLATE_BUILDER_MAP: Record<
|
||||
StatefulFieldType['name'],
|
||||
FieldInputTemplateBuilder
|
||||
> = {
|
||||
BoardField: buildBoardFieldInputTemplate,
|
||||
BooleanField: buildBooleanFieldInputTemplate,
|
||||
ColorField: buildColorFieldInputTemplate,
|
||||
ControlNetModelField: buildControlNetModelFieldInputTemplate,
|
||||
EnumField: buildEnumFieldInputTemplate,
|
||||
FloatField: buildFloatFieldInputTemplate,
|
||||
ImageField: buildImageFieldInputTemplate,
|
||||
IntegerField: buildIntegerFieldInputTemplate,
|
||||
IPAdapterModelField: buildIPAdapterModelFieldInputTemplate,
|
||||
LoRAModelField: buildLoRAModelFieldInputTemplate,
|
||||
MainModelField: buildMainModelFieldInputTemplate,
|
||||
SchedulerField: buildSchedulerFieldInputTemplate,
|
||||
SDXLMainModelField: buildSDXLMainModelFieldInputTemplate,
|
||||
SDXLRefinerModelField: buildRefinerModelFieldInputTemplate,
|
||||
StringField: buildStringFieldInputTemplate,
|
||||
T2IAdapterModelField: buildT2IAdapterModelFieldInputTemplate,
|
||||
VAEModelField: buildVAEModelFieldInputTemplate,
|
||||
};
|
||||
|
||||
export const buildFieldInputTemplate = (
|
||||
fieldSchema: InvocationFieldSchema,
|
||||
name: string,
|
||||
fieldType: FieldType
|
||||
): FieldInputTemplate => {
|
||||
const {
|
||||
input,
|
||||
ui_hidden,
|
||||
ui_component,
|
||||
ui_type,
|
||||
ui_order,
|
||||
ui_choice_labels,
|
||||
orig_required: required,
|
||||
} = fieldSchema;
|
||||
|
||||
// This is the base field template that is common to all fields. The builder function will add all other
|
||||
// properties to this template.
|
||||
const baseField: Omit<FieldInputTemplate, 'type'> = {
|
||||
name,
|
||||
title: fieldSchema.title ?? (name ? startCase(name) : ''),
|
||||
required,
|
||||
description: fieldSchema.description ?? '',
|
||||
fieldKind: 'input' as const,
|
||||
input,
|
||||
ui_hidden,
|
||||
ui_component,
|
||||
ui_type,
|
||||
ui_order,
|
||||
ui_choice_labels,
|
||||
};
|
||||
|
||||
if (isStatefulFieldType(fieldType)) {
|
||||
const builder = TEMPLATE_BUILDER_MAP[fieldType.name];
|
||||
return builder({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
isCollection: fieldType.isCollection,
|
||||
isPolymorphic: fieldType.isPolymorphic,
|
||||
});
|
||||
}
|
||||
|
||||
// This is a StatelessField, create it directly.
|
||||
const template: StatelessFieldInputTemplate = {
|
||||
...baseField,
|
||||
input: 'connection', // stateless --> connection only inputs
|
||||
type: fieldType,
|
||||
default: undefined, // stateless --> no default value
|
||||
};
|
||||
return template;
|
||||
};
|
@ -1,13 +1,13 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { NodesState } from '../store/types';
|
||||
import { Workflow, zWorkflowEdge, zWorkflowNode } from '../types/types';
|
||||
import { WorkflowV2, zWorkflowEdge, zWorkflowNode } from '../types/workflow';
|
||||
import { fromZodError } from 'zod-validation-error';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import i18n from 'i18next';
|
||||
|
||||
export const buildWorkflow = (nodesState: NodesState): Workflow => {
|
||||
export const buildWorkflow = (nodesState: NodesState): WorkflowV2 => {
|
||||
const { workflow: workflowMeta, nodes, edges } = nodesState;
|
||||
const workflow: Workflow = {
|
||||
const workflow: WorkflowV2 = {
|
||||
...workflowMeta,
|
||||
nodes: [],
|
||||
edges: [],
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,85 +0,0 @@
|
||||
import { FieldType, InputFieldTemplate, InputFieldValue } from '../types/types';
|
||||
|
||||
const FIELD_VALUE_FALLBACK_MAP: {
|
||||
[key in FieldType]: InputFieldValue['value'];
|
||||
} = {
|
||||
Any: undefined,
|
||||
enum: '',
|
||||
BoardField: undefined,
|
||||
boolean: false,
|
||||
BooleanCollection: [],
|
||||
BooleanPolymorphic: false,
|
||||
ClipField: undefined,
|
||||
Collection: [],
|
||||
CollectionItem: undefined,
|
||||
ColorCollection: [],
|
||||
ColorField: undefined,
|
||||
ColorPolymorphic: undefined,
|
||||
ConditioningCollection: [],
|
||||
ConditioningField: undefined,
|
||||
ConditioningPolymorphic: undefined,
|
||||
ControlCollection: [],
|
||||
ControlField: undefined,
|
||||
ControlNetModelField: undefined,
|
||||
ControlPolymorphic: undefined,
|
||||
DenoiseMaskField: undefined,
|
||||
float: 0,
|
||||
FloatCollection: [],
|
||||
FloatPolymorphic: 0,
|
||||
ImageCollection: [],
|
||||
ImageField: undefined,
|
||||
ImagePolymorphic: undefined,
|
||||
integer: 0,
|
||||
IntegerCollection: [],
|
||||
IntegerPolymorphic: 0,
|
||||
IPAdapterCollection: [],
|
||||
IPAdapterField: undefined,
|
||||
IPAdapterModelField: undefined,
|
||||
IPAdapterPolymorphic: undefined,
|
||||
LatentsCollection: [],
|
||||
LatentsField: undefined,
|
||||
LatentsPolymorphic: undefined,
|
||||
MetadataItemField: undefined,
|
||||
MetadataItemCollection: [],
|
||||
MetadataItemPolymorphic: undefined,
|
||||
MetadataField: undefined,
|
||||
MetadataCollection: [],
|
||||
LoRAModelField: undefined,
|
||||
MainModelField: undefined,
|
||||
ONNXModelField: undefined,
|
||||
Scheduler: 'euler',
|
||||
SDXLMainModelField: undefined,
|
||||
SDXLRefinerModelField: undefined,
|
||||
string: '',
|
||||
StringCollection: [],
|
||||
StringPolymorphic: '',
|
||||
T2IAdapterCollection: [],
|
||||
T2IAdapterField: undefined,
|
||||
T2IAdapterModelField: undefined,
|
||||
T2IAdapterPolymorphic: undefined,
|
||||
UNetField: undefined,
|
||||
VaeField: undefined,
|
||||
VaeModelField: undefined,
|
||||
};
|
||||
|
||||
export const buildInputFieldValue = (
|
||||
id: string,
|
||||
template: InputFieldTemplate
|
||||
): InputFieldValue => {
|
||||
// TODO: this should be `fieldValue: InputFieldValue`, but that introduces a TS issue I couldn't
|
||||
// resolve - for some reason, it doesn't like `template.type`, which is the discriminant for both
|
||||
// `InputFieldTemplate` union. It is (type-structurally) equal to the discriminant for the
|
||||
// `InputFieldValue` union, but TS doesn't seem to like it...
|
||||
const fieldValue = {
|
||||
id,
|
||||
name: template.name,
|
||||
type: template.type,
|
||||
label: '',
|
||||
fieldKind: 'input',
|
||||
} as InputFieldValue;
|
||||
|
||||
fieldValue.value =
|
||||
template.default ?? FIELD_VALUE_FALLBACK_MAP[template.type];
|
||||
|
||||
return fieldValue;
|
||||
};
|
@ -1,8 +1,8 @@
|
||||
import { isNil } from 'lodash-es';
|
||||
import { InputFieldTemplate, OutputFieldTemplate } from '../types/types';
|
||||
import { FieldInputTemplate, FieldOutputTemplate } from '../types/field';
|
||||
|
||||
export const getSortedFilteredFieldNames = (
|
||||
fields: InputFieldTemplate[] | OutputFieldTemplate[]
|
||||
fields: FieldInputTemplate[] | FieldOutputTemplate[]
|
||||
): string[] => {
|
||||
const visibleFields = fields.filter((field) => !field.ui_hidden);
|
||||
|
||||
|
@ -6,8 +6,8 @@ import {
|
||||
ControlField,
|
||||
ControlNetInvocation,
|
||||
CoreMetadataInvocation,
|
||||
NonNullableGraph,
|
||||
} from 'services/api/types';
|
||||
import { NonNullableGraph } from '../../types/types';
|
||||
import {
|
||||
CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||
CONTROL_NET_COLLECT,
|
||||
|
@ -1,13 +1,13 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { roundToMultiple } from 'common/util/roundDownToMultiple';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import {
|
||||
DenoiseLatentsInvocation,
|
||||
ESRGANInvocation,
|
||||
Edge,
|
||||
LatentsToImageInvocation,
|
||||
NoiseInvocation,
|
||||
NonNullableGraph,
|
||||
} from 'services/api/types';
|
||||
import {
|
||||
DENOISE_LATENTS,
|
||||
|
@ -6,8 +6,8 @@ import {
|
||||
CoreMetadataInvocation,
|
||||
IPAdapterInvocation,
|
||||
IPAdapterMetadataField,
|
||||
NonNullableGraph,
|
||||
} from 'services/api/types';
|
||||
import { NonNullableGraph } from '../../types/types';
|
||||
import {
|
||||
CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||
IP_ADAPTER_COLLECT,
|
||||
|
@ -1,7 +1,6 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { LinearUIOutputInvocation } from 'services/api/types';
|
||||
import { LinearUIOutputInvocation, NonNullableGraph } from 'services/api/types';
|
||||
import {
|
||||
CANVAS_OUTPUT,
|
||||
LATENTS_TO_IMAGE,
|
||||
|
@ -1,9 +1,9 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { forEach, size } from 'lodash-es';
|
||||
import {
|
||||
CoreMetadataInvocation,
|
||||
LoraLoaderInvocation,
|
||||
NonNullableGraph,
|
||||
} from 'services/api/types';
|
||||
import {
|
||||
CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||
|
@ -1,8 +1,8 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import {
|
||||
ImageNSFWBlurInvocation,
|
||||
LatentsToImageInvocation,
|
||||
NonNullableGraph,
|
||||
} from 'services/api/types';
|
||||
import { LATENTS_TO_IMAGE, NSFW_CHECKER } from './constants';
|
||||
|
||||
|
@ -1,11 +1,10 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import {
|
||||
LoRAMetadataItem,
|
||||
NonNullableGraph,
|
||||
zLoRAMetadataItem,
|
||||
} from 'features/nodes/types/types';
|
||||
} from 'features/nodes/types/metadata';
|
||||
import { forEach, size } from 'lodash-es';
|
||||
import { SDXLLoraLoaderInvocation } from 'services/api/types';
|
||||
import { NonNullableGraph, SDXLLoraLoaderInvocation } from 'services/api/types';
|
||||
import {
|
||||
CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||
LORA_LOADER,
|
||||
|
@ -2,9 +2,9 @@ import { RootState } from 'app/store/store';
|
||||
import {
|
||||
CreateDenoiseMaskInvocation,
|
||||
ImageDTO,
|
||||
NonNullableGraph,
|
||||
SeamlessModeInvocation,
|
||||
} from 'services/api/types';
|
||||
import { NonNullableGraph } from '../../types/types';
|
||||
import {
|
||||
CANVAS_OUTPUT,
|
||||
INPAINT_IMAGE_RESIZE_UP,
|
||||
|
@ -1,7 +1,5 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { SeamlessModeInvocation } from 'services/api/types';
|
||||
import { NonNullableGraph } from '../../types/types';
|
||||
import { upsertMetadata } from './metadata';
|
||||
import { NonNullableGraph, SeamlessModeInvocation } from 'services/api/types';
|
||||
import {
|
||||
CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||
CANVAS_INPAINT_GRAPH,
|
||||
@ -16,6 +14,7 @@ import {
|
||||
SDXL_TEXT_TO_IMAGE_GRAPH,
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { upsertMetadata } from './metadata';
|
||||
|
||||
export const addSeamlessToLinearGraph = (
|
||||
state: RootState,
|
||||
|
@ -4,9 +4,10 @@ import { omit } from 'lodash-es';
|
||||
import {
|
||||
CollectInvocation,
|
||||
CoreMetadataInvocation,
|
||||
NonNullableGraph,
|
||||
T2IAdapterField,
|
||||
T2IAdapterInvocation,
|
||||
} from 'services/api/types';
|
||||
import { NonNullableGraph, T2IAdapterField } from '../../types/types';
|
||||
import {
|
||||
CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||
T2I_ADAPTER_COLLECT,
|
||||
|
@ -1,5 +1,5 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { NonNullableGraph } from 'services/api/types';
|
||||
import {
|
||||
CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||
CANVAS_IMAGE_TO_IMAGE_GRAPH,
|
||||
|
@ -1,10 +1,10 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import {
|
||||
ImageNSFWBlurInvocation,
|
||||
ImageWatermarkInvocation,
|
||||
LatentsToImageInvocation,
|
||||
NonNullableGraph,
|
||||
} from 'services/api/types';
|
||||
import { LATENTS_TO_IMAGE, NSFW_CHECKER, WATERMARKER } from './constants';
|
||||
|
||||
|
@ -1,10 +1,10 @@
|
||||
import { BoardId } from 'features/gallery/store/types';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { ESRGANModelName } from 'features/parameters/store/postprocessingSlice';
|
||||
import {
|
||||
ESRGANInvocation,
|
||||
Graph,
|
||||
LinearUIOutputInvocation,
|
||||
NonNullableGraph,
|
||||
} from 'services/api/types';
|
||||
import { ESRGAN, LINEAR_UI_OUTPUT } from './constants';
|
||||
import { addCoreMetadataNode, upsertMetadata } from './metadata';
|
||||
|
@ -1,6 +1,5 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import { ImageDTO, NonNullableGraph } from 'services/api/types';
|
||||
import { buildCanvasImageToImageGraph } from './buildCanvasImageToImageGraph';
|
||||
import { buildCanvasInpaintGraph } from './buildCanvasInpaintGraph';
|
||||
import { buildCanvasOutpaintGraph } from './buildCanvasOutpaintGraph';
|
||||
|
@ -1,12 +1,15 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { ImageDTO, ImageToLatentsInvocation } from 'services/api/types';
|
||||
import {
|
||||
ImageDTO,
|
||||
ImageToLatentsInvocation,
|
||||
NonNullableGraph,
|
||||
} from 'services/api/types';
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
|
||||
import { addLoRAsToGraph } from './addLoRAsToGraph';
|
||||
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
|
||||
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
|
||||
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
|
||||
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
|
@ -1,6 +1,5 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import {
|
||||
CreateDenoiseMaskInvocation,
|
||||
ImageBlurInvocation,
|
||||
@ -8,12 +7,13 @@ import {
|
||||
ImageToLatentsInvocation,
|
||||
MaskEdgeInvocation,
|
||||
NoiseInvocation,
|
||||
NonNullableGraph,
|
||||
} from 'services/api/types';
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
|
||||
import { addLoRAsToGraph } from './addLoRAsToGraph';
|
||||
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
|
||||
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
|
||||
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
|
||||
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
|
@ -1,18 +1,18 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import {
|
||||
ImageDTO,
|
||||
ImageToLatentsInvocation,
|
||||
InfillPatchMatchInvocation,
|
||||
InfillTileInvocation,
|
||||
NoiseInvocation,
|
||||
NonNullableGraph,
|
||||
} from 'services/api/types';
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
|
||||
import { addLoRAsToGraph } from './addLoRAsToGraph';
|
||||
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
|
||||
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
|
||||
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
|
||||
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
|
@ -1,14 +1,18 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { ImageDTO, ImageToLatentsInvocation } from 'services/api/types';
|
||||
import {
|
||||
ImageDTO,
|
||||
ImageToLatentsInvocation,
|
||||
NonNullableGraph,
|
||||
} from 'services/api/types';
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
|
||||
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
|
||||
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
|
||||
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
|
||||
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
|
||||
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
|
||||
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
||||
import {
|
||||
@ -26,7 +30,6 @@ import {
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { buildSDXLStylePrompts } from './helpers/craftSDXLStylePrompt';
|
||||
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
|
||||
import { addCoreMetadataNode } from './metadata';
|
||||
|
||||
/**
|
||||
|
@ -1,6 +1,5 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import {
|
||||
CreateDenoiseMaskInvocation,
|
||||
ImageBlurInvocation,
|
||||
@ -8,13 +7,14 @@ import {
|
||||
ImageToLatentsInvocation,
|
||||
MaskEdgeInvocation,
|
||||
NoiseInvocation,
|
||||
NonNullableGraph,
|
||||
} from 'services/api/types';
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
|
||||
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
|
||||
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
|
||||
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
|
||||
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
|
||||
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
|
||||
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
|
@ -1,19 +1,19 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import {
|
||||
ImageDTO,
|
||||
ImageToLatentsInvocation,
|
||||
InfillPatchMatchInvocation,
|
||||
InfillTileInvocation,
|
||||
NoiseInvocation,
|
||||
NonNullableGraph,
|
||||
} from 'services/api/types';
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
|
||||
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
|
||||
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
|
||||
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
|
||||
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
|
||||
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
|
||||
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
|
@ -1,16 +1,16 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import {
|
||||
DenoiseLatentsInvocation,
|
||||
NonNullableGraph,
|
||||
ONNXTextToLatentsInvocation,
|
||||
} from 'services/api/types';
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
|
||||
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
|
||||
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
|
||||
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
|
||||
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
|
||||
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
|
||||
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
|
@ -1,15 +1,15 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import {
|
||||
DenoiseLatentsInvocation,
|
||||
NonNullableGraph,
|
||||
ONNXTextToLatentsInvocation,
|
||||
} from 'services/api/types';
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
|
||||
import { addLoRAsToGraph } from './addLoRAsToGraph';
|
||||
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
|
||||
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
|
||||
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
|
||||
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
|
@ -1,10 +1,9 @@
|
||||
import { NUMPY_RAND_MAX } from 'app/constants';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { generateSeeds } from 'common/util/generateSeeds';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { range } from 'lodash-es';
|
||||
import { components } from 'services/api/schema';
|
||||
import { Batch, BatchConfig } from 'services/api/types';
|
||||
import { Batch, BatchConfig, NonNullableGraph } from 'services/api/types';
|
||||
import {
|
||||
CANVAS_COHERENCE_NOISE,
|
||||
METADATA,
|
||||
|
@ -1,15 +1,15 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import {
|
||||
ImageResizeInvocation,
|
||||
ImageToLatentsInvocation,
|
||||
NonNullableGraph,
|
||||
} from 'services/api/types';
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
|
||||
import { addLoRAsToGraph } from './addLoRAsToGraph';
|
||||
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
|
||||
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
|
||||
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
|
||||
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
|
@ -1,16 +1,16 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import {
|
||||
ImageResizeInvocation,
|
||||
ImageToLatentsInvocation,
|
||||
NonNullableGraph,
|
||||
} from 'services/api/types';
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
|
||||
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
|
||||
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
|
||||
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
|
||||
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
|
||||
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
|
||||
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
|
@ -1,17 +1,16 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { NonNullableGraph } from 'services/api/types';
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
|
||||
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
|
||||
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
|
||||
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
|
||||
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
|
||||
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
|
||||
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
||||
import { addCoreMetadataNode } from './metadata';
|
||||
import {
|
||||
LATENTS_TO_IMAGE,
|
||||
NEGATIVE_CONDITIONING,
|
||||
@ -24,6 +23,7 @@ import {
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { buildSDXLStylePrompts } from './helpers/craftSDXLStylePrompt';
|
||||
import { addCoreMetadataNode } from './metadata';
|
||||
|
||||
export const buildLinearSDXLTextToImageGraph = (
|
||||
state: RootState
|
||||
|
@ -1,21 +1,20 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import {
|
||||
DenoiseLatentsInvocation,
|
||||
NonNullableGraph,
|
||||
ONNXTextToLatentsInvocation,
|
||||
} from 'services/api/types';
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addHrfToGraph } from './addHrfToGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
|
||||
import { addLoRAsToGraph } from './addLoRAsToGraph';
|
||||
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
|
||||
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
|
||||
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
|
||||
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
||||
import { addCoreMetadataNode } from './metadata';
|
||||
import {
|
||||
CLIP_SKIP,
|
||||
DENOISE_LATENTS,
|
||||
@ -28,6 +27,7 @@ import {
|
||||
SEAMLESS,
|
||||
TEXT_TO_IMAGE_GRAPH,
|
||||
} from './constants';
|
||||
import { addCoreMetadataNode } from './metadata';
|
||||
|
||||
export const buildLinearTextToImageGraph = (
|
||||
state: RootState
|
||||
|
@ -1,16 +1,20 @@
|
||||
import { NodesState } from 'features/nodes/store/types';
|
||||
import { InputFieldValue, isInvocationNode } from 'features/nodes/types/types';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { cloneDeep, omit, reduce } from 'lodash-es';
|
||||
import { Graph } from 'services/api/types';
|
||||
import { AnyInvocation } from 'services/events/types';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { buildWorkflow } from '../buildWorkflow';
|
||||
import {
|
||||
FieldInputInstance,
|
||||
isColorFieldInputInstance,
|
||||
} from 'features/nodes/types/field';
|
||||
|
||||
/**
|
||||
* We need to do special handling for some fields
|
||||
*/
|
||||
export const parseFieldValue = (field: InputFieldValue) => {
|
||||
if (field.type === 'ColorField') {
|
||||
export const parseFieldValue = (field: FieldInputInstance) => {
|
||||
if (isColorFieldInputInstance(field)) {
|
||||
if (field.value) {
|
||||
const clonedValue = cloneDeep(field.value);
|
||||
|
||||
|
@ -1,5 +1,4 @@
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { CoreMetadataInvocation } from 'services/api/types';
|
||||
import { CoreMetadataInvocation, NonNullableGraph } from 'services/api/types';
|
||||
import { JsonObject } from 'type-fest';
|
||||
import { METADATA } from './constants';
|
||||
|
||||
|
233
invokeai/frontend/web/src/features/nodes/util/parseFieldType.ts
Normal file
233
invokeai/frontend/web/src/features/nodes/util/parseFieldType.ts
Normal file
@ -0,0 +1,233 @@
|
||||
import { t } from 'i18next';
|
||||
import { isArray } from 'lodash-es';
|
||||
import { OpenAPIV3_1 } from 'openapi-types';
|
||||
import { FieldTypeParseError, UnsupportedFieldTypeError } from '../types/error';
|
||||
import { FieldType } from '../types/field';
|
||||
import {
|
||||
OpenAPIV3_1SchemaOrRef,
|
||||
isArraySchemaObject,
|
||||
isInvocationFieldSchema,
|
||||
isNonArraySchemaObject,
|
||||
isRefObject,
|
||||
isSchemaObject,
|
||||
} from '../types/openapi';
|
||||
|
||||
/**
|
||||
* Transforms an invocation output ref object to field type.
|
||||
* @param ref The ref string to transform
|
||||
* @returns The field type.
|
||||
*
|
||||
* @example
|
||||
* refObjectToFieldType({ "$ref": "#/components/schemas/ImageField" }) --> 'ImageField'
|
||||
*/
|
||||
export const refObjectToSchemaName = (refObject: OpenAPIV3_1.ReferenceObject) =>
|
||||
refObject.$ref.split('/').slice(-1)[0];
|
||||
|
||||
const OPENAPI_TO_FIELD_TYPE_MAP: Record<string, string> = {
|
||||
integer: 'IntegerField',
|
||||
number: 'FloatField',
|
||||
string: 'StringField',
|
||||
boolean: 'BooleanField',
|
||||
};
|
||||
|
||||
const isCollectionFieldType = (fieldType: string) => {
|
||||
/**
|
||||
* CollectionField is `list[Any]` in the pydantic schema, but we need to distinguish between
|
||||
* it and other `list[Any]` fields, due to its special internal handling.
|
||||
*
|
||||
* In pydantic, it gets an explicit field type of `CollectionField`.
|
||||
*/
|
||||
if (fieldType === 'CollectionField') {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
export const parseFieldType = (
|
||||
schemaObject: OpenAPIV3_1SchemaOrRef
|
||||
): FieldType => {
|
||||
if (isInvocationFieldSchema(schemaObject)) {
|
||||
// Check if this field has an explicit type provided by the node schema
|
||||
const { ui_type } = schemaObject;
|
||||
if (ui_type) {
|
||||
return {
|
||||
name: ui_type,
|
||||
isCollection: isCollectionFieldType(ui_type),
|
||||
isPolymorphic: false,
|
||||
};
|
||||
}
|
||||
}
|
||||
if (isSchemaObject(schemaObject)) {
|
||||
if (!schemaObject.type) {
|
||||
// if schemaObject has no type, then it should have one of allOf, anyOf, oneOf
|
||||
|
||||
if (schemaObject.allOf) {
|
||||
const allOf = schemaObject.allOf;
|
||||
if (allOf && allOf[0] && isRefObject(allOf[0])) {
|
||||
// This is a single ref type
|
||||
const name = refObjectToSchemaName(allOf[0]);
|
||||
if (!name) {
|
||||
throw new FieldTypeParseError(
|
||||
t('nodes.unableToExtractSchemaNameFromRef')
|
||||
);
|
||||
}
|
||||
return {
|
||||
name,
|
||||
isCollection: false,
|
||||
isPolymorphic: false,
|
||||
};
|
||||
}
|
||||
} else if (schemaObject.anyOf) {
|
||||
// ignore null types
|
||||
const filteredAnyOf = schemaObject.anyOf.filter((i) => {
|
||||
if (isSchemaObject(i)) {
|
||||
if (i.type === 'null') {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
});
|
||||
if (filteredAnyOf.length === 1) {
|
||||
// This is a single ref type
|
||||
if (isRefObject(filteredAnyOf[0])) {
|
||||
const name = refObjectToSchemaName(filteredAnyOf[0]);
|
||||
if (!name) {
|
||||
throw new FieldTypeParseError(
|
||||
t('nodes.unableToExtractSchemaNameFromRef')
|
||||
);
|
||||
}
|
||||
|
||||
return {
|
||||
name,
|
||||
isCollection: false,
|
||||
isPolymorphic: false,
|
||||
};
|
||||
} else if (isSchemaObject(filteredAnyOf[0])) {
|
||||
return parseFieldType(filteredAnyOf[0]);
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Handle Polymorphic inputs, eg string | string[]. In OpenAPI, this is:
|
||||
* - an `anyOf` with two items
|
||||
* - one is an `ArraySchemaObject` with a single `SchemaObject or ReferenceObject` of type T in its `items`
|
||||
* - the other is a `SchemaObject` or `ReferenceObject` of type T
|
||||
*
|
||||
* Any other cases we ignore.
|
||||
*/
|
||||
|
||||
let firstType: string | undefined;
|
||||
let secondType: string | undefined;
|
||||
|
||||
if (isArraySchemaObject(filteredAnyOf[0])) {
|
||||
// first is array, second is not
|
||||
const first = filteredAnyOf[0].items;
|
||||
const second = filteredAnyOf[1];
|
||||
if (isRefObject(first) && isRefObject(second)) {
|
||||
firstType = refObjectToSchemaName(first);
|
||||
secondType = refObjectToSchemaName(second);
|
||||
} else if (
|
||||
isNonArraySchemaObject(first) &&
|
||||
isNonArraySchemaObject(second)
|
||||
) {
|
||||
firstType = first.type;
|
||||
secondType = second.type;
|
||||
}
|
||||
} else if (isArraySchemaObject(filteredAnyOf[1])) {
|
||||
// first is not array, second is
|
||||
const first = filteredAnyOf[0];
|
||||
const second = filteredAnyOf[1].items;
|
||||
if (isRefObject(first) && isRefObject(second)) {
|
||||
firstType = refObjectToSchemaName(first);
|
||||
secondType = refObjectToSchemaName(second);
|
||||
} else if (
|
||||
isNonArraySchemaObject(first) &&
|
||||
isNonArraySchemaObject(second)
|
||||
) {
|
||||
firstType = first.type;
|
||||
secondType = second.type;
|
||||
}
|
||||
}
|
||||
if (firstType && firstType === secondType) {
|
||||
return {
|
||||
name: OPENAPI_TO_FIELD_TYPE_MAP[firstType] ?? firstType,
|
||||
isCollection: false,
|
||||
isPolymorphic: true, // <-- don't forget, polymorphic!
|
||||
};
|
||||
}
|
||||
}
|
||||
} else if (schemaObject.enum) {
|
||||
return { name: 'EnumField', isCollection: false, isPolymorphic: false };
|
||||
} else if (schemaObject.type) {
|
||||
if (schemaObject.type === 'array') {
|
||||
// We need to get the type of the items
|
||||
if (isSchemaObject(schemaObject.items)) {
|
||||
const itemType = schemaObject.items.type;
|
||||
if (!itemType || isArray(itemType)) {
|
||||
throw new UnsupportedFieldTypeError(
|
||||
t('nodes.unsupportedArrayItemType', {
|
||||
type: itemType,
|
||||
})
|
||||
);
|
||||
}
|
||||
// This is an OpenAPI primitive - 'null', 'object', 'array', 'integer', 'number', 'string', 'boolean'
|
||||
const name = OPENAPI_TO_FIELD_TYPE_MAP[itemType];
|
||||
if (!name) {
|
||||
// it's 'null', 'object', or 'array' - skip
|
||||
throw new UnsupportedFieldTypeError(
|
||||
t('nodes.unsupportedArrayItemType', {
|
||||
type: itemType,
|
||||
})
|
||||
);
|
||||
}
|
||||
return {
|
||||
name,
|
||||
isCollection: true, // <-- don't forget, collection!
|
||||
isPolymorphic: false,
|
||||
};
|
||||
}
|
||||
|
||||
// This is a ref object, extract the type name
|
||||
const name = refObjectToSchemaName(schemaObject.items);
|
||||
if (!name) {
|
||||
throw new FieldTypeParseError(
|
||||
t('nodes.unableToExtractSchemaNameFromRef')
|
||||
);
|
||||
}
|
||||
return {
|
||||
name,
|
||||
isCollection: true, // <-- don't forget, collection!
|
||||
isPolymorphic: false,
|
||||
};
|
||||
} else if (!isArray(schemaObject.type)) {
|
||||
// This is an OpenAPI primitive - 'null', 'object', 'array', 'integer', 'number', 'string', 'boolean'
|
||||
const name = OPENAPI_TO_FIELD_TYPE_MAP[schemaObject.type];
|
||||
if (!name) {
|
||||
// it's 'null', 'object', or 'array' - skip
|
||||
throw new UnsupportedFieldTypeError(
|
||||
t('nodes.unsupportedArrayItemType', {
|
||||
type: schemaObject.type,
|
||||
})
|
||||
);
|
||||
}
|
||||
return {
|
||||
name,
|
||||
isCollection: false,
|
||||
isPolymorphic: false,
|
||||
};
|
||||
}
|
||||
}
|
||||
} else if (isRefObject(schemaObject)) {
|
||||
const name = refObjectToSchemaName(schemaObject);
|
||||
if (!name) {
|
||||
throw new FieldTypeParseError(
|
||||
t('nodes.unableToExtractSchemaNameFromRef')
|
||||
);
|
||||
}
|
||||
return {
|
||||
name,
|
||||
isCollection: false,
|
||||
isPolymorphic: false,
|
||||
};
|
||||
}
|
||||
throw new FieldTypeParseError(t('nodes.unableToParseFieldType'));
|
||||
};
|
@ -2,24 +2,24 @@ import { logger } from 'app/logging/logger';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { reduce, startCase } from 'lodash-es';
|
||||
import { OpenAPIV3_1 } from 'openapi-types';
|
||||
import { AnyInvocationType } from 'services/events/types';
|
||||
import { FieldInputTemplate, FieldOutputTemplate } from '../types/field';
|
||||
import { InvocationTemplate } from '../types/invocation';
|
||||
import {
|
||||
InputFieldTemplate,
|
||||
InvocationSchemaObject,
|
||||
InvocationTemplate,
|
||||
OutputFieldTemplate,
|
||||
isFieldType,
|
||||
isInvocationFieldSchema,
|
||||
isInvocationOutputSchemaObject,
|
||||
isInvocationSchemaObject,
|
||||
} from '../types/types';
|
||||
import { buildInputFieldTemplate, getFieldType } from './fieldTemplateBuilders';
|
||||
} from '../types/openapi';
|
||||
import { buildFieldInputTemplate } from './buildFieldInputTemplate';
|
||||
import { parseFieldType } from './parseFieldType';
|
||||
import { FieldTypeParseError, UnsupportedFieldTypeError } from '../types/error';
|
||||
import { t } from 'i18next';
|
||||
|
||||
const RESERVED_INPUT_FIELD_NAMES = ['id', 'type', 'use_cache'];
|
||||
const RESERVED_OUTPUT_FIELD_NAMES = ['type'];
|
||||
const RESERVED_FIELD_TYPES = ['IsIntermediate'];
|
||||
|
||||
const invocationDenylist: AnyInvocationType[] = ['graph', 'linear_ui_output'];
|
||||
const invocationDenylist: string[] = ['graph', 'linear_ui_output'];
|
||||
|
||||
const isReservedInputField = (nodeType: string, fieldName: string) => {
|
||||
if (RESERVED_INPUT_FIELD_NAMES.includes(fieldName)) {
|
||||
@ -83,13 +83,13 @@ export const parseSchema = (
|
||||
const inputs = reduce(
|
||||
schema.properties,
|
||||
(
|
||||
inputsAccumulator: Record<string, InputFieldTemplate>,
|
||||
inputsAccumulator: Record<string, FieldInputTemplate>,
|
||||
property,
|
||||
propertyName
|
||||
) => {
|
||||
if (isReservedInputField(type, propertyName)) {
|
||||
logger('nodes').trace(
|
||||
{ node: type, fieldName: propertyName, field: parseify(property) },
|
||||
{ node: type, field: propertyName, schema: parseify(property) },
|
||||
'Skipped reserved input field'
|
||||
);
|
||||
return inputsAccumulator;
|
||||
@ -97,79 +97,53 @@ export const parseSchema = (
|
||||
|
||||
if (!isInvocationFieldSchema(property)) {
|
||||
logger('nodes').warn(
|
||||
{ node: type, propertyName, property: parseify(property) },
|
||||
{ node: type, field: propertyName, schema: parseify(property) },
|
||||
'Unhandled input property'
|
||||
);
|
||||
return inputsAccumulator;
|
||||
}
|
||||
|
||||
const fieldType = property.ui_type ?? getFieldType(property);
|
||||
try {
|
||||
const fieldType = parseFieldType(property);
|
||||
|
||||
if (!fieldType) {
|
||||
logger('nodes').warn(
|
||||
{
|
||||
node: type,
|
||||
fieldName: propertyName,
|
||||
fieldType,
|
||||
field: parseify(property),
|
||||
},
|
||||
'Missing input field type'
|
||||
if (fieldType.name === 'WorkflowField') {
|
||||
// This supports workflows, set the flag and skip to next field
|
||||
withWorkflow = true;
|
||||
return inputsAccumulator;
|
||||
}
|
||||
|
||||
if (isReservedFieldType(fieldType.name)) {
|
||||
// Skip processing this reserved field
|
||||
return inputsAccumulator;
|
||||
}
|
||||
|
||||
const fieldInputTemplate = buildFieldInputTemplate(
|
||||
property,
|
||||
propertyName,
|
||||
fieldType
|
||||
);
|
||||
return inputsAccumulator;
|
||||
|
||||
inputsAccumulator[propertyName] = fieldInputTemplate;
|
||||
} catch (e) {
|
||||
if (
|
||||
e instanceof FieldTypeParseError ||
|
||||
e instanceof UnsupportedFieldTypeError
|
||||
) {
|
||||
logger('nodes').warn(
|
||||
{
|
||||
node: type,
|
||||
field: propertyName,
|
||||
schema: parseify(property),
|
||||
},
|
||||
t('nodes.inputFieldTypeParseError', {
|
||||
node: type,
|
||||
field: propertyName,
|
||||
message: e.message,
|
||||
})
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if (fieldType === 'WorkflowField') {
|
||||
withWorkflow = true;
|
||||
return inputsAccumulator;
|
||||
}
|
||||
|
||||
if (isReservedFieldType(fieldType)) {
|
||||
logger('nodes').trace(
|
||||
{
|
||||
node: type,
|
||||
fieldName: propertyName,
|
||||
fieldType,
|
||||
field: parseify(property),
|
||||
},
|
||||
`Skipping reserved input field type: ${fieldType}`
|
||||
);
|
||||
return inputsAccumulator;
|
||||
}
|
||||
|
||||
if (!isFieldType(fieldType)) {
|
||||
logger('nodes').warn(
|
||||
{
|
||||
node: type,
|
||||
fieldName: propertyName,
|
||||
fieldType,
|
||||
field: parseify(property),
|
||||
},
|
||||
`Skipping unknown input field type: ${fieldType}`
|
||||
);
|
||||
return inputsAccumulator;
|
||||
}
|
||||
|
||||
const field = buildInputFieldTemplate(
|
||||
schema,
|
||||
property,
|
||||
propertyName,
|
||||
fieldType
|
||||
);
|
||||
|
||||
if (!field) {
|
||||
logger('nodes').warn(
|
||||
{
|
||||
node: type,
|
||||
fieldName: propertyName,
|
||||
fieldType,
|
||||
field: parseify(property),
|
||||
},
|
||||
'Skipping input field with no template'
|
||||
);
|
||||
return inputsAccumulator;
|
||||
}
|
||||
|
||||
inputsAccumulator[propertyName] = field;
|
||||
return inputsAccumulator;
|
||||
},
|
||||
{}
|
||||
@ -206,7 +180,7 @@ export const parseSchema = (
|
||||
(outputsAccumulator, property, propertyName) => {
|
||||
if (!isAllowedOutputField(type, propertyName)) {
|
||||
logger('nodes').trace(
|
||||
{ type, propertyName, property: parseify(property) },
|
||||
{ node: type, field: propertyName, schema: parseify(property) },
|
||||
'Skipped reserved output field'
|
||||
);
|
||||
return outputsAccumulator;
|
||||
@ -214,37 +188,62 @@ export const parseSchema = (
|
||||
|
||||
if (!isInvocationFieldSchema(property)) {
|
||||
logger('nodes').warn(
|
||||
{ type, propertyName, property: parseify(property) },
|
||||
{ node: type, field: propertyName, schema: parseify(property) },
|
||||
'Unhandled output property'
|
||||
);
|
||||
return outputsAccumulator;
|
||||
}
|
||||
|
||||
const fieldType = property.ui_type ?? getFieldType(property);
|
||||
try {
|
||||
const fieldType = parseFieldType(property);
|
||||
|
||||
if (!isFieldType(fieldType)) {
|
||||
logger('nodes').warn(
|
||||
{ fieldName: propertyName, fieldType, field: parseify(property) },
|
||||
'Skipping unknown output field type'
|
||||
);
|
||||
return outputsAccumulator;
|
||||
if (!fieldType) {
|
||||
logger('nodes').warn(
|
||||
{
|
||||
node: type,
|
||||
field: propertyName,
|
||||
schema: parseify(property),
|
||||
},
|
||||
'Missing output field type'
|
||||
);
|
||||
return outputsAccumulator;
|
||||
}
|
||||
|
||||
const fieldOutputTemplate: FieldOutputTemplate = {
|
||||
fieldKind: 'output',
|
||||
name: propertyName,
|
||||
title:
|
||||
property.title ?? (propertyName ? startCase(propertyName) : ''),
|
||||
description: property.description ?? '',
|
||||
type: fieldType,
|
||||
ui_hidden: property.ui_hidden ?? false,
|
||||
ui_type: property.ui_type,
|
||||
ui_order: property.ui_order,
|
||||
};
|
||||
|
||||
outputsAccumulator[propertyName] = fieldOutputTemplate;
|
||||
} catch (e) {
|
||||
if (
|
||||
e instanceof FieldTypeParseError ||
|
||||
e instanceof UnsupportedFieldTypeError
|
||||
) {
|
||||
logger('nodes').warn(
|
||||
{
|
||||
node: type,
|
||||
field: propertyName,
|
||||
schema: parseify(property),
|
||||
},
|
||||
t('nodes.outputFieldTypeParseError', {
|
||||
node: type,
|
||||
field: propertyName,
|
||||
message: e.message,
|
||||
})
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
outputsAccumulator[propertyName] = {
|
||||
fieldKind: 'output',
|
||||
name: propertyName,
|
||||
title:
|
||||
property.title ?? (propertyName ? startCase(propertyName) : ''),
|
||||
description: property.description ?? '',
|
||||
type: fieldType,
|
||||
ui_hidden: property.ui_hidden ?? false,
|
||||
ui_type: property.ui_type,
|
||||
ui_order: property.ui_order,
|
||||
};
|
||||
|
||||
return outputsAccumulator;
|
||||
},
|
||||
{} as Record<string, OutputFieldTemplate>
|
||||
{} as Record<string, FieldOutputTemplate>
|
||||
);
|
||||
|
||||
const useCache = schema.properties.use_cache.default;
|
||||
|
@ -1,123 +1,159 @@
|
||||
import { compareVersions } from 'compare-versions';
|
||||
import { cloneDeep, keyBy } from 'lodash-es';
|
||||
import {
|
||||
InvocationTemplate,
|
||||
Workflow,
|
||||
WorkflowWarning,
|
||||
isWorkflowInvocationNode,
|
||||
} from '../types/types';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import i18n from 'i18next';
|
||||
import { t } from 'i18next';
|
||||
import { keyBy } from 'lodash-es';
|
||||
import { JsonObject } from 'type-fest';
|
||||
import { getNeedsUpdate } from '../store/util/nodeUpdate';
|
||||
import { InvocationTemplate } from '../types/invocation';
|
||||
import { parseAndMigrateWorkflow } from '../types/migration/migrations';
|
||||
import { WorkflowV2, isWorkflowInvocationNode } from '../types/workflow';
|
||||
|
||||
type WorkflowWarning = {
|
||||
message: string;
|
||||
issues?: string[];
|
||||
data: JsonObject;
|
||||
};
|
||||
|
||||
type ValidateWorkflowResult = {
|
||||
workflow: WorkflowV2;
|
||||
warnings: WorkflowWarning[];
|
||||
};
|
||||
|
||||
/**
|
||||
* Parses and validates a workflow:
|
||||
* - Parses the workflow schema, and migrates it to the latest version if necessary.
|
||||
* - Validates the workflow against the node templates, warning if the template is not known.
|
||||
* - Attempts to update nodes which have a mismatched version.
|
||||
* - Removes edges which are invalid.
|
||||
* @param workflow The raw workflow object (e.g. JSON.parse(stringifiedWorklow))
|
||||
* @param invocationTemplates The node templates to validate against.
|
||||
* @throws {WorkflowVersionError} If the workflow version is not recognized.
|
||||
* @throws {z.ZodError} If there is a validation error.
|
||||
*/
|
||||
export const validateWorkflow = (
|
||||
workflow: Workflow,
|
||||
nodeTemplates: Record<string, InvocationTemplate>
|
||||
) => {
|
||||
const clone = cloneDeep(workflow);
|
||||
const { nodes, edges } = clone;
|
||||
const errors: WorkflowWarning[] = [];
|
||||
workflow: unknown,
|
||||
invocationTemplates: Record<string, InvocationTemplate>
|
||||
): ValidateWorkflowResult => {
|
||||
// Parse the raw workflow data & migrate it to the latest version
|
||||
const _workflow = parseAndMigrateWorkflow(workflow);
|
||||
|
||||
// Now we can validate the graph
|
||||
const { nodes, edges } = _workflow;
|
||||
const warnings: WorkflowWarning[] = [];
|
||||
|
||||
// We don't need to validate Note nodes or CurrentImage nodes - only Invocation nodes
|
||||
const invocationNodes = nodes.filter(isWorkflowInvocationNode);
|
||||
const keyedNodes = keyBy(invocationNodes, 'id');
|
||||
nodes.forEach((node) => {
|
||||
if (!isWorkflowInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const nodeTemplate = nodeTemplates[node.data.type];
|
||||
if (!nodeTemplate) {
|
||||
errors.push({
|
||||
message: `${i18n.t('nodes.node')} "${node.data.type}" ${i18n.t(
|
||||
'nodes.skipped'
|
||||
)}`,
|
||||
issues: [
|
||||
`${i18n.t('nodes.nodeType')}"${node.data.type}" ${i18n.t(
|
||||
'nodes.doesNotExist'
|
||||
)}`,
|
||||
],
|
||||
data: node,
|
||||
invocationNodes.forEach((node) => {
|
||||
const template = invocationTemplates[node.data.type];
|
||||
if (!template) {
|
||||
// This node's type template does not exist
|
||||
const message = t('nodes.missingTemplate', {
|
||||
node: node.id,
|
||||
type: node.data.type,
|
||||
});
|
||||
warnings.push({
|
||||
message,
|
||||
data: parseify(node),
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
if (
|
||||
nodeTemplate.version &&
|
||||
node.data.version &&
|
||||
compareVersions(nodeTemplate.version, node.data.version) !== 0
|
||||
) {
|
||||
errors.push({
|
||||
message: `${i18n.t('nodes.node')} "${node.data.type}" ${i18n.t(
|
||||
'nodes.mismatchedVersion'
|
||||
)}`,
|
||||
issues: [
|
||||
`${i18n.t('nodes.node')} "${node.data.type}" v${
|
||||
node.data.version
|
||||
} ${i18n.t('nodes.maybeIncompatible')} v${nodeTemplate.version}`,
|
||||
],
|
||||
data: { node, nodeTemplate: parseify(nodeTemplate) },
|
||||
if (getNeedsUpdate(node, template)) {
|
||||
// This node needs to be updated, based on comparison of its version to the template version
|
||||
const message = t('nodes.mismatchedVersion', {
|
||||
node: node.id,
|
||||
type: node.data.type,
|
||||
});
|
||||
warnings.push({
|
||||
message,
|
||||
data: parseify({ node, nodeTemplate: template }),
|
||||
});
|
||||
return;
|
||||
}
|
||||
});
|
||||
edges.forEach((edge, i) => {
|
||||
// Validate each edge. If the edge is invalid, we must remove it to prevent runtime errors with reactflow.
|
||||
const sourceNode = keyedNodes[edge.source];
|
||||
const targetNode = keyedNodes[edge.target];
|
||||
const issues: string[] = [];
|
||||
|
||||
if (!sourceNode) {
|
||||
// The edge's source/output node does not exist
|
||||
issues.push(
|
||||
`${i18n.t('nodes.outputNode')} ${edge.source} ${i18n.t(
|
||||
'nodes.doesNotExist'
|
||||
)}`
|
||||
t('nodes.sourceNodeDoesNotExist', {
|
||||
node: edge.source,
|
||||
})
|
||||
);
|
||||
} else if (
|
||||
edge.type === 'default' &&
|
||||
!(edge.sourceHandle in sourceNode.data.outputs)
|
||||
) {
|
||||
// The edge's source/output node field does not exist
|
||||
issues.push(
|
||||
`${i18n.t('nodes.outputNode')} "${edge.source}.${
|
||||
edge.sourceHandle
|
||||
}" ${i18n.t('nodes.doesNotExist')}`
|
||||
t('nodes.sourceNodeFieldDoesNotExist', {
|
||||
node: edge.source,
|
||||
field: edge.sourceHandle,
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
if (!targetNode) {
|
||||
// The edge's target/input node does not exist
|
||||
issues.push(
|
||||
`${i18n.t('nodes.inputNode')} ${edge.target} ${i18n.t(
|
||||
'nodes.doesNotExist'
|
||||
)}`
|
||||
t('nodes.targetNodeDoesNotExist', {
|
||||
node: edge.target,
|
||||
})
|
||||
);
|
||||
} else if (
|
||||
edge.type === 'default' &&
|
||||
!(edge.targetHandle in targetNode.data.inputs)
|
||||
) {
|
||||
// The edge's target/input node field does not exist
|
||||
issues.push(
|
||||
`${i18n.t('nodes.inputField')} "${edge.target}.${
|
||||
edge.targetHandle
|
||||
}" ${i18n.t('nodes.doesNotExist')}`
|
||||
t('nodes.targetNodeFieldDoesNotExist', {
|
||||
node: edge.target,
|
||||
field: edge.targetHandle,
|
||||
})
|
||||
);
|
||||
}
|
||||
if (!nodeTemplates[sourceNode?.data.type ?? '__UNKNOWN_NODE_TYPE__']) {
|
||||
|
||||
if (!sourceNode?.data.type || !invocationTemplates[sourceNode.data.type]) {
|
||||
// The edge's source/output node template does not exist
|
||||
issues.push(
|
||||
`${i18n.t('nodes.sourceNode')} "${edge.source}" ${i18n.t(
|
||||
'nodes.missingTemplate'
|
||||
)} "${sourceNode?.data.type}"`
|
||||
t('nodes.missingTemplate', {
|
||||
node: edge.source,
|
||||
type: sourceNode?.data.type,
|
||||
})
|
||||
);
|
||||
}
|
||||
if (!nodeTemplates[targetNode?.data.type ?? '__UNKNOWN_NODE_TYPE__']) {
|
||||
if (!targetNode?.data.type || !invocationTemplates[targetNode?.data.type]) {
|
||||
// The edge's target/input node template does not exist
|
||||
issues.push(
|
||||
`${i18n.t('nodes.sourceNode')}"${edge.target}" ${i18n.t(
|
||||
'nodes.missingTemplate'
|
||||
)} "${targetNode?.data.type}"`
|
||||
t('nodes.missingTemplate', {
|
||||
node: edge.target,
|
||||
type: targetNode?.data.type,
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
if (issues.length) {
|
||||
// This edge has some issues. Remove it.
|
||||
delete edges[i];
|
||||
const src = edge.type === 'default' ? edge.sourceHandle : edge.source;
|
||||
const tgt = edge.type === 'default' ? edge.targetHandle : edge.target;
|
||||
errors.push({
|
||||
message: `Edge "${src} -> ${tgt}" skipped`,
|
||||
const source =
|
||||
edge.type === 'default'
|
||||
? `${edge.source}.${edge.sourceHandle}`
|
||||
: edge.source;
|
||||
const target =
|
||||
edge.type === 'default'
|
||||
? `${edge.source}.${edge.targetHandle}`
|
||||
: edge.target;
|
||||
warnings.push({
|
||||
message: t('nodes.deletedInvalidEdge', { source, target }),
|
||||
issues,
|
||||
data: edge,
|
||||
});
|
||||
}
|
||||
});
|
||||
return { workflow: clone, errors };
|
||||
return { workflow: _workflow, warnings };
|
||||
};
|
||||
|
Reference in New Issue
Block a user