mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): fix a lot of model-related crashes/bugs
We were storing all types of models by their model ID, which is a format like `sd-1/main/deliberate`. This meant we had to do a lot of extra parsing, because nodes actually wants something like `{base_model: 'sd-1', model_name: 'deliberate'}`. Some of this parsing was done with zod's error-throwing `parse()` method, and in other places it was done with brittle string parsing. This commit refactors the state to use the object form of models. There is still a bit of string parsing done in the to construct the ID from the object form, but it's far less complicated. Also, the zod parsing is now done using `safeParse()`, which does not throw. This requires a few more conditional checks, but should prevent further crashes.
This commit is contained in:
@ -1,8 +1,8 @@
|
||||
import { Box, Flex } from '@chakra-ui/react';
|
||||
import ModelSelect from 'features/system/components/ModelSelect';
|
||||
import VAESelect from 'features/system/components/VAESelect';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { memo } from 'react';
|
||||
import { useFeatureStatus } from '../../../../system/hooks/useFeatureStatus';
|
||||
import ParamMainModelSelect from '../MainModel/ParamMainModelSelect';
|
||||
import ParamVAEModelSelect from '../VAEModel/ParamVAEModelSelect';
|
||||
import ParamScheduler from './ParamScheduler';
|
||||
|
||||
const ParamModelandVAEandScheduler = () => {
|
||||
@ -11,12 +11,12 @@ const ParamModelandVAEandScheduler = () => {
|
||||
return (
|
||||
<Flex gap={3} w="full" flexWrap={isVaeEnabled ? 'wrap' : 'nowrap'}>
|
||||
<Box w="full">
|
||||
<ModelSelect />
|
||||
<ParamMainModelSelect />
|
||||
</Box>
|
||||
<Flex gap={3} w="full">
|
||||
{isVaeEnabled && (
|
||||
<Box w="full">
|
||||
<VAESelect />
|
||||
<ParamVAEModelSelect />
|
||||
</Box>
|
||||
)}
|
||||
<Box w="full">
|
||||
|
@ -5,7 +5,7 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||
import { setScheduler } from 'features/parameters/store/generationSlice';
|
||||
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
|
||||
import { SchedulerParam } from 'features/parameters/types/parameterSchemas';
|
||||
import { uiSelector } from 'features/ui/store/uiSelectors';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
@ -0,0 +1,100 @@
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
|
||||
import { SelectItem } from '@mantine/core';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { modelSelected } from 'features/parameters/store/actions';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
(state) => ({ model: state.generation.model }),
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const ParamMainModelSelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const { model } = useAppSelector(selector);
|
||||
|
||||
const { data: mainModels, isLoading } = useGetMainModelsQuery();
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!mainModels) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const data: SelectItem[] = [];
|
||||
|
||||
forEach(mainModels.entities, (model, id) => {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
|
||||
data.push({
|
||||
value: id,
|
||||
label: model.model_name,
|
||||
group: MODEL_TYPE_MAP[model.base_model],
|
||||
});
|
||||
});
|
||||
|
||||
return data;
|
||||
}, [mainModels]);
|
||||
|
||||
// grab the full model entity from the RTK Query cache
|
||||
// TODO: maybe we should just store the full model entity in state?
|
||||
const selectedModel = useMemo(
|
||||
() =>
|
||||
mainModels?.entities[`${model?.base_model}/main/${model?.model_name}`] ??
|
||||
null,
|
||||
[mainModels?.entities, model]
|
||||
);
|
||||
|
||||
const handleChangeModel = useCallback(
|
||||
(v: string | null) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
|
||||
const newModel = modelIdToMainModelParam(v);
|
||||
|
||||
if (!newModel) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(modelSelected(newModel));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return isLoading ? (
|
||||
<IAIMantineSelect
|
||||
label={t('modelManager.model')}
|
||||
placeholder="Loading..."
|
||||
disabled={true}
|
||||
data={[]}
|
||||
/>
|
||||
) : (
|
||||
<IAIMantineSelect
|
||||
tooltip={selectedModel?.description}
|
||||
label={t('modelManager.model')}
|
||||
value={selectedModel?.id}
|
||||
placeholder={data.length > 0 ? 'Select a model' : 'No models available'}
|
||||
data={data}
|
||||
error={data.length === 0}
|
||||
disabled={data.length === 0}
|
||||
onChange={handleChangeModel}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamMainModelSelect);
|
@ -0,0 +1,109 @@
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
|
||||
import { SelectItem } from '@mantine/core';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
|
||||
import { vaeSelected } from 'features/parameters/store/generationSlice';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { modelIdToVAEModelParam } from 'features/parameters/util/modelIdToVAEModelParam';
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
({ generation }) => {
|
||||
const { model, vae } = generation;
|
||||
return { model, vae };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const ParamVAEModelSelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const { model, vae } = useAppSelector(selector);
|
||||
|
||||
const { data: vaeModels } = useGetVaeModelsQuery();
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!vaeModels) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const data: SelectItem[] = [
|
||||
{
|
||||
value: 'default',
|
||||
label: 'Default',
|
||||
group: 'Default',
|
||||
},
|
||||
];
|
||||
|
||||
forEach(vaeModels.entities, (vae, id) => {
|
||||
if (!vae) {
|
||||
return;
|
||||
}
|
||||
|
||||
const disabled = model?.base_model !== vae.base_model;
|
||||
|
||||
data.push({
|
||||
value: id,
|
||||
label: vae.model_name,
|
||||
group: MODEL_TYPE_MAP[vae.base_model],
|
||||
disabled,
|
||||
tooltip: disabled
|
||||
? `Incompatible base model: ${vae.base_model}`
|
||||
: undefined,
|
||||
});
|
||||
});
|
||||
|
||||
return data.sort((a, b) => (a.disabled && !b.disabled ? 1 : -1));
|
||||
}, [vaeModels, model?.base_model]);
|
||||
|
||||
// grab the full model entity from the RTK Query cache
|
||||
const selectedVaeModel = useMemo(
|
||||
() =>
|
||||
vaeModels?.entities[`${vae?.base_model}/vae/${vae?.model_name}`] ?? null,
|
||||
[vaeModels?.entities, vae]
|
||||
);
|
||||
|
||||
const handleChangeModel = useCallback(
|
||||
(v: string | null) => {
|
||||
if (!v || v === 'default') {
|
||||
dispatch(vaeSelected(null));
|
||||
return;
|
||||
}
|
||||
|
||||
const newVaeModel = modelIdToVAEModelParam(v);
|
||||
|
||||
if (!newVaeModel) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(vaeSelected(newVaeModel));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<IAIMantineSelect
|
||||
itemComponent={IAIMantineSelectItemWithTooltip}
|
||||
tooltip={selectedVaeModel?.description}
|
||||
label={t('modelManager.vae')}
|
||||
value={selectedVaeModel?.id ?? 'default'}
|
||||
placeholder="Default"
|
||||
data={data}
|
||||
onChange={handleChangeModel}
|
||||
disabled={data.length === 0}
|
||||
clearable
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamVAEModelSelect);
|
Reference in New Issue
Block a user