feat(ui): fix a lot of model-related crashes/bugs

We were storing all types of models by their model ID, which is a format like `sd-1/main/deliberate`.

This meant we had to do a lot of extra parsing, because nodes actually wants something like `{base_model: 'sd-1', model_name: 'deliberate'}`.

Some of this parsing was done with zod's error-throwing `parse()` method, and in other places it was done with brittle string parsing.

This commit refactors the state to use the object form of models.

There is still a bit of string parsing done in the to construct the ID from the object form, but it's far less complicated.

Also, the zod parsing is now done using `safeParse()`, which does not throw. This requires a few more conditional checks, but should prevent further crashes.
This commit is contained in:
psychedelicious 2023-07-14 14:14:03 +10:00
parent 14587464d5
commit a071873327
34 changed files with 342 additions and 201 deletions

View File

@ -1,6 +1,7 @@
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
// zod needs the array to be `as const` to infer the type correctly // zod needs the array to be `as const` to infer the type correctly
import { SchedulerParam } from 'features/parameters/types/parameterSchemas';
// this is the source of the `SchedulerParam` type, which is generated by zod // this is the source of the `SchedulerParam` type, which is generated by zod
export const SCHEDULER_NAMES_AS_CONST = [ export const SCHEDULER_NAMES_AS_CONST = [
'euler', 'euler',

View File

@ -1,36 +1,70 @@
import { makeToast } from 'app/components/Toaster'; import { makeToast } from 'app/components/Toaster';
import { log } from 'app/logging/useLogger';
import { loraRemoved } from 'features/lora/store/loraSlice';
import { modelSelected } from 'features/parameters/store/actions'; import { modelSelected } from 'features/parameters/store/actions';
import { import {
modelChanged, modelChanged,
vaeSelected, vaeSelected,
} from 'features/parameters/store/generationSlice'; } from 'features/parameters/store/generationSlice';
import { zMainModel } from 'features/parameters/store/parameterZodSchemas'; import { zMainModel } from 'features/parameters/types/parameterSchemas';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { forEach } from 'lodash-es';
import { startAppListening } from '..'; import { startAppListening } from '..';
import { lorasCleared } from '../../../../../features/lora/store/loraSlice';
const moduleLog = log.child({ module: 'models' });
export const addModelSelectedListener = () => { export const addModelSelectedListener = () => {
startAppListening({ startAppListening({
actionCreator: modelSelected, actionCreator: modelSelected,
effect: (action, { getState, dispatch }) => { effect: (action, { getState, dispatch }) => {
const state = getState(); const state = getState();
const { base_model, model_name } = action.payload; const result = zMainModel.safeParse(action.payload);
if (!result.success) {
moduleLog.error(
{ error: result.error.format() },
'Failed to parse main model'
);
return;
}
const newModel = result.data;
const { base_model } = newModel;
if (state.generation.model?.base_model !== base_model) { if (state.generation.model?.base_model !== base_model) {
// we may need to reset some incompatible submodels
let modelsCleared = 0;
// handle incompatible loras
forEach(state.lora.loras, (lora, id) => {
if (lora.base_model !== base_model) {
dispatch(loraRemoved(id));
modelsCleared += 1;
}
});
// handle incompatible vae
const { vae } = state.generation;
if (vae && vae.base_model !== base_model) {
dispatch(vaeSelected(null));
modelsCleared += 1;
}
// TODO: handle incompatible controlnet; pending model manager support
if (modelsCleared > 0) {
dispatch( dispatch(
addToast( addToast(
makeToast({ makeToast({
title: 'Base model changed, clearing submodels', title: `Base model changed, cleared ${modelsCleared} incompatible submodel${
modelsCleared === 1 ? '' : 's'
}`,
status: 'warning', status: 'warning',
}) })
) )
); );
dispatch(vaeSelected(null));
dispatch(lorasCleared());
// TODO: controlnet cleared
} }
}
const newModel = zMainModel.parse(action.payload);
dispatch(modelChanged(newModel)); dispatch(modelChanged(newModel));
}, },

View File

@ -1,8 +1,19 @@
import { modelChanged } from 'features/parameters/store/generationSlice'; import { log } from 'app/logging/useLogger';
import { some } from 'lodash-es'; import { loraRemoved } from 'features/lora/store/loraSlice';
import {
modelChanged,
vaeSelected,
} from 'features/parameters/store/generationSlice';
import {
zMainModel,
zVaeModel,
} from 'features/parameters/types/parameterSchemas';
import { forEach, some } from 'lodash-es';
import { modelsApi } from 'services/api/endpoints/models'; import { modelsApi } from 'services/api/endpoints/models';
import { startAppListening } from '..'; import { startAppListening } from '..';
const moduleLog = log.child({ module: 'models' });
export const addModelsLoadedListener = () => { export const addModelsLoadedListener = () => {
startAppListening({ startAppListening({
matcher: modelsApi.endpoints.getMainModels.matchFulfilled, matcher: modelsApi.endpoints.getMainModels.matchFulfilled,
@ -31,12 +42,92 @@ export const addModelsLoadedListener = () => {
return; return;
} }
dispatch( const result = zMainModel.safeParse(firstModel);
modelChanged({
base_model: firstModel.base_model, if (!result.success) {
model_name: firstModel.model_name, moduleLog.error(
}) { error: result.error.format() },
'Failed to parse main model'
); );
return;
}
dispatch(modelChanged(result.data));
},
});
startAppListening({
matcher: modelsApi.endpoints.getVaeModels.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
// VAEs loaded, need to reset the VAE is it's no longer available
const currentVae = getState().generation.vae;
if (currentVae === null) {
// null is a valid VAE! it means "use the default with the main model"
return;
}
const isCurrentVAEAvailable = some(
action.payload.entities,
(m) =>
m?.model_name === currentVae?.model_name &&
m?.base_model === currentVae?.base_model
);
if (isCurrentVAEAvailable) {
return;
}
const firstModelId = action.payload.ids[0];
const firstModel = action.payload.entities[firstModelId];
if (!firstModel) {
// No custom VAEs loaded at all; use the default
dispatch(modelChanged(null));
return;
}
const result = zVaeModel.safeParse(firstModel);
if (!result.success) {
moduleLog.error(
{ error: result.error.format() },
'Failed to parse VAE model'
);
return;
}
dispatch(vaeSelected(result.data));
},
});
startAppListening({
matcher: modelsApi.endpoints.getLoRAModels.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
// LoRA models loaded - need to remove missing LoRAs from state
const loras = getState().lora.loras;
forEach(loras, (lora, id) => {
const isLoRAAvailable = some(
action.payload.entities,
(m) =>
m?.model_name === lora?.model_name &&
m?.base_model === lora?.base_model
);
if (isLoRAAvailable) {
return;
}
dispatch(loraRemoved(id));
});
},
});
startAppListening({
matcher: modelsApi.endpoints.getControlNetModels.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
// ControlNet models loaded - need to remove missing ControlNets from state
// TODO: pending model manager controlnet support
}, },
}); });
}; };

View File

@ -11,7 +11,7 @@ import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip'; import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { forEach } from 'lodash-es'; import { forEach } from 'lodash-es';
import { PropsWithChildren, useCallback, useMemo, useRef } from 'react'; import { PropsWithChildren, useCallback, useMemo, useRef } from 'react';
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models'; import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';

View File

@ -5,14 +5,14 @@ import IAISlider from 'common/components/IAISlider';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { FaTrash } from 'react-icons/fa'; import { FaTrash } from 'react-icons/fa';
import { import {
Lora, LoRA,
loraRemoved, loraRemoved,
loraWeightChanged, loraWeightChanged,
loraWeightReset, loraWeightReset,
} from '../store/loraSlice'; } from '../store/loraSlice';
type Props = { type Props = {
lora: Lora; lora: LoRA;
}; };
const ParamLora = (props: Props) => { const ParamLora = (props: Props) => {

View File

@ -6,9 +6,9 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAICollapse from 'common/components/IAICollapse'; import IAICollapse from 'common/components/IAICollapse';
import { size } from 'lodash-es'; import { size } from 'lodash-es';
import { memo } from 'react'; import { memo } from 'react';
import ParamLoraList from './ParamLoraList';
import ParamLoraSelect from './ParamLoraSelect';
import { useFeatureStatus } from '../../system/hooks/useFeatureStatus'; import { useFeatureStatus } from '../../system/hooks/useFeatureStatus';
import ParamLoraList from './ParamLoraList';
import ParamLoRASelect from './ParamLoraSelect';
const selector = createSelector( const selector = createSelector(
stateSelector, stateSelector,
@ -33,7 +33,7 @@ const ParamLoraCollapse = () => {
return ( return (
<IAICollapse label={'LoRA'} activeLabel={activeLabel}> <IAICollapse label={'LoRA'} activeLabel={activeLabel}>
<Flex sx={{ flexDir: 'column', gap: 2 }}> <Flex sx={{ flexDir: 'column', gap: 2 }}>
<ParamLoraSelect /> <ParamLoRASelect />
<ParamLoraList /> <ParamLoraList />
</Flex> </Flex>
</IAICollapse> </IAICollapse>

View File

@ -7,7 +7,7 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip'; import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { loraAdded } from 'features/lora/store/loraSlice'; import { loraAdded } from 'features/lora/store/loraSlice';
import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { forEach } from 'lodash-es'; import { forEach } from 'lodash-es';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models'; import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
@ -20,23 +20,23 @@ const selector = createSelector(
defaultSelectorOptions defaultSelectorOptions
); );
const ParamLoraSelect = () => { const ParamLoRASelect = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { loras } = useAppSelector(selector); const { loras } = useAppSelector(selector);
const { data: lorasQueryData } = useGetLoRAModelsQuery(); const { data: loraModels } = useGetLoRAModelsQuery();
const currentMainModel = useAppSelector( const currentMainModel = useAppSelector(
(state: RootState) => state.generation.model (state: RootState) => state.generation.model
); );
const data = useMemo(() => { const data = useMemo(() => {
if (!lorasQueryData) { if (!loraModels) {
return []; return [];
} }
const data: SelectItem[] = []; const data: SelectItem[] = [];
forEach(lorasQueryData.entities, (lora, id) => { forEach(loraModels.entities, (lora, id) => {
if (!lora || Boolean(id in loras)) { if (!lora || Boolean(id in loras)) {
return; return;
} }
@ -55,23 +55,25 @@ const ParamLoraSelect = () => {
}); });
return data.sort((a, b) => (a.disabled && !b.disabled ? 1 : -1)); return data.sort((a, b) => (a.disabled && !b.disabled ? 1 : -1));
}, [loras, lorasQueryData, currentMainModel?.base_model]); }, [loras, loraModels, currentMainModel?.base_model]);
const handleChange = useCallback( const handleChange = useCallback(
(v: string | null | undefined) => { (v: string | null | undefined) => {
if (!v) { if (!v) {
return; return;
} }
const loraEntity = lorasQueryData?.entities[v]; const loraEntity = loraModels?.entities[v];
if (!loraEntity) { if (!loraEntity) {
return; return;
} }
dispatch(loraAdded(loraEntity)); dispatch(loraAdded(loraEntity));
}, },
[dispatch, lorasQueryData?.entities] [dispatch, loraModels?.entities]
); );
if (lorasQueryData?.ids.length === 0) { if (loraModels?.ids.length === 0) {
return ( return (
<Flex sx={{ justifyContent: 'center', p: 2 }}> <Flex sx={{ justifyContent: 'center', p: 2 }}>
<Text sx={{ fontSize: 'sm', color: 'base.500', _dark: 'base.700' }}> <Text sx={{ fontSize: 'sm', color: 'base.500', _dark: 'base.700' }}>
@ -98,4 +100,4 @@ const ParamLoraSelect = () => {
); );
}; };
export default ParamLoraSelect; export default ParamLoRASelect;

View File

@ -1,8 +1,8 @@
import { PayloadAction, createSlice } from '@reduxjs/toolkit'; import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import { LoRAModelParam } from 'features/parameters/store/parameterZodSchemas'; import { LoRAModelParam } from 'features/parameters/types/parameterSchemas';
import { LoRAModelConfigEntity } from 'services/api/endpoints/models'; import { LoRAModelConfigEntity } from 'services/api/endpoints/models';
export type Lora = LoRAModelParam & { export type LoRA = LoRAModelParam & {
weight: number; weight: number;
}; };
@ -11,7 +11,7 @@ export const defaultLoRAConfig = {
}; };
export type LoraState = { export type LoraState = {
loras: Record<string, LoRAModelParam & { weight: number }>; loras: Record<string, LoRA>;
}; };
export const intialLoraState: LoraState = { export const intialLoraState: LoraState = {
@ -24,7 +24,7 @@ export const loraSlice = createSlice({
reducers: { reducers: {
loraAdded: (state, action: PayloadAction<LoRAModelConfigEntity>) => { loraAdded: (state, action: PayloadAction<LoRAModelConfigEntity>) => {
const { model_name, id, base_model } = action.payload; const { model_name, id, base_model } = action.payload;
state.loras[id] = { id, model_name, base_model, ...defaultLoRAConfig }; state.loras[id] = { model_name, base_model, ...defaultLoRAConfig };
}, },
loraRemoved: (state, action: PayloadAction<string>) => { loraRemoved: (state, action: PayloadAction<string>) => {
const id = action.payload; const id = action.payload;

View File

@ -6,7 +6,7 @@ import {
VaeModelInputFieldTemplate, VaeModelInputFieldTemplate,
VaeModelInputFieldValue, VaeModelInputFieldValue,
} from 'features/nodes/types/types'; } from 'features/nodes/types/types';
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { forEach, isString } from 'lodash-es'; import { forEach, isString } from 'lodash-es';
import { memo, useCallback, useEffect, useMemo } from 'react'; import { memo, useCallback, useEffect, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
@ -46,7 +46,7 @@ const LoRAModelInputFieldComponent = (
data.push({ data.push({
value: id, value: id,
label: model.model_name, label: model.model_name,
group: BASE_MODEL_NAME_MAP[model.base_model], group: MODEL_TYPE_MAP[model.base_model],
}); });
}); });
@ -88,8 +88,7 @@ const LoRAModelInputFieldComponent = (
<IAIMantineSelect <IAIMantineSelect
tooltip={selectedModel?.description} tooltip={selectedModel?.description}
label={ label={
selectedModel?.base_model && selectedModel?.base_model && MODEL_TYPE_MAP[selectedModel?.base_model]
BASE_MODEL_NAME_MAP[selectedModel?.base_model]
} }
value={field.value} value={field.value}
placeholder="Pick one" placeholder="Pick one"

View File

@ -7,7 +7,7 @@ import {
} from 'features/nodes/types/types'; } from 'features/nodes/types/types';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { forEach, isString } from 'lodash-es'; import { forEach, isString } from 'lodash-es';
import { memo, useCallback, useEffect, useMemo } from 'react'; import { memo, useCallback, useEffect, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
@ -39,7 +39,7 @@ const ModelInputFieldComponent = (
data.push({ data.push({
value: id, value: id,
label: model.model_name, label: model.model_name,
group: BASE_MODEL_NAME_MAP[model.base_model], group: MODEL_TYPE_MAP[model.base_model],
}); });
}); });
@ -86,8 +86,7 @@ const ModelInputFieldComponent = (
<IAIMantineSelect <IAIMantineSelect
tooltip={selectedModel?.description} tooltip={selectedModel?.description}
label={ label={
selectedModel?.base_model && selectedModel?.base_model && MODEL_TYPE_MAP[selectedModel?.base_model]
BASE_MODEL_NAME_MAP[selectedModel?.base_model]
} }
value={field.value} value={field.value}
placeholder="Pick one" placeholder="Pick one"

View File

@ -6,7 +6,7 @@ import {
VaeModelInputFieldTemplate, VaeModelInputFieldTemplate,
VaeModelInputFieldValue, VaeModelInputFieldValue,
} from 'features/nodes/types/types'; } from 'features/nodes/types/types';
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { forEach } from 'lodash-es'; import { forEach } from 'lodash-es';
import { memo, useCallback, useEffect, useMemo } from 'react'; import { memo, useCallback, useEffect, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
@ -46,7 +46,7 @@ const VaeModelInputFieldComponent = (
data.push({ data.push({
value: id, value: id,
label: model.model_name, label: model.model_name,
group: BASE_MODEL_NAME_MAP[model.base_model], group: MODEL_TYPE_MAP[model.base_model],
}); });
}); });
@ -81,8 +81,7 @@ const VaeModelInputFieldComponent = (
<IAIMantineSelect <IAIMantineSelect
tooltip={selectedModel?.description} tooltip={selectedModel?.description}
label={ label={
selectedModel?.base_model && selectedModel?.base_model && MODEL_TYPE_MAP[selectedModel?.base_model]
BASE_MODEL_NAME_MAP[selectedModel?.base_model]
} }
value={field.value} value={field.value}
placeholder="Pick one" placeholder="Pick one"

View File

@ -5,7 +5,6 @@ import {
LoraLoaderInvocation, LoraLoaderInvocation,
MetadataAccumulatorInvocation, MetadataAccumulatorInvocation,
} from 'services/api/types'; } from 'services/api/types';
import { modelIdToLoRAModelField } from '../modelIdToLoRAName';
import { import {
CLIP_SKIP, CLIP_SKIP,
LORA_LOADER, LORA_LOADER,
@ -55,23 +54,22 @@ export const addLoRAsToGraph = (
let currentLoraIndex = 0; let currentLoraIndex = 0;
forEach(loras, (lora) => { forEach(loras, (lora) => {
const { id, name, weight } = lora; const { model_name, base_model, weight } = lora;
const loraField = modelIdToLoRAModelField(id); const currentLoraNodeId = `${LORA_LOADER}_${model_name.replace('.', '_')}`;
const currentLoraNodeId = `${LORA_LOADER}_${loraField.model_name.replace(
'.',
'_'
)}`;
const loraLoaderNode: LoraLoaderInvocation = { const loraLoaderNode: LoraLoaderInvocation = {
type: 'lora_loader', type: 'lora_loader',
id: currentLoraNodeId, id: currentLoraNodeId,
lora: loraField, lora,
weight, weight,
}; };
// add the lora to the metadata accumulator // add the lora to the metadata accumulator
if (metadataAccumulator) { if (metadataAccumulator) {
metadataAccumulator.loras.push({ lora: loraField, weight }); metadataAccumulator.loras.push({
lora: { model_name, base_model },
weight,
});
} }
// add to graph // add to graph

View File

@ -1,7 +1,6 @@
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types'; import { NonNullableGraph } from 'features/nodes/types/types';
import { MetadataAccumulatorInvocation } from 'services/api/types'; import { MetadataAccumulatorInvocation } from 'services/api/types';
import { modelIdToVAEModelField } from '../modelIdToVAEModelField';
import { import {
IMAGE_TO_IMAGE_GRAPH, IMAGE_TO_IMAGE_GRAPH,
IMAGE_TO_LATENTS, IMAGE_TO_LATENTS,
@ -19,7 +18,6 @@ export const addVAEToGraph = (
graph: NonNullableGraph graph: NonNullableGraph
): void => { ): void => {
const { vae } = state.generation; const { vae } = state.generation;
const vae_model = modelIdToVAEModelField(vae?.id || '');
const isAutoVae = !vae; const isAutoVae = !vae;
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
@ -30,7 +28,7 @@ export const addVAEToGraph = (
graph.nodes[VAE_LOADER] = { graph.nodes[VAE_LOADER] = {
type: 'vae_loader', type: 'vae_loader',
id: VAE_LOADER, id: VAE_LOADER,
vae_model, vae_model: vae,
}; };
} }
@ -74,6 +72,6 @@ export const addVAEToGraph = (
} }
if (vae && metadataAccumulator) { if (vae && metadataAccumulator) {
metadataAccumulator.vae = vae_model; metadataAccumulator.vae = vae;
} }
}; };

View File

@ -1,12 +1,12 @@
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { InputFieldValue } from 'features/nodes/types/types'; import { InputFieldValue } from 'features/nodes/types/types';
import { modelIdToLoRAModelParam } from 'features/parameters/util/modelIdToLoRAModelParam';
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
import { modelIdToVAEModelParam } from 'features/parameters/util/modelIdToVAEModelParam';
import { cloneDeep, omit, reduce } from 'lodash-es'; import { cloneDeep, omit, reduce } from 'lodash-es';
import { Graph } from 'services/api/types'; import { Graph } from 'services/api/types';
import { AnyInvocation } from 'services/events/types'; import { AnyInvocation } from 'services/events/types';
import { v4 as uuidv4 } from 'uuid'; import { v4 as uuidv4 } from 'uuid';
import { modelIdToLoRAModelField } from '../modelIdToLoRAName';
import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { modelIdToVAEModelField } from '../modelIdToVAEModelField';
/** /**
* We need to do special handling for some fields * We need to do special handling for some fields
@ -29,19 +29,19 @@ export const parseFieldValue = (field: InputFieldValue) => {
if (field.type === 'model') { if (field.type === 'model') {
if (field.value) { if (field.value) {
return modelIdToMainModelField(field.value); return modelIdToMainModelParam(field.value);
} }
} }
if (field.type === 'vae_model') { if (field.type === 'vae_model') {
if (field.value) { if (field.value) {
return modelIdToVAEModelField(field.value); return modelIdToVAEModelParam(field.value);
} }
} }
if (field.type === 'lora_model') { if (field.type === 'lora_model') {
if (field.value) { if (field.value) {
return modelIdToLoRAModelField(field.value); return modelIdToLoRAModelParam(field.value);
} }
} }

View File

@ -1,12 +0,0 @@
import { BaseModelType, LoRAModelField } from 'services/api/types';
export const modelIdToLoRAModelField = (loraId: string): LoRAModelField => {
const [base_model, model_type, model_name] = loraId.split('/');
const field: LoRAModelField = {
base_model: base_model as BaseModelType,
model_name,
};
return field;
};

View File

@ -1,16 +0,0 @@
import { BaseModelType, MainModelField } from 'services/api/types';
/**
* Crudely converts a model id to a main model field
* TODO: Make better
*/
export const modelIdToMainModelField = (modelId: string): MainModelField => {
const [base_model, model_type, model_name] = modelId.split('/');
const field: MainModelField = {
base_model: base_model as BaseModelType,
model_name,
};
return field;
};

View File

@ -1,16 +0,0 @@
import { BaseModelType, VAEModelField } from 'services/api/types';
/**
* Crudely converts a model id to a main model field
* TODO: Make better
*/
export const modelIdToVAEModelField = (modelId: string): VAEModelField => {
const [base_model, model_type, model_name] = modelId.split('/');
const field: VAEModelField = {
base_model: base_model as BaseModelType,
model_name,
};
return field;
};

View File

@ -1,8 +1,8 @@
import { Box, Flex } from '@chakra-ui/react'; import { Box, Flex } from '@chakra-ui/react';
import ModelSelect from 'features/system/components/ModelSelect'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import VAESelect from 'features/system/components/VAESelect';
import { memo } from 'react'; import { memo } from 'react';
import { useFeatureStatus } from '../../../../system/hooks/useFeatureStatus'; import ParamMainModelSelect from '../MainModel/ParamMainModelSelect';
import ParamVAEModelSelect from '../VAEModel/ParamVAEModelSelect';
import ParamScheduler from './ParamScheduler'; import ParamScheduler from './ParamScheduler';
const ParamModelandVAEandScheduler = () => { const ParamModelandVAEandScheduler = () => {
@ -11,12 +11,12 @@ const ParamModelandVAEandScheduler = () => {
return ( return (
<Flex gap={3} w="full" flexWrap={isVaeEnabled ? 'wrap' : 'nowrap'}> <Flex gap={3} w="full" flexWrap={isVaeEnabled ? 'wrap' : 'nowrap'}>
<Box w="full"> <Box w="full">
<ModelSelect /> <ParamMainModelSelect />
</Box> </Box>
<Flex gap={3} w="full"> <Flex gap={3} w="full">
{isVaeEnabled && ( {isVaeEnabled && (
<Box w="full"> <Box w="full">
<VAESelect /> <ParamVAEModelSelect />
</Box> </Box>
)} )}
<Box w="full"> <Box w="full">

View File

@ -5,7 +5,7 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { generationSelector } from 'features/parameters/store/generationSelectors'; import { generationSelector } from 'features/parameters/store/generationSelectors';
import { setScheduler } from 'features/parameters/store/generationSlice'; import { setScheduler } from 'features/parameters/store/generationSlice';
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas'; import { SchedulerParam } from 'features/parameters/types/parameterSchemas';
import { uiSelector } from 'features/ui/store/uiSelectors'; import { uiSelector } from 'features/ui/store/uiSelectors';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';

View File

@ -8,27 +8,23 @@ import { SelectItem } from '@mantine/core';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { modelIdToMainModelField } from 'features/nodes/util/modelIdToMainModelField';
import { modelSelected } from 'features/parameters/store/actions'; import { modelSelected } from 'features/parameters/store/actions';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
import { forEach } from 'lodash-es'; import { forEach } from 'lodash-es';
import { useGetMainModelsQuery } from 'services/api/endpoints/models'; import { useGetMainModelsQuery } from 'services/api/endpoints/models';
export const MODEL_TYPE_MAP = {
'sd-1': 'Stable Diffusion 1.x',
'sd-2': 'Stable Diffusion 2.x',
};
const selector = createSelector( const selector = createSelector(
stateSelector, stateSelector,
(state) => ({ currentModel: state.generation.model }), (state) => ({ model: state.generation.model }),
defaultSelectorOptions defaultSelectorOptions
); );
const ModelSelect = () => { const ParamMainModelSelect = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const { currentModel } = useAppSelector(selector); const { model } = useAppSelector(selector);
const { data: mainModels, isLoading } = useGetMainModelsQuery(); const { data: mainModels, isLoading } = useGetMainModelsQuery();
@ -54,12 +50,13 @@ const ModelSelect = () => {
return data; return data;
}, [mainModels]); }, [mainModels]);
// grab the full model entity from the RTK Query cache
// TODO: maybe we should just store the full model entity in state?
const selectedModel = useMemo( const selectedModel = useMemo(
() => () =>
mainModels?.entities[ mainModels?.entities[`${model?.base_model}/main/${model?.model_name}`] ??
`${currentModel?.base_model}/main/${currentModel?.model_name}` null,
], [mainModels?.entities, model]
[mainModels?.entities, currentModel]
); );
const handleChangeModel = useCallback( const handleChangeModel = useCallback(
@ -68,8 +65,13 @@ const ModelSelect = () => {
return; return;
} }
const modelField = modelIdToMainModelField(v); const newModel = modelIdToMainModelParam(v);
dispatch(modelSelected(modelField));
if (!newModel) {
return;
}
dispatch(modelSelected(newModel));
}, },
[dispatch] [dispatch]
); );
@ -95,4 +97,4 @@ const ModelSelect = () => {
); );
}; };
export default memo(ModelSelect); export default memo(ParamMainModelSelect);

View File

@ -1,4 +1,4 @@
import { memo, useCallback, useEffect, useMemo } from 'react'; import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
@ -8,26 +8,30 @@ import { SelectItem } from '@mantine/core';
import { forEach } from 'lodash-es'; import { forEach } from 'lodash-es';
import { useGetVaeModelsQuery } from 'services/api/endpoints/models'; import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
import { RootState } from 'app/store/store'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip'; import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { vaeSelected } from 'features/parameters/store/generationSlice'; import { vaeSelected } from 'features/parameters/store/generationSlice';
import { zVaeModel } from 'features/parameters/store/parameterZodSchemas'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { MODEL_TYPE_MAP } from './ModelSelect'; import { modelIdToVAEModelParam } from 'features/parameters/util/modelIdToVAEModelParam';
const VAESelect = () => { const selector = createSelector(
stateSelector,
({ generation }) => {
const { model, vae } = generation;
return { model, vae };
},
defaultSelectorOptions
);
const ParamVAEModelSelect = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const { model, vae } = useAppSelector(selector);
const { data: vaeModels } = useGetVaeModelsQuery(); const { data: vaeModels } = useGetVaeModelsQuery();
const currentMainModel = useAppSelector(
(state: RootState) => state.generation.model
);
const selectedVae = useAppSelector(
(state: RootState) => state.generation.vae
);
const data = useMemo(() => { const data = useMemo(() => {
if (!vaeModels) { if (!vaeModels) {
return []; return [];
@ -41,30 +45,32 @@ const VAESelect = () => {
}, },
]; ];
forEach(vaeModels.entities, (model, id) => { forEach(vaeModels.entities, (vae, id) => {
if (!model) { if (!vae) {
return; return;
} }
const disabled = currentMainModel?.base_model !== model.base_model; const disabled = model?.base_model !== vae.base_model;
data.push({ data.push({
value: id, value: id,
label: model.model_name, label: vae.model_name,
group: MODEL_TYPE_MAP[model.base_model], group: MODEL_TYPE_MAP[vae.base_model],
disabled, disabled,
tooltip: disabled tooltip: disabled
? `Incompatible base model: ${model.base_model}` ? `Incompatible base model: ${vae.base_model}`
: undefined, : undefined,
}); });
}); });
return data.sort((a, b) => (a.disabled && !b.disabled ? 1 : -1)); return data.sort((a, b) => (a.disabled && !b.disabled ? 1 : -1));
}, [vaeModels, currentMainModel?.base_model]); }, [vaeModels, model?.base_model]);
// grab the full model entity from the RTK Query cache
const selectedVaeModel = useMemo( const selectedVaeModel = useMemo(
() => (selectedVae?.id ? vaeModels?.entities[selectedVae?.id] : null), () =>
[vaeModels?.entities, selectedVae] vaeModels?.entities[`${vae?.base_model}/vae/${vae?.model_name}`] ?? null,
[vaeModels?.entities, vae]
); );
const handleChangeModel = useCallback( const handleChangeModel = useCallback(
@ -74,32 +80,23 @@ const VAESelect = () => {
return; return;
} }
const [base_model, type, name] = v.split('/'); const newVaeModel = modelIdToVAEModelParam(v);
const model = zVaeModel.parse({ if (!newVaeModel) {
id: v, return;
name, }
base_model,
});
dispatch(vaeSelected(model)); dispatch(vaeSelected(newVaeModel));
}, },
[dispatch] [dispatch]
); );
useEffect(() => {
if (selectedVae && vaeModels?.ids.includes(selectedVae.id)) {
return;
}
dispatch(vaeSelected(null));
}, [handleChangeModel, vaeModels?.ids, selectedVae, dispatch]);
return ( return (
<IAIMantineSelect <IAIMantineSelect
itemComponent={IAIMantineSelectItemWithTooltip} itemComponent={IAIMantineSelectItemWithTooltip}
tooltip={selectedVaeModel?.description} tooltip={selectedVaeModel?.description}
label={t('modelManager.vae')} label={t('modelManager.vae')}
value={selectedVae?.id ?? 'default'} value={selectedVaeModel?.id ?? 'default'}
placeholder="Default" placeholder="Default"
data={data} data={data}
onChange={handleChangeModel} onChange={handleChangeModel}
@ -109,4 +106,4 @@ const VAESelect = () => {
); );
}; };
export default memo(VAESelect); export default memo(ParamVAEModelSelect);

View File

@ -28,7 +28,7 @@ import {
isValidSteps, isValidSteps,
isValidStrength, isValidStrength,
isValidWidth, isValidWidth,
} from '../store/parameterZodSchemas'; } from '../types/parameterSchemas';
export const useRecallParameters = () => { export const useRecallParameters = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();

View File

@ -13,6 +13,7 @@ import { clipSkipMap } from '../components/Parameters/Advanced/ParamClipSkip';
import { import {
CfgScaleParam, CfgScaleParam,
HeightParam, HeightParam,
MainModelParam,
NegativePromptParam, NegativePromptParam,
PositivePromptParam, PositivePromptParam,
SchedulerParam, SchedulerParam,
@ -22,7 +23,7 @@ import {
VaeModelParam, VaeModelParam,
WidthParam, WidthParam,
zMainModel, zMainModel,
} from './parameterZodSchemas'; } from '../types/parameterSchemas';
export interface GenerationState { export interface GenerationState {
cfgScale: CfgScaleParam; cfgScale: CfgScaleParam;
@ -226,18 +227,19 @@ export const generationSlice = createSlice({
const { image_name, width, height } = action.payload; const { image_name, width, height } = action.payload;
state.initialImage = { imageName: image_name, width, height }; state.initialImage = { imageName: image_name, width, height };
}, },
modelChanged: (state, action: PayloadAction<MainModelField | null>) => { modelChanged: (state, action: PayloadAction<MainModelParam | null>) => {
if (!action.payload) { state.model = action.payload;
state.model = null;
}
state.model = zMainModel.parse(action.payload); if (state.model === null) {
return;
}
// Clamp ClipSkip Based On Selected Model // Clamp ClipSkip Based On Selected Model
const { maxClip } = clipSkipMap[state.model.base_model]; const { maxClip } = clipSkipMap[state.model.base_model];
state.clipSkip = clamp(state.clipSkip, 0, maxClip); state.clipSkip = clamp(state.clipSkip, 0, maxClip);
}, },
vaeSelected: (state, action: PayloadAction<VaeModelParam | null>) => { vaeSelected: (state, action: PayloadAction<VaeModelParam | null>) => {
// null is a valid VAE!
state.vae = action.payload; state.vae = action.payload;
}, },
setClipSkip: (state, action: PayloadAction<number>) => { setClipSkip: (state, action: PayloadAction<number>) => {
@ -253,11 +255,15 @@ export const generationSlice = createSlice({
if (defaultModel && !state.model) { if (defaultModel && !state.model) {
const [base_model, model_type, model_name] = defaultModel.split('/'); const [base_model, model_type, model_name] = defaultModel.split('/');
state.model = zMainModel.parse({
id: defaultModel, const result = zMainModel.safeParse({
name: model_name, model_name,
base_model, base_model,
}); });
if (result.success) {
state.model = result.data;
}
} }
}); });
builder.addCase(setShouldShowAdvancedOptions, (state, action) => { builder.addCase(setShouldShowAdvancedOptions, (state, action) => {

View File

@ -0,0 +1,4 @@
export const MODEL_TYPE_MAP = {
'sd-1': 'Stable Diffusion 1.x',
'sd-2': 'Stable Diffusion 2.x',
};

View File

@ -135,7 +135,7 @@ export type BaseModelParam = z.infer<typeof zBaseModel>;
* TODO: Make this a dynamically generated enum? * TODO: Make this a dynamically generated enum?
*/ */
export const zMainModel = z.object({ export const zMainModel = z.object({
model_name: z.string(), model_name: z.string().min(1),
base_model: zBaseModel, base_model: zBaseModel,
}); });
@ -152,8 +152,7 @@ export const isValidMainModel = (val: unknown): val is MainModelParam =>
* Zod schema for VAE parameter * Zod schema for VAE parameter
*/ */
export const zVaeModel = z.object({ export const zVaeModel = z.object({
id: z.string(), model_name: z.string().min(1),
name: z.string(),
base_model: zBaseModel, base_model: zBaseModel,
}); });
/** /**
@ -169,8 +168,7 @@ export const isValidVaeModel = (val: unknown): val is VaeModelParam =>
* Zod schema for LoRA * Zod schema for LoRA
*/ */
export const zLoRAModel = z.object({ export const zLoRAModel = z.object({
id: z.string(), model_name: z.string().min(1),
model_name: z.string(),
base_model: zBaseModel, base_model: zBaseModel,
}); });
/** /**

View File

@ -0,0 +1,18 @@
import { LoRAModelParam, zLoRAModel } from '../types/parameterSchemas';
export const modelIdToLoRAModelParam = (
loraId: string
): LoRAModelParam | undefined => {
const [base_model, model_type, model_name] = loraId.split('/');
const result = zLoRAModel.safeParse({
base_model,
model_name,
});
if (!result.success) {
return;
}
return result.data;
};

View File

@ -0,0 +1,21 @@
import {
MainModelParam,
zMainModel,
} from 'features/parameters/types/parameterSchemas';
export const modelIdToMainModelParam = (
modelId: string
): MainModelParam | undefined => {
const [base_model, model_type, model_name] = modelId.split('/');
const result = zMainModel.safeParse({
base_model,
model_name,
});
if (!result.success) {
return;
}
return result.data;
};

View File

@ -0,0 +1,18 @@
import { VaeModelParam, zVaeModel } from '../types/parameterSchemas';
export const modelIdToVAEModelParam = (
modelId: string
): VaeModelParam | undefined => {
const [base_model, model_type, model_name] = modelId.split('/');
const result = zVaeModel.safeParse({
base_model,
model_name,
});
if (!result.success) {
return;
}
return result.data;
};

View File

@ -2,7 +2,7 @@ import { SCHEDULER_LABEL_MAP, SCHEDULER_NAMES } from 'app/constants';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect'; import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas'; import { SchedulerParam } from 'features/parameters/types/parameterSchemas';
import { favoriteSchedulersChanged } from 'features/ui/store/uiSlice'; import { favoriteSchedulersChanged } from 'features/ui/store/uiSlice';
import { map } from 'lodash-es'; import { map } from 'lodash-es';
import { useCallback } from 'react'; import { useCallback } from 'react';

View File

@ -10,7 +10,7 @@ import type { RootState } from 'app/store/store';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput'; import IAIInput from 'common/components/IAIInput';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { S } from 'services/api/types'; import { S } from 'services/api/types';
import ModelConvert from './ModelConvert'; import ModelConvert from './ModelConvert';

View File

@ -10,7 +10,7 @@ import type { RootState } from 'app/store/store';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput'; import IAIInput from 'common/components/IAIInput';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { S } from 'services/api/types'; import { S } from 'services/api/types';
type DiffusersModel = type DiffusersModel =

View File

@ -1,7 +1,7 @@
import type { PayloadAction } from '@reduxjs/toolkit'; import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit';
import { initialImageChanged } from 'features/parameters/store/generationSlice'; import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas'; import { SchedulerParam } from 'features/parameters/types/parameterSchemas';
import { setActiveTabReducer } from './extraReducers'; import { setActiveTabReducer } from './extraReducers';
import { InvokeTabName } from './tabMap'; import { InvokeTabName } from './tabMap';
import { AddNewModelType, UIState } from './uiTypes'; import { AddNewModelType, UIState } from './uiTypes';

View File

@ -1,4 +1,4 @@
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas'; import { SchedulerParam } from 'features/parameters/types/parameterSchemas';
export type AddNewModelType = 'ckpt' | 'diffusers' | null; export type AddNewModelType = 'ckpt' | 'diffusers' | null;

View File

@ -2,8 +2,8 @@ export { default as InvokeAIUI } from './app/components/InvokeAIUI';
export type { PartialAppConfig } from './app/types/invokeai'; export type { PartialAppConfig } from './app/types/invokeai';
export { default as IAIIconButton } from './common/components/IAIIconButton'; export { default as IAIIconButton } from './common/components/IAIIconButton';
export { default as IAIPopover } from './common/components/IAIPopover'; export { default as IAIPopover } from './common/components/IAIPopover';
export { default as ParamMainModelSelect } from './features/parameters/components/Parameters/MainModel/ParamMainModelSelect';
export { default as ColorModeButton } from './features/system/components/ColorModeButton';
export { default as InvokeAiLogoComponent } from './features/system/components/InvokeAILogoComponent'; export { default as InvokeAiLogoComponent } from './features/system/components/InvokeAILogoComponent';
export { default as ModelSelect } from './features/system/components/ModelSelect';
export { default as SettingsModal } from './features/system/components/SettingsModal/SettingsModal'; export { default as SettingsModal } from './features/system/components/SettingsModal/SettingsModal';
export { default as StatusIndicator } from './features/system/components/StatusIndicator'; export { default as StatusIndicator } from './features/system/components/StatusIndicator';
export { default as ColorModeButton } from './features/system/components/ColorModeButton';