feat: Add VAESelect Component

This commit is contained in:
blessedcoolant 2023-06-29 08:57:50 +12:00 committed by psychedelicious
parent 6c62f41f2e
commit 8d5a953dcb
4 changed files with 110 additions and 0 deletions

View File

@ -335,6 +335,7 @@
"modelManager": {
"modelManager": "Model Manager",
"model": "Model",
"customVAE": "Custom VAE",
"allModels": "All Models",
"checkpointModels": "Checkpoints",
"diffusersModels": "Diffusers",

View File

@ -14,6 +14,7 @@ import {
SeedParam,
StepsParam,
StrengthParam,
VAEParam,
WidthParam,
} from './parameterZodSchemas';
@ -47,6 +48,7 @@ export interface GenerationState {
horizontalSymmetrySteps: number;
verticalSymmetrySteps: number;
model: ModelParam;
vae: VAEParam;
shouldUseSeamless: boolean;
seamlessXAxis: boolean;
seamlessYAxis: boolean;
@ -81,6 +83,7 @@ export const initialGenerationState: GenerationState = {
horizontalSymmetrySteps: 0,
verticalSymmetrySteps: 0,
model: '',
vae: '',
shouldUseSeamless: false,
seamlessXAxis: true,
seamlessYAxis: true,
@ -216,6 +219,9 @@ export const generationSlice = createSlice({
modelSelected: (state, action: PayloadAction<string>) => {
state.model = action.payload;
},
vaeSelected: (state, action: PayloadAction<string>) => {
state.vae = action.payload;
},
},
extraReducers: (builder) => {
builder.addCase(configChanged, (state, action) => {
@ -260,6 +266,7 @@ export const {
setVerticalSymmetrySteps,
initialImageChanged,
modelSelected,
vaeSelected,
setShouldUseNoiseSettings,
setSeamless,
setSeamlessXAxis,

View File

@ -135,6 +135,15 @@ export const zModel = z.string();
* Type alias for model parameter, inferred from its zod schema
*/
export type ModelParam = z.infer<typeof zModel>;
/**
* Zod schema for VAE parameter
* TODO: Make this a dynamically generated enum?
*/
export const zVAE = z.string();
/**
* Type alias for model parameter, inferred from its zod schema
*/
export type VAEParam = z.infer<typeof zVAE>;
/**
* Validates/type-guards a value as a model parameter
*/

View File

@ -0,0 +1,93 @@
import { memo, useCallback, useEffect, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { SelectItem } from '@mantine/core';
import { forEach } from 'lodash-es';
import { useListModelsQuery } from 'services/api/endpoints/models';
import { RootState } from 'app/store/store';
import { vaeSelected } from 'features/parameters/store/generationSlice';
export const MODEL_TYPE_MAP = {
'sd-1': 'Stable Diffusion 1.x',
'sd-2': 'Stable Diffusion 2.x',
};
const VAESelect = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { data: vaeModels } = useListModelsQuery({
model_type: 'vae',
});
const selectedModelId = useAppSelector(
(state: RootState) => state.generation.vae
);
const data = useMemo(() => {
if (!vaeModels) {
return [];
}
const data: SelectItem[] = [
{
value: 'none',
label: 'None',
group: 'Default',
},
];
forEach(vaeModels.entities, (model, id) => {
if (!model) {
return;
}
data.push({
value: id,
label: model.name,
group: MODEL_TYPE_MAP[model.base_model],
});
});
return data;
}, [vaeModels]);
const selectedModel = useMemo(
() => vaeModels?.entities[selectedModelId],
[vaeModels?.entities, selectedModelId]
);
const handleChangeModel = useCallback(
(v: string | null) => {
if (!v) {
return;
}
dispatch(vaeSelected(v));
},
[dispatch]
);
useEffect(() => {
if (selectedModelId && vaeModels?.ids.includes(selectedModelId)) {
return;
}
handleChangeModel('none');
}, [handleChangeModel, vaeModels?.ids, selectedModelId]);
return (
<IAIMantineSelect
tooltip={selectedModel?.description}
label={t('modelManager.customVAE')}
value={selectedModelId}
placeholder="Pick one"
data={data}
onChange={handleChangeModel}
/>
);
};
export default memo(VAESelect);