diff --git a/invokeai/frontend/web/.storybook/ReduxInit.tsx b/invokeai/frontend/web/.storybook/ReduxInit.tsx index 55d0132242..7d3f8e0d2b 100644 --- a/invokeai/frontend/web/.storybook/ReduxInit.tsx +++ b/invokeai/frontend/web/.storybook/ReduxInit.tsx @@ -10,13 +10,7 @@ export const ReduxInit = memo((props: PropsWithChildren) => { const dispatch = useAppDispatch(); useGlobalModifiersInit(); useEffect(() => { - dispatch( - modelChanged({ - model_name: 'test_model', - base_model: 'sd-1', - model_type: 'main', - }) - ); + dispatch(modelChanged({ key: 'test_model', base: 'sd-1' })); }, []); return props.children; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts index e1e13fadbe..d1cb692c98 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts @@ -19,7 +19,7 @@ export const addEnqueueRequestedLinear = () => { let graph; - if (model && model.base_model === 'sdxl') { + if (model && model.base === 'sdxl') { if (action.payload.tabName === 'txt2img') { graph = buildLinearSDXLTextToImageGraph(state); } else { diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts index 7638c5522a..35e2ad5f9b 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts @@ -30,8 +30,8 @@ export const addModelSelectedListener = () => { const newModel = result.data; - const newBaseModel = newModel.base_model; - const didBaseModelChange = state.generation.model?.base_model !== newBaseModel; + const newBaseModel = newModel.base; + const didBaseModelChange = state.generation.model?.base !== newBaseModel; if (didBaseModelChange) { // we may need to reset some incompatible submodels @@ -39,7 +39,7 @@ export const addModelSelectedListener = () => { // handle incompatible loras forEach(state.lora.loras, (lora, id) => { - if (lora.base_model !== newBaseModel) { + if (lora.base !== newBaseModel) { dispatch(loraRemoved(id)); modelsCleared += 1; } @@ -47,14 +47,14 @@ export const addModelSelectedListener = () => { // handle incompatible vae const { vae } = state.generation; - if (vae && vae.base_model !== newBaseModel) { + if (vae && vae.base !== newBaseModel) { dispatch(vaeSelected(null)); modelsCleared += 1; } // handle incompatible controlnets selectControlAdapterAll(state.controlAdapters).forEach((ca) => { - if (ca.model?.base_model !== newBaseModel) { + if (ca.model?.base !== newBaseModel) { dispatch(controlAdapterIsEnabledChanged({ id: ca.id, isEnabled: false })); modelsCleared += 1; } diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts index 0ffe88cd07..366644fa68 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts @@ -34,14 +34,7 @@ export const addModelsLoadedListener = () => { return; } - const isCurrentModelAvailable = currentModel - ? models.some( - (m) => - m.model_name === currentModel.model_name && - m.base_model === currentModel.base_model && - m.model_type === currentModel.model_type - ) - : false; + const isCurrentModelAvailable = currentModel ? models.some((m) => m.key === currentModel.key) : false; if (isCurrentModelAvailable) { return; @@ -74,14 +67,7 @@ export const addModelsLoadedListener = () => { return; } - const isCurrentModelAvailable = currentModel - ? models.some( - (m) => - m.model_name === currentModel.model_name && - m.base_model === currentModel.base_model && - m.model_type === currentModel.model_type - ) - : false; + const isCurrentModelAvailable = currentModel ? models.some((m) => m.key === currentModel.key) : false; if (!isCurrentModelAvailable) { dispatch(refinerModelChanged(null)); @@ -103,10 +89,7 @@ export const addModelsLoadedListener = () => { return; } - const isCurrentVAEAvailable = some( - action.payload.entities, - (m) => m?.model_name === currentVae?.model_name && m?.base_model === currentVae?.base_model - ); + const isCurrentVAEAvailable = some(action.payload.entities, (m) => m?.key === currentVae?.key); if (isCurrentVAEAvailable) { return; @@ -140,10 +123,7 @@ export const addModelsLoadedListener = () => { const loras = getState().lora.loras; forEach(loras, (lora, id) => { - const isLoRAAvailable = some( - action.payload.entities, - (m) => m?.model_name === lora?.model_name && m?.base_model === lora?.base_model - ); + const isLoRAAvailable = some(action.payload.entities, (m) => m?.key === lora?.key); if (isLoRAAvailable) { return; @@ -161,10 +141,7 @@ export const addModelsLoadedListener = () => { log.info({ models: action.payload.entities }, `ControlNet models loaded (${action.payload.ids.length})`); selectAllControlNets(getState().controlAdapters).forEach((ca) => { - const isModelAvailable = some( - action.payload.entities, - (m) => m?.model_name === ca?.model?.model_name && m?.base_model === ca?.model?.base_model - ); + const isModelAvailable = some(action.payload.entities, (m) => m?.key === ca?.model?.key); if (isModelAvailable) { return; @@ -182,10 +159,7 @@ export const addModelsLoadedListener = () => { log.info({ models: action.payload.entities }, `T2I Adapter models loaded (${action.payload.ids.length})`); selectAllT2IAdapters(getState().controlAdapters).forEach((ca) => { - const isModelAvailable = some( - action.payload.entities, - (m) => m?.model_name === ca?.model?.model_name && m?.base_model === ca?.model?.base_model - ); + const isModelAvailable = some(action.payload.entities, (m) => m?.key === ca?.model?.key); if (isModelAvailable) { return; @@ -203,10 +177,7 @@ export const addModelsLoadedListener = () => { log.info({ models: action.payload.entities }, `IP Adapter models loaded (${action.payload.ids.length})`); selectAllIPAdapters(getState().controlAdapters).forEach((ca) => { - const isModelAvailable = some( - action.payload.entities, - (m) => m?.model_name === ca?.model?.model_name && m?.base_model === ca?.model?.base_model - ); + const isModelAvailable = some(action.payload.entities, (m) => m?.key === ca?.model?.key); if (isModelAvailable) { return; diff --git a/invokeai/frontend/web/src/common/hooks/useGroupedModelCombobox.ts b/invokeai/frontend/web/src/common/hooks/useGroupedModelCombobox.ts index eb55db79ca..875ce1f1c4 100644 --- a/invokeai/frontend/web/src/common/hooks/useGroupedModelCombobox.ts +++ b/invokeai/frontend/web/src/common/hooks/useGroupedModelCombobox.ts @@ -5,10 +5,10 @@ import type { GroupBase } from 'chakra-react-select'; import { groupBy, map, reduce } from 'lodash-es'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import type { AnyModelConfigEntity } from 'services/api/endpoints/models'; +import type { AnyModelConfig } from 'services/api/endpoints/models'; import { getModelId } from 'services/api/endpoints/models'; -type UseGroupedModelComboboxArg = { +type UseGroupedModelComboboxArg = { modelEntities: EntityState | undefined; selectedModel?: Pick | null; onChange: (value: T | null) => void; @@ -24,7 +24,7 @@ type UseGroupedModelComboboxReturn = { noOptionsMessage: () => string; }; -export const useGroupedModelCombobox = ( +export const useGroupedModelCombobox = ( arg: UseGroupedModelComboboxArg ): UseGroupedModelComboboxReturn => { const { t } = useTranslation(); diff --git a/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts b/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts index baa704e75c..b31efed970 100644 --- a/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts +++ b/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts @@ -105,7 +105,7 @@ const selector = createMemoizedSelector( number: i + 1, }) ); - } else if (ca.model.base_model !== model?.base_model) { + } else if (ca.model.base !== model?.base) { // This should never happen, just a sanity check reasons.push( i18n.t('parameters.invoke.incompatibleBaseModelForControlAdapter', { diff --git a/invokeai/frontend/web/src/common/hooks/useModelCombobox.ts b/invokeai/frontend/web/src/common/hooks/useModelCombobox.ts index 880b316379..07e6aeb34c 100644 --- a/invokeai/frontend/web/src/common/hooks/useModelCombobox.ts +++ b/invokeai/frontend/web/src/common/hooks/useModelCombobox.ts @@ -3,10 +3,10 @@ import type { EntityState } from '@reduxjs/toolkit'; import { map } from 'lodash-es'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import type { AnyModelConfigEntity } from 'services/api/endpoints/models'; +import type { AnyModelConfig } from 'services/api/endpoints/models'; import { getModelId } from 'services/api/endpoints/models'; -type UseModelComboboxArg = { +type UseModelComboboxArg = { modelEntities: EntityState | undefined; selectedModel?: Pick | null; onChange: (value: T | null) => void; @@ -23,9 +23,7 @@ type UseModelComboboxReturn = { noOptionsMessage: () => string; }; -export const useModelCombobox = ( - arg: UseModelComboboxArg -): UseModelComboboxReturn => { +export const useModelCombobox = (arg: UseModelComboboxArg): UseModelComboboxReturn => { const { t } = useTranslation(); const { modelEntities, selectedModel, getIsDisabled, onChange, isLoading, optionsFilter = () => true } = arg; const options = useMemo(() => { diff --git a/invokeai/frontend/web/src/common/hooks/useModelCustomSelect.ts b/invokeai/frontend/web/src/common/hooks/useModelCustomSelect.ts new file mode 100644 index 0000000000..07ea98a274 --- /dev/null +++ b/invokeai/frontend/web/src/common/hooks/useModelCustomSelect.ts @@ -0,0 +1,88 @@ +import type { Item } from '@invoke-ai/ui-library'; +import type { EntityState } from '@reduxjs/toolkit'; +import { EMPTY_ARRAY } from 'app/store/util'; +import type { ModelIdentifierWithBase } from 'features/nodes/types/common'; +import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants'; +import { filter } from 'lodash-es'; +import { useCallback, useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; +import type { AnyModelConfig } from 'services/api/types'; + +type UseModelCustomSelectArg = { + data: EntityState | undefined; + isLoading: boolean; + selectedModel?: ModelIdentifierWithBase | null; + onChange: (value: T | null) => void; + modelFilter?: (model: T) => boolean; + isModelDisabled?: (model: T) => boolean; +}; + +type UseModelCustomSelectReturn = { + selectedItem: Item | null; + items: Item[]; + onChange: (item: Item | null) => void; + placeholder: string; +}; + +const modelFilterDefault = () => true; +const isModelDisabledDefault = () => false; + +export const useModelCustomSelect = ({ + data, + isLoading, + selectedModel, + onChange, + modelFilter = modelFilterDefault, + isModelDisabled = isModelDisabledDefault, +}: UseModelCustomSelectArg): UseModelCustomSelectReturn => { + const { t } = useTranslation(); + + const items: Item[] = useMemo( + () => + data + ? filter(data.entities, modelFilter).map((m) => ({ + label: m.name, + value: m.key, + description: m.description, + group: MODEL_TYPE_SHORT_MAP[m.base], + isDisabled: isModelDisabled(m), + })) + : EMPTY_ARRAY, + [data, isModelDisabled, modelFilter] + ); + + const _onChange = useCallback( + (item: Item | null) => { + if (!item || !data) { + return; + } + const model = data.entities[item.value]; + if (!model) { + return; + } + onChange(model); + }, + [data, onChange] + ); + + const selectedItem = useMemo(() => items.find((o) => o.value === selectedModel?.key) ?? null, [selectedModel, items]); + + const placeholder = useMemo(() => { + if (isLoading) { + return t('common.loading'); + } + + if (items.length === 0) { + return t('models.noModelsAvailable'); + } + + return t('models.selectModel'); + }, [isLoading, items, t]); + + return { + items, + onChange: _onChange, + selectedItem, + placeholder, + }; +}; diff --git a/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts b/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts index cd734d3f00..f50d52c1bf 100644 --- a/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts +++ b/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts @@ -626,7 +626,7 @@ export const canvasSlice = createSlice({ }, extraReducers: (builder) => { builder.addCase(modelChanged, (state, action) => { - if (action.meta.previousModel?.base_model === action.payload?.base_model) { + if (action.meta.previousModel?.base === action.payload?.base) { // The base model hasn't changed, we don't need to optimize the size return; } diff --git a/invokeai/frontend/web/src/features/controlAdapters/components/parameters/ParamControlAdapterModel.tsx b/invokeai/frontend/web/src/features/controlAdapters/components/parameters/ParamControlAdapterModel.tsx index 13851b143c..696bf47b2a 100644 --- a/invokeai/frontend/web/src/features/controlAdapters/components/parameters/ParamControlAdapterModel.tsx +++ b/invokeai/frontend/web/src/features/controlAdapters/components/parameters/ParamControlAdapterModel.tsx @@ -1,49 +1,37 @@ -import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library'; -import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; +import { CustomSelect, FormControl } from '@invoke-ai/ui-library'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; +import { useModelCustomSelect } from 'common/hooks/useModelCustomSelect'; import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled'; import { useControlAdapterModel } from 'features/controlAdapters/hooks/useControlAdapterModel'; -import { useControlAdapterModelEntities } from 'features/controlAdapters/hooks/useControlAdapterModelEntities'; +import { useControlAdapterModelQuery } from 'features/controlAdapters/hooks/useControlAdapterModelQuery'; import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType'; import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice'; -import { selectGenerationSlice } from 'features/parameters/store/generationSlice'; import { pick } from 'lodash-es'; import { memo, useCallback, useMemo } from 'react'; -import { useTranslation } from 'react-i18next'; -import type { - ControlNetModelConfigEntity, - IPAdapterModelConfigEntity, - T2IAdapterModelConfigEntity, -} from 'services/api/endpoints/models'; -import type { AnyModelConfig } from 'services/api/types'; +import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'services/api/types'; type ParamControlAdapterModelProps = { id: string; }; -const selectMainModel = createMemoizedSelector(selectGenerationSlice, (generation) => generation.model); - const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => { const isEnabled = useControlAdapterIsEnabled(id); const controlAdapterType = useControlAdapterType(id); const model = useControlAdapterModel(id); const dispatch = useAppDispatch(); - const currentBaseModel = useAppSelector((s) => s.generation.model?.base_model); - const mainModel = useAppSelector(selectMainModel); - const { t } = useTranslation(); + const currentBaseModel = useAppSelector((s) => s.generation.model?.base); - const models = useControlAdapterModelEntities(controlAdapterType); + const { data, isLoading } = useControlAdapterModelQuery(controlAdapterType); const _onChange = useCallback( - (model: ControlNetModelConfigEntity | IPAdapterModelConfigEntity | T2IAdapterModelConfigEntity | null) => { + (model: ControlNetConfig | IPAdapterConfig | T2IAdapterConfig | null) => { if (!model) { return; } dispatch( controlAdapterModelChanged({ id, - model: pick(model, 'base_model', 'model_name'), + model: pick(model, 'base', 'key'), }) ); }, @@ -55,34 +43,18 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => { [controlAdapterType, model] ); - const getIsDisabled = useCallback( - (model: AnyModelConfig): boolean => { - const isCompatible = currentBaseModel === model.base_model; - const hasMainModel = Boolean(currentBaseModel); - return !hasMainModel || !isCompatible; - }, - [currentBaseModel] - ); - - const { options, value, onChange, noOptionsMessage } = useGroupedModelCombobox({ - modelEntities: models, - onChange: _onChange, + const { items, selectedItem, onChange, placeholder } = useModelCustomSelect({ + data, + isLoading, selectedModel, - getIsDisabled, + onChange: _onChange, + modelFilter: (model) => model.base === currentBaseModel, }); return ( - - - - - + + + ); }; diff --git a/invokeai/frontend/web/src/features/controlAdapters/hooks/useAddControlAdapter.ts b/invokeai/frontend/web/src/features/controlAdapters/hooks/useAddControlAdapter.ts index 51b36968d2..7fd1088767 100644 --- a/invokeai/frontend/web/src/features/controlAdapters/hooks/useAddControlAdapter.ts +++ b/invokeai/frontend/web/src/features/controlAdapters/hooks/useAddControlAdapter.ts @@ -6,14 +6,14 @@ import { useCallback, useMemo } from 'react'; import { useControlAdapterModels } from './useControlAdapterModels'; export const useAddControlAdapter = (type: ControlAdapterType) => { - const baseModel = useAppSelector((s) => s.generation.model?.base_model); + const baseModel = useAppSelector((s) => s.generation.model?.base); const dispatch = useAppDispatch(); const models = useControlAdapterModels(type); const firstModel = useMemo(() => { // prefer to use a model that matches the base model - const firstCompatibleModel = models.filter((m) => (baseModel ? m.base_model === baseModel : true))[0]; + const firstCompatibleModel = models.filter((m) => (baseModel ? m.base === baseModel : true))[0]; if (firstCompatibleModel) { return firstCompatibleModel; diff --git a/invokeai/frontend/web/src/features/controlAdapters/hooks/useControlAdapterModelEntities.ts b/invokeai/frontend/web/src/features/controlAdapters/hooks/useControlAdapterModelEntities.ts deleted file mode 100644 index 0c8baaacc2..0000000000 --- a/invokeai/frontend/web/src/features/controlAdapters/hooks/useControlAdapterModelEntities.ts +++ /dev/null @@ -1,23 +0,0 @@ -import type { ControlAdapterType } from 'features/controlAdapters/store/types'; -import { - useGetControlNetModelsQuery, - useGetIPAdapterModelsQuery, - useGetT2IAdapterModelsQuery, -} from 'services/api/endpoints/models'; - -export const useControlAdapterModelEntities = (type?: ControlAdapterType) => { - const { data: controlNetModelsData } = useGetControlNetModelsQuery(); - const { data: t2iAdapterModelsData } = useGetT2IAdapterModelsQuery(); - const { data: ipAdapterModelsData } = useGetIPAdapterModelsQuery(); - - if (type === 'controlnet') { - return controlNetModelsData; - } - if (type === 't2i_adapter') { - return t2iAdapterModelsData; - } - if (type === 'ip_adapter') { - return ipAdapterModelsData; - } - return; -}; diff --git a/invokeai/frontend/web/src/features/controlAdapters/hooks/useControlAdapterModelQuery.ts b/invokeai/frontend/web/src/features/controlAdapters/hooks/useControlAdapterModelQuery.ts new file mode 100644 index 0000000000..1d092497af --- /dev/null +++ b/invokeai/frontend/web/src/features/controlAdapters/hooks/useControlAdapterModelQuery.ts @@ -0,0 +1,26 @@ +import type { ControlAdapterType } from 'features/controlAdapters/store/types'; +import { + useGetControlNetModelsQuery, + useGetIPAdapterModelsQuery, + useGetT2IAdapterModelsQuery, +} from 'services/api/endpoints/models'; + +export const useControlAdapterModelQuery = (type: ControlAdapterType) => { + const controlNetModelsQuery = useGetControlNetModelsQuery(); + const t2iAdapterModelsQuery = useGetT2IAdapterModelsQuery(); + const ipAdapterModelsQuery = useGetIPAdapterModelsQuery(); + + if (type === 'controlnet') { + return controlNetModelsQuery; + } + if (type === 't2i_adapter') { + return t2iAdapterModelsQuery; + } + if (type === 'ip_adapter') { + return ipAdapterModelsQuery; + } + + // Assert that the end of the function is not reachable. + const exhaustiveCheck: never = type; + return exhaustiveCheck; +}; diff --git a/invokeai/frontend/web/src/features/controlAdapters/hooks/useControlAdapterType.ts b/invokeai/frontend/web/src/features/controlAdapters/hooks/useControlAdapterType.ts index 4e15dc9e64..fe818f3287 100644 --- a/invokeai/frontend/web/src/features/controlAdapters/hooks/useControlAdapterType.ts +++ b/invokeai/frontend/web/src/features/controlAdapters/hooks/useControlAdapterType.ts @@ -5,14 +5,16 @@ import { selectControlAdaptersSlice, } from 'features/controlAdapters/store/controlAdaptersSlice'; import { useMemo } from 'react'; +import { assert } from 'tsafe'; export const useControlAdapterType = (id: string) => { const selector = useMemo( () => - createMemoizedSelector( - selectControlAdaptersSlice, - (controlAdapters) => selectControlAdapterById(controlAdapters, id)?.type - ), + createMemoizedSelector(selectControlAdaptersSlice, (controlAdapters) => { + const type = selectControlAdapterById(controlAdapters, id)?.type; + assert(type !== undefined, `Control adapter with id ${id} not found`); + return type; + }), [id] ); diff --git a/invokeai/frontend/web/src/features/controlAdapters/store/controlAdaptersSlice.ts b/invokeai/frontend/web/src/features/controlAdapters/store/controlAdaptersSlice.ts index 49b07f16a1..fce94ad019 100644 --- a/invokeai/frontend/web/src/features/controlAdapters/store/controlAdaptersSlice.ts +++ b/invokeai/frontend/web/src/features/controlAdapters/store/controlAdaptersSlice.ts @@ -236,7 +236,8 @@ export const controlAdaptersSlice = createSlice({ let processorType: ControlAdapterProcessorType | undefined = undefined; for (const modelSubstring in CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS) { - if (model.model_name.includes(modelSubstring)) { + // TODO(MM2): matching modelSubstring to the model key is no longer a valid way to figure out the default processorType + if (model.key.includes(modelSubstring)) { processorType = CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS[modelSubstring]; break; } @@ -359,7 +360,8 @@ export const controlAdaptersSlice = createSlice({ let processorType: ControlAdapterProcessorType | undefined = undefined; for (const modelSubstring in CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS) { - if (cn.model?.model_name.includes(modelSubstring)) { + // TODO(MM2): matching modelSubstring to the model key is no longer a valid way to figure out the default processorType + if (cn.model?.key.includes(modelSubstring)) { processorType = CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS[modelSubstring]; break; } diff --git a/invokeai/frontend/web/src/features/embedding/EmbeddingSelect.tsx b/invokeai/frontend/web/src/features/embedding/EmbeddingSelect.tsx index ffe9d63360..426ddd21e2 100644 --- a/invokeai/frontend/web/src/features/embedding/EmbeddingSelect.tsx +++ b/invokeai/frontend/web/src/features/embedding/EmbeddingSelect.tsx @@ -6,18 +6,18 @@ import type { EmbeddingSelectProps } from 'features/embedding/types'; import { t } from 'i18next'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; -import type { TextualInversionModelConfigEntity } from 'services/api/endpoints/models'; import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models'; +import type { TextualInversionConfig } from 'services/api/types'; const noOptionsMessage = () => t('embedding.noMatchingEmbedding'); export const EmbeddingSelect = memo(({ onSelect, onClose }: EmbeddingSelectProps) => { const { t } = useTranslation(); - const currentBaseModel = useAppSelector((s) => s.generation.model?.base_model); + const currentBaseModel = useAppSelector((s) => s.generation.model?.base); const getIsDisabled = useCallback( - (embedding: TextualInversionModelConfigEntity): boolean => { + (embedding: TextualInversionConfig): boolean => { const isCompatible = currentBaseModel === embedding.base_model; const hasMainModel = Boolean(currentBaseModel); return !hasMainModel || !isCompatible; @@ -27,7 +27,7 @@ export const EmbeddingSelect = memo(({ onSelect, onClose }: EmbeddingSelectProps const { data, isLoading } = useGetTextualInversionModelsQuery(); const _onChange = useCallback( - (embedding: TextualInversionModelConfigEntity | null) => { + (embedding: TextualInversionConfig | null) => { if (!embedding) { return; } diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx index e9a1461186..5907ba0700 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx @@ -208,8 +208,8 @@ const ImageMetadataActions = (props: Props) => { {metadata.seed !== undefined && metadata.seed !== null && ( )} - {metadata.model !== undefined && metadata.model !== null && metadata.model.model_name && ( - + {metadata.model !== undefined && metadata.model !== null && metadata.model.key && ( + )} {metadata.width && ( @@ -222,7 +222,7 @@ const ImageMetadataActions = (props: Props) => { )} {metadata.steps && ( @@ -269,7 +269,7 @@ const ImageMetadataActions = (props: Props) => { ); @@ -279,7 +279,7 @@ const ImageMetadataActions = (props: Props) => { ))} @@ -287,7 +287,7 @@ const ImageMetadataActions = (props: Props) => { ))} @@ -295,7 +295,7 @@ const ImageMetadataActions = (props: Props) => { ))} diff --git a/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx b/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx index caedde875a..a194fb1361 100644 --- a/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx +++ b/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx @@ -43,7 +43,7 @@ export const LoRACard = memo((props: LoRACardProps) => { - {lora.model_name} + {lora.key} diff --git a/invokeai/frontend/web/src/features/lora/components/LoRAList.tsx b/invokeai/frontend/web/src/features/lora/components/LoRAList.tsx index 9f37454d16..7bcd537805 100644 --- a/invokeai/frontend/web/src/features/lora/components/LoRAList.tsx +++ b/invokeai/frontend/web/src/features/lora/components/LoRAList.tsx @@ -18,7 +18,7 @@ export const LoRAList = memo(() => { return ( {lorasArray.map((lora) => ( - + ))} ); diff --git a/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx b/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx index 30ef99d2f7..910f7087df 100644 --- a/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx +++ b/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx @@ -6,7 +6,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; import { loraAdded, selectLoraSlice } from 'features/lora/store/loraSlice'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import type { LoRAModelConfigEntity } from 'services/api/endpoints/models'; +import type { LoRAConfig } from 'services/api/endpoints/models'; import { useGetLoRAModelsQuery } from 'services/api/endpoints/models'; const selectAddedLoRAs = createMemoizedSelector(selectLoraSlice, (lora) => lora.loras); @@ -18,7 +18,7 @@ const LoRASelect = () => { const addedLoRAs = useAppSelector(selectAddedLoRAs); const currentBaseModel = useAppSelector((s) => s.generation.model?.base_model); - const getIsDisabled = (lora: LoRAModelConfigEntity): boolean => { + const getIsDisabled = (lora: LoRAConfig): boolean => { const isCompatible = currentBaseModel === lora.base_model; const isAdded = Boolean(addedLoRAs[lora.id]); const hasMainModel = Boolean(currentBaseModel); @@ -26,7 +26,7 @@ const LoRASelect = () => { }; const _onChange = useCallback( - (lora: LoRAModelConfigEntity | null) => { + (lora: LoRAConfig | null) => { if (!lora) { return; } diff --git a/invokeai/frontend/web/src/features/lora/store/loraSlice.ts b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts index ab1b140a7c..dd455e12c3 100644 --- a/invokeai/frontend/web/src/features/lora/store/loraSlice.ts +++ b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts @@ -2,10 +2,9 @@ import type { PayloadAction } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit'; import type { PersistConfig, RootState } from 'app/store/store'; import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas'; -import type { LoRAModelConfigEntity } from 'services/api/endpoints/models'; +import type { LoRAConfig } from 'services/api/types'; export type LoRA = ParameterLoRAModel & { - id: string; weight: number; isEnabled?: boolean; }; @@ -29,40 +28,40 @@ export const loraSlice = createSlice({ name: 'lora', initialState: initialLoraState, reducers: { - loraAdded: (state, action: PayloadAction) => { - const { model_name, id, base_model } = action.payload; - state.loras[id] = { id, model_name, base_model, ...defaultLoRAConfig }; + loraAdded: (state, action: PayloadAction) => { + const { key, base } = action.payload; + state.loras[key] = { key, base, ...defaultLoRAConfig }; }, - loraRecalled: (state, action: PayloadAction) => { - const { model_name, id, base_model, weight } = action.payload; - state.loras[id] = { id, model_name, base_model, weight, isEnabled: true }; + loraRecalled: (state, action: PayloadAction) => { + const { key, base, weight } = action.payload; + state.loras[key] = { key, base, weight, isEnabled: true }; }, loraRemoved: (state, action: PayloadAction) => { - const id = action.payload; - delete state.loras[id]; + const key = action.payload; + delete state.loras[key]; }, lorasCleared: (state) => { state.loras = {}; }, - loraWeightChanged: (state, action: PayloadAction<{ id: string; weight: number }>) => { - const { id, weight } = action.payload; - const lora = state.loras[id]; + loraWeightChanged: (state, action: PayloadAction<{ key: string; weight: number }>) => { + const { key, weight } = action.payload; + const lora = state.loras[key]; if (!lora) { return; } lora.weight = weight; }, loraWeightReset: (state, action: PayloadAction) => { - const id = action.payload; - const lora = state.loras[id]; + const key = action.payload; + const lora = state.loras[key]; if (!lora) { return; } lora.weight = defaultLoRAConfig.weight; }, - loraIsEnabledChanged: (state, action: PayloadAction>) => { - const { id, isEnabled } = action.payload; - const lora = state.loras[id]; + loraIsEnabledChanged: (state, action: PayloadAction>) => { + const { key, isEnabled } = action.payload; + const lora = state.loras[key]; if (!lora) { return; } diff --git a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel.tsx b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel.tsx index 6b9abdbfec..15149b339b 100644 --- a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel.tsx +++ b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel.tsx @@ -2,11 +2,7 @@ import { Flex, Text } from '@invoke-ai/ui-library'; import { memo, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { ALL_BASE_MODELS } from 'services/api/constants'; -import type { - DiffusersModelConfigEntity, - LoRAModelConfigEntity, - MainModelConfigEntity, -} from 'services/api/endpoints/models'; +import type { DiffusersModelConfig, LoRAConfig, MainModelConfig } from 'services/api/endpoints/models'; import { useGetLoRAModelsQuery, useGetMainModelsQuery } from 'services/api/endpoints/models'; import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit'; @@ -38,7 +34,7 @@ const ModelManagerPanel = () => { }; type ModelEditProps = { - model: MainModelConfigEntity | LoRAModelConfigEntity | undefined; + model: MainModelConfig | LoRAConfig | undefined; }; const ModelEdit = (props: ModelEditProps) => { @@ -50,7 +46,7 @@ const ModelEdit = (props: ModelEditProps) => { } if (model?.model_format === 'diffusers') { - return ; + return ; } if (model?.model_type === 'lora') { diff --git a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx index f4d271187d..43707308e0 100644 --- a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx +++ b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx @@ -21,14 +21,14 @@ import { memo, useCallback, useEffect, useState } from 'react'; import type { SubmitHandler } from 'react-hook-form'; import { useForm } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; -import type { CheckpointModelConfigEntity } from 'services/api/endpoints/models'; +import type { CheckpointModelConfig } from 'services/api/endpoints/models'; import { useGetCheckpointConfigsQuery, useUpdateMainModelsMutation } from 'services/api/endpoints/models'; import type { CheckpointModelConfig } from 'services/api/types'; import ModelConvert from './ModelConvert'; type CheckpointModelEditProps = { - model: CheckpointModelConfigEntity; + model: CheckpointModelConfig; }; const CheckpointModelEdit = (props: CheckpointModelEditProps) => { diff --git a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx index 4670f32157..bf6349234f 100644 --- a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx +++ b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx @@ -9,12 +9,12 @@ import { memo, useCallback } from 'react'; import type { SubmitHandler } from 'react-hook-form'; import { useForm } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; -import type { DiffusersModelConfigEntity } from 'services/api/endpoints/models'; +import type { DiffusersModelConfig } from 'services/api/endpoints/models'; import { useUpdateMainModelsMutation } from 'services/api/endpoints/models'; import type { DiffusersModelConfig } from 'services/api/types'; type DiffusersModelEditProps = { - model: DiffusersModelConfigEntity; + model: DiffusersModelConfig; }; const DiffusersModelEdit = (props: DiffusersModelEditProps) => { diff --git a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/LoRAModelEdit.tsx b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/LoRAModelEdit.tsx index 2baf735bee..75151cd001 100644 --- a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/LoRAModelEdit.tsx +++ b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/LoRAModelEdit.tsx @@ -8,12 +8,12 @@ import { memo, useCallback } from 'react'; import type { SubmitHandler } from 'react-hook-form'; import { useForm } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; -import type { LoRAModelConfigEntity } from 'services/api/endpoints/models'; +import type { LoRAConfig } from 'services/api/endpoints/models'; import { useUpdateLoRAModelsMutation } from 'services/api/endpoints/models'; -import type { LoRAModelConfig } from 'services/api/types'; +import type { LoRAConfig } from 'services/api/types'; type LoRAModelEditProps = { - model: LoRAModelConfigEntity; + model: LoRAConfig; }; const LoRAModelEdit = (props: LoRAModelEditProps) => { @@ -30,7 +30,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => { control, formState: { errors }, reset, - } = useForm({ + } = useForm({ defaultValues: { model_name: model.model_name ? model.model_name : '', base_model: model.base_model, @@ -42,7 +42,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => { mode: 'onChange', }); - const onSubmit = useCallback>( + const onSubmit = useCallback>( (values) => { const responseBody = { base_model: model.base_model, @@ -53,7 +53,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => { updateLoRAModel(responseBody) .unwrap() .then((payload) => { - reset(payload as LoRAModelConfig, { keepDefaultValues: true }); + reset(payload as LoRAConfig, { keepDefaultValues: true }); dispatch( addToast( makeToast({ @@ -106,7 +106,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => { {t('modelManager.description')} - control={control} name="base_model" /> + control={control} name="base_model" /> {t('modelManager.modelLocation')} diff --git a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/ModelList.tsx b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/ModelList.tsx index 94db3d20c3..dd74bb0c23 100644 --- a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/ModelList.tsx +++ b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/ModelList.tsx @@ -5,7 +5,7 @@ import type { ChangeEvent, PropsWithChildren } from 'react'; import { memo, useCallback, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { ALL_BASE_MODELS } from 'services/api/constants'; -import type { LoRAModelConfigEntity, MainModelConfigEntity } from 'services/api/endpoints/models'; +import type { LoRAConfig, MainModelConfig } from 'services/api/endpoints/models'; import { useGetLoRAModelsQuery, useGetMainModelsQuery } from 'services/api/endpoints/models'; import ModelListItem from './ModelListItem'; @@ -127,7 +127,7 @@ const ModelList = (props: ModelListProps) => { export default memo(ModelList); -const modelsFilter = ( +const modelsFilter = ( data: EntityState | undefined, model_type: ModelType, model_format: ModelFormat | undefined, @@ -163,7 +163,7 @@ StyledModelContainer.displayName = 'StyledModelContainer'; type ModelListWrapperProps = { title: string; - modelList: MainModelConfigEntity[] | LoRAModelConfigEntity[]; + modelList: MainModelConfig[] | LoRAConfig[]; selected: ModelListProps; }; diff --git a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/ModelListItem.tsx b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/ModelListItem.tsx index fdd13e09f5..835499d25a 100644 --- a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/ModelListItem.tsx +++ b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/ModelListItem.tsx @@ -15,11 +15,11 @@ import { makeToast } from 'features/system/util/makeToast'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { PiTrashSimpleBold } from 'react-icons/pi'; -import type { LoRAModelConfigEntity, MainModelConfigEntity } from 'services/api/endpoints/models'; +import type { LoRAConfig, MainModelConfig } from 'services/api/endpoints/models'; import { useDeleteLoRAModelsMutation, useDeleteMainModelsMutation } from 'services/api/endpoints/models'; type ModelListItemProps = { - model: MainModelConfigEntity | LoRAModelConfigEntity; + model: MainModelConfig | LoRAConfig; isSelected: boolean; setSelectedModelId: (v: string | undefined) => void; }; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlNetModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlNetModelFieldInputComponent.tsx index 22024f3d3c..53d800e7b6 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlNetModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlNetModelFieldInputComponent.tsx @@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; import { fieldControlNetModelValueChanged } from 'features/nodes/store/nodesSlice'; import type { ControlNetModelFieldInputInstance, ControlNetModelFieldInputTemplate } from 'features/nodes/types/field'; import { memo, useCallback } from 'react'; -import type { ControlNetModelConfigEntity } from 'services/api/endpoints/models'; +import type { ControlNetConfig } from 'services/api/endpoints/models'; import { useGetControlNetModelsQuery } from 'services/api/endpoints/models'; import type { FieldComponentProps } from './types'; @@ -17,7 +17,7 @@ const ControlNetModelFieldInputComponent = (props: Props) => { const { data, isLoading } = useGetControlNetModelsQuery(); const _onChange = useCallback( - (value: ControlNetModelConfigEntity | null) => { + (value: ControlNetConfig | null) => { if (!value) { return; } diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/IPAdapterModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/IPAdapterModelFieldInputComponent.tsx index 2cde347247..3f195ceb32 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/IPAdapterModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/IPAdapterModelFieldInputComponent.tsx @@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; import { fieldIPAdapterModelValueChanged } from 'features/nodes/store/nodesSlice'; import type { IPAdapterModelFieldInputInstance, IPAdapterModelFieldInputTemplate } from 'features/nodes/types/field'; import { memo, useCallback } from 'react'; -import type { IPAdapterModelConfigEntity } from 'services/api/endpoints/models'; +import type { IPAdapterConfig } from 'services/api/endpoints/models'; import { useGetIPAdapterModelsQuery } from 'services/api/endpoints/models'; import type { FieldComponentProps } from './types'; @@ -17,7 +17,7 @@ const IPAdapterModelFieldInputComponent = ( const { data: ipAdapterModels } = useGetIPAdapterModelsQuery(); const _onChange = useCallback( - (value: IPAdapterModelConfigEntity | null) => { + (value: IPAdapterConfig | null) => { if (!value) { return; } diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/LoRAModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/LoRAModelFieldInputComponent.tsx index 96208d68d4..eeb07fa08e 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/LoRAModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/LoRAModelFieldInputComponent.tsx @@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; import { fieldLoRAModelValueChanged } from 'features/nodes/store/nodesSlice'; import type { LoRAModelFieldInputInstance, LoRAModelFieldInputTemplate } from 'features/nodes/types/field'; import { memo, useCallback } from 'react'; -import type { LoRAModelConfigEntity } from 'services/api/endpoints/models'; +import type { LoRAConfig } from 'services/api/endpoints/models'; import { useGetLoRAModelsQuery } from 'services/api/endpoints/models'; import type { FieldComponentProps } from './types'; @@ -16,7 +16,7 @@ const LoRAModelFieldInputComponent = (props: Props) => { const dispatch = useAppDispatch(); const { data, isLoading } = useGetLoRAModelsQuery(); const _onChange = useCallback( - (value: LoRAModelConfigEntity | null) => { + (value: LoRAConfig | null) => { if (!value) { return; } diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/MainModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/MainModelFieldInputComponent.tsx index 64c6970cae..7ddde08816 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/MainModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/MainModelFieldInputComponent.tsx @@ -6,7 +6,7 @@ import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice'; import type { MainModelFieldInputInstance, MainModelFieldInputTemplate } from 'features/nodes/types/field'; import { memo, useCallback } from 'react'; import { NON_SDXL_MAIN_MODELS } from 'services/api/constants'; -import type { MainModelConfigEntity } from 'services/api/endpoints/models'; +import type { MainModelConfig } from 'services/api/endpoints/models'; import { useGetMainModelsQuery } from 'services/api/endpoints/models'; import type { FieldComponentProps } from './types'; @@ -18,7 +18,7 @@ const MainModelFieldInputComponent = (props: Props) => { const dispatch = useAppDispatch(); const { data, isLoading } = useGetMainModelsQuery(NON_SDXL_MAIN_MODELS); const _onChange = useCallback( - (value: MainModelConfigEntity | null) => { + (value: MainModelConfig | null) => { if (!value) { return; } diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/RefinerModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/RefinerModelFieldInputComponent.tsx index 98901af38b..9b5a1138d4 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/RefinerModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/RefinerModelFieldInputComponent.tsx @@ -9,7 +9,7 @@ import type { } from 'features/nodes/types/field'; import { memo, useCallback } from 'react'; import { REFINER_BASE_MODELS } from 'services/api/constants'; -import type { MainModelConfigEntity } from 'services/api/endpoints/models'; +import type { MainModelConfig } from 'services/api/endpoints/models'; import { useGetMainModelsQuery } from 'services/api/endpoints/models'; import type { FieldComponentProps } from './types'; @@ -21,7 +21,7 @@ const RefinerModelFieldInputComponent = (props: Props) => { const dispatch = useAppDispatch(); const { data, isLoading } = useGetMainModelsQuery(REFINER_BASE_MODELS); const _onChange = useCallback( - (value: MainModelConfigEntity | null) => { + (value: MainModelConfig | null) => { if (!value) { return; } diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SDXLMainModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SDXLMainModelFieldInputComponent.tsx index f5bc7ac3e4..cf353619e8 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SDXLMainModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SDXLMainModelFieldInputComponent.tsx @@ -6,7 +6,7 @@ import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice'; import type { SDXLMainModelFieldInputInstance, SDXLMainModelFieldInputTemplate } from 'features/nodes/types/field'; import { memo, useCallback } from 'react'; import { SDXL_MAIN_MODELS } from 'services/api/constants'; -import type { MainModelConfigEntity } from 'services/api/endpoints/models'; +import type { MainModelConfig } from 'services/api/endpoints/models'; import { useGetMainModelsQuery } from 'services/api/endpoints/models'; import type { FieldComponentProps } from './types'; @@ -18,7 +18,7 @@ const SDXLMainModelFieldInputComponent = (props: Props) => { const dispatch = useAppDispatch(); const { data, isLoading } = useGetMainModelsQuery(SDXL_MAIN_MODELS); const _onChange = useCallback( - (value: MainModelConfigEntity | null) => { + (value: MainModelConfig | null) => { if (!value) { return; } diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T2IAdapterModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T2IAdapterModelFieldInputComponent.tsx index 9baf0d2d61..8402c56343 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T2IAdapterModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T2IAdapterModelFieldInputComponent.tsx @@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; import { fieldT2IAdapterModelValueChanged } from 'features/nodes/store/nodesSlice'; import type { T2IAdapterModelFieldInputInstance, T2IAdapterModelFieldInputTemplate } from 'features/nodes/types/field'; import { memo, useCallback } from 'react'; -import type { T2IAdapterModelConfigEntity } from 'services/api/endpoints/models'; +import type { T2IAdapterConfig } from 'services/api/endpoints/models'; import { useGetT2IAdapterModelsQuery } from 'services/api/endpoints/models'; import type { FieldComponentProps } from './types'; @@ -18,7 +18,7 @@ const T2IAdapterModelFieldInputComponent = ( const { data: t2iAdapterModels } = useGetT2IAdapterModelsQuery(); const _onChange = useCallback( - (value: T2IAdapterModelConfigEntity | null) => { + (value: T2IAdapterConfig | null) => { if (!value) { return; } diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/VAEModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/VAEModelFieldInputComponent.tsx index 070178f32a..af09f2d8f2 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/VAEModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/VAEModelFieldInputComponent.tsx @@ -5,7 +5,7 @@ import { SyncModelsIconButton } from 'features/modelManager/components/SyncModel import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice'; import type { VAEModelFieldInputInstance, VAEModelFieldInputTemplate } from 'features/nodes/types/field'; import { memo, useCallback } from 'react'; -import type { VaeModelConfigEntity } from 'services/api/endpoints/models'; +import type { VAEConfig } from 'services/api/endpoints/models'; import { useGetVaeModelsQuery } from 'services/api/endpoints/models'; import type { FieldComponentProps } from './types'; @@ -17,7 +17,7 @@ const VAEModelFieldInputComponent = (props: Props) => { const dispatch = useAppDispatch(); const { data, isLoading } = useGetVaeModelsQuery(); const _onChange = useCallback( - (value: VaeModelConfigEntity | null) => { + (value: VAEConfig | null) => { if (!value) { return; } diff --git a/invokeai/frontend/web/src/features/nodes/types/common.ts b/invokeai/frontend/web/src/features/nodes/types/common.ts index ef579fce8c..891bd29bc8 100644 --- a/invokeai/frontend/web/src/features/nodes/types/common.ts +++ b/invokeai/frontend/web/src/features/nodes/types/common.ts @@ -67,11 +67,13 @@ export const zModelName = z.string().min(3); export const zModelIdentifier = z.object({ key: z.string().min(1), }); +export const zModelFieldBase = zModelIdentifier; +export const zModelIdentifierWithBase = zModelIdentifier.extend({ base: zBaseModel }); export type BaseModel = z.infer; export type ModelType = z.infer; export type ModelIdentifier = z.infer; - -export const zMainModelField = zModelIdentifier; +export type ModelIdentifierWithBase = z.infer; +export const zMainModelField = zModelFieldBase; export type MainModelField = z.infer; export const zSDXLRefinerModelField = zModelIdentifier; @@ -91,23 +93,23 @@ export const zSubModelType = z.enum([ ]); export type SubModelType = z.infer; -export const zVAEModelField = zModelIdentifier; +export const zVAEModelField = zModelFieldBase; export const zModelInfo = zModelIdentifier.extend({ submodel_type: zSubModelType.nullish(), }); export type ModelInfo = z.infer; -export const zLoRAModelField = zModelIdentifier; +export const zLoRAModelField = zModelFieldBase; export type LoRAModelField = z.infer; -export const zControlNetModelField = zModelIdentifier; +export const zControlNetModelField = zModelFieldBase; export type ControlNetModelField = z.infer; -export const zIPAdapterModelField = zModelIdentifier; +export const zIPAdapterModelField = zModelFieldBase; export type IPAdapterModelField = z.infer; -export const zT2IAdapterModelField = zModelIdentifier; +export const zT2IAdapterModelField = zModelFieldBase; export type T2IAdapterModelField = z.infer; export const zLoraInfo = zModelInfo.extend({ diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addControlNetToLinearGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addControlNetToLinearGraph.ts index d862b0986e..1853d3722c 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addControlNetToLinearGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addControlNetToLinearGraph.ts @@ -14,7 +14,7 @@ import { upsertMetadata } from './metadata'; export const addControlNetToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => { const validControlNets = selectValidControlNets(state.controlAdapters).filter( - (ca) => ca.model?.base_model === state.generation.model?.base_model + (ca) => ca.model?.base === state.generation.model?.base ); // const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addIPAdapterToLinearGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addIPAdapterToLinearGraph.ts index 3a79b78c6e..b51ac1bd52 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addIPAdapterToLinearGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addIPAdapterToLinearGraph.ts @@ -14,7 +14,7 @@ import { upsertMetadata } from './metadata'; export const addIPAdapterToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => { const validIPAdapters = selectValidIPAdapters(state.controlAdapters).filter( - (ca) => ca.model?.base_model === state.generation.model?.base_model + (ca) => ca.model?.base === state.generation.model?.base ); if (validIPAdapters.length) { diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addLoRAsToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addLoRAsToGraph.ts index 3ed71b7529..95bba9b441 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addLoRAsToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addLoRAsToGraph.ts @@ -28,6 +28,7 @@ export const addLoRAsToGraph = ( * So we need to inject a LoRA chain into the graph. */ + // TODO(MM2): check base model const enabledLoRAs = filter(state.lora.loras, (l) => l.isEnabled ?? false); const loraCount = size(enabledLoRAs); @@ -48,19 +49,19 @@ export const addLoRAsToGraph = ( const loraMetadata: CoreMetadataInvocation['loras'] = []; enabledLoRAs.forEach((lora) => { - const { model_name, base_model, weight } = lora; - const currentLoraNodeId = `${LORA_LOADER}_${model_name.replace('.', '_')}`; + const { key, weight } = lora; + const currentLoraNodeId = `${LORA_LOADER}_${key}`; const loraLoaderNode: LoraLoaderInvocation = { type: 'lora_loader', id: currentLoraNodeId, is_intermediate: true, - lora: { model_name, base_model }, + lora: { key }, weight, }; loraMetadata.push({ - lora: { model_name, base_model }, + lora: { key }, weight, }); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addSDXLLoRAstoGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addSDXLLoRAstoGraph.ts index 9553568922..7874b059c9 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addSDXLLoRAstoGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addSDXLLoRAstoGraph.ts @@ -31,6 +31,7 @@ export const addSDXLLoRAsToGraph = ( * So we need to inject a LoRA chain into the graph. */ + // TODO(MM2): check base model const enabledLoRAs = filter(state.lora.loras, (l) => l.isEnabled ?? false); const loraCount = size(enabledLoRAs); @@ -60,20 +61,20 @@ export const addSDXLLoRAsToGraph = ( let currentLoraIndex = 0; enabledLoRAs.forEach((lora) => { - const { model_name, base_model, weight } = lora; - const currentLoraNodeId = `${LORA_LOADER}_${model_name.replace('.', '_')}`; + const { key, weight } = lora; + const currentLoraNodeId = `${LORA_LOADER}_${key}`; const loraLoaderNode: SDXLLoraLoaderInvocation = { type: 'sdxl_lora_loader', id: currentLoraNodeId, is_intermediate: true, - lora: { model_name, base_model }, + lora: { key }, weight, }; loraMetadata.push( zLoRAMetadataItem.parse({ - lora: { model_name, base_model }, + lora: { key }, weight, }) ); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addT2IAdapterToLinearGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addT2IAdapterToLinearGraph.ts index d35f72a2b4..84002337d7 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addT2IAdapterToLinearGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addT2IAdapterToLinearGraph.ts @@ -14,7 +14,7 @@ import { upsertMetadata } from './metadata'; export const addT2IAdaptersToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => { const validT2IAdapters = selectValidT2IAdapters(state.controlAdapters).filter( - (ca) => ca.model?.base_model === state.generation.model?.base_model + (ca) => ca.model?.base === state.generation.model?.base ); if (validT2IAdapters.length) { diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasGraph.ts index 2b64f4898b..4ce2e4d673 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasGraph.ts @@ -19,7 +19,7 @@ export const buildCanvasGraph = ( let graph: NonNullableGraph; if (generationMode === 'txt2img') { - if (state.generation.model && state.generation.model.base_model === 'sdxl') { + if (state.generation.model && state.generation.model.base === 'sdxl') { graph = buildCanvasSDXLTextToImageGraph(state); } else { graph = buildCanvasTextToImageGraph(state); @@ -28,7 +28,7 @@ export const buildCanvasGraph = ( if (!canvasInitImage) { throw new Error('Missing canvas init image'); } - if (state.generation.model && state.generation.model.base_model === 'sdxl') { + if (state.generation.model && state.generation.model.base === 'sdxl') { graph = buildCanvasSDXLImageToImageGraph(state, canvasInitImage); } else { graph = buildCanvasImageToImageGraph(state, canvasInitImage); @@ -37,7 +37,7 @@ export const buildCanvasGraph = ( if (!canvasInitImage || !canvasMaskImage) { throw new Error('Missing canvas init and mask images'); } - if (state.generation.model && state.generation.model.base_model === 'sdxl') { + if (state.generation.model && state.generation.model.base === 'sdxl') { graph = buildCanvasSDXLInpaintGraph(state, canvasInitImage, canvasMaskImage); } else { graph = buildCanvasInpaintGraph(state, canvasInitImage, canvasMaskImage); @@ -46,7 +46,7 @@ export const buildCanvasGraph = ( if (!canvasInitImage) { throw new Error('Missing canvas init image'); } - if (state.generation.model && state.generation.model.base_model === 'sdxl') { + if (state.generation.model && state.generation.model.base === 'sdxl') { graph = buildCanvasSDXLOutpaintGraph(state, canvasInitImage, canvasMaskImage); } else { graph = buildCanvasOutpaintGraph(state, canvasInitImage, canvasMaskImage); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearBatchConfig.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearBatchConfig.ts index d0e331fb46..9fcc6afaa0 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearBatchConfig.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearBatchConfig.ts @@ -105,7 +105,7 @@ export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph, }); } - if (shouldConcatSDXLStylePrompt && model?.base_model === 'sdxl') { + if (shouldConcatSDXLStylePrompt && model?.base === 'sdxl') { if (graph.nodes[POSITIVE_CONDITIONING]) { firstBatchDatumList.push({ node_path: POSITIVE_CONDITIONING, diff --git a/invokeai/frontend/web/src/features/parameters/components/Advanced/ParamClipSkip.tsx b/invokeai/frontend/web/src/features/parameters/components/Advanced/ParamClipSkip.tsx index 621ed56ef6..c23d541613 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Advanced/ParamClipSkip.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Advanced/ParamClipSkip.tsx @@ -29,17 +29,17 @@ const ParamClipSkip = () => { if (!model) { return CLIP_SKIP_MAP['sd-1'].maxClip; } - return CLIP_SKIP_MAP[model.base_model].maxClip; + return CLIP_SKIP_MAP[model.base].maxClip; }, [model]); const sliderMarks = useMemo(() => { if (!model) { return CLIP_SKIP_MAP['sd-1'].markers; } - return CLIP_SKIP_MAP[model.base_model].markers; + return CLIP_SKIP_MAP[model.base].markers; }, [model]); - if (model?.base_model === 'sdxl') { + if (model?.base === 'sdxl') { return null; } diff --git a/invokeai/frontend/web/src/features/parameters/components/Core/ParamPositivePrompt.tsx b/invokeai/frontend/web/src/features/parameters/components/Core/ParamPositivePrompt.tsx index ae81f78fd1..a1852bfafe 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Core/ParamPositivePrompt.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Core/ParamPositivePrompt.tsx @@ -15,7 +15,7 @@ import { useTranslation } from 'react-i18next'; export const ParamPositivePrompt = memo(() => { const dispatch = useAppDispatch(); const prompt = useAppSelector((s) => s.generation.positivePrompt); - const baseModel = useAppSelector((s) => s.generation.model)?.base_model; + const baseModel = useAppSelector((s) => s.generation.model)?.base; const textareaRef = useRef(null); const { t } = useTranslation(); diff --git a/invokeai/frontend/web/src/features/parameters/components/MainModel/ParamMainModelSelect.tsx b/invokeai/frontend/web/src/features/parameters/components/MainModel/ParamMainModelSelect.tsx index 13e03962b4..84b112eb5c 100644 --- a/invokeai/frontend/web/src/features/parameters/components/MainModel/ParamMainModelSelect.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/MainModel/ParamMainModelSelect.tsx @@ -1,58 +1,45 @@ -import { Combobox, FormControl, FormLabel, Tooltip } from '@invoke-ai/ui-library'; +import { CustomSelect, FormControl, FormLabel } from '@invoke-ai/ui-library'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; +import { useModelCustomSelect } from 'common/hooks/useModelCustomSelect'; import { modelSelected } from 'features/parameters/store/actions'; import { selectGenerationSlice } from 'features/parameters/store/generationSlice'; -import { pick } from 'lodash-es'; -import { memo, useCallback, useMemo } from 'react'; +import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { NON_REFINER_BASE_MODELS } from 'services/api/constants'; -import type { MainModelConfigEntity } from 'services/api/endpoints/models'; -import { getModelId, mainModelsAdapterSelectors, useGetMainModelsQuery } from 'services/api/endpoints/models'; +import { useGetMainModelsQuery } from 'services/api/endpoints/models'; +import type { MainModelConfig } from 'services/api/types'; const selectModel = createMemoizedSelector(selectGenerationSlice, (generation) => generation.model); const ParamMainModelSelect = () => { const dispatch = useAppDispatch(); const { t } = useTranslation(); - const model = useAppSelector(selectModel); + const selectedModel = useAppSelector(selectModel); const { data, isLoading } = useGetMainModelsQuery(NON_REFINER_BASE_MODELS); - const tooltipLabel = useMemo(() => { - if (!data || !model) { - return; - } - return mainModelsAdapterSelectors.selectById(data, getModelId(model))?.description; - }, [data, model]); + const _onChange = useCallback( - (model: MainModelConfigEntity | null) => { + (model: MainModelConfig | null) => { if (!model) { return; } - dispatch(modelSelected(pick(model, ['base_model', 'model_name', 'model_type']))); + dispatch(modelSelected({ key: model.key, base: model.base })); }, [dispatch] ); - const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({ - modelEntities: data, - onChange: _onChange, - selectedModel: model, + + const { items, selectedItem, onChange, placeholder } = useModelCustomSelect({ + data, isLoading, + selectedModel, + onChange: _onChange, }); return ( - - - {t('modelManager.model')} - - - + + {t('modelManager.model')} + + ); }; diff --git a/invokeai/frontend/web/src/features/parameters/components/VAEModel/ParamVAEModelSelect.tsx b/invokeai/frontend/web/src/features/parameters/components/VAEModel/ParamVAEModelSelect.tsx index f290378aa8..cc0164153d 100644 --- a/invokeai/frontend/web/src/features/parameters/components/VAEModel/ParamVAEModelSelect.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/VAEModel/ParamVAEModelSelect.tsx @@ -7,7 +7,7 @@ import { selectGenerationSlice, vaeSelected } from 'features/parameters/store/ge import { pick } from 'lodash-es'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; -import type { VaeModelConfigEntity } from 'services/api/endpoints/models'; +import type { VAEConfig } from 'services/api/endpoints/models'; import { useGetVaeModelsQuery } from 'services/api/endpoints/models'; const selector = createMemoizedSelector(selectGenerationSlice, (generation) => { @@ -21,7 +21,7 @@ const ParamVAEModelSelect = () => { const { model, vae } = useAppSelector(selector); const { data, isLoading } = useGetVaeModelsQuery(); const getIsDisabled = useCallback( - (vae: VaeModelConfigEntity): boolean => { + (vae: VAEConfig): boolean => { const isCompatible = model?.base_model === vae.base_model; const hasMainModel = Boolean(model?.base_model); return !hasMainModel || !isCompatible; @@ -29,7 +29,7 @@ const ParamVAEModelSelect = () => { [model?.base_model] ); const _onChange = useCallback( - (vae: VaeModelConfigEntity | null) => { + (vae: VAEConfig | null) => { dispatch(vaeSelected(vae ? pick(vae, 'base_model', 'model_name') : null)); }, [dispatch] diff --git a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts index 5a9fa6c66d..c8b17816bb 100644 --- a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts +++ b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts @@ -464,17 +464,15 @@ export const useRecallParameters = () => { return { lora: null, error: 'Invalid LoRA model' }; } - const { base_model, model_name } = loraMetadataItem.lora; + const { lora } = loraMetadataItem; - const matchingLoRA = loraModels - ? loraModelsAdapterSelectors.selectById(loraModels, `${base_model}/lora/${model_name}`) - : undefined; + const matchingLoRA = loraModels ? loraModelsAdapterSelectors.selectById(loraModels, lora.key) : undefined; if (!matchingLoRA) { return { lora: null, error: 'LoRA model is not installed' }; } - const isCompatibleBaseModel = matchingLoRA?.base_model === (newModel ?? model)?.base_model; + const isCompatibleBaseModel = matchingLoRA?.base === (newModel ?? model)?.base; if (!isCompatibleBaseModel) { return { @@ -520,17 +518,14 @@ export const useRecallParameters = () => { controlnetMetadataItem; const matchingControlNetModel = controlNetModels - ? controlNetModelsAdapterSelectors.selectById( - controlNetModels, - `${control_model.base_model}/controlnet/${control_model.model_name}` - ) + ? controlNetModelsAdapterSelectors.selectById(controlNetModels, control_model.key) : undefined; if (!matchingControlNetModel) { return { controlnet: null, error: 'ControlNet model is not installed' }; } - const isCompatibleBaseModel = matchingControlNetModel?.base_model === (newModel ?? model)?.base_model; + const isCompatibleBaseModel = matchingControlNetModel?.base === (newModel ?? model)?.base; if (!isCompatibleBaseModel) { return { @@ -597,17 +592,14 @@ export const useRecallParameters = () => { t2iAdapterMetadataItem; const matchingT2IAdapterModel = t2iAdapterModels - ? t2iAdapterModelsAdapterSelectors.selectById( - t2iAdapterModels, - `${t2i_adapter_model.base_model}/t2i_adapter/${t2i_adapter_model.model_name}` - ) + ? t2iAdapterModelsAdapterSelectors.selectById(t2iAdapterModels, t2i_adapter_model.key) : undefined; if (!matchingT2IAdapterModel) { return { controlnet: null, error: 'ControlNet model is not installed' }; } - const isCompatibleBaseModel = matchingT2IAdapterModel?.base_model === (newModel ?? model)?.base_model; + const isCompatibleBaseModel = matchingT2IAdapterModel?.base === (newModel ?? model)?.base; if (!isCompatibleBaseModel) { return { @@ -672,17 +664,14 @@ export const useRecallParameters = () => { const { image, ip_adapter_model, weight, begin_step_percent, end_step_percent } = ipAdapterMetadataItem; const matchingIPAdapterModel = ipAdapterModels - ? ipAdapterModelsAdapterSelectors.selectById( - ipAdapterModels, - `${ip_adapter_model.base_model}/ip_adapter/${ip_adapter_model.model_name}` - ) + ? ipAdapterModelsAdapterSelectors.selectById(ipAdapterModels, ip_adapter_model.key) : undefined; if (!matchingIPAdapterModel) { return { ipAdapter: null, error: 'IP Adapter model is not installed' }; } - const isCompatibleBaseModel = matchingIPAdapterModel?.base_model === (newModel ?? model)?.base_model; + const isCompatibleBaseModel = matchingIPAdapterModel?.base === (newModel ?? model)?.base; if (!isCompatibleBaseModel) { return { diff --git a/invokeai/frontend/web/src/features/parameters/store/actions.ts b/invokeai/frontend/web/src/features/parameters/store/actions.ts index 0d4dda0e87..f7bf127c05 100644 --- a/invokeai/frontend/web/src/features/parameters/store/actions.ts +++ b/invokeai/frontend/web/src/features/parameters/store/actions.ts @@ -1,6 +1,7 @@ import { createAction } from '@reduxjs/toolkit'; -import type { ImageDTO, MainModelField } from 'services/api/types'; +import type { ParameterModel } from 'features/parameters/types/parameterSchemas'; +import type { ImageDTO } from 'services/api/types'; export const initialImageSelected = createAction('generation/initialImageSelected'); -export const modelSelected = createAction('generation/modelSelected'); +export const modelSelected = createAction('generation/modelSelected'); diff --git a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts index df98943cd3..1666a34d6a 100644 --- a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts +++ b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts @@ -158,15 +158,15 @@ export const generationSlice = createSlice({ // Clamp ClipSkip Based On Selected Model // TODO(psyche): remove this special handling when https://github.com/invoke-ai/InvokeAI/issues/4583 is resolved // WIP PR here: https://github.com/invoke-ai/InvokeAI/pull/4624 - if (newModel.base_model === 'sdxl') { + if (newModel.base === 'sdxl') { // We don't support clip skip for SDXL yet - it's not in the graphs state.clipSkip = 0; } else { - const { maxClip } = CLIP_SKIP_MAP[newModel.base_model]; + const { maxClip } = CLIP_SKIP_MAP[newModel.base]; state.clipSkip = clamp(state.clipSkip, 0, maxClip); } - if (action.meta.previousModel?.base_model === newModel.base_model) { + if (action.meta.previousModel?.base === newModel.base) { // The base model hasn't changed, we don't need to optimize the size return; } diff --git a/invokeai/frontend/web/src/features/parameters/types/constants.ts b/invokeai/frontend/web/src/features/parameters/types/constants.ts index a74807a959..bd78759b39 100644 --- a/invokeai/frontend/web/src/features/parameters/types/constants.ts +++ b/invokeai/frontend/web/src/features/parameters/types/constants.ts @@ -17,8 +17,8 @@ export const MODEL_TYPE_MAP = { */ export const MODEL_TYPE_SHORT_MAP = { any: 'Any', - 'sd-1': 'SD1', - 'sd-2': 'SD2', + 'sd-1': 'SD1.X', + 'sd-2': 'SD2.X', sdxl: 'SDXL', 'sdxl-refiner': 'SDXLR', }; diff --git a/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts b/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts index 7a5efe9dcf..abd8ee2810 100644 --- a/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts +++ b/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts @@ -1,5 +1,6 @@ import { NUMPY_RAND_MAX } from 'app/constants'; import { + zBaseModel, zControlNetModelField, zIPAdapterModelField, zLoRAModelField, @@ -104,48 +105,48 @@ export const isParameterAspectRatio = (val: unknown): val is ParameterAspectRati // #endregion // #region Model -export const zParameterModel = zMainModelField; +export const zParameterModel = zMainModelField.extend({ base: zBaseModel }); export type ParameterModel = z.infer; export const isParameterModel = (val: unknown): val is ParameterModel => zParameterModel.safeParse(val).success; // #endregion // #region SDXL Refiner Model -export const zParameterSDXLRefinerModel = zSDXLRefinerModelField; +export const zParameterSDXLRefinerModel = zSDXLRefinerModelField.extend({ base: zBaseModel }); export type ParameterSDXLRefinerModel = z.infer; export const isParameterSDXLRefinerModel = (val: unknown): val is ParameterSDXLRefinerModel => zParameterSDXLRefinerModel.safeParse(val).success; // #endregion // #region VAE Model -export const zParameterVAEModel = zVAEModelField; +export const zParameterVAEModel = zVAEModelField.extend({ base: zBaseModel }); export type ParameterVAEModel = z.infer; export const isParameterVAEModel = (val: unknown): val is ParameterVAEModel => zParameterVAEModel.safeParse(val).success; // #endregion // #region LoRA Model -export const zParameterLoRAModel = zLoRAModelField; +export const zParameterLoRAModel = zLoRAModelField.extend({ base: zBaseModel }); export type ParameterLoRAModel = z.infer; export const isParameterLoRAModel = (val: unknown): val is ParameterLoRAModel => zParameterLoRAModel.safeParse(val).success; // #endregion // #region ControlNet Model -export const zParameterControlNetModel = zControlNetModelField; +export const zParameterControlNetModel = zControlNetModelField.extend({ base: zBaseModel }); export type ParameterControlNetModel = z.infer; export const isParameterControlNetModel = (val: unknown): val is ParameterControlNetModel => zParameterControlNetModel.safeParse(val).success; // #endregion // #region IP Adapter Model -export const zParameterIPAdapterModel = zIPAdapterModelField; +export const zParameterIPAdapterModel = zIPAdapterModelField.extend({ base: zBaseModel }); export type ParameterIPAdapterModel = z.infer; export const isParameterIPAdapterModel = (val: unknown): val is ParameterIPAdapterModel => zParameterIPAdapterModel.safeParse(val).success; // #endregion // #region T2I Adapter Model -export const zParameterT2IAdapterModel = zT2IAdapterModelField; +export const zParameterT2IAdapterModel = zT2IAdapterModelField.extend({ base: zBaseModel }); export type ParameterT2IAdapterModel = z.infer; export const isParameterT2IAdapterModel = (val: unknown): val is ParameterT2IAdapterModel => zParameterT2IAdapterModel.safeParse(val).success; diff --git a/invokeai/frontend/web/src/features/parameters/util/optimalDimension.ts b/invokeai/frontend/web/src/features/parameters/util/optimalDimension.ts index 1c550eb8a4..92b4f18272 100644 --- a/invokeai/frontend/web/src/features/parameters/util/optimalDimension.ts +++ b/invokeai/frontend/web/src/features/parameters/util/optimalDimension.ts @@ -1,12 +1,12 @@ -import type { ModelIdentifier } from 'features/nodes/types/common'; +import type { ModelIdentifierWithBase } from 'features/nodes/types/common'; /** * Gets the optimal dimension for a givel model, based on the model's base_model * @param model The model identifier * @returns The optimal dimension for the model */ -export const getOptimalDimension = (model?: ModelIdentifier | null): number => - model?.base_model === 'sdxl' ? 1024 : 512; +export const getOptimalDimension = (model?: ModelIdentifierWithBase | null): number => + model?.base === 'sdxl' ? 1024 : 512; const MIN_AREA_FACTOR = 0.8; const MAX_AREA_FACTOR = 1.2; diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerModelSelect.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerModelSelect.tsx index c19dc10fb5..1bca514df1 100644 --- a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerModelSelect.tsx +++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerModelSelect.tsx @@ -6,12 +6,12 @@ import { refinerModelChanged, selectSdxlSlice } from 'features/sdxl/store/sdxlSl import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { REFINER_BASE_MODELS } from 'services/api/constants'; -import type { MainModelConfigEntity } from 'services/api/endpoints/models'; +import type { MainModelConfig } from 'services/api/endpoints/models'; import { useGetMainModelsQuery } from 'services/api/endpoints/models'; const selectModel = createMemoizedSelector(selectSdxlSlice, (sdxl) => sdxl.refinerModel); -const optionsFilter = (model: MainModelConfigEntity) => model.base_model === 'sdxl-refiner'; +const optionsFilter = (model: MainModelConfig) => model.base_model === 'sdxl-refiner'; const ParamSDXLRefinerModelSelect = () => { const dispatch = useAppDispatch(); @@ -19,7 +19,7 @@ const ParamSDXLRefinerModelSelect = () => { const { t } = useTranslation(); const { data, isLoading } = useGetMainModelsQuery(REFINER_BASE_MODELS); const _onChange = useCallback( - (model: MainModelConfigEntity | null) => { + (model: MainModelConfig | null) => { if (!model) { dispatch(refinerModelChanged(null)); return; diff --git a/invokeai/frontend/web/src/features/settingsAccordions/components/AdvancedSettingsAccordion/AdvancedSettingsAccordion.tsx b/invokeai/frontend/web/src/features/settingsAccordions/components/AdvancedSettingsAccordion/AdvancedSettingsAccordion.tsx index bceee915cd..fc8c54576c 100644 --- a/invokeai/frontend/web/src/features/settingsAccordions/components/AdvancedSettingsAccordion/AdvancedSettingsAccordion.tsx +++ b/invokeai/frontend/web/src/features/settingsAccordions/components/AdvancedSettingsAccordion/AdvancedSettingsAccordion.tsx @@ -24,7 +24,8 @@ const formLabelProps2: FormLabelProps = { const selectBadges = createMemoizedSelector(selectGenerationSlice, (generation) => { const badges: (string | number)[] = []; if (generation.vae) { - let vaeBadge = generation.vae.model_name; + // TODO(MM2): Fetch the vae name + let vaeBadge = generation.vae.key; if (generation.vaePrecision === 'fp16') { vaeBadge += ` ${generation.vaePrecision}`; } diff --git a/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx b/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx index 077875a8a7..cda7dcf6e9 100644 --- a/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx +++ b/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx @@ -35,9 +35,10 @@ const badgesSelector = createMemoizedSelector(selectLoraSlice, selectGenerationS const enabledLoRAsCount = filter(lora.loras, (l) => !!l.isEnabled).length; const loraTabBadges = enabledLoRAsCount ? [enabledLoRAsCount] : []; const accordionBadges: (string | number)[] = []; + // TODO(MM2): fetch model name if (generation.model) { - accordionBadges.push(generation.model.model_name); - accordionBadges.push(generation.model.base_model); + accordionBadges.push(generation.model.key); + accordionBadges.push(generation.model.base); } return { loraTabBadges, accordionBadges }; diff --git a/invokeai/frontend/web/src/features/settingsAccordions/components/ImageSettingsAccordion/ImageSettingsAccordion.tsx b/invokeai/frontend/web/src/features/settingsAccordions/components/ImageSettingsAccordion/ImageSettingsAccordion.tsx index 2f778fe717..8f876850e8 100644 --- a/invokeai/frontend/web/src/features/settingsAccordions/components/ImageSettingsAccordion/ImageSettingsAccordion.tsx +++ b/invokeai/frontend/web/src/features/settingsAccordions/components/ImageSettingsAccordion/ImageSettingsAccordion.tsx @@ -56,7 +56,7 @@ const selector = createMemoizedSelector( if (hrfEnabled) { badges.push('HiRes Fix'); } - return { badges, activeTabName, isSDXL: model?.base_model === 'sdxl' }; + return { badges, activeTabName, isSDXL: model?.base === 'sdxl' }; } ); diff --git a/invokeai/frontend/web/src/features/ui/components/ParametersPanel.tsx b/invokeai/frontend/web/src/features/ui/components/ParametersPanel.tsx index d52b2d9000..a74d132bd6 100644 --- a/invokeai/frontend/web/src/features/ui/components/ParametersPanel.tsx +++ b/invokeai/frontend/web/src/features/ui/components/ParametersPanel.tsx @@ -22,7 +22,7 @@ const overlayScrollbarsStyles: CSSProperties = { const ParametersPanel = () => { const activeTabName = useAppSelector(activeTabNameSelector); - const isSDXL = useAppSelector((s) => s.generation.model?.base_model === 'sdxl'); + const isSDXL = useAppSelector((s) => s.generation.model?.base === 'sdxl'); return ( diff --git a/invokeai/frontend/web/src/services/api/endpoints/appInfo.ts b/invokeai/frontend/web/src/services/api/endpoints/appInfo.ts index c0916f568e..a7efaafcc8 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/appInfo.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/appInfo.ts @@ -3,27 +3,35 @@ import type { OpenAPIV3_1 } from 'openapi-types'; import type { paths } from 'services/api/schema'; import type { AppConfig, AppDependencyVersions, AppVersion } from 'services/api/types'; -import { api } from '..'; +import { api, buildV1Url } from '..'; + +/** + * Builds an endpoint URL for the app router + * @example + * buildAppInfoUrl('some-path') + * // '/api/v1/app/some-path' + */ +const buildAppInfoUrl = (path: string = '') => buildV1Url(`app/${path}`); export const appInfoApi = api.injectEndpoints({ endpoints: (build) => ({ getAppVersion: build.query({ query: () => ({ - url: `app/version`, + url: buildAppInfoUrl('version'), method: 'GET', }), providesTags: ['FetchOnReconnect'], }), getAppDeps: build.query({ query: () => ({ - url: `app/app_deps`, + url: buildAppInfoUrl('app_deps'), method: 'GET', }), providesTags: ['FetchOnReconnect'], }), getAppConfig: build.query({ query: () => ({ - url: `app/config`, + url: buildAppInfoUrl('config'), method: 'GET', }), providesTags: ['FetchOnReconnect'], @@ -33,28 +41,28 @@ export const appInfoApi = api.injectEndpoints({ void >({ query: () => ({ - url: `app/invocation_cache/status`, + url: buildAppInfoUrl('invocation_cache/status'), method: 'GET', }), providesTags: ['InvocationCacheStatus', 'FetchOnReconnect'], }), clearInvocationCache: build.mutation({ query: () => ({ - url: `app/invocation_cache`, + url: buildAppInfoUrl('invocation_cache'), method: 'DELETE', }), invalidatesTags: ['InvocationCacheStatus'], }), enableInvocationCache: build.mutation({ query: () => ({ - url: `app/invocation_cache/enable`, + url: buildAppInfoUrl('invocation_cache/enable'), method: 'PUT', }), invalidatesTags: ['InvocationCacheStatus'], }), disableInvocationCache: build.mutation({ query: () => ({ - url: `app/invocation_cache/disable`, + url: buildAppInfoUrl('invocation_cache/disable'), method: 'PUT', }), invalidatesTags: ['InvocationCacheStatus'], diff --git a/invokeai/frontend/web/src/services/api/endpoints/boards.ts b/invokeai/frontend/web/src/services/api/endpoints/boards.ts index 6977a2bd53..8efda86737 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/boards.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/boards.ts @@ -9,7 +9,15 @@ import type { import { getListImagesUrl } from 'services/api/util'; import type { ApiTagDescription } from '..'; -import { api, LIST_TAG } from '..'; +import { api, buildV1Url, LIST_TAG } from '..'; + +/** + * Builds an endpoint URL for the boards router + * @example + * buildBoardsUrl('some-path') + * // '/api/v1/boards/some-path' + */ +export const buildBoardsUrl = (path: string = '') => buildV1Url(`boards/${path}`); export const boardsApi = api.injectEndpoints({ endpoints: (build) => ({ @@ -17,7 +25,7 @@ export const boardsApi = api.injectEndpoints({ * Boards Queries */ listBoards: build.query({ - query: (arg) => ({ url: 'boards/', params: arg }), + query: (arg) => ({ url: buildBoardsUrl(), params: arg }), providesTags: (result) => { // any list of boards const tags: ApiTagDescription[] = [{ type: 'Board', id: LIST_TAG }, 'FetchOnReconnect']; @@ -38,7 +46,7 @@ export const boardsApi = api.injectEndpoints({ listAllBoards: build.query, void>({ query: () => ({ - url: 'boards/', + url: buildBoardsUrl(), params: { all: true }, }), providesTags: (result) => { @@ -61,7 +69,7 @@ export const boardsApi = api.injectEndpoints({ listAllImageNamesForBoard: build.query, string>({ query: (board_id) => ({ - url: `boards/${board_id}/image_names`, + url: buildBoardsUrl(`${board_id}/image_names`), }), providesTags: (result, error, arg) => [{ type: 'ImageNameList', id: arg }, 'FetchOnReconnect'], keepUnusedDataFor: 0, @@ -107,7 +115,7 @@ export const boardsApi = api.injectEndpoints({ createBoard: build.mutation({ query: (board_name) => ({ - url: `boards/`, + url: buildBoardsUrl(), method: 'POST', params: { board_name }, }), @@ -116,7 +124,7 @@ export const boardsApi = api.injectEndpoints({ updateBoard: build.mutation({ query: ({ board_id, changes }) => ({ - url: `boards/${board_id}`, + url: buildBoardsUrl(board_id), method: 'PATCH', body: changes, }), diff --git a/invokeai/frontend/web/src/services/api/endpoints/images.ts b/invokeai/frontend/web/src/services/api/endpoints/images.ts index 181c8d23fc..49eb28390f 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/images.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/images.ts @@ -26,8 +26,24 @@ import { } from 'services/api/util'; import type { ApiTagDescription } from '..'; -import { api, LIST_TAG } from '..'; -import { boardsApi } from './boards'; +import { api, buildV1Url, LIST_TAG } from '..'; +import { boardsApi, buildBoardsUrl } from './boards'; + +/** + * Builds an endpoint URL for the images router + * @example + * buildImagesUrl('some-path') + * // '/api/v1/images/some-path' + */ +const buildImagesUrl = (path: string = '') => buildV1Url(`images/${path}`); + +/** + * Builds an endpoint URL for the board_images router + * @example + * buildBoardImagesUrl('some-path') + * // '/api/v1/board_images/some-path' + */ +const buildBoardImagesUrl = (path: string = '') => buildV1Url(`board_images/${path}`); export const imagesApi = api.injectEndpoints({ endpoints: (build) => ({ @@ -90,20 +106,20 @@ export const imagesApi = api.injectEndpoints({ keepUnusedDataFor: 86400, }), getIntermediatesCount: build.query({ - query: () => ({ url: 'images/intermediates' }), + query: () => ({ url: buildImagesUrl('intermediates') }), providesTags: ['IntermediatesCount', 'FetchOnReconnect'], }), clearIntermediates: build.mutation({ - query: () => ({ url: `images/intermediates`, method: 'DELETE' }), + query: () => ({ url: buildImagesUrl('intermediates'), method: 'DELETE' }), invalidatesTags: ['IntermediatesCount'], }), getImageDTO: build.query({ - query: (image_name) => ({ url: `images/i/${image_name}` }), + query: (image_name) => ({ url: buildImagesUrl(`i/${image_name}`) }), providesTags: (result, error, image_name) => [{ type: 'Image', id: image_name }], keepUnusedDataFor: 86400, // 24 hours }), getImageMetadata: build.query({ - query: (image_name) => ({ url: `images/i/${image_name}/metadata` }), + query: (image_name) => ({ url: buildImagesUrl(`i/${image_name}/metadata`) }), providesTags: (result, error, image_name) => [{ type: 'ImageMetadata', id: image_name }], transformResponse: ( response: paths['/api/v1/images/i/{image_name}/metadata']['get']['responses']['200']['content']['application/json'] @@ -130,7 +146,7 @@ export const imagesApi = api.injectEndpoints({ }), deleteImage: build.mutation({ query: ({ image_name }) => ({ - url: `images/i/${image_name}`, + url: buildImagesUrl(`i/${image_name}`), method: 'DELETE', }), async onQueryStarted(imageDTO, { dispatch, queryFulfilled }) { @@ -185,7 +201,7 @@ export const imagesApi = api.injectEndpoints({ query: ({ imageDTOs }) => { const image_names = imageDTOs.map((imageDTO) => imageDTO.image_name); return { - url: `images/delete`, + url: buildImagesUrl('delete'), method: 'POST', body: { image_names, @@ -258,7 +274,7 @@ export const imagesApi = api.injectEndpoints({ */ changeImageIsIntermediate: build.mutation({ query: ({ imageDTO, is_intermediate }) => ({ - url: `images/i/${imageDTO.image_name}`, + url: buildImagesUrl(`i/${imageDTO.image_name}`), method: 'PATCH', body: { is_intermediate }, }), @@ -380,7 +396,7 @@ export const imagesApi = api.injectEndpoints({ */ changeImageSessionId: build.mutation({ query: ({ imageDTO, session_id }) => ({ - url: `images/i/${imageDTO.image_name}`, + url: buildImagesUrl(`i/${imageDTO.image_name}`), method: 'PATCH', body: { session_id }, }), @@ -417,7 +433,7 @@ export const imagesApi = api.injectEndpoints({ { imageDTOs: ImageDTO[] } >({ query: ({ imageDTOs: images }) => ({ - url: `images/star`, + url: buildImagesUrl('star'), method: 'POST', body: { image_names: images.map((img) => img.image_name) }, }), @@ -511,7 +527,7 @@ export const imagesApi = api.injectEndpoints({ { imageDTOs: ImageDTO[] } >({ query: ({ imageDTOs: images }) => ({ - url: `images/unstar`, + url: buildImagesUrl('unstar'), method: 'POST', body: { image_names: images.map((img) => img.image_name) }, }), @@ -611,7 +627,7 @@ export const imagesApi = api.injectEndpoints({ const formData = new FormData(); formData.append('file', file); return { - url: `images/upload`, + url: buildImagesUrl('upload'), method: 'POST', body: formData, params: { @@ -674,7 +690,7 @@ export const imagesApi = api.injectEndpoints({ }), deleteBoard: build.mutation({ - query: (board_id) => ({ url: `boards/${board_id}`, method: 'DELETE' }), + query: (board_id) => ({ url: buildBoardsUrl(board_id), method: 'DELETE' }), invalidatesTags: () => [ { type: 'Board', id: LIST_TAG }, // invalidate the 'No Board' cache @@ -764,7 +780,7 @@ export const imagesApi = api.injectEndpoints({ deleteBoardAndImages: build.mutation({ query: (board_id) => ({ - url: `boards/${board_id}`, + url: buildBoardsUrl(board_id), method: 'DELETE', params: { include_images: true }, }), @@ -840,7 +856,7 @@ export const imagesApi = api.injectEndpoints({ query: ({ board_id, imageDTO }) => { const { image_name } = imageDTO; return { - url: `board_images/`, + url: buildBoardImagesUrl(), method: 'POST', body: { board_id, image_name }, }; @@ -961,7 +977,7 @@ export const imagesApi = api.injectEndpoints({ query: ({ imageDTO }) => { const { image_name } = imageDTO; return { - url: `board_images/`, + url: buildBoardImagesUrl(), method: 'DELETE', body: { image_name }, }; @@ -1080,7 +1096,7 @@ export const imagesApi = api.injectEndpoints({ } >({ query: ({ board_id, imageDTOs }) => ({ - url: `board_images/batch`, + url: buildBoardImagesUrl('batch'), method: 'POST', body: { image_names: imageDTOs.map((i) => i.image_name), @@ -1197,7 +1213,7 @@ export const imagesApi = api.injectEndpoints({ } >({ query: ({ imageDTOs }) => ({ - url: `board_images/batch/delete`, + url: buildBoardImagesUrl('batch/delete'), method: 'POST', body: { image_names: imageDTOs.map((i) => i.image_name), @@ -1321,7 +1337,7 @@ export const imagesApi = api.injectEndpoints({ components['schemas']['Body_download_images_from_list'] >({ query: ({ image_names, board_id }) => ({ - url: `images/download`, + url: buildImagesUrl('download'), method: 'POST', body: { image_names, diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index c11c8b45e5..9a7f108056 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -1,63 +1,28 @@ -import type { EntityState } from '@reduxjs/toolkit'; +import type { EntityAdapter, EntityState } from '@reduxjs/toolkit'; import { createEntityAdapter } from '@reduxjs/toolkit'; import { getSelectorsOptions } from 'app/store/createMemoizedSelector'; -import { cloneDeep } from 'lodash-es'; import queryString from 'query-string'; import type { operations, paths } from 'services/api/schema'; import type { AnyModelConfig, BaseModelType, - CheckpointModelConfig, - ControlNetModelConfig, - DiffusersModelConfig, + ControlNetConfig, ImportModelConfig, - IPAdapterModelConfig, - LoRAModelConfig, + IPAdapterConfig, + LoRAConfig, MainModelConfig, MergeModelConfig, ModelType, - T2IAdapterModelConfig, - TextualInversionModelConfig, - VaeModelConfig, + T2IAdapterConfig, + TextualInversionConfig, + VAEConfig, } from 'services/api/types'; -import type { ApiTagDescription } from '..'; -import { api, LIST_TAG } from '..'; +import type { ApiTagDescription, tagTypes } from '..'; +import { api, buildV2Url, LIST_TAG } from '..'; -export type DiffusersModelConfigEntity = DiffusersModelConfig & { id: string }; -export type CheckpointModelConfigEntity = CheckpointModelConfig & { - id: string; -}; -export type MainModelConfigEntity = DiffusersModelConfigEntity | CheckpointModelConfigEntity; - -export type LoRAModelConfigEntity = LoRAModelConfig & { id: string }; - -export type ControlNetModelConfigEntity = ControlNetModelConfig & { - id: string; -}; - -export type IPAdapterModelConfigEntity = IPAdapterModelConfig & { - id: string; -}; - -export type T2IAdapterModelConfigEntity = T2IAdapterModelConfig & { - id: string; -}; - -export type TextualInversionModelConfigEntity = TextualInversionModelConfig & { - id: string; -}; - -export type VaeModelConfigEntity = VaeModelConfig & { id: string }; - -export type AnyModelConfigEntity = - | MainModelConfigEntity - | LoRAModelConfigEntity - | ControlNetModelConfigEntity - | IPAdapterModelConfigEntity - | T2IAdapterModelConfigEntity - | TextualInversionModelConfigEntity - | VaeModelConfigEntity; +/* eslint-disable @typescript-eslint/no-explicit-any */ +export const getModelId = (input: any): any => input; type UpdateMainModelArg = { base_model: BaseModelType; @@ -68,11 +33,13 @@ type UpdateMainModelArg = { type UpdateLoRAModelArg = { base_model: BaseModelType; model_name: string; - body: LoRAModelConfig; + body: LoRAConfig; }; type UpdateMainModelResponse = - paths['/api/v1/models/{base_model}/{model_type}/{model_name}']['patch']['responses']['200']['content']['application/json']; + paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json']; + +type ListModelsArg = NonNullable; type UpdateLoRAModelResponse = UpdateMainModelResponse; @@ -128,91 +95,95 @@ type CheckpointConfigsResponse = type SearchFolderArg = operations['search_for_models']['parameters']['query']; -export const mainModelsAdapter = createEntityAdapter({ - sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), +export const mainModelsAdapter = createEntityAdapter({ + selectId: (entity) => entity.key, + sortComparer: (a, b) => a.name.localeCompare(b.name), }); export const mainModelsAdapterSelectors = mainModelsAdapter.getSelectors(undefined, getSelectorsOptions); -export const loraModelsAdapter = createEntityAdapter({ - sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), +export const loraModelsAdapter = createEntityAdapter({ + selectId: (entity) => entity.key, + sortComparer: (a, b) => a.name.localeCompare(b.name), }); export const loraModelsAdapterSelectors = loraModelsAdapter.getSelectors(undefined, getSelectorsOptions); -export const controlNetModelsAdapter = createEntityAdapter({ - sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), +export const controlNetModelsAdapter = createEntityAdapter({ + selectId: (entity) => entity.key, + sortComparer: (a, b) => a.name.localeCompare(b.name), }); export const controlNetModelsAdapterSelectors = controlNetModelsAdapter.getSelectors(undefined, getSelectorsOptions); -export const ipAdapterModelsAdapter = createEntityAdapter({ - sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), +export const ipAdapterModelsAdapter = createEntityAdapter({ + selectId: (entity) => entity.key, + sortComparer: (a, b) => a.name.localeCompare(b.name), }); export const ipAdapterModelsAdapterSelectors = ipAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions); -export const t2iAdapterModelsAdapter = createEntityAdapter({ - sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), +export const t2iAdapterModelsAdapter = createEntityAdapter({ + selectId: (entity) => entity.key, + sortComparer: (a, b) => a.name.localeCompare(b.name), }); export const t2iAdapterModelsAdapterSelectors = t2iAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions); -export const textualInversionModelsAdapter = createEntityAdapter({ - sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), +export const textualInversionModelsAdapter = createEntityAdapter({ + selectId: (entity) => entity.key, + sortComparer: (a, b) => a.name.localeCompare(b.name), }); export const textualInversionModelsAdapterSelectors = textualInversionModelsAdapter.getSelectors( undefined, getSelectorsOptions ); -export const vaeModelsAdapter = createEntityAdapter({ - sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), +export const vaeModelsAdapter = createEntityAdapter({ + selectId: (entity) => entity.key, + sortComparer: (a, b) => a.name.localeCompare(b.name), }); export const vaeModelsAdapterSelectors = vaeModelsAdapter.getSelectors(undefined, getSelectorsOptions); -export const getModelId = ({ - base_model, - model_type, - model_name, -}: Pick) => `${base_model}/${model_type}/${model_name}`; +const buildProvidesTags = + (tagType: (typeof tagTypes)[number]) => + (result: EntityState | undefined) => { + const tags: ApiTagDescription[] = [{ type: tagType, id: LIST_TAG }, 'Model']; -const createModelEntities = (models: AnyModelConfig[]): T[] => { - const entityArray: T[] = []; - models.forEach((model) => { - const entity = { - ...cloneDeep(model), - id: getModelId(model), - } as T; - entityArray.push(entity); - }); - return entityArray; -}; + if (result) { + tags.push( + ...result.ids.map((id) => ({ + type: tagType, + id, + })) + ); + } + + return tags; + }; + +const buildTransformResponse = + (adapter: EntityAdapter) => + (response: { models: T[] }) => { + return adapter.setAll(adapter.getInitialState(), response.models); + }; + +/** + * Builds an endpoint URL for the models router + * @example + * buildModelsUrl('some-path') + * // '/api/v1/models/some-path' + */ +const buildModelsUrl = (path: string = '') => buildV2Url(`models/${path}`); export const modelsApi = api.injectEndpoints({ endpoints: (build) => ({ - getMainModels: build.query, BaseModelType[]>({ + getMainModels: build.query, BaseModelType[]>({ query: (base_models) => { - const params = { + const params: ListModelsArg = { model_type: 'main', base_models, }; const query = queryString.stringify(params, { arrayFormat: 'none' }); - return `models/?${query}`; - }, - providesTags: (result) => { - const tags: ApiTagDescription[] = [{ type: 'MainModel', id: LIST_TAG }, 'Model']; - - if (result) { - tags.push( - ...result.ids.map((id) => ({ - type: 'MainModel' as const, - id, - })) - ); - } - - return tags; - }, - transformResponse: (response: { models: MainModelConfig[] }) => { - const entities = createModelEntities(response.models); - return mainModelsAdapter.setAll(mainModelsAdapter.getInitialState(), entities); + return buildModelsUrl(`?${query}`); }, + providesTags: buildProvidesTags('MainModel'), + transformResponse: buildTransformResponse(mainModelsAdapter), }), updateMainModels: build.mutation({ query: ({ base_model, model_name, body }) => { return { - url: `models/${base_model}/main/${model_name}`, + url: buildModelsUrl(`${base_model}/main/${model_name}`), method: 'PATCH', body: body, }; @@ -222,7 +193,7 @@ export const modelsApi = api.injectEndpoints({ importMainModels: build.mutation({ query: ({ body }) => { return { - url: `models/import`, + url: buildModelsUrl('import'), method: 'POST', body: body, }; @@ -232,7 +203,7 @@ export const modelsApi = api.injectEndpoints({ addMainModels: build.mutation({ query: ({ body }) => { return { - url: `models/add`, + url: buildModelsUrl('add'), method: 'POST', body: body, }; @@ -242,7 +213,7 @@ export const modelsApi = api.injectEndpoints({ deleteMainModels: build.mutation({ query: ({ base_model, model_name, model_type }) => { return { - url: `models/${base_model}/${model_type}/${model_name}`, + url: buildModelsUrl(`${base_model}/${model_type}/${model_name}`), method: 'DELETE', }; }, @@ -251,7 +222,7 @@ export const modelsApi = api.injectEndpoints({ convertMainModels: build.mutation({ query: ({ base_model, model_name, convert_dest_directory }) => { return { - url: `models/convert/${base_model}/main/${model_name}`, + url: buildModelsUrl(`convert/${base_model}/main/${model_name}`), method: 'PUT', params: { convert_dest_directory }, }; @@ -261,7 +232,7 @@ export const modelsApi = api.injectEndpoints({ mergeMainModels: build.mutation({ query: ({ base_model, body }) => { return { - url: `models/merge/${base_model}`, + url: buildModelsUrl(`merge/${base_model}`), method: 'PUT', body: body, }; @@ -271,37 +242,21 @@ export const modelsApi = api.injectEndpoints({ syncModels: build.mutation({ query: () => { return { - url: `models/sync`, + url: buildModelsUrl('sync'), method: 'POST', }; }, invalidatesTags: ['Model'], }), - getLoRAModels: build.query, void>({ - query: () => ({ url: 'models/', params: { model_type: 'lora' } }), - providesTags: (result) => { - const tags: ApiTagDescription[] = [{ type: 'LoRAModel', id: LIST_TAG }, 'Model']; - - if (result) { - tags.push( - ...result.ids.map((id) => ({ - type: 'LoRAModel' as const, - id, - })) - ); - } - - return tags; - }, - transformResponse: (response: { models: LoRAModelConfig[] }) => { - const entities = createModelEntities(response.models); - return loraModelsAdapter.setAll(loraModelsAdapter.getInitialState(), entities); - }, + getLoRAModels: build.query, void>({ + query: () => ({ url: buildModelsUrl(), params: { model_type: 'lora' } }), + providesTags: buildProvidesTags('LoRAModel'), + transformResponse: buildTransformResponse(loraModelsAdapter), }), updateLoRAModels: build.mutation({ query: ({ base_model, model_name, body }) => { return { - url: `models/${base_model}/lora/${model_name}`, + url: buildModelsUrl(`${base_model}/lora/${model_name}`), method: 'PATCH', body: body, }; @@ -311,129 +266,49 @@ export const modelsApi = api.injectEndpoints({ deleteLoRAModels: build.mutation({ query: ({ base_model, model_name }) => { return { - url: `models/${base_model}/lora/${model_name}`, + url: buildModelsUrl(`${base_model}/lora/${model_name}`), method: 'DELETE', }; }, invalidatesTags: [{ type: 'LoRAModel', id: LIST_TAG }], }), - getControlNetModels: build.query, void>({ - query: () => ({ url: 'models/', params: { model_type: 'controlnet' } }), - providesTags: (result) => { - const tags: ApiTagDescription[] = [{ type: 'ControlNetModel', id: LIST_TAG }, 'Model']; - - if (result) { - tags.push( - ...result.ids.map((id) => ({ - type: 'ControlNetModel' as const, - id, - })) - ); - } - - return tags; - }, - transformResponse: (response: { models: ControlNetModelConfig[] }) => { - const entities = createModelEntities(response.models); - return controlNetModelsAdapter.setAll(controlNetModelsAdapter.getInitialState(), entities); - }, + getControlNetModels: build.query, void>({ + query: () => ({ url: buildModelsUrl(), params: { model_type: 'controlnet' } }), + providesTags: buildProvidesTags('ControlNetModel'), + transformResponse: buildTransformResponse(controlNetModelsAdapter), }), - getIPAdapterModels: build.query, void>({ - query: () => ({ url: 'models/', params: { model_type: 'ip_adapter' } }), - providesTags: (result) => { - const tags: ApiTagDescription[] = [{ type: 'IPAdapterModel', id: LIST_TAG }, 'Model']; - - if (result) { - tags.push( - ...result.ids.map((id) => ({ - type: 'IPAdapterModel' as const, - id, - })) - ); - } - - return tags; - }, - transformResponse: (response: { models: IPAdapterModelConfig[] }) => { - const entities = createModelEntities(response.models); - return ipAdapterModelsAdapter.setAll(ipAdapterModelsAdapter.getInitialState(), entities); - }, + getIPAdapterModels: build.query, void>({ + query: () => ({ url: buildModelsUrl(), params: { model_type: 'ip_adapter' } }), + providesTags: buildProvidesTags('IPAdapterModel'), + transformResponse: buildTransformResponse(ipAdapterModelsAdapter), }), - getT2IAdapterModels: build.query, void>({ - query: () => ({ url: 'models/', params: { model_type: 't2i_adapter' } }), - providesTags: (result) => { - const tags: ApiTagDescription[] = [{ type: 'T2IAdapterModel', id: LIST_TAG }, 'Model']; - - if (result) { - tags.push( - ...result.ids.map((id) => ({ - type: 'T2IAdapterModel' as const, - id, - })) - ); - } - - return tags; - }, - transformResponse: (response: { models: T2IAdapterModelConfig[] }) => { - const entities = createModelEntities(response.models); - return t2iAdapterModelsAdapter.setAll(t2iAdapterModelsAdapter.getInitialState(), entities); - }, + getT2IAdapterModels: build.query, void>({ + query: () => ({ url: buildModelsUrl(), params: { model_type: 't2i_adapter' } }), + providesTags: buildProvidesTags('T2IAdapterModel'), + transformResponse: buildTransformResponse(t2iAdapterModelsAdapter), }), - getVaeModels: build.query, void>({ - query: () => ({ url: 'models/', params: { model_type: 'vae' } }), - providesTags: (result) => { - const tags: ApiTagDescription[] = [{ type: 'VaeModel', id: LIST_TAG }, 'Model']; - - if (result) { - tags.push( - ...result.ids.map((id) => ({ - type: 'VaeModel' as const, - id, - })) - ); - } - - return tags; - }, - transformResponse: (response: { models: VaeModelConfig[] }) => { - const entities = createModelEntities(response.models); - return vaeModelsAdapter.setAll(vaeModelsAdapter.getInitialState(), entities); - }, + getVaeModels: build.query, void>({ + query: () => ({ url: buildModelsUrl(), params: { model_type: 'vae' } }), + providesTags: buildProvidesTags('VaeModel'), + transformResponse: buildTransformResponse(vaeModelsAdapter), }), - getTextualInversionModels: build.query, void>({ - query: () => ({ url: 'models/', params: { model_type: 'embedding' } }), - providesTags: (result) => { - const tags: ApiTagDescription[] = [{ type: 'TextualInversionModel', id: LIST_TAG }, 'Model']; - - if (result) { - tags.push( - ...result.ids.map((id) => ({ - type: 'TextualInversionModel' as const, - id, - })) - ); - } - - return tags; - }, - transformResponse: (response: { models: TextualInversionModelConfig[] }) => { - const entities = createModelEntities(response.models); - return textualInversionModelsAdapter.setAll(textualInversionModelsAdapter.getInitialState(), entities); - }, + getTextualInversionModels: build.query, void>({ + query: () => ({ url: buildModelsUrl(), params: { model_type: 'embedding' } }), + providesTags: buildProvidesTags('TextualInversionModel'), + transformResponse: buildTransformResponse(textualInversionModelsAdapter), }), getModelsInFolder: build.query({ query: (arg) => { const folderQueryStr = queryString.stringify(arg, {}); return { - url: `/models/search?${folderQueryStr}`, + url: buildModelsUrl(`search?${folderQueryStr}`), }; }, }), getCheckpointConfigs: build.query({ query: () => { return { - url: `/models/ckpt_confs`, + url: buildModelsUrl(`ckpt_confs`), }; }, }), diff --git a/invokeai/frontend/web/src/services/api/endpoints/queue.ts b/invokeai/frontend/web/src/services/api/endpoints/queue.ts index 6c0798a936..385aa8ad12 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/queue.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/queue.ts @@ -7,7 +7,15 @@ import queryString from 'query-string'; import type { components, paths } from 'services/api/schema'; import type { ApiTagDescription } from '..'; -import { api } from '..'; +import { api, buildV1Url } from '..'; + +/** + * Builds an endpoint URL for the queue router + * @example + * buildQueueUrl('some-path') + * // '/api/v1/queue/queue_id/some-path' + */ +const buildQueueUrl = (path: string = '') => buildV1Url(`queue/${$queueId.get()}/${path}`); const getListQueueItemsUrl = (queryArgs?: paths['/api/v1/queue/{queue_id}/list']['get']['parameters']['query']) => { const query = queryArgs @@ -17,10 +25,10 @@ const getListQueueItemsUrl = (queryArgs?: paths['/api/v1/queue/{queue_id}/list'] : undefined; if (query) { - return `queue/${$queueId.get()}/list?${query}`; + return buildQueueUrl(`list?${query}`); } - return `queue/${$queueId.get()}/list`; + return buildQueueUrl('list'); }; export type SessionQueueItemStatus = NonNullable< @@ -58,7 +66,7 @@ export const queueApi = api.injectEndpoints({ paths['/api/v1/queue/{queue_id}/enqueue_batch']['post']['requestBody']['content']['application/json'] >({ query: (arg) => ({ - url: `queue/${$queueId.get()}/enqueue_batch`, + url: buildQueueUrl('enqueue_batch'), body: arg, method: 'POST', }), @@ -78,7 +86,7 @@ export const queueApi = api.injectEndpoints({ void >({ query: () => ({ - url: `queue/${$queueId.get()}/processor/resume`, + url: buildQueueUrl('processor/resume'), method: 'PUT', }), invalidatesTags: ['CurrentSessionQueueItem', 'SessionQueueStatus'], @@ -88,7 +96,7 @@ export const queueApi = api.injectEndpoints({ void >({ query: () => ({ - url: `queue/${$queueId.get()}/processor/pause`, + url: buildQueueUrl('processor/pause'), method: 'PUT', }), invalidatesTags: ['CurrentSessionQueueItem', 'SessionQueueStatus'], @@ -98,7 +106,7 @@ export const queueApi = api.injectEndpoints({ void >({ query: () => ({ - url: `queue/${$queueId.get()}/prune`, + url: buildQueueUrl('prune'), method: 'PUT', }), invalidatesTags: ['SessionQueueStatus', 'BatchStatus'], @@ -117,7 +125,7 @@ export const queueApi = api.injectEndpoints({ void >({ query: () => ({ - url: `queue/${$queueId.get()}/clear`, + url: buildQueueUrl('clear'), method: 'PUT', }), invalidatesTags: [ @@ -142,7 +150,7 @@ export const queueApi = api.injectEndpoints({ void >({ query: () => ({ - url: `queue/${$queueId.get()}/current`, + url: buildQueueUrl('current'), method: 'GET', }), providesTags: (result) => { @@ -158,7 +166,7 @@ export const queueApi = api.injectEndpoints({ void >({ query: () => ({ - url: `queue/${$queueId.get()}/next`, + url: buildQueueUrl('next'), method: 'GET', }), providesTags: (result) => { @@ -174,7 +182,7 @@ export const queueApi = api.injectEndpoints({ void >({ query: () => ({ - url: `queue/${$queueId.get()}/status`, + url: buildQueueUrl('status'), method: 'GET', }), providesTags: ['SessionQueueStatus', 'FetchOnReconnect'], @@ -184,7 +192,7 @@ export const queueApi = api.injectEndpoints({ { batch_id: string } >({ query: ({ batch_id }) => ({ - url: `queue/${$queueId.get()}/b/${batch_id}/status`, + url: buildQueueUrl(`/b/${batch_id}/status`), method: 'GET', }), providesTags: (result) => { @@ -200,7 +208,7 @@ export const queueApi = api.injectEndpoints({ number >({ query: (item_id) => ({ - url: `queue/${$queueId.get()}/i/${item_id}`, + url: buildQueueUrl(`i/${item_id}`), method: 'GET', }), providesTags: (result) => { @@ -216,7 +224,7 @@ export const queueApi = api.injectEndpoints({ number >({ query: (item_id) => ({ - url: `queue/${$queueId.get()}/i/${item_id}/cancel`, + url: buildQueueUrl(`i/${item_id}/cancel`), method: 'PUT', }), onQueryStarted: async (item_id, { dispatch, queryFulfilled }) => { @@ -253,7 +261,7 @@ export const queueApi = api.injectEndpoints({ paths['/api/v1/queue/{queue_id}/cancel_by_batch_ids']['put']['requestBody']['content']['application/json'] >({ query: (body) => ({ - url: `queue/${$queueId.get()}/cancel_by_batch_ids`, + url: buildQueueUrl('cancel_by_batch_ids'), method: 'PUT', body, }), @@ -279,7 +287,7 @@ export const queueApi = api.injectEndpoints({ method: 'GET', }), serializeQueryArgs: () => { - return `queue/${$queueId.get()}/list`; + return buildQueueUrl('list'); }, transformResponse: (response: components['schemas']['CursorPaginatedResults_SessionQueueItemDTO_']) => queueItemsAdapter.addMany( diff --git a/invokeai/frontend/web/src/services/api/endpoints/utilities.ts b/invokeai/frontend/web/src/services/api/endpoints/utilities.ts index c08ee62dc9..309dd2dc79 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/utilities.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/utilities.ts @@ -1,6 +1,14 @@ import type { components } from 'services/api/schema'; -import { api } from '..'; +import { api, buildV1Url } from '..'; + +/** + * Builds an endpoint URL for the utilities router + * @example + * buildUtilitiesUrl('some-path') + * // '/api/v1/utilities/some-path' + */ +const buildUtilitiesUrl = (path: string = '') => buildV1Url(`utilities/${path}`); export const utilitiesApi = api.injectEndpoints({ endpoints: (build) => ({ @@ -9,7 +17,7 @@ export const utilitiesApi = api.injectEndpoints({ { prompt: string; max_prompts: number } >({ query: (arg) => ({ - url: 'utilities/dynamicprompts', + url: buildUtilitiesUrl('dynamicprompts'), body: arg, method: 'POST', }), diff --git a/invokeai/frontend/web/src/services/api/endpoints/workflows.ts b/invokeai/frontend/web/src/services/api/endpoints/workflows.ts index c382f7e111..0280e2ebc4 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/workflows.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/workflows.ts @@ -1,6 +1,14 @@ import type { paths } from 'services/api/schema'; -import { api, LIST_TAG } from '..'; +import { api, buildV1Url, LIST_TAG } from '..'; + +/** + * Builds an endpoint URL for the workflows router + * @example + * buildWorkflowsUrl('some-path') + * // '/api/v1/workflows/some-path' + */ +const buildWorkflowsUrl = (path: string = '') => buildV1Url(`workflows/${path}`); export const workflowsApi = api.injectEndpoints({ endpoints: (build) => ({ @@ -8,7 +16,7 @@ export const workflowsApi = api.injectEndpoints({ paths['/api/v1/workflows/i/{workflow_id}']['get']['responses']['200']['content']['application/json'], string >({ - query: (workflow_id) => `workflows/i/${workflow_id}`, + query: (workflow_id) => buildWorkflowsUrl(`i/${workflow_id}`), providesTags: (result, error, workflow_id) => [{ type: 'Workflow', id: workflow_id }, 'FetchOnReconnect'], onQueryStarted: async (arg, api) => { const { dispatch, queryFulfilled } = api; @@ -22,7 +30,7 @@ export const workflowsApi = api.injectEndpoints({ }), deleteWorkflow: build.mutation({ query: (workflow_id) => ({ - url: `workflows/i/${workflow_id}`, + url: buildWorkflowsUrl(`i/${workflow_id}`), method: 'DELETE', }), invalidatesTags: (result, error, workflow_id) => [ @@ -36,7 +44,7 @@ export const workflowsApi = api.injectEndpoints({ paths['/api/v1/workflows/']['post']['requestBody']['content']['application/json']['workflow'] >({ query: (workflow) => ({ - url: 'workflows/', + url: buildWorkflowsUrl(), method: 'POST', body: { workflow }, }), @@ -50,7 +58,7 @@ export const workflowsApi = api.injectEndpoints({ paths['/api/v1/workflows/i/{workflow_id}']['patch']['requestBody']['content']['application/json']['workflow'] >({ query: (workflow) => ({ - url: `workflows/i/${workflow.id}`, + url: buildWorkflowsUrl(`i/${workflow.id}`), method: 'PATCH', body: { workflow }, }), @@ -65,7 +73,7 @@ export const workflowsApi = api.injectEndpoints({ NonNullable >({ query: (params) => ({ - url: 'workflows/', + url: buildWorkflowsUrl(), params, }), providesTags: ['FetchOnReconnect', { type: 'Workflow', id: LIST_TAG }], diff --git a/invokeai/frontend/web/src/services/api/index.ts b/invokeai/frontend/web/src/services/api/index.ts index 8cb9aa8618..1f567d7905 100644 --- a/invokeai/frontend/web/src/services/api/index.ts +++ b/invokeai/frontend/web/src/services/api/index.ts @@ -54,7 +54,7 @@ const dynamicBaseQuery: BaseQueryFn { if (authToken) { headers.set('Authorization', `Bearer ${authToken}`); @@ -108,3 +108,6 @@ function getCircularReplacer() { return value; }; } + +export const buildV1Url = (path: string): string => `api/v1/${path}`; +export const buildV2Url = (path: string): string => `api/v2/${path}`; diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 3393e74d48..40fc262be2 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -4212,7 +4212,7 @@ export type components = { * @description The nodes in this graph */ nodes?: { - [key: string]: components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"]; + [key: string]: components["schemas"]["ControlNetInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["CvInpaintInvocation"]; }; /** * Edges @@ -4249,7 +4249,7 @@ export type components = { * @description The results of node executions */ results: { - [key: string]: components["schemas"]["ImageCollectionOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["SDXLLoraLoaderOutput"] | components["schemas"]["String2Output"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["CLIPOutput"]; + [key: string]: components["schemas"]["SchedulerOutput"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["SDXLLoraLoaderOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["String2Output"] | components["schemas"]["IntegerOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["IterateInvocationOutput"]; }; /** * Errors @@ -11119,17 +11119,11 @@ export type components = { */ VaeModelFormat: "checkpoint" | "diffusers"; /** - * StableDiffusion2ModelFormat + * T2IAdapterModelFormat * @description An enumeration. * @enum {string} */ - StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; - /** - * CLIPVisionModelFormat - * @description An enumeration. - * @enum {string} - */ - CLIPVisionModelFormat: "diffusers"; + T2IAdapterModelFormat: "diffusers"; /** * StableDiffusionXLModelFormat * @description An enumeration. @@ -11142,12 +11136,6 @@ export type components = { * @enum {string} */ StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; - /** - * ControlNetModelFormat - * @description An enumeration. - * @enum {string} - */ - ControlNetModelFormat: "checkpoint" | "diffusers"; /** * StableDiffusionOnnxModelFormat * @description An enumeration. @@ -11155,17 +11143,29 @@ export type components = { */ StableDiffusionOnnxModelFormat: "olive" | "onnx"; /** - * T2IAdapterModelFormat + * ControlNetModelFormat * @description An enumeration. * @enum {string} */ - T2IAdapterModelFormat: "diffusers"; + ControlNetModelFormat: "checkpoint" | "diffusers"; + /** + * StableDiffusion2ModelFormat + * @description An enumeration. + * @enum {string} + */ + StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; /** * LoRAModelFormat * @description An enumeration. * @enum {string} */ LoRAModelFormat: "lycoris" | "diffusers"; + /** + * CLIPVisionModelFormat + * @description An enumeration. + * @enum {string} + */ + CLIPVisionModelFormat: "diffusers"; /** * IPAdapterModelFormat * @description An enumeration. diff --git a/invokeai/frontend/web/src/services/api/types.ts b/invokeai/frontend/web/src/services/api/types.ts index f9a1decf65..4ae2f9b594 100644 --- a/invokeai/frontend/web/src/services/api/types.ts +++ b/invokeai/frontend/web/src/services/api/types.ts @@ -2,6 +2,7 @@ import type { UseToastOptions } from '@invoke-ai/ui-library'; import type { EntityState } from '@reduxjs/toolkit'; import type { components, paths } from 'services/api/schema'; import type { O } from 'ts-toolbelt'; +import type { SetRequired } from 'type-fest'; export type S = components['schemas']; @@ -54,40 +55,36 @@ export type LoRAModelFormat = S['LoRAModelFormat']; export type ControlNetModelField = S['ControlNetModelField']; export type IPAdapterModelField = S['IPAdapterModelField']; export type T2IAdapterModelField = S['T2IAdapterModelField']; -export type ModelsList = S['invokeai__app__api__routers__models__ModelsList']; export type ControlField = S['ControlField']; export type IPAdapterField = S['IPAdapterField']; // Model Configs -export type LoRAModelConfig = S['LoRAModelConfig']; -export type VaeModelConfig = S['VaeModelConfig']; -export type ControlNetModelCheckpointConfig = S['ControlNetModelCheckpointConfig']; -export type ControlNetModelDiffusersConfig = S['ControlNetModelDiffusersConfig']; -export type ControlNetModelConfig = ControlNetModelCheckpointConfig | ControlNetModelDiffusersConfig; -export type IPAdapterModelInvokeAIConfig = S['IPAdapterModelInvokeAIConfig']; -export type IPAdapterModelConfig = IPAdapterModelInvokeAIConfig; -export type T2IAdapterModelDiffusersConfig = S['T2IAdapterModelDiffusersConfig']; -export type T2IAdapterModelConfig = T2IAdapterModelDiffusersConfig; -export type TextualInversionModelConfig = S['TextualInversionModelConfig']; -export type DiffusersModelConfig = - | S['StableDiffusion1ModelDiffusersConfig'] - | S['StableDiffusion2ModelDiffusersConfig'] - | S['StableDiffusionXLModelDiffusersConfig']; -export type CheckpointModelConfig = - | S['StableDiffusion1ModelCheckpointConfig'] - | S['StableDiffusion2ModelCheckpointConfig'] - | S['StableDiffusionXLModelCheckpointConfig']; + +// TODO(MM2): Can we make key required in the pydantic model? +type KeyRequired = SetRequired; +export type LoRAConfig = KeyRequired; +// TODO(MM2): Can we rename this from Vae -> VAE +export type VAEConfig = KeyRequired | KeyRequired; +export type ControlNetConfig = + | KeyRequired + | KeyRequired; +export type IPAdapterConfig = KeyRequired; +// TODO(MM2): Can we rename this to T2IAdapterConfig +export type T2IAdapterConfig = KeyRequired; +export type TextualInversionConfig = KeyRequired; +export type DiffusersModelConfig = KeyRequired; +export type CheckpointModelConfig = KeyRequired; export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig; export type AnyModelConfig = - | LoRAModelConfig - | VaeModelConfig - | ControlNetModelConfig - | IPAdapterModelConfig - | T2IAdapterModelConfig - | TextualInversionModelConfig + | LoRAConfig + | VAEConfig + | ControlNetConfig + | IPAdapterConfig + | T2IAdapterConfig + | TextualInversionConfig | MainModelConfig; -export type MergeModelConfig = S['Body_merge_models']; +export type MergeModelConfig = S['Body_merge']; export type ImportModelConfig = S['Body_import_model']; // Graphs diff --git a/invokeai/frontend/web/src/services/api/util.ts b/invokeai/frontend/web/src/services/api/util.ts index f7f36f4630..a7a5d6451e 100644 --- a/invokeai/frontend/web/src/services/api/util.ts +++ b/invokeai/frontend/web/src/services/api/util.ts @@ -3,6 +3,7 @@ import { getSelectorsOptions } from 'app/store/createMemoizedSelector'; import { dateComparator } from 'common/util/dateComparator'; import { ASSETS_CATEGORIES, IMAGE_CATEGORIES } from 'features/gallery/store/types'; import queryString from 'query-string'; +import { buildV1Url } from 'services/api'; import type { ImageCache, ImageDTO, ListImagesArgs } from './types'; @@ -79,4 +80,4 @@ export const imagesSelectors = imagesAdapter.getSelectors(undefined, getSelector // Helper to create the url for the listImages endpoint. Also we use it to create the cache key. export const getListImagesUrl = (queryArgs: ListImagesArgs) => - `images/?${queryString.stringify(queryArgs, { arrayFormat: 'none' })}`; + buildV1Url(`images/?${queryString.stringify(queryArgs, { arrayFormat: 'none' })}`); diff --git a/invokeai/frontend/web/vite.config.mts b/invokeai/frontend/web/vite.config.mts index f4dbae7123..32e3e1f64f 100644 --- a/invokeai/frontend/web/vite.config.mts +++ b/invokeai/frontend/web/vite.config.mts @@ -76,9 +76,9 @@ export default defineConfig(({ mode }) => { changeOrigin: true, }, // proxy nodes api - '/api/v1': { - target: 'http://127.0.0.1:9090/api/v1', - rewrite: (path) => path.replace(/^\/api\/v1/, ''), + '/api/': { + target: 'http://127.0.0.1:9090/api/', + rewrite: (path) => path.replace(/^\/api/, ''), changeOrigin: true, }, },