mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): fix a lot of model-related crashes/bugs
We were storing all types of models by their model ID, which is a format like `sd-1/main/deliberate`. This meant we had to do a lot of extra parsing, because nodes actually wants something like `{base_model: 'sd-1', model_name: 'deliberate'}`. Some of this parsing was done with zod's error-throwing `parse()` method, and in other places it was done with brittle string parsing. This commit refactors the state to use the object form of models. There is still a bit of string parsing done in the to construct the ID from the object form, but it's far less complicated. Also, the zod parsing is now done using `safeParse()`, which does not throw. This requires a few more conditional checks, but should prevent further crashes.
This commit is contained in:
parent
14587464d5
commit
a071873327
@ -1,6 +1,7 @@
|
|||||||
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
|
|
||||||
|
|
||||||
// zod needs the array to be `as const` to infer the type correctly
|
// zod needs the array to be `as const` to infer the type correctly
|
||||||
|
|
||||||
|
import { SchedulerParam } from 'features/parameters/types/parameterSchemas';
|
||||||
|
|
||||||
// this is the source of the `SchedulerParam` type, which is generated by zod
|
// this is the source of the `SchedulerParam` type, which is generated by zod
|
||||||
export const SCHEDULER_NAMES_AS_CONST = [
|
export const SCHEDULER_NAMES_AS_CONST = [
|
||||||
'euler',
|
'euler',
|
||||||
|
@ -1,36 +1,70 @@
|
|||||||
import { makeToast } from 'app/components/Toaster';
|
import { makeToast } from 'app/components/Toaster';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { loraRemoved } from 'features/lora/store/loraSlice';
|
||||||
import { modelSelected } from 'features/parameters/store/actions';
|
import { modelSelected } from 'features/parameters/store/actions';
|
||||||
import {
|
import {
|
||||||
modelChanged,
|
modelChanged,
|
||||||
vaeSelected,
|
vaeSelected,
|
||||||
} from 'features/parameters/store/generationSlice';
|
} from 'features/parameters/store/generationSlice';
|
||||||
import { zMainModel } from 'features/parameters/store/parameterZodSchemas';
|
import { zMainModel } from 'features/parameters/types/parameterSchemas';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { forEach } from 'lodash-es';
|
||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
import { lorasCleared } from '../../../../../features/lora/store/loraSlice';
|
|
||||||
|
const moduleLog = log.child({ module: 'models' });
|
||||||
|
|
||||||
export const addModelSelectedListener = () => {
|
export const addModelSelectedListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: modelSelected,
|
actionCreator: modelSelected,
|
||||||
effect: (action, { getState, dispatch }) => {
|
effect: (action, { getState, dispatch }) => {
|
||||||
const state = getState();
|
const state = getState();
|
||||||
const { base_model, model_name } = action.payload;
|
const result = zMainModel.safeParse(action.payload);
|
||||||
|
|
||||||
|
if (!result.success) {
|
||||||
|
moduleLog.error(
|
||||||
|
{ error: result.error.format() },
|
||||||
|
'Failed to parse main model'
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const newModel = result.data;
|
||||||
|
|
||||||
|
const { base_model } = newModel;
|
||||||
|
|
||||||
if (state.generation.model?.base_model !== base_model) {
|
if (state.generation.model?.base_model !== base_model) {
|
||||||
|
// we may need to reset some incompatible submodels
|
||||||
|
let modelsCleared = 0;
|
||||||
|
|
||||||
|
// handle incompatible loras
|
||||||
|
forEach(state.lora.loras, (lora, id) => {
|
||||||
|
if (lora.base_model !== base_model) {
|
||||||
|
dispatch(loraRemoved(id));
|
||||||
|
modelsCleared += 1;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// handle incompatible vae
|
||||||
|
const { vae } = state.generation;
|
||||||
|
if (vae && vae.base_model !== base_model) {
|
||||||
|
dispatch(vaeSelected(null));
|
||||||
|
modelsCleared += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: handle incompatible controlnet; pending model manager support
|
||||||
|
if (modelsCleared > 0) {
|
||||||
dispatch(
|
dispatch(
|
||||||
addToast(
|
addToast(
|
||||||
makeToast({
|
makeToast({
|
||||||
title: 'Base model changed, clearing submodels',
|
title: `Base model changed, cleared ${modelsCleared} incompatible submodel${
|
||||||
|
modelsCleared === 1 ? '' : 's'
|
||||||
|
}`,
|
||||||
status: 'warning',
|
status: 'warning',
|
||||||
})
|
})
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
dispatch(vaeSelected(null));
|
|
||||||
dispatch(lorasCleared());
|
|
||||||
// TODO: controlnet cleared
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
const newModel = zMainModel.parse(action.payload);
|
|
||||||
|
|
||||||
dispatch(modelChanged(newModel));
|
dispatch(modelChanged(newModel));
|
||||||
},
|
},
|
||||||
|
@ -1,8 +1,19 @@
|
|||||||
import { modelChanged } from 'features/parameters/store/generationSlice';
|
import { log } from 'app/logging/useLogger';
|
||||||
import { some } from 'lodash-es';
|
import { loraRemoved } from 'features/lora/store/loraSlice';
|
||||||
|
import {
|
||||||
|
modelChanged,
|
||||||
|
vaeSelected,
|
||||||
|
} from 'features/parameters/store/generationSlice';
|
||||||
|
import {
|
||||||
|
zMainModel,
|
||||||
|
zVaeModel,
|
||||||
|
} from 'features/parameters/types/parameterSchemas';
|
||||||
|
import { forEach, some } from 'lodash-es';
|
||||||
import { modelsApi } from 'services/api/endpoints/models';
|
import { modelsApi } from 'services/api/endpoints/models';
|
||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ module: 'models' });
|
||||||
|
|
||||||
export const addModelsLoadedListener = () => {
|
export const addModelsLoadedListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
matcher: modelsApi.endpoints.getMainModels.matchFulfilled,
|
matcher: modelsApi.endpoints.getMainModels.matchFulfilled,
|
||||||
@ -31,12 +42,92 @@ export const addModelsLoadedListener = () => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
dispatch(
|
const result = zMainModel.safeParse(firstModel);
|
||||||
modelChanged({
|
|
||||||
base_model: firstModel.base_model,
|
if (!result.success) {
|
||||||
model_name: firstModel.model_name,
|
moduleLog.error(
|
||||||
})
|
{ error: result.error.format() },
|
||||||
|
'Failed to parse main model'
|
||||||
);
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(modelChanged(result.data));
|
||||||
|
},
|
||||||
|
});
|
||||||
|
startAppListening({
|
||||||
|
matcher: modelsApi.endpoints.getVaeModels.matchFulfilled,
|
||||||
|
effect: async (action, { getState, dispatch }) => {
|
||||||
|
// VAEs loaded, need to reset the VAE is it's no longer available
|
||||||
|
|
||||||
|
const currentVae = getState().generation.vae;
|
||||||
|
|
||||||
|
if (currentVae === null) {
|
||||||
|
// null is a valid VAE! it means "use the default with the main model"
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const isCurrentVAEAvailable = some(
|
||||||
|
action.payload.entities,
|
||||||
|
(m) =>
|
||||||
|
m?.model_name === currentVae?.model_name &&
|
||||||
|
m?.base_model === currentVae?.base_model
|
||||||
|
);
|
||||||
|
|
||||||
|
if (isCurrentVAEAvailable) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const firstModelId = action.payload.ids[0];
|
||||||
|
const firstModel = action.payload.entities[firstModelId];
|
||||||
|
|
||||||
|
if (!firstModel) {
|
||||||
|
// No custom VAEs loaded at all; use the default
|
||||||
|
dispatch(modelChanged(null));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const result = zVaeModel.safeParse(firstModel);
|
||||||
|
|
||||||
|
if (!result.success) {
|
||||||
|
moduleLog.error(
|
||||||
|
{ error: result.error.format() },
|
||||||
|
'Failed to parse VAE model'
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(vaeSelected(result.data));
|
||||||
|
},
|
||||||
|
});
|
||||||
|
startAppListening({
|
||||||
|
matcher: modelsApi.endpoints.getLoRAModels.matchFulfilled,
|
||||||
|
effect: async (action, { getState, dispatch }) => {
|
||||||
|
// LoRA models loaded - need to remove missing LoRAs from state
|
||||||
|
|
||||||
|
const loras = getState().lora.loras;
|
||||||
|
|
||||||
|
forEach(loras, (lora, id) => {
|
||||||
|
const isLoRAAvailable = some(
|
||||||
|
action.payload.entities,
|
||||||
|
(m) =>
|
||||||
|
m?.model_name === lora?.model_name &&
|
||||||
|
m?.base_model === lora?.base_model
|
||||||
|
);
|
||||||
|
|
||||||
|
if (isLoRAAvailable) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(loraRemoved(id));
|
||||||
|
});
|
||||||
|
},
|
||||||
|
});
|
||||||
|
startAppListening({
|
||||||
|
matcher: modelsApi.endpoints.getControlNetModels.matchFulfilled,
|
||||||
|
effect: async (action, { getState, dispatch }) => {
|
||||||
|
// ControlNet models loaded - need to remove missing ControlNets from state
|
||||||
|
// TODO: pending model manager controlnet support
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
@ -11,7 +11,7 @@ import { RootState } from 'app/store/store';
|
|||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
|
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
|
||||||
import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect';
|
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||||
import { forEach } from 'lodash-es';
|
import { forEach } from 'lodash-es';
|
||||||
import { PropsWithChildren, useCallback, useMemo, useRef } from 'react';
|
import { PropsWithChildren, useCallback, useMemo, useRef } from 'react';
|
||||||
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
|
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
@ -5,14 +5,14 @@ import IAISlider from 'common/components/IAISlider';
|
|||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { FaTrash } from 'react-icons/fa';
|
import { FaTrash } from 'react-icons/fa';
|
||||||
import {
|
import {
|
||||||
Lora,
|
LoRA,
|
||||||
loraRemoved,
|
loraRemoved,
|
||||||
loraWeightChanged,
|
loraWeightChanged,
|
||||||
loraWeightReset,
|
loraWeightReset,
|
||||||
} from '../store/loraSlice';
|
} from '../store/loraSlice';
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
lora: Lora;
|
lora: LoRA;
|
||||||
};
|
};
|
||||||
|
|
||||||
const ParamLora = (props: Props) => {
|
const ParamLora = (props: Props) => {
|
||||||
|
@ -6,9 +6,9 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
|||||||
import IAICollapse from 'common/components/IAICollapse';
|
import IAICollapse from 'common/components/IAICollapse';
|
||||||
import { size } from 'lodash-es';
|
import { size } from 'lodash-es';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
import ParamLoraList from './ParamLoraList';
|
|
||||||
import ParamLoraSelect from './ParamLoraSelect';
|
|
||||||
import { useFeatureStatus } from '../../system/hooks/useFeatureStatus';
|
import { useFeatureStatus } from '../../system/hooks/useFeatureStatus';
|
||||||
|
import ParamLoraList from './ParamLoraList';
|
||||||
|
import ParamLoRASelect from './ParamLoraSelect';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
stateSelector,
|
stateSelector,
|
||||||
@ -33,7 +33,7 @@ const ParamLoraCollapse = () => {
|
|||||||
return (
|
return (
|
||||||
<IAICollapse label={'LoRA'} activeLabel={activeLabel}>
|
<IAICollapse label={'LoRA'} activeLabel={activeLabel}>
|
||||||
<Flex sx={{ flexDir: 'column', gap: 2 }}>
|
<Flex sx={{ flexDir: 'column', gap: 2 }}>
|
||||||
<ParamLoraSelect />
|
<ParamLoRASelect />
|
||||||
<ParamLoraList />
|
<ParamLoraList />
|
||||||
</Flex>
|
</Flex>
|
||||||
</IAICollapse>
|
</IAICollapse>
|
||||||
|
@ -7,7 +7,7 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
|||||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
|
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
|
||||||
import { loraAdded } from 'features/lora/store/loraSlice';
|
import { loraAdded } from 'features/lora/store/loraSlice';
|
||||||
import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect';
|
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||||
import { forEach } from 'lodash-es';
|
import { forEach } from 'lodash-es';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
|
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
|
||||||
@ -20,23 +20,23 @@ const selector = createSelector(
|
|||||||
defaultSelectorOptions
|
defaultSelectorOptions
|
||||||
);
|
);
|
||||||
|
|
||||||
const ParamLoraSelect = () => {
|
const ParamLoRASelect = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { loras } = useAppSelector(selector);
|
const { loras } = useAppSelector(selector);
|
||||||
const { data: lorasQueryData } = useGetLoRAModelsQuery();
|
const { data: loraModels } = useGetLoRAModelsQuery();
|
||||||
|
|
||||||
const currentMainModel = useAppSelector(
|
const currentMainModel = useAppSelector(
|
||||||
(state: RootState) => state.generation.model
|
(state: RootState) => state.generation.model
|
||||||
);
|
);
|
||||||
|
|
||||||
const data = useMemo(() => {
|
const data = useMemo(() => {
|
||||||
if (!lorasQueryData) {
|
if (!loraModels) {
|
||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
|
|
||||||
const data: SelectItem[] = [];
|
const data: SelectItem[] = [];
|
||||||
|
|
||||||
forEach(lorasQueryData.entities, (lora, id) => {
|
forEach(loraModels.entities, (lora, id) => {
|
||||||
if (!lora || Boolean(id in loras)) {
|
if (!lora || Boolean(id in loras)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -55,23 +55,25 @@ const ParamLoraSelect = () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
return data.sort((a, b) => (a.disabled && !b.disabled ? 1 : -1));
|
return data.sort((a, b) => (a.disabled && !b.disabled ? 1 : -1));
|
||||||
}, [loras, lorasQueryData, currentMainModel?.base_model]);
|
}, [loras, loraModels, currentMainModel?.base_model]);
|
||||||
|
|
||||||
const handleChange = useCallback(
|
const handleChange = useCallback(
|
||||||
(v: string | null | undefined) => {
|
(v: string | null | undefined) => {
|
||||||
if (!v) {
|
if (!v) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const loraEntity = lorasQueryData?.entities[v];
|
const loraEntity = loraModels?.entities[v];
|
||||||
|
|
||||||
if (!loraEntity) {
|
if (!loraEntity) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
dispatch(loraAdded(loraEntity));
|
dispatch(loraAdded(loraEntity));
|
||||||
},
|
},
|
||||||
[dispatch, lorasQueryData?.entities]
|
[dispatch, loraModels?.entities]
|
||||||
);
|
);
|
||||||
|
|
||||||
if (lorasQueryData?.ids.length === 0) {
|
if (loraModels?.ids.length === 0) {
|
||||||
return (
|
return (
|
||||||
<Flex sx={{ justifyContent: 'center', p: 2 }}>
|
<Flex sx={{ justifyContent: 'center', p: 2 }}>
|
||||||
<Text sx={{ fontSize: 'sm', color: 'base.500', _dark: 'base.700' }}>
|
<Text sx={{ fontSize: 'sm', color: 'base.500', _dark: 'base.700' }}>
|
||||||
@ -98,4 +100,4 @@ const ParamLoraSelect = () => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default ParamLoraSelect;
|
export default ParamLoRASelect;
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
|
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
|
||||||
import { LoRAModelParam } from 'features/parameters/store/parameterZodSchemas';
|
import { LoRAModelParam } from 'features/parameters/types/parameterSchemas';
|
||||||
import { LoRAModelConfigEntity } from 'services/api/endpoints/models';
|
import { LoRAModelConfigEntity } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
export type Lora = LoRAModelParam & {
|
export type LoRA = LoRAModelParam & {
|
||||||
weight: number;
|
weight: number;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -11,7 +11,7 @@ export const defaultLoRAConfig = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
export type LoraState = {
|
export type LoraState = {
|
||||||
loras: Record<string, LoRAModelParam & { weight: number }>;
|
loras: Record<string, LoRA>;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const intialLoraState: LoraState = {
|
export const intialLoraState: LoraState = {
|
||||||
@ -24,7 +24,7 @@ export const loraSlice = createSlice({
|
|||||||
reducers: {
|
reducers: {
|
||||||
loraAdded: (state, action: PayloadAction<LoRAModelConfigEntity>) => {
|
loraAdded: (state, action: PayloadAction<LoRAModelConfigEntity>) => {
|
||||||
const { model_name, id, base_model } = action.payload;
|
const { model_name, id, base_model } = action.payload;
|
||||||
state.loras[id] = { id, model_name, base_model, ...defaultLoRAConfig };
|
state.loras[id] = { model_name, base_model, ...defaultLoRAConfig };
|
||||||
},
|
},
|
||||||
loraRemoved: (state, action: PayloadAction<string>) => {
|
loraRemoved: (state, action: PayloadAction<string>) => {
|
||||||
const id = action.payload;
|
const id = action.payload;
|
||||||
|
@ -6,7 +6,7 @@ import {
|
|||||||
VaeModelInputFieldTemplate,
|
VaeModelInputFieldTemplate,
|
||||||
VaeModelInputFieldValue,
|
VaeModelInputFieldValue,
|
||||||
} from 'features/nodes/types/types';
|
} from 'features/nodes/types/types';
|
||||||
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
|
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||||
import { forEach, isString } from 'lodash-es';
|
import { forEach, isString } from 'lodash-es';
|
||||||
import { memo, useCallback, useEffect, useMemo } from 'react';
|
import { memo, useCallback, useEffect, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
@ -46,7 +46,7 @@ const LoRAModelInputFieldComponent = (
|
|||||||
data.push({
|
data.push({
|
||||||
value: id,
|
value: id,
|
||||||
label: model.model_name,
|
label: model.model_name,
|
||||||
group: BASE_MODEL_NAME_MAP[model.base_model],
|
group: MODEL_TYPE_MAP[model.base_model],
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -88,8 +88,7 @@ const LoRAModelInputFieldComponent = (
|
|||||||
<IAIMantineSelect
|
<IAIMantineSelect
|
||||||
tooltip={selectedModel?.description}
|
tooltip={selectedModel?.description}
|
||||||
label={
|
label={
|
||||||
selectedModel?.base_model &&
|
selectedModel?.base_model && MODEL_TYPE_MAP[selectedModel?.base_model]
|
||||||
BASE_MODEL_NAME_MAP[selectedModel?.base_model]
|
|
||||||
}
|
}
|
||||||
value={field.value}
|
value={field.value}
|
||||||
placeholder="Pick one"
|
placeholder="Pick one"
|
||||||
|
@ -7,7 +7,7 @@ import {
|
|||||||
} from 'features/nodes/types/types';
|
} from 'features/nodes/types/types';
|
||||||
|
|
||||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
|
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||||
import { forEach, isString } from 'lodash-es';
|
import { forEach, isString } from 'lodash-es';
|
||||||
import { memo, useCallback, useEffect, useMemo } from 'react';
|
import { memo, useCallback, useEffect, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
@ -39,7 +39,7 @@ const ModelInputFieldComponent = (
|
|||||||
data.push({
|
data.push({
|
||||||
value: id,
|
value: id,
|
||||||
label: model.model_name,
|
label: model.model_name,
|
||||||
group: BASE_MODEL_NAME_MAP[model.base_model],
|
group: MODEL_TYPE_MAP[model.base_model],
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -86,8 +86,7 @@ const ModelInputFieldComponent = (
|
|||||||
<IAIMantineSelect
|
<IAIMantineSelect
|
||||||
tooltip={selectedModel?.description}
|
tooltip={selectedModel?.description}
|
||||||
label={
|
label={
|
||||||
selectedModel?.base_model &&
|
selectedModel?.base_model && MODEL_TYPE_MAP[selectedModel?.base_model]
|
||||||
BASE_MODEL_NAME_MAP[selectedModel?.base_model]
|
|
||||||
}
|
}
|
||||||
value={field.value}
|
value={field.value}
|
||||||
placeholder="Pick one"
|
placeholder="Pick one"
|
||||||
|
@ -6,7 +6,7 @@ import {
|
|||||||
VaeModelInputFieldTemplate,
|
VaeModelInputFieldTemplate,
|
||||||
VaeModelInputFieldValue,
|
VaeModelInputFieldValue,
|
||||||
} from 'features/nodes/types/types';
|
} from 'features/nodes/types/types';
|
||||||
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
|
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||||
import { forEach } from 'lodash-es';
|
import { forEach } from 'lodash-es';
|
||||||
import { memo, useCallback, useEffect, useMemo } from 'react';
|
import { memo, useCallback, useEffect, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
@ -46,7 +46,7 @@ const VaeModelInputFieldComponent = (
|
|||||||
data.push({
|
data.push({
|
||||||
value: id,
|
value: id,
|
||||||
label: model.model_name,
|
label: model.model_name,
|
||||||
group: BASE_MODEL_NAME_MAP[model.base_model],
|
group: MODEL_TYPE_MAP[model.base_model],
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -81,8 +81,7 @@ const VaeModelInputFieldComponent = (
|
|||||||
<IAIMantineSelect
|
<IAIMantineSelect
|
||||||
tooltip={selectedModel?.description}
|
tooltip={selectedModel?.description}
|
||||||
label={
|
label={
|
||||||
selectedModel?.base_model &&
|
selectedModel?.base_model && MODEL_TYPE_MAP[selectedModel?.base_model]
|
||||||
BASE_MODEL_NAME_MAP[selectedModel?.base_model]
|
|
||||||
}
|
}
|
||||||
value={field.value}
|
value={field.value}
|
||||||
placeholder="Pick one"
|
placeholder="Pick one"
|
||||||
|
@ -5,7 +5,6 @@ import {
|
|||||||
LoraLoaderInvocation,
|
LoraLoaderInvocation,
|
||||||
MetadataAccumulatorInvocation,
|
MetadataAccumulatorInvocation,
|
||||||
} from 'services/api/types';
|
} from 'services/api/types';
|
||||||
import { modelIdToLoRAModelField } from '../modelIdToLoRAName';
|
|
||||||
import {
|
import {
|
||||||
CLIP_SKIP,
|
CLIP_SKIP,
|
||||||
LORA_LOADER,
|
LORA_LOADER,
|
||||||
@ -55,23 +54,22 @@ export const addLoRAsToGraph = (
|
|||||||
let currentLoraIndex = 0;
|
let currentLoraIndex = 0;
|
||||||
|
|
||||||
forEach(loras, (lora) => {
|
forEach(loras, (lora) => {
|
||||||
const { id, name, weight } = lora;
|
const { model_name, base_model, weight } = lora;
|
||||||
const loraField = modelIdToLoRAModelField(id);
|
const currentLoraNodeId = `${LORA_LOADER}_${model_name.replace('.', '_')}`;
|
||||||
const currentLoraNodeId = `${LORA_LOADER}_${loraField.model_name.replace(
|
|
||||||
'.',
|
|
||||||
'_'
|
|
||||||
)}`;
|
|
||||||
|
|
||||||
const loraLoaderNode: LoraLoaderInvocation = {
|
const loraLoaderNode: LoraLoaderInvocation = {
|
||||||
type: 'lora_loader',
|
type: 'lora_loader',
|
||||||
id: currentLoraNodeId,
|
id: currentLoraNodeId,
|
||||||
lora: loraField,
|
lora,
|
||||||
weight,
|
weight,
|
||||||
};
|
};
|
||||||
|
|
||||||
// add the lora to the metadata accumulator
|
// add the lora to the metadata accumulator
|
||||||
if (metadataAccumulator) {
|
if (metadataAccumulator) {
|
||||||
metadataAccumulator.loras.push({ lora: loraField, weight });
|
metadataAccumulator.loras.push({
|
||||||
|
lora: { model_name, base_model },
|
||||||
|
weight,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// add to graph
|
// add to graph
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||||
import { MetadataAccumulatorInvocation } from 'services/api/types';
|
import { MetadataAccumulatorInvocation } from 'services/api/types';
|
||||||
import { modelIdToVAEModelField } from '../modelIdToVAEModelField';
|
|
||||||
import {
|
import {
|
||||||
IMAGE_TO_IMAGE_GRAPH,
|
IMAGE_TO_IMAGE_GRAPH,
|
||||||
IMAGE_TO_LATENTS,
|
IMAGE_TO_LATENTS,
|
||||||
@ -19,7 +18,6 @@ export const addVAEToGraph = (
|
|||||||
graph: NonNullableGraph
|
graph: NonNullableGraph
|
||||||
): void => {
|
): void => {
|
||||||
const { vae } = state.generation;
|
const { vae } = state.generation;
|
||||||
const vae_model = modelIdToVAEModelField(vae?.id || '');
|
|
||||||
|
|
||||||
const isAutoVae = !vae;
|
const isAutoVae = !vae;
|
||||||
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
|
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
|
||||||
@ -30,7 +28,7 @@ export const addVAEToGraph = (
|
|||||||
graph.nodes[VAE_LOADER] = {
|
graph.nodes[VAE_LOADER] = {
|
||||||
type: 'vae_loader',
|
type: 'vae_loader',
|
||||||
id: VAE_LOADER,
|
id: VAE_LOADER,
|
||||||
vae_model,
|
vae_model: vae,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -74,6 +72,6 @@ export const addVAEToGraph = (
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (vae && metadataAccumulator) {
|
if (vae && metadataAccumulator) {
|
||||||
metadataAccumulator.vae = vae_model;
|
metadataAccumulator.vae = vae;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { InputFieldValue } from 'features/nodes/types/types';
|
import { InputFieldValue } from 'features/nodes/types/types';
|
||||||
|
import { modelIdToLoRAModelParam } from 'features/parameters/util/modelIdToLoRAModelParam';
|
||||||
|
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
|
||||||
|
import { modelIdToVAEModelParam } from 'features/parameters/util/modelIdToVAEModelParam';
|
||||||
import { cloneDeep, omit, reduce } from 'lodash-es';
|
import { cloneDeep, omit, reduce } from 'lodash-es';
|
||||||
import { Graph } from 'services/api/types';
|
import { Graph } from 'services/api/types';
|
||||||
import { AnyInvocation } from 'services/events/types';
|
import { AnyInvocation } from 'services/events/types';
|
||||||
import { v4 as uuidv4 } from 'uuid';
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
import { modelIdToLoRAModelField } from '../modelIdToLoRAName';
|
|
||||||
import { modelIdToMainModelField } from '../modelIdToMainModelField';
|
|
||||||
import { modelIdToVAEModelField } from '../modelIdToVAEModelField';
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* We need to do special handling for some fields
|
* We need to do special handling for some fields
|
||||||
@ -29,19 +29,19 @@ export const parseFieldValue = (field: InputFieldValue) => {
|
|||||||
|
|
||||||
if (field.type === 'model') {
|
if (field.type === 'model') {
|
||||||
if (field.value) {
|
if (field.value) {
|
||||||
return modelIdToMainModelField(field.value);
|
return modelIdToMainModelParam(field.value);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (field.type === 'vae_model') {
|
if (field.type === 'vae_model') {
|
||||||
if (field.value) {
|
if (field.value) {
|
||||||
return modelIdToVAEModelField(field.value);
|
return modelIdToVAEModelParam(field.value);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (field.type === 'lora_model') {
|
if (field.type === 'lora_model') {
|
||||||
if (field.value) {
|
if (field.value) {
|
||||||
return modelIdToLoRAModelField(field.value);
|
return modelIdToLoRAModelParam(field.value);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,12 +0,0 @@
|
|||||||
import { BaseModelType, LoRAModelField } from 'services/api/types';
|
|
||||||
|
|
||||||
export const modelIdToLoRAModelField = (loraId: string): LoRAModelField => {
|
|
||||||
const [base_model, model_type, model_name] = loraId.split('/');
|
|
||||||
|
|
||||||
const field: LoRAModelField = {
|
|
||||||
base_model: base_model as BaseModelType,
|
|
||||||
model_name,
|
|
||||||
};
|
|
||||||
|
|
||||||
return field;
|
|
||||||
};
|
|
@ -1,16 +0,0 @@
|
|||||||
import { BaseModelType, MainModelField } from 'services/api/types';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Crudely converts a model id to a main model field
|
|
||||||
* TODO: Make better
|
|
||||||
*/
|
|
||||||
export const modelIdToMainModelField = (modelId: string): MainModelField => {
|
|
||||||
const [base_model, model_type, model_name] = modelId.split('/');
|
|
||||||
|
|
||||||
const field: MainModelField = {
|
|
||||||
base_model: base_model as BaseModelType,
|
|
||||||
model_name,
|
|
||||||
};
|
|
||||||
|
|
||||||
return field;
|
|
||||||
};
|
|
@ -1,16 +0,0 @@
|
|||||||
import { BaseModelType, VAEModelField } from 'services/api/types';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Crudely converts a model id to a main model field
|
|
||||||
* TODO: Make better
|
|
||||||
*/
|
|
||||||
export const modelIdToVAEModelField = (modelId: string): VAEModelField => {
|
|
||||||
const [base_model, model_type, model_name] = modelId.split('/');
|
|
||||||
|
|
||||||
const field: VAEModelField = {
|
|
||||||
base_model: base_model as BaseModelType,
|
|
||||||
model_name,
|
|
||||||
};
|
|
||||||
|
|
||||||
return field;
|
|
||||||
};
|
|
@ -1,8 +1,8 @@
|
|||||||
import { Box, Flex } from '@chakra-ui/react';
|
import { Box, Flex } from '@chakra-ui/react';
|
||||||
import ModelSelect from 'features/system/components/ModelSelect';
|
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||||
import VAESelect from 'features/system/components/VAESelect';
|
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
import { useFeatureStatus } from '../../../../system/hooks/useFeatureStatus';
|
import ParamMainModelSelect from '../MainModel/ParamMainModelSelect';
|
||||||
|
import ParamVAEModelSelect from '../VAEModel/ParamVAEModelSelect';
|
||||||
import ParamScheduler from './ParamScheduler';
|
import ParamScheduler from './ParamScheduler';
|
||||||
|
|
||||||
const ParamModelandVAEandScheduler = () => {
|
const ParamModelandVAEandScheduler = () => {
|
||||||
@ -11,12 +11,12 @@ const ParamModelandVAEandScheduler = () => {
|
|||||||
return (
|
return (
|
||||||
<Flex gap={3} w="full" flexWrap={isVaeEnabled ? 'wrap' : 'nowrap'}>
|
<Flex gap={3} w="full" flexWrap={isVaeEnabled ? 'wrap' : 'nowrap'}>
|
||||||
<Box w="full">
|
<Box w="full">
|
||||||
<ModelSelect />
|
<ParamMainModelSelect />
|
||||||
</Box>
|
</Box>
|
||||||
<Flex gap={3} w="full">
|
<Flex gap={3} w="full">
|
||||||
{isVaeEnabled && (
|
{isVaeEnabled && (
|
||||||
<Box w="full">
|
<Box w="full">
|
||||||
<VAESelect />
|
<ParamVAEModelSelect />
|
||||||
</Box>
|
</Box>
|
||||||
)}
|
)}
|
||||||
<Box w="full">
|
<Box w="full">
|
||||||
|
@ -5,7 +5,7 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
|||||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||||
import { setScheduler } from 'features/parameters/store/generationSlice';
|
import { setScheduler } from 'features/parameters/store/generationSlice';
|
||||||
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
|
import { SchedulerParam } from 'features/parameters/types/parameterSchemas';
|
||||||
import { uiSelector } from 'features/ui/store/uiSelectors';
|
import { uiSelector } from 'features/ui/store/uiSelectors';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
@ -8,27 +8,23 @@ import { SelectItem } from '@mantine/core';
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { stateSelector } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import { modelIdToMainModelField } from 'features/nodes/util/modelIdToMainModelField';
|
|
||||||
import { modelSelected } from 'features/parameters/store/actions';
|
import { modelSelected } from 'features/parameters/store/actions';
|
||||||
|
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||||
|
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
|
||||||
import { forEach } from 'lodash-es';
|
import { forEach } from 'lodash-es';
|
||||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
export const MODEL_TYPE_MAP = {
|
|
||||||
'sd-1': 'Stable Diffusion 1.x',
|
|
||||||
'sd-2': 'Stable Diffusion 2.x',
|
|
||||||
};
|
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
stateSelector,
|
stateSelector,
|
||||||
(state) => ({ currentModel: state.generation.model }),
|
(state) => ({ model: state.generation.model }),
|
||||||
defaultSelectorOptions
|
defaultSelectorOptions
|
||||||
);
|
);
|
||||||
|
|
||||||
const ModelSelect = () => {
|
const ParamMainModelSelect = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const { currentModel } = useAppSelector(selector);
|
const { model } = useAppSelector(selector);
|
||||||
|
|
||||||
const { data: mainModels, isLoading } = useGetMainModelsQuery();
|
const { data: mainModels, isLoading } = useGetMainModelsQuery();
|
||||||
|
|
||||||
@ -54,12 +50,13 @@ const ModelSelect = () => {
|
|||||||
return data;
|
return data;
|
||||||
}, [mainModels]);
|
}, [mainModels]);
|
||||||
|
|
||||||
|
// grab the full model entity from the RTK Query cache
|
||||||
|
// TODO: maybe we should just store the full model entity in state?
|
||||||
const selectedModel = useMemo(
|
const selectedModel = useMemo(
|
||||||
() =>
|
() =>
|
||||||
mainModels?.entities[
|
mainModels?.entities[`${model?.base_model}/main/${model?.model_name}`] ??
|
||||||
`${currentModel?.base_model}/main/${currentModel?.model_name}`
|
null,
|
||||||
],
|
[mainModels?.entities, model]
|
||||||
[mainModels?.entities, currentModel]
|
|
||||||
);
|
);
|
||||||
|
|
||||||
const handleChangeModel = useCallback(
|
const handleChangeModel = useCallback(
|
||||||
@ -68,8 +65,13 @@ const ModelSelect = () => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const modelField = modelIdToMainModelField(v);
|
const newModel = modelIdToMainModelParam(v);
|
||||||
dispatch(modelSelected(modelField));
|
|
||||||
|
if (!newModel) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(modelSelected(newModel));
|
||||||
},
|
},
|
||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
@ -95,4 +97,4 @@ const ModelSelect = () => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default memo(ModelSelect);
|
export default memo(ParamMainModelSelect);
|
@ -1,4 +1,4 @@
|
|||||||
import { memo, useCallback, useEffect, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
@ -8,26 +8,30 @@ import { SelectItem } from '@mantine/core';
|
|||||||
import { forEach } from 'lodash-es';
|
import { forEach } from 'lodash-es';
|
||||||
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
import { RootState } from 'app/store/store';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
|
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
|
||||||
import { vaeSelected } from 'features/parameters/store/generationSlice';
|
import { vaeSelected } from 'features/parameters/store/generationSlice';
|
||||||
import { zVaeModel } from 'features/parameters/store/parameterZodSchemas';
|
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||||
import { MODEL_TYPE_MAP } from './ModelSelect';
|
import { modelIdToVAEModelParam } from 'features/parameters/util/modelIdToVAEModelParam';
|
||||||
|
|
||||||
const VAESelect = () => {
|
const selector = createSelector(
|
||||||
|
stateSelector,
|
||||||
|
({ generation }) => {
|
||||||
|
const { model, vae } = generation;
|
||||||
|
return { model, vae };
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
|
const ParamVAEModelSelect = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
const { model, vae } = useAppSelector(selector);
|
||||||
|
|
||||||
const { data: vaeModels } = useGetVaeModelsQuery();
|
const { data: vaeModels } = useGetVaeModelsQuery();
|
||||||
|
|
||||||
const currentMainModel = useAppSelector(
|
|
||||||
(state: RootState) => state.generation.model
|
|
||||||
);
|
|
||||||
|
|
||||||
const selectedVae = useAppSelector(
|
|
||||||
(state: RootState) => state.generation.vae
|
|
||||||
);
|
|
||||||
|
|
||||||
const data = useMemo(() => {
|
const data = useMemo(() => {
|
||||||
if (!vaeModels) {
|
if (!vaeModels) {
|
||||||
return [];
|
return [];
|
||||||
@ -41,30 +45,32 @@ const VAESelect = () => {
|
|||||||
},
|
},
|
||||||
];
|
];
|
||||||
|
|
||||||
forEach(vaeModels.entities, (model, id) => {
|
forEach(vaeModels.entities, (vae, id) => {
|
||||||
if (!model) {
|
if (!vae) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const disabled = currentMainModel?.base_model !== model.base_model;
|
const disabled = model?.base_model !== vae.base_model;
|
||||||
|
|
||||||
data.push({
|
data.push({
|
||||||
value: id,
|
value: id,
|
||||||
label: model.model_name,
|
label: vae.model_name,
|
||||||
group: MODEL_TYPE_MAP[model.base_model],
|
group: MODEL_TYPE_MAP[vae.base_model],
|
||||||
disabled,
|
disabled,
|
||||||
tooltip: disabled
|
tooltip: disabled
|
||||||
? `Incompatible base model: ${model.base_model}`
|
? `Incompatible base model: ${vae.base_model}`
|
||||||
: undefined,
|
: undefined,
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
return data.sort((a, b) => (a.disabled && !b.disabled ? 1 : -1));
|
return data.sort((a, b) => (a.disabled && !b.disabled ? 1 : -1));
|
||||||
}, [vaeModels, currentMainModel?.base_model]);
|
}, [vaeModels, model?.base_model]);
|
||||||
|
|
||||||
|
// grab the full model entity from the RTK Query cache
|
||||||
const selectedVaeModel = useMemo(
|
const selectedVaeModel = useMemo(
|
||||||
() => (selectedVae?.id ? vaeModels?.entities[selectedVae?.id] : null),
|
() =>
|
||||||
[vaeModels?.entities, selectedVae]
|
vaeModels?.entities[`${vae?.base_model}/vae/${vae?.model_name}`] ?? null,
|
||||||
|
[vaeModels?.entities, vae]
|
||||||
);
|
);
|
||||||
|
|
||||||
const handleChangeModel = useCallback(
|
const handleChangeModel = useCallback(
|
||||||
@ -74,32 +80,23 @@ const VAESelect = () => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const [base_model, type, name] = v.split('/');
|
const newVaeModel = modelIdToVAEModelParam(v);
|
||||||
|
|
||||||
const model = zVaeModel.parse({
|
if (!newVaeModel) {
|
||||||
id: v,
|
return;
|
||||||
name,
|
}
|
||||||
base_model,
|
|
||||||
});
|
|
||||||
|
|
||||||
dispatch(vaeSelected(model));
|
dispatch(vaeSelected(newVaeModel));
|
||||||
},
|
},
|
||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
if (selectedVae && vaeModels?.ids.includes(selectedVae.id)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
dispatch(vaeSelected(null));
|
|
||||||
}, [handleChangeModel, vaeModels?.ids, selectedVae, dispatch]);
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<IAIMantineSelect
|
<IAIMantineSelect
|
||||||
itemComponent={IAIMantineSelectItemWithTooltip}
|
itemComponent={IAIMantineSelectItemWithTooltip}
|
||||||
tooltip={selectedVaeModel?.description}
|
tooltip={selectedVaeModel?.description}
|
||||||
label={t('modelManager.vae')}
|
label={t('modelManager.vae')}
|
||||||
value={selectedVae?.id ?? 'default'}
|
value={selectedVaeModel?.id ?? 'default'}
|
||||||
placeholder="Default"
|
placeholder="Default"
|
||||||
data={data}
|
data={data}
|
||||||
onChange={handleChangeModel}
|
onChange={handleChangeModel}
|
||||||
@ -109,4 +106,4 @@ const VAESelect = () => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default memo(VAESelect);
|
export default memo(ParamVAEModelSelect);
|
@ -28,7 +28,7 @@ import {
|
|||||||
isValidSteps,
|
isValidSteps,
|
||||||
isValidStrength,
|
isValidStrength,
|
||||||
isValidWidth,
|
isValidWidth,
|
||||||
} from '../store/parameterZodSchemas';
|
} from '../types/parameterSchemas';
|
||||||
|
|
||||||
export const useRecallParameters = () => {
|
export const useRecallParameters = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
@ -13,6 +13,7 @@ import { clipSkipMap } from '../components/Parameters/Advanced/ParamClipSkip';
|
|||||||
import {
|
import {
|
||||||
CfgScaleParam,
|
CfgScaleParam,
|
||||||
HeightParam,
|
HeightParam,
|
||||||
|
MainModelParam,
|
||||||
NegativePromptParam,
|
NegativePromptParam,
|
||||||
PositivePromptParam,
|
PositivePromptParam,
|
||||||
SchedulerParam,
|
SchedulerParam,
|
||||||
@ -22,7 +23,7 @@ import {
|
|||||||
VaeModelParam,
|
VaeModelParam,
|
||||||
WidthParam,
|
WidthParam,
|
||||||
zMainModel,
|
zMainModel,
|
||||||
} from './parameterZodSchemas';
|
} from '../types/parameterSchemas';
|
||||||
|
|
||||||
export interface GenerationState {
|
export interface GenerationState {
|
||||||
cfgScale: CfgScaleParam;
|
cfgScale: CfgScaleParam;
|
||||||
@ -226,18 +227,19 @@ export const generationSlice = createSlice({
|
|||||||
const { image_name, width, height } = action.payload;
|
const { image_name, width, height } = action.payload;
|
||||||
state.initialImage = { imageName: image_name, width, height };
|
state.initialImage = { imageName: image_name, width, height };
|
||||||
},
|
},
|
||||||
modelChanged: (state, action: PayloadAction<MainModelField | null>) => {
|
modelChanged: (state, action: PayloadAction<MainModelParam | null>) => {
|
||||||
if (!action.payload) {
|
state.model = action.payload;
|
||||||
state.model = null;
|
|
||||||
}
|
|
||||||
|
|
||||||
state.model = zMainModel.parse(action.payload);
|
if (state.model === null) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// Clamp ClipSkip Based On Selected Model
|
// Clamp ClipSkip Based On Selected Model
|
||||||
const { maxClip } = clipSkipMap[state.model.base_model];
|
const { maxClip } = clipSkipMap[state.model.base_model];
|
||||||
state.clipSkip = clamp(state.clipSkip, 0, maxClip);
|
state.clipSkip = clamp(state.clipSkip, 0, maxClip);
|
||||||
},
|
},
|
||||||
vaeSelected: (state, action: PayloadAction<VaeModelParam | null>) => {
|
vaeSelected: (state, action: PayloadAction<VaeModelParam | null>) => {
|
||||||
|
// null is a valid VAE!
|
||||||
state.vae = action.payload;
|
state.vae = action.payload;
|
||||||
},
|
},
|
||||||
setClipSkip: (state, action: PayloadAction<number>) => {
|
setClipSkip: (state, action: PayloadAction<number>) => {
|
||||||
@ -253,11 +255,15 @@ export const generationSlice = createSlice({
|
|||||||
|
|
||||||
if (defaultModel && !state.model) {
|
if (defaultModel && !state.model) {
|
||||||
const [base_model, model_type, model_name] = defaultModel.split('/');
|
const [base_model, model_type, model_name] = defaultModel.split('/');
|
||||||
state.model = zMainModel.parse({
|
|
||||||
id: defaultModel,
|
const result = zMainModel.safeParse({
|
||||||
name: model_name,
|
model_name,
|
||||||
base_model,
|
base_model,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
if (result.success) {
|
||||||
|
state.model = result.data;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
builder.addCase(setShouldShowAdvancedOptions, (state, action) => {
|
builder.addCase(setShouldShowAdvancedOptions, (state, action) => {
|
||||||
|
@ -0,0 +1,4 @@
|
|||||||
|
export const MODEL_TYPE_MAP = {
|
||||||
|
'sd-1': 'Stable Diffusion 1.x',
|
||||||
|
'sd-2': 'Stable Diffusion 2.x',
|
||||||
|
};
|
@ -135,7 +135,7 @@ export type BaseModelParam = z.infer<typeof zBaseModel>;
|
|||||||
* TODO: Make this a dynamically generated enum?
|
* TODO: Make this a dynamically generated enum?
|
||||||
*/
|
*/
|
||||||
export const zMainModel = z.object({
|
export const zMainModel = z.object({
|
||||||
model_name: z.string(),
|
model_name: z.string().min(1),
|
||||||
base_model: zBaseModel,
|
base_model: zBaseModel,
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -152,8 +152,7 @@ export const isValidMainModel = (val: unknown): val is MainModelParam =>
|
|||||||
* Zod schema for VAE parameter
|
* Zod schema for VAE parameter
|
||||||
*/
|
*/
|
||||||
export const zVaeModel = z.object({
|
export const zVaeModel = z.object({
|
||||||
id: z.string(),
|
model_name: z.string().min(1),
|
||||||
name: z.string(),
|
|
||||||
base_model: zBaseModel,
|
base_model: zBaseModel,
|
||||||
});
|
});
|
||||||
/**
|
/**
|
||||||
@ -169,8 +168,7 @@ export const isValidVaeModel = (val: unknown): val is VaeModelParam =>
|
|||||||
* Zod schema for LoRA
|
* Zod schema for LoRA
|
||||||
*/
|
*/
|
||||||
export const zLoRAModel = z.object({
|
export const zLoRAModel = z.object({
|
||||||
id: z.string(),
|
model_name: z.string().min(1),
|
||||||
model_name: z.string(),
|
|
||||||
base_model: zBaseModel,
|
base_model: zBaseModel,
|
||||||
});
|
});
|
||||||
/**
|
/**
|
@ -0,0 +1,18 @@
|
|||||||
|
import { LoRAModelParam, zLoRAModel } from '../types/parameterSchemas';
|
||||||
|
|
||||||
|
export const modelIdToLoRAModelParam = (
|
||||||
|
loraId: string
|
||||||
|
): LoRAModelParam | undefined => {
|
||||||
|
const [base_model, model_type, model_name] = loraId.split('/');
|
||||||
|
|
||||||
|
const result = zLoRAModel.safeParse({
|
||||||
|
base_model,
|
||||||
|
model_name,
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!result.success) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
return result.data;
|
||||||
|
};
|
@ -0,0 +1,21 @@
|
|||||||
|
import {
|
||||||
|
MainModelParam,
|
||||||
|
zMainModel,
|
||||||
|
} from 'features/parameters/types/parameterSchemas';
|
||||||
|
|
||||||
|
export const modelIdToMainModelParam = (
|
||||||
|
modelId: string
|
||||||
|
): MainModelParam | undefined => {
|
||||||
|
const [base_model, model_type, model_name] = modelId.split('/');
|
||||||
|
|
||||||
|
const result = zMainModel.safeParse({
|
||||||
|
base_model,
|
||||||
|
model_name,
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!result.success) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
return result.data;
|
||||||
|
};
|
@ -0,0 +1,18 @@
|
|||||||
|
import { VaeModelParam, zVaeModel } from '../types/parameterSchemas';
|
||||||
|
|
||||||
|
export const modelIdToVAEModelParam = (
|
||||||
|
modelId: string
|
||||||
|
): VaeModelParam | undefined => {
|
||||||
|
const [base_model, model_type, model_name] = modelId.split('/');
|
||||||
|
|
||||||
|
const result = zVaeModel.safeParse({
|
||||||
|
base_model,
|
||||||
|
model_name,
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!result.success) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
return result.data;
|
||||||
|
};
|
@ -2,7 +2,7 @@ import { SCHEDULER_LABEL_MAP, SCHEDULER_NAMES } from 'app/constants';
|
|||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
|
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
|
||||||
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
|
import { SchedulerParam } from 'features/parameters/types/parameterSchemas';
|
||||||
import { favoriteSchedulersChanged } from 'features/ui/store/uiSlice';
|
import { favoriteSchedulersChanged } from 'features/ui/store/uiSlice';
|
||||||
import { map } from 'lodash-es';
|
import { map } from 'lodash-es';
|
||||||
import { useCallback } from 'react';
|
import { useCallback } from 'react';
|
||||||
|
@ -10,7 +10,7 @@ import type { RootState } from 'app/store/store';
|
|||||||
import IAIButton from 'common/components/IAIButton';
|
import IAIButton from 'common/components/IAIButton';
|
||||||
import IAIInput from 'common/components/IAIInput';
|
import IAIInput from 'common/components/IAIInput';
|
||||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect';
|
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||||
import { S } from 'services/api/types';
|
import { S } from 'services/api/types';
|
||||||
import ModelConvert from './ModelConvert';
|
import ModelConvert from './ModelConvert';
|
||||||
|
|
||||||
|
@ -10,7 +10,7 @@ import type { RootState } from 'app/store/store';
|
|||||||
import IAIButton from 'common/components/IAIButton';
|
import IAIButton from 'common/components/IAIButton';
|
||||||
import IAIInput from 'common/components/IAIInput';
|
import IAIInput from 'common/components/IAIInput';
|
||||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect';
|
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||||
import { S } from 'services/api/types';
|
import { S } from 'services/api/types';
|
||||||
|
|
||||||
type DiffusersModel =
|
type DiffusersModel =
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||||
import { createSlice } from '@reduxjs/toolkit';
|
import { createSlice } from '@reduxjs/toolkit';
|
||||||
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
||||||
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
|
import { SchedulerParam } from 'features/parameters/types/parameterSchemas';
|
||||||
import { setActiveTabReducer } from './extraReducers';
|
import { setActiveTabReducer } from './extraReducers';
|
||||||
import { InvokeTabName } from './tabMap';
|
import { InvokeTabName } from './tabMap';
|
||||||
import { AddNewModelType, UIState } from './uiTypes';
|
import { AddNewModelType, UIState } from './uiTypes';
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
|
import { SchedulerParam } from 'features/parameters/types/parameterSchemas';
|
||||||
|
|
||||||
export type AddNewModelType = 'ckpt' | 'diffusers' | null;
|
export type AddNewModelType = 'ckpt' | 'diffusers' | null;
|
||||||
|
|
||||||
|
@ -2,8 +2,8 @@ export { default as InvokeAIUI } from './app/components/InvokeAIUI';
|
|||||||
export type { PartialAppConfig } from './app/types/invokeai';
|
export type { PartialAppConfig } from './app/types/invokeai';
|
||||||
export { default as IAIIconButton } from './common/components/IAIIconButton';
|
export { default as IAIIconButton } from './common/components/IAIIconButton';
|
||||||
export { default as IAIPopover } from './common/components/IAIPopover';
|
export { default as IAIPopover } from './common/components/IAIPopover';
|
||||||
|
export { default as ParamMainModelSelect } from './features/parameters/components/Parameters/MainModel/ParamMainModelSelect';
|
||||||
|
export { default as ColorModeButton } from './features/system/components/ColorModeButton';
|
||||||
export { default as InvokeAiLogoComponent } from './features/system/components/InvokeAILogoComponent';
|
export { default as InvokeAiLogoComponent } from './features/system/components/InvokeAILogoComponent';
|
||||||
export { default as ModelSelect } from './features/system/components/ModelSelect';
|
|
||||||
export { default as SettingsModal } from './features/system/components/SettingsModal/SettingsModal';
|
export { default as SettingsModal } from './features/system/components/SettingsModal/SettingsModal';
|
||||||
export { default as StatusIndicator } from './features/system/components/StatusIndicator';
|
export { default as StatusIndicator } from './features/system/components/StatusIndicator';
|
||||||
export { default as ColorModeButton } from './features/system/components/ColorModeButton';
|
|
||||||
|
Loading…
Reference in New Issue
Block a user