From dab939f7d17a139a6652c0c525af61d5a900ef4c Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 16 Feb 2024 18:56:02 +1100 Subject: [PATCH] feat(ui): update model identifier to be key (wip) - Update most model identifiers to be `{key: string}` instead of name/base/type. Doesn't change the model select components yet. - Update model _parameters_, stored in redux, to be `{key: string, base: BaseModel}` - we need to store the base model to be able to check model compatibility. May want to store the whole config? Not sure... --- .../frontend/web/.storybook/ReduxInit.tsx | 8 +- .../listeners/enqueueRequestedLinear.ts | 2 +- .../listeners/modelSelected.ts | 10 +- .../listeners/modelsLoaded.ts | 43 +-- .../common/hooks/useGroupedModelCombobox.ts | 6 +- .../src/common/hooks/useIsReadyToEnqueue.ts | 2 +- .../web/src/common/hooks/useModelCombobox.ts | 6 +- .../src/features/canvas/store/canvasSlice.ts | 2 +- .../parameters/ParamControlAdapterModel.tsx | 17 +- .../hooks/useAddControlAdapter.ts | 4 +- .../store/controlAdaptersSlice.ts | 6 +- .../features/embedding/EmbeddingSelect.tsx | 8 +- .../ImageMetadataActions.tsx | 14 +- .../src/features/lora/components/LoRACard.tsx | 2 +- .../src/features/lora/components/LoRAList.tsx | 2 +- .../features/lora/components/LoRASelect.tsx | 6 +- .../web/src/features/lora/store/loraSlice.ts | 35 ++- .../subpanels/ModelManagerPanel.tsx | 10 +- .../ModelManagerPanel/CheckpointModelEdit.tsx | 4 +- .../ModelManagerPanel/DiffusersModelEdit.tsx | 4 +- .../ModelManagerPanel/LoRAModelEdit.tsx | 14 +- .../subpanels/ModelManagerPanel/ModelList.tsx | 6 +- .../ModelManagerPanel/ModelListItem.tsx | 4 +- .../ControlNetModelFieldInputComponent.tsx | 4 +- .../IPAdapterModelFieldInputComponent.tsx | 4 +- .../inputs/LoRAModelFieldInputComponent.tsx | 4 +- .../inputs/MainModelFieldInputComponent.tsx | 4 +- .../RefinerModelFieldInputComponent.tsx | 4 +- .../SDXLMainModelFieldInputComponent.tsx | 4 +- .../T2IAdapterModelFieldInputComponent.tsx | 4 +- .../inputs/VAEModelFieldInputComponent.tsx | 4 +- .../web/src/features/nodes/types/common.ts | 16 +- .../util/graph/addControlNetToLinearGraph.ts | 2 +- .../util/graph/addIPAdapterToLinearGraph.ts | 2 +- .../nodes/util/graph/addLoRAsToGraph.ts | 9 +- .../nodes/util/graph/addSDXLLoRAstoGraph.ts | 9 +- .../util/graph/addT2IAdapterToLinearGraph.ts | 2 +- .../nodes/util/graph/buildCanvasGraph.ts | 8 +- .../util/graph/buildLinearBatchConfig.ts | 2 +- .../components/Advanced/ParamClipSkip.tsx | 6 +- .../components/Core/ParamPositivePrompt.tsx | 2 +- .../MainModel/ParamMainModelSelect.tsx | 4 +- .../VAEModel/ParamVAEModelSelect.tsx | 6 +- .../parameters/hooks/useRecallParameters.ts | 29 +- .../parameters/store/generationSlice.ts | 6 +- .../parameters/types/parameterSchemas.ts | 15 +- .../parameters/util/optimalDimension.ts | 6 +- .../ParamSDXLRefinerModelSelect.tsx | 6 +- .../AdvancedSettingsAccordion.tsx | 3 +- .../GenerationSettingsAccordion.tsx | 5 +- .../ImageSettingsAccordion.tsx | 2 +- .../ui/components/ParametersPanel.tsx | 2 +- .../web/src/services/api/endpoints/models.ts | 284 +++++------------- .../frontend/web/src/services/api/types.ts | 47 ++- 54 files changed, 267 insertions(+), 453 deletions(-) diff --git a/invokeai/frontend/web/.storybook/ReduxInit.tsx b/invokeai/frontend/web/.storybook/ReduxInit.tsx index 55d0132242..7d3f8e0d2b 100644 --- a/invokeai/frontend/web/.storybook/ReduxInit.tsx +++ b/invokeai/frontend/web/.storybook/ReduxInit.tsx @@ -10,13 +10,7 @@ export const ReduxInit = memo((props: PropsWithChildren) => { const dispatch = useAppDispatch(); useGlobalModifiersInit(); useEffect(() => { - dispatch( - modelChanged({ - model_name: 'test_model', - base_model: 'sd-1', - model_type: 'main', - }) - ); + dispatch(modelChanged({ key: 'test_model', base: 'sd-1' })); }, []); return props.children; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts index e1e13fadbe..d1cb692c98 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts @@ -19,7 +19,7 @@ export const addEnqueueRequestedLinear = () => { let graph; - if (model && model.base_model === 'sdxl') { + if (model && model.base === 'sdxl') { if (action.payload.tabName === 'txt2img') { graph = buildLinearSDXLTextToImageGraph(state); } else { diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts index 7638c5522a..35e2ad5f9b 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts @@ -30,8 +30,8 @@ export const addModelSelectedListener = () => { const newModel = result.data; - const newBaseModel = newModel.base_model; - const didBaseModelChange = state.generation.model?.base_model !== newBaseModel; + const newBaseModel = newModel.base; + const didBaseModelChange = state.generation.model?.base !== newBaseModel; if (didBaseModelChange) { // we may need to reset some incompatible submodels @@ -39,7 +39,7 @@ export const addModelSelectedListener = () => { // handle incompatible loras forEach(state.lora.loras, (lora, id) => { - if (lora.base_model !== newBaseModel) { + if (lora.base !== newBaseModel) { dispatch(loraRemoved(id)); modelsCleared += 1; } @@ -47,14 +47,14 @@ export const addModelSelectedListener = () => { // handle incompatible vae const { vae } = state.generation; - if (vae && vae.base_model !== newBaseModel) { + if (vae && vae.base !== newBaseModel) { dispatch(vaeSelected(null)); modelsCleared += 1; } // handle incompatible controlnets selectControlAdapterAll(state.controlAdapters).forEach((ca) => { - if (ca.model?.base_model !== newBaseModel) { + if (ca.model?.base !== newBaseModel) { dispatch(controlAdapterIsEnabledChanged({ id: ca.id, isEnabled: false })); modelsCleared += 1; } diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts index 0ffe88cd07..366644fa68 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts @@ -34,14 +34,7 @@ export const addModelsLoadedListener = () => { return; } - const isCurrentModelAvailable = currentModel - ? models.some( - (m) => - m.model_name === currentModel.model_name && - m.base_model === currentModel.base_model && - m.model_type === currentModel.model_type - ) - : false; + const isCurrentModelAvailable = currentModel ? models.some((m) => m.key === currentModel.key) : false; if (isCurrentModelAvailable) { return; @@ -74,14 +67,7 @@ export const addModelsLoadedListener = () => { return; } - const isCurrentModelAvailable = currentModel - ? models.some( - (m) => - m.model_name === currentModel.model_name && - m.base_model === currentModel.base_model && - m.model_type === currentModel.model_type - ) - : false; + const isCurrentModelAvailable = currentModel ? models.some((m) => m.key === currentModel.key) : false; if (!isCurrentModelAvailable) { dispatch(refinerModelChanged(null)); @@ -103,10 +89,7 @@ export const addModelsLoadedListener = () => { return; } - const isCurrentVAEAvailable = some( - action.payload.entities, - (m) => m?.model_name === currentVae?.model_name && m?.base_model === currentVae?.base_model - ); + const isCurrentVAEAvailable = some(action.payload.entities, (m) => m?.key === currentVae?.key); if (isCurrentVAEAvailable) { return; @@ -140,10 +123,7 @@ export const addModelsLoadedListener = () => { const loras = getState().lora.loras; forEach(loras, (lora, id) => { - const isLoRAAvailable = some( - action.payload.entities, - (m) => m?.model_name === lora?.model_name && m?.base_model === lora?.base_model - ); + const isLoRAAvailable = some(action.payload.entities, (m) => m?.key === lora?.key); if (isLoRAAvailable) { return; @@ -161,10 +141,7 @@ export const addModelsLoadedListener = () => { log.info({ models: action.payload.entities }, `ControlNet models loaded (${action.payload.ids.length})`); selectAllControlNets(getState().controlAdapters).forEach((ca) => { - const isModelAvailable = some( - action.payload.entities, - (m) => m?.model_name === ca?.model?.model_name && m?.base_model === ca?.model?.base_model - ); + const isModelAvailable = some(action.payload.entities, (m) => m?.key === ca?.model?.key); if (isModelAvailable) { return; @@ -182,10 +159,7 @@ export const addModelsLoadedListener = () => { log.info({ models: action.payload.entities }, `T2I Adapter models loaded (${action.payload.ids.length})`); selectAllT2IAdapters(getState().controlAdapters).forEach((ca) => { - const isModelAvailable = some( - action.payload.entities, - (m) => m?.model_name === ca?.model?.model_name && m?.base_model === ca?.model?.base_model - ); + const isModelAvailable = some(action.payload.entities, (m) => m?.key === ca?.model?.key); if (isModelAvailable) { return; @@ -203,10 +177,7 @@ export const addModelsLoadedListener = () => { log.info({ models: action.payload.entities }, `IP Adapter models loaded (${action.payload.ids.length})`); selectAllIPAdapters(getState().controlAdapters).forEach((ca) => { - const isModelAvailable = some( - action.payload.entities, - (m) => m?.model_name === ca?.model?.model_name && m?.base_model === ca?.model?.base_model - ); + const isModelAvailable = some(action.payload.entities, (m) => m?.key === ca?.model?.key); if (isModelAvailable) { return; diff --git a/invokeai/frontend/web/src/common/hooks/useGroupedModelCombobox.ts b/invokeai/frontend/web/src/common/hooks/useGroupedModelCombobox.ts index eb55db79ca..875ce1f1c4 100644 --- a/invokeai/frontend/web/src/common/hooks/useGroupedModelCombobox.ts +++ b/invokeai/frontend/web/src/common/hooks/useGroupedModelCombobox.ts @@ -5,10 +5,10 @@ import type { GroupBase } from 'chakra-react-select'; import { groupBy, map, reduce } from 'lodash-es'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import type { AnyModelConfigEntity } from 'services/api/endpoints/models'; +import type { AnyModelConfig } from 'services/api/endpoints/models'; import { getModelId } from 'services/api/endpoints/models'; -type UseGroupedModelComboboxArg = { +type UseGroupedModelComboboxArg = { modelEntities: EntityState | undefined; selectedModel?: Pick | null; onChange: (value: T | null) => void; @@ -24,7 +24,7 @@ type UseGroupedModelComboboxReturn = { noOptionsMessage: () => string; }; -export const useGroupedModelCombobox = ( +export const useGroupedModelCombobox = ( arg: UseGroupedModelComboboxArg ): UseGroupedModelComboboxReturn => { const { t } = useTranslation(); diff --git a/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts b/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts index baa704e75c..b31efed970 100644 --- a/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts +++ b/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts @@ -105,7 +105,7 @@ const selector = createMemoizedSelector( number: i + 1, }) ); - } else if (ca.model.base_model !== model?.base_model) { + } else if (ca.model.base !== model?.base) { // This should never happen, just a sanity check reasons.push( i18n.t('parameters.invoke.incompatibleBaseModelForControlAdapter', { diff --git a/invokeai/frontend/web/src/common/hooks/useModelCombobox.ts b/invokeai/frontend/web/src/common/hooks/useModelCombobox.ts index 880b316379..341fed1e47 100644 --- a/invokeai/frontend/web/src/common/hooks/useModelCombobox.ts +++ b/invokeai/frontend/web/src/common/hooks/useModelCombobox.ts @@ -3,10 +3,10 @@ import type { EntityState } from '@reduxjs/toolkit'; import { map } from 'lodash-es'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import type { AnyModelConfigEntity } from 'services/api/endpoints/models'; +import type { AnyModelConfig } from 'services/api/endpoints/models'; import { getModelId } from 'services/api/endpoints/models'; -type UseModelComboboxArg = { +type UseModelComboboxArg = { modelEntities: EntityState | undefined; selectedModel?: Pick | null; onChange: (value: T | null) => void; @@ -23,7 +23,7 @@ type UseModelComboboxReturn = { noOptionsMessage: () => string; }; -export const useModelCombobox = ( +export const useModelCombobox = ( arg: UseModelComboboxArg ): UseModelComboboxReturn => { const { t } = useTranslation(); diff --git a/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts b/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts index cd734d3f00..f50d52c1bf 100644 --- a/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts +++ b/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts @@ -626,7 +626,7 @@ export const canvasSlice = createSlice({ }, extraReducers: (builder) => { builder.addCase(modelChanged, (state, action) => { - if (action.meta.previousModel?.base_model === action.payload?.base_model) { + if (action.meta.previousModel?.base === action.payload?.base) { // The base model hasn't changed, we don't need to optimize the size return; } diff --git a/invokeai/frontend/web/src/features/controlAdapters/components/parameters/ParamControlAdapterModel.tsx b/invokeai/frontend/web/src/features/controlAdapters/components/parameters/ParamControlAdapterModel.tsx index 13851b143c..a320238445 100644 --- a/invokeai/frontend/web/src/features/controlAdapters/components/parameters/ParamControlAdapterModel.tsx +++ b/invokeai/frontend/web/src/features/controlAdapters/components/parameters/ParamControlAdapterModel.tsx @@ -11,12 +11,7 @@ import { selectGenerationSlice } from 'features/parameters/store/generationSlice import { pick } from 'lodash-es'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import type { - ControlNetModelConfigEntity, - IPAdapterModelConfigEntity, - T2IAdapterModelConfigEntity, -} from 'services/api/endpoints/models'; -import type { AnyModelConfig } from 'services/api/types'; +import type { AnyModelConfig, ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'services/api/types'; type ParamControlAdapterModelProps = { id: string; @@ -29,21 +24,21 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => { const controlAdapterType = useControlAdapterType(id); const model = useControlAdapterModel(id); const dispatch = useAppDispatch(); - const currentBaseModel = useAppSelector((s) => s.generation.model?.base_model); + const currentBaseModel = useAppSelector((s) => s.generation.model?.base); const mainModel = useAppSelector(selectMainModel); const { t } = useTranslation(); const models = useControlAdapterModelEntities(controlAdapterType); const _onChange = useCallback( - (model: ControlNetModelConfigEntity | IPAdapterModelConfigEntity | T2IAdapterModelConfigEntity | null) => { + (model: ControlNetConfig | IPAdapterConfig | T2IAdapterConfig | null) => { if (!model) { return; } dispatch( controlAdapterModelChanged({ id, - model: pick(model, 'base_model', 'model_name'), + model: pick(model, 'base', 'key'), }) ); }, @@ -57,7 +52,7 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => { const getIsDisabled = useCallback( (model: AnyModelConfig): boolean => { - const isCompatible = currentBaseModel === model.base_model; + const isCompatible = currentBaseModel === model.base; const hasMainModel = Boolean(currentBaseModel); return !hasMainModel || !isCompatible; }, @@ -73,7 +68,7 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => { return ( - + { - const baseModel = useAppSelector((s) => s.generation.model?.base_model); + const baseModel = useAppSelector((s) => s.generation.model?.base); const dispatch = useAppDispatch(); const models = useControlAdapterModels(type); const firstModel = useMemo(() => { // prefer to use a model that matches the base model - const firstCompatibleModel = models.filter((m) => (baseModel ? m.base_model === baseModel : true))[0]; + const firstCompatibleModel = models.filter((m) => (baseModel ? m.base === baseModel : true))[0]; if (firstCompatibleModel) { return firstCompatibleModel; diff --git a/invokeai/frontend/web/src/features/controlAdapters/store/controlAdaptersSlice.ts b/invokeai/frontend/web/src/features/controlAdapters/store/controlAdaptersSlice.ts index 49b07f16a1..fce94ad019 100644 --- a/invokeai/frontend/web/src/features/controlAdapters/store/controlAdaptersSlice.ts +++ b/invokeai/frontend/web/src/features/controlAdapters/store/controlAdaptersSlice.ts @@ -236,7 +236,8 @@ export const controlAdaptersSlice = createSlice({ let processorType: ControlAdapterProcessorType | undefined = undefined; for (const modelSubstring in CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS) { - if (model.model_name.includes(modelSubstring)) { + // TODO(MM2): matching modelSubstring to the model key is no longer a valid way to figure out the default processorType + if (model.key.includes(modelSubstring)) { processorType = CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS[modelSubstring]; break; } @@ -359,7 +360,8 @@ export const controlAdaptersSlice = createSlice({ let processorType: ControlAdapterProcessorType | undefined = undefined; for (const modelSubstring in CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS) { - if (cn.model?.model_name.includes(modelSubstring)) { + // TODO(MM2): matching modelSubstring to the model key is no longer a valid way to figure out the default processorType + if (cn.model?.key.includes(modelSubstring)) { processorType = CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS[modelSubstring]; break; } diff --git a/invokeai/frontend/web/src/features/embedding/EmbeddingSelect.tsx b/invokeai/frontend/web/src/features/embedding/EmbeddingSelect.tsx index ffe9d63360..426ddd21e2 100644 --- a/invokeai/frontend/web/src/features/embedding/EmbeddingSelect.tsx +++ b/invokeai/frontend/web/src/features/embedding/EmbeddingSelect.tsx @@ -6,18 +6,18 @@ import type { EmbeddingSelectProps } from 'features/embedding/types'; import { t } from 'i18next'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; -import type { TextualInversionModelConfigEntity } from 'services/api/endpoints/models'; import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models'; +import type { TextualInversionConfig } from 'services/api/types'; const noOptionsMessage = () => t('embedding.noMatchingEmbedding'); export const EmbeddingSelect = memo(({ onSelect, onClose }: EmbeddingSelectProps) => { const { t } = useTranslation(); - const currentBaseModel = useAppSelector((s) => s.generation.model?.base_model); + const currentBaseModel = useAppSelector((s) => s.generation.model?.base); const getIsDisabled = useCallback( - (embedding: TextualInversionModelConfigEntity): boolean => { + (embedding: TextualInversionConfig): boolean => { const isCompatible = currentBaseModel === embedding.base_model; const hasMainModel = Boolean(currentBaseModel); return !hasMainModel || !isCompatible; @@ -27,7 +27,7 @@ export const EmbeddingSelect = memo(({ onSelect, onClose }: EmbeddingSelectProps const { data, isLoading } = useGetTextualInversionModelsQuery(); const _onChange = useCallback( - (embedding: TextualInversionModelConfigEntity | null) => { + (embedding: TextualInversionConfig | null) => { if (!embedding) { return; } diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx index e9a1461186..5907ba0700 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx @@ -208,8 +208,8 @@ const ImageMetadataActions = (props: Props) => { {metadata.seed !== undefined && metadata.seed !== null && ( )} - {metadata.model !== undefined && metadata.model !== null && metadata.model.model_name && ( - + {metadata.model !== undefined && metadata.model !== null && metadata.model.key && ( + )} {metadata.width && ( @@ -222,7 +222,7 @@ const ImageMetadataActions = (props: Props) => { )} {metadata.steps && ( @@ -269,7 +269,7 @@ const ImageMetadataActions = (props: Props) => { ); @@ -279,7 +279,7 @@ const ImageMetadataActions = (props: Props) => { ))} @@ -287,7 +287,7 @@ const ImageMetadataActions = (props: Props) => { ))} @@ -295,7 +295,7 @@ const ImageMetadataActions = (props: Props) => { ))} diff --git a/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx b/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx index 28bd8afe95..81e0027b2d 100644 --- a/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx +++ b/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx @@ -44,7 +44,7 @@ export const LoRACard = memo((props: LoRACardProps) => { - {lora.model_name} + {lora.key} diff --git a/invokeai/frontend/web/src/features/lora/components/LoRAList.tsx b/invokeai/frontend/web/src/features/lora/components/LoRAList.tsx index 9f37454d16..7bcd537805 100644 --- a/invokeai/frontend/web/src/features/lora/components/LoRAList.tsx +++ b/invokeai/frontend/web/src/features/lora/components/LoRAList.tsx @@ -18,7 +18,7 @@ export const LoRAList = memo(() => { return ( {lorasArray.map((lora) => ( - + ))} ); diff --git a/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx b/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx index ed70a4d44a..069c557aef 100644 --- a/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx +++ b/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx @@ -7,7 +7,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; import { loraAdded, selectLoraSlice } from 'features/lora/store/loraSlice'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import type { LoRAModelConfigEntity } from 'services/api/endpoints/models'; +import type { LoRAConfig } from 'services/api/endpoints/models'; import { useGetLoRAModelsQuery } from 'services/api/endpoints/models'; const selectAddedLoRAs = createMemoizedSelector(selectLoraSlice, (lora) => lora.loras); @@ -19,7 +19,7 @@ const LoRASelect = () => { const addedLoRAs = useAppSelector(selectAddedLoRAs); const currentBaseModel = useAppSelector((s) => s.generation.model?.base_model); - const getIsDisabled = (lora: LoRAModelConfigEntity): boolean => { + const getIsDisabled = (lora: LoRAConfig): boolean => { const isCompatible = currentBaseModel === lora.base_model; const isAdded = Boolean(addedLoRAs[lora.id]); const hasMainModel = Boolean(currentBaseModel); @@ -27,7 +27,7 @@ const LoRASelect = () => { }; const _onChange = useCallback( - (lora: LoRAModelConfigEntity | null) => { + (lora: LoRAConfig | null) => { if (!lora) { return; } diff --git a/invokeai/frontend/web/src/features/lora/store/loraSlice.ts b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts index ab1b140a7c..dd455e12c3 100644 --- a/invokeai/frontend/web/src/features/lora/store/loraSlice.ts +++ b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts @@ -2,10 +2,9 @@ import type { PayloadAction } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit'; import type { PersistConfig, RootState } from 'app/store/store'; import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas'; -import type { LoRAModelConfigEntity } from 'services/api/endpoints/models'; +import type { LoRAConfig } from 'services/api/types'; export type LoRA = ParameterLoRAModel & { - id: string; weight: number; isEnabled?: boolean; }; @@ -29,40 +28,40 @@ export const loraSlice = createSlice({ name: 'lora', initialState: initialLoraState, reducers: { - loraAdded: (state, action: PayloadAction) => { - const { model_name, id, base_model } = action.payload; - state.loras[id] = { id, model_name, base_model, ...defaultLoRAConfig }; + loraAdded: (state, action: PayloadAction) => { + const { key, base } = action.payload; + state.loras[key] = { key, base, ...defaultLoRAConfig }; }, - loraRecalled: (state, action: PayloadAction) => { - const { model_name, id, base_model, weight } = action.payload; - state.loras[id] = { id, model_name, base_model, weight, isEnabled: true }; + loraRecalled: (state, action: PayloadAction) => { + const { key, base, weight } = action.payload; + state.loras[key] = { key, base, weight, isEnabled: true }; }, loraRemoved: (state, action: PayloadAction) => { - const id = action.payload; - delete state.loras[id]; + const key = action.payload; + delete state.loras[key]; }, lorasCleared: (state) => { state.loras = {}; }, - loraWeightChanged: (state, action: PayloadAction<{ id: string; weight: number }>) => { - const { id, weight } = action.payload; - const lora = state.loras[id]; + loraWeightChanged: (state, action: PayloadAction<{ key: string; weight: number }>) => { + const { key, weight } = action.payload; + const lora = state.loras[key]; if (!lora) { return; } lora.weight = weight; }, loraWeightReset: (state, action: PayloadAction) => { - const id = action.payload; - const lora = state.loras[id]; + const key = action.payload; + const lora = state.loras[key]; if (!lora) { return; } lora.weight = defaultLoRAConfig.weight; }, - loraIsEnabledChanged: (state, action: PayloadAction>) => { - const { id, isEnabled } = action.payload; - const lora = state.loras[id]; + loraIsEnabledChanged: (state, action: PayloadAction>) => { + const { key, isEnabled } = action.payload; + const lora = state.loras[key]; if (!lora) { return; } diff --git a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel.tsx b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel.tsx index 6b9abdbfec..7501151ba4 100644 --- a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel.tsx +++ b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel.tsx @@ -3,9 +3,9 @@ import { memo, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { ALL_BASE_MODELS } from 'services/api/constants'; import type { - DiffusersModelConfigEntity, - LoRAModelConfigEntity, - MainModelConfigEntity, + DiffusersModelConfig, + LoRAConfig, + MainModelConfig, } from 'services/api/endpoints/models'; import { useGetLoRAModelsQuery, useGetMainModelsQuery } from 'services/api/endpoints/models'; @@ -38,7 +38,7 @@ const ModelManagerPanel = () => { }; type ModelEditProps = { - model: MainModelConfigEntity | LoRAModelConfigEntity | undefined; + model: MainModelConfig | LoRAConfig | undefined; }; const ModelEdit = (props: ModelEditProps) => { @@ -50,7 +50,7 @@ const ModelEdit = (props: ModelEditProps) => { } if (model?.model_format === 'diffusers') { - return ; + return ; } if (model?.model_type === 'lora') { diff --git a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx index f4d271187d..43707308e0 100644 --- a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx +++ b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx @@ -21,14 +21,14 @@ import { memo, useCallback, useEffect, useState } from 'react'; import type { SubmitHandler } from 'react-hook-form'; import { useForm } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; -import type { CheckpointModelConfigEntity } from 'services/api/endpoints/models'; +import type { CheckpointModelConfig } from 'services/api/endpoints/models'; import { useGetCheckpointConfigsQuery, useUpdateMainModelsMutation } from 'services/api/endpoints/models'; import type { CheckpointModelConfig } from 'services/api/types'; import ModelConvert from './ModelConvert'; type CheckpointModelEditProps = { - model: CheckpointModelConfigEntity; + model: CheckpointModelConfig; }; const CheckpointModelEdit = (props: CheckpointModelEditProps) => { diff --git a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx index 4670f32157..bf6349234f 100644 --- a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx +++ b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx @@ -9,12 +9,12 @@ import { memo, useCallback } from 'react'; import type { SubmitHandler } from 'react-hook-form'; import { useForm } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; -import type { DiffusersModelConfigEntity } from 'services/api/endpoints/models'; +import type { DiffusersModelConfig } from 'services/api/endpoints/models'; import { useUpdateMainModelsMutation } from 'services/api/endpoints/models'; import type { DiffusersModelConfig } from 'services/api/types'; type DiffusersModelEditProps = { - model: DiffusersModelConfigEntity; + model: DiffusersModelConfig; }; const DiffusersModelEdit = (props: DiffusersModelEditProps) => { diff --git a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/LoRAModelEdit.tsx b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/LoRAModelEdit.tsx index 2baf735bee..75151cd001 100644 --- a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/LoRAModelEdit.tsx +++ b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/LoRAModelEdit.tsx @@ -8,12 +8,12 @@ import { memo, useCallback } from 'react'; import type { SubmitHandler } from 'react-hook-form'; import { useForm } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; -import type { LoRAModelConfigEntity } from 'services/api/endpoints/models'; +import type { LoRAConfig } from 'services/api/endpoints/models'; import { useUpdateLoRAModelsMutation } from 'services/api/endpoints/models'; -import type { LoRAModelConfig } from 'services/api/types'; +import type { LoRAConfig } from 'services/api/types'; type LoRAModelEditProps = { - model: LoRAModelConfigEntity; + model: LoRAConfig; }; const LoRAModelEdit = (props: LoRAModelEditProps) => { @@ -30,7 +30,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => { control, formState: { errors }, reset, - } = useForm({ + } = useForm({ defaultValues: { model_name: model.model_name ? model.model_name : '', base_model: model.base_model, @@ -42,7 +42,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => { mode: 'onChange', }); - const onSubmit = useCallback>( + const onSubmit = useCallback>( (values) => { const responseBody = { base_model: model.base_model, @@ -53,7 +53,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => { updateLoRAModel(responseBody) .unwrap() .then((payload) => { - reset(payload as LoRAModelConfig, { keepDefaultValues: true }); + reset(payload as LoRAConfig, { keepDefaultValues: true }); dispatch( addToast( makeToast({ @@ -106,7 +106,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => { {t('modelManager.description')} - control={control} name="base_model" /> + control={control} name="base_model" /> {t('modelManager.modelLocation')} diff --git a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/ModelList.tsx b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/ModelList.tsx index 94db3d20c3..dd74bb0c23 100644 --- a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/ModelList.tsx +++ b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/ModelList.tsx @@ -5,7 +5,7 @@ import type { ChangeEvent, PropsWithChildren } from 'react'; import { memo, useCallback, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { ALL_BASE_MODELS } from 'services/api/constants'; -import type { LoRAModelConfigEntity, MainModelConfigEntity } from 'services/api/endpoints/models'; +import type { LoRAConfig, MainModelConfig } from 'services/api/endpoints/models'; import { useGetLoRAModelsQuery, useGetMainModelsQuery } from 'services/api/endpoints/models'; import ModelListItem from './ModelListItem'; @@ -127,7 +127,7 @@ const ModelList = (props: ModelListProps) => { export default memo(ModelList); -const modelsFilter = ( +const modelsFilter = ( data: EntityState | undefined, model_type: ModelType, model_format: ModelFormat | undefined, @@ -163,7 +163,7 @@ StyledModelContainer.displayName = 'StyledModelContainer'; type ModelListWrapperProps = { title: string; - modelList: MainModelConfigEntity[] | LoRAModelConfigEntity[]; + modelList: MainModelConfig[] | LoRAConfig[]; selected: ModelListProps; }; diff --git a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/ModelListItem.tsx b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/ModelListItem.tsx index fdd13e09f5..835499d25a 100644 --- a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/ModelListItem.tsx +++ b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/ModelListItem.tsx @@ -15,11 +15,11 @@ import { makeToast } from 'features/system/util/makeToast'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { PiTrashSimpleBold } from 'react-icons/pi'; -import type { LoRAModelConfigEntity, MainModelConfigEntity } from 'services/api/endpoints/models'; +import type { LoRAConfig, MainModelConfig } from 'services/api/endpoints/models'; import { useDeleteLoRAModelsMutation, useDeleteMainModelsMutation } from 'services/api/endpoints/models'; type ModelListItemProps = { - model: MainModelConfigEntity | LoRAModelConfigEntity; + model: MainModelConfig | LoRAConfig; isSelected: boolean; setSelectedModelId: (v: string | undefined) => void; }; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlNetModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlNetModelFieldInputComponent.tsx index 22024f3d3c..53d800e7b6 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlNetModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlNetModelFieldInputComponent.tsx @@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; import { fieldControlNetModelValueChanged } from 'features/nodes/store/nodesSlice'; import type { ControlNetModelFieldInputInstance, ControlNetModelFieldInputTemplate } from 'features/nodes/types/field'; import { memo, useCallback } from 'react'; -import type { ControlNetModelConfigEntity } from 'services/api/endpoints/models'; +import type { ControlNetConfig } from 'services/api/endpoints/models'; import { useGetControlNetModelsQuery } from 'services/api/endpoints/models'; import type { FieldComponentProps } from './types'; @@ -17,7 +17,7 @@ const ControlNetModelFieldInputComponent = (props: Props) => { const { data, isLoading } = useGetControlNetModelsQuery(); const _onChange = useCallback( - (value: ControlNetModelConfigEntity | null) => { + (value: ControlNetConfig | null) => { if (!value) { return; } diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/IPAdapterModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/IPAdapterModelFieldInputComponent.tsx index 2cde347247..3f195ceb32 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/IPAdapterModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/IPAdapterModelFieldInputComponent.tsx @@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; import { fieldIPAdapterModelValueChanged } from 'features/nodes/store/nodesSlice'; import type { IPAdapterModelFieldInputInstance, IPAdapterModelFieldInputTemplate } from 'features/nodes/types/field'; import { memo, useCallback } from 'react'; -import type { IPAdapterModelConfigEntity } from 'services/api/endpoints/models'; +import type { IPAdapterConfig } from 'services/api/endpoints/models'; import { useGetIPAdapterModelsQuery } from 'services/api/endpoints/models'; import type { FieldComponentProps } from './types'; @@ -17,7 +17,7 @@ const IPAdapterModelFieldInputComponent = ( const { data: ipAdapterModels } = useGetIPAdapterModelsQuery(); const _onChange = useCallback( - (value: IPAdapterModelConfigEntity | null) => { + (value: IPAdapterConfig | null) => { if (!value) { return; } diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/LoRAModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/LoRAModelFieldInputComponent.tsx index 96208d68d4..eeb07fa08e 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/LoRAModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/LoRAModelFieldInputComponent.tsx @@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; import { fieldLoRAModelValueChanged } from 'features/nodes/store/nodesSlice'; import type { LoRAModelFieldInputInstance, LoRAModelFieldInputTemplate } from 'features/nodes/types/field'; import { memo, useCallback } from 'react'; -import type { LoRAModelConfigEntity } from 'services/api/endpoints/models'; +import type { LoRAConfig } from 'services/api/endpoints/models'; import { useGetLoRAModelsQuery } from 'services/api/endpoints/models'; import type { FieldComponentProps } from './types'; @@ -16,7 +16,7 @@ const LoRAModelFieldInputComponent = (props: Props) => { const dispatch = useAppDispatch(); const { data, isLoading } = useGetLoRAModelsQuery(); const _onChange = useCallback( - (value: LoRAModelConfigEntity | null) => { + (value: LoRAConfig | null) => { if (!value) { return; } diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/MainModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/MainModelFieldInputComponent.tsx index 64c6970cae..7ddde08816 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/MainModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/MainModelFieldInputComponent.tsx @@ -6,7 +6,7 @@ import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice'; import type { MainModelFieldInputInstance, MainModelFieldInputTemplate } from 'features/nodes/types/field'; import { memo, useCallback } from 'react'; import { NON_SDXL_MAIN_MODELS } from 'services/api/constants'; -import type { MainModelConfigEntity } from 'services/api/endpoints/models'; +import type { MainModelConfig } from 'services/api/endpoints/models'; import { useGetMainModelsQuery } from 'services/api/endpoints/models'; import type { FieldComponentProps } from './types'; @@ -18,7 +18,7 @@ const MainModelFieldInputComponent = (props: Props) => { const dispatch = useAppDispatch(); const { data, isLoading } = useGetMainModelsQuery(NON_SDXL_MAIN_MODELS); const _onChange = useCallback( - (value: MainModelConfigEntity | null) => { + (value: MainModelConfig | null) => { if (!value) { return; } diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/RefinerModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/RefinerModelFieldInputComponent.tsx index 98901af38b..9b5a1138d4 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/RefinerModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/RefinerModelFieldInputComponent.tsx @@ -9,7 +9,7 @@ import type { } from 'features/nodes/types/field'; import { memo, useCallback } from 'react'; import { REFINER_BASE_MODELS } from 'services/api/constants'; -import type { MainModelConfigEntity } from 'services/api/endpoints/models'; +import type { MainModelConfig } from 'services/api/endpoints/models'; import { useGetMainModelsQuery } from 'services/api/endpoints/models'; import type { FieldComponentProps } from './types'; @@ -21,7 +21,7 @@ const RefinerModelFieldInputComponent = (props: Props) => { const dispatch = useAppDispatch(); const { data, isLoading } = useGetMainModelsQuery(REFINER_BASE_MODELS); const _onChange = useCallback( - (value: MainModelConfigEntity | null) => { + (value: MainModelConfig | null) => { if (!value) { return; } diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SDXLMainModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SDXLMainModelFieldInputComponent.tsx index f5bc7ac3e4..cf353619e8 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SDXLMainModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SDXLMainModelFieldInputComponent.tsx @@ -6,7 +6,7 @@ import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice'; import type { SDXLMainModelFieldInputInstance, SDXLMainModelFieldInputTemplate } from 'features/nodes/types/field'; import { memo, useCallback } from 'react'; import { SDXL_MAIN_MODELS } from 'services/api/constants'; -import type { MainModelConfigEntity } from 'services/api/endpoints/models'; +import type { MainModelConfig } from 'services/api/endpoints/models'; import { useGetMainModelsQuery } from 'services/api/endpoints/models'; import type { FieldComponentProps } from './types'; @@ -18,7 +18,7 @@ const SDXLMainModelFieldInputComponent = (props: Props) => { const dispatch = useAppDispatch(); const { data, isLoading } = useGetMainModelsQuery(SDXL_MAIN_MODELS); const _onChange = useCallback( - (value: MainModelConfigEntity | null) => { + (value: MainModelConfig | null) => { if (!value) { return; } diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T2IAdapterModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T2IAdapterModelFieldInputComponent.tsx index 9baf0d2d61..8402c56343 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T2IAdapterModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T2IAdapterModelFieldInputComponent.tsx @@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; import { fieldT2IAdapterModelValueChanged } from 'features/nodes/store/nodesSlice'; import type { T2IAdapterModelFieldInputInstance, T2IAdapterModelFieldInputTemplate } from 'features/nodes/types/field'; import { memo, useCallback } from 'react'; -import type { T2IAdapterModelConfigEntity } from 'services/api/endpoints/models'; +import type { T2IAdapterConfig } from 'services/api/endpoints/models'; import { useGetT2IAdapterModelsQuery } from 'services/api/endpoints/models'; import type { FieldComponentProps } from './types'; @@ -18,7 +18,7 @@ const T2IAdapterModelFieldInputComponent = ( const { data: t2iAdapterModels } = useGetT2IAdapterModelsQuery(); const _onChange = useCallback( - (value: T2IAdapterModelConfigEntity | null) => { + (value: T2IAdapterConfig | null) => { if (!value) { return; } diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/VAEModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/VAEModelFieldInputComponent.tsx index 070178f32a..af09f2d8f2 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/VAEModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/VAEModelFieldInputComponent.tsx @@ -5,7 +5,7 @@ import { SyncModelsIconButton } from 'features/modelManager/components/SyncModel import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice'; import type { VAEModelFieldInputInstance, VAEModelFieldInputTemplate } from 'features/nodes/types/field'; import { memo, useCallback } from 'react'; -import type { VaeModelConfigEntity } from 'services/api/endpoints/models'; +import type { VAEConfig } from 'services/api/endpoints/models'; import { useGetVaeModelsQuery } from 'services/api/endpoints/models'; import type { FieldComponentProps } from './types'; @@ -17,7 +17,7 @@ const VAEModelFieldInputComponent = (props: Props) => { const dispatch = useAppDispatch(); const { data, isLoading } = useGetVaeModelsQuery(); const _onChange = useCallback( - (value: VaeModelConfigEntity | null) => { + (value: VAEConfig | null) => { if (!value) { return; } diff --git a/invokeai/frontend/web/src/features/nodes/types/common.ts b/invokeai/frontend/web/src/features/nodes/types/common.ts index ef579fce8c..891bd29bc8 100644 --- a/invokeai/frontend/web/src/features/nodes/types/common.ts +++ b/invokeai/frontend/web/src/features/nodes/types/common.ts @@ -67,11 +67,13 @@ export const zModelName = z.string().min(3); export const zModelIdentifier = z.object({ key: z.string().min(1), }); +export const zModelFieldBase = zModelIdentifier; +export const zModelIdentifierWithBase = zModelIdentifier.extend({ base: zBaseModel }); export type BaseModel = z.infer; export type ModelType = z.infer; export type ModelIdentifier = z.infer; - -export const zMainModelField = zModelIdentifier; +export type ModelIdentifierWithBase = z.infer; +export const zMainModelField = zModelFieldBase; export type MainModelField = z.infer; export const zSDXLRefinerModelField = zModelIdentifier; @@ -91,23 +93,23 @@ export const zSubModelType = z.enum([ ]); export type SubModelType = z.infer; -export const zVAEModelField = zModelIdentifier; +export const zVAEModelField = zModelFieldBase; export const zModelInfo = zModelIdentifier.extend({ submodel_type: zSubModelType.nullish(), }); export type ModelInfo = z.infer; -export const zLoRAModelField = zModelIdentifier; +export const zLoRAModelField = zModelFieldBase; export type LoRAModelField = z.infer; -export const zControlNetModelField = zModelIdentifier; +export const zControlNetModelField = zModelFieldBase; export type ControlNetModelField = z.infer; -export const zIPAdapterModelField = zModelIdentifier; +export const zIPAdapterModelField = zModelFieldBase; export type IPAdapterModelField = z.infer; -export const zT2IAdapterModelField = zModelIdentifier; +export const zT2IAdapterModelField = zModelFieldBase; export type T2IAdapterModelField = z.infer; export const zLoraInfo = zModelInfo.extend({ diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addControlNetToLinearGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addControlNetToLinearGraph.ts index d862b0986e..1853d3722c 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addControlNetToLinearGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addControlNetToLinearGraph.ts @@ -14,7 +14,7 @@ import { upsertMetadata } from './metadata'; export const addControlNetToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => { const validControlNets = selectValidControlNets(state.controlAdapters).filter( - (ca) => ca.model?.base_model === state.generation.model?.base_model + (ca) => ca.model?.base === state.generation.model?.base ); // const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addIPAdapterToLinearGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addIPAdapterToLinearGraph.ts index 3a79b78c6e..b51ac1bd52 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addIPAdapterToLinearGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addIPAdapterToLinearGraph.ts @@ -14,7 +14,7 @@ import { upsertMetadata } from './metadata'; export const addIPAdapterToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => { const validIPAdapters = selectValidIPAdapters(state.controlAdapters).filter( - (ca) => ca.model?.base_model === state.generation.model?.base_model + (ca) => ca.model?.base === state.generation.model?.base ); if (validIPAdapters.length) { diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addLoRAsToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addLoRAsToGraph.ts index 3ed71b7529..95bba9b441 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addLoRAsToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addLoRAsToGraph.ts @@ -28,6 +28,7 @@ export const addLoRAsToGraph = ( * So we need to inject a LoRA chain into the graph. */ + // TODO(MM2): check base model const enabledLoRAs = filter(state.lora.loras, (l) => l.isEnabled ?? false); const loraCount = size(enabledLoRAs); @@ -48,19 +49,19 @@ export const addLoRAsToGraph = ( const loraMetadata: CoreMetadataInvocation['loras'] = []; enabledLoRAs.forEach((lora) => { - const { model_name, base_model, weight } = lora; - const currentLoraNodeId = `${LORA_LOADER}_${model_name.replace('.', '_')}`; + const { key, weight } = lora; + const currentLoraNodeId = `${LORA_LOADER}_${key}`; const loraLoaderNode: LoraLoaderInvocation = { type: 'lora_loader', id: currentLoraNodeId, is_intermediate: true, - lora: { model_name, base_model }, + lora: { key }, weight, }; loraMetadata.push({ - lora: { model_name, base_model }, + lora: { key }, weight, }); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addSDXLLoRAstoGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addSDXLLoRAstoGraph.ts index 9553568922..7874b059c9 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addSDXLLoRAstoGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addSDXLLoRAstoGraph.ts @@ -31,6 +31,7 @@ export const addSDXLLoRAsToGraph = ( * So we need to inject a LoRA chain into the graph. */ + // TODO(MM2): check base model const enabledLoRAs = filter(state.lora.loras, (l) => l.isEnabled ?? false); const loraCount = size(enabledLoRAs); @@ -60,20 +61,20 @@ export const addSDXLLoRAsToGraph = ( let currentLoraIndex = 0; enabledLoRAs.forEach((lora) => { - const { model_name, base_model, weight } = lora; - const currentLoraNodeId = `${LORA_LOADER}_${model_name.replace('.', '_')}`; + const { key, weight } = lora; + const currentLoraNodeId = `${LORA_LOADER}_${key}`; const loraLoaderNode: SDXLLoraLoaderInvocation = { type: 'sdxl_lora_loader', id: currentLoraNodeId, is_intermediate: true, - lora: { model_name, base_model }, + lora: { key }, weight, }; loraMetadata.push( zLoRAMetadataItem.parse({ - lora: { model_name, base_model }, + lora: { key }, weight, }) ); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addT2IAdapterToLinearGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addT2IAdapterToLinearGraph.ts index d35f72a2b4..84002337d7 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addT2IAdapterToLinearGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addT2IAdapterToLinearGraph.ts @@ -14,7 +14,7 @@ import { upsertMetadata } from './metadata'; export const addT2IAdaptersToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => { const validT2IAdapters = selectValidT2IAdapters(state.controlAdapters).filter( - (ca) => ca.model?.base_model === state.generation.model?.base_model + (ca) => ca.model?.base === state.generation.model?.base ); if (validT2IAdapters.length) { diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasGraph.ts index 2b64f4898b..4ce2e4d673 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasGraph.ts @@ -19,7 +19,7 @@ export const buildCanvasGraph = ( let graph: NonNullableGraph; if (generationMode === 'txt2img') { - if (state.generation.model && state.generation.model.base_model === 'sdxl') { + if (state.generation.model && state.generation.model.base === 'sdxl') { graph = buildCanvasSDXLTextToImageGraph(state); } else { graph = buildCanvasTextToImageGraph(state); @@ -28,7 +28,7 @@ export const buildCanvasGraph = ( if (!canvasInitImage) { throw new Error('Missing canvas init image'); } - if (state.generation.model && state.generation.model.base_model === 'sdxl') { + if (state.generation.model && state.generation.model.base === 'sdxl') { graph = buildCanvasSDXLImageToImageGraph(state, canvasInitImage); } else { graph = buildCanvasImageToImageGraph(state, canvasInitImage); @@ -37,7 +37,7 @@ export const buildCanvasGraph = ( if (!canvasInitImage || !canvasMaskImage) { throw new Error('Missing canvas init and mask images'); } - if (state.generation.model && state.generation.model.base_model === 'sdxl') { + if (state.generation.model && state.generation.model.base === 'sdxl') { graph = buildCanvasSDXLInpaintGraph(state, canvasInitImage, canvasMaskImage); } else { graph = buildCanvasInpaintGraph(state, canvasInitImage, canvasMaskImage); @@ -46,7 +46,7 @@ export const buildCanvasGraph = ( if (!canvasInitImage) { throw new Error('Missing canvas init image'); } - if (state.generation.model && state.generation.model.base_model === 'sdxl') { + if (state.generation.model && state.generation.model.base === 'sdxl') { graph = buildCanvasSDXLOutpaintGraph(state, canvasInitImage, canvasMaskImage); } else { graph = buildCanvasOutpaintGraph(state, canvasInitImage, canvasMaskImage); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearBatchConfig.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearBatchConfig.ts index d0e331fb46..9fcc6afaa0 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearBatchConfig.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearBatchConfig.ts @@ -105,7 +105,7 @@ export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph, }); } - if (shouldConcatSDXLStylePrompt && model?.base_model === 'sdxl') { + if (shouldConcatSDXLStylePrompt && model?.base === 'sdxl') { if (graph.nodes[POSITIVE_CONDITIONING]) { firstBatchDatumList.push({ node_path: POSITIVE_CONDITIONING, diff --git a/invokeai/frontend/web/src/features/parameters/components/Advanced/ParamClipSkip.tsx b/invokeai/frontend/web/src/features/parameters/components/Advanced/ParamClipSkip.tsx index 621ed56ef6..c23d541613 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Advanced/ParamClipSkip.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Advanced/ParamClipSkip.tsx @@ -29,17 +29,17 @@ const ParamClipSkip = () => { if (!model) { return CLIP_SKIP_MAP['sd-1'].maxClip; } - return CLIP_SKIP_MAP[model.base_model].maxClip; + return CLIP_SKIP_MAP[model.base].maxClip; }, [model]); const sliderMarks = useMemo(() => { if (!model) { return CLIP_SKIP_MAP['sd-1'].markers; } - return CLIP_SKIP_MAP[model.base_model].markers; + return CLIP_SKIP_MAP[model.base].markers; }, [model]); - if (model?.base_model === 'sdxl') { + if (model?.base === 'sdxl') { return null; } diff --git a/invokeai/frontend/web/src/features/parameters/components/Core/ParamPositivePrompt.tsx b/invokeai/frontend/web/src/features/parameters/components/Core/ParamPositivePrompt.tsx index ae81f78fd1..a1852bfafe 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Core/ParamPositivePrompt.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Core/ParamPositivePrompt.tsx @@ -15,7 +15,7 @@ import { useTranslation } from 'react-i18next'; export const ParamPositivePrompt = memo(() => { const dispatch = useAppDispatch(); const prompt = useAppSelector((s) => s.generation.positivePrompt); - const baseModel = useAppSelector((s) => s.generation.model)?.base_model; + const baseModel = useAppSelector((s) => s.generation.model)?.base; const textareaRef = useRef(null); const { t } = useTranslation(); diff --git a/invokeai/frontend/web/src/features/parameters/components/MainModel/ParamMainModelSelect.tsx b/invokeai/frontend/web/src/features/parameters/components/MainModel/ParamMainModelSelect.tsx index c6c77b5fe9..18f780bdee 100644 --- a/invokeai/frontend/web/src/features/parameters/components/MainModel/ParamMainModelSelect.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/MainModel/ParamMainModelSelect.tsx @@ -9,7 +9,7 @@ import { pick } from 'lodash-es'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { NON_REFINER_BASE_MODELS } from 'services/api/constants'; -import type { MainModelConfigEntity } from 'services/api/endpoints/models'; +import type { MainModelConfig } from 'services/api/endpoints/models'; import { getModelId, mainModelsAdapterSelectors, useGetMainModelsQuery } from 'services/api/endpoints/models'; const selectModel = createMemoizedSelector(selectGenerationSlice, (generation) => generation.model); @@ -26,7 +26,7 @@ const ParamMainModelSelect = () => { return mainModelsAdapterSelectors.selectById(data, getModelId(model))?.description; }, [data, model]); const _onChange = useCallback( - (model: MainModelConfigEntity | null) => { + (model: MainModelConfig | null) => { if (!model) { return; } diff --git a/invokeai/frontend/web/src/features/parameters/components/VAEModel/ParamVAEModelSelect.tsx b/invokeai/frontend/web/src/features/parameters/components/VAEModel/ParamVAEModelSelect.tsx index f290378aa8..cc0164153d 100644 --- a/invokeai/frontend/web/src/features/parameters/components/VAEModel/ParamVAEModelSelect.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/VAEModel/ParamVAEModelSelect.tsx @@ -7,7 +7,7 @@ import { selectGenerationSlice, vaeSelected } from 'features/parameters/store/ge import { pick } from 'lodash-es'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; -import type { VaeModelConfigEntity } from 'services/api/endpoints/models'; +import type { VAEConfig } from 'services/api/endpoints/models'; import { useGetVaeModelsQuery } from 'services/api/endpoints/models'; const selector = createMemoizedSelector(selectGenerationSlice, (generation) => { @@ -21,7 +21,7 @@ const ParamVAEModelSelect = () => { const { model, vae } = useAppSelector(selector); const { data, isLoading } = useGetVaeModelsQuery(); const getIsDisabled = useCallback( - (vae: VaeModelConfigEntity): boolean => { + (vae: VAEConfig): boolean => { const isCompatible = model?.base_model === vae.base_model; const hasMainModel = Boolean(model?.base_model); return !hasMainModel || !isCompatible; @@ -29,7 +29,7 @@ const ParamVAEModelSelect = () => { [model?.base_model] ); const _onChange = useCallback( - (vae: VaeModelConfigEntity | null) => { + (vae: VAEConfig | null) => { dispatch(vaeSelected(vae ? pick(vae, 'base_model', 'model_name') : null)); }, [dispatch] diff --git a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts index 5a9fa6c66d..c8b17816bb 100644 --- a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts +++ b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts @@ -464,17 +464,15 @@ export const useRecallParameters = () => { return { lora: null, error: 'Invalid LoRA model' }; } - const { base_model, model_name } = loraMetadataItem.lora; + const { lora } = loraMetadataItem; - const matchingLoRA = loraModels - ? loraModelsAdapterSelectors.selectById(loraModels, `${base_model}/lora/${model_name}`) - : undefined; + const matchingLoRA = loraModels ? loraModelsAdapterSelectors.selectById(loraModels, lora.key) : undefined; if (!matchingLoRA) { return { lora: null, error: 'LoRA model is not installed' }; } - const isCompatibleBaseModel = matchingLoRA?.base_model === (newModel ?? model)?.base_model; + const isCompatibleBaseModel = matchingLoRA?.base === (newModel ?? model)?.base; if (!isCompatibleBaseModel) { return { @@ -520,17 +518,14 @@ export const useRecallParameters = () => { controlnetMetadataItem; const matchingControlNetModel = controlNetModels - ? controlNetModelsAdapterSelectors.selectById( - controlNetModels, - `${control_model.base_model}/controlnet/${control_model.model_name}` - ) + ? controlNetModelsAdapterSelectors.selectById(controlNetModels, control_model.key) : undefined; if (!matchingControlNetModel) { return { controlnet: null, error: 'ControlNet model is not installed' }; } - const isCompatibleBaseModel = matchingControlNetModel?.base_model === (newModel ?? model)?.base_model; + const isCompatibleBaseModel = matchingControlNetModel?.base === (newModel ?? model)?.base; if (!isCompatibleBaseModel) { return { @@ -597,17 +592,14 @@ export const useRecallParameters = () => { t2iAdapterMetadataItem; const matchingT2IAdapterModel = t2iAdapterModels - ? t2iAdapterModelsAdapterSelectors.selectById( - t2iAdapterModels, - `${t2i_adapter_model.base_model}/t2i_adapter/${t2i_adapter_model.model_name}` - ) + ? t2iAdapterModelsAdapterSelectors.selectById(t2iAdapterModels, t2i_adapter_model.key) : undefined; if (!matchingT2IAdapterModel) { return { controlnet: null, error: 'ControlNet model is not installed' }; } - const isCompatibleBaseModel = matchingT2IAdapterModel?.base_model === (newModel ?? model)?.base_model; + const isCompatibleBaseModel = matchingT2IAdapterModel?.base === (newModel ?? model)?.base; if (!isCompatibleBaseModel) { return { @@ -672,17 +664,14 @@ export const useRecallParameters = () => { const { image, ip_adapter_model, weight, begin_step_percent, end_step_percent } = ipAdapterMetadataItem; const matchingIPAdapterModel = ipAdapterModels - ? ipAdapterModelsAdapterSelectors.selectById( - ipAdapterModels, - `${ip_adapter_model.base_model}/ip_adapter/${ip_adapter_model.model_name}` - ) + ? ipAdapterModelsAdapterSelectors.selectById(ipAdapterModels, ip_adapter_model.key) : undefined; if (!matchingIPAdapterModel) { return { ipAdapter: null, error: 'IP Adapter model is not installed' }; } - const isCompatibleBaseModel = matchingIPAdapterModel?.base_model === (newModel ?? model)?.base_model; + const isCompatibleBaseModel = matchingIPAdapterModel?.base === (newModel ?? model)?.base; if (!isCompatibleBaseModel) { return { diff --git a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts index df98943cd3..1666a34d6a 100644 --- a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts +++ b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts @@ -158,15 +158,15 @@ export const generationSlice = createSlice({ // Clamp ClipSkip Based On Selected Model // TODO(psyche): remove this special handling when https://github.com/invoke-ai/InvokeAI/issues/4583 is resolved // WIP PR here: https://github.com/invoke-ai/InvokeAI/pull/4624 - if (newModel.base_model === 'sdxl') { + if (newModel.base === 'sdxl') { // We don't support clip skip for SDXL yet - it's not in the graphs state.clipSkip = 0; } else { - const { maxClip } = CLIP_SKIP_MAP[newModel.base_model]; + const { maxClip } = CLIP_SKIP_MAP[newModel.base]; state.clipSkip = clamp(state.clipSkip, 0, maxClip); } - if (action.meta.previousModel?.base_model === newModel.base_model) { + if (action.meta.previousModel?.base === newModel.base) { // The base model hasn't changed, we don't need to optimize the size return; } diff --git a/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts b/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts index 7a5efe9dcf..abd8ee2810 100644 --- a/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts +++ b/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts @@ -1,5 +1,6 @@ import { NUMPY_RAND_MAX } from 'app/constants'; import { + zBaseModel, zControlNetModelField, zIPAdapterModelField, zLoRAModelField, @@ -104,48 +105,48 @@ export const isParameterAspectRatio = (val: unknown): val is ParameterAspectRati // #endregion // #region Model -export const zParameterModel = zMainModelField; +export const zParameterModel = zMainModelField.extend({ base: zBaseModel }); export type ParameterModel = z.infer; export const isParameterModel = (val: unknown): val is ParameterModel => zParameterModel.safeParse(val).success; // #endregion // #region SDXL Refiner Model -export const zParameterSDXLRefinerModel = zSDXLRefinerModelField; +export const zParameterSDXLRefinerModel = zSDXLRefinerModelField.extend({ base: zBaseModel }); export type ParameterSDXLRefinerModel = z.infer; export const isParameterSDXLRefinerModel = (val: unknown): val is ParameterSDXLRefinerModel => zParameterSDXLRefinerModel.safeParse(val).success; // #endregion // #region VAE Model -export const zParameterVAEModel = zVAEModelField; +export const zParameterVAEModel = zVAEModelField.extend({ base: zBaseModel }); export type ParameterVAEModel = z.infer; export const isParameterVAEModel = (val: unknown): val is ParameterVAEModel => zParameterVAEModel.safeParse(val).success; // #endregion // #region LoRA Model -export const zParameterLoRAModel = zLoRAModelField; +export const zParameterLoRAModel = zLoRAModelField.extend({ base: zBaseModel }); export type ParameterLoRAModel = z.infer; export const isParameterLoRAModel = (val: unknown): val is ParameterLoRAModel => zParameterLoRAModel.safeParse(val).success; // #endregion // #region ControlNet Model -export const zParameterControlNetModel = zControlNetModelField; +export const zParameterControlNetModel = zControlNetModelField.extend({ base: zBaseModel }); export type ParameterControlNetModel = z.infer; export const isParameterControlNetModel = (val: unknown): val is ParameterControlNetModel => zParameterControlNetModel.safeParse(val).success; // #endregion // #region IP Adapter Model -export const zParameterIPAdapterModel = zIPAdapterModelField; +export const zParameterIPAdapterModel = zIPAdapterModelField.extend({ base: zBaseModel }); export type ParameterIPAdapterModel = z.infer; export const isParameterIPAdapterModel = (val: unknown): val is ParameterIPAdapterModel => zParameterIPAdapterModel.safeParse(val).success; // #endregion // #region T2I Adapter Model -export const zParameterT2IAdapterModel = zT2IAdapterModelField; +export const zParameterT2IAdapterModel = zT2IAdapterModelField.extend({ base: zBaseModel }); export type ParameterT2IAdapterModel = z.infer; export const isParameterT2IAdapterModel = (val: unknown): val is ParameterT2IAdapterModel => zParameterT2IAdapterModel.safeParse(val).success; diff --git a/invokeai/frontend/web/src/features/parameters/util/optimalDimension.ts b/invokeai/frontend/web/src/features/parameters/util/optimalDimension.ts index 1c550eb8a4..92b4f18272 100644 --- a/invokeai/frontend/web/src/features/parameters/util/optimalDimension.ts +++ b/invokeai/frontend/web/src/features/parameters/util/optimalDimension.ts @@ -1,12 +1,12 @@ -import type { ModelIdentifier } from 'features/nodes/types/common'; +import type { ModelIdentifierWithBase } from 'features/nodes/types/common'; /** * Gets the optimal dimension for a givel model, based on the model's base_model * @param model The model identifier * @returns The optimal dimension for the model */ -export const getOptimalDimension = (model?: ModelIdentifier | null): number => - model?.base_model === 'sdxl' ? 1024 : 512; +export const getOptimalDimension = (model?: ModelIdentifierWithBase | null): number => + model?.base === 'sdxl' ? 1024 : 512; const MIN_AREA_FACTOR = 0.8; const MAX_AREA_FACTOR = 1.2; diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerModelSelect.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerModelSelect.tsx index 5559ec76b7..4c54251557 100644 --- a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerModelSelect.tsx +++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerModelSelect.tsx @@ -7,12 +7,12 @@ import { refinerModelChanged, selectSdxlSlice } from 'features/sdxl/store/sdxlSl import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { REFINER_BASE_MODELS } from 'services/api/constants'; -import type { MainModelConfigEntity } from 'services/api/endpoints/models'; +import type { MainModelConfig } from 'services/api/endpoints/models'; import { useGetMainModelsQuery } from 'services/api/endpoints/models'; const selectModel = createMemoizedSelector(selectSdxlSlice, (sdxl) => sdxl.refinerModel); -const optionsFilter = (model: MainModelConfigEntity) => model.base_model === 'sdxl-refiner'; +const optionsFilter = (model: MainModelConfig) => model.base_model === 'sdxl-refiner'; const ParamSDXLRefinerModelSelect = () => { const dispatch = useAppDispatch(); @@ -20,7 +20,7 @@ const ParamSDXLRefinerModelSelect = () => { const { t } = useTranslation(); const { data, isLoading } = useGetMainModelsQuery(REFINER_BASE_MODELS); const _onChange = useCallback( - (model: MainModelConfigEntity | null) => { + (model: MainModelConfig | null) => { if (!model) { dispatch(refinerModelChanged(null)); return; diff --git a/invokeai/frontend/web/src/features/settingsAccordions/components/AdvancedSettingsAccordion/AdvancedSettingsAccordion.tsx b/invokeai/frontend/web/src/features/settingsAccordions/components/AdvancedSettingsAccordion/AdvancedSettingsAccordion.tsx index bceee915cd..fc8c54576c 100644 --- a/invokeai/frontend/web/src/features/settingsAccordions/components/AdvancedSettingsAccordion/AdvancedSettingsAccordion.tsx +++ b/invokeai/frontend/web/src/features/settingsAccordions/components/AdvancedSettingsAccordion/AdvancedSettingsAccordion.tsx @@ -24,7 +24,8 @@ const formLabelProps2: FormLabelProps = { const selectBadges = createMemoizedSelector(selectGenerationSlice, (generation) => { const badges: (string | number)[] = []; if (generation.vae) { - let vaeBadge = generation.vae.model_name; + // TODO(MM2): Fetch the vae name + let vaeBadge = generation.vae.key; if (generation.vaePrecision === 'fp16') { vaeBadge += ` ${generation.vaePrecision}`; } diff --git a/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx b/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx index 077875a8a7..cda7dcf6e9 100644 --- a/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx +++ b/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx @@ -35,9 +35,10 @@ const badgesSelector = createMemoizedSelector(selectLoraSlice, selectGenerationS const enabledLoRAsCount = filter(lora.loras, (l) => !!l.isEnabled).length; const loraTabBadges = enabledLoRAsCount ? [enabledLoRAsCount] : []; const accordionBadges: (string | number)[] = []; + // TODO(MM2): fetch model name if (generation.model) { - accordionBadges.push(generation.model.model_name); - accordionBadges.push(generation.model.base_model); + accordionBadges.push(generation.model.key); + accordionBadges.push(generation.model.base); } return { loraTabBadges, accordionBadges }; diff --git a/invokeai/frontend/web/src/features/settingsAccordions/components/ImageSettingsAccordion/ImageSettingsAccordion.tsx b/invokeai/frontend/web/src/features/settingsAccordions/components/ImageSettingsAccordion/ImageSettingsAccordion.tsx index 2f778fe717..8f876850e8 100644 --- a/invokeai/frontend/web/src/features/settingsAccordions/components/ImageSettingsAccordion/ImageSettingsAccordion.tsx +++ b/invokeai/frontend/web/src/features/settingsAccordions/components/ImageSettingsAccordion/ImageSettingsAccordion.tsx @@ -56,7 +56,7 @@ const selector = createMemoizedSelector( if (hrfEnabled) { badges.push('HiRes Fix'); } - return { badges, activeTabName, isSDXL: model?.base_model === 'sdxl' }; + return { badges, activeTabName, isSDXL: model?.base === 'sdxl' }; } ); diff --git a/invokeai/frontend/web/src/features/ui/components/ParametersPanel.tsx b/invokeai/frontend/web/src/features/ui/components/ParametersPanel.tsx index d52b2d9000..a74d132bd6 100644 --- a/invokeai/frontend/web/src/features/ui/components/ParametersPanel.tsx +++ b/invokeai/frontend/web/src/features/ui/components/ParametersPanel.tsx @@ -22,7 +22,7 @@ const overlayScrollbarsStyles: CSSProperties = { const ParametersPanel = () => { const activeTabName = useAppSelector(activeTabNameSelector); - const isSDXL = useAppSelector((s) => s.generation.model?.base_model === 'sdxl'); + const isSDXL = useAppSelector((s) => s.generation.model?.base === 'sdxl'); return ( diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index c11c8b45e5..97e221454d 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -1,64 +1,26 @@ -import type { EntityState } from '@reduxjs/toolkit'; +import type { EntityAdapter, EntityState } from '@reduxjs/toolkit'; import { createEntityAdapter } from '@reduxjs/toolkit'; import { getSelectorsOptions } from 'app/store/createMemoizedSelector'; -import { cloneDeep } from 'lodash-es'; import queryString from 'query-string'; import type { operations, paths } from 'services/api/schema'; import type { AnyModelConfig, BaseModelType, - CheckpointModelConfig, - ControlNetModelConfig, - DiffusersModelConfig, + ControlNetConfig, ImportModelConfig, - IPAdapterModelConfig, - LoRAModelConfig, + IPAdapterConfig, + LoRAConfig, MainModelConfig, MergeModelConfig, ModelType, - T2IAdapterModelConfig, - TextualInversionModelConfig, - VaeModelConfig, + T2IAdapterConfig, + TextualInversionConfig, + VAEConfig, } from 'services/api/types'; -import type { ApiTagDescription } from '..'; +import type { ApiTagDescription, tagTypes } from '..'; import { api, LIST_TAG } from '..'; -export type DiffusersModelConfigEntity = DiffusersModelConfig & { id: string }; -export type CheckpointModelConfigEntity = CheckpointModelConfig & { - id: string; -}; -export type MainModelConfigEntity = DiffusersModelConfigEntity | CheckpointModelConfigEntity; - -export type LoRAModelConfigEntity = LoRAModelConfig & { id: string }; - -export type ControlNetModelConfigEntity = ControlNetModelConfig & { - id: string; -}; - -export type IPAdapterModelConfigEntity = IPAdapterModelConfig & { - id: string; -}; - -export type T2IAdapterModelConfigEntity = T2IAdapterModelConfig & { - id: string; -}; - -export type TextualInversionModelConfigEntity = TextualInversionModelConfig & { - id: string; -}; - -export type VaeModelConfigEntity = VaeModelConfig & { id: string }; - -export type AnyModelConfigEntity = - | MainModelConfigEntity - | LoRAModelConfigEntity - | ControlNetModelConfigEntity - | IPAdapterModelConfigEntity - | T2IAdapterModelConfigEntity - | TextualInversionModelConfigEntity - | VaeModelConfigEntity; - type UpdateMainModelArg = { base_model: BaseModelType; model_name: string; @@ -68,11 +30,11 @@ type UpdateMainModelArg = { type UpdateLoRAModelArg = { base_model: BaseModelType; model_name: string; - body: LoRAModelConfig; + body: LoRAConfig; }; type UpdateMainModelResponse = - paths['/api/v1/models/{base_model}/{model_type}/{model_name}']['patch']['responses']['200']['content']['application/json']; + paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json']; type UpdateLoRAModelResponse = UpdateMainModelResponse; @@ -128,59 +90,71 @@ type CheckpointConfigsResponse = type SearchFolderArg = operations['search_for_models']['parameters']['query']; -export const mainModelsAdapter = createEntityAdapter({ - sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), +export const mainModelsAdapter = createEntityAdapter({ + selectId: (entity) => entity.key, + sortComparer: (a, b) => a.name.localeCompare(b.name), }); export const mainModelsAdapterSelectors = mainModelsAdapter.getSelectors(undefined, getSelectorsOptions); -export const loraModelsAdapter = createEntityAdapter({ - sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), +export const loraModelsAdapter = createEntityAdapter({ + selectId: (entity) => entity.key, + sortComparer: (a, b) => a.name.localeCompare(b.name), }); export const loraModelsAdapterSelectors = loraModelsAdapter.getSelectors(undefined, getSelectorsOptions); -export const controlNetModelsAdapter = createEntityAdapter({ - sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), +export const controlNetModelsAdapter = createEntityAdapter({ + selectId: (entity) => entity.key, + sortComparer: (a, b) => a.name.localeCompare(b.name), }); export const controlNetModelsAdapterSelectors = controlNetModelsAdapter.getSelectors(undefined, getSelectorsOptions); -export const ipAdapterModelsAdapter = createEntityAdapter({ - sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), +export const ipAdapterModelsAdapter = createEntityAdapter({ + selectId: (entity) => entity.key, + sortComparer: (a, b) => a.name.localeCompare(b.name), }); export const ipAdapterModelsAdapterSelectors = ipAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions); -export const t2iAdapterModelsAdapter = createEntityAdapter({ - sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), +export const t2iAdapterModelsAdapter = createEntityAdapter({ + selectId: (entity) => entity.key, + sortComparer: (a, b) => a.name.localeCompare(b.name), }); export const t2iAdapterModelsAdapterSelectors = t2iAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions); -export const textualInversionModelsAdapter = createEntityAdapter({ - sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), +export const textualInversionModelsAdapter = createEntityAdapter({ + selectId: (entity) => entity.key, + sortComparer: (a, b) => a.name.localeCompare(b.name), }); export const textualInversionModelsAdapterSelectors = textualInversionModelsAdapter.getSelectors( undefined, getSelectorsOptions ); -export const vaeModelsAdapter = createEntityAdapter({ - sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), +export const vaeModelsAdapter = createEntityAdapter({ + selectId: (entity) => entity.key, + sortComparer: (a, b) => a.name.localeCompare(b.name), }); export const vaeModelsAdapterSelectors = vaeModelsAdapter.getSelectors(undefined, getSelectorsOptions); -export const getModelId = ({ - base_model, - model_type, - model_name, -}: Pick) => `${base_model}/${model_type}/${model_name}`; +const buildProvidesTags = + (tagType: (typeof tagTypes)[number]) => + (result: EntityState | undefined) => { + const tags: ApiTagDescription[] = [{ type: tagType, id: LIST_TAG }, 'Model']; -const createModelEntities = (models: AnyModelConfig[]): T[] => { - const entityArray: T[] = []; - models.forEach((model) => { - const entity = { - ...cloneDeep(model), - id: getModelId(model), - } as T; - entityArray.push(entity); - }); - return entityArray; -}; + if (result) { + tags.push( + ...result.ids.map((id) => ({ + type: tagType, + id, + })) + ); + } + + return tags; + }; + +const buildTransformResponse = + (adapter: EntityAdapter) => + (response: { models: T[] }) => { + return adapter.setAll(adapter.getInitialState(), response.models); + }; export const modelsApi = api.injectEndpoints({ endpoints: (build) => ({ - getMainModels: build.query, BaseModelType[]>({ + getMainModels: build.query, BaseModelType[]>({ query: (base_models) => { const params = { model_type: 'main', @@ -190,24 +164,8 @@ export const modelsApi = api.injectEndpoints({ const query = queryString.stringify(params, { arrayFormat: 'none' }); return `models/?${query}`; }, - providesTags: (result) => { - const tags: ApiTagDescription[] = [{ type: 'MainModel', id: LIST_TAG }, 'Model']; - - if (result) { - tags.push( - ...result.ids.map((id) => ({ - type: 'MainModel' as const, - id, - })) - ); - } - - return tags; - }, - transformResponse: (response: { models: MainModelConfig[] }) => { - const entities = createModelEntities(response.models); - return mainModelsAdapter.setAll(mainModelsAdapter.getInitialState(), entities); - }, + providesTags: buildProvidesTags('MainModel'), + transformResponse: buildTransformResponse(mainModelsAdapter), }), updateMainModels: build.mutation({ query: ({ base_model, model_name, body }) => { @@ -277,26 +235,10 @@ export const modelsApi = api.injectEndpoints({ }, invalidatesTags: ['Model'], }), - getLoRAModels: build.query, void>({ + getLoRAModels: build.query, void>({ query: () => ({ url: 'models/', params: { model_type: 'lora' } }), - providesTags: (result) => { - const tags: ApiTagDescription[] = [{ type: 'LoRAModel', id: LIST_TAG }, 'Model']; - - if (result) { - tags.push( - ...result.ids.map((id) => ({ - type: 'LoRAModel' as const, - id, - })) - ); - } - - return tags; - }, - transformResponse: (response: { models: LoRAModelConfig[] }) => { - const entities = createModelEntities(response.models); - return loraModelsAdapter.setAll(loraModelsAdapter.getInitialState(), entities); - }, + providesTags: buildProvidesTags('LoRAModel'), + transformResponse: buildTransformResponse(loraModelsAdapter), }), updateLoRAModels: build.mutation({ query: ({ base_model, model_name, body }) => { @@ -317,110 +259,30 @@ export const modelsApi = api.injectEndpoints({ }, invalidatesTags: [{ type: 'LoRAModel', id: LIST_TAG }], }), - getControlNetModels: build.query, void>({ + getControlNetModels: build.query, void>({ query: () => ({ url: 'models/', params: { model_type: 'controlnet' } }), - providesTags: (result) => { - const tags: ApiTagDescription[] = [{ type: 'ControlNetModel', id: LIST_TAG }, 'Model']; - - if (result) { - tags.push( - ...result.ids.map((id) => ({ - type: 'ControlNetModel' as const, - id, - })) - ); - } - - return tags; - }, - transformResponse: (response: { models: ControlNetModelConfig[] }) => { - const entities = createModelEntities(response.models); - return controlNetModelsAdapter.setAll(controlNetModelsAdapter.getInitialState(), entities); - }, + providesTags: buildProvidesTags('ControlNetModel'), + transformResponse: buildTransformResponse(controlNetModelsAdapter), }), - getIPAdapterModels: build.query, void>({ + getIPAdapterModels: build.query, void>({ query: () => ({ url: 'models/', params: { model_type: 'ip_adapter' } }), - providesTags: (result) => { - const tags: ApiTagDescription[] = [{ type: 'IPAdapterModel', id: LIST_TAG }, 'Model']; - - if (result) { - tags.push( - ...result.ids.map((id) => ({ - type: 'IPAdapterModel' as const, - id, - })) - ); - } - - return tags; - }, - transformResponse: (response: { models: IPAdapterModelConfig[] }) => { - const entities = createModelEntities(response.models); - return ipAdapterModelsAdapter.setAll(ipAdapterModelsAdapter.getInitialState(), entities); - }, + providesTags: buildProvidesTags('IPAdapterModel'), + transformResponse: buildTransformResponse(ipAdapterModelsAdapter), }), - getT2IAdapterModels: build.query, void>({ + getT2IAdapterModels: build.query, void>({ query: () => ({ url: 'models/', params: { model_type: 't2i_adapter' } }), - providesTags: (result) => { - const tags: ApiTagDescription[] = [{ type: 'T2IAdapterModel', id: LIST_TAG }, 'Model']; - - if (result) { - tags.push( - ...result.ids.map((id) => ({ - type: 'T2IAdapterModel' as const, - id, - })) - ); - } - - return tags; - }, - transformResponse: (response: { models: T2IAdapterModelConfig[] }) => { - const entities = createModelEntities(response.models); - return t2iAdapterModelsAdapter.setAll(t2iAdapterModelsAdapter.getInitialState(), entities); - }, + providesTags: buildProvidesTags('T2IAdapterModel'), + transformResponse: buildTransformResponse(t2iAdapterModelsAdapter), }), - getVaeModels: build.query, void>({ + getVaeModels: build.query, void>({ query: () => ({ url: 'models/', params: { model_type: 'vae' } }), - providesTags: (result) => { - const tags: ApiTagDescription[] = [{ type: 'VaeModel', id: LIST_TAG }, 'Model']; - - if (result) { - tags.push( - ...result.ids.map((id) => ({ - type: 'VaeModel' as const, - id, - })) - ); - } - - return tags; - }, - transformResponse: (response: { models: VaeModelConfig[] }) => { - const entities = createModelEntities(response.models); - return vaeModelsAdapter.setAll(vaeModelsAdapter.getInitialState(), entities); - }, + providesTags: buildProvidesTags('VaeModel'), + transformResponse: buildTransformResponse(vaeModelsAdapter), }), - getTextualInversionModels: build.query, void>({ + getTextualInversionModels: build.query, void>({ query: () => ({ url: 'models/', params: { model_type: 'embedding' } }), - providesTags: (result) => { - const tags: ApiTagDescription[] = [{ type: 'TextualInversionModel', id: LIST_TAG }, 'Model']; - - if (result) { - tags.push( - ...result.ids.map((id) => ({ - type: 'TextualInversionModel' as const, - id, - })) - ); - } - - return tags; - }, - transformResponse: (response: { models: TextualInversionModelConfig[] }) => { - const entities = createModelEntities(response.models); - return textualInversionModelsAdapter.setAll(textualInversionModelsAdapter.getInitialState(), entities); - }, + providesTags: buildProvidesTags('TextualInversionModel'), + transformResponse: buildTransformResponse(textualInversionModelsAdapter), }), getModelsInFolder: build.query({ query: (arg) => { diff --git a/invokeai/frontend/web/src/services/api/types.ts b/invokeai/frontend/web/src/services/api/types.ts index f9a1decf65..7a02cc5568 100644 --- a/invokeai/frontend/web/src/services/api/types.ts +++ b/invokeai/frontend/web/src/services/api/types.ts @@ -2,6 +2,7 @@ import type { UseToastOptions } from '@invoke-ai/ui-library'; import type { EntityState } from '@reduxjs/toolkit'; import type { components, paths } from 'services/api/schema'; import type { O } from 'ts-toolbelt'; +import type { SetRequired } from 'type-fest'; export type S = components['schemas']; @@ -54,40 +55,34 @@ export type LoRAModelFormat = S['LoRAModelFormat']; export type ControlNetModelField = S['ControlNetModelField']; export type IPAdapterModelField = S['IPAdapterModelField']; export type T2IAdapterModelField = S['T2IAdapterModelField']; -export type ModelsList = S['invokeai__app__api__routers__models__ModelsList']; export type ControlField = S['ControlField']; export type IPAdapterField = S['IPAdapterField']; // Model Configs -export type LoRAModelConfig = S['LoRAModelConfig']; -export type VaeModelConfig = S['VaeModelConfig']; -export type ControlNetModelCheckpointConfig = S['ControlNetModelCheckpointConfig']; -export type ControlNetModelDiffusersConfig = S['ControlNetModelDiffusersConfig']; -export type ControlNetModelConfig = ControlNetModelCheckpointConfig | ControlNetModelDiffusersConfig; -export type IPAdapterModelInvokeAIConfig = S['IPAdapterModelInvokeAIConfig']; -export type IPAdapterModelConfig = IPAdapterModelInvokeAIConfig; -export type T2IAdapterModelDiffusersConfig = S['T2IAdapterModelDiffusersConfig']; -export type T2IAdapterModelConfig = T2IAdapterModelDiffusersConfig; -export type TextualInversionModelConfig = S['TextualInversionModelConfig']; -export type DiffusersModelConfig = - | S['StableDiffusion1ModelDiffusersConfig'] - | S['StableDiffusion2ModelDiffusersConfig'] - | S['StableDiffusionXLModelDiffusersConfig']; -export type CheckpointModelConfig = - | S['StableDiffusion1ModelCheckpointConfig'] - | S['StableDiffusion2ModelCheckpointConfig'] - | S['StableDiffusionXLModelCheckpointConfig']; + +// TODO(MM2): Can we make key required in the pydantic model? +type KeyRequired = SetRequired; +export type LoRAConfig = KeyRequired; +// TODO(MM2): Can we rename this from Vae -> VAE +export type VAEConfig = KeyRequired | KeyRequired; +export type ControlNetConfig = KeyRequired | KeyRequired; +export type IPAdapterConfig = KeyRequired; +// TODO(MM2): Can we rename this to T2IAdapterConfig +export type T2IAdapterConfig = KeyRequired; +export type TextualInversionConfig = KeyRequired; +export type DiffusersModelConfig = KeyRequired; +export type CheckpointModelConfig = KeyRequired; export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig; export type AnyModelConfig = - | LoRAModelConfig - | VaeModelConfig - | ControlNetModelConfig - | IPAdapterModelConfig - | T2IAdapterModelConfig - | TextualInversionModelConfig + | LoRAConfig + | VAEConfig + | ControlNetConfig + | IPAdapterConfig + | T2IAdapterConfig + | TextualInversionConfig | MainModelConfig; -export type MergeModelConfig = S['Body_merge_models']; +export type MergeModelConfig = S['Body_merge']; export type ImportModelConfig = S['Body_import_model']; // Graphs