playing around with saving values to exposed fields but not the graph

This commit is contained in:
Mary Hipp 2024-02-12 14:02:32 -05:00
parent 1dd07fb1eb
commit ea64649135
7 changed files with 190 additions and 38 deletions

View File

@ -60,9 +60,10 @@ import VAEModelFieldInputComponent from './inputs/VAEModelFieldInputComponent';
type InputFieldProps = {
nodeId: string;
fieldName: string;
saveToGraph?: boolean;
};
const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
const InputFieldRenderer = ({ nodeId, fieldName, saveToGraph = true }: InputFieldProps) => {
const { t } = useTranslation();
const fieldInstance = useFieldInstance(nodeId, fieldName);
const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input');
@ -76,69 +77,181 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
}
if (isStringFieldInputInstance(fieldInstance) && isStringFieldInputTemplate(fieldTemplate)) {
return <StringFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<StringFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
saveToGraph={saveToGraph}
/>
);
}
if (isBooleanFieldInputInstance(fieldInstance) && isBooleanFieldInputTemplate(fieldTemplate)) {
return <BooleanFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<BooleanFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
saveToGraph={saveToGraph}
/>
);
}
if (
(isIntegerFieldInputInstance(fieldInstance) && isIntegerFieldInputTemplate(fieldTemplate)) ||
(isFloatFieldInputInstance(fieldInstance) && isFloatFieldInputTemplate(fieldTemplate))
) {
return <NumberFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<NumberFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
saveToGraph={saveToGraph}
/>
);
}
if (isEnumFieldInputInstance(fieldInstance) && isEnumFieldInputTemplate(fieldTemplate)) {
return <EnumFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<EnumFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
saveToGraph={saveToGraph}
/>
);
}
if (isImageFieldInputInstance(fieldInstance) && isImageFieldInputTemplate(fieldTemplate)) {
return <ImageFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<ImageFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
saveToGraph={saveToGraph}
/>
);
}
if (isBoardFieldInputInstance(fieldInstance) && isBoardFieldInputTemplate(fieldTemplate)) {
return <BoardFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<BoardFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
saveToGraph={saveToGraph}
/>
);
}
if (isMainModelFieldInputInstance(fieldInstance) && isMainModelFieldInputTemplate(fieldTemplate)) {
return <MainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<MainModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
saveToGraph={saveToGraph}
/>
);
}
if (isSDXLRefinerModelFieldInputInstance(fieldInstance) && isSDXLRefinerModelFieldInputTemplate(fieldTemplate)) {
return <RefinerModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<RefinerModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
saveToGraph={saveToGraph}
/>
);
}
if (isVAEModelFieldInputInstance(fieldInstance) && isVAEModelFieldInputTemplate(fieldTemplate)) {
return <VAEModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<VAEModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
saveToGraph={saveToGraph}
/>
);
}
if (isLoRAModelFieldInputInstance(fieldInstance) && isLoRAModelFieldInputTemplate(fieldTemplate)) {
return <LoRAModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<LoRAModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
saveToGraph={saveToGraph}
/>
);
}
if (isControlNetModelFieldInputInstance(fieldInstance) && isControlNetModelFieldInputTemplate(fieldTemplate)) {
return <ControlNetModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<ControlNetModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
saveToGraph={saveToGraph}
/>
);
}
if (isIPAdapterModelFieldInputInstance(fieldInstance) && isIPAdapterModelFieldInputTemplate(fieldTemplate)) {
return <IPAdapterModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<IPAdapterModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
saveToGraph={saveToGraph}
/>
);
}
if (isT2IAdapterModelFieldInputInstance(fieldInstance) && isT2IAdapterModelFieldInputTemplate(fieldTemplate)) {
return <T2IAdapterModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<T2IAdapterModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
saveToGraph={saveToGraph}
/>
);
}
if (isColorFieldInputInstance(fieldInstance) && isColorFieldInputTemplate(fieldTemplate)) {
return <ColorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<ColorFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
saveToGraph={saveToGraph}
/>
);
}
if (isSDXLMainModelFieldInputInstance(fieldInstance) && isSDXLMainModelFieldInputTemplate(fieldTemplate)) {
return <SDXLMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<SDXLMainModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
saveToGraph={saveToGraph}
/>
);
}
if (isSchedulerFieldInputInstance(fieldInstance) && isSchedulerFieldInputTemplate(fieldTemplate)) {
return <SchedulerFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<SchedulerFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
saveToGraph={saveToGraph}
/>
);
}
if (fieldInstance && fieldTemplate) {

View File

@ -57,7 +57,7 @@ const LinearViewField = ({ nodeId, fieldName }: Props) => {
icon={<PiTrashSimpleBold />}
/>
</Flex>
<InputFieldRenderer nodeId={nodeId} fieldName={fieldName} />
<InputFieldRenderer nodeId={nodeId} fieldName={fieldName} saveToGraph={false} />
<NodeSelectionOverlay isSelected={false} isHovered={isMouseOverNode} />
</Flex>
);

View File

@ -8,7 +8,7 @@ import { memo, useCallback } from 'react';
import type { FieldComponentProps } from './types';
const StringFieldInputComponent = (props: FieldComponentProps<StringFieldInputInstance, StringFieldInputTemplate>) => {
const { nodeId, field, fieldTemplate } = props;
const { nodeId, field, fieldTemplate, saveToGraph } = props;
const dispatch = useAppDispatch();
const handleValueChanged = useCallback(
@ -18,10 +18,11 @@ const StringFieldInputComponent = (props: FieldComponentProps<StringFieldInputIn
nodeId,
fieldName: field.name,
value: e.target.value,
saveToGraph,
})
);
},
[dispatch, field.name, nodeId]
[dispatch, field.name, nodeId, saveToGraph]
);
if (fieldTemplate.ui_component === 'textarea') {

View File

@ -4,4 +4,5 @@ export type FieldComponentProps<V extends FieldInputInstance, T extends FieldInp
nodeId: string;
field: V;
fieldTemplate: T;
saveToGraph: boolean;
};

