mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): finalize base model compatibility for lora, ti, vae
This commit is contained in:
parent
a9a4081f51
commit
8457fcf7d3
@ -1,11 +1,12 @@
|
||||
import { makeToast } from 'app/components/Toaster';
|
||||
import { modelSelected } from 'features/parameters/store/actions';
|
||||
import {
|
||||
modelChanged,
|
||||
vaeSelected,
|
||||
} from 'features/parameters/store/generationSlice';
|
||||
import { zMainModel } from 'features/parameters/store/parameterZodSchemas';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { startAppListening } from '..';
|
||||
import { modelSelected } from 'features/parameters/store/actions';
|
||||
import { makeToast } from 'app/components/Toaster';
|
||||
import { lorasCleared } from '../../../../../features/lora/store/loraSlice';
|
||||
|
||||
export const addModelSelectedListener = () => {
|
||||
@ -24,12 +25,18 @@ export const addModelSelectedListener = () => {
|
||||
})
|
||||
)
|
||||
);
|
||||
dispatch(vaeSelected('auto'));
|
||||
dispatch(vaeSelected(null));
|
||||
dispatch(lorasCleared());
|
||||
// TODO: controlnet cleared
|
||||
}
|
||||
|
||||
dispatch(modelChanged({ id: action.payload, base_model, name, type }));
|
||||
const newModel = zMainModel.parse({
|
||||
id: action.payload,
|
||||
base_model,
|
||||
name,
|
||||
});
|
||||
|
||||
dispatch(modelChanged(newModel));
|
||||
},
|
||||
});
|
||||
};
|
||||
|
@ -66,6 +66,7 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
|
||||
'&[data-disabled]': {
|
||||
backgroundColor: mode(base300, base700)(colorMode),
|
||||
color: mode(base600, base400)(colorMode),
|
||||
cursor: 'not-allowed',
|
||||
},
|
||||
},
|
||||
value: {
|
||||
@ -108,6 +109,10 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
|
||||
color: mode('white', base50)(colorMode),
|
||||
},
|
||||
},
|
||||
'&[data-disabled]': {
|
||||
color: mode(base500, base600)(colorMode),
|
||||
cursor: 'not-allowed',
|
||||
},
|
||||
},
|
||||
rightSection: {
|
||||
width: 24,
|
||||
|
@ -67,6 +67,7 @@ const IAIMantineSelect = (props: IAISelectProps) => {
|
||||
'&[data-disabled]': {
|
||||
backgroundColor: mode(base300, base700)(colorMode),
|
||||
color: mode(base600, base400)(colorMode),
|
||||
cursor: 'not-allowed',
|
||||
},
|
||||
},
|
||||
value: {
|
||||
@ -109,6 +110,10 @@ const IAIMantineSelect = (props: IAISelectProps) => {
|
||||
color: mode('white', base50)(colorMode),
|
||||
},
|
||||
},
|
||||
'&[data-disabled]': {
|
||||
color: mode(base500, base600)(colorMode),
|
||||
cursor: 'not-allowed',
|
||||
},
|
||||
},
|
||||
rightSection: {
|
||||
width: 32,
|
||||
|
@ -1,37 +1,28 @@
|
||||
import { Tooltip, Text } from '@mantine/core';
|
||||
import { Box, Tooltip } from '@chakra-ui/react';
|
||||
import { Text } from '@mantine/core';
|
||||
import { forwardRef, memo } from 'react';
|
||||
|
||||
interface ItemProps extends React.ComponentPropsWithoutRef<'div'> {
|
||||
label: string;
|
||||
description?: string;
|
||||
tooltip?: string;
|
||||
disabled?: boolean;
|
||||
}
|
||||
|
||||
const IAIMantineSelectItemWithTooltip = forwardRef<HTMLDivElement, ItemProps>(
|
||||
({ label, tooltip, description, ...others }: ItemProps, ref) => (
|
||||
<div ref={ref} {...others}>
|
||||
{tooltip ? (
|
||||
<Tooltip.Floating label={tooltip}>
|
||||
<div>
|
||||
({ label, tooltip, description, disabled, ...others }: ItemProps, ref) => (
|
||||
<Tooltip label={tooltip} placement="top" hasArrow>
|
||||
<Box ref={ref} {...others}>
|
||||
<Box>
|
||||
<Text>{label}</Text>
|
||||
{description && (
|
||||
<Text size="xs" color="base.600">
|
||||
{description}
|
||||
</Text>
|
||||
)}
|
||||
</div>
|
||||
</Tooltip.Floating>
|
||||
) : (
|
||||
<div>
|
||||
<Text>{label}</Text>
|
||||
{description && (
|
||||
<Text size="xs" color="base.600">
|
||||
{description}
|
||||
</Text>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</Box>
|
||||
</Box>
|
||||
</Tooltip>
|
||||
)
|
||||
);
|
||||
|
||||
|
@ -6,20 +6,16 @@ import {
|
||||
PopoverTrigger,
|
||||
Text,
|
||||
} from '@chakra-ui/react';
|
||||
import { SelectItem } from '@mantine/core';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
|
||||
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
|
||||
import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { PropsWithChildren, useCallback, useMemo, useRef } from 'react';
|
||||
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
|
||||
import { PARAMETERS_PANEL_WIDTH } from 'theme/util/constants';
|
||||
import { RootState } from '../../../app/store/store';
|
||||
import { useAppSelector } from '../../../app/store/storeHooks';
|
||||
import IAIMantineSelectItemWithTooltip from '../../../common/components/IAIMantineSelectItemWithTooltip';
|
||||
|
||||
type EmbeddingSelectItem = {
|
||||
label: string;
|
||||
value: string;
|
||||
description?: string;
|
||||
};
|
||||
|
||||
type Props = PropsWithChildren & {
|
||||
onSelect: (v: string) => void;
|
||||
@ -41,22 +37,27 @@ const ParamEmbeddingPopover = (props: Props) => {
|
||||
return [];
|
||||
}
|
||||
|
||||
const data: EmbeddingSelectItem[] = [];
|
||||
const data: SelectItem[] = [];
|
||||
|
||||
forEach(embeddingQueryData.entities, (embedding, _) => {
|
||||
if (!embedding) return;
|
||||
if (!embedding) {
|
||||
return;
|
||||
}
|
||||
|
||||
const disabled = currentMainModel?.base_model !== embedding.base_model;
|
||||
|
||||
data.push({
|
||||
value: embedding.name,
|
||||
label: embedding.name,
|
||||
description: embedding.description,
|
||||
...(currentMainModel?.base_model !== embedding.base_model
|
||||
? { disabled: true, tooltip: 'Incompatible base model' }
|
||||
: {}),
|
||||
group: MODEL_TYPE_MAP[embedding.base_model],
|
||||
disabled,
|
||||
tooltip: disabled
|
||||
? `Incompatible base model: ${embedding.base_model}`
|
||||
: undefined,
|
||||
});
|
||||
});
|
||||
|
||||
return data;
|
||||
return data.sort((a, b) => (a.disabled && !b.disabled ? 1 : -1));
|
||||
}, [embeddingQueryData, currentMainModel?.base_model]);
|
||||
|
||||
const handleChange = useCallback(
|
||||
@ -114,8 +115,10 @@ const ParamEmbeddingPopover = (props: Props) => {
|
||||
nothingFound="No Matching Embeddings"
|
||||
itemComponent={IAIMantineSelectItemWithTooltip}
|
||||
disabled={data.length === 0}
|
||||
filter={(value, selected, item: EmbeddingSelectItem) =>
|
||||
item.label.toLowerCase().includes(value.toLowerCase().trim()) ||
|
||||
filter={(value, selected, item: SelectItem) =>
|
||||
item.label
|
||||
?.toLowerCase()
|
||||
.includes(value.toLowerCase().trim()) ||
|
||||
item.value.toLowerCase().includes(value.toLowerCase().trim())
|
||||
}
|
||||
onChange={handleChange}
|
||||
|
@ -1,20 +1,16 @@
|
||||
import { Flex, Text } from '@chakra-ui/react';
|
||||
import { SelectItem } from '@mantine/core';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { RootState, stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
|
||||
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
|
||||
import { loraAdded } from 'features/lora/store/loraSlice';
|
||||
import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
|
||||
import { loraAdded } from '../store/loraSlice';
|
||||
import IAIMantineSelectItemWithTooltip from '../../../common/components/IAIMantineSelectItemWithTooltip';
|
||||
|
||||
type LoraSelectItem = {
|
||||
label: string;
|
||||
value: string;
|
||||
description?: string;
|
||||
};
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
@ -38,24 +34,27 @@ const ParamLoraSelect = () => {
|
||||
return [];
|
||||
}
|
||||
|
||||
const data: LoraSelectItem[] = [];
|
||||
const data: SelectItem[] = [];
|
||||
|
||||
forEach(lorasQueryData.entities, (lora, id) => {
|
||||
if (!lora || Boolean(id in loras)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const disabled = currentMainModel?.base_model !== lora.base_model;
|
||||
|
||||
data.push({
|
||||
value: id,
|
||||
label: lora.name,
|
||||
description: lora.description,
|
||||
...(currentMainModel?.base_model !== lora.base_model
|
||||
? { disabled: true, tooltip: 'Incompatible base model' }
|
||||
: {}),
|
||||
disabled,
|
||||
group: MODEL_TYPE_MAP[lora.base_model],
|
||||
tooltip: disabled
|
||||
? `Incompatible base model: ${lora.base_model}`
|
||||
: undefined,
|
||||
});
|
||||
});
|
||||
|
||||
return data;
|
||||
return data.sort((a, b) => (a.disabled && !b.disabled ? 1 : -1));
|
||||
}, [loras, lorasQueryData, currentMainModel?.base_model]);
|
||||
|
||||
const handleChange = useCallback(
|
||||
@ -88,8 +87,8 @@ const ParamLoraSelect = () => {
|
||||
nothingFound="No matching LoRAs"
|
||||
itemComponent={IAIMantineSelectItemWithTooltip}
|
||||
disabled={data.length === 0}
|
||||
filter={(value, selected, item: LoraSelectItem) =>
|
||||
item.label.toLowerCase().includes(value.toLowerCase().trim()) ||
|
||||
filter={(value, selected, item: SelectItem) =>
|
||||
item.label?.toLowerCase().includes(value.toLowerCase().trim()) ||
|
||||
item.value.toLowerCase().includes(value.toLowerCase().trim())
|
||||
}
|
||||
onChange={handleChange}
|
||||
|
@ -1,18 +1,21 @@
|
||||
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
|
||||
import { LoRAModelParam } from 'features/parameters/store/parameterZodSchemas';
|
||||
import { LoRAModelConfigEntity } from 'services/api/endpoints/models';
|
||||
import { BaseModelType } from 'services/api/types';
|
||||
|
||||
export type Lora = {
|
||||
id: string;
|
||||
base_model: BaseModelType;
|
||||
name: string;
|
||||
weight: number;
|
||||
};
|
||||
|
||||
export const defaultLoRAConfig: Omit<Lora, 'id' | 'name'> = {
|
||||
export const defaultLoRAConfig = {
|
||||
weight: 0.75,
|
||||
};
|
||||
|
||||
export type LoraState = {
|
||||
loras: Record<string, Lora>;
|
||||
loras: Record<string, LoRAModelParam & { weight: number }>;
|
||||
};
|
||||
|
||||
export const intialLoraState: LoraState = {
|
||||
@ -24,14 +27,14 @@ export const loraSlice = createSlice({
|
||||
initialState: intialLoraState,
|
||||
reducers: {
|
||||
loraAdded: (state, action: PayloadAction<LoRAModelConfigEntity>) => {
|
||||
const { name, id } = action.payload;
|
||||
state.loras[id] = { id, name, ...defaultLoRAConfig };
|
||||
const { name, id, base_model } = action.payload;
|
||||
state.loras[id] = { id, name, base_model, ...defaultLoRAConfig };
|
||||
},
|
||||
loraRemoved: (state, action: PayloadAction<string>) => {
|
||||
const id = action.payload;
|
||||
delete state.loras[id];
|
||||
},
|
||||
lorasCleared: (state, action: PayloadAction<>) => {
|
||||
lorasCleared: (state) => {
|
||||
state.loras = {};
|
||||
},
|
||||
loraWeightChanged: (
|
||||
|
@ -19,7 +19,7 @@ export const addVAEToGraph = (
|
||||
const { vae } = state.generation;
|
||||
const vae_model = modelIdToVAEModelField(vae?.id || '');
|
||||
|
||||
const isAutoVae = vae?.id === 'auto';
|
||||
const isAutoVae = !vae;
|
||||
|
||||
if (!isAutoVae) {
|
||||
graph.nodes[VAE_LOADER] = {
|
||||
|
@ -1,6 +1,10 @@
|
||||
import { useAppToaster } from 'app/components/Toaster';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { isImageField } from 'services/api/guards';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import { initialImageSelected, modelSelected } from '../store/actions';
|
||||
import {
|
||||
setCfgScale,
|
||||
setHeight,
|
||||
@ -12,14 +16,10 @@ import {
|
||||
setSteps,
|
||||
setWidth,
|
||||
} from '../store/generationSlice';
|
||||
import { isImageField } from 'services/api/guards';
|
||||
import { initialImageSelected, modelSelected } from '../store/actions';
|
||||
import { useAppToaster } from 'app/components/Toaster';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import {
|
||||
isValidCfgScale,
|
||||
isValidHeight,
|
||||
isValidModel,
|
||||
isValidMainModel,
|
||||
isValidNegativePrompt,
|
||||
isValidPositivePrompt,
|
||||
isValidScheduler,
|
||||
@ -158,7 +158,7 @@ export const useRecallParameters = () => {
|
||||
*/
|
||||
const recallModel = useCallback(
|
||||
(model: unknown) => {
|
||||
if (!isValidModel(model)) {
|
||||
if (!isValidMainModel(model)) {
|
||||
parameterNotSetToast();
|
||||
return;
|
||||
}
|
||||
@ -295,7 +295,7 @@ export const useRecallParameters = () => {
|
||||
if (isValidCfgScale(cfg_scale)) {
|
||||
dispatch(setCfgScale(cfg_scale));
|
||||
}
|
||||
if (isValidModel(model)) {
|
||||
if (isValidMainModel(model)) {
|
||||
dispatch(modelSelected(model));
|
||||
}
|
||||
if (isValidPositivePrompt(positive_conditioning)) {
|
||||
|
@ -9,14 +9,16 @@ import { clipSkipMap } from '../components/Parameters/Advanced/ParamClipSkip';
|
||||
import {
|
||||
CfgScaleParam,
|
||||
HeightParam,
|
||||
ModelParam,
|
||||
MainModelParam,
|
||||
NegativePromptParam,
|
||||
PositivePromptParam,
|
||||
SchedulerParam,
|
||||
SeedParam,
|
||||
StepsParam,
|
||||
StrengthParam,
|
||||
VaeModelParam,
|
||||
WidthParam,
|
||||
zMainModel,
|
||||
} from './parameterZodSchemas';
|
||||
|
||||
export interface GenerationState {
|
||||
@ -48,8 +50,8 @@ export interface GenerationState {
|
||||
shouldUseSymmetry: boolean;
|
||||
horizontalSymmetrySteps: number;
|
||||
verticalSymmetrySteps: number;
|
||||
model: ModelParam;
|
||||
vae: VAEParam;
|
||||
model: MainModelParam | null;
|
||||
vae: VaeModelParam | null;
|
||||
seamlessXAxis: boolean;
|
||||
seamlessYAxis: boolean;
|
||||
clipSkip: number;
|
||||
@ -84,7 +86,7 @@ export const initialGenerationState: GenerationState = {
|
||||
horizontalSymmetrySteps: 0,
|
||||
verticalSymmetrySteps: 0,
|
||||
model: null,
|
||||
vae: '',
|
||||
vae: null,
|
||||
seamlessXAxis: false,
|
||||
seamlessYAxis: false,
|
||||
clipSkip: 0,
|
||||
@ -221,12 +223,17 @@ export const generationSlice = createSlice({
|
||||
const { maxClip } = clipSkipMap[base_model as keyof typeof clipSkipMap];
|
||||
state.clipSkip = clamp(state.clipSkip, 0, maxClip);
|
||||
|
||||
state.model = { id: action.payload, base_model, name, type };
|
||||
state.model = zMainModel.parse({
|
||||
id: action.payload,
|
||||
base_model,
|
||||
name,
|
||||
type,
|
||||
});
|
||||
},
|
||||
modelChanged: (state, action: PayloadAction<ModelParam>) => {
|
||||
modelChanged: (state, action: PayloadAction<MainModelParam>) => {
|
||||
state.model = action.payload;
|
||||
},
|
||||
vaeSelected: (state, action: PayloadAction<string>) => {
|
||||
vaeSelected: (state, action: PayloadAction<VaeModelParam | null>) => {
|
||||
state.vae = action.payload;
|
||||
},
|
||||
setClipSkip: (state, action: PayloadAction<number>) => {
|
||||
@ -236,14 +243,14 @@ export const generationSlice = createSlice({
|
||||
extraReducers: (builder) => {
|
||||
builder.addCase(configChanged, (state, action) => {
|
||||
const defaultModel = action.payload.sd?.defaultModel;
|
||||
|
||||
if (defaultModel && !state.model) {
|
||||
const [base_model, model_type, model_name] = defaultModel.split('/');
|
||||
state.model = {
|
||||
state.model = zMainModel.parse({
|
||||
id: defaultModel,
|
||||
name: model_name,
|
||||
type: model_type,
|
||||
base_model: base_model,
|
||||
};
|
||||
base_model,
|
||||
});
|
||||
}
|
||||
});
|
||||
builder.addCase(setShouldShowAdvancedOptions, (state, action) => {
|
||||
|
@ -126,35 +126,63 @@ export type HeightParam = z.infer<typeof zHeight>;
|
||||
export const isValidHeight = (val: unknown): val is HeightParam =>
|
||||
zHeight.safeParse(val).success;
|
||||
|
||||
const zBaseModel = z.enum(['sd-1', 'sd-2']);
|
||||
|
||||
export type BaseModelParam = z.infer<typeof zBaseModel>;
|
||||
|
||||
/**
|
||||
* Zod schema for model parameter
|
||||
* TODO: Make this a dynamically generated enum?
|
||||
*/
|
||||
const zModel = z.object({
|
||||
export const zMainModel = z.object({
|
||||
id: z.string(),
|
||||
name: z.string(),
|
||||
type: z.string(),
|
||||
base_model: z.string(),
|
||||
base_model: zBaseModel,
|
||||
});
|
||||
|
||||
/**
|
||||
* Type alias for model parameter, inferred from its zod schema
|
||||
*/
|
||||
export type ModelParam = z.infer<typeof zModel> | null;
|
||||
/**
|
||||
* Zod schema for VAE parameter
|
||||
* TODO: Make this a dynamically generated enum?
|
||||
*/
|
||||
export const zVAE = z.string();
|
||||
/**
|
||||
* Type alias for model parameter, inferred from its zod schema
|
||||
*/
|
||||
export type VAEParam = z.infer<typeof zVAE>;
|
||||
export type MainModelParam = z.infer<typeof zMainModel>;
|
||||
/**
|
||||
* Validates/type-guards a value as a model parameter
|
||||
*/
|
||||
export const isValidModel = (val: unknown): val is ModelParam =>
|
||||
zModel.safeParse(val).success;
|
||||
export const isValidMainModel = (val: unknown): val is MainModelParam =>
|
||||
zMainModel.safeParse(val).success;
|
||||
/**
|
||||
* Zod schema for VAE parameter
|
||||
*/
|
||||
export const zVaeModel = z.object({
|
||||
id: z.string(),
|
||||
name: z.string(),
|
||||
base_model: zBaseModel,
|
||||
});
|
||||
/**
|
||||
* Type alias for model parameter, inferred from its zod schema
|
||||
*/
|
||||
export type VaeModelParam = z.infer<typeof zVaeModel>;
|
||||
/**
|
||||
* Validates/type-guards a value as a model parameter
|
||||
*/
|
||||
export const isValidVaeModel = (val: unknown): val is VaeModelParam =>
|
||||
zVaeModel.safeParse(val).success;
|
||||
/**
|
||||
* Zod schema for LoRA
|
||||
*/
|
||||
export const zLoRAModel = z.object({
|
||||
id: z.string(),
|
||||
name: z.string(),
|
||||
base_model: zBaseModel,
|
||||
});
|
||||
/**
|
||||
* Type alias for model parameter, inferred from its zod schema
|
||||
*/
|
||||
export type LoRAModelParam = z.infer<typeof zLoRAModel>;
|
||||
/**
|
||||
* Validates/type-guards a value as a model parameter
|
||||
*/
|
||||
export const isValidLoRAModel = (val: unknown): val is LoRAModelParam =>
|
||||
zLoRAModel.safeParse(val).success;
|
||||
|
||||
/**
|
||||
* Zod schema for l2l strength parameter
|
||||
|
@ -6,9 +6,9 @@ import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
|
||||
import { SelectItem } from '@mantine/core';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { modelSelected } from 'features/parameters/store/actions';
|
||||
import { forEach, isString } from 'lodash-es';
|
||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
import { modelSelected } from '../../parameters/store/actions';
|
||||
|
||||
export const MODEL_TYPE_MAP = {
|
||||
'sd-1': 'Stable Diffusion 1.x',
|
||||
@ -63,6 +63,16 @@ const ModelSelect = () => {
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (isLoading) {
|
||||
// return early here to avoid resetting model selection before we've loaded the available models
|
||||
return;
|
||||
}
|
||||
|
||||
if (selectedModel && mainModels?.ids.includes(selectedModel?.id)) {
|
||||
// the selected model is an available model, no need to change it
|
||||
return;
|
||||
}
|
||||
|
||||
const firstModel = mainModels?.ids[0];
|
||||
|
||||
if (!isString(firstModel)) {
|
||||
@ -70,7 +80,7 @@ const ModelSelect = () => {
|
||||
}
|
||||
|
||||
handleChangeModel(firstModel);
|
||||
}, [handleChangeModel, mainModels?.ids]);
|
||||
}, [handleChangeModel, isLoading, mainModels?.ids, selectedModel]);
|
||||
|
||||
return isLoading ? (
|
||||
<IAIMantineSelect
|
||||
|
@ -9,9 +9,10 @@ import { forEach } from 'lodash-es';
|
||||
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import { RootState } from 'app/store/store';
|
||||
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
|
||||
import { vaeSelected } from 'features/parameters/store/generationSlice';
|
||||
import { zVaeModel } from 'features/parameters/store/parameterZodSchemas';
|
||||
import { MODEL_TYPE_MAP } from './ModelSelect';
|
||||
import IAIMantineSelectItemWithTooltip from '../../../common/components/IAIMantineSelectItemWithTooltip';
|
||||
|
||||
const VAESelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
@ -34,8 +35,8 @@ const VAESelect = () => {
|
||||
|
||||
const data: SelectItem[] = [
|
||||
{
|
||||
value: 'auto',
|
||||
label: 'Automatic',
|
||||
value: 'default',
|
||||
label: 'Default',
|
||||
group: 'Default',
|
||||
},
|
||||
];
|
||||
@ -45,50 +46,65 @@ const VAESelect = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
const disabled = currentMainModel?.base_model !== model.base_model;
|
||||
|
||||
data.push({
|
||||
value: id,
|
||||
label: model.name,
|
||||
group: MODEL_TYPE_MAP[model.base_model],
|
||||
...(currentMainModel?.base_model !== model.base_model
|
||||
? { disabled: true, tooltip: 'Incompatible base model' }
|
||||
: {}),
|
||||
disabled,
|
||||
tooltip: disabled
|
||||
? `Incompatible base model: ${model.base_model}`
|
||||
: undefined,
|
||||
});
|
||||
});
|
||||
|
||||
return data;
|
||||
return data.sort((a, b) => (a.disabled && !b.disabled ? 1 : -1));
|
||||
}, [vaeModels, currentMainModel?.base_model]);
|
||||
|
||||
const selectedVaeModel = useMemo(
|
||||
() => vaeModels?.entities[selectedVae],
|
||||
() => (selectedVae?.id ? vaeModels?.entities[selectedVae?.id] : null),
|
||||
[vaeModels?.entities, selectedVae]
|
||||
);
|
||||
|
||||
const handleChangeModel = useCallback(
|
||||
(v: string | null) => {
|
||||
if (!v) {
|
||||
if (!v || v === 'default') {
|
||||
dispatch(vaeSelected(null));
|
||||
return;
|
||||
}
|
||||
dispatch(vaeSelected(v));
|
||||
|
||||
const [base_model, type, name] = v.split('/');
|
||||
|
||||
const model = zVaeModel.parse({
|
||||
id: v,
|
||||
name,
|
||||
base_model,
|
||||
});
|
||||
|
||||
dispatch(vaeSelected(model));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (selectedVae && vaeModels?.ids.includes(selectedVae)) {
|
||||
if (selectedVae && vaeModels?.ids.includes(selectedVae.id)) {
|
||||
return;
|
||||
}
|
||||
handleChangeModel('auto');
|
||||
}, [handleChangeModel, vaeModels?.ids, selectedVae]);
|
||||
dispatch(vaeSelected(null));
|
||||
}, [handleChangeModel, vaeModels?.ids, selectedVae, dispatch]);
|
||||
|
||||
return (
|
||||
<IAIMantineSelect
|
||||
itemComponent={IAIMantineSelectItemWithTooltip}
|
||||
tooltip={selectedVaeModel?.description}
|
||||
label={t('modelManager.vae')}
|
||||
value={selectedVae}
|
||||
placeholder="Pick one"
|
||||
value={selectedVae?.id ?? 'default'}
|
||||
placeholder="Default"
|
||||
data={data}
|
||||
onChange={handleChangeModel}
|
||||
disabled={data.length === 0}
|
||||
clearable
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user