From 06428fac672d43a064584c56830aeb7ad2baecd1 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Sun, 18 Jun 2023 20:02:36 +1200 Subject: [PATCH] fix: Revert scheduler back to zod validation --- invokeai/frontend/web/src/app/constants.ts | 4 ++-- .../components/Parameters/Core/ParamScheduler.tsx | 9 +++++---- .../web/src/features/parameters/store/generationSlice.ts | 3 ++- .../src/features/parameters/store/parameterZodSchemas.ts | 4 ++-- 4 files changed, 11 insertions(+), 9 deletions(-) diff --git a/invokeai/frontend/web/src/app/constants.ts b/invokeai/frontend/web/src/app/constants.ts index c757c8959f..ae362ec85b 100644 --- a/invokeai/frontend/web/src/app/constants.ts +++ b/invokeai/frontend/web/src/app/constants.ts @@ -41,9 +41,9 @@ export const SCHEDULER_ITEMS = [ 'heun', 'heun_k', 'unipc', -]; +] as const; -export type Scheduler = typeof SCHEDULERS; +export type Scheduler = (typeof SCHEDULER_ITEMS)[number]; // Valid upscaling levels export const UPSCALING_LEVELS: Array<{ label: string; value: string }> = [ diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx index d749b33eea..50295b206e 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx @@ -1,4 +1,4 @@ -import { SCHEDULER_ITEMS } from 'app/constants'; +import { SCHEDULER_ITEMS, Scheduler } from 'app/constants'; import { RootState } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; @@ -25,12 +25,13 @@ const ParamScheduler = () => { useEffect(() => { if (selectedSchedulers.length === 0) - dispatch(setSelectedSchedulers(SCHEDULER_ITEMS)); + dispatch(setSelectedSchedulers([...SCHEDULER_ITEMS])); const schedulerFound = activeSchedulers.find( (activeSchedulers) => activeSchedulers.label === scheduler ); - if (!schedulerFound) dispatch(setScheduler(activeSchedulers[0].value)); + if (!schedulerFound) + dispatch(setScheduler(activeSchedulers[0].value as Scheduler)); }, [dispatch, selectedSchedulers, scheduler, activeSchedulers]); const handleChange = useCallback( @@ -38,7 +39,7 @@ const ParamScheduler = () => { if (!v) { return; } - dispatch(setScheduler(v)); + dispatch(setScheduler(v as Scheduler)); }, [dispatch] ); diff --git a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts index 5f13bf59ea..59801b8c46 100644 --- a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts +++ b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts @@ -1,5 +1,6 @@ import type { PayloadAction } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit'; +import { Scheduler } from 'app/constants'; import { configChanged } from 'features/system/store/configSlice'; import { clamp, sortBy } from 'lodash-es'; import { ImageDTO } from 'services/api'; @@ -132,7 +133,7 @@ export const generationSlice = createSlice({ setWidth: (state, action: PayloadAction) => { state.width = action.payload; }, - setScheduler: (state, action: PayloadAction) => { + setScheduler: (state, action: PayloadAction) => { state.scheduler = action.payload; }, setSeed: (state, action: PayloadAction) => { diff --git a/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts b/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts index d17dcbce10..b865faf121 100644 --- a/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts +++ b/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts @@ -1,4 +1,4 @@ -import { NUMPY_RAND_MAX } from 'app/constants'; +import { NUMPY_RAND_MAX, SCHEDULER_ITEMS } from 'app/constants'; import { z } from 'zod'; /** @@ -73,7 +73,7 @@ export const isValidCfgScale = (val: unknown): val is CfgScaleParam => /** * Zod schema for scheduler parameter */ -export const zScheduler = z.string(); +export const zScheduler = z.enum(SCHEDULER_ITEMS); /** * Type alias for scheduler parameter, inferred from its zod schema */