mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): workflow schema v3 (WIP)
The changes aim to deduplicate data between workflows and node templates, decoupling workflows from internal implementation details. A good amount of data that was needlessly duplicated from the node template to the workflow is removed. These changes substantially reduce the file size of workflows (and therefore the images with embedded workflows): - Default T2I SD1.5 workflow JSON is reduced from 23.7kb (798 lines) to 10.9kb (407 lines). - Default tiled upscale workflow JSON is reduced from 102.7kb (3341 lines) to 51.9kb (1774 lines). The trade-off is that we need to reference node templates to get things like the field type and other things. In practice, this is a non-issue, because we need a node template to do anything with a node anyways. - Field types are not included in the workflow. They are always pulled from the node templates. The field type is now properly an internal implementation detail and we can change it as needed. Previously this would require a migration for the workflow itself. With the v3 schema, the structure of a field type is an internal implementation detail that we are free to change as we see fit. - Workflow nodes no long have an `outputs` property and there is no longer such a thing as a `FieldOutputInstance`. These are only on the templates. These were never referenced at a time when we didn't also have the templates available, and there'd be no reason to do so. - Node width and height are no longer stored in the node. These weren't used. Also, per https://reactflow.dev/api-reference/types/node, we shouldn't be programmatically changing these properties. A future enhancement can properly add node resizing. - `nodeTemplates` slice is merged back into `nodesSlice` as `nodes.templates`. Turns out it's just a hassle having these separate in separate slices. - Workflow migration logic updated to support the new schema. V1 workflows migrate all the way to v3 now. - Changes throughout the nodes code to accommodate the above changes.
This commit is contained in:
parent
5fbfed30ac
commit
f8525837b2
@ -1,6 +1,6 @@
|
||||
import type { UnknownAction } from '@reduxjs/toolkit';
|
||||
import { isAnyGraphBuilt } from 'features/nodes/store/actions';
|
||||
import { nodeTemplatesBuilt } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
|
||||
import { cloneDeep } from 'lodash-es';
|
||||
import { appInfoApi } from 'services/api/endpoints/appInfo';
|
||||
import type { Graph } from 'services/api/types';
|
||||
|
@ -1,6 +1,6 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { nodeTemplatesBuilt } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
|
||||
import { parseSchema } from 'features/nodes/util/schema/parseSchema';
|
||||
import { size } from 'lodash-es';
|
||||
import { appInfoApi } from 'services/api/endpoints/appInfo';
|
||||
|
@ -15,8 +15,7 @@ export const addUpdateAllNodesRequestedListener = () => {
|
||||
actionCreator: updateAllNodesRequested,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const log = logger('nodes');
|
||||
const nodes = getState().nodes.nodes;
|
||||
const templates = getState().nodeTemplates.templates;
|
||||
const { nodes, templates } = getState().nodes;
|
||||
|
||||
let unableToUpdateCount = 0;
|
||||
|
||||
|
@ -18,7 +18,7 @@ export const addWorkflowLoadRequestedListener = () => {
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const log = logger('nodes');
|
||||
const { workflow, asCopy } = action.payload;
|
||||
const nodeTemplates = getState().nodeTemplates.templates;
|
||||
const nodeTemplates = getState().nodes.templates;
|
||||
|
||||
try {
|
||||
const { workflow: validatedWorkflow, warnings } = validateWorkflow(workflow, nodeTemplates);
|
||||
|
@ -16,7 +16,6 @@ import { hrfPersistConfig, hrfSlice } from 'features/hrf/store/hrfSlice';
|
||||
import { loraPersistConfig, loraSlice } from 'features/lora/store/loraSlice';
|
||||
import { modelManagerPersistConfig, modelManagerSlice } from 'features/modelManager/store/modelManagerSlice';
|
||||
import { nodesPersistConfig, nodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { nodesTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { workflowPersistConfig, workflowSlice } from 'features/nodes/store/workflowSlice';
|
||||
import { generationPersistConfig, generationSlice } from 'features/parameters/store/generationSlice';
|
||||
import { postprocessingPersistConfig, postprocessingSlice } from 'features/parameters/store/postprocessingSlice';
|
||||
@ -46,7 +45,6 @@ const allReducers = {
|
||||
[gallerySlice.name]: gallerySlice.reducer,
|
||||
[generationSlice.name]: generationSlice.reducer,
|
||||
[nodesSlice.name]: nodesSlice.reducer,
|
||||
[nodesTemplatesSlice.name]: nodesTemplatesSlice.reducer,
|
||||
[postprocessingSlice.name]: postprocessingSlice.reducer,
|
||||
[systemSlice.name]: systemSlice.reducer,
|
||||
[configSlice.name]: configSlice.reducer,
|
||||
|
@ -1,7 +1,8 @@
|
||||
import type { AppThunkDispatch, RootState } from 'app/store/store';
|
||||
import type { TypedUseSelectorHook } from 'react-redux';
|
||||
import { useDispatch, useSelector } from 'react-redux';
|
||||
import { useDispatch, useSelector, useStore } from 'react-redux';
|
||||
|
||||
// Use throughout your app instead of plain `useDispatch` and `useSelector`
|
||||
export const useAppDispatch = () => useDispatch<AppThunkDispatch>();
|
||||
export const useAppSelector: TypedUseSelectorHook<RootState> = useSelector;
|
||||
export const useAppStore = () => useStore<RootState>();
|
||||
|
2
invokeai/frontend/web/src/app/store/util.ts
Normal file
2
invokeai/frontend/web/src/app/store/util.ts
Normal file
@ -0,0 +1,2 @@
|
||||
export const EMPTY_ARRAY = [];
|
||||
export const EMPTY_OBJECT = {};
|
@ -8,7 +8,6 @@ import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
|
||||
import { selectDynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
|
||||
import { getShouldProcessPrompt } from 'features/dynamicPrompts/util/getShouldProcessPrompt';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
|
||||
import { selectSystemSlice } from 'features/system/store/systemSlice';
|
||||
@ -23,11 +22,10 @@ const selector = createMemoizedSelector(
|
||||
selectGenerationSlice,
|
||||
selectSystemSlice,
|
||||
selectNodesSlice,
|
||||
selectNodeTemplatesSlice,
|
||||
selectDynamicPromptsSlice,
|
||||
activeTabNameSelector,
|
||||
],
|
||||
(controlAdapters, generation, system, nodes, nodeTemplates, dynamicPrompts, activeTabName) => {
|
||||
(controlAdapters, generation, system, nodes, dynamicPrompts, activeTabName) => {
|
||||
const { initialImage, model, positivePrompt } = generation;
|
||||
|
||||
const { isConnected } = system;
|
||||
@ -54,7 +52,7 @@ const selector = createMemoizedSelector(
|
||||
return;
|
||||
}
|
||||
|
||||
const nodeTemplate = nodeTemplates.templates[node.data.type];
|
||||
const nodeTemplate = nodes.templates[node.data.type];
|
||||
|
||||
if (!nodeTemplate) {
|
||||
// Node type not found
|
||||
|
@ -7,8 +7,12 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import type { SelectInstance } from 'chakra-react-select';
|
||||
import { useBuildNode } from 'features/nodes/hooks/useBuildNode';
|
||||
import { addNodePopoverClosed, addNodePopoverOpened, nodeAdded } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import {
|
||||
addNodePopoverClosed,
|
||||
addNodePopoverOpened,
|
||||
nodeAdded,
|
||||
selectNodesSlice,
|
||||
} from 'features/nodes/store/nodesSlice';
|
||||
import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes';
|
||||
import { filter, map, memoize, some } from 'lodash-es';
|
||||
import type { KeyboardEventHandler } from 'react';
|
||||
@ -54,10 +58,10 @@ const AddNodePopover = () => {
|
||||
const fieldFilter = useAppSelector((s) => s.nodes.connectionStartFieldType);
|
||||
const handleFilter = useAppSelector((s) => s.nodes.connectionStartParams?.handleType);
|
||||
|
||||
const selector = createMemoizedSelector(selectNodeTemplatesSlice, (nodeTemplates) => {
|
||||
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
// If we have a connection in progress, we need to filter the node choices
|
||||
const filteredNodeTemplates = fieldFilter
|
||||
? filter(nodeTemplates.templates, (template) => {
|
||||
? filter(nodes.templates, (template) => {
|
||||
const handles = handleFilter === 'source' ? template.inputs : template.outputs;
|
||||
|
||||
return some(handles, (handle) => {
|
||||
@ -67,7 +71,7 @@ const AddNodePopover = () => {
|
||||
return validateSourceAndTargetTypes(sourceType, targetType);
|
||||
});
|
||||
})
|
||||
: map(nodeTemplates.templates);
|
||||
: map(nodes.templates);
|
||||
|
||||
const options: ComboboxOption[] = map(filteredNodeTemplates, (template) => {
|
||||
return {
|
||||
|
@ -1,10 +1,17 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectFieldOutputTemplate } from 'features/nodes/store/selectors';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
|
||||
import { getFieldColor } from './getEdgeColor';
|
||||
|
||||
const defaultReturnValue = {
|
||||
isSelected: false,
|
||||
shouldAnimate: false,
|
||||
stroke: colorTokenToCssVar('base.500'),
|
||||
};
|
||||
|
||||
export const makeEdgeSelector = (
|
||||
source: string,
|
||||
sourceHandleId: string | null | undefined,
|
||||
@ -12,14 +19,19 @@ export const makeEdgeSelector = (
|
||||
targetHandleId: string | null | undefined,
|
||||
selected?: boolean
|
||||
) =>
|
||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
createMemoizedSelector(selectNodesSlice, (nodes): { isSelected: boolean; shouldAnimate: boolean; stroke: string } => {
|
||||
const sourceNode = nodes.nodes.find((node) => node.id === source);
|
||||
const targetNode = nodes.nodes.find((node) => node.id === target);
|
||||
|
||||
const isInvocationToInvocationEdge = isInvocationNode(sourceNode) && isInvocationNode(targetNode);
|
||||
|
||||
const isSelected = sourceNode?.selected || targetNode?.selected || selected;
|
||||
const sourceType = isInvocationToInvocationEdge ? sourceNode?.data?.outputs[sourceHandleId || '']?.type : undefined;
|
||||
const isSelected = Boolean(sourceNode?.selected || targetNode?.selected || selected);
|
||||
if (!sourceNode || !sourceHandleId) {
|
||||
return defaultReturnValue;
|
||||
}
|
||||
|
||||
const outputFieldTemplate = selectFieldOutputTemplate(nodes, sourceNode.id, sourceHandleId);
|
||||
const sourceType = isInvocationToInvocationEdge ? outputFieldTemplate?.type : undefined;
|
||||
|
||||
const stroke = sourceType && nodes.shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500');
|
||||
|
||||
|
@ -1,6 +1,5 @@
|
||||
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
|
||||
import { useNodeData } from 'features/nodes/hooks/useNodeData';
|
||||
import { isInvocationNodeData } from 'features/nodes/types/invocation';
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import { map } from 'lodash-es';
|
||||
import type { CSSProperties } from 'react';
|
||||
import { memo, useMemo } from 'react';
|
||||
@ -13,7 +12,7 @@ interface Props {
|
||||
const hiddenHandleStyles: CSSProperties = { visibility: 'hidden' };
|
||||
|
||||
const InvocationNodeCollapsedHandles = ({ nodeId }: Props) => {
|
||||
const data = useNodeData(nodeId);
|
||||
const template = useNodeTemplate(nodeId);
|
||||
const { base600 } = useChakraThemeTokens();
|
||||
|
||||
const dummyHandleStyles: CSSProperties = useMemo(
|
||||
@ -37,7 +36,7 @@ const InvocationNodeCollapsedHandles = ({ nodeId }: Props) => {
|
||||
[dummyHandleStyles]
|
||||
);
|
||||
|
||||
if (!isInvocationNodeData(data)) {
|
||||
if (!template) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@ -45,14 +44,14 @@ const InvocationNodeCollapsedHandles = ({ nodeId }: Props) => {
|
||||
<>
|
||||
<Handle
|
||||
type="target"
|
||||
id={`${data.id}-collapsed-target`}
|
||||
id={`${nodeId}-collapsed-target`}
|
||||
isConnectable={false}
|
||||
position={Position.Left}
|
||||
style={collapsedTargetStyles}
|
||||
/>
|
||||
{map(data.inputs, (input) => (
|
||||
{map(template.inputs, (input) => (
|
||||
<Handle
|
||||
key={`${data.id}-${input.name}-collapsed-input-handle`}
|
||||
key={`${nodeId}-${input.name}-collapsed-input-handle`}
|
||||
type="target"
|
||||
id={input.name}
|
||||
isConnectable={false}
|
||||
@ -62,14 +61,14 @@ const InvocationNodeCollapsedHandles = ({ nodeId }: Props) => {
|
||||
))}
|
||||
<Handle
|
||||
type="source"
|
||||
id={`${data.id}-collapsed-source`}
|
||||
id={`${nodeId}-collapsed-source`}
|
||||
isConnectable={false}
|
||||
position={Position.Right}
|
||||
style={collapsedSourceStyles}
|
||||
/>
|
||||
{map(data.outputs, (output) => (
|
||||
{map(template.outputs, (output) => (
|
||||
<Handle
|
||||
key={`${data.id}-${output.name}-collapsed-output-handle`}
|
||||
key={`${nodeId}-${output.name}-collapsed-output-handle`}
|
||||
type="source"
|
||||
id={output.name}
|
||||
isConnectable={false}
|
||||
|
@ -1,7 +1,7 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import InvocationNode from 'features/nodes/components/flow/nodes/Invocation/InvocationNode';
|
||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import type { InvocationNodeData } from 'features/nodes/types/invocation';
|
||||
import { memo, useMemo } from 'react';
|
||||
import type { NodeProps } from 'reactflow';
|
||||
@ -13,7 +13,7 @@ const InvocationNodeWrapper = (props: NodeProps<InvocationNodeData>) => {
|
||||
const { id: nodeId, type, isOpen, label } = data;
|
||||
|
||||
const hasTemplateSelector = useMemo(
|
||||
() => createSelector(selectNodeTemplatesSlice, (nodeTemplates) => Boolean(nodeTemplates.templates[type])),
|
||||
() => createSelector(selectNodesSlice, (nodes) => Boolean(nodes.templates[type])),
|
||||
[type]
|
||||
);
|
||||
|
||||
|
@ -22,7 +22,7 @@ import FieldTooltipContent from './FieldTooltipContent';
|
||||
interface Props {
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
kind: 'input' | 'output';
|
||||
kind: 'inputs' | 'outputs';
|
||||
isMissingInput?: boolean;
|
||||
withTooltip?: boolean;
|
||||
}
|
||||
@ -58,7 +58,7 @@ const EditableFieldTitle = forwardRef((props: Props, ref) => {
|
||||
|
||||
return (
|
||||
<Tooltip
|
||||
label={withTooltip ? <FieldTooltipContent nodeId={nodeId} fieldName={fieldName} kind="input" /> : undefined}
|
||||
label={withTooltip ? <FieldTooltipContent nodeId={nodeId} fieldName={fieldName} kind="inputs" /> : undefined}
|
||||
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
|
||||
>
|
||||
<Editable
|
||||
|
@ -6,7 +6,7 @@ import { memo } from 'react';
|
||||
interface Props {
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
kind: 'input' | 'output';
|
||||
kind: 'inputs' | 'outputs';
|
||||
isMissingInput?: boolean;
|
||||
}
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
import { Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { useFieldInstance } from 'features/nodes/hooks/useFieldData';
|
||||
import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance';
|
||||
import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate';
|
||||
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
|
||||
import { isFieldInputInstance, isFieldInputTemplate } from 'features/nodes/types/field';
|
||||
@ -9,11 +9,11 @@ import { useTranslation } from 'react-i18next';
|
||||
interface Props {
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
kind: 'input' | 'output';
|
||||
kind: 'inputs' | 'outputs';
|
||||
}
|
||||
|
||||
const FieldTooltipContent = ({ nodeId, fieldName, kind }: Props) => {
|
||||
const field = useFieldInstance(nodeId, fieldName);
|
||||
const field = useFieldInputInstance(nodeId, fieldName);
|
||||
const fieldTemplate = useFieldTemplate(nodeId, fieldName, kind);
|
||||
const isInputTemplate = isFieldInputTemplate(fieldTemplate);
|
||||
const fieldTypeName = useFieldTypeName(fieldTemplate?.type);
|
||||
|
@ -25,7 +25,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
|
||||
const [isHovered, setIsHovered] = useState(false);
|
||||
|
||||
const { isConnected, isConnectionInProgress, isConnectionStartField, connectionError, shouldDim } =
|
||||
useConnectionState({ nodeId, fieldName, kind: 'input' });
|
||||
useConnectionState({ nodeId, fieldName, kind: 'inputs' });
|
||||
|
||||
const isMissingInput = useMemo(() => {
|
||||
if (!fieldTemplate) {
|
||||
@ -76,7 +76,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
|
||||
<EditableFieldTitle
|
||||
nodeId={nodeId}
|
||||
fieldName={fieldName}
|
||||
kind="input"
|
||||
kind="inputs"
|
||||
isMissingInput={isMissingInput}
|
||||
withTooltip
|
||||
/>
|
||||
@ -101,7 +101,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
|
||||
<EditableFieldTitle
|
||||
nodeId={nodeId}
|
||||
fieldName={fieldName}
|
||||
kind="input"
|
||||
kind="inputs"
|
||||
isMissingInput={isMissingInput}
|
||||
withTooltip
|
||||
/>
|
||||
|
@ -1,6 +1,5 @@
|
||||
import { Box, Text } from '@invoke-ai/ui-library';
|
||||
import { useFieldInstance } from 'features/nodes/hooks/useFieldData';
|
||||
import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate';
|
||||
import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance';
|
||||
import { useFieldInputTemplate } from 'features/nodes/hooks/useFieldInputTemplate';
|
||||
import {
|
||||
isBoardFieldInputInstance,
|
||||
isBoardFieldInputTemplate,
|
||||
@ -38,7 +37,6 @@ import {
|
||||
isVAEModelFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import BoardFieldInputComponent from './inputs/BoardFieldInputComponent';
|
||||
import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent';
|
||||
@ -63,17 +61,8 @@ type InputFieldProps = {
|
||||
};
|
||||
|
||||
const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
const { t } = useTranslation();
|
||||
const fieldInstance = useFieldInstance(nodeId, fieldName);
|
||||
const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input');
|
||||
|
||||
if (fieldTemplate?.fieldKind === 'output') {
|
||||
return (
|
||||
<Box p={2}>
|
||||
{t('nodes.outputFieldInInput')}: {fieldInstance?.type.name}
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
const fieldInstance = useFieldInputInstance(nodeId, fieldName);
|
||||
const fieldTemplate = useFieldInputTemplate(nodeId, fieldName);
|
||||
|
||||
if (isStringFieldInputInstance(fieldInstance) && isStringFieldInputTemplate(fieldTemplate)) {
|
||||
return <StringFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
@ -141,18 +130,10 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
return <SchedulerFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (fieldInstance && fieldTemplate) {
|
||||
if (fieldTemplate) {
|
||||
// Fallback for when there is no component for the type
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Box p={1}>
|
||||
<Text fontSize="sm" fontWeight="semibold" color="error.300">
|
||||
{t('nodes.unknownFieldType', { type: fieldInstance?.type.name })}
|
||||
</Text>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(InputFieldRenderer);
|
||||
|
@ -62,7 +62,7 @@ const LinearViewField = ({ nodeId, fieldName }: Props) => {
|
||||
/>
|
||||
<Flex flexDir="column" w="full">
|
||||
<Flex alignItems="center">
|
||||
<EditableFieldTitle nodeId={nodeId} fieldName={fieldName} kind="input" />
|
||||
<EditableFieldTitle nodeId={nodeId} fieldName={fieldName} kind="inputs" />
|
||||
<Spacer />
|
||||
{isValueChanged && (
|
||||
<IconButton
|
||||
@ -75,7 +75,7 @@ const LinearViewField = ({ nodeId, fieldName }: Props) => {
|
||||
/>
|
||||
)}
|
||||
<Tooltip
|
||||
label={<FieldTooltipContent nodeId={nodeId} fieldName={fieldName} kind="input" />}
|
||||
label={<FieldTooltipContent nodeId={nodeId} fieldName={fieldName} kind="inputs" />}
|
||||
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
|
||||
placement="top"
|
||||
>
|
||||
|
@ -1,6 +1,5 @@
|
||||
import { Flex, FormControl, FormLabel, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useConnectionState } from 'features/nodes/hooks/useConnectionState';
|
||||
import { useFieldOutputInstance } from 'features/nodes/hooks/useFieldOutputInstance';
|
||||
import { useFieldOutputTemplate } from 'features/nodes/hooks/useFieldOutputTemplate';
|
||||
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
|
||||
import type { PropsWithChildren } from 'react';
|
||||
@ -18,18 +17,17 @@ interface Props {
|
||||
const OutputField = ({ nodeId, fieldName }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const fieldTemplate = useFieldOutputTemplate(nodeId, fieldName);
|
||||
const fieldInstance = useFieldOutputInstance(nodeId, fieldName);
|
||||
|
||||
const { isConnected, isConnectionInProgress, isConnectionStartField, connectionError, shouldDim } =
|
||||
useConnectionState({ nodeId, fieldName, kind: 'output' });
|
||||
useConnectionState({ nodeId, fieldName, kind: 'outputs' });
|
||||
|
||||
if (!fieldTemplate || !fieldInstance) {
|
||||
if (!fieldTemplate) {
|
||||
return (
|
||||
<OutputFieldWrapper shouldDim={shouldDim}>
|
||||
<FormControl alignItems="stretch" justifyContent="space-between" gap={2} h="full" w="full">
|
||||
<FormLabel display="flex" alignItems="center" h="full" color="error.300" mb={0} px={1} gap={2}>
|
||||
{t('nodes.unknownOutput', {
|
||||
name: fieldTemplate?.title ?? fieldName,
|
||||
name: fieldName,
|
||||
})}
|
||||
</FormLabel>
|
||||
</FormControl>
|
||||
@ -40,7 +38,7 @@ const OutputField = ({ nodeId, fieldName }: Props) => {
|
||||
return (
|
||||
<OutputFieldWrapper shouldDim={shouldDim}>
|
||||
<Tooltip
|
||||
label={<FieldTooltipContent nodeId={nodeId} fieldName={fieldName} kind="output" />}
|
||||
label={<FieldTooltipContent nodeId={nodeId} fieldName={fieldName} kind="outputs" />}
|
||||
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
|
||||
placement="top"
|
||||
shouldWrapChildren
|
||||
|
@ -6,19 +6,18 @@ import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableCon
|
||||
import NotesTextarea from 'features/nodes/components/flow/nodes/Invocation/NotesTextarea';
|
||||
import { useNodeNeedsUpdate } from 'features/nodes/hooks/useNodeNeedsUpdate';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import EditableNodeTitle from './details/EditableNodeTitle';
|
||||
|
||||
const selector = createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => {
|
||||
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1];
|
||||
|
||||
const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId);
|
||||
|
||||
const lastSelectedNodeTemplate = lastSelectedNode ? nodeTemplates.templates[lastSelectedNode.data.type] : undefined;
|
||||
const lastSelectedNodeTemplate = lastSelectedNode ? nodes.templates[lastSelectedNode.data.type] : undefined;
|
||||
|
||||
if (!isInvocationNode(lastSelectedNode) || !lastSelectedNodeTemplate) {
|
||||
return;
|
||||
|
@ -5,7 +5,6 @@ import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@ -14,12 +13,12 @@ import type { AnyResult } from 'services/events/types';
|
||||
|
||||
import ImageOutputPreview from './outputs/ImageOutputPreview';
|
||||
|
||||
const selector = createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => {
|
||||
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1];
|
||||
|
||||
const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId);
|
||||
|
||||
const lastSelectedNodeTemplate = lastSelectedNode ? nodeTemplates.templates[lastSelectedNode.data.type] : undefined;
|
||||
const lastSelectedNodeTemplate = lastSelectedNode ? nodes.templates[lastSelectedNode.data.type] : undefined;
|
||||
|
||||
const nes = nodes.nodeExecutionStates[lastSelectedNodeId ?? '__UNKNOWN_NODE__'];
|
||||
|
||||
|
@ -3,16 +3,15 @@ import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const selector = createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => {
|
||||
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1];
|
||||
|
||||
const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId);
|
||||
|
||||
const lastSelectedNodeTemplate = lastSelectedNode ? nodeTemplates.templates[lastSelectedNode.data.type] : undefined;
|
||||
const lastSelectedNodeTemplate = lastSelectedNode ? nodes.templates[lastSelectedNode.data.type] : undefined;
|
||||
|
||||
return {
|
||||
template: lastSelectedNodeTemplate,
|
||||
|
@ -1,26 +1,22 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { EMPTY_ARRAY } from 'app/store/util';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectNodeTemplate } from 'features/nodes/store/selectors';
|
||||
import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames';
|
||||
import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate';
|
||||
import { keys, map } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useAnyOrDirectInputFieldNames = (nodeId: string) => {
|
||||
export const useAnyOrDirectInputFieldNames = (nodeId: string): string[] => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return [];
|
||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
const template = selectNodeTemplate(nodes, nodeId);
|
||||
if (!template) {
|
||||
return EMPTY_ARRAY;
|
||||
}
|
||||
const nodeTemplate = nodeTemplates.templates[node.data.type];
|
||||
if (!nodeTemplate) {
|
||||
return [];
|
||||
}
|
||||
const fields = map(nodeTemplate.inputs).filter(
|
||||
const fields = map(template.inputs).filter(
|
||||
(field) =>
|
||||
(['any', 'direct'].includes(field.input) || field.type.isCollectionOrScalar) &&
|
||||
keys(TEMPLATE_BUILDER_MAP).includes(field.type.name)
|
||||
|
@ -13,7 +13,7 @@ export const SHARED_NODE_PROPERTIES: Partial<Node> = {
|
||||
};
|
||||
|
||||
export const useBuildNode = () => {
|
||||
const nodeTemplates = useAppSelector((s) => s.nodeTemplates.templates);
|
||||
const nodeTemplates = useAppSelector((s) => s.nodes.templates);
|
||||
|
||||
const flow = useReactFlow();
|
||||
|
||||
|
@ -1,28 +1,24 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { EMPTY_ARRAY } from 'app/store/util';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectNodeTemplate } from 'features/nodes/store/selectors';
|
||||
import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames';
|
||||
import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate';
|
||||
import { keys, map } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useConnectionInputFieldNames = (nodeId: string) => {
|
||||
export const useConnectionInputFieldNames = (nodeId: string): string[] => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return [];
|
||||
}
|
||||
const nodeTemplate = nodeTemplates.templates[node.data.type];
|
||||
if (!nodeTemplate) {
|
||||
return [];
|
||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
const template = selectNodeTemplate(nodes, nodeId);
|
||||
if (!template) {
|
||||
return EMPTY_ARRAY;
|
||||
}
|
||||
|
||||
// get the visible fields
|
||||
const fields = map(nodeTemplate.inputs).filter(
|
||||
const fields = map(template.inputs).filter(
|
||||
(field) =>
|
||||
(field.input === 'connection' && !field.type.isCollectionOrScalar) ||
|
||||
!keys(TEMPLATE_BUILDER_MAP).includes(field.type.name)
|
||||
|
@ -14,7 +14,7 @@ const selectIsConnectionInProgress = createSelector(
|
||||
export type UseConnectionStateProps = {
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
kind: 'input' | 'output';
|
||||
kind: 'inputs' | 'outputs';
|
||||
};
|
||||
|
||||
export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionStateProps) => {
|
||||
@ -26,8 +26,8 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta
|
||||
Boolean(
|
||||
nodes.edges.filter((edge) => {
|
||||
return (
|
||||
(kind === 'input' ? edge.target : edge.source) === nodeId &&
|
||||
(kind === 'input' ? edge.targetHandle : edge.sourceHandle) === fieldName
|
||||
(kind === 'inputs' ? edge.target : edge.source) === nodeId &&
|
||||
(kind === 'inputs' ? edge.targetHandle : edge.sourceHandle) === fieldName
|
||||
);
|
||||
}).length
|
||||
)
|
||||
@ -36,7 +36,7 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta
|
||||
);
|
||||
|
||||
const selectConnectionError = useMemo(
|
||||
() => makeConnectionErrorSelector(nodeId, fieldName, kind === 'input' ? 'target' : 'source', fieldType),
|
||||
() => makeConnectionErrorSelector(nodeId, fieldName, kind === 'inputs' ? 'target' : 'source', fieldType),
|
||||
[nodeId, fieldName, kind, fieldType]
|
||||
);
|
||||
|
||||
@ -46,7 +46,7 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta
|
||||
Boolean(
|
||||
nodes.connectionStartParams?.nodeId === nodeId &&
|
||||
nodes.connectionStartParams?.handleId === fieldName &&
|
||||
nodes.connectionStartParams?.handleType === { input: 'target', output: 'source' }[kind]
|
||||
nodes.connectionStartParams?.handleType === { inputs: 'target', outputs: 'source' }[kind]
|
||||
)
|
||||
),
|
||||
[fieldName, kind, nodeId]
|
||||
|
@ -2,23 +2,19 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { compareVersions } from 'compare-versions';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectNodeData, selectNodeTemplate } from 'features/nodes/store/selectors';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useDoNodeVersionsMatch = (nodeId: string) => {
|
||||
export const useDoNodeVersionsMatch = (nodeId: string): boolean => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
createSelector(selectNodesSlice, (nodes) => {
|
||||
const data = selectNodeData(nodes, nodeId);
|
||||
const template = selectNodeTemplate(nodes, nodeId);
|
||||
if (!template?.version || !data?.version) {
|
||||
return false;
|
||||
}
|
||||
const nodeTemplate = nodeTemplates.templates[node?.data.type ?? ''];
|
||||
if (!nodeTemplate?.version || !node.data?.version) {
|
||||
return false;
|
||||
}
|
||||
return compareVersions(nodeTemplate.version, node.data.version) === 0;
|
||||
return compareVersions(template.version, data.version) === 0;
|
||||
}),
|
||||
[nodeId]
|
||||
);
|
||||
|
@ -1,18 +1,18 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectNodeData } from 'features/nodes/store/selectors';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useDoesInputHaveValue = (nodeId: string, fieldName: string) => {
|
||||
export const useDoesInputHaveValue = (nodeId: string, fieldName: string): boolean => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
const data = selectNodeData(nodes, nodeId);
|
||||
if (!data) {
|
||||
return false;
|
||||
}
|
||||
return node?.data.inputs[fieldName]?.value !== undefined;
|
||||
return data.inputs[fieldName]?.value !== undefined;
|
||||
}),
|
||||
[fieldName, nodeId]
|
||||
);
|
||||
|
@ -1,23 +0,0 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useFieldInstance = (nodeId: string, fieldName: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
return node?.data.inputs[fieldName];
|
||||
}),
|
||||
[fieldName, nodeId]
|
||||
);
|
||||
|
||||
const fieldData = useAppSelector(selector);
|
||||
|
||||
return fieldData;
|
||||
};
|
@ -1,23 +1,20 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectFieldInputInstance } from 'features/nodes/store/selectors';
|
||||
import type { FieldInputInstance } from 'features/nodes/types/field';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useFieldInputInstance = (nodeId: string, fieldName: string) => {
|
||||
export const useFieldInputInstance = (nodeId: string, fieldName: string): FieldInputInstance | null => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
return node.data.inputs[fieldName];
|
||||
return selectFieldInputInstance(nodes, nodeId, fieldName);
|
||||
}),
|
||||
[fieldName, nodeId]
|
||||
);
|
||||
|
||||
const fieldTemplate = useAppSelector(selector);
|
||||
const fieldData = useAppSelector(selector);
|
||||
|
||||
return fieldTemplate;
|
||||
return fieldData;
|
||||
};
|
||||
|
@ -1,21 +1,16 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectFieldInputTemplate } from 'features/nodes/store/selectors';
|
||||
import type { FieldInput } from 'features/nodes/types/field';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useFieldInputKind = (nodeId: string, fieldName: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
const nodeTemplate = nodeTemplates.templates[node?.data.type ?? ''];
|
||||
const fieldTemplate = nodeTemplate?.inputs[fieldName];
|
||||
return fieldTemplate?.input;
|
||||
createSelector(selectNodesSlice, (nodes): FieldInput | null => {
|
||||
const template = selectFieldInputTemplate(nodes, nodeId, fieldName);
|
||||
return template?.input ?? null;
|
||||
}),
|
||||
[fieldName, nodeId]
|
||||
);
|
||||
|
@ -1,20 +1,15 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectFieldInputTemplate } from 'features/nodes/store/selectors';
|
||||
import type { FieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useFieldInputTemplate = (nodeId: string, fieldName: string) => {
|
||||
export const useFieldInputTemplate = (nodeId: string, fieldName: string): FieldInputTemplate | null => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
const nodeTemplate = nodeTemplates.templates[node?.data.type ?? ''];
|
||||
return nodeTemplate?.inputs[fieldName];
|
||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
return selectFieldInputTemplate(nodes, nodeId, fieldName);
|
||||
}),
|
||||
[fieldName, nodeId]
|
||||
);
|
||||
|
@ -1,18 +1,14 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectFieldInputInstance } from 'features/nodes/store/selectors';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useFieldLabel = (nodeId: string, fieldName: string) => {
|
||||
export const useFieldLabel = (nodeId: string, fieldName: string): string | null => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(selectNodesSlice, (nodes) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
return node?.data.inputs[fieldName]?.label;
|
||||
return selectFieldInputInstance(nodes, nodeId, fieldName)?.label ?? null;
|
||||
}),
|
||||
[fieldName, nodeId]
|
||||
);
|
||||
|
@ -1,23 +0,0 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useFieldOutputInstance = (nodeId: string, fieldName: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
return node.data.outputs[fieldName];
|
||||
}),
|
||||
[fieldName, nodeId]
|
||||
);
|
||||
|
||||
const fieldTemplate = useAppSelector(selector);
|
||||
|
||||
return fieldTemplate;
|
||||
};
|
@ -1,20 +1,15 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectFieldOutputTemplate } from 'features/nodes/store/selectors';
|
||||
import type { FieldOutputTemplate } from 'features/nodes/types/field';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useFieldOutputTemplate = (nodeId: string, fieldName: string) => {
|
||||
export const useFieldOutputTemplate = (nodeId: string, fieldName: string): FieldOutputTemplate | null => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
const nodeTemplate = nodeTemplates.templates[node?.data.type ?? ''];
|
||||
return nodeTemplate?.outputs[fieldName];
|
||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
return selectFieldOutputTemplate(nodes, nodeId, fieldName);
|
||||
}),
|
||||
[fieldName, nodeId]
|
||||
);
|
||||
|
@ -1,21 +1,22 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { KIND_MAP } from 'features/nodes/types/constants';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectFieldInputTemplate, selectFieldOutputTemplate } from 'features/nodes/store/selectors';
|
||||
import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useFieldTemplate = (nodeId: string, fieldName: string, kind: 'input' | 'output') => {
|
||||
export const useFieldTemplate = (
|
||||
nodeId: string,
|
||||
fieldName: string,
|
||||
kind: 'inputs' | 'outputs'
|
||||
): FieldInputTemplate | FieldOutputTemplate | null => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
if (kind === 'inputs') {
|
||||
return selectFieldInputTemplate(nodes, nodeId, fieldName);
|
||||
}
|
||||
const nodeTemplate = nodeTemplates.templates[node?.data.type ?? ''];
|
||||
return nodeTemplate?.[KIND_MAP[kind]][fieldName];
|
||||
return selectFieldOutputTemplate(nodes, nodeId, fieldName);
|
||||
}),
|
||||
[fieldName, kind, nodeId]
|
||||
);
|
||||
|
@ -1,21 +1,17 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { KIND_MAP } from 'features/nodes/types/constants';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectFieldInputTemplate, selectFieldOutputTemplate } from 'features/nodes/store/selectors';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useFieldTemplateTitle = (nodeId: string, fieldName: string, kind: 'input' | 'output') => {
|
||||
export const useFieldTemplateTitle = (nodeId: string, fieldName: string, kind: 'inputs' | 'outputs'): string | null => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
createSelector(selectNodesSlice, (nodes) => {
|
||||
if (kind === 'inputs') {
|
||||
return selectFieldInputTemplate(nodes, nodeId, fieldName)?.title ?? null;
|
||||
}
|
||||
const nodeTemplate = nodeTemplates.templates[node?.data.type ?? ''];
|
||||
return nodeTemplate?.[KIND_MAP[kind]][fieldName]?.title;
|
||||
return selectFieldOutputTemplate(nodes, nodeId, fieldName)?.title ?? null;
|
||||
}),
|
||||
[fieldName, kind, nodeId]
|
||||
);
|
||||
|
@ -1,20 +1,18 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { KIND_MAP } from 'features/nodes/types/constants';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectFieldInputTemplate, selectFieldOutputTemplate } from 'features/nodes/store/selectors';
|
||||
import type { FieldType } from 'features/nodes/types/field';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useFieldType = (nodeId: string, fieldName: string, kind: 'input' | 'output') => {
|
||||
export const useFieldType = (nodeId: string, fieldName: string, kind: 'inputs' | 'outputs'): FieldType | null => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
if (kind === 'inputs') {
|
||||
return selectFieldInputTemplate(nodes, nodeId, fieldName)?.type ?? null;
|
||||
}
|
||||
const field = node.data[KIND_MAP[kind]][fieldName];
|
||||
return field?.type;
|
||||
return selectFieldOutputTemplate(nodes, nodeId, fieldName)?.type ?? null;
|
||||
}),
|
||||
[fieldName, kind, nodeId]
|
||||
);
|
||||
|
@ -1,13 +1,12 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate';
|
||||
|
||||
const selector = createSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) =>
|
||||
const selector = createSelector(selectNodesSlice, (nodes) =>
|
||||
nodes.nodes.filter(isInvocationNode).some((node) => {
|
||||
const template = nodeTemplates.templates[node.data.type];
|
||||
const template = nodes.templates[node.data.type];
|
||||
if (!template) {
|
||||
return false;
|
||||
}
|
||||
|
@ -1,24 +1,21 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectNodeTemplate } from 'features/nodes/store/selectors';
|
||||
import { some } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useHasImageOutput = (nodeId: string) => {
|
||||
export const useHasImageOutput = (nodeId: string): boolean => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return false;
|
||||
}
|
||||
const template = selectNodeTemplate(nodes, nodeId);
|
||||
return some(
|
||||
node.data.outputs,
|
||||
template?.outputs,
|
||||
(output) =>
|
||||
output.type.name === 'ImageField' &&
|
||||
// the image primitive node (node type "image") does not actually save the image, do not show the image-saving checkboxes
|
||||
node.data.type !== 'image'
|
||||
template?.type !== 'image'
|
||||
);
|
||||
}),
|
||||
[nodeId]
|
||||
|
@ -1,18 +1,14 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectNodeData } from 'features/nodes/store/selectors';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useIsIntermediate = (nodeId: string) => {
|
||||
export const useIsIntermediate = (nodeId: string): boolean => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(selectNodesSlice, (nodes) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return false;
|
||||
}
|
||||
return node.data.isIntermediate;
|
||||
return selectNodeData(nodes, nodeId)?.isIntermediate ?? false;
|
||||
}),
|
||||
[nodeId]
|
||||
);
|
||||
|
@ -1,11 +1,10 @@
|
||||
// TODO: enable this at some point
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
|
||||
import { getIsGraphAcyclic } from 'features/nodes/store/util/getIsGraphAcyclic';
|
||||
import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes';
|
||||
import type { InvocationNodeData } from 'features/nodes/types/invocation';
|
||||
import { useCallback } from 'react';
|
||||
import type { Connection, Node } from 'reactflow';
|
||||
import { useReactFlow } from 'reactflow';
|
||||
|
||||
/**
|
||||
* NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts`
|
||||
@ -13,39 +12,34 @@ import { useReactFlow } from 'reactflow';
|
||||
*/
|
||||
|
||||
export const useIsValidConnection = () => {
|
||||
const flow = useReactFlow();
|
||||
const store = useAppStore();
|
||||
const shouldValidateGraph = useAppSelector((s) => s.nodes.shouldValidateGraph);
|
||||
const isValidConnection = useCallback(
|
||||
({ source, sourceHandle, target, targetHandle }: Connection): boolean => {
|
||||
const edges = flow.getEdges();
|
||||
const nodes = flow.getNodes();
|
||||
// Connection must have valid targets
|
||||
if (!(source && sourceHandle && target && targetHandle)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Find the source and target nodes
|
||||
const sourceNode = flow.getNode(source) as Node<InvocationNodeData>;
|
||||
const targetNode = flow.getNode(target) as Node<InvocationNodeData>;
|
||||
|
||||
// Conditional guards against undefined nodes/handles
|
||||
if (!(sourceNode && targetNode && sourceNode.data && targetNode.data)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const sourceField = sourceNode.data.outputs[sourceHandle];
|
||||
const targetField = targetNode.data.inputs[targetHandle];
|
||||
|
||||
if (!sourceField || !targetField) {
|
||||
// something has gone terribly awry
|
||||
return false;
|
||||
}
|
||||
|
||||
if (source === target) {
|
||||
// Don't allow nodes to connect to themselves, even if validation is disabled
|
||||
return false;
|
||||
}
|
||||
|
||||
const state = store.getState();
|
||||
const { nodes, edges, templates } = state.nodes;
|
||||
|
||||
// Find the source and target nodes
|
||||
const sourceNode = nodes.find((node) => node.id === source) as Node<InvocationNodeData>;
|
||||
const targetNode = nodes.find((node) => node.id === target) as Node<InvocationNodeData>;
|
||||
const sourceFieldTemplate = templates[sourceNode.data.type]?.outputs[sourceHandle];
|
||||
const targetFieldTemplate = templates[targetNode.data.type]?.inputs[targetHandle];
|
||||
|
||||
// Conditional guards against undefined nodes/handles
|
||||
if (!(sourceFieldTemplate && targetFieldTemplate)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!shouldValidateGraph) {
|
||||
// manual override!
|
||||
return true;
|
||||
@ -69,20 +63,20 @@ export const useIsValidConnection = () => {
|
||||
return edge.target === target && edge.targetHandle === targetHandle;
|
||||
}) &&
|
||||
// except CollectionItem inputs can have multiples
|
||||
targetField.type.name !== 'CollectionItemField'
|
||||
targetFieldTemplate.type.name !== 'CollectionItemField'
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Must use the originalType here if it exists
|
||||
if (!validateSourceAndTargetTypes(sourceField.type, targetField.type)) {
|
||||
if (!validateSourceAndTargetTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Graphs much be acyclic (no loops!)
|
||||
return getIsGraphAcyclic(source, target, nodes, edges);
|
||||
},
|
||||
[flow, shouldValidateGraph]
|
||||
[shouldValidateGraph, store]
|
||||
);
|
||||
|
||||
return isValidConnection;
|
||||
|
@ -1,20 +1,15 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectNodeTemplate } from 'features/nodes/store/selectors';
|
||||
import type { Classification } from 'features/nodes/types/common';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useNodeClassification = (nodeId: string) => {
|
||||
export const useNodeClassification = (nodeId: string): Classification | null => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return false;
|
||||
}
|
||||
const nodeTemplate = nodeTemplates.templates[node?.data.type ?? ''];
|
||||
return nodeTemplate?.classification;
|
||||
createSelector(selectNodesSlice, (nodes) => {
|
||||
return selectNodeTemplate(nodes, nodeId)?.classification ?? null;
|
||||
}),
|
||||
[nodeId]
|
||||
);
|
||||
|
@ -1,14 +1,15 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeData } from 'features/nodes/store/selectors';
|
||||
import type { InvocationNodeData } from 'features/nodes/types/invocation';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useNodeData = (nodeId: string) => {
|
||||
export const useNodeData = (nodeId: string): InvocationNodeData | null => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
return node?.data;
|
||||
return selectNodeData(nodes, nodeId);
|
||||
}),
|
||||
[nodeId]
|
||||
);
|
||||
|
@ -1,19 +1,14 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectNodeData } from 'features/nodes/store/selectors';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useNodeLabel = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(selectNodesSlice, (nodes) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return node.data.label;
|
||||
return selectNodeData(nodes, nodeId)?.label ?? null;
|
||||
}),
|
||||
[nodeId]
|
||||
);
|
||||
|
@ -1,21 +1,20 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectInvocationNode, selectNodeTemplate } from 'features/nodes/store/selectors';
|
||||
import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useNodeNeedsUpdate = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
const template = nodeTemplates.templates[node?.data.type ?? ''];
|
||||
if (isInvocationNode(node) && template) {
|
||||
return getNeedsUpdate(node, template);
|
||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
const node = selectInvocationNode(nodes, nodeId);
|
||||
const template = selectNodeTemplate(nodes, nodeId);
|
||||
if (!node || !template) {
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
return getNeedsUpdate(node, template);
|
||||
}),
|
||||
[nodeId]
|
||||
);
|
||||
|
@ -1,18 +1,14 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectNodeData } from 'features/nodes/store/selectors';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useNodePack = (nodeId: string) => {
|
||||
export const useNodePack = (nodeId: string): string | null => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(selectNodesSlice, (nodes) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return false;
|
||||
}
|
||||
return node.data.nodePack;
|
||||
return selectNodeData(nodes, nodeId)?.nodePack ?? null;
|
||||
}),
|
||||
[nodeId]
|
||||
);
|
||||
|
@ -1,16 +1,15 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { selectNodeTemplate } from 'features/nodes/store/selectors';
|
||||
import type { InvocationTemplate } from 'features/nodes/types/invocation';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useNodeTemplate = (nodeId: string) => {
|
||||
export const useNodeTemplate = (nodeId: string): InvocationTemplate | null => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
const nodeTemplate = nodeTemplates.templates[node?.data.type ?? ''];
|
||||
return nodeTemplate;
|
||||
createSelector(selectNodesSlice, (nodes) => {
|
||||
return selectNodeTemplate(nodes, nodeId);
|
||||
}),
|
||||
[nodeId]
|
||||
);
|
||||
|
@ -1,14 +1,14 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import type { InvocationTemplate } from 'features/nodes/types/invocation';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useNodeTemplateByType = (type: string) => {
|
||||
export const useNodeTemplateByType = (type: string): InvocationTemplate | null => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodeTemplatesSlice, (nodeTemplates): InvocationTemplate | undefined => {
|
||||
return nodeTemplates.templates[type];
|
||||
createSelector(selectNodesSlice, (nodes) => {
|
||||
return nodes.templates[type] ?? null;
|
||||
}),
|
||||
[type]
|
||||
);
|
||||
|
@ -1,21 +1,14 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectNodeTemplate } from 'features/nodes/store/selectors';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useNodeTemplateTitle = (nodeId: string) => {
|
||||
export const useNodeTemplateTitle = (nodeId: string): string | null => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return false;
|
||||
}
|
||||
const nodeTemplate = node ? nodeTemplates.templates[node.data.type] : undefined;
|
||||
|
||||
return nodeTemplate?.title;
|
||||
createSelector(selectNodesSlice, (nodes) => {
|
||||
return selectNodeTemplate(nodes, nodeId)?.title ?? null;
|
||||
}),
|
||||
[nodeId]
|
||||
);
|
||||
|
@ -1,8 +1,8 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { EMPTY_ARRAY } from 'app/store/util';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectNodeTemplate } from 'features/nodes/store/selectors';
|
||||
import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames';
|
||||
import { map } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
@ -10,17 +10,13 @@ import { useMemo } from 'react';
|
||||
export const useOutputFieldNames = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return [];
|
||||
}
|
||||
const nodeTemplate = nodeTemplates.templates[node.data.type];
|
||||
if (!nodeTemplate) {
|
||||
return [];
|
||||
createSelector(selectNodesSlice, (nodes) => {
|
||||
const template = selectNodeTemplate(nodes, nodeId);
|
||||
if (!template) {
|
||||
return EMPTY_ARRAY;
|
||||
}
|
||||
|
||||
return getSortedFilteredFieldNames(map(nodeTemplate.outputs));
|
||||
return getSortedFilteredFieldNames(map(template.outputs));
|
||||
}),
|
||||
[nodeId]
|
||||
);
|
||||
|
@ -1,18 +1,14 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectNodeData } from 'features/nodes/store/selectors';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useUseCache = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(selectNodesSlice, (nodes) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return false;
|
||||
}
|
||||
return node.data.useCache;
|
||||
return selectNodeData(nodes, nodeId)?.useCache ?? false;
|
||||
}),
|
||||
[nodeId]
|
||||
);
|
||||
|
@ -2,14 +2,14 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectWorkflowSlice } from 'features/nodes/store/workflowSlice';
|
||||
import type { WorkflowV2 } from 'features/nodes/types/workflow';
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import type { BuildWorkflowArg } from 'features/nodes/util/workflow/buildWorkflow';
|
||||
import { buildWorkflowFast } from 'features/nodes/util/workflow/buildWorkflow';
|
||||
import { debounce } from 'lodash-es';
|
||||
import { atom } from 'nanostores';
|
||||
import { useEffect } from 'react';
|
||||
|
||||
export const $builtWorkflow = atom<WorkflowV2 | null>(null);
|
||||
export const $builtWorkflow = atom<WorkflowV3 | null>(null);
|
||||
|
||||
const debouncedBuildWorkflow = debounce((arg: BuildWorkflowArg) => {
|
||||
$builtWorkflow.set(buildWorkflowFast(arg));
|
||||
|
@ -1,5 +1,5 @@
|
||||
import { createAction, isAnyOf } from '@reduxjs/toolkit';
|
||||
import type { WorkflowV2 } from 'features/nodes/types/workflow';
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import type { Graph } from 'services/api/types';
|
||||
|
||||
export const textToImageGraphBuilt = createAction<Graph>('nodes/textToImageGraphBuilt');
|
||||
@ -21,4 +21,4 @@ export const workflowLoadRequested = createAction<{
|
||||
|
||||
export const updateAllNodesRequested = createAction('nodes/updateAllNodesRequested');
|
||||
|
||||
export const workflowLoaded = createAction<WorkflowV2>('workflow/workflowLoaded');
|
||||
export const workflowLoaded = createAction<WorkflowV3>('workflow/workflowLoaded');
|
||||
|
@ -1,24 +0,0 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { InvocationTemplate } from 'features/nodes/types/invocation';
|
||||
|
||||
import type { NodeTemplatesState } from './types';
|
||||
|
||||
export const initialNodeTemplatesState: NodeTemplatesState = {
|
||||
templates: {},
|
||||
};
|
||||
|
||||
export const nodesTemplatesSlice = createSlice({
|
||||
name: 'nodeTemplates',
|
||||
initialState: initialNodeTemplatesState,
|
||||
reducers: {
|
||||
nodeTemplatesBuilt: (state, action: PayloadAction<Record<string, InvocationTemplate>>) => {
|
||||
state.templates = action.payload;
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
export const { nodeTemplatesBuilt } = nodesTemplatesSlice.actions;
|
||||
|
||||
export const selectNodeTemplatesSlice = (state: RootState) => state.nodeTemplates;
|
@ -42,7 +42,7 @@ import {
|
||||
zT2IAdapterModelFieldValue,
|
||||
zVAEModelFieldValue,
|
||||
} from 'features/nodes/types/field';
|
||||
import type { AnyNode, NodeExecutionState } from 'features/nodes/types/invocation';
|
||||
import type { AnyNode, InvocationTemplate, NodeExecutionState } from 'features/nodes/types/invocation';
|
||||
import { isInvocationNode, isNotesNode, zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { cloneDeep, forEach } from 'lodash-es';
|
||||
import type {
|
||||
@ -92,6 +92,7 @@ export const initialNodesState: NodesState = {
|
||||
_version: 1,
|
||||
nodes: [],
|
||||
edges: [],
|
||||
templates: {},
|
||||
connectionStartParams: null,
|
||||
connectionStartFieldType: null,
|
||||
connectionMade: false,
|
||||
@ -190,6 +191,7 @@ export const nodesSlice = createSlice({
|
||||
node,
|
||||
state.nodes,
|
||||
state.edges,
|
||||
state.templates,
|
||||
nodeId,
|
||||
handleId,
|
||||
handleType,
|
||||
@ -224,12 +226,12 @@ export const nodesSlice = createSlice({
|
||||
if (!nodeId || !handleId) {
|
||||
return;
|
||||
}
|
||||
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
|
||||
const node = state.nodes?.[nodeIndex];
|
||||
const node = state.nodes.find((n) => n.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
const field = handleType === 'source' ? node.data.outputs[handleId] : node.data.inputs[handleId];
|
||||
const template = state.templates[node.data.type];
|
||||
const field = handleType === 'source' ? template?.outputs[handleId] : template?.inputs[handleId];
|
||||
state.connectionStartFieldType = field?.type ?? null;
|
||||
},
|
||||
connectionMade: (state, action: PayloadAction<Connection>) => {
|
||||
@ -260,6 +262,7 @@ export const nodesSlice = createSlice({
|
||||
mouseOverNode,
|
||||
state.nodes,
|
||||
state.edges,
|
||||
state.templates,
|
||||
nodeId,
|
||||
handleId,
|
||||
handleType,
|
||||
@ -677,6 +680,9 @@ export const nodesSlice = createSlice({
|
||||
selectionModeChanged: (state, action: PayloadAction<boolean>) => {
|
||||
state.selectionMode = action.payload ? SelectionMode.Full : SelectionMode.Partial;
|
||||
},
|
||||
nodeTemplatesBuilt: (state, action: PayloadAction<Record<string, InvocationTemplate>>) => {
|
||||
state.templates = action.payload;
|
||||
},
|
||||
},
|
||||
extraReducers: (builder) => {
|
||||
builder.addCase(workflowLoaded, (state, action) => {
|
||||
@ -808,6 +814,7 @@ export const {
|
||||
shouldValidateGraphChanged,
|
||||
viewportChanged,
|
||||
edgeAdded,
|
||||
nodeTemplatesBuilt,
|
||||
} = nodesSlice.actions;
|
||||
|
||||
// This is used for tracking `state.workflow.isTouched`
|
||||
|
51
invokeai/frontend/web/src/features/nodes/store/selectors.ts
Normal file
51
invokeai/frontend/web/src/features/nodes/store/selectors.ts
Normal file
@ -0,0 +1,51 @@
|
||||
import type { NodesState } from 'features/nodes/store/types';
|
||||
import type { FieldInputInstance, FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field';
|
||||
import type { InvocationNode, InvocationNodeData, InvocationTemplate } from 'features/nodes/types/invocation';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
|
||||
export const selectInvocationNode = (nodesSlice: NodesState, nodeId: string): InvocationNode | null => {
|
||||
const node = nodesSlice.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return null;
|
||||
}
|
||||
return node;
|
||||
};
|
||||
|
||||
export const selectNodeData = (nodesSlice: NodesState, nodeId: string): InvocationNodeData | null => {
|
||||
return selectInvocationNode(nodesSlice, nodeId)?.data ?? null;
|
||||
};
|
||||
|
||||
export const selectNodeTemplate = (nodesSlice: NodesState, nodeId: string): InvocationTemplate | null => {
|
||||
const node = selectInvocationNode(nodesSlice, nodeId);
|
||||
if (!node) {
|
||||
return null;
|
||||
}
|
||||
return nodesSlice.templates[node.data.type] ?? null;
|
||||
};
|
||||
|
||||
export const selectFieldInputInstance = (
|
||||
nodesSlice: NodesState,
|
||||
nodeId: string,
|
||||
fieldName: string
|
||||
): FieldInputInstance | null => {
|
||||
const data = selectNodeData(nodesSlice, nodeId);
|
||||
return data?.inputs[fieldName] ?? null;
|
||||
};
|
||||
|
||||
export const selectFieldInputTemplate = (
|
||||
nodesSlice: NodesState,
|
||||
nodeId: string,
|
||||
fieldName: string
|
||||
): FieldInputTemplate | null => {
|
||||
const template = selectNodeTemplate(nodesSlice, nodeId);
|
||||
return template?.inputs[fieldName] ?? null;
|
||||
};
|
||||
|
||||
export const selectFieldOutputTemplate = (
|
||||
nodesSlice: NodesState,
|
||||
nodeId: string,
|
||||
fieldName: string
|
||||
): FieldOutputTemplate | null => {
|
||||
const template = selectNodeTemplate(nodesSlice, nodeId);
|
||||
return template?.outputs[fieldName] ?? null;
|
||||
};
|
@ -5,13 +5,14 @@ import type {
|
||||
InvocationTemplate,
|
||||
NodeExecutionState,
|
||||
} from 'features/nodes/types/invocation';
|
||||
import type { WorkflowV2 } from 'features/nodes/types/workflow';
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import type { OnConnectStartParams, SelectionMode, Viewport, XYPosition } from 'reactflow';
|
||||
|
||||
export type NodesState = {
|
||||
_version: 1;
|
||||
nodes: AnyNode[];
|
||||
edges: InvocationNodeEdge[];
|
||||
templates: Record<string, InvocationTemplate>;
|
||||
connectionStartParams: OnConnectStartParams | null;
|
||||
connectionStartFieldType: FieldType | null;
|
||||
connectionMade: boolean;
|
||||
@ -38,7 +39,7 @@ export type FieldIdentifierWithValue = FieldIdentifier & {
|
||||
value: StatefulFieldValue;
|
||||
};
|
||||
|
||||
export type WorkflowsState = Omit<WorkflowV2, 'nodes' | 'edges'> & {
|
||||
export type WorkflowsState = Omit<WorkflowV3, 'nodes' | 'edges'> & {
|
||||
_version: 1;
|
||||
isTouched: boolean;
|
||||
mode: WorkflowMode;
|
||||
|
@ -1,4 +1,6 @@
|
||||
import type { FieldInputInstance, FieldOutputInstance, FieldType } from 'features/nodes/types/field';
|
||||
import type { FieldInputTemplate, FieldOutputTemplate, FieldType } from 'features/nodes/types/field';
|
||||
import type { AnyNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import type { Connection, Edge, HandleType, Node } from 'reactflow';
|
||||
|
||||
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
|
||||
@ -9,7 +11,7 @@ const isValidConnection = (
|
||||
handleCurrentType: HandleType,
|
||||
handleCurrentFieldType: FieldType,
|
||||
node: Node,
|
||||
handle: FieldInputInstance | FieldOutputInstance
|
||||
handle: FieldInputTemplate | FieldOutputTemplate
|
||||
) => {
|
||||
let isValidConnection = true;
|
||||
if (handleCurrentType === 'source') {
|
||||
@ -38,24 +40,31 @@ const isValidConnection = (
|
||||
};
|
||||
|
||||
export const findConnectionToValidHandle = (
|
||||
node: Node,
|
||||
nodes: Node[],
|
||||
edges: Edge[],
|
||||
node: AnyNode,
|
||||
nodes: AnyNode[],
|
||||
edges: InvocationNodeEdge[],
|
||||
templates: Record<string, InvocationTemplate>,
|
||||
handleCurrentNodeId: string,
|
||||
handleCurrentName: string,
|
||||
handleCurrentType: HandleType,
|
||||
handleCurrentFieldType: FieldType
|
||||
): Connection | null => {
|
||||
if (node.id === handleCurrentNodeId) {
|
||||
if (node.id === handleCurrentNodeId || !isInvocationNode(node)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const handles = handleCurrentType === 'source' ? node.data.inputs : node.data.outputs;
|
||||
const template = templates[node.data.type];
|
||||
|
||||
if (!template) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const handles = handleCurrentType === 'source' ? template.inputs : template.outputs;
|
||||
|
||||
//Prioritize handles whos name matches the node we're coming from
|
||||
if (handles[handleCurrentName]) {
|
||||
const handle = handles[handleCurrentName];
|
||||
const handle = handles[handleCurrentName];
|
||||
|
||||
if (handle) {
|
||||
const sourceID = handleCurrentType === 'source' ? handleCurrentNodeId : node.id;
|
||||
const targetID = handleCurrentType === 'source' ? node.id : handleCurrentNodeId;
|
||||
const sourceHandle = handleCurrentType === 'source' ? handleCurrentName : handle.name;
|
||||
@ -77,6 +86,9 @@ export const findConnectionToValidHandle = (
|
||||
|
||||
for (const handleName in handles) {
|
||||
const handle = handles[handleName];
|
||||
if (!handle) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const sourceID = handleCurrentType === 'source' ? handleCurrentNodeId : node.id;
|
||||
const targetID = handleCurrentType === 'source' ? node.id : handleCurrentNodeId;
|
||||
|
@ -16,7 +16,7 @@ export const makeConnectionErrorSelector = (
|
||||
nodeId: string,
|
||||
fieldName: string,
|
||||
handleType: HandleType,
|
||||
fieldType?: FieldType
|
||||
fieldType?: FieldType | null
|
||||
) => {
|
||||
return createSelector(selectNodesSlice, (nodesSlice) => {
|
||||
if (!fieldType) {
|
||||
|
@ -10,10 +10,10 @@ import type {
|
||||
} from 'features/nodes/store/types';
|
||||
import type { FieldIdentifier } from 'features/nodes/types/field';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import type { WorkflowCategory, WorkflowV2 } from 'features/nodes/types/workflow';
|
||||
import type { WorkflowCategory, WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import { cloneDeep, isEqual, omit, uniqBy } from 'lodash-es';
|
||||
|
||||
export const blankWorkflow: Omit<WorkflowV2, 'nodes' | 'edges'> = {
|
||||
export const blankWorkflow: Omit<WorkflowV3, 'nodes' | 'edges'> = {
|
||||
name: '',
|
||||
author: '',
|
||||
description: '',
|
||||
@ -22,7 +22,7 @@ export const blankWorkflow: Omit<WorkflowV2, 'nodes' | 'edges'> = {
|
||||
tags: '',
|
||||
notes: '',
|
||||
exposedFields: [],
|
||||
meta: { version: '2.0.0', category: 'user' },
|
||||
meta: { version: '3.0.0', category: 'user' },
|
||||
id: undefined,
|
||||
};
|
||||
|
||||
|
@ -46,20 +46,11 @@ export type FieldInput = z.infer<typeof zFieldInput>;
|
||||
export const zFieldUIComponent = z.enum(['none', 'textarea', 'slider']);
|
||||
export type FieldUIComponent = z.infer<typeof zFieldUIComponent>;
|
||||
|
||||
export const zFieldInstanceBase = z.object({
|
||||
id: z.string().trim().min(1),
|
||||
export const zFieldInputInstanceBase = z.object({
|
||||
name: z.string().trim().min(1),
|
||||
});
|
||||
export const zFieldInputInstanceBase = zFieldInstanceBase.extend({
|
||||
fieldKind: z.literal('input'),
|
||||
label: z.string().nullish(),
|
||||
});
|
||||
export const zFieldOutputInstanceBase = zFieldInstanceBase.extend({
|
||||
fieldKind: z.literal('output'),
|
||||
});
|
||||
export type FieldInstanceBase = z.infer<typeof zFieldInstanceBase>;
|
||||
export type FieldInputInstanceBase = z.infer<typeof zFieldInputInstanceBase>;
|
||||
export type FieldOutputInstanceBase = z.infer<typeof zFieldOutputInstanceBase>;
|
||||
|
||||
export const zFieldTemplateBase = z.object({
|
||||
name: z.string().min(1),
|
||||
@ -102,12 +93,8 @@ export const zIntegerFieldType = zFieldTypeBase.extend({
|
||||
});
|
||||
export const zIntegerFieldValue = z.number().int();
|
||||
export const zIntegerFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zIntegerFieldType,
|
||||
value: zIntegerFieldValue,
|
||||
});
|
||||
export const zIntegerFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zIntegerFieldType,
|
||||
});
|
||||
export const zIntegerFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zIntegerFieldType,
|
||||
default: zIntegerFieldValue,
|
||||
@ -136,12 +123,8 @@ export const zFloatFieldType = zFieldTypeBase.extend({
|
||||
});
|
||||
export const zFloatFieldValue = z.number();
|
||||
export const zFloatFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zFloatFieldType,
|
||||
value: zFloatFieldValue,
|
||||
});
|
||||
export const zFloatFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zFloatFieldType,
|
||||
});
|
||||
export const zFloatFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zFloatFieldType,
|
||||
default: zFloatFieldValue,
|
||||
@ -157,7 +140,6 @@ export const zFloatFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type FloatFieldType = z.infer<typeof zFloatFieldType>;
|
||||
export type FloatFieldValue = z.infer<typeof zFloatFieldValue>;
|
||||
export type FloatFieldInputInstance = z.infer<typeof zFloatFieldInputInstance>;
|
||||
export type FloatFieldOutputInstance = z.infer<typeof zFloatFieldOutputInstance>;
|
||||
export type FloatFieldInputTemplate = z.infer<typeof zFloatFieldInputTemplate>;
|
||||
export type FloatFieldOutputTemplate = z.infer<typeof zFloatFieldOutputTemplate>;
|
||||
export const isFloatFieldInputInstance = (val: unknown): val is FloatFieldInputInstance =>
|
||||
@ -172,12 +154,8 @@ export const zStringFieldType = zFieldTypeBase.extend({
|
||||
});
|
||||
export const zStringFieldValue = z.string();
|
||||
export const zStringFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zStringFieldType,
|
||||
value: zStringFieldValue,
|
||||
});
|
||||
export const zStringFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zStringFieldType,
|
||||
});
|
||||
export const zStringFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zStringFieldType,
|
||||
default: zStringFieldValue,
|
||||
@ -191,7 +169,6 @@ export const zStringFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type StringFieldType = z.infer<typeof zStringFieldType>;
|
||||
export type StringFieldValue = z.infer<typeof zStringFieldValue>;
|
||||
export type StringFieldInputInstance = z.infer<typeof zStringFieldInputInstance>;
|
||||
export type StringFieldOutputInstance = z.infer<typeof zStringFieldOutputInstance>;
|
||||
export type StringFieldInputTemplate = z.infer<typeof zStringFieldInputTemplate>;
|
||||
export type StringFieldOutputTemplate = z.infer<typeof zStringFieldOutputTemplate>;
|
||||
export const isStringFieldInputInstance = (val: unknown): val is StringFieldInputInstance =>
|
||||
@ -206,12 +183,8 @@ export const zBooleanFieldType = zFieldTypeBase.extend({
|
||||
});
|
||||
export const zBooleanFieldValue = z.boolean();
|
||||
export const zBooleanFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zBooleanFieldType,
|
||||
value: zBooleanFieldValue,
|
||||
});
|
||||
export const zBooleanFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zBooleanFieldType,
|
||||
});
|
||||
export const zBooleanFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zBooleanFieldType,
|
||||
default: zBooleanFieldValue,
|
||||
@ -222,7 +195,6 @@ export const zBooleanFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type BooleanFieldType = z.infer<typeof zBooleanFieldType>;
|
||||
export type BooleanFieldValue = z.infer<typeof zBooleanFieldValue>;
|
||||
export type BooleanFieldInputInstance = z.infer<typeof zBooleanFieldInputInstance>;
|
||||
export type BooleanFieldOutputInstance = z.infer<typeof zBooleanFieldOutputInstance>;
|
||||
export type BooleanFieldInputTemplate = z.infer<typeof zBooleanFieldInputTemplate>;
|
||||
export type BooleanFieldOutputTemplate = z.infer<typeof zBooleanFieldOutputTemplate>;
|
||||
export const isBooleanFieldInputInstance = (val: unknown): val is BooleanFieldInputInstance =>
|
||||
@ -237,12 +209,8 @@ export const zEnumFieldType = zFieldTypeBase.extend({
|
||||
});
|
||||
export const zEnumFieldValue = z.string();
|
||||
export const zEnumFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zEnumFieldType,
|
||||
value: zEnumFieldValue,
|
||||
});
|
||||
export const zEnumFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zEnumFieldType,
|
||||
});
|
||||
export const zEnumFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zEnumFieldType,
|
||||
default: zEnumFieldValue,
|
||||
@ -255,7 +223,6 @@ export const zEnumFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type EnumFieldType = z.infer<typeof zEnumFieldType>;
|
||||
export type EnumFieldValue = z.infer<typeof zEnumFieldValue>;
|
||||
export type EnumFieldInputInstance = z.infer<typeof zEnumFieldInputInstance>;
|
||||
export type EnumFieldOutputInstance = z.infer<typeof zEnumFieldOutputInstance>;
|
||||
export type EnumFieldInputTemplate = z.infer<typeof zEnumFieldInputTemplate>;
|
||||
export type EnumFieldOutputTemplate = z.infer<typeof zEnumFieldOutputTemplate>;
|
||||
export const isEnumFieldInputInstance = (val: unknown): val is EnumFieldInputInstance =>
|
||||
@ -270,12 +237,8 @@ export const zImageFieldType = zFieldTypeBase.extend({
|
||||
});
|
||||
export const zImageFieldValue = zImageField.optional();
|
||||
export const zImageFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zImageFieldType,
|
||||
value: zImageFieldValue,
|
||||
});
|
||||
export const zImageFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zImageFieldType,
|
||||
});
|
||||
export const zImageFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zImageFieldType,
|
||||
default: zImageFieldValue,
|
||||
@ -286,7 +249,6 @@ export const zImageFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type ImageFieldType = z.infer<typeof zImageFieldType>;
|
||||
export type ImageFieldValue = z.infer<typeof zImageFieldValue>;
|
||||
export type ImageFieldInputInstance = z.infer<typeof zImageFieldInputInstance>;
|
||||
export type ImageFieldOutputInstance = z.infer<typeof zImageFieldOutputInstance>;
|
||||
export type ImageFieldInputTemplate = z.infer<typeof zImageFieldInputTemplate>;
|
||||
export type ImageFieldOutputTemplate = z.infer<typeof zImageFieldOutputTemplate>;
|
||||
export const isImageFieldInputInstance = (val: unknown): val is ImageFieldInputInstance =>
|
||||
@ -301,12 +263,8 @@ export const zBoardFieldType = zFieldTypeBase.extend({
|
||||
});
|
||||
export const zBoardFieldValue = zBoardField.optional();
|
||||
export const zBoardFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zBoardFieldType,
|
||||
value: zBoardFieldValue,
|
||||
});
|
||||
export const zBoardFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zBoardFieldType,
|
||||
});
|
||||
export const zBoardFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zBoardFieldType,
|
||||
default: zBoardFieldValue,
|
||||
@ -317,7 +275,6 @@ export const zBoardFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type BoardFieldType = z.infer<typeof zBoardFieldType>;
|
||||
export type BoardFieldValue = z.infer<typeof zBoardFieldValue>;
|
||||
export type BoardFieldInputInstance = z.infer<typeof zBoardFieldInputInstance>;
|
||||
export type BoardFieldOutputInstance = z.infer<typeof zBoardFieldOutputInstance>;
|
||||
export type BoardFieldInputTemplate = z.infer<typeof zBoardFieldInputTemplate>;
|
||||
export type BoardFieldOutputTemplate = z.infer<typeof zBoardFieldOutputTemplate>;
|
||||
export const isBoardFieldInputInstance = (val: unknown): val is BoardFieldInputInstance =>
|
||||
@ -332,12 +289,8 @@ export const zColorFieldType = zFieldTypeBase.extend({
|
||||
});
|
||||
export const zColorFieldValue = zColorField.optional();
|
||||
export const zColorFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zColorFieldType,
|
||||
value: zColorFieldValue,
|
||||
});
|
||||
export const zColorFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zColorFieldType,
|
||||
});
|
||||
export const zColorFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zColorFieldType,
|
||||
default: zColorFieldValue,
|
||||
@ -348,7 +301,6 @@ export const zColorFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type ColorFieldType = z.infer<typeof zColorFieldType>;
|
||||
export type ColorFieldValue = z.infer<typeof zColorFieldValue>;
|
||||
export type ColorFieldInputInstance = z.infer<typeof zColorFieldInputInstance>;
|
||||
export type ColorFieldOutputInstance = z.infer<typeof zColorFieldOutputInstance>;
|
||||
export type ColorFieldInputTemplate = z.infer<typeof zColorFieldInputTemplate>;
|
||||
export type ColorFieldOutputTemplate = z.infer<typeof zColorFieldOutputTemplate>;
|
||||
export const isColorFieldInputInstance = (val: unknown): val is ColorFieldInputInstance =>
|
||||
@ -363,12 +315,8 @@ export const zMainModelFieldType = zFieldTypeBase.extend({
|
||||
});
|
||||
export const zMainModelFieldValue = zMainModelField.optional();
|
||||
export const zMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zMainModelFieldType,
|
||||
value: zMainModelFieldValue,
|
||||
});
|
||||
export const zMainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zMainModelFieldType,
|
||||
});
|
||||
export const zMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zMainModelFieldType,
|
||||
default: zMainModelFieldValue,
|
||||
@ -379,7 +327,6 @@ export const zMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type MainModelFieldType = z.infer<typeof zMainModelFieldType>;
|
||||
export type MainModelFieldValue = z.infer<typeof zMainModelFieldValue>;
|
||||
export type MainModelFieldInputInstance = z.infer<typeof zMainModelFieldInputInstance>;
|
||||
export type MainModelFieldOutputInstance = z.infer<typeof zMainModelFieldOutputInstance>;
|
||||
export type MainModelFieldInputTemplate = z.infer<typeof zMainModelFieldInputTemplate>;
|
||||
export type MainModelFieldOutputTemplate = z.infer<typeof zMainModelFieldOutputTemplate>;
|
||||
export const isMainModelFieldInputInstance = (val: unknown): val is MainModelFieldInputInstance =>
|
||||
@ -394,12 +341,8 @@ export const zSDXLMainModelFieldType = zFieldTypeBase.extend({
|
||||
});
|
||||
export const zSDXLMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
|
||||
export const zSDXLMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zSDXLMainModelFieldType,
|
||||
value: zSDXLMainModelFieldValue,
|
||||
});
|
||||
export const zSDXLMainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zSDXLMainModelFieldType,
|
||||
});
|
||||
export const zSDXLMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zSDXLMainModelFieldType,
|
||||
default: zSDXLMainModelFieldValue,
|
||||
@ -410,7 +353,6 @@ export const zSDXLMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend
|
||||
export type SDXLMainModelFieldType = z.infer<typeof zSDXLMainModelFieldType>;
|
||||
export type SDXLMainModelFieldValue = z.infer<typeof zSDXLMainModelFieldValue>;
|
||||
export type SDXLMainModelFieldInputInstance = z.infer<typeof zSDXLMainModelFieldInputInstance>;
|
||||
export type SDXLMainModelFieldOutputInstance = z.infer<typeof zSDXLMainModelFieldOutputInstance>;
|
||||
export type SDXLMainModelFieldInputTemplate = z.infer<typeof zSDXLMainModelFieldInputTemplate>;
|
||||
export type SDXLMainModelFieldOutputTemplate = z.infer<typeof zSDXLMainModelFieldOutputTemplate>;
|
||||
export const isSDXLMainModelFieldInputInstance = (val: unknown): val is SDXLMainModelFieldInputInstance =>
|
||||
@ -425,12 +367,8 @@ export const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({
|
||||
});
|
||||
export const zSDXLRefinerModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL Refiner models only.
|
||||
export const zSDXLRefinerModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zSDXLRefinerModelFieldType,
|
||||
value: zSDXLRefinerModelFieldValue,
|
||||
});
|
||||
export const zSDXLRefinerModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zSDXLRefinerModelFieldType,
|
||||
});
|
||||
export const zSDXLRefinerModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zSDXLRefinerModelFieldType,
|
||||
default: zSDXLRefinerModelFieldValue,
|
||||
@ -441,7 +379,6 @@ export const zSDXLRefinerModelFieldOutputTemplate = zFieldOutputTemplateBase.ext
|
||||
export type SDXLRefinerModelFieldType = z.infer<typeof zSDXLRefinerModelFieldType>;
|
||||
export type SDXLRefinerModelFieldValue = z.infer<typeof zSDXLRefinerModelFieldValue>;
|
||||
export type SDXLRefinerModelFieldInputInstance = z.infer<typeof zSDXLRefinerModelFieldInputInstance>;
|
||||
export type SDXLRefinerModelFieldOutputInstance = z.infer<typeof zSDXLRefinerModelFieldOutputInstance>;
|
||||
export type SDXLRefinerModelFieldInputTemplate = z.infer<typeof zSDXLRefinerModelFieldInputTemplate>;
|
||||
export type SDXLRefinerModelFieldOutputTemplate = z.infer<typeof zSDXLRefinerModelFieldOutputTemplate>;
|
||||
export const isSDXLRefinerModelFieldInputInstance = (val: unknown): val is SDXLRefinerModelFieldInputInstance =>
|
||||
@ -456,12 +393,8 @@ export const zVAEModelFieldType = zFieldTypeBase.extend({
|
||||
});
|
||||
export const zVAEModelFieldValue = zVAEModelField.optional();
|
||||
export const zVAEModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zVAEModelFieldType,
|
||||
value: zVAEModelFieldValue,
|
||||
});
|
||||
export const zVAEModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zVAEModelFieldType,
|
||||
});
|
||||
export const zVAEModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zVAEModelFieldType,
|
||||
default: zVAEModelFieldValue,
|
||||
@ -472,7 +405,6 @@ export const zVAEModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type VAEModelFieldType = z.infer<typeof zVAEModelFieldType>;
|
||||
export type VAEModelFieldValue = z.infer<typeof zVAEModelFieldValue>;
|
||||
export type VAEModelFieldInputInstance = z.infer<typeof zVAEModelFieldInputInstance>;
|
||||
export type VAEModelFieldOutputInstance = z.infer<typeof zVAEModelFieldOutputInstance>;
|
||||
export type VAEModelFieldInputTemplate = z.infer<typeof zVAEModelFieldInputTemplate>;
|
||||
export type VAEModelFieldOutputTemplate = z.infer<typeof zVAEModelFieldOutputTemplate>;
|
||||
export const isVAEModelFieldInputInstance = (val: unknown): val is VAEModelFieldInputInstance =>
|
||||
@ -487,12 +419,8 @@ export const zLoRAModelFieldType = zFieldTypeBase.extend({
|
||||
});
|
||||
export const zLoRAModelFieldValue = zLoRAModelField.optional();
|
||||
export const zLoRAModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zLoRAModelFieldType,
|
||||
value: zLoRAModelFieldValue,
|
||||
});
|
||||
export const zLoRAModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zLoRAModelFieldType,
|
||||
});
|
||||
export const zLoRAModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zLoRAModelFieldType,
|
||||
default: zLoRAModelFieldValue,
|
||||
@ -503,7 +431,6 @@ export const zLoRAModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type LoRAModelFieldType = z.infer<typeof zLoRAModelFieldType>;
|
||||
export type LoRAModelFieldValue = z.infer<typeof zLoRAModelFieldValue>;
|
||||
export type LoRAModelFieldInputInstance = z.infer<typeof zLoRAModelFieldInputInstance>;
|
||||
export type LoRAModelFieldOutputInstance = z.infer<typeof zLoRAModelFieldOutputInstance>;
|
||||
export type LoRAModelFieldInputTemplate = z.infer<typeof zLoRAModelFieldInputTemplate>;
|
||||
export type LoRAModelFieldOutputTemplate = z.infer<typeof zLoRAModelFieldOutputTemplate>;
|
||||
export const isLoRAModelFieldInputInstance = (val: unknown): val is LoRAModelFieldInputInstance =>
|
||||
@ -518,12 +445,8 @@ export const zControlNetModelFieldType = zFieldTypeBase.extend({
|
||||
});
|
||||
export const zControlNetModelFieldValue = zControlNetModelField.optional();
|
||||
export const zControlNetModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zControlNetModelFieldType,
|
||||
value: zControlNetModelFieldValue,
|
||||
});
|
||||
export const zControlNetModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zControlNetModelFieldType,
|
||||
});
|
||||
export const zControlNetModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zControlNetModelFieldType,
|
||||
default: zControlNetModelFieldValue,
|
||||
@ -534,7 +457,6 @@ export const zControlNetModelFieldOutputTemplate = zFieldOutputTemplateBase.exte
|
||||
export type ControlNetModelFieldType = z.infer<typeof zControlNetModelFieldType>;
|
||||
export type ControlNetModelFieldValue = z.infer<typeof zControlNetModelFieldValue>;
|
||||
export type ControlNetModelFieldInputInstance = z.infer<typeof zControlNetModelFieldInputInstance>;
|
||||
export type ControlNetModelFieldOutputInstance = z.infer<typeof zControlNetModelFieldOutputInstance>;
|
||||
export type ControlNetModelFieldInputTemplate = z.infer<typeof zControlNetModelFieldInputTemplate>;
|
||||
export type ControlNetModelFieldOutputTemplate = z.infer<typeof zControlNetModelFieldOutputTemplate>;
|
||||
export const isControlNetModelFieldInputInstance = (val: unknown): val is ControlNetModelFieldInputInstance =>
|
||||
@ -551,12 +473,8 @@ export const zIPAdapterModelFieldType = zFieldTypeBase.extend({
|
||||
});
|
||||
export const zIPAdapterModelFieldValue = zIPAdapterModelField.optional();
|
||||
export const zIPAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zIPAdapterModelFieldType,
|
||||
value: zIPAdapterModelFieldValue,
|
||||
});
|
||||
export const zIPAdapterModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zIPAdapterModelFieldType,
|
||||
});
|
||||
export const zIPAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zIPAdapterModelFieldType,
|
||||
default: zIPAdapterModelFieldValue,
|
||||
@ -567,7 +485,6 @@ export const zIPAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.exten
|
||||
export type IPAdapterModelFieldType = z.infer<typeof zIPAdapterModelFieldType>;
|
||||
export type IPAdapterModelFieldValue = z.infer<typeof zIPAdapterModelFieldValue>;
|
||||
export type IPAdapterModelFieldInputInstance = z.infer<typeof zIPAdapterModelFieldInputInstance>;
|
||||
export type IPAdapterModelFieldOutputInstance = z.infer<typeof zIPAdapterModelFieldOutputInstance>;
|
||||
export type IPAdapterModelFieldInputTemplate = z.infer<typeof zIPAdapterModelFieldInputTemplate>;
|
||||
export type IPAdapterModelFieldOutputTemplate = z.infer<typeof zIPAdapterModelFieldOutputTemplate>;
|
||||
export const isIPAdapterModelFieldInputInstance = (val: unknown): val is IPAdapterModelFieldInputInstance =>
|
||||
@ -584,12 +501,8 @@ export const zT2IAdapterModelFieldType = zFieldTypeBase.extend({
|
||||
});
|
||||
export const zT2IAdapterModelFieldValue = zT2IAdapterModelField.optional();
|
||||
export const zT2IAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zT2IAdapterModelFieldType,
|
||||
value: zT2IAdapterModelFieldValue,
|
||||
});
|
||||
export const zT2IAdapterModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zT2IAdapterModelFieldType,
|
||||
});
|
||||
export const zT2IAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zT2IAdapterModelFieldType,
|
||||
default: zT2IAdapterModelFieldValue,
|
||||
@ -600,7 +513,6 @@ export const zT2IAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.exte
|
||||
export type T2IAdapterModelFieldType = z.infer<typeof zT2IAdapterModelFieldType>;
|
||||
export type T2IAdapterModelFieldValue = z.infer<typeof zT2IAdapterModelFieldValue>;
|
||||
export type T2IAdapterModelFieldInputInstance = z.infer<typeof zT2IAdapterModelFieldInputInstance>;
|
||||
export type T2IAdapterModelFieldOutputInstance = z.infer<typeof zT2IAdapterModelFieldOutputInstance>;
|
||||
export type T2IAdapterModelFieldInputTemplate = z.infer<typeof zT2IAdapterModelFieldInputTemplate>;
|
||||
export type T2IAdapterModelFieldOutputTemplate = z.infer<typeof zT2IAdapterModelFieldOutputTemplate>;
|
||||
export const isT2IAdapterModelFieldInputInstance = (val: unknown): val is T2IAdapterModelFieldInputInstance =>
|
||||
@ -615,12 +527,8 @@ export const zSchedulerFieldType = zFieldTypeBase.extend({
|
||||
});
|
||||
export const zSchedulerFieldValue = zSchedulerField.optional();
|
||||
export const zSchedulerFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zSchedulerFieldType,
|
||||
value: zSchedulerFieldValue,
|
||||
});
|
||||
export const zSchedulerFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zSchedulerFieldType,
|
||||
});
|
||||
export const zSchedulerFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zSchedulerFieldType,
|
||||
default: zSchedulerFieldValue,
|
||||
@ -631,7 +539,6 @@ export const zSchedulerFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type SchedulerFieldType = z.infer<typeof zSchedulerFieldType>;
|
||||
export type SchedulerFieldValue = z.infer<typeof zSchedulerFieldValue>;
|
||||
export type SchedulerFieldInputInstance = z.infer<typeof zSchedulerFieldInputInstance>;
|
||||
export type SchedulerFieldOutputInstance = z.infer<typeof zSchedulerFieldOutputInstance>;
|
||||
export type SchedulerFieldInputTemplate = z.infer<typeof zSchedulerFieldInputTemplate>;
|
||||
export type SchedulerFieldOutputTemplate = z.infer<typeof zSchedulerFieldOutputTemplate>;
|
||||
export const isSchedulerFieldInputInstance = (val: unknown): val is SchedulerFieldInputInstance =>
|
||||
@ -657,12 +564,8 @@ export const zStatelessFieldType = zFieldTypeBase.extend({
|
||||
});
|
||||
export const zStatelessFieldValue = z.undefined().catch(undefined); // stateless --> no value, but making this z.never() introduces a lot of extra TS fanagling
|
||||
export const zStatelessFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zStatelessFieldType,
|
||||
value: zStatelessFieldValue,
|
||||
});
|
||||
export const zStatelessFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zStatelessFieldType,
|
||||
});
|
||||
export const zStatelessFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zStatelessFieldType,
|
||||
default: zStatelessFieldValue,
|
||||
@ -675,7 +578,6 @@ export const zStatelessFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type StatelessFieldType = z.infer<typeof zStatelessFieldType>;
|
||||
export type StatelessFieldValue = z.infer<typeof zStatelessFieldValue>;
|
||||
export type StatelessFieldInputInstance = z.infer<typeof zStatelessFieldInputInstance>;
|
||||
export type StatelessFieldOutputInstance = z.infer<typeof zStatelessFieldOutputInstance>;
|
||||
export type StatelessFieldInputTemplate = z.infer<typeof zStatelessFieldInputTemplate>;
|
||||
export type StatelessFieldOutputTemplate = z.infer<typeof zStatelessFieldOutputTemplate>;
|
||||
// #endregion
|
||||
@ -783,36 +685,6 @@ export const isFieldInputInstance = (val: unknown): val is FieldInputInstance =>
|
||||
zFieldInputInstance.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region StatefulFieldOutputInstance & FieldOutputInstance
|
||||
export const zStatefulFieldOutputInstance = z.union([
|
||||
zIntegerFieldOutputInstance,
|
||||
zFloatFieldOutputInstance,
|
||||
zStringFieldOutputInstance,
|
||||
zBooleanFieldOutputInstance,
|
||||
zEnumFieldOutputInstance,
|
||||
zImageFieldOutputInstance,
|
||||
zBoardFieldOutputInstance,
|
||||
zMainModelFieldOutputInstance,
|
||||
zSDXLMainModelFieldOutputInstance,
|
||||
zSDXLRefinerModelFieldOutputInstance,
|
||||
zVAEModelFieldOutputInstance,
|
||||
zLoRAModelFieldOutputInstance,
|
||||
zControlNetModelFieldOutputInstance,
|
||||
zIPAdapterModelFieldOutputInstance,
|
||||
zT2IAdapterModelFieldOutputInstance,
|
||||
zColorFieldOutputInstance,
|
||||
zSchedulerFieldOutputInstance,
|
||||
]);
|
||||
export type StatefulFieldOutputInstance = z.infer<typeof zStatefulFieldOutputInstance>;
|
||||
export const isStatefulFieldOutputInstance = (val: unknown): val is StatefulFieldOutputInstance =>
|
||||
zStatefulFieldOutputInstance.safeParse(val).success;
|
||||
|
||||
export const zFieldOutputInstance = z.union([zStatefulFieldOutputInstance, zStatelessFieldOutputInstance]);
|
||||
export type FieldOutputInstance = z.infer<typeof zFieldOutputInstance>;
|
||||
export const isFieldOutputInstance = (val: unknown): val is FieldOutputInstance =>
|
||||
zFieldOutputInstance.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region StatefulFieldInputTemplate & FieldInputTemplate
|
||||
export const zStatefulFieldInputTemplate = z.union([
|
||||
zIntegerFieldInputTemplate,
|
||||
|
@ -2,7 +2,7 @@ import type { Edge, Node } from 'reactflow';
|
||||
import { z } from 'zod';
|
||||
|
||||
import { zClassification, zProgressImage } from './common';
|
||||
import { zFieldInputInstance, zFieldInputTemplate, zFieldOutputInstance, zFieldOutputTemplate } from './field';
|
||||
import { zFieldInputInstance, zFieldInputTemplate, zFieldOutputTemplate } from './field';
|
||||
import { zSemVer } from './semver';
|
||||
|
||||
// #region InvocationTemplate
|
||||
@ -25,16 +25,15 @@ export type InvocationTemplate = z.infer<typeof zInvocationTemplate>;
|
||||
// #region NodeData
|
||||
export const zInvocationNodeData = z.object({
|
||||
id: z.string().trim().min(1),
|
||||
type: z.string().trim().min(1),
|
||||
label: z.string(),
|
||||
isOpen: z.boolean(),
|
||||
notes: z.string(),
|
||||
isIntermediate: z.boolean(),
|
||||
useCache: z.boolean(),
|
||||
version: zSemVer,
|
||||
nodePack: z.string().min(1).nullish(),
|
||||
label: z.string(),
|
||||
notes: z.string(),
|
||||
type: z.string().trim().min(1),
|
||||
inputs: z.record(zFieldInputInstance),
|
||||
outputs: z.record(zFieldOutputInstance),
|
||||
isOpen: z.boolean(),
|
||||
isIntermediate: z.boolean(),
|
||||
useCache: z.boolean(),
|
||||
});
|
||||
|
||||
export const zNotesNodeData = z.object({
|
||||
@ -62,11 +61,12 @@ export type NotesNode = Node<NotesNodeData, 'notes'>;
|
||||
export type CurrentImageNode = Node<CurrentImageNodeData, 'current_image'>;
|
||||
export type AnyNode = Node<AnyNodeData>;
|
||||
|
||||
export const isInvocationNode = (node?: AnyNode): node is InvocationNode => Boolean(node && node.type === 'invocation');
|
||||
export const isNotesNode = (node?: AnyNode): node is NotesNode => Boolean(node && node.type === 'notes');
|
||||
export const isCurrentImageNode = (node?: AnyNode): node is CurrentImageNode =>
|
||||
export const isInvocationNode = (node?: AnyNode | null): node is InvocationNode =>
|
||||
Boolean(node && node.type === 'invocation');
|
||||
export const isNotesNode = (node?: AnyNode | null): node is NotesNode => Boolean(node && node.type === 'notes');
|
||||
export const isCurrentImageNode = (node?: AnyNode | null): node is CurrentImageNode =>
|
||||
Boolean(node && node.type === 'current_image');
|
||||
export const isInvocationNodeData = (node?: AnyNodeData): node is InvocationNodeData =>
|
||||
export const isInvocationNodeData = (node?: AnyNodeData | null): node is InvocationNodeData =>
|
||||
Boolean(node && !['notes', 'current_image'].includes(node.type)); // node.type may be 'notes', 'current_image', or any invocation type
|
||||
// #endregion
|
||||
|
||||
|
188
invokeai/frontend/web/src/features/nodes/types/v2/common.ts
Normal file
188
invokeai/frontend/web/src/features/nodes/types/v2/common.ts
Normal file
@ -0,0 +1,188 @@
|
||||
import { z } from 'zod';
|
||||
|
||||
// #region Field data schemas
|
||||
export const zImageField = z.object({
|
||||
image_name: z.string().trim().min(1),
|
||||
});
|
||||
export type ImageField = z.infer<typeof zImageField>;
|
||||
|
||||
export const zBoardField = z.object({
|
||||
board_id: z.string().trim().min(1),
|
||||
});
|
||||
export type BoardField = z.infer<typeof zBoardField>;
|
||||
|
||||
export const zColorField = z.object({
|
||||
r: z.number().int().min(0).max(255),
|
||||
g: z.number().int().min(0).max(255),
|
||||
b: z.number().int().min(0).max(255),
|
||||
a: z.number().int().min(0).max(255),
|
||||
});
|
||||
export type ColorField = z.infer<typeof zColorField>;
|
||||
|
||||
export const zClassification = z.enum(['stable', 'beta', 'prototype']);
|
||||
export type Classification = z.infer<typeof zClassification>;
|
||||
|
||||
export const zSchedulerField = z.enum([
|
||||
'euler',
|
||||
'deis',
|
||||
'ddim',
|
||||
'ddpm',
|
||||
'dpmpp_2s',
|
||||
'dpmpp_2m',
|
||||
'dpmpp_2m_sde',
|
||||
'dpmpp_sde',
|
||||
'heun',
|
||||
'kdpm_2',
|
||||
'lms',
|
||||
'pndm',
|
||||
'unipc',
|
||||
'euler_k',
|
||||
'dpmpp_2s_k',
|
||||
'dpmpp_2m_k',
|
||||
'dpmpp_2m_sde_k',
|
||||
'dpmpp_sde_k',
|
||||
'heun_k',
|
||||
'lms_k',
|
||||
'euler_a',
|
||||
'kdpm_2_a',
|
||||
'lcm',
|
||||
]);
|
||||
export type SchedulerField = z.infer<typeof zSchedulerField>;
|
||||
// #endregion
|
||||
|
||||
// #region Model-related schemas
|
||||
export const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
|
||||
export const zModelType = z.enum(['main', 'vae', 'lora', 'controlnet', 'embedding']);
|
||||
export const zModelName = z.string().min(3);
|
||||
export const zModelIdentifier = z.object({
|
||||
model_name: zModelName,
|
||||
base_model: zBaseModel,
|
||||
});
|
||||
export type BaseModel = z.infer<typeof zBaseModel>;
|
||||
export type ModelType = z.infer<typeof zModelType>;
|
||||
export type ModelIdentifier = z.infer<typeof zModelIdentifier>;
|
||||
|
||||
export const zMainModelField = z.object({
|
||||
model_name: zModelName,
|
||||
base_model: zBaseModel,
|
||||
model_type: z.literal('main'),
|
||||
});
|
||||
export const zSDXLRefinerModelField = z.object({
|
||||
model_name: z.string().min(1),
|
||||
base_model: z.literal('sdxl-refiner'),
|
||||
model_type: z.literal('main'),
|
||||
});
|
||||
export type MainModelField = z.infer<typeof zMainModelField>;
|
||||
export type SDXLRefinerModelField = z.infer<typeof zSDXLRefinerModelField>;
|
||||
|
||||
export const zSubModelType = z.enum([
|
||||
'unet',
|
||||
'text_encoder',
|
||||
'text_encoder_2',
|
||||
'tokenizer',
|
||||
'tokenizer_2',
|
||||
'vae',
|
||||
'vae_decoder',
|
||||
'vae_encoder',
|
||||
'scheduler',
|
||||
'safety_checker',
|
||||
]);
|
||||
export type SubModelType = z.infer<typeof zSubModelType>;
|
||||
|
||||
export const zVAEModelField = zModelIdentifier;
|
||||
|
||||
export const zModelInfo = zModelIdentifier.extend({
|
||||
model_type: zModelType,
|
||||
submodel: zSubModelType.optional(),
|
||||
});
|
||||
export type ModelInfo = z.infer<typeof zModelInfo>;
|
||||
|
||||
export const zLoRAModelField = zModelIdentifier;
|
||||
export type LoRAModelField = z.infer<typeof zLoRAModelField>;
|
||||
|
||||
export const zControlNetModelField = zModelIdentifier;
|
||||
export type ControlNetModelField = z.infer<typeof zControlNetModelField>;
|
||||
|
||||
export const zIPAdapterModelField = zModelIdentifier;
|
||||
export type IPAdapterModelField = z.infer<typeof zIPAdapterModelField>;
|
||||
|
||||
export const zT2IAdapterModelField = zModelIdentifier;
|
||||
export type T2IAdapterModelField = z.infer<typeof zT2IAdapterModelField>;
|
||||
|
||||
export const zLoraInfo = zModelInfo.extend({
|
||||
weight: z.number().optional(),
|
||||
});
|
||||
export type LoraInfo = z.infer<typeof zLoraInfo>;
|
||||
|
||||
export const zUNetField = z.object({
|
||||
unet: zModelInfo,
|
||||
scheduler: zModelInfo,
|
||||
loras: z.array(zLoraInfo),
|
||||
});
|
||||
export type UNetField = z.infer<typeof zUNetField>;
|
||||
|
||||
export const zCLIPField = z.object({
|
||||
tokenizer: zModelInfo,
|
||||
text_encoder: zModelInfo,
|
||||
skipped_layers: z.number(),
|
||||
loras: z.array(zLoraInfo),
|
||||
});
|
||||
export type CLIPField = z.infer<typeof zCLIPField>;
|
||||
|
||||
export const zVAEField = z.object({
|
||||
vae: zModelInfo,
|
||||
});
|
||||
export type VAEField = z.infer<typeof zVAEField>;
|
||||
// #endregion
|
||||
|
||||
// #region Control Adapters
|
||||
export const zControlField = z.object({
|
||||
image: zImageField,
|
||||
control_model: zControlNetModelField,
|
||||
control_weight: z.union([z.number(), z.array(z.number())]).optional(),
|
||||
begin_step_percent: z.number().optional(),
|
||||
end_step_percent: z.number().optional(),
|
||||
control_mode: z.enum(['balanced', 'more_prompt', 'more_control', 'unbalanced']).optional(),
|
||||
resize_mode: z.enum(['just_resize', 'crop_resize', 'fill_resize', 'just_resize_simple']).optional(),
|
||||
});
|
||||
export type ControlField = z.infer<typeof zControlField>;
|
||||
|
||||
export const zIPAdapterField = z.object({
|
||||
image: zImageField,
|
||||
ip_adapter_model: zIPAdapterModelField,
|
||||
weight: z.number(),
|
||||
begin_step_percent: z.number().optional(),
|
||||
end_step_percent: z.number().optional(),
|
||||
});
|
||||
export type IPAdapterField = z.infer<typeof zIPAdapterField>;
|
||||
|
||||
export const zT2IAdapterField = z.object({
|
||||
image: zImageField,
|
||||
t2i_adapter_model: zT2IAdapterModelField,
|
||||
weight: z.union([z.number(), z.array(z.number())]).optional(),
|
||||
begin_step_percent: z.number().optional(),
|
||||
end_step_percent: z.number().optional(),
|
||||
resize_mode: z.enum(['just_resize', 'crop_resize', 'fill_resize', 'just_resize_simple']).optional(),
|
||||
});
|
||||
export type T2IAdapterField = z.infer<typeof zT2IAdapterField>;
|
||||
// #endregion
|
||||
|
||||
// #region ProgressImage
|
||||
export const zProgressImage = z.object({
|
||||
dataURL: z.string(),
|
||||
width: z.number().int(),
|
||||
height: z.number().int(),
|
||||
});
|
||||
export type ProgressImage = z.infer<typeof zProgressImage>;
|
||||
// #endregion
|
||||
|
||||
// #region ImageOutput
|
||||
export const zImageOutput = z.object({
|
||||
image: zImageField,
|
||||
width: z.number().int().gt(0),
|
||||
height: z.number().int().gt(0),
|
||||
type: z.literal('image_output'),
|
||||
});
|
||||
export type ImageOutput = z.infer<typeof zImageOutput>;
|
||||
export const isImageOutput = (output: unknown): output is ImageOutput => zImageOutput.safeParse(output).success;
|
||||
// #endregion
|
@ -0,0 +1,80 @@
|
||||
import type { Node } from 'reactflow';
|
||||
|
||||
/**
|
||||
* How long to wait before showing a tooltip when hovering a field handle.
|
||||
*/
|
||||
export const HANDLE_TOOLTIP_OPEN_DELAY = 500;
|
||||
|
||||
/**
|
||||
* The width of a node in the UI in pixels.
|
||||
*/
|
||||
export const NODE_WIDTH = 320;
|
||||
|
||||
/**
|
||||
* This class name is special - reactflow uses it to identify the drag handle of a node,
|
||||
* applying the appropriate listeners to it.
|
||||
*/
|
||||
export const DRAG_HANDLE_CLASSNAME = 'node-drag-handle';
|
||||
|
||||
/**
|
||||
* reactflow-specifc properties shared between all node types.
|
||||
*/
|
||||
export const SHARED_NODE_PROPERTIES: Partial<Node> = {
|
||||
dragHandle: `.${DRAG_HANDLE_CLASSNAME}`,
|
||||
};
|
||||
|
||||
/**
|
||||
* Helper for getting the kind of a field.
|
||||
*/
|
||||
export const KIND_MAP = {
|
||||
input: 'inputs' as const,
|
||||
output: 'outputs' as const,
|
||||
};
|
||||
|
||||
/**
|
||||
* Model types' handles are rendered as squares in the UI.
|
||||
*/
|
||||
export const MODEL_TYPES = [
|
||||
'IPAdapterModelField',
|
||||
'ControlNetModelField',
|
||||
'LoRAModelField',
|
||||
'MainModelField',
|
||||
'SDXLMainModelField',
|
||||
'SDXLRefinerModelField',
|
||||
'VaeModelField',
|
||||
'UNetField',
|
||||
'VaeField',
|
||||
'ClipField',
|
||||
'T2IAdapterModelField',
|
||||
'IPAdapterModelField',
|
||||
];
|
||||
|
||||
/**
|
||||
* Colors for each field type - applies to their handles and edges.
|
||||
*/
|
||||
export const FIELD_COLORS: { [key: string]: string } = {
|
||||
BoardField: 'purple.500',
|
||||
BooleanField: 'green.500',
|
||||
ClipField: 'green.500',
|
||||
ColorField: 'pink.300',
|
||||
ConditioningField: 'cyan.500',
|
||||
ControlField: 'teal.500',
|
||||
ControlNetModelField: 'teal.500',
|
||||
EnumField: 'blue.500',
|
||||
FloatField: 'orange.500',
|
||||
ImageField: 'purple.500',
|
||||
IntegerField: 'red.500',
|
||||
IPAdapterField: 'teal.500',
|
||||
IPAdapterModelField: 'teal.500',
|
||||
LatentsField: 'pink.500',
|
||||
LoRAModelField: 'teal.500',
|
||||
MainModelField: 'teal.500',
|
||||
SDXLMainModelField: 'teal.500',
|
||||
SDXLRefinerModelField: 'teal.500',
|
||||
StringField: 'yellow.500',
|
||||
T2IAdapterField: 'teal.500',
|
||||
T2IAdapterModelField: 'teal.500',
|
||||
UNetField: 'red.500',
|
||||
VaeField: 'blue.500',
|
||||
VaeModelField: 'teal.500',
|
||||
};
|
58
invokeai/frontend/web/src/features/nodes/types/v2/error.ts
Normal file
58
invokeai/frontend/web/src/features/nodes/types/v2/error.ts
Normal file
@ -0,0 +1,58 @@
|
||||
/**
|
||||
* Invalid Workflow Version Error
|
||||
* Raised when a workflow version is not recognized.
|
||||
*/
|
||||
export class WorkflowVersionError extends Error {
|
||||
/**
|
||||
* Create WorkflowVersionError
|
||||
* @param {String} message
|
||||
*/
|
||||
constructor(message: string) {
|
||||
super(message);
|
||||
this.name = this.constructor.name;
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Workflow Migration Error
|
||||
* Raised when a workflow migration fails.
|
||||
*/
|
||||
export class WorkflowMigrationError extends Error {
|
||||
/**
|
||||
* Create WorkflowMigrationError
|
||||
* @param {String} message
|
||||
*/
|
||||
constructor(message: string) {
|
||||
super(message);
|
||||
this.name = this.constructor.name;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Unable to Update Node Error
|
||||
* Raised when a node cannot be updated.
|
||||
*/
|
||||
export class NodeUpdateError extends Error {
|
||||
/**
|
||||
* Create NodeUpdateError
|
||||
* @param {String} message
|
||||
*/
|
||||
constructor(message: string) {
|
||||
super(message);
|
||||
this.name = this.constructor.name;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* FieldParseError
|
||||
* Raised when a field cannot be parsed from a field schema.
|
||||
*/
|
||||
export class FieldParseError extends Error {
|
||||
/**
|
||||
* Create FieldTypeParseError
|
||||
* @param {String} message
|
||||
*/
|
||||
constructor(message: string) {
|
||||
super(message);
|
||||
this.name = this.constructor.name;
|
||||
}
|
||||
}
|
875
invokeai/frontend/web/src/features/nodes/types/v2/field.ts
Normal file
875
invokeai/frontend/web/src/features/nodes/types/v2/field.ts
Normal file
@ -0,0 +1,875 @@
|
||||
import { z } from 'zod';
|
||||
|
||||
import {
|
||||
zBoardField,
|
||||
zColorField,
|
||||
zControlNetModelField,
|
||||
zImageField,
|
||||
zIPAdapterModelField,
|
||||
zLoRAModelField,
|
||||
zMainModelField,
|
||||
zSchedulerField,
|
||||
zT2IAdapterModelField,
|
||||
zVAEModelField,
|
||||
} from './common';
|
||||
|
||||
/**
|
||||
* zod schemas & inferred types for fields.
|
||||
*
|
||||
* These schemas and types are only required for stateful field - fields that have UI components
|
||||
* and allow the user to directly provide values.
|
||||
*
|
||||
* This includes primitive values (numbers, strings, booleans), models, scheduler, etc.
|
||||
*
|
||||
* If a field type does not have a UI component, then it does not need to be included here, because
|
||||
* we never store its value. Such field types will be handled via the "StatelessField" logic.
|
||||
*
|
||||
* Fields require:
|
||||
* - z<TypeName>FieldType - zod schema for the field type
|
||||
* - z<TypeName>FieldValue - zod schema for the field value
|
||||
* - z<TypeName>FieldInputInstance - zod schema for the field's input instance
|
||||
* - z<TypeName>FieldOutputInstance - zod schema for the field's output instance
|
||||
* - z<TypeName>FieldInputTemplate - zod schema for the field's input template
|
||||
* - z<TypeName>FieldOutputTemplate - zod schema for the field's output template
|
||||
* - inferred types for each schema
|
||||
* - type guards for InputInstance and InputTemplate
|
||||
*
|
||||
* These then must be added to the unions at the bottom of this file.
|
||||
*/
|
||||
|
||||
/** */
|
||||
|
||||
// #region Base schemas & misc
|
||||
export const zFieldInput = z.enum(['connection', 'direct', 'any']);
|
||||
export type FieldInput = z.infer<typeof zFieldInput>;
|
||||
|
||||
export const zFieldUIComponent = z.enum(['none', 'textarea', 'slider']);
|
||||
export type FieldUIComponent = z.infer<typeof zFieldUIComponent>;
|
||||
|
||||
export const zFieldInstanceBase = z.object({
|
||||
id: z.string().trim().min(1),
|
||||
name: z.string().trim().min(1),
|
||||
});
|
||||
export const zFieldInputInstanceBase = zFieldInstanceBase.extend({
|
||||
fieldKind: z.literal('input'),
|
||||
label: z.string().nullish(),
|
||||
});
|
||||
export const zFieldOutputInstanceBase = zFieldInstanceBase.extend({
|
||||
fieldKind: z.literal('output'),
|
||||
});
|
||||
export type FieldInstanceBase = z.infer<typeof zFieldInstanceBase>;
|
||||
export type FieldInputInstanceBase = z.infer<typeof zFieldInputInstanceBase>;
|
||||
export type FieldOutputInstanceBase = z.infer<typeof zFieldOutputInstanceBase>;
|
||||
|
||||
export const zFieldTemplateBase = z.object({
|
||||
name: z.string().min(1),
|
||||
title: z.string().min(1),
|
||||
description: z.string().nullish(),
|
||||
ui_hidden: z.boolean(),
|
||||
ui_type: z.string().nullish(),
|
||||
ui_order: z.number().int().nullish(),
|
||||
});
|
||||
export const zFieldInputTemplateBase = zFieldTemplateBase.extend({
|
||||
fieldKind: z.literal('input'),
|
||||
input: zFieldInput,
|
||||
required: z.boolean(),
|
||||
ui_component: zFieldUIComponent.nullish(),
|
||||
ui_choice_labels: z.record(z.string()).nullish(),
|
||||
});
|
||||
export const zFieldOutputTemplateBase = zFieldTemplateBase.extend({
|
||||
fieldKind: z.literal('output'),
|
||||
});
|
||||
export type FieldTemplateBase = z.infer<typeof zFieldTemplateBase>;
|
||||
export type FieldInputTemplateBase = z.infer<typeof zFieldInputTemplateBase>;
|
||||
export type FieldOutputTemplateBase = z.infer<typeof zFieldOutputTemplateBase>;
|
||||
|
||||
export const zFieldTypeBase = z.object({
|
||||
isCollection: z.boolean(),
|
||||
isCollectionOrScalar: z.boolean(),
|
||||
});
|
||||
|
||||
export const zFieldIdentifier = z.object({
|
||||
nodeId: z.string().trim().min(1),
|
||||
fieldName: z.string().trim().min(1),
|
||||
});
|
||||
export type FieldIdentifier = z.infer<typeof zFieldIdentifier>;
|
||||
export const isFieldIdentifier = (val: unknown): val is FieldIdentifier => zFieldIdentifier.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region IntegerField
|
||||
export const zIntegerFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('IntegerField'),
|
||||
});
|
||||
export const zIntegerFieldValue = z.number().int();
|
||||
export const zIntegerFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zIntegerFieldType,
|
||||
value: zIntegerFieldValue,
|
||||
});
|
||||
export const zIntegerFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zIntegerFieldType,
|
||||
});
|
||||
export const zIntegerFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zIntegerFieldType,
|
||||
default: zIntegerFieldValue,
|
||||
multipleOf: z.number().int().optional(),
|
||||
maximum: z.number().int().optional(),
|
||||
exclusiveMaximum: z.number().int().optional(),
|
||||
minimum: z.number().int().optional(),
|
||||
exclusiveMinimum: z.number().int().optional(),
|
||||
});
|
||||
export const zIntegerFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zIntegerFieldType,
|
||||
});
|
||||
export type IntegerFieldType = z.infer<typeof zIntegerFieldType>;
|
||||
export type IntegerFieldValue = z.infer<typeof zIntegerFieldValue>;
|
||||
export type IntegerFieldInputInstance = z.infer<typeof zIntegerFieldInputInstance>;
|
||||
export type IntegerFieldInputTemplate = z.infer<typeof zIntegerFieldInputTemplate>;
|
||||
export const isIntegerFieldInputInstance = (val: unknown): val is IntegerFieldInputInstance =>
|
||||
zIntegerFieldInputInstance.safeParse(val).success;
|
||||
export const isIntegerFieldInputTemplate = (val: unknown): val is IntegerFieldInputTemplate =>
|
||||
zIntegerFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region FloatField
|
||||
export const zFloatFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('FloatField'),
|
||||
});
|
||||
export const zFloatFieldValue = z.number();
|
||||
export const zFloatFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zFloatFieldType,
|
||||
value: zFloatFieldValue,
|
||||
});
|
||||
export const zFloatFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zFloatFieldType,
|
||||
});
|
||||
export const zFloatFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zFloatFieldType,
|
||||
default: zFloatFieldValue,
|
||||
multipleOf: z.number().optional(),
|
||||
maximum: z.number().optional(),
|
||||
exclusiveMaximum: z.number().optional(),
|
||||
minimum: z.number().optional(),
|
||||
exclusiveMinimum: z.number().optional(),
|
||||
});
|
||||
export const zFloatFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zFloatFieldType,
|
||||
});
|
||||
export type FloatFieldType = z.infer<typeof zFloatFieldType>;
|
||||
export type FloatFieldValue = z.infer<typeof zFloatFieldValue>;
|
||||
export type FloatFieldInputInstance = z.infer<typeof zFloatFieldInputInstance>;
|
||||
export type FloatFieldOutputInstance = z.infer<typeof zFloatFieldOutputInstance>;
|
||||
export type FloatFieldInputTemplate = z.infer<typeof zFloatFieldInputTemplate>;
|
||||
export type FloatFieldOutputTemplate = z.infer<typeof zFloatFieldOutputTemplate>;
|
||||
export const isFloatFieldInputInstance = (val: unknown): val is FloatFieldInputInstance =>
|
||||
zFloatFieldInputInstance.safeParse(val).success;
|
||||
export const isFloatFieldInputTemplate = (val: unknown): val is FloatFieldInputTemplate =>
|
||||
zFloatFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region StringField
|
||||
export const zStringFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('StringField'),
|
||||
});
|
||||
export const zStringFieldValue = z.string();
|
||||
export const zStringFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zStringFieldType,
|
||||
value: zStringFieldValue,
|
||||
});
|
||||
export const zStringFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zStringFieldType,
|
||||
});
|
||||
export const zStringFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zStringFieldType,
|
||||
default: zStringFieldValue,
|
||||
maxLength: z.number().int().optional(),
|
||||
minLength: z.number().int().optional(),
|
||||
});
|
||||
export const zStringFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zStringFieldType,
|
||||
});
|
||||
|
||||
export type StringFieldType = z.infer<typeof zStringFieldType>;
|
||||
export type StringFieldValue = z.infer<typeof zStringFieldValue>;
|
||||
export type StringFieldInputInstance = z.infer<typeof zStringFieldInputInstance>;
|
||||
export type StringFieldOutputInstance = z.infer<typeof zStringFieldOutputInstance>;
|
||||
export type StringFieldInputTemplate = z.infer<typeof zStringFieldInputTemplate>;
|
||||
export type StringFieldOutputTemplate = z.infer<typeof zStringFieldOutputTemplate>;
|
||||
export const isStringFieldInputInstance = (val: unknown): val is StringFieldInputInstance =>
|
||||
zStringFieldInputInstance.safeParse(val).success;
|
||||
export const isStringFieldInputTemplate = (val: unknown): val is StringFieldInputTemplate =>
|
||||
zStringFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region BooleanField
|
||||
export const zBooleanFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('BooleanField'),
|
||||
});
|
||||
export const zBooleanFieldValue = z.boolean();
|
||||
export const zBooleanFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zBooleanFieldType,
|
||||
value: zBooleanFieldValue,
|
||||
});
|
||||
export const zBooleanFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zBooleanFieldType,
|
||||
});
|
||||
export const zBooleanFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zBooleanFieldType,
|
||||
default: zBooleanFieldValue,
|
||||
});
|
||||
export const zBooleanFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zBooleanFieldType,
|
||||
});
|
||||
export type BooleanFieldType = z.infer<typeof zBooleanFieldType>;
|
||||
export type BooleanFieldValue = z.infer<typeof zBooleanFieldValue>;
|
||||
export type BooleanFieldInputInstance = z.infer<typeof zBooleanFieldInputInstance>;
|
||||
export type BooleanFieldOutputInstance = z.infer<typeof zBooleanFieldOutputInstance>;
|
||||
export type BooleanFieldInputTemplate = z.infer<typeof zBooleanFieldInputTemplate>;
|
||||
export type BooleanFieldOutputTemplate = z.infer<typeof zBooleanFieldOutputTemplate>;
|
||||
export const isBooleanFieldInputInstance = (val: unknown): val is BooleanFieldInputInstance =>
|
||||
zBooleanFieldInputInstance.safeParse(val).success;
|
||||
export const isBooleanFieldInputTemplate = (val: unknown): val is BooleanFieldInputTemplate =>
|
||||
zBooleanFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region EnumField
|
||||
export const zEnumFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('EnumField'),
|
||||
});
|
||||
export const zEnumFieldValue = z.string();
|
||||
export const zEnumFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zEnumFieldType,
|
||||
value: zEnumFieldValue,
|
||||
});
|
||||
export const zEnumFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zEnumFieldType,
|
||||
});
|
||||
export const zEnumFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zEnumFieldType,
|
||||
default: zEnumFieldValue,
|
||||
options: z.array(z.string()),
|
||||
labels: z.record(z.string()).optional(),
|
||||
});
|
||||
export const zEnumFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zEnumFieldType,
|
||||
});
|
||||
export type EnumFieldType = z.infer<typeof zEnumFieldType>;
|
||||
export type EnumFieldValue = z.infer<typeof zEnumFieldValue>;
|
||||
export type EnumFieldInputInstance = z.infer<typeof zEnumFieldInputInstance>;
|
||||
export type EnumFieldOutputInstance = z.infer<typeof zEnumFieldOutputInstance>;
|
||||
export type EnumFieldInputTemplate = z.infer<typeof zEnumFieldInputTemplate>;
|
||||
export type EnumFieldOutputTemplate = z.infer<typeof zEnumFieldOutputTemplate>;
|
||||
export const isEnumFieldInputInstance = (val: unknown): val is EnumFieldInputInstance =>
|
||||
zEnumFieldInputInstance.safeParse(val).success;
|
||||
export const isEnumFieldInputTemplate = (val: unknown): val is EnumFieldInputTemplate =>
|
||||
zEnumFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region ImageField
|
||||
export const zImageFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('ImageField'),
|
||||
});
|
||||
export const zImageFieldValue = zImageField.optional();
|
||||
export const zImageFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zImageFieldType,
|
||||
value: zImageFieldValue,
|
||||
});
|
||||
export const zImageFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zImageFieldType,
|
||||
});
|
||||
export const zImageFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zImageFieldType,
|
||||
default: zImageFieldValue,
|
||||
});
|
||||
export const zImageFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zImageFieldType,
|
||||
});
|
||||
export type ImageFieldType = z.infer<typeof zImageFieldType>;
|
||||
export type ImageFieldValue = z.infer<typeof zImageFieldValue>;
|
||||
export type ImageFieldInputInstance = z.infer<typeof zImageFieldInputInstance>;
|
||||
export type ImageFieldOutputInstance = z.infer<typeof zImageFieldOutputInstance>;
|
||||
export type ImageFieldInputTemplate = z.infer<typeof zImageFieldInputTemplate>;
|
||||
export type ImageFieldOutputTemplate = z.infer<typeof zImageFieldOutputTemplate>;
|
||||
export const isImageFieldInputInstance = (val: unknown): val is ImageFieldInputInstance =>
|
||||
zImageFieldInputInstance.safeParse(val).success;
|
||||
export const isImageFieldInputTemplate = (val: unknown): val is ImageFieldInputTemplate =>
|
||||
zImageFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region BoardField
|
||||
export const zBoardFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('BoardField'),
|
||||
});
|
||||
export const zBoardFieldValue = zBoardField.optional();
|
||||
export const zBoardFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zBoardFieldType,
|
||||
value: zBoardFieldValue,
|
||||
});
|
||||
export const zBoardFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zBoardFieldType,
|
||||
});
|
||||
export const zBoardFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zBoardFieldType,
|
||||
default: zBoardFieldValue,
|
||||
});
|
||||
export const zBoardFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zBoardFieldType,
|
||||
});
|
||||
export type BoardFieldType = z.infer<typeof zBoardFieldType>;
|
||||
export type BoardFieldValue = z.infer<typeof zBoardFieldValue>;
|
||||
export type BoardFieldInputInstance = z.infer<typeof zBoardFieldInputInstance>;
|
||||
export type BoardFieldOutputInstance = z.infer<typeof zBoardFieldOutputInstance>;
|
||||
export type BoardFieldInputTemplate = z.infer<typeof zBoardFieldInputTemplate>;
|
||||
export type BoardFieldOutputTemplate = z.infer<typeof zBoardFieldOutputTemplate>;
|
||||
export const isBoardFieldInputInstance = (val: unknown): val is BoardFieldInputInstance =>
|
||||
zBoardFieldInputInstance.safeParse(val).success;
|
||||
export const isBoardFieldInputTemplate = (val: unknown): val is BoardFieldInputTemplate =>
|
||||
zBoardFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region ColorField
|
||||
export const zColorFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('ColorField'),
|
||||
});
|
||||
export const zColorFieldValue = zColorField.optional();
|
||||
export const zColorFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zColorFieldType,
|
||||
value: zColorFieldValue,
|
||||
});
|
||||
export const zColorFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zColorFieldType,
|
||||
});
|
||||
export const zColorFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zColorFieldType,
|
||||
default: zColorFieldValue,
|
||||
});
|
||||
export const zColorFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zColorFieldType,
|
||||
});
|
||||
export type ColorFieldType = z.infer<typeof zColorFieldType>;
|
||||
export type ColorFieldValue = z.infer<typeof zColorFieldValue>;
|
||||
export type ColorFieldInputInstance = z.infer<typeof zColorFieldInputInstance>;
|
||||
export type ColorFieldOutputInstance = z.infer<typeof zColorFieldOutputInstance>;
|
||||
export type ColorFieldInputTemplate = z.infer<typeof zColorFieldInputTemplate>;
|
||||
export type ColorFieldOutputTemplate = z.infer<typeof zColorFieldOutputTemplate>;
|
||||
export const isColorFieldInputInstance = (val: unknown): val is ColorFieldInputInstance =>
|
||||
zColorFieldInputInstance.safeParse(val).success;
|
||||
export const isColorFieldInputTemplate = (val: unknown): val is ColorFieldInputTemplate =>
|
||||
zColorFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region MainModelField
|
||||
export const zMainModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('MainModelField'),
|
||||
});
|
||||
export const zMainModelFieldValue = zMainModelField.optional();
|
||||
export const zMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zMainModelFieldType,
|
||||
value: zMainModelFieldValue,
|
||||
});
|
||||
export const zMainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zMainModelFieldType,
|
||||
});
|
||||
export const zMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zMainModelFieldType,
|
||||
default: zMainModelFieldValue,
|
||||
});
|
||||
export const zMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zMainModelFieldType,
|
||||
});
|
||||
export type MainModelFieldType = z.infer<typeof zMainModelFieldType>;
|
||||
export type MainModelFieldValue = z.infer<typeof zMainModelFieldValue>;
|
||||
export type MainModelFieldInputInstance = z.infer<typeof zMainModelFieldInputInstance>;
|
||||
export type MainModelFieldOutputInstance = z.infer<typeof zMainModelFieldOutputInstance>;
|
||||
export type MainModelFieldInputTemplate = z.infer<typeof zMainModelFieldInputTemplate>;
|
||||
export type MainModelFieldOutputTemplate = z.infer<typeof zMainModelFieldOutputTemplate>;
|
||||
export const isMainModelFieldInputInstance = (val: unknown): val is MainModelFieldInputInstance =>
|
||||
zMainModelFieldInputInstance.safeParse(val).success;
|
||||
export const isMainModelFieldInputTemplate = (val: unknown): val is MainModelFieldInputTemplate =>
|
||||
zMainModelFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region SDXLMainModelField
|
||||
export const zSDXLMainModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SDXLMainModelField'),
|
||||
});
|
||||
export const zSDXLMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
|
||||
export const zSDXLMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zSDXLMainModelFieldType,
|
||||
value: zSDXLMainModelFieldValue,
|
||||
});
|
||||
export const zSDXLMainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zSDXLMainModelFieldType,
|
||||
});
|
||||
export const zSDXLMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zSDXLMainModelFieldType,
|
||||
default: zSDXLMainModelFieldValue,
|
||||
});
|
||||
export const zSDXLMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zSDXLMainModelFieldType,
|
||||
});
|
||||
export type SDXLMainModelFieldType = z.infer<typeof zSDXLMainModelFieldType>;
|
||||
export type SDXLMainModelFieldValue = z.infer<typeof zSDXLMainModelFieldValue>;
|
||||
export type SDXLMainModelFieldInputInstance = z.infer<typeof zSDXLMainModelFieldInputInstance>;
|
||||
export type SDXLMainModelFieldOutputInstance = z.infer<typeof zSDXLMainModelFieldOutputInstance>;
|
||||
export type SDXLMainModelFieldInputTemplate = z.infer<typeof zSDXLMainModelFieldInputTemplate>;
|
||||
export type SDXLMainModelFieldOutputTemplate = z.infer<typeof zSDXLMainModelFieldOutputTemplate>;
|
||||
export const isSDXLMainModelFieldInputInstance = (val: unknown): val is SDXLMainModelFieldInputInstance =>
|
||||
zSDXLMainModelFieldInputInstance.safeParse(val).success;
|
||||
export const isSDXLMainModelFieldInputTemplate = (val: unknown): val is SDXLMainModelFieldInputTemplate =>
|
||||
zSDXLMainModelFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region SDXLRefinerModelField
|
||||
export const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SDXLRefinerModelField'),
|
||||
});
|
||||
export const zSDXLRefinerModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL Refiner models only.
|
||||
export const zSDXLRefinerModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zSDXLRefinerModelFieldType,
|
||||
value: zSDXLRefinerModelFieldValue,
|
||||
});
|
||||
export const zSDXLRefinerModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zSDXLRefinerModelFieldType,
|
||||
});
|
||||
export const zSDXLRefinerModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zSDXLRefinerModelFieldType,
|
||||
default: zSDXLRefinerModelFieldValue,
|
||||
});
|
||||
export const zSDXLRefinerModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zSDXLRefinerModelFieldType,
|
||||
});
|
||||
export type SDXLRefinerModelFieldType = z.infer<typeof zSDXLRefinerModelFieldType>;
|
||||
export type SDXLRefinerModelFieldValue = z.infer<typeof zSDXLRefinerModelFieldValue>;
|
||||
export type SDXLRefinerModelFieldInputInstance = z.infer<typeof zSDXLRefinerModelFieldInputInstance>;
|
||||
export type SDXLRefinerModelFieldOutputInstance = z.infer<typeof zSDXLRefinerModelFieldOutputInstance>;
|
||||
export type SDXLRefinerModelFieldInputTemplate = z.infer<typeof zSDXLRefinerModelFieldInputTemplate>;
|
||||
export type SDXLRefinerModelFieldOutputTemplate = z.infer<typeof zSDXLRefinerModelFieldOutputTemplate>;
|
||||
export const isSDXLRefinerModelFieldInputInstance = (val: unknown): val is SDXLRefinerModelFieldInputInstance =>
|
||||
zSDXLRefinerModelFieldInputInstance.safeParse(val).success;
|
||||
export const isSDXLRefinerModelFieldInputTemplate = (val: unknown): val is SDXLRefinerModelFieldInputTemplate =>
|
||||
zSDXLRefinerModelFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region VAEModelField
|
||||
export const zVAEModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('VAEModelField'),
|
||||
});
|
||||
export const zVAEModelFieldValue = zVAEModelField.optional();
|
||||
export const zVAEModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zVAEModelFieldType,
|
||||
value: zVAEModelFieldValue,
|
||||
});
|
||||
export const zVAEModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zVAEModelFieldType,
|
||||
});
|
||||
export const zVAEModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zVAEModelFieldType,
|
||||
default: zVAEModelFieldValue,
|
||||
});
|
||||
export const zVAEModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zVAEModelFieldType,
|
||||
});
|
||||
export type VAEModelFieldType = z.infer<typeof zVAEModelFieldType>;
|
||||
export type VAEModelFieldValue = z.infer<typeof zVAEModelFieldValue>;
|
||||
export type VAEModelFieldInputInstance = z.infer<typeof zVAEModelFieldInputInstance>;
|
||||
export type VAEModelFieldOutputInstance = z.infer<typeof zVAEModelFieldOutputInstance>;
|
||||
export type VAEModelFieldInputTemplate = z.infer<typeof zVAEModelFieldInputTemplate>;
|
||||
export type VAEModelFieldOutputTemplate = z.infer<typeof zVAEModelFieldOutputTemplate>;
|
||||
export const isVAEModelFieldInputInstance = (val: unknown): val is VAEModelFieldInputInstance =>
|
||||
zVAEModelFieldInputInstance.safeParse(val).success;
|
||||
export const isVAEModelFieldInputTemplate = (val: unknown): val is VAEModelFieldInputTemplate =>
|
||||
zVAEModelFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region LoRAModelField
|
||||
export const zLoRAModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('LoRAModelField'),
|
||||
});
|
||||
export const zLoRAModelFieldValue = zLoRAModelField.optional();
|
||||
export const zLoRAModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zLoRAModelFieldType,
|
||||
value: zLoRAModelFieldValue,
|
||||
});
|
||||
export const zLoRAModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zLoRAModelFieldType,
|
||||
});
|
||||
export const zLoRAModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zLoRAModelFieldType,
|
||||
default: zLoRAModelFieldValue,
|
||||
});
|
||||
export const zLoRAModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zLoRAModelFieldType,
|
||||
});
|
||||
export type LoRAModelFieldType = z.infer<typeof zLoRAModelFieldType>;
|
||||
export type LoRAModelFieldValue = z.infer<typeof zLoRAModelFieldValue>;
|
||||
export type LoRAModelFieldInputInstance = z.infer<typeof zLoRAModelFieldInputInstance>;
|
||||
export type LoRAModelFieldOutputInstance = z.infer<typeof zLoRAModelFieldOutputInstance>;
|
||||
export type LoRAModelFieldInputTemplate = z.infer<typeof zLoRAModelFieldInputTemplate>;
|
||||
export type LoRAModelFieldOutputTemplate = z.infer<typeof zLoRAModelFieldOutputTemplate>;
|
||||
export const isLoRAModelFieldInputInstance = (val: unknown): val is LoRAModelFieldInputInstance =>
|
||||
zLoRAModelFieldInputInstance.safeParse(val).success;
|
||||
export const isLoRAModelFieldInputTemplate = (val: unknown): val is LoRAModelFieldInputTemplate =>
|
||||
zLoRAModelFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region ControlNetModelField
|
||||
export const zControlNetModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('ControlNetModelField'),
|
||||
});
|
||||
export const zControlNetModelFieldValue = zControlNetModelField.optional();
|
||||
export const zControlNetModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zControlNetModelFieldType,
|
||||
value: zControlNetModelFieldValue,
|
||||
});
|
||||
export const zControlNetModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zControlNetModelFieldType,
|
||||
});
|
||||
export const zControlNetModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zControlNetModelFieldType,
|
||||
default: zControlNetModelFieldValue,
|
||||
});
|
||||
export const zControlNetModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zControlNetModelFieldType,
|
||||
});
|
||||
export type ControlNetModelFieldType = z.infer<typeof zControlNetModelFieldType>;
|
||||
export type ControlNetModelFieldValue = z.infer<typeof zControlNetModelFieldValue>;
|
||||
export type ControlNetModelFieldInputInstance = z.infer<typeof zControlNetModelFieldInputInstance>;
|
||||
export type ControlNetModelFieldOutputInstance = z.infer<typeof zControlNetModelFieldOutputInstance>;
|
||||
export type ControlNetModelFieldInputTemplate = z.infer<typeof zControlNetModelFieldInputTemplate>;
|
||||
export type ControlNetModelFieldOutputTemplate = z.infer<typeof zControlNetModelFieldOutputTemplate>;
|
||||
export const isControlNetModelFieldInputInstance = (val: unknown): val is ControlNetModelFieldInputInstance =>
|
||||
zControlNetModelFieldInputInstance.safeParse(val).success;
|
||||
export const isControlNetModelFieldInputTemplate = (val: unknown): val is ControlNetModelFieldInputTemplate =>
|
||||
zControlNetModelFieldInputTemplate.safeParse(val).success;
|
||||
export const isControlNetModelFieldValue = (v: unknown): v is ControlNetModelFieldValue =>
|
||||
zControlNetModelFieldValue.safeParse(v).success;
|
||||
// #endregion
|
||||
|
||||
// #region IPAdapterModelField
|
||||
export const zIPAdapterModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('IPAdapterModelField'),
|
||||
});
|
||||
export const zIPAdapterModelFieldValue = zIPAdapterModelField.optional();
|
||||
export const zIPAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zIPAdapterModelFieldType,
|
||||
value: zIPAdapterModelFieldValue,
|
||||
});
|
||||
export const zIPAdapterModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zIPAdapterModelFieldType,
|
||||
});
|
||||
export const zIPAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zIPAdapterModelFieldType,
|
||||
default: zIPAdapterModelFieldValue,
|
||||
});
|
||||
export const zIPAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zIPAdapterModelFieldType,
|
||||
});
|
||||
export type IPAdapterModelFieldType = z.infer<typeof zIPAdapterModelFieldType>;
|
||||
export type IPAdapterModelFieldValue = z.infer<typeof zIPAdapterModelFieldValue>;
|
||||
export type IPAdapterModelFieldInputInstance = z.infer<typeof zIPAdapterModelFieldInputInstance>;
|
||||
export type IPAdapterModelFieldOutputInstance = z.infer<typeof zIPAdapterModelFieldOutputInstance>;
|
||||
export type IPAdapterModelFieldInputTemplate = z.infer<typeof zIPAdapterModelFieldInputTemplate>;
|
||||
export type IPAdapterModelFieldOutputTemplate = z.infer<typeof zIPAdapterModelFieldOutputTemplate>;
|
||||
export const isIPAdapterModelFieldInputInstance = (val: unknown): val is IPAdapterModelFieldInputInstance =>
|
||||
zIPAdapterModelFieldInputInstance.safeParse(val).success;
|
||||
export const isIPAdapterModelFieldInputTemplate = (val: unknown): val is IPAdapterModelFieldInputTemplate =>
|
||||
zIPAdapterModelFieldInputTemplate.safeParse(val).success;
|
||||
export const isIPAdapterModelFieldValue = (val: unknown): val is IPAdapterModelFieldValue =>
|
||||
zIPAdapterModelFieldValue.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region T2IAdapterField
|
||||
export const zT2IAdapterModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('T2IAdapterModelField'),
|
||||
});
|
||||
export const zT2IAdapterModelFieldValue = zT2IAdapterModelField.optional();
|
||||
export const zT2IAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zT2IAdapterModelFieldType,
|
||||
value: zT2IAdapterModelFieldValue,
|
||||
});
|
||||
export const zT2IAdapterModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zT2IAdapterModelFieldType,
|
||||
});
|
||||
export const zT2IAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zT2IAdapterModelFieldType,
|
||||
default: zT2IAdapterModelFieldValue,
|
||||
});
|
||||
export const zT2IAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zT2IAdapterModelFieldType,
|
||||
});
|
||||
export type T2IAdapterModelFieldType = z.infer<typeof zT2IAdapterModelFieldType>;
|
||||
export type T2IAdapterModelFieldValue = z.infer<typeof zT2IAdapterModelFieldValue>;
|
||||
export type T2IAdapterModelFieldInputInstance = z.infer<typeof zT2IAdapterModelFieldInputInstance>;
|
||||
export type T2IAdapterModelFieldOutputInstance = z.infer<typeof zT2IAdapterModelFieldOutputInstance>;
|
||||
export type T2IAdapterModelFieldInputTemplate = z.infer<typeof zT2IAdapterModelFieldInputTemplate>;
|
||||
export type T2IAdapterModelFieldOutputTemplate = z.infer<typeof zT2IAdapterModelFieldOutputTemplate>;
|
||||
export const isT2IAdapterModelFieldInputInstance = (val: unknown): val is T2IAdapterModelFieldInputInstance =>
|
||||
zT2IAdapterModelFieldInputInstance.safeParse(val).success;
|
||||
export const isT2IAdapterModelFieldInputTemplate = (val: unknown): val is T2IAdapterModelFieldInputTemplate =>
|
||||
zT2IAdapterModelFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region SchedulerField
|
||||
export const zSchedulerFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SchedulerField'),
|
||||
});
|
||||
export const zSchedulerFieldValue = zSchedulerField.optional();
|
||||
export const zSchedulerFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zSchedulerFieldType,
|
||||
value: zSchedulerFieldValue,
|
||||
});
|
||||
export const zSchedulerFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zSchedulerFieldType,
|
||||
});
|
||||
export const zSchedulerFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zSchedulerFieldType,
|
||||
default: zSchedulerFieldValue,
|
||||
});
|
||||
export const zSchedulerFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zSchedulerFieldType,
|
||||
});
|
||||
export type SchedulerFieldType = z.infer<typeof zSchedulerFieldType>;
|
||||
export type SchedulerFieldValue = z.infer<typeof zSchedulerFieldValue>;
|
||||
export type SchedulerFieldInputInstance = z.infer<typeof zSchedulerFieldInputInstance>;
|
||||
export type SchedulerFieldOutputInstance = z.infer<typeof zSchedulerFieldOutputInstance>;
|
||||
export type SchedulerFieldInputTemplate = z.infer<typeof zSchedulerFieldInputTemplate>;
|
||||
export type SchedulerFieldOutputTemplate = z.infer<typeof zSchedulerFieldOutputTemplate>;
|
||||
export const isSchedulerFieldInputInstance = (val: unknown): val is SchedulerFieldInputInstance =>
|
||||
zSchedulerFieldInputInstance.safeParse(val).success;
|
||||
export const isSchedulerFieldInputTemplate = (val: unknown): val is SchedulerFieldInputTemplate =>
|
||||
zSchedulerFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region StatelessField
|
||||
/**
|
||||
* StatelessField is a catchall for stateless fields with no UI input components. They do not
|
||||
* do not support "direct" input, instead only accepting connections from other fields.
|
||||
*
|
||||
* This field type serves as a "generic" field type.
|
||||
*
|
||||
* Examples include:
|
||||
* - Fields like UNetField or LatentsField where we do not allow direct UI input
|
||||
* - Reserved fields like IsIntermediate
|
||||
* - Any other field we don't have full-on schemas for
|
||||
*/
|
||||
export const zStatelessFieldType = zFieldTypeBase.extend({
|
||||
name: z.string().min(1), // stateless --> we accept the field's name as the type
|
||||
});
|
||||
export const zStatelessFieldValue = z.undefined().catch(undefined); // stateless --> no value, but making this z.never() introduces a lot of extra TS fanagling
|
||||
export const zStatelessFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zStatelessFieldType,
|
||||
value: zStatelessFieldValue,
|
||||
});
|
||||
export const zStatelessFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zStatelessFieldType,
|
||||
});
|
||||
export const zStatelessFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zStatelessFieldType,
|
||||
default: zStatelessFieldValue,
|
||||
input: z.literal('connection'), // stateless --> only accepts connection inputs
|
||||
});
|
||||
export const zStatelessFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zStatelessFieldType,
|
||||
});
|
||||
|
||||
export type StatelessFieldType = z.infer<typeof zStatelessFieldType>;
|
||||
export type StatelessFieldValue = z.infer<typeof zStatelessFieldValue>;
|
||||
export type StatelessFieldInputInstance = z.infer<typeof zStatelessFieldInputInstance>;
|
||||
export type StatelessFieldOutputInstance = z.infer<typeof zStatelessFieldOutputInstance>;
|
||||
export type StatelessFieldInputTemplate = z.infer<typeof zStatelessFieldInputTemplate>;
|
||||
export type StatelessFieldOutputTemplate = z.infer<typeof zStatelessFieldOutputTemplate>;
|
||||
// #endregion
|
||||
|
||||
/**
|
||||
* Here we define the main field unions:
|
||||
* - FieldType
|
||||
* - FieldValue
|
||||
* - FieldInputInstance
|
||||
* - FieldOutputInstance
|
||||
* - FieldInputTemplate
|
||||
* - FieldOutputTemplate
|
||||
*
|
||||
* All stateful fields are unioned together, and then that union is unioned with StatelessField.
|
||||
*
|
||||
* This allows us to interact with stateful fields without needing to worry about "generic" handling
|
||||
* for all other StatelessFields.
|
||||
*/
|
||||
|
||||
// #region StatefulFieldType & FieldType
|
||||
export const zStatefulFieldType = z.union([
|
||||
zIntegerFieldType,
|
||||
zFloatFieldType,
|
||||
zStringFieldType,
|
||||
zBooleanFieldType,
|
||||
zEnumFieldType,
|
||||
zImageFieldType,
|
||||
zBoardFieldType,
|
||||
zMainModelFieldType,
|
||||
zSDXLMainModelFieldType,
|
||||
zSDXLRefinerModelFieldType,
|
||||
zVAEModelFieldType,
|
||||
zLoRAModelFieldType,
|
||||
zControlNetModelFieldType,
|
||||
zIPAdapterModelFieldType,
|
||||
zT2IAdapterModelFieldType,
|
||||
zColorFieldType,
|
||||
zSchedulerFieldType,
|
||||
]);
|
||||
export type StatefulFieldType = z.infer<typeof zStatefulFieldType>;
|
||||
export const isStatefulFieldType = (val: unknown): val is StatefulFieldType =>
|
||||
zStatefulFieldType.safeParse(val).success;
|
||||
|
||||
export const zFieldType = z.union([zStatefulFieldType, zStatelessFieldType]);
|
||||
export type FieldType = z.infer<typeof zFieldType>;
|
||||
export const isFieldType = (val: unknown): val is FieldType => zFieldType.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region StatefulFieldValue & FieldValue
|
||||
export const zStatefulFieldValue = z.union([
|
||||
zIntegerFieldValue,
|
||||
zFloatFieldValue,
|
||||
zStringFieldValue,
|
||||
zBooleanFieldValue,
|
||||
zEnumFieldValue,
|
||||
zImageFieldValue,
|
||||
zBoardFieldValue,
|
||||
zMainModelFieldValue,
|
||||
zSDXLMainModelFieldValue,
|
||||
zSDXLRefinerModelFieldValue,
|
||||
zVAEModelFieldValue,
|
||||
zLoRAModelFieldValue,
|
||||
zControlNetModelFieldValue,
|
||||
zIPAdapterModelFieldValue,
|
||||
zT2IAdapterModelFieldValue,
|
||||
zColorFieldValue,
|
||||
zSchedulerFieldValue,
|
||||
]);
|
||||
export type StatefulFieldValue = z.infer<typeof zStatefulFieldValue>;
|
||||
export const isStatefulFieldValue = (val: unknown): val is StatefulFieldValue =>
|
||||
zStatefulFieldValue.safeParse(val).success;
|
||||
|
||||
export const zFieldValue = z.union([zStatefulFieldValue, zStatelessFieldValue]);
|
||||
export type FieldValue = z.infer<typeof zFieldValue>;
|
||||
export const isFieldValue = (val: unknown): val is FieldValue => zFieldValue.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region StatefulFieldInputInstance & FieldInputInstance
|
||||
export const zStatefulFieldInputInstance = z.union([
|
||||
zIntegerFieldInputInstance,
|
||||
zFloatFieldInputInstance,
|
||||
zStringFieldInputInstance,
|
||||
zBooleanFieldInputInstance,
|
||||
zEnumFieldInputInstance,
|
||||
zImageFieldInputInstance,
|
||||
zBoardFieldInputInstance,
|
||||
zMainModelFieldInputInstance,
|
||||
zSDXLMainModelFieldInputInstance,
|
||||
zSDXLRefinerModelFieldInputInstance,
|
||||
zVAEModelFieldInputInstance,
|
||||
zLoRAModelFieldInputInstance,
|
||||
zControlNetModelFieldInputInstance,
|
||||
zIPAdapterModelFieldInputInstance,
|
||||
zT2IAdapterModelFieldInputInstance,
|
||||
zColorFieldInputInstance,
|
||||
zSchedulerFieldInputInstance,
|
||||
]);
|
||||
export type StatefulFieldInputInstance = z.infer<typeof zStatefulFieldInputInstance>;
|
||||
export const isStatefulFieldInputInstance = (val: unknown): val is StatefulFieldInputInstance =>
|
||||
zStatefulFieldInputInstance.safeParse(val).success;
|
||||
|
||||
export const zFieldInputInstance = z.union([zStatefulFieldInputInstance, zStatelessFieldInputInstance]);
|
||||
export type FieldInputInstance = z.infer<typeof zFieldInputInstance>;
|
||||
export const isFieldInputInstance = (val: unknown): val is FieldInputInstance =>
|
||||
zFieldInputInstance.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region StatefulFieldOutputInstance & FieldOutputInstance
|
||||
export const zStatefulFieldOutputInstance = z.union([
|
||||
zIntegerFieldOutputInstance,
|
||||
zFloatFieldOutputInstance,
|
||||
zStringFieldOutputInstance,
|
||||
zBooleanFieldOutputInstance,
|
||||
zEnumFieldOutputInstance,
|
||||
zImageFieldOutputInstance,
|
||||
zBoardFieldOutputInstance,
|
||||
zMainModelFieldOutputInstance,
|
||||
zSDXLMainModelFieldOutputInstance,
|
||||
zSDXLRefinerModelFieldOutputInstance,
|
||||
zVAEModelFieldOutputInstance,
|
||||
zLoRAModelFieldOutputInstance,
|
||||
zControlNetModelFieldOutputInstance,
|
||||
zIPAdapterModelFieldOutputInstance,
|
||||
zT2IAdapterModelFieldOutputInstance,
|
||||
zColorFieldOutputInstance,
|
||||
zSchedulerFieldOutputInstance,
|
||||
]);
|
||||
export type StatefulFieldOutputInstance = z.infer<typeof zStatefulFieldOutputInstance>;
|
||||
export const isStatefulFieldOutputInstance = (val: unknown): val is StatefulFieldOutputInstance =>
|
||||
zStatefulFieldOutputInstance.safeParse(val).success;
|
||||
|
||||
export const zFieldOutputInstance = z.union([zStatefulFieldOutputInstance, zStatelessFieldOutputInstance]);
|
||||
export type FieldOutputInstance = z.infer<typeof zFieldOutputInstance>;
|
||||
export const isFieldOutputInstance = (val: unknown): val is FieldOutputInstance =>
|
||||
zFieldOutputInstance.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region StatefulFieldInputTemplate & FieldInputTemplate
|
||||
export const zStatefulFieldInputTemplate = z.union([
|
||||
zIntegerFieldInputTemplate,
|
||||
zFloatFieldInputTemplate,
|
||||
zStringFieldInputTemplate,
|
||||
zBooleanFieldInputTemplate,
|
||||
zEnumFieldInputTemplate,
|
||||
zImageFieldInputTemplate,
|
||||
zBoardFieldInputTemplate,
|
||||
zMainModelFieldInputTemplate,
|
||||
zSDXLMainModelFieldInputTemplate,
|
||||
zSDXLRefinerModelFieldInputTemplate,
|
||||
zVAEModelFieldInputTemplate,
|
||||
zLoRAModelFieldInputTemplate,
|
||||
zControlNetModelFieldInputTemplate,
|
||||
zIPAdapterModelFieldInputTemplate,
|
||||
zT2IAdapterModelFieldInputTemplate,
|
||||
zColorFieldInputTemplate,
|
||||
zSchedulerFieldInputTemplate,
|
||||
zStatelessFieldInputTemplate,
|
||||
]);
|
||||
export type StatefulFieldInputTemplate = z.infer<typeof zFieldInputTemplate>;
|
||||
export const isStatefulFieldInputTemplate = (val: unknown): val is StatefulFieldInputTemplate =>
|
||||
zStatefulFieldInputTemplate.safeParse(val).success;
|
||||
|
||||
export const zFieldInputTemplate = z.union([zStatefulFieldInputTemplate, zStatelessFieldInputTemplate]);
|
||||
export type FieldInputTemplate = z.infer<typeof zFieldInputTemplate>;
|
||||
export const isFieldInputTemplate = (val: unknown): val is FieldInputTemplate =>
|
||||
zFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region StatefulFieldOutputTemplate & FieldOutputTemplate
|
||||
export const zStatefulFieldOutputTemplate = z.union([
|
||||
zIntegerFieldOutputTemplate,
|
||||
zFloatFieldOutputTemplate,
|
||||
zStringFieldOutputTemplate,
|
||||
zBooleanFieldOutputTemplate,
|
||||
zEnumFieldOutputTemplate,
|
||||
zImageFieldOutputTemplate,
|
||||
zBoardFieldOutputTemplate,
|
||||
zMainModelFieldOutputTemplate,
|
||||
zSDXLMainModelFieldOutputTemplate,
|
||||
zSDXLRefinerModelFieldOutputTemplate,
|
||||
zVAEModelFieldOutputTemplate,
|
||||
zLoRAModelFieldOutputTemplate,
|
||||
zControlNetModelFieldOutputTemplate,
|
||||
zIPAdapterModelFieldOutputTemplate,
|
||||
zT2IAdapterModelFieldOutputTemplate,
|
||||
zColorFieldOutputTemplate,
|
||||
zSchedulerFieldOutputTemplate,
|
||||
]);
|
||||
export type StatefulFieldOutputTemplate = z.infer<typeof zStatefulFieldOutputTemplate>;
|
||||
export const isStatefulFieldOutputTemplate = (val: unknown): val is StatefulFieldOutputTemplate =>
|
||||
zStatefulFieldOutputTemplate.safeParse(val).success;
|
||||
|
||||
export const zFieldOutputTemplate = z.union([zStatefulFieldOutputTemplate, zStatelessFieldOutputTemplate]);
|
||||
export type FieldOutputTemplate = z.infer<typeof zFieldOutputTemplate>;
|
||||
export const isFieldOutputTemplate = (val: unknown): val is FieldOutputTemplate =>
|
||||
zFieldOutputTemplate.safeParse(val).success;
|
||||
// #endregion
|
@ -0,0 +1,93 @@
|
||||
import type { Edge, Node } from 'reactflow';
|
||||
import { z } from 'zod';
|
||||
|
||||
import { zClassification, zProgressImage } from './common';
|
||||
import { zFieldInputInstance, zFieldInputTemplate, zFieldOutputInstance, zFieldOutputTemplate } from './field';
|
||||
import { zSemVer } from './semver';
|
||||
|
||||
// #region InvocationTemplate
|
||||
export const zInvocationTemplate = z.object({
|
||||
type: z.string(),
|
||||
title: z.string(),
|
||||
description: z.string(),
|
||||
tags: z.array(z.string().min(1)),
|
||||
inputs: z.record(zFieldInputTemplate),
|
||||
outputs: z.record(zFieldOutputTemplate),
|
||||
outputType: z.string().min(1),
|
||||
version: zSemVer,
|
||||
useCache: z.boolean(),
|
||||
nodePack: z.string().min(1).nullish(),
|
||||
classification: zClassification,
|
||||
});
|
||||
export type InvocationTemplate = z.infer<typeof zInvocationTemplate>;
|
||||
// #endregion
|
||||
|
||||
// #region NodeData
|
||||
export const zInvocationNodeData = z.object({
|
||||
id: z.string().trim().min(1),
|
||||
type: z.string().trim().min(1),
|
||||
label: z.string(),
|
||||
isOpen: z.boolean(),
|
||||
notes: z.string(),
|
||||
isIntermediate: z.boolean(),
|
||||
useCache: z.boolean(),
|
||||
version: zSemVer,
|
||||
nodePack: z.string().min(1).nullish(),
|
||||
inputs: z.record(zFieldInputInstance),
|
||||
outputs: z.record(zFieldOutputInstance),
|
||||
});
|
||||
|
||||
export const zNotesNodeData = z.object({
|
||||
id: z.string().trim().min(1),
|
||||
type: z.literal('notes'),
|
||||
label: z.string(),
|
||||
isOpen: z.boolean(),
|
||||
notes: z.string(),
|
||||
});
|
||||
export const zCurrentImageNodeData = z.object({
|
||||
id: z.string().trim().min(1),
|
||||
type: z.literal('current_image'),
|
||||
label: z.string(),
|
||||
isOpen: z.boolean(),
|
||||
});
|
||||
export const zAnyNodeData = z.union([zInvocationNodeData, zNotesNodeData, zCurrentImageNodeData]);
|
||||
|
||||
export type NotesNodeData = z.infer<typeof zNotesNodeData>;
|
||||
export type InvocationNodeData = z.infer<typeof zInvocationNodeData>;
|
||||
export type CurrentImageNodeData = z.infer<typeof zCurrentImageNodeData>;
|
||||
export type AnyNodeData = z.infer<typeof zAnyNodeData>;
|
||||
|
||||
export type InvocationNode = Node<InvocationNodeData, 'invocation'>;
|
||||
export type NotesNode = Node<NotesNodeData, 'notes'>;
|
||||
export type CurrentImageNode = Node<CurrentImageNodeData, 'current_image'>;
|
||||
export type AnyNode = Node<AnyNodeData>;
|
||||
|
||||
export const isInvocationNode = (node?: AnyNode): node is InvocationNode => Boolean(node && node.type === 'invocation');
|
||||
export const isNotesNode = (node?: AnyNode): node is NotesNode => Boolean(node && node.type === 'notes');
|
||||
export const isCurrentImageNode = (node?: AnyNode): node is CurrentImageNode =>
|
||||
Boolean(node && node.type === 'current_image');
|
||||
export const isInvocationNodeData = (node?: AnyNodeData): node is InvocationNodeData =>
|
||||
Boolean(node && !['notes', 'current_image'].includes(node.type)); // node.type may be 'notes', 'current_image', or any invocation type
|
||||
// #endregion
|
||||
|
||||
// #region NodeExecutionState
|
||||
export const zNodeStatus = z.enum(['PENDING', 'IN_PROGRESS', 'COMPLETED', 'FAILED']);
|
||||
export const zNodeExecutionState = z.object({
|
||||
nodeId: z.string().trim().min(1),
|
||||
status: zNodeStatus,
|
||||
progress: z.number().nullable(),
|
||||
progressImage: zProgressImage.nullable(),
|
||||
error: z.string().nullable(),
|
||||
outputs: z.array(z.any()),
|
||||
});
|
||||
export type NodeExecutionState = z.infer<typeof zNodeExecutionState>;
|
||||
export type NodeStatus = z.infer<typeof zNodeStatus>;
|
||||
// #endregion
|
||||
|
||||
// #region Edges
|
||||
export const zInvocationNodeEdgeExtra = z.object({
|
||||
type: z.union([z.literal('default'), z.literal('collapsed')]),
|
||||
});
|
||||
export type InvocationNodeEdgeExtra = z.infer<typeof zInvocationNodeEdgeExtra>;
|
||||
export type InvocationNodeEdge = Edge<InvocationNodeEdgeExtra>;
|
||||
// #endregion
|
@ -0,0 +1,77 @@
|
||||
import { z } from 'zod';
|
||||
|
||||
import {
|
||||
zControlField,
|
||||
zIPAdapterField,
|
||||
zLoRAModelField,
|
||||
zMainModelField,
|
||||
zSDXLRefinerModelField,
|
||||
zT2IAdapterField,
|
||||
zVAEModelField,
|
||||
} from './common';
|
||||
|
||||
// #region Metadata-optimized versions of schemas
|
||||
// TODO: It's possible that `deepPartial` will be deprecated:
|
||||
// - https://github.com/colinhacks/zod/issues/2106
|
||||
// - https://github.com/colinhacks/zod/issues/2854
|
||||
export const zLoRAMetadataItem = z.object({
|
||||
lora: zLoRAModelField.deepPartial(),
|
||||
weight: z.number(),
|
||||
});
|
||||
const zControlNetMetadataItem = zControlField.deepPartial();
|
||||
const zIPAdapterMetadataItem = zIPAdapterField.deepPartial();
|
||||
const zT2IAdapterMetadataItem = zT2IAdapterField.deepPartial();
|
||||
const zSDXLRefinerModelMetadataItem = zSDXLRefinerModelField.deepPartial();
|
||||
const zModelMetadataItem = zMainModelField.deepPartial();
|
||||
const zVAEModelMetadataItem = zVAEModelField.deepPartial();
|
||||
export type LoRAMetadataItem = z.infer<typeof zLoRAMetadataItem>;
|
||||
export type ControlNetMetadataItem = z.infer<typeof zControlNetMetadataItem>;
|
||||
export type IPAdapterMetadataItem = z.infer<typeof zIPAdapterMetadataItem>;
|
||||
export type T2IAdapterMetadataItem = z.infer<typeof zT2IAdapterMetadataItem>;
|
||||
export type SDXLRefinerModelMetadataItem = z.infer<typeof zSDXLRefinerModelMetadataItem>;
|
||||
export type ModelMetadataItem = z.infer<typeof zModelMetadataItem>;
|
||||
export type VAEModelMetadataItem = z.infer<typeof zVAEModelMetadataItem>;
|
||||
// #endregion
|
||||
|
||||
// #region CoreMetadata
|
||||
export const zCoreMetadata = z
|
||||
.object({
|
||||
app_version: z.string().nullish().catch(null),
|
||||
generation_mode: z.string().nullish().catch(null),
|
||||
created_by: z.string().nullish().catch(null),
|
||||
positive_prompt: z.string().nullish().catch(null),
|
||||
negative_prompt: z.string().nullish().catch(null),
|
||||
width: z.number().int().nullish().catch(null),
|
||||
height: z.number().int().nullish().catch(null),
|
||||
seed: z.number().int().nullish().catch(null),
|
||||
rand_device: z.string().nullish().catch(null),
|
||||
cfg_scale: z.number().nullish().catch(null),
|
||||
cfg_rescale_multiplier: z.number().nullish().catch(null),
|
||||
steps: z.number().int().nullish().catch(null),
|
||||
scheduler: z.string().nullish().catch(null),
|
||||
clip_skip: z.number().int().nullish().catch(null),
|
||||
model: zModelMetadataItem.nullish().catch(null),
|
||||
controlnets: z.array(zControlNetMetadataItem).nullish().catch(null),
|
||||
ipAdapters: z.array(zIPAdapterMetadataItem).nullish().catch(null),
|
||||
t2iAdapters: z.array(zT2IAdapterMetadataItem).nullish().catch(null),
|
||||
loras: z.array(zLoRAMetadataItem).nullish().catch(null),
|
||||
vae: zVAEModelMetadataItem.nullish().catch(null),
|
||||
strength: z.number().nullish().catch(null),
|
||||
hrf_enabled: z.boolean().nullish().catch(null),
|
||||
hrf_strength: z.number().nullish().catch(null),
|
||||
hrf_method: z.string().nullish().catch(null),
|
||||
init_image: z.string().nullish().catch(null),
|
||||
positive_style_prompt: z.string().nullish().catch(null),
|
||||
negative_style_prompt: z.string().nullish().catch(null),
|
||||
refiner_model: zSDXLRefinerModelMetadataItem.nullish().catch(null),
|
||||
refiner_cfg_scale: z.number().nullish().catch(null),
|
||||
refiner_steps: z.number().int().nullish().catch(null),
|
||||
refiner_scheduler: z.string().nullish().catch(null),
|
||||
refiner_positive_aesthetic_score: z.number().nullish().catch(null),
|
||||
refiner_negative_aesthetic_score: z.number().nullish().catch(null),
|
||||
refiner_start: z.number().nullish().catch(null),
|
||||
})
|
||||
.passthrough();
|
||||
export type CoreMetadata = z.infer<typeof zCoreMetadata>;
|
||||
|
||||
// #endregion
|
86
invokeai/frontend/web/src/features/nodes/types/v2/openapi.ts
Normal file
86
invokeai/frontend/web/src/features/nodes/types/v2/openapi.ts
Normal file
@ -0,0 +1,86 @@
|
||||
import type { OpenAPIV3_1 } from 'openapi-types';
|
||||
import type {
|
||||
InputFieldJSONSchemaExtra,
|
||||
InvocationJSONSchemaExtra,
|
||||
OutputFieldJSONSchemaExtra,
|
||||
} from 'services/api/types';
|
||||
|
||||
// Janky customization of OpenAPI Schema :/
|
||||
|
||||
export type InvocationSchemaExtra = InvocationJSONSchemaExtra & {
|
||||
output: OpenAPIV3_1.ReferenceObject; // the output of the invocation
|
||||
title: string;
|
||||
category?: string;
|
||||
tags?: string[];
|
||||
version: string;
|
||||
properties: Omit<
|
||||
NonNullable<OpenAPIV3_1.SchemaObject['properties']> & (InputFieldJSONSchemaExtra | OutputFieldJSONSchemaExtra),
|
||||
'type'
|
||||
> & {
|
||||
type: Omit<OpenAPIV3_1.SchemaObject, 'default'> & {
|
||||
default: string;
|
||||
};
|
||||
use_cache: Omit<OpenAPIV3_1.SchemaObject, 'default'> & {
|
||||
default: boolean;
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
export type InvocationSchemaType = {
|
||||
default: string; // the type of the invocation
|
||||
};
|
||||
|
||||
export type InvocationBaseSchemaObject = Omit<OpenAPIV3_1.BaseSchemaObject, 'title' | 'type' | 'properties'> &
|
||||
InvocationSchemaExtra;
|
||||
|
||||
export type InvocationOutputSchemaObject = Omit<OpenAPIV3_1.SchemaObject, 'properties'> & {
|
||||
properties: OpenAPIV3_1.SchemaObject['properties'] & {
|
||||
type: Omit<OpenAPIV3_1.SchemaObject, 'default'> & {
|
||||
default: string;
|
||||
};
|
||||
} & {
|
||||
class: 'output';
|
||||
};
|
||||
};
|
||||
|
||||
export type InvocationFieldSchema = OpenAPIV3_1.SchemaObject & InputFieldJSONSchemaExtra;
|
||||
|
||||
export type OpenAPIV3_1SchemaOrRef = OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject;
|
||||
|
||||
export interface ArraySchemaObject extends InvocationBaseSchemaObject {
|
||||
type: OpenAPIV3_1.ArraySchemaObjectType;
|
||||
items: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject;
|
||||
}
|
||||
export interface NonArraySchemaObject extends InvocationBaseSchemaObject {
|
||||
type?: OpenAPIV3_1.NonArraySchemaObjectType;
|
||||
}
|
||||
|
||||
export type InvocationSchemaObject = (ArraySchemaObject | NonArraySchemaObject) & { class: 'invocation' };
|
||||
|
||||
export const isSchemaObject = (
|
||||
obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined
|
||||
): obj is OpenAPIV3_1.SchemaObject => Boolean(obj && !('$ref' in obj));
|
||||
|
||||
export const isArraySchemaObject = (
|
||||
obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined
|
||||
): obj is OpenAPIV3_1.ArraySchemaObject => Boolean(obj && !('$ref' in obj) && obj.type === 'array');
|
||||
|
||||
export const isNonArraySchemaObject = (
|
||||
obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined
|
||||
): obj is OpenAPIV3_1.NonArraySchemaObject => Boolean(obj && !('$ref' in obj) && obj.type !== 'array');
|
||||
|
||||
export const isRefObject = (
|
||||
obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined
|
||||
): obj is OpenAPIV3_1.ReferenceObject => Boolean(obj && '$ref' in obj);
|
||||
|
||||
export const isInvocationSchemaObject = (
|
||||
obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | InvocationSchemaObject
|
||||
): obj is InvocationSchemaObject => 'class' in obj && obj.class === 'invocation';
|
||||
|
||||
export const isInvocationOutputSchemaObject = (
|
||||
obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | InvocationOutputSchemaObject
|
||||
): obj is InvocationOutputSchemaObject => 'class' in obj && obj.class === 'output';
|
||||
|
||||
export const isInvocationFieldSchema = (
|
||||
obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject
|
||||
): obj is InvocationFieldSchema => !('$ref' in obj);
|
21
invokeai/frontend/web/src/features/nodes/types/v2/semver.ts
Normal file
21
invokeai/frontend/web/src/features/nodes/types/v2/semver.ts
Normal file
@ -0,0 +1,21 @@
|
||||
import { z } from 'zod';
|
||||
|
||||
// Schemas and types for working with semver
|
||||
|
||||
const zVersionInt = z.coerce.number().int().min(0);
|
||||
|
||||
export const zSemVer = z.string().refine((val) => {
|
||||
const [major, minor, patch] = val.split('.');
|
||||
return (
|
||||
zVersionInt.safeParse(major).success && zVersionInt.safeParse(minor).success && zVersionInt.safeParse(patch).success
|
||||
);
|
||||
});
|
||||
|
||||
export const zParsedSemver = zSemVer.transform((val) => {
|
||||
const [major, minor, patch] = val.split('.');
|
||||
return {
|
||||
major: Number(major),
|
||||
minor: Number(minor),
|
||||
patch: Number(patch),
|
||||
};
|
||||
});
|
@ -0,0 +1,89 @@
|
||||
import { z } from 'zod';
|
||||
|
||||
import { zFieldIdentifier } from './field';
|
||||
import { zInvocationNodeData, zNotesNodeData } from './invocation';
|
||||
|
||||
// #region Workflow misc
|
||||
export const zXYPosition = z
|
||||
.object({
|
||||
x: z.number(),
|
||||
y: z.number(),
|
||||
})
|
||||
.default({ x: 0, y: 0 });
|
||||
export type XYPosition = z.infer<typeof zXYPosition>;
|
||||
|
||||
export const zDimension = z.number().gt(0).nullish();
|
||||
export type Dimension = z.infer<typeof zDimension>;
|
||||
|
||||
export const zWorkflowCategory = z.enum(['user', 'default', 'project']);
|
||||
export type WorkflowCategory = z.infer<typeof zWorkflowCategory>;
|
||||
// #endregion
|
||||
|
||||
// #region Workflow Nodes
|
||||
export const zWorkflowInvocationNode = z.object({
|
||||
id: z.string().trim().min(1),
|
||||
type: z.literal('invocation'),
|
||||
data: zInvocationNodeData,
|
||||
width: zDimension,
|
||||
height: zDimension,
|
||||
position: zXYPosition,
|
||||
});
|
||||
export const zWorkflowNotesNode = z.object({
|
||||
id: z.string().trim().min(1),
|
||||
type: z.literal('notes'),
|
||||
data: zNotesNodeData,
|
||||
width: zDimension,
|
||||
height: zDimension,
|
||||
position: zXYPosition,
|
||||
});
|
||||
export const zWorkflowNode = z.union([zWorkflowInvocationNode, zWorkflowNotesNode]);
|
||||
|
||||
export type WorkflowInvocationNode = z.infer<typeof zWorkflowInvocationNode>;
|
||||
export type WorkflowNotesNode = z.infer<typeof zWorkflowNotesNode>;
|
||||
export type WorkflowNode = z.infer<typeof zWorkflowNode>;
|
||||
|
||||
export const isWorkflowInvocationNode = (val: unknown): val is WorkflowInvocationNode =>
|
||||
zWorkflowInvocationNode.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region Workflow Edges
|
||||
export const zWorkflowEdgeBase = z.object({
|
||||
id: z.string().trim().min(1),
|
||||
source: z.string().trim().min(1),
|
||||
target: z.string().trim().min(1),
|
||||
});
|
||||
export const zWorkflowEdgeDefault = zWorkflowEdgeBase.extend({
|
||||
type: z.literal('default'),
|
||||
sourceHandle: z.string().trim().min(1),
|
||||
targetHandle: z.string().trim().min(1),
|
||||
});
|
||||
export const zWorkflowEdgeCollapsed = zWorkflowEdgeBase.extend({
|
||||
type: z.literal('collapsed'),
|
||||
});
|
||||
export const zWorkflowEdge = z.union([zWorkflowEdgeDefault, zWorkflowEdgeCollapsed]);
|
||||
|
||||
export type WorkflowEdgeDefault = z.infer<typeof zWorkflowEdgeDefault>;
|
||||
export type WorkflowEdgeCollapsed = z.infer<typeof zWorkflowEdgeCollapsed>;
|
||||
export type WorkflowEdge = z.infer<typeof zWorkflowEdge>;
|
||||
// #endregion
|
||||
|
||||
// #region Workflow
|
||||
export const zWorkflowV2 = z.object({
|
||||
id: z.string().min(1).optional(),
|
||||
name: z.string(),
|
||||
author: z.string(),
|
||||
description: z.string(),
|
||||
version: z.string(),
|
||||
contact: z.string(),
|
||||
tags: z.string(),
|
||||
notes: z.string(),
|
||||
nodes: z.array(zWorkflowNode),
|
||||
edges: z.array(zWorkflowEdge),
|
||||
exposedFields: z.array(zFieldIdentifier),
|
||||
meta: z.object({
|
||||
category: zWorkflowCategory.default('user'),
|
||||
version: z.literal('2.0.0'),
|
||||
}),
|
||||
});
|
||||
export type WorkflowV2 = z.infer<typeof zWorkflowV2>;
|
||||
// #endregion
|
@ -24,16 +24,12 @@ export const zWorkflowInvocationNode = z.object({
|
||||
id: z.string().trim().min(1),
|
||||
type: z.literal('invocation'),
|
||||
data: zInvocationNodeData,
|
||||
width: zDimension,
|
||||
height: zDimension,
|
||||
position: zXYPosition,
|
||||
});
|
||||
export const zWorkflowNotesNode = z.object({
|
||||
id: z.string().trim().min(1),
|
||||
type: z.literal('notes'),
|
||||
data: zNotesNodeData,
|
||||
width: zDimension,
|
||||
height: zDimension,
|
||||
position: zXYPosition,
|
||||
});
|
||||
export const zWorkflowNode = z.union([zWorkflowInvocationNode, zWorkflowNotesNode]);
|
||||
@ -68,7 +64,7 @@ export type WorkflowEdge = z.infer<typeof zWorkflowEdge>;
|
||||
// #endregion
|
||||
|
||||
// #region Workflow
|
||||
export const zWorkflowV2 = z.object({
|
||||
export const zWorkflowV3 = z.object({
|
||||
id: z.string().min(1).optional(),
|
||||
name: z.string(),
|
||||
author: z.string(),
|
||||
@ -82,8 +78,8 @@ export const zWorkflowV2 = z.object({
|
||||
exposedFields: z.array(zFieldIdentifier),
|
||||
meta: z.object({
|
||||
category: zWorkflowCategory.default('user'),
|
||||
version: z.literal('2.0.0'),
|
||||
version: z.literal('3.0.0'),
|
||||
}),
|
||||
});
|
||||
export type WorkflowV2 = z.infer<typeof zWorkflowV2>;
|
||||
export type WorkflowV3 = z.infer<typeof zWorkflowV3>;
|
||||
// #endregion
|
||||
|
@ -1,5 +1,5 @@
|
||||
import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
|
||||
import type { FieldInputInstance, FieldOutputInstance } from 'features/nodes/types/field';
|
||||
import type { FieldInputInstance } from 'features/nodes/types/field';
|
||||
import type { InvocationNode, InvocationTemplate } from 'features/nodes/types/invocation';
|
||||
import { buildFieldInputInstance } from 'features/nodes/util/schema/buildFieldInputInstance';
|
||||
import { reduce } from 'lodash-es';
|
||||
@ -24,25 +24,6 @@ export const buildInvocationNode = (position: XYPosition, template: InvocationTe
|
||||
{} as Record<string, FieldInputInstance>
|
||||
);
|
||||
|
||||
const outputs = reduce(
|
||||
template.outputs,
|
||||
(outputsAccumulator, outputTemplate, outputName) => {
|
||||
const fieldId = uuidv4();
|
||||
|
||||
const outputFieldValue: FieldOutputInstance = {
|
||||
id: fieldId,
|
||||
name: outputName,
|
||||
type: outputTemplate.type,
|
||||
fieldKind: 'output',
|
||||
};
|
||||
|
||||
outputsAccumulator[outputName] = outputFieldValue;
|
||||
|
||||
return outputsAccumulator;
|
||||
},
|
||||
{} as Record<string, FieldOutputInstance>
|
||||
);
|
||||
|
||||
const node: InvocationNode = {
|
||||
...SHARED_NODE_PROPERTIES,
|
||||
id: nodeId,
|
||||
@ -58,7 +39,6 @@ export const buildInvocationNode = (position: XYPosition, template: InvocationTe
|
||||
isIntermediate: type === 'save_image' ? false : true,
|
||||
useCache: template.useCache,
|
||||
inputs,
|
||||
outputs,
|
||||
},
|
||||
};
|
||||
|
||||
|
@ -54,6 +54,5 @@ export const updateNode = (node: InvocationNode, template: InvocationTemplate):
|
||||
|
||||
// Remove any fields that are not in the template
|
||||
clone.data.inputs = pick(clone.data.inputs, keys(defaults.data.inputs));
|
||||
clone.data.outputs = pick(clone.data.outputs, keys(defaults.data.outputs));
|
||||
return clone;
|
||||
};
|
||||
|
@ -23,11 +23,8 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
|
||||
|
||||
export const buildFieldInputInstance = (id: string, template: FieldInputTemplate): FieldInputInstance => {
|
||||
const fieldInstance: FieldInputInstance = {
|
||||
id,
|
||||
name: template.name,
|
||||
type: template.type,
|
||||
label: '',
|
||||
fieldKind: 'input' as const,
|
||||
value: template.default ?? get(FIELD_VALUE_FALLBACK_MAP, template.type.name),
|
||||
};
|
||||
|
||||
|
@ -2,8 +2,8 @@ import { logger } from 'app/logging/logger';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import type { NodesState, WorkflowsState } from 'features/nodes/store/types';
|
||||
import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
|
||||
import type { WorkflowV2 } from 'features/nodes/types/workflow';
|
||||
import { zWorkflowV2 } from 'features/nodes/types/workflow';
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import { zWorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import i18n from 'i18n';
|
||||
import { cloneDeep, pick } from 'lodash-es';
|
||||
import { fromZodError } from 'zod-validation-error';
|
||||
@ -25,14 +25,14 @@ const workflowKeys = [
|
||||
'exposedFields',
|
||||
'meta',
|
||||
'id',
|
||||
] satisfies (keyof WorkflowV2)[];
|
||||
] satisfies (keyof WorkflowV3)[];
|
||||
|
||||
export type BuildWorkflowFunction = (arg: BuildWorkflowArg) => WorkflowV2;
|
||||
export type BuildWorkflowFunction = (arg: BuildWorkflowArg) => WorkflowV3;
|
||||
|
||||
export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflow }: BuildWorkflowArg): WorkflowV2 => {
|
||||
export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflow }: BuildWorkflowArg): WorkflowV3 => {
|
||||
const clonedWorkflow = pick(cloneDeep(workflow), workflowKeys);
|
||||
|
||||
const newWorkflow: WorkflowV2 = {
|
||||
const newWorkflow: WorkflowV3 = {
|
||||
...clonedWorkflow,
|
||||
nodes: [],
|
||||
edges: [],
|
||||
@ -45,8 +45,6 @@ export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflo
|
||||
type: node.type,
|
||||
data: cloneDeep(node.data),
|
||||
position: { ...node.position },
|
||||
width: node.width,
|
||||
height: node.height,
|
||||
});
|
||||
} else if (isNotesNode(node) && node.type) {
|
||||
newWorkflow.nodes.push({
|
||||
@ -54,8 +52,6 @@ export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflo
|
||||
type: node.type,
|
||||
data: cloneDeep(node.data),
|
||||
position: { ...node.position },
|
||||
width: node.width,
|
||||
height: node.height,
|
||||
});
|
||||
}
|
||||
});
|
||||
@ -83,12 +79,12 @@ export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflo
|
||||
return newWorkflow;
|
||||
};
|
||||
|
||||
export const buildWorkflowWithValidation = ({ nodes, edges, workflow }: BuildWorkflowArg): WorkflowV2 | null => {
|
||||
export const buildWorkflowWithValidation = ({ nodes, edges, workflow }: BuildWorkflowArg): WorkflowV3 | null => {
|
||||
// builds what really, really should be a valid workflow
|
||||
const workflowToValidate = buildWorkflowFast({ nodes, edges, workflow });
|
||||
|
||||
// but bc we are storing this in the DB, let's be extra sure
|
||||
const result = zWorkflowV2.safeParse(workflowToValidate);
|
||||
const result = zWorkflowV3.safeParse(workflowToValidate);
|
||||
|
||||
if (!result.success) {
|
||||
const { message } = fromZodError(result.error, {
|
||||
|
@ -6,8 +6,10 @@ import { zSemVer } from 'features/nodes/types/semver';
|
||||
import { FIELD_TYPE_V1_TO_FIELD_TYPE_V2_MAPPING } from 'features/nodes/types/v1/fieldTypeMap';
|
||||
import type { WorkflowV1 } from 'features/nodes/types/v1/workflowV1';
|
||||
import { zWorkflowV1 } from 'features/nodes/types/v1/workflowV1';
|
||||
import type { WorkflowV2 } from 'features/nodes/types/workflow';
|
||||
import { zWorkflowV2 } from 'features/nodes/types/workflow';
|
||||
import type { WorkflowV2 } from 'features/nodes/types/v2/workflow';
|
||||
import { zWorkflowV2 } from 'features/nodes/types/v2/workflow';
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import { zWorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import { t } from 'i18next';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { z } from 'zod';
|
||||
@ -30,7 +32,7 @@ const zWorkflowMetaVersion = z.object({
|
||||
* - Workflow schema version bumped to 2.0.0
|
||||
*/
|
||||
const migrateV1toV2 = (workflowToMigrate: WorkflowV1): WorkflowV2 => {
|
||||
const invocationTemplates = $store.get()?.getState().nodeTemplates.templates;
|
||||
const invocationTemplates = $store.get()?.getState().nodes.templates;
|
||||
|
||||
if (!invocationTemplates) {
|
||||
throw new Error(t('app.storeNotInitialized'));
|
||||
@ -70,26 +72,34 @@ const migrateV1toV2 = (workflowToMigrate: WorkflowV1): WorkflowV2 => {
|
||||
return zWorkflowV2.parse(workflowToMigrate);
|
||||
};
|
||||
|
||||
const migrateV2toV3 = (workflowToMigrate: WorkflowV2): WorkflowV3 => {
|
||||
// Bump version
|
||||
(workflowToMigrate as unknown as WorkflowV3).meta.version = '3.0.0';
|
||||
// Parsing strips out any extra properties not in the latest version
|
||||
return zWorkflowV3.parse(workflowToMigrate);
|
||||
};
|
||||
|
||||
/**
|
||||
* Parses a workflow and migrates it to the latest version if necessary.
|
||||
*/
|
||||
export const parseAndMigrateWorkflow = (data: unknown): WorkflowV2 => {
|
||||
export const parseAndMigrateWorkflow = (data: unknown): WorkflowV3 => {
|
||||
const workflowVersionResult = zWorkflowMetaVersion.safeParse(data);
|
||||
|
||||
if (!workflowVersionResult.success) {
|
||||
throw new WorkflowVersionError(t('nodes.unableToGetWorkflowVersion'));
|
||||
}
|
||||
|
||||
const { version } = workflowVersionResult.data.meta;
|
||||
let workflow = data as WorkflowV1 | WorkflowV2 | WorkflowV3;
|
||||
|
||||
if (version === '1.0.0') {
|
||||
const v1 = zWorkflowV1.parse(data);
|
||||
return migrateV1toV2(v1);
|
||||
if (workflow.meta.version === '1.0.0') {
|
||||
const v1 = zWorkflowV1.parse(workflow);
|
||||
workflow = migrateV1toV2(v1);
|
||||
}
|
||||
|
||||
if (version === '2.0.0') {
|
||||
return zWorkflowV2.parse(data);
|
||||
if (workflow.meta.version === '2.0.0') {
|
||||
const v2 = zWorkflowV2.parse(workflow);
|
||||
workflow = migrateV2toV3(v2);
|
||||
}
|
||||
|
||||
throw new WorkflowVersionError(t('nodes.unrecognizedWorkflowVersion', { version }));
|
||||
return workflow as WorkflowV3;
|
||||
};
|
||||
|
@ -1,6 +1,6 @@
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import type { InvocationTemplate } from 'features/nodes/types/invocation';
|
||||
import type { WorkflowV2 } from 'features/nodes/types/workflow';
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import { isWorkflowInvocationNode } from 'features/nodes/types/workflow';
|
||||
import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate';
|
||||
import { t } from 'i18next';
|
||||
@ -16,7 +16,7 @@ type WorkflowWarning = {
|
||||
};
|
||||
|
||||
type ValidateWorkflowResult = {
|
||||
workflow: WorkflowV2;
|
||||
workflow: WorkflowV3;
|
||||
warnings: WorkflowWarning[];
|
||||
};
|
||||
|
||||
|
@ -3,7 +3,7 @@ import { useToast } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { $builtWorkflow } from 'features/nodes/hooks/useWorkflowWatcher';
|
||||
import { workflowIDChanged, workflowSaved } from 'features/nodes/store/workflowSlice';
|
||||
import type { WorkflowV2 } from 'features/nodes/types/workflow';
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import { workflowUpdated } from 'features/workflowLibrary/store/actions';
|
||||
import { useCallback, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@ -18,7 +18,7 @@ type UseSaveLibraryWorkflowReturn = {
|
||||
|
||||
type UseSaveLibraryWorkflow = () => UseSaveLibraryWorkflowReturn;
|
||||
|
||||
export const isWorkflowWithID = (workflow: WorkflowV2): workflow is O.Required<WorkflowV2, 'id'> =>
|
||||
export const isWorkflowWithID = (workflow: WorkflowV3): workflow is O.Required<WorkflowV3, 'id'> =>
|
||||
Boolean(workflow.id);
|
||||
|
||||
export const useSaveLibraryWorkflow: UseSaveLibraryWorkflow = () => {
|
||||
|
Loading…
Reference in New Issue
Block a user