diff --git a/invokeai/frontend/web/src/app/constants.ts b/invokeai/frontend/web/src/app/constants.ts
index 5fd413d915..440d9330d6 100644
--- a/invokeai/frontend/web/src/app/constants.ts
+++ b/invokeai/frontend/web/src/app/constants.ts
@@ -1,6 +1,7 @@
-import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
-
// zod needs the array to be `as const` to infer the type correctly
+
+import { SchedulerParam } from 'features/parameters/types/parameterSchemas';
+
// this is the source of the `SchedulerParam` type, which is generated by zod
export const SCHEDULER_NAMES_AS_CONST = [
'euler',
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts
index 5ab30570d9..ee879a8915 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts
@@ -1,36 +1,70 @@
import { makeToast } from 'app/components/Toaster';
+import { log } from 'app/logging/useLogger';
+import { loraRemoved } from 'features/lora/store/loraSlice';
import { modelSelected } from 'features/parameters/store/actions';
import {
modelChanged,
vaeSelected,
} from 'features/parameters/store/generationSlice';
-import { zMainModel } from 'features/parameters/store/parameterZodSchemas';
+import { zMainModel } from 'features/parameters/types/parameterSchemas';
import { addToast } from 'features/system/store/systemSlice';
+import { forEach } from 'lodash-es';
import { startAppListening } from '..';
-import { lorasCleared } from '../../../../../features/lora/store/loraSlice';
+
+const moduleLog = log.child({ module: 'models' });
export const addModelSelectedListener = () => {
startAppListening({
actionCreator: modelSelected,
effect: (action, { getState, dispatch }) => {
const state = getState();
- const { base_model, model_name } = action.payload;
+ const result = zMainModel.safeParse(action.payload);
- if (state.generation.model?.base_model !== base_model) {
- dispatch(
- addToast(
- makeToast({
- title: 'Base model changed, clearing submodels',
- status: 'warning',
- })
- )
+ if (!result.success) {
+ moduleLog.error(
+ { error: result.error.format() },
+ 'Failed to parse main model'
);
- dispatch(vaeSelected(null));
- dispatch(lorasCleared());
- // TODO: controlnet cleared
+ return;
}
- const newModel = zMainModel.parse(action.payload);
+ const newModel = result.data;
+
+ const { base_model } = newModel;
+
+ if (state.generation.model?.base_model !== base_model) {
+ // we may need to reset some incompatible submodels
+ let modelsCleared = 0;
+
+ // handle incompatible loras
+ forEach(state.lora.loras, (lora, id) => {
+ if (lora.base_model !== base_model) {
+ dispatch(loraRemoved(id));
+ modelsCleared += 1;
+ }
+ });
+
+ // handle incompatible vae
+ const { vae } = state.generation;
+ if (vae && vae.base_model !== base_model) {
+ dispatch(vaeSelected(null));
+ modelsCleared += 1;
+ }
+
+ // TODO: handle incompatible controlnet; pending model manager support
+ if (modelsCleared > 0) {
+ dispatch(
+ addToast(
+ makeToast({
+ title: `Base model changed, cleared ${modelsCleared} incompatible submodel${
+ modelsCleared === 1 ? '' : 's'
+ }`,
+ status: 'warning',
+ })
+ )
+ );
+ }
+ }
dispatch(modelChanged(newModel));
},
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts
index ee02028848..f8abcfa758 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts
@@ -1,8 +1,19 @@
-import { modelChanged } from 'features/parameters/store/generationSlice';
-import { some } from 'lodash-es';
+import { log } from 'app/logging/useLogger';
+import { loraRemoved } from 'features/lora/store/loraSlice';
+import {
+ modelChanged,
+ vaeSelected,
+} from 'features/parameters/store/generationSlice';
+import {
+ zMainModel,
+ zVaeModel,
+} from 'features/parameters/types/parameterSchemas';
+import { forEach, some } from 'lodash-es';
import { modelsApi } from 'services/api/endpoints/models';
import { startAppListening } from '..';
+const moduleLog = log.child({ module: 'models' });
+
export const addModelsLoadedListener = () => {
startAppListening({
matcher: modelsApi.endpoints.getMainModels.matchFulfilled,
@@ -31,12 +42,92 @@ export const addModelsLoadedListener = () => {
return;
}
- dispatch(
- modelChanged({
- base_model: firstModel.base_model,
- model_name: firstModel.model_name,
- })
+ const result = zMainModel.safeParse(firstModel);
+
+ if (!result.success) {
+ moduleLog.error(
+ { error: result.error.format() },
+ 'Failed to parse main model'
+ );
+ return;
+ }
+
+ dispatch(modelChanged(result.data));
+ },
+ });
+ startAppListening({
+ matcher: modelsApi.endpoints.getVaeModels.matchFulfilled,
+ effect: async (action, { getState, dispatch }) => {
+ // VAEs loaded, need to reset the VAE is it's no longer available
+
+ const currentVae = getState().generation.vae;
+
+ if (currentVae === null) {
+ // null is a valid VAE! it means "use the default with the main model"
+ return;
+ }
+
+ const isCurrentVAEAvailable = some(
+ action.payload.entities,
+ (m) =>
+ m?.model_name === currentVae?.model_name &&
+ m?.base_model === currentVae?.base_model
);
+
+ if (isCurrentVAEAvailable) {
+ return;
+ }
+
+ const firstModelId = action.payload.ids[0];
+ const firstModel = action.payload.entities[firstModelId];
+
+ if (!firstModel) {
+ // No custom VAEs loaded at all; use the default
+ dispatch(modelChanged(null));
+ return;
+ }
+
+ const result = zVaeModel.safeParse(firstModel);
+
+ if (!result.success) {
+ moduleLog.error(
+ { error: result.error.format() },
+ 'Failed to parse VAE model'
+ );
+ return;
+ }
+
+ dispatch(vaeSelected(result.data));
+ },
+ });
+ startAppListening({
+ matcher: modelsApi.endpoints.getLoRAModels.matchFulfilled,
+ effect: async (action, { getState, dispatch }) => {
+ // LoRA models loaded - need to remove missing LoRAs from state
+
+ const loras = getState().lora.loras;
+
+ forEach(loras, (lora, id) => {
+ const isLoRAAvailable = some(
+ action.payload.entities,
+ (m) =>
+ m?.model_name === lora?.model_name &&
+ m?.base_model === lora?.base_model
+ );
+
+ if (isLoRAAvailable) {
+ return;
+ }
+
+ dispatch(loraRemoved(id));
+ });
+ },
+ });
+ startAppListening({
+ matcher: modelsApi.endpoints.getControlNetModels.matchFulfilled,
+ effect: async (action, { getState, dispatch }) => {
+ // ControlNet models loaded - need to remove missing ControlNets from state
+ // TODO: pending model manager controlnet support
},
});
};
diff --git a/invokeai/frontend/web/src/features/embedding/components/ParamEmbeddingPopover.tsx b/invokeai/frontend/web/src/features/embedding/components/ParamEmbeddingPopover.tsx
index 32822125d2..89302b78d4 100644
--- a/invokeai/frontend/web/src/features/embedding/components/ParamEmbeddingPopover.tsx
+++ b/invokeai/frontend/web/src/features/embedding/components/ParamEmbeddingPopover.tsx
@@ -11,7 +11,7 @@ import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
-import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect';
+import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { forEach } from 'lodash-es';
import { PropsWithChildren, useCallback, useMemo, useRef } from 'react';
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
diff --git a/invokeai/frontend/web/src/features/lora/components/ParamLora.tsx b/invokeai/frontend/web/src/features/lora/components/ParamLora.tsx
index 7dba2aa6ed..a1584ca13a 100644
--- a/invokeai/frontend/web/src/features/lora/components/ParamLora.tsx
+++ b/invokeai/frontend/web/src/features/lora/components/ParamLora.tsx
@@ -5,14 +5,14 @@ import IAISlider from 'common/components/IAISlider';
import { memo, useCallback } from 'react';
import { FaTrash } from 'react-icons/fa';
import {
- Lora,
+ LoRA,
loraRemoved,
loraWeightChanged,
loraWeightReset,
} from '../store/loraSlice';
type Props = {
- lora: Lora;
+ lora: LoRA;
};
const ParamLora = (props: Props) => {
diff --git a/invokeai/frontend/web/src/features/lora/components/ParamLoraCollapse.tsx b/invokeai/frontend/web/src/features/lora/components/ParamLoraCollapse.tsx
index 436c32f46b..e212efbfa2 100644
--- a/invokeai/frontend/web/src/features/lora/components/ParamLoraCollapse.tsx
+++ b/invokeai/frontend/web/src/features/lora/components/ParamLoraCollapse.tsx
@@ -6,9 +6,9 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAICollapse from 'common/components/IAICollapse';
import { size } from 'lodash-es';
import { memo } from 'react';
-import ParamLoraList from './ParamLoraList';
-import ParamLoraSelect from './ParamLoraSelect';
import { useFeatureStatus } from '../../system/hooks/useFeatureStatus';
+import ParamLoraList from './ParamLoraList';
+import ParamLoRASelect from './ParamLoraSelect';
const selector = createSelector(
stateSelector,
@@ -33,7 +33,7 @@ const ParamLoraCollapse = () => {
return (
-
+
diff --git a/invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx b/invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx
index ebceeb34db..f0aa252339 100644
--- a/invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx
+++ b/invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx
@@ -7,7 +7,7 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { loraAdded } from 'features/lora/store/loraSlice';
-import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect';
+import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { forEach } from 'lodash-es';
import { useCallback, useMemo } from 'react';
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
@@ -20,23 +20,23 @@ const selector = createSelector(
defaultSelectorOptions
);
-const ParamLoraSelect = () => {
+const ParamLoRASelect = () => {
const dispatch = useAppDispatch();
const { loras } = useAppSelector(selector);
- const { data: lorasQueryData } = useGetLoRAModelsQuery();
+ const { data: loraModels } = useGetLoRAModelsQuery();
const currentMainModel = useAppSelector(
(state: RootState) => state.generation.model
);
const data = useMemo(() => {
- if (!lorasQueryData) {
+ if (!loraModels) {
return [];
}
const data: SelectItem[] = [];
- forEach(lorasQueryData.entities, (lora, id) => {
+ forEach(loraModels.entities, (lora, id) => {
if (!lora || Boolean(id in loras)) {
return;
}
@@ -55,23 +55,25 @@ const ParamLoraSelect = () => {
});
return data.sort((a, b) => (a.disabled && !b.disabled ? 1 : -1));
- }, [loras, lorasQueryData, currentMainModel?.base_model]);
+ }, [loras, loraModels, currentMainModel?.base_model]);
const handleChange = useCallback(
(v: string | null | undefined) => {
if (!v) {
return;
}
- const loraEntity = lorasQueryData?.entities[v];
+ const loraEntity = loraModels?.entities[v];
+
if (!loraEntity) {
return;
}
+
dispatch(loraAdded(loraEntity));
},
- [dispatch, lorasQueryData?.entities]
+ [dispatch, loraModels?.entities]
);
- if (lorasQueryData?.ids.length === 0) {
+ if (loraModels?.ids.length === 0) {
return (
@@ -98,4 +100,4 @@ const ParamLoraSelect = () => {
);
};
-export default ParamLoraSelect;
+export default ParamLoRASelect;
diff --git a/invokeai/frontend/web/src/features/lora/store/loraSlice.ts b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts
index a97a0887a5..2dc739a737 100644
--- a/invokeai/frontend/web/src/features/lora/store/loraSlice.ts
+++ b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts
@@ -1,8 +1,8 @@
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
-import { LoRAModelParam } from 'features/parameters/store/parameterZodSchemas';
+import { LoRAModelParam } from 'features/parameters/types/parameterSchemas';
import { LoRAModelConfigEntity } from 'services/api/endpoints/models';
-export type Lora = LoRAModelParam & {
+export type LoRA = LoRAModelParam & {
weight: number;
};
@@ -11,7 +11,7 @@ export const defaultLoRAConfig = {
};
export type LoraState = {
- loras: Record;
+ loras: Record;
};
export const intialLoraState: LoraState = {
@@ -24,7 +24,7 @@ export const loraSlice = createSlice({
reducers: {
loraAdded: (state, action: PayloadAction) => {
const { model_name, id, base_model } = action.payload;
- state.loras[id] = { id, model_name, base_model, ...defaultLoRAConfig };
+ state.loras[id] = { model_name, base_model, ...defaultLoRAConfig };
},
loraRemoved: (state, action: PayloadAction) => {
const id = action.payload;
diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/LoRAModelInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/LoRAModelInputFieldComponent.tsx
index 271408b817..861f919b33 100644
--- a/invokeai/frontend/web/src/features/nodes/components/fields/LoRAModelInputFieldComponent.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/fields/LoRAModelInputFieldComponent.tsx
@@ -6,7 +6,7 @@ import {
VaeModelInputFieldTemplate,
VaeModelInputFieldValue,
} from 'features/nodes/types/types';
-import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
+import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { forEach, isString } from 'lodash-es';
import { memo, useCallback, useEffect, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@@ -46,7 +46,7 @@ const LoRAModelInputFieldComponent = (
data.push({
value: id,
label: model.model_name,
- group: BASE_MODEL_NAME_MAP[model.base_model],
+ group: MODEL_TYPE_MAP[model.base_model],
});
});
@@ -88,8 +88,7 @@ const LoRAModelInputFieldComponent = (
{
- const { id, name, weight } = lora;
- const loraField = modelIdToLoRAModelField(id);
- const currentLoraNodeId = `${LORA_LOADER}_${loraField.model_name.replace(
- '.',
- '_'
- )}`;
+ const { model_name, base_model, weight } = lora;
+ const currentLoraNodeId = `${LORA_LOADER}_${model_name.replace('.', '_')}`;
const loraLoaderNode: LoraLoaderInvocation = {
type: 'lora_loader',
id: currentLoraNodeId,
- lora: loraField,
+ lora,
weight,
};
// add the lora to the metadata accumulator
if (metadataAccumulator) {
- metadataAccumulator.loras.push({ lora: loraField, weight });
+ metadataAccumulator.loras.push({
+ lora: { model_name, base_model },
+ weight,
+ });
}
// add to graph
diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addVAEToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addVAEToGraph.ts
index d76fec093c..8574dc4e46 100644
--- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addVAEToGraph.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addVAEToGraph.ts
@@ -1,7 +1,6 @@
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import { MetadataAccumulatorInvocation } from 'services/api/types';
-import { modelIdToVAEModelField } from '../modelIdToVAEModelField';
import {
IMAGE_TO_IMAGE_GRAPH,
IMAGE_TO_LATENTS,
@@ -19,7 +18,6 @@ export const addVAEToGraph = (
graph: NonNullableGraph
): void => {
const { vae } = state.generation;
- const vae_model = modelIdToVAEModelField(vae?.id || '');
const isAutoVae = !vae;
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
@@ -30,7 +28,7 @@ export const addVAEToGraph = (
graph.nodes[VAE_LOADER] = {
type: 'vae_loader',
id: VAE_LOADER,
- vae_model,
+ vae_model: vae,
};
}
@@ -74,6 +72,6 @@ export const addVAEToGraph = (
}
if (vae && metadataAccumulator) {
- metadataAccumulator.vae = vae_model;
+ metadataAccumulator.vae = vae;
}
};
diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts
index ea68ac50fb..6963cf16b8 100644
--- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts
@@ -1,12 +1,12 @@
import { RootState } from 'app/store/store';
import { InputFieldValue } from 'features/nodes/types/types';
+import { modelIdToLoRAModelParam } from 'features/parameters/util/modelIdToLoRAModelParam';
+import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
+import { modelIdToVAEModelParam } from 'features/parameters/util/modelIdToVAEModelParam';
import { cloneDeep, omit, reduce } from 'lodash-es';
import { Graph } from 'services/api/types';
import { AnyInvocation } from 'services/events/types';
import { v4 as uuidv4 } from 'uuid';
-import { modelIdToLoRAModelField } from '../modelIdToLoRAName';
-import { modelIdToMainModelField } from '../modelIdToMainModelField';
-import { modelIdToVAEModelField } from '../modelIdToVAEModelField';
/**
* We need to do special handling for some fields
@@ -29,19 +29,19 @@ export const parseFieldValue = (field: InputFieldValue) => {
if (field.type === 'model') {
if (field.value) {
- return modelIdToMainModelField(field.value);
+ return modelIdToMainModelParam(field.value);
}
}
if (field.type === 'vae_model') {
if (field.value) {
- return modelIdToVAEModelField(field.value);
+ return modelIdToVAEModelParam(field.value);
}
}
if (field.type === 'lora_model') {
if (field.value) {
- return modelIdToLoRAModelField(field.value);
+ return modelIdToLoRAModelParam(field.value);
}
}
diff --git a/invokeai/frontend/web/src/features/nodes/util/modelIdToLoRAName.ts b/invokeai/frontend/web/src/features/nodes/util/modelIdToLoRAName.ts
deleted file mode 100644
index 052b58484b..0000000000
--- a/invokeai/frontend/web/src/features/nodes/util/modelIdToLoRAName.ts
+++ /dev/null
@@ -1,12 +0,0 @@
-import { BaseModelType, LoRAModelField } from 'services/api/types';
-
-export const modelIdToLoRAModelField = (loraId: string): LoRAModelField => {
- const [base_model, model_type, model_name] = loraId.split('/');
-
- const field: LoRAModelField = {
- base_model: base_model as BaseModelType,
- model_name,
- };
-
- return field;
-};
diff --git a/invokeai/frontend/web/src/features/nodes/util/modelIdToMainModelField.ts b/invokeai/frontend/web/src/features/nodes/util/modelIdToMainModelField.ts
deleted file mode 100644
index 6bb0f776b2..0000000000
--- a/invokeai/frontend/web/src/features/nodes/util/modelIdToMainModelField.ts
+++ /dev/null
@@ -1,16 +0,0 @@
-import { BaseModelType, MainModelField } from 'services/api/types';
-
-/**
- * Crudely converts a model id to a main model field
- * TODO: Make better
- */
-export const modelIdToMainModelField = (modelId: string): MainModelField => {
- const [base_model, model_type, model_name] = modelId.split('/');
-
- const field: MainModelField = {
- base_model: base_model as BaseModelType,
- model_name,
- };
-
- return field;
-};
diff --git a/invokeai/frontend/web/src/features/nodes/util/modelIdToVAEModelField.ts b/invokeai/frontend/web/src/features/nodes/util/modelIdToVAEModelField.ts
deleted file mode 100644
index 0cb608a936..0000000000
--- a/invokeai/frontend/web/src/features/nodes/util/modelIdToVAEModelField.ts
+++ /dev/null
@@ -1,16 +0,0 @@
-import { BaseModelType, VAEModelField } from 'services/api/types';
-
-/**
- * Crudely converts a model id to a main model field
- * TODO: Make better
- */
-export const modelIdToVAEModelField = (modelId: string): VAEModelField => {
- const [base_model, model_type, model_name] = modelId.split('/');
-
- const field: VAEModelField = {
- base_model: base_model as BaseModelType,
- model_name,
- };
-
- return field;
-};
diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamModelandVAEandScheduler.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamModelandVAEandScheduler.tsx
index 9054afcca2..74418de1d3 100644
--- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamModelandVAEandScheduler.tsx
+++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamModelandVAEandScheduler.tsx
@@ -1,8 +1,8 @@
import { Box, Flex } from '@chakra-ui/react';
-import ModelSelect from 'features/system/components/ModelSelect';
-import VAESelect from 'features/system/components/VAESelect';
+import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { memo } from 'react';
-import { useFeatureStatus } from '../../../../system/hooks/useFeatureStatus';
+import ParamMainModelSelect from '../MainModel/ParamMainModelSelect';
+import ParamVAEModelSelect from '../VAEModel/ParamVAEModelSelect';
import ParamScheduler from './ParamScheduler';
const ParamModelandVAEandScheduler = () => {
@@ -11,12 +11,12 @@ const ParamModelandVAEandScheduler = () => {
return (
-
+
{isVaeEnabled && (
-
+
)}
diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx
index 8818dcba9b..be8db632bc 100644
--- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx
+++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx
@@ -5,7 +5,7 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import { setScheduler } from 'features/parameters/store/generationSlice';
-import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
+import { SchedulerParam } from 'features/parameters/types/parameterSchemas';
import { uiSelector } from 'features/ui/store/uiSelectors';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
diff --git a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/MainModel/ParamMainModelSelect.tsx
similarity index 72%
rename from invokeai/frontend/web/src/features/system/components/ModelSelect.tsx
rename to invokeai/frontend/web/src/features/parameters/components/Parameters/MainModel/ParamMainModelSelect.tsx
index bc3da20b06..dbe732fc55 100644
--- a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx
+++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/MainModel/ParamMainModelSelect.tsx
@@ -8,27 +8,23 @@ import { SelectItem } from '@mantine/core';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
-import { modelIdToMainModelField } from 'features/nodes/util/modelIdToMainModelField';
import { modelSelected } from 'features/parameters/store/actions';
+import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
+import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
import { forEach } from 'lodash-es';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
-export const MODEL_TYPE_MAP = {
- 'sd-1': 'Stable Diffusion 1.x',
- 'sd-2': 'Stable Diffusion 2.x',
-};
-
const selector = createSelector(
stateSelector,
- (state) => ({ currentModel: state.generation.model }),
+ (state) => ({ model: state.generation.model }),
defaultSelectorOptions
);
-const ModelSelect = () => {
+const ParamMainModelSelect = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
- const { currentModel } = useAppSelector(selector);
+ const { model } = useAppSelector(selector);
const { data: mainModels, isLoading } = useGetMainModelsQuery();
@@ -54,12 +50,13 @@ const ModelSelect = () => {
return data;
}, [mainModels]);
+ // grab the full model entity from the RTK Query cache
+ // TODO: maybe we should just store the full model entity in state?
const selectedModel = useMemo(
() =>
- mainModels?.entities[
- `${currentModel?.base_model}/main/${currentModel?.model_name}`
- ],
- [mainModels?.entities, currentModel]
+ mainModels?.entities[`${model?.base_model}/main/${model?.model_name}`] ??
+ null,
+ [mainModels?.entities, model]
);
const handleChangeModel = useCallback(
@@ -68,8 +65,13 @@ const ModelSelect = () => {
return;
}
- const modelField = modelIdToMainModelField(v);
- dispatch(modelSelected(modelField));
+ const newModel = modelIdToMainModelParam(v);
+
+ if (!newModel) {
+ return;
+ }
+
+ dispatch(modelSelected(newModel));
},
[dispatch]
);
@@ -95,4 +97,4 @@ const ModelSelect = () => {
);
};
-export default memo(ModelSelect);
+export default memo(ParamMainModelSelect);
diff --git a/invokeai/frontend/web/src/features/system/components/VAESelect.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/VAEModel/ParamVAEModelSelect.tsx
similarity index 54%
rename from invokeai/frontend/web/src/features/system/components/VAESelect.tsx
rename to invokeai/frontend/web/src/features/parameters/components/Parameters/VAEModel/ParamVAEModelSelect.tsx
index bed1b72123..d1e040e181 100644
--- a/invokeai/frontend/web/src/features/system/components/VAESelect.tsx
+++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/VAEModel/ParamVAEModelSelect.tsx
@@ -1,4 +1,4 @@
-import { memo, useCallback, useEffect, useMemo } from 'react';
+import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
@@ -8,26 +8,30 @@ import { SelectItem } from '@mantine/core';
import { forEach } from 'lodash-es';
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
-import { RootState } from 'app/store/store';
+import { createSelector } from '@reduxjs/toolkit';
+import { stateSelector } from 'app/store/store';
+import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { vaeSelected } from 'features/parameters/store/generationSlice';
-import { zVaeModel } from 'features/parameters/store/parameterZodSchemas';
-import { MODEL_TYPE_MAP } from './ModelSelect';
+import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
+import { modelIdToVAEModelParam } from 'features/parameters/util/modelIdToVAEModelParam';
-const VAESelect = () => {
+const selector = createSelector(
+ stateSelector,
+ ({ generation }) => {
+ const { model, vae } = generation;
+ return { model, vae };
+ },
+ defaultSelectorOptions
+);
+
+const ParamVAEModelSelect = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
+ const { model, vae } = useAppSelector(selector);
const { data: vaeModels } = useGetVaeModelsQuery();
- const currentMainModel = useAppSelector(
- (state: RootState) => state.generation.model
- );
-
- const selectedVae = useAppSelector(
- (state: RootState) => state.generation.vae
- );
-
const data = useMemo(() => {
if (!vaeModels) {
return [];
@@ -41,30 +45,32 @@ const VAESelect = () => {
},
];
- forEach(vaeModels.entities, (model, id) => {
- if (!model) {
+ forEach(vaeModels.entities, (vae, id) => {
+ if (!vae) {
return;
}
- const disabled = currentMainModel?.base_model !== model.base_model;
+ const disabled = model?.base_model !== vae.base_model;
data.push({
value: id,
- label: model.model_name,
- group: MODEL_TYPE_MAP[model.base_model],
+ label: vae.model_name,
+ group: MODEL_TYPE_MAP[vae.base_model],
disabled,
tooltip: disabled
- ? `Incompatible base model: ${model.base_model}`
+ ? `Incompatible base model: ${vae.base_model}`
: undefined,
});
});
return data.sort((a, b) => (a.disabled && !b.disabled ? 1 : -1));
- }, [vaeModels, currentMainModel?.base_model]);
+ }, [vaeModels, model?.base_model]);
+ // grab the full model entity from the RTK Query cache
const selectedVaeModel = useMemo(
- () => (selectedVae?.id ? vaeModels?.entities[selectedVae?.id] : null),
- [vaeModels?.entities, selectedVae]
+ () =>
+ vaeModels?.entities[`${vae?.base_model}/vae/${vae?.model_name}`] ?? null,
+ [vaeModels?.entities, vae]
);
const handleChangeModel = useCallback(
@@ -74,32 +80,23 @@ const VAESelect = () => {
return;
}
- const [base_model, type, name] = v.split('/');
+ const newVaeModel = modelIdToVAEModelParam(v);
- const model = zVaeModel.parse({
- id: v,
- name,
- base_model,
- });
+ if (!newVaeModel) {
+ return;
+ }
- dispatch(vaeSelected(model));
+ dispatch(vaeSelected(newVaeModel));
},
[dispatch]
);
- useEffect(() => {
- if (selectedVae && vaeModels?.ids.includes(selectedVae.id)) {
- return;
- }
- dispatch(vaeSelected(null));
- }, [handleChangeModel, vaeModels?.ids, selectedVae, dispatch]);
-
return (
{
);
};
-export default memo(VAESelect);
+export default memo(ParamVAEModelSelect);
diff --git a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts
index 9e4f5aeff0..6329d9d677 100644
--- a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts
+++ b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts
@@ -28,7 +28,7 @@ import {
isValidSteps,
isValidStrength,
isValidWidth,
-} from '../store/parameterZodSchemas';
+} from '../types/parameterSchemas';
export const useRecallParameters = () => {
const dispatch = useAppDispatch();
diff --git a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts
index dff277ae7e..c5ec7930a4 100644
--- a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts
+++ b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts
@@ -13,6 +13,7 @@ import { clipSkipMap } from '../components/Parameters/Advanced/ParamClipSkip';
import {
CfgScaleParam,
HeightParam,
+ MainModelParam,
NegativePromptParam,
PositivePromptParam,
SchedulerParam,
@@ -22,7 +23,7 @@ import {
VaeModelParam,
WidthParam,
zMainModel,
-} from './parameterZodSchemas';
+} from '../types/parameterSchemas';
export interface GenerationState {
cfgScale: CfgScaleParam;
@@ -226,18 +227,19 @@ export const generationSlice = createSlice({
const { image_name, width, height } = action.payload;
state.initialImage = { imageName: image_name, width, height };
},
- modelChanged: (state, action: PayloadAction) => {
- if (!action.payload) {
- state.model = null;
- }
+ modelChanged: (state, action: PayloadAction) => {
+ state.model = action.payload;
- state.model = zMainModel.parse(action.payload);
+ if (state.model === null) {
+ return;
+ }
// Clamp ClipSkip Based On Selected Model
const { maxClip } = clipSkipMap[state.model.base_model];
state.clipSkip = clamp(state.clipSkip, 0, maxClip);
},
vaeSelected: (state, action: PayloadAction) => {
+ // null is a valid VAE!
state.vae = action.payload;
},
setClipSkip: (state, action: PayloadAction) => {
@@ -253,11 +255,15 @@ export const generationSlice = createSlice({
if (defaultModel && !state.model) {
const [base_model, model_type, model_name] = defaultModel.split('/');
- state.model = zMainModel.parse({
- id: defaultModel,
- name: model_name,
+
+ const result = zMainModel.safeParse({
+ model_name,
base_model,
});
+
+ if (result.success) {
+ state.model = result.data;
+ }
}
});
builder.addCase(setShouldShowAdvancedOptions, (state, action) => {
diff --git a/invokeai/frontend/web/src/features/parameters/types/constants.ts b/invokeai/frontend/web/src/features/parameters/types/constants.ts
new file mode 100644
index 0000000000..56f808738d
--- /dev/null
+++ b/invokeai/frontend/web/src/features/parameters/types/constants.ts
@@ -0,0 +1,4 @@
+export const MODEL_TYPE_MAP = {
+ 'sd-1': 'Stable Diffusion 1.x',
+ 'sd-2': 'Stable Diffusion 2.x',
+};
diff --git a/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts b/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts
similarity index 98%
rename from invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts
rename to invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts
index 16fbf0e155..aa2c60f3a8 100644
--- a/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts
+++ b/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts
@@ -135,7 +135,7 @@ export type BaseModelParam = z.infer;
* TODO: Make this a dynamically generated enum?
*/
export const zMainModel = z.object({
- model_name: z.string(),
+ model_name: z.string().min(1),
base_model: zBaseModel,
});
@@ -152,8 +152,7 @@ export const isValidMainModel = (val: unknown): val is MainModelParam =>
* Zod schema for VAE parameter
*/
export const zVaeModel = z.object({
- id: z.string(),
- name: z.string(),
+ model_name: z.string().min(1),
base_model: zBaseModel,
});
/**
@@ -169,8 +168,7 @@ export const isValidVaeModel = (val: unknown): val is VaeModelParam =>
* Zod schema for LoRA
*/
export const zLoRAModel = z.object({
- id: z.string(),
- model_name: z.string(),
+ model_name: z.string().min(1),
base_model: zBaseModel,
});
/**
diff --git a/invokeai/frontend/web/src/features/parameters/util/modelIdToLoRAModelParam.ts b/invokeai/frontend/web/src/features/parameters/util/modelIdToLoRAModelParam.ts
new file mode 100644
index 0000000000..2ea7cacb5d
--- /dev/null
+++ b/invokeai/frontend/web/src/features/parameters/util/modelIdToLoRAModelParam.ts
@@ -0,0 +1,18 @@
+import { LoRAModelParam, zLoRAModel } from '../types/parameterSchemas';
+
+export const modelIdToLoRAModelParam = (
+ loraId: string
+): LoRAModelParam | undefined => {
+ const [base_model, model_type, model_name] = loraId.split('/');
+
+ const result = zLoRAModel.safeParse({
+ base_model,
+ model_name,
+ });
+
+ if (!result.success) {
+ return;
+ }
+
+ return result.data;
+};
diff --git a/invokeai/frontend/web/src/features/parameters/util/modelIdToMainModelParam.ts b/invokeai/frontend/web/src/features/parameters/util/modelIdToMainModelParam.ts
new file mode 100644
index 0000000000..b73d3c5f0d
--- /dev/null
+++ b/invokeai/frontend/web/src/features/parameters/util/modelIdToMainModelParam.ts
@@ -0,0 +1,21 @@
+import {
+ MainModelParam,
+ zMainModel,
+} from 'features/parameters/types/parameterSchemas';
+
+export const modelIdToMainModelParam = (
+ modelId: string
+): MainModelParam | undefined => {
+ const [base_model, model_type, model_name] = modelId.split('/');
+
+ const result = zMainModel.safeParse({
+ base_model,
+ model_name,
+ });
+
+ if (!result.success) {
+ return;
+ }
+
+ return result.data;
+};
diff --git a/invokeai/frontend/web/src/features/parameters/util/modelIdToVAEModelParam.ts b/invokeai/frontend/web/src/features/parameters/util/modelIdToVAEModelParam.ts
new file mode 100644
index 0000000000..49856531d6
--- /dev/null
+++ b/invokeai/frontend/web/src/features/parameters/util/modelIdToVAEModelParam.ts
@@ -0,0 +1,18 @@
+import { VaeModelParam, zVaeModel } from '../types/parameterSchemas';
+
+export const modelIdToVAEModelParam = (
+ modelId: string
+): VaeModelParam | undefined => {
+ const [base_model, model_type, model_name] = modelId.split('/');
+
+ const result = zVaeModel.safeParse({
+ base_model,
+ model_name,
+ });
+
+ if (!result.success) {
+ return;
+ }
+
+ return result.data;
+};
diff --git a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx
index 26c11604e1..959559548e 100644
--- a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx
+++ b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx
@@ -2,7 +2,7 @@ import { SCHEDULER_LABEL_MAP, SCHEDULER_NAMES } from 'app/constants';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
-import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
+import { SchedulerParam } from 'features/parameters/types/parameterSchemas';
import { favoriteSchedulersChanged } from 'features/ui/store/uiSlice';
import { map } from 'lodash-es';
import { useCallback } from 'react';
diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx
index 586be4566e..c101b68d45 100644
--- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx
+++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx
@@ -10,7 +10,7 @@ import type { RootState } from 'app/store/store';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
-import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect';
+import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { S } from 'services/api/types';
import ModelConvert from './ModelConvert';
diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx
index f0ed12d361..e5b6fd625f 100644
--- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx
+++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx
@@ -10,7 +10,7 @@ import type { RootState } from 'app/store/store';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
-import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect';
+import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { S } from 'services/api/types';
type DiffusersModel =
diff --git a/invokeai/frontend/web/src/features/ui/store/uiSlice.ts b/invokeai/frontend/web/src/features/ui/store/uiSlice.ts
index 4f38f84fe2..ccce10f5c4 100644
--- a/invokeai/frontend/web/src/features/ui/store/uiSlice.ts
+++ b/invokeai/frontend/web/src/features/ui/store/uiSlice.ts
@@ -1,7 +1,7 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import { initialImageChanged } from 'features/parameters/store/generationSlice';
-import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
+import { SchedulerParam } from 'features/parameters/types/parameterSchemas';
import { setActiveTabReducer } from './extraReducers';
import { InvokeTabName } from './tabMap';
import { AddNewModelType, UIState } from './uiTypes';
diff --git a/invokeai/frontend/web/src/features/ui/store/uiTypes.ts b/invokeai/frontend/web/src/features/ui/store/uiTypes.ts
index e574f0ab79..4c72bd6239 100644
--- a/invokeai/frontend/web/src/features/ui/store/uiTypes.ts
+++ b/invokeai/frontend/web/src/features/ui/store/uiTypes.ts
@@ -1,4 +1,4 @@
-import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
+import { SchedulerParam } from 'features/parameters/types/parameterSchemas';
export type AddNewModelType = 'ckpt' | 'diffusers' | null;
diff --git a/invokeai/frontend/web/src/index.ts b/invokeai/frontend/web/src/index.ts
index e70e756ed9..add4999b6d 100644
--- a/invokeai/frontend/web/src/index.ts
+++ b/invokeai/frontend/web/src/index.ts
@@ -2,8 +2,8 @@ export { default as InvokeAIUI } from './app/components/InvokeAIUI';
export type { PartialAppConfig } from './app/types/invokeai';
export { default as IAIIconButton } from './common/components/IAIIconButton';
export { default as IAIPopover } from './common/components/IAIPopover';
+export { default as ParamMainModelSelect } from './features/parameters/components/Parameters/MainModel/ParamMainModelSelect';
+export { default as ColorModeButton } from './features/system/components/ColorModeButton';
export { default as InvokeAiLogoComponent } from './features/system/components/InvokeAILogoComponent';
-export { default as ModelSelect } from './features/system/components/ModelSelect';
export { default as SettingsModal } from './features/system/components/SettingsModal/SettingsModal';
export { default as StatusIndicator } from './features/system/components/StatusIndicator';
-export { default as ColorModeButton } from './features/system/components/ColorModeButton';