feat(ui): fix a lot of model-related crashes/bugs

We were storing all types of models by their model ID, which is a format like `sd-1/main/deliberate`.

This meant we had to do a lot of extra parsing, because nodes actually wants something like `{base_model: 'sd-1', model_name: 'deliberate'}`.

Some of this parsing was done with zod's error-throwing `parse()` method, and in other places it was done with brittle string parsing.

This commit refactors the state to use the object form of models.

There is still a bit of string parsing done in the to construct the ID from the object form, but it's far less complicated.

Also, the zod parsing is now done using `safeParse()`, which does not throw. This requires a few more conditional checks, but should prevent further crashes.
This commit is contained in:
psychedelicious
2023-07-14 14:14:03 +10:00
parent 14587464d5
commit a071873327
34 changed files with 342 additions and 201 deletions

View File

@ -1,8 +1,8 @@
import { Box, Flex } from '@chakra-ui/react';
import ModelSelect from 'features/system/components/ModelSelect';
import VAESelect from 'features/system/components/VAESelect';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { memo } from 'react';
import { useFeatureStatus } from '../../../../system/hooks/useFeatureStatus';
import ParamMainModelSelect from '../MainModel/ParamMainModelSelect';
import ParamVAEModelSelect from '../VAEModel/ParamVAEModelSelect';
import ParamScheduler from './ParamScheduler';
const ParamModelandVAEandScheduler = () => {
@ -11,12 +11,12 @@ const ParamModelandVAEandScheduler = () => {
return (
<Flex gap={3} w="full" flexWrap={isVaeEnabled ? 'wrap' : 'nowrap'}>
<Box w="full">
<ModelSelect />
<ParamMainModelSelect />
</Box>
<Flex gap={3} w="full">
{isVaeEnabled && (
<Box w="full">
<VAESelect />
<ParamVAEModelSelect />
</Box>
)}
<Box w="full">

View File

@ -5,7 +5,7 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import { setScheduler } from 'features/parameters/store/generationSlice';
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
import { SchedulerParam } from 'features/parameters/types/parameterSchemas';
import { uiSelector } from 'features/ui/store/uiSelectors';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';

View File

@ -0,0 +1,100 @@
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { SelectItem } from '@mantine/core';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { modelSelected } from 'features/parameters/store/actions';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
import { forEach } from 'lodash-es';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
const selector = createSelector(
stateSelector,
(state) => ({ model: state.generation.model }),
defaultSelectorOptions
);
const ParamMainModelSelect = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { model } = useAppSelector(selector);
const { data: mainModels, isLoading } = useGetMainModelsQuery();
const data = useMemo(() => {
if (!mainModels) {
return [];
}
const data: SelectItem[] = [];
forEach(mainModels.entities, (model, id) => {
if (!model) {
return;
}
data.push({
value: id,
label: model.model_name,
group: MODEL_TYPE_MAP[model.base_model],
});
});
return data;
}, [mainModels]);
// grab the full model entity from the RTK Query cache
// TODO: maybe we should just store the full model entity in state?
const selectedModel = useMemo(
() =>
mainModels?.entities[`${model?.base_model}/main/${model?.model_name}`] ??
null,
[mainModels?.entities, model]
);
const handleChangeModel = useCallback(
(v: string | null) => {
if (!v) {
return;
}
const newModel = modelIdToMainModelParam(v);
if (!newModel) {
return;
}
dispatch(modelSelected(newModel));
},
[dispatch]
);
return isLoading ? (
<IAIMantineSelect
label={t('modelManager.model')}
placeholder="Loading..."
disabled={true}
data={[]}
/>
) : (
<IAIMantineSelect
tooltip={selectedModel?.description}
label={t('modelManager.model')}
value={selectedModel?.id}
placeholder={data.length > 0 ? 'Select a model' : 'No models available'}
data={data}
error={data.length === 0}
disabled={data.length === 0}
onChange={handleChangeModel}
/>
);
};
export default memo(ParamMainModelSelect);

View File

@ -0,0 +1,109 @@
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { SelectItem } from '@mantine/core';
import { forEach } from 'lodash-es';
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { vaeSelected } from 'features/parameters/store/generationSlice';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToVAEModelParam } from 'features/parameters/util/modelIdToVAEModelParam';
const selector = createSelector(
stateSelector,
({ generation }) => {
const { model, vae } = generation;
return { model, vae };
},
defaultSelectorOptions
);
const ParamVAEModelSelect = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { model, vae } = useAppSelector(selector);
const { data: vaeModels } = useGetVaeModelsQuery();
const data = useMemo(() => {
if (!vaeModels) {
return [];
}
const data: SelectItem[] = [
{
value: 'default',
label: 'Default',
group: 'Default',
},
];
forEach(vaeModels.entities, (vae, id) => {
if (!vae) {
return;
}
const disabled = model?.base_model !== vae.base_model;
data.push({
value: id,
label: vae.model_name,
group: MODEL_TYPE_MAP[vae.base_model],
disabled,
tooltip: disabled
? `Incompatible base model: ${vae.base_model}`
: undefined,
});
});
return data.sort((a, b) => (a.disabled && !b.disabled ? 1 : -1));
}, [vaeModels, model?.base_model]);
// grab the full model entity from the RTK Query cache
const selectedVaeModel = useMemo(
() =>
vaeModels?.entities[`${vae?.base_model}/vae/${vae?.model_name}`] ?? null,
[vaeModels?.entities, vae]
);
const handleChangeModel = useCallback(
(v: string | null) => {
if (!v || v === 'default') {
dispatch(vaeSelected(null));
return;
}
const newVaeModel = modelIdToVAEModelParam(v);
if (!newVaeModel) {
return;
}
dispatch(vaeSelected(newVaeModel));
},
[dispatch]
);
return (
<IAIMantineSelect
itemComponent={IAIMantineSelectItemWithTooltip}
tooltip={selectedVaeModel?.description}
label={t('modelManager.vae')}
value={selectedVaeModel?.id ?? 'default'}
placeholder="Default"
data={data}
onChange={handleChangeModel}
disabled={data.length === 0}
clearable
/>
);
};
export default memo(ParamVAEModelSelect);

