feat(ui): add support for custom field types

Node authors may now create their own arbitrary/custom field types. Any pydantic model is supported.

Two notes:
1. Your field type's class name must be unique.

Suggest prefixing fields with something related to the node pack as a kind of namespace.

2. Custom field types function as connection-only fields.

For example, if your custom field has string attributes, you will not get a text input for that attribute when you give a node a field with your custom type.

This is the same behaviour as other complex fields that don't have custom UIs in the workflow editor - like, say, a string collection.

feat(ui): fix tooltips for custom types

We need to hold onto the original type of the field so they don't all just show up as "Unknown".

fix(ui): fix ts error with custom fields

feat(ui): custom field types connection validation

In the initial commit, a custom field's original type was added to the *field templates* only as `originalType`. Custom fields' `type` property was `"Custom"`*. This allowed for type safety throughout the UI logic.

*Actually, it was `"Unknown"`, but I changed it to custom for clarity.

Connection validation logic, however, uses the *field instance* of the node/field. Like the templates, *field instances* with custom types have their `type` set to `"Custom"`, but they didn't have an `originalType` property. As a result, all custom fields could be connected to all other custom fields.

To resolve this, we need to add `originalType` to the *field instances*, then switch the validation logic to use this instead of `type`.

This ended up needing a bit of fanagling:

- If we make `originalType` a required property on field instances, existing workflows will break during connection validation, because they won't have this property. We'd need a new layer of logic to migrate the workflows, adding the new `originalType` property.

While this layer is probably needed anyways, typing `originalType` as optional is much simpler. Workflow migration logic can come layer.

