feat(ui): finalize base model compatibility for lora, ti, vae

This commit is contained in:
psychedelicious 2023-07-07 21:23:03 +10:00
parent a9a4081f51
commit 8457fcf7d3
13 changed files with 187 additions and 113 deletions

View File

@ -1,11 +1,12 @@
import { makeToast } from 'app/components/Toaster';
import { modelSelected } from 'features/parameters/store/actions';
import {
modelChanged,
vaeSelected,
} from 'features/parameters/store/generationSlice';
import { zMainModel } from 'features/parameters/store/parameterZodSchemas';
import { addToast } from 'features/system/store/systemSlice';
import { startAppListening } from '..';
import { modelSelected } from 'features/parameters/store/actions';
import { makeToast } from 'app/components/Toaster';
import { lorasCleared } from '../../../../../features/lora/store/loraSlice';
export const addModelSelectedListener = () => {
@ -24,12 +25,18 @@ export const addModelSelectedListener = () => {
})
)
);
dispatch(vaeSelected('auto'));
dispatch(vaeSelected(null));
dispatch(lorasCleared());
// TODO: controlnet cleared
}
dispatch(modelChanged({ id: action.payload, base_model, name, type }));
const newModel = zMainModel.parse({
id: action.payload,
base_model,
name,
});
dispatch(modelChanged(newModel));
},
});
};

View File