View File

@ -28,7 +28,7 @@ import {
isValidSteps,
isValidStrength,
isValidWidth,
} from '../store/parameterZodSchemas';
} from '../types/parameterSchemas';
export const useRecallParameters = () => {
const dispatch = useAppDispatch();

View File

@ -13,6 +13,7 @@ import { clipSkipMap } from '../components/Parameters/Advanced/ParamClipSkip';
import {
CfgScaleParam,
HeightParam,
MainModelParam,
NegativePromptParam,
PositivePromptParam,
SchedulerParam,
@ -22,7 +23,7 @@ import {
VaeModelParam,
WidthParam,
zMainModel,
} from './parameterZodSchemas';
} from '../types/parameterSchemas';
export interface GenerationState {
cfgScale: CfgScaleParam;
@ -226,18 +227,19 @@ export const generationSlice = createSlice({
const { image_name, width, height } = action.payload;
state.initialImage = { imageName: image_name, width, height };
},
modelChanged: (state, action: PayloadAction<MainModelField | null>) => {
if (!action.payload) {
state.model = null;
}
modelChanged: (state, action: PayloadAction<MainModelParam | null>) => {
state.model = action.payload;
state.model = zMainModel.parse(action.payload);
if (state.model === null) {
return;
}
// Clamp ClipSkip Based On Selected Model
const { maxClip } = clipSkipMap[state.model.base_model];
state.clipSkip = clamp(state.clipSkip, 0, maxClip);
},
vaeSelected: (state, action: PayloadAction<VaeModelParam | null>) => {
// null is a valid VAE!
state.vae = action.payload;
},
setClipSkip: (state, action: PayloadAction<number>) => {
@ -253,11 +255,15 @@ export const generationSlice = createSlice({
if (defaultModel && !state.model) {
const [base_model, model_type, model_name] = defaultModel.split('/');
state.model = zMainModel.parse({
id: defaultModel,
name: model_name,
const result = zMainModel.safeParse({
model_name,
base_model,
});
if (result.success) {
state.model = result.data;
}
}
});
builder.addCase(setShouldShowAdvancedOptions, (state, action) => {

View File

@ -0,0 +1,4 @@
export const MODEL_TYPE_MAP = {
'sd-1': 'Stable Diffusion 1.x',
'sd-2': 'Stable Diffusion 2.x',
};

View File

@ -135,7 +135,7 @@ export type BaseModelParam = z.infer<typeof zBaseModel>;
* TODO: Make this a dynamically generated enum?
*/
export const zMainModel = z.object({
model_name: z.string(),
model_name: z.string().min(1),
base_model: zBaseModel,
});
@ -152,8 +152,7 @@ export const isValidMainModel = (val: unknown): val is MainModelParam =>
* Zod schema for VAE parameter
*/
export const zVaeModel = z.object({
id: z.string(),
name: z.string(),
model_name: z.string().min(1),
base_model: zBaseModel,
});
/**
@ -169,8 +168,7 @@ export const isValidVaeModel = (val: unknown): val is VaeModelParam =>
* Zod schema for LoRA
*/
export const zLoRAModel = z.object({
id: z.string(),
model_name: z.string(),
model_name: z.string().min(1),
base_model: zBaseModel,
});
/**

View File

@ -0,0 +1,18 @@
import { LoRAModelParam, zLoRAModel } from '../types/parameterSchemas';
export const modelIdToLoRAModelParam = (
loraId: string
): LoRAModelParam | undefined => {
const [base_model, model_type, model_name] = loraId.split('/');
const result = zLoRAModel.safeParse({
base_model,
model_name,
});
if (!result.success) {
return;
}
return result.data;
};

View File

@ -0,0 +1,21 @@
import {
MainModelParam,
zMainModel,
} from 'features/parameters/types/parameterSchemas';
export const modelIdToMainModelParam = (
modelId: string
): MainModelParam | undefined => {
const [base_model, model_type, model_name] = modelId.split('/');
const result = zMainModel.safeParse({
base_model,
model_name,
});
if (!result.success) {
return;
}
return result.data;
};

View File

@ -0,0 +1,18 @@
import { VaeModelParam, zVaeModel } from '../types/parameterSchemas';
export const modelIdToVAEModelParam = (
modelId: string
): VaeModelParam | undefined => {
const [base_model, model_type, model_name] = modelId.split('/');
const result = zVaeModel.safeParse({
base_model,
model_name,
});
if (!result.success) {
return;
}
return result.data;
};