From 076284c26f2835a06233e5c06595500e0d0aa683 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 9 Dec 2023 17:08:34 +1100 Subject: [PATCH] fix(ui): add validation to field value reducers Insurance against invalid inputs. Closes #5250 --- .../src/features/nodes/store/nodesSlice.ts | 56 ++++++++++++------- 1 file changed, 36 insertions(+), 20 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index 8b79753dbc..a9c5434289 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -20,6 +20,22 @@ import { StringFieldValue, T2IAdapterModelFieldValue, VAEModelFieldValue, + zBoardFieldValue, + zBooleanFieldValue, + zColorFieldValue, + zControlNetModelFieldValue, + zEnumFieldValue, + zFloatFieldValue, + zImageFieldValue, + zIntegerFieldValue, + zIPAdapterModelFieldValue, + zLoRAModelFieldValue, + zMainModelFieldValue, + zSchedulerFieldValue, + zSDXLRefinerModelFieldValue, + zStringFieldValue, + zT2IAdapterModelFieldValue, + zVAEModelFieldValue, } from 'features/nodes/types/field'; import { AnyNode, @@ -58,6 +74,7 @@ import { appSocketQueueItemStatusChanged, } from 'services/events/actions'; import { v4 as uuidv4 } from 'uuid'; +import { z } from 'zod'; import { NodesState } from './types'; import { findConnectionToValidHandle } from './util/findConnectionToValidHandle'; import { findUnoccupiedPosition } from './util/findUnoccupiedPosition'; @@ -106,7 +123,8 @@ type FieldValueAction = PayloadAction<{ const fieldValueReducer = ( state: NodesState, - action: FieldValueAction + action: FieldValueAction, + schema: z.ZodTypeAny ) => { const { nodeId, fieldName, value } = action.payload; const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId); @@ -115,12 +133,10 @@ const fieldValueReducer = ( return; } const input = node.data?.inputs[fieldName]; - if (!input) { + if (!input || nodeIndex < 0 || !schema.safeParse(value).success) { return; } - if (nodeIndex > -1) { - input.value = value; - } + input.value = value; }; const nodesSlice = createSlice({ @@ -527,91 +543,91 @@ const nodesSlice = createSlice({ state, action: FieldValueAction ) => { - fieldValueReducer(state, action); + fieldValueReducer(state, action, zStringFieldValue); }, fieldNumberValueChanged: ( state, action: FieldValueAction ) => { - fieldValueReducer(state, action); + fieldValueReducer(state, action, zIntegerFieldValue.or(zFloatFieldValue)); }, fieldBooleanValueChanged: ( state, action: FieldValueAction ) => { - fieldValueReducer(state, action); + fieldValueReducer(state, action, zBooleanFieldValue); }, fieldBoardValueChanged: ( state, action: FieldValueAction ) => { - fieldValueReducer(state, action); + fieldValueReducer(state, action, zBoardFieldValue); }, fieldImageValueChanged: ( state, action: FieldValueAction ) => { - fieldValueReducer(state, action); + fieldValueReducer(state, action, zImageFieldValue); }, fieldColorValueChanged: ( state, action: FieldValueAction ) => { - fieldValueReducer(state, action); + fieldValueReducer(state, action, zColorFieldValue); }, fieldMainModelValueChanged: ( state, action: FieldValueAction ) => { - fieldValueReducer(state, action); + fieldValueReducer(state, action, zMainModelFieldValue); }, fieldRefinerModelValueChanged: ( state, action: FieldValueAction ) => { - fieldValueReducer(state, action); + fieldValueReducer(state, action, zSDXLRefinerModelFieldValue); }, fieldVaeModelValueChanged: ( state, action: FieldValueAction ) => { - fieldValueReducer(state, action); + fieldValueReducer(state, action, zVAEModelFieldValue); }, fieldLoRAModelValueChanged: ( state, action: FieldValueAction ) => { - fieldValueReducer(state, action); + fieldValueReducer(state, action, zLoRAModelFieldValue); }, fieldControlNetModelValueChanged: ( state, action: FieldValueAction ) => { - fieldValueReducer(state, action); + fieldValueReducer(state, action, zControlNetModelFieldValue); }, fieldIPAdapterModelValueChanged: ( state, action: FieldValueAction ) => { - fieldValueReducer(state, action); + fieldValueReducer(state, action, zIPAdapterModelFieldValue); }, fieldT2IAdapterModelValueChanged: ( state, action: FieldValueAction ) => { - fieldValueReducer(state, action); + fieldValueReducer(state, action, zT2IAdapterModelFieldValue); }, fieldEnumModelValueChanged: ( state, action: FieldValueAction ) => { - fieldValueReducer(state, action); + fieldValueReducer(state, action, zEnumFieldValue); }, fieldSchedulerValueChanged: ( state, action: FieldValueAction ) => { - fieldValueReducer(state, action); + fieldValueReducer(state, action, zSchedulerFieldValue); }, notesNodeValueChanged: ( state,