View File

@ -118,6 +118,7 @@ type FieldValueAction<T extends FieldValue> = PayloadAction<{
nodeId: string;
fieldName: string;
value: T;
saveToGraph: boolean;
}>;
const fieldValueReducer = <T extends FieldValue>(
@ -482,49 +483,51 @@ export const nodesSlice = createSlice({
state.selectedEdges = action.payload;
},
fieldStringValueChanged: (state, action: FieldValueAction<StringFieldValue>) => {
fieldValueReducer(state, action, zStringFieldValue);
if (action.payload.saveToGraph) {
fieldValueReducer(state, action, zStringFieldValue);
}
},
fieldNumberValueChanged: (state, action: FieldValueAction<IntegerFieldValue | FloatFieldValue>) => {
fieldValueReducer(state, action, zIntegerFieldValue.or(zFloatFieldValue));
if (action.payload.saveToGraph) { fieldValueReducer(state, action, zIntegerFieldValue.or(zFloatFieldValue)) };
},
fieldBooleanValueChanged: (state, action: FieldValueAction<BooleanFieldValue>) => {
fieldValueReducer(state, action, zBooleanFieldValue);
if (action.payload.saveToGraph) { fieldValueReducer(state, action, zBooleanFieldValue) }
},
fieldBoardValueChanged: (state, action: FieldValueAction<BoardFieldValue>) => {
fieldValueReducer(state, action, zBoardFieldValue);
if (action.payload.saveToGraph) { fieldValueReducer(state, action, zBoardFieldValue) }
},
fieldImageValueChanged: (state, action: FieldValueAction<ImageFieldValue>) => {
fieldValueReducer(state, action, zImageFieldValue);
if (action.payload.saveToGraph) { fieldValueReducer(state, action, zImageFieldValue) }
},
fieldColorValueChanged: (state, action: FieldValueAction<ColorFieldValue>) => {
fieldValueReducer(state, action, zColorFieldValue);
if (action.payload.saveToGraph) { fieldValueReducer(state, action, zColorFieldValue) }
},
fieldMainModelValueChanged: (state, action: FieldValueAction<MainModelFieldValue>) => {
fieldValueReducer(state, action, zMainModelFieldValue);
if (action.payload.saveToGraph) { fieldValueReducer(state, action, zMainModelFieldValue) }
},
fieldRefinerModelValueChanged: (state, action: FieldValueAction<SDXLRefinerModelFieldValue>) => {
fieldValueReducer(state, action, zSDXLRefinerModelFieldValue);
if (action.payload.saveToGraph) { fieldValueReducer(state, action, zSDXLRefinerModelFieldValue) }
},
fieldVaeModelValueChanged: (state, action: FieldValueAction<VAEModelFieldValue>) => {
fieldValueReducer(state, action, zVAEModelFieldValue);
if (action.payload.saveToGraph) { fieldValueReducer(state, action, zVAEModelFieldValue) }
},
fieldLoRAModelValueChanged: (state, action: FieldValueAction<LoRAModelFieldValue>) => {
fieldValueReducer(state, action, zLoRAModelFieldValue);
if (action.payload.saveToGraph) { fieldValueReducer(state, action, zLoRAModelFieldValue) }
},
fieldControlNetModelValueChanged: (state, action: FieldValueAction<ControlNetModelFieldValue>) => {
fieldValueReducer(state, action, zControlNetModelFieldValue);
if (action.payload.saveToGraph) { fieldValueReducer(state, action, zControlNetModelFieldValue) }
},
fieldIPAdapterModelValueChanged: (state, action: FieldValueAction<IPAdapterModelFieldValue>) => {
fieldValueReducer(state, action, zIPAdapterModelFieldValue);
if (action.payload.saveToGraph) { fieldValueReducer(state, action, zIPAdapterModelFieldValue) }
},
fieldT2IAdapterModelValueChanged: (state, action: FieldValueAction<T2IAdapterModelFieldValue>) => {
fieldValueReducer(state, action, zT2IAdapterModelFieldValue);
if (action.payload.saveToGraph) { fieldValueReducer(state, action, zT2IAdapterModelFieldValue) }
},
fieldEnumModelValueChanged: (state, action: FieldValueAction<EnumFieldValue>) => {
fieldValueReducer(state, action, zEnumFieldValue);
if (action.payload.saveToGraph) { fieldValueReducer(state, action, zEnumFieldValue) }
},
fieldSchedulerValueChanged: (state, action: FieldValueAction<SchedulerFieldValue>) => {
fieldValueReducer(state, action, zSchedulerFieldValue);
if (action.payload.saveToGraph) { fieldValueReducer(state, action, zSchedulerFieldValue) }
},
notesNodeValueChanged: (state, action: PayloadAction<{ nodeId: string; value: string }>) => {
const { nodeId, value } = action.payload;

View File

@ -2,11 +2,12 @@ import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import { workflowLoaded } from 'features/nodes/store/actions';
import { isAnyNodeOrEdgeMutation, nodeEditorReset, nodesDeleted } from 'features/nodes/store/nodesSlice';
import type { WorkflowsState as WorkflowState } from 'features/nodes/store/types';
import type { FieldIdentifier } from 'features/nodes/types/field';
import { fieldStringValueChanged, isAnyNodeOrEdgeMutation, nodeEditorReset, nodesDeleted } from 'features/nodes/store/nodesSlice';
import type { NodesState, WorkflowsState as WorkflowState } from 'features/nodes/store/types';
import { zStringFieldValue, type FieldIdentifier, type FieldValue } from 'features/nodes/types/field';
import type { WorkflowCategory, WorkflowV2 } from 'features/nodes/types/workflow';
import { cloneDeep, isEqual, uniqBy } from 'lodash-es';
import { z } from 'zod';
export const blankWorkflow: Omit<WorkflowV2, 'nodes' | 'edges'> = {
name: '',
@ -27,6 +28,29 @@ export const initialWorkflowState: WorkflowState = {
...blankWorkflow,
};
type FieldValueAction<T extends FieldValue> = PayloadAction<{
nodeId: string;
fieldName: string;
value: T;
saveToGraph: boolean;
}>;
const exposedFieldValueReducer = <T extends FieldValue>(
state: WorkflowState,
action: FieldValueAction<T>,
schema: z.ZodTypeAny
) => {
const { nodeId, fieldName, value } = action.payload;
const exposedField = state.exposedFields.find(field => field.nodeId === nodeId)
const result = schema.safeParse(value);
if (!result || !exposedField || !result.success) {
return;
}
exposedField.value = schema.safeParse(value);
};
export const workflowSlice = createSlice({
name: 'workflow',
initialState: initialWorkflowState,
@ -42,6 +66,7 @@ export const workflowSlice = createSlice({
state.exposedFields = state.exposedFields.filter((field) => !isEqual(field, action.payload));
state.isTouched = true;
},
workflowNameChanged: (state, action: PayloadAction<string>) => {
state.name = action.payload;
state.isTouched = true;
@ -97,9 +122,17 @@ export const workflowSlice = createSlice({
builder.addCase(nodeEditorReset, () => cloneDeep(initialWorkflowState));
builder.addCase(fieldStringValueChanged, (state, action) => {
if (!action.payload.saveToGraph) {
exposedFieldValueReducer(state, action, zStringFieldValue);
}
})
builder.addMatcher(isAnyNodeOrEdgeMutation, (state) => {
state.isTouched = true;
});
},
});

View File

@ -91,6 +91,7 @@ export const zFieldTypeBase = z.object({
export const zFieldIdentifier = z.object({
nodeId: z.string().trim().min(1),
fieldName: z.string().trim().min(1),
value: z.any().optional()
});
export type FieldIdentifier = z.infer<typeof zFieldIdentifier>;
export const isFieldIdentifier = (val: unknown): val is FieldIdentifier => zFieldIdentifier.safeParse(val).success;