(Technically, we could remove all references to field types from the workflow files, and let the templates hold all this information. This feels like a significant change and I'm reluctant to do it now.)

- Because `originalType` is optional, anywhere we care about the type of a field, we need to use it over `type`. So there are a number of `field.originalType ?? field.type` expressions. This is a bit of a gotcha, we'll need to remember this in the future.

- We use `Array.prototype.includes()` often in the workflow editor, e.g. `COLLECTION_TYPES.includes(type)`. In these cases, the const array is of type `FieldType[]`, and `type` is is `FieldType`.

Because we now support custom types, the arg `type` is now widened from `FieldType` to `string`.

This causes a TS error. This behaviour is somewhat controversial (see https://github.com/microsoft/TypeScript/issues/14520). These expressions are now rewritten as `COLLECTION_TYPES.some((t) => t === type)` to satisfy TS. It's logically equivalent.

fix(ui): typo

feat(ui): add CustomCollection and CustomPolymorphic field types

feat(ui): add validation for CustomCollection & CustomPolymorphic types

- Update connection validation for custom types
- Use simple string parsing to determine if a field is a collection or polymorphic type.
- No longer need to keep a list of collection and polymorphic types.
- Added runtime checks in `baseinvocation.py` to ensure no fields are named in such a way that it could mess up the new parsing

chore(ui): remove errant console.log

fix(ui): rename 'nodes.currentConnectionFieldType' -> 'nodes.connectionStartFieldType'

This was confusingly named and kept tripping me up. Renamed to be consistent with the `reactflow` `ConnectionStartParams` type.

fix(ui): fix ts error

feat(nodes): add runtime check for custom field names

"Custom", "CustomCollection" and "CustomPolymorphic" are reserved field names.

chore(ui): add TODO for revising field type names

wip refactor fieldtype structured

wip refactor field types

wip refactor types

wip refactor types

fix node layout

refactor field types

chore: mypy

organisation

organisation

organisation

fix(nodes): fix field orig_required, field_kind and input statuses

feat(nodes): remove broken implementation of default_factory on InputField

Use of this could break connection validation due to the difference in node schemas required fields and invoke() required args.

Removed entirely for now. It wasn't ever actually used by the system, because all graphs always had values provided for fields where default_factory was used.

Also, pydantic is smart enough to not reuse the same object when specifying a default value - it clones the object first. So, the common pattern of `default_factory=list` is extraneous. It can just be `default=[]`.

fix(nodes): fix InputField name validation

workflow validation

validation

chore: ruff

feat(nodes): fix up baseinvocation comments

fix(ui): improve typing & logic of buildFieldInputTemplate

improved error handling in parseFieldType

fix: back compat for deprecated default_factory and UIType

feat(nodes): do not show node packs loaded log if none loaded

chore(ui): typegen
This commit is contained in:
psychedelicious
2023-11-17 11:32:35 +11:00
parent 0d52430481
commit 86a74e929a
186 changed files with 5713 additions and 5704 deletions

View File

@ -805,6 +805,8 @@
"clipField": "Clip",
"clipFieldDescription": "Tokenizer and text_encoder submodels.",
"collection": "Collection",
"collectionFieldType": "{{name}} Collection",
"polymorphicFieldType": "{{name}} Polymorphic",
"collectionDescription": "TODO",
"collectionItem": "Collection Item",
"collectionItemDescription": "TODO",
@ -891,10 +893,15 @@
"mainModelField": "Model",
"mainModelFieldDescription": "TODO",
"maybeIncompatible": "May be Incompatible With Installed",
"mismatchedVersion": "Has Mismatched Version",
"mismatchedVersion": "Invalid node: node {{node}} of type {{type}} has mismatched version (try updating?)",
"missingCanvaInitImage": "Missing canvas init image",
"missingCanvaInitMaskImages": "Missing canvas init and mask images",
"missingTemplate": "Missing Template",
"missingTemplate": "Invalid node: node {{node}} of type {{type}} missing template (not installed?)",
"sourceNodeDoesNotExist": "Invalid edge: source/output node {{node}} does not exist",
"targetNodeDoesNotExist": "Invalid edge: target/input node {{node}} does not exist",
"sourceNodeFieldDoesNotExist": "Invalid edge: source/output field {{node}}.{{field}} does not exist",
"targetNodeFieldDoesNotExist": "Invalid edge: target/input field {{node}}.{{field}} does not exist",
"deletedInvalidEdge": "Deleted invalid edge {{source}} -> {{target}}",
"noConnectionData": "No connection data",
"noConnectionInProgress": "No connection in progress",
"node": "Node",
@ -954,10 +961,17 @@
"stringDescription": "Strings are text.",
"stringPolymorphic": "String Polymorphic",
"stringPolymorphicDescription": "A collection of strings.",
"unableToLoadWorkflow": "Unable to Validate Workflow",
"unableToLoadWorkflow": "Unable to Load Workflow",
"unableToParseEdge": "Unable to parse edge",
"unableToParseNode": "Unable to parse node",
"unableToUpdateNode": "Unable to update node",
"unableToValidateWorkflow": "Unable to Validate Workflow",
"unknownErrorValidatingWorkflow": "Unknown error validating workflow",
"inputFieldTypeParseError": "Unable to parse type of input field {{node}}.{{field}} ({{message}})",
"outputFieldTypeParseError": "Unable to parse type of output field {{node}}.{{field}} ({{message}})",
"unableToExtractSchemaNameFromRef": "unable to extract schema name from ref",
"unsupportedArrayItemType": "unsupported array item type \"{{type}}\"",
"unableToParseFieldType": "unable to parse field type",
"uNetField": "UNet",
"uNetFieldDescription": "UNet submodel.",
"unhandledInputProperty": "Unhandled input property",
@ -971,8 +985,9 @@
"unkownInvocation": "Unknown Invocation type",
"unknownOutput": "Unknown output",
"updateNode": "Update Node",
"updateAllNodes": "Update All Nodes",
"updateApp": "Update App",
"updateAllNodes": "Update All Nodes",
"allNodesUpdated": "All Nodes Updated",
"unableToUpdateNodes_one": "Unable to update {{count}} node",
"unableToUpdateNodes_other": "Unable to update {{count}} nodes",
"vaeField": "Vae",
@ -981,6 +996,8 @@
"vaeModelFieldDescription": "TODO",
"validateConnections": "Validate Connections and Graph",
"validateConnectionsHelp": "Prevent invalid connections from being made, and invalid graphs from being invoked",
"unableToGetWorkflowVersion": "Unable to get workflow schema version",
"unrecognizedWorkflowVersion": "Unrecognized workflow schema version {{version}}",
"version": "Version",
"versionUnknown": " Version Unknown",
"workflow": "Workflow",

View File

@ -71,7 +71,7 @@ import { addSocketUnsubscribedEventListener as addSocketUnsubscribedListener } f
import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSaved';
import { addTabChangedListener } from './listeners/tabChanged';
import { addUpscaleRequestedListener } from './listeners/upscaleRequested';
import { addWorkflowLoadedListener } from './listeners/workflowLoaded';
import { addWorkflowLoadRequestedListener } from './listeners/workflowLoadRequested';
import { addUpdateAllNodesRequestedListener } from './listeners/updateAllNodesRequested';
export const listenerMiddleware = createListenerMiddleware();
@ -178,7 +178,7 @@ addBoardIdSelectedListener();
addReceivedOpenAPISchemaListener();
// Workflows
addWorkflowLoadedListener();
addWorkflowLoadRequestedListener();
addUpdateAllNodesRequestedListener();
// DND

View File

@ -12,10 +12,10 @@ import { addToast } from 'features/system/store/systemSlice';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
import { queueApi } from 'services/api/endpoints/queue';
import { isImageOutput } from 'services/api/guards';
import { BatchConfig, ImageDTO } from 'services/api/types';
import { socketInvocationComplete } from 'services/events/actions';
import { startAppListening } from '..';
import { isImageOutput } from 'features/nodes/types/common';
export const addControlNetImageProcessedListener = () => {
startAppListening({

View File

@ -5,19 +5,20 @@ import {
controlAdapterProcessedImageChanged,
selectControlAdapterAll,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions';
import { isModalOpenChanged } from 'features/deleteImageModal/store/slice';
import { selectListImagesBaseQueryArgs } from 'features/gallery/store/gallerySelectors';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { isInvocationNode } from 'features/nodes/types/types';
import { isImageFieldInputInstance } from 'features/nodes/types/field';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { clearInitialImage } from 'features/parameters/store/generationSlice';
import { clamp, forEach } from 'lodash-es';
import { api } from 'services/api';
import { imagesApi } from 'services/api/endpoints/images';
import { imagesAdapter } from 'services/api/util';
import { startAppListening } from '..';
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
export const addRequestedSingleImageDeletionListener = () => {
startAppListening({
@ -121,7 +122,7 @@ export const addRequestedSingleImageDeletionListener = () => {
forEach(node.data.inputs, (input) => {
if (
input.type === 'ImageField' &&
isImageFieldInputInstance(input) &&
input.value?.image_name === imageDTO.image_name
) {
dispatch(
@ -241,7 +242,7 @@ export const addRequestedMultipleImageDeletionListener = () => {
forEach(node.data.inputs, (input) => {
if (
input.type === 'ImageField' &&
isImageFieldInputInstance(input) &&
input.value?.image_name === imageDTO.image_name
) {
dispatch(

View File

@ -12,12 +12,12 @@ import {
setWidth,
vaeSelected,
} from 'features/parameters/store/generationSlice';
import { zMainOrOnnxModel } from 'features/parameters/types/parameterSchemas';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next';
import { forEach } from 'lodash-es';
import { startAppListening } from '..';
import { zParameterModel } from 'features/parameters/types/parameterSchemas';
export const addModelSelectedListener = () => {
startAppListening({
@ -26,7 +26,7 @@ export const addModelSelectedListener = () => {
const log = logger('models');
const state = getState();
const result = zMainOrOnnxModel.safeParse(action.payload);
const result = zParameterModel.safeParse(action.payload);
if (!result.success) {
log.error(

View File

@ -11,9 +11,9 @@ import {
vaeSelected,
} from 'features/parameters/store/generationSlice';
import {
zMainOrOnnxModel,
zSDXLRefinerModel,
zVaeModel,
zParameterModel,
zParameterSDXLRefinerModel,
zParameterVAEModel,
} from 'features/parameters/types/parameterSchemas';
import {
refinerModelChanged,
@ -67,7 +67,7 @@ export const addModelsLoadedListener = () => {
return;
}
const result = zMainOrOnnxModel.safeParse(models[0]);
const result = zParameterModel.safeParse(models[0]);
if (!result.success) {
log.error(
@ -119,7 +119,7 @@ export const addModelsLoadedListener = () => {
return;
}
const result = zSDXLRefinerModel.safeParse(models[0]);
const result = zParameterSDXLRefinerModel.safeParse(models[0]);
if (!result.success) {
log.error(
@ -170,7 +170,7 @@ export const addModelsLoadedListener = () => {
return;
}
const result = zVaeModel.safeParse(firstModel);
const result = zParameterVAEModel.safeParse(firstModel);
if (!result.success) {
log.error(

View File

@ -15,6 +15,7 @@ export const addReceivedOpenAPISchemaListener = () => {
log.debug({ schemaJSON }, 'Received OpenAPI schema');
const { nodesAllowlist, nodesDenylist } = getState().config;
const nodeTemplates = parseSchema(
schemaJSON,
nodesAllowlist,

View File

@ -13,13 +13,13 @@ import {
} from 'features/nodes/util/graphBuilders/constants';
import { boardsApi } from 'services/api/endpoints/boards';
import { imagesApi } from 'services/api/endpoints/images';
import { isImageOutput } from 'services/api/guards';
import { imagesAdapter } from 'services/api/util';
import {
appSocketInvocationComplete,
socketInvocationComplete,
} from 'services/events/actions';
import { startAppListening } from '../..';
import { isImageOutput } from 'features/nodes/types/common';
// These nodes output an image, but do not actually *save* an image, so we don't want to handle the gallery logic on them
const nodeTypeDenylist = ['load_image', 'image'];

View File

@ -1,14 +1,16 @@
import { logger } from 'app/logging/logger';
import { updateAllNodesRequested } from 'features/nodes/store/actions';
import { nodeReplaced } from 'features/nodes/store/nodesSlice';
import {
getNeedsUpdate,
updateNode,
} from 'features/nodes/hooks/useNodeVersion';
import { updateAllNodesRequested } from 'features/nodes/store/actions';
import { nodeReplaced } from 'features/nodes/store/nodesSlice';
import { startAppListening } from '..';
import { logger } from 'app/logging/logger';
} from 'features/nodes/store/util/nodeUpdate';
import { NodeUpdateError } from 'features/nodes/types/error';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next';
import { startAppListening } from '..';
export const addUpdateAllNodesRequestedListener = () => {
startAppListening({
@ -20,22 +22,31 @@ export const addUpdateAllNodesRequestedListener = () => {
let unableToUpdateCount = 0;
nodes.forEach((node) => {
nodes.filter(isInvocationNode).forEach((node) => {
const template = templates[node.data.type];
const needsUpdate = getNeedsUpdate(node, template);
const updatedNode = updateNode(node, template);
if (!updatedNode) {
if (needsUpdate) {
unableToUpdateCount++;
}
if (!template) {
unableToUpdateCount++;
return;
}
dispatch(nodeReplaced({ nodeId: updatedNode.id, node: updatedNode }));
if (!getNeedsUpdate(node, template)) {
// No need to increment the count here, since we're not actually updating
return;
}
try {
const updatedNode = updateNode(node, template);
dispatch(nodeReplaced({ nodeId: updatedNode.id, node: updatedNode }));
} catch (e) {
if (e instanceof NodeUpdateError) {
unableToUpdateCount++;
}
}
});
if (unableToUpdateCount) {
log.warn(
`Unable to update ${unableToUpdateCount} nodes. Please report this issue.`
t('nodes.unableToUpdateNodes', {
count: unableToUpdateCount,
})
);
dispatch(
addToast(
@ -46,6 +57,15 @@ export const addUpdateAllNodesRequestedListener = () => {
})
)
);
} else {
dispatch(
addToast(
makeToast({
title: t('nodes.allNodesUpdated'),
status: 'success',
})
)
);
}
},
});

View File

@ -0,0 +1,105 @@
import { logger } from 'app/logging/logger';
import { parseify } from 'common/util/serialize';
import { workflowLoadRequested } from 'features/nodes/store/actions';
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
import { $flow } from 'features/nodes/store/reactFlowInstance';
import { WorkflowVersionError } from 'features/nodes/types/error';
import { validateWorkflow } from 'features/nodes/util/validateWorkflow';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { t } from 'i18next';
import { z } from 'zod';
import { fromZodError } from 'zod-validation-error';
import { startAppListening } from '..';
export const addWorkflowLoadRequestedListener = () => {
startAppListening({
actionCreator: workflowLoadRequested,
effect: (action, { dispatch, getState }) => {
const log = logger('nodes');
const workflow = action.payload;
const nodeTemplates = getState().nodes.nodeTemplates;
try {
const { workflow: validatedWorkflow, warnings } = validateWorkflow(
workflow,
nodeTemplates
);
dispatch(workflowLoaded(validatedWorkflow));
if (!warnings.length) {
dispatch(
addToast(
makeToast({
title: t('toast.workflowLoaded'),
status: 'success',
})
)
);
} else {
dispatch(
addToast(
makeToast({
title: t('toast.loadedWithWarnings'),
status: 'warning',
})
)
);
warnings.forEach(({ message, ...rest }) => {
log.warn(rest, message);
});
}
dispatch(setActiveTab('nodes'));
requestAnimationFrame(() => {
$flow.get()?.fitView();
});
} catch (e) {
if (e instanceof WorkflowVersionError) {
// The workflow version was not recognized in the valid list of versions
log.error({ error: parseify(e) }, e.message);
dispatch(
addToast(
makeToast({
title: t('nodes.unableToValidateWorkflow'),
status: 'error',
description: e.message,
})
)
);
} else if (e instanceof z.ZodError) {
// There was a problem validating the workflow itself
const { message } = fromZodError(e, {
prefix: t('nodes.workflowValidation'),
});
log.error({ error: parseify(e) }, message);
dispatch(
addToast(
makeToast({
title: t('nodes.unableToValidateWorkflow'),
status: 'error',
description: message,
})
)
);
} else {
// Some other error occurred
console.log(e);
log.error(
{ error: parseify(e) },
t('nodes.unknownErrorValidatingWorkflow')
);
dispatch(
addToast(
makeToast({
title: t('nodes.unableToValidateWorkflow'),
status: 'error',
description: t('nodes.unknownErrorValidatingWorkflow'),
})
)
);
}
}
},
});
};

View File

@ -1,56 +0,0 @@
import { logger } from 'app/logging/logger';
import { workflowLoadRequested } from 'features/nodes/store/actions';
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
import { $flow } from 'features/nodes/store/reactFlowInstance';
import { validateWorkflow } from 'features/nodes/util/validateWorkflow';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { startAppListening } from '..';
import { t } from 'i18next';
export const addWorkflowLoadedListener = () => {
startAppListening({
actionCreator: workflowLoadRequested,
effect: (action, { dispatch, getState }) => {
const log = logger('nodes');
const workflow = action.payload;
const nodeTemplates = getState().nodes.nodeTemplates;
const { workflow: validatedWorkflow, errors } = validateWorkflow(
workflow,
nodeTemplates
);
dispatch(workflowLoaded(validatedWorkflow));
if (!errors.length) {
dispatch(
addToast(
makeToast({
title: t('toast.workflowLoaded'),
status: 'success',
})
)
);
} else {
dispatch(
addToast(
makeToast({
title: t('toast.loadedWithWarnings'),
status: 'warning',
})
)
);
errors.forEach(({ message, ...rest }) => {
log.warn(rest, message);
});
}
dispatch(setActiveTab('nodes'));
requestAnimationFrame(() => {
$flow.get()?.fitView();
});
},
});
};

View File

@ -4,7 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { selectControlAdapterAll } from 'features/controlAdapters/store/controlAdaptersSlice';
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
import { isInvocationNode } from 'features/nodes/types/types';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import i18n from 'i18next';
import { forEach } from 'lodash-es';

View File

@ -6,9 +6,9 @@ import {
isAnyOf,
} from '@reduxjs/toolkit';
import {
ControlNetModelParam,
IPAdapterModelParam,
T2IAdapterModelParam,
ParameterControlNetModel,
ParameterIPAdapterModel,
ParameterT2IAdapterModel,
} from 'features/parameters/types/parameterSchemas';
import { cloneDeep, merge, uniq } from 'lodash-es';
import { appSocketInvocationError } from 'services/events/actions';
@ -243,9 +243,9 @@ export const controlAdaptersSlice = createSlice({
action: PayloadAction<{
id: string;
model:
| ControlNetModelParam
| T2IAdapterModelParam
| IPAdapterModelParam;
| ParameterControlNetModel
| ParameterT2IAdapterModel
| ParameterIPAdapterModel;
}>
) => {
const { id, model } = action.payload;

View File

@ -1,8 +1,8 @@
import { EntityState } from '@reduxjs/toolkit';
import {
ControlNetModelParam,
IPAdapterModelParam,
T2IAdapterModelParam,
ParameterControlNetModel,
ParameterIPAdapterModel,
ParameterT2IAdapterModel,
} from 'features/parameters/types/parameterSchemas';
import { isObject } from 'lodash-es';
import { components } from 'services/api/schema';
@ -378,7 +378,7 @@ export type ControlNetConfig = {
type: 'controlnet';
id: string;
isEnabled: boolean;
model: ControlNetModelParam | null;
model: ParameterControlNetModel | null;
weight: number;
beginStepPct: number;
endStepPct: number;
@ -395,7 +395,7 @@ export type T2IAdapterConfig = {
type: 't2i_adapter';
id: string;
isEnabled: boolean;
model: T2IAdapterModelParam | null;
model: ParameterT2IAdapterModel | null;
weight: number;
beginStepPct: number;
endStepPct: number;
@ -412,7 +412,7 @@ export type IPAdapterConfig = {
id: string;
isEnabled: boolean;
controlImage: string | null;
model: IPAdapterModelParam | null;
model: ParameterIPAdapterModel | null;
weight: number;
beginStepPct: number;
endStepPct: number;

View File

@ -1,11 +1,12 @@
import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { isInvocationNode } from 'features/nodes/types/types';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { some } from 'lodash-es';
import { ImageUsage } from './types';
import { selectControlAdapterAll } from 'features/controlAdapters/store/controlAdaptersSlice';
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
import { isImageFieldInputInstance } from 'features/nodes/types/field';
export const getImageUsage = (state: RootState, image_name: string) => {
const { generation, canvas, nodes, controlAdapters } = state;
@ -19,7 +20,8 @@ export const getImageUsage = (state: RootState, image_name: string) => {
return some(
node.data.inputs,
(input) =>
input.type === 'ImageField' && input.value?.image_name === image_name
isImageFieldInputInstance(input) &&
input.value?.image_name === image_name
);
});

View File

@ -11,9 +11,9 @@ import {
useDroppable as useOriginalDroppable,
} from '@dnd-kit/core';
import {
InputFieldTemplate,
InputFieldValue,
} from 'features/nodes/types/types';
FieldInputTemplate,
FieldInputInstance,
} from 'features/nodes/types/field';
import { ImageDTO } from 'services/api/types';
type BaseDropData = {
@ -93,8 +93,8 @@ export type NodeFieldDraggableData = BaseDragData & {
payloadType: 'NODE_FIELD';
payload: {
nodeId: string;
field: InputFieldValue;
fieldTemplate: InputFieldTemplate;
field: FieldInputInstance;
fieldTemplate: FieldInputTemplate;
};
};

View File

@ -4,14 +4,14 @@ import {
LoRAMetadataItem,
IPAdapterMetadataItem,
T2IAdapterMetadataItem,
} from 'features/nodes/types/types';
} from 'features/nodes/types/metadata';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { memo, useMemo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import {
isValidControlNetModel,
isValidLoRAModel,
isValidT2IAdapterModel,
isParameterControlNetModel,
isParameterLoRAModel,
isParameterT2IAdapterModel,
} from '../../../parameters/types/parameterSchemas';
import ImageMetadataItem from './ImageMetadataItem';
@ -132,7 +132,7 @@ const ImageMetadataActions = (props: Props) => {
const validControlNets: ControlNetMetadataItem[] = useMemo(() => {
return metadata?.controlnets
? metadata.controlnets.filter((controlnet) =>
isValidControlNetModel(controlnet.control_model)
isParameterControlNetModel(controlnet.control_model)
)
: [];
}, [metadata?.controlnets]);
@ -140,7 +140,7 @@ const ImageMetadataActions = (props: Props) => {
const validIPAdapters: IPAdapterMetadataItem[] = useMemo(() => {
return metadata?.ipAdapters
? metadata.ipAdapters.filter((ipAdapter) =>
isValidControlNetModel(ipAdapter.ip_adapter_model)
isParameterControlNetModel(ipAdapter.ip_adapter_model)
)
: [];
}, [metadata?.ipAdapters]);
@ -148,7 +148,7 @@ const ImageMetadataActions = (props: Props) => {
const validT2IAdapters: T2IAdapterMetadataItem[] = useMemo(() => {
return metadata?.t2iAdapters
? metadata.t2iAdapters.filter((t2iAdapter) =>
isValidT2IAdapterModel(t2iAdapter.t2i_adapter_model)
isParameterT2IAdapterModel(t2iAdapter.t2i_adapter_model)
)
: [];
}, [metadata?.t2iAdapters]);
@ -157,8 +157,6 @@ const ImageMetadataActions = (props: Props) => {
return null;
}
console.log(metadata);
return (
<>
{metadata.created_by && (
@ -275,7 +273,7 @@ const ImageMetadataActions = (props: Props) => {
)}
{metadata.loras &&
metadata.loras.map((lora, index) => {
if (isValidLoRAModel(lora.lora)) {
if (isParameterLoRAModel(lora.lora)) {
return (
<ImageMetadataItem
key={index}

View File

@ -1,8 +1,8 @@
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import { LoRAModelParam } from 'features/parameters/types/parameterSchemas';
import { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
import { LoRAModelConfigEntity } from 'services/api/endpoints/models';
export type LoRA = LoRAModelParam & {
export type LoRA = ParameterLoRAModel & {
id: string;
weight: number;
};

View File

@ -24,7 +24,6 @@ import { useHotkeys } from 'react-hotkeys-hook';
import { HotkeyCallback } from 'react-hotkeys-hook/dist/types';
import { useTranslation } from 'react-i18next';
import 'reactflow/dist/style.css';
import { AnyInvocationType } from 'services/events/types';
import { AddNodePopoverSelectItem } from './AddNodePopoverSelectItem';
type NodeTemplate = {
@ -57,7 +56,7 @@ const AddNodePopover = () => {
const { t } = useTranslation();
const fieldFilter = useAppSelector(
(state) => state.nodes.currentConnectionFieldType
(state) => state.nodes.connectionStartFieldType
);
const handleFilter = useAppSelector(
(state) => state.nodes.connectionStartParams?.handleType
@ -111,7 +110,7 @@ const AddNodePopover = () => {
data.sort((a, b) => a.label.localeCompare(b.label));
return { data, t };
return { data };
},
defaultSelectorOptions
);
@ -121,7 +120,7 @@ const AddNodePopover = () => {
const inputRef = useRef<HTMLInputElement>(null);
const addNode = useCallback(
(nodeType: AnyInvocationType) => {
(nodeType: string) => {
const invocation = buildInvocation(nodeType);
if (!invocation) {
const errorMessage = t('nodes.unknownNode', {
@ -145,7 +144,7 @@ const AddNodePopover = () => {
return;
}
addNode(v as AnyInvocationType);
addNode(v);
},
[addNode]
);

View File

@ -2,18 +2,17 @@ import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { FIELDS } from 'features/nodes/types/constants';
import { memo } from 'react';
import { ConnectionLineComponentProps, getBezierPath } from 'reactflow';
import { getFieldColor } from '../edges/util/getEdgeColor';
const selector = createSelector(stateSelector, ({ nodes }) => {
const { shouldAnimateEdges, currentConnectionFieldType, shouldColorEdges } =
const { shouldAnimateEdges, connectionStartFieldType, shouldColorEdges } =
nodes;
const stroke =
currentConnectionFieldType && shouldColorEdges
? colorTokenToCssVar(FIELDS[currentConnectionFieldType].color)
: colorTokenToCssVar('base.500');
const stroke = shouldColorEdges
? getFieldColor(connectionStartFieldType)
: colorTokenToCssVar('base.500');
let className = 'react-flow__custom_connection-path';

View File

@ -0,0 +1,12 @@
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { FIELD_COLORS } from 'features/nodes/types/constants';
import { FieldType } from 'features/nodes/types/field';
export const getFieldColor = (fieldType: FieldType | null): string => {
if (!fieldType) {
return colorTokenToCssVar('base.500');
}
const color = FIELD_COLORS[fieldType.name];
return color ? colorTokenToCssVar(color) : colorTokenToCssVar('base.500');
};

View File

@ -2,8 +2,8 @@ import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { FIELDS } from 'features/nodes/types/constants';
import { isInvocationNode } from 'features/nodes/types/types';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { getFieldColor } from './getEdgeColor';
export const makeEdgeSelector = (
source: string,
@ -29,7 +29,7 @@ export const makeEdgeSelector = (
const stroke =
sourceType && nodes.shouldColorEdges
? colorTokenToCssVar(FIELDS[sourceType].color)
? getFieldColor(sourceType)
: colorTokenToCssVar('base.500');
return {

View File

@ -1,7 +1,7 @@
import { useColorModeValue } from '@chakra-ui/react';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import { useNodeData } from 'features/nodes/hooks/useNodeData';
import { isInvocationNodeData } from 'features/nodes/types/types';
import { isInvocationNodeData } from 'features/nodes/types/invocation';
import { map } from 'lodash-es';
import { CSSProperties, memo, useMemo } from 'react';
import { Handle, Position } from 'reactflow';

View File

@ -2,8 +2,8 @@ import { Flex, Icon, Text, Tooltip } from '@chakra-ui/react';
import { compare } from 'compare-versions';
import { useNodeData } from 'features/nodes/hooks/useNodeData';
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { useNodeVersion } from 'features/nodes/hooks/useNodeVersion';
import { isInvocationNodeData } from 'features/nodes/types/types';
import { useNodeNeedsUpdate } from 'features/nodes/hooks/useNodeNeedsUpdate';
import { isInvocationNodeData } from 'features/nodes/types/invocation';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { FaInfoCircle } from 'react-icons/fa';
@ -13,7 +13,7 @@ interface Props {
}
const InvocationNodeInfoIcon = ({ nodeId }: Props) => {
const { needsUpdate } = useNodeVersion(nodeId);
const needsUpdate = useNodeNeedsUpdate(nodeId);
return (
<Tooltip

View File

@ -11,7 +11,10 @@ import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import { NodeExecutionState, NodeStatus } from 'features/nodes/types/types';
import {
NodeExecutionState,
zNodeStatus,
} from 'features/nodes/types/invocation';
import { memo, useMemo } from 'react';
import { FaCheck, FaEllipsisH, FaExclamation } from 'react-icons/fa';
import { useTranslation } from 'react-i18next';
@ -74,10 +77,10 @@ type TooltipLabelProps = {
const TooltipLabel = memo(({ nodeExecutionState }: TooltipLabelProps) => {
const { status, progress, progressImage } = nodeExecutionState;
const { t } = useTranslation();
if (status === NodeStatus.PENDING) {
if (status === zNodeStatus.enum.PENDING) {
return <Text>{t('queue.pending')}</Text>;
}
if (status === NodeStatus.IN_PROGRESS) {
if (status === zNodeStatus.enum.IN_PROGRESS) {
if (progressImage) {
return (
<Flex sx={{ pos: 'relative', pt: 1.5, pb: 0.5 }}>
@ -108,11 +111,11 @@ const TooltipLabel = memo(({ nodeExecutionState }: TooltipLabelProps) => {
return <Text>{t('nodes.executionStateInProgress')}</Text>;
}
if (status === NodeStatus.COMPLETED) {
if (status === zNodeStatus.enum.COMPLETED) {
return <Text>{t('nodes.executionStateCompleted')}</Text>;
}
if (status === NodeStatus.FAILED) {
if (status === zNodeStatus.enum.FAILED) {
return <Text>{t('nodes.executionStateError')}</Text>;
}
@ -127,7 +130,7 @@ type StatusIconProps = {
const StatusIcon = memo((props: StatusIconProps) => {
const { progress, status } = props.nodeExecutionState;
if (status === NodeStatus.PENDING) {
if (status === zNodeStatus.enum.PENDING) {
return (
<Icon
as={FaEllipsisH}
@ -139,7 +142,7 @@ const StatusIcon = memo((props: StatusIconProps) => {
/>
);
}
if (status === NodeStatus.IN_PROGRESS) {
if (status === zNodeStatus.enum.IN_PROGRESS) {
return progress === null ? (
<CircularProgress
isIndeterminate
@ -158,7 +161,7 @@ const StatusIcon = memo((props: StatusIconProps) => {
/>
);
}
if (status === NodeStatus.COMPLETED) {
if (status === zNodeStatus.enum.COMPLETED) {
return (
<Icon
as={FaCheck}
@ -170,7 +173,7 @@ const StatusIcon = memo((props: StatusIconProps) => {
/>
);
}
if (status === NodeStatus.FAILED) {
if (status === zNodeStatus.enum.FAILED) {
return (
<Icon
as={FaExclamation}

View File

@ -1,7 +1,7 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { InvocationNodeData } from 'features/nodes/types/types';
import { InvocationNodeData } from 'features/nodes/types/invocation';
import { memo, useMemo } from 'react';
import { NodeProps } from 'reactflow';
import InvocationNode from '../Invocation/InvocationNode';

View File

@ -3,7 +3,7 @@ import { useAppDispatch } from 'app/store/storeHooks';
import IAITextarea from 'common/components/IAITextarea';
import { useNodeData } from 'features/nodes/hooks/useNodeData';
import { nodeNotesChanged } from 'features/nodes/store/nodesSlice';
import { isInvocationNodeData } from 'features/nodes/types/types';
import { isInvocationNodeData } from 'features/nodes/types/invocation';
import { ChangeEvent, memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';

View File

@ -56,7 +56,7 @@ const FieldContextMenu = ({ nodeId, fieldName, kind, children }: Props) => {
);
const mayExpose = useMemo(
() => ['any', 'direct'].includes(input ?? '__UNKNOWN_INPUT__'),
() => input && ['any', 'direct'].includes(input),
[input]
);

View File

@ -1,18 +1,17 @@
import { Tooltip } from '@chakra-ui/react';
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
import {
COLLECTION_TYPES,
FIELDS,
HANDLE_TOOLTIP_OPEN_DELAY,
MODEL_TYPES,
POLYMORPHIC_TYPES,
} from 'features/nodes/types/constants';
import {
InputFieldTemplate,
OutputFieldTemplate,
} from 'features/nodes/types/types';
FieldInputTemplate,
FieldOutputTemplate,
} from 'features/nodes/types/field';
import { CSSProperties, memo, useMemo } from 'react';
import { Handle, HandleType, Position } from 'reactflow';
import { getFieldColor } from '../../../edges/util/getEdgeColor';
export const handleBaseStyles: CSSProperties = {
position: 'absolute',
@ -32,11 +31,11 @@ export const outputHandleStyles: CSSProperties = {
};
type FieldHandleProps = {
fieldTemplate: InputFieldTemplate | OutputFieldTemplate;
fieldTemplate: FieldInputTemplate | FieldOutputTemplate;
handleType: HandleType;
isConnectionInProgress: boolean;
isConnectionStartField: boolean;
connectionError: string | null;
connectionError?: string;
};
const FieldHandle = (props: FieldHandleProps) => {
@ -47,23 +46,21 @@ const FieldHandle = (props: FieldHandleProps) => {
isConnectionStartField,
connectionError,
} = props;
const { name, type } = fieldTemplate;
const { color: typeColor, title } = FIELDS[type];
const { name } = fieldTemplate;
const type = fieldTemplate.type;
const fieldTypeName = useFieldTypeName(type);
const styles: CSSProperties = useMemo(() => {
const isCollectionType = COLLECTION_TYPES.includes(type);
const isPolymorphicType = POLYMORPHIC_TYPES.includes(type);
const isModelType = MODEL_TYPES.includes(type);
const color = colorTokenToCssVar(typeColor);
const isModelType = MODEL_TYPES.some((t) => t === type.name);
const color = getFieldColor(type);
const s: CSSProperties = {
backgroundColor:
isCollectionType || isPolymorphicType
? 'var(--invokeai-colors-base-900)'
type.isCollection || type.isPolymorphic
? colorTokenToCssVar('base.900')
: color,
position: 'absolute',
width: '1rem',
height: '1rem',
borderWidth: isCollectionType || isPolymorphicType ? 4 : 0,
borderWidth: type.isCollection || type.isPolymorphic ? 4 : 0,
borderStyle: 'solid',
borderColor: color,
borderRadius: isModelType ? 4 : '100%',
@ -97,18 +94,14 @@ const FieldHandle = (props: FieldHandleProps) => {
isConnectionInProgress,
isConnectionStartField,
type,
typeColor,
]);
const tooltip = useMemo(() => {
if (isConnectionInProgress && isConnectionStartField) {
return title;
}
if (isConnectionInProgress && connectionError) {
return connectionError ?? title;
return connectionError;
}
return title;
}, [connectionError, isConnectionInProgress, isConnectionStartField, title]);
return fieldTypeName;
}, [connectionError, fieldTypeName, isConnectionInProgress]);
return (
<Tooltip

View File

@ -1,15 +1,14 @@
import { Flex, Text } from '@chakra-ui/react';
import { useFieldData } from 'features/nodes/hooks/useFieldData';
import { useFieldInstance } from 'features/nodes/hooks/useFieldData';
import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate';
import { FIELDS } from 'features/nodes/types/constants';
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
import {
isInputFieldTemplate,
isInputFieldValue,
} from 'features/nodes/types/types';
isFieldInputInstance,
isFieldInputTemplate,
} from 'features/nodes/types/field';
import { startCase } from 'lodash-es';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
interface Props {
nodeId: string;
fieldName: string;
@ -17,12 +16,13 @@ interface Props {
}
const FieldTooltipContent = ({ nodeId, fieldName, kind }: Props) => {
const field = useFieldData(nodeId, fieldName);
const field = useFieldInstance(nodeId, fieldName);
const fieldTemplate = useFieldTemplate(nodeId, fieldName, kind);
const isInputTemplate = isInputFieldTemplate(fieldTemplate);
const isInputTemplate = isFieldInputTemplate(fieldTemplate);
const fieldTypeName = useFieldTypeName(fieldTemplate?.type);
const { t } = useTranslation();
const fieldTitle = useMemo(() => {
if (isInputFieldValue(field)) {
if (isFieldInputInstance(field)) {
if (field.label && fieldTemplate?.title) {
return `${field.label} (${fieldTemplate.title})`;
}
@ -49,9 +49,9 @@ const FieldTooltipContent = ({ nodeId, fieldName, kind }: Props) => {
{fieldTemplate.description}
</Text>
)}
{fieldTemplate && (
{fieldTypeName && (
<Text>
{t('parameters.type')}: {FIELDS[fieldTemplate.type].title}
{t('parameters.type')}: {fieldTypeName}
</Text>
)}
{isInputTemplate && (

View File

@ -77,10 +77,10 @@ const InputField = ({ nodeId, fieldName }: Props) => {
sx={{
display: 'flex',
alignItems: 'center',
h: 'full',
mb: 0,
px: 1,
gap: 2,
h: 'full',
}}
>
<EditableFieldTitle

View File

@ -1,24 +1,60 @@
import { Box, Text } from '@chakra-ui/react';
import { useFieldData } from 'features/nodes/hooks/useFieldData';
import { useFieldInstance } from 'features/nodes/hooks/useFieldData';
import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate';
import {
isBoardFieldInputInstance,
isBoardFieldInputTemplate,
isBooleanFieldInputInstance,
isBooleanFieldInputTemplate,
isColorFieldInputInstance,
isColorFieldInputTemplate,
isControlNetModelFieldInputInstance,
isControlNetModelFieldInputTemplate,
isEnumFieldInputInstance,
isEnumFieldInputTemplate,
isFloatFieldInputInstance,
isFloatFieldInputTemplate,
isIPAdapterModelFieldInputInstance,
isIPAdapterModelFieldInputTemplate,
isImageFieldInputInstance,
isImageFieldInputTemplate,
isIntegerFieldInputInstance,
isIntegerFieldInputTemplate,
isLoRAModelFieldInputInstance,
isLoRAModelFieldInputTemplate,
isMainModelFieldInputInstance,
isMainModelFieldInputTemplate,
isSDXLMainModelFieldInputInstance,
isSDXLMainModelFieldInputTemplate,
isSDXLRefinerModelFieldInputInstance,
isSDXLRefinerModelFieldInputTemplate,
isSchedulerFieldInputInstance,
isSchedulerFieldInputTemplate,
isStringFieldInputInstance,
isStringFieldInputTemplate,
isT2IAdapterModelFieldInputInstance,
isT2IAdapterModelFieldInputTemplate,
isVAEModelFieldInputInstance,
isVAEModelFieldInputTemplate,
} from 'features/nodes/types/field';
import { memo } from 'react';
import BooleanInputField from './inputs/BooleanInputField';
import ColorInputField from './inputs/ColorInputField';
import ControlNetModelInputField from './inputs/ControlNetModelInputField';
import EnumInputField from './inputs/EnumInputField';
import ImageInputField from './inputs/ImageInputField';
import LoRAModelInputField from './inputs/LoRAModelInputField';
import MainModelInputField from './inputs/MainModelInputField';
import NumberInputField from './inputs/NumberInputField';
import RefinerModelInputField from './inputs/RefinerModelInputField';
import SDXLMainModelInputField from './inputs/SDXLMainModelInputField';
import SchedulerInputField from './inputs/SchedulerInputField';
import StringInputField from './inputs/StringInputField';
import VaeModelInputField from './inputs/VaeModelInputField';
import IPAdapterModelInputField from './inputs/IPAdapterModelInputField';
import T2IAdapterModelInputField from './inputs/T2IAdapterModelInputField';
import BoardInputField from './inputs/BoardInputField';
import { useTranslation } from 'react-i18next';
import BoardFieldInputComponent from './inputs/BoardFieldInputComponent';
import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent';
import ColorFieldInputComponent from './inputs/ColorFieldInputComponent';
import ControlNetModelFieldInputComponent from './inputs/ControlNetModelFieldInputComponent';
import EnumFieldInputComponent from './inputs/EnumFieldInputComponent';
import IPAdapterModelFieldInputComponent from './inputs/IPAdapterModelFieldInputComponent';
import ImageFieldInputComponent from './inputs/ImageFieldInputComponent';
import LoRAModelFieldInputComponent from './inputs/LoRAModelFieldInputComponent';
import MainModelFieldInputComponent from './inputs/MainModelFieldInputComponent';
import NumberFieldInputComponent from './inputs/NumberFieldInputComponent';
import RefinerModelFieldInputComponent from './inputs/RefinerModelFieldInputComponent';
import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent';
import SchedulerFieldInputComponent from './inputs/SchedulerFieldInputComponent';
import StringFieldInputComponent from './inputs/StringFieldInputComponent';
import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent';
import VAEModelFieldInputComponent from './inputs/VAEModelFieldInputComponent';
type InputFieldProps = {
nodeId: string;
@ -27,220 +63,227 @@ type InputFieldProps = {
const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
const { t } = useTranslation();
const field = useFieldData(nodeId, fieldName);
const fieldInstance = useFieldInstance(nodeId, fieldName);
const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input');
if (fieldTemplate?.fieldKind === 'output') {
return (
<Box p={2}>
{t('nodes.outputFieldInInput')}: {field?.type}
{t('nodes.outputFieldInInput')}: {fieldInstance?.type.name}
</Box>
);
}
if (
(field?.type === 'string' && fieldTemplate?.type === 'string') ||
(field?.type === 'StringPolymorphic' &&
fieldTemplate?.type === 'StringPolymorphic')
isStringFieldInputInstance(fieldInstance) &&
isStringFieldInputTemplate(fieldTemplate)
) {
return (
<StringInputField
<StringFieldInputComponent
nodeId={nodeId}
field={field}
field={fieldInstance}
fieldTemplate={fieldTemplate}
/>
);
}
if (
(field?.type === 'boolean' && fieldTemplate?.type === 'boolean') ||
(field?.type === 'BooleanPolymorphic' &&
fieldTemplate?.type === 'BooleanPolymorphic')
isBooleanFieldInputInstance(fieldInstance) &&
isBooleanFieldInputTemplate(fieldTemplate)
) {
return (
<BooleanInputField
<BooleanFieldInputComponent
nodeId={nodeId}
field={field}
field={fieldInstance}
fieldTemplate={fieldTemplate}
/>
);
}
if (
(field?.type === 'integer' && fieldTemplate?.type === 'integer') ||
(field?.type === 'float' && fieldTemplate?.type === 'float') ||
(field?.type === 'FloatPolymorphic' &&
fieldTemplate?.type === 'FloatPolymorphic') ||
(field?.type === 'IntegerPolymorphic' &&
fieldTemplate?.type === 'IntegerPolymorphic')
(isIntegerFieldInputInstance(fieldInstance) &&
isIntegerFieldInputTemplate(fieldTemplate)) ||
(isFloatFieldInputInstance(fieldInstance) &&
isFloatFieldInputTemplate(fieldTemplate))
) {
return (
<NumberInputField
<NumberFieldInputComponent
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (field?.type === 'enum' && fieldTemplate?.type === 'enum') {
return (
<EnumInputField
nodeId={nodeId}
field={field}
field={fieldInstance}
fieldTemplate={fieldTemplate}
/>
);
}
if (
(field?.type === 'ImageField' && fieldTemplate?.type === 'ImageField') ||
(field?.type === 'ImagePolymorphic' &&
fieldTemplate?.type === 'ImagePolymorphic')
isEnumFieldInputInstance(fieldInstance) &&
isEnumFieldInputTemplate(fieldTemplate)
) {
return (
<ImageInputField
<EnumFieldInputComponent
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (field?.type === 'BoardField' && fieldTemplate?.type === 'BoardField') {
return (
<BoardInputField
nodeId={nodeId}
field={field}
field={fieldInstance}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'MainModelField' &&
fieldTemplate?.type === 'MainModelField'
isImageFieldInputInstance(fieldInstance) &&
isImageFieldInputTemplate(fieldTemplate)
) {
return (
<MainModelInputField
<ImageFieldInputComponent
nodeId={nodeId}
field={field}
field={fieldInstance}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'SDXLRefinerModelField' &&
fieldTemplate?.type === 'SDXLRefinerModelField'
isBoardFieldInputInstance(fieldInstance) &&
isBoardFieldInputTemplate(fieldTemplate)
) {
return (
<RefinerModelInputField
<BoardFieldInputComponent
nodeId={nodeId}
field={field}
field={fieldInstance}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'VaeModelField' &&
fieldTemplate?.type === 'VaeModelField'
isMainModelFieldInputInstance(fieldInstance) &&
isMainModelFieldInputTemplate(fieldTemplate)
) {
return (
<VaeModelInputField
<MainModelFieldInputComponent
nodeId={nodeId}
field={field}
field={fieldInstance}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'LoRAModelField' &&
fieldTemplate?.type === 'LoRAModelField'
isSDXLRefinerModelFieldInputInstance(fieldInstance) &&
isSDXLRefinerModelFieldInputTemplate(fieldTemplate)
) {
return (
<LoRAModelInputField
<RefinerModelFieldInputComponent
nodeId={nodeId}
field={field}
field={fieldInstance}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'ControlNetModelField' &&
fieldTemplate?.type === 'ControlNetModelField'
isVAEModelFieldInputInstance(fieldInstance) &&
isVAEModelFieldInputTemplate(fieldTemplate)
) {
return (
<ControlNetModelInputField
<VAEModelFieldInputComponent
nodeId={nodeId}
field={field}
field={fieldInstance}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'IPAdapterModelField' &&
fieldTemplate?.type === 'IPAdapterModelField'
isLoRAModelFieldInputInstance(fieldInstance) &&
isLoRAModelFieldInputTemplate(fieldTemplate)
) {
return (
<IPAdapterModelInputField
<LoRAModelFieldInputComponent
nodeId={nodeId}
field={field}
field={fieldInstance}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'T2IAdapterModelField' &&
fieldTemplate?.type === 'T2IAdapterModelField'
isControlNetModelFieldInputInstance(fieldInstance) &&
isControlNetModelFieldInputTemplate(fieldTemplate)
) {
return (
<T2IAdapterModelInputField
<ControlNetModelFieldInputComponent
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (field?.type === 'ColorField' && fieldTemplate?.type === 'ColorField') {
return (
<ColorInputField
nodeId={nodeId}
field={field}
field={fieldInstance}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'SDXLMainModelField' &&
fieldTemplate?.type === 'SDXLMainModelField'
isIPAdapterModelFieldInputInstance(fieldInstance) &&
isIPAdapterModelFieldInputTemplate(fieldTemplate)
) {
return (
<SDXLMainModelInputField
<IPAdapterModelFieldInputComponent
nodeId={nodeId}
field={field}
field={fieldInstance}
fieldTemplate={fieldTemplate}
/>
);
}
if (field?.type === 'Scheduler' && fieldTemplate?.type === 'Scheduler') {
if (
isT2IAdapterModelFieldInputInstance(fieldInstance) &&
isT2IAdapterModelFieldInputTemplate(fieldTemplate)
) {
return (
<SchedulerInputField
<T2IAdapterModelFieldInputComponent
nodeId={nodeId}
field={field}
field={fieldInstance}
fieldTemplate={fieldTemplate}
/>
);
}
if (
isColorFieldInputInstance(fieldInstance) &&
isColorFieldInputTemplate(fieldTemplate)
) {
return (
<ColorFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
/>
);
}
if (field && fieldTemplate) {
if (
isSDXLMainModelFieldInputInstance(fieldInstance) &&
isSDXLMainModelFieldInputTemplate(fieldTemplate)
) {
return (
<SDXLMainModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
/>
);
}
if (
isSchedulerFieldInputInstance(fieldInstance) &&
isSchedulerFieldInputTemplate(fieldTemplate)
) {
return (
<SchedulerFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
/>
);
}
if (fieldInstance && fieldTemplate) {
// Fallback for when there is no component for the type
return null;
}
@ -255,7 +298,7 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
_dark: { color: 'error.300' },
}}
>
{t('nodes.unknownFieldType')}: {field?.type}
{t('nodes.unknownFieldType')}: {fieldInstance?.type.name}
</Text>
</Box>
);

View File

@ -3,15 +3,15 @@ import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { fieldBoardValueChanged } from 'features/nodes/store/nodesSlice';
import {
BoardInputFieldTemplate,
BoardInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
BoardFieldInputTemplate,
BoardFieldInputInstance,
} from 'features/nodes/types/field';
import { FieldComponentProps } from './types';
import { memo, useCallback } from 'react';
import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
const BoardInputFieldComponent = (
props: FieldComponentProps<BoardInputFieldValue, BoardInputFieldTemplate>
const BoardFieldInputComponent = (
props: FieldComponentProps<BoardFieldInputInstance, BoardFieldInputTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
@ -61,4 +61,4 @@ const BoardInputFieldComponent = (
);
};
export default memo(BoardInputFieldComponent);
export default memo(BoardFieldInputComponent);

View File

@ -2,18 +2,16 @@ import { Switch } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { fieldBooleanValueChanged } from 'features/nodes/store/nodesSlice';
import {
BooleanInputFieldTemplate,
BooleanInputFieldValue,
BooleanPolymorphicInputFieldTemplate,
BooleanPolymorphicInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
BooleanFieldInputInstance,
BooleanFieldInputTemplate,
} from 'features/nodes/types/field';
import { FieldComponentProps } from './types';
import { ChangeEvent, memo, useCallback } from 'react';
const BooleanInputFieldComponent = (
const BooleanFieldInputComponent = (
props: FieldComponentProps<
BooleanInputFieldValue | BooleanPolymorphicInputFieldValue,
BooleanInputFieldTemplate | BooleanPolymorphicInputFieldTemplate
BooleanFieldInputInstance,
BooleanFieldInputTemplate
>
) => {
const { nodeId, field } = props;
@ -42,4 +40,4 @@ const BooleanInputFieldComponent = (
);
};
export default memo(BooleanInputFieldComponent);
export default memo(BooleanFieldInputComponent);

View File

@ -1,15 +1,15 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { fieldColorValueChanged } from 'features/nodes/store/nodesSlice';
import {
ColorInputFieldTemplate,
ColorInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
ColorFieldInputTemplate,
ColorFieldInputInstance,
} from 'features/nodes/types/field';
import { FieldComponentProps } from './types';
import { memo, useCallback } from 'react';
import { RgbaColor, RgbaColorPicker } from 'react-colorful';
const ColorInputFieldComponent = (
props: FieldComponentProps<ColorInputFieldValue, ColorInputFieldTemplate>
const ColorFieldInputComponent = (
props: FieldComponentProps<ColorFieldInputInstance, ColorFieldInputTemplate>
) => {
const { nodeId, field } = props;
@ -37,4 +37,4 @@ const ColorInputFieldComponent = (
);
};
export default memo(ColorInputFieldComponent);
export default memo(ColorFieldInputComponent);

View File

@ -3,20 +3,20 @@ import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { fieldControlNetModelValueChanged } from 'features/nodes/store/nodesSlice';
import {
ControlNetModelInputFieldTemplate,
ControlNetModelInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
ControlNetModelFieldInputTemplate,
ControlNetModelFieldInputInstance,
} from 'features/nodes/types/field';
import { FieldComponentProps } from './types';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToControlNetModelParam } from 'features/parameters/util/modelIdToControlNetModelParam';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
const ControlNetModelInputFieldComponent = (
const ControlNetModelFieldInputComponent = (
props: FieldComponentProps<
ControlNetModelInputFieldValue,
ControlNetModelInputFieldTemplate
ControlNetModelFieldInputInstance,
ControlNetModelFieldInputTemplate
>
) => {
const { nodeId, field } = props;
@ -97,4 +97,4 @@ const ControlNetModelInputFieldComponent = (
);
};
export default memo(ControlNetModelInputFieldComponent);
export default memo(ControlNetModelFieldInputComponent);

View File

@ -2,14 +2,14 @@ import { Select } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { fieldEnumModelValueChanged } from 'features/nodes/store/nodesSlice';
import {
EnumInputFieldTemplate,
EnumInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
EnumFieldInputInstance,
EnumFieldInputTemplate,
} from 'features/nodes/types/field';
import { FieldComponentProps } from './types';
import { ChangeEvent, memo, useCallback } from 'react';
const EnumInputFieldComponent = (
props: FieldComponentProps<EnumInputFieldValue, EnumInputFieldTemplate>
const EnumFieldInputComponent = (
props: FieldComponentProps<EnumFieldInputInstance, EnumFieldInputTemplate>
) => {
const { nodeId, field, fieldTemplate } = props;
@ -45,4 +45,4 @@ const EnumInputFieldComponent = (
);
};
export default memo(EnumInputFieldComponent);
export default memo(EnumFieldInputComponent);

View File

@ -3,20 +3,20 @@ import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { fieldIPAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
import {
IPAdapterModelInputFieldTemplate,
IPAdapterModelInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
IPAdapterModelFieldInputTemplate,
IPAdapterModelFieldInputInstance,
} from 'features/nodes/types/field';
import { FieldComponentProps } from './types';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToIPAdapterModelParam } from 'features/parameters/util/modelIdToIPAdapterModelParams';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useGetIPAdapterModelsQuery } from 'services/api/endpoints/models';
const IPAdapterModelInputFieldComponent = (
const IPAdapterModelFieldInputComponent = (
props: FieldComponentProps<
IPAdapterModelInputFieldValue,
IPAdapterModelInputFieldTemplate
IPAdapterModelFieldInputInstance,
IPAdapterModelFieldInputTemplate
>
) => {
const { nodeId, field } = props;
@ -97,4 +97,4 @@ const IPAdapterModelInputFieldComponent = (
);
};
export default memo(IPAdapterModelInputFieldComponent);
export default memo(IPAdapterModelFieldInputComponent);

View File

@ -9,23 +9,18 @@ import {
} from 'features/dnd/types';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import {
FieldComponentProps,
ImageInputFieldTemplate,
ImageInputFieldValue,
ImagePolymorphicInputFieldTemplate,
ImagePolymorphicInputFieldValue,
} from 'features/nodes/types/types';
ImageFieldInputInstance,
ImageFieldInputTemplate,
} from 'features/nodes/types/field';
import { FieldComponentProps } from './types';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { FaUndo } from 'react-icons/fa';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { PostUploadAction } from 'services/api/types';
const ImageInputFieldComponent = (
props: FieldComponentProps<
ImageInputFieldValue | ImagePolymorphicInputFieldValue,
ImageInputFieldTemplate | ImagePolymorphicInputFieldTemplate
>
const ImageFieldInputComponent = (
props: FieldComponentProps<ImageFieldInputInstance, ImageFieldInputTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
@ -102,7 +97,7 @@ const ImageInputFieldComponent = (
);
};
export default memo(ImageInputFieldComponent);
export default memo(ImageFieldInputComponent);
const UploadElement = memo(() => {
const { t } = useTranslation();

View File

@ -5,10 +5,10 @@ import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSe
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { fieldLoRAModelValueChanged } from 'features/nodes/store/nodesSlice';
import {
LoRAModelInputFieldTemplate,
LoRAModelInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
LoRAModelFieldInputTemplate,
LoRAModelFieldInputInstance,
} from 'features/nodes/types/field';
import { FieldComponentProps } from './types';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToLoRAModelParam } from 'features/parameters/util/modelIdToLoRAModelParam';
import { forEach } from 'lodash-es';
@ -16,10 +16,10 @@ import { memo, useCallback, useMemo } from 'react';
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
import { useTranslation } from 'react-i18next';
const LoRAModelInputFieldComponent = (
const LoRAModelFieldInputComponent = (
props: FieldComponentProps<
LoRAModelInputFieldValue,
LoRAModelInputFieldTemplate
LoRAModelFieldInputInstance,
LoRAModelFieldInputTemplate
>
) => {
const { nodeId, field } = props;
@ -121,4 +121,4 @@ const LoRAModelInputFieldComponent = (
);
};
export default memo(LoRAModelInputFieldComponent);
export default memo(LoRAModelFieldInputComponent);

View File

@ -4,10 +4,10 @@ import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import {
MainModelInputFieldTemplate,
MainModelInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
MainModelFieldInputTemplate,
MainModelFieldInputInstance,
} from 'features/nodes/types/field';
import { FieldComponentProps } from './types';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
@ -21,10 +21,10 @@ import {
} from 'services/api/endpoints/models';
import { useTranslation } from 'react-i18next';
const MainModelInputFieldComponent = (
const MainModelFieldInputComponent = (
props: FieldComponentProps<
MainModelInputFieldValue,
MainModelInputFieldTemplate
MainModelFieldInputInstance,
MainModelFieldInputTemplate
>
) => {
const { nodeId, field } = props;
@ -149,4 +149,4 @@ const MainModelInputFieldComponent = (
);
};
export default memo(MainModelInputFieldComponent);
export default memo(MainModelFieldInputComponent);

View File

@ -9,28 +9,18 @@ import { useAppDispatch } from 'app/store/storeHooks';
import { numberStringRegex } from 'common/components/IAINumberInput';
import { fieldNumberValueChanged } from 'features/nodes/store/nodesSlice';
import {
FieldComponentProps,
FloatInputFieldTemplate,
FloatInputFieldValue,
FloatPolymorphicInputFieldTemplate,
FloatPolymorphicInputFieldValue,
IntegerInputFieldTemplate,
IntegerInputFieldValue,
IntegerPolymorphicInputFieldTemplate,
IntegerPolymorphicInputFieldValue,
} from 'features/nodes/types/types';
FloatFieldInputInstance,
FloatFieldInputTemplate,
IntegerFieldInputInstance,
IntegerFieldInputTemplate,
} from 'features/nodes/types/field';
import { FieldComponentProps } from './types';
import { memo, useCallback, useEffect, useMemo, useState } from 'react';
const NumberInputFieldComponent = (
const NumberFieldInputComponent = (
props: FieldComponentProps<
| IntegerInputFieldValue
| IntegerPolymorphicInputFieldValue
| FloatInputFieldValue
| FloatPolymorphicInputFieldValue,
| IntegerInputFieldTemplate
| IntegerPolymorphicInputFieldTemplate
| FloatInputFieldTemplate
| FloatPolymorphicInputFieldTemplate
IntegerFieldInputInstance | FloatFieldInputInstance,
IntegerFieldInputTemplate | FloatFieldInputTemplate
>
) => {
const { nodeId, field, fieldTemplate } = props;
@ -39,7 +29,7 @@ const NumberInputFieldComponent = (
String(field.value)
);
const isIntegerField = useMemo(
() => fieldTemplate.type === 'integer',
() => fieldTemplate.type.name === 'IntegerField',
[fieldTemplate.type]
);
@ -86,4 +76,4 @@ const NumberInputFieldComponent = (
);
};
export default memo(NumberInputFieldComponent);
export default memo(NumberFieldInputComponent);

View File

@ -4,10 +4,10 @@ import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { fieldRefinerModelValueChanged } from 'features/nodes/store/nodesSlice';
import {
FieldComponentProps,
SDXLRefinerModelInputFieldTemplate,
SDXLRefinerModelInputFieldValue,
} from 'features/nodes/types/types';
SDXLRefinerModelFieldInputTemplate,
SDXLRefinerModelFieldInputInstance,
} from 'features/nodes/types/field';
import { FieldComponentProps } from './types';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
@ -18,10 +18,10 @@ import { useTranslation } from 'react-i18next';
import { REFINER_BASE_MODELS } from 'services/api/constants';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
const RefinerModelInputFieldComponent = (
const RefinerModelFieldInputComponent = (
props: FieldComponentProps<
SDXLRefinerModelInputFieldValue,
SDXLRefinerModelInputFieldTemplate
SDXLRefinerModelFieldInputInstance,
SDXLRefinerModelFieldInputTemplate
>
) => {
const { nodeId, field } = props;
@ -120,4 +120,4 @@ const RefinerModelInputFieldComponent = (
);
};
export default memo(RefinerModelInputFieldComponent);
export default memo(RefinerModelFieldInputComponent);

View File

@ -4,10 +4,10 @@ import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import {
SDXLMainModelInputFieldTemplate,
SDXLMainModelInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
SDXLMainModelFieldInputTemplate,
SDXLMainModelFieldInputInstance,
} from 'features/nodes/types/field';
import { FieldComponentProps } from './types';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
@ -21,10 +21,10 @@ import {
useGetOnnxModelsQuery,
} from 'services/api/endpoints/models';
const ModelInputFieldComponent = (
const SDXLMainModelFieldInputComponent = (
props: FieldComponentProps<
SDXLMainModelInputFieldValue,
SDXLMainModelInputFieldTemplate
SDXLMainModelFieldInputInstance,
SDXLMainModelFieldInputTemplate
>
) => {
const { nodeId, field } = props;
@ -147,4 +147,4 @@ const ModelInputFieldComponent = (
);
};
export default memo(ModelInputFieldComponent);
export default memo(SDXLMainModelFieldInputComponent);

View File

@ -5,14 +5,12 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { fieldSchedulerValueChanged } from 'features/nodes/store/nodesSlice';
import {
SchedulerInputFieldTemplate,
SchedulerInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
import {
SCHEDULER_LABEL_MAP,
SchedulerParam,
} from 'features/parameters/types/parameterSchemas';
SchedulerFieldInputTemplate,
SchedulerFieldInputInstance,
} from 'features/nodes/types/field';
import { FieldComponentProps } from './types';
import { ParameterScheduler } from 'features/parameters/types/parameterSchemas';
import { SCHEDULER_LABEL_MAP } from 'features/parameters/types/constants';
import { map } from 'lodash-es';
import { memo, useCallback } from 'react';
@ -24,7 +22,7 @@ const selector = createSelector(
const data = map(SCHEDULER_LABEL_MAP, (label, name) => ({
value: name,
label: label,
group: enabledSchedulers.includes(name as SchedulerParam)
group: enabledSchedulers.includes(name as ParameterScheduler)
? 'Favorites'
: undefined,
})).sort((a, b) => a.label.localeCompare(b.label));
@ -36,10 +34,10 @@ const selector = createSelector(
defaultSelectorOptions
);
const SchedulerInputField = (
const SchedulerFieldInputComponent = (
props: FieldComponentProps<
SchedulerInputFieldValue,
SchedulerInputFieldTemplate
SchedulerFieldInputInstance,
SchedulerFieldInputTemplate
>
) => {
const { nodeId, field } = props;
@ -55,7 +53,7 @@ const SchedulerInputField = (
fieldSchedulerValueChanged({
nodeId,
fieldName: field.name,
value: value as SchedulerParam,
value: value as ParameterScheduler,
})
);
},
@ -72,4 +70,4 @@ const SchedulerInputField = (
);
};
export default memo(SchedulerInputField);
export default memo(SchedulerFieldInputComponent);

View File

@ -3,19 +3,14 @@ import IAIInput from 'common/components/IAIInput';
import IAITextarea from 'common/components/IAITextarea';
import { fieldStringValueChanged } from 'features/nodes/store/nodesSlice';
import {
StringInputFieldTemplate,
StringInputFieldValue,
FieldComponentProps,
StringPolymorphicInputFieldValue,
StringPolymorphicInputFieldTemplate,
} from 'features/nodes/types/types';
StringFieldInputInstance,
StringFieldInputTemplate,
} from 'features/nodes/types/field';
import { FieldComponentProps } from './types';
import { ChangeEvent, memo, useCallback } from 'react';
const StringInputFieldComponent = (
props: FieldComponentProps<
StringInputFieldValue | StringPolymorphicInputFieldValue,
StringInputFieldTemplate | StringPolymorphicInputFieldTemplate
>
const StringFieldInputComponent = (
props: FieldComponentProps<StringFieldInputInstance, StringFieldInputTemplate>
) => {
const { nodeId, field, fieldTemplate } = props;
const dispatch = useAppDispatch();
@ -48,4 +43,4 @@ const StringInputFieldComponent = (
return <IAIInput onChange={handleValueChanged} value={field.value} />;
};
export default memo(StringInputFieldComponent);
export default memo(StringFieldInputComponent);

View File

@ -3,20 +3,20 @@ import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { fieldT2IAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
import {
T2IAdapterModelInputFieldTemplate,
T2IAdapterModelInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
T2IAdapterModelFieldInputInstance,
T2IAdapterModelFieldInputTemplate,
} from 'features/nodes/types/field';
import { FieldComponentProps } from './types';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToT2IAdapterModelParam } from 'features/parameters/util/modelIdToT2IAdapterModelParam';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useGetT2IAdapterModelsQuery } from 'services/api/endpoints/models';
const T2IAdapterModelInputFieldComponent = (
const T2IAdapterModelFieldInputComponent = (
props: FieldComponentProps<
T2IAdapterModelInputFieldValue,
T2IAdapterModelInputFieldTemplate
T2IAdapterModelFieldInputInstance,
T2IAdapterModelFieldInputTemplate
>
) => {
const { nodeId, field } = props;
@ -97,4 +97,4 @@ const T2IAdapterModelInputFieldComponent = (
);
};
export default memo(T2IAdapterModelInputFieldComponent);
export default memo(T2IAdapterModelFieldInputComponent);

View File

@ -4,20 +4,20 @@ import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSe
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice';
import {
FieldComponentProps,
VaeModelInputFieldTemplate,
VaeModelInputFieldValue,
} from 'features/nodes/types/types';
VAEModelFieldInputTemplate,
VAEModelFieldInputInstance,
} from 'features/nodes/types/field';
import { FieldComponentProps } from './types';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToVAEModelParam } from 'features/parameters/util/modelIdToVAEModelParam';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
const VaeModelInputFieldComponent = (
const VAEModelFieldInputComponent = (
props: FieldComponentProps<
VaeModelInputFieldValue,
VaeModelInputFieldTemplate
VAEModelFieldInputInstance,
VAEModelFieldInputTemplate
>
) => {
const { nodeId, field } = props;
@ -105,4 +105,4 @@ const VaeModelInputFieldComponent = (
);
};
export default memo(VaeModelInputFieldComponent);
export default memo(VAEModelFieldInputComponent);

View File

@ -0,0 +1,13 @@
import {
FieldInputInstance,
FieldInputTemplate,
} from 'features/nodes/types/field';
export type FieldComponentProps<
V extends FieldInputInstance,
T extends FieldInputTemplate,
> = {
nodeId: string;
field: V;
fieldTemplate: T;
};

View File

@ -2,7 +2,7 @@ import { Box, Flex } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import IAITextarea from 'common/components/IAITextarea';
import { notesNodeValueChanged } from 'features/nodes/store/nodesSlice';
import { NotesNodeData } from 'features/nodes/types/types';
import { NotesNodeData } from 'features/nodes/types/invocation';
import { ChangeEvent, memo, useCallback } from 'react';
import { NodeProps } from 'reactflow';
import NodeWrapper from '../common/NodeWrapper';

View File

@ -14,7 +14,7 @@ import {
DRAG_HANDLE_CLASSNAME,
NODE_WIDTH,
} from 'features/nodes/types/constants';
import { NodeStatus } from 'features/nodes/types/types';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { contextMenusClosed } from 'features/ui/store/uiSlice';
import {
MouseEvent,
@ -40,7 +40,8 @@ const NodeWrapper = (props: NodeWrapperProps) => {
createSelector(
stateSelector,
({ nodes }) =>
nodes.nodeExecutionStates[nodeId]?.status === NodeStatus.IN_PROGRESS
nodes.nodeExecutionStates[nodeId]?.status ===
zNodeStatus.enum.IN_PROGRESS
),
[nodeId]
);

View File

@ -8,7 +8,7 @@ import { FaUpload } from 'react-icons/fa';
const LoadWorkflowButton = () => {
const { t } = useTranslation();
const resetRef = useRef<() => void>(null);
const loadWorkflowFromFile = useLoadWorkflowFromFile();
const loadWorkflowFromFile = useLoadWorkflowFromFile(resetRef);
return (
<FileButton
resetRef={resetRef}

View File

@ -1,31 +0,0 @@
import { Badge, Flex, Tooltip } from '@chakra-ui/react';
import { FIELDS } from 'features/nodes/types/constants';
import { map } from 'lodash-es';
import { memo } from 'react';
import 'reactflow/dist/style.css';
const FieldTypeLegend = () => {
return (
<Flex sx={{ gap: 2, flexDir: 'column' }}>
{map(FIELDS, ({ title, description, color }, key) => (
<Tooltip key={key} label={description}>
<Badge
sx={{
userSelect: 'none',
color:
parseInt(color.split('.')[1] ?? '0', 10) < 500
? 'base.800'
: 'base.50',
bg: color,
}}
textAlign="center"
>
{title}
</Badge>
</Tooltip>
))}
</Flex>
);
};
export default memo(FieldTypeLegend);

View File

@ -1,18 +1,11 @@
import { Flex } from '@chakra-ui/layout';
import { useAppSelector } from 'app/store/storeHooks';
import { memo } from 'react';
import FieldTypeLegend from './FieldTypeLegend';
import WorkflowEditorSettings from './WorkflowEditorSettings';
const TopRightPanel = () => {
const shouldShowFieldTypeLegend = useAppSelector(
(state) => state.nodes.shouldShowFieldTypeLegend
);
return (
<Flex sx={{ gap: 2, position: 'absolute', top: 2, insetInlineEnd: 2 }}>
<WorkflowEditorSettings />
{shouldShowFieldTypeLegend && <FieldTypeLegend />}
</Flex>
);
};

View File

@ -10,17 +10,15 @@ import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIIconButton from 'common/components/IAIIconButton';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { useNodeVersion } from 'features/nodes/hooks/useNodeVersion';
import { getNeedsUpdate } from 'features/nodes/store/util/nodeUpdate';
import {
InvocationNodeData,
InvocationTemplate,
isInvocationNode,
} from 'features/nodes/types/types';
import { memo } from 'react';
} from 'features/nodes/types/invocation';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { FaSync } from 'react-icons/fa';
import { Node } from 'reactflow';
import NotesTextarea from '../../flow/nodes/Invocation/NotesTextarea';
import ScrollableContent from '../ScrollableContent';
@ -63,12 +61,17 @@ const InspectorDetailsTab = () => {
export default memo(InspectorDetailsTab);
const Content = (props: {
type ContentProps = {
node: Node<InvocationNodeData>;
template: InvocationTemplate;
}) => {
};
const Content = memo(({ node, template }: ContentProps) => {
const { t } = useTranslation();
const { needsUpdate, updateNode } = useNodeVersion(props.node.id);
const needsUpdate = useMemo(
() => getNeedsUpdate(node, template),
[node, template]
);
return (
<Box
sx={{
@ -87,12 +90,12 @@ const Content = (props: {
w: 'full',
}}
>
<EditableNodeTitle nodeId={props.node.data.id} />
<EditableNodeTitle nodeId={node.data.id} />
<HStack>
<FormControl>
<FormLabel>{t('nodes.nodeType')}</FormLabel>
<Text fontSize="sm" fontWeight={600}>
{props.template.title}
{template.title}
</Text>
</FormControl>
<Flex
@ -104,22 +107,16 @@ const Content = (props: {
<FormControl isInvalid={needsUpdate}>
<FormLabel>{t('nodes.nodeVersion')}</FormLabel>
<Text fontSize="sm" fontWeight={600}>
{props.node.data.version}
{node.data.version}
</Text>
</FormControl>
{needsUpdate && (
<IAIIconButton
aria-label={t('nodes.updateNode')}
tooltip={t('nodes.updateNode')}
icon={<FaSync />}
onClick={updateNode}
/>
)}
</Flex>
</HStack>
<NotesTextarea nodeId={props.node.data.id} />
<NotesTextarea nodeId={node.data.id} />
</Flex>
</ScrollableContent>
</Box>
);
};
});
Content.displayName = 'Content';

View File

@ -5,7 +5,7 @@ import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
import { isInvocationNode } from 'features/nodes/types/types';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { memo } from 'react';
import { ImageOutput } from 'services/api/types';
import { AnyResult } from 'services/events/types';

View File

@ -2,14 +2,11 @@ import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { map } from 'lodash-es';
import { keys, map } from 'lodash-es';
import { useMemo } from 'react';
import { isInvocationNode } from '../types/types';
import {
POLYMORPHIC_TYPES,
TYPES_WITH_INPUT_COMPONENTS,
} from '../types/constants';
import { isInvocationNode } from '../types/invocation';
import { getSortedFilteredFieldNames } from '../util/getSortedFilteredFieldNames';
import { TEMPLATE_BUILDER_MAP } from '../util/buildFieldInputTemplate';
export const useAnyOrDirectInputFieldNames = (nodeId: string) => {
const selector = useMemo(
@ -28,8 +25,8 @@ export const useAnyOrDirectInputFieldNames = (nodeId: string) => {
const fields = map(nodeTemplate.inputs).filter(
(field) =>
(['any', 'direct'].includes(field.input) ||
POLYMORPHIC_TYPES.includes(field.type)) &&
TYPES_WITH_INPUT_COMPONENTS.includes(field.type)
field.type.isPolymorphic) &&
keys(TEMPLATE_BUILDER_MAP).includes(field.type.name)
);
return getSortedFilteredFieldNames(fields);
},

View File

@ -3,10 +3,13 @@ import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { useCallback } from 'react';
import { Node, useReactFlow } from 'reactflow';
import { AnyInvocationType } from 'services/events/types';
import { buildNodeData } from '../store/util/buildNodeData';
import {
buildCurrentImageNode,
buildInvocationNode,
buildNotesNode,
} from '../store/util/buildNodeData';
import { DRAG_HANDLE_CLASSNAME, NODE_WIDTH } from '../types/constants';
import { AnyNodeData, InvocationTemplate } from '../types/invocation';
const templatesSelector = createSelector(
[(state: RootState) => state.nodes],
(nodes) => nodes.nodeTemplates
@ -22,7 +25,8 @@ export const useBuildNodeData = () => {
const flow = useReactFlow();
return useCallback(
(type: AnyInvocationType | 'current_image' | 'notes') => {
// string here is "any invocation type"
(type: string | 'current_image' | 'notes'): Node<AnyNodeData> => {
let _x = window.innerWidth / 2;
let _y = window.innerHeight / 2;
@ -41,9 +45,19 @@ export const useBuildNodeData = () => {
y: _y,
});
const template = nodeTemplates[type];
if (type === 'current_image') {
return buildCurrentImageNode(position);
}
return buildNodeData(type, position, template);
if (type === 'notes') {
return buildNotesNode(position);
}
// TODO: Keep track of invocation types so we do not need to cast this
// We know it is safe because the caller of this function gets the `type` arg from the list of invocation templates.
const template = nodeTemplates[type] as InvocationTemplate;
return buildInvocationNode(position, template);
},
[nodeTemplates, flow]
);

View File

@ -2,14 +2,11 @@ import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { map } from 'lodash-es';
import { keys, map } from 'lodash-es';
import { useMemo } from 'react';
import {
POLYMORPHIC_TYPES,
TYPES_WITH_INPUT_COMPONENTS,
} from '../types/constants';
import { isInvocationNode } from '../types/types';
import { isInvocationNode } from '../types/invocation';
import { getSortedFilteredFieldNames } from '../util/getSortedFilteredFieldNames';
import { TEMPLATE_BUILDER_MAP } from '../util/buildFieldInputTemplate';
export const useConnectionInputFieldNames = (nodeId: string) => {
const selector = useMemo(
@ -29,9 +26,8 @@ export const useConnectionInputFieldNames = (nodeId: string) => {
// get the visible fields
const fields = map(nodeTemplate.inputs).filter(
(field) =>
(field.input === 'connection' &&
!POLYMORPHIC_TYPES.includes(field.type)) ||
!TYPES_WITH_INPUT_COMPONENTS.includes(field.type)
(field.input === 'connection' && !field.type.isPolymorphic) ||
!keys(TEMPLATE_BUILDER_MAP).includes(field.type.name)
);
return getSortedFilteredFieldNames(fields);

View File

@ -8,7 +8,7 @@ import { useFieldType } from './useFieldType.ts';
const selectIsConnectionInProgress = createSelector(
stateSelector,
({ nodes }) =>
nodes.currentConnectionFieldType !== null &&
nodes.connectionStartFieldType !== null &&
nodes.connectionStartParams !== null
);

View File

@ -4,7 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { compareVersions } from 'compare-versions';
import { useMemo } from 'react';
import { isInvocationNode } from '../types/types';
import { isInvocationNode } from '../types/invocation';
export const useDoNodeVersionsMatch = (nodeId: string) => {
const selector = useMemo(

View File

@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useMemo } from 'react';
import { isInvocationNode } from '../types/types';
import { isInvocationNode } from '../types/invocation';
export const useDoesInputHaveValue = (nodeId: string, fieldName: string) => {
const selector = useMemo(

View File

@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useMemo } from 'react';
import { isInvocationNode } from '../types/types';
import { isInvocationNode } from '../types/invocation';
export const useEmbedWorkflow = (nodeId: string) => {
const selector = useMemo(

View File

@ -3,9 +3,9 @@ import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useMemo } from 'react';
import { isInvocationNode } from '../types/types';
import { isInvocationNode } from '../types/invocation';
export const useFieldData = (nodeId: string, fieldName: string) => {
export const useFieldInstance = (nodeId: string, fieldName: string) => {
const selector = useMemo(
() =>
createSelector(

View File

@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useMemo } from 'react';
import { isInvocationNode } from '../types/types';
import { isInvocationNode } from '../types/invocation';
export const useFieldInputKind = (nodeId: string, fieldName: string) => {
const selector = useMemo(

View File

@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useMemo } from 'react';
import { isInvocationNode } from '../types/types';
import { isInvocationNode } from '../types/invocation';
export const useFieldLabel = (nodeId: string, fieldName: string) => {
const selector = useMemo(

View File

@ -4,7 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useMemo } from 'react';
import { KIND_MAP } from '../types/constants';
import { isInvocationNode } from '../types/types';
import { isInvocationNode } from '../types/invocation';
export const useFieldTemplate = (
nodeId: string,

View File

@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useMemo } from 'react';
import { isInvocationNode } from '../types/types';
import { isInvocationNode } from '../types/invocation';
import { KIND_MAP } from '../types/constants';
export const useFieldTemplateTitle = (

View File

@ -4,7 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useMemo } from 'react';
import { KIND_MAP } from '../types/constants';
import { isInvocationNode } from '../types/types';
import { isInvocationNode } from '../types/invocation';
export const useFieldType = (
nodeId: string,
@ -20,7 +20,8 @@ export const useFieldType = (
if (!isInvocationNode(node)) {
return;
}
return node?.data[KIND_MAP[kind]][fieldName]?.type;
const field = node.data[KIND_MAP[kind]][fieldName];
return field?.type;
},
defaultSelectorOptions
),

View File

@ -2,7 +2,8 @@ import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { getNeedsUpdate } from './useNodeVersion';
import { getNeedsUpdate } from '../store/util/nodeUpdate';
import { isInvocationNode } from '../types/invocation';
const selector = createSelector(
stateSelector,
@ -10,8 +11,11 @@ const selector = createSelector(
const nodes = state.nodes.nodes;
const templates = state.nodes.nodeTemplates;
const needsUpdate = nodes.some((node) => {
const needsUpdate = nodes.filter(isInvocationNode).some((node) => {
const template = templates[node.data.type];
if (!template) {
return false;
}
return getNeedsUpdate(node, template);
});
return needsUpdate;

View File

@ -4,8 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { some } from 'lodash-es';
import { useMemo } from 'react';
import { IMAGE_FIELDS } from '../types/constants';
import { isInvocationNode } from '../types/types';
import { isInvocationNode } from '../types/invocation';
export const useHasImageOutput = (nodeId: string) => {
const selector = useMemo(
@ -20,8 +19,8 @@ export const useHasImageOutput = (nodeId: string) => {
return some(
node.data.outputs,
(output) =>
IMAGE_FIELDS.includes(output.type) &&
// the image primitive node does not actually save the image, do not show the image-saving checkboxes
output.type.name === 'ImageField' &&
// the image primitive node (node type "image") does not actually save the image, do not show the image-saving checkboxes
node.data.type !== 'image'
);
},

View File

@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useMemo } from 'react';
import { isInvocationNode } from '../types/types';
import { isInvocationNode } from '../types/invocation';
export const useIsIntermediate = (nodeId: string) => {
const selector = useMemo(

View File

@ -4,7 +4,7 @@ import { useCallback } from 'react';
import { Connection, Node, useReactFlow } from 'reactflow';
import { validateSourceAndTargetTypes } from '../store/util/validateSourceAndTargetTypes';
import { getIsGraphAcyclic } from '../store/util/getIsGraphAcyclic';
import { InvocationNodeData } from '../types/types';
import { InvocationNodeData } from '../types/invocation';
/**
* NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts`
@ -34,10 +34,10 @@ export const useIsValidConnection = () => {
return false;
}
const sourceType = sourceNode.data.outputs[sourceHandle]?.type;
const targetType = targetNode.data.inputs[targetHandle]?.type;
const sourceField = sourceNode.data.outputs[sourceHandle];
const targetField = targetNode.data.inputs[targetHandle];
if (!sourceType || !targetType) {
if (!sourceField || !targetField) {
// something has gone terribly awry
return false;
}
@ -70,12 +70,13 @@ export const useIsValidConnection = () => {
return edge.target === target && edge.targetHandle === targetHandle;
}) &&
// except CollectionItem inputs can have multiples
targetType !== 'CollectionItem'
targetField.type.name !== 'CollectionItemField'
) {
return false;
}
if (!validateSourceAndTargetTypes(sourceType, targetType)) {
// Must use the originalType here if it exists
if (!validateSourceAndTargetTypes(sourceField.type, targetField.type)) {
return false;
}

View File

@ -1,17 +1,15 @@
import { ListItem, Text, UnorderedList } from '@chakra-ui/react';
import { useLogger } from 'app/logging/useLogger';
import { useAppDispatch } from 'app/store/storeHooks';
import { parseify } from 'common/util/serialize';
import { zWorkflow } from 'features/nodes/types/types';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { memo, useCallback } from 'react';
import { ZodError } from 'zod';
import { fromZodError, fromZodIssue } from 'zod-validation-error';
import { workflowLoadRequested } from '../store/actions';
import { RefObject, memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { ZodError } from 'zod';
import { fromZodIssue } from 'zod-validation-error';
import { workflowLoadRequested } from '../store/actions';
export const useLoadWorkflowFromFile = () => {
export const useLoadWorkflowFromFile = (resetRef: RefObject<() => void>) => {
const dispatch = useAppDispatch();
const logger = useLogger('nodes');
const { t } = useTranslation();
@ -26,33 +24,10 @@ export const useLoadWorkflowFromFile = () => {
try {
const parsedJSON = JSON.parse(String(rawJSON));
const result = zWorkflow.safeParse(parsedJSON);
if (!result.success) {
const { message } = fromZodError(result.error, {
prefix: t('nodes.workflowValidation'),
});
logger.error({ error: parseify(result.error) }, message);
dispatch(
addToast(
makeToast({
title: t('nodes.unableToValidateWorkflow'),
status: 'error',
duration: 5000,
})
)
);
reader.abort();
return;
}
dispatch(workflowLoadRequested(result.data));
reader.abort();
} catch {
// file reader error
dispatch(workflowLoadRequested(parsedJSON));
} catch (e) {
// There was a problem reading the file
logger.error(t('nodes.unableToLoadWorkflow'));
dispatch(
addToast(
makeToast({
@ -61,12 +36,15 @@ export const useLoadWorkflowFromFile = () => {
})
)
);
reader.abort();
}
};
reader.readAsText(file);
// Reset the file picker internal state so that the same file can be loaded again
resetRef.current?.();
},
[dispatch, logger, t]
[dispatch, logger, resetRef, t]
);
return loadWorkflowFromFile;

View File

@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useMemo } from 'react';
import { isInvocationNode } from '../types/types';
import { isInvocationNode } from '../types/invocation';
export const useNodeLabel = (nodeId: string) => {
const selector = useMemo(

View File

@ -0,0 +1,35 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useMemo } from 'react';
import { isInvocationNode } from '../types/invocation';
import { getNeedsUpdate } from '../store/util/nodeUpdate';
export const useNodeNeedsUpdate = (nodeId: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
const template = nodes.nodeTemplates[node?.data.type ?? ''];
return { node, template };
},
defaultSelectorOptions
),
[nodeId]
);
const { node, template } = useAppSelector(selector);
const needsUpdate = useMemo(
() =>
isInvocationNode(node) && template
? getNeedsUpdate(node, template)
: false,
[node, template]
);
return needsUpdate;
};

View File

@ -3,16 +3,14 @@ import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useMemo } from 'react';
import { AnyInvocationType } from 'services/events/types';
import { InvocationTemplate } from '../types/invocation';
export const useNodeTemplateByType = (
type: AnyInvocationType | 'current_image' | 'notes'
) => {
export const useNodeTemplateByType = (type: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
({ nodes }): InvocationTemplate | undefined => {
const nodeTemplate = nodes.nodeTemplates[type];
return nodeTemplate;
},

View File

@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useMemo } from 'react';
import { isInvocationNode } from '../types/types';
import { isInvocationNode } from '../types/invocation';
export const useNodeTemplateTitle = (nodeId: string) => {
const selector = useMemo(

View File

@ -1,119 +0,0 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { satisfies } from 'compare-versions';
import { cloneDeep, defaultsDeep } from 'lodash-es';
import { useCallback, useMemo } from 'react';
import { Node } from 'reactflow';
import { AnyInvocationType } from 'services/events/types';
import { nodeReplaced } from '../store/nodesSlice';
import { buildNodeData } from '../store/util/buildNodeData';
import {
InvocationNodeData,
InvocationTemplate,
NodeData,
isInvocationNode,
zParsedSemver,
} from '../types/types';
import { useAppToaster } from 'app/components/Toaster';
import { useTranslation } from 'react-i18next';
export const getNeedsUpdate = (
node?: Node<NodeData>,
template?: InvocationTemplate
) => {
if (!isInvocationNode(node) || !template) {
return false;
}
return node.data.version !== template.version;
};
export const getMayUpdateNode = (
node?: Node<NodeData>,
template?: InvocationTemplate
) => {
const needsUpdate = getNeedsUpdate(node, template);
if (
!needsUpdate ||
!isInvocationNode(node) ||
!template ||
!node.data.version
) {
return false;
}
const templateMajor = zParsedSemver.parse(template.version).major;
return satisfies(node.data.version, `^${templateMajor}`);
};
export const updateNode = (
node?: Node<NodeData>,
template?: InvocationTemplate
) => {
const mayUpdate = getMayUpdateNode(node, template);
if (
!mayUpdate ||
!isInvocationNode(node) ||
!template ||
!node.data.version
) {
return;
}
const defaults = buildNodeData(
node.data.type as AnyInvocationType,
node.position,
template
) as Node<InvocationNodeData>;
const clone = cloneDeep(node);
clone.data.version = template.version;
defaultsDeep(clone, defaults);
return clone;
};
export const useNodeVersion = (nodeId: string) => {
const dispatch = useAppDispatch();
const toast = useAppToaster();
const { t } = useTranslation();
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? ''];
return { node, nodeTemplate };
},
defaultSelectorOptions
),
[nodeId]
);
const { node, nodeTemplate } = useAppSelector(selector);
const needsUpdate = useMemo(
() => getNeedsUpdate(node, nodeTemplate),
[node, nodeTemplate]
);
const mayUpdate = useMemo(
() => getMayUpdateNode(node, nodeTemplate),
[node, nodeTemplate]
);
const _updateNode = useCallback(() => {
const needsUpdate = getNeedsUpdate(node, nodeTemplate);
const updatedNode = updateNode(node, nodeTemplate);
if (!updatedNode) {
if (needsUpdate) {
toast({ title: t('nodes.unableToUpdateNodes', { count: 1 }) });
}
return;
}
dispatch(nodeReplaced({ nodeId: updatedNode.id, node: updatedNode }));
}, [dispatch, node, nodeTemplate, t, toast]);
return { needsUpdate, mayUpdate, updateNode: _updateNode };
};

View File

@ -4,7 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { map } from 'lodash-es';
import { useMemo } from 'react';
import { isInvocationNode } from '../types/types';
import { isInvocationNode } from '../types/invocation';
import { getSortedFilteredFieldNames } from '../util/getSortedFilteredFieldNames';
export const useOutputFieldNames = (nodeId: string) => {

View File

@ -0,0 +1,23 @@
import { useTranslation } from 'react-i18next';
import { FieldType } from '../types/field';
import { useMemo } from 'react';
export const useFieldTypeName = (fieldType?: FieldType): string => {
const { t } = useTranslation();
const name = useMemo(() => {
if (!fieldType) {
return '';
}
const { name } = fieldType;
if (fieldType.isCollection) {
return t('nodes.collectionFieldType', { name });
}
if (fieldType.isPolymorphic) {
return t('nodes.polymorphicFieldType', { name });
}
return name;
}, [fieldType, t]);
return name;
};

View File

@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useMemo } from 'react';
import { isInvocationNode } from '../types/types';
import { isInvocationNode } from '../types/invocation';
export const useUseCache = (nodeId: string) => {
const selector = useMemo(

View File

@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useMemo } from 'react';
import { isInvocationNode } from '../types/types';
import { isInvocationNode } from '../types/invocation';
export const useWithWorkflow = (nodeId: string) => {
const selector = useMemo(

View File

@ -1,6 +1,5 @@
import { createAction, isAnyOf } from '@reduxjs/toolkit';
import { Graph } from 'services/api/types';
import { Workflow } from '../types/types';
export const textToImageGraphBuilt = createAction<Graph>(
'nodes/textToImageGraphBuilt'
@ -18,7 +17,7 @@ export const isAnyGraphBuilt = isAnyOf(
nodesGraphBuilt
);
export const workflowLoadRequested = createAction<Workflow>(
export const workflowLoadRequested = createAction<unknown>(
'nodes/workflowLoadRequested'
);

View File

@ -6,7 +6,7 @@ import { NodesState } from './types';
export const nodesPersistDenylist: (keyof NodesState)[] = [
'nodeTemplates',
'connectionStartParams',
'currentConnectionFieldType',
'connectionStartFieldType',
'selectedNodes',
'selectedEdges',
'isReady',

View File

@ -20,7 +20,6 @@ import {
XYPosition,
} from 'reactflow';
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
import { ImageField } from 'services/api/types';
import {
appSocketGeneratorProgress,
appSocketInvocationComplete,
@ -31,60 +30,58 @@ import {
import { v4 as uuidv4 } from 'uuid';
import { DRAG_HANDLE_CLASSNAME } from '../types/constants';
import {
BoardInputFieldValue,
BooleanInputFieldValue,
ColorInputFieldValue,
ControlNetModelInputFieldValue,
CurrentImageNodeData,
EnumInputFieldValue,
BoardFieldValue,
BooleanFieldValue,
ColorFieldValue,
ControlNetModelFieldValue,
EnumFieldValue,
FieldIdentifier,
FloatInputFieldValue,
ImageInputFieldValue,
InputFieldValue,
IntegerInputFieldValue,
InvocationNodeData,
FieldValue,
FloatFieldValue,
ImageFieldValue,
IntegerFieldValue,
IPAdapterModelFieldValue,
LoRAModelFieldValue,
MainModelFieldValue,
SchedulerFieldValue,
SDXLRefinerModelFieldValue,
StringFieldValue,
T2IAdapterModelFieldValue,
VAEModelFieldValue,
} from '../types/field';
import {
AnyNodeData,
InvocationTemplate,
IPAdapterModelInputFieldValue,
isInvocationNode,
isNotesNode,
LoRAModelInputFieldValue,
MainModelInputFieldValue,
NodeExecutionState,
NodeStatus,
NotesNodeData,
SchedulerInputFieldValue,
SDXLRefinerModelInputFieldValue,
StringInputFieldValue,
T2IAdapterModelInputFieldValue,
VaeModelInputFieldValue,
Workflow,
} from '../types/types';
zNodeStatus,
} from '../types/invocation';
import { WorkflowV2 } from '../types/workflow';
import { NodesState } from './types';
import { findUnoccupiedPosition } from './util/findUnoccupiedPosition';
import { findConnectionToValidHandle } from './util/findConnectionToValidHandle';
export const WORKFLOW_FORMAT_VERSION = '1.0.0';
import { findUnoccupiedPosition } from './util/findUnoccupiedPosition';
const initialNodeExecutionState: Omit<NodeExecutionState, 'nodeId'> = {
status: NodeStatus.PENDING,
status: zNodeStatus.enum.PENDING,
error: null,
progress: null,
progressImage: null,
outputs: [],
};
export const initialWorkflow = {
meta: {
version: WORKFLOW_FORMAT_VERSION,
},
const INITIAL_WORKFLOW: WorkflowV2 = {
name: '',
author: '',
description: '',
notes: '',
tags: '',
contact: '',
version: '',
contact: '',
tags: '',
notes: '',
nodes: [],
edges: [],
exposedFields: [],
meta: { version: '2.0.0' },
};
export const initialNodesState: NodesState = {
@ -93,11 +90,10 @@ export const initialNodesState: NodesState = {
nodeTemplates: {},
isReady: false,
connectionStartParams: null,
currentConnectionFieldType: null,
connectionStartFieldType: null,
connectionMade: false,
modifyingEdge: false,
addNewNodePosition: null,
shouldShowFieldTypeLegend: false,
shouldShowMinimapPanel: true,
shouldValidateGraph: true,
shouldAnimateEdges: true,
@ -107,7 +103,7 @@ export const initialNodesState: NodesState = {
nodeOpacity: 1,
selectedNodes: [],
selectedEdges: [],
workflow: initialWorkflow,
workflow: INITIAL_WORKFLOW,
nodeExecutionStates: {},
viewport: { x: 0, y: 0, zoom: 1 },
mouseOverField: null,
@ -117,13 +113,13 @@ export const initialNodesState: NodesState = {
selectionMode: SelectionMode.Partial,
};
type FieldValueAction<T extends InputFieldValue> = PayloadAction<{
type FieldValueAction<T extends FieldValue> = PayloadAction<{
nodeId: string;
fieldName: string;
value: T['value'];
value: T;
}>;
const fieldValueReducer = <T extends InputFieldValue>(
const fieldValueReducer = <T extends FieldValue>(
state: NodesState,
action: FieldValueAction<T>
) => {
@ -161,12 +157,7 @@ const nodesSlice = createSlice({
}
state.nodes[nodeIndex] = action.payload.node;
},
nodeAdded: (
state,
action: PayloadAction<
Node<InvocationNodeData | CurrentImageNodeData | NotesNodeData>
>
) => {
nodeAdded: (state, action: PayloadAction<Node<AnyNodeData>>) => {
const node = action.payload;
const position = findUnoccupiedPosition(
state.nodes,
@ -203,7 +194,7 @@ const nodesSlice = createSlice({
nodeId &&
handleId &&
handleType &&
state.currentConnectionFieldType
state.connectionStartFieldType
) {
const newConnection = findConnectionToValidHandle(
node,
@ -212,7 +203,7 @@ const nodesSlice = createSlice({
nodeId,
handleId,
handleType,
state.currentConnectionFieldType
state.connectionStartFieldType
);
if (newConnection) {
state.edges = addEdge(
@ -224,7 +215,7 @@ const nodesSlice = createSlice({
}
state.connectionStartParams = null;
state.currentConnectionFieldType = null;
state.connectionStartFieldType = null;
},
edgeChangeStarted: (state) => {
state.modifyingEdge = true;
@ -258,10 +249,10 @@ const nodesSlice = createSlice({
handleType === 'source'
? node.data.outputs[handleId]
: node.data.inputs[handleId];
state.currentConnectionFieldType = field?.type ?? null;
state.connectionStartFieldType = field?.type ?? null;
},
connectionMade: (state, action: PayloadAction<Connection>) => {
const fieldType = state.currentConnectionFieldType;
const fieldType = state.connectionStartFieldType;
if (!fieldType) {
return;
}
@ -286,7 +277,7 @@ const nodesSlice = createSlice({
nodeId &&
handleId &&
handleType &&
state.currentConnectionFieldType
state.connectionStartFieldType
) {
const newConnection = findConnectionToValidHandle(
mouseOverNode,
@ -295,7 +286,7 @@ const nodesSlice = createSlice({
nodeId,
handleId,
handleType,
state.currentConnectionFieldType
state.connectionStartFieldType
);
if (newConnection) {
state.edges = addEdge(
@ -306,14 +297,14 @@ const nodesSlice = createSlice({
}
}
state.connectionStartParams = null;
state.currentConnectionFieldType = null;
state.connectionStartFieldType = null;
} else {
state.addNewNodePosition = action.payload.cursorPosition;
state.isAddNodePopoverOpen = true;
}
} else {
state.connectionStartParams = null;
state.currentConnectionFieldType = null;
state.connectionStartFieldType = null;
}
state.modifyingEdge = false;
},
@ -529,12 +520,7 @@ const nodesSlice = createSlice({
state.edges = applyEdgeChanges(edgeChanges, state.edges);
}
},
nodesDeleted: (
state,
action: PayloadAction<
Node<InvocationNodeData | NotesNodeData | CurrentImageNodeData>[]
>
) => {
nodesDeleted: (state, action: PayloadAction<Node<AnyNodeData>[]>) => {
action.payload.forEach((node) => {
state.workflow.exposedFields = state.workflow.exposedFields.filter(
(f) => f.nodeId !== node.id
@ -588,132 +574,94 @@ const nodesSlice = createSlice({
},
fieldStringValueChanged: (
state,
action: FieldValueAction<StringInputFieldValue>
action: FieldValueAction<StringFieldValue>
) => {
fieldValueReducer(state, action);
},
fieldNumberValueChanged: (
state,
action: FieldValueAction<IntegerInputFieldValue | FloatInputFieldValue>
action: FieldValueAction<IntegerFieldValue | FloatFieldValue>
) => {
fieldValueReducer(state, action);
},
fieldBooleanValueChanged: (
state,
action: FieldValueAction<BooleanInputFieldValue>
action: FieldValueAction<BooleanFieldValue>
) => {
fieldValueReducer(state, action);
},
fieldBoardValueChanged: (
state,
action: FieldValueAction<BoardInputFieldValue>
action: FieldValueAction<BoardFieldValue>
) => {
fieldValueReducer(state, action);
},
fieldImageValueChanged: (
state,
action: FieldValueAction<ImageInputFieldValue>
action: FieldValueAction<ImageFieldValue>
) => {
fieldValueReducer(state, action);
},
fieldColorValueChanged: (
state,
action: FieldValueAction<ColorInputFieldValue>
action: FieldValueAction<ColorFieldValue>
) => {
fieldValueReducer(state, action);
},
fieldMainModelValueChanged: (
state,
action: FieldValueAction<MainModelInputFieldValue>
action: FieldValueAction<MainModelFieldValue>
) => {
fieldValueReducer(state, action);
},
fieldRefinerModelValueChanged: (
state,
action: FieldValueAction<SDXLRefinerModelInputFieldValue>
action: FieldValueAction<SDXLRefinerModelFieldValue>
) => {
fieldValueReducer(state, action);
},
fieldVaeModelValueChanged: (
state,
action: FieldValueAction<VaeModelInputFieldValue>
action: FieldValueAction<VAEModelFieldValue>
) => {
fieldValueReducer(state, action);
},
fieldLoRAModelValueChanged: (
state,
action: FieldValueAction<LoRAModelInputFieldValue>
action: FieldValueAction<LoRAModelFieldValue>
) => {
fieldValueReducer(state, action);
},
fieldControlNetModelValueChanged: (
state,
action: FieldValueAction<ControlNetModelInputFieldValue>
action: FieldValueAction<ControlNetModelFieldValue>
) => {
fieldValueReducer(state, action);
},
fieldIPAdapterModelValueChanged: (
state,
action: FieldValueAction<IPAdapterModelInputFieldValue>
action: FieldValueAction<IPAdapterModelFieldValue>
) => {
fieldValueReducer(state, action);
},
fieldT2IAdapterModelValueChanged: (
state,
action: FieldValueAction<T2IAdapterModelInputFieldValue>
action: FieldValueAction<T2IAdapterModelFieldValue>
) => {
fieldValueReducer(state, action);
},
fieldEnumModelValueChanged: (
state,
action: FieldValueAction<EnumInputFieldValue>
action: FieldValueAction<EnumFieldValue>
) => {
fieldValueReducer(state, action);
},
fieldSchedulerValueChanged: (
state,
action: FieldValueAction<SchedulerInputFieldValue>
action: FieldValueAction<SchedulerFieldValue>
) => {
fieldValueReducer(state, action);
},
imageCollectionFieldValueChanged: (
state,
action: PayloadAction<{
nodeId: string;
fieldName: string;
value: ImageField[];
}>
) => {
const { nodeId, fieldName, value } = action.payload;
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
if (nodeIndex === -1) {
return;
}
const node = state.nodes?.[nodeIndex];
if (!isInvocationNode(node)) {
return;
}
const input = node.data?.inputs[fieldName];
if (!input) {
return;
}
const currentValue = cloneDeep(input.value);
if (!currentValue) {
input.value = value;
return;
}
input.value = uniqBy(
(currentValue as ImageField[]).concat(value),
'image_name'
);
},
notesNodeValueChanged: (
state,
action: PayloadAction<{ nodeId: string; value: string }>
@ -726,12 +674,6 @@ const nodesSlice = createSlice({
}
node.data.notes = value;
},
shouldShowFieldTypeLegendChanged: (
state,
action: PayloadAction<boolean>
) => {
state.shouldShowFieldTypeLegend = action.payload;
},
shouldShowMinimapPanelChanged: (state, action: PayloadAction<boolean>) => {
state.shouldShowMinimapPanel = action.payload;
},
@ -745,7 +687,7 @@ const nodesSlice = createSlice({
nodeEditorReset: (state) => {
state.nodes = [];
state.edges = [];
state.workflow = cloneDeep(initialWorkflow);
state.workflow = cloneDeep(INITIAL_WORKFLOW);
},
shouldValidateGraphChanged: (state, action: PayloadAction<boolean>) => {
state.shouldValidateGraph = action.payload;
@ -783,7 +725,7 @@ const nodesSlice = createSlice({
workflowContactChanged: (state, action: PayloadAction<string>) => {
state.workflow.contact = action.payload;
},
workflowLoaded: (state, action: PayloadAction<Workflow>) => {
workflowLoaded: (state, action: PayloadAction<WorkflowV2>) => {
const { nodes, edges, ...workflow } = action.payload;
state.workflow = workflow;
@ -810,7 +752,7 @@ const nodesSlice = createSlice({
}, {});
},
workflowReset: (state) => {
state.workflow = cloneDeep(initialWorkflow);
state.workflow = cloneDeep(INITIAL_WORKFLOW);
},
viewportChanged: (state, action: PayloadAction<Viewport>) => {
state.viewport = action.payload;
@ -942,7 +884,7 @@ const nodesSlice = createSlice({
//Make sure these get reset if we close the popover and haven't selected a node
state.connectionStartParams = null;
state.currentConnectionFieldType = null;
state.connectionStartFieldType = null;
},
addNodePopoverToggled: (state) => {
state.isAddNodePopoverOpen = !state.isAddNodePopoverOpen;
@ -961,14 +903,14 @@ const nodesSlice = createSlice({
const { source_node_id } = action.payload.data;
const node = state.nodeExecutionStates[source_node_id];
if (node) {
node.status = NodeStatus.IN_PROGRESS;
node.status = zNodeStatus.enum.IN_PROGRESS;
}
});
builder.addCase(appSocketInvocationComplete, (state, action) => {
const { source_node_id, result } = action.payload.data;
const nes = state.nodeExecutionStates[source_node_id];
if (nes) {
nes.status = NodeStatus.COMPLETED;
nes.status = zNodeStatus.enum.COMPLETED;
if (nes.progress !== null) {
nes.progress = 1;
}
@ -979,7 +921,7 @@ const nodesSlice = createSlice({
const { source_node_id } = action.payload.data;
const node = state.nodeExecutionStates[source_node_id];
if (node) {
node.status = NodeStatus.FAILED;
node.status = zNodeStatus.enum.FAILED;
node.error = action.payload.data.error;
node.progress = null;
node.progressImage = null;
@ -990,7 +932,7 @@ const nodesSlice = createSlice({
action.payload.data;
const node = state.nodeExecutionStates[source_node_id];
if (node) {
node.status = NodeStatus.IN_PROGRESS;
node.status = zNodeStatus.enum.IN_PROGRESS;
node.progress = (step + 1) / total_steps;
node.progressImage = progress_image ?? null;
}
@ -998,7 +940,7 @@ const nodesSlice = createSlice({
builder.addCase(appSocketQueueItemStatusChanged, (state, action) => {
if (['in_progress'].includes(action.payload.data.queue_item.status)) {
forEach(state.nodeExecutionStates, (nes) => {
nes.status = NodeStatus.PENDING;
nes.status = zNodeStatus.enum.PENDING;
nes.error = null;
nes.progress = null;
nes.progressImage = null;
@ -1037,7 +979,6 @@ export const {
fieldSchedulerValueChanged,
fieldStringValueChanged,
fieldVaeModelValueChanged,
imageCollectionFieldValueChanged,
mouseOverFieldChanged,
mouseOverNodeChanged,
nodeAdded,
@ -1063,7 +1004,6 @@ export const {
selectionPasted,
shouldAnimateEdgesChanged,
shouldColorEdgesChanged,
shouldShowFieldTypeLegendChanged,
shouldShowMinimapPanelChanged,
shouldSnapToGridChanged,
shouldValidateGraphChanged,

View File

@ -6,25 +6,23 @@ import {
Viewport,
XYPosition,
} from 'reactflow';
import { FieldIdentifier, FieldType } from '../types/field';
import {
FieldIdentifier,
FieldType,
AnyNodeData,
InvocationEdgeExtra,
InvocationTemplate,
NodeData,
NodeExecutionState,
Workflow,
} from '../types/types';
} from '../types/invocation';
import { WorkflowV2 } from '../types/workflow';
export type NodesState = {
nodes: Node<NodeData>[];
nodes: Node<AnyNodeData>[];
edges: Edge<InvocationEdgeExtra>[];
nodeTemplates: Record<string, InvocationTemplate>;
connectionStartParams: OnConnectStartParams | null;
currentConnectionFieldType: FieldType | null;
connectionStartFieldType: FieldType | null;
connectionMade: boolean;
modifyingEdge: boolean;
shouldShowFieldTypeLegend: boolean;
shouldShowMinimapPanel: boolean;
shouldValidateGraph: boolean;
shouldAnimateEdges: boolean;
@ -33,13 +31,13 @@ export type NodesState = {
shouldColorEdges: boolean;
selectedNodes: string[];
selectedEdges: string[];
workflow: Omit<Workflow, 'nodes' | 'edges'>;
workflow: Omit<WorkflowV2, 'nodes' | 'edges'>;
nodeExecutionStates: Record<string, NodeExecutionState>;
viewport: Viewport;
isReady: boolean;
mouseOverField: FieldIdentifier | null;
mouseOverNode: string | null;
nodesToCopy: Node<NodeData>[];
nodesToCopy: Node<AnyNodeData>[];
edgesToCopy: Edge<InvocationEdgeExtra>[];
isAddNodePopoverOpen: boolean;
addNewNodePosition: XYPosition | null;

View File

@ -1,78 +1,73 @@
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import {
FieldInputInstance,
FieldOutputInstance,
} from 'features/nodes/types/field';
import {
CurrentImageNodeData,
InputFieldValue,
InvocationNodeData,
InvocationTemplate,
NotesNodeData,
OutputFieldValue,
} from 'features/nodes/types/types';
import { buildInputFieldValue } from 'features/nodes/util/fieldValueBuilders';
} from 'features/nodes/types/invocation';
import { buildFieldInputInstance } from 'features/nodes/util/buildFieldInputInstance';
import { reduce } from 'lodash-es';
import { Node, XYPosition } from 'reactflow';
import { AnyInvocationType } from 'services/events/types';
import { v4 as uuidv4 } from 'uuid';
export const SHARED_NODE_PROPERTIES: Partial<Node> = {
dragHandle: `.${DRAG_HANDLE_CLASSNAME}`,
};
export const buildNodeData = (
type: AnyInvocationType | 'current_image' | 'notes',
position: XYPosition,
template?: InvocationTemplate
):
| Node<CurrentImageNodeData>
| Node<NotesNodeData>
| Node<InvocationNodeData>
| undefined => {
const nodeId = uuidv4();
if (type === 'current_image') {
const node: Node<CurrentImageNodeData> = {
...SHARED_NODE_PROPERTIES,
export const buildNotesNode = (position: XYPosition): Node<NotesNodeData> => {
const nodeId = uuidv4();
const node: Node<NotesNodeData> = {
...SHARED_NODE_PROPERTIES,
id: nodeId,
type: 'notes',
position,
data: {
id: nodeId,
isOpen: true,
label: 'Notes',
notes: '',
type: 'notes',
},
};
return node;
};
export const buildCurrentImageNode = (
position: XYPosition
): Node<CurrentImageNodeData> => {
const nodeId = uuidv4();
const node: Node<CurrentImageNodeData> = {
...SHARED_NODE_PROPERTIES,
id: nodeId,
type: 'current_image',
position,
data: {
id: nodeId,
type: 'current_image',
position,
data: {
id: nodeId,
type: 'current_image',
isOpen: true,
label: 'Current Image',
},
};
isOpen: true,
label: 'Current Image',
},
};
return node;
};
return node;
}
if (type === 'notes') {
const node: Node<NotesNodeData> = {
...SHARED_NODE_PROPERTIES,
id: nodeId,
type: 'notes',
position,
data: {
id: nodeId,
isOpen: true,
label: 'Notes',
notes: '',
type: 'notes',
},
};
return node;
}
if (template === undefined) {
console.error(`Unable to find template ${type}.`);
return;
}
export const buildInvocationNode = (
position: XYPosition,
template: InvocationTemplate
): Node<InvocationNodeData> => {
const nodeId = uuidv4();
const { type } = template;
const inputs = reduce(
template.inputs,
(inputsAccumulator, inputTemplate, inputName) => {
const fieldId = uuidv4();
const inputFieldValue: InputFieldValue = buildInputFieldValue(
const inputFieldValue: FieldInputInstance = buildFieldInputInstance(
fieldId,
inputTemplate
);
@ -81,7 +76,7 @@ export const buildNodeData = (
return inputsAccumulator;
},
{} as Record<string, InputFieldValue>
{} as Record<string, FieldInputInstance>
);
const outputs = reduce(
@ -89,7 +84,7 @@ export const buildNodeData = (
(outputsAccumulator, outputTemplate, outputName) => {
const fieldId = uuidv4();
const outputFieldValue: OutputFieldValue = {
const outputFieldValue: FieldOutputInstance = {
id: fieldId,
name: outputName,
type: outputTemplate.type,
@ -100,10 +95,10 @@ export const buildNodeData = (
return outputsAccumulator;
},
{} as Record<string, OutputFieldValue>
{} as Record<string, FieldOutputInstance>
);
const invocation: Node<InvocationNodeData> = {
const node: Node<InvocationNodeData> = {
...SHARED_NODE_PROPERTIES,
id: nodeId,
type: 'invocation',
@ -117,11 +112,11 @@ export const buildNodeData = (
isOpen: true,
embedWorkflow: false,
isIntermediate: type === 'save_image' ? false : true,
useCache: template.useCache,
inputs,
outputs,
useCache: template.useCache,
},
};
return invocation;
return node;
};

View File

@ -1,20 +1,19 @@
import { Connection, HandleType } from 'reactflow';
import { Node, Edge } from 'reactflow';
import {
FieldType,
InputFieldValue,
OutputFieldValue,
} from 'features/nodes/types/types';
import { Connection, Edge, HandleType, Node } from 'reactflow';
import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
import {
FieldInputInstance,
FieldOutputInstance,
FieldType,
} from 'features/nodes/types/field';
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
const isValidConnection = (
edges: Edge[],
handleCurrentType: HandleType,
handleCurrentFieldType: FieldType,
node: Node,
handle: InputFieldValue | OutputFieldValue
handle: FieldInputInstance | FieldOutputInstance
) => {
let isValidConnection = true;
if (handleCurrentType === 'source') {

View File

@ -1,9 +1,9 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
import { FieldType } from 'features/nodes/types/types';
import { FieldType } from 'features/nodes/types/field';
import i18n from 'i18next';
import { HandleType } from 'reactflow';
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
/**
@ -17,15 +17,15 @@ export const makeConnectionErrorSelector = (
handleType: HandleType,
fieldType?: FieldType
) => {
return createSelector(stateSelector, (state) => {
return createSelector(stateSelector, (state): string | undefined => {
if (!fieldType) {
return i18n.t('nodes.noFieldType');
}
const { currentConnectionFieldType, connectionStartParams, nodes, edges } =
const { connectionStartFieldType, connectionStartParams, nodes, edges } =
state.nodes;
if (!connectionStartParams || !currentConnectionFieldType) {
if (!connectionStartParams || !connectionStartFieldType) {
return i18n.t('nodes.noConnectionInProgress');
}
@ -40,9 +40,9 @@ export const makeConnectionErrorSelector = (
}
const targetType =
handleType === 'target' ? fieldType : currentConnectionFieldType;
handleType === 'target' ? fieldType : connectionStartFieldType;
const sourceType =
handleType === 'source' ? fieldType : currentConnectionFieldType;
handleType === 'source' ? fieldType : connectionStartFieldType;
if (nodeId === connectionNodeId) {
return i18n.t('nodes.cannotConnectToSelf');
@ -80,7 +80,7 @@ export const makeConnectionErrorSelector = (
return edge.target === target && edge.targetHandle === targetHandle;
}) &&
// except CollectionItem inputs can have multiples
targetType !== 'CollectionItem'
targetType.name !== 'CollectionItemField'
) {
return i18n.t('nodes.inputMayOnlyHaveOneConnection');
}
@ -100,6 +100,6 @@ export const makeConnectionErrorSelector = (
return i18n.t('nodes.connectionWouldCreateCycle');
}
return null;
return;
});
};

View File

@ -0,0 +1,68 @@
import { satisfies } from 'compare-versions';
import { NodeUpdateError } from 'features/nodes/types/error';
import {
InvocationNodeData,
InvocationTemplate,
} from 'features/nodes/types/invocation';
import { zParsedSemver } from 'features/nodes/types/semver';
import { cloneDeep, defaultsDeep } from 'lodash-es';
import { Node } from 'reactflow';
import { buildInvocationNode } from './buildNodeData';
export const getNeedsUpdate = (
node: Node<InvocationNodeData>,
template: InvocationTemplate
): boolean => {
if (node.data.type !== template.type) {
return true;
}
return node.data.version !== template.version;
}; /**
* Checks if a node may be updated by comparing its major version with the template's major version.
* @param node The node to check.
* @param template The invocation template to check against.
*/
export const getMayUpdateNode = (
node: Node<InvocationNodeData>,
template: InvocationTemplate
): boolean => {
const needsUpdate = getNeedsUpdate(node, template);
if (!needsUpdate || node.data.type !== template.type) {
return false;
}
const templateMajor = zParsedSemver.parse(template.version).major;
return satisfies(node.data.version, `^${templateMajor}`);
}; /**
* Updates a node to the latest version of its template:
* - Create a new node data object with the latest version of the template.
* - Recursively merge new node data object into the node to be updated.
*
* @param node The node to updated.
* @param template The invocation template to update to.
* @throws {NodeUpdateError} If the node is not an invocation node.
*/
export const updateNode = (
node: Node<InvocationNodeData>,
template: InvocationTemplate
): Node<InvocationNodeData> => {
const mayUpdate = getMayUpdateNode(node, template);
if (!mayUpdate || node.data.type !== template.type) {
throw new NodeUpdateError(`Unable to update node ${node.id}`);
}
// Start with a "fresh" node - just as if the user created a new node of this type
const defaults = buildInvocationNode(node.position, template);
// The updateability of a node, via semver comparison, relies on the this kind of recursive merge
// being valid. We rely on the template's major version to be majorly incremented if this kind of
// merge would result in an invalid node.
const clone = cloneDeep(node);
clone.data.version = template.version;
defaultsDeep(clone, defaults); // mutates!
return clone;
};

View File

@ -1,11 +1,12 @@
import {
COLLECTION_MAP,
COLLECTION_TYPES,
POLYMORPHIC_TO_SINGLE_MAP,
POLYMORPHIC_TYPES,
} from 'features/nodes/types/constants';
import { FieldType } from 'features/nodes/types/types';
import { FieldType } from 'features/nodes/types/field';
import { isEqual } from 'lodash-es';
/**
* Validates that the source and target types are compatible for a connection.
* @param sourceType The type of the source field.
* @param targetType The type of the target field.
* @returns True if the connection is valid, false otherwise.
*/
export const validateSourceAndTargetTypes = (
sourceType: FieldType,
targetType: FieldType
@ -13,11 +14,14 @@ export const validateSourceAndTargetTypes = (
// TODO: There's a bug with Collect -> Iterate nodes:
// https://github.com/invoke-ai/InvokeAI/issues/3956
// Once this is resolved, we can remove this check.
if (sourceType === 'Collection' && targetType === 'Collection') {
if (
sourceType.name === 'CollectionField' &&
targetType.name === 'CollectionField'
) {
return false;
}
if (sourceType === targetType) {
if (isEqual(sourceType, targetType)) {
return true;
}
@ -31,46 +35,42 @@ export const validateSourceAndTargetTypes = (
*/
const isCollectionItemToNonCollection =
sourceType === 'CollectionItem' && !COLLECTION_TYPES.includes(targetType);
sourceType.name === 'CollectionItemField' && !targetType.isCollection;
const isNonCollectionToCollectionItem =
targetType === 'CollectionItem' &&
!COLLECTION_TYPES.includes(sourceType) &&
!POLYMORPHIC_TYPES.includes(sourceType);
targetType.name === 'CollectionItemField' &&
!sourceType.isCollection &&
!sourceType.isPolymorphic;
const isAnythingToPolymorphicOfSameBaseType =
POLYMORPHIC_TYPES.includes(targetType) &&
(() => {
if (!POLYMORPHIC_TYPES.includes(targetType)) {
return false;
}
const baseType =
POLYMORPHIC_TO_SINGLE_MAP[
targetType as keyof typeof POLYMORPHIC_TO_SINGLE_MAP
];
const collectionType =
COLLECTION_MAP[baseType as keyof typeof COLLECTION_MAP];
return sourceType === baseType || sourceType === collectionType;
})();
targetType.isPolymorphic && sourceType.name === targetType.name;
const isGenericCollectionToAnyCollectionOrPolymorphic =
sourceType === 'Collection' &&
(COLLECTION_TYPES.includes(targetType) ||
POLYMORPHIC_TYPES.includes(targetType));
sourceType.name === 'CollectionField' &&
(targetType.isCollection || targetType.isPolymorphic);
const isCollectionToGenericCollection =
targetType === 'Collection' && COLLECTION_TYPES.includes(sourceType);
targetType.name === 'CollectionField' && sourceType.isCollection;
const isIntToFloat = sourceType === 'integer' && targetType === 'float';
const areBothTypesSingle =
!sourceType.isCollection &&
!sourceType.isPolymorphic &&
!targetType.isCollection &&
!targetType.isPolymorphic;
const isIntToFloat =
areBothTypesSingle &&
sourceType.name === 'IntegerField' &&
targetType.name === 'FloatField';
const isIntOrFloatToString =
(sourceType === 'integer' || sourceType === 'float') &&
targetType === 'string';
areBothTypesSingle &&
(sourceType.name === 'IntegerField' || sourceType.name === 'FloatField') &&
targetType.name === 'StringField';
const isTargetAnyType = targetType === 'Any';
const isTargetAnyType = targetType.name === 'AnyField';
// One of these must be true for the connection to be valid
return (
isCollectionItemToNonCollection ||
isNonCollectionToCollectionItem ||

View File

@ -0,0 +1,216 @@
import { z } from 'zod';
// #region Field data schemas
export const zImageField = z.object({
image_name: z.string().trim().min(1),
});
export type ImageField = z.infer<typeof zImageField>;
export const zBoardField = z.object({
board_id: z.string().trim().min(1),
});
export type BoardField = z.infer<typeof zBoardField>;
export const zColorField = z.object({
r: z.number().int().min(0).max(255),
g: z.number().int().min(0).max(255),
b: z.number().int().min(0).max(255),
a: z.number().int().min(0).max(255),
});
export type ColorField = z.infer<typeof zColorField>;
export const zSchedulerField = z.enum([
'euler',
'deis',
'ddim',
'ddpm',
'dpmpp_2s',
'dpmpp_2m',
'dpmpp_2m_sde',
'dpmpp_sde',
'heun',
'kdpm_2',
'lms',
'pndm',
'unipc',
'euler_k',
'dpmpp_2s_k',
'dpmpp_2m_k',
'dpmpp_2m_sde_k',
'dpmpp_sde_k',
'heun_k',
'lms_k',
'euler_a',
'kdpm_2_a',
'lcm',
]);
export type SchedulerField = z.infer<typeof zSchedulerField>;
// #endregion
// #region Model-related schemas
export const zBaseModel = z.enum([
'any',
'sd-1',
'sd-2',
'sdxl',
'sdxl-refiner',
]);
export const zModelType = z.enum([
'onnx',
'main',
'vae',
'lora',
'controlnet',
'embedding',
]);
export const zModelName = z.string().trim().min(1);
export const zModelIdentifier = z.object({
model_name: zModelName,
base_model: zBaseModel,
});
export type BaseModel = z.infer<typeof zBaseModel>;
export type ModelType = z.infer<typeof zModelType>;
export type ModelIdentifier = z.infer<typeof zModelIdentifier>;
export const zMainModelField = z.object({
model_name: zModelName,
base_model: zBaseModel,
model_type: z.literal('main'),
});
export const zONNXModelField = z.object({
model_name: zModelName,
base_model: zBaseModel,
model_type: z.literal('onnx'),
});
export const zMainOrONNXModelField = z.union([
zMainModelField,
zONNXModelField,
]);
export const zSDXLRefinerModelField = z.object({
model_name: z.string().min(1),
base_model: z.literal('sdxl-refiner'),
model_type: z.literal('main'),
});
export type MainModelField = z.infer<typeof zMainModelField>;
export type ONNXModelField = z.infer<typeof zONNXModelField>;
export type MainOrONNXModelField = z.infer<typeof zMainOrONNXModelField>;
export type SDXLRefinerModelField = z.infer<typeof zSDXLRefinerModelField>;
export const zSubModelType = z.enum([
'unet',
'text_encoder',
'text_encoder_2',
'tokenizer',
'tokenizer_2',
'vae',
'vae_decoder',
'vae_encoder',
'scheduler',
'safety_checker',
]);
export type SubModelType = z.infer<typeof zSubModelType>;
export const zVAEModelField = zModelIdentifier;
export const zModelInfo = zModelIdentifier.extend({
model_type: zModelType,
submodel: zSubModelType.optional(),
});
export type ModelInfo = z.infer<typeof zModelInfo>;
export const zLoRAModelField = zModelIdentifier;
export type LoRAModelField = z.infer<typeof zLoRAModelField>;
export const zControlNetModelField = zModelIdentifier;
export type ControlNetModelField = z.infer<typeof zControlNetModelField>;
export const zIPAdapterModelField = zModelIdentifier;
export type IPAdapterModelField = z.infer<typeof zIPAdapterModelField>;
export const zT2IAdapterModelField = zModelIdentifier;
export type T2IAdapterModelField = z.infer<typeof zT2IAdapterModelField>;
export const zLoraInfo = zModelInfo.extend({
weight: z.number().optional(),
});
export type LoraInfo = z.infer<typeof zLoraInfo>;
export const zUNetField = z.object({
unet: zModelInfo,
scheduler: zModelInfo,
loras: z.array(zLoraInfo),
});
export type UNetField = z.infer<typeof zUNetField>;
export const zCLIPField = z.object({
tokenizer: zModelInfo,
text_encoder: zModelInfo,
skipped_layers: z.number(),
loras: z.array(zLoraInfo),
});
export type CLIPField = z.infer<typeof zCLIPField>;
export const zVAEField = z.object({
vae: zModelInfo,
});
export type VAEField = z.infer<typeof zVAEField>;
// #endregion
// #region Control Adapters
export const zControlField = z.object({
image: zImageField,
control_model: zControlNetModelField,
control_weight: z.union([z.number(), z.array(z.number())]).optional(),
begin_step_percent: z.number().optional(),
end_step_percent: z.number().optional(),
control_mode: z
.enum(['balanced', 'more_prompt', 'more_control', 'unbalanced'])
.optional(),
resize_mode: z
.enum(['just_resize', 'crop_resize', 'fill_resize', 'just_resize_simple'])
.optional(),
});
export type ControlField = z.infer<typeof zControlField>;
export const zIPAdapterField = z.object({
image: zImageField,
ip_adapter_model: zIPAdapterModelField,
weight: z.number(),
begin_step_percent: z.number().optional(),
end_step_percent: z.number().optional(),
});
export type IPAdapterField = z.infer<typeof zIPAdapterField>;
export const zT2IAdapterField = z.object({
image: zImageField,
t2i_adapter_model: zT2IAdapterModelField,
weight: z.union([z.number(), z.array(z.number())]).optional(),
begin_step_percent: z.number().optional(),
end_step_percent: z.number().optional(),
resize_mode: z
.enum(['just_resize', 'crop_resize', 'fill_resize', 'just_resize_simple'])
.optional(),
});
export type T2IAdapterField = z.infer<typeof zT2IAdapterField>;
// #endregion
// #region ProgressImage
export const zProgressImage = z.object({
dataURL: z.string(),
width: z.number().int(),
height: z.number().int(),
});
export type ProgressImage = z.infer<typeof zProgressImage>;
// #endregion
// #region ImageOutput
export const zImageOutput = z.object({
image: zImageField,
width: z.number().int(),
height: z.number().int(),
type: z.literal('image_output'),
});
export type ImageOutput = z.infer<typeof zImageOutput>;
export const isImageOutput = (output: unknown): output is ImageOutput =>
zImageOutput.safeParse(output).success;
// #endregion

View File

@ -1,58 +1,31 @@
import {
FieldType,
FieldTypeMap,
FieldTypeMapWithNumber,
FieldUIConfig,
} from './types';
import { t } from 'i18next';
/**
* How long to wait before showing a tooltip when hovering a field handle.
*/
export const HANDLE_TOOLTIP_OPEN_DELAY = 500;
export const COLOR_TOKEN_VALUE = 500;
/**
* The width of a node in the UI in pixels.
*/
export const NODE_WIDTH = 320;
export const NODE_MIN_WIDTH = 320;
/**
* This class name is special - reactflow uses it to identify the drag handle of a node,
* applying the appropriate listeners to it.
*/
export const DRAG_HANDLE_CLASSNAME = 'node-drag-handle';
export const IMAGE_FIELDS = ['ImageField', 'ImageCollection'];
export const FOOTER_FIELDS = IMAGE_FIELDS;
/**
* Helper for getting the kind of a field.
*/
export const KIND_MAP = {
input: 'inputs' as const,
output: 'outputs' as const,
};
export const COLLECTION_TYPES: FieldType[] = [
'Collection',
'IntegerCollection',
'BooleanCollection',
'FloatCollection',
'StringCollection',
'ImageCollection',
'LatentsCollection',
'ConditioningCollection',
'ControlCollection',
'ColorCollection',
'T2IAdapterCollection',
'IPAdapterCollection',
'MetadataItemCollection',
'MetadataCollection',
];
export const POLYMORPHIC_TYPES: FieldType[] = [
'IntegerPolymorphic',
'BooleanPolymorphic',
'FloatPolymorphic',
'StringPolymorphic',
'ImagePolymorphic',
'LatentsPolymorphic',
'ConditioningPolymorphic',
'ControlPolymorphic',
'ColorPolymorphic',
'T2IAdapterPolymorphic',
'IPAdapterPolymorphic',
'MetadataItemPolymorphic',
];
export const MODEL_TYPES: FieldType[] = [
/**
* Model types' handles are rendered as squares in the UI.
*/
export const MODEL_TYPES = [
'IPAdapterModelField',
'ControlNetModelField',
'LoRAModelField',
@ -68,373 +41,33 @@ export const MODEL_TYPES: FieldType[] = [
'IPAdapterModelField',
];
export const COLLECTION_MAP: FieldTypeMapWithNumber = {
integer: 'IntegerCollection',
boolean: 'BooleanCollection',
number: 'FloatCollection',
float: 'FloatCollection',
string: 'StringCollection',
ImageField: 'ImageCollection',
LatentsField: 'LatentsCollection',
ConditioningField: 'ConditioningCollection',
ControlField: 'ControlCollection',
ColorField: 'ColorCollection',
T2IAdapterField: 'T2IAdapterCollection',
IPAdapterField: 'IPAdapterCollection',
MetadataItemField: 'MetadataItemCollection',
MetadataField: 'MetadataCollection',
};
export const isCollectionItemType = (
itemType: string | undefined
): itemType is keyof typeof COLLECTION_MAP =>
Boolean(itemType && itemType in COLLECTION_MAP);
export const SINGLE_TO_POLYMORPHIC_MAP: FieldTypeMapWithNumber = {
integer: 'IntegerPolymorphic',
boolean: 'BooleanPolymorphic',
number: 'FloatPolymorphic',
float: 'FloatPolymorphic',
string: 'StringPolymorphic',
ImageField: 'ImagePolymorphic',
LatentsField: 'LatentsPolymorphic',
ConditioningField: 'ConditioningPolymorphic',
ControlField: 'ControlPolymorphic',
ColorField: 'ColorPolymorphic',
T2IAdapterField: 'T2IAdapterPolymorphic',
IPAdapterField: 'IPAdapterPolymorphic',
MetadataItemField: 'MetadataItemPolymorphic',
};
export const POLYMORPHIC_TO_SINGLE_MAP: FieldTypeMap = {
IntegerPolymorphic: 'integer',
BooleanPolymorphic: 'boolean',
FloatPolymorphic: 'float',
StringPolymorphic: 'string',
ImagePolymorphic: 'ImageField',
LatentsPolymorphic: 'LatentsField',
ConditioningPolymorphic: 'ConditioningField',
ControlPolymorphic: 'ControlField',
ColorPolymorphic: 'ColorField',
T2IAdapterPolymorphic: 'T2IAdapterField',
IPAdapterPolymorphic: 'IPAdapterField',
MetadataItemPolymorphic: 'MetadataItemField',
};
export const TYPES_WITH_INPUT_COMPONENTS: FieldType[] = [
'string',
'StringPolymorphic',
'boolean',
'BooleanPolymorphic',
'integer',
'float',
'FloatPolymorphic',
'IntegerPolymorphic',
'enum',
'ImageField',
'ImagePolymorphic',
'MainModelField',
'SDXLRefinerModelField',
'VaeModelField',
'LoRAModelField',
'ControlNetModelField',
'ColorField',
'SDXLMainModelField',
'Scheduler',
'IPAdapterModelField',
'BoardField',
'T2IAdapterModelField',
];
export const isPolymorphicItemType = (
itemType: string | undefined
): itemType is keyof typeof SINGLE_TO_POLYMORPHIC_MAP =>
Boolean(itemType && itemType in SINGLE_TO_POLYMORPHIC_MAP);
export const FIELDS: Record<FieldType, FieldUIConfig> = {
Any: {
color: 'gray.500',
description: 'Any field type is accepted.',
title: 'Any',
},
MetadataField: {
color: 'gray.500',
description: 'A metadata dict.',
title: 'Metadata Dict',
},
MetadataCollection: {
color: 'gray.500',
description: 'A collection of metadata dicts.',
title: 'Metadata Dict Collection',
},
MetadataItemField: {
color: 'gray.500',
description: 'A metadata item.',
title: 'Metadata Item',
},
MetadataItemCollection: {
color: 'gray.500',
description: 'Any field type is accepted.',
title: 'Metadata Item Collection',
},
MetadataItemPolymorphic: {
color: 'gray.500',
description:
'MetadataItem or MetadataItemCollection field types are accepted.',
title: 'Metadata Item Polymorphic',
},
boolean: {
color: 'green.500',
description: t('nodes.booleanDescription'),
title: t('nodes.boolean'),
},
BooleanCollection: {
color: 'green.500',
description: t('nodes.booleanCollectionDescription'),
title: t('nodes.booleanCollection'),
},
BooleanPolymorphic: {
color: 'green.500',
description: t('nodes.booleanPolymorphicDescription'),
title: t('nodes.booleanPolymorphic'),
},
ClipField: {
color: 'green.500',
description: t('nodes.clipFieldDescription'),
title: t('nodes.clipField'),
},
Collection: {
color: 'base.500',
description: t('nodes.collectionDescription'),
title: t('nodes.collection'),
},
CollectionItem: {
color: 'base.500',
description: t('nodes.collectionItemDescription'),
title: t('nodes.collectionItem'),
},
ColorCollection: {
color: 'pink.300',
description: t('nodes.colorCollectionDescription'),
title: t('nodes.colorCollection'),
},
ColorField: {
color: 'pink.300',
description: t('nodes.colorFieldDescription'),
title: t('nodes.colorField'),
},
ColorPolymorphic: {
color: 'pink.300',
description: t('nodes.colorPolymorphicDescription'),
title: t('nodes.colorPolymorphic'),
},
ConditioningCollection: {
color: 'cyan.500',
description: t('nodes.conditioningCollectionDescription'),
title: t('nodes.conditioningCollection'),
},
ConditioningField: {
color: 'cyan.500',
description: t('nodes.conditioningFieldDescription'),
title: t('nodes.conditioningField'),
},
ConditioningPolymorphic: {
color: 'cyan.500',
description: t('nodes.conditioningPolymorphicDescription'),
title: t('nodes.conditioningPolymorphic'),
},
ControlCollection: {
color: 'teal.500',
description: t('nodes.controlCollectionDescription'),
title: t('nodes.controlCollection'),
},
ControlField: {
color: 'teal.500',
description: t('nodes.controlFieldDescription'),
title: t('nodes.controlField'),
},
ControlNetModelField: {
color: 'teal.500',
description: 'TODO',
title: 'ControlNet',
},
ControlPolymorphic: {
color: 'teal.500',
description: 'Control info passed between nodes.',
title: 'Control Polymorphic',
},
DenoiseMaskField: {
color: 'blue.300',
description: t('nodes.denoiseMaskFieldDescription'),
title: t('nodes.denoiseMaskField'),
},
enum: {
color: 'blue.500',
description: t('nodes.enumDescription'),
title: t('nodes.enum'),
},
float: {
color: 'orange.500',
description: t('nodes.floatDescription'),
title: t('nodes.float'),
},
FloatCollection: {
color: 'orange.500',
description: t('nodes.floatCollectionDescription'),
title: t('nodes.floatCollection'),
},
FloatPolymorphic: {
color: 'orange.500',
description: t('nodes.floatPolymorphicDescription'),
title: t('nodes.floatPolymorphic'),
},
ImageCollection: {
color: 'purple.500',
description: t('nodes.imageCollectionDescription'),
title: t('nodes.imageCollection'),
},
ImageField: {
color: 'purple.500',
description: t('nodes.imageFieldDescription'),
title: t('nodes.imageField'),
},
BoardField: {
color: 'purple.500',
description: t('nodes.imageFieldDescription'),
title: t('nodes.imageField'),
},
ImagePolymorphic: {
color: 'purple.500',
description: t('nodes.imagePolymorphicDescription'),
title: t('nodes.imagePolymorphic'),
},
integer: {
color: 'red.500',
description: t('nodes.integerDescription'),
title: t('nodes.integer'),
},
IntegerCollection: {
color: 'red.500',
description: t('nodes.integerCollectionDescription'),
title: t('nodes.integerCollection'),
},
IntegerPolymorphic: {
color: 'red.500',
description: t('nodes.integerPolymorphicDescription'),
title: t('nodes.integerPolymorphic'),
},
IPAdapterCollection: {
color: 'teal.500',
description: t('nodes.ipAdapterCollectionDescription'),
title: t('nodes.ipAdapterCollection'),
},
IPAdapterField: {
color: 'teal.500',
description: t('nodes.ipAdapterDescription'),
title: t('nodes.ipAdapter'),
},
IPAdapterModelField: {
color: 'teal.500',
description: t('nodes.ipAdapterModelDescription'),
title: t('nodes.ipAdapterModel'),
},
IPAdapterPolymorphic: {
color: 'teal.500',
description: t('nodes.ipAdapterPolymorphicDescription'),
title: t('nodes.ipAdapterPolymorphic'),
},
LatentsCollection: {
color: 'pink.500',
description: t('nodes.latentsCollectionDescription'),
title: t('nodes.latentsCollection'),
},
LatentsField: {
color: 'pink.500',
description: t('nodes.latentsFieldDescription'),
title: t('nodes.latentsField'),
},
LatentsPolymorphic: {
color: 'pink.500',
description: t('nodes.latentsPolymorphicDescription'),
title: t('nodes.latentsPolymorphic'),
},
LoRAModelField: {
color: 'teal.500',
description: t('nodes.loRAModelFieldDescription'),
title: t('nodes.loRAModelField'),
},
MainModelField: {
color: 'teal.500',
description: t('nodes.mainModelFieldDescription'),
title: t('nodes.mainModelField'),
},
ONNXModelField: {
color: 'teal.500',
description: t('nodes.oNNXModelFieldDescription'),
title: t('nodes.oNNXModelField'),
},
Scheduler: {
color: 'base.500',
description: t('nodes.schedulerDescription'),
title: t('nodes.scheduler'),
},
SDXLMainModelField: {
color: 'teal.500',
description: t('nodes.sDXLMainModelFieldDescription'),
title: t('nodes.sDXLMainModelField'),
},
SDXLRefinerModelField: {
color: 'teal.500',
description: t('nodes.sDXLRefinerModelFieldDescription'),
title: t('nodes.sDXLRefinerModelField'),
},
string: {
color: 'yellow.500',
description: t('nodes.stringDescription'),
title: t('nodes.string'),
},
StringCollection: {
color: 'yellow.500',
description: t('nodes.stringCollectionDescription'),
title: t('nodes.stringCollection'),
},
StringPolymorphic: {
color: 'yellow.500',
description: t('nodes.stringPolymorphicDescription'),
title: t('nodes.stringPolymorphic'),
},
T2IAdapterCollection: {
color: 'teal.500',
description: t('nodes.t2iAdapterCollectionDescription'),
title: t('nodes.t2iAdapterCollection'),
},
T2IAdapterField: {
color: 'teal.500',
description: t('nodes.t2iAdapterFieldDescription'),
title: t('nodes.t2iAdapterField'),
},
T2IAdapterModelField: {
color: 'teal.500',
description: 'TODO',
title: 'T2I-Adapter',
},
T2IAdapterPolymorphic: {
color: 'teal.500',
description: 'T2I-Adapter info passed between nodes.',
title: 'T2I-Adapter Polymorphic',
},
UNetField: {
color: 'red.500',
description: t('nodes.uNetFieldDescription'),
title: t('nodes.uNetField'),
},
VaeField: {
color: 'blue.500',
description: t('nodes.vaeFieldDescription'),
title: t('nodes.vaeField'),
},
VaeModelField: {
color: 'teal.500',
description: t('nodes.vaeModelFieldDescription'),
title: t('nodes.vaeModelField'),
},
/**
* Colors for each field type - applies to their handles and edges.
*/
export const FIELD_COLORS: { [key: string]: string } = {
BoardField: 'purple.500',
BooleanField: 'green.500',
ClipField: 'green.500',
ColorField: 'pink.300',
ConditioningField: 'cyan.500',
ControlField: 'teal.500',
ControlNetModelField: 'teal.500',
EnumField: 'blue.500',
FloatField: 'orange.500',
ImageField: 'purple.500',
IntegerField: 'red.500',
IPAdapterField: 'teal.500',
IPAdapterModelField: 'teal.500',
LatentsField: 'pink.500',
LoRAModelField: 'teal.500',
MainModelField: 'teal.500',
ONNXModelField: 'teal.500',
SDXLMainModelField: 'teal.500',
SDXLRefinerModelField: 'teal.500',
StringField: 'yellow.500',
T2IAdapterField: 'teal.500',
T2IAdapterModelField: 'teal.500',
UNetField: 'red.500',
VaeField: 'blue.500',
VaeModelField: 'teal.500',
};

View File

@ -0,0 +1,59 @@
/**
* Invalid Workflow Version Error
* Raised when a workflow version is not recognized.
*/
export class WorkflowVersionError extends Error {
/**
* Create WorkflowVersionError
* @param {String} message
*/
constructor(message: string) {
super(message);
this.name = this.constructor.name;
}
}
/**
* Unable to Update Node Error
* Raised when a node cannot be updated.
*/
export class NodeUpdateError extends Error {
/**
* Create NodeUpdateError
* @param {String} message
*/
constructor(message: string) {
super(message);
this.name = this.constructor.name;
}
}
/**
* FieldTypeParseError
* Raised when a field cannot be parsed from a field schema.
*/
export class FieldTypeParseError extends Error {
/**
* Create FieldTypeParseError
* @param {String} message
*/
constructor(message: string) {
super(message);
this.name = this.constructor.name;
}
}
/**
* UnsupportedFieldTypeError
* Raised when an unsupported field type is parsed.
*/
export class UnsupportedFieldTypeError extends Error {
/**
* Create UnsupportedFieldTypeError
* @param {String} message
*/
constructor(message: string) {
super(message);
this.name = this.constructor.name;
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,108 @@
import { Node } from 'reactflow';
import { z } from 'zod';
import { zProgressImage } from './common';
import {
zFieldInputInstance,
zFieldInputTemplate,
zFieldOutputInstance,
zFieldOutputTemplate,
} from './field';
import { zSemVer } from './semver';
// #region InvocationTemplate
export const zInvocationTemplate = z.object({
type: z.string(),
title: z.string(),
description: z.string(),
tags: z.array(z.string().min(1)),
inputs: z.record(zFieldInputTemplate),
outputs: z.record(zFieldOutputTemplate),
outputType: z.string().min(1),
withWorkflow: z.boolean(),
version: zSemVer,
useCache: z.boolean(),
});
export type InvocationTemplate = z.infer<typeof zInvocationTemplate>;
// #endregion
// #region NodeData
export const zInvocationNodeData = z.object({
id: z.string().trim().min(1),
type: z.string().trim().min(1),
label: z.string(),
isOpen: z.boolean(),
notes: z.string(),
embedWorkflow: z.boolean(),
isIntermediate: z.boolean(),
useCache: z.boolean(),
version: zSemVer,
inputs: z.record(zFieldInputInstance),
outputs: z.record(zFieldOutputInstance),
});
export const zNotesNodeData = z.object({
id: z.string().trim().min(1),
type: z.literal('notes'),
label: z.string(),
isOpen: z.boolean(),
notes: z.string(),
});
export const zCurrentImageNodeData = z.object({
id: z.string().trim().min(1),
type: z.literal('current_image'),
label: z.string(),
isOpen: z.boolean(),
});
export const zAnyNodeData = z.union([
zInvocationNodeData,
zNotesNodeData,
zCurrentImageNodeData,
]);
export type NotesNodeData = z.infer<typeof zNotesNodeData>;
export type InvocationNodeData = z.infer<typeof zInvocationNodeData>;
export type CurrentImageNodeData = z.infer<typeof zCurrentImageNodeData>;
export type AnyNodeData = z.infer<typeof zAnyNodeData>;
export const isInvocationNode = (
node?: Node<AnyNodeData>
): node is Node<InvocationNodeData> =>
Boolean(node && node.type === 'invocation');
export const isNotesNode = (
node?: Node<AnyNodeData>
): node is Node<NotesNodeData> => Boolean(node && node.type === 'notes');
export const isProgressImageNode = (
node?: Node<AnyNodeData>
): node is Node<CurrentImageNodeData> =>
Boolean(node && node.type === 'current_image');
export const isInvocationNodeData = (
node?: AnyNodeData
): node is InvocationNodeData =>
Boolean(node && !['notes', 'current_image'].includes(node.type)); // node.type may be 'notes', 'current_image', or any invocation type
// #endregion
// #region NodeExecutionState
export const zNodeStatus = z.enum([
'PENDING',
'IN_PROGRESS',
'COMPLETED',
'FAILED',
]);
export const zNodeExecutionState = z.object({
nodeId: z.string().trim().min(1),
status: zNodeStatus,
progress: z.number().nullable(),
progressImage: zProgressImage.nullable(),
error: z.string().nullable(),
outputs: z.array(z.any()),
});
export type NodeExecutionState = z.infer<typeof zNodeExecutionState>;
export type NodeStatus = z.infer<typeof zNodeStatus>;
// #endregion
// #region Edges
export const zInvocationEdgeExtra = z.object({
type: z.union([z.literal('default'), z.literal('collapsed')]),
});
export type InvocationEdgeExtra = z.infer<typeof zInvocationEdgeExtra>;
// #endregion

View File

@ -0,0 +1,81 @@
import { z } from 'zod';
import {
zControlField,
zIPAdapterField,
zLoRAModelField,
zMainModelField,
zONNXModelField,
zSDXLRefinerModelField,
zT2IAdapterField,
zVAEModelField,
} from './common';
// #region Metadata-optimized versions of schemas
// TODO: It's possible that `deepPartial` will be deprecated:
// - https://github.com/colinhacks/zod/issues/2106
// - https://github.com/colinhacks/zod/issues/2854
export const zLoRAMetadataItem = z.object({
lora: zLoRAModelField.deepPartial(),
weight: z.number(),
});
const zControlNetMetadataItem = zControlField.deepPartial();
const zIPAdapterMetadataItem = zIPAdapterField.deepPartial();
const zT2IAdapterMetadataItem = zT2IAdapterField.deepPartial();
const zSDXLRefinerModelMetadataItem = zSDXLRefinerModelField.deepPartial();
const zModelMetadataitem = z.union([
zMainModelField.deepPartial(),
zONNXModelField.deepPartial(),
]);
const zVAEModelMetadataItem = zVAEModelField.deepPartial();
export type LoRAMetadataItem = z.infer<typeof zLoRAMetadataItem>;
export type ControlNetMetadataItem = z.infer<typeof zControlNetMetadataItem>;
export type IPAdapterMetadataItem = z.infer<typeof zIPAdapterMetadataItem>;
export type T2IAdapterMetadataItem = z.infer<typeof zT2IAdapterMetadataItem>;
export type SDXLRefinerModelMetadataItem = z.infer<
typeof zSDXLRefinerModelMetadataItem
>;
export type ModelMetadataitem = z.infer<typeof zModelMetadataitem>;
export type VAEModelMetadataItem = z.infer<typeof zVAEModelMetadataItem>;
// #endregion
// #region CoreMetadata
export const zCoreMetadata = z
.object({
app_version: z.string().nullish().catch(null),
generation_mode: z.string().nullish().catch(null),
created_by: z.string().nullish().catch(null),
positive_prompt: z.string().nullish().catch(null),
negative_prompt: z.string().nullish().catch(null),
width: z.number().int().nullish().catch(null),
height: z.number().int().nullish().catch(null),
seed: z.number().int().nullish().catch(null),
rand_device: z.string().nullish().catch(null),
cfg_scale: z.number().nullish().catch(null),
steps: z.number().int().nullish().catch(null),
scheduler: z.string().nullish().catch(null),
clip_skip: z.number().int().nullish().catch(null),
model: zModelMetadataitem.nullish().catch(null),
controlnets: z.array(zControlNetMetadataItem).nullish().catch(null),
ipAdapters: z.array(zIPAdapterMetadataItem).nullish().catch(null),
t2iAdapters: z.array(zT2IAdapterMetadataItem).nullish().catch(null),
loras: z.array(zLoRAMetadataItem).nullish().catch(null),
vae: zVAEModelMetadataItem.nullish().catch(null),
strength: z.number().nullish().catch(null),
hrf_enabled: z.boolean().nullish().catch(null),
hrf_strength: z.number().nullish().catch(null),
hrf_method: z.string().nullish().catch(null),
init_image: z.string().nullish().catch(null),
positive_style_prompt: z.string().nullish().catch(null),
negative_style_prompt: z.string().nullish().catch(null),
refiner_model: zSDXLRefinerModelMetadataItem.nullish().catch(null),
refiner_cfg_scale: z.number().nullish().catch(null),
refiner_steps: z.number().int().nullish().catch(null),
refiner_scheduler: z.string().nullish().catch(null),
refiner_positive_aesthetic_score: z.number().nullish().catch(null),
refiner_negative_aesthetic_score: z.number().nullish().catch(null),
refiner_start: z.number().nullish().catch(null),
})
.passthrough();
export type CoreMetadata = z.infer<typeof zCoreMetadata>;
// #endregion

View File

@ -0,0 +1,69 @@
import { forEach, isString } from 'lodash-es';
import { z } from 'zod';
import { WorkflowVersionError } from '../error';
import { zSemVer } from '../semver';
import { WorkflowV2, zWorkflowV2 } from '../workflow';
import { FIELD_TYPE_V1_TO_FIELD_TYPE_V2_MAPPING } from './v1/fieldTypeMap';
import { WorkflowV1, zWorkflowV1 } from './v1/workflowV1';
import { t } from 'i18next';
/**
* Helper schema to extract the version from a workflow.
*
* All properties except for the version are ignored in this schema.
*/
const zWorkflowMetaVersion = z.object({
meta: z.object({ version: zSemVer }),
});
/**
* Migrates a workflow from V1 to V2.
*/
const migrateV1toV2 = (workflowToMigrate: WorkflowV1): WorkflowV2 => {
workflowToMigrate.nodes.forEach((node) => {
if (node.type === 'invocation') {
forEach(node.data.inputs, (input) => {
if (!isString(input.type)) {
return;
}
(input.type as unknown) =
FIELD_TYPE_V1_TO_FIELD_TYPE_V2_MAPPING[input.type];
});
forEach(node.data.outputs, (output) => {
if (!isString(output.type)) {
return;
}
(output.type as unknown) =
FIELD_TYPE_V1_TO_FIELD_TYPE_V2_MAPPING[output.type];
});
}
});
(workflowToMigrate.meta.version as WorkflowV2['meta']['version']) = '2.0.0';
return zWorkflowV2.parse(workflowToMigrate);
};
/**
* Parses a workflow and migrates it to the latest version if necessary.
*/
export const parseAndMigrateWorkflow = (data: unknown): WorkflowV2 => {
const workflowVersionResult = zWorkflowMetaVersion.safeParse(data);
if (!workflowVersionResult.success) {
throw new WorkflowVersionError(t('nodes.unableToGetWorkflowVersion'));
}
const { version } = workflowVersionResult.data.meta;
if (version === '1.0.0') {
const v1 = zWorkflowV1.parse(data);
return migrateV1toV2(v1);
}
if (version === '2.0.0') {
return zWorkflowV2.parse(data);
}
throw new WorkflowVersionError(
t('nodes.unrecognizedWorkflowVersion', { version })
);
};

View File

@ -0,0 +1,270 @@
import { FieldType, StatefulFieldType } from '../../field';
import { FieldTypeV1 } from './workflowV1';
/**
* Mapping of V1 field type strings to their *stateful* V2 field type counterparts.
*/
const FIELD_TYPE_V1_TO_STATEFUL_FIELD_TYPE_V2: {
[key in FieldTypeV1]?: StatefulFieldType;
} = {
BoardField: { name: 'BoardField', isCollection: false, isPolymorphic: false },
boolean: { name: 'BooleanField', isCollection: false, isPolymorphic: false },
BooleanCollection: {
name: 'BooleanField',
isCollection: true,
isPolymorphic: false,
},
BooleanPolymorphic: {
name: 'BooleanField',
isCollection: false,
isPolymorphic: true,
},
ColorField: { name: 'ColorField', isCollection: false, isPolymorphic: false },
ColorCollection: {
name: 'ColorField',
isCollection: true,
isPolymorphic: false,
},
ColorPolymorphic: {
name: 'ColorField',
isCollection: false,
isPolymorphic: true,
},
ControlNetModelField: {
name: 'ControlNetModelField',
isCollection: false,
isPolymorphic: false,
},
enum: { name: 'EnumField', isCollection: false, isPolymorphic: false },
float: { name: 'FloatField', isCollection: false, isPolymorphic: false },
FloatCollection: {
name: 'FloatField',
isCollection: true,
isPolymorphic: false,
},
FloatPolymorphic: {
name: 'FloatField',
isCollection: false,
isPolymorphic: true,
},
ImageCollection: {
name: 'ImageField',
isCollection: true,
isPolymorphic: false,
},
ImageField: { name: 'ImageField', isCollection: false, isPolymorphic: false },
ImagePolymorphic: {
name: 'ImageField',
isCollection: false,
isPolymorphic: true,
},
integer: { name: 'IntegerField', isCollection: false, isPolymorphic: false },
IntegerCollection: {
name: 'IntegerField',
isCollection: true,
isPolymorphic: false,
},
IntegerPolymorphic: {
name: 'IntegerField',
isCollection: false,
isPolymorphic: true,
},
IPAdapterModelField: {
name: 'IPAdapterModelField',
isCollection: false,
isPolymorphic: false,
},
LoRAModelField: {
name: 'LoRAModelField',
isCollection: false,
isPolymorphic: false,
},
MainModelField: {
name: 'MainModelField',
isCollection: false,
isPolymorphic: false,
},
Scheduler: {
name: 'SchedulerField',
isCollection: false,
isPolymorphic: false,
},
SDXLMainModelField: {
name: 'SDXLMainModelField',
isCollection: false,
isPolymorphic: false,
},
SDXLRefinerModelField: {
name: 'SDXLRefinerModelField',
isCollection: false,
isPolymorphic: false,
},
string: { name: 'StringField', isCollection: false, isPolymorphic: false },
StringCollection: {
name: 'StringField',
isCollection: true,
isPolymorphic: false,
},
StringPolymorphic: {
name: 'StringField',
isCollection: false,
isPolymorphic: true,
},
T2IAdapterModelField: {
name: 'T2IAdapterModelField',
isCollection: false,
isPolymorphic: false,
},
VaeModelField: {
name: 'VAEModelField',
isCollection: false,
isPolymorphic: false,
},
};
/**
* Mapping of V1 field type strings to their *stateless* V2 field type counterparts.
*
* The type doesn't do what I want it to do.
*
* Ideally, the value of each propery would be a `FieldType` where `FieldType['name']` is not in
* `StatefulFieldType['name']`, but this is hard to represent. That's because `FieldType['name']` is
* actually widened to `string`, and TS's `Exclude<T,U>` doesn't work on `string`.
*
* There's probably some way to do it with conditionals and intersections but I can't figure it out.
*
* Thus, this object was manually edited to ensure it is correct.
*/
const FIELD_TYPE_V1_TO_STATELESS_FIELD_TYPE_V2: {
[key in FieldTypeV1]?: FieldType;
} = {
Any: { name: 'AnyField', isCollection: false, isPolymorphic: false },
ClipField: { name: 'ClipField', isCollection: false, isPolymorphic: false },
Collection: {
name: 'CollectionField',
isCollection: true,
isPolymorphic: false,
},
CollectionItem: {
name: 'CollectionItemField',
isCollection: false,
isPolymorphic: false,
},
ConditioningCollection: {
name: 'ConditioningField',
isCollection: true,
isPolymorphic: false,
},
ConditioningField: {
name: 'ConditioningField',
isCollection: false,
isPolymorphic: false,
},
ConditioningPolymorphic: {
name: 'ConditioningField',
isCollection: false,
isPolymorphic: true,
},
ControlCollection: {
name: 'ControlField',
isCollection: true,
isPolymorphic: false,
},
ControlField: {
name: 'ControlField',
isCollection: false,
isPolymorphic: false,
},
ControlPolymorphic: {
name: 'ControlField',
isCollection: false,
isPolymorphic: true,
},
DenoiseMaskField: {
name: 'DenoiseMaskField',
isCollection: false,
isPolymorphic: false,
},
IPAdapterField: {
name: 'IPAdapterField',
isCollection: false,
isPolymorphic: false,
},
IPAdapterCollection: {
name: 'IPAdapterField',
isCollection: true,
isPolymorphic: false,
},
IPAdapterPolymorphic: {
name: 'IPAdapterField',
isCollection: false,
isPolymorphic: true,
},
LatentsField: {
name: 'LatentsField',
isCollection: false,
isPolymorphic: false,
},
LatentsCollection: {
name: 'LatentsField',
isCollection: true,
isPolymorphic: false,
},
LatentsPolymorphic: {
name: 'LatentsField',
isCollection: false,
isPolymorphic: true,
},
MetadataField: {
name: 'MetadataField',
isCollection: false,
isPolymorphic: false,
},
MetadataCollection: {
name: 'MetadataField',
isCollection: true,
isPolymorphic: false,
},
MetadataItemField: {
name: 'MetadataItemField',
isCollection: false,
isPolymorphic: false,
},
MetadataItemCollection: {
name: 'MetadataItemField',
isCollection: true,
isPolymorphic: false,
},
MetadataItemPolymorphic: {
name: 'MetadataItemField',
isCollection: false,
isPolymorphic: true,
},
ONNXModelField: {
name: 'ONNXModelField',
isCollection: false,
isPolymorphic: false,
},
T2IAdapterField: {
name: 'T2IAdapterField',
isCollection: false,
isPolymorphic: false,
},
T2IAdapterCollection: {
name: 'T2IAdapterField',
isCollection: true,
isPolymorphic: false,
},
T2IAdapterPolymorphic: {
name: 'T2IAdapterField',
isCollection: false,
isPolymorphic: true,
},
UNetField: { name: 'UNetField', isCollection: false, isPolymorphic: false },
VaeField: { name: 'VaeField', isCollection: false, isPolymorphic: false },
};
export const FIELD_TYPE_V1_TO_FIELD_TYPE_V2_MAPPING = {
...FIELD_TYPE_V1_TO_STATEFUL_FIELD_TYPE_V2,
...FIELD_TYPE_V1_TO_STATELESS_FIELD_TYPE_V2,
};

Some files were not shown because too many files have changed in this diff Show More