mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): add support for custom field types
Node authors may now create their own arbitrary/custom field types. Any pydantic model is supported. Two notes: 1. Your field type's class name must be unique. Suggest prefixing fields with something related to the node pack as a kind of namespace. 2. Custom field types function as connection-only fields. For example, if your custom field has string attributes, you will not get a text input for that attribute when you give a node a field with your custom type. This is the same behaviour as other complex fields that don't have custom UIs in the workflow editor - like, say, a string collection. feat(ui): fix tooltips for custom types We need to hold onto the original type of the field so they don't all just show up as "Unknown". fix(ui): fix ts error with custom fields feat(ui): custom field types connection validation In the initial commit, a custom field's original type was added to the *field templates* only as `originalType`. Custom fields' `type` property was `"Custom"`*. This allowed for type safety throughout the UI logic. *Actually, it was `"Unknown"`, but I changed it to custom for clarity. Connection validation logic, however, uses the *field instance* of the node/field. Like the templates, *field instances* with custom types have their `type` set to `"Custom"`, but they didn't have an `originalType` property. As a result, all custom fields could be connected to all other custom fields. To resolve this, we need to add `originalType` to the *field instances*, then switch the validation logic to use this instead of `type`. This ended up needing a bit of fanagling: - If we make `originalType` a required property on field instances, existing workflows will break during connection validation, because they won't have this property. We'd need a new layer of logic to migrate the workflows, adding the new `originalType` property. While this layer is probably needed anyways, typing `originalType` as optional is much simpler. Workflow migration logic can come layer. (Technically, we could remove all references to field types from the workflow files, and let the templates hold all this information. This feels like a significant change and I'm reluctant to do it now.) - Because `originalType` is optional, anywhere we care about the type of a field, we need to use it over `type`. So there are a number of `field.originalType ?? field.type` expressions. This is a bit of a gotcha, we'll need to remember this in the future. - We use `Array.prototype.includes()` often in the workflow editor, e.g. `COLLECTION_TYPES.includes(type)`. In these cases, the const array is of type `FieldType[]`, and `type` is is `FieldType`. Because we now support custom types, the arg `type` is now widened from `FieldType` to `string`. This causes a TS error. This behaviour is somewhat controversial (see https://github.com/microsoft/TypeScript/issues/14520). These expressions are now rewritten as `COLLECTION_TYPES.some((t) => t === type)` to satisfy TS. It's logically equivalent. fix(ui): typo feat(ui): add CustomCollection and CustomPolymorphic field types feat(ui): add validation for CustomCollection & CustomPolymorphic types - Update connection validation for custom types - Use simple string parsing to determine if a field is a collection or polymorphic type. - No longer need to keep a list of collection and polymorphic types. - Added runtime checks in `baseinvocation.py` to ensure no fields are named in such a way that it could mess up the new parsing chore(ui): remove errant console.log fix(ui): rename 'nodes.currentConnectionFieldType' -> 'nodes.connectionStartFieldType' This was confusingly named and kept tripping me up. Renamed to be consistent with the `reactflow` `ConnectionStartParams` type. fix(ui): fix ts error feat(nodes): add runtime check for custom field names "Custom", "CustomCollection" and "CustomPolymorphic" are reserved field names. chore(ui): add TODO for revising field type names wip refactor fieldtype structured wip refactor field types wip refactor types wip refactor types fix node layout refactor field types chore: mypy organisation organisation organisation fix(nodes): fix field orig_required, field_kind and input statuses feat(nodes): remove broken implementation of default_factory on InputField Use of this could break connection validation due to the difference in node schemas required fields and invoke() required args. Removed entirely for now. It wasn't ever actually used by the system, because all graphs always had values provided for fields where default_factory was used. Also, pydantic is smart enough to not reuse the same object when specifying a default value - it clones the object first. So, the common pattern of `default_factory=list` is extraneous. It can just be `default=[]`. fix(nodes): fix InputField name validation workflow validation validation chore: ruff feat(nodes): fix up baseinvocation comments fix(ui): improve typing & logic of buildFieldInputTemplate improved error handling in parseFieldType fix: back compat for deprecated default_factory and UIType feat(nodes): do not show node packs loaded log if none loaded chore(ui): typegen
This commit is contained in:
@ -805,6 +805,8 @@
|
||||
"clipField": "Clip",
|
||||
"clipFieldDescription": "Tokenizer and text_encoder submodels.",
|
||||
"collection": "Collection",
|
||||
"collectionFieldType": "{{name}} Collection",
|
||||
"polymorphicFieldType": "{{name}} Polymorphic",
|
||||
"collectionDescription": "TODO",
|
||||
"collectionItem": "Collection Item",
|
||||
"collectionItemDescription": "TODO",
|
||||
@ -891,10 +893,15 @@
|
||||
"mainModelField": "Model",
|
||||
"mainModelFieldDescription": "TODO",
|
||||
"maybeIncompatible": "May be Incompatible With Installed",
|
||||
"mismatchedVersion": "Has Mismatched Version",
|
||||
"mismatchedVersion": "Invalid node: node {{node}} of type {{type}} has mismatched version (try updating?)",
|
||||
"missingCanvaInitImage": "Missing canvas init image",
|
||||
"missingCanvaInitMaskImages": "Missing canvas init and mask images",
|
||||
"missingTemplate": "Missing Template",
|
||||
"missingTemplate": "Invalid node: node {{node}} of type {{type}} missing template (not installed?)",
|
||||
"sourceNodeDoesNotExist": "Invalid edge: source/output node {{node}} does not exist",
|
||||
"targetNodeDoesNotExist": "Invalid edge: target/input node {{node}} does not exist",
|
||||
"sourceNodeFieldDoesNotExist": "Invalid edge: source/output field {{node}}.{{field}} does not exist",
|
||||
"targetNodeFieldDoesNotExist": "Invalid edge: target/input field {{node}}.{{field}} does not exist",
|
||||
"deletedInvalidEdge": "Deleted invalid edge {{source}} -> {{target}}",
|
||||
"noConnectionData": "No connection data",
|
||||
"noConnectionInProgress": "No connection in progress",
|
||||
"node": "Node",
|
||||
@ -954,10 +961,17 @@
|
||||
"stringDescription": "Strings are text.",
|
||||
"stringPolymorphic": "String Polymorphic",
|
||||
"stringPolymorphicDescription": "A collection of strings.",
|
||||
"unableToLoadWorkflow": "Unable to Validate Workflow",
|
||||
"unableToLoadWorkflow": "Unable to Load Workflow",
|
||||
"unableToParseEdge": "Unable to parse edge",
|
||||
"unableToParseNode": "Unable to parse node",
|
||||
"unableToUpdateNode": "Unable to update node",
|
||||
"unableToValidateWorkflow": "Unable to Validate Workflow",
|
||||
"unknownErrorValidatingWorkflow": "Unknown error validating workflow",
|
||||
"inputFieldTypeParseError": "Unable to parse type of input field {{node}}.{{field}} ({{message}})",
|
||||
"outputFieldTypeParseError": "Unable to parse type of output field {{node}}.{{field}} ({{message}})",
|
||||
"unableToExtractSchemaNameFromRef": "unable to extract schema name from ref",
|
||||
"unsupportedArrayItemType": "unsupported array item type \"{{type}}\"",
|
||||
"unableToParseFieldType": "unable to parse field type",
|
||||
"uNetField": "UNet",
|
||||
"uNetFieldDescription": "UNet submodel.",
|
||||
"unhandledInputProperty": "Unhandled input property",
|
||||
@ -971,8 +985,9 @@
|
||||
"unkownInvocation": "Unknown Invocation type",
|
||||
"unknownOutput": "Unknown output",
|
||||
"updateNode": "Update Node",
|
||||
"updateAllNodes": "Update All Nodes",
|
||||
"updateApp": "Update App",
|
||||
"updateAllNodes": "Update All Nodes",
|
||||
"allNodesUpdated": "All Nodes Updated",
|
||||
"unableToUpdateNodes_one": "Unable to update {{count}} node",
|
||||
"unableToUpdateNodes_other": "Unable to update {{count}} nodes",
|
||||
"vaeField": "Vae",
|
||||
@ -981,6 +996,8 @@
|
||||
"vaeModelFieldDescription": "TODO",
|
||||
"validateConnections": "Validate Connections and Graph",
|
||||
"validateConnectionsHelp": "Prevent invalid connections from being made, and invalid graphs from being invoked",
|
||||
"unableToGetWorkflowVersion": "Unable to get workflow schema version",
|
||||
"unrecognizedWorkflowVersion": "Unrecognized workflow schema version {{version}}",
|
||||
"version": "Version",
|
||||
"versionUnknown": " Version Unknown",
|
||||
"workflow": "Workflow",
|
||||
|
@ -71,7 +71,7 @@ import { addSocketUnsubscribedEventListener as addSocketUnsubscribedListener } f
|
||||
import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSaved';
|
||||
import { addTabChangedListener } from './listeners/tabChanged';
|
||||
import { addUpscaleRequestedListener } from './listeners/upscaleRequested';
|
||||
import { addWorkflowLoadedListener } from './listeners/workflowLoaded';
|
||||
import { addWorkflowLoadRequestedListener } from './listeners/workflowLoadRequested';
|
||||
import { addUpdateAllNodesRequestedListener } from './listeners/updateAllNodesRequested';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
@ -178,7 +178,7 @@ addBoardIdSelectedListener();
|
||||
addReceivedOpenAPISchemaListener();
|
||||
|
||||
// Workflows
|
||||
addWorkflowLoadedListener();
|
||||
addWorkflowLoadRequestedListener();
|
||||
addUpdateAllNodesRequestedListener();
|
||||
|
||||
// DND
|
||||
|
@ -12,10 +12,10 @@ import { addToast } from 'features/system/store/systemSlice';
|
||||
import { t } from 'i18next';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import { isImageOutput } from 'services/api/guards';
|
||||
import { BatchConfig, ImageDTO } from 'services/api/types';
|
||||
import { socketInvocationComplete } from 'services/events/actions';
|
||||
import { startAppListening } from '..';
|
||||
import { isImageOutput } from 'features/nodes/types/common';
|
||||
|
||||
export const addControlNetImageProcessedListener = () => {
|
||||
startAppListening({
|
||||
|
@ -5,19 +5,20 @@ import {
|
||||
controlAdapterProcessedImageChanged,
|
||||
selectControlAdapterAll,
|
||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
|
||||
import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions';
|
||||
import { isModalOpenChanged } from 'features/deleteImageModal/store/slice';
|
||||
import { selectListImagesBaseQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/types';
|
||||
import { isImageFieldInputInstance } from 'features/nodes/types/field';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
||||
import { clamp, forEach } from 'lodash-es';
|
||||
import { api } from 'services/api';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { imagesAdapter } from 'services/api/util';
|
||||
import { startAppListening } from '..';
|
||||
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
|
||||
|
||||
export const addRequestedSingleImageDeletionListener = () => {
|
||||
startAppListening({
|
||||
@ -121,7 +122,7 @@ export const addRequestedSingleImageDeletionListener = () => {
|
||||
|
||||
forEach(node.data.inputs, (input) => {
|
||||
if (
|
||||
input.type === 'ImageField' &&
|
||||
isImageFieldInputInstance(input) &&
|
||||
input.value?.image_name === imageDTO.image_name
|
||||
) {
|
||||
dispatch(
|
||||
@ -241,7 +242,7 @@ export const addRequestedMultipleImageDeletionListener = () => {
|
||||
|
||||
forEach(node.data.inputs, (input) => {
|
||||
if (
|
||||
input.type === 'ImageField' &&
|
||||
isImageFieldInputInstance(input) &&
|
||||
input.value?.image_name === imageDTO.image_name
|
||||
) {
|
||||
dispatch(
|
||||
|
@ -12,12 +12,12 @@ import {
|
||||
setWidth,
|
||||
vaeSelected,
|
||||
} from 'features/parameters/store/generationSlice';
|
||||
import { zMainOrOnnxModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { t } from 'i18next';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { startAppListening } from '..';
|
||||
import { zParameterModel } from 'features/parameters/types/parameterSchemas';
|
||||
|
||||
export const addModelSelectedListener = () => {
|
||||
startAppListening({
|
||||
@ -26,7 +26,7 @@ export const addModelSelectedListener = () => {
|
||||
const log = logger('models');
|
||||
|
||||
const state = getState();
|
||||
const result = zMainOrOnnxModel.safeParse(action.payload);
|
||||
const result = zParameterModel.safeParse(action.payload);
|
||||
|
||||
if (!result.success) {
|
||||
log.error(
|
||||
|
@ -11,9 +11,9 @@ import {
|
||||
vaeSelected,
|
||||
} from 'features/parameters/store/generationSlice';
|
||||
import {
|
||||
zMainOrOnnxModel,
|
||||
zSDXLRefinerModel,
|
||||
zVaeModel,
|
||||
zParameterModel,
|
||||
zParameterSDXLRefinerModel,
|
||||
zParameterVAEModel,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import {
|
||||
refinerModelChanged,
|
||||
@ -67,7 +67,7 @@ export const addModelsLoadedListener = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
const result = zMainOrOnnxModel.safeParse(models[0]);
|
||||
const result = zParameterModel.safeParse(models[0]);
|
||||
|
||||
if (!result.success) {
|
||||
log.error(
|
||||
@ -119,7 +119,7 @@ export const addModelsLoadedListener = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
const result = zSDXLRefinerModel.safeParse(models[0]);
|
||||
const result = zParameterSDXLRefinerModel.safeParse(models[0]);
|
||||
|
||||
if (!result.success) {
|
||||
log.error(
|
||||
@ -170,7 +170,7 @@ export const addModelsLoadedListener = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
const result = zVaeModel.safeParse(firstModel);
|
||||
const result = zParameterVAEModel.safeParse(firstModel);
|
||||
|
||||
if (!result.success) {
|
||||
log.error(
|
||||
|
@ -15,6 +15,7 @@ export const addReceivedOpenAPISchemaListener = () => {
|
||||
|
||||
log.debug({ schemaJSON }, 'Received OpenAPI schema');
|
||||
const { nodesAllowlist, nodesDenylist } = getState().config;
|
||||
|
||||
const nodeTemplates = parseSchema(
|
||||
schemaJSON,
|
||||
nodesAllowlist,
|
||||
|
@ -13,13 +13,13 @@ import {
|
||||
} from 'features/nodes/util/graphBuilders/constants';
|
||||
import { boardsApi } from 'services/api/endpoints/boards';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { isImageOutput } from 'services/api/guards';
|
||||
import { imagesAdapter } from 'services/api/util';
|
||||
import {
|
||||
appSocketInvocationComplete,
|
||||
socketInvocationComplete,
|
||||
} from 'services/events/actions';
|
||||
import { startAppListening } from '../..';
|
||||
import { isImageOutput } from 'features/nodes/types/common';
|
||||
|
||||
// These nodes output an image, but do not actually *save* an image, so we don't want to handle the gallery logic on them
|
||||
const nodeTypeDenylist = ['load_image', 'image'];
|
||||
|
@ -1,14 +1,16 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { updateAllNodesRequested } from 'features/nodes/store/actions';
|
||||
import { nodeReplaced } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
getNeedsUpdate,
|
||||
updateNode,
|
||||
} from 'features/nodes/hooks/useNodeVersion';
|
||||
import { updateAllNodesRequested } from 'features/nodes/store/actions';
|
||||
import { nodeReplaced } from 'features/nodes/store/nodesSlice';
|
||||
import { startAppListening } from '..';
|
||||
import { logger } from 'app/logging/logger';
|
||||
} from 'features/nodes/store/util/nodeUpdate';
|
||||
import { NodeUpdateError } from 'features/nodes/types/error';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { t } from 'i18next';
|
||||
import { startAppListening } from '..';
|
||||
|
||||
export const addUpdateAllNodesRequestedListener = () => {
|
||||
startAppListening({
|
||||
@ -20,22 +22,31 @@ export const addUpdateAllNodesRequestedListener = () => {
|
||||
|
||||
let unableToUpdateCount = 0;
|
||||
|
||||
nodes.forEach((node) => {
|
||||
nodes.filter(isInvocationNode).forEach((node) => {
|
||||
const template = templates[node.data.type];
|
||||
const needsUpdate = getNeedsUpdate(node, template);
|
||||
const updatedNode = updateNode(node, template);
|
||||
if (!updatedNode) {
|
||||
if (needsUpdate) {
|
||||
unableToUpdateCount++;
|
||||
}
|
||||
if (!template) {
|
||||
unableToUpdateCount++;
|
||||
return;
|
||||
}
|
||||
dispatch(nodeReplaced({ nodeId: updatedNode.id, node: updatedNode }));
|
||||
if (!getNeedsUpdate(node, template)) {
|
||||
// No need to increment the count here, since we're not actually updating
|
||||
return;
|
||||
}
|
||||
try {
|
||||
const updatedNode = updateNode(node, template);
|
||||
dispatch(nodeReplaced({ nodeId: updatedNode.id, node: updatedNode }));
|
||||
} catch (e) {
|
||||
if (e instanceof NodeUpdateError) {
|
||||
unableToUpdateCount++;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if (unableToUpdateCount) {
|
||||
log.warn(
|
||||
`Unable to update ${unableToUpdateCount} nodes. Please report this issue.`
|
||||
t('nodes.unableToUpdateNodes', {
|
||||
count: unableToUpdateCount,
|
||||
})
|
||||
);
|
||||
dispatch(
|
||||
addToast(
|
||||
@ -46,6 +57,15 @@ export const addUpdateAllNodesRequestedListener = () => {
|
||||
})
|
||||
)
|
||||
);
|
||||
} else {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('nodes.allNodesUpdated'),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
@ -0,0 +1,105 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { workflowLoadRequested } from 'features/nodes/store/actions';
|
||||
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
|
||||
import { $flow } from 'features/nodes/store/reactFlowInstance';
|
||||
import { WorkflowVersionError } from 'features/nodes/types/error';
|
||||
import { validateWorkflow } from 'features/nodes/util/validateWorkflow';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||
import { t } from 'i18next';
|
||||
import { z } from 'zod';
|
||||
import { fromZodError } from 'zod-validation-error';
|
||||
import { startAppListening } from '..';
|
||||
|
||||
export const addWorkflowLoadRequestedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: workflowLoadRequested,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const log = logger('nodes');
|
||||
const workflow = action.payload;
|
||||
const nodeTemplates = getState().nodes.nodeTemplates;
|
||||
|
||||
try {
|
||||
const { workflow: validatedWorkflow, warnings } = validateWorkflow(
|
||||
workflow,
|
||||
nodeTemplates
|
||||
);
|
||||
dispatch(workflowLoaded(validatedWorkflow));
|
||||
if (!warnings.length) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('toast.workflowLoaded'),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
} else {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('toast.loadedWithWarnings'),
|
||||
status: 'warning',
|
||||
})
|
||||
)
|
||||
);
|
||||
warnings.forEach(({ message, ...rest }) => {
|
||||
log.warn(rest, message);
|
||||
});
|
||||
}
|
||||
|
||||
dispatch(setActiveTab('nodes'));
|
||||
requestAnimationFrame(() => {
|
||||
$flow.get()?.fitView();
|
||||
});
|
||||
} catch (e) {
|
||||
if (e instanceof WorkflowVersionError) {
|
||||
// The workflow version was not recognized in the valid list of versions
|
||||
log.error({ error: parseify(e) }, e.message);
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('nodes.unableToValidateWorkflow'),
|
||||
status: 'error',
|
||||
description: e.message,
|
||||
})
|
||||
)
|
||||
);
|
||||
} else if (e instanceof z.ZodError) {
|
||||
// There was a problem validating the workflow itself
|
||||
const { message } = fromZodError(e, {
|
||||
prefix: t('nodes.workflowValidation'),
|
||||
});
|
||||
log.error({ error: parseify(e) }, message);
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('nodes.unableToValidateWorkflow'),
|
||||
status: 'error',
|
||||
description: message,
|
||||
})
|
||||
)
|
||||
);
|
||||
} else {
|
||||
// Some other error occurred
|
||||
console.log(e);
|
||||
log.error(
|
||||
{ error: parseify(e) },
|
||||
t('nodes.unknownErrorValidatingWorkflow')
|
||||
);
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('nodes.unableToValidateWorkflow'),
|
||||
status: 'error',
|
||||
description: t('nodes.unknownErrorValidatingWorkflow'),
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
@ -1,56 +0,0 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { workflowLoadRequested } from 'features/nodes/store/actions';
|
||||
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
|
||||
import { $flow } from 'features/nodes/store/reactFlowInstance';
|
||||
import { validateWorkflow } from 'features/nodes/util/validateWorkflow';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||
import { startAppListening } from '..';
|
||||
import { t } from 'i18next';
|
||||
|
||||
export const addWorkflowLoadedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: workflowLoadRequested,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const log = logger('nodes');
|
||||
const workflow = action.payload;
|
||||
const nodeTemplates = getState().nodes.nodeTemplates;
|
||||
|
||||
const { workflow: validatedWorkflow, errors } = validateWorkflow(
|
||||
workflow,
|
||||
nodeTemplates
|
||||
);
|
||||
|
||||
dispatch(workflowLoaded(validatedWorkflow));
|
||||
|
||||
if (!errors.length) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('toast.workflowLoaded'),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
} else {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('toast.loadedWithWarnings'),
|
||||
status: 'warning',
|
||||
})
|
||||
)
|
||||
);
|
||||
errors.forEach(({ message, ...rest }) => {
|
||||
log.warn(rest, message);
|
||||
});
|
||||
}
|
||||
|
||||
dispatch(setActiveTab('nodes'));
|
||||
requestAnimationFrame(() => {
|
||||
$flow.get()?.fitView();
|
||||
});
|
||||
},
|
||||
});
|
||||
};
|
@ -4,7 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { selectControlAdapterAll } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
|
||||
import { isInvocationNode } from 'features/nodes/types/types';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import i18n from 'i18next';
|
||||
import { forEach } from 'lodash-es';
|
||||
|
@ -6,9 +6,9 @@ import {
|
||||
isAnyOf,
|
||||
} from '@reduxjs/toolkit';
|
||||
import {
|
||||
ControlNetModelParam,
|
||||
IPAdapterModelParam,
|
||||
T2IAdapterModelParam,
|
||||
ParameterControlNetModel,
|
||||
ParameterIPAdapterModel,
|
||||
ParameterT2IAdapterModel,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import { cloneDeep, merge, uniq } from 'lodash-es';
|
||||
import { appSocketInvocationError } from 'services/events/actions';
|
||||
@ -243,9 +243,9 @@ export const controlAdaptersSlice = createSlice({
|
||||
action: PayloadAction<{
|
||||
id: string;
|
||||
model:
|
||||
| ControlNetModelParam
|
||||
| T2IAdapterModelParam
|
||||
| IPAdapterModelParam;
|
||||
| ParameterControlNetModel
|
||||
| ParameterT2IAdapterModel
|
||||
| ParameterIPAdapterModel;
|
||||
}>
|
||||
) => {
|
||||
const { id, model } = action.payload;
|
||||
|
@ -1,8 +1,8 @@
|
||||
import { EntityState } from '@reduxjs/toolkit';
|
||||
import {
|
||||
ControlNetModelParam,
|
||||
IPAdapterModelParam,
|
||||
T2IAdapterModelParam,
|
||||
ParameterControlNetModel,
|
||||
ParameterIPAdapterModel,
|
||||
ParameterT2IAdapterModel,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import { isObject } from 'lodash-es';
|
||||
import { components } from 'services/api/schema';
|
||||
@ -378,7 +378,7 @@ export type ControlNetConfig = {
|
||||
type: 'controlnet';
|
||||
id: string;
|
||||
isEnabled: boolean;
|
||||
model: ControlNetModelParam | null;
|
||||
model: ParameterControlNetModel | null;
|
||||
weight: number;
|
||||
beginStepPct: number;
|
||||
endStepPct: number;
|
||||
@ -395,7 +395,7 @@ export type T2IAdapterConfig = {
|
||||
type: 't2i_adapter';
|
||||
id: string;
|
||||
isEnabled: boolean;
|
||||
model: T2IAdapterModelParam | null;
|
||||
model: ParameterT2IAdapterModel | null;
|
||||
weight: number;
|
||||
beginStepPct: number;
|
||||
endStepPct: number;
|
||||
@ -412,7 +412,7 @@ export type IPAdapterConfig = {
|
||||
id: string;
|
||||
isEnabled: boolean;
|
||||
controlImage: string | null;
|
||||
model: IPAdapterModelParam | null;
|
||||
model: ParameterIPAdapterModel | null;
|
||||
weight: number;
|
||||
beginStepPct: number;
|
||||
endStepPct: number;
|
||||
|
@ -1,11 +1,12 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { isInvocationNode } from 'features/nodes/types/types';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { some } from 'lodash-es';
|
||||
import { ImageUsage } from './types';
|
||||
import { selectControlAdapterAll } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
|
||||
import { isImageFieldInputInstance } from 'features/nodes/types/field';
|
||||
|
||||
export const getImageUsage = (state: RootState, image_name: string) => {
|
||||
const { generation, canvas, nodes, controlAdapters } = state;
|
||||
@ -19,7 +20,8 @@ export const getImageUsage = (state: RootState, image_name: string) => {
|
||||
return some(
|
||||
node.data.inputs,
|
||||
(input) =>
|
||||
input.type === 'ImageField' && input.value?.image_name === image_name
|
||||
isImageFieldInputInstance(input) &&
|
||||
input.value?.image_name === image_name
|
||||
);
|
||||
});
|
||||
|
||||
|
@ -11,9 +11,9 @@ import {
|
||||
useDroppable as useOriginalDroppable,
|
||||
} from '@dnd-kit/core';
|
||||
import {
|
||||
InputFieldTemplate,
|
||||
InputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
FieldInputTemplate,
|
||||
FieldInputInstance,
|
||||
} from 'features/nodes/types/field';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
|
||||
type BaseDropData = {
|
||||
@ -93,8 +93,8 @@ export type NodeFieldDraggableData = BaseDragData & {
|
||||
payloadType: 'NODE_FIELD';
|
||||
payload: {
|
||||
nodeId: string;
|
||||
field: InputFieldValue;
|
||||
fieldTemplate: InputFieldTemplate;
|
||||
field: FieldInputInstance;
|
||||
fieldTemplate: FieldInputTemplate;
|
||||
};
|
||||
};
|
||||
|
||||
|
@ -4,14 +4,14 @@ import {
|
||||
LoRAMetadataItem,
|
||||
IPAdapterMetadataItem,
|
||||
T2IAdapterMetadataItem,
|
||||
} from 'features/nodes/types/types';
|
||||
} from 'features/nodes/types/metadata';
|
||||
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
||||
import { memo, useMemo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
isValidControlNetModel,
|
||||
isValidLoRAModel,
|
||||
isValidT2IAdapterModel,
|
||||
isParameterControlNetModel,
|
||||
isParameterLoRAModel,
|
||||
isParameterT2IAdapterModel,
|
||||
} from '../../../parameters/types/parameterSchemas';
|
||||
import ImageMetadataItem from './ImageMetadataItem';
|
||||
|
||||
@ -132,7 +132,7 @@ const ImageMetadataActions = (props: Props) => {
|
||||
const validControlNets: ControlNetMetadataItem[] = useMemo(() => {
|
||||
return metadata?.controlnets
|
||||
? metadata.controlnets.filter((controlnet) =>
|
||||
isValidControlNetModel(controlnet.control_model)
|
||||
isParameterControlNetModel(controlnet.control_model)
|
||||
)
|
||||
: [];
|
||||
}, [metadata?.controlnets]);
|
||||
@ -140,7 +140,7 @@ const ImageMetadataActions = (props: Props) => {
|
||||
const validIPAdapters: IPAdapterMetadataItem[] = useMemo(() => {
|
||||
return metadata?.ipAdapters
|
||||
? metadata.ipAdapters.filter((ipAdapter) =>
|
||||
isValidControlNetModel(ipAdapter.ip_adapter_model)
|
||||
isParameterControlNetModel(ipAdapter.ip_adapter_model)
|
||||
)
|
||||
: [];
|
||||
}, [metadata?.ipAdapters]);
|
||||
@ -148,7 +148,7 @@ const ImageMetadataActions = (props: Props) => {
|
||||
const validT2IAdapters: T2IAdapterMetadataItem[] = useMemo(() => {
|
||||
return metadata?.t2iAdapters
|
||||
? metadata.t2iAdapters.filter((t2iAdapter) =>
|
||||
isValidT2IAdapterModel(t2iAdapter.t2i_adapter_model)
|
||||
isParameterT2IAdapterModel(t2iAdapter.t2i_adapter_model)
|
||||
)
|
||||
: [];
|
||||
}, [metadata?.t2iAdapters]);
|
||||
@ -157,8 +157,6 @@ const ImageMetadataActions = (props: Props) => {
|
||||
return null;
|
||||
}
|
||||
|
||||
console.log(metadata);
|
||||
|
||||
return (
|
||||
<>
|
||||
{metadata.created_by && (
|
||||
@ -275,7 +273,7 @@ const ImageMetadataActions = (props: Props) => {
|
||||
)}
|
||||
{metadata.loras &&
|
||||
metadata.loras.map((lora, index) => {
|
||||
if (isValidLoRAModel(lora.lora)) {
|
||||
if (isParameterLoRAModel(lora.lora)) {
|
||||
return (
|
||||
<ImageMetadataItem
|
||||
key={index}
|
||||
|
@ -1,8 +1,8 @@
|
||||
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
|
||||
import { LoRAModelParam } from 'features/parameters/types/parameterSchemas';
|
||||
import { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { LoRAModelConfigEntity } from 'services/api/endpoints/models';
|
||||
|
||||
export type LoRA = LoRAModelParam & {
|
||||
export type LoRA = ParameterLoRAModel & {
|
||||
id: string;
|
||||
weight: number;
|
||||
};
|
||||
|
@ -24,7 +24,6 @@ import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { HotkeyCallback } from 'react-hotkeys-hook/dist/types';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import 'reactflow/dist/style.css';
|
||||
import { AnyInvocationType } from 'services/events/types';
|
||||
import { AddNodePopoverSelectItem } from './AddNodePopoverSelectItem';
|
||||
|
||||
type NodeTemplate = {
|
||||
@ -57,7 +56,7 @@ const AddNodePopover = () => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const fieldFilter = useAppSelector(
|
||||
(state) => state.nodes.currentConnectionFieldType
|
||||
(state) => state.nodes.connectionStartFieldType
|
||||
);
|
||||
const handleFilter = useAppSelector(
|
||||
(state) => state.nodes.connectionStartParams?.handleType
|
||||
@ -111,7 +110,7 @@ const AddNodePopover = () => {
|
||||
|
||||
data.sort((a, b) => a.label.localeCompare(b.label));
|
||||
|
||||
return { data, t };
|
||||
return { data };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
@ -121,7 +120,7 @@ const AddNodePopover = () => {
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
const addNode = useCallback(
|
||||
(nodeType: AnyInvocationType) => {
|
||||
(nodeType: string) => {
|
||||
const invocation = buildInvocation(nodeType);
|
||||
if (!invocation) {
|
||||
const errorMessage = t('nodes.unknownNode', {
|
||||
@ -145,7 +144,7 @@ const AddNodePopover = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
addNode(v as AnyInvocationType);
|
||||
addNode(v);
|
||||
},
|
||||
[addNode]
|
||||
);
|
||||
|
@ -2,18 +2,17 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
|
||||
import { FIELDS } from 'features/nodes/types/constants';
|
||||
import { memo } from 'react';
|
||||
import { ConnectionLineComponentProps, getBezierPath } from 'reactflow';
|
||||
import { getFieldColor } from '../edges/util/getEdgeColor';
|
||||
|
||||
const selector = createSelector(stateSelector, ({ nodes }) => {
|
||||
const { shouldAnimateEdges, currentConnectionFieldType, shouldColorEdges } =
|
||||
const { shouldAnimateEdges, connectionStartFieldType, shouldColorEdges } =
|
||||
nodes;
|
||||
|
||||
const stroke =
|
||||
currentConnectionFieldType && shouldColorEdges
|
||||
? colorTokenToCssVar(FIELDS[currentConnectionFieldType].color)
|
||||
: colorTokenToCssVar('base.500');
|
||||
const stroke = shouldColorEdges
|
||||
? getFieldColor(connectionStartFieldType)
|
||||
: colorTokenToCssVar('base.500');
|
||||
|
||||
let className = 'react-flow__custom_connection-path';
|
||||
|
||||
|
@ -0,0 +1,12 @@
|
||||
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
|
||||
import { FIELD_COLORS } from 'features/nodes/types/constants';
|
||||
import { FieldType } from 'features/nodes/types/field';
|
||||
|
||||
export const getFieldColor = (fieldType: FieldType | null): string => {
|
||||
if (!fieldType) {
|
||||
return colorTokenToCssVar('base.500');
|
||||
}
|
||||
const color = FIELD_COLORS[fieldType.name];
|
||||
|
||||
return color ? colorTokenToCssVar(color) : colorTokenToCssVar('base.500');
|
||||
};
|
@ -2,8 +2,8 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
|
||||
import { FIELDS } from 'features/nodes/types/constants';
|
||||
import { isInvocationNode } from 'features/nodes/types/types';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { getFieldColor } from './getEdgeColor';
|
||||
|
||||
export const makeEdgeSelector = (
|
||||
source: string,
|
||||
@ -29,7 +29,7 @@ export const makeEdgeSelector = (
|
||||
|
||||
const stroke =
|
||||
sourceType && nodes.shouldColorEdges
|
||||
? colorTokenToCssVar(FIELDS[sourceType].color)
|
||||
? getFieldColor(sourceType)
|
||||
: colorTokenToCssVar('base.500');
|
||||
|
||||
return {
|
||||
|
@ -1,7 +1,7 @@
|
||||
import { useColorModeValue } from '@chakra-ui/react';
|
||||
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
|
||||
import { useNodeData } from 'features/nodes/hooks/useNodeData';
|
||||
import { isInvocationNodeData } from 'features/nodes/types/types';
|
||||
import { isInvocationNodeData } from 'features/nodes/types/invocation';
|
||||
import { map } from 'lodash-es';
|
||||
import { CSSProperties, memo, useMemo } from 'react';
|
||||
import { Handle, Position } from 'reactflow';
|
||||
|
@ -2,8 +2,8 @@ import { Flex, Icon, Text, Tooltip } from '@chakra-ui/react';
|
||||
import { compare } from 'compare-versions';
|
||||
import { useNodeData } from 'features/nodes/hooks/useNodeData';
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import { useNodeVersion } from 'features/nodes/hooks/useNodeVersion';
|
||||
import { isInvocationNodeData } from 'features/nodes/types/types';
|
||||
import { useNodeNeedsUpdate } from 'features/nodes/hooks/useNodeNeedsUpdate';
|
||||
import { isInvocationNodeData } from 'features/nodes/types/invocation';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaInfoCircle } from 'react-icons/fa';
|
||||
@ -13,7 +13,7 @@ interface Props {
|
||||
}
|
||||
|
||||
const InvocationNodeInfoIcon = ({ nodeId }: Props) => {
|
||||
const { needsUpdate } = useNodeVersion(nodeId);
|
||||
const needsUpdate = useNodeNeedsUpdate(nodeId);
|
||||
|
||||
return (
|
||||
<Tooltip
|
||||
|
@ -11,7 +11,10 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
|
||||
import { NodeExecutionState, NodeStatus } from 'features/nodes/types/types';
|
||||
import {
|
||||
NodeExecutionState,
|
||||
zNodeStatus,
|
||||
} from 'features/nodes/types/invocation';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { FaCheck, FaEllipsisH, FaExclamation } from 'react-icons/fa';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@ -74,10 +77,10 @@ type TooltipLabelProps = {
|
||||
const TooltipLabel = memo(({ nodeExecutionState }: TooltipLabelProps) => {
|
||||
const { status, progress, progressImage } = nodeExecutionState;
|
||||
const { t } = useTranslation();
|
||||
if (status === NodeStatus.PENDING) {
|
||||
if (status === zNodeStatus.enum.PENDING) {
|
||||
return <Text>{t('queue.pending')}</Text>;
|
||||
}
|
||||
if (status === NodeStatus.IN_PROGRESS) {
|
||||
if (status === zNodeStatus.enum.IN_PROGRESS) {
|
||||
if (progressImage) {
|
||||
return (
|
||||
<Flex sx={{ pos: 'relative', pt: 1.5, pb: 0.5 }}>
|
||||
@ -108,11 +111,11 @@ const TooltipLabel = memo(({ nodeExecutionState }: TooltipLabelProps) => {
|
||||
return <Text>{t('nodes.executionStateInProgress')}</Text>;
|
||||
}
|
||||
|
||||
if (status === NodeStatus.COMPLETED) {
|
||||
if (status === zNodeStatus.enum.COMPLETED) {
|
||||
return <Text>{t('nodes.executionStateCompleted')}</Text>;
|
||||
}
|
||||
|
||||
if (status === NodeStatus.FAILED) {
|
||||
if (status === zNodeStatus.enum.FAILED) {
|
||||
return <Text>{t('nodes.executionStateError')}</Text>;
|
||||
}
|
||||
|
||||
@ -127,7 +130,7 @@ type StatusIconProps = {
|
||||
|
||||
const StatusIcon = memo((props: StatusIconProps) => {
|
||||
const { progress, status } = props.nodeExecutionState;
|
||||
if (status === NodeStatus.PENDING) {
|
||||
if (status === zNodeStatus.enum.PENDING) {
|
||||
return (
|
||||
<Icon
|
||||
as={FaEllipsisH}
|
||||
@ -139,7 +142,7 @@ const StatusIcon = memo((props: StatusIconProps) => {
|
||||
/>
|
||||
);
|
||||
}
|
||||
if (status === NodeStatus.IN_PROGRESS) {
|
||||
if (status === zNodeStatus.enum.IN_PROGRESS) {
|
||||
return progress === null ? (
|
||||
<CircularProgress
|
||||
isIndeterminate
|
||||
@ -158,7 +161,7 @@ const StatusIcon = memo((props: StatusIconProps) => {
|
||||
/>
|
||||
);
|
||||
}
|
||||
if (status === NodeStatus.COMPLETED) {
|
||||
if (status === zNodeStatus.enum.COMPLETED) {
|
||||
return (
|
||||
<Icon
|
||||
as={FaCheck}
|
||||
@ -170,7 +173,7 @@ const StatusIcon = memo((props: StatusIconProps) => {
|
||||
/>
|
||||
);
|
||||
}
|
||||
if (status === NodeStatus.FAILED) {
|
||||
if (status === zNodeStatus.enum.FAILED) {
|
||||
return (
|
||||
<Icon
|
||||
as={FaExclamation}
|
||||
|
@ -1,7 +1,7 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { InvocationNodeData } from 'features/nodes/types/types';
|
||||
import { InvocationNodeData } from 'features/nodes/types/invocation';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { NodeProps } from 'reactflow';
|
||||
import InvocationNode from '../Invocation/InvocationNode';
|
||||
|
@ -3,7 +3,7 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAITextarea from 'common/components/IAITextarea';
|
||||
import { useNodeData } from 'features/nodes/hooks/useNodeData';
|
||||
import { nodeNotesChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { isInvocationNodeData } from 'features/nodes/types/types';
|
||||
import { isInvocationNodeData } from 'features/nodes/types/invocation';
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
|
@ -56,7 +56,7 @@ const FieldContextMenu = ({ nodeId, fieldName, kind, children }: Props) => {
|
||||
);
|
||||
|
||||
const mayExpose = useMemo(
|
||||
() => ['any', 'direct'].includes(input ?? '__UNKNOWN_INPUT__'),
|
||||
() => input && ['any', 'direct'].includes(input),
|
||||
[input]
|
||||
);
|
||||
|
||||
|
@ -1,18 +1,17 @@
|
||||
import { Tooltip } from '@chakra-ui/react';
|
||||
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
|
||||
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
|
||||
import {
|
||||
COLLECTION_TYPES,
|
||||
FIELDS,
|
||||
HANDLE_TOOLTIP_OPEN_DELAY,
|
||||
MODEL_TYPES,
|
||||
POLYMORPHIC_TYPES,
|
||||
} from 'features/nodes/types/constants';
|
||||
import {
|
||||
InputFieldTemplate,
|
||||
OutputFieldTemplate,
|
||||
} from 'features/nodes/types/types';
|
||||
FieldInputTemplate,
|
||||
FieldOutputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { CSSProperties, memo, useMemo } from 'react';
|
||||
import { Handle, HandleType, Position } from 'reactflow';
|
||||
import { getFieldColor } from '../../../edges/util/getEdgeColor';
|
||||
|
||||
export const handleBaseStyles: CSSProperties = {
|
||||
position: 'absolute',
|
||||
@ -32,11 +31,11 @@ export const outputHandleStyles: CSSProperties = {
|
||||
};
|
||||
|
||||
type FieldHandleProps = {
|
||||
fieldTemplate: InputFieldTemplate | OutputFieldTemplate;
|
||||
fieldTemplate: FieldInputTemplate | FieldOutputTemplate;
|
||||
handleType: HandleType;
|
||||
isConnectionInProgress: boolean;
|
||||
isConnectionStartField: boolean;
|
||||
connectionError: string | null;
|
||||
connectionError?: string;
|
||||
};
|
||||
|
||||
const FieldHandle = (props: FieldHandleProps) => {
|
||||
@ -47,23 +46,21 @@ const FieldHandle = (props: FieldHandleProps) => {
|
||||
isConnectionStartField,
|
||||
connectionError,
|
||||
} = props;
|
||||
const { name, type } = fieldTemplate;
|
||||
const { color: typeColor, title } = FIELDS[type];
|
||||
|
||||
const { name } = fieldTemplate;
|
||||
const type = fieldTemplate.type;
|
||||
const fieldTypeName = useFieldTypeName(type);
|
||||
const styles: CSSProperties = useMemo(() => {
|
||||
const isCollectionType = COLLECTION_TYPES.includes(type);
|
||||
const isPolymorphicType = POLYMORPHIC_TYPES.includes(type);
|
||||
const isModelType = MODEL_TYPES.includes(type);
|
||||
const color = colorTokenToCssVar(typeColor);
|
||||
const isModelType = MODEL_TYPES.some((t) => t === type.name);
|
||||
const color = getFieldColor(type);
|
||||
const s: CSSProperties = {
|
||||
backgroundColor:
|
||||
isCollectionType || isPolymorphicType
|
||||
? 'var(--invokeai-colors-base-900)'
|
||||
type.isCollection || type.isPolymorphic
|
||||
? colorTokenToCssVar('base.900')
|
||||
: color,
|
||||
position: 'absolute',
|
||||
width: '1rem',
|
||||
height: '1rem',
|
||||
borderWidth: isCollectionType || isPolymorphicType ? 4 : 0,
|
||||
borderWidth: type.isCollection || type.isPolymorphic ? 4 : 0,
|
||||
borderStyle: 'solid',
|
||||
borderColor: color,
|
||||
borderRadius: isModelType ? 4 : '100%',
|
||||
@ -97,18 +94,14 @@ const FieldHandle = (props: FieldHandleProps) => {
|
||||
isConnectionInProgress,
|
||||
isConnectionStartField,
|
||||
type,
|
||||
typeColor,
|
||||
]);
|
||||
|
||||
const tooltip = useMemo(() => {
|
||||
if (isConnectionInProgress && isConnectionStartField) {
|
||||
return title;
|
||||
}
|
||||
if (isConnectionInProgress && connectionError) {
|
||||
return connectionError ?? title;
|
||||
return connectionError;
|
||||
}
|
||||
return title;
|
||||
}, [connectionError, isConnectionInProgress, isConnectionStartField, title]);
|
||||
return fieldTypeName;
|
||||
}, [connectionError, fieldTypeName, isConnectionInProgress]);
|
||||
|
||||
return (
|
||||
<Tooltip
|
||||
|
@ -1,15 +1,14 @@
|
||||
import { Flex, Text } from '@chakra-ui/react';
|
||||
import { useFieldData } from 'features/nodes/hooks/useFieldData';
|
||||
import { useFieldInstance } from 'features/nodes/hooks/useFieldData';
|
||||
import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate';
|
||||
import { FIELDS } from 'features/nodes/types/constants';
|
||||
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
|
||||
import {
|
||||
isInputFieldTemplate,
|
||||
isInputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
isFieldInputInstance,
|
||||
isFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { startCase } from 'lodash-es';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
interface Props {
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
@ -17,12 +16,13 @@ interface Props {
|
||||
}
|
||||
|
||||
const FieldTooltipContent = ({ nodeId, fieldName, kind }: Props) => {
|
||||
const field = useFieldData(nodeId, fieldName);
|
||||
const field = useFieldInstance(nodeId, fieldName);
|
||||
const fieldTemplate = useFieldTemplate(nodeId, fieldName, kind);
|
||||
const isInputTemplate = isInputFieldTemplate(fieldTemplate);
|
||||
const isInputTemplate = isFieldInputTemplate(fieldTemplate);
|
||||
const fieldTypeName = useFieldTypeName(fieldTemplate?.type);
|
||||
const { t } = useTranslation();
|
||||
const fieldTitle = useMemo(() => {
|
||||
if (isInputFieldValue(field)) {
|
||||
if (isFieldInputInstance(field)) {
|
||||
if (field.label && fieldTemplate?.title) {
|
||||
return `${field.label} (${fieldTemplate.title})`;
|
||||
}
|
||||
@ -49,9 +49,9 @@ const FieldTooltipContent = ({ nodeId, fieldName, kind }: Props) => {
|
||||
{fieldTemplate.description}
|
||||
</Text>
|
||||
)}
|
||||
{fieldTemplate && (
|
||||
{fieldTypeName && (
|
||||
<Text>
|
||||
{t('parameters.type')}: {FIELDS[fieldTemplate.type].title}
|
||||
{t('parameters.type')}: {fieldTypeName}
|
||||
</Text>
|
||||
)}
|
||||
{isInputTemplate && (
|
||||
|
@ -77,10 +77,10 @@ const InputField = ({ nodeId, fieldName }: Props) => {
|
||||
sx={{
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
h: 'full',
|
||||
mb: 0,
|
||||
px: 1,
|
||||
gap: 2,
|
||||
h: 'full',
|
||||
}}
|
||||
>
|
||||
<EditableFieldTitle
|
||||
|
@ -1,24 +1,60 @@
|
||||
import { Box, Text } from '@chakra-ui/react';
|
||||
import { useFieldData } from 'features/nodes/hooks/useFieldData';
|
||||
import { useFieldInstance } from 'features/nodes/hooks/useFieldData';
|
||||
import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate';
|
||||
import {
|
||||
isBoardFieldInputInstance,
|
||||
isBoardFieldInputTemplate,
|
||||
isBooleanFieldInputInstance,
|
||||
isBooleanFieldInputTemplate,
|
||||
isColorFieldInputInstance,
|
||||
isColorFieldInputTemplate,
|
||||
isControlNetModelFieldInputInstance,
|
||||
isControlNetModelFieldInputTemplate,
|
||||
isEnumFieldInputInstance,
|
||||
isEnumFieldInputTemplate,
|
||||
isFloatFieldInputInstance,
|
||||
isFloatFieldInputTemplate,
|
||||
isIPAdapterModelFieldInputInstance,
|
||||
isIPAdapterModelFieldInputTemplate,
|
||||
isImageFieldInputInstance,
|
||||
isImageFieldInputTemplate,
|
||||
isIntegerFieldInputInstance,
|
||||
isIntegerFieldInputTemplate,
|
||||
isLoRAModelFieldInputInstance,
|
||||
isLoRAModelFieldInputTemplate,
|
||||
isMainModelFieldInputInstance,
|
||||
isMainModelFieldInputTemplate,
|
||||
isSDXLMainModelFieldInputInstance,
|
||||
isSDXLMainModelFieldInputTemplate,
|
||||
isSDXLRefinerModelFieldInputInstance,
|
||||
isSDXLRefinerModelFieldInputTemplate,
|
||||
isSchedulerFieldInputInstance,
|
||||
isSchedulerFieldInputTemplate,
|
||||
isStringFieldInputInstance,
|
||||
isStringFieldInputTemplate,
|
||||
isT2IAdapterModelFieldInputInstance,
|
||||
isT2IAdapterModelFieldInputTemplate,
|
||||
isVAEModelFieldInputInstance,
|
||||
isVAEModelFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { memo } from 'react';
|
||||
import BooleanInputField from './inputs/BooleanInputField';
|
||||
import ColorInputField from './inputs/ColorInputField';
|
||||
import ControlNetModelInputField from './inputs/ControlNetModelInputField';
|
||||
import EnumInputField from './inputs/EnumInputField';
|
||||
import ImageInputField from './inputs/ImageInputField';
|
||||
import LoRAModelInputField from './inputs/LoRAModelInputField';
|
||||
import MainModelInputField from './inputs/MainModelInputField';
|
||||
import NumberInputField from './inputs/NumberInputField';
|
||||
import RefinerModelInputField from './inputs/RefinerModelInputField';
|
||||
import SDXLMainModelInputField from './inputs/SDXLMainModelInputField';
|
||||
import SchedulerInputField from './inputs/SchedulerInputField';
|
||||
import StringInputField from './inputs/StringInputField';
|
||||
import VaeModelInputField from './inputs/VaeModelInputField';
|
||||
import IPAdapterModelInputField from './inputs/IPAdapterModelInputField';
|
||||
import T2IAdapterModelInputField from './inputs/T2IAdapterModelInputField';
|
||||
import BoardInputField from './inputs/BoardInputField';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import BoardFieldInputComponent from './inputs/BoardFieldInputComponent';
|
||||
import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent';
|
||||
import ColorFieldInputComponent from './inputs/ColorFieldInputComponent';
|
||||
import ControlNetModelFieldInputComponent from './inputs/ControlNetModelFieldInputComponent';
|
||||
import EnumFieldInputComponent from './inputs/EnumFieldInputComponent';
|
||||
import IPAdapterModelFieldInputComponent from './inputs/IPAdapterModelFieldInputComponent';
|
||||
import ImageFieldInputComponent from './inputs/ImageFieldInputComponent';
|
||||
import LoRAModelFieldInputComponent from './inputs/LoRAModelFieldInputComponent';
|
||||
import MainModelFieldInputComponent from './inputs/MainModelFieldInputComponent';
|
||||
import NumberFieldInputComponent from './inputs/NumberFieldInputComponent';
|
||||
import RefinerModelFieldInputComponent from './inputs/RefinerModelFieldInputComponent';
|
||||
import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent';
|
||||
import SchedulerFieldInputComponent from './inputs/SchedulerFieldInputComponent';
|
||||
import StringFieldInputComponent from './inputs/StringFieldInputComponent';
|
||||
import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent';
|
||||
import VAEModelFieldInputComponent from './inputs/VAEModelFieldInputComponent';
|
||||
|
||||
type InputFieldProps = {
|
||||
nodeId: string;
|
||||
@ -27,220 +63,227 @@ type InputFieldProps = {
|
||||
|
||||
const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
const { t } = useTranslation();
|
||||
const field = useFieldData(nodeId, fieldName);
|
||||
const fieldInstance = useFieldInstance(nodeId, fieldName);
|
||||
const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input');
|
||||
|
||||
if (fieldTemplate?.fieldKind === 'output') {
|
||||
return (
|
||||
<Box p={2}>
|
||||
{t('nodes.outputFieldInInput')}: {field?.type}
|
||||
{t('nodes.outputFieldInInput')}: {fieldInstance?.type.name}
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
(field?.type === 'string' && fieldTemplate?.type === 'string') ||
|
||||
(field?.type === 'StringPolymorphic' &&
|
||||
fieldTemplate?.type === 'StringPolymorphic')
|
||||
isStringFieldInputInstance(fieldInstance) &&
|
||||
isStringFieldInputTemplate(fieldTemplate)
|
||||
) {
|
||||
return (
|
||||
<StringInputField
|
||||
<StringFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
(field?.type === 'boolean' && fieldTemplate?.type === 'boolean') ||
|
||||
(field?.type === 'BooleanPolymorphic' &&
|
||||
fieldTemplate?.type === 'BooleanPolymorphic')
|
||||
isBooleanFieldInputInstance(fieldInstance) &&
|
||||
isBooleanFieldInputTemplate(fieldTemplate)
|
||||
) {
|
||||
return (
|
||||
<BooleanInputField
|
||||
<BooleanFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
(field?.type === 'integer' && fieldTemplate?.type === 'integer') ||
|
||||
(field?.type === 'float' && fieldTemplate?.type === 'float') ||
|
||||
(field?.type === 'FloatPolymorphic' &&
|
||||
fieldTemplate?.type === 'FloatPolymorphic') ||
|
||||
(field?.type === 'IntegerPolymorphic' &&
|
||||
fieldTemplate?.type === 'IntegerPolymorphic')
|
||||
(isIntegerFieldInputInstance(fieldInstance) &&
|
||||
isIntegerFieldInputTemplate(fieldTemplate)) ||
|
||||
(isFloatFieldInputInstance(fieldInstance) &&
|
||||
isFloatFieldInputTemplate(fieldTemplate))
|
||||
) {
|
||||
return (
|
||||
<NumberInputField
|
||||
<NumberFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (field?.type === 'enum' && fieldTemplate?.type === 'enum') {
|
||||
return (
|
||||
<EnumInputField
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
(field?.type === 'ImageField' && fieldTemplate?.type === 'ImageField') ||
|
||||
(field?.type === 'ImagePolymorphic' &&
|
||||
fieldTemplate?.type === 'ImagePolymorphic')
|
||||
isEnumFieldInputInstance(fieldInstance) &&
|
||||
isEnumFieldInputTemplate(fieldTemplate)
|
||||
) {
|
||||
return (
|
||||
<ImageInputField
|
||||
<EnumFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (field?.type === 'BoardField' && fieldTemplate?.type === 'BoardField') {
|
||||
return (
|
||||
<BoardInputField
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'MainModelField' &&
|
||||
fieldTemplate?.type === 'MainModelField'
|
||||
isImageFieldInputInstance(fieldInstance) &&
|
||||
isImageFieldInputTemplate(fieldTemplate)
|
||||
) {
|
||||
return (
|
||||
<MainModelInputField
|
||||
<ImageFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'SDXLRefinerModelField' &&
|
||||
fieldTemplate?.type === 'SDXLRefinerModelField'
|
||||
isBoardFieldInputInstance(fieldInstance) &&
|
||||
isBoardFieldInputTemplate(fieldTemplate)
|
||||
) {
|
||||
return (
|
||||
<RefinerModelInputField
|
||||
<BoardFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'VaeModelField' &&
|
||||
fieldTemplate?.type === 'VaeModelField'
|
||||
isMainModelFieldInputInstance(fieldInstance) &&
|
||||
isMainModelFieldInputTemplate(fieldTemplate)
|
||||
) {
|
||||
return (
|
||||
<VaeModelInputField
|
||||
<MainModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'LoRAModelField' &&
|
||||
fieldTemplate?.type === 'LoRAModelField'
|
||||
isSDXLRefinerModelFieldInputInstance(fieldInstance) &&
|
||||
isSDXLRefinerModelFieldInputTemplate(fieldTemplate)
|
||||
) {
|
||||
return (
|
||||
<LoRAModelInputField
|
||||
<RefinerModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'ControlNetModelField' &&
|
||||
fieldTemplate?.type === 'ControlNetModelField'
|
||||
isVAEModelFieldInputInstance(fieldInstance) &&
|
||||
isVAEModelFieldInputTemplate(fieldTemplate)
|
||||
) {
|
||||
return (
|
||||
<ControlNetModelInputField
|
||||
<VAEModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'IPAdapterModelField' &&
|
||||
fieldTemplate?.type === 'IPAdapterModelField'
|
||||
isLoRAModelFieldInputInstance(fieldInstance) &&
|
||||
isLoRAModelFieldInputTemplate(fieldTemplate)
|
||||
) {
|
||||
return (
|
||||
<IPAdapterModelInputField
|
||||
<LoRAModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'T2IAdapterModelField' &&
|
||||
fieldTemplate?.type === 'T2IAdapterModelField'
|
||||
isControlNetModelFieldInputInstance(fieldInstance) &&
|
||||
isControlNetModelFieldInputTemplate(fieldTemplate)
|
||||
) {
|
||||
return (
|
||||
<T2IAdapterModelInputField
|
||||
<ControlNetModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
if (field?.type === 'ColorField' && fieldTemplate?.type === 'ColorField') {
|
||||
return (
|
||||
<ColorInputField
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'SDXLMainModelField' &&
|
||||
fieldTemplate?.type === 'SDXLMainModelField'
|
||||
isIPAdapterModelFieldInputInstance(fieldInstance) &&
|
||||
isIPAdapterModelFieldInputTemplate(fieldTemplate)
|
||||
) {
|
||||
return (
|
||||
<SDXLMainModelInputField
|
||||
<IPAdapterModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (field?.type === 'Scheduler' && fieldTemplate?.type === 'Scheduler') {
|
||||
if (
|
||||
isT2IAdapterModelFieldInputInstance(fieldInstance) &&
|
||||
isT2IAdapterModelFieldInputTemplate(fieldTemplate)
|
||||
) {
|
||||
return (
|
||||
<SchedulerInputField
|
||||
<T2IAdapterModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
if (
|
||||
isColorFieldInputInstance(fieldInstance) &&
|
||||
isColorFieldInputTemplate(fieldTemplate)
|
||||
) {
|
||||
return (
|
||||
<ColorFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (field && fieldTemplate) {
|
||||
if (
|
||||
isSDXLMainModelFieldInputInstance(fieldInstance) &&
|
||||
isSDXLMainModelFieldInputTemplate(fieldTemplate)
|
||||
) {
|
||||
return (
|
||||
<SDXLMainModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
isSchedulerFieldInputInstance(fieldInstance) &&
|
||||
isSchedulerFieldInputTemplate(fieldTemplate)
|
||||
) {
|
||||
return (
|
||||
<SchedulerFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (fieldInstance && fieldTemplate) {
|
||||
// Fallback for when there is no component for the type
|
||||
return null;
|
||||
}
|
||||
@ -255,7 +298,7 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
_dark: { color: 'error.300' },
|
||||
}}
|
||||
>
|
||||
{t('nodes.unknownFieldType')}: {field?.type}
|
||||
{t('nodes.unknownFieldType')}: {fieldInstance?.type.name}
|
||||
</Text>
|
||||
</Box>
|
||||
);
|
||||
|
@ -3,15 +3,15 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||
import { fieldBoardValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
BoardInputFieldTemplate,
|
||||
BoardInputFieldValue,
|
||||
FieldComponentProps,
|
||||
} from 'features/nodes/types/types';
|
||||
BoardFieldInputTemplate,
|
||||
BoardFieldInputInstance,
|
||||
} from 'features/nodes/types/field';
|
||||
import { FieldComponentProps } from './types';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
|
||||
|
||||
const BoardInputFieldComponent = (
|
||||
props: FieldComponentProps<BoardInputFieldValue, BoardInputFieldTemplate>
|
||||
const BoardFieldInputComponent = (
|
||||
props: FieldComponentProps<BoardFieldInputInstance, BoardFieldInputTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
@ -61,4 +61,4 @@ const BoardInputFieldComponent = (
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(BoardInputFieldComponent);
|
||||
export default memo(BoardFieldInputComponent);
|
@ -2,18 +2,16 @@ import { Switch } from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { fieldBooleanValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
BooleanInputFieldTemplate,
|
||||
BooleanInputFieldValue,
|
||||
BooleanPolymorphicInputFieldTemplate,
|
||||
BooleanPolymorphicInputFieldValue,
|
||||
FieldComponentProps,
|
||||
} from 'features/nodes/types/types';
|
||||
BooleanFieldInputInstance,
|
||||
BooleanFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { FieldComponentProps } from './types';
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
|
||||
const BooleanInputFieldComponent = (
|
||||
const BooleanFieldInputComponent = (
|
||||
props: FieldComponentProps<
|
||||
BooleanInputFieldValue | BooleanPolymorphicInputFieldValue,
|
||||
BooleanInputFieldTemplate | BooleanPolymorphicInputFieldTemplate
|
||||
BooleanFieldInputInstance,
|
||||
BooleanFieldInputTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
@ -42,4 +40,4 @@ const BooleanInputFieldComponent = (
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(BooleanInputFieldComponent);
|
||||
export default memo(BooleanFieldInputComponent);
|
@ -1,15 +1,15 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { fieldColorValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
ColorInputFieldTemplate,
|
||||
ColorInputFieldValue,
|
||||
FieldComponentProps,
|
||||
} from 'features/nodes/types/types';
|
||||
ColorFieldInputTemplate,
|
||||
ColorFieldInputInstance,
|
||||
} from 'features/nodes/types/field';
|
||||
import { FieldComponentProps } from './types';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { RgbaColor, RgbaColorPicker } from 'react-colorful';
|
||||
|
||||
const ColorInputFieldComponent = (
|
||||
props: FieldComponentProps<ColorInputFieldValue, ColorInputFieldTemplate>
|
||||
const ColorFieldInputComponent = (
|
||||
props: FieldComponentProps<ColorFieldInputInstance, ColorFieldInputTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
|
||||
@ -37,4 +37,4 @@ const ColorInputFieldComponent = (
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ColorInputFieldComponent);
|
||||
export default memo(ColorFieldInputComponent);
|
@ -3,20 +3,20 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { fieldControlNetModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
ControlNetModelInputFieldTemplate,
|
||||
ControlNetModelInputFieldValue,
|
||||
FieldComponentProps,
|
||||
} from 'features/nodes/types/types';
|
||||
ControlNetModelFieldInputTemplate,
|
||||
ControlNetModelFieldInputInstance,
|
||||
} from 'features/nodes/types/field';
|
||||
import { FieldComponentProps } from './types';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { modelIdToControlNetModelParam } from 'features/parameters/util/modelIdToControlNetModelParam';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
const ControlNetModelInputFieldComponent = (
|
||||
const ControlNetModelFieldInputComponent = (
|
||||
props: FieldComponentProps<
|
||||
ControlNetModelInputFieldValue,
|
||||
ControlNetModelInputFieldTemplate
|
||||
ControlNetModelFieldInputInstance,
|
||||
ControlNetModelFieldInputTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
@ -97,4 +97,4 @@ const ControlNetModelInputFieldComponent = (
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ControlNetModelInputFieldComponent);
|
||||
export default memo(ControlNetModelFieldInputComponent);
|
@ -2,14 +2,14 @@ import { Select } from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { fieldEnumModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
EnumInputFieldTemplate,
|
||||
EnumInputFieldValue,
|
||||
FieldComponentProps,
|
||||
} from 'features/nodes/types/types';
|
||||
EnumFieldInputInstance,
|
||||
EnumFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { FieldComponentProps } from './types';
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
|
||||
const EnumInputFieldComponent = (
|
||||
props: FieldComponentProps<EnumInputFieldValue, EnumInputFieldTemplate>
|
||||
const EnumFieldInputComponent = (
|
||||
props: FieldComponentProps<EnumFieldInputInstance, EnumFieldInputTemplate>
|
||||
) => {
|
||||
const { nodeId, field, fieldTemplate } = props;
|
||||
|
||||
@ -45,4 +45,4 @@ const EnumInputFieldComponent = (
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(EnumInputFieldComponent);
|
||||
export default memo(EnumFieldInputComponent);
|
@ -3,20 +3,20 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { fieldIPAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
IPAdapterModelInputFieldTemplate,
|
||||
IPAdapterModelInputFieldValue,
|
||||
FieldComponentProps,
|
||||
} from 'features/nodes/types/types';
|
||||
IPAdapterModelFieldInputTemplate,
|
||||
IPAdapterModelFieldInputInstance,
|
||||
} from 'features/nodes/types/field';
|
||||
import { FieldComponentProps } from './types';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { modelIdToIPAdapterModelParam } from 'features/parameters/util/modelIdToIPAdapterModelParams';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useGetIPAdapterModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
const IPAdapterModelInputFieldComponent = (
|
||||
const IPAdapterModelFieldInputComponent = (
|
||||
props: FieldComponentProps<
|
||||
IPAdapterModelInputFieldValue,
|
||||
IPAdapterModelInputFieldTemplate
|
||||
IPAdapterModelFieldInputInstance,
|
||||
IPAdapterModelFieldInputTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
@ -97,4 +97,4 @@ const IPAdapterModelInputFieldComponent = (
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(IPAdapterModelInputFieldComponent);
|
||||
export default memo(IPAdapterModelFieldInputComponent);
|
@ -9,23 +9,18 @@ import {
|
||||
} from 'features/dnd/types';
|
||||
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
FieldComponentProps,
|
||||
ImageInputFieldTemplate,
|
||||
ImageInputFieldValue,
|
||||
ImagePolymorphicInputFieldTemplate,
|
||||
ImagePolymorphicInputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
ImageFieldInputInstance,
|
||||
ImageFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { FieldComponentProps } from './types';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaUndo } from 'react-icons/fa';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import { PostUploadAction } from 'services/api/types';
|
||||
|
||||
const ImageInputFieldComponent = (
|
||||
props: FieldComponentProps<
|
||||
ImageInputFieldValue | ImagePolymorphicInputFieldValue,
|
||||
ImageInputFieldTemplate | ImagePolymorphicInputFieldTemplate
|
||||
>
|
||||
const ImageFieldInputComponent = (
|
||||
props: FieldComponentProps<ImageFieldInputInstance, ImageFieldInputTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
@ -102,7 +97,7 @@ const ImageInputFieldComponent = (
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ImageInputFieldComponent);
|
||||
export default memo(ImageFieldInputComponent);
|
||||
|
||||
const UploadElement = memo(() => {
|
||||
const { t } = useTranslation();
|
@ -5,10 +5,10 @@ import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSe
|
||||
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
|
||||
import { fieldLoRAModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
LoRAModelInputFieldTemplate,
|
||||
LoRAModelInputFieldValue,
|
||||
FieldComponentProps,
|
||||
} from 'features/nodes/types/types';
|
||||
LoRAModelFieldInputTemplate,
|
||||
LoRAModelFieldInputInstance,
|
||||
} from 'features/nodes/types/field';
|
||||
import { FieldComponentProps } from './types';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { modelIdToLoRAModelParam } from 'features/parameters/util/modelIdToLoRAModelParam';
|
||||
import { forEach } from 'lodash-es';
|
||||
@ -16,10 +16,10 @@ import { memo, useCallback, useMemo } from 'react';
|
||||
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const LoRAModelInputFieldComponent = (
|
||||
const LoRAModelFieldInputComponent = (
|
||||
props: FieldComponentProps<
|
||||
LoRAModelInputFieldValue,
|
||||
LoRAModelInputFieldTemplate
|
||||
LoRAModelFieldInputInstance,
|
||||
LoRAModelFieldInputTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
@ -121,4 +121,4 @@ const LoRAModelInputFieldComponent = (
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(LoRAModelInputFieldComponent);
|
||||
export default memo(LoRAModelFieldInputComponent);
|
@ -4,10 +4,10 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
MainModelInputFieldTemplate,
|
||||
MainModelInputFieldValue,
|
||||
FieldComponentProps,
|
||||
} from 'features/nodes/types/types';
|
||||
MainModelFieldInputTemplate,
|
||||
MainModelFieldInputInstance,
|
||||
} from 'features/nodes/types/field';
|
||||
import { FieldComponentProps } from './types';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
@ -21,10 +21,10 @@ import {
|
||||
} from 'services/api/endpoints/models';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const MainModelInputFieldComponent = (
|
||||
const MainModelFieldInputComponent = (
|
||||
props: FieldComponentProps<
|
||||
MainModelInputFieldValue,
|
||||
MainModelInputFieldTemplate
|
||||
MainModelFieldInputInstance,
|
||||
MainModelFieldInputTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
@ -149,4 +149,4 @@ const MainModelInputFieldComponent = (
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(MainModelInputFieldComponent);
|
||||
export default memo(MainModelFieldInputComponent);
|
@ -9,28 +9,18 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { numberStringRegex } from 'common/components/IAINumberInput';
|
||||
import { fieldNumberValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
FieldComponentProps,
|
||||
FloatInputFieldTemplate,
|
||||
FloatInputFieldValue,
|
||||
FloatPolymorphicInputFieldTemplate,
|
||||
FloatPolymorphicInputFieldValue,
|
||||
IntegerInputFieldTemplate,
|
||||
IntegerInputFieldValue,
|
||||
IntegerPolymorphicInputFieldTemplate,
|
||||
IntegerPolymorphicInputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
FloatFieldInputInstance,
|
||||
FloatFieldInputTemplate,
|
||||
IntegerFieldInputInstance,
|
||||
IntegerFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { FieldComponentProps } from './types';
|
||||
import { memo, useCallback, useEffect, useMemo, useState } from 'react';
|
||||
|
||||
const NumberInputFieldComponent = (
|
||||
const NumberFieldInputComponent = (
|
||||
props: FieldComponentProps<
|
||||
| IntegerInputFieldValue
|
||||
| IntegerPolymorphicInputFieldValue
|
||||
| FloatInputFieldValue
|
||||
| FloatPolymorphicInputFieldValue,
|
||||
| IntegerInputFieldTemplate
|
||||
| IntegerPolymorphicInputFieldTemplate
|
||||
| FloatInputFieldTemplate
|
||||
| FloatPolymorphicInputFieldTemplate
|
||||
IntegerFieldInputInstance | FloatFieldInputInstance,
|
||||
IntegerFieldInputTemplate | FloatFieldInputTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field, fieldTemplate } = props;
|
||||
@ -39,7 +29,7 @@ const NumberInputFieldComponent = (
|
||||
String(field.value)
|
||||
);
|
||||
const isIntegerField = useMemo(
|
||||
() => fieldTemplate.type === 'integer',
|
||||
() => fieldTemplate.type.name === 'IntegerField',
|
||||
[fieldTemplate.type]
|
||||
);
|
||||
|
||||
@ -86,4 +76,4 @@ const NumberInputFieldComponent = (
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(NumberInputFieldComponent);
|
||||
export default memo(NumberFieldInputComponent);
|
@ -4,10 +4,10 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||
import { fieldRefinerModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
FieldComponentProps,
|
||||
SDXLRefinerModelInputFieldTemplate,
|
||||
SDXLRefinerModelInputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
SDXLRefinerModelFieldInputTemplate,
|
||||
SDXLRefinerModelFieldInputInstance,
|
||||
} from 'features/nodes/types/field';
|
||||
import { FieldComponentProps } from './types';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
@ -18,10 +18,10 @@ import { useTranslation } from 'react-i18next';
|
||||
import { REFINER_BASE_MODELS } from 'services/api/constants';
|
||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
const RefinerModelInputFieldComponent = (
|
||||
const RefinerModelFieldInputComponent = (
|
||||
props: FieldComponentProps<
|
||||
SDXLRefinerModelInputFieldValue,
|
||||
SDXLRefinerModelInputFieldTemplate
|
||||
SDXLRefinerModelFieldInputInstance,
|
||||
SDXLRefinerModelFieldInputTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
@ -120,4 +120,4 @@ const RefinerModelInputFieldComponent = (
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(RefinerModelInputFieldComponent);
|
||||
export default memo(RefinerModelFieldInputComponent);
|
@ -4,10 +4,10 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
SDXLMainModelInputFieldTemplate,
|
||||
SDXLMainModelInputFieldValue,
|
||||
FieldComponentProps,
|
||||
} from 'features/nodes/types/types';
|
||||
SDXLMainModelFieldInputTemplate,
|
||||
SDXLMainModelFieldInputInstance,
|
||||
} from 'features/nodes/types/field';
|
||||
import { FieldComponentProps } from './types';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
@ -21,10 +21,10 @@ import {
|
||||
useGetOnnxModelsQuery,
|
||||
} from 'services/api/endpoints/models';
|
||||
|
||||
const ModelInputFieldComponent = (
|
||||
const SDXLMainModelFieldInputComponent = (
|
||||
props: FieldComponentProps<
|
||||
SDXLMainModelInputFieldValue,
|
||||
SDXLMainModelInputFieldTemplate
|
||||
SDXLMainModelFieldInputInstance,
|
||||
SDXLMainModelFieldInputTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
@ -147,4 +147,4 @@ const ModelInputFieldComponent = (
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ModelInputFieldComponent);
|
||||
export default memo(SDXLMainModelFieldInputComponent);
|
@ -5,14 +5,12 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||
import { fieldSchedulerValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
SchedulerInputFieldTemplate,
|
||||
SchedulerInputFieldValue,
|
||||
FieldComponentProps,
|
||||
} from 'features/nodes/types/types';
|
||||
import {
|
||||
SCHEDULER_LABEL_MAP,
|
||||
SchedulerParam,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
SchedulerFieldInputTemplate,
|
||||
SchedulerFieldInputInstance,
|
||||
} from 'features/nodes/types/field';
|
||||
import { FieldComponentProps } from './types';
|
||||
import { ParameterScheduler } from 'features/parameters/types/parameterSchemas';
|
||||
import { SCHEDULER_LABEL_MAP } from 'features/parameters/types/constants';
|
||||
import { map } from 'lodash-es';
|
||||
import { memo, useCallback } from 'react';
|
||||
|
||||
@ -24,7 +22,7 @@ const selector = createSelector(
|
||||
const data = map(SCHEDULER_LABEL_MAP, (label, name) => ({
|
||||
value: name,
|
||||
label: label,
|
||||
group: enabledSchedulers.includes(name as SchedulerParam)
|
||||
group: enabledSchedulers.includes(name as ParameterScheduler)
|
||||
? 'Favorites'
|
||||
: undefined,
|
||||
})).sort((a, b) => a.label.localeCompare(b.label));
|
||||
@ -36,10 +34,10 @@ const selector = createSelector(
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const SchedulerInputField = (
|
||||
const SchedulerFieldInputComponent = (
|
||||
props: FieldComponentProps<
|
||||
SchedulerInputFieldValue,
|
||||
SchedulerInputFieldTemplate
|
||||
SchedulerFieldInputInstance,
|
||||
SchedulerFieldInputTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
@ -55,7 +53,7 @@ const SchedulerInputField = (
|
||||
fieldSchedulerValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value: value as SchedulerParam,
|
||||
value: value as ParameterScheduler,
|
||||
})
|
||||
);
|
||||
},
|
||||
@ -72,4 +70,4 @@ const SchedulerInputField = (
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(SchedulerInputField);
|
||||
export default memo(SchedulerFieldInputComponent);
|
@ -3,19 +3,14 @@ import IAIInput from 'common/components/IAIInput';
|
||||
import IAITextarea from 'common/components/IAITextarea';
|
||||
import { fieldStringValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
StringInputFieldTemplate,
|
||||
StringInputFieldValue,
|
||||
FieldComponentProps,
|
||||
StringPolymorphicInputFieldValue,
|
||||
StringPolymorphicInputFieldTemplate,
|
||||
} from 'features/nodes/types/types';
|
||||
StringFieldInputInstance,
|
||||
StringFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { FieldComponentProps } from './types';
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
|
||||
const StringInputFieldComponent = (
|
||||
props: FieldComponentProps<
|
||||
StringInputFieldValue | StringPolymorphicInputFieldValue,
|
||||
StringInputFieldTemplate | StringPolymorphicInputFieldTemplate
|
||||
>
|
||||
const StringFieldInputComponent = (
|
||||
props: FieldComponentProps<StringFieldInputInstance, StringFieldInputTemplate>
|
||||
) => {
|
||||
const { nodeId, field, fieldTemplate } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
@ -48,4 +43,4 @@ const StringInputFieldComponent = (
|
||||
return <IAIInput onChange={handleValueChanged} value={field.value} />;
|
||||
};
|
||||
|
||||
export default memo(StringInputFieldComponent);
|
||||
export default memo(StringFieldInputComponent);
|
@ -3,20 +3,20 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { fieldT2IAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
T2IAdapterModelInputFieldTemplate,
|
||||
T2IAdapterModelInputFieldValue,
|
||||
FieldComponentProps,
|
||||
} from 'features/nodes/types/types';
|
||||
T2IAdapterModelFieldInputInstance,
|
||||
T2IAdapterModelFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { FieldComponentProps } from './types';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { modelIdToT2IAdapterModelParam } from 'features/parameters/util/modelIdToT2IAdapterModelParam';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useGetT2IAdapterModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
const T2IAdapterModelInputFieldComponent = (
|
||||
const T2IAdapterModelFieldInputComponent = (
|
||||
props: FieldComponentProps<
|
||||
T2IAdapterModelInputFieldValue,
|
||||
T2IAdapterModelInputFieldTemplate
|
||||
T2IAdapterModelFieldInputInstance,
|
||||
T2IAdapterModelFieldInputTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
@ -97,4 +97,4 @@ const T2IAdapterModelInputFieldComponent = (
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(T2IAdapterModelInputFieldComponent);
|
||||
export default memo(T2IAdapterModelFieldInputComponent);
|
@ -4,20 +4,20 @@ import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSe
|
||||
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
|
||||
import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
FieldComponentProps,
|
||||
VaeModelInputFieldTemplate,
|
||||
VaeModelInputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
VAEModelFieldInputTemplate,
|
||||
VAEModelFieldInputInstance,
|
||||
} from 'features/nodes/types/field';
|
||||
import { FieldComponentProps } from './types';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { modelIdToVAEModelParam } from 'features/parameters/util/modelIdToVAEModelParam';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
const VaeModelInputFieldComponent = (
|
||||
const VAEModelFieldInputComponent = (
|
||||
props: FieldComponentProps<
|
||||
VaeModelInputFieldValue,
|
||||
VaeModelInputFieldTemplate
|
||||
VAEModelFieldInputInstance,
|
||||
VAEModelFieldInputTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
@ -105,4 +105,4 @@ const VaeModelInputFieldComponent = (
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(VaeModelInputFieldComponent);
|
||||
export default memo(VAEModelFieldInputComponent);
|
@ -0,0 +1,13 @@
|
||||
import {
|
||||
FieldInputInstance,
|
||||
FieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
|
||||
export type FieldComponentProps<
|
||||
V extends FieldInputInstance,
|
||||
T extends FieldInputTemplate,
|
||||
> = {
|
||||
nodeId: string;
|
||||
field: V;
|
||||
fieldTemplate: T;
|
||||
};
|
@ -2,7 +2,7 @@ import { Box, Flex } from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAITextarea from 'common/components/IAITextarea';
|
||||
import { notesNodeValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { NotesNodeData } from 'features/nodes/types/types';
|
||||
import { NotesNodeData } from 'features/nodes/types/invocation';
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
import { NodeProps } from 'reactflow';
|
||||
import NodeWrapper from '../common/NodeWrapper';
|
||||
|
@ -14,7 +14,7 @@ import {
|
||||
DRAG_HANDLE_CLASSNAME,
|
||||
NODE_WIDTH,
|
||||
} from 'features/nodes/types/constants';
|
||||
import { NodeStatus } from 'features/nodes/types/types';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { contextMenusClosed } from 'features/ui/store/uiSlice';
|
||||
import {
|
||||
MouseEvent,
|
||||
@ -40,7 +40,8 @@ const NodeWrapper = (props: NodeWrapperProps) => {
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ nodes }) =>
|
||||
nodes.nodeExecutionStates[nodeId]?.status === NodeStatus.IN_PROGRESS
|
||||
nodes.nodeExecutionStates[nodeId]?.status ===
|
||||
zNodeStatus.enum.IN_PROGRESS
|
||||
),
|
||||
[nodeId]
|
||||
);
|
||||
|
@ -8,7 +8,7 @@ import { FaUpload } from 'react-icons/fa';
|
||||
const LoadWorkflowButton = () => {
|
||||
const { t } = useTranslation();
|
||||
const resetRef = useRef<() => void>(null);
|
||||
const loadWorkflowFromFile = useLoadWorkflowFromFile();
|
||||
const loadWorkflowFromFile = useLoadWorkflowFromFile(resetRef);
|
||||
return (
|
||||
<FileButton
|
||||
resetRef={resetRef}
|
||||
|
@ -1,31 +0,0 @@
|
||||
import { Badge, Flex, Tooltip } from '@chakra-ui/react';
|
||||
import { FIELDS } from 'features/nodes/types/constants';
|
||||
import { map } from 'lodash-es';
|
||||
import { memo } from 'react';
|
||||
import 'reactflow/dist/style.css';
|
||||
|
||||
const FieldTypeLegend = () => {
|
||||
return (
|
||||
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
||||
{map(FIELDS, ({ title, description, color }, key) => (
|
||||
<Tooltip key={key} label={description}>
|
||||
<Badge
|
||||
sx={{
|
||||
userSelect: 'none',
|
||||
color:
|
||||
parseInt(color.split('.')[1] ?? '0', 10) < 500
|
||||
? 'base.800'
|
||||
: 'base.50',
|
||||
bg: color,
|
||||
}}
|
||||
textAlign="center"
|
||||
>
|
||||
{title}
|
||||
</Badge>
|
||||
</Tooltip>
|
||||
))}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(FieldTypeLegend);
|
@ -1,18 +1,11 @@
|
||||
import { Flex } from '@chakra-ui/layout';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { memo } from 'react';
|
||||
import FieldTypeLegend from './FieldTypeLegend';
|
||||
import WorkflowEditorSettings from './WorkflowEditorSettings';
|
||||
|
||||
const TopRightPanel = () => {
|
||||
const shouldShowFieldTypeLegend = useAppSelector(
|
||||
(state) => state.nodes.shouldShowFieldTypeLegend
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex sx={{ gap: 2, position: 'absolute', top: 2, insetInlineEnd: 2 }}>
|
||||
<WorkflowEditorSettings />
|
||||
{shouldShowFieldTypeLegend && <FieldTypeLegend />}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
@ -10,17 +10,15 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import { useNodeVersion } from 'features/nodes/hooks/useNodeVersion';
|
||||
import { getNeedsUpdate } from 'features/nodes/store/util/nodeUpdate';
|
||||
import {
|
||||
InvocationNodeData,
|
||||
InvocationTemplate,
|
||||
isInvocationNode,
|
||||
} from 'features/nodes/types/types';
|
||||
import { memo } from 'react';
|
||||
} from 'features/nodes/types/invocation';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaSync } from 'react-icons/fa';
|
||||
import { Node } from 'reactflow';
|
||||
import NotesTextarea from '../../flow/nodes/Invocation/NotesTextarea';
|
||||
import ScrollableContent from '../ScrollableContent';
|
||||
@ -63,12 +61,17 @@ const InspectorDetailsTab = () => {
|
||||
|
||||
export default memo(InspectorDetailsTab);
|
||||
|
||||
const Content = (props: {
|
||||
type ContentProps = {
|
||||
node: Node<InvocationNodeData>;
|
||||
template: InvocationTemplate;
|
||||
}) => {
|
||||
};
|
||||
|
||||
const Content = memo(({ node, template }: ContentProps) => {
|
||||
const { t } = useTranslation();
|
||||
const { needsUpdate, updateNode } = useNodeVersion(props.node.id);
|
||||
const needsUpdate = useMemo(
|
||||
() => getNeedsUpdate(node, template),
|
||||
[node, template]
|
||||
);
|
||||
return (
|
||||
<Box
|
||||
sx={{
|
||||
@ -87,12 +90,12 @@ const Content = (props: {
|
||||
w: 'full',
|
||||
}}
|
||||
>
|
||||
<EditableNodeTitle nodeId={props.node.data.id} />
|
||||
<EditableNodeTitle nodeId={node.data.id} />
|
||||
<HStack>
|
||||
<FormControl>
|
||||
<FormLabel>{t('nodes.nodeType')}</FormLabel>
|
||||
<Text fontSize="sm" fontWeight={600}>
|
||||
{props.template.title}
|
||||
{template.title}
|
||||
</Text>
|
||||
</FormControl>
|
||||
<Flex
|
||||
@ -104,22 +107,16 @@ const Content = (props: {
|
||||
<FormControl isInvalid={needsUpdate}>
|
||||
<FormLabel>{t('nodes.nodeVersion')}</FormLabel>
|
||||
<Text fontSize="sm" fontWeight={600}>
|
||||
{props.node.data.version}
|
||||
{node.data.version}
|
||||
</Text>
|
||||
</FormControl>
|
||||
{needsUpdate && (
|
||||
<IAIIconButton
|
||||
aria-label={t('nodes.updateNode')}
|
||||
tooltip={t('nodes.updateNode')}
|
||||
icon={<FaSync />}
|
||||
onClick={updateNode}
|
||||
/>
|
||||
)}
|
||||
</Flex>
|
||||
</HStack>
|
||||
<NotesTextarea nodeId={props.node.data.id} />
|
||||
<NotesTextarea nodeId={node.data.id} />
|
||||
</Flex>
|
||||
</ScrollableContent>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
});
|
||||
|
||||
Content.displayName = 'Content';
|
||||
|
@ -5,7 +5,7 @@ import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
|
||||
import { isInvocationNode } from 'features/nodes/types/types';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { memo } from 'react';
|
||||
import { ImageOutput } from 'services/api/types';
|
||||
import { AnyResult } from 'services/events/types';
|
||||
|
@ -2,14 +2,11 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { map } from 'lodash-es';
|
||||
import { keys, map } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
import {
|
||||
POLYMORPHIC_TYPES,
|
||||
TYPES_WITH_INPUT_COMPONENTS,
|
||||
} from '../types/constants';
|
||||
import { isInvocationNode } from '../types/invocation';
|
||||
import { getSortedFilteredFieldNames } from '../util/getSortedFilteredFieldNames';
|
||||
import { TEMPLATE_BUILDER_MAP } from '../util/buildFieldInputTemplate';
|
||||
|
||||
export const useAnyOrDirectInputFieldNames = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
@ -28,8 +25,8 @@ export const useAnyOrDirectInputFieldNames = (nodeId: string) => {
|
||||
const fields = map(nodeTemplate.inputs).filter(
|
||||
(field) =>
|
||||
(['any', 'direct'].includes(field.input) ||
|
||||
POLYMORPHIC_TYPES.includes(field.type)) &&
|
||||
TYPES_WITH_INPUT_COMPONENTS.includes(field.type)
|
||||
field.type.isPolymorphic) &&
|
||||
keys(TEMPLATE_BUILDER_MAP).includes(field.type.name)
|
||||
);
|
||||
return getSortedFilteredFieldNames(fields);
|
||||
},
|
||||
|
@ -3,10 +3,13 @@ import { RootState } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useCallback } from 'react';
|
||||
import { Node, useReactFlow } from 'reactflow';
|
||||
import { AnyInvocationType } from 'services/events/types';
|
||||
import { buildNodeData } from '../store/util/buildNodeData';
|
||||
import {
|
||||
buildCurrentImageNode,
|
||||
buildInvocationNode,
|
||||
buildNotesNode,
|
||||
} from '../store/util/buildNodeData';
|
||||
import { DRAG_HANDLE_CLASSNAME, NODE_WIDTH } from '../types/constants';
|
||||
|
||||
import { AnyNodeData, InvocationTemplate } from '../types/invocation';
|
||||
const templatesSelector = createSelector(
|
||||
[(state: RootState) => state.nodes],
|
||||
(nodes) => nodes.nodeTemplates
|
||||
@ -22,7 +25,8 @@ export const useBuildNodeData = () => {
|
||||
const flow = useReactFlow();
|
||||
|
||||
return useCallback(
|
||||
(type: AnyInvocationType | 'current_image' | 'notes') => {
|
||||
// string here is "any invocation type"
|
||||
(type: string | 'current_image' | 'notes'): Node<AnyNodeData> => {
|
||||
let _x = window.innerWidth / 2;
|
||||
let _y = window.innerHeight / 2;
|
||||
|
||||
@ -41,9 +45,19 @@ export const useBuildNodeData = () => {
|
||||
y: _y,
|
||||
});
|
||||
|
||||
const template = nodeTemplates[type];
|
||||
if (type === 'current_image') {
|
||||
return buildCurrentImageNode(position);
|
||||
}
|
||||
|
||||
return buildNodeData(type, position, template);
|
||||
if (type === 'notes') {
|
||||
return buildNotesNode(position);
|
||||
}
|
||||
|
||||
// TODO: Keep track of invocation types so we do not need to cast this
|
||||
// We know it is safe because the caller of this function gets the `type` arg from the list of invocation templates.
|
||||
const template = nodeTemplates[type] as InvocationTemplate;
|
||||
|
||||
return buildInvocationNode(position, template);
|
||||
},
|
||||
[nodeTemplates, flow]
|
||||
);
|
||||
|
@ -2,14 +2,11 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { map } from 'lodash-es';
|
||||
import { keys, map } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
import {
|
||||
POLYMORPHIC_TYPES,
|
||||
TYPES_WITH_INPUT_COMPONENTS,
|
||||
} from '../types/constants';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
import { isInvocationNode } from '../types/invocation';
|
||||
import { getSortedFilteredFieldNames } from '../util/getSortedFilteredFieldNames';
|
||||
import { TEMPLATE_BUILDER_MAP } from '../util/buildFieldInputTemplate';
|
||||
|
||||
export const useConnectionInputFieldNames = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
@ -29,9 +26,8 @@ export const useConnectionInputFieldNames = (nodeId: string) => {
|
||||
// get the visible fields
|
||||
const fields = map(nodeTemplate.inputs).filter(
|
||||
(field) =>
|
||||
(field.input === 'connection' &&
|
||||
!POLYMORPHIC_TYPES.includes(field.type)) ||
|
||||
!TYPES_WITH_INPUT_COMPONENTS.includes(field.type)
|
||||
(field.input === 'connection' && !field.type.isPolymorphic) ||
|
||||
!keys(TEMPLATE_BUILDER_MAP).includes(field.type.name)
|
||||
);
|
||||
|
||||
return getSortedFilteredFieldNames(fields);
|
||||
|
@ -8,7 +8,7 @@ import { useFieldType } from './useFieldType.ts';
|
||||
const selectIsConnectionInProgress = createSelector(
|
||||
stateSelector,
|
||||
({ nodes }) =>
|
||||
nodes.currentConnectionFieldType !== null &&
|
||||
nodes.connectionStartFieldType !== null &&
|
||||
nodes.connectionStartParams !== null
|
||||
);
|
||||
|
||||
|
@ -4,7 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { compareVersions } from 'compare-versions';
|
||||
import { useMemo } from 'react';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
import { isInvocationNode } from '../types/invocation';
|
||||
|
||||
export const useDoNodeVersionsMatch = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
|
@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { useMemo } from 'react';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
import { isInvocationNode } from '../types/invocation';
|
||||
|
||||
export const useDoesInputHaveValue = (nodeId: string, fieldName: string) => {
|
||||
const selector = useMemo(
|
||||
|
@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { useMemo } from 'react';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
import { isInvocationNode } from '../types/invocation';
|
||||
|
||||
export const useEmbedWorkflow = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
|
@ -3,9 +3,9 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { useMemo } from 'react';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
import { isInvocationNode } from '../types/invocation';
|
||||
|
||||
export const useFieldData = (nodeId: string, fieldName: string) => {
|
||||
export const useFieldInstance = (nodeId: string, fieldName: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
|
@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { useMemo } from 'react';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
import { isInvocationNode } from '../types/invocation';
|
||||
|
||||
export const useFieldInputKind = (nodeId: string, fieldName: string) => {
|
||||
const selector = useMemo(
|
||||
|
@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { useMemo } from 'react';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
import { isInvocationNode } from '../types/invocation';
|
||||
|
||||
export const useFieldLabel = (nodeId: string, fieldName: string) => {
|
||||
const selector = useMemo(
|
||||
|
@ -4,7 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { useMemo } from 'react';
|
||||
import { KIND_MAP } from '../types/constants';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
import { isInvocationNode } from '../types/invocation';
|
||||
|
||||
export const useFieldTemplate = (
|
||||
nodeId: string,
|
||||
|
@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { useMemo } from 'react';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
import { isInvocationNode } from '../types/invocation';
|
||||
import { KIND_MAP } from '../types/constants';
|
||||
|
||||
export const useFieldTemplateTitle = (
|
||||
|
@ -4,7 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { useMemo } from 'react';
|
||||
import { KIND_MAP } from '../types/constants';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
import { isInvocationNode } from '../types/invocation';
|
||||
|
||||
export const useFieldType = (
|
||||
nodeId: string,
|
||||
@ -20,7 +20,8 @@ export const useFieldType = (
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
return node?.data[KIND_MAP[kind]][fieldName]?.type;
|
||||
const field = node.data[KIND_MAP[kind]][fieldName];
|
||||
return field?.type;
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
|
@ -2,7 +2,8 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { getNeedsUpdate } from './useNodeVersion';
|
||||
import { getNeedsUpdate } from '../store/util/nodeUpdate';
|
||||
import { isInvocationNode } from '../types/invocation';
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
@ -10,8 +11,11 @@ const selector = createSelector(
|
||||
const nodes = state.nodes.nodes;
|
||||
const templates = state.nodes.nodeTemplates;
|
||||
|
||||
const needsUpdate = nodes.some((node) => {
|
||||
const needsUpdate = nodes.filter(isInvocationNode).some((node) => {
|
||||
const template = templates[node.data.type];
|
||||
if (!template) {
|
||||
return false;
|
||||
}
|
||||
return getNeedsUpdate(node, template);
|
||||
});
|
||||
return needsUpdate;
|
||||
|
@ -4,8 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { some } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
import { IMAGE_FIELDS } from '../types/constants';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
import { isInvocationNode } from '../types/invocation';
|
||||
|
||||
export const useHasImageOutput = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
@ -20,8 +19,8 @@ export const useHasImageOutput = (nodeId: string) => {
|
||||
return some(
|
||||
node.data.outputs,
|
||||
(output) =>
|
||||
IMAGE_FIELDS.includes(output.type) &&
|
||||
// the image primitive node does not actually save the image, do not show the image-saving checkboxes
|
||||
output.type.name === 'ImageField' &&
|
||||
// the image primitive node (node type "image") does not actually save the image, do not show the image-saving checkboxes
|
||||
node.data.type !== 'image'
|
||||
);
|
||||
},
|
||||
|
@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { useMemo } from 'react';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
import { isInvocationNode } from '../types/invocation';
|
||||
|
||||
export const useIsIntermediate = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
|
@ -4,7 +4,7 @@ import { useCallback } from 'react';
|
||||
import { Connection, Node, useReactFlow } from 'reactflow';
|
||||
import { validateSourceAndTargetTypes } from '../store/util/validateSourceAndTargetTypes';
|
||||
import { getIsGraphAcyclic } from '../store/util/getIsGraphAcyclic';
|
||||
import { InvocationNodeData } from '../types/types';
|
||||
import { InvocationNodeData } from '../types/invocation';
|
||||
|
||||
/**
|
||||
* NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts`
|
||||
@ -34,10 +34,10 @@ export const useIsValidConnection = () => {
|
||||
return false;
|
||||
}
|
||||
|
||||
const sourceType = sourceNode.data.outputs[sourceHandle]?.type;
|
||||
const targetType = targetNode.data.inputs[targetHandle]?.type;
|
||||
const sourceField = sourceNode.data.outputs[sourceHandle];
|
||||
const targetField = targetNode.data.inputs[targetHandle];
|
||||
|
||||
if (!sourceType || !targetType) {
|
||||
if (!sourceField || !targetField) {
|
||||
// something has gone terribly awry
|
||||
return false;
|
||||
}
|
||||
@ -70,12 +70,13 @@ export const useIsValidConnection = () => {
|
||||
return edge.target === target && edge.targetHandle === targetHandle;
|
||||
}) &&
|
||||
// except CollectionItem inputs can have multiples
|
||||
targetType !== 'CollectionItem'
|
||||
targetField.type.name !== 'CollectionItemField'
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!validateSourceAndTargetTypes(sourceType, targetType)) {
|
||||
// Must use the originalType here if it exists
|
||||
if (!validateSourceAndTargetTypes(sourceField.type, targetField.type)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -1,17 +1,15 @@
|
||||
import { ListItem, Text, UnorderedList } from '@chakra-ui/react';
|
||||
import { useLogger } from 'app/logging/useLogger';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { zWorkflow } from 'features/nodes/types/types';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { ZodError } from 'zod';
|
||||
import { fromZodError, fromZodIssue } from 'zod-validation-error';
|
||||
import { workflowLoadRequested } from '../store/actions';
|
||||
import { RefObject, memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { ZodError } from 'zod';
|
||||
import { fromZodIssue } from 'zod-validation-error';
|
||||
import { workflowLoadRequested } from '../store/actions';
|
||||
|
||||
export const useLoadWorkflowFromFile = () => {
|
||||
export const useLoadWorkflowFromFile = (resetRef: RefObject<() => void>) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const logger = useLogger('nodes');
|
||||
const { t } = useTranslation();
|
||||
@ -26,33 +24,10 @@ export const useLoadWorkflowFromFile = () => {
|
||||
|
||||
try {
|
||||
const parsedJSON = JSON.parse(String(rawJSON));
|
||||
const result = zWorkflow.safeParse(parsedJSON);
|
||||
|
||||
if (!result.success) {
|
||||
const { message } = fromZodError(result.error, {
|
||||
prefix: t('nodes.workflowValidation'),
|
||||
});
|
||||
|
||||
logger.error({ error: parseify(result.error) }, message);
|
||||
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('nodes.unableToValidateWorkflow'),
|
||||
status: 'error',
|
||||
duration: 5000,
|
||||
})
|
||||
)
|
||||
);
|
||||
reader.abort();
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(workflowLoadRequested(result.data));
|
||||
|
||||
reader.abort();
|
||||
} catch {
|
||||
// file reader error
|
||||
dispatch(workflowLoadRequested(parsedJSON));
|
||||
} catch (e) {
|
||||
// There was a problem reading the file
|
||||
logger.error(t('nodes.unableToLoadWorkflow'));
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
@ -61,12 +36,15 @@ export const useLoadWorkflowFromFile = () => {
|
||||
})
|
||||
)
|
||||
);
|
||||
reader.abort();
|
||||
}
|
||||
};
|
||||
|
||||
reader.readAsText(file);
|
||||
// Reset the file picker internal state so that the same file can be loaded again
|
||||
resetRef.current?.();
|
||||
},
|
||||
[dispatch, logger, t]
|
||||
[dispatch, logger, resetRef, t]
|
||||
);
|
||||
|
||||
return loadWorkflowFromFile;
|
||||
|
@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { useMemo } from 'react';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
import { isInvocationNode } from '../types/invocation';
|
||||
|
||||
export const useNodeLabel = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
|
@ -0,0 +1,35 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { useMemo } from 'react';
|
||||
import { isInvocationNode } from '../types/invocation';
|
||||
import { getNeedsUpdate } from '../store/util/nodeUpdate';
|
||||
|
||||
export const useNodeNeedsUpdate = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ nodes }) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
const template = nodes.nodeTemplates[node?.data.type ?? ''];
|
||||
return { node, template };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
const { node, template } = useAppSelector(selector);
|
||||
|
||||
const needsUpdate = useMemo(
|
||||
() =>
|
||||
isInvocationNode(node) && template
|
||||
? getNeedsUpdate(node, template)
|
||||
: false,
|
||||
[node, template]
|
||||
);
|
||||
|
||||
return needsUpdate;
|
||||
};
|
@ -3,16 +3,14 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { useMemo } from 'react';
|
||||
import { AnyInvocationType } from 'services/events/types';
|
||||
import { InvocationTemplate } from '../types/invocation';
|
||||
|
||||
export const useNodeTemplateByType = (
|
||||
type: AnyInvocationType | 'current_image' | 'notes'
|
||||
) => {
|
||||
export const useNodeTemplateByType = (type: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ nodes }) => {
|
||||
({ nodes }): InvocationTemplate | undefined => {
|
||||
const nodeTemplate = nodes.nodeTemplates[type];
|
||||
return nodeTemplate;
|
||||
},
|
||||
|
@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { useMemo } from 'react';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
import { isInvocationNode } from '../types/invocation';
|
||||
|
||||
export const useNodeTemplateTitle = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
|
@ -1,119 +0,0 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { satisfies } from 'compare-versions';
|
||||
import { cloneDeep, defaultsDeep } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { Node } from 'reactflow';
|
||||
import { AnyInvocationType } from 'services/events/types';
|
||||
import { nodeReplaced } from '../store/nodesSlice';
|
||||
import { buildNodeData } from '../store/util/buildNodeData';
|
||||
import {
|
||||
InvocationNodeData,
|
||||
InvocationTemplate,
|
||||
NodeData,
|
||||
isInvocationNode,
|
||||
zParsedSemver,
|
||||
} from '../types/types';
|
||||
import { useAppToaster } from 'app/components/Toaster';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export const getNeedsUpdate = (
|
||||
node?: Node<NodeData>,
|
||||
template?: InvocationTemplate
|
||||
) => {
|
||||
if (!isInvocationNode(node) || !template) {
|
||||
return false;
|
||||
}
|
||||
return node.data.version !== template.version;
|
||||
};
|
||||
|
||||
export const getMayUpdateNode = (
|
||||
node?: Node<NodeData>,
|
||||
template?: InvocationTemplate
|
||||
) => {
|
||||
const needsUpdate = getNeedsUpdate(node, template);
|
||||
if (
|
||||
!needsUpdate ||
|
||||
!isInvocationNode(node) ||
|
||||
!template ||
|
||||
!node.data.version
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
const templateMajor = zParsedSemver.parse(template.version).major;
|
||||
|
||||
return satisfies(node.data.version, `^${templateMajor}`);
|
||||
};
|
||||
|
||||
export const updateNode = (
|
||||
node?: Node<NodeData>,
|
||||
template?: InvocationTemplate
|
||||
) => {
|
||||
const mayUpdate = getMayUpdateNode(node, template);
|
||||
if (
|
||||
!mayUpdate ||
|
||||
!isInvocationNode(node) ||
|
||||
!template ||
|
||||
!node.data.version
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
const defaults = buildNodeData(
|
||||
node.data.type as AnyInvocationType,
|
||||
node.position,
|
||||
template
|
||||
) as Node<InvocationNodeData>;
|
||||
|
||||
const clone = cloneDeep(node);
|
||||
clone.data.version = template.version;
|
||||
defaultsDeep(clone, defaults);
|
||||
return clone;
|
||||
};
|
||||
|
||||
export const useNodeVersion = (nodeId: string) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const toast = useAppToaster();
|
||||
const { t } = useTranslation();
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ nodes }) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? ''];
|
||||
return { node, nodeTemplate };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
const { node, nodeTemplate } = useAppSelector(selector);
|
||||
|
||||
const needsUpdate = useMemo(
|
||||
() => getNeedsUpdate(node, nodeTemplate),
|
||||
[node, nodeTemplate]
|
||||
);
|
||||
|
||||
const mayUpdate = useMemo(
|
||||
() => getMayUpdateNode(node, nodeTemplate),
|
||||
[node, nodeTemplate]
|
||||
);
|
||||
|
||||
const _updateNode = useCallback(() => {
|
||||
const needsUpdate = getNeedsUpdate(node, nodeTemplate);
|
||||
const updatedNode = updateNode(node, nodeTemplate);
|
||||
if (!updatedNode) {
|
||||
if (needsUpdate) {
|
||||
toast({ title: t('nodes.unableToUpdateNodes', { count: 1 }) });
|
||||
}
|
||||
return;
|
||||
}
|
||||
dispatch(nodeReplaced({ nodeId: updatedNode.id, node: updatedNode }));
|
||||
}, [dispatch, node, nodeTemplate, t, toast]);
|
||||
|
||||
return { needsUpdate, mayUpdate, updateNode: _updateNode };
|
||||
};
|
@ -4,7 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { map } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
import { isInvocationNode } from '../types/invocation';
|
||||
import { getSortedFilteredFieldNames } from '../util/getSortedFilteredFieldNames';
|
||||
|
||||
export const useOutputFieldNames = (nodeId: string) => {
|
||||
|
@ -0,0 +1,23 @@
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FieldType } from '../types/field';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useFieldTypeName = (fieldType?: FieldType): string => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const name = useMemo(() => {
|
||||
if (!fieldType) {
|
||||
return '';
|
||||
}
|
||||
const { name } = fieldType;
|
||||
if (fieldType.isCollection) {
|
||||
return t('nodes.collectionFieldType', { name });
|
||||
}
|
||||
if (fieldType.isPolymorphic) {
|
||||
return t('nodes.polymorphicFieldType', { name });
|
||||
}
|
||||
return name;
|
||||
}, [fieldType, t]);
|
||||
|
||||
return name;
|
||||
};
|
@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { useMemo } from 'react';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
import { isInvocationNode } from '../types/invocation';
|
||||
|
||||
export const useUseCache = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
|
@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { useMemo } from 'react';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
import { isInvocationNode } from '../types/invocation';
|
||||
|
||||
export const useWithWorkflow = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
|
@ -1,6 +1,5 @@
|
||||
import { createAction, isAnyOf } from '@reduxjs/toolkit';
|
||||
import { Graph } from 'services/api/types';
|
||||
import { Workflow } from '../types/types';
|
||||
|
||||
export const textToImageGraphBuilt = createAction<Graph>(
|
||||
'nodes/textToImageGraphBuilt'
|
||||
@ -18,7 +17,7 @@ export const isAnyGraphBuilt = isAnyOf(
|
||||
nodesGraphBuilt
|
||||
);
|
||||
|
||||
export const workflowLoadRequested = createAction<Workflow>(
|
||||
export const workflowLoadRequested = createAction<unknown>(
|
||||
'nodes/workflowLoadRequested'
|
||||
);
|
||||
|
||||
|
@ -6,7 +6,7 @@ import { NodesState } from './types';
|
||||
export const nodesPersistDenylist: (keyof NodesState)[] = [
|
||||
'nodeTemplates',
|
||||
'connectionStartParams',
|
||||
'currentConnectionFieldType',
|
||||
'connectionStartFieldType',
|
||||
'selectedNodes',
|
||||
'selectedEdges',
|
||||
'isReady',
|
||||
|
@ -20,7 +20,6 @@ import {
|
||||
XYPosition,
|
||||
} from 'reactflow';
|
||||
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
|
||||
import { ImageField } from 'services/api/types';
|
||||
import {
|
||||
appSocketGeneratorProgress,
|
||||
appSocketInvocationComplete,
|
||||
@ -31,60 +30,58 @@ import {
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { DRAG_HANDLE_CLASSNAME } from '../types/constants';
|
||||
import {
|
||||
BoardInputFieldValue,
|
||||
BooleanInputFieldValue,
|
||||
ColorInputFieldValue,
|
||||
ControlNetModelInputFieldValue,
|
||||
CurrentImageNodeData,
|
||||
EnumInputFieldValue,
|
||||
BoardFieldValue,
|
||||
BooleanFieldValue,
|
||||
ColorFieldValue,
|
||||
ControlNetModelFieldValue,
|
||||
EnumFieldValue,
|
||||
FieldIdentifier,
|
||||
FloatInputFieldValue,
|
||||
ImageInputFieldValue,
|
||||
InputFieldValue,
|
||||
IntegerInputFieldValue,
|
||||
InvocationNodeData,
|
||||
FieldValue,
|
||||
FloatFieldValue,
|
||||
ImageFieldValue,
|
||||
IntegerFieldValue,
|
||||
IPAdapterModelFieldValue,
|
||||
LoRAModelFieldValue,
|
||||
MainModelFieldValue,
|
||||
SchedulerFieldValue,
|
||||
SDXLRefinerModelFieldValue,
|
||||
StringFieldValue,
|
||||
T2IAdapterModelFieldValue,
|
||||
VAEModelFieldValue,
|
||||
} from '../types/field';
|
||||
import {
|
||||
AnyNodeData,
|
||||
InvocationTemplate,
|
||||
IPAdapterModelInputFieldValue,
|
||||
isInvocationNode,
|
||||
isNotesNode,
|
||||
LoRAModelInputFieldValue,
|
||||
MainModelInputFieldValue,
|
||||
NodeExecutionState,
|
||||
NodeStatus,
|
||||
NotesNodeData,
|
||||
SchedulerInputFieldValue,
|
||||
SDXLRefinerModelInputFieldValue,
|
||||
StringInputFieldValue,
|
||||
T2IAdapterModelInputFieldValue,
|
||||
VaeModelInputFieldValue,
|
||||
Workflow,
|
||||
} from '../types/types';
|
||||
zNodeStatus,
|
||||
} from '../types/invocation';
|
||||
import { WorkflowV2 } from '../types/workflow';
|
||||
import { NodesState } from './types';
|
||||
import { findUnoccupiedPosition } from './util/findUnoccupiedPosition';
|
||||
import { findConnectionToValidHandle } from './util/findConnectionToValidHandle';
|
||||
|
||||
export const WORKFLOW_FORMAT_VERSION = '1.0.0';
|
||||
import { findUnoccupiedPosition } from './util/findUnoccupiedPosition';
|
||||
|
||||
const initialNodeExecutionState: Omit<NodeExecutionState, 'nodeId'> = {
|
||||
status: NodeStatus.PENDING,
|
||||
status: zNodeStatus.enum.PENDING,
|
||||
error: null,
|
||||
progress: null,
|
||||
progressImage: null,
|
||||
outputs: [],
|
||||
};
|
||||
|
||||
export const initialWorkflow = {
|
||||
meta: {
|
||||
version: WORKFLOW_FORMAT_VERSION,
|
||||
},
|
||||
const INITIAL_WORKFLOW: WorkflowV2 = {
|
||||
name: '',
|
||||
author: '',
|
||||
description: '',
|
||||
notes: '',
|
||||
tags: '',
|
||||
contact: '',
|
||||
version: '',
|
||||
contact: '',
|
||||
tags: '',
|
||||
notes: '',
|
||||
nodes: [],
|
||||
edges: [],
|
||||
exposedFields: [],
|
||||
meta: { version: '2.0.0' },
|
||||
};
|
||||
|
||||
export const initialNodesState: NodesState = {
|
||||
@ -93,11 +90,10 @@ export const initialNodesState: NodesState = {
|
||||
nodeTemplates: {},
|
||||
isReady: false,
|
||||
connectionStartParams: null,
|
||||
currentConnectionFieldType: null,
|
||||
connectionStartFieldType: null,
|
||||
connectionMade: false,
|
||||
modifyingEdge: false,
|
||||
addNewNodePosition: null,
|
||||
shouldShowFieldTypeLegend: false,
|
||||
shouldShowMinimapPanel: true,
|
||||
shouldValidateGraph: true,
|
||||
shouldAnimateEdges: true,
|
||||
@ -107,7 +103,7 @@ export const initialNodesState: NodesState = {
|
||||
nodeOpacity: 1,
|
||||
selectedNodes: [],
|
||||
selectedEdges: [],
|
||||
workflow: initialWorkflow,
|
||||
workflow: INITIAL_WORKFLOW,
|
||||
nodeExecutionStates: {},
|
||||
viewport: { x: 0, y: 0, zoom: 1 },
|
||||
mouseOverField: null,
|
||||
@ -117,13 +113,13 @@ export const initialNodesState: NodesState = {
|
||||
selectionMode: SelectionMode.Partial,
|
||||
};
|
||||
|
||||
type FieldValueAction<T extends InputFieldValue> = PayloadAction<{
|
||||
type FieldValueAction<T extends FieldValue> = PayloadAction<{
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
value: T['value'];
|
||||
value: T;
|
||||
}>;
|
||||
|
||||
const fieldValueReducer = <T extends InputFieldValue>(
|
||||
const fieldValueReducer = <T extends FieldValue>(
|
||||
state: NodesState,
|
||||
action: FieldValueAction<T>
|
||||
) => {
|
||||
@ -161,12 +157,7 @@ const nodesSlice = createSlice({
|
||||
}
|
||||
state.nodes[nodeIndex] = action.payload.node;
|
||||
},
|
||||
nodeAdded: (
|
||||
state,
|
||||
action: PayloadAction<
|
||||
Node<InvocationNodeData | CurrentImageNodeData | NotesNodeData>
|
||||
>
|
||||
) => {
|
||||
nodeAdded: (state, action: PayloadAction<Node<AnyNodeData>>) => {
|
||||
const node = action.payload;
|
||||
const position = findUnoccupiedPosition(
|
||||
state.nodes,
|
||||
@ -203,7 +194,7 @@ const nodesSlice = createSlice({
|
||||
nodeId &&
|
||||
handleId &&
|
||||
handleType &&
|
||||
state.currentConnectionFieldType
|
||||
state.connectionStartFieldType
|
||||
) {
|
||||
const newConnection = findConnectionToValidHandle(
|
||||
node,
|
||||
@ -212,7 +203,7 @@ const nodesSlice = createSlice({
|
||||
nodeId,
|
||||
handleId,
|
||||
handleType,
|
||||
state.currentConnectionFieldType
|
||||
state.connectionStartFieldType
|
||||
);
|
||||
if (newConnection) {
|
||||
state.edges = addEdge(
|
||||
@ -224,7 +215,7 @@ const nodesSlice = createSlice({
|
||||
}
|
||||
|
||||
state.connectionStartParams = null;
|
||||
state.currentConnectionFieldType = null;
|
||||
state.connectionStartFieldType = null;
|
||||
},
|
||||
edgeChangeStarted: (state) => {
|
||||
state.modifyingEdge = true;
|
||||
@ -258,10 +249,10 @@ const nodesSlice = createSlice({
|
||||
handleType === 'source'
|
||||
? node.data.outputs[handleId]
|
||||
: node.data.inputs[handleId];
|
||||
state.currentConnectionFieldType = field?.type ?? null;
|
||||
state.connectionStartFieldType = field?.type ?? null;
|
||||
},
|
||||
connectionMade: (state, action: PayloadAction<Connection>) => {
|
||||
const fieldType = state.currentConnectionFieldType;
|
||||
const fieldType = state.connectionStartFieldType;
|
||||
if (!fieldType) {
|
||||
return;
|
||||
}
|
||||
@ -286,7 +277,7 @@ const nodesSlice = createSlice({
|
||||
nodeId &&
|
||||
handleId &&
|
||||
handleType &&
|
||||
state.currentConnectionFieldType
|
||||
state.connectionStartFieldType
|
||||
) {
|
||||
const newConnection = findConnectionToValidHandle(
|
||||
mouseOverNode,
|
||||
@ -295,7 +286,7 @@ const nodesSlice = createSlice({
|
||||
nodeId,
|
||||
handleId,
|
||||
handleType,
|
||||
state.currentConnectionFieldType
|
||||
state.connectionStartFieldType
|
||||
);
|
||||
if (newConnection) {
|
||||
state.edges = addEdge(
|
||||
@ -306,14 +297,14 @@ const nodesSlice = createSlice({
|
||||
}
|
||||
}
|
||||
state.connectionStartParams = null;
|
||||
state.currentConnectionFieldType = null;
|
||||
state.connectionStartFieldType = null;
|
||||
} else {
|
||||
state.addNewNodePosition = action.payload.cursorPosition;
|
||||
state.isAddNodePopoverOpen = true;
|
||||
}
|
||||
} else {
|
||||
state.connectionStartParams = null;
|
||||
state.currentConnectionFieldType = null;
|
||||
state.connectionStartFieldType = null;
|
||||
}
|
||||
state.modifyingEdge = false;
|
||||
},
|
||||
@ -529,12 +520,7 @@ const nodesSlice = createSlice({
|
||||
state.edges = applyEdgeChanges(edgeChanges, state.edges);
|
||||
}
|
||||
},
|
||||
nodesDeleted: (
|
||||
state,
|
||||
action: PayloadAction<
|
||||
Node<InvocationNodeData | NotesNodeData | CurrentImageNodeData>[]
|
||||
>
|
||||
) => {
|
||||
nodesDeleted: (state, action: PayloadAction<Node<AnyNodeData>[]>) => {
|
||||
action.payload.forEach((node) => {
|
||||
state.workflow.exposedFields = state.workflow.exposedFields.filter(
|
||||
(f) => f.nodeId !== node.id
|
||||
@ -588,132 +574,94 @@ const nodesSlice = createSlice({
|
||||
},
|
||||
fieldStringValueChanged: (
|
||||
state,
|
||||
action: FieldValueAction<StringInputFieldValue>
|
||||
action: FieldValueAction<StringFieldValue>
|
||||
) => {
|
||||
fieldValueReducer(state, action);
|
||||
},
|
||||
fieldNumberValueChanged: (
|
||||
state,
|
||||
action: FieldValueAction<IntegerInputFieldValue | FloatInputFieldValue>
|
||||
action: FieldValueAction<IntegerFieldValue | FloatFieldValue>
|
||||
) => {
|
||||
fieldValueReducer(state, action);
|
||||
},
|
||||
fieldBooleanValueChanged: (
|
||||
state,
|
||||
action: FieldValueAction<BooleanInputFieldValue>
|
||||
action: FieldValueAction<BooleanFieldValue>
|
||||
) => {
|
||||
fieldValueReducer(state, action);
|
||||
},
|
||||
fieldBoardValueChanged: (
|
||||
state,
|
||||
action: FieldValueAction<BoardInputFieldValue>
|
||||
action: FieldValueAction<BoardFieldValue>
|
||||
) => {
|
||||
fieldValueReducer(state, action);
|
||||
},
|
||||
fieldImageValueChanged: (
|
||||
state,
|
||||
action: FieldValueAction<ImageInputFieldValue>
|
||||
action: FieldValueAction<ImageFieldValue>
|
||||
) => {
|
||||
fieldValueReducer(state, action);
|
||||
},
|
||||
fieldColorValueChanged: (
|
||||
state,
|
||||
action: FieldValueAction<ColorInputFieldValue>
|
||||
action: FieldValueAction<ColorFieldValue>
|
||||
) => {
|
||||
fieldValueReducer(state, action);
|
||||
},
|
||||
fieldMainModelValueChanged: (
|
||||
state,
|
||||
action: FieldValueAction<MainModelInputFieldValue>
|
||||
action: FieldValueAction<MainModelFieldValue>
|
||||
) => {
|
||||
fieldValueReducer(state, action);
|
||||
},
|
||||
fieldRefinerModelValueChanged: (
|
||||
state,
|
||||
action: FieldValueAction<SDXLRefinerModelInputFieldValue>
|
||||
action: FieldValueAction<SDXLRefinerModelFieldValue>
|
||||
) => {
|
||||
fieldValueReducer(state, action);
|
||||
},
|
||||
fieldVaeModelValueChanged: (
|
||||
state,
|
||||
action: FieldValueAction<VaeModelInputFieldValue>
|
||||
action: FieldValueAction<VAEModelFieldValue>
|
||||
) => {
|
||||
fieldValueReducer(state, action);
|
||||
},
|
||||
fieldLoRAModelValueChanged: (
|
||||
state,
|
||||
action: FieldValueAction<LoRAModelInputFieldValue>
|
||||
action: FieldValueAction<LoRAModelFieldValue>
|
||||
) => {
|
||||
fieldValueReducer(state, action);
|
||||
},
|
||||
fieldControlNetModelValueChanged: (
|
||||
state,
|
||||
action: FieldValueAction<ControlNetModelInputFieldValue>
|
||||
action: FieldValueAction<ControlNetModelFieldValue>
|
||||
) => {
|
||||
fieldValueReducer(state, action);
|
||||
},
|
||||
fieldIPAdapterModelValueChanged: (
|
||||
state,
|
||||
action: FieldValueAction<IPAdapterModelInputFieldValue>
|
||||
action: FieldValueAction<IPAdapterModelFieldValue>
|
||||
) => {
|
||||
fieldValueReducer(state, action);
|
||||
},
|
||||
fieldT2IAdapterModelValueChanged: (
|
||||
state,
|
||||
action: FieldValueAction<T2IAdapterModelInputFieldValue>
|
||||
action: FieldValueAction<T2IAdapterModelFieldValue>
|
||||
) => {
|
||||
fieldValueReducer(state, action);
|
||||
},
|
||||
fieldEnumModelValueChanged: (
|
||||
state,
|
||||
action: FieldValueAction<EnumInputFieldValue>
|
||||
action: FieldValueAction<EnumFieldValue>
|
||||
) => {
|
||||
fieldValueReducer(state, action);
|
||||
},
|
||||
fieldSchedulerValueChanged: (
|
||||
state,
|
||||
action: FieldValueAction<SchedulerInputFieldValue>
|
||||
action: FieldValueAction<SchedulerFieldValue>
|
||||
) => {
|
||||
fieldValueReducer(state, action);
|
||||
},
|
||||
imageCollectionFieldValueChanged: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
value: ImageField[];
|
||||
}>
|
||||
) => {
|
||||
const { nodeId, fieldName, value } = action.payload;
|
||||
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
|
||||
|
||||
if (nodeIndex === -1) {
|
||||
return;
|
||||
}
|
||||
|
||||
const node = state.nodes?.[nodeIndex];
|
||||
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const input = node.data?.inputs[fieldName];
|
||||
if (!input) {
|
||||
return;
|
||||
}
|
||||
|
||||
const currentValue = cloneDeep(input.value);
|
||||
|
||||
if (!currentValue) {
|
||||
input.value = value;
|
||||
return;
|
||||
}
|
||||
|
||||
input.value = uniqBy(
|
||||
(currentValue as ImageField[]).concat(value),
|
||||
'image_name'
|
||||
);
|
||||
},
|
||||
notesNodeValueChanged: (
|
||||
state,
|
||||
action: PayloadAction<{ nodeId: string; value: string }>
|
||||
@ -726,12 +674,6 @@ const nodesSlice = createSlice({
|
||||
}
|
||||
node.data.notes = value;
|
||||
},
|
||||
shouldShowFieldTypeLegendChanged: (
|
||||
state,
|
||||
action: PayloadAction<boolean>
|
||||
) => {
|
||||
state.shouldShowFieldTypeLegend = action.payload;
|
||||
},
|
||||
shouldShowMinimapPanelChanged: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldShowMinimapPanel = action.payload;
|
||||
},
|
||||
@ -745,7 +687,7 @@ const nodesSlice = createSlice({
|
||||
nodeEditorReset: (state) => {
|
||||
state.nodes = [];
|
||||
state.edges = [];
|
||||
state.workflow = cloneDeep(initialWorkflow);
|
||||
state.workflow = cloneDeep(INITIAL_WORKFLOW);
|
||||
},
|
||||
shouldValidateGraphChanged: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldValidateGraph = action.payload;
|
||||
@ -783,7 +725,7 @@ const nodesSlice = createSlice({
|
||||
workflowContactChanged: (state, action: PayloadAction<string>) => {
|
||||
state.workflow.contact = action.payload;
|
||||
},
|
||||
workflowLoaded: (state, action: PayloadAction<Workflow>) => {
|
||||
workflowLoaded: (state, action: PayloadAction<WorkflowV2>) => {
|
||||
const { nodes, edges, ...workflow } = action.payload;
|
||||
state.workflow = workflow;
|
||||
|
||||
@ -810,7 +752,7 @@ const nodesSlice = createSlice({
|
||||
}, {});
|
||||
},
|
||||
workflowReset: (state) => {
|
||||
state.workflow = cloneDeep(initialWorkflow);
|
||||
state.workflow = cloneDeep(INITIAL_WORKFLOW);
|
||||
},
|
||||
viewportChanged: (state, action: PayloadAction<Viewport>) => {
|
||||
state.viewport = action.payload;
|
||||
@ -942,7 +884,7 @@ const nodesSlice = createSlice({
|
||||
|
||||
//Make sure these get reset if we close the popover and haven't selected a node
|
||||
state.connectionStartParams = null;
|
||||
state.currentConnectionFieldType = null;
|
||||
state.connectionStartFieldType = null;
|
||||
},
|
||||
addNodePopoverToggled: (state) => {
|
||||
state.isAddNodePopoverOpen = !state.isAddNodePopoverOpen;
|
||||
@ -961,14 +903,14 @@ const nodesSlice = createSlice({
|
||||
const { source_node_id } = action.payload.data;
|
||||
const node = state.nodeExecutionStates[source_node_id];
|
||||
if (node) {
|
||||
node.status = NodeStatus.IN_PROGRESS;
|
||||
node.status = zNodeStatus.enum.IN_PROGRESS;
|
||||
}
|
||||
});
|
||||
builder.addCase(appSocketInvocationComplete, (state, action) => {
|
||||
const { source_node_id, result } = action.payload.data;
|
||||
const nes = state.nodeExecutionStates[source_node_id];
|
||||
if (nes) {
|
||||
nes.status = NodeStatus.COMPLETED;
|
||||
nes.status = zNodeStatus.enum.COMPLETED;
|
||||
if (nes.progress !== null) {
|
||||
nes.progress = 1;
|
||||
}
|
||||
@ -979,7 +921,7 @@ const nodesSlice = createSlice({
|
||||
const { source_node_id } = action.payload.data;
|
||||
const node = state.nodeExecutionStates[source_node_id];
|
||||
if (node) {
|
||||
node.status = NodeStatus.FAILED;
|
||||
node.status = zNodeStatus.enum.FAILED;
|
||||
node.error = action.payload.data.error;
|
||||
node.progress = null;
|
||||
node.progressImage = null;
|
||||
@ -990,7 +932,7 @@ const nodesSlice = createSlice({
|
||||
action.payload.data;
|
||||
const node = state.nodeExecutionStates[source_node_id];
|
||||
if (node) {
|
||||
node.status = NodeStatus.IN_PROGRESS;
|
||||
node.status = zNodeStatus.enum.IN_PROGRESS;
|
||||
node.progress = (step + 1) / total_steps;
|
||||
node.progressImage = progress_image ?? null;
|
||||
}
|
||||
@ -998,7 +940,7 @@ const nodesSlice = createSlice({
|
||||
builder.addCase(appSocketQueueItemStatusChanged, (state, action) => {
|
||||
if (['in_progress'].includes(action.payload.data.queue_item.status)) {
|
||||
forEach(state.nodeExecutionStates, (nes) => {
|
||||
nes.status = NodeStatus.PENDING;
|
||||
nes.status = zNodeStatus.enum.PENDING;
|
||||
nes.error = null;
|
||||
nes.progress = null;
|
||||
nes.progressImage = null;
|
||||
@ -1037,7 +979,6 @@ export const {
|
||||
fieldSchedulerValueChanged,
|
||||
fieldStringValueChanged,
|
||||
fieldVaeModelValueChanged,
|
||||
imageCollectionFieldValueChanged,
|
||||
mouseOverFieldChanged,
|
||||
mouseOverNodeChanged,
|
||||
nodeAdded,
|
||||
@ -1063,7 +1004,6 @@ export const {
|
||||
selectionPasted,
|
||||
shouldAnimateEdgesChanged,
|
||||
shouldColorEdgesChanged,
|
||||
shouldShowFieldTypeLegendChanged,
|
||||
shouldShowMinimapPanelChanged,
|
||||
shouldSnapToGridChanged,
|
||||
shouldValidateGraphChanged,
|
||||
|
@ -6,25 +6,23 @@ import {
|
||||
Viewport,
|
||||
XYPosition,
|
||||
} from 'reactflow';
|
||||
import { FieldIdentifier, FieldType } from '../types/field';
|
||||
import {
|
||||
FieldIdentifier,
|
||||
FieldType,
|
||||
AnyNodeData,
|
||||
InvocationEdgeExtra,
|
||||
InvocationTemplate,
|
||||
NodeData,
|
||||
NodeExecutionState,
|
||||
Workflow,
|
||||
} from '../types/types';
|
||||
} from '../types/invocation';
|
||||
import { WorkflowV2 } from '../types/workflow';
|
||||
|
||||
export type NodesState = {
|
||||
nodes: Node<NodeData>[];
|
||||
nodes: Node<AnyNodeData>[];
|
||||
edges: Edge<InvocationEdgeExtra>[];
|
||||
nodeTemplates: Record<string, InvocationTemplate>;
|
||||
connectionStartParams: OnConnectStartParams | null;
|
||||
currentConnectionFieldType: FieldType | null;
|
||||
connectionStartFieldType: FieldType | null;
|
||||
connectionMade: boolean;
|
||||
modifyingEdge: boolean;
|
||||
shouldShowFieldTypeLegend: boolean;
|
||||
shouldShowMinimapPanel: boolean;
|
||||
shouldValidateGraph: boolean;
|
||||
shouldAnimateEdges: boolean;
|
||||
@ -33,13 +31,13 @@ export type NodesState = {
|
||||
shouldColorEdges: boolean;
|
||||
selectedNodes: string[];
|
||||
selectedEdges: string[];
|
||||
workflow: Omit<Workflow, 'nodes' | 'edges'>;
|
||||
workflow: Omit<WorkflowV2, 'nodes' | 'edges'>;
|
||||
nodeExecutionStates: Record<string, NodeExecutionState>;
|
||||
viewport: Viewport;
|
||||
isReady: boolean;
|
||||
mouseOverField: FieldIdentifier | null;
|
||||
mouseOverNode: string | null;
|
||||
nodesToCopy: Node<NodeData>[];
|
||||
nodesToCopy: Node<AnyNodeData>[];
|
||||
edgesToCopy: Edge<InvocationEdgeExtra>[];
|
||||
isAddNodePopoverOpen: boolean;
|
||||
addNewNodePosition: XYPosition | null;
|
||||
|
@ -1,78 +1,73 @@
|
||||
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
|
||||
import {
|
||||
FieldInputInstance,
|
||||
FieldOutputInstance,
|
||||
} from 'features/nodes/types/field';
|
||||
import {
|
||||
CurrentImageNodeData,
|
||||
InputFieldValue,
|
||||
InvocationNodeData,
|
||||
InvocationTemplate,
|
||||
NotesNodeData,
|
||||
OutputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
import { buildInputFieldValue } from 'features/nodes/util/fieldValueBuilders';
|
||||
} from 'features/nodes/types/invocation';
|
||||
import { buildFieldInputInstance } from 'features/nodes/util/buildFieldInputInstance';
|
||||
import { reduce } from 'lodash-es';
|
||||
import { Node, XYPosition } from 'reactflow';
|
||||
import { AnyInvocationType } from 'services/events/types';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
export const SHARED_NODE_PROPERTIES: Partial<Node> = {
|
||||
dragHandle: `.${DRAG_HANDLE_CLASSNAME}`,
|
||||
};
|
||||
export const buildNodeData = (
|
||||
type: AnyInvocationType | 'current_image' | 'notes',
|
||||
position: XYPosition,
|
||||
template?: InvocationTemplate
|
||||
):
|
||||
| Node<CurrentImageNodeData>
|
||||
| Node<NotesNodeData>
|
||||
| Node<InvocationNodeData>
|
||||
| undefined => {
|
||||
const nodeId = uuidv4();
|
||||
|
||||
if (type === 'current_image') {
|
||||
const node: Node<CurrentImageNodeData> = {
|
||||
...SHARED_NODE_PROPERTIES,
|
||||
export const buildNotesNode = (position: XYPosition): Node<NotesNodeData> => {
|
||||
const nodeId = uuidv4();
|
||||
const node: Node<NotesNodeData> = {
|
||||
...SHARED_NODE_PROPERTIES,
|
||||
id: nodeId,
|
||||
type: 'notes',
|
||||
position,
|
||||
data: {
|
||||
id: nodeId,
|
||||
isOpen: true,
|
||||
label: 'Notes',
|
||||
notes: '',
|
||||
type: 'notes',
|
||||
},
|
||||
};
|
||||
return node;
|
||||
};
|
||||
|
||||
export const buildCurrentImageNode = (
|
||||
position: XYPosition
|
||||
): Node<CurrentImageNodeData> => {
|
||||
const nodeId = uuidv4();
|
||||
const node: Node<CurrentImageNodeData> = {
|
||||
...SHARED_NODE_PROPERTIES,
|
||||
id: nodeId,
|
||||
type: 'current_image',
|
||||
position,
|
||||
data: {
|
||||
id: nodeId,
|
||||
type: 'current_image',
|
||||
position,
|
||||
data: {
|
||||
id: nodeId,
|
||||
type: 'current_image',
|
||||
isOpen: true,
|
||||
label: 'Current Image',
|
||||
},
|
||||
};
|
||||
isOpen: true,
|
||||
label: 'Current Image',
|
||||
},
|
||||
};
|
||||
return node;
|
||||
};
|
||||
|
||||
return node;
|
||||
}
|
||||
|
||||
if (type === 'notes') {
|
||||
const node: Node<NotesNodeData> = {
|
||||
...SHARED_NODE_PROPERTIES,
|
||||
id: nodeId,
|
||||
type: 'notes',
|
||||
position,
|
||||
data: {
|
||||
id: nodeId,
|
||||
isOpen: true,
|
||||
label: 'Notes',
|
||||
notes: '',
|
||||
type: 'notes',
|
||||
},
|
||||
};
|
||||
|
||||
return node;
|
||||
}
|
||||
|
||||
if (template === undefined) {
|
||||
console.error(`Unable to find template ${type}.`);
|
||||
return;
|
||||
}
|
||||
export const buildInvocationNode = (
|
||||
position: XYPosition,
|
||||
template: InvocationTemplate
|
||||
): Node<InvocationNodeData> => {
|
||||
const nodeId = uuidv4();
|
||||
const { type } = template;
|
||||
|
||||
const inputs = reduce(
|
||||
template.inputs,
|
||||
(inputsAccumulator, inputTemplate, inputName) => {
|
||||
const fieldId = uuidv4();
|
||||
|
||||
const inputFieldValue: InputFieldValue = buildInputFieldValue(
|
||||
const inputFieldValue: FieldInputInstance = buildFieldInputInstance(
|
||||
fieldId,
|
||||
inputTemplate
|
||||
);
|
||||
@ -81,7 +76,7 @@ export const buildNodeData = (
|
||||
|
||||
return inputsAccumulator;
|
||||
},
|
||||
{} as Record<string, InputFieldValue>
|
||||
{} as Record<string, FieldInputInstance>
|
||||
);
|
||||
|
||||
const outputs = reduce(
|
||||
@ -89,7 +84,7 @@ export const buildNodeData = (
|
||||
(outputsAccumulator, outputTemplate, outputName) => {
|
||||
const fieldId = uuidv4();
|
||||
|
||||
const outputFieldValue: OutputFieldValue = {
|
||||
const outputFieldValue: FieldOutputInstance = {
|
||||
id: fieldId,
|
||||
name: outputName,
|
||||
type: outputTemplate.type,
|
||||
@ -100,10 +95,10 @@ export const buildNodeData = (
|
||||
|
||||
return outputsAccumulator;
|
||||
},
|
||||
{} as Record<string, OutputFieldValue>
|
||||
{} as Record<string, FieldOutputInstance>
|
||||
);
|
||||
|
||||
const invocation: Node<InvocationNodeData> = {
|
||||
const node: Node<InvocationNodeData> = {
|
||||
...SHARED_NODE_PROPERTIES,
|
||||
id: nodeId,
|
||||
type: 'invocation',
|
||||
@ -117,11 +112,11 @@ export const buildNodeData = (
|
||||
isOpen: true,
|
||||
embedWorkflow: false,
|
||||
isIntermediate: type === 'save_image' ? false : true,
|
||||
useCache: template.useCache,
|
||||
inputs,
|
||||
outputs,
|
||||
useCache: template.useCache,
|
||||
},
|
||||
};
|
||||
|
||||
return invocation;
|
||||
return node;
|
||||
};
|
||||
|
@ -1,20 +1,19 @@
|
||||
import { Connection, HandleType } from 'reactflow';
|
||||
import { Node, Edge } from 'reactflow';
|
||||
import {
|
||||
FieldType,
|
||||
InputFieldValue,
|
||||
OutputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
import { Connection, Edge, HandleType, Node } from 'reactflow';
|
||||
|
||||
import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
|
||||
import {
|
||||
FieldInputInstance,
|
||||
FieldOutputInstance,
|
||||
FieldType,
|
||||
} from 'features/nodes/types/field';
|
||||
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
|
||||
import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
|
||||
|
||||
const isValidConnection = (
|
||||
edges: Edge[],
|
||||
handleCurrentType: HandleType,
|
||||
handleCurrentFieldType: FieldType,
|
||||
node: Node,
|
||||
handle: InputFieldValue | OutputFieldValue
|
||||
handle: FieldInputInstance | FieldOutputInstance
|
||||
) => {
|
||||
let isValidConnection = true;
|
||||
if (handleCurrentType === 'source') {
|
||||
|
@ -1,9 +1,9 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
|
||||
import { FieldType } from 'features/nodes/types/types';
|
||||
import { FieldType } from 'features/nodes/types/field';
|
||||
import i18n from 'i18next';
|
||||
import { HandleType } from 'reactflow';
|
||||
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
|
||||
import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
|
||||
|
||||
/**
|
||||
@ -17,15 +17,15 @@ export const makeConnectionErrorSelector = (
|
||||
handleType: HandleType,
|
||||
fieldType?: FieldType
|
||||
) => {
|
||||
return createSelector(stateSelector, (state) => {
|
||||
return createSelector(stateSelector, (state): string | undefined => {
|
||||
if (!fieldType) {
|
||||
return i18n.t('nodes.noFieldType');
|
||||
}
|
||||
|
||||
const { currentConnectionFieldType, connectionStartParams, nodes, edges } =
|
||||
const { connectionStartFieldType, connectionStartParams, nodes, edges } =
|
||||
state.nodes;
|
||||
|
||||
if (!connectionStartParams || !currentConnectionFieldType) {
|
||||
if (!connectionStartParams || !connectionStartFieldType) {
|
||||
return i18n.t('nodes.noConnectionInProgress');
|
||||
}
|
||||
|
||||
@ -40,9 +40,9 @@ export const makeConnectionErrorSelector = (
|
||||
}
|
||||
|
||||
const targetType =
|
||||
handleType === 'target' ? fieldType : currentConnectionFieldType;
|
||||
handleType === 'target' ? fieldType : connectionStartFieldType;
|
||||
const sourceType =
|
||||
handleType === 'source' ? fieldType : currentConnectionFieldType;
|
||||
handleType === 'source' ? fieldType : connectionStartFieldType;
|
||||
|
||||
if (nodeId === connectionNodeId) {
|
||||
return i18n.t('nodes.cannotConnectToSelf');
|
||||
@ -80,7 +80,7 @@ export const makeConnectionErrorSelector = (
|
||||
return edge.target === target && edge.targetHandle === targetHandle;
|
||||
}) &&
|
||||
// except CollectionItem inputs can have multiples
|
||||
targetType !== 'CollectionItem'
|
||||
targetType.name !== 'CollectionItemField'
|
||||
) {
|
||||
return i18n.t('nodes.inputMayOnlyHaveOneConnection');
|
||||
}
|
||||
@ -100,6 +100,6 @@ export const makeConnectionErrorSelector = (
|
||||
return i18n.t('nodes.connectionWouldCreateCycle');
|
||||
}
|
||||
|
||||
return null;
|
||||
return;
|
||||
});
|
||||
};
|
||||
|
@ -0,0 +1,68 @@
|
||||
import { satisfies } from 'compare-versions';
|
||||
import { NodeUpdateError } from 'features/nodes/types/error';
|
||||
import {
|
||||
InvocationNodeData,
|
||||
InvocationTemplate,
|
||||
} from 'features/nodes/types/invocation';
|
||||
import { zParsedSemver } from 'features/nodes/types/semver';
|
||||
import { cloneDeep, defaultsDeep } from 'lodash-es';
|
||||
import { Node } from 'reactflow';
|
||||
import { buildInvocationNode } from './buildNodeData';
|
||||
|
||||
export const getNeedsUpdate = (
|
||||
node: Node<InvocationNodeData>,
|
||||
template: InvocationTemplate
|
||||
): boolean => {
|
||||
if (node.data.type !== template.type) {
|
||||
return true;
|
||||
}
|
||||
return node.data.version !== template.version;
|
||||
}; /**
|
||||
* Checks if a node may be updated by comparing its major version with the template's major version.
|
||||
* @param node The node to check.
|
||||
* @param template The invocation template to check against.
|
||||
*/
|
||||
|
||||
export const getMayUpdateNode = (
|
||||
node: Node<InvocationNodeData>,
|
||||
template: InvocationTemplate
|
||||
): boolean => {
|
||||
const needsUpdate = getNeedsUpdate(node, template);
|
||||
if (!needsUpdate || node.data.type !== template.type) {
|
||||
return false;
|
||||
}
|
||||
const templateMajor = zParsedSemver.parse(template.version).major;
|
||||
|
||||
return satisfies(node.data.version, `^${templateMajor}`);
|
||||
}; /**
|
||||
* Updates a node to the latest version of its template:
|
||||
* - Create a new node data object with the latest version of the template.
|
||||
* - Recursively merge new node data object into the node to be updated.
|
||||
*
|
||||
* @param node The node to updated.
|
||||
* @param template The invocation template to update to.
|
||||
* @throws {NodeUpdateError} If the node is not an invocation node.
|
||||
*/
|
||||
|
||||
export const updateNode = (
|
||||
node: Node<InvocationNodeData>,
|
||||
template: InvocationTemplate
|
||||
): Node<InvocationNodeData> => {
|
||||
const mayUpdate = getMayUpdateNode(node, template);
|
||||
|
||||
if (!mayUpdate || node.data.type !== template.type) {
|
||||
throw new NodeUpdateError(`Unable to update node ${node.id}`);
|
||||
}
|
||||
|
||||
// Start with a "fresh" node - just as if the user created a new node of this type
|
||||
const defaults = buildInvocationNode(node.position, template);
|
||||
|
||||
// The updateability of a node, via semver comparison, relies on the this kind of recursive merge
|
||||
// being valid. We rely on the template's major version to be majorly incremented if this kind of
|
||||
// merge would result in an invalid node.
|
||||
const clone = cloneDeep(node);
|
||||
clone.data.version = template.version;
|
||||
defaultsDeep(clone, defaults); // mutates!
|
||||
|
||||
return clone;
|
||||
};
|
@ -1,11 +1,12 @@
|
||||
import {
|
||||
COLLECTION_MAP,
|
||||
COLLECTION_TYPES,
|
||||
POLYMORPHIC_TO_SINGLE_MAP,
|
||||
POLYMORPHIC_TYPES,
|
||||
} from 'features/nodes/types/constants';
|
||||
import { FieldType } from 'features/nodes/types/types';
|
||||
import { FieldType } from 'features/nodes/types/field';
|
||||
import { isEqual } from 'lodash-es';
|
||||
|
||||
/**
|
||||
* Validates that the source and target types are compatible for a connection.
|
||||
* @param sourceType The type of the source field.
|
||||
* @param targetType The type of the target field.
|
||||
* @returns True if the connection is valid, false otherwise.
|
||||
*/
|
||||
export const validateSourceAndTargetTypes = (
|
||||
sourceType: FieldType,
|
||||
targetType: FieldType
|
||||
@ -13,11 +14,14 @@ export const validateSourceAndTargetTypes = (
|
||||
// TODO: There's a bug with Collect -> Iterate nodes:
|
||||
// https://github.com/invoke-ai/InvokeAI/issues/3956
|
||||
// Once this is resolved, we can remove this check.
|
||||
if (sourceType === 'Collection' && targetType === 'Collection') {
|
||||
if (
|
||||
sourceType.name === 'CollectionField' &&
|
||||
targetType.name === 'CollectionField'
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (sourceType === targetType) {
|
||||
if (isEqual(sourceType, targetType)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -31,46 +35,42 @@ export const validateSourceAndTargetTypes = (
|
||||
*/
|
||||
|
||||
const isCollectionItemToNonCollection =
|
||||
sourceType === 'CollectionItem' && !COLLECTION_TYPES.includes(targetType);
|
||||
sourceType.name === 'CollectionItemField' && !targetType.isCollection;
|
||||
|
||||
const isNonCollectionToCollectionItem =
|
||||
targetType === 'CollectionItem' &&
|
||||
!COLLECTION_TYPES.includes(sourceType) &&
|
||||
!POLYMORPHIC_TYPES.includes(sourceType);
|
||||
targetType.name === 'CollectionItemField' &&
|
||||
!sourceType.isCollection &&
|
||||
!sourceType.isPolymorphic;
|
||||
|
||||
const isAnythingToPolymorphicOfSameBaseType =
|
||||
POLYMORPHIC_TYPES.includes(targetType) &&
|
||||
(() => {
|
||||
if (!POLYMORPHIC_TYPES.includes(targetType)) {
|
||||
return false;
|
||||
}
|
||||
const baseType =
|
||||
POLYMORPHIC_TO_SINGLE_MAP[
|
||||
targetType as keyof typeof POLYMORPHIC_TO_SINGLE_MAP
|
||||
];
|
||||
|
||||
const collectionType =
|
||||
COLLECTION_MAP[baseType as keyof typeof COLLECTION_MAP];
|
||||
|
||||
return sourceType === baseType || sourceType === collectionType;
|
||||
})();
|
||||
targetType.isPolymorphic && sourceType.name === targetType.name;
|
||||
|
||||
const isGenericCollectionToAnyCollectionOrPolymorphic =
|
||||
sourceType === 'Collection' &&
|
||||
(COLLECTION_TYPES.includes(targetType) ||
|
||||
POLYMORPHIC_TYPES.includes(targetType));
|
||||
sourceType.name === 'CollectionField' &&
|
||||
(targetType.isCollection || targetType.isPolymorphic);
|
||||
|
||||
const isCollectionToGenericCollection =
|
||||
targetType === 'Collection' && COLLECTION_TYPES.includes(sourceType);
|
||||
targetType.name === 'CollectionField' && sourceType.isCollection;
|
||||
|
||||
const isIntToFloat = sourceType === 'integer' && targetType === 'float';
|
||||
const areBothTypesSingle =
|
||||
!sourceType.isCollection &&
|
||||
!sourceType.isPolymorphic &&
|
||||
!targetType.isCollection &&
|
||||
!targetType.isPolymorphic;
|
||||
|
||||
const isIntToFloat =
|
||||
areBothTypesSingle &&
|
||||
sourceType.name === 'IntegerField' &&
|
||||
targetType.name === 'FloatField';
|
||||
|
||||
const isIntOrFloatToString =
|
||||
(sourceType === 'integer' || sourceType === 'float') &&
|
||||
targetType === 'string';
|
||||
areBothTypesSingle &&
|
||||
(sourceType.name === 'IntegerField' || sourceType.name === 'FloatField') &&
|
||||
targetType.name === 'StringField';
|
||||
|
||||
const isTargetAnyType = targetType === 'Any';
|
||||
const isTargetAnyType = targetType.name === 'AnyField';
|
||||
|
||||
// One of these must be true for the connection to be valid
|
||||
return (
|
||||
isCollectionItemToNonCollection ||
|
||||
isNonCollectionToCollectionItem ||
|
||||
|
216
invokeai/frontend/web/src/features/nodes/types/common.ts
Normal file
216
invokeai/frontend/web/src/features/nodes/types/common.ts
Normal file
@ -0,0 +1,216 @@
|
||||
import { z } from 'zod';
|
||||
|
||||
// #region Field data schemas
|
||||
export const zImageField = z.object({
|
||||
image_name: z.string().trim().min(1),
|
||||
});
|
||||
export type ImageField = z.infer<typeof zImageField>;
|
||||
|
||||
export const zBoardField = z.object({
|
||||
board_id: z.string().trim().min(1),
|
||||
});
|
||||
export type BoardField = z.infer<typeof zBoardField>;
|
||||
|
||||
export const zColorField = z.object({
|
||||
r: z.number().int().min(0).max(255),
|
||||
g: z.number().int().min(0).max(255),
|
||||
b: z.number().int().min(0).max(255),
|
||||
a: z.number().int().min(0).max(255),
|
||||
});
|
||||
export type ColorField = z.infer<typeof zColorField>;
|
||||
|
||||
export const zSchedulerField = z.enum([
|
||||
'euler',
|
||||
'deis',
|
||||
'ddim',
|
||||
'ddpm',
|
||||
'dpmpp_2s',
|
||||
'dpmpp_2m',
|
||||
'dpmpp_2m_sde',
|
||||
'dpmpp_sde',
|
||||
'heun',
|
||||
'kdpm_2',
|
||||
'lms',
|
||||
'pndm',
|
||||
'unipc',
|
||||
'euler_k',
|
||||
'dpmpp_2s_k',
|
||||
'dpmpp_2m_k',
|
||||
'dpmpp_2m_sde_k',
|
||||
'dpmpp_sde_k',
|
||||
'heun_k',
|
||||
'lms_k',
|
||||
'euler_a',
|
||||
'kdpm_2_a',
|
||||
'lcm',
|
||||
]);
|
||||
export type SchedulerField = z.infer<typeof zSchedulerField>;
|
||||
// #endregion
|
||||
|
||||
// #region Model-related schemas
|
||||
export const zBaseModel = z.enum([
|
||||
'any',
|
||||
'sd-1',
|
||||
'sd-2',
|
||||
'sdxl',
|
||||
'sdxl-refiner',
|
||||
]);
|
||||
export const zModelType = z.enum([
|
||||
'onnx',
|
||||
'main',
|
||||
'vae',
|
||||
'lora',
|
||||
'controlnet',
|
||||
'embedding',
|
||||
]);
|
||||
export const zModelName = z.string().trim().min(1);
|
||||
export const zModelIdentifier = z.object({
|
||||
model_name: zModelName,
|
||||
base_model: zBaseModel,
|
||||
});
|
||||
export type BaseModel = z.infer<typeof zBaseModel>;
|
||||
export type ModelType = z.infer<typeof zModelType>;
|
||||
export type ModelIdentifier = z.infer<typeof zModelIdentifier>;
|
||||
|
||||
export const zMainModelField = z.object({
|
||||
model_name: zModelName,
|
||||
base_model: zBaseModel,
|
||||
model_type: z.literal('main'),
|
||||
});
|
||||
export const zONNXModelField = z.object({
|
||||
model_name: zModelName,
|
||||
base_model: zBaseModel,
|
||||
model_type: z.literal('onnx'),
|
||||
});
|
||||
export const zMainOrONNXModelField = z.union([
|
||||
zMainModelField,
|
||||
zONNXModelField,
|
||||
]);
|
||||
export const zSDXLRefinerModelField = z.object({
|
||||
model_name: z.string().min(1),
|
||||
base_model: z.literal('sdxl-refiner'),
|
||||
model_type: z.literal('main'),
|
||||
});
|
||||
export type MainModelField = z.infer<typeof zMainModelField>;
|
||||
export type ONNXModelField = z.infer<typeof zONNXModelField>;
|
||||
export type MainOrONNXModelField = z.infer<typeof zMainOrONNXModelField>;
|
||||
export type SDXLRefinerModelField = z.infer<typeof zSDXLRefinerModelField>;
|
||||
|
||||
export const zSubModelType = z.enum([
|
||||
'unet',
|
||||
'text_encoder',
|
||||
'text_encoder_2',
|
||||
'tokenizer',
|
||||
'tokenizer_2',
|
||||
'vae',
|
||||
'vae_decoder',
|
||||
'vae_encoder',
|
||||
'scheduler',
|
||||
'safety_checker',
|
||||
]);
|
||||
export type SubModelType = z.infer<typeof zSubModelType>;
|
||||
|
||||
export const zVAEModelField = zModelIdentifier;
|
||||
|
||||
export const zModelInfo = zModelIdentifier.extend({
|
||||
model_type: zModelType,
|
||||
submodel: zSubModelType.optional(),
|
||||
});
|
||||
export type ModelInfo = z.infer<typeof zModelInfo>;
|
||||
|
||||
export const zLoRAModelField = zModelIdentifier;
|
||||
export type LoRAModelField = z.infer<typeof zLoRAModelField>;
|
||||
|
||||
export const zControlNetModelField = zModelIdentifier;
|
||||
export type ControlNetModelField = z.infer<typeof zControlNetModelField>;
|
||||
|
||||
export const zIPAdapterModelField = zModelIdentifier;
|
||||
export type IPAdapterModelField = z.infer<typeof zIPAdapterModelField>;
|
||||
|
||||
export const zT2IAdapterModelField = zModelIdentifier;
|
||||
export type T2IAdapterModelField = z.infer<typeof zT2IAdapterModelField>;
|
||||
|
||||
export const zLoraInfo = zModelInfo.extend({
|
||||
weight: z.number().optional(),
|
||||
});
|
||||
export type LoraInfo = z.infer<typeof zLoraInfo>;
|
||||
|
||||
export const zUNetField = z.object({
|
||||
unet: zModelInfo,
|
||||
scheduler: zModelInfo,
|
||||
loras: z.array(zLoraInfo),
|
||||
});
|
||||
export type UNetField = z.infer<typeof zUNetField>;
|
||||
|
||||
export const zCLIPField = z.object({
|
||||
tokenizer: zModelInfo,
|
||||
text_encoder: zModelInfo,
|
||||
skipped_layers: z.number(),
|
||||
loras: z.array(zLoraInfo),
|
||||
});
|
||||
export type CLIPField = z.infer<typeof zCLIPField>;
|
||||
|
||||
export const zVAEField = z.object({
|
||||
vae: zModelInfo,
|
||||
});
|
||||
export type VAEField = z.infer<typeof zVAEField>;
|
||||
// #endregion
|
||||
|
||||
// #region Control Adapters
|
||||
export const zControlField = z.object({
|
||||
image: zImageField,
|
||||
control_model: zControlNetModelField,
|
||||
control_weight: z.union([z.number(), z.array(z.number())]).optional(),
|
||||
begin_step_percent: z.number().optional(),
|
||||
end_step_percent: z.number().optional(),
|
||||
control_mode: z
|
||||
.enum(['balanced', 'more_prompt', 'more_control', 'unbalanced'])
|
||||
.optional(),
|
||||
resize_mode: z
|
||||
.enum(['just_resize', 'crop_resize', 'fill_resize', 'just_resize_simple'])
|
||||
.optional(),
|
||||
});
|
||||
export type ControlField = z.infer<typeof zControlField>;
|
||||
|
||||
export const zIPAdapterField = z.object({
|
||||
image: zImageField,
|
||||
ip_adapter_model: zIPAdapterModelField,
|
||||
weight: z.number(),
|
||||
begin_step_percent: z.number().optional(),
|
||||
end_step_percent: z.number().optional(),
|
||||
});
|
||||
export type IPAdapterField = z.infer<typeof zIPAdapterField>;
|
||||
|
||||
export const zT2IAdapterField = z.object({
|
||||
image: zImageField,
|
||||
t2i_adapter_model: zT2IAdapterModelField,
|
||||
weight: z.union([z.number(), z.array(z.number())]).optional(),
|
||||
begin_step_percent: z.number().optional(),
|
||||
end_step_percent: z.number().optional(),
|
||||
resize_mode: z
|
||||
.enum(['just_resize', 'crop_resize', 'fill_resize', 'just_resize_simple'])
|
||||
.optional(),
|
||||
});
|
||||
export type T2IAdapterField = z.infer<typeof zT2IAdapterField>;
|
||||
// #endregion
|
||||
|
||||
// #region ProgressImage
|
||||
export const zProgressImage = z.object({
|
||||
dataURL: z.string(),
|
||||
width: z.number().int(),
|
||||
height: z.number().int(),
|
||||
});
|
||||
export type ProgressImage = z.infer<typeof zProgressImage>;
|
||||
// #endregion
|
||||
|
||||
// #region ImageOutput
|
||||
export const zImageOutput = z.object({
|
||||
image: zImageField,
|
||||
width: z.number().int(),
|
||||
height: z.number().int(),
|
||||
type: z.literal('image_output'),
|
||||
});
|
||||
export type ImageOutput = z.infer<typeof zImageOutput>;
|
||||
export const isImageOutput = (output: unknown): output is ImageOutput =>
|
||||
zImageOutput.safeParse(output).success;
|
||||
// #endregion
|
@ -1,58 +1,31 @@
|
||||
import {
|
||||
FieldType,
|
||||
FieldTypeMap,
|
||||
FieldTypeMapWithNumber,
|
||||
FieldUIConfig,
|
||||
} from './types';
|
||||
import { t } from 'i18next';
|
||||
|
||||
/**
|
||||
* How long to wait before showing a tooltip when hovering a field handle.
|
||||
*/
|
||||
export const HANDLE_TOOLTIP_OPEN_DELAY = 500;
|
||||
export const COLOR_TOKEN_VALUE = 500;
|
||||
|
||||
/**
|
||||
* The width of a node in the UI in pixels.
|
||||
*/
|
||||
export const NODE_WIDTH = 320;
|
||||
export const NODE_MIN_WIDTH = 320;
|
||||
|
||||
/**
|
||||
* This class name is special - reactflow uses it to identify the drag handle of a node,
|
||||
* applying the appropriate listeners to it.
|
||||
*/
|
||||
export const DRAG_HANDLE_CLASSNAME = 'node-drag-handle';
|
||||
|
||||
export const IMAGE_FIELDS = ['ImageField', 'ImageCollection'];
|
||||
export const FOOTER_FIELDS = IMAGE_FIELDS;
|
||||
|
||||
/**
|
||||
* Helper for getting the kind of a field.
|
||||
*/
|
||||
export const KIND_MAP = {
|
||||
input: 'inputs' as const,
|
||||
output: 'outputs' as const,
|
||||
};
|
||||
|
||||
export const COLLECTION_TYPES: FieldType[] = [
|
||||
'Collection',
|
||||
'IntegerCollection',
|
||||
'BooleanCollection',
|
||||
'FloatCollection',
|
||||
'StringCollection',
|
||||
'ImageCollection',
|
||||
'LatentsCollection',
|
||||
'ConditioningCollection',
|
||||
'ControlCollection',
|
||||
'ColorCollection',
|
||||
'T2IAdapterCollection',
|
||||
'IPAdapterCollection',
|
||||
'MetadataItemCollection',
|
||||
'MetadataCollection',
|
||||
];
|
||||
|
||||
export const POLYMORPHIC_TYPES: FieldType[] = [
|
||||
'IntegerPolymorphic',
|
||||
'BooleanPolymorphic',
|
||||
'FloatPolymorphic',
|
||||
'StringPolymorphic',
|
||||
'ImagePolymorphic',
|
||||
'LatentsPolymorphic',
|
||||
'ConditioningPolymorphic',
|
||||
'ControlPolymorphic',
|
||||
'ColorPolymorphic',
|
||||
'T2IAdapterPolymorphic',
|
||||
'IPAdapterPolymorphic',
|
||||
'MetadataItemPolymorphic',
|
||||
];
|
||||
|
||||
export const MODEL_TYPES: FieldType[] = [
|
||||
/**
|
||||
* Model types' handles are rendered as squares in the UI.
|
||||
*/
|
||||
export const MODEL_TYPES = [
|
||||
'IPAdapterModelField',
|
||||
'ControlNetModelField',
|
||||
'LoRAModelField',
|
||||
@ -68,373 +41,33 @@ export const MODEL_TYPES: FieldType[] = [
|
||||
'IPAdapterModelField',
|
||||
];
|
||||
|
||||
export const COLLECTION_MAP: FieldTypeMapWithNumber = {
|
||||
integer: 'IntegerCollection',
|
||||
boolean: 'BooleanCollection',
|
||||
number: 'FloatCollection',
|
||||
float: 'FloatCollection',
|
||||
string: 'StringCollection',
|
||||
ImageField: 'ImageCollection',
|
||||
LatentsField: 'LatentsCollection',
|
||||
ConditioningField: 'ConditioningCollection',
|
||||
ControlField: 'ControlCollection',
|
||||
ColorField: 'ColorCollection',
|
||||
T2IAdapterField: 'T2IAdapterCollection',
|
||||
IPAdapterField: 'IPAdapterCollection',
|
||||
MetadataItemField: 'MetadataItemCollection',
|
||||
MetadataField: 'MetadataCollection',
|
||||
};
|
||||
export const isCollectionItemType = (
|
||||
itemType: string | undefined
|
||||
): itemType is keyof typeof COLLECTION_MAP =>
|
||||
Boolean(itemType && itemType in COLLECTION_MAP);
|
||||
|
||||
export const SINGLE_TO_POLYMORPHIC_MAP: FieldTypeMapWithNumber = {
|
||||
integer: 'IntegerPolymorphic',
|
||||
boolean: 'BooleanPolymorphic',
|
||||
number: 'FloatPolymorphic',
|
||||
float: 'FloatPolymorphic',
|
||||
string: 'StringPolymorphic',
|
||||
ImageField: 'ImagePolymorphic',
|
||||
LatentsField: 'LatentsPolymorphic',
|
||||
ConditioningField: 'ConditioningPolymorphic',
|
||||
ControlField: 'ControlPolymorphic',
|
||||
ColorField: 'ColorPolymorphic',
|
||||
T2IAdapterField: 'T2IAdapterPolymorphic',
|
||||
IPAdapterField: 'IPAdapterPolymorphic',
|
||||
MetadataItemField: 'MetadataItemPolymorphic',
|
||||
};
|
||||
|
||||
export const POLYMORPHIC_TO_SINGLE_MAP: FieldTypeMap = {
|
||||
IntegerPolymorphic: 'integer',
|
||||
BooleanPolymorphic: 'boolean',
|
||||
FloatPolymorphic: 'float',
|
||||
StringPolymorphic: 'string',
|
||||
ImagePolymorphic: 'ImageField',
|
||||
LatentsPolymorphic: 'LatentsField',
|
||||
ConditioningPolymorphic: 'ConditioningField',
|
||||
ControlPolymorphic: 'ControlField',
|
||||
ColorPolymorphic: 'ColorField',
|
||||
T2IAdapterPolymorphic: 'T2IAdapterField',
|
||||
IPAdapterPolymorphic: 'IPAdapterField',
|
||||
MetadataItemPolymorphic: 'MetadataItemField',
|
||||
};
|
||||
|
||||
export const TYPES_WITH_INPUT_COMPONENTS: FieldType[] = [
|
||||
'string',
|
||||
'StringPolymorphic',
|
||||
'boolean',
|
||||
'BooleanPolymorphic',
|
||||
'integer',
|
||||
'float',
|
||||
'FloatPolymorphic',
|
||||
'IntegerPolymorphic',
|
||||
'enum',
|
||||
'ImageField',
|
||||
'ImagePolymorphic',
|
||||
'MainModelField',
|
||||
'SDXLRefinerModelField',
|
||||
'VaeModelField',
|
||||
'LoRAModelField',
|
||||
'ControlNetModelField',
|
||||
'ColorField',
|
||||
'SDXLMainModelField',
|
||||
'Scheduler',
|
||||
'IPAdapterModelField',
|
||||
'BoardField',
|
||||
'T2IAdapterModelField',
|
||||
];
|
||||
|
||||
export const isPolymorphicItemType = (
|
||||
itemType: string | undefined
|
||||
): itemType is keyof typeof SINGLE_TO_POLYMORPHIC_MAP =>
|
||||
Boolean(itemType && itemType in SINGLE_TO_POLYMORPHIC_MAP);
|
||||
|
||||
export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
||||
Any: {
|
||||
color: 'gray.500',
|
||||
description: 'Any field type is accepted.',
|
||||
title: 'Any',
|
||||
},
|
||||
MetadataField: {
|
||||
color: 'gray.500',
|
||||
description: 'A metadata dict.',
|
||||
title: 'Metadata Dict',
|
||||
},
|
||||
MetadataCollection: {
|
||||
color: 'gray.500',
|
||||
description: 'A collection of metadata dicts.',
|
||||
title: 'Metadata Dict Collection',
|
||||
},
|
||||
MetadataItemField: {
|
||||
color: 'gray.500',
|
||||
description: 'A metadata item.',
|
||||
title: 'Metadata Item',
|
||||
},
|
||||
MetadataItemCollection: {
|
||||
color: 'gray.500',
|
||||
description: 'Any field type is accepted.',
|
||||
title: 'Metadata Item Collection',
|
||||
},
|
||||
MetadataItemPolymorphic: {
|
||||
color: 'gray.500',
|
||||
description:
|
||||
'MetadataItem or MetadataItemCollection field types are accepted.',
|
||||
title: 'Metadata Item Polymorphic',
|
||||
},
|
||||
boolean: {
|
||||
color: 'green.500',
|
||||
description: t('nodes.booleanDescription'),
|
||||
title: t('nodes.boolean'),
|
||||
},
|
||||
BooleanCollection: {
|
||||
color: 'green.500',
|
||||
description: t('nodes.booleanCollectionDescription'),
|
||||
title: t('nodes.booleanCollection'),
|
||||
},
|
||||
BooleanPolymorphic: {
|
||||
color: 'green.500',
|
||||
description: t('nodes.booleanPolymorphicDescription'),
|
||||
title: t('nodes.booleanPolymorphic'),
|
||||
},
|
||||
ClipField: {
|
||||
color: 'green.500',
|
||||
description: t('nodes.clipFieldDescription'),
|
||||
title: t('nodes.clipField'),
|
||||
},
|
||||
Collection: {
|
||||
color: 'base.500',
|
||||
description: t('nodes.collectionDescription'),
|
||||
title: t('nodes.collection'),
|
||||
},
|
||||
CollectionItem: {
|
||||
color: 'base.500',
|
||||
description: t('nodes.collectionItemDescription'),
|
||||
title: t('nodes.collectionItem'),
|
||||
},
|
||||
ColorCollection: {
|
||||
color: 'pink.300',
|
||||
description: t('nodes.colorCollectionDescription'),
|
||||
title: t('nodes.colorCollection'),
|
||||
},
|
||||
ColorField: {
|
||||
color: 'pink.300',
|
||||
description: t('nodes.colorFieldDescription'),
|
||||
title: t('nodes.colorField'),
|
||||
},
|
||||
ColorPolymorphic: {
|
||||
color: 'pink.300',
|
||||
description: t('nodes.colorPolymorphicDescription'),
|
||||
title: t('nodes.colorPolymorphic'),
|
||||
},
|
||||
ConditioningCollection: {
|
||||
color: 'cyan.500',
|
||||
description: t('nodes.conditioningCollectionDescription'),
|
||||
title: t('nodes.conditioningCollection'),
|
||||
},
|
||||
ConditioningField: {
|
||||
color: 'cyan.500',
|
||||
description: t('nodes.conditioningFieldDescription'),
|
||||
title: t('nodes.conditioningField'),
|
||||
},
|
||||
ConditioningPolymorphic: {
|
||||
color: 'cyan.500',
|
||||
description: t('nodes.conditioningPolymorphicDescription'),
|
||||
title: t('nodes.conditioningPolymorphic'),
|
||||
},
|
||||
ControlCollection: {
|
||||
color: 'teal.500',
|
||||
description: t('nodes.controlCollectionDescription'),
|
||||
title: t('nodes.controlCollection'),
|
||||
},
|
||||
ControlField: {
|
||||
color: 'teal.500',
|
||||
description: t('nodes.controlFieldDescription'),
|
||||
title: t('nodes.controlField'),
|
||||
},
|
||||
ControlNetModelField: {
|
||||
color: 'teal.500',
|
||||
description: 'TODO',
|
||||
title: 'ControlNet',
|
||||
},
|
||||
ControlPolymorphic: {
|
||||
color: 'teal.500',
|
||||
description: 'Control info passed between nodes.',
|
||||
title: 'Control Polymorphic',
|
||||
},
|
||||
DenoiseMaskField: {
|
||||
color: 'blue.300',
|
||||
description: t('nodes.denoiseMaskFieldDescription'),
|
||||
title: t('nodes.denoiseMaskField'),
|
||||
},
|
||||
enum: {
|
||||
color: 'blue.500',
|
||||
description: t('nodes.enumDescription'),
|
||||
title: t('nodes.enum'),
|
||||
},
|
||||
float: {
|
||||
color: 'orange.500',
|
||||
description: t('nodes.floatDescription'),
|
||||
title: t('nodes.float'),
|
||||
},
|
||||
FloatCollection: {
|
||||
color: 'orange.500',
|
||||
description: t('nodes.floatCollectionDescription'),
|
||||
title: t('nodes.floatCollection'),
|
||||
},
|
||||
FloatPolymorphic: {
|
||||
color: 'orange.500',
|
||||
description: t('nodes.floatPolymorphicDescription'),
|
||||
title: t('nodes.floatPolymorphic'),
|
||||
},
|
||||
ImageCollection: {
|
||||
color: 'purple.500',
|
||||
description: t('nodes.imageCollectionDescription'),
|
||||
title: t('nodes.imageCollection'),
|
||||
},
|
||||
ImageField: {
|
||||
color: 'purple.500',
|
||||
description: t('nodes.imageFieldDescription'),
|
||||
title: t('nodes.imageField'),
|
||||
},
|
||||
BoardField: {
|
||||
color: 'purple.500',
|
||||
description: t('nodes.imageFieldDescription'),
|
||||
title: t('nodes.imageField'),
|
||||
},
|
||||
ImagePolymorphic: {
|
||||
color: 'purple.500',
|
||||
description: t('nodes.imagePolymorphicDescription'),
|
||||
title: t('nodes.imagePolymorphic'),
|
||||
},
|
||||
integer: {
|
||||
color: 'red.500',
|
||||
description: t('nodes.integerDescription'),
|
||||
title: t('nodes.integer'),
|
||||
},
|
||||
IntegerCollection: {
|
||||
color: 'red.500',
|
||||
description: t('nodes.integerCollectionDescription'),
|
||||
title: t('nodes.integerCollection'),
|
||||
},
|
||||
IntegerPolymorphic: {
|
||||
color: 'red.500',
|
||||
description: t('nodes.integerPolymorphicDescription'),
|
||||
title: t('nodes.integerPolymorphic'),
|
||||
},
|
||||
IPAdapterCollection: {
|
||||
color: 'teal.500',
|
||||
description: t('nodes.ipAdapterCollectionDescription'),
|
||||
title: t('nodes.ipAdapterCollection'),
|
||||
},
|
||||
IPAdapterField: {
|
||||
color: 'teal.500',
|
||||
description: t('nodes.ipAdapterDescription'),
|
||||
title: t('nodes.ipAdapter'),
|
||||
},
|
||||
IPAdapterModelField: {
|
||||
color: 'teal.500',
|
||||
description: t('nodes.ipAdapterModelDescription'),
|
||||
title: t('nodes.ipAdapterModel'),
|
||||
},
|
||||
IPAdapterPolymorphic: {
|
||||
color: 'teal.500',
|
||||
description: t('nodes.ipAdapterPolymorphicDescription'),
|
||||
title: t('nodes.ipAdapterPolymorphic'),
|
||||
},
|
||||
LatentsCollection: {
|
||||
color: 'pink.500',
|
||||
description: t('nodes.latentsCollectionDescription'),
|
||||
title: t('nodes.latentsCollection'),
|
||||
},
|
||||
LatentsField: {
|
||||
color: 'pink.500',
|
||||
description: t('nodes.latentsFieldDescription'),
|
||||
title: t('nodes.latentsField'),
|
||||
},
|
||||
LatentsPolymorphic: {
|
||||
color: 'pink.500',
|
||||
description: t('nodes.latentsPolymorphicDescription'),
|
||||
title: t('nodes.latentsPolymorphic'),
|
||||
},
|
||||
LoRAModelField: {
|
||||
color: 'teal.500',
|
||||
description: t('nodes.loRAModelFieldDescription'),
|
||||
title: t('nodes.loRAModelField'),
|
||||
},
|
||||
MainModelField: {
|
||||
color: 'teal.500',
|
||||
description: t('nodes.mainModelFieldDescription'),
|
||||
title: t('nodes.mainModelField'),
|
||||
},
|
||||
ONNXModelField: {
|
||||
color: 'teal.500',
|
||||
description: t('nodes.oNNXModelFieldDescription'),
|
||||
title: t('nodes.oNNXModelField'),
|
||||
},
|
||||
Scheduler: {
|
||||
color: 'base.500',
|
||||
description: t('nodes.schedulerDescription'),
|
||||
title: t('nodes.scheduler'),
|
||||
},
|
||||
SDXLMainModelField: {
|
||||
color: 'teal.500',
|
||||
description: t('nodes.sDXLMainModelFieldDescription'),
|
||||
title: t('nodes.sDXLMainModelField'),
|
||||
},
|
||||
SDXLRefinerModelField: {
|
||||
color: 'teal.500',
|
||||
description: t('nodes.sDXLRefinerModelFieldDescription'),
|
||||
title: t('nodes.sDXLRefinerModelField'),
|
||||
},
|
||||
string: {
|
||||
color: 'yellow.500',
|
||||
description: t('nodes.stringDescription'),
|
||||
title: t('nodes.string'),
|
||||
},
|
||||
StringCollection: {
|
||||
color: 'yellow.500',
|
||||
description: t('nodes.stringCollectionDescription'),
|
||||
title: t('nodes.stringCollection'),
|
||||
},
|
||||
StringPolymorphic: {
|
||||
color: 'yellow.500',
|
||||
description: t('nodes.stringPolymorphicDescription'),
|
||||
title: t('nodes.stringPolymorphic'),
|
||||
},
|
||||
T2IAdapterCollection: {
|
||||
color: 'teal.500',
|
||||
description: t('nodes.t2iAdapterCollectionDescription'),
|
||||
title: t('nodes.t2iAdapterCollection'),
|
||||
},
|
||||
T2IAdapterField: {
|
||||
color: 'teal.500',
|
||||
description: t('nodes.t2iAdapterFieldDescription'),
|
||||
title: t('nodes.t2iAdapterField'),
|
||||
},
|
||||
T2IAdapterModelField: {
|
||||
color: 'teal.500',
|
||||
description: 'TODO',
|
||||
title: 'T2I-Adapter',
|
||||
},
|
||||
T2IAdapterPolymorphic: {
|
||||
color: 'teal.500',
|
||||
description: 'T2I-Adapter info passed between nodes.',
|
||||
title: 'T2I-Adapter Polymorphic',
|
||||
},
|
||||
UNetField: {
|
||||
color: 'red.500',
|
||||
description: t('nodes.uNetFieldDescription'),
|
||||
title: t('nodes.uNetField'),
|
||||
},
|
||||
VaeField: {
|
||||
color: 'blue.500',
|
||||
description: t('nodes.vaeFieldDescription'),
|
||||
title: t('nodes.vaeField'),
|
||||
},
|
||||
VaeModelField: {
|
||||
color: 'teal.500',
|
||||
description: t('nodes.vaeModelFieldDescription'),
|
||||
title: t('nodes.vaeModelField'),
|
||||
},
|
||||
/**
|
||||
* Colors for each field type - applies to their handles and edges.
|
||||
*/
|
||||
export const FIELD_COLORS: { [key: string]: string } = {
|
||||
BoardField: 'purple.500',
|
||||
BooleanField: 'green.500',
|
||||
ClipField: 'green.500',
|
||||
ColorField: 'pink.300',
|
||||
ConditioningField: 'cyan.500',
|
||||
ControlField: 'teal.500',
|
||||
ControlNetModelField: 'teal.500',
|
||||
EnumField: 'blue.500',
|
||||
FloatField: 'orange.500',
|
||||
ImageField: 'purple.500',
|
||||
IntegerField: 'red.500',
|
||||
IPAdapterField: 'teal.500',
|
||||
IPAdapterModelField: 'teal.500',
|
||||
LatentsField: 'pink.500',
|
||||
LoRAModelField: 'teal.500',
|
||||
MainModelField: 'teal.500',
|
||||
ONNXModelField: 'teal.500',
|
||||
SDXLMainModelField: 'teal.500',
|
||||
SDXLRefinerModelField: 'teal.500',
|
||||
StringField: 'yellow.500',
|
||||
T2IAdapterField: 'teal.500',
|
||||
T2IAdapterModelField: 'teal.500',
|
||||
UNetField: 'red.500',
|
||||
VaeField: 'blue.500',
|
||||
VaeModelField: 'teal.500',
|
||||
};
|
||||
|
59
invokeai/frontend/web/src/features/nodes/types/error.ts
Normal file
59
invokeai/frontend/web/src/features/nodes/types/error.ts
Normal file
@ -0,0 +1,59 @@
|
||||
/**
|
||||
* Invalid Workflow Version Error
|
||||
* Raised when a workflow version is not recognized.
|
||||
*/
|
||||
export class WorkflowVersionError extends Error {
|
||||
/**
|
||||
* Create WorkflowVersionError
|
||||
* @param {String} message
|
||||
*/
|
||||
constructor(message: string) {
|
||||
super(message);
|
||||
this.name = this.constructor.name;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Unable to Update Node Error
|
||||
* Raised when a node cannot be updated.
|
||||
*/
|
||||
export class NodeUpdateError extends Error {
|
||||
/**
|
||||
* Create NodeUpdateError
|
||||
* @param {String} message
|
||||
*/
|
||||
constructor(message: string) {
|
||||
super(message);
|
||||
this.name = this.constructor.name;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* FieldTypeParseError
|
||||
* Raised when a field cannot be parsed from a field schema.
|
||||
*/
|
||||
export class FieldTypeParseError extends Error {
|
||||
/**
|
||||
* Create FieldTypeParseError
|
||||
* @param {String} message
|
||||
*/
|
||||
constructor(message: string) {
|
||||
super(message);
|
||||
this.name = this.constructor.name;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* UnsupportedFieldTypeError
|
||||
* Raised when an unsupported field type is parsed.
|
||||
*/
|
||||
export class UnsupportedFieldTypeError extends Error {
|
||||
/**
|
||||
* Create UnsupportedFieldTypeError
|
||||
* @param {String} message
|
||||
*/
|
||||
constructor(message: string) {
|
||||
super(message);
|
||||
this.name = this.constructor.name;
|
||||
}
|
||||
}
|
1114
invokeai/frontend/web/src/features/nodes/types/field.ts
Normal file
1114
invokeai/frontend/web/src/features/nodes/types/field.ts
Normal file
File diff suppressed because it is too large
Load Diff
108
invokeai/frontend/web/src/features/nodes/types/invocation.ts
Normal file
108
invokeai/frontend/web/src/features/nodes/types/invocation.ts
Normal file
@ -0,0 +1,108 @@
|
||||
import { Node } from 'reactflow';
|
||||
import { z } from 'zod';
|
||||
import { zProgressImage } from './common';
|
||||
import {
|
||||
zFieldInputInstance,
|
||||
zFieldInputTemplate,
|
||||
zFieldOutputInstance,
|
||||
zFieldOutputTemplate,
|
||||
} from './field';
|
||||
import { zSemVer } from './semver';
|
||||
|
||||
// #region InvocationTemplate
|
||||
export const zInvocationTemplate = z.object({
|
||||
type: z.string(),
|
||||
title: z.string(),
|
||||
description: z.string(),
|
||||
tags: z.array(z.string().min(1)),
|
||||
inputs: z.record(zFieldInputTemplate),
|
||||
outputs: z.record(zFieldOutputTemplate),
|
||||
outputType: z.string().min(1),
|
||||
withWorkflow: z.boolean(),
|
||||
version: zSemVer,
|
||||
useCache: z.boolean(),
|
||||
});
|
||||
export type InvocationTemplate = z.infer<typeof zInvocationTemplate>;
|
||||
// #endregion
|
||||
|
||||
// #region NodeData
|
||||
export const zInvocationNodeData = z.object({
|
||||
id: z.string().trim().min(1),
|
||||
type: z.string().trim().min(1),
|
||||
label: z.string(),
|
||||
isOpen: z.boolean(),
|
||||
notes: z.string(),
|
||||
embedWorkflow: z.boolean(),
|
||||
isIntermediate: z.boolean(),
|
||||
useCache: z.boolean(),
|
||||
version: zSemVer,
|
||||
inputs: z.record(zFieldInputInstance),
|
||||
outputs: z.record(zFieldOutputInstance),
|
||||
});
|
||||
|
||||
export const zNotesNodeData = z.object({
|
||||
id: z.string().trim().min(1),
|
||||
type: z.literal('notes'),
|
||||
label: z.string(),
|
||||
isOpen: z.boolean(),
|
||||
notes: z.string(),
|
||||
});
|
||||
export const zCurrentImageNodeData = z.object({
|
||||
id: z.string().trim().min(1),
|
||||
type: z.literal('current_image'),
|
||||
label: z.string(),
|
||||
isOpen: z.boolean(),
|
||||
});
|
||||
export const zAnyNodeData = z.union([
|
||||
zInvocationNodeData,
|
||||
zNotesNodeData,
|
||||
zCurrentImageNodeData,
|
||||
]);
|
||||
|
||||
export type NotesNodeData = z.infer<typeof zNotesNodeData>;
|
||||
export type InvocationNodeData = z.infer<typeof zInvocationNodeData>;
|
||||
export type CurrentImageNodeData = z.infer<typeof zCurrentImageNodeData>;
|
||||
export type AnyNodeData = z.infer<typeof zAnyNodeData>;
|
||||
|
||||
export const isInvocationNode = (
|
||||
node?: Node<AnyNodeData>
|
||||
): node is Node<InvocationNodeData> =>
|
||||
Boolean(node && node.type === 'invocation');
|
||||
export const isNotesNode = (
|
||||
node?: Node<AnyNodeData>
|
||||
): node is Node<NotesNodeData> => Boolean(node && node.type === 'notes');
|
||||
export const isProgressImageNode = (
|
||||
node?: Node<AnyNodeData>
|
||||
): node is Node<CurrentImageNodeData> =>
|
||||
Boolean(node && node.type === 'current_image');
|
||||
export const isInvocationNodeData = (
|
||||
node?: AnyNodeData
|
||||
): node is InvocationNodeData =>
|
||||
Boolean(node && !['notes', 'current_image'].includes(node.type)); // node.type may be 'notes', 'current_image', or any invocation type
|
||||
// #endregion
|
||||
|
||||
// #region NodeExecutionState
|
||||
export const zNodeStatus = z.enum([
|
||||
'PENDING',
|
||||
'IN_PROGRESS',
|
||||
'COMPLETED',
|
||||
'FAILED',
|
||||
]);
|
||||
export const zNodeExecutionState = z.object({
|
||||
nodeId: z.string().trim().min(1),
|
||||
status: zNodeStatus,
|
||||
progress: z.number().nullable(),
|
||||
progressImage: zProgressImage.nullable(),
|
||||
error: z.string().nullable(),
|
||||
outputs: z.array(z.any()),
|
||||
});
|
||||
export type NodeExecutionState = z.infer<typeof zNodeExecutionState>;
|
||||
export type NodeStatus = z.infer<typeof zNodeStatus>;
|
||||
// #endregion
|
||||
|
||||
// #region Edges
|
||||
export const zInvocationEdgeExtra = z.object({
|
||||
type: z.union([z.literal('default'), z.literal('collapsed')]),
|
||||
});
|
||||
export type InvocationEdgeExtra = z.infer<typeof zInvocationEdgeExtra>;
|
||||
// #endregion
|
81
invokeai/frontend/web/src/features/nodes/types/metadata.ts
Normal file
81
invokeai/frontend/web/src/features/nodes/types/metadata.ts
Normal file
@ -0,0 +1,81 @@
|
||||
import { z } from 'zod';
|
||||
import {
|
||||
zControlField,
|
||||
zIPAdapterField,
|
||||
zLoRAModelField,
|
||||
zMainModelField,
|
||||
zONNXModelField,
|
||||
zSDXLRefinerModelField,
|
||||
zT2IAdapterField,
|
||||
zVAEModelField,
|
||||
} from './common';
|
||||
|
||||
// #region Metadata-optimized versions of schemas
|
||||
// TODO: It's possible that `deepPartial` will be deprecated:
|
||||
// - https://github.com/colinhacks/zod/issues/2106
|
||||
// - https://github.com/colinhacks/zod/issues/2854
|
||||
export const zLoRAMetadataItem = z.object({
|
||||
lora: zLoRAModelField.deepPartial(),
|
||||
weight: z.number(),
|
||||
});
|
||||
const zControlNetMetadataItem = zControlField.deepPartial();
|
||||
const zIPAdapterMetadataItem = zIPAdapterField.deepPartial();
|
||||
const zT2IAdapterMetadataItem = zT2IAdapterField.deepPartial();
|
||||
const zSDXLRefinerModelMetadataItem = zSDXLRefinerModelField.deepPartial();
|
||||
const zModelMetadataitem = z.union([
|
||||
zMainModelField.deepPartial(),
|
||||
zONNXModelField.deepPartial(),
|
||||
]);
|
||||
const zVAEModelMetadataItem = zVAEModelField.deepPartial();
|
||||
export type LoRAMetadataItem = z.infer<typeof zLoRAMetadataItem>;
|
||||
export type ControlNetMetadataItem = z.infer<typeof zControlNetMetadataItem>;
|
||||
export type IPAdapterMetadataItem = z.infer<typeof zIPAdapterMetadataItem>;
|
||||
export type T2IAdapterMetadataItem = z.infer<typeof zT2IAdapterMetadataItem>;
|
||||
export type SDXLRefinerModelMetadataItem = z.infer<
|
||||
typeof zSDXLRefinerModelMetadataItem
|
||||
>;
|
||||
export type ModelMetadataitem = z.infer<typeof zModelMetadataitem>;
|
||||
export type VAEModelMetadataItem = z.infer<typeof zVAEModelMetadataItem>;
|
||||
// #endregion
|
||||
|
||||
// #region CoreMetadata
|
||||
export const zCoreMetadata = z
|
||||
.object({
|
||||
app_version: z.string().nullish().catch(null),
|
||||
generation_mode: z.string().nullish().catch(null),
|
||||
created_by: z.string().nullish().catch(null),
|
||||
positive_prompt: z.string().nullish().catch(null),
|
||||
negative_prompt: z.string().nullish().catch(null),
|
||||
width: z.number().int().nullish().catch(null),
|
||||
height: z.number().int().nullish().catch(null),
|
||||
seed: z.number().int().nullish().catch(null),
|
||||
rand_device: z.string().nullish().catch(null),
|
||||
cfg_scale: z.number().nullish().catch(null),
|
||||
steps: z.number().int().nullish().catch(null),
|
||||
scheduler: z.string().nullish().catch(null),
|
||||
clip_skip: z.number().int().nullish().catch(null),
|
||||
model: zModelMetadataitem.nullish().catch(null),
|
||||
controlnets: z.array(zControlNetMetadataItem).nullish().catch(null),
|
||||
ipAdapters: z.array(zIPAdapterMetadataItem).nullish().catch(null),
|
||||
t2iAdapters: z.array(zT2IAdapterMetadataItem).nullish().catch(null),
|
||||
loras: z.array(zLoRAMetadataItem).nullish().catch(null),
|
||||
vae: zVAEModelMetadataItem.nullish().catch(null),
|
||||
strength: z.number().nullish().catch(null),
|
||||
hrf_enabled: z.boolean().nullish().catch(null),
|
||||
hrf_strength: z.number().nullish().catch(null),
|
||||
hrf_method: z.string().nullish().catch(null),
|
||||
init_image: z.string().nullish().catch(null),
|
||||
positive_style_prompt: z.string().nullish().catch(null),
|
||||
negative_style_prompt: z.string().nullish().catch(null),
|
||||
refiner_model: zSDXLRefinerModelMetadataItem.nullish().catch(null),
|
||||
refiner_cfg_scale: z.number().nullish().catch(null),
|
||||
refiner_steps: z.number().int().nullish().catch(null),
|
||||
refiner_scheduler: z.string().nullish().catch(null),
|
||||
refiner_positive_aesthetic_score: z.number().nullish().catch(null),
|
||||
refiner_negative_aesthetic_score: z.number().nullish().catch(null),
|
||||
refiner_start: z.number().nullish().catch(null),
|
||||
})
|
||||
.passthrough();
|
||||
export type CoreMetadata = z.infer<typeof zCoreMetadata>;
|
||||
|
||||
// #endregion
|
@ -0,0 +1,69 @@
|
||||
import { forEach, isString } from 'lodash-es';
|
||||
import { z } from 'zod';
|
||||
import { WorkflowVersionError } from '../error';
|
||||
import { zSemVer } from '../semver';
|
||||
import { WorkflowV2, zWorkflowV2 } from '../workflow';
|
||||
import { FIELD_TYPE_V1_TO_FIELD_TYPE_V2_MAPPING } from './v1/fieldTypeMap';
|
||||
import { WorkflowV1, zWorkflowV1 } from './v1/workflowV1';
|
||||
import { t } from 'i18next';
|
||||
|
||||
/**
|
||||
* Helper schema to extract the version from a workflow.
|
||||
*
|
||||
* All properties except for the version are ignored in this schema.
|
||||
*/
|
||||
const zWorkflowMetaVersion = z.object({
|
||||
meta: z.object({ version: zSemVer }),
|
||||
});
|
||||
|
||||
/**
|
||||
* Migrates a workflow from V1 to V2.
|
||||
*/
|
||||
const migrateV1toV2 = (workflowToMigrate: WorkflowV1): WorkflowV2 => {
|
||||
workflowToMigrate.nodes.forEach((node) => {
|
||||
if (node.type === 'invocation') {
|
||||
forEach(node.data.inputs, (input) => {
|
||||
if (!isString(input.type)) {
|
||||
return;
|
||||
}
|
||||
(input.type as unknown) =
|
||||
FIELD_TYPE_V1_TO_FIELD_TYPE_V2_MAPPING[input.type];
|
||||
});
|
||||
forEach(node.data.outputs, (output) => {
|
||||
if (!isString(output.type)) {
|
||||
return;
|
||||
}
|
||||
(output.type as unknown) =
|
||||
FIELD_TYPE_V1_TO_FIELD_TYPE_V2_MAPPING[output.type];
|
||||
});
|
||||
}
|
||||
});
|
||||
(workflowToMigrate.meta.version as WorkflowV2['meta']['version']) = '2.0.0';
|
||||
return zWorkflowV2.parse(workflowToMigrate);
|
||||
};
|
||||
|
||||
/**
|
||||
* Parses a workflow and migrates it to the latest version if necessary.
|
||||
*/
|
||||
export const parseAndMigrateWorkflow = (data: unknown): WorkflowV2 => {
|
||||
const workflowVersionResult = zWorkflowMetaVersion.safeParse(data);
|
||||
|
||||
if (!workflowVersionResult.success) {
|
||||
throw new WorkflowVersionError(t('nodes.unableToGetWorkflowVersion'));
|
||||
}
|
||||
|
||||
const { version } = workflowVersionResult.data.meta;
|
||||
|
||||
if (version === '1.0.0') {
|
||||
const v1 = zWorkflowV1.parse(data);
|
||||
return migrateV1toV2(v1);
|
||||
}
|
||||
|
||||
if (version === '2.0.0') {
|
||||
return zWorkflowV2.parse(data);
|
||||
}
|
||||
|
||||
throw new WorkflowVersionError(
|
||||
t('nodes.unrecognizedWorkflowVersion', { version })
|
||||
);
|
||||
};
|
@ -0,0 +1,270 @@
|
||||
import { FieldType, StatefulFieldType } from '../../field';
|
||||
import { FieldTypeV1 } from './workflowV1';
|
||||
|
||||
/**
|
||||
* Mapping of V1 field type strings to their *stateful* V2 field type counterparts.
|
||||
*/
|
||||
const FIELD_TYPE_V1_TO_STATEFUL_FIELD_TYPE_V2: {
|
||||
[key in FieldTypeV1]?: StatefulFieldType;
|
||||
} = {
|
||||
BoardField: { name: 'BoardField', isCollection: false, isPolymorphic: false },
|
||||
boolean: { name: 'BooleanField', isCollection: false, isPolymorphic: false },
|
||||
BooleanCollection: {
|
||||
name: 'BooleanField',
|
||||
isCollection: true,
|
||||
isPolymorphic: false,
|
||||
},
|
||||
BooleanPolymorphic: {
|
||||
name: 'BooleanField',
|
||||
isCollection: false,
|
||||
isPolymorphic: true,
|
||||
},
|
||||
ColorField: { name: 'ColorField', isCollection: false, isPolymorphic: false },
|
||||
ColorCollection: {
|
||||
name: 'ColorField',
|
||||
isCollection: true,
|
||||
isPolymorphic: false,
|
||||
},
|
||||
ColorPolymorphic: {
|
||||
name: 'ColorField',
|
||||
isCollection: false,
|
||||
isPolymorphic: true,
|
||||
},
|
||||
ControlNetModelField: {
|
||||
name: 'ControlNetModelField',
|
||||
isCollection: false,
|
||||
isPolymorphic: false,
|
||||
},
|
||||
enum: { name: 'EnumField', isCollection: false, isPolymorphic: false },
|
||||
float: { name: 'FloatField', isCollection: false, isPolymorphic: false },
|
||||
FloatCollection: {
|
||||
name: 'FloatField',
|
||||
isCollection: true,
|
||||
isPolymorphic: false,
|
||||
},
|
||||
FloatPolymorphic: {
|
||||
name: 'FloatField',
|
||||
isCollection: false,
|
||||
isPolymorphic: true,
|
||||
},
|
||||
ImageCollection: {
|
||||
name: 'ImageField',
|
||||
isCollection: true,
|
||||
isPolymorphic: false,
|
||||
},
|
||||
ImageField: { name: 'ImageField', isCollection: false, isPolymorphic: false },
|
||||
ImagePolymorphic: {
|
||||
name: 'ImageField',
|
||||
isCollection: false,
|
||||
isPolymorphic: true,
|
||||
},
|
||||
integer: { name: 'IntegerField', isCollection: false, isPolymorphic: false },
|
||||
IntegerCollection: {
|
||||
name: 'IntegerField',
|
||||
isCollection: true,
|
||||
isPolymorphic: false,
|
||||
},
|
||||
IntegerPolymorphic: {
|
||||
name: 'IntegerField',
|
||||
isCollection: false,
|
||||
isPolymorphic: true,
|
||||
},
|
||||
IPAdapterModelField: {
|
||||
name: 'IPAdapterModelField',
|
||||
isCollection: false,
|
||||
isPolymorphic: false,
|
||||
},
|
||||
LoRAModelField: {
|
||||
name: 'LoRAModelField',
|
||||
isCollection: false,
|
||||
isPolymorphic: false,
|
||||
},
|
||||
MainModelField: {
|
||||
name: 'MainModelField',
|
||||
isCollection: false,
|
||||
isPolymorphic: false,
|
||||
},
|
||||
Scheduler: {
|
||||
name: 'SchedulerField',
|
||||
isCollection: false,
|
||||
isPolymorphic: false,
|
||||
},
|
||||
SDXLMainModelField: {
|
||||
name: 'SDXLMainModelField',
|
||||
isCollection: false,
|
||||
isPolymorphic: false,
|
||||
},
|
||||
SDXLRefinerModelField: {
|
||||
name: 'SDXLRefinerModelField',
|
||||
isCollection: false,
|
||||
isPolymorphic: false,
|
||||
},
|
||||
string: { name: 'StringField', isCollection: false, isPolymorphic: false },
|
||||
StringCollection: {
|
||||
name: 'StringField',
|
||||
isCollection: true,
|
||||
isPolymorphic: false,
|
||||
},
|
||||
StringPolymorphic: {
|
||||
name: 'StringField',
|
||||
isCollection: false,
|
||||
isPolymorphic: true,
|
||||
},
|
||||
T2IAdapterModelField: {
|
||||
name: 'T2IAdapterModelField',
|
||||
isCollection: false,
|
||||
isPolymorphic: false,
|
||||
},
|
||||
VaeModelField: {
|
||||
name: 'VAEModelField',
|
||||
isCollection: false,
|
||||
isPolymorphic: false,
|
||||
},
|
||||
};
|
||||
|
||||
/**
|
||||
* Mapping of V1 field type strings to their *stateless* V2 field type counterparts.
|
||||
*
|
||||
* The type doesn't do what I want it to do.
|
||||
*
|
||||
* Ideally, the value of each propery would be a `FieldType` where `FieldType['name']` is not in
|
||||
* `StatefulFieldType['name']`, but this is hard to represent. That's because `FieldType['name']` is
|
||||
* actually widened to `string`, and TS's `Exclude<T,U>` doesn't work on `string`.
|
||||
*
|
||||
* There's probably some way to do it with conditionals and intersections but I can't figure it out.
|
||||
*
|
||||
* Thus, this object was manually edited to ensure it is correct.
|
||||
*/
|
||||
const FIELD_TYPE_V1_TO_STATELESS_FIELD_TYPE_V2: {
|
||||
[key in FieldTypeV1]?: FieldType;
|
||||
} = {
|
||||
Any: { name: 'AnyField', isCollection: false, isPolymorphic: false },
|
||||
ClipField: { name: 'ClipField', isCollection: false, isPolymorphic: false },
|
||||
Collection: {
|
||||
name: 'CollectionField',
|
||||
isCollection: true,
|
||||
isPolymorphic: false,
|
||||
},
|
||||
CollectionItem: {
|
||||
name: 'CollectionItemField',
|
||||
isCollection: false,
|
||||
isPolymorphic: false,
|
||||
},
|
||||
ConditioningCollection: {
|
||||
name: 'ConditioningField',
|
||||
isCollection: true,
|
||||
isPolymorphic: false,
|
||||
},
|
||||
ConditioningField: {
|
||||
name: 'ConditioningField',
|
||||
isCollection: false,
|
||||
isPolymorphic: false,
|
||||
},
|
||||
ConditioningPolymorphic: {
|
||||
name: 'ConditioningField',
|
||||
isCollection: false,
|
||||
isPolymorphic: true,
|
||||
},
|
||||
ControlCollection: {
|
||||
name: 'ControlField',
|
||||
isCollection: true,
|
||||
isPolymorphic: false,
|
||||
},
|
||||
ControlField: {
|
||||
name: 'ControlField',
|
||||
isCollection: false,
|
||||
isPolymorphic: false,
|
||||
},
|
||||
ControlPolymorphic: {
|
||||
name: 'ControlField',
|
||||
isCollection: false,
|
||||
isPolymorphic: true,
|
||||
},
|
||||
DenoiseMaskField: {
|
||||
name: 'DenoiseMaskField',
|
||||
isCollection: false,
|
||||
isPolymorphic: false,
|
||||
},
|
||||
IPAdapterField: {
|
||||
name: 'IPAdapterField',
|
||||
isCollection: false,
|
||||
isPolymorphic: false,
|
||||
},
|
||||
IPAdapterCollection: {
|
||||
name: 'IPAdapterField',
|
||||
isCollection: true,
|
||||
isPolymorphic: false,
|
||||
},
|
||||
IPAdapterPolymorphic: {
|
||||
name: 'IPAdapterField',
|
||||
isCollection: false,
|
||||
isPolymorphic: true,
|
||||
},
|
||||
LatentsField: {
|
||||
name: 'LatentsField',
|
||||
isCollection: false,
|
||||
isPolymorphic: false,
|
||||
},
|
||||
LatentsCollection: {
|
||||
name: 'LatentsField',
|
||||
isCollection: true,
|
||||
isPolymorphic: false,
|
||||
},
|
||||
LatentsPolymorphic: {
|
||||
name: 'LatentsField',
|
||||
isCollection: false,
|
||||
isPolymorphic: true,
|
||||
},
|
||||
MetadataField: {
|
||||
name: 'MetadataField',
|
||||
isCollection: false,
|
||||
isPolymorphic: false,
|
||||
},
|
||||
MetadataCollection: {
|
||||
name: 'MetadataField',
|
||||
isCollection: true,
|
||||
isPolymorphic: false,
|
||||
},
|
||||
MetadataItemField: {
|
||||
name: 'MetadataItemField',
|
||||
isCollection: false,
|
||||
isPolymorphic: false,
|
||||
},
|
||||
MetadataItemCollection: {
|
||||
name: 'MetadataItemField',
|
||||
isCollection: true,
|
||||
isPolymorphic: false,
|
||||
},
|
||||
MetadataItemPolymorphic: {
|
||||
name: 'MetadataItemField',
|
||||
isCollection: false,
|
||||
isPolymorphic: true,
|
||||
},
|
||||
ONNXModelField: {
|
||||
name: 'ONNXModelField',
|
||||
isCollection: false,
|
||||
isPolymorphic: false,
|
||||
},
|
||||
T2IAdapterField: {
|
||||
name: 'T2IAdapterField',
|
||||
isCollection: false,
|
||||
isPolymorphic: false,
|
||||
},
|
||||
T2IAdapterCollection: {
|
||||
name: 'T2IAdapterField',
|
||||
isCollection: true,
|
||||
isPolymorphic: false,
|
||||
},
|
||||
T2IAdapterPolymorphic: {
|
||||
name: 'T2IAdapterField',
|
||||
isCollection: false,
|
||||
isPolymorphic: true,
|
||||
},
|
||||
UNetField: { name: 'UNetField', isCollection: false, isPolymorphic: false },
|
||||
VaeField: { name: 'VaeField', isCollection: false, isPolymorphic: false },
|
||||
};
|
||||
|
||||
export const FIELD_TYPE_V1_TO_FIELD_TYPE_V2_MAPPING = {
|
||||
...FIELD_TYPE_V1_TO_STATEFUL_FIELD_TYPE_V2,
|
||||
...FIELD_TYPE_V1_TO_STATELESS_FIELD_TYPE_V2,
|
||||
};
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user