feat(ui): update model identifiers to use key (#5730)

## What type of PR is this? (check all applicable)

- [x] Refactor

## Description

- Update zod schemas & types to use key instead of name/base/type
- Use new `CustomSelect` component instead of `ComboBox` for main model
select and control adapter model selects (less jank, will switch to
ComboBox based on CustomSelect for v4 so you can search the select)

## QA Instructions, Screenshots, Recordings

If you hold your breath, you should be able to generate with a control
adapter.

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

## Merge Plan

This PR can be merged when approved. Frontend tests not passing.

<!--
A merge plan describes how this PR should be handled after it is
approved.

Example merge plans:
- "This PR can be merged when approved"
- "This must be squash-merged when approved"
- "DO NOT MERGE - I will rebase and tidy commits before merging"
- "#dev-chat on discord needs to be advised of this change when it is
merged"

A merge plan is particularly important for large PRs or PRs that touch
the
database in any way.
-->
This commit is contained in:
Brandon 2024-02-16 11:17:35 -05:00 committed by GitHub
commit bc524026f9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
70 changed files with 591 additions and 650 deletions

View File

@ -10,13 +10,7 @@ export const ReduxInit = memo((props: PropsWithChildren) => {
const dispatch = useAppDispatch();
useGlobalModifiersInit();
useEffect(() => {
dispatch(
modelChanged({
model_name: 'test_model',
base_model: 'sd-1',
model_type: 'main',
})
);
dispatch(modelChanged({ key: 'test_model', base: 'sd-1' }));
}, []);
return props.children;

View File

@ -19,7 +19,7 @@ export const addEnqueueRequestedLinear = () => {
let graph;
if (model && model.base_model === 'sdxl') {
if (model && model.base === 'sdxl') {
if (action.payload.tabName === 'txt2img') {
graph = buildLinearSDXLTextToImageGraph(state);
} else {

View File

@ -30,8 +30,8 @@ export const addModelSelectedListener = () => {
const newModel = result.data;
const newBaseModel = newModel.base_model;
const didBaseModelChange = state.generation.model?.base_model !== newBaseModel;
const newBaseModel = newModel.base;
const didBaseModelChange = state.generation.model?.base !== newBaseModel;
if (didBaseModelChange) {
// we may need to reset some incompatible submodels
@ -39,7 +39,7 @@ export const addModelSelectedListener = () => {
// handle incompatible loras
forEach(state.lora.loras, (lora, id) => {
if (lora.base_model !== newBaseModel) {
if (lora.base !== newBaseModel) {
dispatch(loraRemoved(id));
modelsCleared += 1;
}
@ -47,14 +47,14 @@ export const addModelSelectedListener = () => {
// handle incompatible vae
const { vae } = state.generation;
if (vae && vae.base_model !== newBaseModel) {
if (vae && vae.base !== newBaseModel) {
dispatch(vaeSelected(null));
modelsCleared += 1;
}
// handle incompatible controlnets
selectControlAdapterAll(state.controlAdapters).forEach((ca) => {
if (ca.model?.base_model !== newBaseModel) {
if (ca.model?.base !== newBaseModel) {
dispatch(controlAdapterIsEnabledChanged({ id: ca.id, isEnabled: false }));
modelsCleared += 1;
}

View File

@ -34,14 +34,7 @@ export const addModelsLoadedListener = () => {
return;
}
const isCurrentModelAvailable = currentModel
? models.some(
(m) =>
m.model_name === currentModel.model_name &&
m.base_model === currentModel.base_model &&
m.model_type === currentModel.model_type
)
: false;
const isCurrentModelAvailable = currentModel ? models.some((m) => m.key === currentModel.key) : false;
if (isCurrentModelAvailable) {
return;
@ -74,14 +67,7 @@ export const addModelsLoadedListener = () => {
return;
}
const isCurrentModelAvailable = currentModel
? models.some(
(m) =>
m.model_name === currentModel.model_name &&
m.base_model === currentModel.base_model &&
m.model_type === currentModel.model_type
)
: false;
const isCurrentModelAvailable = currentModel ? models.some((m) => m.key === currentModel.key) : false;
if (!isCurrentModelAvailable) {
dispatch(refinerModelChanged(null));
@ -103,10 +89,7 @@ export const addModelsLoadedListener = () => {
return;
}
const isCurrentVAEAvailable = some(
action.payload.entities,
(m) => m?.model_name === currentVae?.model_name && m?.base_model === currentVae?.base_model
);
const isCurrentVAEAvailable = some(action.payload.entities, (m) => m?.key === currentVae?.key);
if (isCurrentVAEAvailable) {
return;
@ -140,10 +123,7 @@ export const addModelsLoadedListener = () => {
const loras = getState().lora.loras;
forEach(loras, (lora, id) => {
const isLoRAAvailable = some(
action.payload.entities,
(m) => m?.model_name === lora?.model_name && m?.base_model === lora?.base_model
);
const isLoRAAvailable = some(action.payload.entities, (m) => m?.key === lora?.key);
if (isLoRAAvailable) {
return;
@ -161,10 +141,7 @@ export const addModelsLoadedListener = () => {
log.info({ models: action.payload.entities }, `ControlNet models loaded (${action.payload.ids.length})`);
selectAllControlNets(getState().controlAdapters).forEach((ca) => {
const isModelAvailable = some(
action.payload.entities,
(m) => m?.model_name === ca?.model?.model_name && m?.base_model === ca?.model?.base_model
);
const isModelAvailable = some(action.payload.entities, (m) => m?.key === ca?.model?.key);
if (isModelAvailable) {
return;
@ -182,10 +159,7 @@ export const addModelsLoadedListener = () => {
log.info({ models: action.payload.entities }, `T2I Adapter models loaded (${action.payload.ids.length})`);
selectAllT2IAdapters(getState().controlAdapters).forEach((ca) => {
const isModelAvailable = some(
action.payload.entities,
(m) => m?.model_name === ca?.model?.model_name && m?.base_model === ca?.model?.base_model
);
const isModelAvailable = some(action.payload.entities, (m) => m?.key === ca?.model?.key);
if (isModelAvailable) {
return;
@ -203,10 +177,7 @@ export const addModelsLoadedListener = () => {
log.info({ models: action.payload.entities }, `IP Adapter models loaded (${action.payload.ids.length})`);
selectAllIPAdapters(getState().controlAdapters).forEach((ca) => {
const isModelAvailable = some(
action.payload.entities,
(m) => m?.model_name === ca?.model?.model_name && m?.base_model === ca?.model?.base_model
);
const isModelAvailable = some(action.payload.entities, (m) => m?.key === ca?.model?.key);
if (isModelAvailable) {
return;

View File

@ -5,10 +5,10 @@ import type { GroupBase } from 'chakra-react-select';
import { groupBy, map, reduce } from 'lodash-es';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { AnyModelConfigEntity } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/endpoints/models';
import { getModelId } from 'services/api/endpoints/models';
type UseGroupedModelComboboxArg<T extends AnyModelConfigEntity> = {
type UseGroupedModelComboboxArg<T extends AnyModelConfig> = {
modelEntities: EntityState<T, string> | undefined;
selectedModel?: Pick<T, 'base_model' | 'model_name' | 'model_type'> | null;
onChange: (value: T | null) => void;
@ -24,7 +24,7 @@ type UseGroupedModelComboboxReturn = {
noOptionsMessage: () => string;
};
export const useGroupedModelCombobox = <T extends AnyModelConfigEntity>(
export const useGroupedModelCombobox = <T extends AnyModelConfig>(
arg: UseGroupedModelComboboxArg<T>
): UseGroupedModelComboboxReturn => {
const { t } = useTranslation();

View File

@ -105,7 +105,7 @@ const selector = createMemoizedSelector(
number: i + 1,
})
);
} else if (ca.model.base_model !== model?.base_model) {
} else if (ca.model.base !== model?.base) {
// This should never happen, just a sanity check
reasons.push(
i18n.t('parameters.invoke.incompatibleBaseModelForControlAdapter', {

View File

@ -3,10 +3,10 @@ import type { EntityState } from '@reduxjs/toolkit';
import { map } from 'lodash-es';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { AnyModelConfigEntity } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/endpoints/models';
import { getModelId } from 'services/api/endpoints/models';
type UseModelComboboxArg<T extends AnyModelConfigEntity> = {
type UseModelComboboxArg<T extends AnyModelConfig> = {
modelEntities: EntityState<T, string> | undefined;
selectedModel?: Pick<T, 'base_model' | 'model_name' | 'model_type'> | null;
onChange: (value: T | null) => void;
@ -23,9 +23,7 @@ type UseModelComboboxReturn = {
noOptionsMessage: () => string;
};
export const useModelCombobox = <T extends AnyModelConfigEntity>(
arg: UseModelComboboxArg<T>
): UseModelComboboxReturn => {
export const useModelCombobox = <T extends AnyModelConfig>(arg: UseModelComboboxArg<T>): UseModelComboboxReturn => {
const { t } = useTranslation();
const { modelEntities, selectedModel, getIsDisabled, onChange, isLoading, optionsFilter = () => true } = arg;
const options = useMemo<ComboboxOption[]>(() => {

View File

@ -0,0 +1,88 @@
import type { Item } from '@invoke-ai/ui-library';
import type { EntityState } from '@reduxjs/toolkit';
import { EMPTY_ARRAY } from 'app/store/util';
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants';
import { filter } from 'lodash-es';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { AnyModelConfig } from 'services/api/types';
type UseModelCustomSelectArg<T extends AnyModelConfig> = {
data: EntityState<T, string> | undefined;
isLoading: boolean;
selectedModel?: ModelIdentifierWithBase | null;
onChange: (value: T | null) => void;
modelFilter?: (model: T) => boolean;
isModelDisabled?: (model: T) => boolean;
};
type UseModelCustomSelectReturn = {
selectedItem: Item | null;
items: Item[];
onChange: (item: Item | null) => void;
placeholder: string;
};
const modelFilterDefault = () => true;
const isModelDisabledDefault = () => false;
export const useModelCustomSelect = <T extends AnyModelConfig>({
data,
isLoading,
selectedModel,
onChange,
modelFilter = modelFilterDefault,
isModelDisabled = isModelDisabledDefault,
}: UseModelCustomSelectArg<T>): UseModelCustomSelectReturn => {
const { t } = useTranslation();
const items: Item[] = useMemo(
() =>
data
? filter(data.entities, modelFilter).map<Item>((m) => ({
label: m.name,
value: m.key,
description: m.description,
group: MODEL_TYPE_SHORT_MAP[m.base],
isDisabled: isModelDisabled(m),
}))
: EMPTY_ARRAY,
[data, isModelDisabled, modelFilter]
);
const _onChange = useCallback(
(item: Item | null) => {
if (!item || !data) {
return;
}
const model = data.entities[item.value];
if (!model) {
return;
}
onChange(model);
},
[data, onChange]
);
const selectedItem = useMemo(() => items.find((o) => o.value === selectedModel?.key) ?? null, [selectedModel, items]);
const placeholder = useMemo(() => {
if (isLoading) {
return t('common.loading');
}
if (items.length === 0) {
return t('models.noModelsAvailable');
}
return t('models.selectModel');
}, [isLoading, items, t]);
return {
items,
onChange: _onChange,
selectedItem,
placeholder,
};
};

View File

@ -626,7 +626,7 @@ export const canvasSlice = createSlice({
},
extraReducers: (builder) => {
builder.addCase(modelChanged, (state, action) => {
if (action.meta.previousModel?.base_model === action.payload?.base_model) {
if (action.meta.previousModel?.base === action.payload?.base) {
// The base model hasn't changed, we don't need to optimize the size
return;
}

View File

@ -1,49 +1,37 @@
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { CustomSelect, FormControl } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { useModelCustomSelect } from 'common/hooks/useModelCustomSelect';
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
import { useControlAdapterModel } from 'features/controlAdapters/hooks/useControlAdapterModel';
import { useControlAdapterModelEntities } from 'features/controlAdapters/hooks/useControlAdapterModelEntities';
import { useControlAdapterModelQuery } from 'features/controlAdapters/hooks/useControlAdapterModelQuery';
import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType';
import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
import { pick } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type {
ControlNetModelConfigEntity,
IPAdapterModelConfigEntity,
T2IAdapterModelConfigEntity,
} from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'services/api/types';
type ParamControlAdapterModelProps = {
id: string;
};
const selectMainModel = createMemoizedSelector(selectGenerationSlice, (generation) => generation.model);
const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
const isEnabled = useControlAdapterIsEnabled(id);
const controlAdapterType = useControlAdapterType(id);
const model = useControlAdapterModel(id);
const dispatch = useAppDispatch();
const currentBaseModel = useAppSelector((s) => s.generation.model?.base_model);
const mainModel = useAppSelector(selectMainModel);
const { t } = useTranslation();
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
const models = useControlAdapterModelEntities(controlAdapterType);
const { data, isLoading } = useControlAdapterModelQuery(controlAdapterType);
const _onChange = useCallback(
(model: ControlNetModelConfigEntity | IPAdapterModelConfigEntity | T2IAdapterModelConfigEntity | null) => {
(model: ControlNetConfig | IPAdapterConfig | T2IAdapterConfig | null) => {
if (!model) {
return;
}
dispatch(
controlAdapterModelChanged({
id,
model: pick(model, 'base_model', 'model_name'),
model: pick(model, 'base', 'key'),
})
);
},
@ -55,34 +43,18 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
[controlAdapterType, model]
);
const getIsDisabled = useCallback(
(model: AnyModelConfig): boolean => {
const isCompatible = currentBaseModel === model.base_model;
const hasMainModel = Boolean(currentBaseModel);
return !hasMainModel || !isCompatible;
},
[currentBaseModel]
);
const { options, value, onChange, noOptionsMessage } = useGroupedModelCombobox({
modelEntities: models,
onChange: _onChange,
const { items, selectedItem, onChange, placeholder } = useModelCustomSelect({
data,
isLoading,
selectedModel,
getIsDisabled,
onChange: _onChange,
modelFilter: (model) => model.base === currentBaseModel,
});
return (
<Tooltip label={value?.description}>
<FormControl isDisabled={!isEnabled} isInvalid={!value || mainModel?.base_model !== model?.base_model}>
<Combobox
options={options}
placeholder={t('controlnet.selectModel')}
value={value}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Tooltip>
<FormControl isDisabled={!items.length || !isEnabled} isInvalid={!selectedItem || !items.length}>
<CustomSelect selectedItem={selectedItem} placeholder={placeholder} items={items} onChange={onChange} />
</FormControl>
);
};

View File

@ -6,14 +6,14 @@ import { useCallback, useMemo } from 'react';
import { useControlAdapterModels } from './useControlAdapterModels';
export const useAddControlAdapter = (type: ControlAdapterType) => {
const baseModel = useAppSelector((s) => s.generation.model?.base_model);
const baseModel = useAppSelector((s) => s.generation.model?.base);
const dispatch = useAppDispatch();
const models = useControlAdapterModels(type);
const firstModel = useMemo(() => {
// prefer to use a model that matches the base model
const firstCompatibleModel = models.filter((m) => (baseModel ? m.base_model === baseModel : true))[0];
const firstCompatibleModel = models.filter((m) => (baseModel ? m.base === baseModel : true))[0];
if (firstCompatibleModel) {
return firstCompatibleModel;

View File

@ -1,23 +0,0 @@
import type { ControlAdapterType } from 'features/controlAdapters/store/types';
import {
useGetControlNetModelsQuery,
useGetIPAdapterModelsQuery,
useGetT2IAdapterModelsQuery,
} from 'services/api/endpoints/models';
export const useControlAdapterModelEntities = (type?: ControlAdapterType) => {
const { data: controlNetModelsData } = useGetControlNetModelsQuery();
const { data: t2iAdapterModelsData } = useGetT2IAdapterModelsQuery();
const { data: ipAdapterModelsData } = useGetIPAdapterModelsQuery();
if (type === 'controlnet') {
return controlNetModelsData;
}
if (type === 't2i_adapter') {
return t2iAdapterModelsData;
}
if (type === 'ip_adapter') {
return ipAdapterModelsData;
}
return;
};

View File

@ -0,0 +1,26 @@
import type { ControlAdapterType } from 'features/controlAdapters/store/types';
import {
useGetControlNetModelsQuery,
useGetIPAdapterModelsQuery,
useGetT2IAdapterModelsQuery,
} from 'services/api/endpoints/models';
export const useControlAdapterModelQuery = (type: ControlAdapterType) => {
const controlNetModelsQuery = useGetControlNetModelsQuery();
const t2iAdapterModelsQuery = useGetT2IAdapterModelsQuery();
const ipAdapterModelsQuery = useGetIPAdapterModelsQuery();
if (type === 'controlnet') {
return controlNetModelsQuery;
}
if (type === 't2i_adapter') {
return t2iAdapterModelsQuery;
}
if (type === 'ip_adapter') {
return ipAdapterModelsQuery;
}
// Assert that the end of the function is not reachable.
const exhaustiveCheck: never = type;
return exhaustiveCheck;
};

View File

@ -5,14 +5,16 @@ import {
selectControlAdaptersSlice,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { useMemo } from 'react';
import { assert } from 'tsafe';
export const useControlAdapterType = (id: string) => {
const selector = useMemo(
() =>
createMemoizedSelector(
selectControlAdaptersSlice,
(controlAdapters) => selectControlAdapterById(controlAdapters, id)?.type
),
createMemoizedSelector(selectControlAdaptersSlice, (controlAdapters) => {
const type = selectControlAdapterById(controlAdapters, id)?.type;
assert(type !== undefined, `Control adapter with id ${id} not found`);
return type;
}),
[id]
);

View File

@ -236,7 +236,8 @@ export const controlAdaptersSlice = createSlice({
let processorType: ControlAdapterProcessorType | undefined = undefined;
for (const modelSubstring in CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS) {
if (model.model_name.includes(modelSubstring)) {
// TODO(MM2): matching modelSubstring to the model key is no longer a valid way to figure out the default processorType
if (model.key.includes(modelSubstring)) {
processorType = CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS[modelSubstring];
break;
}
@ -359,7 +360,8 @@ export const controlAdaptersSlice = createSlice({
let processorType: ControlAdapterProcessorType | undefined = undefined;
for (const modelSubstring in CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS) {
if (cn.model?.model_name.includes(modelSubstring)) {
// TODO(MM2): matching modelSubstring to the model key is no longer a valid way to figure out the default processorType
if (cn.model?.key.includes(modelSubstring)) {
processorType = CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS[modelSubstring];
break;
}

View File

@ -6,18 +6,18 @@ import type { EmbeddingSelectProps } from 'features/embedding/types';
import { t } from 'i18next';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import type { TextualInversionModelConfigEntity } from 'services/api/endpoints/models';
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
import type { TextualInversionConfig } from 'services/api/types';
const noOptionsMessage = () => t('embedding.noMatchingEmbedding');
export const EmbeddingSelect = memo(({ onSelect, onClose }: EmbeddingSelectProps) => {
const { t } = useTranslation();
const currentBaseModel = useAppSelector((s) => s.generation.model?.base_model);
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
const getIsDisabled = useCallback(
(embedding: TextualInversionModelConfigEntity): boolean => {
(embedding: TextualInversionConfig): boolean => {
const isCompatible = currentBaseModel === embedding.base_model;
const hasMainModel = Boolean(currentBaseModel);
return !hasMainModel || !isCompatible;
@ -27,7 +27,7 @@ export const EmbeddingSelect = memo(({ onSelect, onClose }: EmbeddingSelectProps
const { data, isLoading } = useGetTextualInversionModelsQuery();
const _onChange = useCallback(
(embedding: TextualInversionModelConfigEntity | null) => {
(embedding: TextualInversionConfig | null) => {
if (!embedding) {
return;
}

View File

@ -208,8 +208,8 @@ const ImageMetadataActions = (props: Props) => {
{metadata.seed !== undefined && metadata.seed !== null && (
<ImageMetadataItem label={t('metadata.seed')} value={metadata.seed} onClick={handleRecallSeed} />
)}
{metadata.model !== undefined && metadata.model !== null && metadata.model.model_name && (
<ImageMetadataItem label={t('metadata.model')} value={metadata.model.model_name} onClick={handleRecallModel} />
{metadata.model !== undefined && metadata.model !== null && metadata.model.key && (
<ImageMetadataItem label={t('metadata.model')} value={metadata.model.key} onClick={handleRecallModel} />
)}
{metadata.width && (
<ImageMetadataItem label={t('metadata.width')} value={metadata.width} onClick={handleRecallWidth} />
@ -222,7 +222,7 @@ const ImageMetadataActions = (props: Props) => {
)}
<ImageMetadataItem
label={t('metadata.vae')}
value={metadata.vae?.model_name ?? 'Default'}
value={metadata.vae?.key ?? 'Default'}
onClick={handleRecallVaeModel}
/>
{metadata.steps && (
@ -269,7 +269,7 @@ const ImageMetadataActions = (props: Props) => {
<ImageMetadataItem
key={index}
label="LoRA"
value={`${lora.lora.model_name} - ${lora.weight}`}
value={`${lora.lora.key} - ${lora.weight}`}
onClick={handleRecallLoRA.bind(null, lora)}
/>
);
@ -279,7 +279,7 @@ const ImageMetadataActions = (props: Props) => {
<ImageMetadataItem
key={index}
label="ControlNet"
value={`${controlnet.control_model?.model_name} - ${controlnet.control_weight}`}
value={`${controlnet.control_model?.key} - ${controlnet.control_weight}`}
onClick={handleRecallControlNet.bind(null, controlnet)}
/>
))}
@ -287,7 +287,7 @@ const ImageMetadataActions = (props: Props) => {
<ImageMetadataItem
key={index}
label="IP Adapter"
value={`${ipAdapter.ip_adapter_model?.model_name} - ${ipAdapter.weight}`}
value={`${ipAdapter.ip_adapter_model?.key} - ${ipAdapter.weight}`}
onClick={handleRecallIPAdapter.bind(null, ipAdapter)}
/>
))}
@ -295,7 +295,7 @@ const ImageMetadataActions = (props: Props) => {
<ImageMetadataItem
key={index}
label="T2I Adapter"
value={`${t2iAdapter.t2i_adapter_model?.model_name} - ${t2iAdapter.weight}`}
value={`${t2iAdapter.t2i_adapter_model?.key} - ${t2iAdapter.weight}`}
onClick={handleRecallT2IAdapter.bind(null, t2iAdapter)}
/>
))}

View File

@ -43,7 +43,7 @@ export const LoRACard = memo((props: LoRACardProps) => {
<CardHeader>
<Flex alignItems="center" justifyContent="space-between" width="100%" gap={2}>
<Text noOfLines={1} wordBreak="break-all" color={lora.isEnabled ? 'base.200' : 'base.500'}>
{lora.model_name}
{lora.key}
</Text>
<Flex alignItems="center" gap={2}>
<Switch size="sm" onChange={handleSetLoraToggle} isChecked={lora.isEnabled} />

View File

@ -18,7 +18,7 @@ export const LoRAList = memo(() => {
return (
<Flex flexWrap="wrap" gap={2}>
{lorasArray.map((lora) => (
<LoRACard key={lora.model_name} lora={lora} />
<LoRACard key={lora.key} lora={lora} />
))}
</Flex>
);

View File

@ -6,7 +6,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { loraAdded, selectLoraSlice } from 'features/lora/store/loraSlice';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { LoRAModelConfigEntity } from 'services/api/endpoints/models';
import type { LoRAConfig } from 'services/api/endpoints/models';
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
const selectAddedLoRAs = createMemoizedSelector(selectLoraSlice, (lora) => lora.loras);
@ -18,7 +18,7 @@ const LoRASelect = () => {
const addedLoRAs = useAppSelector(selectAddedLoRAs);
const currentBaseModel = useAppSelector((s) => s.generation.model?.base_model);
const getIsDisabled = (lora: LoRAModelConfigEntity): boolean => {
const getIsDisabled = (lora: LoRAConfig): boolean => {
const isCompatible = currentBaseModel === lora.base_model;
const isAdded = Boolean(addedLoRAs[lora.id]);
const hasMainModel = Boolean(currentBaseModel);
@ -26,7 +26,7 @@ const LoRASelect = () => {
};
const _onChange = useCallback(
(lora: LoRAModelConfigEntity | null) => {
(lora: LoRAConfig | null) => {
if (!lora) {
return;
}

View File

@ -2,10 +2,9 @@ import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
import type { LoRAModelConfigEntity } from 'services/api/endpoints/models';
import type { LoRAConfig } from 'services/api/types';
export type LoRA = ParameterLoRAModel & {
id: string;
weight: number;
isEnabled?: boolean;
};
@ -29,40 +28,40 @@ export const loraSlice = createSlice({
name: 'lora',
initialState: initialLoraState,
reducers: {
loraAdded: (state, action: PayloadAction<LoRAModelConfigEntity>) => {
const { model_name, id, base_model } = action.payload;
state.loras[id] = { id, model_name, base_model, ...defaultLoRAConfig };
loraAdded: (state, action: PayloadAction<LoRAConfig>) => {
const { key, base } = action.payload;
state.loras[key] = { key, base, ...defaultLoRAConfig };
},
loraRecalled: (state, action: PayloadAction<LoRAModelConfigEntity & { weight: number }>) => {
const { model_name, id, base_model, weight } = action.payload;
state.loras[id] = { id, model_name, base_model, weight, isEnabled: true };
loraRecalled: (state, action: PayloadAction<LoRAConfig & { weight: number }>) => {
const { key, base, weight } = action.payload;
state.loras[key] = { key, base, weight, isEnabled: true };
},
loraRemoved: (state, action: PayloadAction<string>) => {
const id = action.payload;
delete state.loras[id];
const key = action.payload;
delete state.loras[key];
},
lorasCleared: (state) => {
state.loras = {};
},
loraWeightChanged: (state, action: PayloadAction<{ id: string; weight: number }>) => {
const { id, weight } = action.payload;
const lora = state.loras[id];
loraWeightChanged: (state, action: PayloadAction<{ key: string; weight: number }>) => {
const { key, weight } = action.payload;
const lora = state.loras[key];
if (!lora) {
return;
}
lora.weight = weight;
},
loraWeightReset: (state, action: PayloadAction<string>) => {
const id = action.payload;
const lora = state.loras[id];
const key = action.payload;
const lora = state.loras[key];
if (!lora) {
return;
}
lora.weight = defaultLoRAConfig.weight;
},
loraIsEnabledChanged: (state, action: PayloadAction<Pick<LoRA, 'id' | 'isEnabled'>>) => {
const { id, isEnabled } = action.payload;
const lora = state.loras[id];
loraIsEnabledChanged: (state, action: PayloadAction<Pick<LoRA, 'key' | 'isEnabled'>>) => {
const { key, isEnabled } = action.payload;
const lora = state.loras[key];
if (!lora) {
return;
}

View File

@ -2,11 +2,7 @@ import { Flex, Text } from '@invoke-ai/ui-library';
import { memo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { ALL_BASE_MODELS } from 'services/api/constants';
import type {
DiffusersModelConfigEntity,
LoRAModelConfigEntity,
MainModelConfigEntity,
} from 'services/api/endpoints/models';
import type { DiffusersModelConfig, LoRAConfig, MainModelConfig } from 'services/api/endpoints/models';
import { useGetLoRAModelsQuery, useGetMainModelsQuery } from 'services/api/endpoints/models';
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
@ -38,7 +34,7 @@ const ModelManagerPanel = () => {
};
type ModelEditProps = {
model: MainModelConfigEntity | LoRAModelConfigEntity | undefined;
model: MainModelConfig | LoRAConfig | undefined;
};
const ModelEdit = (props: ModelEditProps) => {
@ -50,7 +46,7 @@ const ModelEdit = (props: ModelEditProps) => {
}
if (model?.model_format === 'diffusers') {
return <DiffusersModelEdit key={model.id} model={model as DiffusersModelConfigEntity} />;
return <DiffusersModelEdit key={model.id} model={model as DiffusersModelConfig} />;
}
if (model?.model_type === 'lora') {

View File

@ -21,14 +21,14 @@ import { memo, useCallback, useEffect, useState } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import type { CheckpointModelConfigEntity } from 'services/api/endpoints/models';
import type { CheckpointModelConfig } from 'services/api/endpoints/models';
import { useGetCheckpointConfigsQuery, useUpdateMainModelsMutation } from 'services/api/endpoints/models';
import type { CheckpointModelConfig } from 'services/api/types';
import ModelConvert from './ModelConvert';
type CheckpointModelEditProps = {
model: CheckpointModelConfigEntity;
model: CheckpointModelConfig;
};
const CheckpointModelEdit = (props: CheckpointModelEditProps) => {

View File

@ -9,12 +9,12 @@ import { memo, useCallback } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import type { DiffusersModelConfigEntity } from 'services/api/endpoints/models';
import type { DiffusersModelConfig } from 'services/api/endpoints/models';
import { useUpdateMainModelsMutation } from 'services/api/endpoints/models';
import type { DiffusersModelConfig } from 'services/api/types';
type DiffusersModelEditProps = {
model: DiffusersModelConfigEntity;
model: DiffusersModelConfig;
};
const DiffusersModelEdit = (props: DiffusersModelEditProps) => {

View File

@ -8,12 +8,12 @@ import { memo, useCallback } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import type { LoRAModelConfigEntity } from 'services/api/endpoints/models';
import type { LoRAConfig } from 'services/api/endpoints/models';
import { useUpdateLoRAModelsMutation } from 'services/api/endpoints/models';
import type { LoRAModelConfig } from 'services/api/types';
import type { LoRAConfig } from 'services/api/types';
type LoRAModelEditProps = {
model: LoRAModelConfigEntity;
model: LoRAConfig;
};
const LoRAModelEdit = (props: LoRAModelEditProps) => {
@ -30,7 +30,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
control,
formState: { errors },
reset,
} = useForm<LoRAModelConfig>({
} = useForm<LoRAConfig>({
defaultValues: {
model_name: model.model_name ? model.model_name : '',
base_model: model.base_model,
@ -42,7 +42,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
mode: 'onChange',
});
const onSubmit = useCallback<SubmitHandler<LoRAModelConfig>>(
const onSubmit = useCallback<SubmitHandler<LoRAConfig>>(
(values) => {
const responseBody = {
base_model: model.base_model,
@ -53,7 +53,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
updateLoRAModel(responseBody)
.unwrap()
.then((payload) => {
reset(payload as LoRAModelConfig, { keepDefaultValues: true });
reset(payload as LoRAConfig, { keepDefaultValues: true });
dispatch(
addToast(
makeToast({
@ -106,7 +106,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
<FormLabel>{t('modelManager.description')}</FormLabel>
<Input {...register('description')} />
</FormControl>
<BaseModelSelect<LoRAModelConfig> control={control} name="base_model" />
<BaseModelSelect<LoRAConfig> control={control} name="base_model" />
<FormControl isInvalid={Boolean(errors.path)}>
<FormLabel>{t('modelManager.modelLocation')}</FormLabel>

View File

@ -5,7 +5,7 @@ import type { ChangeEvent, PropsWithChildren } from 'react';
import { memo, useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { ALL_BASE_MODELS } from 'services/api/constants';
import type { LoRAModelConfigEntity, MainModelConfigEntity } from 'services/api/endpoints/models';
import type { LoRAConfig, MainModelConfig } from 'services/api/endpoints/models';
import { useGetLoRAModelsQuery, useGetMainModelsQuery } from 'services/api/endpoints/models';
import ModelListItem from './ModelListItem';
@ -127,7 +127,7 @@ const ModelList = (props: ModelListProps) => {
export default memo(ModelList);
const modelsFilter = <T extends MainModelConfigEntity | LoRAModelConfigEntity>(
const modelsFilter = <T extends MainModelConfig | LoRAConfig>(
data: EntityState<T, string> | undefined,
model_type: ModelType,
model_format: ModelFormat | undefined,
@ -163,7 +163,7 @@ StyledModelContainer.displayName = 'StyledModelContainer';
type ModelListWrapperProps = {
title: string;
modelList: MainModelConfigEntity[] | LoRAModelConfigEntity[];
modelList: MainModelConfig[] | LoRAConfig[];
selected: ModelListProps;
};

View File

@ -15,11 +15,11 @@ import { makeToast } from 'features/system/util/makeToast';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiTrashSimpleBold } from 'react-icons/pi';
import type { LoRAModelConfigEntity, MainModelConfigEntity } from 'services/api/endpoints/models';
import type { LoRAConfig, MainModelConfig } from 'services/api/endpoints/models';
import { useDeleteLoRAModelsMutation, useDeleteMainModelsMutation } from 'services/api/endpoints/models';
type ModelListItemProps = {
model: MainModelConfigEntity | LoRAModelConfigEntity;
model: MainModelConfig | LoRAConfig;
isSelected: boolean;
setSelectedModelId: (v: string | undefined) => void;
};

View File

@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldControlNetModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { ControlNetModelFieldInputInstance, ControlNetModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import type { ControlNetModelConfigEntity } from 'services/api/endpoints/models';
import type { ControlNetConfig } from 'services/api/endpoints/models';
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
import type { FieldComponentProps } from './types';
@ -17,7 +17,7 @@ const ControlNetModelFieldInputComponent = (props: Props) => {
const { data, isLoading } = useGetControlNetModelsQuery();
const _onChange = useCallback(
(value: ControlNetModelConfigEntity | null) => {
(value: ControlNetConfig | null) => {
if (!value) {
return;
}

View File

@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldIPAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { IPAdapterModelFieldInputInstance, IPAdapterModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import type { IPAdapterModelConfigEntity } from 'services/api/endpoints/models';
import type { IPAdapterConfig } from 'services/api/endpoints/models';
import { useGetIPAdapterModelsQuery } from 'services/api/endpoints/models';
import type { FieldComponentProps } from './types';
@ -17,7 +17,7 @@ const IPAdapterModelFieldInputComponent = (
const { data: ipAdapterModels } = useGetIPAdapterModelsQuery();
const _onChange = useCallback(
(value: IPAdapterModelConfigEntity | null) => {
(value: IPAdapterConfig | null) => {
if (!value) {
return;
}

View File

@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldLoRAModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { LoRAModelFieldInputInstance, LoRAModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import type { LoRAModelConfigEntity } from 'services/api/endpoints/models';
import type { LoRAConfig } from 'services/api/endpoints/models';
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
import type { FieldComponentProps } from './types';
@ -16,7 +16,7 @@ const LoRAModelFieldInputComponent = (props: Props) => {
const dispatch = useAppDispatch();
const { data, isLoading } = useGetLoRAModelsQuery();
const _onChange = useCallback(
(value: LoRAModelConfigEntity | null) => {
(value: LoRAConfig | null) => {
if (!value) {
return;
}

View File

@ -6,7 +6,7 @@ import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { MainModelFieldInputInstance, MainModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { NON_SDXL_MAIN_MODELS } from 'services/api/constants';
import type { MainModelConfigEntity } from 'services/api/endpoints/models';
import type { MainModelConfig } from 'services/api/endpoints/models';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import type { FieldComponentProps } from './types';
@ -18,7 +18,7 @@ const MainModelFieldInputComponent = (props: Props) => {
const dispatch = useAppDispatch();
const { data, isLoading } = useGetMainModelsQuery(NON_SDXL_MAIN_MODELS);
const _onChange = useCallback(
(value: MainModelConfigEntity | null) => {
(value: MainModelConfig | null) => {
if (!value) {
return;
}

View File

@ -9,7 +9,7 @@ import type {
} from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { REFINER_BASE_MODELS } from 'services/api/constants';
import type { MainModelConfigEntity } from 'services/api/endpoints/models';
import type { MainModelConfig } from 'services/api/endpoints/models';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import type { FieldComponentProps } from './types';
@ -21,7 +21,7 @@ const RefinerModelFieldInputComponent = (props: Props) => {
const dispatch = useAppDispatch();
const { data, isLoading } = useGetMainModelsQuery(REFINER_BASE_MODELS);
const _onChange = useCallback(
(value: MainModelConfigEntity | null) => {
(value: MainModelConfig | null) => {
if (!value) {
return;
}

View File

@ -6,7 +6,7 @@ import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { SDXLMainModelFieldInputInstance, SDXLMainModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { SDXL_MAIN_MODELS } from 'services/api/constants';
import type { MainModelConfigEntity } from 'services/api/endpoints/models';
import type { MainModelConfig } from 'services/api/endpoints/models';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import type { FieldComponentProps } from './types';
@ -18,7 +18,7 @@ const SDXLMainModelFieldInputComponent = (props: Props) => {
const dispatch = useAppDispatch();
const { data, isLoading } = useGetMainModelsQuery(SDXL_MAIN_MODELS);
const _onChange = useCallback(
(value: MainModelConfigEntity | null) => {
(value: MainModelConfig | null) => {
if (!value) {
return;
}

View File

@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldT2IAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { T2IAdapterModelFieldInputInstance, T2IAdapterModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import type { T2IAdapterModelConfigEntity } from 'services/api/endpoints/models';
import type { T2IAdapterConfig } from 'services/api/endpoints/models';
import { useGetT2IAdapterModelsQuery } from 'services/api/endpoints/models';
import type { FieldComponentProps } from './types';
@ -18,7 +18,7 @@ const T2IAdapterModelFieldInputComponent = (
const { data: t2iAdapterModels } = useGetT2IAdapterModelsQuery();
const _onChange = useCallback(
(value: T2IAdapterModelConfigEntity | null) => {
(value: T2IAdapterConfig | null) => {
if (!value) {
return;
}

View File

@ -5,7 +5,7 @@ import { SyncModelsIconButton } from 'features/modelManager/components/SyncModel
import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { VAEModelFieldInputInstance, VAEModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import type { VaeModelConfigEntity } from 'services/api/endpoints/models';
import type { VAEConfig } from 'services/api/endpoints/models';
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
import type { FieldComponentProps } from './types';
@ -17,7 +17,7 @@ const VAEModelFieldInputComponent = (props: Props) => {
const dispatch = useAppDispatch();
const { data, isLoading } = useGetVaeModelsQuery();
const _onChange = useCallback(
(value: VaeModelConfigEntity | null) => {
(value: VAEConfig | null) => {
if (!value) {
return;
}

View File

@ -67,11 +67,13 @@ export const zModelName = z.string().min(3);
export const zModelIdentifier = z.object({
key: z.string().min(1),
});
export const zModelFieldBase = zModelIdentifier;
export const zModelIdentifierWithBase = zModelIdentifier.extend({ base: zBaseModel });
export type BaseModel = z.infer<typeof zBaseModel>;
export type ModelType = z.infer<typeof zModelType>;
export type ModelIdentifier = z.infer<typeof zModelIdentifier>;
export const zMainModelField = zModelIdentifier;
export type ModelIdentifierWithBase = z.infer<typeof zModelIdentifierWithBase>;
export const zMainModelField = zModelFieldBase;
export type MainModelField = z.infer<typeof zMainModelField>;
export const zSDXLRefinerModelField = zModelIdentifier;
@ -91,23 +93,23 @@ export const zSubModelType = z.enum([
]);
export type SubModelType = z.infer<typeof zSubModelType>;
export const zVAEModelField = zModelIdentifier;
export const zVAEModelField = zModelFieldBase;
export const zModelInfo = zModelIdentifier.extend({
submodel_type: zSubModelType.nullish(),
});
export type ModelInfo = z.infer<typeof zModelInfo>;
export const zLoRAModelField = zModelIdentifier;
export const zLoRAModelField = zModelFieldBase;
export type LoRAModelField = z.infer<typeof zLoRAModelField>;
export const zControlNetModelField = zModelIdentifier;
export const zControlNetModelField = zModelFieldBase;
export type ControlNetModelField = z.infer<typeof zControlNetModelField>;
export const zIPAdapterModelField = zModelIdentifier;
export const zIPAdapterModelField = zModelFieldBase;
export type IPAdapterModelField = z.infer<typeof zIPAdapterModelField>;
export const zT2IAdapterModelField = zModelIdentifier;
export const zT2IAdapterModelField = zModelFieldBase;
export type T2IAdapterModelField = z.infer<typeof zT2IAdapterModelField>;
export const zLoraInfo = zModelInfo.extend({

View File

@ -14,7 +14,7 @@ import { upsertMetadata } from './metadata';
export const addControlNetToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => {
const validControlNets = selectValidControlNets(state.controlAdapters).filter(
(ca) => ca.model?.base_model === state.generation.model?.base_model
(ca) => ca.model?.base === state.generation.model?.base
);
// const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as

View File

@ -14,7 +14,7 @@ import { upsertMetadata } from './metadata';
export const addIPAdapterToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => {
const validIPAdapters = selectValidIPAdapters(state.controlAdapters).filter(
(ca) => ca.model?.base_model === state.generation.model?.base_model
(ca) => ca.model?.base === state.generation.model?.base
);
if (validIPAdapters.length) {

View File

@ -28,6 +28,7 @@ export const addLoRAsToGraph = (
* So we need to inject a LoRA chain into the graph.
*/
// TODO(MM2): check base model
const enabledLoRAs = filter(state.lora.loras, (l) => l.isEnabled ?? false);
const loraCount = size(enabledLoRAs);
@ -48,19 +49,19 @@ export const addLoRAsToGraph = (
const loraMetadata: CoreMetadataInvocation['loras'] = [];
enabledLoRAs.forEach((lora) => {
const { model_name, base_model, weight } = lora;
const currentLoraNodeId = `${LORA_LOADER}_${model_name.replace('.', '_')}`;
const { key, weight } = lora;
const currentLoraNodeId = `${LORA_LOADER}_${key}`;
const loraLoaderNode: LoraLoaderInvocation = {
type: 'lora_loader',
id: currentLoraNodeId,
is_intermediate: true,
lora: { model_name, base_model },
lora: { key },
weight,
};
loraMetadata.push({
lora: { model_name, base_model },
lora: { key },
weight,
});

View File

@ -31,6 +31,7 @@ export const addSDXLLoRAsToGraph = (
* So we need to inject a LoRA chain into the graph.
*/
// TODO(MM2): check base model
const enabledLoRAs = filter(state.lora.loras, (l) => l.isEnabled ?? false);
const loraCount = size(enabledLoRAs);
@ -60,20 +61,20 @@ export const addSDXLLoRAsToGraph = (
let currentLoraIndex = 0;
enabledLoRAs.forEach((lora) => {
const { model_name, base_model, weight } = lora;
const currentLoraNodeId = `${LORA_LOADER}_${model_name.replace('.', '_')}`;
const { key, weight } = lora;
const currentLoraNodeId = `${LORA_LOADER}_${key}`;
const loraLoaderNode: SDXLLoraLoaderInvocation = {
type: 'sdxl_lora_loader',
id: currentLoraNodeId,
is_intermediate: true,
lora: { model_name, base_model },
lora: { key },
weight,
};
loraMetadata.push(
zLoRAMetadataItem.parse({
lora: { model_name, base_model },
lora: { key },
weight,
})
);

View File

@ -14,7 +14,7 @@ import { upsertMetadata } from './metadata';
export const addT2IAdaptersToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => {
const validT2IAdapters = selectValidT2IAdapters(state.controlAdapters).filter(
(ca) => ca.model?.base_model === state.generation.model?.base_model
(ca) => ca.model?.base === state.generation.model?.base
);
if (validT2IAdapters.length) {

View File

@ -19,7 +19,7 @@ export const buildCanvasGraph = (
let graph: NonNullableGraph;
if (generationMode === 'txt2img') {
if (state.generation.model && state.generation.model.base_model === 'sdxl') {
if (state.generation.model && state.generation.model.base === 'sdxl') {
graph = buildCanvasSDXLTextToImageGraph(state);
} else {
graph = buildCanvasTextToImageGraph(state);
@ -28,7 +28,7 @@ export const buildCanvasGraph = (
if (!canvasInitImage) {
throw new Error('Missing canvas init image');
}
if (state.generation.model && state.generation.model.base_model === 'sdxl') {
if (state.generation.model && state.generation.model.base === 'sdxl') {
graph = buildCanvasSDXLImageToImageGraph(state, canvasInitImage);
} else {
graph = buildCanvasImageToImageGraph(state, canvasInitImage);
@ -37,7 +37,7 @@ export const buildCanvasGraph = (
if (!canvasInitImage || !canvasMaskImage) {
throw new Error('Missing canvas init and mask images');
}
if (state.generation.model && state.generation.model.base_model === 'sdxl') {
if (state.generation.model && state.generation.model.base === 'sdxl') {
graph = buildCanvasSDXLInpaintGraph(state, canvasInitImage, canvasMaskImage);
} else {
graph = buildCanvasInpaintGraph(state, canvasInitImage, canvasMaskImage);
@ -46,7 +46,7 @@ export const buildCanvasGraph = (
if (!canvasInitImage) {
throw new Error('Missing canvas init image');
}
if (state.generation.model && state.generation.model.base_model === 'sdxl') {
if (state.generation.model && state.generation.model.base === 'sdxl') {
graph = buildCanvasSDXLOutpaintGraph(state, canvasInitImage, canvasMaskImage);
} else {
graph = buildCanvasOutpaintGraph(state, canvasInitImage, canvasMaskImage);

View File

@ -105,7 +105,7 @@ export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph,
});
}
if (shouldConcatSDXLStylePrompt && model?.base_model === 'sdxl') {
if (shouldConcatSDXLStylePrompt && model?.base === 'sdxl') {
if (graph.nodes[POSITIVE_CONDITIONING]) {
firstBatchDatumList.push({
node_path: POSITIVE_CONDITIONING,

View File

@ -29,17 +29,17 @@ const ParamClipSkip = () => {
if (!model) {
return CLIP_SKIP_MAP['sd-1'].maxClip;
}
return CLIP_SKIP_MAP[model.base_model].maxClip;
return CLIP_SKIP_MAP[model.base].maxClip;
}, [model]);
const sliderMarks = useMemo(() => {
if (!model) {
return CLIP_SKIP_MAP['sd-1'].markers;
}
return CLIP_SKIP_MAP[model.base_model].markers;
return CLIP_SKIP_MAP[model.base].markers;
}, [model]);
if (model?.base_model === 'sdxl') {
if (model?.base === 'sdxl') {
return null;
}

View File

@ -15,7 +15,7 @@ import { useTranslation } from 'react-i18next';
export const ParamPositivePrompt = memo(() => {
const dispatch = useAppDispatch();
const prompt = useAppSelector((s) => s.generation.positivePrompt);
const baseModel = useAppSelector((s) => s.generation.model)?.base_model;
const baseModel = useAppSelector((s) => s.generation.model)?.base;
const textareaRef = useRef<HTMLTextAreaElement>(null);
const { t } = useTranslation();

View File

@ -1,58 +1,45 @@
import { Combobox, FormControl, FormLabel, Tooltip } from '@invoke-ai/ui-library';
import { CustomSelect, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { useModelCustomSelect } from 'common/hooks/useModelCustomSelect';
import { modelSelected } from 'features/parameters/store/actions';
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
import { pick } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
import type { MainModelConfigEntity } from 'services/api/endpoints/models';
import { getModelId, mainModelsAdapterSelectors, useGetMainModelsQuery } from 'services/api/endpoints/models';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import type { MainModelConfig } from 'services/api/types';
const selectModel = createMemoizedSelector(selectGenerationSlice, (generation) => generation.model);
const ParamMainModelSelect = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const model = useAppSelector(selectModel);
const selectedModel = useAppSelector(selectModel);
const { data, isLoading } = useGetMainModelsQuery(NON_REFINER_BASE_MODELS);
const tooltipLabel = useMemo(() => {
if (!data || !model) {
return;
}
return mainModelsAdapterSelectors.selectById(data, getModelId(model))?.description;
}, [data, model]);
const _onChange = useCallback(
(model: MainModelConfigEntity | null) => {
(model: MainModelConfig | null) => {
if (!model) {
return;
}
dispatch(modelSelected(pick(model, ['base_model', 'model_name', 'model_type'])));
dispatch(modelSelected({ key: model.key, base: model.base }));
},
[dispatch]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelEntities: data,
onChange: _onChange,
selectedModel: model,
const { items, selectedItem, onChange, placeholder } = useModelCustomSelect({
data,
isLoading,
selectedModel,
onChange: _onChange,
});
return (
<Tooltip label={tooltipLabel}>
<FormControl isDisabled={!options.length} isInvalid={!options.length}>
<FormLabel>{t('modelManager.model')}</FormLabel>
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Tooltip>
<FormControl isDisabled={!items.length} isInvalid={!selectedItem || !items.length}>
<FormLabel>{t('modelManager.model')}</FormLabel>
<CustomSelect selectedItem={selectedItem} placeholder={placeholder} items={items} onChange={onChange} />
</FormControl>
);
};

View File

@ -7,7 +7,7 @@ import { selectGenerationSlice, vaeSelected } from 'features/parameters/store/ge
import { pick } from 'lodash-es';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import type { VaeModelConfigEntity } from 'services/api/endpoints/models';
import type { VAEConfig } from 'services/api/endpoints/models';
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
const selector = createMemoizedSelector(selectGenerationSlice, (generation) => {
@ -21,7 +21,7 @@ const ParamVAEModelSelect = () => {
const { model, vae } = useAppSelector(selector);
const { data, isLoading } = useGetVaeModelsQuery();
const getIsDisabled = useCallback(
(vae: VaeModelConfigEntity): boolean => {
(vae: VAEConfig): boolean => {
const isCompatible = model?.base_model === vae.base_model;
const hasMainModel = Boolean(model?.base_model);
return !hasMainModel || !isCompatible;
@ -29,7 +29,7 @@ const ParamVAEModelSelect = () => {
[model?.base_model]
);
const _onChange = useCallback(
(vae: VaeModelConfigEntity | null) => {
(vae: VAEConfig | null) => {
dispatch(vaeSelected(vae ? pick(vae, 'base_model', 'model_name') : null));
},
[dispatch]

View File

@ -464,17 +464,15 @@ export const useRecallParameters = () => {
return { lora: null, error: 'Invalid LoRA model' };
}
const { base_model, model_name } = loraMetadataItem.lora;
const { lora } = loraMetadataItem;
const matchingLoRA = loraModels
? loraModelsAdapterSelectors.selectById(loraModels, `${base_model}/lora/${model_name}`)
: undefined;
const matchingLoRA = loraModels ? loraModelsAdapterSelectors.selectById(loraModels, lora.key) : undefined;
if (!matchingLoRA) {
return { lora: null, error: 'LoRA model is not installed' };
}
const isCompatibleBaseModel = matchingLoRA?.base_model === (newModel ?? model)?.base_model;
const isCompatibleBaseModel = matchingLoRA?.base === (newModel ?? model)?.base;
if (!isCompatibleBaseModel) {
return {
@ -520,17 +518,14 @@ export const useRecallParameters = () => {
controlnetMetadataItem;
const matchingControlNetModel = controlNetModels
? controlNetModelsAdapterSelectors.selectById(
controlNetModels,
`${control_model.base_model}/controlnet/${control_model.model_name}`
)
? controlNetModelsAdapterSelectors.selectById(controlNetModels, control_model.key)
: undefined;
if (!matchingControlNetModel) {
return { controlnet: null, error: 'ControlNet model is not installed' };
}
const isCompatibleBaseModel = matchingControlNetModel?.base_model === (newModel ?? model)?.base_model;
const isCompatibleBaseModel = matchingControlNetModel?.base === (newModel ?? model)?.base;
if (!isCompatibleBaseModel) {
return {
@ -597,17 +592,14 @@ export const useRecallParameters = () => {
t2iAdapterMetadataItem;
const matchingT2IAdapterModel = t2iAdapterModels
? t2iAdapterModelsAdapterSelectors.selectById(
t2iAdapterModels,
`${t2i_adapter_model.base_model}/t2i_adapter/${t2i_adapter_model.model_name}`
)
? t2iAdapterModelsAdapterSelectors.selectById(t2iAdapterModels, t2i_adapter_model.key)
: undefined;
if (!matchingT2IAdapterModel) {
return { controlnet: null, error: 'ControlNet model is not installed' };
}
const isCompatibleBaseModel = matchingT2IAdapterModel?.base_model === (newModel ?? model)?.base_model;
const isCompatibleBaseModel = matchingT2IAdapterModel?.base === (newModel ?? model)?.base;
if (!isCompatibleBaseModel) {
return {
@ -672,17 +664,14 @@ export const useRecallParameters = () => {
const { image, ip_adapter_model, weight, begin_step_percent, end_step_percent } = ipAdapterMetadataItem;
const matchingIPAdapterModel = ipAdapterModels
? ipAdapterModelsAdapterSelectors.selectById(
ipAdapterModels,
`${ip_adapter_model.base_model}/ip_adapter/${ip_adapter_model.model_name}`
)
? ipAdapterModelsAdapterSelectors.selectById(ipAdapterModels, ip_adapter_model.key)
: undefined;
if (!matchingIPAdapterModel) {
return { ipAdapter: null, error: 'IP Adapter model is not installed' };
}
const isCompatibleBaseModel = matchingIPAdapterModel?.base_model === (newModel ?? model)?.base_model;
const isCompatibleBaseModel = matchingIPAdapterModel?.base === (newModel ?? model)?.base;
if (!isCompatibleBaseModel) {
return {

View File

@ -1,6 +1,7 @@
import { createAction } from '@reduxjs/toolkit';
import type { ImageDTO, MainModelField } from 'services/api/types';
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
import type { ImageDTO } from 'services/api/types';
export const initialImageSelected = createAction<ImageDTO | undefined>('generation/initialImageSelected');
export const modelSelected = createAction<MainModelField>('generation/modelSelected');
export const modelSelected = createAction<ParameterModel>('generation/modelSelected');

View File

@ -158,15 +158,15 @@ export const generationSlice = createSlice({
// Clamp ClipSkip Based On Selected Model
// TODO(psyche): remove this special handling when https://github.com/invoke-ai/InvokeAI/issues/4583 is resolved
// WIP PR here: https://github.com/invoke-ai/InvokeAI/pull/4624
if (newModel.base_model === 'sdxl') {
if (newModel.base === 'sdxl') {
// We don't support clip skip for SDXL yet - it's not in the graphs
state.clipSkip = 0;
} else {
const { maxClip } = CLIP_SKIP_MAP[newModel.base_model];
const { maxClip } = CLIP_SKIP_MAP[newModel.base];
state.clipSkip = clamp(state.clipSkip, 0, maxClip);
}
if (action.meta.previousModel?.base_model === newModel.base_model) {
if (action.meta.previousModel?.base === newModel.base) {
// The base model hasn't changed, we don't need to optimize the size
return;
}

View File

@ -17,8 +17,8 @@ export const MODEL_TYPE_MAP = {
*/
export const MODEL_TYPE_SHORT_MAP = {
any: 'Any',
'sd-1': 'SD1',
'sd-2': 'SD2',
'sd-1': 'SD1.X',
'sd-2': 'SD2.X',
sdxl: 'SDXL',
'sdxl-refiner': 'SDXLR',
};

View File

@ -1,5 +1,6 @@
import { NUMPY_RAND_MAX } from 'app/constants';
import {
zBaseModel,
zControlNetModelField,
zIPAdapterModelField,
zLoRAModelField,
@ -104,48 +105,48 @@ export const isParameterAspectRatio = (val: unknown): val is ParameterAspectRati
// #endregion
// #region Model
export const zParameterModel = zMainModelField;
export const zParameterModel = zMainModelField.extend({ base: zBaseModel });
export type ParameterModel = z.infer<typeof zParameterModel>;
export const isParameterModel = (val: unknown): val is ParameterModel => zParameterModel.safeParse(val).success;
// #endregion
// #region SDXL Refiner Model
export const zParameterSDXLRefinerModel = zSDXLRefinerModelField;
export const zParameterSDXLRefinerModel = zSDXLRefinerModelField.extend({ base: zBaseModel });
export type ParameterSDXLRefinerModel = z.infer<typeof zParameterSDXLRefinerModel>;
export const isParameterSDXLRefinerModel = (val: unknown): val is ParameterSDXLRefinerModel =>
zParameterSDXLRefinerModel.safeParse(val).success;
// #endregion
// #region VAE Model
export const zParameterVAEModel = zVAEModelField;
export const zParameterVAEModel = zVAEModelField.extend({ base: zBaseModel });
export type ParameterVAEModel = z.infer<typeof zParameterVAEModel>;
export const isParameterVAEModel = (val: unknown): val is ParameterVAEModel =>
zParameterVAEModel.safeParse(val).success;
// #endregion
// #region LoRA Model
export const zParameterLoRAModel = zLoRAModelField;
export const zParameterLoRAModel = zLoRAModelField.extend({ base: zBaseModel });
export type ParameterLoRAModel = z.infer<typeof zParameterLoRAModel>;
export const isParameterLoRAModel = (val: unknown): val is ParameterLoRAModel =>
zParameterLoRAModel.safeParse(val).success;
// #endregion
// #region ControlNet Model
export const zParameterControlNetModel = zControlNetModelField;
export const zParameterControlNetModel = zControlNetModelField.extend({ base: zBaseModel });
export type ParameterControlNetModel = z.infer<typeof zParameterLoRAModel>;
export const isParameterControlNetModel = (val: unknown): val is ParameterControlNetModel =>
zParameterControlNetModel.safeParse(val).success;
// #endregion
// #region IP Adapter Model
export const zParameterIPAdapterModel = zIPAdapterModelField;
export const zParameterIPAdapterModel = zIPAdapterModelField.extend({ base: zBaseModel });
export type ParameterIPAdapterModel = z.infer<typeof zParameterIPAdapterModel>;
export const isParameterIPAdapterModel = (val: unknown): val is ParameterIPAdapterModel =>
zParameterIPAdapterModel.safeParse(val).success;
// #endregion
// #region T2I Adapter Model
export const zParameterT2IAdapterModel = zT2IAdapterModelField;
export const zParameterT2IAdapterModel = zT2IAdapterModelField.extend({ base: zBaseModel });
export type ParameterT2IAdapterModel = z.infer<typeof zParameterT2IAdapterModel>;
export const isParameterT2IAdapterModel = (val: unknown): val is ParameterT2IAdapterModel =>
zParameterT2IAdapterModel.safeParse(val).success;

View File

@ -1,12 +1,12 @@
import type { ModelIdentifier } from 'features/nodes/types/common';
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
/**
* Gets the optimal dimension for a givel model, based on the model's base_model
* @param model The model identifier
* @returns The optimal dimension for the model
*/
export const getOptimalDimension = (model?: ModelIdentifier | null): number =>
model?.base_model === 'sdxl' ? 1024 : 512;
export const getOptimalDimension = (model?: ModelIdentifierWithBase | null): number =>
model?.base === 'sdxl' ? 1024 : 512;
const MIN_AREA_FACTOR = 0.8;
const MAX_AREA_FACTOR = 1.2;

View File

@ -6,12 +6,12 @@ import { refinerModelChanged, selectSdxlSlice } from 'features/sdxl/store/sdxlSl
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { REFINER_BASE_MODELS } from 'services/api/constants';
import type { MainModelConfigEntity } from 'services/api/endpoints/models';
import type { MainModelConfig } from 'services/api/endpoints/models';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
const selectModel = createMemoizedSelector(selectSdxlSlice, (sdxl) => sdxl.refinerModel);
const optionsFilter = (model: MainModelConfigEntity) => model.base_model === 'sdxl-refiner';
const optionsFilter = (model: MainModelConfig) => model.base_model === 'sdxl-refiner';
const ParamSDXLRefinerModelSelect = () => {
const dispatch = useAppDispatch();
@ -19,7 +19,7 @@ const ParamSDXLRefinerModelSelect = () => {
const { t } = useTranslation();
const { data, isLoading } = useGetMainModelsQuery(REFINER_BASE_MODELS);
const _onChange = useCallback(
(model: MainModelConfigEntity | null) => {
(model: MainModelConfig | null) => {
if (!model) {
dispatch(refinerModelChanged(null));
return;

View File

@ -24,7 +24,8 @@ const formLabelProps2: FormLabelProps = {
const selectBadges = createMemoizedSelector(selectGenerationSlice, (generation) => {
const badges: (string | number)[] = [];
if (generation.vae) {
let vaeBadge = generation.vae.model_name;
// TODO(MM2): Fetch the vae name
let vaeBadge = generation.vae.key;
if (generation.vaePrecision === 'fp16') {
vaeBadge += ` ${generation.vaePrecision}`;
}

View File

@ -35,9 +35,10 @@ const badgesSelector = createMemoizedSelector(selectLoraSlice, selectGenerationS
const enabledLoRAsCount = filter(lora.loras, (l) => !!l.isEnabled).length;
const loraTabBadges = enabledLoRAsCount ? [enabledLoRAsCount] : [];
const accordionBadges: (string | number)[] = [];
// TODO(MM2): fetch model name
if (generation.model) {
accordionBadges.push(generation.model.model_name);
accordionBadges.push(generation.model.base_model);
accordionBadges.push(generation.model.key);
accordionBadges.push(generation.model.base);
}
return { loraTabBadges, accordionBadges };

View File

@ -56,7 +56,7 @@ const selector = createMemoizedSelector(
if (hrfEnabled) {
badges.push('HiRes Fix');
}
return { badges, activeTabName, isSDXL: model?.base_model === 'sdxl' };
return { badges, activeTabName, isSDXL: model?.base === 'sdxl' };
}
);

View File

@ -22,7 +22,7 @@ const overlayScrollbarsStyles: CSSProperties = {
const ParametersPanel = () => {
const activeTabName = useAppSelector(activeTabNameSelector);
const isSDXL = useAppSelector((s) => s.generation.model?.base_model === 'sdxl');
const isSDXL = useAppSelector((s) => s.generation.model?.base === 'sdxl');
return (
<Flex w="full" h="full" flexDir="column" gap={2}>

View File

@ -3,27 +3,35 @@ import type { OpenAPIV3_1 } from 'openapi-types';
import type { paths } from 'services/api/schema';
import type { AppConfig, AppDependencyVersions, AppVersion } from 'services/api/types';
import { api } from '..';
import { api, buildV1Url } from '..';
/**
* Builds an endpoint URL for the app router
* @example
* buildAppInfoUrl('some-path')
* // '/api/v1/app/some-path'
*/
const buildAppInfoUrl = (path: string = '') => buildV1Url(`app/${path}`);
export const appInfoApi = api.injectEndpoints({
endpoints: (build) => ({
getAppVersion: build.query<AppVersion, void>({
query: () => ({
url: `app/version`,
url: buildAppInfoUrl('version'),
method: 'GET',
}),
providesTags: ['FetchOnReconnect'],
}),
getAppDeps: build.query<AppDependencyVersions, void>({
query: () => ({
url: `app/app_deps`,
url: buildAppInfoUrl('app_deps'),
method: 'GET',
}),
providesTags: ['FetchOnReconnect'],
}),
getAppConfig: build.query<AppConfig, void>({
query: () => ({
url: `app/config`,
url: buildAppInfoUrl('config'),
method: 'GET',
}),
providesTags: ['FetchOnReconnect'],
@ -33,28 +41,28 @@ export const appInfoApi = api.injectEndpoints({
void
>({
query: () => ({
url: `app/invocation_cache/status`,
url: buildAppInfoUrl('invocation_cache/status'),
method: 'GET',
}),
providesTags: ['InvocationCacheStatus', 'FetchOnReconnect'],
}),
clearInvocationCache: build.mutation<void, void>({
query: () => ({
url: `app/invocation_cache`,
url: buildAppInfoUrl('invocation_cache'),
method: 'DELETE',
}),
invalidatesTags: ['InvocationCacheStatus'],
}),
enableInvocationCache: build.mutation<void, void>({
query: () => ({
url: `app/invocation_cache/enable`,
url: buildAppInfoUrl('invocation_cache/enable'),
method: 'PUT',
}),
invalidatesTags: ['InvocationCacheStatus'],
}),
disableInvocationCache: build.mutation<void, void>({
query: () => ({
url: `app/invocation_cache/disable`,
url: buildAppInfoUrl('invocation_cache/disable'),
method: 'PUT',
}),
invalidatesTags: ['InvocationCacheStatus'],

View File

@ -9,7 +9,15 @@ import type {
import { getListImagesUrl } from 'services/api/util';
import type { ApiTagDescription } from '..';
import { api, LIST_TAG } from '..';
import { api, buildV1Url, LIST_TAG } from '..';
/**
* Builds an endpoint URL for the boards router
* @example
* buildBoardsUrl('some-path')
* // '/api/v1/boards/some-path'
*/
export const buildBoardsUrl = (path: string = '') => buildV1Url(`boards/${path}`);
export const boardsApi = api.injectEndpoints({
endpoints: (build) => ({
@ -17,7 +25,7 @@ export const boardsApi = api.injectEndpoints({
* Boards Queries
*/
listBoards: build.query<OffsetPaginatedResults_BoardDTO_, ListBoardsArg>({
query: (arg) => ({ url: 'boards/', params: arg }),
query: (arg) => ({ url: buildBoardsUrl(), params: arg }),
providesTags: (result) => {
// any list of boards
const tags: ApiTagDescription[] = [{ type: 'Board', id: LIST_TAG }, 'FetchOnReconnect'];
@ -38,7 +46,7 @@ export const boardsApi = api.injectEndpoints({
listAllBoards: build.query<Array<BoardDTO>, void>({
query: () => ({
url: 'boards/',
url: buildBoardsUrl(),
params: { all: true },
}),
providesTags: (result) => {
@ -61,7 +69,7 @@ export const boardsApi = api.injectEndpoints({
listAllImageNamesForBoard: build.query<Array<string>, string>({
query: (board_id) => ({
url: `boards/${board_id}/image_names`,
url: buildBoardsUrl(`${board_id}/image_names`),
}),
providesTags: (result, error, arg) => [{ type: 'ImageNameList', id: arg }, 'FetchOnReconnect'],
keepUnusedDataFor: 0,
@ -107,7 +115,7 @@ export const boardsApi = api.injectEndpoints({
createBoard: build.mutation<BoardDTO, string>({
query: (board_name) => ({
url: `boards/`,
url: buildBoardsUrl(),
method: 'POST',
params: { board_name },
}),
@ -116,7 +124,7 @@ export const boardsApi = api.injectEndpoints({
updateBoard: build.mutation<BoardDTO, UpdateBoardArg>({
query: ({ board_id, changes }) => ({
url: `boards/${board_id}`,
url: buildBoardsUrl(board_id),
method: 'PATCH',
body: changes,
}),

View File

@ -26,8 +26,24 @@ import {
} from 'services/api/util';
import type { ApiTagDescription } from '..';
import { api, LIST_TAG } from '..';
import { boardsApi } from './boards';
import { api, buildV1Url, LIST_TAG } from '..';
import { boardsApi, buildBoardsUrl } from './boards';
/**
* Builds an endpoint URL for the images router
* @example
* buildImagesUrl('some-path')
* // '/api/v1/images/some-path'
*/
const buildImagesUrl = (path: string = '') => buildV1Url(`images/${path}`);
/**
* Builds an endpoint URL for the board_images router
* @example
* buildBoardImagesUrl('some-path')
* // '/api/v1/board_images/some-path'
*/
const buildBoardImagesUrl = (path: string = '') => buildV1Url(`board_images/${path}`);
export const imagesApi = api.injectEndpoints({
endpoints: (build) => ({
@ -90,20 +106,20 @@ export const imagesApi = api.injectEndpoints({
keepUnusedDataFor: 86400,
}),
getIntermediatesCount: build.query<number, void>({
query: () => ({ url: 'images/intermediates' }),
query: () => ({ url: buildImagesUrl('intermediates') }),
providesTags: ['IntermediatesCount', 'FetchOnReconnect'],
}),
clearIntermediates: build.mutation<number, void>({
query: () => ({ url: `images/intermediates`, method: 'DELETE' }),
query: () => ({ url: buildImagesUrl('intermediates'), method: 'DELETE' }),
invalidatesTags: ['IntermediatesCount'],
}),
getImageDTO: build.query<ImageDTO, string>({
query: (image_name) => ({ url: `images/i/${image_name}` }),
query: (image_name) => ({ url: buildImagesUrl(`i/${image_name}`) }),
providesTags: (result, error, image_name) => [{ type: 'Image', id: image_name }],
keepUnusedDataFor: 86400, // 24 hours
}),
getImageMetadata: build.query<CoreMetadata | undefined, string>({
query: (image_name) => ({ url: `images/i/${image_name}/metadata` }),
query: (image_name) => ({ url: buildImagesUrl(`i/${image_name}/metadata`) }),
providesTags: (result, error, image_name) => [{ type: 'ImageMetadata', id: image_name }],
transformResponse: (
response: paths['/api/v1/images/i/{image_name}/metadata']['get']['responses']['200']['content']['application/json']
@ -130,7 +146,7 @@ export const imagesApi = api.injectEndpoints({
}),
deleteImage: build.mutation<void, ImageDTO>({
query: ({ image_name }) => ({
url: `images/i/${image_name}`,
url: buildImagesUrl(`i/${image_name}`),
method: 'DELETE',
}),
async onQueryStarted(imageDTO, { dispatch, queryFulfilled }) {
@ -185,7 +201,7 @@ export const imagesApi = api.injectEndpoints({
query: ({ imageDTOs }) => {
const image_names = imageDTOs.map((imageDTO) => imageDTO.image_name);
return {
url: `images/delete`,
url: buildImagesUrl('delete'),
method: 'POST',
body: {
image_names,
@ -258,7 +274,7 @@ export const imagesApi = api.injectEndpoints({
*/
changeImageIsIntermediate: build.mutation<ImageDTO, { imageDTO: ImageDTO; is_intermediate: boolean }>({
query: ({ imageDTO, is_intermediate }) => ({
url: `images/i/${imageDTO.image_name}`,
url: buildImagesUrl(`i/${imageDTO.image_name}`),
method: 'PATCH',
body: { is_intermediate },
}),
@ -380,7 +396,7 @@ export const imagesApi = api.injectEndpoints({
*/
changeImageSessionId: build.mutation<ImageDTO, { imageDTO: ImageDTO; session_id: string }>({
query: ({ imageDTO, session_id }) => ({
url: `images/i/${imageDTO.image_name}`,
url: buildImagesUrl(`i/${imageDTO.image_name}`),
method: 'PATCH',
body: { session_id },
}),
@ -417,7 +433,7 @@ export const imagesApi = api.injectEndpoints({
{ imageDTOs: ImageDTO[] }
>({
query: ({ imageDTOs: images }) => ({
url: `images/star`,
url: buildImagesUrl('star'),
method: 'POST',
body: { image_names: images.map((img) => img.image_name) },
}),
@ -511,7 +527,7 @@ export const imagesApi = api.injectEndpoints({
{ imageDTOs: ImageDTO[] }
>({
query: ({ imageDTOs: images }) => ({
url: `images/unstar`,
url: buildImagesUrl('unstar'),
method: 'POST',
body: { image_names: images.map((img) => img.image_name) },
}),
@ -611,7 +627,7 @@ export const imagesApi = api.injectEndpoints({
const formData = new FormData();
formData.append('file', file);
return {
url: `images/upload`,
url: buildImagesUrl('upload'),
method: 'POST',
body: formData,
params: {
@ -674,7 +690,7 @@ export const imagesApi = api.injectEndpoints({
}),
deleteBoard: build.mutation<DeleteBoardResult, string>({
query: (board_id) => ({ url: `boards/${board_id}`, method: 'DELETE' }),
query: (board_id) => ({ url: buildBoardsUrl(board_id), method: 'DELETE' }),
invalidatesTags: () => [
{ type: 'Board', id: LIST_TAG },
// invalidate the 'No Board' cache
@ -764,7 +780,7 @@ export const imagesApi = api.injectEndpoints({
deleteBoardAndImages: build.mutation<DeleteBoardResult, string>({
query: (board_id) => ({
url: `boards/${board_id}`,
url: buildBoardsUrl(board_id),
method: 'DELETE',
params: { include_images: true },
}),
@ -840,7 +856,7 @@ export const imagesApi = api.injectEndpoints({
query: ({ board_id, imageDTO }) => {
const { image_name } = imageDTO;
return {
url: `board_images/`,
url: buildBoardImagesUrl(),
method: 'POST',
body: { board_id, image_name },
};
@ -961,7 +977,7 @@ export const imagesApi = api.injectEndpoints({
query: ({ imageDTO }) => {
const { image_name } = imageDTO;
return {
url: `board_images/`,
url: buildBoardImagesUrl(),
method: 'DELETE',
body: { image_name },
};
@ -1080,7 +1096,7 @@ export const imagesApi = api.injectEndpoints({
}
>({
query: ({ board_id, imageDTOs }) => ({
url: `board_images/batch`,
url: buildBoardImagesUrl('batch'),
method: 'POST',
body: {
image_names: imageDTOs.map((i) => i.image_name),
@ -1197,7 +1213,7 @@ export const imagesApi = api.injectEndpoints({
}
>({
query: ({ imageDTOs }) => ({
url: `board_images/batch/delete`,
url: buildBoardImagesUrl('batch/delete'),
method: 'POST',
body: {
image_names: imageDTOs.map((i) => i.image_name),
@ -1321,7 +1337,7 @@ export const imagesApi = api.injectEndpoints({
components['schemas']['Body_download_images_from_list']
>({
query: ({ image_names, board_id }) => ({
url: `images/download`,
url: buildImagesUrl('download'),
method: 'POST',
body: {
image_names,

View File

@ -1,63 +1,28 @@
import type { EntityState } from '@reduxjs/toolkit';
import type { EntityAdapter, EntityState } from '@reduxjs/toolkit';
import { createEntityAdapter } from '@reduxjs/toolkit';
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
import { cloneDeep } from 'lodash-es';
import queryString from 'query-string';
import type { operations, paths } from 'services/api/schema';
import type {
AnyModelConfig,
BaseModelType,
CheckpointModelConfig,
ControlNetModelConfig,
DiffusersModelConfig,
ControlNetConfig,
ImportModelConfig,
IPAdapterModelConfig,
LoRAModelConfig,
IPAdapterConfig,
LoRAConfig,
MainModelConfig,
MergeModelConfig,
ModelType,
T2IAdapterModelConfig,
TextualInversionModelConfig,
VaeModelConfig,
T2IAdapterConfig,
TextualInversionConfig,
VAEConfig,
} from 'services/api/types';
import type { ApiTagDescription } from '..';
import { api, LIST_TAG } from '..';
import type { ApiTagDescription, tagTypes } from '..';
import { api, buildV2Url, LIST_TAG } from '..';
export type DiffusersModelConfigEntity = DiffusersModelConfig & { id: string };
export type CheckpointModelConfigEntity = CheckpointModelConfig & {
id: string;
};
export type MainModelConfigEntity = DiffusersModelConfigEntity | CheckpointModelConfigEntity;
export type LoRAModelConfigEntity = LoRAModelConfig & { id: string };
export type ControlNetModelConfigEntity = ControlNetModelConfig & {
id: string;
};
export type IPAdapterModelConfigEntity = IPAdapterModelConfig & {
id: string;
};
export type T2IAdapterModelConfigEntity = T2IAdapterModelConfig & {
id: string;
};
export type TextualInversionModelConfigEntity = TextualInversionModelConfig & {
id: string;
};
export type VaeModelConfigEntity = VaeModelConfig & { id: string };
export type AnyModelConfigEntity =
| MainModelConfigEntity
| LoRAModelConfigEntity
| ControlNetModelConfigEntity
| IPAdapterModelConfigEntity
| T2IAdapterModelConfigEntity
| TextualInversionModelConfigEntity
| VaeModelConfigEntity;
/* eslint-disable @typescript-eslint/no-explicit-any */
export const getModelId = (input: any): any => input;
type UpdateMainModelArg = {
base_model: BaseModelType;
@ -68,11 +33,13 @@ type UpdateMainModelArg = {
type UpdateLoRAModelArg = {
base_model: BaseModelType;
model_name: string;
body: LoRAModelConfig;
body: LoRAConfig;
};
type UpdateMainModelResponse =
paths['/api/v1/models/{base_model}/{model_type}/{model_name}']['patch']['responses']['200']['content']['application/json'];
paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
type ListModelsArg = NonNullable<paths['/api/models_v2/']['get']['parameters']['query']>;
type UpdateLoRAModelResponse = UpdateMainModelResponse;
@ -128,91 +95,95 @@ type CheckpointConfigsResponse =
type SearchFolderArg = operations['search_for_models']['parameters']['query'];
export const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
export const mainModelsAdapter = createEntityAdapter<MainModelConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const mainModelsAdapterSelectors = mainModelsAdapter.getSelectors(undefined, getSelectorsOptions);
export const loraModelsAdapter = createEntityAdapter<LoRAModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
export const loraModelsAdapter = createEntityAdapter<LoRAConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const loraModelsAdapterSelectors = loraModelsAdapter.getSelectors(undefined, getSelectorsOptions);
export const controlNetModelsAdapter = createEntityAdapter<ControlNetModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
export const controlNetModelsAdapter = createEntityAdapter<ControlNetConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const controlNetModelsAdapterSelectors = controlNetModelsAdapter.getSelectors(undefined, getSelectorsOptions);
export const ipAdapterModelsAdapter = createEntityAdapter<IPAdapterModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
export const ipAdapterModelsAdapter = createEntityAdapter<IPAdapterConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const ipAdapterModelsAdapterSelectors = ipAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions);
export const t2iAdapterModelsAdapter = createEntityAdapter<T2IAdapterModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
export const t2iAdapterModelsAdapter = createEntityAdapter<T2IAdapterConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const t2iAdapterModelsAdapterSelectors = t2iAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions);
export const textualInversionModelsAdapter = createEntityAdapter<TextualInversionModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
export const textualInversionModelsAdapter = createEntityAdapter<TextualInversionConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const textualInversionModelsAdapterSelectors = textualInversionModelsAdapter.getSelectors(
undefined,
getSelectorsOptions
);
export const vaeModelsAdapter = createEntityAdapter<VaeModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
export const vaeModelsAdapter = createEntityAdapter<VAEConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const vaeModelsAdapterSelectors = vaeModelsAdapter.getSelectors(undefined, getSelectorsOptions);
export const getModelId = ({
base_model,
model_type,
model_name,
}: Pick<AnyModelConfig, 'base_model' | 'model_name' | 'model_type'>) => `${base_model}/${model_type}/${model_name}`;
const buildProvidesTags =
<TEntity extends AnyModelConfig>(tagType: (typeof tagTypes)[number]) =>
(result: EntityState<TEntity, string> | undefined) => {
const tags: ApiTagDescription[] = [{ type: tagType, id: LIST_TAG }, 'Model'];
const createModelEntities = <T extends AnyModelConfigEntity>(models: AnyModelConfig[]): T[] => {
const entityArray: T[] = [];
models.forEach((model) => {
const entity = {
...cloneDeep(model),
id: getModelId(model),
} as T;
entityArray.push(entity);
});
return entityArray;
};
if (result) {
tags.push(
...result.ids.map((id) => ({
type: tagType,
id,
}))
);
}
return tags;
};
const buildTransformResponse =
<T extends AnyModelConfig>(adapter: EntityAdapter<T, string>) =>
(response: { models: T[] }) => {
return adapter.setAll(adapter.getInitialState(), response.models);
};
/**
* Builds an endpoint URL for the models router
* @example
* buildModelsUrl('some-path')
* // '/api/v1/models/some-path'
*/
const buildModelsUrl = (path: string = '') => buildV2Url(`models/${path}`);
export const modelsApi = api.injectEndpoints({
endpoints: (build) => ({
getMainModels: build.query<EntityState<MainModelConfigEntity, string>, BaseModelType[]>({
getMainModels: build.query<EntityState<MainModelConfig, string>, BaseModelType[]>({
query: (base_models) => {
const params = {
const params: ListModelsArg = {
model_type: 'main',
base_models,
};
const query = queryString.stringify(params, { arrayFormat: 'none' });
return `models/?${query}`;
},
providesTags: (result) => {
const tags: ApiTagDescription[] = [{ type: 'MainModel', id: LIST_TAG }, 'Model'];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: 'MainModel' as const,
id,
}))
);
}
return tags;
},
transformResponse: (response: { models: MainModelConfig[] }) => {
const entities = createModelEntities<MainModelConfigEntity>(response.models);
return mainModelsAdapter.setAll(mainModelsAdapter.getInitialState(), entities);
return buildModelsUrl(`?${query}`);
},
providesTags: buildProvidesTags<MainModelConfig>('MainModel'),
transformResponse: buildTransformResponse<MainModelConfig>(mainModelsAdapter),
}),
updateMainModels: build.mutation<UpdateMainModelResponse, UpdateMainModelArg>({
query: ({ base_model, model_name, body }) => {
return {
url: `models/${base_model}/main/${model_name}`,
url: buildModelsUrl(`${base_model}/main/${model_name}`),
method: 'PATCH',
body: body,
};
@ -222,7 +193,7 @@ export const modelsApi = api.injectEndpoints({
importMainModels: build.mutation<ImportMainModelResponse, ImportMainModelArg>({
query: ({ body }) => {
return {
url: `models/import`,
url: buildModelsUrl('import'),
method: 'POST',
body: body,
};
@ -232,7 +203,7 @@ export const modelsApi = api.injectEndpoints({
addMainModels: build.mutation<AddMainModelResponse, AddMainModelArg>({
query: ({ body }) => {
return {
url: `models/add`,
url: buildModelsUrl('add'),
method: 'POST',
body: body,
};
@ -242,7 +213,7 @@ export const modelsApi = api.injectEndpoints({
deleteMainModels: build.mutation<DeleteMainModelResponse, DeleteMainModelArg>({
query: ({ base_model, model_name, model_type }) => {
return {
url: `models/${base_model}/${model_type}/${model_name}`,
url: buildModelsUrl(`${base_model}/${model_type}/${model_name}`),
method: 'DELETE',
};
},
@ -251,7 +222,7 @@ export const modelsApi = api.injectEndpoints({
convertMainModels: build.mutation<ConvertMainModelResponse, ConvertMainModelArg>({
query: ({ base_model, model_name, convert_dest_directory }) => {
return {
url: `models/convert/${base_model}/main/${model_name}`,
url: buildModelsUrl(`convert/${base_model}/main/${model_name}`),
method: 'PUT',
params: { convert_dest_directory },
};
@ -261,7 +232,7 @@ export const modelsApi = api.injectEndpoints({
mergeMainModels: build.mutation<MergeMainModelResponse, MergeMainModelArg>({
query: ({ base_model, body }) => {
return {
url: `models/merge/${base_model}`,
url: buildModelsUrl(`merge/${base_model}`),
method: 'PUT',
body: body,
};
@ -271,37 +242,21 @@ export const modelsApi = api.injectEndpoints({
syncModels: build.mutation<SyncModelsResponse, void>({
query: () => {
return {
url: `models/sync`,
url: buildModelsUrl('sync'),
method: 'POST',
};
},
invalidatesTags: ['Model'],
}),
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity, string>, void>({
query: () => ({ url: 'models/', params: { model_type: 'lora' } }),
providesTags: (result) => {
const tags: ApiTagDescription[] = [{ type: 'LoRAModel', id: LIST_TAG }, 'Model'];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: 'LoRAModel' as const,
id,
}))
);
}
return tags;
},
transformResponse: (response: { models: LoRAModelConfig[] }) => {
const entities = createModelEntities<LoRAModelConfigEntity>(response.models);
return loraModelsAdapter.setAll(loraModelsAdapter.getInitialState(), entities);
},
getLoRAModels: build.query<EntityState<LoRAConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'lora' } }),
providesTags: buildProvidesTags<LoRAConfig>('LoRAModel'),
transformResponse: buildTransformResponse<LoRAConfig>(loraModelsAdapter),
}),
updateLoRAModels: build.mutation<UpdateLoRAModelResponse, UpdateLoRAModelArg>({
query: ({ base_model, model_name, body }) => {
return {
url: `models/${base_model}/lora/${model_name}`,
url: buildModelsUrl(`${base_model}/lora/${model_name}`),
method: 'PATCH',
body: body,
};
@ -311,129 +266,49 @@ export const modelsApi = api.injectEndpoints({
deleteLoRAModels: build.mutation<DeleteLoRAModelResponse, DeleteLoRAModelArg>({
query: ({ base_model, model_name }) => {
return {
url: `models/${base_model}/lora/${model_name}`,
url: buildModelsUrl(`${base_model}/lora/${model_name}`),
method: 'DELETE',
};
},
invalidatesTags: [{ type: 'LoRAModel', id: LIST_TAG }],
}),
getControlNetModels: build.query<EntityState<ControlNetModelConfigEntity, string>, void>({
query: () => ({ url: 'models/', params: { model_type: 'controlnet' } }),
providesTags: (result) => {
const tags: ApiTagDescription[] = [{ type: 'ControlNetModel', id: LIST_TAG }, 'Model'];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: 'ControlNetModel' as const,
id,
}))
);
}
return tags;
},
transformResponse: (response: { models: ControlNetModelConfig[] }) => {
const entities = createModelEntities<ControlNetModelConfigEntity>(response.models);
return controlNetModelsAdapter.setAll(controlNetModelsAdapter.getInitialState(), entities);
},
getControlNetModels: build.query<EntityState<ControlNetConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'controlnet' } }),
providesTags: buildProvidesTags<ControlNetConfig>('ControlNetModel'),
transformResponse: buildTransformResponse<ControlNetConfig>(controlNetModelsAdapter),
}),
getIPAdapterModels: build.query<EntityState<IPAdapterModelConfigEntity, string>, void>({
query: () => ({ url: 'models/', params: { model_type: 'ip_adapter' } }),
providesTags: (result) => {
const tags: ApiTagDescription[] = [{ type: 'IPAdapterModel', id: LIST_TAG }, 'Model'];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: 'IPAdapterModel' as const,
id,
}))
);
}
return tags;
},
transformResponse: (response: { models: IPAdapterModelConfig[] }) => {
const entities = createModelEntities<IPAdapterModelConfigEntity>(response.models);
return ipAdapterModelsAdapter.setAll(ipAdapterModelsAdapter.getInitialState(), entities);
},
getIPAdapterModels: build.query<EntityState<IPAdapterConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'ip_adapter' } }),
providesTags: buildProvidesTags<IPAdapterConfig>('IPAdapterModel'),
transformResponse: buildTransformResponse<IPAdapterConfig>(ipAdapterModelsAdapter),
}),
getT2IAdapterModels: build.query<EntityState<T2IAdapterModelConfigEntity, string>, void>({
query: () => ({ url: 'models/', params: { model_type: 't2i_adapter' } }),
providesTags: (result) => {
const tags: ApiTagDescription[] = [{ type: 'T2IAdapterModel', id: LIST_TAG }, 'Model'];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: 'T2IAdapterModel' as const,
id,
}))
);
}
return tags;
},
transformResponse: (response: { models: T2IAdapterModelConfig[] }) => {
const entities = createModelEntities<T2IAdapterModelConfigEntity>(response.models);
return t2iAdapterModelsAdapter.setAll(t2iAdapterModelsAdapter.getInitialState(), entities);
},
getT2IAdapterModels: build.query<EntityState<T2IAdapterConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 't2i_adapter' } }),
providesTags: buildProvidesTags<T2IAdapterConfig>('T2IAdapterModel'),
transformResponse: buildTransformResponse<T2IAdapterConfig>(t2iAdapterModelsAdapter),
}),
getVaeModels: build.query<EntityState<VaeModelConfigEntity, string>, void>({
query: () => ({ url: 'models/', params: { model_type: 'vae' } }),
providesTags: (result) => {
const tags: ApiTagDescription[] = [{ type: 'VaeModel', id: LIST_TAG }, 'Model'];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: 'VaeModel' as const,
id,
}))
);
}
return tags;
},
transformResponse: (response: { models: VaeModelConfig[] }) => {
const entities = createModelEntities<VaeModelConfigEntity>(response.models);
return vaeModelsAdapter.setAll(vaeModelsAdapter.getInitialState(), entities);
},
getVaeModels: build.query<EntityState<VAEConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'vae' } }),
providesTags: buildProvidesTags<VAEConfig>('VaeModel'),
transformResponse: buildTransformResponse<VAEConfig>(vaeModelsAdapter),
}),
getTextualInversionModels: build.query<EntityState<TextualInversionModelConfigEntity, string>, void>({
query: () => ({ url: 'models/', params: { model_type: 'embedding' } }),
providesTags: (result) => {
const tags: ApiTagDescription[] = [{ type: 'TextualInversionModel', id: LIST_TAG }, 'Model'];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: 'TextualInversionModel' as const,
id,
}))
);
}
return tags;
},
transformResponse: (response: { models: TextualInversionModelConfig[] }) => {
const entities = createModelEntities<TextualInversionModelConfigEntity>(response.models);
return textualInversionModelsAdapter.setAll(textualInversionModelsAdapter.getInitialState(), entities);
},
getTextualInversionModels: build.query<EntityState<TextualInversionConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'embedding' } }),
providesTags: buildProvidesTags<TextualInversionConfig>('TextualInversionModel'),
transformResponse: buildTransformResponse<TextualInversionConfig>(textualInversionModelsAdapter),
}),
getModelsInFolder: build.query<SearchFolderResponse, SearchFolderArg>({
query: (arg) => {
const folderQueryStr = queryString.stringify(arg, {});
return {
url: `/models/search?${folderQueryStr}`,
url: buildModelsUrl(`search?${folderQueryStr}`),
};
},
}),
getCheckpointConfigs: build.query<CheckpointConfigsResponse, void>({
query: () => {
return {
url: `/models/ckpt_confs`,
url: buildModelsUrl(`ckpt_confs`),
};
},
}),

View File

@ -7,7 +7,15 @@ import queryString from 'query-string';
import type { components, paths } from 'services/api/schema';
import type { ApiTagDescription } from '..';
import { api } from '..';
import { api, buildV1Url } from '..';
/**
* Builds an endpoint URL for the queue router
* @example
* buildQueueUrl('some-path')
* // '/api/v1/queue/queue_id/some-path'
*/
const buildQueueUrl = (path: string = '') => buildV1Url(`queue/${$queueId.get()}/${path}`);
const getListQueueItemsUrl = (queryArgs?: paths['/api/v1/queue/{queue_id}/list']['get']['parameters']['query']) => {
const query = queryArgs
@ -17,10 +25,10 @@ const getListQueueItemsUrl = (queryArgs?: paths['/api/v1/queue/{queue_id}/list']
: undefined;
if (query) {
return `queue/${$queueId.get()}/list?${query}`;
return buildQueueUrl(`list?${query}`);
}
return `queue/${$queueId.get()}/list`;
return buildQueueUrl('list');
};
export type SessionQueueItemStatus = NonNullable<
@ -58,7 +66,7 @@ export const queueApi = api.injectEndpoints({
paths['/api/v1/queue/{queue_id}/enqueue_batch']['post']['requestBody']['content']['application/json']
>({
query: (arg) => ({
url: `queue/${$queueId.get()}/enqueue_batch`,
url: buildQueueUrl('enqueue_batch'),
body: arg,
method: 'POST',
}),
@ -78,7 +86,7 @@ export const queueApi = api.injectEndpoints({
void
>({
query: () => ({
url: `queue/${$queueId.get()}/processor/resume`,
url: buildQueueUrl('processor/resume'),
method: 'PUT',
}),
invalidatesTags: ['CurrentSessionQueueItem', 'SessionQueueStatus'],
@ -88,7 +96,7 @@ export const queueApi = api.injectEndpoints({
void
>({
query: () => ({
url: `queue/${$queueId.get()}/processor/pause`,
url: buildQueueUrl('processor/pause'),
method: 'PUT',
}),
invalidatesTags: ['CurrentSessionQueueItem', 'SessionQueueStatus'],
@ -98,7 +106,7 @@ export const queueApi = api.injectEndpoints({
void
>({
query: () => ({
url: `queue/${$queueId.get()}/prune`,
url: buildQueueUrl('prune'),
method: 'PUT',
}),
invalidatesTags: ['SessionQueueStatus', 'BatchStatus'],
@ -117,7 +125,7 @@ export const queueApi = api.injectEndpoints({
void
>({
query: () => ({
url: `queue/${$queueId.get()}/clear`,
url: buildQueueUrl('clear'),
method: 'PUT',
}),
invalidatesTags: [
@ -142,7 +150,7 @@ export const queueApi = api.injectEndpoints({
void
>({
query: () => ({
url: `queue/${$queueId.get()}/current`,
url: buildQueueUrl('current'),
method: 'GET',
}),
providesTags: (result) => {
@ -158,7 +166,7 @@ export const queueApi = api.injectEndpoints({
void
>({
query: () => ({
url: `queue/${$queueId.get()}/next`,
url: buildQueueUrl('next'),
method: 'GET',
}),
providesTags: (result) => {
@ -174,7 +182,7 @@ export const queueApi = api.injectEndpoints({
void
>({
query: () => ({
url: `queue/${$queueId.get()}/status`,
url: buildQueueUrl('status'),
method: 'GET',
}),
providesTags: ['SessionQueueStatus', 'FetchOnReconnect'],
@ -184,7 +192,7 @@ export const queueApi = api.injectEndpoints({
{ batch_id: string }
>({
query: ({ batch_id }) => ({
url: `queue/${$queueId.get()}/b/${batch_id}/status`,
url: buildQueueUrl(`/b/${batch_id}/status`),
method: 'GET',
}),
providesTags: (result) => {
@ -200,7 +208,7 @@ export const queueApi = api.injectEndpoints({
number
>({
query: (item_id) => ({
url: `queue/${$queueId.get()}/i/${item_id}`,
url: buildQueueUrl(`i/${item_id}`),
method: 'GET',
}),
providesTags: (result) => {
@ -216,7 +224,7 @@ export const queueApi = api.injectEndpoints({
number
>({
query: (item_id) => ({
url: `queue/${$queueId.get()}/i/${item_id}/cancel`,
url: buildQueueUrl(`i/${item_id}/cancel`),
method: 'PUT',
}),
onQueryStarted: async (item_id, { dispatch, queryFulfilled }) => {
@ -253,7 +261,7 @@ export const queueApi = api.injectEndpoints({
paths['/api/v1/queue/{queue_id}/cancel_by_batch_ids']['put']['requestBody']['content']['application/json']
>({
query: (body) => ({
url: `queue/${$queueId.get()}/cancel_by_batch_ids`,
url: buildQueueUrl('cancel_by_batch_ids'),
method: 'PUT',
body,
}),
@ -279,7 +287,7 @@ export const queueApi = api.injectEndpoints({
method: 'GET',
}),
serializeQueryArgs: () => {
return `queue/${$queueId.get()}/list`;
return buildQueueUrl('list');
},
transformResponse: (response: components['schemas']['CursorPaginatedResults_SessionQueueItemDTO_']) =>
queueItemsAdapter.addMany(

View File

@ -1,6 +1,14 @@
import type { components } from 'services/api/schema';
import { api } from '..';
import { api, buildV1Url } from '..';
/**
* Builds an endpoint URL for the utilities router
* @example
* buildUtilitiesUrl('some-path')
* // '/api/v1/utilities/some-path'
*/
const buildUtilitiesUrl = (path: string = '') => buildV1Url(`utilities/${path}`);
export const utilitiesApi = api.injectEndpoints({
endpoints: (build) => ({
@ -9,7 +17,7 @@ export const utilitiesApi = api.injectEndpoints({
{ prompt: string; max_prompts: number }
>({
query: (arg) => ({
url: 'utilities/dynamicprompts',
url: buildUtilitiesUrl('dynamicprompts'),
body: arg,
method: 'POST',
}),

View File

@ -1,6 +1,14 @@
import type { paths } from 'services/api/schema';
import { api, LIST_TAG } from '..';
import { api, buildV1Url, LIST_TAG } from '..';
/**
* Builds an endpoint URL for the workflows router
* @example
* buildWorkflowsUrl('some-path')
* // '/api/v1/workflows/some-path'
*/
const buildWorkflowsUrl = (path: string = '') => buildV1Url(`workflows/${path}`);
export const workflowsApi = api.injectEndpoints({
endpoints: (build) => ({
@ -8,7 +16,7 @@ export const workflowsApi = api.injectEndpoints({
paths['/api/v1/workflows/i/{workflow_id}']['get']['responses']['200']['content']['application/json'],
string
>({
query: (workflow_id) => `workflows/i/${workflow_id}`,
query: (workflow_id) => buildWorkflowsUrl(`i/${workflow_id}`),
providesTags: (result, error, workflow_id) => [{ type: 'Workflow', id: workflow_id }, 'FetchOnReconnect'],
onQueryStarted: async (arg, api) => {
const { dispatch, queryFulfilled } = api;
@ -22,7 +30,7 @@ export const workflowsApi = api.injectEndpoints({
}),
deleteWorkflow: build.mutation<void, string>({
query: (workflow_id) => ({
url: `workflows/i/${workflow_id}`,
url: buildWorkflowsUrl(`i/${workflow_id}`),
method: 'DELETE',
}),
invalidatesTags: (result, error, workflow_id) => [
@ -36,7 +44,7 @@ export const workflowsApi = api.injectEndpoints({
paths['/api/v1/workflows/']['post']['requestBody']['content']['application/json']['workflow']
>({
query: (workflow) => ({
url: 'workflows/',
url: buildWorkflowsUrl(),
method: 'POST',
body: { workflow },
}),
@ -50,7 +58,7 @@ export const workflowsApi = api.injectEndpoints({
paths['/api/v1/workflows/i/{workflow_id}']['patch']['requestBody']['content']['application/json']['workflow']
>({
query: (workflow) => ({
url: `workflows/i/${workflow.id}`,
url: buildWorkflowsUrl(`i/${workflow.id}`),
method: 'PATCH',
body: { workflow },
}),
@ -65,7 +73,7 @@ export const workflowsApi = api.injectEndpoints({
NonNullable<paths['/api/v1/workflows/']['get']['parameters']['query']>
>({
query: (params) => ({
url: 'workflows/',
url: buildWorkflowsUrl(),
params,
}),
providesTags: ['FetchOnReconnect', { type: 'Workflow', id: LIST_TAG }],

View File

@ -54,7 +54,7 @@ const dynamicBaseQuery: BaseQueryFn<string | FetchArgs, unknown, FetchBaseQueryE
const projectId = $projectId.get();
const fetchBaseQueryArgs: FetchBaseQueryArgs = {
baseUrl: baseUrl ? `${baseUrl}/api/v1` : `${window.location.href.replace(/\/$/, '')}/api/v1`,
baseUrl: baseUrl || window.location.href.replace(/\/$/, ''),
prepareHeaders: (headers) => {
if (authToken) {
headers.set('Authorization', `Bearer ${authToken}`);
@ -108,3 +108,6 @@ function getCircularReplacer() {
return value;
};
}
export const buildV1Url = (path: string): string => `api/v1/${path}`;
export const buildV2Url = (path: string): string => `api/v2/${path}`;

File diff suppressed because one or more lines are too long

View File

@ -2,6 +2,7 @@ import type { UseToastOptions } from '@invoke-ai/ui-library';
import type { EntityState } from '@reduxjs/toolkit';
import type { components, paths } from 'services/api/schema';
import type { O } from 'ts-toolbelt';
import type { SetRequired } from 'type-fest';
export type S = components['schemas'];
@ -54,40 +55,36 @@ export type LoRAModelFormat = S['LoRAModelFormat'];
export type ControlNetModelField = S['ControlNetModelField'];
export type IPAdapterModelField = S['IPAdapterModelField'];
export type T2IAdapterModelField = S['T2IAdapterModelField'];
export type ModelsList = S['invokeai__app__api__routers__models__ModelsList'];
export type ControlField = S['ControlField'];
export type IPAdapterField = S['IPAdapterField'];
// Model Configs
export type LoRAModelConfig = S['LoRAModelConfig'];
export type VaeModelConfig = S['VaeModelConfig'];
export type ControlNetModelCheckpointConfig = S['ControlNetModelCheckpointConfig'];
export type ControlNetModelDiffusersConfig = S['ControlNetModelDiffusersConfig'];
export type ControlNetModelConfig = ControlNetModelCheckpointConfig | ControlNetModelDiffusersConfig;
export type IPAdapterModelInvokeAIConfig = S['IPAdapterModelInvokeAIConfig'];
export type IPAdapterModelConfig = IPAdapterModelInvokeAIConfig;
export type T2IAdapterModelDiffusersConfig = S['T2IAdapterModelDiffusersConfig'];
export type T2IAdapterModelConfig = T2IAdapterModelDiffusersConfig;
export type TextualInversionModelConfig = S['TextualInversionModelConfig'];
export type DiffusersModelConfig =
| S['StableDiffusion1ModelDiffusersConfig']
| S['StableDiffusion2ModelDiffusersConfig']
| S['StableDiffusionXLModelDiffusersConfig'];
export type CheckpointModelConfig =
| S['StableDiffusion1ModelCheckpointConfig']
| S['StableDiffusion2ModelCheckpointConfig']
| S['StableDiffusionXLModelCheckpointConfig'];
// TODO(MM2): Can we make key required in the pydantic model?
type KeyRequired<T extends { key?: string }> = SetRequired<T, 'key'>;
export type LoRAConfig = KeyRequired<S['LoRAConfig']>;
// TODO(MM2): Can we rename this from Vae -> VAE
export type VAEConfig = KeyRequired<S['VaeCheckpointConfig']> | KeyRequired<S['VaeDiffusersConfig']>;
export type ControlNetConfig =
| KeyRequired<S['ControlNetDiffusersConfig']>
| KeyRequired<S['ControlNetCheckpointConfig']>;
export type IPAdapterConfig = KeyRequired<S['IPAdapterConfig']>;
// TODO(MM2): Can we rename this to T2IAdapterConfig
export type T2IAdapterConfig = KeyRequired<S['T2IConfig']>;
export type TextualInversionConfig = KeyRequired<S['TextualInversionConfig']>;
export type DiffusersModelConfig = KeyRequired<S['MainDiffusersConfig']>;
export type CheckpointModelConfig = KeyRequired<S['MainCheckpointConfig']>;
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig;
export type AnyModelConfig =
| LoRAModelConfig
| VaeModelConfig
| ControlNetModelConfig
| IPAdapterModelConfig
| T2IAdapterModelConfig
| TextualInversionModelConfig
| LoRAConfig
| VAEConfig
| ControlNetConfig
| IPAdapterConfig
| T2IAdapterConfig
| TextualInversionConfig
| MainModelConfig;
export type MergeModelConfig = S['Body_merge_models'];
export type MergeModelConfig = S['Body_merge'];
export type ImportModelConfig = S['Body_import_model'];
// Graphs

View File

@ -3,6 +3,7 @@ import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
import { dateComparator } from 'common/util/dateComparator';
import { ASSETS_CATEGORIES, IMAGE_CATEGORIES } from 'features/gallery/store/types';
import queryString from 'query-string';
import { buildV1Url } from 'services/api';
import type { ImageCache, ImageDTO, ListImagesArgs } from './types';
@ -79,4 +80,4 @@ export const imagesSelectors = imagesAdapter.getSelectors(undefined, getSelector
// Helper to create the url for the listImages endpoint. Also we use it to create the cache key.
export const getListImagesUrl = (queryArgs: ListImagesArgs) =>
`images/?${queryString.stringify(queryArgs, { arrayFormat: 'none' })}`;
buildV1Url(`images/?${queryString.stringify(queryArgs, { arrayFormat: 'none' })}`);

View File

@ -76,9 +76,9 @@ export default defineConfig(({ mode }) => {
changeOrigin: true,
},
// proxy nodes api
'/api/v1': {
target: 'http://127.0.0.1:9090/api/v1',
rewrite: (path) => path.replace(/^\/api\/v1/, ''),
'/api/': {
target: 'http://127.0.0.1:9090/api/',
rewrite: (path) => path.replace(/^\/api/, ''),
changeOrigin: true,
},
},