feat: Add VAESelect Component

This commit is contained in:
blessedcoolant 2023-06-29 08:57:50 +12:00 committed by psychedelicious
parent 6c62f41f2e
commit 8d5a953dcb
4 changed files with 110 additions and 0 deletions

View File

@ -335,6 +335,7 @@
"modelManager": { "modelManager": {
"modelManager": "Model Manager", "modelManager": "Model Manager",
"model": "Model", "model": "Model",
"customVAE": "Custom VAE",
"allModels": "All Models", "allModels": "All Models",
"checkpointModels": "Checkpoints", "checkpointModels": "Checkpoints",
"diffusersModels": "Diffusers", "diffusersModels": "Diffusers",

View File

@ -14,6 +14,7 @@ import {
SeedParam, SeedParam,
StepsParam, StepsParam,
StrengthParam, StrengthParam,
VAEParam,
WidthParam, WidthParam,
} from './parameterZodSchemas'; } from './parameterZodSchemas';
@ -47,6 +48,7 @@ export interface GenerationState {
horizontalSymmetrySteps: number; horizontalSymmetrySteps: number;
verticalSymmetrySteps: number; verticalSymmetrySteps: number;
model: ModelParam; model: ModelParam;
vae: VAEParam;
shouldUseSeamless: boolean; shouldUseSeamless: boolean;
seamlessXAxis: boolean; seamlessXAxis: boolean;
seamlessYAxis: boolean; seamlessYAxis: boolean;
@ -81,6 +83,7 @@ export const initialGenerationState: GenerationState = {
horizontalSymmetrySteps: 0, horizontalSymmetrySteps: 0,
verticalSymmetrySteps: 0, verticalSymmetrySteps: 0,
model: '', model: '',
vae: '',
shouldUseSeamless: false, shouldUseSeamless: false,
seamlessXAxis: true, seamlessXAxis: true,
seamlessYAxis: true, seamlessYAxis: true,
@ -216,6 +219,9 @@ export const generationSlice = createSlice({
modelSelected: (state, action: PayloadAction<string>) => { modelSelected: (state, action: PayloadAction<string>) => {
state.model = action.payload; state.model = action.payload;
}, },
vaeSelected: (state, action: PayloadAction<string>) => {
state.vae = action.payload;
},
}, },
extraReducers: (builder) => { extraReducers: (builder) => {
builder.addCase(configChanged, (state, action) => { builder.addCase(configChanged, (state, action) => {
@ -260,6 +266,7 @@ export const {
setVerticalSymmetrySteps, setVerticalSymmetrySteps,
initialImageChanged, initialImageChanged,
modelSelected, modelSelected,
vaeSelected,
setShouldUseNoiseSettings, setShouldUseNoiseSettings,
setSeamless, setSeamless,
setSeamlessXAxis, setSeamlessXAxis,

View File

@ -135,6 +135,15 @@ export const zModel = z.string();
* Type alias for model parameter, inferred from its zod schema * Type alias for model parameter, inferred from its zod schema
*/ */
export type ModelParam = z.infer<typeof zModel>; export type ModelParam = z.infer<typeof zModel>;
/**
* Zod schema for VAE parameter
* TODO: Make this a dynamically generated enum?
*/
export const zVAE = z.string();
/**
* Type alias for model parameter, inferred from its zod schema
*/
export type VAEParam = z.infer<typeof zVAE>;
/** /**
* Validates/type-guards a value as a model parameter * Validates/type-guards a value as a model parameter
*/ */

View File

@ -0,0 +1,93 @@
import { memo, useCallback, useEffect, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { SelectItem } from '@mantine/core';
import { forEach } from 'lodash-es';
import { useListModelsQuery } from 'services/api/endpoints/models';
import { RootState } from 'app/store/store';
import { vaeSelected } from 'features/parameters/store/generationSlice';
export const MODEL_TYPE_MAP = {
'sd-1': 'Stable Diffusion 1.x',
'sd-2': 'Stable Diffusion 2.x',
};
const VAESelect = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { data: vaeModels } = useListModelsQuery({
model_type: 'vae',
});
const selectedModelId = useAppSelector(
(state: RootState) => state.generation.vae
);
const data = useMemo(() => {
if (!vaeModels) {
return [];
}
const data: SelectItem[] = [
{
value: 'none',
label: 'None',
group: 'Default',
},
];
forEach(vaeModels.entities, (model, id) => {
if (!model) {
return;
}
data.push({
value: id,
label: model.name,
group: MODEL_TYPE_MAP[model.base_model],
});
});
return data;
}, [vaeModels]);
const selectedModel = useMemo(
() => vaeModels?.entities[selectedModelId],
[vaeModels?.entities, selectedModelId]
);
const handleChangeModel = useCallback(
(v: string | null) => {
if (!v) {
return;
}
dispatch(vaeSelected(v));
},
[dispatch]
);
useEffect(() => {
if (selectedModelId && vaeModels?.ids.includes(selectedModelId)) {
return;
}
handleChangeModel('none');
}, [handleChangeModel, vaeModels?.ids, selectedModelId]);
return (
<IAIMantineSelect
tooltip={selectedModel?.description}
label={t('modelManager.customVAE')}
value={selectedModelId}
placeholder="Pick one"
data={data}
onChange={handleChangeModel}
/>
);
};
export default memo(VAESelect);