@ -66,6 +66,7 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
'&[data-disabled]': {
backgroundColor: mode(base300, base700)(colorMode),
color: mode(base600, base400)(colorMode),
cursor: 'not-allowed',
},
},
value: {
@ -108,6 +109,10 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
color: mode('white', base50)(colorMode),
},
},
'&[data-disabled]': {
color: mode(base500, base600)(colorMode),
cursor: 'not-allowed',
},
},
rightSection: {
width: 24,

View File

@ -67,6 +67,7 @@ const IAIMantineSelect = (props: IAISelectProps) => {
'&[data-disabled]': {
backgroundColor: mode(base300, base700)(colorMode),
color: mode(base600, base400)(colorMode),
cursor: 'not-allowed',
},
},
value: {
@ -109,6 +110,10 @@ const IAIMantineSelect = (props: IAISelectProps) => {
color: mode('white', base50)(colorMode),
},
},
'&[data-disabled]': {
color: mode(base500, base600)(colorMode),
cursor: 'not-allowed',
},
},
rightSection: {
width: 32,

View File

@ -1,37 +1,28 @@
import { Tooltip, Text } from '@mantine/core';
import { Box, Tooltip } from '@chakra-ui/react';
import { Text } from '@mantine/core';
import { forwardRef, memo } from 'react';
interface ItemProps extends React.ComponentPropsWithoutRef<'div'> {
label: string;
description?: string;
tooltip?: string;
disabled?: boolean;
}
const IAIMantineSelectItemWithTooltip = forwardRef<HTMLDivElement, ItemProps>(
({ label, tooltip, description, ...others }: ItemProps, ref) => (
<div ref={ref} {...others}>
{tooltip ? (
<Tooltip.Floating label={tooltip}>
<div>
({ label, tooltip, description, disabled, ...others }: ItemProps, ref) => (
<Tooltip label={tooltip} placement="top" hasArrow>
<Box ref={ref} {...others}>
<Box>
<Text>{label}</Text>
{description && (
<Text size="xs" color="base.600">
{description}
</Text>
)}
</div>
</Tooltip.Floating>
) : (
<div>
<Text>{label}</Text>
{description && (
<Text size="xs" color="base.600">
{description}
</Text>
)}
</div>
)}
</div>
</Box>
</Box>
</Tooltip>
)
);

View File

@ -6,20 +6,16 @@ import {
PopoverTrigger,
Text,
} from '@chakra-ui/react';
import { SelectItem } from '@mantine/core';
import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect';
import { forEach } from 'lodash-es';
import { PropsWithChildren, useCallback, useMemo, useRef } from 'react';
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
import { PARAMETERS_PANEL_WIDTH } from 'theme/util/constants';
import { RootState } from '../../../app/store/store';
import { useAppSelector } from '../../../app/store/storeHooks';
import IAIMantineSelectItemWithTooltip from '../../../common/components/IAIMantineSelectItemWithTooltip';
type EmbeddingSelectItem = {
label: string;
value: string;
description?: string;
};
type Props = PropsWithChildren & {
onSelect: (v: string) => void;
@ -41,22 +37,27 @@ const ParamEmbeddingPopover = (props: Props) => {
return [];
}
const data: EmbeddingSelectItem[] = [];
const data: SelectItem[] = [];
forEach(embeddingQueryData.entities, (embedding, _) => {
if (!embedding) return;
if (!embedding) {
return;
}
const disabled = currentMainModel?.base_model !== embedding.base_model;
data.push({
value: embedding.name,
label: embedding.name,
description: embedding.description,
...(currentMainModel?.base_model !== embedding.base_model
? { disabled: true, tooltip: 'Incompatible base model' }
: {}),
group: MODEL_TYPE_MAP[embedding.base_model],
disabled,
tooltip: disabled
? `Incompatible base model: ${embedding.base_model}`
: undefined,
});
});
return data;
return data.sort((a, b) => (a.disabled && !b.disabled ? 1 : -1));
}, [embeddingQueryData, currentMainModel?.base_model]);
const handleChange = useCallback(
@ -114,8 +115,10 @@ const ParamEmbeddingPopover = (props: Props) => {
nothingFound="No Matching Embeddings"
itemComponent={IAIMantineSelectItemWithTooltip}
disabled={data.length === 0}
filter={(value, selected, item: EmbeddingSelectItem) =>
item.label.toLowerCase().includes(value.toLowerCase().trim()) ||
filter={(value, selected, item: SelectItem) =>
item.label
?.toLowerCase()
.includes(value.toLowerCase().trim()) ||
item.value.toLowerCase().includes(value.toLowerCase().trim())
}
onChange={handleChange}

View File

@ -1,20 +1,16 @@
import { Flex, Text } from '@chakra-ui/react';
import { SelectItem } from '@mantine/core';
import { createSelector } from '@reduxjs/toolkit';
import { RootState, stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { loraAdded } from 'features/lora/store/loraSlice';
import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect';
import { forEach } from 'lodash-es';
import { useCallback, useMemo } from 'react';
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
import { loraAdded } from '../store/loraSlice';
import IAIMantineSelectItemWithTooltip from '../../../common/components/IAIMantineSelectItemWithTooltip';
type LoraSelectItem = {
label: string;
value: string;
description?: string;
};
const selector = createSelector(
stateSelector,
@ -38,24 +34,27 @@ const ParamLoraSelect = () => {
return [];
}
const data: LoraSelectItem[] = [];
const data: SelectItem[] = [];
forEach(lorasQueryData.entities, (lora, id) => {
if (!lora || Boolean(id in loras)) {
return;
}
const disabled = currentMainModel?.base_model !== lora.base_model;
data.push({
value: id,
label: lora.name,
description: lora.description,
...(currentMainModel?.base_model !== lora.base_model
? { disabled: true, tooltip: 'Incompatible base model' }
: {}),
disabled,
group: MODEL_TYPE_MAP[lora.base_model],
tooltip: disabled
? `Incompatible base model: ${lora.base_model}`
: undefined,
});
});
return data;
return data.sort((a, b) => (a.disabled && !b.disabled ? 1 : -1));
}, [loras, lorasQueryData, currentMainModel?.base_model]);
const handleChange = useCallback(
@ -88,8 +87,8 @@ const ParamLoraSelect = () => {
nothingFound="No matching LoRAs"
itemComponent={IAIMantineSelectItemWithTooltip}
disabled={data.length === 0}
filter={(value, selected, item: LoraSelectItem) =>
item.label.toLowerCase().includes(value.toLowerCase().trim()) ||
filter={(value, selected, item: SelectItem) =>
item.label?.toLowerCase().includes(value.toLowerCase().trim()) ||
item.value.toLowerCase().includes(value.toLowerCase().trim())
}
onChange={handleChange}

View File

@ -1,18 +1,21 @@
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import { LoRAModelParam } from 'features/parameters/store/parameterZodSchemas';
import { LoRAModelConfigEntity } from 'services/api/endpoints/models';
import { BaseModelType } from 'services/api/types';
export type Lora = {
id: string;
base_model: BaseModelType;
name: string;
weight: number;
};
export const defaultLoRAConfig: Omit<Lora, 'id' | 'name'> = {
export const defaultLoRAConfig = {
weight: 0.75,
};
export type LoraState = {
loras: Record<string, Lora>;
loras: Record<string, LoRAModelParam & { weight: number }>;
};
export const intialLoraState: LoraState = {
@ -24,14 +27,14 @@ export const loraSlice = createSlice({
initialState: intialLoraState,
reducers: {
loraAdded: (state, action: PayloadAction<LoRAModelConfigEntity>) => {
const { name, id } = action.payload;
state.loras[id] = { id, name, ...defaultLoRAConfig };
const { name, id, base_model } = action.payload;
state.loras[id] = { id, name, base_model, ...defaultLoRAConfig };
},
loraRemoved: (state, action: PayloadAction<string>) => {
const id = action.payload;
delete state.loras[id];
},
lorasCleared: (state, action: PayloadAction<>) => {
lorasCleared: (state) => {
state.loras = {};
},
loraWeightChanged: (

View File

@ -19,7 +19,7 @@ export const addVAEToGraph = (
const { vae } = state.generation;
const vae_model = modelIdToVAEModelField(vae?.id || '');
const isAutoVae = vae?.id === 'auto';
const isAutoVae = !vae;
if (!isAutoVae) {
graph.nodes[VAE_LOADER] = {

View File

@ -1,6 +1,10 @@
import { useAppToaster } from 'app/components/Toaster';
import { useAppDispatch } from 'app/store/storeHooks';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { isImageField } from 'services/api/guards';
import { ImageDTO } from 'services/api/types';
import { initialImageSelected, modelSelected } from '../store/actions';
import {
setCfgScale,
setHeight,
@ -12,14 +16,10 @@ import {
setSteps,
setWidth,
} from '../store/generationSlice';
import { isImageField } from 'services/api/guards';
import { initialImageSelected, modelSelected } from '../store/actions';
import { useAppToaster } from 'app/components/Toaster';
import { ImageDTO } from 'services/api/types';
import {
isValidCfgScale,
isValidHeight,
isValidModel,
isValidMainModel,
isValidNegativePrompt,
isValidPositivePrompt,
isValidScheduler,
@ -158,7 +158,7 @@ export const useRecallParameters = () => {
*/
const recallModel = useCallback(
(model: unknown) => {
if (!isValidModel(model)) {
if (!isValidMainModel(model)) {
parameterNotSetToast();
return;
}
@ -295,7 +295,7 @@ export const useRecallParameters = () => {
if (isValidCfgScale(cfg_scale)) {
dispatch(setCfgScale(cfg_scale));
}
if (isValidModel(model)) {
if (isValidMainModel(model)) {
dispatch(modelSelected(model));
}
if (isValidPositivePrompt(positive_conditioning)) {

View File

@ -9,14 +9,16 @@ import { clipSkipMap } from '../components/Parameters/Advanced/ParamClipSkip';
import {
CfgScaleParam,
HeightParam,
ModelParam,
MainModelParam,
NegativePromptParam,
PositivePromptParam,
SchedulerParam,
SeedParam,
StepsParam,
StrengthParam,
VaeModelParam,
WidthParam,
zMainModel,
} from './parameterZodSchemas';
export interface GenerationState {
@ -48,8 +50,8 @@ export interface GenerationState {
shouldUseSymmetry: boolean;
horizontalSymmetrySteps: number;
verticalSymmetrySteps: number;
model: ModelParam;
vae: VAEParam;
model: MainModelParam | null;
vae: VaeModelParam | null;
seamlessXAxis: boolean;
seamlessYAxis: boolean;
clipSkip: number;
@ -84,7 +86,7 @@ export const initialGenerationState: GenerationState = {
horizontalSymmetrySteps: 0,
verticalSymmetrySteps: 0,
model: null,
vae: '',
vae: null,
seamlessXAxis: false,
seamlessYAxis: false,
clipSkip: 0,
@ -221,12 +223,17 @@ export const generationSlice = createSlice({
const { maxClip } = clipSkipMap[base_model as keyof typeof clipSkipMap];
state.clipSkip = clamp(state.clipSkip, 0, maxClip);
state.model = { id: action.payload, base_model, name, type };
state.model = zMainModel.parse({
id: action.payload,
base_model,
name,
type,
});
},
modelChanged: (state, action: PayloadAction<ModelParam>) => {
modelChanged: (state, action: PayloadAction<MainModelParam>) => {
state.model = action.payload;
},
vaeSelected: (state, action: PayloadAction<string>) => {
vaeSelected: (state, action: PayloadAction<VaeModelParam | null>) => {
state.vae = action.payload;
},
setClipSkip: (state, action: PayloadAction<number>) => {
@ -236,14 +243,14 @@ export const generationSlice = createSlice({
extraReducers: (builder) => {
builder.addCase(configChanged, (state, action) => {
const defaultModel = action.payload.sd?.defaultModel;
if (defaultModel && !state.model) {
const [base_model, model_type, model_name] = defaultModel.split('/');
state.model = {
state.model = zMainModel.parse({
id: defaultModel,
name: model_name,
type: model_type,
base_model: base_model,
};
base_model,
});
}
});
builder.addCase(setShouldShowAdvancedOptions, (state, action) => {

View File

@ -126,35 +126,63 @@ export type HeightParam = z.infer<typeof zHeight>;
export const isValidHeight = (val: unknown): val is HeightParam =>
zHeight.safeParse(val).success;
const zBaseModel = z.enum(['sd-1', 'sd-2']);
export type BaseModelParam = z.infer<typeof zBaseModel>;
/**
* Zod schema for model parameter
* TODO: Make this a dynamically generated enum?
*/
const zModel = z.object({
export const zMainModel = z.object({
id: z.string(),
name: z.string(),
type: z.string(),
base_model: z.string(),
base_model: zBaseModel,
});
/**
* Type alias for model parameter, inferred from its zod schema
*/
export type ModelParam = z.infer<typeof zModel> | null;
/**
* Zod schema for VAE parameter
* TODO: Make this a dynamically generated enum?
*/
export const zVAE = z.string();
/**
* Type alias for model parameter, inferred from its zod schema
*/
export type VAEParam = z.infer<typeof zVAE>;
export type MainModelParam = z.infer<typeof zMainModel>;
/**
* Validates/type-guards a value as a model parameter
*/
export const isValidModel = (val: unknown): val is ModelParam =>
zModel.safeParse(val).success;
export const isValidMainModel = (val: unknown): val is MainModelParam =>
zMainModel.safeParse(val).success;
/**
* Zod schema for VAE parameter
*/
export const zVaeModel = z.object({
id: z.string(),
name: z.string(),
base_model: zBaseModel,
});
/**
* Type alias for model parameter, inferred from its zod schema
*/
export type VaeModelParam = z.infer<typeof zVaeModel>;
/**
* Validates/type-guards a value as a model parameter
*/
export const isValidVaeModel = (val: unknown): val is VaeModelParam =>
zVaeModel.safeParse(val).success;
/**
* Zod schema for LoRA
*/
export const zLoRAModel = z.object({
id: z.string(),
name: z.string(),
base_model: zBaseModel,
});
/**
* Type alias for model parameter, inferred from its zod schema
*/
export type LoRAModelParam = z.infer<typeof zLoRAModel>;
/**
* Validates/type-guards a value as a model parameter
*/
export const isValidLoRAModel = (val: unknown): val is LoRAModelParam =>
zLoRAModel.safeParse(val).success;
/**
* Zod schema for l2l strength parameter

View File

@ -6,9 +6,9 @@ import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { SelectItem } from '@mantine/core';
import { RootState } from 'app/store/store';
import { modelSelected } from 'features/parameters/store/actions';
import { forEach, isString } from 'lodash-es';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import { modelSelected } from '../../parameters/store/actions';
export const MODEL_TYPE_MAP = {
'sd-1': 'Stable Diffusion 1.x',
@ -63,6 +63,16 @@ const ModelSelect = () => {
);
useEffect(() => {
if (isLoading) {
// return early here to avoid resetting model selection before we've loaded the available models
return;
}
if (selectedModel && mainModels?.ids.includes(selectedModel?.id)) {
// the selected model is an available model, no need to change it
return;
}
const firstModel = mainModels?.ids[0];
if (!isString(firstModel)) {
@ -70,7 +80,7 @@ const ModelSelect = () => {
}
handleChangeModel(firstModel);
}, [handleChangeModel, mainModels?.ids]);
}, [handleChangeModel, isLoading, mainModels?.ids, selectedModel]);
return isLoading ? (
<IAIMantineSelect

View File

@ -9,9 +9,10 @@ import { forEach } from 'lodash-es';
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
import { RootState } from 'app/store/store';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { vaeSelected } from 'features/parameters/store/generationSlice';
import { zVaeModel } from 'features/parameters/store/parameterZodSchemas';
import { MODEL_TYPE_MAP } from './ModelSelect';
import IAIMantineSelectItemWithTooltip from '../../../common/components/IAIMantineSelectItemWithTooltip';
const VAESelect = () => {
const dispatch = useAppDispatch();
@ -34,8 +35,8 @@ const VAESelect = () => {
const data: SelectItem[] = [
{
value: 'auto',
label: 'Automatic',
value: 'default',
label: 'Default',
group: 'Default',
},
];
@ -45,50 +46,65 @@ const VAESelect = () => {
return;
}
const disabled = currentMainModel?.base_model !== model.base_model;
data.push({
value: id,
label: model.name,
group: MODEL_TYPE_MAP[model.base_model],
...(currentMainModel?.base_model !== model.base_model
? { disabled: true, tooltip: 'Incompatible base model' }
: {}),
disabled,
tooltip: disabled
? `Incompatible base model: ${model.base_model}`
: undefined,
});
});
return data;
return data.sort((a, b) => (a.disabled && !b.disabled ? 1 : -1));
}, [vaeModels, currentMainModel?.base_model]);
const selectedVaeModel = useMemo(
() => vaeModels?.entities[selectedVae],
() => (selectedVae?.id ? vaeModels?.entities[selectedVae?.id] : null),
[vaeModels?.entities, selectedVae]
);
const handleChangeModel = useCallback(
(v: string | null) => {
if (!v) {
if (!v || v === 'default') {
dispatch(vaeSelected(null));
return;
}
dispatch(vaeSelected(v));
const [base_model, type, name] = v.split('/');
const model = zVaeModel.parse({
id: v,
name,
base_model,
});
dispatch(vaeSelected(model));
},
[dispatch]
);
useEffect(() => {
if (selectedVae && vaeModels?.ids.includes(selectedVae)) {
if (selectedVae && vaeModels?.ids.includes(selectedVae.id)) {
return;
}
handleChangeModel('auto');
}, [handleChangeModel, vaeModels?.ids, selectedVae]);
dispatch(vaeSelected(null));
}, [handleChangeModel, vaeModels?.ids, selectedVae, dispatch]);
return (
<IAIMantineSelect
itemComponent={IAIMantineSelectItemWithTooltip}
tooltip={selectedVaeModel?.description}
label={t('modelManager.vae')}
value={selectedVae}
placeholder="Pick one"
value={selectedVae?.id ?? 'default'}
placeholder="Default"
data={data}
onChange={handleChangeModel}
disabled={data.length === 0}
clearable
/>
);
};