diff --git a/invokeai/frontend/web/src/app/store/nanostores/store.ts b/invokeai/frontend/web/src/app/store/nanostores/store.ts index aee0f0e6ef..f4cd001c96 100644 --- a/invokeai/frontend/web/src/app/store/nanostores/store.ts +++ b/invokeai/frontend/web/src/app/store/nanostores/store.ts @@ -8,4 +8,26 @@ declare global { } } +/** + * Raised when the redux store is unable to be retrieved. + */ +export class ReduxStoreNotInitialized extends Error { + /** + * Create ReduxStoreNotInitialized + * @param {String} message + */ + constructor(message = 'Redux store not initialized') { + super(message); + this.name = this.constructor.name; + } +} + export const $store = atom> | undefined>(); + +export const getStore = () => { + const store = $store.get(); + if (!store) { + throw new ReduxStoreNotInitialized(); + } + return store; +}; diff --git a/invokeai/frontend/web/src/features/controlAdapters/components/parameters/ParamControlAdapterModel.tsx b/invokeai/frontend/web/src/features/controlAdapters/components/parameters/ParamControlAdapterModel.tsx index 696bf47b2a..75372c350d 100644 --- a/invokeai/frontend/web/src/features/controlAdapters/components/parameters/ParamControlAdapterModel.tsx +++ b/invokeai/frontend/web/src/features/controlAdapters/components/parameters/ParamControlAdapterModel.tsx @@ -8,7 +8,7 @@ import { useControlAdapterType } from 'features/controlAdapters/hooks/useControl import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice'; import { pick } from 'lodash-es'; import { memo, useCallback, useMemo } from 'react'; -import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'services/api/types'; +import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types'; type ParamControlAdapterModelProps = { id: string; @@ -24,7 +24,7 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => { const { data, isLoading } = useControlAdapterModelQuery(controlAdapterType); const _onChange = useCallback( - (model: ControlNetConfig | IPAdapterConfig | T2IAdapterConfig | null) => { + (model: ControlNetModelConfig | IPAdapterModelConfig | T2IAdapterModelConfig | null) => { if (!model) { return; } diff --git a/invokeai/frontend/web/src/features/embedding/EmbeddingSelect.tsx b/invokeai/frontend/web/src/features/embedding/EmbeddingSelect.tsx index fd05edc466..a5ad358fa0 100644 --- a/invokeai/frontend/web/src/features/embedding/EmbeddingSelect.tsx +++ b/invokeai/frontend/web/src/features/embedding/EmbeddingSelect.tsx @@ -7,7 +7,7 @@ import { t } from 'i18next'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models'; -import type { TextualInversionConfig } from 'services/api/types'; +import type { TextualInversionModelConfig } from 'services/api/types'; const noOptionsMessage = () => t('embedding.noMatchingEmbedding'); @@ -17,7 +17,7 @@ export const EmbeddingSelect = memo(({ onSelect, onClose }: EmbeddingSelectProps const currentBaseModel = useAppSelector((s) => s.generation.model?.base); const getIsDisabled = useCallback( - (embedding: TextualInversionConfig): boolean => { + (embedding: TextualInversionModelConfig): boolean => { const isCompatible = currentBaseModel === embedding.base; const hasMainModel = Boolean(currentBaseModel); return !hasMainModel || !isCompatible; @@ -27,7 +27,7 @@ export const EmbeddingSelect = memo(({ onSelect, onClose }: EmbeddingSelectProps const { data, isLoading } = useGetTextualInversionModelsQuery(); const _onChange = useCallback( - (embedding: TextualInversionConfig | null) => { + (embedding: TextualInversionModelConfig | null) => { if (!embedding) { return; } diff --git a/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx b/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx index b58751ca5e..e7d40c5eaf 100644 --- a/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx +++ b/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx @@ -8,7 +8,7 @@ import { loraAdded, selectLoraSlice } from 'features/lora/store/loraSlice'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useGetLoRAModelsQuery } from 'services/api/endpoints/models'; -import type { LoRAConfig } from 'services/api/types'; +import type { LoRAModelConfig } from 'services/api/types'; const selectAddedLoRAs = createMemoizedSelector(selectLoraSlice, (lora) => lora.loras); @@ -19,7 +19,7 @@ const LoRASelect = () => { const addedLoRAs = useAppSelector(selectAddedLoRAs); const currentBaseModel = useAppSelector((s) => s.generation.model?.base); - const getIsDisabled = (lora: LoRAConfig): boolean => { + const getIsDisabled = (lora: LoRAModelConfig): boolean => { const isCompatible = currentBaseModel === lora.base; const isAdded = Boolean(addedLoRAs[lora.key]); const hasMainModel = Boolean(currentBaseModel); @@ -27,7 +27,7 @@ const LoRASelect = () => { }; const _onChange = useCallback( - (lora: LoRAConfig | null) => { + (lora: LoRAModelConfig | null) => { if (!lora) { return; } diff --git a/invokeai/frontend/web/src/features/lora/store/loraSlice.ts b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts index dd455e12c3..377406b3e5 100644 --- a/invokeai/frontend/web/src/features/lora/store/loraSlice.ts +++ b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts @@ -2,7 +2,7 @@ import type { PayloadAction } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit'; import type { PersistConfig, RootState } from 'app/store/store'; import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas'; -import type { LoRAConfig } from 'services/api/types'; +import type { LoRAModelConfig } from 'services/api/types'; export type LoRA = ParameterLoRAModel & { weight: number; @@ -28,13 +28,12 @@ export const loraSlice = createSlice({ name: 'lora', initialState: initialLoraState, reducers: { - loraAdded: (state, action: PayloadAction) => { + loraAdded: (state, action: PayloadAction) => { const { key, base } = action.payload; state.loras[key] = { key, base, ...defaultLoRAConfig }; }, - loraRecalled: (state, action: PayloadAction) => { - const { key, base, weight } = action.payload; - state.loras[key] = { key, base, weight, isEnabled: true }; + loraRecalled: (state, action: PayloadAction) => { + state.loras[action.payload.key] = action.payload; }, loraRemoved: (state, action: PayloadAction) => { const key = action.payload; diff --git a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/LoRAModelEdit.tsx b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/LoRAModelEdit.tsx index 75151cd001..1a8f235aaf 100644 --- a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/LoRAModelEdit.tsx +++ b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/LoRAModelEdit.tsx @@ -8,12 +8,11 @@ import { memo, useCallback } from 'react'; import type { SubmitHandler } from 'react-hook-form'; import { useForm } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; -import type { LoRAConfig } from 'services/api/endpoints/models'; import { useUpdateLoRAModelsMutation } from 'services/api/endpoints/models'; -import type { LoRAConfig } from 'services/api/types'; +import type { LoRAModelConfig } from 'services/api/types'; type LoRAModelEditProps = { - model: LoRAConfig; + model: LoRAModelConfig; }; const LoRAModelEdit = (props: LoRAModelEditProps) => { @@ -30,7 +29,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => { control, formState: { errors }, reset, - } = useForm({ + } = useForm({ defaultValues: { model_name: model.model_name ? model.model_name : '', base_model: model.base_model, @@ -42,7 +41,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => { mode: 'onChange', }); - const onSubmit = useCallback>( + const onSubmit = useCallback>( (values) => { const responseBody = { base_model: model.base_model, @@ -53,7 +52,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => { updateLoRAModel(responseBody) .unwrap() .then((payload) => { - reset(payload as LoRAConfig, { keepDefaultValues: true }); + reset(payload as LoRAModelConfig, { keepDefaultValues: true }); dispatch( addToast( makeToast({ @@ -106,7 +105,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => { {t('modelManager.description')} - control={control} name="base_model" /> + control={control} name="base_model" /> {t('modelManager.modelLocation')} diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlNetModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlNetModelFieldInputComponent.tsx index 1951ec60d3..29a1f93dd5 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlNetModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlNetModelFieldInputComponent.tsx @@ -6,7 +6,7 @@ import type { ControlNetModelFieldInputInstance, ControlNetModelFieldInputTempla import { pick } from 'lodash-es'; import { memo, useCallback } from 'react'; import { useGetControlNetModelsQuery } from 'services/api/endpoints/models'; -import type { ControlNetConfig } from 'services/api/types'; +import type { ControlNetModelConfig } from 'services/api/types'; import type { FieldComponentProps } from './types'; @@ -18,7 +18,7 @@ const ControlNetModelFieldInputComponent = (props: Props) => { const { data, isLoading } = useGetControlNetModelsQuery(); const _onChange = useCallback( - (value: ControlNetConfig | null) => { + (value: ControlNetModelConfig | null) => { if (!value) { return; } diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/IPAdapterModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/IPAdapterModelFieldInputComponent.tsx index 137f751fca..d4f0ae3de1 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/IPAdapterModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/IPAdapterModelFieldInputComponent.tsx @@ -6,7 +6,7 @@ import type { IPAdapterModelFieldInputInstance, IPAdapterModelFieldInputTemplate import { pick } from 'lodash-es'; import { memo, useCallback } from 'react'; import { useGetIPAdapterModelsQuery } from 'services/api/endpoints/models'; -import type { IPAdapterConfig } from 'services/api/types'; +import type { IPAdapterModelConfig } from 'services/api/types'; import type { FieldComponentProps } from './types'; @@ -18,7 +18,7 @@ const IPAdapterModelFieldInputComponent = ( const { data: ipAdapterModels } = useGetIPAdapterModelsQuery(); const _onChange = useCallback( - (value: IPAdapterConfig | null) => { + (value: IPAdapterModelConfig | null) => { if (!value) { return; } diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/LoRAModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/LoRAModelFieldInputComponent.tsx index 5f6318de9e..9fd223e694 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/LoRAModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/LoRAModelFieldInputComponent.tsx @@ -6,7 +6,7 @@ import type { LoRAModelFieldInputInstance, LoRAModelFieldInputTemplate } from 'f import { pick } from 'lodash-es'; import { memo, useCallback } from 'react'; import { useGetLoRAModelsQuery } from 'services/api/endpoints/models'; -import type { LoRAConfig } from 'services/api/types'; +import type { LoRAModelConfig } from 'services/api/types'; import type { FieldComponentProps } from './types'; @@ -17,7 +17,7 @@ const LoRAModelFieldInputComponent = (props: Props) => { const dispatch = useAppDispatch(); const { data, isLoading } = useGetLoRAModelsQuery(); const _onChange = useCallback( - (value: LoRAConfig | null) => { + (value: LoRAModelConfig | null) => { if (!value) { return; } diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T2IAdapterModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T2IAdapterModelFieldInputComponent.tsx index 9115f22c14..a38356a0b8 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T2IAdapterModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T2IAdapterModelFieldInputComponent.tsx @@ -6,7 +6,7 @@ import type { T2IAdapterModelFieldInputInstance, T2IAdapterModelFieldInputTempla import { pick } from 'lodash-es'; import { memo, useCallback } from 'react'; import { useGetT2IAdapterModelsQuery } from 'services/api/endpoints/models'; -import type { T2IAdapterConfig } from 'services/api/types'; +import type { T2IAdapterModelConfig } from 'services/api/types'; import type { FieldComponentProps } from './types'; @@ -19,7 +19,7 @@ const T2IAdapterModelFieldInputComponent = ( const { data: t2iAdapterModels } = useGetT2IAdapterModelsQuery(); const _onChange = useCallback( - (value: T2IAdapterConfig | null) => { + (value: T2IAdapterModelConfig | null) => { if (!value) { return; } diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/VAEModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/VAEModelFieldInputComponent.tsx index 87272f48b9..272f7f5b35 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/VAEModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/VAEModelFieldInputComponent.tsx @@ -7,7 +7,7 @@ import type { VAEModelFieldInputInstance, VAEModelFieldInputTemplate } from 'fea import { pick } from 'lodash-es'; import { memo, useCallback } from 'react'; import { useGetVaeModelsQuery } from 'services/api/endpoints/models'; -import type { VAEConfig } from 'services/api/types'; +import type { VAEModelConfig } from 'services/api/types'; import type { FieldComponentProps } from './types'; @@ -18,7 +18,7 @@ const VAEModelFieldInputComponent = (props: Props) => { const dispatch = useAppDispatch(); const { data, isLoading } = useGetVaeModelsQuery(); const _onChange = useCallback( - (value: VAEConfig | null) => { + (value: VAEModelConfig | null) => { if (!value) { return; } diff --git a/invokeai/frontend/web/src/features/parameters/components/VAEModel/ParamVAEModelSelect.tsx b/invokeai/frontend/web/src/features/parameters/components/VAEModel/ParamVAEModelSelect.tsx index 4b9f2764bf..4a630fa9ce 100644 --- a/invokeai/frontend/web/src/features/parameters/components/VAEModel/ParamVAEModelSelect.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/VAEModel/ParamVAEModelSelect.tsx @@ -8,7 +8,7 @@ import { pick } from 'lodash-es'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { useGetVaeModelsQuery } from 'services/api/endpoints/models'; -import type { VAEConfig } from 'services/api/types'; +import type { VAEModelConfig } from 'services/api/types'; const selector = createMemoizedSelector(selectGenerationSlice, (generation) => { const { model, vae } = generation; @@ -21,7 +21,7 @@ const ParamVAEModelSelect = () => { const { model, vae } = useAppSelector(selector); const { data, isLoading } = useGetVaeModelsQuery(); const getIsDisabled = useCallback( - (vae: VAEConfig): boolean => { + (vae: VAEModelConfig): boolean => { const isCompatible = model?.base === vae.base; const hasMainModel = Boolean(model?.base); return !hasMainModel || !isCompatible; @@ -29,7 +29,7 @@ const ParamVAEModelSelect = () => { [model?.base] ); const _onChange = useCallback( - (vae: VAEConfig | null) => { + (vae: VAEModelConfig | null) => { dispatch(vaeSelected(vae ? pick(vae, 'key', 'base') : null)); }, [dispatch] diff --git a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts index 0d464cd9b9..0929fc1dc3 100644 --- a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts +++ b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts @@ -1,17 +1,9 @@ import { useAppToaster } from 'app/components/Toaster'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants'; import { controlAdapterRecalled, controlAdaptersReset } from 'features/controlAdapters/store/controlAdaptersSlice'; -import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types'; -import { - initialControlNet, - initialIPAdapter, - initialT2IAdapter, -} from 'features/controlAdapters/util/buildControlAdapter'; import { setHrfEnabled, setHrfMethod, setHrfStrength } from 'features/hrf/store/hrfSlice'; import { loraRecalled, lorasCleared } from 'features/lora/store/loraSlice'; -import type { ModelIdentifier } from 'features/nodes/types/common'; import { isModelIdentifier } from 'features/nodes/types/common'; import type { ControlNetMetadataItem, @@ -56,6 +48,14 @@ import { isParameterStrength, isParameterWidth, } from 'features/parameters/types/parameterSchemas'; +import { + prepareControlNetMetadataItem, + prepareIPAdapterMetadataItem, + prepareLoRAMetadataItem, + prepareMainModelMetadataItem, + prepareT2IAdapterMetadataItem, + prepareVAEMetadataItem, +} from 'features/parameters/util/modelMetadataHelpers'; import { refinerModelChanged, setNegativeStylePromptSDXL, @@ -70,23 +70,7 @@ import { import { isNil } from 'lodash-es'; import { useCallback } from 'react'; import { useTranslation } from 'react-i18next'; -import { ALL_BASE_MODELS } from 'services/api/constants'; -import { - controlNetModelsAdapterSelectors, - ipAdapterModelsAdapterSelectors, - loraModelsAdapterSelectors, - mainModelsAdapterSelectors, - t2iAdapterModelsAdapterSelectors, - useGetControlNetModelsQuery, - useGetIPAdapterModelsQuery, - useGetLoRAModelsQuery, - useGetMainModelsQuery, - useGetT2IAdapterModelsQuery, - useGetVaeModelsQuery, - vaeModelsAdapterSelectors, -} from 'services/api/endpoints/models'; import type { ImageDTO } from 'services/api/types'; -import { v4 as uuidv4 } from 'uuid'; const selectModel = createMemoizedSelector(selectGenerationSlice, (generation) => generation.model); @@ -140,9 +124,6 @@ export const useRecallParameters = () => { [t, toaster] ); - /** - * Recall both prompts with toast - */ const recallBothPrompts = useCallback( (positivePrompt: unknown, negativePrompt: unknown, positiveStylePrompt: unknown, negativeStylePrompt: unknown) => { if ( @@ -175,9 +156,6 @@ export const useRecallParameters = () => { [dispatch, parameterSetToast, parameterNotSetToast] ); - /** - * Recall positive prompt with toast - */ const recallPositivePrompt = useCallback( (positivePrompt: unknown) => { if (!isParameterPositivePrompt(positivePrompt)) { @@ -190,9 +168,6 @@ export const useRecallParameters = () => { [dispatch, parameterSetToast, parameterNotSetToast] ); - /** - * Recall negative prompt with toast - */ const recallNegativePrompt = useCallback( (negativePrompt: unknown) => { if (!isParameterNegativePrompt(negativePrompt)) { @@ -205,9 +180,6 @@ export const useRecallParameters = () => { [dispatch, parameterSetToast, parameterNotSetToast] ); - /** - * Recall SDXL Positive Style Prompt with toast - */ const recallSDXLPositiveStylePrompt = useCallback( (positiveStylePrompt: unknown) => { if (!isParameterPositiveStylePromptSDXL(positiveStylePrompt)) { @@ -220,9 +192,6 @@ export const useRecallParameters = () => { [dispatch, parameterSetToast, parameterNotSetToast] ); - /** - * Recall SDXL Negative Style Prompt with toast - */ const recallSDXLNegativeStylePrompt = useCallback( (negativeStylePrompt: unknown) => { if (!isParameterNegativeStylePromptSDXL(negativeStylePrompt)) { @@ -235,9 +204,6 @@ export const useRecallParameters = () => { [dispatch, parameterSetToast, parameterNotSetToast] ); - /** - * Recall seed with toast - */ const recallSeed = useCallback( (seed: unknown) => { if (!isParameterSeed(seed)) { @@ -250,9 +216,6 @@ export const useRecallParameters = () => { [dispatch, parameterSetToast, parameterNotSetToast] ); - /** - * Recall CFG scale with toast - */ const recallCfgScale = useCallback( (cfgScale: unknown) => { if (!isParameterCFGScale(cfgScale)) { @@ -265,9 +228,6 @@ export const useRecallParameters = () => { [dispatch, parameterSetToast, parameterNotSetToast] ); - /** - * Recall CFG rescale multiplier with toast - */ const recallCfgRescaleMultiplier = useCallback( (cfgRescaleMultiplier: unknown) => { if (!isParameterCFGRescaleMultiplier(cfgRescaleMultiplier)) { @@ -280,9 +240,6 @@ export const useRecallParameters = () => { [dispatch, parameterSetToast, parameterNotSetToast] ); - /** - * Recall scheduler with toast - */ const recallScheduler = useCallback( (scheduler: unknown) => { if (!isParameterScheduler(scheduler)) { @@ -295,9 +252,6 @@ export const useRecallParameters = () => { [dispatch, parameterSetToast, parameterNotSetToast] ); - /** - * Recall steps with toast - */ const recallSteps = useCallback( (steps: unknown) => { if (!isParameterSteps(steps)) { @@ -310,9 +264,6 @@ export const useRecallParameters = () => { [dispatch, parameterSetToast, parameterNotSetToast] ); - /** - * Recall width with toast - */ const recallWidth = useCallback( (width: unknown) => { if (!isParameterWidth(width)) { @@ -325,9 +276,6 @@ export const useRecallParameters = () => { [dispatch, parameterSetToast, parameterNotSetToast] ); - /** - * Recall height with toast - */ const recallHeight = useCallback( (height: unknown) => { if (!isParameterHeight(height)) { @@ -340,9 +288,6 @@ export const useRecallParameters = () => { [dispatch, parameterSetToast, parameterNotSetToast] ); - /** - * Recall width and height with toast - */ const recallWidthAndHeight = useCallback( (width: unknown, height: unknown) => { if (!isParameterWidth(width)) { @@ -360,9 +305,6 @@ export const useRecallParameters = () => { [dispatch, allParameterSetToast, allParameterNotSetToast] ); - /** - * Recall strength with toast - */ const recallStrength = useCallback( (strength: unknown) => { if (!isParameterStrength(strength)) { @@ -375,9 +317,6 @@ export const useRecallParameters = () => { [dispatch, parameterSetToast, parameterNotSetToast] ); - /** - * Recall high resolution enabled with toast - */ const recallHrfEnabled = useCallback( (hrfEnabled: unknown) => { if (!isParameterHRFEnabled(hrfEnabled)) { @@ -390,9 +329,6 @@ export const useRecallParameters = () => { [dispatch, parameterSetToast, parameterNotSetToast] ); - /** - * Recall high resolution strength with toast - */ const recallHrfStrength = useCallback( (hrfStrength: unknown) => { if (!isParameterStrength(hrfStrength)) { @@ -405,9 +341,6 @@ export const useRecallParameters = () => { [dispatch, parameterSetToast, parameterNotSetToast] ); - /** - * Recall high resolution method with toast - */ const recallHrfMethod = useCallback( (hrfMethod: unknown) => { if (!isParameterHRFMethod(hrfMethod)) { @@ -420,358 +353,95 @@ export const useRecallParameters = () => { [dispatch, parameterSetToast, parameterNotSetToast] ); - const { data: mainModels } = useGetMainModelsQuery(ALL_BASE_MODELS); - - const prepareMainModelMetadataItem = useCallback( - (model: ModelIdentifier) => { - const matchingModel = mainModels ? mainModelsAdapterSelectors.selectById(mainModels, model.key) : undefined; - - if (!matchingModel) { - return { model: null, error: 'Model is not installed' }; - } - - return { model: matchingModel, error: null }; - }, - [mainModels] - ); - - /** - * Recall model with toast - */ const recallModel = useCallback( - (model: unknown) => { - if (!isModelIdentifier(model)) { - parameterNotSetToast(); + async (modelMetadataItem: unknown) => { + try { + const model = await prepareMainModelMetadataItem(modelMetadataItem); + dispatch(modelSelected(model)); + parameterSetToast(); + } catch (e) { + parameterNotSetToast((e as unknown as Error).message); return; } - - const result = prepareMainModelMetadataItem(model); - - if (!result.model) { - parameterNotSetToast(result.error); - return; - } - - dispatch(modelSelected(result.model)); - parameterSetToast(); }, - [prepareMainModelMetadataItem, dispatch, parameterSetToast, parameterNotSetToast] + [dispatch, parameterSetToast, parameterNotSetToast] ); - const { data: vaeModels } = useGetVaeModelsQuery(); - - const prepareVAEMetadataItem = useCallback( - (vae: ModelIdentifier, newModel?: ParameterModel) => { - const matchingModel = vaeModels ? vaeModelsAdapterSelectors.selectById(vaeModels, vae.key) : undefined; - if (!matchingModel) { - return { vae: null, error: 'VAE model is not installed' }; - } - const isCompatibleBaseModel = matchingModel?.base === (newModel ?? model)?.base; - - if (!isCompatibleBaseModel) { - return { - vae: null, - error: 'VAE incompatible with currently-selected model', - }; - } - - return { vae: matchingModel, error: null }; - }, - [model, vaeModels] - ); - - /** - * Recall vae model - */ const recallVaeModel = useCallback( - (vae: unknown) => { - if (!isModelIdentifier(vae) && !isNil(vae)) { - parameterNotSetToast(); - return; - } - - if (isNil(vae)) { + async (vaeMetadataItem: unknown) => { + if (isNil(vaeMetadataItem)) { dispatch(vaeSelected(null)); parameterSetToast(); return; } - - const result = prepareVAEMetadataItem(vae); - - if (!result.vae) { - parameterNotSetToast(result.error); + try { + const vae = await prepareVAEMetadataItem(vaeMetadataItem); + dispatch(vaeSelected(vae)); + parameterSetToast(); + } catch (e) { + parameterNotSetToast((e as unknown as Error).message); return; } - - dispatch(vaeSelected(result.vae)); - parameterSetToast(); }, - [prepareVAEMetadataItem, dispatch, parameterSetToast, parameterNotSetToast] - ); - - /** - * Recall LoRA with toast - */ - - const { data: loraModels } = useGetLoRAModelsQuery(undefined); - - const prepareLoRAMetadataItem = useCallback( - (loraMetadataItem: LoRAMetadataItem, newModel?: ParameterModel) => { - if (!isModelIdentifier(loraMetadataItem.lora)) { - return { lora: null, error: 'Invalid LoRA model' }; - } - - const { lora } = loraMetadataItem; - - const matchingLoRA = loraModels ? loraModelsAdapterSelectors.selectById(loraModels, lora.key) : undefined; - - if (!matchingLoRA) { - return { lora: null, error: 'LoRA model is not installed' }; - } - - const isCompatibleBaseModel = matchingLoRA?.base === (newModel ?? model)?.base; - - if (!isCompatibleBaseModel) { - return { - lora: null, - error: 'LoRA incompatible with currently-selected model', - }; - } - - return { lora: matchingLoRA, error: null }; - }, - [loraModels, model] + [dispatch, parameterSetToast, parameterNotSetToast] ); const recallLoRA = useCallback( - (loraMetadataItem: LoRAMetadataItem) => { - const result = prepareLoRAMetadataItem(loraMetadataItem); - - if (!result.lora) { - parameterNotSetToast(result.error); + async (loraMetadataItem: LoRAMetadataItem) => { + try { + const lora = await prepareLoRAMetadataItem(loraMetadataItem, model?.base); + dispatch(loraRecalled(lora)); + parameterSetToast(); + } catch (e) { + parameterNotSetToast((e as unknown as Error).message); return; } - - dispatch(loraRecalled({ ...result.lora, weight: loraMetadataItem.weight })); - - parameterSetToast(); }, - [prepareLoRAMetadataItem, dispatch, parameterSetToast, parameterNotSetToast] - ); - - /** - * Recall ControlNet with toast - */ - - const { data: controlNetModels } = useGetControlNetModelsQuery(undefined); - - const prepareControlNetMetadataItem = useCallback( - (controlnetMetadataItem: ControlNetMetadataItem, newModel?: ParameterModel) => { - if (!isModelIdentifier(controlnetMetadataItem.control_model)) { - return { controlnet: null, error: 'Invalid ControlNet model' }; - } - - const { image, control_model, control_weight, begin_step_percent, end_step_percent, control_mode, resize_mode } = - controlnetMetadataItem; - - const matchingControlNetModel = controlNetModels - ? controlNetModelsAdapterSelectors.selectById(controlNetModels, control_model.key) - : undefined; - - if (!matchingControlNetModel) { - return { controlnet: null, error: 'ControlNet model is not installed' }; - } - - const isCompatibleBaseModel = matchingControlNetModel?.base === (newModel ?? model)?.base; - - if (!isCompatibleBaseModel) { - return { - controlnet: null, - error: 'ControlNet incompatible with currently-selected model', - }; - } - - // We don't save the original image that was processed into a control image, only the processed image - const processorType = 'none'; - const processorNode = CONTROLNET_PROCESSORS.none.default; - - const controlnet: ControlNetConfig = { - type: 'controlnet', - isEnabled: true, - model: matchingControlNetModel, - weight: typeof control_weight === 'number' ? control_weight : initialControlNet.weight, - beginStepPct: begin_step_percent || initialControlNet.beginStepPct, - endStepPct: end_step_percent || initialControlNet.endStepPct, - controlMode: control_mode || initialControlNet.controlMode, - resizeMode: resize_mode || initialControlNet.resizeMode, - controlImage: image?.image_name || null, - processedControlImage: image?.image_name || null, - processorType, - processorNode, - shouldAutoConfig: true, - id: uuidv4(), - }; - - return { controlnet, error: null }; - }, - [controlNetModels, model] + [model?.base, dispatch, parameterSetToast, parameterNotSetToast] ); const recallControlNet = useCallback( - (controlnetMetadataItem: ControlNetMetadataItem) => { - const result = prepareControlNetMetadataItem(controlnetMetadataItem); - - if (!result.controlnet) { - parameterNotSetToast(result.error); + async (controlnetMetadataItem: ControlNetMetadataItem) => { + try { + const controlNetConfig = await prepareControlNetMetadataItem(controlnetMetadataItem, model?.base); + dispatch(controlAdapterRecalled(controlNetConfig)); + parameterSetToast(); + } catch (e) { + parameterNotSetToast((e as unknown as Error).message); return; } - - dispatch(controlAdapterRecalled(result.controlnet)); - - parameterSetToast(); }, - [prepareControlNetMetadataItem, dispatch, parameterSetToast, parameterNotSetToast] - ); - - /** - * Recall T2I Adapter with toast - */ - - const { data: t2iAdapterModels } = useGetT2IAdapterModelsQuery(undefined); - - const prepareT2IAdapterMetadataItem = useCallback( - (t2iAdapterMetadataItem: T2IAdapterMetadataItem, newModel?: ParameterModel) => { - if (!isModelIdentifier(t2iAdapterMetadataItem.t2i_adapter_model)) { - return { controlnet: null, error: 'Invalid ControlNet model' }; - } - - const { image, t2i_adapter_model, weight, begin_step_percent, end_step_percent, resize_mode } = - t2iAdapterMetadataItem; - - const matchingT2IAdapterModel = t2iAdapterModels - ? t2iAdapterModelsAdapterSelectors.selectById(t2iAdapterModels, t2i_adapter_model.key) - : undefined; - - if (!matchingT2IAdapterModel) { - return { controlnet: null, error: 'ControlNet model is not installed' }; - } - - const isCompatibleBaseModel = matchingT2IAdapterModel?.base === (newModel ?? model)?.base; - - if (!isCompatibleBaseModel) { - return { - t2iAdapter: null, - error: 'ControlNet incompatible with currently-selected model', - }; - } - - // We don't save the original image that was processed into a control image, only the processed image - const processorType = 'none'; - const processorNode = CONTROLNET_PROCESSORS.none.default; - - const t2iAdapter: T2IAdapterConfig = { - type: 't2i_adapter', - isEnabled: true, - model: matchingT2IAdapterModel, - weight: typeof weight === 'number' ? weight : initialT2IAdapter.weight, - beginStepPct: begin_step_percent || initialT2IAdapter.beginStepPct, - endStepPct: end_step_percent || initialT2IAdapter.endStepPct, - resizeMode: resize_mode || initialT2IAdapter.resizeMode, - controlImage: image?.image_name || null, - processedControlImage: image?.image_name || null, - processorType, - processorNode, - shouldAutoConfig: true, - id: uuidv4(), - }; - - return { t2iAdapter, error: null }; - }, - [model, t2iAdapterModels] + [model?.base, dispatch, parameterSetToast, parameterNotSetToast] ); const recallT2IAdapter = useCallback( - (t2iAdapterMetadataItem: T2IAdapterMetadataItem) => { - const result = prepareT2IAdapterMetadataItem(t2iAdapterMetadataItem); - - if (!result.t2iAdapter) { - parameterNotSetToast(result.error); + async (t2iAdapterMetadataItem: T2IAdapterMetadataItem) => { + try { + const t2iAdapterConfig = await prepareT2IAdapterMetadataItem(t2iAdapterMetadataItem, model?.base); + dispatch(controlAdapterRecalled(t2iAdapterConfig)); + parameterSetToast(); + } catch (e) { + parameterNotSetToast((e as unknown as Error).message); return; } - - dispatch(controlAdapterRecalled(result.t2iAdapter)); - - parameterSetToast(); }, - [prepareT2IAdapterMetadataItem, dispatch, parameterSetToast, parameterNotSetToast] - ); - - /** - * Recall IP Adapter with toast - */ - - const { data: ipAdapterModels } = useGetIPAdapterModelsQuery(undefined); - - const prepareIPAdapterMetadataItem = useCallback( - (ipAdapterMetadataItem: IPAdapterMetadataItem, newModel?: ParameterModel) => { - if (!isModelIdentifier(ipAdapterMetadataItem?.ip_adapter_model)) { - return { ipAdapter: null, error: 'Invalid IP Adapter model' }; - } - - const { image, ip_adapter_model, weight, begin_step_percent, end_step_percent } = ipAdapterMetadataItem; - - const matchingIPAdapterModel = ipAdapterModels - ? ipAdapterModelsAdapterSelectors.selectById(ipAdapterModels, ip_adapter_model.key) - : undefined; - - if (!matchingIPAdapterModel) { - return { ipAdapter: null, error: 'IP Adapter model is not installed' }; - } - - const isCompatibleBaseModel = matchingIPAdapterModel?.base === (newModel ?? model)?.base; - - if (!isCompatibleBaseModel) { - return { - ipAdapter: null, - error: 'IP Adapter incompatible with currently-selected model', - }; - } - - const ipAdapter: IPAdapterConfig = { - id: uuidv4(), - type: 'ip_adapter', - isEnabled: true, - controlImage: image?.image_name ?? null, - model: matchingIPAdapterModel, - weight: weight ?? initialIPAdapter.weight, - beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct, - endStepPct: end_step_percent ?? initialIPAdapter.endStepPct, - }; - - return { ipAdapter, error: null }; - }, - [ipAdapterModels, model] + [model?.base, dispatch, parameterSetToast, parameterNotSetToast] ); const recallIPAdapter = useCallback( - (ipAdapterMetadataItem: IPAdapterMetadataItem) => { - const result = prepareIPAdapterMetadataItem(ipAdapterMetadataItem); - - if (!result.ipAdapter) { - parameterNotSetToast(result.error); + async (ipAdapterMetadataItem: IPAdapterMetadataItem) => { + try { + const ipAdapterConfig = await prepareIPAdapterMetadataItem(ipAdapterMetadataItem, model?.base); + dispatch(controlAdapterRecalled(ipAdapterConfig)); + parameterSetToast(); + } catch (e) { + parameterNotSetToast((e as unknown as Error).message); return; } - - dispatch(controlAdapterRecalled(result.ipAdapter)); - - parameterSetToast(); }, - [prepareIPAdapterMetadataItem, dispatch, parameterSetToast, parameterNotSetToast] + [model?.base, dispatch, parameterSetToast, parameterNotSetToast] ); - /* - * Sets image as initial image with toast - */ const sendToImageToImage = useCallback( (image: ImageDTO) => { dispatch(initialImageSelected(image)); @@ -780,7 +450,7 @@ export const useRecallParameters = () => { ); const recallAllParameters = useCallback( - (metadata: CoreMetadata | undefined) => { + async (metadata: CoreMetadata | undefined) => { if (!metadata) { allParameterNotSetToast(); return; @@ -820,10 +490,12 @@ export const useRecallParameters = () => { let newModel: ParameterModel | undefined = undefined; if (isModelIdentifier(model)) { - const result = prepareMainModelMetadataItem(model); - if (result.model) { - dispatch(modelSelected(result.model)); - newModel = result.model; + try { + const _model = await prepareMainModelMetadataItem(model); + dispatch(modelSelected(_model)); + newModel = _model; + } catch { + return; } } @@ -850,9 +522,11 @@ export const useRecallParameters = () => { if (isNil(vae)) { dispatch(vaeSelected(null)); } else { - const result = prepareVAEMetadataItem(vae, newModel); - if (result.vae) { - dispatch(vaeSelected(result.vae)); + try { + const _vae = await prepareVAEMetadataItem(vae, newModel?.base); + dispatch(vaeSelected(_vae)); + } catch { + return; } } } @@ -926,48 +600,46 @@ export const useRecallParameters = () => { } dispatch(lorasCleared()); - loras?.forEach((lora) => { - const result = prepareLoRAMetadataItem(lora, newModel); - if (result.lora) { - dispatch(loraRecalled({ ...result.lora, weight: lora.weight })); + loras?.forEach(async (loraMetadataItem) => { + try { + const lora = await prepareLoRAMetadataItem(loraMetadataItem, newModel?.base); + dispatch(loraRecalled(lora)); + } catch { + return; } }); dispatch(controlAdaptersReset()); - controlnets?.forEach((controlnet) => { - const result = prepareControlNetMetadataItem(controlnet, newModel); - if (result.controlnet) { - dispatch(controlAdapterRecalled(result.controlnet)); + controlnets?.forEach(async (controlNetMetadataItem) => { + try { + const controlNet = await prepareControlNetMetadataItem(controlNetMetadataItem, newModel?.base); + dispatch(controlAdapterRecalled(controlNet)); + } catch { + return; } }); - ipAdapters?.forEach((ipAdapter) => { - const result = prepareIPAdapterMetadataItem(ipAdapter, newModel); - if (result.ipAdapter) { - dispatch(controlAdapterRecalled(result.ipAdapter)); + ipAdapters?.forEach(async (ipAdapterMetadataItem) => { + try { + const ipAdapter = await prepareIPAdapterMetadataItem(ipAdapterMetadataItem, newModel?.base); + dispatch(controlAdapterRecalled(ipAdapter)); + } catch { + return; } }); - t2iAdapters?.forEach((t2iAdapter) => { - const result = prepareT2IAdapterMetadataItem(t2iAdapter, newModel); - if (result.t2iAdapter) { - dispatch(controlAdapterRecalled(result.t2iAdapter)); + t2iAdapters?.forEach(async (t2iAdapterMetadataItem) => { + try { + const t2iAdapter = await prepareT2IAdapterMetadataItem(t2iAdapterMetadataItem, newModel?.base); + dispatch(controlAdapterRecalled(t2iAdapter)); + } catch { + return; } }); allParameterSetToast(); }, - [ - dispatch, - allParameterSetToast, - allParameterNotSetToast, - prepareMainModelMetadataItem, - prepareVAEMetadataItem, - prepareLoRAMetadataItem, - prepareControlNetMetadataItem, - prepareIPAdapterMetadataItem, - prepareT2IAdapterMetadataItem, - ] + [dispatch, allParameterSetToast, allParameterNotSetToast] ); return { diff --git a/invokeai/frontend/web/src/features/parameters/util/modelFetchingHelpers.ts b/invokeai/frontend/web/src/features/parameters/util/modelFetchingHelpers.ts new file mode 100644 index 0000000000..c7d25fed8b --- /dev/null +++ b/invokeai/frontend/web/src/features/parameters/util/modelFetchingHelpers.ts @@ -0,0 +1,113 @@ +import { getStore } from 'app/store/nanostores/store'; +import { isModelIdentifier } from 'features/nodes/types/common'; +import { modelsApi } from 'services/api/endpoints/models'; +import type { AnyModelConfig, BaseModelType } from 'services/api/types'; +import { + isControlNetModelConfig, + isIPAdapterModelConfig, + isLoRAModelConfig, + isNonRefinerMainModelConfig, + isRefinerMainModelModelConfig, + isT2IAdapterModelConfig, + isTextualInversionModelConfig, + isVAEModelConfig, +} from 'services/api/types'; + +/** + * Raised when a model config is unable to be fetched. + */ +export class ModelConfigNotFoundError extends Error { + /** + * Create ModelConfigNotFoundError + * @param {String} message + */ + constructor(message: string) { + super(message); + this.name = this.constructor.name; + } +} + +/** + * Raised when a fetched model config is of an unexpected type. + */ +export class InvalidModelConfigError extends Error { + /** + * Create InvalidModelConfigError + * @param {String} message + */ + constructor(message: string) { + super(message); + this.name = this.constructor.name; + } +} + +export const fetchModelConfig = async (key: string): Promise => { + const { dispatch } = getStore(); + try { + const req = dispatch(modelsApi.endpoints.getModelConfig.initiate(key)); + req.unsubscribe(); + return await req.unwrap(); + } catch { + throw new ModelConfigNotFoundError(`Unable to retrieve model config for key ${key}`); + } +}; + +export const fetchModelConfigWithTypeGuard = async ( + key: string, + typeGuard: (config: AnyModelConfig) => config is T +) => { + const modelConfig = await fetchModelConfig(key); + if (!typeGuard(modelConfig)) { + throw new InvalidModelConfigError(`Invalid model type for key ${key}: ${modelConfig.type}`); + } + return modelConfig; +}; + +export const fetchMainModel = async (key: string) => { + return fetchModelConfigWithTypeGuard(key, isNonRefinerMainModelConfig); +}; + +export const fetchRefinerModel = async (key: string) => { + return fetchModelConfigWithTypeGuard(key, isRefinerMainModelModelConfig); +}; + +export const fetchVAEModel = async (key: string) => { + return fetchModelConfigWithTypeGuard(key, isVAEModelConfig); +}; + +export const fetchLoRAModel = async (key: string) => { + return fetchModelConfigWithTypeGuard(key, isLoRAModelConfig); +}; + +export const fetchControlNetModel = async (key: string) => { + return fetchModelConfigWithTypeGuard(key, isControlNetModelConfig); +}; + +export const fetchIPAdapterModel = async (key: string) => { + return fetchModelConfigWithTypeGuard(key, isIPAdapterModelConfig); +}; + +export const fetchT2IAdapterModel = async (key: string) => { + return fetchModelConfigWithTypeGuard(key, isT2IAdapterModelConfig); +}; + +export const fetchTextualInversionModel = async (key: string) => { + return fetchModelConfigWithTypeGuard(key, isTextualInversionModelConfig); +}; + +export const isBaseCompatible = (sourceBase: BaseModelType, targetBase: BaseModelType) => { + return sourceBase === targetBase; +}; + +export const raiseIfBaseIncompatible = (sourceBase: BaseModelType, targetBase?: BaseModelType, message?: string) => { + if (targetBase && !isBaseCompatible(sourceBase, targetBase)) { + throw new InvalidModelConfigError(message || `Incompatible base models: ${sourceBase} and ${targetBase}`); + } +}; + +export const getModelKey = (modelIdentifier: unknown, message?: string): string => { + if (!isModelIdentifier(modelIdentifier)) { + throw new InvalidModelConfigError(message || `Invalid model identifier: ${modelIdentifier}`); + } + return modelIdentifier.key; +}; diff --git a/invokeai/frontend/web/src/features/parameters/util/modelMetadataHelpers.ts b/invokeai/frontend/web/src/features/parameters/util/modelMetadataHelpers.ts new file mode 100644 index 0000000000..722073366f --- /dev/null +++ b/invokeai/frontend/web/src/features/parameters/util/modelMetadataHelpers.ts @@ -0,0 +1,150 @@ +import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants'; +import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types'; +import { + initialControlNet, + initialIPAdapter, + initialT2IAdapter, +} from 'features/controlAdapters/util/buildControlAdapter'; +import type { LoRA } from 'features/lora/store/loraSlice'; +import type { ModelIdentifierWithBase } from 'features/nodes/types/common'; +import { zModelIdentifierWithBase } from 'features/nodes/types/common'; +import type { + ControlNetMetadataItem, + IPAdapterMetadataItem, + LoRAMetadataItem, + T2IAdapterMetadataItem, +} from 'features/nodes/types/metadata'; +import { + fetchControlNetModel, + fetchIPAdapterModel, + fetchLoRAModel, + fetchMainModel, + fetchRefinerModel, + fetchT2IAdapterModel, + fetchVAEModel, + getModelKey, + raiseIfBaseIncompatible, +} from 'features/parameters/util/modelFetchingHelpers'; +import type { BaseModelType } from 'services/api/types'; +import { v4 as uuidv4 } from 'uuid'; + +export const prepareMainModelMetadataItem = async (model: unknown): Promise => { + const key = getModelKey(model); + const mainModel = await fetchMainModel(key); + return zModelIdentifierWithBase.parse(mainModel); +}; + +export const prepareRefinerMetadataItem = async (model: unknown): Promise => { + const key = getModelKey(model); + const refinerModel = await fetchRefinerModel(key); + return zModelIdentifierWithBase.parse(refinerModel); +}; + +export const prepareVAEMetadataItem = async (vae: unknown, base?: BaseModelType): Promise => { + const key = getModelKey(vae); + const vaeModel = await fetchVAEModel(key); + raiseIfBaseIncompatible(vaeModel.base, base, 'VAE incompatible with currently-selected model'); + return zModelIdentifierWithBase.parse(vaeModel); +}; + +export const prepareLoRAMetadataItem = async ( + loraMetadataItem: LoRAMetadataItem, + base?: BaseModelType +): Promise => { + const key = getModelKey(loraMetadataItem.lora); + const loraModel = await fetchLoRAModel(key); + raiseIfBaseIncompatible(loraModel.base, base, 'LoRA incompatible with currently-selected model'); + return { key: loraModel.key, base: loraModel.base, weight: loraMetadataItem.weight, isEnabled: true }; +}; + +export const prepareControlNetMetadataItem = async ( + controlnetMetadataItem: ControlNetMetadataItem, + base?: BaseModelType +): Promise => { + const key = getModelKey(controlnetMetadataItem.control_model); + const controlNetModel = await fetchControlNetModel(key); + raiseIfBaseIncompatible(controlNetModel.base, base, 'ControlNet incompatible with currently-selected model'); + + const { image, control_weight, begin_step_percent, end_step_percent, control_mode, resize_mode } = + controlnetMetadataItem; + + // We don't save the original image that was processed into a control image, only the processed image + const processorType = 'none'; + const processorNode = CONTROLNET_PROCESSORS.none.default; + + const controlnet: ControlNetConfig = { + type: 'controlnet', + isEnabled: true, + model: zModelIdentifierWithBase.parse(controlNetModel), + weight: typeof control_weight === 'number' ? control_weight : initialControlNet.weight, + beginStepPct: begin_step_percent || initialControlNet.beginStepPct, + endStepPct: end_step_percent || initialControlNet.endStepPct, + controlMode: control_mode || initialControlNet.controlMode, + resizeMode: resize_mode || initialControlNet.resizeMode, + controlImage: image?.image_name || null, + processedControlImage: image?.image_name || null, + processorType, + processorNode, + shouldAutoConfig: true, + id: uuidv4(), + }; + + return controlnet; +}; + +export const prepareT2IAdapterMetadataItem = async ( + t2iAdapterMetadataItem: T2IAdapterMetadataItem, + base?: BaseModelType +): Promise => { + const key = getModelKey(t2iAdapterMetadataItem.t2i_adapter_model); + const t2iAdapterModel = await fetchT2IAdapterModel(key); + raiseIfBaseIncompatible(t2iAdapterModel.base, base, 'T2I Adapter incompatible with currently-selected model'); + + const { image, weight, begin_step_percent, end_step_percent, resize_mode } = t2iAdapterMetadataItem; + + // We don't save the original image that was processed into a control image, only the processed image + const processorType = 'none'; + const processorNode = CONTROLNET_PROCESSORS.none.default; + + const t2iAdapter: T2IAdapterConfig = { + type: 't2i_adapter', + isEnabled: true, + model: zModelIdentifierWithBase.parse(t2iAdapterModel), + weight: typeof weight === 'number' ? weight : initialT2IAdapter.weight, + beginStepPct: begin_step_percent || initialT2IAdapter.beginStepPct, + endStepPct: end_step_percent || initialT2IAdapter.endStepPct, + resizeMode: resize_mode || initialT2IAdapter.resizeMode, + controlImage: image?.image_name || null, + processedControlImage: image?.image_name || null, + processorType, + processorNode, + shouldAutoConfig: true, + id: uuidv4(), + }; + + return t2iAdapter; +}; + +export const prepareIPAdapterMetadataItem = async ( + ipAdapterMetadataItem: IPAdapterMetadataItem, + base?: BaseModelType +): Promise => { + const key = getModelKey(ipAdapterMetadataItem?.ip_adapter_model); + const ipAdapterModel = await fetchIPAdapterModel(key); + raiseIfBaseIncompatible(ipAdapterModel.base, base, 'T2I Adapter incompatible with currently-selected model'); + + const { image, weight, begin_step_percent, end_step_percent } = ipAdapterMetadataItem; + + const ipAdapter: IPAdapterConfig = { + id: uuidv4(), + type: 'ip_adapter', + isEnabled: true, + controlImage: image?.image_name ?? null, + model: zModelIdentifierWithBase.parse(ipAdapterModel), + weight: weight ?? initialIPAdapter.weight, + beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct, + endStepPct: end_step_percent ?? initialIPAdapter.endStepPct, + }; + + return ipAdapter; +}; diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index 666e0c707d..2bd1a0a246 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -6,16 +6,16 @@ import type { operations, paths } from 'services/api/schema'; import type { AnyModelConfig, BaseModelType, - ControlNetConfig, + ControlNetModelConfig, ImportModelConfig, - IPAdapterConfig, - LoRAConfig, + IPAdapterModelConfig, + LoRAModelConfig, MainModelConfig, MergeModelConfig, ModelType, - T2IAdapterConfig, - TextualInversionConfig, - VAEConfig, + T2IAdapterModelConfig, + TextualInversionModelConfig, + VAEModelConfig, } from 'services/api/types'; import type { ApiTagDescription, tagTypes } from '..'; @@ -30,7 +30,7 @@ type UpdateMainModelArg = { type UpdateLoRAModelArg = { base_model: BaseModelType; model_name: string; - body: LoRAConfig; + body: LoRAModelConfig; }; type UpdateMainModelResponse = @@ -97,27 +97,27 @@ export const mainModelsAdapter = createEntityAdapter({ sortComparer: (a, b) => a.name.localeCompare(b.name), }); export const mainModelsAdapterSelectors = mainModelsAdapter.getSelectors(undefined, getSelectorsOptions); -export const loraModelsAdapter = createEntityAdapter({ +export const loraModelsAdapter = createEntityAdapter({ selectId: (entity) => entity.key, sortComparer: (a, b) => a.name.localeCompare(b.name), }); export const loraModelsAdapterSelectors = loraModelsAdapter.getSelectors(undefined, getSelectorsOptions); -export const controlNetModelsAdapter = createEntityAdapter({ +export const controlNetModelsAdapter = createEntityAdapter({ selectId: (entity) => entity.key, sortComparer: (a, b) => a.name.localeCompare(b.name), }); export const controlNetModelsAdapterSelectors = controlNetModelsAdapter.getSelectors(undefined, getSelectorsOptions); -export const ipAdapterModelsAdapter = createEntityAdapter({ +export const ipAdapterModelsAdapter = createEntityAdapter({ selectId: (entity) => entity.key, sortComparer: (a, b) => a.name.localeCompare(b.name), }); export const ipAdapterModelsAdapterSelectors = ipAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions); -export const t2iAdapterModelsAdapter = createEntityAdapter({ +export const t2iAdapterModelsAdapter = createEntityAdapter({ selectId: (entity) => entity.key, sortComparer: (a, b) => a.name.localeCompare(b.name), }); export const t2iAdapterModelsAdapterSelectors = t2iAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions); -export const textualInversionModelsAdapter = createEntityAdapter({ +export const textualInversionModelsAdapter = createEntityAdapter({ selectId: (entity) => entity.key, sortComparer: (a, b) => a.name.localeCompare(b.name), }); @@ -125,7 +125,7 @@ export const textualInversionModelsAdapterSelectors = textualInversionModelsAdap undefined, getSelectorsOptions ); -export const vaeModelsAdapter = createEntityAdapter({ +export const vaeModelsAdapter = createEntityAdapter({ selectId: (entity) => entity.key, sortComparer: (a, b) => a.name.localeCompare(b.name), }); @@ -162,6 +162,8 @@ const buildTransformResponse = */ const buildModelsUrl = (path: string = '') => buildV2Url(`models/${path}`); +// TODO(psyche): Ideally we can share the cache between the `getXYZModels` queries and `getModelConfig` query + export const modelsApi = api.injectEndpoints({ endpoints: (build) => ({ getMainModels: build.query, BaseModelType[]>({ @@ -257,10 +259,10 @@ export const modelsApi = api.injectEndpoints({ }, invalidatesTags: ['Model'], }), - getLoRAModels: build.query, void>({ + getLoRAModels: build.query, void>({ query: () => ({ url: buildModelsUrl(), params: { model_type: 'lora' } }), - providesTags: buildProvidesTags('LoRAModel'), - transformResponse: buildTransformResponse(loraModelsAdapter), + providesTags: buildProvidesTags('LoRAModel'), + transformResponse: buildTransformResponse(loraModelsAdapter), }), updateLoRAModels: build.mutation({ query: ({ base_model, model_name, body }) => { @@ -281,30 +283,30 @@ export const modelsApi = api.injectEndpoints({ }, invalidatesTags: [{ type: 'LoRAModel', id: LIST_TAG }], }), - getControlNetModels: build.query, void>({ + getControlNetModels: build.query, void>({ query: () => ({ url: buildModelsUrl(), params: { model_type: 'controlnet' } }), - providesTags: buildProvidesTags('ControlNetModel'), - transformResponse: buildTransformResponse(controlNetModelsAdapter), + providesTags: buildProvidesTags('ControlNetModel'), + transformResponse: buildTransformResponse(controlNetModelsAdapter), }), - getIPAdapterModels: build.query, void>({ + getIPAdapterModels: build.query, void>({ query: () => ({ url: buildModelsUrl(), params: { model_type: 'ip_adapter' } }), - providesTags: buildProvidesTags('IPAdapterModel'), - transformResponse: buildTransformResponse(ipAdapterModelsAdapter), + providesTags: buildProvidesTags('IPAdapterModel'), + transformResponse: buildTransformResponse(ipAdapterModelsAdapter), }), - getT2IAdapterModels: build.query, void>({ + getT2IAdapterModels: build.query, void>({ query: () => ({ url: buildModelsUrl(), params: { model_type: 't2i_adapter' } }), - providesTags: buildProvidesTags('T2IAdapterModel'), - transformResponse: buildTransformResponse(t2iAdapterModelsAdapter), + providesTags: buildProvidesTags('T2IAdapterModel'), + transformResponse: buildTransformResponse(t2iAdapterModelsAdapter), }), - getVaeModels: build.query, void>({ + getVaeModels: build.query, void>({ query: () => ({ url: buildModelsUrl(), params: { model_type: 'vae' } }), - providesTags: buildProvidesTags('VaeModel'), - transformResponse: buildTransformResponse(vaeModelsAdapter), + providesTags: buildProvidesTags('VaeModel'), + transformResponse: buildTransformResponse(vaeModelsAdapter), }), - getTextualInversionModels: build.query, void>({ + getTextualInversionModels: build.query, void>({ query: () => ({ url: buildModelsUrl(), params: { model_type: 'embedding' } }), - providesTags: buildProvidesTags('TextualInversionModel'), - transformResponse: buildTransformResponse(textualInversionModelsAdapter), + providesTags: buildProvidesTags('TextualInversionModel'), + transformResponse: buildTransformResponse(textualInversionModelsAdapter), }), getModelsInFolder: build.query({ query: (arg) => {