playing around with saving values to exposed fields but not the graph

This commit is contained in:
Mary Hipp 2024-02-12 14:02:32 -05:00
parent 1dd07fb1eb
commit ea64649135
7 changed files with 190 additions and 38 deletions

View File

@ -60,9 +60,10 @@ import VAEModelFieldInputComponent from './inputs/VAEModelFieldInputComponent';
type InputFieldProps = { type InputFieldProps = {
nodeId: string; nodeId: string;
fieldName: string; fieldName: string;
saveToGraph?: boolean;
}; };
const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => { const InputFieldRenderer = ({ nodeId, fieldName, saveToGraph = true }: InputFieldProps) => {
const { t } = useTranslation(); const { t } = useTranslation();
const fieldInstance = useFieldInstance(nodeId, fieldName); const fieldInstance = useFieldInstance(nodeId, fieldName);
const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input'); const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input');
@ -76,69 +77,181 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
} }
if (isStringFieldInputInstance(fieldInstance) && isStringFieldInputTemplate(fieldTemplate)) { if (isStringFieldInputInstance(fieldInstance) && isStringFieldInputTemplate(fieldTemplate)) {
return <StringFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />; return (
<StringFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
saveToGraph={saveToGraph}
/>
);
} }
if (isBooleanFieldInputInstance(fieldInstance) && isBooleanFieldInputTemplate(fieldTemplate)) { if (isBooleanFieldInputInstance(fieldInstance) && isBooleanFieldInputTemplate(fieldTemplate)) {
return <BooleanFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />; return (
<BooleanFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
saveToGraph={saveToGraph}
/>
);
} }
if ( if (
(isIntegerFieldInputInstance(fieldInstance) && isIntegerFieldInputTemplate(fieldTemplate)) || (isIntegerFieldInputInstance(fieldInstance) && isIntegerFieldInputTemplate(fieldTemplate)) ||
(isFloatFieldInputInstance(fieldInstance) && isFloatFieldInputTemplate(fieldTemplate)) (isFloatFieldInputInstance(fieldInstance) && isFloatFieldInputTemplate(fieldTemplate))
) { ) {
return <NumberFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />; return (
<NumberFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
saveToGraph={saveToGraph}
/>
);
} }
if (isEnumFieldInputInstance(fieldInstance) && isEnumFieldInputTemplate(fieldTemplate)) { if (isEnumFieldInputInstance(fieldInstance) && isEnumFieldInputTemplate(fieldTemplate)) {
return <EnumFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />; return (
<EnumFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
saveToGraph={saveToGraph}
/>
);
} }
if (isImageFieldInputInstance(fieldInstance) && isImageFieldInputTemplate(fieldTemplate)) { if (isImageFieldInputInstance(fieldInstance) && isImageFieldInputTemplate(fieldTemplate)) {
return <ImageFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />; return (
<ImageFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
saveToGraph={saveToGraph}
/>
);
} }
if (isBoardFieldInputInstance(fieldInstance) && isBoardFieldInputTemplate(fieldTemplate)) { if (isBoardFieldInputInstance(fieldInstance) && isBoardFieldInputTemplate(fieldTemplate)) {
return <BoardFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />; return (
<BoardFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
saveToGraph={saveToGraph}
/>
);
} }
if (isMainModelFieldInputInstance(fieldInstance) && isMainModelFieldInputTemplate(fieldTemplate)) { if (isMainModelFieldInputInstance(fieldInstance) && isMainModelFieldInputTemplate(fieldTemplate)) {
return <MainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />; return (
<MainModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
saveToGraph={saveToGraph}
/>
);
} }
if (isSDXLRefinerModelFieldInputInstance(fieldInstance) && isSDXLRefinerModelFieldInputTemplate(fieldTemplate)) { if (isSDXLRefinerModelFieldInputInstance(fieldInstance) && isSDXLRefinerModelFieldInputTemplate(fieldTemplate)) {
return <RefinerModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />; return (
<RefinerModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
saveToGraph={saveToGraph}
/>
);
} }
if (isVAEModelFieldInputInstance(fieldInstance) && isVAEModelFieldInputTemplate(fieldTemplate)) { if (isVAEModelFieldInputInstance(fieldInstance) && isVAEModelFieldInputTemplate(fieldTemplate)) {
return <VAEModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />; return (
<VAEModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
saveToGraph={saveToGraph}
/>
);
} }
if (isLoRAModelFieldInputInstance(fieldInstance) && isLoRAModelFieldInputTemplate(fieldTemplate)) { if (isLoRAModelFieldInputInstance(fieldInstance) && isLoRAModelFieldInputTemplate(fieldTemplate)) {
return <LoRAModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />; return (
<LoRAModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
saveToGraph={saveToGraph}
/>
);
} }
if (isControlNetModelFieldInputInstance(fieldInstance) && isControlNetModelFieldInputTemplate(fieldTemplate)) { if (isControlNetModelFieldInputInstance(fieldInstance) && isControlNetModelFieldInputTemplate(fieldTemplate)) {
return <ControlNetModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />; return (
<ControlNetModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
saveToGraph={saveToGraph}
/>
);
} }
if (isIPAdapterModelFieldInputInstance(fieldInstance) && isIPAdapterModelFieldInputTemplate(fieldTemplate)) { if (isIPAdapterModelFieldInputInstance(fieldInstance) && isIPAdapterModelFieldInputTemplate(fieldTemplate)) {
return <IPAdapterModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />; return (
<IPAdapterModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
saveToGraph={saveToGraph}
/>
);
} }
if (isT2IAdapterModelFieldInputInstance(fieldInstance) && isT2IAdapterModelFieldInputTemplate(fieldTemplate)) { if (isT2IAdapterModelFieldInputInstance(fieldInstance) && isT2IAdapterModelFieldInputTemplate(fieldTemplate)) {
return <T2IAdapterModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />; return (
<T2IAdapterModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
saveToGraph={saveToGraph}
/>
);
} }
if (isColorFieldInputInstance(fieldInstance) && isColorFieldInputTemplate(fieldTemplate)) { if (isColorFieldInputInstance(fieldInstance) && isColorFieldInputTemplate(fieldTemplate)) {
return <ColorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />; return (
<ColorFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
saveToGraph={saveToGraph}
/>
);
} }
if (isSDXLMainModelFieldInputInstance(fieldInstance) && isSDXLMainModelFieldInputTemplate(fieldTemplate)) { if (isSDXLMainModelFieldInputInstance(fieldInstance) && isSDXLMainModelFieldInputTemplate(fieldTemplate)) {
return <SDXLMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />; return (
<SDXLMainModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
saveToGraph={saveToGraph}
/>
);
} }
if (isSchedulerFieldInputInstance(fieldInstance) && isSchedulerFieldInputTemplate(fieldTemplate)) { if (isSchedulerFieldInputInstance(fieldInstance) && isSchedulerFieldInputTemplate(fieldTemplate)) {
return <SchedulerFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />; return (
<SchedulerFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
saveToGraph={saveToGraph}
/>
);
} }
if (fieldInstance && fieldTemplate) { if (fieldInstance && fieldTemplate) {

View File

@ -57,7 +57,7 @@ const LinearViewField = ({ nodeId, fieldName }: Props) => {
icon={<PiTrashSimpleBold />} icon={<PiTrashSimpleBold />}
/> />
</Flex> </Flex>
<InputFieldRenderer nodeId={nodeId} fieldName={fieldName} /> <InputFieldRenderer nodeId={nodeId} fieldName={fieldName} saveToGraph={false} />
<NodeSelectionOverlay isSelected={false} isHovered={isMouseOverNode} /> <NodeSelectionOverlay isSelected={false} isHovered={isMouseOverNode} />
</Flex> </Flex>
); );

View File

@ -8,7 +8,7 @@ import { memo, useCallback } from 'react';
import type { FieldComponentProps } from './types'; import type { FieldComponentProps } from './types';
const StringFieldInputComponent = (props: FieldComponentProps<StringFieldInputInstance, StringFieldInputTemplate>) => { const StringFieldInputComponent = (props: FieldComponentProps<StringFieldInputInstance, StringFieldInputTemplate>) => {
const { nodeId, field, fieldTemplate } = props; const { nodeId, field, fieldTemplate, saveToGraph } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const handleValueChanged = useCallback( const handleValueChanged = useCallback(
@ -18,10 +18,11 @@ const StringFieldInputComponent = (props: FieldComponentProps<StringFieldInputIn
nodeId, nodeId,
fieldName: field.name, fieldName: field.name,
value: e.target.value, value: e.target.value,
saveToGraph,
}) })
); );
}, },
[dispatch, field.name, nodeId] [dispatch, field.name, nodeId, saveToGraph]
); );
if (fieldTemplate.ui_component === 'textarea') { if (fieldTemplate.ui_component === 'textarea') {

View File

@ -4,4 +4,5 @@ export type FieldComponentProps<V extends FieldInputInstance, T extends FieldInp
nodeId: string; nodeId: string;
field: V; field: V;
fieldTemplate: T; fieldTemplate: T;
saveToGraph: boolean;
}; };

