feat(ui): finalize base model compatibility for lora, ti, vae

This commit is contained in:
psychedelicious 2023-07-07 21:23:03 +10:00
parent a9a4081f51
commit 8457fcf7d3
13 changed files with 187 additions and 113 deletions

View File

@ -1,11 +1,12 @@
import { makeToast } from 'app/components/Toaster';
import { modelSelected } from 'features/parameters/store/actions';
import { import {
modelChanged, modelChanged,
vaeSelected, vaeSelected,
} from 'features/parameters/store/generationSlice'; } from 'features/parameters/store/generationSlice';
import { zMainModel } from 'features/parameters/store/parameterZodSchemas';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { startAppListening } from '..'; import { startAppListening } from '..';
import { modelSelected } from 'features/parameters/store/actions';
import { makeToast } from 'app/components/Toaster';
import { lorasCleared } from '../../../../../features/lora/store/loraSlice'; import { lorasCleared } from '../../../../../features/lora/store/loraSlice';
export const addModelSelectedListener = () => { export const addModelSelectedListener = () => {
@ -24,12 +25,18 @@ export const addModelSelectedListener = () => {
}) })
) )
); );
dispatch(vaeSelected('auto')); dispatch(vaeSelected(null));
dispatch(lorasCleared()); dispatch(lorasCleared());
// TODO: controlnet cleared // TODO: controlnet cleared
} }
dispatch(modelChanged({ id: action.payload, base_model, name, type })); const newModel = zMainModel.parse({
id: action.payload,
base_model,
name,
});
dispatch(modelChanged(newModel));
}, },
}); });
}; };

View File

