From 8d5a953dcb55b59d286173c0f25435f06913e926 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Thu, 29 Jun 2023 08:57:50 +1200 Subject: [PATCH] feat: Add VAESelect Component --- invokeai/frontend/web/public/locales/en.json | 1 + .../parameters/store/generationSlice.ts | 7 ++ .../parameters/store/parameterZodSchemas.ts | 9 ++ .../features/system/components/VAESelect.tsx | 93 +++++++++++++++++++ 4 files changed, 110 insertions(+) create mode 100644 invokeai/frontend/web/src/features/system/components/VAESelect.tsx diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index eb87ea6420..9cfb3379ae 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -335,6 +335,7 @@ "modelManager": { "modelManager": "Model Manager", "model": "Model", + "customVAE": "Custom VAE", "allModels": "All Models", "checkpointModels": "Checkpoints", "diffusersModels": "Diffusers", diff --git a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts index c8e65314da..209cf4b639 100644 --- a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts +++ b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts @@ -14,6 +14,7 @@ import { SeedParam, StepsParam, StrengthParam, + VAEParam, WidthParam, } from './parameterZodSchemas'; @@ -47,6 +48,7 @@ export interface GenerationState { horizontalSymmetrySteps: number; verticalSymmetrySteps: number; model: ModelParam; + vae: VAEParam; shouldUseSeamless: boolean; seamlessXAxis: boolean; seamlessYAxis: boolean; @@ -81,6 +83,7 @@ export const initialGenerationState: GenerationState = { horizontalSymmetrySteps: 0, verticalSymmetrySteps: 0, model: '', + vae: '', shouldUseSeamless: false, seamlessXAxis: true, seamlessYAxis: true, @@ -216,6 +219,9 @@ export const generationSlice = createSlice({ modelSelected: (state, action: PayloadAction) => { state.model = action.payload; }, + vaeSelected: (state, action: PayloadAction) => { + state.vae = action.payload; + }, }, extraReducers: (builder) => { builder.addCase(configChanged, (state, action) => { @@ -260,6 +266,7 @@ export const { setVerticalSymmetrySteps, initialImageChanged, modelSelected, + vaeSelected, setShouldUseNoiseSettings, setSeamless, setSeamlessXAxis, diff --git a/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts b/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts index 48eb309e7d..12d77beeb9 100644 --- a/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts +++ b/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts @@ -135,6 +135,15 @@ export const zModel = z.string(); * Type alias for model parameter, inferred from its zod schema */ export type ModelParam = z.infer; +/** + * Zod schema for VAE parameter + * TODO: Make this a dynamically generated enum? + */ +export const zVAE = z.string(); +/** + * Type alias for model parameter, inferred from its zod schema + */ +export type VAEParam = z.infer; /** * Validates/type-guards a value as a model parameter */ diff --git a/invokeai/frontend/web/src/features/system/components/VAESelect.tsx b/invokeai/frontend/web/src/features/system/components/VAESelect.tsx new file mode 100644 index 0000000000..b8d6ccdfc3 --- /dev/null +++ b/invokeai/frontend/web/src/features/system/components/VAESelect.tsx @@ -0,0 +1,93 @@ +import { memo, useCallback, useEffect, useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; + +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import IAIMantineSelect from 'common/components/IAIMantineSelect'; + +import { SelectItem } from '@mantine/core'; +import { forEach } from 'lodash-es'; +import { useListModelsQuery } from 'services/api/endpoints/models'; + +import { RootState } from 'app/store/store'; +import { vaeSelected } from 'features/parameters/store/generationSlice'; + +export const MODEL_TYPE_MAP = { + 'sd-1': 'Stable Diffusion 1.x', + 'sd-2': 'Stable Diffusion 2.x', +}; + +const VAESelect = () => { + const dispatch = useAppDispatch(); + const { t } = useTranslation(); + + const { data: vaeModels } = useListModelsQuery({ + model_type: 'vae', + }); + + const selectedModelId = useAppSelector( + (state: RootState) => state.generation.vae + ); + + const data = useMemo(() => { + if (!vaeModels) { + return []; + } + + const data: SelectItem[] = [ + { + value: 'none', + label: 'None', + group: 'Default', + }, + ]; + + forEach(vaeModels.entities, (model, id) => { + if (!model) { + return; + } + + data.push({ + value: id, + label: model.name, + group: MODEL_TYPE_MAP[model.base_model], + }); + }); + + return data; + }, [vaeModels]); + + const selectedModel = useMemo( + () => vaeModels?.entities[selectedModelId], + [vaeModels?.entities, selectedModelId] + ); + + const handleChangeModel = useCallback( + (v: string | null) => { + if (!v) { + return; + } + dispatch(vaeSelected(v)); + }, + [dispatch] + ); + + useEffect(() => { + if (selectedModelId && vaeModels?.ids.includes(selectedModelId)) { + return; + } + handleChangeModel('none'); + }, [handleChangeModel, vaeModels?.ids, selectedModelId]); + + return ( + + ); +}; + +export default memo(VAESelect);