View File

@ -118,6 +118,7 @@ type FieldValueAction<T extends FieldValue> = PayloadAction<{
nodeId: string; nodeId: string;
fieldName: string; fieldName: string;
value: T; value: T;
saveToGraph: boolean;
}>; }>;
const fieldValueReducer = <T extends FieldValue>( const fieldValueReducer = <T extends FieldValue>(
@ -482,49 +483,51 @@ export const nodesSlice = createSlice({
state.selectedEdges = action.payload; state.selectedEdges = action.payload;
}, },
fieldStringValueChanged: (state, action: FieldValueAction<StringFieldValue>) => { fieldStringValueChanged: (state, action: FieldValueAction<StringFieldValue>) => {
if (action.payload.saveToGraph) {
fieldValueReducer(state, action, zStringFieldValue); fieldValueReducer(state, action, zStringFieldValue);
}
}, },
fieldNumberValueChanged: (state, action: FieldValueAction<IntegerFieldValue | FloatFieldValue>) => { fieldNumberValueChanged: (state, action: FieldValueAction<IntegerFieldValue | FloatFieldValue>) => {
fieldValueReducer(state, action, zIntegerFieldValue.or(zFloatFieldValue)); if (action.payload.saveToGraph) { fieldValueReducer(state, action, zIntegerFieldValue.or(zFloatFieldValue)) };
}, },
fieldBooleanValueChanged: (state, action: FieldValueAction<BooleanFieldValue>) => { fieldBooleanValueChanged: (state, action: FieldValueAction<BooleanFieldValue>) => {
fieldValueReducer(state, action, zBooleanFieldValue); if (action.payload.saveToGraph) { fieldValueReducer(state, action, zBooleanFieldValue) }
}, },
fieldBoardValueChanged: (state, action: FieldValueAction<BoardFieldValue>) => { fieldBoardValueChanged: (state, action: FieldValueAction<BoardFieldValue>) => {
fieldValueReducer(state, action, zBoardFieldValue); if (action.payload.saveToGraph) { fieldValueReducer(state, action, zBoardFieldValue) }
}, },
fieldImageValueChanged: (state, action: FieldValueAction<ImageFieldValue>) => { fieldImageValueChanged: (state, action: FieldValueAction<ImageFieldValue>) => {
fieldValueReducer(state, action, zImageFieldValue); if (action.payload.saveToGraph) { fieldValueReducer(state, action, zImageFieldValue) }
}, },
fieldColorValueChanged: (state, action: FieldValueAction<ColorFieldValue>) => { fieldColorValueChanged: (state, action: FieldValueAction<ColorFieldValue>) => {
fieldValueReducer(state, action, zColorFieldValue); if (action.payload.saveToGraph) { fieldValueReducer(state, action, zColorFieldValue) }
}, },
fieldMainModelValueChanged: (state, action: FieldValueAction<MainModelFieldValue>) => { fieldMainModelValueChanged: (state, action: FieldValueAction<MainModelFieldValue>) => {
fieldValueReducer(state, action, zMainModelFieldValue); if (action.payload.saveToGraph) { fieldValueReducer(state, action, zMainModelFieldValue) }
}, },
fieldRefinerModelValueChanged: (state, action: FieldValueAction<SDXLRefinerModelFieldValue>) => { fieldRefinerModelValueChanged: (state, action: FieldValueAction<SDXLRefinerModelFieldValue>) => {
fieldValueReducer(state, action, zSDXLRefinerModelFieldValue); if (action.payload.saveToGraph) { fieldValueReducer(state, action, zSDXLRefinerModelFieldValue) }
}, },
fieldVaeModelValueChanged: (state, action: FieldValueAction<VAEModelFieldValue>) => { fieldVaeModelValueChanged: (state, action: FieldValueAction<VAEModelFieldValue>) => {
fieldValueReducer(state, action, zVAEModelFieldValue); if (action.payload.saveToGraph) { fieldValueReducer(state, action, zVAEModelFieldValue) }
}, },
fieldLoRAModelValueChanged: (state, action: FieldValueAction<LoRAModelFieldValue>) => { fieldLoRAModelValueChanged: (state, action: FieldValueAction<LoRAModelFieldValue>) => {
fieldValueReducer(state, action, zLoRAModelFieldValue); if (action.payload.saveToGraph) { fieldValueReducer(state, action, zLoRAModelFieldValue) }
}, },
fieldControlNetModelValueChanged: (state, action: FieldValueAction<ControlNetModelFieldValue>) => { fieldControlNetModelValueChanged: (state, action: FieldValueAction<ControlNetModelFieldValue>) => {
fieldValueReducer(state, action, zControlNetModelFieldValue); if (action.payload.saveToGraph) { fieldValueReducer(state, action, zControlNetModelFieldValue) }
}, },
fieldIPAdapterModelValueChanged: (state, action: FieldValueAction<IPAdapterModelFieldValue>) => { fieldIPAdapterModelValueChanged: (state, action: FieldValueAction<IPAdapterModelFieldValue>) => {
fieldValueReducer(state, action, zIPAdapterModelFieldValue); if (action.payload.saveToGraph) { fieldValueReducer(state, action, zIPAdapterModelFieldValue) }
}, },
fieldT2IAdapterModelValueChanged: (state, action: FieldValueAction<T2IAdapterModelFieldValue>) => { fieldT2IAdapterModelValueChanged: (state, action: FieldValueAction<T2IAdapterModelFieldValue>) => {
fieldValueReducer(state, action, zT2IAdapterModelFieldValue); if (action.payload.saveToGraph) { fieldValueReducer(state, action, zT2IAdapterModelFieldValue) }
}, },
fieldEnumModelValueChanged: (state, action: FieldValueAction<EnumFieldValue>) => { fieldEnumModelValueChanged: (state, action: FieldValueAction<EnumFieldValue>) => {
fieldValueReducer(state, action, zEnumFieldValue); if (action.payload.saveToGraph) { fieldValueReducer(state, action, zEnumFieldValue) }
}, },
fieldSchedulerValueChanged: (state, action: FieldValueAction<SchedulerFieldValue>) => { fieldSchedulerValueChanged: (state, action: FieldValueAction<SchedulerFieldValue>) => {
fieldValueReducer(state, action, zSchedulerFieldValue); if (action.payload.saveToGraph) { fieldValueReducer(state, action, zSchedulerFieldValue) }
}, },
notesNodeValueChanged: (state, action: PayloadAction<{ nodeId: string; value: string }>) => { notesNodeValueChanged: (state, action: PayloadAction<{ nodeId: string; value: string }>) => {
const { nodeId, value } = action.payload; const { nodeId, value } = action.payload;

View File

@ -2,11 +2,12 @@ import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store'; import type { PersistConfig, RootState } from 'app/store/store';
import { workflowLoaded } from 'features/nodes/store/actions'; import { workflowLoaded } from 'features/nodes/store/actions';
import { isAnyNodeOrEdgeMutation, nodeEditorReset, nodesDeleted } from 'features/nodes/store/nodesSlice'; import { fieldStringValueChanged, isAnyNodeOrEdgeMutation, nodeEditorReset, nodesDeleted } from 'features/nodes/store/nodesSlice';
import type { WorkflowsState as WorkflowState } from 'features/nodes/store/types'; import type { NodesState, WorkflowsState as WorkflowState } from 'features/nodes/store/types';
import type { FieldIdentifier } from 'features/nodes/types/field'; import { zStringFieldValue, type FieldIdentifier, type FieldValue } from 'features/nodes/types/field';
import type { WorkflowCategory, WorkflowV2 } from 'features/nodes/types/workflow'; import type { WorkflowCategory, WorkflowV2 } from 'features/nodes/types/workflow';
import { cloneDeep, isEqual, uniqBy } from 'lodash-es'; import { cloneDeep, isEqual, uniqBy } from 'lodash-es';
import { z } from 'zod';
export const blankWorkflow: Omit<WorkflowV2, 'nodes' | 'edges'> = { export const blankWorkflow: Omit<WorkflowV2, 'nodes' | 'edges'> = {
name: '', name: '',
@ -27,6 +28,29 @@ export const initialWorkflowState: WorkflowState = {
...blankWorkflow, ...blankWorkflow,
}; };
type FieldValueAction<T extends FieldValue> = PayloadAction<{
nodeId: string;
fieldName: string;
value: T;
saveToGraph: boolean;
}>;
const exposedFieldValueReducer = <T extends FieldValue>(
state: WorkflowState,
action: FieldValueAction<T>,
schema: z.ZodTypeAny
) => {
const { nodeId, fieldName, value } = action.payload;
const exposedField = state.exposedFields.find(field => field.nodeId === nodeId)
const result = schema.safeParse(value);
if (!result || !exposedField || !result.success) {
return;
}
exposedField.value = schema.safeParse(value);
};
export const workflowSlice = createSlice({ export const workflowSlice = createSlice({
name: 'workflow', name: 'workflow',
initialState: initialWorkflowState, initialState: initialWorkflowState,
@ -42,6 +66,7 @@ export const workflowSlice = createSlice({
state.exposedFields = state.exposedFields.filter((field) => !isEqual(field, action.payload)); state.exposedFields = state.exposedFields.filter((field) => !isEqual(field, action.payload));
state.isTouched = true; state.isTouched = true;
}, },
workflowNameChanged: (state, action: PayloadAction<string>) => { workflowNameChanged: (state, action: PayloadAction<string>) => {
state.name = action.payload; state.name = action.payload;
state.isTouched = true; state.isTouched = true;
@ -97,9 +122,17 @@ export const workflowSlice = createSlice({
builder.addCase(nodeEditorReset, () => cloneDeep(initialWorkflowState)); builder.addCase(nodeEditorReset, () => cloneDeep(initialWorkflowState));
builder.addCase(fieldStringValueChanged, (state, action) => {
if (!action.payload.saveToGraph) {
exposedFieldValueReducer(state, action, zStringFieldValue);
}
})
builder.addMatcher(isAnyNodeOrEdgeMutation, (state) => { builder.addMatcher(isAnyNodeOrEdgeMutation, (state) => {
state.isTouched = true; state.isTouched = true;
}); });
}, },
}); });

View File

@ -91,6 +91,7 @@ export const zFieldTypeBase = z.object({
export const zFieldIdentifier = z.object({ export const zFieldIdentifier = z.object({
nodeId: z.string().trim().min(1), nodeId: z.string().trim().min(1),
fieldName: z.string().trim().min(1), fieldName: z.string().trim().min(1),
value: z.any().optional()
}); });
export type FieldIdentifier = z.infer<typeof zFieldIdentifier>; export type FieldIdentifier = z.infer<typeof zFieldIdentifier>;
export const isFieldIdentifier = (val: unknown): val is FieldIdentifier => zFieldIdentifier.safeParse(val).success; export const isFieldIdentifier = (val: unknown): val is FieldIdentifier => zFieldIdentifier.safeParse(val).success;