@ -66,6 +66,7 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
'&[data-disabled]': { '&[data-disabled]': {
backgroundColor: mode(base300, base700)(colorMode), backgroundColor: mode(base300, base700)(colorMode),
color: mode(base600, base400)(colorMode), color: mode(base600, base400)(colorMode),
cursor: 'not-allowed',
}, },
}, },
value: { value: {
@ -108,6 +109,10 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
color: mode('white', base50)(colorMode), color: mode('white', base50)(colorMode),
}, },
}, },
'&[data-disabled]': {
color: mode(base500, base600)(colorMode),
cursor: 'not-allowed',
},
}, },
rightSection: { rightSection: {
width: 24, width: 24,

View File

@ -67,6 +67,7 @@ const IAIMantineSelect = (props: IAISelectProps) => {
'&[data-disabled]': { '&[data-disabled]': {
backgroundColor: mode(base300, base700)(colorMode), backgroundColor: mode(base300, base700)(colorMode),
color: mode(base600, base400)(colorMode), color: mode(base600, base400)(colorMode),
cursor: 'not-allowed',
}, },
}, },
value: { value: {
@ -109,6 +110,10 @@ const IAIMantineSelect = (props: IAISelectProps) => {
color: mode('white', base50)(colorMode), color: mode('white', base50)(colorMode),
}, },
}, },
'&[data-disabled]': {
color: mode(base500, base600)(colorMode),
cursor: 'not-allowed',
},
}, },
rightSection: { rightSection: {
width: 32, width: 32,

View File

@ -1,37 +1,28 @@
import { Tooltip, Text } from '@mantine/core'; import { Box, Tooltip } from '@chakra-ui/react';
import { Text } from '@mantine/core';
import { forwardRef, memo } from 'react'; import { forwardRef, memo } from 'react';
interface ItemProps extends React.ComponentPropsWithoutRef<'div'> { interface ItemProps extends React.ComponentPropsWithoutRef<'div'> {
label: string; label: string;
description?: string; description?: string;
tooltip?: string; tooltip?: string;
disabled?: boolean;
} }
const IAIMantineSelectItemWithTooltip = forwardRef<HTMLDivElement, ItemProps>( const IAIMantineSelectItemWithTooltip = forwardRef<HTMLDivElement, ItemProps>(
({ label, tooltip, description, ...others }: ItemProps, ref) => ( ({ label, tooltip, description, disabled, ...others }: ItemProps, ref) => (
<div ref={ref} {...others}> <Tooltip label={tooltip} placement="top" hasArrow>
{tooltip ? ( <Box ref={ref} {...others}>
<Tooltip.Floating label={tooltip}> <Box>
<div>
<Text>{label}</Text>
{description && (
<Text size="xs" color="base.600">
{description}
</Text>
)}
</div>
</Tooltip.Floating>
) : (
<div>
<Text>{label}</Text> <Text>{label}</Text>
{description && ( {description && (
<Text size="xs" color="base.600"> <Text size="xs" color="base.600">
{description} {description}
</Text> </Text>
)} )}
</div> </Box>
)} </Box>
</div> </Tooltip>
) )
); );

View File

@ -6,20 +6,16 @@ import {
PopoverTrigger, PopoverTrigger,
Text, Text,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { SelectItem } from '@mantine/core';
import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect'; import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect';
import { forEach } from 'lodash-es'; import { forEach } from 'lodash-es';
import { PropsWithChildren, useCallback, useMemo, useRef } from 'react'; import { PropsWithChildren, useCallback, useMemo, useRef } from 'react';
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models'; import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
import { PARAMETERS_PANEL_WIDTH } from 'theme/util/constants'; import { PARAMETERS_PANEL_WIDTH } from 'theme/util/constants';
import { RootState } from '../../../app/store/store';
import { useAppSelector } from '../../../app/store/storeHooks';
import IAIMantineSelectItemWithTooltip from '../../../common/components/IAIMantineSelectItemWithTooltip';
type EmbeddingSelectItem = {
label: string;
value: string;
description?: string;
};
type Props = PropsWithChildren & { type Props = PropsWithChildren & {
onSelect: (v: string) => void; onSelect: (v: string) => void;
@ -41,22 +37,27 @@ const ParamEmbeddingPopover = (props: Props) => {
return []; return [];
} }
const data: EmbeddingSelectItem[] = []; const data: SelectItem[] = [];
forEach(embeddingQueryData.entities, (embedding, _) => { forEach(embeddingQueryData.entities, (embedding, _) => {
if (!embedding) return; if (!embedding) {
return;
}
const disabled = currentMainModel?.base_model !== embedding.base_model;
data.push({ data.push({
value: embedding.name, value: embedding.name,
label: embedding.name, label: embedding.name,
description: embedding.description, group: MODEL_TYPE_MAP[embedding.base_model],
...(currentMainModel?.base_model !== embedding.base_model disabled,
? { disabled: true, tooltip: 'Incompatible base model' } tooltip: disabled
: {}), ? `Incompatible base model: ${embedding.base_model}`
: undefined,
}); });
}); });
return data; return data.sort((a, b) => (a.disabled && !b.disabled ? 1 : -1));
}, [embeddingQueryData, currentMainModel?.base_model]); }, [embeddingQueryData, currentMainModel?.base_model]);
const handleChange = useCallback( const handleChange = useCallback(
@ -114,8 +115,10 @@ const ParamEmbeddingPopover = (props: Props) => {
nothingFound="No Matching Embeddings" nothingFound="No Matching Embeddings"
itemComponent={IAIMantineSelectItemWithTooltip} itemComponent={IAIMantineSelectItemWithTooltip}
disabled={data.length === 0} disabled={data.length === 0}
filter={(value, selected, item: EmbeddingSelectItem) => filter={(value, selected, item: SelectItem) =>
item.label.toLowerCase().includes(value.toLowerCase().trim()) || item.label
?.toLowerCase()
.includes(value.toLowerCase().trim()) ||
item.value.toLowerCase().includes(value.toLowerCase().trim()) item.value.toLowerCase().includes(value.toLowerCase().trim())
} }
onChange={handleChange} onChange={handleChange}

View File

@ -1,20 +1,16 @@
import { Flex, Text } from '@chakra-ui/react'; import { Flex, Text } from '@chakra-ui/react';
import { SelectItem } from '@mantine/core';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { RootState, stateSelector } from 'app/store/store'; import { RootState, stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect'; import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { loraAdded } from 'features/lora/store/loraSlice';
import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect';
import { forEach } from 'lodash-es'; import { forEach } from 'lodash-es';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models'; import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
import { loraAdded } from '../store/loraSlice';
import IAIMantineSelectItemWithTooltip from '../../../common/components/IAIMantineSelectItemWithTooltip';
type LoraSelectItem = {
label: string;
value: string;
description?: string;
};
const selector = createSelector( const selector = createSelector(
stateSelector, stateSelector,
@ -38,24 +34,27 @@ const ParamLoraSelect = () => {
return []; return [];
} }
const data: LoraSelectItem[] = []; const data: SelectItem[] = [];
forEach(lorasQueryData.entities, (lora, id) => { forEach(lorasQueryData.entities, (lora, id) => {
if (!lora || Boolean(id in loras)) { if (!lora || Boolean(id in loras)) {
return; return;
} }
const disabled = currentMainModel?.base_model !== lora.base_model;
data.push({ data.push({
value: id, value: id,
label: lora.name, label: lora.name,
description: lora.description, disabled,
...(currentMainModel?.base_model !== lora.base_model group: MODEL_TYPE_MAP[lora.base_model],
? { disabled: true, tooltip: 'Incompatible base model' } tooltip: disabled
: {}), ? `Incompatible base model: ${lora.base_model}`
: undefined,
}); });
}); });
return data; return data.sort((a, b) => (a.disabled && !b.disabled ? 1 : -1));
}, [loras, lorasQueryData, currentMainModel?.base_model]); }, [loras, lorasQueryData, currentMainModel?.base_model]);
const handleChange = useCallback( const handleChange = useCallback(
@ -88,8 +87,8 @@ const ParamLoraSelect = () => {
nothingFound="No matching LoRAs" nothingFound="No matching LoRAs"
itemComponent={IAIMantineSelectItemWithTooltip} itemComponent={IAIMantineSelectItemWithTooltip}
disabled={data.length === 0} disabled={data.length === 0}
filter={(value, selected, item: LoraSelectItem) => filter={(value, selected, item: SelectItem) =>
item.label.toLowerCase().includes(value.toLowerCase().trim()) || item.label?.toLowerCase().includes(value.toLowerCase().trim()) ||
item.value.toLowerCase().includes(value.toLowerCase().trim()) item.value.toLowerCase().includes(value.toLowerCase().trim())
} }
onChange={handleChange} onChange={handleChange}

View File

@ -1,18 +1,21 @@
import { PayloadAction, createSlice } from '@reduxjs/toolkit'; import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import { LoRAModelParam } from 'features/parameters/store/parameterZodSchemas';
import { LoRAModelConfigEntity } from 'services/api/endpoints/models'; import { LoRAModelConfigEntity } from 'services/api/endpoints/models';
import { BaseModelType } from 'services/api/types';
export type Lora = { export type Lora = {
id: string; id: string;
base_model: BaseModelType;
name: string; name: string;
weight: number; weight: number;
}; };
export const defaultLoRAConfig: Omit<Lora, 'id' | 'name'> = { export const defaultLoRAConfig = {
weight: 0.75, weight: 0.75,
}; };
export type LoraState = { export type LoraState = {
loras: Record<string, Lora>; loras: Record<string, LoRAModelParam & { weight: number }>;
}; };
export const intialLoraState: LoraState = { export const intialLoraState: LoraState = {
@ -24,14 +27,14 @@ export const loraSlice = createSlice({
initialState: intialLoraState, initialState: intialLoraState,
reducers: { reducers: {
loraAdded: (state, action: PayloadAction<LoRAModelConfigEntity>) => { loraAdded: (state, action: PayloadAction<LoRAModelConfigEntity>) => {
const { name, id } = action.payload; const { name, id, base_model } = action.payload;
state.loras[id] = { id, name, ...defaultLoRAConfig }; state.loras[id] = { id, name, base_model, ...defaultLoRAConfig };
}, },
loraRemoved: (state, action: PayloadAction<string>) => { loraRemoved: (state, action: PayloadAction<string>) => {
const id = action.payload; const id = action.payload;
delete state.loras[id]; delete state.loras[id];
}, },
lorasCleared: (state, action: PayloadAction<>) => { lorasCleared: (state) => {
state.loras = {}; state.loras = {};
}, },
loraWeightChanged: ( loraWeightChanged: (

View File

@ -19,7 +19,7 @@ export const addVAEToGraph = (
const { vae } = state.generation; const { vae } = state.generation;
const vae_model = modelIdToVAEModelField(vae?.id || ''); const vae_model = modelIdToVAEModelField(vae?.id || '');
const isAutoVae = vae?.id === 'auto'; const isAutoVae = !vae;
if (!isAutoVae) { if (!isAutoVae) {
graph.nodes[VAE_LOADER] = { graph.nodes[VAE_LOADER] = {

View File

@ -1,6 +1,10 @@
import { useAppToaster } from 'app/components/Toaster';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import { useCallback } from 'react'; import { useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { isImageField } from 'services/api/guards';
import { ImageDTO } from 'services/api/types';
import { initialImageSelected, modelSelected } from '../store/actions';
import { import {
setCfgScale, setCfgScale,
setHeight, setHeight,
@ -12,14 +16,10 @@ import {
setSteps, setSteps,
setWidth, setWidth,
} from '../store/generationSlice'; } from '../store/generationSlice';
import { isImageField } from 'services/api/guards';
import { initialImageSelected, modelSelected } from '../store/actions';
import { useAppToaster } from 'app/components/Toaster';
import { ImageDTO } from 'services/api/types';
import { import {
isValidCfgScale, isValidCfgScale,
isValidHeight, isValidHeight,
isValidModel, isValidMainModel,
isValidNegativePrompt, isValidNegativePrompt,
isValidPositivePrompt, isValidPositivePrompt,
isValidScheduler, isValidScheduler,
@ -158,7 +158,7 @@ export const useRecallParameters = () => {
*/ */
const recallModel = useCallback( const recallModel = useCallback(
(model: unknown) => { (model: unknown) => {
if (!isValidModel(model)) { if (!isValidMainModel(model)) {
parameterNotSetToast(); parameterNotSetToast();
return; return;
} }
@ -295,7 +295,7 @@ export const useRecallParameters = () => {
if (isValidCfgScale(cfg_scale)) { if (isValidCfgScale(cfg_scale)) {
dispatch(setCfgScale(cfg_scale)); dispatch(setCfgScale(cfg_scale));
} }
if (isValidModel(model)) { if (isValidMainModel(model)) {
dispatch(modelSelected(model)); dispatch(modelSelected(model));
} }
if (isValidPositivePrompt(positive_conditioning)) { if (isValidPositivePrompt(positive_conditioning)) {

View File

@ -9,14 +9,16 @@ import { clipSkipMap } from '../components/Parameters/Advanced/ParamClipSkip';
import { import {
CfgScaleParam, CfgScaleParam,
HeightParam, HeightParam,
ModelParam, MainModelParam,
NegativePromptParam, NegativePromptParam,
PositivePromptParam, PositivePromptParam,
SchedulerParam, SchedulerParam,
SeedParam, SeedParam,
StepsParam, StepsParam,
StrengthParam, StrengthParam,
VaeModelParam,
WidthParam, WidthParam,
zMainModel,
} from './parameterZodSchemas'; } from './parameterZodSchemas';
export interface GenerationState { export interface GenerationState {
@ -48,8 +50,8 @@ export interface GenerationState {
shouldUseSymmetry: boolean; shouldUseSymmetry: boolean;
horizontalSymmetrySteps: number; horizontalSymmetrySteps: number;
verticalSymmetrySteps: number; verticalSymmetrySteps: number;
model: ModelParam; model: MainModelParam | null;
vae: VAEParam; vae: VaeModelParam | null;
seamlessXAxis: boolean; seamlessXAxis: boolean;
seamlessYAxis: boolean; seamlessYAxis: boolean;
clipSkip: number; clipSkip: number;
@ -84,7 +86,7 @@ export const initialGenerationState: GenerationState = {
horizontalSymmetrySteps: 0, horizontalSymmetrySteps: 0,
verticalSymmetrySteps: 0, verticalSymmetrySteps: 0,
model: null, model: null,
vae: '', vae: null,
seamlessXAxis: false, seamlessXAxis: false,
seamlessYAxis: false, seamlessYAxis: false,
clipSkip: 0, clipSkip: 0,
@ -221,12 +223,17 @@ export const generationSlice = createSlice({
const { maxClip } = clipSkipMap[base_model as keyof typeof clipSkipMap]; const { maxClip } = clipSkipMap[base_model as keyof typeof clipSkipMap];
state.clipSkip = clamp(state.clipSkip, 0, maxClip); state.clipSkip = clamp(state.clipSkip, 0, maxClip);
state.model = { id: action.payload, base_model, name, type }; state.model = zMainModel.parse({
id: action.payload,
base_model,
name,
type,
});
}, },
modelChanged: (state, action: PayloadAction<ModelParam>) => { modelChanged: (state, action: PayloadAction<MainModelParam>) => {
state.model = action.payload; state.model = action.payload;
}, },
vaeSelected: (state, action: PayloadAction<string>) => { vaeSelected: (state, action: PayloadAction<VaeModelParam | null>) => {
state.vae = action.payload; state.vae = action.payload;
}, },
setClipSkip: (state, action: PayloadAction<number>) => { setClipSkip: (state, action: PayloadAction<number>) => {
@ -236,14 +243,14 @@ export const generationSlice = createSlice({
extraReducers: (builder) => { extraReducers: (builder) => {
builder.addCase(configChanged, (state, action) => { builder.addCase(configChanged, (state, action) => {
const defaultModel = action.payload.sd?.defaultModel; const defaultModel = action.payload.sd?.defaultModel;
if (defaultModel && !state.model) { if (defaultModel && !state.model) {
const [base_model, model_type, model_name] = defaultModel.split('/'); const [base_model, model_type, model_name] = defaultModel.split('/');
state.model = { state.model = zMainModel.parse({
id: defaultModel, id: defaultModel,
name: model_name, name: model_name,
type: model_type, base_model,
base_model: base_model, });
};
} }
}); });
builder.addCase(setShouldShowAdvancedOptions, (state, action) => { builder.addCase(setShouldShowAdvancedOptions, (state, action) => {

View File

@ -126,35 +126,63 @@ export type HeightParam = z.infer<typeof zHeight>;
export const isValidHeight = (val: unknown): val is HeightParam => export const isValidHeight = (val: unknown): val is HeightParam =>
zHeight.safeParse(val).success; zHeight.safeParse(val).success;
const zBaseModel = z.enum(['sd-1', 'sd-2']);
export type BaseModelParam = z.infer<typeof zBaseModel>;
/** /**
* Zod schema for model parameter * Zod schema for model parameter
* TODO: Make this a dynamically generated enum? * TODO: Make this a dynamically generated enum?
*/ */
const zModel = z.object({ export const zMainModel = z.object({
id: z.string(), id: z.string(),
name: z.string(), name: z.string(),
type: z.string(), base_model: zBaseModel,
base_model: z.string(),
}); });
/** /**
* Type alias for model parameter, inferred from its zod schema * Type alias for model parameter, inferred from its zod schema
*/ */
export type ModelParam = z.infer<typeof zModel> | null; export type MainModelParam = z.infer<typeof zMainModel>;
/**
* Zod schema for VAE parameter
* TODO: Make this a dynamically generated enum?
*/
export const zVAE = z.string();
/**
* Type alias for model parameter, inferred from its zod schema
*/
export type VAEParam = z.infer<typeof zVAE>;
/** /**
* Validates/type-guards a value as a model parameter * Validates/type-guards a value as a model parameter
*/ */
export const isValidModel = (val: unknown): val is ModelParam => export const isValidMainModel = (val: unknown): val is MainModelParam =>
zModel.safeParse(val).success; zMainModel.safeParse(val).success;
/**
* Zod schema for VAE parameter
*/
export const zVaeModel = z.object({
id: z.string(),
name: z.string(),
base_model: zBaseModel,
});
/**
* Type alias for model parameter, inferred from its zod schema
*/
export type VaeModelParam = z.infer<typeof zVaeModel>;
/**
* Validates/type-guards a value as a model parameter
*/
export const isValidVaeModel = (val: unknown): val is VaeModelParam =>
zVaeModel.safeParse(val).success;
/**
* Zod schema for LoRA
*/
export const zLoRAModel = z.object({
id: z.string(),
name: z.string(),
base_model: zBaseModel,
});
/**
* Type alias for model parameter, inferred from its zod schema
*/
export type LoRAModelParam = z.infer<typeof zLoRAModel>;
/**
* Validates/type-guards a value as a model parameter
*/
export const isValidLoRAModel = (val: unknown): val is LoRAModelParam =>
zLoRAModel.safeParse(val).success;
/** /**
* Zod schema for l2l strength parameter * Zod schema for l2l strength parameter

View File

@ -6,9 +6,9 @@ import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { SelectItem } from '@mantine/core'; import { SelectItem } from '@mantine/core';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { modelSelected } from 'features/parameters/store/actions';
import { forEach, isString } from 'lodash-es'; import { forEach, isString } from 'lodash-es';
import { useGetMainModelsQuery } from 'services/api/endpoints/models'; import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import { modelSelected } from '../../parameters/store/actions';
export const MODEL_TYPE_MAP = { export const MODEL_TYPE_MAP = {
'sd-1': 'Stable Diffusion 1.x', 'sd-1': 'Stable Diffusion 1.x',
@ -63,6 +63,16 @@ const ModelSelect = () => {
); );
useEffect(() => { useEffect(() => {
if (isLoading) {
// return early here to avoid resetting model selection before we've loaded the available models
return;
}
if (selectedModel && mainModels?.ids.includes(selectedModel?.id)) {
// the selected model is an available model, no need to change it
return;
}
const firstModel = mainModels?.ids[0]; const firstModel = mainModels?.ids[0];
if (!isString(firstModel)) { if (!isString(firstModel)) {
@ -70,7 +80,7 @@ const ModelSelect = () => {
} }
handleChangeModel(firstModel); handleChangeModel(firstModel);
}, [handleChangeModel, mainModels?.ids]); }, [handleChangeModel, isLoading, mainModels?.ids, selectedModel]);
return isLoading ? ( return isLoading ? (
<IAIMantineSelect <IAIMantineSelect

View File

@ -9,9 +9,10 @@ import { forEach } from 'lodash-es';
import { useGetVaeModelsQuery } from 'services/api/endpoints/models'; import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { vaeSelected } from 'features/parameters/store/generationSlice'; import { vaeSelected } from 'features/parameters/store/generationSlice';
import { zVaeModel } from 'features/parameters/store/parameterZodSchemas';
import { MODEL_TYPE_MAP } from './ModelSelect'; import { MODEL_TYPE_MAP } from './ModelSelect';
import IAIMantineSelectItemWithTooltip from '../../../common/components/IAIMantineSelectItemWithTooltip';
const VAESelect = () => { const VAESelect = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
@ -34,8 +35,8 @@ const VAESelect = () => {
const data: SelectItem[] = [ const data: SelectItem[] = [
{ {
value: 'auto', value: 'default',
label: 'Automatic', label: 'Default',
group: 'Default', group: 'Default',
}, },
]; ];
@ -45,50 +46,65 @@ const VAESelect = () => {
return; return;
} }
const disabled = currentMainModel?.base_model !== model.base_model;
data.push({ data.push({
value: id, value: id,
label: model.name, label: model.name,
group: MODEL_TYPE_MAP[model.base_model], group: MODEL_TYPE_MAP[model.base_model],
...(currentMainModel?.base_model !== model.base_model disabled,
? { disabled: true, tooltip: 'Incompatible base model' } tooltip: disabled
: {}), ? `Incompatible base model: ${model.base_model}`
: undefined,
}); });
}); });
return data; return data.sort((a, b) => (a.disabled && !b.disabled ? 1 : -1));
}, [vaeModels, currentMainModel?.base_model]); }, [vaeModels, currentMainModel?.base_model]);
const selectedVaeModel = useMemo( const selectedVaeModel = useMemo(
() => vaeModels?.entities[selectedVae], () => (selectedVae?.id ? vaeModels?.entities[selectedVae?.id] : null),
[vaeModels?.entities, selectedVae] [vaeModels?.entities, selectedVae]
); );
const handleChangeModel = useCallback( const handleChangeModel = useCallback(
(v: string | null) => { (v: string | null) => {
if (!v) { if (!v || v === 'default') {
dispatch(vaeSelected(null));
return; return;
} }
dispatch(vaeSelected(v));
const [base_model, type, name] = v.split('/');
const model = zVaeModel.parse({
id: v,
name,
base_model,
});
dispatch(vaeSelected(model));
}, },
[dispatch] [dispatch]
); );
useEffect(() => { useEffect(() => {
if (selectedVae && vaeModels?.ids.includes(selectedVae)) { if (selectedVae && vaeModels?.ids.includes(selectedVae.id)) {
return; return;
} }
handleChangeModel('auto'); dispatch(vaeSelected(null));
}, [handleChangeModel, vaeModels?.ids, selectedVae]); }, [handleChangeModel, vaeModels?.ids, selectedVae, dispatch]);
return ( return (
<IAIMantineSelect <IAIMantineSelect
itemComponent={IAIMantineSelectItemWithTooltip} itemComponent={IAIMantineSelectItemWithTooltip}
tooltip={selectedVaeModel?.description} tooltip={selectedVaeModel?.description}
label={t('modelManager.vae')} label={t('modelManager.vae')}
value={selectedVae} value={selectedVae?.id ?? 'default'}
placeholder="Pick one" placeholder="Default"
data={data} data={data}
onChange={handleChangeModel} onChange={handleChangeModel}
disabled={data.length === 0}
clearable
/> />
); );
}; };