diff --git a/invokeai/frontend/web/src/app/constants.ts b/invokeai/frontend/web/src/app/constants.ts index 5fd413d915..440d9330d6 100644 --- a/invokeai/frontend/web/src/app/constants.ts +++ b/invokeai/frontend/web/src/app/constants.ts @@ -1,6 +1,7 @@ -import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas'; - // zod needs the array to be `as const` to infer the type correctly + +import { SchedulerParam } from 'features/parameters/types/parameterSchemas'; + // this is the source of the `SchedulerParam` type, which is generated by zod export const SCHEDULER_NAMES_AS_CONST = [ 'euler', diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts index 5ab30570d9..ee879a8915 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts @@ -1,36 +1,70 @@ import { makeToast } from 'app/components/Toaster'; +import { log } from 'app/logging/useLogger'; +import { loraRemoved } from 'features/lora/store/loraSlice'; import { modelSelected } from 'features/parameters/store/actions'; import { modelChanged, vaeSelected, } from 'features/parameters/store/generationSlice'; -import { zMainModel } from 'features/parameters/store/parameterZodSchemas'; +import { zMainModel } from 'features/parameters/types/parameterSchemas'; import { addToast } from 'features/system/store/systemSlice'; +import { forEach } from 'lodash-es'; import { startAppListening } from '..'; -import { lorasCleared } from '../../../../../features/lora/store/loraSlice'; + +const moduleLog = log.child({ module: 'models' }); export const addModelSelectedListener = () => { startAppListening({ actionCreator: modelSelected, effect: (action, { getState, dispatch }) => { const state = getState(); - const { base_model, model_name } = action.payload; + const result = zMainModel.safeParse(action.payload); - if (state.generation.model?.base_model !== base_model) { - dispatch( - addToast( - makeToast({ - title: 'Base model changed, clearing submodels', - status: 'warning', - }) - ) + if (!result.success) { + moduleLog.error( + { error: result.error.format() }, + 'Failed to parse main model' ); - dispatch(vaeSelected(null)); - dispatch(lorasCleared()); - // TODO: controlnet cleared + return; } - const newModel = zMainModel.parse(action.payload); + const newModel = result.data; + + const { base_model } = newModel; + + if (state.generation.model?.base_model !== base_model) { + // we may need to reset some incompatible submodels + let modelsCleared = 0; + + // handle incompatible loras + forEach(state.lora.loras, (lora, id) => { + if (lora.base_model !== base_model) { + dispatch(loraRemoved(id)); + modelsCleared += 1; + } + }); + + // handle incompatible vae + const { vae } = state.generation; + if (vae && vae.base_model !== base_model) { + dispatch(vaeSelected(null)); + modelsCleared += 1; + } + + // TODO: handle incompatible controlnet; pending model manager support + if (modelsCleared > 0) { + dispatch( + addToast( + makeToast({ + title: `Base model changed, cleared ${modelsCleared} incompatible submodel${ + modelsCleared === 1 ? '' : 's' + }`, + status: 'warning', + }) + ) + ); + } + } dispatch(modelChanged(newModel)); }, diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts index ee02028848..f8abcfa758 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts @@ -1,8 +1,19 @@ -import { modelChanged } from 'features/parameters/store/generationSlice'; -import { some } from 'lodash-es'; +import { log } from 'app/logging/useLogger'; +import { loraRemoved } from 'features/lora/store/loraSlice'; +import { + modelChanged, + vaeSelected, +} from 'features/parameters/store/generationSlice'; +import { + zMainModel, + zVaeModel, +} from 'features/parameters/types/parameterSchemas'; +import { forEach, some } from 'lodash-es'; import { modelsApi } from 'services/api/endpoints/models'; import { startAppListening } from '..'; +const moduleLog = log.child({ module: 'models' }); + export const addModelsLoadedListener = () => { startAppListening({ matcher: modelsApi.endpoints.getMainModels.matchFulfilled, @@ -31,12 +42,92 @@ export const addModelsLoadedListener = () => { return; } - dispatch( - modelChanged({ - base_model: firstModel.base_model, - model_name: firstModel.model_name, - }) + const result = zMainModel.safeParse(firstModel); + + if (!result.success) { + moduleLog.error( + { error: result.error.format() }, + 'Failed to parse main model' + ); + return; + } + + dispatch(modelChanged(result.data)); + }, + }); + startAppListening({ + matcher: modelsApi.endpoints.getVaeModels.matchFulfilled, + effect: async (action, { getState, dispatch }) => { + // VAEs loaded, need to reset the VAE is it's no longer available + + const currentVae = getState().generation.vae; + + if (currentVae === null) { + // null is a valid VAE! it means "use the default with the main model" + return; + } + + const isCurrentVAEAvailable = some( + action.payload.entities, + (m) => + m?.model_name === currentVae?.model_name && + m?.base_model === currentVae?.base_model ); + + if (isCurrentVAEAvailable) { + return; + } + + const firstModelId = action.payload.ids[0]; + const firstModel = action.payload.entities[firstModelId]; + + if (!firstModel) { + // No custom VAEs loaded at all; use the default + dispatch(modelChanged(null)); + return; + } + + const result = zVaeModel.safeParse(firstModel); + + if (!result.success) { + moduleLog.error( + { error: result.error.format() }, + 'Failed to parse VAE model' + ); + return; + } + + dispatch(vaeSelected(result.data)); + }, + }); + startAppListening({ + matcher: modelsApi.endpoints.getLoRAModels.matchFulfilled, + effect: async (action, { getState, dispatch }) => { + // LoRA models loaded - need to remove missing LoRAs from state + + const loras = getState().lora.loras; + + forEach(loras, (lora, id) => { + const isLoRAAvailable = some( + action.payload.entities, + (m) => + m?.model_name === lora?.model_name && + m?.base_model === lora?.base_model + ); + + if (isLoRAAvailable) { + return; + } + + dispatch(loraRemoved(id)); + }); + }, + }); + startAppListening({ + matcher: modelsApi.endpoints.getControlNetModels.matchFulfilled, + effect: async (action, { getState, dispatch }) => { + // ControlNet models loaded - need to remove missing ControlNets from state + // TODO: pending model manager controlnet support }, }); }; diff --git a/invokeai/frontend/web/src/features/embedding/components/ParamEmbeddingPopover.tsx b/invokeai/frontend/web/src/features/embedding/components/ParamEmbeddingPopover.tsx index 32822125d2..89302b78d4 100644 --- a/invokeai/frontend/web/src/features/embedding/components/ParamEmbeddingPopover.tsx +++ b/invokeai/frontend/web/src/features/embedding/components/ParamEmbeddingPopover.tsx @@ -11,7 +11,7 @@ import { RootState } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip'; -import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect'; +import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { forEach } from 'lodash-es'; import { PropsWithChildren, useCallback, useMemo, useRef } from 'react'; import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models'; diff --git a/invokeai/frontend/web/src/features/lora/components/ParamLora.tsx b/invokeai/frontend/web/src/features/lora/components/ParamLora.tsx index 7dba2aa6ed..a1584ca13a 100644 --- a/invokeai/frontend/web/src/features/lora/components/ParamLora.tsx +++ b/invokeai/frontend/web/src/features/lora/components/ParamLora.tsx @@ -5,14 +5,14 @@ import IAISlider from 'common/components/IAISlider'; import { memo, useCallback } from 'react'; import { FaTrash } from 'react-icons/fa'; import { - Lora, + LoRA, loraRemoved, loraWeightChanged, loraWeightReset, } from '../store/loraSlice'; type Props = { - lora: Lora; + lora: LoRA; }; const ParamLora = (props: Props) => { diff --git a/invokeai/frontend/web/src/features/lora/components/ParamLoraCollapse.tsx b/invokeai/frontend/web/src/features/lora/components/ParamLoraCollapse.tsx index 436c32f46b..e212efbfa2 100644 --- a/invokeai/frontend/web/src/features/lora/components/ParamLoraCollapse.tsx +++ b/invokeai/frontend/web/src/features/lora/components/ParamLoraCollapse.tsx @@ -6,9 +6,9 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAICollapse from 'common/components/IAICollapse'; import { size } from 'lodash-es'; import { memo } from 'react'; -import ParamLoraList from './ParamLoraList'; -import ParamLoraSelect from './ParamLoraSelect'; import { useFeatureStatus } from '../../system/hooks/useFeatureStatus'; +import ParamLoraList from './ParamLoraList'; +import ParamLoRASelect from './ParamLoraSelect'; const selector = createSelector( stateSelector, @@ -33,7 +33,7 @@ const ParamLoraCollapse = () => { return ( - + diff --git a/invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx b/invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx index ebceeb34db..f0aa252339 100644 --- a/invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx +++ b/invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx @@ -7,7 +7,7 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip'; import { loraAdded } from 'features/lora/store/loraSlice'; -import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect'; +import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { forEach } from 'lodash-es'; import { useCallback, useMemo } from 'react'; import { useGetLoRAModelsQuery } from 'services/api/endpoints/models'; @@ -20,23 +20,23 @@ const selector = createSelector( defaultSelectorOptions ); -const ParamLoraSelect = () => { +const ParamLoRASelect = () => { const dispatch = useAppDispatch(); const { loras } = useAppSelector(selector); - const { data: lorasQueryData } = useGetLoRAModelsQuery(); + const { data: loraModels } = useGetLoRAModelsQuery(); const currentMainModel = useAppSelector( (state: RootState) => state.generation.model ); const data = useMemo(() => { - if (!lorasQueryData) { + if (!loraModels) { return []; } const data: SelectItem[] = []; - forEach(lorasQueryData.entities, (lora, id) => { + forEach(loraModels.entities, (lora, id) => { if (!lora || Boolean(id in loras)) { return; } @@ -55,23 +55,25 @@ const ParamLoraSelect = () => { }); return data.sort((a, b) => (a.disabled && !b.disabled ? 1 : -1)); - }, [loras, lorasQueryData, currentMainModel?.base_model]); + }, [loras, loraModels, currentMainModel?.base_model]); const handleChange = useCallback( (v: string | null | undefined) => { if (!v) { return; } - const loraEntity = lorasQueryData?.entities[v]; + const loraEntity = loraModels?.entities[v]; + if (!loraEntity) { return; } + dispatch(loraAdded(loraEntity)); }, - [dispatch, lorasQueryData?.entities] + [dispatch, loraModels?.entities] ); - if (lorasQueryData?.ids.length === 0) { + if (loraModels?.ids.length === 0) { return ( @@ -98,4 +100,4 @@ const ParamLoraSelect = () => { ); }; -export default ParamLoraSelect; +export default ParamLoRASelect; diff --git a/invokeai/frontend/web/src/features/lora/store/loraSlice.ts b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts index a97a0887a5..2dc739a737 100644 --- a/invokeai/frontend/web/src/features/lora/store/loraSlice.ts +++ b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts @@ -1,8 +1,8 @@ import { PayloadAction, createSlice } from '@reduxjs/toolkit'; -import { LoRAModelParam } from 'features/parameters/store/parameterZodSchemas'; +import { LoRAModelParam } from 'features/parameters/types/parameterSchemas'; import { LoRAModelConfigEntity } from 'services/api/endpoints/models'; -export type Lora = LoRAModelParam & { +export type LoRA = LoRAModelParam & { weight: number; }; @@ -11,7 +11,7 @@ export const defaultLoRAConfig = { }; export type LoraState = { - loras: Record; + loras: Record; }; export const intialLoraState: LoraState = { @@ -24,7 +24,7 @@ export const loraSlice = createSlice({ reducers: { loraAdded: (state, action: PayloadAction) => { const { model_name, id, base_model } = action.payload; - state.loras[id] = { id, model_name, base_model, ...defaultLoRAConfig }; + state.loras[id] = { model_name, base_model, ...defaultLoRAConfig }; }, loraRemoved: (state, action: PayloadAction) => { const id = action.payload; diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/LoRAModelInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/LoRAModelInputFieldComponent.tsx index 271408b817..861f919b33 100644 --- a/invokeai/frontend/web/src/features/nodes/components/fields/LoRAModelInputFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/fields/LoRAModelInputFieldComponent.tsx @@ -6,7 +6,7 @@ import { VaeModelInputFieldTemplate, VaeModelInputFieldValue, } from 'features/nodes/types/types'; -import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect'; +import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { forEach, isString } from 'lodash-es'; import { memo, useCallback, useEffect, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; @@ -46,7 +46,7 @@ const LoRAModelInputFieldComponent = ( data.push({ value: id, label: model.model_name, - group: BASE_MODEL_NAME_MAP[model.base_model], + group: MODEL_TYPE_MAP[model.base_model], }); }); @@ -88,8 +88,7 @@ const LoRAModelInputFieldComponent = ( { - const { id, name, weight } = lora; - const loraField = modelIdToLoRAModelField(id); - const currentLoraNodeId = `${LORA_LOADER}_${loraField.model_name.replace( - '.', - '_' - )}`; + const { model_name, base_model, weight } = lora; + const currentLoraNodeId = `${LORA_LOADER}_${model_name.replace('.', '_')}`; const loraLoaderNode: LoraLoaderInvocation = { type: 'lora_loader', id: currentLoraNodeId, - lora: loraField, + lora, weight, }; // add the lora to the metadata accumulator if (metadataAccumulator) { - metadataAccumulator.loras.push({ lora: loraField, weight }); + metadataAccumulator.loras.push({ + lora: { model_name, base_model }, + weight, + }); } // add to graph diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addVAEToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addVAEToGraph.ts index d76fec093c..8574dc4e46 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addVAEToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addVAEToGraph.ts @@ -1,7 +1,6 @@ import { RootState } from 'app/store/store'; import { NonNullableGraph } from 'features/nodes/types/types'; import { MetadataAccumulatorInvocation } from 'services/api/types'; -import { modelIdToVAEModelField } from '../modelIdToVAEModelField'; import { IMAGE_TO_IMAGE_GRAPH, IMAGE_TO_LATENTS, @@ -19,7 +18,6 @@ export const addVAEToGraph = ( graph: NonNullableGraph ): void => { const { vae } = state.generation; - const vae_model = modelIdToVAEModelField(vae?.id || ''); const isAutoVae = !vae; const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as @@ -30,7 +28,7 @@ export const addVAEToGraph = ( graph.nodes[VAE_LOADER] = { type: 'vae_loader', id: VAE_LOADER, - vae_model, + vae_model: vae, }; } @@ -74,6 +72,6 @@ export const addVAEToGraph = ( } if (vae && metadataAccumulator) { - metadataAccumulator.vae = vae_model; + metadataAccumulator.vae = vae; } }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts index ea68ac50fb..6963cf16b8 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts @@ -1,12 +1,12 @@ import { RootState } from 'app/store/store'; import { InputFieldValue } from 'features/nodes/types/types'; +import { modelIdToLoRAModelParam } from 'features/parameters/util/modelIdToLoRAModelParam'; +import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam'; +import { modelIdToVAEModelParam } from 'features/parameters/util/modelIdToVAEModelParam'; import { cloneDeep, omit, reduce } from 'lodash-es'; import { Graph } from 'services/api/types'; import { AnyInvocation } from 'services/events/types'; import { v4 as uuidv4 } from 'uuid'; -import { modelIdToLoRAModelField } from '../modelIdToLoRAName'; -import { modelIdToMainModelField } from '../modelIdToMainModelField'; -import { modelIdToVAEModelField } from '../modelIdToVAEModelField'; /** * We need to do special handling for some fields @@ -29,19 +29,19 @@ export const parseFieldValue = (field: InputFieldValue) => { if (field.type === 'model') { if (field.value) { - return modelIdToMainModelField(field.value); + return modelIdToMainModelParam(field.value); } } if (field.type === 'vae_model') { if (field.value) { - return modelIdToVAEModelField(field.value); + return modelIdToVAEModelParam(field.value); } } if (field.type === 'lora_model') { if (field.value) { - return modelIdToLoRAModelField(field.value); + return modelIdToLoRAModelParam(field.value); } } diff --git a/invokeai/frontend/web/src/features/nodes/util/modelIdToLoRAName.ts b/invokeai/frontend/web/src/features/nodes/util/modelIdToLoRAName.ts deleted file mode 100644 index 052b58484b..0000000000 --- a/invokeai/frontend/web/src/features/nodes/util/modelIdToLoRAName.ts +++ /dev/null @@ -1,12 +0,0 @@ -import { BaseModelType, LoRAModelField } from 'services/api/types'; - -export const modelIdToLoRAModelField = (loraId: string): LoRAModelField => { - const [base_model, model_type, model_name] = loraId.split('/'); - - const field: LoRAModelField = { - base_model: base_model as BaseModelType, - model_name, - }; - - return field; -}; diff --git a/invokeai/frontend/web/src/features/nodes/util/modelIdToMainModelField.ts b/invokeai/frontend/web/src/features/nodes/util/modelIdToMainModelField.ts deleted file mode 100644 index 6bb0f776b2..0000000000 --- a/invokeai/frontend/web/src/features/nodes/util/modelIdToMainModelField.ts +++ /dev/null @@ -1,16 +0,0 @@ -import { BaseModelType, MainModelField } from 'services/api/types'; - -/** - * Crudely converts a model id to a main model field - * TODO: Make better - */ -export const modelIdToMainModelField = (modelId: string): MainModelField => { - const [base_model, model_type, model_name] = modelId.split('/'); - - const field: MainModelField = { - base_model: base_model as BaseModelType, - model_name, - }; - - return field; -}; diff --git a/invokeai/frontend/web/src/features/nodes/util/modelIdToVAEModelField.ts b/invokeai/frontend/web/src/features/nodes/util/modelIdToVAEModelField.ts deleted file mode 100644 index 0cb608a936..0000000000 --- a/invokeai/frontend/web/src/features/nodes/util/modelIdToVAEModelField.ts +++ /dev/null @@ -1,16 +0,0 @@ -import { BaseModelType, VAEModelField } from 'services/api/types'; - -/** - * Crudely converts a model id to a main model field - * TODO: Make better - */ -export const modelIdToVAEModelField = (modelId: string): VAEModelField => { - const [base_model, model_type, model_name] = modelId.split('/'); - - const field: VAEModelField = { - base_model: base_model as BaseModelType, - model_name, - }; - - return field; -}; diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamModelandVAEandScheduler.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamModelandVAEandScheduler.tsx index 9054afcca2..74418de1d3 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamModelandVAEandScheduler.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamModelandVAEandScheduler.tsx @@ -1,8 +1,8 @@ import { Box, Flex } from '@chakra-ui/react'; -import ModelSelect from 'features/system/components/ModelSelect'; -import VAESelect from 'features/system/components/VAESelect'; +import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import { memo } from 'react'; -import { useFeatureStatus } from '../../../../system/hooks/useFeatureStatus'; +import ParamMainModelSelect from '../MainModel/ParamMainModelSelect'; +import ParamVAEModelSelect from '../VAEModel/ParamVAEModelSelect'; import ParamScheduler from './ParamScheduler'; const ParamModelandVAEandScheduler = () => { @@ -11,12 +11,12 @@ const ParamModelandVAEandScheduler = () => { return ( - + {isVaeEnabled && ( - + )} diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx index 8818dcba9b..be8db632bc 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx @@ -5,7 +5,7 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; import { generationSelector } from 'features/parameters/store/generationSelectors'; import { setScheduler } from 'features/parameters/store/generationSlice'; -import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas'; +import { SchedulerParam } from 'features/parameters/types/parameterSchemas'; import { uiSelector } from 'features/ui/store/uiSelectors'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; diff --git a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/MainModel/ParamMainModelSelect.tsx similarity index 72% rename from invokeai/frontend/web/src/features/system/components/ModelSelect.tsx rename to invokeai/frontend/web/src/features/parameters/components/Parameters/MainModel/ParamMainModelSelect.tsx index bc3da20b06..dbe732fc55 100644 --- a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/MainModel/ParamMainModelSelect.tsx @@ -8,27 +8,23 @@ import { SelectItem } from '@mantine/core'; import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import { modelIdToMainModelField } from 'features/nodes/util/modelIdToMainModelField'; import { modelSelected } from 'features/parameters/store/actions'; +import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; +import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam'; import { forEach } from 'lodash-es'; import { useGetMainModelsQuery } from 'services/api/endpoints/models'; -export const MODEL_TYPE_MAP = { - 'sd-1': 'Stable Diffusion 1.x', - 'sd-2': 'Stable Diffusion 2.x', -}; - const selector = createSelector( stateSelector, - (state) => ({ currentModel: state.generation.model }), + (state) => ({ model: state.generation.model }), defaultSelectorOptions ); -const ModelSelect = () => { +const ParamMainModelSelect = () => { const dispatch = useAppDispatch(); const { t } = useTranslation(); - const { currentModel } = useAppSelector(selector); + const { model } = useAppSelector(selector); const { data: mainModels, isLoading } = useGetMainModelsQuery(); @@ -54,12 +50,13 @@ const ModelSelect = () => { return data; }, [mainModels]); + // grab the full model entity from the RTK Query cache + // TODO: maybe we should just store the full model entity in state? const selectedModel = useMemo( () => - mainModels?.entities[ - `${currentModel?.base_model}/main/${currentModel?.model_name}` - ], - [mainModels?.entities, currentModel] + mainModels?.entities[`${model?.base_model}/main/${model?.model_name}`] ?? + null, + [mainModels?.entities, model] ); const handleChangeModel = useCallback( @@ -68,8 +65,13 @@ const ModelSelect = () => { return; } - const modelField = modelIdToMainModelField(v); - dispatch(modelSelected(modelField)); + const newModel = modelIdToMainModelParam(v); + + if (!newModel) { + return; + } + + dispatch(modelSelected(newModel)); }, [dispatch] ); @@ -95,4 +97,4 @@ const ModelSelect = () => { ); }; -export default memo(ModelSelect); +export default memo(ParamMainModelSelect); diff --git a/invokeai/frontend/web/src/features/system/components/VAESelect.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/VAEModel/ParamVAEModelSelect.tsx similarity index 54% rename from invokeai/frontend/web/src/features/system/components/VAESelect.tsx rename to invokeai/frontend/web/src/features/parameters/components/Parameters/VAEModel/ParamVAEModelSelect.tsx index bed1b72123..d1e040e181 100644 --- a/invokeai/frontend/web/src/features/system/components/VAESelect.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/VAEModel/ParamVAEModelSelect.tsx @@ -1,4 +1,4 @@ -import { memo, useCallback, useEffect, useMemo } from 'react'; +import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; @@ -8,26 +8,30 @@ import { SelectItem } from '@mantine/core'; import { forEach } from 'lodash-es'; import { useGetVaeModelsQuery } from 'services/api/endpoints/models'; -import { RootState } from 'app/store/store'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip'; import { vaeSelected } from 'features/parameters/store/generationSlice'; -import { zVaeModel } from 'features/parameters/store/parameterZodSchemas'; -import { MODEL_TYPE_MAP } from './ModelSelect'; +import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; +import { modelIdToVAEModelParam } from 'features/parameters/util/modelIdToVAEModelParam'; -const VAESelect = () => { +const selector = createSelector( + stateSelector, + ({ generation }) => { + const { model, vae } = generation; + return { model, vae }; + }, + defaultSelectorOptions +); + +const ParamVAEModelSelect = () => { const dispatch = useAppDispatch(); const { t } = useTranslation(); + const { model, vae } = useAppSelector(selector); const { data: vaeModels } = useGetVaeModelsQuery(); - const currentMainModel = useAppSelector( - (state: RootState) => state.generation.model - ); - - const selectedVae = useAppSelector( - (state: RootState) => state.generation.vae - ); - const data = useMemo(() => { if (!vaeModels) { return []; @@ -41,30 +45,32 @@ const VAESelect = () => { }, ]; - forEach(vaeModels.entities, (model, id) => { - if (!model) { + forEach(vaeModels.entities, (vae, id) => { + if (!vae) { return; } - const disabled = currentMainModel?.base_model !== model.base_model; + const disabled = model?.base_model !== vae.base_model; data.push({ value: id, - label: model.model_name, - group: MODEL_TYPE_MAP[model.base_model], + label: vae.model_name, + group: MODEL_TYPE_MAP[vae.base_model], disabled, tooltip: disabled - ? `Incompatible base model: ${model.base_model}` + ? `Incompatible base model: ${vae.base_model}` : undefined, }); }); return data.sort((a, b) => (a.disabled && !b.disabled ? 1 : -1)); - }, [vaeModels, currentMainModel?.base_model]); + }, [vaeModels, model?.base_model]); + // grab the full model entity from the RTK Query cache const selectedVaeModel = useMemo( - () => (selectedVae?.id ? vaeModels?.entities[selectedVae?.id] : null), - [vaeModels?.entities, selectedVae] + () => + vaeModels?.entities[`${vae?.base_model}/vae/${vae?.model_name}`] ?? null, + [vaeModels?.entities, vae] ); const handleChangeModel = useCallback( @@ -74,32 +80,23 @@ const VAESelect = () => { return; } - const [base_model, type, name] = v.split('/'); + const newVaeModel = modelIdToVAEModelParam(v); - const model = zVaeModel.parse({ - id: v, - name, - base_model, - }); + if (!newVaeModel) { + return; + } - dispatch(vaeSelected(model)); + dispatch(vaeSelected(newVaeModel)); }, [dispatch] ); - useEffect(() => { - if (selectedVae && vaeModels?.ids.includes(selectedVae.id)) { - return; - } - dispatch(vaeSelected(null)); - }, [handleChangeModel, vaeModels?.ids, selectedVae, dispatch]); - return ( { ); }; -export default memo(VAESelect); +export default memo(ParamVAEModelSelect); diff --git a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts index 9e4f5aeff0..6329d9d677 100644 --- a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts +++ b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts @@ -28,7 +28,7 @@ import { isValidSteps, isValidStrength, isValidWidth, -} from '../store/parameterZodSchemas'; +} from '../types/parameterSchemas'; export const useRecallParameters = () => { const dispatch = useAppDispatch(); diff --git a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts index dff277ae7e..c5ec7930a4 100644 --- a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts +++ b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts @@ -13,6 +13,7 @@ import { clipSkipMap } from '../components/Parameters/Advanced/ParamClipSkip'; import { CfgScaleParam, HeightParam, + MainModelParam, NegativePromptParam, PositivePromptParam, SchedulerParam, @@ -22,7 +23,7 @@ import { VaeModelParam, WidthParam, zMainModel, -} from './parameterZodSchemas'; +} from '../types/parameterSchemas'; export interface GenerationState { cfgScale: CfgScaleParam; @@ -226,18 +227,19 @@ export const generationSlice = createSlice({ const { image_name, width, height } = action.payload; state.initialImage = { imageName: image_name, width, height }; }, - modelChanged: (state, action: PayloadAction) => { - if (!action.payload) { - state.model = null; - } + modelChanged: (state, action: PayloadAction) => { + state.model = action.payload; - state.model = zMainModel.parse(action.payload); + if (state.model === null) { + return; + } // Clamp ClipSkip Based On Selected Model const { maxClip } = clipSkipMap[state.model.base_model]; state.clipSkip = clamp(state.clipSkip, 0, maxClip); }, vaeSelected: (state, action: PayloadAction) => { + // null is a valid VAE! state.vae = action.payload; }, setClipSkip: (state, action: PayloadAction) => { @@ -253,11 +255,15 @@ export const generationSlice = createSlice({ if (defaultModel && !state.model) { const [base_model, model_type, model_name] = defaultModel.split('/'); - state.model = zMainModel.parse({ - id: defaultModel, - name: model_name, + + const result = zMainModel.safeParse({ + model_name, base_model, }); + + if (result.success) { + state.model = result.data; + } } }); builder.addCase(setShouldShowAdvancedOptions, (state, action) => { diff --git a/invokeai/frontend/web/src/features/parameters/types/constants.ts b/invokeai/frontend/web/src/features/parameters/types/constants.ts new file mode 100644 index 0000000000..56f808738d --- /dev/null +++ b/invokeai/frontend/web/src/features/parameters/types/constants.ts @@ -0,0 +1,4 @@ +export const MODEL_TYPE_MAP = { + 'sd-1': 'Stable Diffusion 1.x', + 'sd-2': 'Stable Diffusion 2.x', +}; diff --git a/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts b/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts similarity index 98% rename from invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts rename to invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts index 16fbf0e155..aa2c60f3a8 100644 --- a/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts +++ b/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts @@ -135,7 +135,7 @@ export type BaseModelParam = z.infer; * TODO: Make this a dynamically generated enum? */ export const zMainModel = z.object({ - model_name: z.string(), + model_name: z.string().min(1), base_model: zBaseModel, }); @@ -152,8 +152,7 @@ export const isValidMainModel = (val: unknown): val is MainModelParam => * Zod schema for VAE parameter */ export const zVaeModel = z.object({ - id: z.string(), - name: z.string(), + model_name: z.string().min(1), base_model: zBaseModel, }); /** @@ -169,8 +168,7 @@ export const isValidVaeModel = (val: unknown): val is VaeModelParam => * Zod schema for LoRA */ export const zLoRAModel = z.object({ - id: z.string(), - model_name: z.string(), + model_name: z.string().min(1), base_model: zBaseModel, }); /** diff --git a/invokeai/frontend/web/src/features/parameters/util/modelIdToLoRAModelParam.ts b/invokeai/frontend/web/src/features/parameters/util/modelIdToLoRAModelParam.ts new file mode 100644 index 0000000000..2ea7cacb5d --- /dev/null +++ b/invokeai/frontend/web/src/features/parameters/util/modelIdToLoRAModelParam.ts @@ -0,0 +1,18 @@ +import { LoRAModelParam, zLoRAModel } from '../types/parameterSchemas'; + +export const modelIdToLoRAModelParam = ( + loraId: string +): LoRAModelParam | undefined => { + const [base_model, model_type, model_name] = loraId.split('/'); + + const result = zLoRAModel.safeParse({ + base_model, + model_name, + }); + + if (!result.success) { + return; + } + + return result.data; +}; diff --git a/invokeai/frontend/web/src/features/parameters/util/modelIdToMainModelParam.ts b/invokeai/frontend/web/src/features/parameters/util/modelIdToMainModelParam.ts new file mode 100644 index 0000000000..b73d3c5f0d --- /dev/null +++ b/invokeai/frontend/web/src/features/parameters/util/modelIdToMainModelParam.ts @@ -0,0 +1,21 @@ +import { + MainModelParam, + zMainModel, +} from 'features/parameters/types/parameterSchemas'; + +export const modelIdToMainModelParam = ( + modelId: string +): MainModelParam | undefined => { + const [base_model, model_type, model_name] = modelId.split('/'); + + const result = zMainModel.safeParse({ + base_model, + model_name, + }); + + if (!result.success) { + return; + } + + return result.data; +}; diff --git a/invokeai/frontend/web/src/features/parameters/util/modelIdToVAEModelParam.ts b/invokeai/frontend/web/src/features/parameters/util/modelIdToVAEModelParam.ts new file mode 100644 index 0000000000..49856531d6 --- /dev/null +++ b/invokeai/frontend/web/src/features/parameters/util/modelIdToVAEModelParam.ts @@ -0,0 +1,18 @@ +import { VaeModelParam, zVaeModel } from '../types/parameterSchemas'; + +export const modelIdToVAEModelParam = ( + modelId: string +): VaeModelParam | undefined => { + const [base_model, model_type, model_name] = modelId.split('/'); + + const result = zVaeModel.safeParse({ + base_model, + model_name, + }); + + if (!result.success) { + return; + } + + return result.data; +}; diff --git a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx index 26c11604e1..959559548e 100644 --- a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx +++ b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx @@ -2,7 +2,7 @@ import { SCHEDULER_LABEL_MAP, SCHEDULER_NAMES } from 'app/constants'; import { RootState } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect'; -import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas'; +import { SchedulerParam } from 'features/parameters/types/parameterSchemas'; import { favoriteSchedulersChanged } from 'features/ui/store/uiSlice'; import { map } from 'lodash-es'; import { useCallback } from 'react'; diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx index 586be4566e..c101b68d45 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx @@ -10,7 +10,7 @@ import type { RootState } from 'app/store/store'; import IAIButton from 'common/components/IAIButton'; import IAIInput from 'common/components/IAIInput'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; -import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect'; +import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { S } from 'services/api/types'; import ModelConvert from './ModelConvert'; diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx index f0ed12d361..e5b6fd625f 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx @@ -10,7 +10,7 @@ import type { RootState } from 'app/store/store'; import IAIButton from 'common/components/IAIButton'; import IAIInput from 'common/components/IAIInput'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; -import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect'; +import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { S } from 'services/api/types'; type DiffusersModel = diff --git a/invokeai/frontend/web/src/features/ui/store/uiSlice.ts b/invokeai/frontend/web/src/features/ui/store/uiSlice.ts index 4f38f84fe2..ccce10f5c4 100644 --- a/invokeai/frontend/web/src/features/ui/store/uiSlice.ts +++ b/invokeai/frontend/web/src/features/ui/store/uiSlice.ts @@ -1,7 +1,7 @@ import type { PayloadAction } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit'; import { initialImageChanged } from 'features/parameters/store/generationSlice'; -import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas'; +import { SchedulerParam } from 'features/parameters/types/parameterSchemas'; import { setActiveTabReducer } from './extraReducers'; import { InvokeTabName } from './tabMap'; import { AddNewModelType, UIState } from './uiTypes'; diff --git a/invokeai/frontend/web/src/features/ui/store/uiTypes.ts b/invokeai/frontend/web/src/features/ui/store/uiTypes.ts index e574f0ab79..4c72bd6239 100644 --- a/invokeai/frontend/web/src/features/ui/store/uiTypes.ts +++ b/invokeai/frontend/web/src/features/ui/store/uiTypes.ts @@ -1,4 +1,4 @@ -import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas'; +import { SchedulerParam } from 'features/parameters/types/parameterSchemas'; export type AddNewModelType = 'ckpt' | 'diffusers' | null; diff --git a/invokeai/frontend/web/src/index.ts b/invokeai/frontend/web/src/index.ts index e70e756ed9..add4999b6d 100644 --- a/invokeai/frontend/web/src/index.ts +++ b/invokeai/frontend/web/src/index.ts @@ -2,8 +2,8 @@ export { default as InvokeAIUI } from './app/components/InvokeAIUI'; export type { PartialAppConfig } from './app/types/invokeai'; export { default as IAIIconButton } from './common/components/IAIIconButton'; export { default as IAIPopover } from './common/components/IAIPopover'; +export { default as ParamMainModelSelect } from './features/parameters/components/Parameters/MainModel/ParamMainModelSelect'; +export { default as ColorModeButton } from './features/system/components/ColorModeButton'; export { default as InvokeAiLogoComponent } from './features/system/components/InvokeAILogoComponent'; -export { default as ModelSelect } from './features/system/components/ModelSelect'; export { default as SettingsModal } from './features/system/components/SettingsModal/SettingsModal'; export { default as StatusIndicator } from './features/system/components/StatusIndicator'; -export { default as ColorModeButton } from './features/system/components/ColorModeButton';