mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): update model identifiers to use key (#5730)
## What type of PR is this? (check all applicable) - [x] Refactor ## Description - Update zod schemas & types to use key instead of name/base/type - Use new `CustomSelect` component instead of `ComboBox` for main model select and control adapter model selects (less jank, will switch to ComboBox based on CustomSelect for v4 so you can search the select) ## QA Instructions, Screenshots, Recordings If you hold your breath, you should be able to generate with a control adapter. <!-- Please provide steps on how to test changes, any hardware or software specifications as well as any other pertinent information. --> ## Merge Plan This PR can be merged when approved. Frontend tests not passing. <!-- A merge plan describes how this PR should be handled after it is approved. Example merge plans: - "This PR can be merged when approved" - "This must be squash-merged when approved" - "DO NOT MERGE - I will rebase and tidy commits before merging" - "#dev-chat on discord needs to be advised of this change when it is merged" A merge plan is particularly important for large PRs or PRs that touch the database in any way. -->
This commit is contained in:
commit
bc524026f9
@ -10,13 +10,7 @@ export const ReduxInit = memo((props: PropsWithChildren) => {
|
||||
const dispatch = useAppDispatch();
|
||||
useGlobalModifiersInit();
|
||||
useEffect(() => {
|
||||
dispatch(
|
||||
modelChanged({
|
||||
model_name: 'test_model',
|
||||
base_model: 'sd-1',
|
||||
model_type: 'main',
|
||||
})
|
||||
);
|
||||
dispatch(modelChanged({ key: 'test_model', base: 'sd-1' }));
|
||||
}, []);
|
||||
|
||||
return props.children;
|
||||
|
@ -19,7 +19,7 @@ export const addEnqueueRequestedLinear = () => {
|
||||
|
||||
let graph;
|
||||
|
||||
if (model && model.base_model === 'sdxl') {
|
||||
if (model && model.base === 'sdxl') {
|
||||
if (action.payload.tabName === 'txt2img') {
|
||||
graph = buildLinearSDXLTextToImageGraph(state);
|
||||
} else {
|
||||
|
@ -30,8 +30,8 @@ export const addModelSelectedListener = () => {
|
||||
|
||||
const newModel = result.data;
|
||||
|
||||
const newBaseModel = newModel.base_model;
|
||||
const didBaseModelChange = state.generation.model?.base_model !== newBaseModel;
|
||||
const newBaseModel = newModel.base;
|
||||
const didBaseModelChange = state.generation.model?.base !== newBaseModel;
|
||||
|
||||
if (didBaseModelChange) {
|
||||
// we may need to reset some incompatible submodels
|
||||
@ -39,7 +39,7 @@ export const addModelSelectedListener = () => {
|
||||
|
||||
// handle incompatible loras
|
||||
forEach(state.lora.loras, (lora, id) => {
|
||||
if (lora.base_model !== newBaseModel) {
|
||||
if (lora.base !== newBaseModel) {
|
||||
dispatch(loraRemoved(id));
|
||||
modelsCleared += 1;
|
||||
}
|
||||
@ -47,14 +47,14 @@ export const addModelSelectedListener = () => {
|
||||
|
||||
// handle incompatible vae
|
||||
const { vae } = state.generation;
|
||||
if (vae && vae.base_model !== newBaseModel) {
|
||||
if (vae && vae.base !== newBaseModel) {
|
||||
dispatch(vaeSelected(null));
|
||||
modelsCleared += 1;
|
||||
}
|
||||
|
||||
// handle incompatible controlnets
|
||||
selectControlAdapterAll(state.controlAdapters).forEach((ca) => {
|
||||
if (ca.model?.base_model !== newBaseModel) {
|
||||
if (ca.model?.base !== newBaseModel) {
|
||||
dispatch(controlAdapterIsEnabledChanged({ id: ca.id, isEnabled: false }));
|
||||
modelsCleared += 1;
|
||||
}
|
||||
|
@ -34,14 +34,7 @@ export const addModelsLoadedListener = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
const isCurrentModelAvailable = currentModel
|
||||
? models.some(
|
||||
(m) =>
|
||||
m.model_name === currentModel.model_name &&
|
||||
m.base_model === currentModel.base_model &&
|
||||
m.model_type === currentModel.model_type
|
||||
)
|
||||
: false;
|
||||
const isCurrentModelAvailable = currentModel ? models.some((m) => m.key === currentModel.key) : false;
|
||||
|
||||
if (isCurrentModelAvailable) {
|
||||
return;
|
||||
@ -74,14 +67,7 @@ export const addModelsLoadedListener = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
const isCurrentModelAvailable = currentModel
|
||||
? models.some(
|
||||
(m) =>
|
||||
m.model_name === currentModel.model_name &&
|
||||
m.base_model === currentModel.base_model &&
|
||||
m.model_type === currentModel.model_type
|
||||
)
|
||||
: false;
|
||||
const isCurrentModelAvailable = currentModel ? models.some((m) => m.key === currentModel.key) : false;
|
||||
|
||||
if (!isCurrentModelAvailable) {
|
||||
dispatch(refinerModelChanged(null));
|
||||
@ -103,10 +89,7 @@ export const addModelsLoadedListener = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
const isCurrentVAEAvailable = some(
|
||||
action.payload.entities,
|
||||
(m) => m?.model_name === currentVae?.model_name && m?.base_model === currentVae?.base_model
|
||||
);
|
||||
const isCurrentVAEAvailable = some(action.payload.entities, (m) => m?.key === currentVae?.key);
|
||||
|
||||
if (isCurrentVAEAvailable) {
|
||||
return;
|
||||
@ -140,10 +123,7 @@ export const addModelsLoadedListener = () => {
|
||||
const loras = getState().lora.loras;
|
||||
|
||||
forEach(loras, (lora, id) => {
|
||||
const isLoRAAvailable = some(
|
||||
action.payload.entities,
|
||||
(m) => m?.model_name === lora?.model_name && m?.base_model === lora?.base_model
|
||||
);
|
||||
const isLoRAAvailable = some(action.payload.entities, (m) => m?.key === lora?.key);
|
||||
|
||||
if (isLoRAAvailable) {
|
||||
return;
|
||||
@ -161,10 +141,7 @@ export const addModelsLoadedListener = () => {
|
||||
log.info({ models: action.payload.entities }, `ControlNet models loaded (${action.payload.ids.length})`);
|
||||
|
||||
selectAllControlNets(getState().controlAdapters).forEach((ca) => {
|
||||
const isModelAvailable = some(
|
||||
action.payload.entities,
|
||||
(m) => m?.model_name === ca?.model?.model_name && m?.base_model === ca?.model?.base_model
|
||||
);
|
||||
const isModelAvailable = some(action.payload.entities, (m) => m?.key === ca?.model?.key);
|
||||
|
||||
if (isModelAvailable) {
|
||||
return;
|
||||
@ -182,10 +159,7 @@ export const addModelsLoadedListener = () => {
|
||||
log.info({ models: action.payload.entities }, `T2I Adapter models loaded (${action.payload.ids.length})`);
|
||||
|
||||
selectAllT2IAdapters(getState().controlAdapters).forEach((ca) => {
|
||||
const isModelAvailable = some(
|
||||
action.payload.entities,
|
||||
(m) => m?.model_name === ca?.model?.model_name && m?.base_model === ca?.model?.base_model
|
||||
);
|
||||
const isModelAvailable = some(action.payload.entities, (m) => m?.key === ca?.model?.key);
|
||||
|
||||
if (isModelAvailable) {
|
||||
return;
|
||||
@ -203,10 +177,7 @@ export const addModelsLoadedListener = () => {
|
||||
log.info({ models: action.payload.entities }, `IP Adapter models loaded (${action.payload.ids.length})`);
|
||||
|
||||
selectAllIPAdapters(getState().controlAdapters).forEach((ca) => {
|
||||
const isModelAvailable = some(
|
||||
action.payload.entities,
|
||||
(m) => m?.model_name === ca?.model?.model_name && m?.base_model === ca?.model?.base_model
|
||||
);
|
||||
const isModelAvailable = some(action.payload.entities, (m) => m?.key === ca?.model?.key);
|
||||
|
||||
if (isModelAvailable) {
|
||||
return;
|
||||
|
@ -5,10 +5,10 @@ import type { GroupBase } from 'chakra-react-select';
|
||||
import { groupBy, map, reduce } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { AnyModelConfigEntity } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig } from 'services/api/endpoints/models';
|
||||
import { getModelId } from 'services/api/endpoints/models';
|
||||
|
||||
type UseGroupedModelComboboxArg<T extends AnyModelConfigEntity> = {
|
||||
type UseGroupedModelComboboxArg<T extends AnyModelConfig> = {
|
||||
modelEntities: EntityState<T, string> | undefined;
|
||||
selectedModel?: Pick<T, 'base_model' | 'model_name' | 'model_type'> | null;
|
||||
onChange: (value: T | null) => void;
|
||||
@ -24,7 +24,7 @@ type UseGroupedModelComboboxReturn = {
|
||||
noOptionsMessage: () => string;
|
||||
};
|
||||
|
||||
export const useGroupedModelCombobox = <T extends AnyModelConfigEntity>(
|
||||
export const useGroupedModelCombobox = <T extends AnyModelConfig>(
|
||||
arg: UseGroupedModelComboboxArg<T>
|
||||
): UseGroupedModelComboboxReturn => {
|
||||
const { t } = useTranslation();
|
||||
|
@ -105,7 +105,7 @@ const selector = createMemoizedSelector(
|
||||
number: i + 1,
|
||||
})
|
||||
);
|
||||
} else if (ca.model.base_model !== model?.base_model) {
|
||||
} else if (ca.model.base !== model?.base) {
|
||||
// This should never happen, just a sanity check
|
||||
reasons.push(
|
||||
i18n.t('parameters.invoke.incompatibleBaseModelForControlAdapter', {
|
||||
|
@ -3,10 +3,10 @@ import type { EntityState } from '@reduxjs/toolkit';
|
||||
import { map } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { AnyModelConfigEntity } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig } from 'services/api/endpoints/models';
|
||||
import { getModelId } from 'services/api/endpoints/models';
|
||||
|
||||
type UseModelComboboxArg<T extends AnyModelConfigEntity> = {
|
||||
type UseModelComboboxArg<T extends AnyModelConfig> = {
|
||||
modelEntities: EntityState<T, string> | undefined;
|
||||
selectedModel?: Pick<T, 'base_model' | 'model_name' | 'model_type'> | null;
|
||||
onChange: (value: T | null) => void;
|
||||
@ -23,9 +23,7 @@ type UseModelComboboxReturn = {
|
||||
noOptionsMessage: () => string;
|
||||
};
|
||||
|
||||
export const useModelCombobox = <T extends AnyModelConfigEntity>(
|
||||
arg: UseModelComboboxArg<T>
|
||||
): UseModelComboboxReturn => {
|
||||
export const useModelCombobox = <T extends AnyModelConfig>(arg: UseModelComboboxArg<T>): UseModelComboboxReturn => {
|
||||
const { t } = useTranslation();
|
||||
const { modelEntities, selectedModel, getIsDisabled, onChange, isLoading, optionsFilter = () => true } = arg;
|
||||
const options = useMemo<ComboboxOption[]>(() => {
|
||||
|
@ -0,0 +1,88 @@
|
||||
import type { Item } from '@invoke-ai/ui-library';
|
||||
import type { EntityState } from '@reduxjs/toolkit';
|
||||
import { EMPTY_ARRAY } from 'app/store/util';
|
||||
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
|
||||
import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants';
|
||||
import { filter } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
type UseModelCustomSelectArg<T extends AnyModelConfig> = {
|
||||
data: EntityState<T, string> | undefined;
|
||||
isLoading: boolean;
|
||||
selectedModel?: ModelIdentifierWithBase | null;
|
||||
onChange: (value: T | null) => void;
|
||||
modelFilter?: (model: T) => boolean;
|
||||
isModelDisabled?: (model: T) => boolean;
|
||||
};
|
||||
|
||||
type UseModelCustomSelectReturn = {
|
||||
selectedItem: Item | null;
|
||||
items: Item[];
|
||||
onChange: (item: Item | null) => void;
|
||||
placeholder: string;
|
||||
};
|
||||
|
||||
const modelFilterDefault = () => true;
|
||||
const isModelDisabledDefault = () => false;
|
||||
|
||||
export const useModelCustomSelect = <T extends AnyModelConfig>({
|
||||
data,
|
||||
isLoading,
|
||||
selectedModel,
|
||||
onChange,
|
||||
modelFilter = modelFilterDefault,
|
||||
isModelDisabled = isModelDisabledDefault,
|
||||
}: UseModelCustomSelectArg<T>): UseModelCustomSelectReturn => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const items: Item[] = useMemo(
|
||||
() =>
|
||||
data
|
||||
? filter(data.entities, modelFilter).map<Item>((m) => ({
|
||||
label: m.name,
|
||||
value: m.key,
|
||||
description: m.description,
|
||||
group: MODEL_TYPE_SHORT_MAP[m.base],
|
||||
isDisabled: isModelDisabled(m),
|
||||
}))
|
||||
: EMPTY_ARRAY,
|
||||
[data, isModelDisabled, modelFilter]
|
||||
);
|
||||
|
||||
const _onChange = useCallback(
|
||||
(item: Item | null) => {
|
||||
if (!item || !data) {
|
||||
return;
|
||||
}
|
||||
const model = data.entities[item.value];
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
onChange(model);
|
||||
},
|
||||
[data, onChange]
|
||||
);
|
||||
|
||||
const selectedItem = useMemo(() => items.find((o) => o.value === selectedModel?.key) ?? null, [selectedModel, items]);
|
||||
|
||||
const placeholder = useMemo(() => {
|
||||
if (isLoading) {
|
||||
return t('common.loading');
|
||||
}
|
||||
|
||||
if (items.length === 0) {
|
||||
return t('models.noModelsAvailable');
|
||||
}
|
||||
|
||||
return t('models.selectModel');
|
||||
}, [isLoading, items, t]);
|
||||
|
||||
return {
|
||||
items,
|
||||
onChange: _onChange,
|
||||
selectedItem,
|
||||
placeholder,
|
||||
};
|
||||
};
|
@ -626,7 +626,7 @@ export const canvasSlice = createSlice({
|
||||
},
|
||||
extraReducers: (builder) => {
|
||||
builder.addCase(modelChanged, (state, action) => {
|
||||
if (action.meta.previousModel?.base_model === action.payload?.base_model) {
|
||||
if (action.meta.previousModel?.base === action.payload?.base) {
|
||||
// The base model hasn't changed, we don't need to optimize the size
|
||||
return;
|
||||
}
|
||||
|
@ -1,49 +1,37 @@
|
||||
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { CustomSelect, FormControl } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { useModelCustomSelect } from 'common/hooks/useModelCustomSelect';
|
||||
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
|
||||
import { useControlAdapterModel } from 'features/controlAdapters/hooks/useControlAdapterModel';
|
||||
import { useControlAdapterModelEntities } from 'features/controlAdapters/hooks/useControlAdapterModelEntities';
|
||||
import { useControlAdapterModelQuery } from 'features/controlAdapters/hooks/useControlAdapterModelQuery';
|
||||
import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType';
|
||||
import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
|
||||
import { pick } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type {
|
||||
ControlNetModelConfigEntity,
|
||||
IPAdapterModelConfigEntity,
|
||||
T2IAdapterModelConfigEntity,
|
||||
} from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'services/api/types';
|
||||
|
||||
type ParamControlAdapterModelProps = {
|
||||
id: string;
|
||||
};
|
||||
|
||||
const selectMainModel = createMemoizedSelector(selectGenerationSlice, (generation) => generation.model);
|
||||
|
||||
const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
|
||||
const isEnabled = useControlAdapterIsEnabled(id);
|
||||
const controlAdapterType = useControlAdapterType(id);
|
||||
const model = useControlAdapterModel(id);
|
||||
const dispatch = useAppDispatch();
|
||||
const currentBaseModel = useAppSelector((s) => s.generation.model?.base_model);
|
||||
const mainModel = useAppSelector(selectMainModel);
|
||||
const { t } = useTranslation();
|
||||
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
||||
|
||||
const models = useControlAdapterModelEntities(controlAdapterType);
|
||||
const { data, isLoading } = useControlAdapterModelQuery(controlAdapterType);
|
||||
|
||||
const _onChange = useCallback(
|
||||
(model: ControlNetModelConfigEntity | IPAdapterModelConfigEntity | T2IAdapterModelConfigEntity | null) => {
|
||||
(model: ControlNetConfig | IPAdapterConfig | T2IAdapterConfig | null) => {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
controlAdapterModelChanged({
|
||||
id,
|
||||
model: pick(model, 'base_model', 'model_name'),
|
||||
model: pick(model, 'base', 'key'),
|
||||
})
|
||||
);
|
||||
},
|
||||
@ -55,34 +43,18 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
|
||||
[controlAdapterType, model]
|
||||
);
|
||||
|
||||
const getIsDisabled = useCallback(
|
||||
(model: AnyModelConfig): boolean => {
|
||||
const isCompatible = currentBaseModel === model.base_model;
|
||||
const hasMainModel = Boolean(currentBaseModel);
|
||||
return !hasMainModel || !isCompatible;
|
||||
},
|
||||
[currentBaseModel]
|
||||
);
|
||||
|
||||
const { options, value, onChange, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelEntities: models,
|
||||
onChange: _onChange,
|
||||
const { items, selectedItem, onChange, placeholder } = useModelCustomSelect({
|
||||
data,
|
||||
isLoading,
|
||||
selectedModel,
|
||||
getIsDisabled,
|
||||
onChange: _onChange,
|
||||
modelFilter: (model) => model.base === currentBaseModel,
|
||||
});
|
||||
|
||||
return (
|
||||
<Tooltip label={value?.description}>
|
||||
<FormControl isDisabled={!isEnabled} isInvalid={!value || mainModel?.base_model !== model?.base_model}>
|
||||
<Combobox
|
||||
options={options}
|
||||
placeholder={t('controlnet.selectModel')}
|
||||
value={value}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</FormControl>
|
||||
</Tooltip>
|
||||
<FormControl isDisabled={!items.length || !isEnabled} isInvalid={!selectedItem || !items.length}>
|
||||
<CustomSelect selectedItem={selectedItem} placeholder={placeholder} items={items} onChange={onChange} />
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
|
||||
|
@ -6,14 +6,14 @@ import { useCallback, useMemo } from 'react';
|
||||
import { useControlAdapterModels } from './useControlAdapterModels';
|
||||
|
||||
export const useAddControlAdapter = (type: ControlAdapterType) => {
|
||||
const baseModel = useAppSelector((s) => s.generation.model?.base_model);
|
||||
const baseModel = useAppSelector((s) => s.generation.model?.base);
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const models = useControlAdapterModels(type);
|
||||
|
||||
const firstModel = useMemo(() => {
|
||||
// prefer to use a model that matches the base model
|
||||
const firstCompatibleModel = models.filter((m) => (baseModel ? m.base_model === baseModel : true))[0];
|
||||
const firstCompatibleModel = models.filter((m) => (baseModel ? m.base === baseModel : true))[0];
|
||||
|
||||
if (firstCompatibleModel) {
|
||||
return firstCompatibleModel;
|
||||
|
@ -1,23 +0,0 @@
|
||||
import type { ControlAdapterType } from 'features/controlAdapters/store/types';
|
||||
import {
|
||||
useGetControlNetModelsQuery,
|
||||
useGetIPAdapterModelsQuery,
|
||||
useGetT2IAdapterModelsQuery,
|
||||
} from 'services/api/endpoints/models';
|
||||
|
||||
export const useControlAdapterModelEntities = (type?: ControlAdapterType) => {
|
||||
const { data: controlNetModelsData } = useGetControlNetModelsQuery();
|
||||
const { data: t2iAdapterModelsData } = useGetT2IAdapterModelsQuery();
|
||||
const { data: ipAdapterModelsData } = useGetIPAdapterModelsQuery();
|
||||
|
||||
if (type === 'controlnet') {
|
||||
return controlNetModelsData;
|
||||
}
|
||||
if (type === 't2i_adapter') {
|
||||
return t2iAdapterModelsData;
|
||||
}
|
||||
if (type === 'ip_adapter') {
|
||||
return ipAdapterModelsData;
|
||||
}
|
||||
return;
|
||||
};
|
@ -0,0 +1,26 @@
|
||||
import type { ControlAdapterType } from 'features/controlAdapters/store/types';
|
||||
import {
|
||||
useGetControlNetModelsQuery,
|
||||
useGetIPAdapterModelsQuery,
|
||||
useGetT2IAdapterModelsQuery,
|
||||
} from 'services/api/endpoints/models';
|
||||
|
||||
export const useControlAdapterModelQuery = (type: ControlAdapterType) => {
|
||||
const controlNetModelsQuery = useGetControlNetModelsQuery();
|
||||
const t2iAdapterModelsQuery = useGetT2IAdapterModelsQuery();
|
||||
const ipAdapterModelsQuery = useGetIPAdapterModelsQuery();
|
||||
|
||||
if (type === 'controlnet') {
|
||||
return controlNetModelsQuery;
|
||||
}
|
||||
if (type === 't2i_adapter') {
|
||||
return t2iAdapterModelsQuery;
|
||||
}
|
||||
if (type === 'ip_adapter') {
|
||||
return ipAdapterModelsQuery;
|
||||
}
|
||||
|
||||
// Assert that the end of the function is not reachable.
|
||||
const exhaustiveCheck: never = type;
|
||||
return exhaustiveCheck;
|
||||
};
|
@ -5,14 +5,16 @@ import {
|
||||
selectControlAdaptersSlice,
|
||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { useMemo } from 'react';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
export const useControlAdapterType = (id: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(
|
||||
selectControlAdaptersSlice,
|
||||
(controlAdapters) => selectControlAdapterById(controlAdapters, id)?.type
|
||||
),
|
||||
createMemoizedSelector(selectControlAdaptersSlice, (controlAdapters) => {
|
||||
const type = selectControlAdapterById(controlAdapters, id)?.type;
|
||||
assert(type !== undefined, `Control adapter with id ${id} not found`);
|
||||
return type;
|
||||
}),
|
||||
[id]
|
||||
);
|
||||
|
||||
|
@ -236,7 +236,8 @@ export const controlAdaptersSlice = createSlice({
|
||||
let processorType: ControlAdapterProcessorType | undefined = undefined;
|
||||
|
||||
for (const modelSubstring in CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS) {
|
||||
if (model.model_name.includes(modelSubstring)) {
|
||||
// TODO(MM2): matching modelSubstring to the model key is no longer a valid way to figure out the default processorType
|
||||
if (model.key.includes(modelSubstring)) {
|
||||
processorType = CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS[modelSubstring];
|
||||
break;
|
||||
}
|
||||
@ -359,7 +360,8 @@ export const controlAdaptersSlice = createSlice({
|
||||
let processorType: ControlAdapterProcessorType | undefined = undefined;
|
||||
|
||||
for (const modelSubstring in CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS) {
|
||||
if (cn.model?.model_name.includes(modelSubstring)) {
|
||||
// TODO(MM2): matching modelSubstring to the model key is no longer a valid way to figure out the default processorType
|
||||
if (cn.model?.key.includes(modelSubstring)) {
|
||||
processorType = CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS[modelSubstring];
|
||||
break;
|
||||
}
|
||||
|
@ -6,18 +6,18 @@ import type { EmbeddingSelectProps } from 'features/embedding/types';
|
||||
import { t } from 'i18next';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { TextualInversionModelConfigEntity } from 'services/api/endpoints/models';
|
||||
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
|
||||
import type { TextualInversionConfig } from 'services/api/types';
|
||||
|
||||
const noOptionsMessage = () => t('embedding.noMatchingEmbedding');
|
||||
|
||||
export const EmbeddingSelect = memo(({ onSelect, onClose }: EmbeddingSelectProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const currentBaseModel = useAppSelector((s) => s.generation.model?.base_model);
|
||||
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
||||
|
||||
const getIsDisabled = useCallback(
|
||||
(embedding: TextualInversionModelConfigEntity): boolean => {
|
||||
(embedding: TextualInversionConfig): boolean => {
|
||||
const isCompatible = currentBaseModel === embedding.base_model;
|
||||
const hasMainModel = Boolean(currentBaseModel);
|
||||
return !hasMainModel || !isCompatible;
|
||||
@ -27,7 +27,7 @@ export const EmbeddingSelect = memo(({ onSelect, onClose }: EmbeddingSelectProps
|
||||
const { data, isLoading } = useGetTextualInversionModelsQuery();
|
||||
|
||||
const _onChange = useCallback(
|
||||
(embedding: TextualInversionModelConfigEntity | null) => {
|
||||
(embedding: TextualInversionConfig | null) => {
|
||||
if (!embedding) {
|
||||
return;
|
||||
}
|
||||
|
@ -208,8 +208,8 @@ const ImageMetadataActions = (props: Props) => {
|
||||
{metadata.seed !== undefined && metadata.seed !== null && (
|
||||
<ImageMetadataItem label={t('metadata.seed')} value={metadata.seed} onClick={handleRecallSeed} />
|
||||
)}
|
||||
{metadata.model !== undefined && metadata.model !== null && metadata.model.model_name && (
|
||||
<ImageMetadataItem label={t('metadata.model')} value={metadata.model.model_name} onClick={handleRecallModel} />
|
||||
{metadata.model !== undefined && metadata.model !== null && metadata.model.key && (
|
||||
<ImageMetadataItem label={t('metadata.model')} value={metadata.model.key} onClick={handleRecallModel} />
|
||||
)}
|
||||
{metadata.width && (
|
||||
<ImageMetadataItem label={t('metadata.width')} value={metadata.width} onClick={handleRecallWidth} />
|
||||
@ -222,7 +222,7 @@ const ImageMetadataActions = (props: Props) => {
|
||||
)}
|
||||
<ImageMetadataItem
|
||||
label={t('metadata.vae')}
|
||||
value={metadata.vae?.model_name ?? 'Default'}
|
||||
value={metadata.vae?.key ?? 'Default'}
|
||||
onClick={handleRecallVaeModel}
|
||||
/>
|
||||
{metadata.steps && (
|
||||
@ -269,7 +269,7 @@ const ImageMetadataActions = (props: Props) => {
|
||||
<ImageMetadataItem
|
||||
key={index}
|
||||
label="LoRA"
|
||||
value={`${lora.lora.model_name} - ${lora.weight}`}
|
||||
value={`${lora.lora.key} - ${lora.weight}`}
|
||||
onClick={handleRecallLoRA.bind(null, lora)}
|
||||
/>
|
||||
);
|
||||
@ -279,7 +279,7 @@ const ImageMetadataActions = (props: Props) => {
|
||||
<ImageMetadataItem
|
||||
key={index}
|
||||
label="ControlNet"
|
||||
value={`${controlnet.control_model?.model_name} - ${controlnet.control_weight}`}
|
||||
value={`${controlnet.control_model?.key} - ${controlnet.control_weight}`}
|
||||
onClick={handleRecallControlNet.bind(null, controlnet)}
|
||||
/>
|
||||
))}
|
||||
@ -287,7 +287,7 @@ const ImageMetadataActions = (props: Props) => {
|
||||
<ImageMetadataItem
|
||||
key={index}
|
||||
label="IP Adapter"
|
||||
value={`${ipAdapter.ip_adapter_model?.model_name} - ${ipAdapter.weight}`}
|
||||
value={`${ipAdapter.ip_adapter_model?.key} - ${ipAdapter.weight}`}
|
||||
onClick={handleRecallIPAdapter.bind(null, ipAdapter)}
|
||||
/>
|
||||
))}
|
||||
@ -295,7 +295,7 @@ const ImageMetadataActions = (props: Props) => {
|
||||
<ImageMetadataItem
|
||||
key={index}
|
||||
label="T2I Adapter"
|
||||
value={`${t2iAdapter.t2i_adapter_model?.model_name} - ${t2iAdapter.weight}`}
|
||||
value={`${t2iAdapter.t2i_adapter_model?.key} - ${t2iAdapter.weight}`}
|
||||
onClick={handleRecallT2IAdapter.bind(null, t2iAdapter)}
|
||||
/>
|
||||
))}
|
||||
|
@ -43,7 +43,7 @@ export const LoRACard = memo((props: LoRACardProps) => {
|
||||
<CardHeader>
|
||||
<Flex alignItems="center" justifyContent="space-between" width="100%" gap={2}>
|
||||
<Text noOfLines={1} wordBreak="break-all" color={lora.isEnabled ? 'base.200' : 'base.500'}>
|
||||
{lora.model_name}
|
||||
{lora.key}
|
||||
</Text>
|
||||
<Flex alignItems="center" gap={2}>
|
||||
<Switch size="sm" onChange={handleSetLoraToggle} isChecked={lora.isEnabled} />
|
||||
|
@ -18,7 +18,7 @@ export const LoRAList = memo(() => {
|
||||
return (
|
||||
<Flex flexWrap="wrap" gap={2}>
|
||||
{lorasArray.map((lora) => (
|
||||
<LoRACard key={lora.model_name} lora={lora} />
|
||||
<LoRACard key={lora.key} lora={lora} />
|
||||
))}
|
||||
</Flex>
|
||||
);
|
||||
|
@ -6,7 +6,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { loraAdded, selectLoraSlice } from 'features/lora/store/loraSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { LoRAModelConfigEntity } from 'services/api/endpoints/models';
|
||||
import type { LoRAConfig } from 'services/api/endpoints/models';
|
||||
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
const selectAddedLoRAs = createMemoizedSelector(selectLoraSlice, (lora) => lora.loras);
|
||||
@ -18,7 +18,7 @@ const LoRASelect = () => {
|
||||
const addedLoRAs = useAppSelector(selectAddedLoRAs);
|
||||
const currentBaseModel = useAppSelector((s) => s.generation.model?.base_model);
|
||||
|
||||
const getIsDisabled = (lora: LoRAModelConfigEntity): boolean => {
|
||||
const getIsDisabled = (lora: LoRAConfig): boolean => {
|
||||
const isCompatible = currentBaseModel === lora.base_model;
|
||||
const isAdded = Boolean(addedLoRAs[lora.id]);
|
||||
const hasMainModel = Boolean(currentBaseModel);
|
||||
@ -26,7 +26,7 @@ const LoRASelect = () => {
|
||||
};
|
||||
|
||||
const _onChange = useCallback(
|
||||
(lora: LoRAModelConfigEntity | null) => {
|
||||
(lora: LoRAConfig | null) => {
|
||||
if (!lora) {
|
||||
return;
|
||||
}
|
||||
|
@ -2,10 +2,9 @@ import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
|
||||
import type { LoRAModelConfigEntity } from 'services/api/endpoints/models';
|
||||
import type { LoRAConfig } from 'services/api/types';
|
||||
|
||||
export type LoRA = ParameterLoRAModel & {
|
||||
id: string;
|
||||
weight: number;
|
||||
isEnabled?: boolean;
|
||||
};
|
||||
@ -29,40 +28,40 @@ export const loraSlice = createSlice({
|
||||
name: 'lora',
|
||||
initialState: initialLoraState,
|
||||
reducers: {
|
||||
loraAdded: (state, action: PayloadAction<LoRAModelConfigEntity>) => {
|
||||
const { model_name, id, base_model } = action.payload;
|
||||
state.loras[id] = { id, model_name, base_model, ...defaultLoRAConfig };
|
||||
loraAdded: (state, action: PayloadAction<LoRAConfig>) => {
|
||||
const { key, base } = action.payload;
|
||||
state.loras[key] = { key, base, ...defaultLoRAConfig };
|
||||
},
|
||||
loraRecalled: (state, action: PayloadAction<LoRAModelConfigEntity & { weight: number }>) => {
|
||||
const { model_name, id, base_model, weight } = action.payload;
|
||||
state.loras[id] = { id, model_name, base_model, weight, isEnabled: true };
|
||||
loraRecalled: (state, action: PayloadAction<LoRAConfig & { weight: number }>) => {
|
||||
const { key, base, weight } = action.payload;
|
||||
state.loras[key] = { key, base, weight, isEnabled: true };
|
||||
},
|
||||
loraRemoved: (state, action: PayloadAction<string>) => {
|
||||
const id = action.payload;
|
||||
delete state.loras[id];
|
||||
const key = action.payload;
|
||||
delete state.loras[key];
|
||||
},
|
||||
lorasCleared: (state) => {
|
||||
state.loras = {};
|
||||
},
|
||||
loraWeightChanged: (state, action: PayloadAction<{ id: string; weight: number }>) => {
|
||||
const { id, weight } = action.payload;
|
||||
const lora = state.loras[id];
|
||||
loraWeightChanged: (state, action: PayloadAction<{ key: string; weight: number }>) => {
|
||||
const { key, weight } = action.payload;
|
||||
const lora = state.loras[key];
|
||||
if (!lora) {
|
||||
return;
|
||||
}
|
||||
lora.weight = weight;
|
||||
},
|
||||
loraWeightReset: (state, action: PayloadAction<string>) => {
|
||||
const id = action.payload;
|
||||
const lora = state.loras[id];
|
||||
const key = action.payload;
|
||||
const lora = state.loras[key];
|
||||
if (!lora) {
|
||||
return;
|
||||
}
|
||||
lora.weight = defaultLoRAConfig.weight;
|
||||
},
|
||||
loraIsEnabledChanged: (state, action: PayloadAction<Pick<LoRA, 'id' | 'isEnabled'>>) => {
|
||||
const { id, isEnabled } = action.payload;
|
||||
const lora = state.loras[id];
|
||||
loraIsEnabledChanged: (state, action: PayloadAction<Pick<LoRA, 'key' | 'isEnabled'>>) => {
|
||||
const { key, isEnabled } = action.payload;
|
||||
const lora = state.loras[key];
|
||||
if (!lora) {
|
||||
return;
|
||||
}
|
||||
|
@ -2,11 +2,7 @@ import { Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { memo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||
import type {
|
||||
DiffusersModelConfigEntity,
|
||||
LoRAModelConfigEntity,
|
||||
MainModelConfigEntity,
|
||||
} from 'services/api/endpoints/models';
|
||||
import type { DiffusersModelConfig, LoRAConfig, MainModelConfig } from 'services/api/endpoints/models';
|
||||
import { useGetLoRAModelsQuery, useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
|
||||
@ -38,7 +34,7 @@ const ModelManagerPanel = () => {
|
||||
};
|
||||
|
||||
type ModelEditProps = {
|
||||
model: MainModelConfigEntity | LoRAModelConfigEntity | undefined;
|
||||
model: MainModelConfig | LoRAConfig | undefined;
|
||||
};
|
||||
|
||||
const ModelEdit = (props: ModelEditProps) => {
|
||||
@ -50,7 +46,7 @@ const ModelEdit = (props: ModelEditProps) => {
|
||||
}
|
||||
|
||||
if (model?.model_format === 'diffusers') {
|
||||
return <DiffusersModelEdit key={model.id} model={model as DiffusersModelConfigEntity} />;
|
||||
return <DiffusersModelEdit key={model.id} model={model as DiffusersModelConfig} />;
|
||||
}
|
||||
|
||||
if (model?.model_type === 'lora') {
|
||||
|
@ -21,14 +21,14 @@ import { memo, useCallback, useEffect, useState } from 'react';
|
||||
import type { SubmitHandler } from 'react-hook-form';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { CheckpointModelConfigEntity } from 'services/api/endpoints/models';
|
||||
import type { CheckpointModelConfig } from 'services/api/endpoints/models';
|
||||
import { useGetCheckpointConfigsQuery, useUpdateMainModelsMutation } from 'services/api/endpoints/models';
|
||||
import type { CheckpointModelConfig } from 'services/api/types';
|
||||
|
||||
import ModelConvert from './ModelConvert';
|
||||
|
||||
type CheckpointModelEditProps = {
|
||||
model: CheckpointModelConfigEntity;
|
||||
model: CheckpointModelConfig;
|
||||
};
|
||||
|
||||
const CheckpointModelEdit = (props: CheckpointModelEditProps) => {
|
||||
|
@ -9,12 +9,12 @@ import { memo, useCallback } from 'react';
|
||||
import type { SubmitHandler } from 'react-hook-form';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { DiffusersModelConfigEntity } from 'services/api/endpoints/models';
|
||||
import type { DiffusersModelConfig } from 'services/api/endpoints/models';
|
||||
import { useUpdateMainModelsMutation } from 'services/api/endpoints/models';
|
||||
import type { DiffusersModelConfig } from 'services/api/types';
|
||||
|
||||
type DiffusersModelEditProps = {
|
||||
model: DiffusersModelConfigEntity;
|
||||
model: DiffusersModelConfig;
|
||||
};
|
||||
|
||||
const DiffusersModelEdit = (props: DiffusersModelEditProps) => {
|
||||
|
@ -8,12 +8,12 @@ import { memo, useCallback } from 'react';
|
||||
import type { SubmitHandler } from 'react-hook-form';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { LoRAModelConfigEntity } from 'services/api/endpoints/models';
|
||||
import type { LoRAConfig } from 'services/api/endpoints/models';
|
||||
import { useUpdateLoRAModelsMutation } from 'services/api/endpoints/models';
|
||||
import type { LoRAModelConfig } from 'services/api/types';
|
||||
import type { LoRAConfig } from 'services/api/types';
|
||||
|
||||
type LoRAModelEditProps = {
|
||||
model: LoRAModelConfigEntity;
|
||||
model: LoRAConfig;
|
||||
};
|
||||
|
||||
const LoRAModelEdit = (props: LoRAModelEditProps) => {
|
||||
@ -30,7 +30,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
|
||||
control,
|
||||
formState: { errors },
|
||||
reset,
|
||||
} = useForm<LoRAModelConfig>({
|
||||
} = useForm<LoRAConfig>({
|
||||
defaultValues: {
|
||||
model_name: model.model_name ? model.model_name : '',
|
||||
base_model: model.base_model,
|
||||
@ -42,7 +42,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
|
||||
mode: 'onChange',
|
||||
});
|
||||
|
||||
const onSubmit = useCallback<SubmitHandler<LoRAModelConfig>>(
|
||||
const onSubmit = useCallback<SubmitHandler<LoRAConfig>>(
|
||||
(values) => {
|
||||
const responseBody = {
|
||||
base_model: model.base_model,
|
||||
@ -53,7 +53,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
|
||||
updateLoRAModel(responseBody)
|
||||
.unwrap()
|
||||
.then((payload) => {
|
||||
reset(payload as LoRAModelConfig, { keepDefaultValues: true });
|
||||
reset(payload as LoRAConfig, { keepDefaultValues: true });
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
@ -106,7 +106,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
|
||||
<FormLabel>{t('modelManager.description')}</FormLabel>
|
||||
<Input {...register('description')} />
|
||||
</FormControl>
|
||||
<BaseModelSelect<LoRAModelConfig> control={control} name="base_model" />
|
||||
<BaseModelSelect<LoRAConfig> control={control} name="base_model" />
|
||||
|
||||
<FormControl isInvalid={Boolean(errors.path)}>
|
||||
<FormLabel>{t('modelManager.modelLocation')}</FormLabel>
|
||||
|
@ -5,7 +5,7 @@ import type { ChangeEvent, PropsWithChildren } from 'react';
|
||||
import { memo, useCallback, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||
import type { LoRAModelConfigEntity, MainModelConfigEntity } from 'services/api/endpoints/models';
|
||||
import type { LoRAConfig, MainModelConfig } from 'services/api/endpoints/models';
|
||||
import { useGetLoRAModelsQuery, useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import ModelListItem from './ModelListItem';
|
||||
@ -127,7 +127,7 @@ const ModelList = (props: ModelListProps) => {
|
||||
|
||||
export default memo(ModelList);
|
||||
|
||||
const modelsFilter = <T extends MainModelConfigEntity | LoRAModelConfigEntity>(
|
||||
const modelsFilter = <T extends MainModelConfig | LoRAConfig>(
|
||||
data: EntityState<T, string> | undefined,
|
||||
model_type: ModelType,
|
||||
model_format: ModelFormat | undefined,
|
||||
@ -163,7 +163,7 @@ StyledModelContainer.displayName = 'StyledModelContainer';
|
||||
|
||||
type ModelListWrapperProps = {
|
||||
title: string;
|
||||
modelList: MainModelConfigEntity[] | LoRAModelConfigEntity[];
|
||||
modelList: MainModelConfig[] | LoRAConfig[];
|
||||
selected: ModelListProps;
|
||||
};
|
||||
|
||||
|
@ -15,11 +15,11 @@ import { makeToast } from 'features/system/util/makeToast';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiTrashSimpleBold } from 'react-icons/pi';
|
||||
import type { LoRAModelConfigEntity, MainModelConfigEntity } from 'services/api/endpoints/models';
|
||||
import type { LoRAConfig, MainModelConfig } from 'services/api/endpoints/models';
|
||||
import { useDeleteLoRAModelsMutation, useDeleteMainModelsMutation } from 'services/api/endpoints/models';
|
||||
|
||||
type ModelListItemProps = {
|
||||
model: MainModelConfigEntity | LoRAModelConfigEntity;
|
||||
model: MainModelConfig | LoRAConfig;
|
||||
isSelected: boolean;
|
||||
setSelectedModelId: (v: string | undefined) => void;
|
||||
};
|
||||
|
@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { fieldControlNetModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { ControlNetModelFieldInputInstance, ControlNetModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import type { ControlNetModelConfigEntity } from 'services/api/endpoints/models';
|
||||
import type { ControlNetConfig } from 'services/api/endpoints/models';
|
||||
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
@ -17,7 +17,7 @@ const ControlNetModelFieldInputComponent = (props: Props) => {
|
||||
const { data, isLoading } = useGetControlNetModelsQuery();
|
||||
|
||||
const _onChange = useCallback(
|
||||
(value: ControlNetModelConfigEntity | null) => {
|
||||
(value: ControlNetConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
|
@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { fieldIPAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { IPAdapterModelFieldInputInstance, IPAdapterModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import type { IPAdapterModelConfigEntity } from 'services/api/endpoints/models';
|
||||
import type { IPAdapterConfig } from 'services/api/endpoints/models';
|
||||
import { useGetIPAdapterModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
@ -17,7 +17,7 @@ const IPAdapterModelFieldInputComponent = (
|
||||
const { data: ipAdapterModels } = useGetIPAdapterModelsQuery();
|
||||
|
||||
const _onChange = useCallback(
|
||||
(value: IPAdapterModelConfigEntity | null) => {
|
||||
(value: IPAdapterConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
|
@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { fieldLoRAModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { LoRAModelFieldInputInstance, LoRAModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import type { LoRAModelConfigEntity } from 'services/api/endpoints/models';
|
||||
import type { LoRAConfig } from 'services/api/endpoints/models';
|
||||
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
@ -16,7 +16,7 @@ const LoRAModelFieldInputComponent = (props: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { data, isLoading } = useGetLoRAModelsQuery();
|
||||
const _onChange = useCallback(
|
||||
(value: LoRAModelConfigEntity | null) => {
|
||||
(value: LoRAConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
|
@ -6,7 +6,7 @@ import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { MainModelFieldInputInstance, MainModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { NON_SDXL_MAIN_MODELS } from 'services/api/constants';
|
||||
import type { MainModelConfigEntity } from 'services/api/endpoints/models';
|
||||
import type { MainModelConfig } from 'services/api/endpoints/models';
|
||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
@ -18,7 +18,7 @@ const MainModelFieldInputComponent = (props: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { data, isLoading } = useGetMainModelsQuery(NON_SDXL_MAIN_MODELS);
|
||||
const _onChange = useCallback(
|
||||
(value: MainModelConfigEntity | null) => {
|
||||
(value: MainModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
|
@ -9,7 +9,7 @@ import type {
|
||||
} from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { REFINER_BASE_MODELS } from 'services/api/constants';
|
||||
import type { MainModelConfigEntity } from 'services/api/endpoints/models';
|
||||
import type { MainModelConfig } from 'services/api/endpoints/models';
|
||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
@ -21,7 +21,7 @@ const RefinerModelFieldInputComponent = (props: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { data, isLoading } = useGetMainModelsQuery(REFINER_BASE_MODELS);
|
||||
const _onChange = useCallback(
|
||||
(value: MainModelConfigEntity | null) => {
|
||||
(value: MainModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
|
@ -6,7 +6,7 @@ import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { SDXLMainModelFieldInputInstance, SDXLMainModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { SDXL_MAIN_MODELS } from 'services/api/constants';
|
||||
import type { MainModelConfigEntity } from 'services/api/endpoints/models';
|
||||
import type { MainModelConfig } from 'services/api/endpoints/models';
|
||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
@ -18,7 +18,7 @@ const SDXLMainModelFieldInputComponent = (props: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { data, isLoading } = useGetMainModelsQuery(SDXL_MAIN_MODELS);
|
||||
const _onChange = useCallback(
|
||||
(value: MainModelConfigEntity | null) => {
|
||||
(value: MainModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
|
@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { fieldT2IAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { T2IAdapterModelFieldInputInstance, T2IAdapterModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import type { T2IAdapterModelConfigEntity } from 'services/api/endpoints/models';
|
||||
import type { T2IAdapterConfig } from 'services/api/endpoints/models';
|
||||
import { useGetT2IAdapterModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
@ -18,7 +18,7 @@ const T2IAdapterModelFieldInputComponent = (
|
||||
const { data: t2iAdapterModels } = useGetT2IAdapterModelsQuery();
|
||||
|
||||
const _onChange = useCallback(
|
||||
(value: T2IAdapterModelConfigEntity | null) => {
|
||||
(value: T2IAdapterConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
|
@ -5,7 +5,7 @@ import { SyncModelsIconButton } from 'features/modelManager/components/SyncModel
|
||||
import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { VAEModelFieldInputInstance, VAEModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import type { VaeModelConfigEntity } from 'services/api/endpoints/models';
|
||||
import type { VAEConfig } from 'services/api/endpoints/models';
|
||||
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
@ -17,7 +17,7 @@ const VAEModelFieldInputComponent = (props: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { data, isLoading } = useGetVaeModelsQuery();
|
||||
const _onChange = useCallback(
|
||||
(value: VaeModelConfigEntity | null) => {
|
||||
(value: VAEConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
|
@ -67,11 +67,13 @@ export const zModelName = z.string().min(3);
|
||||
export const zModelIdentifier = z.object({
|
||||
key: z.string().min(1),
|
||||
});
|
||||
export const zModelFieldBase = zModelIdentifier;
|
||||
export const zModelIdentifierWithBase = zModelIdentifier.extend({ base: zBaseModel });
|
||||
export type BaseModel = z.infer<typeof zBaseModel>;
|
||||
export type ModelType = z.infer<typeof zModelType>;
|
||||
export type ModelIdentifier = z.infer<typeof zModelIdentifier>;
|
||||
|
||||
export const zMainModelField = zModelIdentifier;
|
||||
export type ModelIdentifierWithBase = z.infer<typeof zModelIdentifierWithBase>;
|
||||
export const zMainModelField = zModelFieldBase;
|
||||
export type MainModelField = z.infer<typeof zMainModelField>;
|
||||
|
||||
export const zSDXLRefinerModelField = zModelIdentifier;
|
||||
@ -91,23 +93,23 @@ export const zSubModelType = z.enum([
|
||||
]);
|
||||
export type SubModelType = z.infer<typeof zSubModelType>;
|
||||
|
||||
export const zVAEModelField = zModelIdentifier;
|
||||
export const zVAEModelField = zModelFieldBase;
|
||||
|
||||
export const zModelInfo = zModelIdentifier.extend({
|
||||
submodel_type: zSubModelType.nullish(),
|
||||
});
|
||||
export type ModelInfo = z.infer<typeof zModelInfo>;
|
||||
|
||||
export const zLoRAModelField = zModelIdentifier;
|
||||
export const zLoRAModelField = zModelFieldBase;
|
||||
export type LoRAModelField = z.infer<typeof zLoRAModelField>;
|
||||
|
||||
export const zControlNetModelField = zModelIdentifier;
|
||||
export const zControlNetModelField = zModelFieldBase;
|
||||
export type ControlNetModelField = z.infer<typeof zControlNetModelField>;
|
||||
|
||||
export const zIPAdapterModelField = zModelIdentifier;
|
||||
export const zIPAdapterModelField = zModelFieldBase;
|
||||
export type IPAdapterModelField = z.infer<typeof zIPAdapterModelField>;
|
||||
|
||||
export const zT2IAdapterModelField = zModelIdentifier;
|
||||
export const zT2IAdapterModelField = zModelFieldBase;
|
||||
export type T2IAdapterModelField = z.infer<typeof zT2IAdapterModelField>;
|
||||
|
||||
export const zLoraInfo = zModelInfo.extend({
|
||||
|
@ -14,7 +14,7 @@ import { upsertMetadata } from './metadata';
|
||||
|
||||
export const addControlNetToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => {
|
||||
const validControlNets = selectValidControlNets(state.controlAdapters).filter(
|
||||
(ca) => ca.model?.base_model === state.generation.model?.base_model
|
||||
(ca) => ca.model?.base === state.generation.model?.base
|
||||
);
|
||||
|
||||
// const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
|
||||
|
@ -14,7 +14,7 @@ import { upsertMetadata } from './metadata';
|
||||
|
||||
export const addIPAdapterToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => {
|
||||
const validIPAdapters = selectValidIPAdapters(state.controlAdapters).filter(
|
||||
(ca) => ca.model?.base_model === state.generation.model?.base_model
|
||||
(ca) => ca.model?.base === state.generation.model?.base
|
||||
);
|
||||
|
||||
if (validIPAdapters.length) {
|
||||
|
@ -28,6 +28,7 @@ export const addLoRAsToGraph = (
|
||||
* So we need to inject a LoRA chain into the graph.
|
||||
*/
|
||||
|
||||
// TODO(MM2): check base model
|
||||
const enabledLoRAs = filter(state.lora.loras, (l) => l.isEnabled ?? false);
|
||||
const loraCount = size(enabledLoRAs);
|
||||
|
||||
@ -48,19 +49,19 @@ export const addLoRAsToGraph = (
|
||||
const loraMetadata: CoreMetadataInvocation['loras'] = [];
|
||||
|
||||
enabledLoRAs.forEach((lora) => {
|
||||
const { model_name, base_model, weight } = lora;
|
||||
const currentLoraNodeId = `${LORA_LOADER}_${model_name.replace('.', '_')}`;
|
||||
const { key, weight } = lora;
|
||||
const currentLoraNodeId = `${LORA_LOADER}_${key}`;
|
||||
|
||||
const loraLoaderNode: LoraLoaderInvocation = {
|
||||
type: 'lora_loader',
|
||||
id: currentLoraNodeId,
|
||||
is_intermediate: true,
|
||||
lora: { model_name, base_model },
|
||||
lora: { key },
|
||||
weight,
|
||||
};
|
||||
|
||||
loraMetadata.push({
|
||||
lora: { model_name, base_model },
|
||||
lora: { key },
|
||||
weight,
|
||||
});
|
||||
|
||||
|
@ -31,6 +31,7 @@ export const addSDXLLoRAsToGraph = (
|
||||
* So we need to inject a LoRA chain into the graph.
|
||||
*/
|
||||
|
||||
// TODO(MM2): check base model
|
||||
const enabledLoRAs = filter(state.lora.loras, (l) => l.isEnabled ?? false);
|
||||
const loraCount = size(enabledLoRAs);
|
||||
|
||||
@ -60,20 +61,20 @@ export const addSDXLLoRAsToGraph = (
|
||||
let currentLoraIndex = 0;
|
||||
|
||||
enabledLoRAs.forEach((lora) => {
|
||||
const { model_name, base_model, weight } = lora;
|
||||
const currentLoraNodeId = `${LORA_LOADER}_${model_name.replace('.', '_')}`;
|
||||
const { key, weight } = lora;
|
||||
const currentLoraNodeId = `${LORA_LOADER}_${key}`;
|
||||
|
||||
const loraLoaderNode: SDXLLoraLoaderInvocation = {
|
||||
type: 'sdxl_lora_loader',
|
||||
id: currentLoraNodeId,
|
||||
is_intermediate: true,
|
||||
lora: { model_name, base_model },
|
||||
lora: { key },
|
||||
weight,
|
||||
};
|
||||
|
||||
loraMetadata.push(
|
||||
zLoRAMetadataItem.parse({
|
||||
lora: { model_name, base_model },
|
||||
lora: { key },
|
||||
weight,
|
||||
})
|
||||
);
|
||||
|
@ -14,7 +14,7 @@ import { upsertMetadata } from './metadata';
|
||||
|
||||
export const addT2IAdaptersToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => {
|
||||
const validT2IAdapters = selectValidT2IAdapters(state.controlAdapters).filter(
|
||||
(ca) => ca.model?.base_model === state.generation.model?.base_model
|
||||
(ca) => ca.model?.base === state.generation.model?.base
|
||||
);
|
||||
|
||||
if (validT2IAdapters.length) {
|
||||
|
@ -19,7 +19,7 @@ export const buildCanvasGraph = (
|
||||
let graph: NonNullableGraph;
|
||||
|
||||
if (generationMode === 'txt2img') {
|
||||
if (state.generation.model && state.generation.model.base_model === 'sdxl') {
|
||||
if (state.generation.model && state.generation.model.base === 'sdxl') {
|
||||
graph = buildCanvasSDXLTextToImageGraph(state);
|
||||
} else {
|
||||
graph = buildCanvasTextToImageGraph(state);
|
||||
@ -28,7 +28,7 @@ export const buildCanvasGraph = (
|
||||
if (!canvasInitImage) {
|
||||
throw new Error('Missing canvas init image');
|
||||
}
|
||||
if (state.generation.model && state.generation.model.base_model === 'sdxl') {
|
||||
if (state.generation.model && state.generation.model.base === 'sdxl') {
|
||||
graph = buildCanvasSDXLImageToImageGraph(state, canvasInitImage);
|
||||
} else {
|
||||
graph = buildCanvasImageToImageGraph(state, canvasInitImage);
|
||||
@ -37,7 +37,7 @@ export const buildCanvasGraph = (
|
||||
if (!canvasInitImage || !canvasMaskImage) {
|
||||
throw new Error('Missing canvas init and mask images');
|
||||
}
|
||||
if (state.generation.model && state.generation.model.base_model === 'sdxl') {
|
||||
if (state.generation.model && state.generation.model.base === 'sdxl') {
|
||||
graph = buildCanvasSDXLInpaintGraph(state, canvasInitImage, canvasMaskImage);
|
||||
} else {
|
||||
graph = buildCanvasInpaintGraph(state, canvasInitImage, canvasMaskImage);
|
||||
@ -46,7 +46,7 @@ export const buildCanvasGraph = (
|
||||
if (!canvasInitImage) {
|
||||
throw new Error('Missing canvas init image');
|
||||
}
|
||||
if (state.generation.model && state.generation.model.base_model === 'sdxl') {
|
||||
if (state.generation.model && state.generation.model.base === 'sdxl') {
|
||||
graph = buildCanvasSDXLOutpaintGraph(state, canvasInitImage, canvasMaskImage);
|
||||
} else {
|
||||
graph = buildCanvasOutpaintGraph(state, canvasInitImage, canvasMaskImage);
|
||||
|
@ -105,7 +105,7 @@ export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph,
|
||||
});
|
||||
}
|
||||
|
||||
if (shouldConcatSDXLStylePrompt && model?.base_model === 'sdxl') {
|
||||
if (shouldConcatSDXLStylePrompt && model?.base === 'sdxl') {
|
||||
if (graph.nodes[POSITIVE_CONDITIONING]) {
|
||||
firstBatchDatumList.push({
|
||||
node_path: POSITIVE_CONDITIONING,
|
||||
|
@ -29,17 +29,17 @@ const ParamClipSkip = () => {
|
||||
if (!model) {
|
||||
return CLIP_SKIP_MAP['sd-1'].maxClip;
|
||||
}
|
||||
return CLIP_SKIP_MAP[model.base_model].maxClip;
|
||||
return CLIP_SKIP_MAP[model.base].maxClip;
|
||||
}, [model]);
|
||||
|
||||
const sliderMarks = useMemo(() => {
|
||||
if (!model) {
|
||||
return CLIP_SKIP_MAP['sd-1'].markers;
|
||||
}
|
||||
return CLIP_SKIP_MAP[model.base_model].markers;
|
||||
return CLIP_SKIP_MAP[model.base].markers;
|
||||
}, [model]);
|
||||
|
||||
if (model?.base_model === 'sdxl') {
|
||||
if (model?.base === 'sdxl') {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
@ -15,7 +15,7 @@ import { useTranslation } from 'react-i18next';
|
||||
export const ParamPositivePrompt = memo(() => {
|
||||
const dispatch = useAppDispatch();
|
||||
const prompt = useAppSelector((s) => s.generation.positivePrompt);
|
||||
const baseModel = useAppSelector((s) => s.generation.model)?.base_model;
|
||||
const baseModel = useAppSelector((s) => s.generation.model)?.base;
|
||||
|
||||
const textareaRef = useRef<HTMLTextAreaElement>(null);
|
||||
const { t } = useTranslation();
|
||||
|
@ -1,58 +1,45 @@
|
||||
import { Combobox, FormControl, FormLabel, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { CustomSelect, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { useModelCustomSelect } from 'common/hooks/useModelCustomSelect';
|
||||
import { modelSelected } from 'features/parameters/store/actions';
|
||||
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
|
||||
import { pick } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
|
||||
import type { MainModelConfigEntity } from 'services/api/endpoints/models';
|
||||
import { getModelId, mainModelsAdapterSelectors, useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
import type { MainModelConfig } from 'services/api/types';
|
||||
|
||||
const selectModel = createMemoizedSelector(selectGenerationSlice, (generation) => generation.model);
|
||||
|
||||
const ParamMainModelSelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const model = useAppSelector(selectModel);
|
||||
const selectedModel = useAppSelector(selectModel);
|
||||
const { data, isLoading } = useGetMainModelsQuery(NON_REFINER_BASE_MODELS);
|
||||
const tooltipLabel = useMemo(() => {
|
||||
if (!data || !model) {
|
||||
return;
|
||||
}
|
||||
return mainModelsAdapterSelectors.selectById(data, getModelId(model))?.description;
|
||||
}, [data, model]);
|
||||
|
||||
const _onChange = useCallback(
|
||||
(model: MainModelConfigEntity | null) => {
|
||||
(model: MainModelConfig | null) => {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
dispatch(modelSelected(pick(model, ['base_model', 'model_name', 'model_type'])));
|
||||
dispatch(modelSelected({ key: model.key, base: model.base }));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelEntities: data,
|
||||
onChange: _onChange,
|
||||
selectedModel: model,
|
||||
|
||||
const { items, selectedItem, onChange, placeholder } = useModelCustomSelect({
|
||||
data,
|
||||
isLoading,
|
||||
selectedModel,
|
||||
onChange: _onChange,
|
||||
});
|
||||
|
||||
return (
|
||||
<Tooltip label={tooltipLabel}>
|
||||
<FormControl isDisabled={!options.length} isInvalid={!options.length}>
|
||||
<FormLabel>{t('modelManager.model')}</FormLabel>
|
||||
<Combobox
|
||||
value={value}
|
||||
placeholder={placeholder}
|
||||
options={options}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</FormControl>
|
||||
</Tooltip>
|
||||
<FormControl isDisabled={!items.length} isInvalid={!selectedItem || !items.length}>
|
||||
<FormLabel>{t('modelManager.model')}</FormLabel>
|
||||
<CustomSelect selectedItem={selectedItem} placeholder={placeholder} items={items} onChange={onChange} />
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
|
||||
|
@ -7,7 +7,7 @@ import { selectGenerationSlice, vaeSelected } from 'features/parameters/store/ge
|
||||
import { pick } from 'lodash-es';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { VaeModelConfigEntity } from 'services/api/endpoints/models';
|
||||
import type { VAEConfig } from 'services/api/endpoints/models';
|
||||
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
const selector = createMemoizedSelector(selectGenerationSlice, (generation) => {
|
||||
@ -21,7 +21,7 @@ const ParamVAEModelSelect = () => {
|
||||
const { model, vae } = useAppSelector(selector);
|
||||
const { data, isLoading } = useGetVaeModelsQuery();
|
||||
const getIsDisabled = useCallback(
|
||||
(vae: VaeModelConfigEntity): boolean => {
|
||||
(vae: VAEConfig): boolean => {
|
||||
const isCompatible = model?.base_model === vae.base_model;
|
||||
const hasMainModel = Boolean(model?.base_model);
|
||||
return !hasMainModel || !isCompatible;
|
||||
@ -29,7 +29,7 @@ const ParamVAEModelSelect = () => {
|
||||
[model?.base_model]
|
||||
);
|
||||
const _onChange = useCallback(
|
||||
(vae: VaeModelConfigEntity | null) => {
|
||||
(vae: VAEConfig | null) => {
|
||||
dispatch(vaeSelected(vae ? pick(vae, 'base_model', 'model_name') : null));
|
||||
},
|
||||
[dispatch]
|
||||
|
@ -464,17 +464,15 @@ export const useRecallParameters = () => {
|
||||
return { lora: null, error: 'Invalid LoRA model' };
|
||||
}
|
||||
|
||||
const { base_model, model_name } = loraMetadataItem.lora;
|
||||
const { lora } = loraMetadataItem;
|
||||
|
||||
const matchingLoRA = loraModels
|
||||
? loraModelsAdapterSelectors.selectById(loraModels, `${base_model}/lora/${model_name}`)
|
||||
: undefined;
|
||||
const matchingLoRA = loraModels ? loraModelsAdapterSelectors.selectById(loraModels, lora.key) : undefined;
|
||||
|
||||
if (!matchingLoRA) {
|
||||
return { lora: null, error: 'LoRA model is not installed' };
|
||||
}
|
||||
|
||||
const isCompatibleBaseModel = matchingLoRA?.base_model === (newModel ?? model)?.base_model;
|
||||
const isCompatibleBaseModel = matchingLoRA?.base === (newModel ?? model)?.base;
|
||||
|
||||
if (!isCompatibleBaseModel) {
|
||||
return {
|
||||
@ -520,17 +518,14 @@ export const useRecallParameters = () => {
|
||||
controlnetMetadataItem;
|
||||
|
||||
const matchingControlNetModel = controlNetModels
|
||||
? controlNetModelsAdapterSelectors.selectById(
|
||||
controlNetModels,
|
||||
`${control_model.base_model}/controlnet/${control_model.model_name}`
|
||||
)
|
||||
? controlNetModelsAdapterSelectors.selectById(controlNetModels, control_model.key)
|
||||
: undefined;
|
||||
|
||||
if (!matchingControlNetModel) {
|
||||
return { controlnet: null, error: 'ControlNet model is not installed' };
|
||||
}
|
||||
|
||||
const isCompatibleBaseModel = matchingControlNetModel?.base_model === (newModel ?? model)?.base_model;
|
||||
const isCompatibleBaseModel = matchingControlNetModel?.base === (newModel ?? model)?.base;
|
||||
|
||||
if (!isCompatibleBaseModel) {
|
||||
return {
|
||||
@ -597,17 +592,14 @@ export const useRecallParameters = () => {
|
||||
t2iAdapterMetadataItem;
|
||||
|
||||
const matchingT2IAdapterModel = t2iAdapterModels
|
||||
? t2iAdapterModelsAdapterSelectors.selectById(
|
||||
t2iAdapterModels,
|
||||
`${t2i_adapter_model.base_model}/t2i_adapter/${t2i_adapter_model.model_name}`
|
||||
)
|
||||
? t2iAdapterModelsAdapterSelectors.selectById(t2iAdapterModels, t2i_adapter_model.key)
|
||||
: undefined;
|
||||
|
||||
if (!matchingT2IAdapterModel) {
|
||||
return { controlnet: null, error: 'ControlNet model is not installed' };
|
||||
}
|
||||
|
||||
const isCompatibleBaseModel = matchingT2IAdapterModel?.base_model === (newModel ?? model)?.base_model;
|
||||
const isCompatibleBaseModel = matchingT2IAdapterModel?.base === (newModel ?? model)?.base;
|
||||
|
||||
if (!isCompatibleBaseModel) {
|
||||
return {
|
||||
@ -672,17 +664,14 @@ export const useRecallParameters = () => {
|
||||
const { image, ip_adapter_model, weight, begin_step_percent, end_step_percent } = ipAdapterMetadataItem;
|
||||
|
||||
const matchingIPAdapterModel = ipAdapterModels
|
||||
? ipAdapterModelsAdapterSelectors.selectById(
|
||||
ipAdapterModels,
|
||||
`${ip_adapter_model.base_model}/ip_adapter/${ip_adapter_model.model_name}`
|
||||
)
|
||||
? ipAdapterModelsAdapterSelectors.selectById(ipAdapterModels, ip_adapter_model.key)
|
||||
: undefined;
|
||||
|
||||
if (!matchingIPAdapterModel) {
|
||||
return { ipAdapter: null, error: 'IP Adapter model is not installed' };
|
||||
}
|
||||
|
||||
const isCompatibleBaseModel = matchingIPAdapterModel?.base_model === (newModel ?? model)?.base_model;
|
||||
const isCompatibleBaseModel = matchingIPAdapterModel?.base === (newModel ?? model)?.base;
|
||||
|
||||
if (!isCompatibleBaseModel) {
|
||||
return {
|
||||
|
@ -1,6 +1,7 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import type { ImageDTO, MainModelField } from 'services/api/types';
|
||||
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
export const initialImageSelected = createAction<ImageDTO | undefined>('generation/initialImageSelected');
|
||||
|
||||
export const modelSelected = createAction<MainModelField>('generation/modelSelected');
|
||||
export const modelSelected = createAction<ParameterModel>('generation/modelSelected');
|
||||
|
@ -158,15 +158,15 @@ export const generationSlice = createSlice({
|
||||
// Clamp ClipSkip Based On Selected Model
|
||||
// TODO(psyche): remove this special handling when https://github.com/invoke-ai/InvokeAI/issues/4583 is resolved
|
||||
// WIP PR here: https://github.com/invoke-ai/InvokeAI/pull/4624
|
||||
if (newModel.base_model === 'sdxl') {
|
||||
if (newModel.base === 'sdxl') {
|
||||
// We don't support clip skip for SDXL yet - it's not in the graphs
|
||||
state.clipSkip = 0;
|
||||
} else {
|
||||
const { maxClip } = CLIP_SKIP_MAP[newModel.base_model];
|
||||
const { maxClip } = CLIP_SKIP_MAP[newModel.base];
|
||||
state.clipSkip = clamp(state.clipSkip, 0, maxClip);
|
||||
}
|
||||
|
||||
if (action.meta.previousModel?.base_model === newModel.base_model) {
|
||||
if (action.meta.previousModel?.base === newModel.base) {
|
||||
// The base model hasn't changed, we don't need to optimize the size
|
||||
return;
|
||||
}
|
||||
|
@ -17,8 +17,8 @@ export const MODEL_TYPE_MAP = {
|
||||
*/
|
||||
export const MODEL_TYPE_SHORT_MAP = {
|
||||
any: 'Any',
|
||||
'sd-1': 'SD1',
|
||||
'sd-2': 'SD2',
|
||||
'sd-1': 'SD1.X',
|
||||
'sd-2': 'SD2.X',
|
||||
sdxl: 'SDXL',
|
||||
'sdxl-refiner': 'SDXLR',
|
||||
};
|
||||
|
@ -1,5 +1,6 @@
|
||||
import { NUMPY_RAND_MAX } from 'app/constants';
|
||||
import {
|
||||
zBaseModel,
|
||||
zControlNetModelField,
|
||||
zIPAdapterModelField,
|
||||
zLoRAModelField,
|
||||
@ -104,48 +105,48 @@ export const isParameterAspectRatio = (val: unknown): val is ParameterAspectRati
|
||||
// #endregion
|
||||
|
||||
// #region Model
|
||||
export const zParameterModel = zMainModelField;
|
||||
export const zParameterModel = zMainModelField.extend({ base: zBaseModel });
|
||||
export type ParameterModel = z.infer<typeof zParameterModel>;
|
||||
export const isParameterModel = (val: unknown): val is ParameterModel => zParameterModel.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region SDXL Refiner Model
|
||||
export const zParameterSDXLRefinerModel = zSDXLRefinerModelField;
|
||||
export const zParameterSDXLRefinerModel = zSDXLRefinerModelField.extend({ base: zBaseModel });
|
||||
export type ParameterSDXLRefinerModel = z.infer<typeof zParameterSDXLRefinerModel>;
|
||||
export const isParameterSDXLRefinerModel = (val: unknown): val is ParameterSDXLRefinerModel =>
|
||||
zParameterSDXLRefinerModel.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region VAE Model
|
||||
export const zParameterVAEModel = zVAEModelField;
|
||||
export const zParameterVAEModel = zVAEModelField.extend({ base: zBaseModel });
|
||||
export type ParameterVAEModel = z.infer<typeof zParameterVAEModel>;
|
||||
export const isParameterVAEModel = (val: unknown): val is ParameterVAEModel =>
|
||||
zParameterVAEModel.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region LoRA Model
|
||||
export const zParameterLoRAModel = zLoRAModelField;
|
||||
export const zParameterLoRAModel = zLoRAModelField.extend({ base: zBaseModel });
|
||||
export type ParameterLoRAModel = z.infer<typeof zParameterLoRAModel>;
|
||||
export const isParameterLoRAModel = (val: unknown): val is ParameterLoRAModel =>
|
||||
zParameterLoRAModel.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region ControlNet Model
|
||||
export const zParameterControlNetModel = zControlNetModelField;
|
||||
export const zParameterControlNetModel = zControlNetModelField.extend({ base: zBaseModel });
|
||||
export type ParameterControlNetModel = z.infer<typeof zParameterLoRAModel>;
|
||||
export const isParameterControlNetModel = (val: unknown): val is ParameterControlNetModel =>
|
||||
zParameterControlNetModel.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region IP Adapter Model
|
||||
export const zParameterIPAdapterModel = zIPAdapterModelField;
|
||||
export const zParameterIPAdapterModel = zIPAdapterModelField.extend({ base: zBaseModel });
|
||||
export type ParameterIPAdapterModel = z.infer<typeof zParameterIPAdapterModel>;
|
||||
export const isParameterIPAdapterModel = (val: unknown): val is ParameterIPAdapterModel =>
|
||||
zParameterIPAdapterModel.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region T2I Adapter Model
|
||||
export const zParameterT2IAdapterModel = zT2IAdapterModelField;
|
||||
export const zParameterT2IAdapterModel = zT2IAdapterModelField.extend({ base: zBaseModel });
|
||||
export type ParameterT2IAdapterModel = z.infer<typeof zParameterT2IAdapterModel>;
|
||||
export const isParameterT2IAdapterModel = (val: unknown): val is ParameterT2IAdapterModel =>
|
||||
zParameterT2IAdapterModel.safeParse(val).success;
|
||||
|
@ -1,12 +1,12 @@
|
||||
import type { ModelIdentifier } from 'features/nodes/types/common';
|
||||
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
|
||||
|
||||
/**
|
||||
* Gets the optimal dimension for a givel model, based on the model's base_model
|
||||
* @param model The model identifier
|
||||
* @returns The optimal dimension for the model
|
||||
*/
|
||||
export const getOptimalDimension = (model?: ModelIdentifier | null): number =>
|
||||
model?.base_model === 'sdxl' ? 1024 : 512;
|
||||
export const getOptimalDimension = (model?: ModelIdentifierWithBase | null): number =>
|
||||
model?.base === 'sdxl' ? 1024 : 512;
|
||||
|
||||
const MIN_AREA_FACTOR = 0.8;
|
||||
const MAX_AREA_FACTOR = 1.2;
|
||||
|
@ -6,12 +6,12 @@ import { refinerModelChanged, selectSdxlSlice } from 'features/sdxl/store/sdxlSl
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { REFINER_BASE_MODELS } from 'services/api/constants';
|
||||
import type { MainModelConfigEntity } from 'services/api/endpoints/models';
|
||||
import type { MainModelConfig } from 'services/api/endpoints/models';
|
||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
const selectModel = createMemoizedSelector(selectSdxlSlice, (sdxl) => sdxl.refinerModel);
|
||||
|
||||
const optionsFilter = (model: MainModelConfigEntity) => model.base_model === 'sdxl-refiner';
|
||||
const optionsFilter = (model: MainModelConfig) => model.base_model === 'sdxl-refiner';
|
||||
|
||||
const ParamSDXLRefinerModelSelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
@ -19,7 +19,7 @@ const ParamSDXLRefinerModelSelect = () => {
|
||||
const { t } = useTranslation();
|
||||
const { data, isLoading } = useGetMainModelsQuery(REFINER_BASE_MODELS);
|
||||
const _onChange = useCallback(
|
||||
(model: MainModelConfigEntity | null) => {
|
||||
(model: MainModelConfig | null) => {
|
||||
if (!model) {
|
||||
dispatch(refinerModelChanged(null));
|
||||
return;
|
||||
|
@ -24,7 +24,8 @@ const formLabelProps2: FormLabelProps = {
|
||||
const selectBadges = createMemoizedSelector(selectGenerationSlice, (generation) => {
|
||||
const badges: (string | number)[] = [];
|
||||
if (generation.vae) {
|
||||
let vaeBadge = generation.vae.model_name;
|
||||
// TODO(MM2): Fetch the vae name
|
||||
let vaeBadge = generation.vae.key;
|
||||
if (generation.vaePrecision === 'fp16') {
|
||||
vaeBadge += ` ${generation.vaePrecision}`;
|
||||
}
|
||||
|
@ -35,9 +35,10 @@ const badgesSelector = createMemoizedSelector(selectLoraSlice, selectGenerationS
|
||||
const enabledLoRAsCount = filter(lora.loras, (l) => !!l.isEnabled).length;
|
||||
const loraTabBadges = enabledLoRAsCount ? [enabledLoRAsCount] : [];
|
||||
const accordionBadges: (string | number)[] = [];
|
||||
// TODO(MM2): fetch model name
|
||||
if (generation.model) {
|
||||
accordionBadges.push(generation.model.model_name);
|
||||
accordionBadges.push(generation.model.base_model);
|
||||
accordionBadges.push(generation.model.key);
|
||||
accordionBadges.push(generation.model.base);
|
||||
}
|
||||
|
||||
return { loraTabBadges, accordionBadges };
|
||||
|
@ -56,7 +56,7 @@ const selector = createMemoizedSelector(
|
||||
if (hrfEnabled) {
|
||||
badges.push('HiRes Fix');
|
||||
}
|
||||
return { badges, activeTabName, isSDXL: model?.base_model === 'sdxl' };
|
||||
return { badges, activeTabName, isSDXL: model?.base === 'sdxl' };
|
||||
}
|
||||
);
|
||||
|
||||
|
@ -22,7 +22,7 @@ const overlayScrollbarsStyles: CSSProperties = {
|
||||
|
||||
const ParametersPanel = () => {
|
||||
const activeTabName = useAppSelector(activeTabNameSelector);
|
||||
const isSDXL = useAppSelector((s) => s.generation.model?.base_model === 'sdxl');
|
||||
const isSDXL = useAppSelector((s) => s.generation.model?.base === 'sdxl');
|
||||
|
||||
return (
|
||||
<Flex w="full" h="full" flexDir="column" gap={2}>
|
||||
|
@ -3,27 +3,35 @@ import type { OpenAPIV3_1 } from 'openapi-types';
|
||||
import type { paths } from 'services/api/schema';
|
||||
import type { AppConfig, AppDependencyVersions, AppVersion } from 'services/api/types';
|
||||
|
||||
import { api } from '..';
|
||||
import { api, buildV1Url } from '..';
|
||||
|
||||
/**
|
||||
* Builds an endpoint URL for the app router
|
||||
* @example
|
||||
* buildAppInfoUrl('some-path')
|
||||
* // '/api/v1/app/some-path'
|
||||
*/
|
||||
const buildAppInfoUrl = (path: string = '') => buildV1Url(`app/${path}`);
|
||||
|
||||
export const appInfoApi = api.injectEndpoints({
|
||||
endpoints: (build) => ({
|
||||
getAppVersion: build.query<AppVersion, void>({
|
||||
query: () => ({
|
||||
url: `app/version`,
|
||||
url: buildAppInfoUrl('version'),
|
||||
method: 'GET',
|
||||
}),
|
||||
providesTags: ['FetchOnReconnect'],
|
||||
}),
|
||||
getAppDeps: build.query<AppDependencyVersions, void>({
|
||||
query: () => ({
|
||||
url: `app/app_deps`,
|
||||
url: buildAppInfoUrl('app_deps'),
|
||||
method: 'GET',
|
||||
}),
|
||||
providesTags: ['FetchOnReconnect'],
|
||||
}),
|
||||
getAppConfig: build.query<AppConfig, void>({
|
||||
query: () => ({
|
||||
url: `app/config`,
|
||||
url: buildAppInfoUrl('config'),
|
||||
method: 'GET',
|
||||
}),
|
||||
providesTags: ['FetchOnReconnect'],
|
||||
@ -33,28 +41,28 @@ export const appInfoApi = api.injectEndpoints({
|
||||
void
|
||||
>({
|
||||
query: () => ({
|
||||
url: `app/invocation_cache/status`,
|
||||
url: buildAppInfoUrl('invocation_cache/status'),
|
||||
method: 'GET',
|
||||
}),
|
||||
providesTags: ['InvocationCacheStatus', 'FetchOnReconnect'],
|
||||
}),
|
||||
clearInvocationCache: build.mutation<void, void>({
|
||||
query: () => ({
|
||||
url: `app/invocation_cache`,
|
||||
url: buildAppInfoUrl('invocation_cache'),
|
||||
method: 'DELETE',
|
||||
}),
|
||||
invalidatesTags: ['InvocationCacheStatus'],
|
||||
}),
|
||||
enableInvocationCache: build.mutation<void, void>({
|
||||
query: () => ({
|
||||
url: `app/invocation_cache/enable`,
|
||||
url: buildAppInfoUrl('invocation_cache/enable'),
|
||||
method: 'PUT',
|
||||
}),
|
||||
invalidatesTags: ['InvocationCacheStatus'],
|
||||
}),
|
||||
disableInvocationCache: build.mutation<void, void>({
|
||||
query: () => ({
|
||||
url: `app/invocation_cache/disable`,
|
||||
url: buildAppInfoUrl('invocation_cache/disable'),
|
||||
method: 'PUT',
|
||||
}),
|
||||
invalidatesTags: ['InvocationCacheStatus'],
|
||||
|
@ -9,7 +9,15 @@ import type {
|
||||
import { getListImagesUrl } from 'services/api/util';
|
||||
|
||||
import type { ApiTagDescription } from '..';
|
||||
import { api, LIST_TAG } from '..';
|
||||
import { api, buildV1Url, LIST_TAG } from '..';
|
||||
|
||||
/**
|
||||
* Builds an endpoint URL for the boards router
|
||||
* @example
|
||||
* buildBoardsUrl('some-path')
|
||||
* // '/api/v1/boards/some-path'
|
||||
*/
|
||||
export const buildBoardsUrl = (path: string = '') => buildV1Url(`boards/${path}`);
|
||||
|
||||
export const boardsApi = api.injectEndpoints({
|
||||
endpoints: (build) => ({
|
||||
@ -17,7 +25,7 @@ export const boardsApi = api.injectEndpoints({
|
||||
* Boards Queries
|
||||
*/
|
||||
listBoards: build.query<OffsetPaginatedResults_BoardDTO_, ListBoardsArg>({
|
||||
query: (arg) => ({ url: 'boards/', params: arg }),
|
||||
query: (arg) => ({ url: buildBoardsUrl(), params: arg }),
|
||||
providesTags: (result) => {
|
||||
// any list of boards
|
||||
const tags: ApiTagDescription[] = [{ type: 'Board', id: LIST_TAG }, 'FetchOnReconnect'];
|
||||
@ -38,7 +46,7 @@ export const boardsApi = api.injectEndpoints({
|
||||
|
||||
listAllBoards: build.query<Array<BoardDTO>, void>({
|
||||
query: () => ({
|
||||
url: 'boards/',
|
||||
url: buildBoardsUrl(),
|
||||
params: { all: true },
|
||||
}),
|
||||
providesTags: (result) => {
|
||||
@ -61,7 +69,7 @@ export const boardsApi = api.injectEndpoints({
|
||||
|
||||
listAllImageNamesForBoard: build.query<Array<string>, string>({
|
||||
query: (board_id) => ({
|
||||
url: `boards/${board_id}/image_names`,
|
||||
url: buildBoardsUrl(`${board_id}/image_names`),
|
||||
}),
|
||||
providesTags: (result, error, arg) => [{ type: 'ImageNameList', id: arg }, 'FetchOnReconnect'],
|
||||
keepUnusedDataFor: 0,
|
||||
@ -107,7 +115,7 @@ export const boardsApi = api.injectEndpoints({
|
||||
|
||||
createBoard: build.mutation<BoardDTO, string>({
|
||||
query: (board_name) => ({
|
||||
url: `boards/`,
|
||||
url: buildBoardsUrl(),
|
||||
method: 'POST',
|
||||
params: { board_name },
|
||||
}),
|
||||
@ -116,7 +124,7 @@ export const boardsApi = api.injectEndpoints({
|
||||
|
||||
updateBoard: build.mutation<BoardDTO, UpdateBoardArg>({
|
||||
query: ({ board_id, changes }) => ({
|
||||
url: `boards/${board_id}`,
|
||||
url: buildBoardsUrl(board_id),
|
||||
method: 'PATCH',
|
||||
body: changes,
|
||||
}),
|
||||
|
@ -26,8 +26,24 @@ import {
|
||||
} from 'services/api/util';
|
||||
|
||||
import type { ApiTagDescription } from '..';
|
||||
import { api, LIST_TAG } from '..';
|
||||
import { boardsApi } from './boards';
|
||||
import { api, buildV1Url, LIST_TAG } from '..';
|
||||
import { boardsApi, buildBoardsUrl } from './boards';
|
||||
|
||||
/**
|
||||
* Builds an endpoint URL for the images router
|
||||
* @example
|
||||
* buildImagesUrl('some-path')
|
||||
* // '/api/v1/images/some-path'
|
||||
*/
|
||||
const buildImagesUrl = (path: string = '') => buildV1Url(`images/${path}`);
|
||||
|
||||
/**
|
||||
* Builds an endpoint URL for the board_images router
|
||||
* @example
|
||||
* buildBoardImagesUrl('some-path')
|
||||
* // '/api/v1/board_images/some-path'
|
||||
*/
|
||||
const buildBoardImagesUrl = (path: string = '') => buildV1Url(`board_images/${path}`);
|
||||
|
||||
export const imagesApi = api.injectEndpoints({
|
||||
endpoints: (build) => ({
|
||||
@ -90,20 +106,20 @@ export const imagesApi = api.injectEndpoints({
|
||||
keepUnusedDataFor: 86400,
|
||||
}),
|
||||
getIntermediatesCount: build.query<number, void>({
|
||||
query: () => ({ url: 'images/intermediates' }),
|
||||
query: () => ({ url: buildImagesUrl('intermediates') }),
|
||||
providesTags: ['IntermediatesCount', 'FetchOnReconnect'],
|
||||
}),
|
||||
clearIntermediates: build.mutation<number, void>({
|
||||
query: () => ({ url: `images/intermediates`, method: 'DELETE' }),
|
||||
query: () => ({ url: buildImagesUrl('intermediates'), method: 'DELETE' }),
|
||||
invalidatesTags: ['IntermediatesCount'],
|
||||
}),
|
||||
getImageDTO: build.query<ImageDTO, string>({
|
||||
query: (image_name) => ({ url: `images/i/${image_name}` }),
|
||||
query: (image_name) => ({ url: buildImagesUrl(`i/${image_name}`) }),
|
||||
providesTags: (result, error, image_name) => [{ type: 'Image', id: image_name }],
|
||||
keepUnusedDataFor: 86400, // 24 hours
|
||||
}),
|
||||
getImageMetadata: build.query<CoreMetadata | undefined, string>({
|
||||
query: (image_name) => ({ url: `images/i/${image_name}/metadata` }),
|
||||
query: (image_name) => ({ url: buildImagesUrl(`i/${image_name}/metadata`) }),
|
||||
providesTags: (result, error, image_name) => [{ type: 'ImageMetadata', id: image_name }],
|
||||
transformResponse: (
|
||||
response: paths['/api/v1/images/i/{image_name}/metadata']['get']['responses']['200']['content']['application/json']
|
||||
@ -130,7 +146,7 @@ export const imagesApi = api.injectEndpoints({
|
||||
}),
|
||||
deleteImage: build.mutation<void, ImageDTO>({
|
||||
query: ({ image_name }) => ({
|
||||
url: `images/i/${image_name}`,
|
||||
url: buildImagesUrl(`i/${image_name}`),
|
||||
method: 'DELETE',
|
||||
}),
|
||||
async onQueryStarted(imageDTO, { dispatch, queryFulfilled }) {
|
||||
@ -185,7 +201,7 @@ export const imagesApi = api.injectEndpoints({
|
||||
query: ({ imageDTOs }) => {
|
||||
const image_names = imageDTOs.map((imageDTO) => imageDTO.image_name);
|
||||
return {
|
||||
url: `images/delete`,
|
||||
url: buildImagesUrl('delete'),
|
||||
method: 'POST',
|
||||
body: {
|
||||
image_names,
|
||||
@ -258,7 +274,7 @@ export const imagesApi = api.injectEndpoints({
|
||||
*/
|
||||
changeImageIsIntermediate: build.mutation<ImageDTO, { imageDTO: ImageDTO; is_intermediate: boolean }>({
|
||||
query: ({ imageDTO, is_intermediate }) => ({
|
||||
url: `images/i/${imageDTO.image_name}`,
|
||||
url: buildImagesUrl(`i/${imageDTO.image_name}`),
|
||||
method: 'PATCH',
|
||||
body: { is_intermediate },
|
||||
}),
|
||||
@ -380,7 +396,7 @@ export const imagesApi = api.injectEndpoints({
|
||||
*/
|
||||
changeImageSessionId: build.mutation<ImageDTO, { imageDTO: ImageDTO; session_id: string }>({
|
||||
query: ({ imageDTO, session_id }) => ({
|
||||
url: `images/i/${imageDTO.image_name}`,
|
||||
url: buildImagesUrl(`i/${imageDTO.image_name}`),
|
||||
method: 'PATCH',
|
||||
body: { session_id },
|
||||
}),
|
||||
@ -417,7 +433,7 @@ export const imagesApi = api.injectEndpoints({
|
||||
{ imageDTOs: ImageDTO[] }
|
||||
>({
|
||||
query: ({ imageDTOs: images }) => ({
|
||||
url: `images/star`,
|
||||
url: buildImagesUrl('star'),
|
||||
method: 'POST',
|
||||
body: { image_names: images.map((img) => img.image_name) },
|
||||
}),
|
||||
@ -511,7 +527,7 @@ export const imagesApi = api.injectEndpoints({
|
||||
{ imageDTOs: ImageDTO[] }
|
||||
>({
|
||||
query: ({ imageDTOs: images }) => ({
|
||||
url: `images/unstar`,
|
||||
url: buildImagesUrl('unstar'),
|
||||
method: 'POST',
|
||||
body: { image_names: images.map((img) => img.image_name) },
|
||||
}),
|
||||
@ -611,7 +627,7 @@ export const imagesApi = api.injectEndpoints({
|
||||
const formData = new FormData();
|
||||
formData.append('file', file);
|
||||
return {
|
||||
url: `images/upload`,
|
||||
url: buildImagesUrl('upload'),
|
||||
method: 'POST',
|
||||
body: formData,
|
||||
params: {
|
||||
@ -674,7 +690,7 @@ export const imagesApi = api.injectEndpoints({
|
||||
}),
|
||||
|
||||
deleteBoard: build.mutation<DeleteBoardResult, string>({
|
||||
query: (board_id) => ({ url: `boards/${board_id}`, method: 'DELETE' }),
|
||||
query: (board_id) => ({ url: buildBoardsUrl(board_id), method: 'DELETE' }),
|
||||
invalidatesTags: () => [
|
||||
{ type: 'Board', id: LIST_TAG },
|
||||
// invalidate the 'No Board' cache
|
||||
@ -764,7 +780,7 @@ export const imagesApi = api.injectEndpoints({
|
||||
|
||||
deleteBoardAndImages: build.mutation<DeleteBoardResult, string>({
|
||||
query: (board_id) => ({
|
||||
url: `boards/${board_id}`,
|
||||
url: buildBoardsUrl(board_id),
|
||||
method: 'DELETE',
|
||||
params: { include_images: true },
|
||||
}),
|
||||
@ -840,7 +856,7 @@ export const imagesApi = api.injectEndpoints({
|
||||
query: ({ board_id, imageDTO }) => {
|
||||
const { image_name } = imageDTO;
|
||||
return {
|
||||
url: `board_images/`,
|
||||
url: buildBoardImagesUrl(),
|
||||
method: 'POST',
|
||||
body: { board_id, image_name },
|
||||
};
|
||||
@ -961,7 +977,7 @@ export const imagesApi = api.injectEndpoints({
|
||||
query: ({ imageDTO }) => {
|
||||
const { image_name } = imageDTO;
|
||||
return {
|
||||
url: `board_images/`,
|
||||
url: buildBoardImagesUrl(),
|
||||
method: 'DELETE',
|
||||
body: { image_name },
|
||||
};
|
||||
@ -1080,7 +1096,7 @@ export const imagesApi = api.injectEndpoints({
|
||||
}
|
||||
>({
|
||||
query: ({ board_id, imageDTOs }) => ({
|
||||
url: `board_images/batch`,
|
||||
url: buildBoardImagesUrl('batch'),
|
||||
method: 'POST',
|
||||
body: {
|
||||
image_names: imageDTOs.map((i) => i.image_name),
|
||||
@ -1197,7 +1213,7 @@ export const imagesApi = api.injectEndpoints({
|
||||
}
|
||||
>({
|
||||
query: ({ imageDTOs }) => ({
|
||||
url: `board_images/batch/delete`,
|
||||
url: buildBoardImagesUrl('batch/delete'),
|
||||
method: 'POST',
|
||||
body: {
|
||||
image_names: imageDTOs.map((i) => i.image_name),
|
||||
@ -1321,7 +1337,7 @@ export const imagesApi = api.injectEndpoints({
|
||||
components['schemas']['Body_download_images_from_list']
|
||||
>({
|
||||
query: ({ image_names, board_id }) => ({
|
||||
url: `images/download`,
|
||||
url: buildImagesUrl('download'),
|
||||
method: 'POST',
|
||||
body: {
|
||||
image_names,
|
||||
|
@ -1,63 +1,28 @@
|
||||
import type { EntityState } from '@reduxjs/toolkit';
|
||||
import type { EntityAdapter, EntityState } from '@reduxjs/toolkit';
|
||||
import { createEntityAdapter } from '@reduxjs/toolkit';
|
||||
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
|
||||
import { cloneDeep } from 'lodash-es';
|
||||
import queryString from 'query-string';
|
||||
import type { operations, paths } from 'services/api/schema';
|
||||
import type {
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
CheckpointModelConfig,
|
||||
ControlNetModelConfig,
|
||||
DiffusersModelConfig,
|
||||
ControlNetConfig,
|
||||
ImportModelConfig,
|
||||
IPAdapterModelConfig,
|
||||
LoRAModelConfig,
|
||||
IPAdapterConfig,
|
||||
LoRAConfig,
|
||||
MainModelConfig,
|
||||
MergeModelConfig,
|
||||
ModelType,
|
||||
T2IAdapterModelConfig,
|
||||
TextualInversionModelConfig,
|
||||
VaeModelConfig,
|
||||
T2IAdapterConfig,
|
||||
TextualInversionConfig,
|
||||
VAEConfig,
|
||||
} from 'services/api/types';
|
||||
|
||||
import type { ApiTagDescription } from '..';
|
||||
import { api, LIST_TAG } from '..';
|
||||
import type { ApiTagDescription, tagTypes } from '..';
|
||||
import { api, buildV2Url, LIST_TAG } from '..';
|
||||
|
||||
export type DiffusersModelConfigEntity = DiffusersModelConfig & { id: string };
|
||||
export type CheckpointModelConfigEntity = CheckpointModelConfig & {
|
||||
id: string;
|
||||
};
|
||||
export type MainModelConfigEntity = DiffusersModelConfigEntity | CheckpointModelConfigEntity;
|
||||
|
||||
export type LoRAModelConfigEntity = LoRAModelConfig & { id: string };
|
||||
|
||||
export type ControlNetModelConfigEntity = ControlNetModelConfig & {
|
||||
id: string;
|
||||
};
|
||||
|
||||
export type IPAdapterModelConfigEntity = IPAdapterModelConfig & {
|
||||
id: string;
|
||||
};
|
||||
|
||||
export type T2IAdapterModelConfigEntity = T2IAdapterModelConfig & {
|
||||
id: string;
|
||||
};
|
||||
|
||||
export type TextualInversionModelConfigEntity = TextualInversionModelConfig & {
|
||||
id: string;
|
||||
};
|
||||
|
||||
export type VaeModelConfigEntity = VaeModelConfig & { id: string };
|
||||
|
||||
export type AnyModelConfigEntity =
|
||||
| MainModelConfigEntity
|
||||
| LoRAModelConfigEntity
|
||||
| ControlNetModelConfigEntity
|
||||
| IPAdapterModelConfigEntity
|
||||
| T2IAdapterModelConfigEntity
|
||||
| TextualInversionModelConfigEntity
|
||||
| VaeModelConfigEntity;
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
export const getModelId = (input: any): any => input;
|
||||
|
||||
type UpdateMainModelArg = {
|
||||
base_model: BaseModelType;
|
||||
@ -68,11 +33,13 @@ type UpdateMainModelArg = {
|
||||
type UpdateLoRAModelArg = {
|
||||
base_model: BaseModelType;
|
||||
model_name: string;
|
||||
body: LoRAModelConfig;
|
||||
body: LoRAConfig;
|
||||
};
|
||||
|
||||
type UpdateMainModelResponse =
|
||||
paths['/api/v1/models/{base_model}/{model_type}/{model_name}']['patch']['responses']['200']['content']['application/json'];
|
||||
paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
|
||||
|
||||
type ListModelsArg = NonNullable<paths['/api/models_v2/']['get']['parameters']['query']>;
|
||||
|
||||
type UpdateLoRAModelResponse = UpdateMainModelResponse;
|
||||
|
||||
@ -128,91 +95,95 @@ type CheckpointConfigsResponse =
|
||||
|
||||
type SearchFolderArg = operations['search_for_models']['parameters']['query'];
|
||||
|
||||
export const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||
export const mainModelsAdapter = createEntityAdapter<MainModelConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
export const mainModelsAdapterSelectors = mainModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||
export const loraModelsAdapter = createEntityAdapter<LoRAModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||
export const loraModelsAdapter = createEntityAdapter<LoRAConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
export const loraModelsAdapterSelectors = loraModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||
export const controlNetModelsAdapter = createEntityAdapter<ControlNetModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||
export const controlNetModelsAdapter = createEntityAdapter<ControlNetConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
export const controlNetModelsAdapterSelectors = controlNetModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||
export const ipAdapterModelsAdapter = createEntityAdapter<IPAdapterModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||
export const ipAdapterModelsAdapter = createEntityAdapter<IPAdapterConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
export const ipAdapterModelsAdapterSelectors = ipAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||
export const t2iAdapterModelsAdapter = createEntityAdapter<T2IAdapterModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||
export const t2iAdapterModelsAdapter = createEntityAdapter<T2IAdapterConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
export const t2iAdapterModelsAdapterSelectors = t2iAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||
export const textualInversionModelsAdapter = createEntityAdapter<TextualInversionModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||
export const textualInversionModelsAdapter = createEntityAdapter<TextualInversionConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
export const textualInversionModelsAdapterSelectors = textualInversionModelsAdapter.getSelectors(
|
||||
undefined,
|
||||
getSelectorsOptions
|
||||
);
|
||||
export const vaeModelsAdapter = createEntityAdapter<VaeModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||
export const vaeModelsAdapter = createEntityAdapter<VAEConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
export const vaeModelsAdapterSelectors = vaeModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||
|
||||
export const getModelId = ({
|
||||
base_model,
|
||||
model_type,
|
||||
model_name,
|
||||
}: Pick<AnyModelConfig, 'base_model' | 'model_name' | 'model_type'>) => `${base_model}/${model_type}/${model_name}`;
|
||||
const buildProvidesTags =
|
||||
<TEntity extends AnyModelConfig>(tagType: (typeof tagTypes)[number]) =>
|
||||
(result: EntityState<TEntity, string> | undefined) => {
|
||||
const tags: ApiTagDescription[] = [{ type: tagType, id: LIST_TAG }, 'Model'];
|
||||
|
||||
const createModelEntities = <T extends AnyModelConfigEntity>(models: AnyModelConfig[]): T[] => {
|
||||
const entityArray: T[] = [];
|
||||
models.forEach((model) => {
|
||||
const entity = {
|
||||
...cloneDeep(model),
|
||||
id: getModelId(model),
|
||||
} as T;
|
||||
entityArray.push(entity);
|
||||
});
|
||||
return entityArray;
|
||||
};
|
||||
if (result) {
|
||||
tags.push(
|
||||
...result.ids.map((id) => ({
|
||||
type: tagType,
|
||||
id,
|
||||
}))
|
||||
);
|
||||
}
|
||||
|
||||
return tags;
|
||||
};
|
||||
|
||||
const buildTransformResponse =
|
||||
<T extends AnyModelConfig>(adapter: EntityAdapter<T, string>) =>
|
||||
(response: { models: T[] }) => {
|
||||
return adapter.setAll(adapter.getInitialState(), response.models);
|
||||
};
|
||||
|
||||
/**
|
||||
* Builds an endpoint URL for the models router
|
||||
* @example
|
||||
* buildModelsUrl('some-path')
|
||||
* // '/api/v1/models/some-path'
|
||||
*/
|
||||
const buildModelsUrl = (path: string = '') => buildV2Url(`models/${path}`);
|
||||
|
||||
export const modelsApi = api.injectEndpoints({
|
||||
endpoints: (build) => ({
|
||||
getMainModels: build.query<EntityState<MainModelConfigEntity, string>, BaseModelType[]>({
|
||||
getMainModels: build.query<EntityState<MainModelConfig, string>, BaseModelType[]>({
|
||||
query: (base_models) => {
|
||||
const params = {
|
||||
const params: ListModelsArg = {
|
||||
model_type: 'main',
|
||||
base_models,
|
||||
};
|
||||
|
||||
const query = queryString.stringify(params, { arrayFormat: 'none' });
|
||||
return `models/?${query}`;
|
||||
},
|
||||
providesTags: (result) => {
|
||||
const tags: ApiTagDescription[] = [{ type: 'MainModel', id: LIST_TAG }, 'Model'];
|
||||
|
||||
if (result) {
|
||||
tags.push(
|
||||
...result.ids.map((id) => ({
|
||||
type: 'MainModel' as const,
|
||||
id,
|
||||
}))
|
||||
);
|
||||
}
|
||||
|
||||
return tags;
|
||||
},
|
||||
transformResponse: (response: { models: MainModelConfig[] }) => {
|
||||
const entities = createModelEntities<MainModelConfigEntity>(response.models);
|
||||
return mainModelsAdapter.setAll(mainModelsAdapter.getInitialState(), entities);
|
||||
return buildModelsUrl(`?${query}`);
|
||||
},
|
||||
providesTags: buildProvidesTags<MainModelConfig>('MainModel'),
|
||||
transformResponse: buildTransformResponse<MainModelConfig>(mainModelsAdapter),
|
||||
}),
|
||||
updateMainModels: build.mutation<UpdateMainModelResponse, UpdateMainModelArg>({
|
||||
query: ({ base_model, model_name, body }) => {
|
||||
return {
|
||||
url: `models/${base_model}/main/${model_name}`,
|
||||
url: buildModelsUrl(`${base_model}/main/${model_name}`),
|
||||
method: 'PATCH',
|
||||
body: body,
|
||||
};
|
||||
@ -222,7 +193,7 @@ export const modelsApi = api.injectEndpoints({
|
||||
importMainModels: build.mutation<ImportMainModelResponse, ImportMainModelArg>({
|
||||
query: ({ body }) => {
|
||||
return {
|
||||
url: `models/import`,
|
||||
url: buildModelsUrl('import'),
|
||||
method: 'POST',
|
||||
body: body,
|
||||
};
|
||||
@ -232,7 +203,7 @@ export const modelsApi = api.injectEndpoints({
|
||||
addMainModels: build.mutation<AddMainModelResponse, AddMainModelArg>({
|
||||
query: ({ body }) => {
|
||||
return {
|
||||
url: `models/add`,
|
||||
url: buildModelsUrl('add'),
|
||||
method: 'POST',
|
||||
body: body,
|
||||
};
|
||||
@ -242,7 +213,7 @@ export const modelsApi = api.injectEndpoints({
|
||||
deleteMainModels: build.mutation<DeleteMainModelResponse, DeleteMainModelArg>({
|
||||
query: ({ base_model, model_name, model_type }) => {
|
||||
return {
|
||||
url: `models/${base_model}/${model_type}/${model_name}`,
|
||||
url: buildModelsUrl(`${base_model}/${model_type}/${model_name}`),
|
||||
method: 'DELETE',
|
||||
};
|
||||
},
|
||||
@ -251,7 +222,7 @@ export const modelsApi = api.injectEndpoints({
|
||||
convertMainModels: build.mutation<ConvertMainModelResponse, ConvertMainModelArg>({
|
||||
query: ({ base_model, model_name, convert_dest_directory }) => {
|
||||
return {
|
||||
url: `models/convert/${base_model}/main/${model_name}`,
|
||||
url: buildModelsUrl(`convert/${base_model}/main/${model_name}`),
|
||||
method: 'PUT',
|
||||
params: { convert_dest_directory },
|
||||
};
|
||||
@ -261,7 +232,7 @@ export const modelsApi = api.injectEndpoints({
|
||||
mergeMainModels: build.mutation<MergeMainModelResponse, MergeMainModelArg>({
|
||||
query: ({ base_model, body }) => {
|
||||
return {
|
||||
url: `models/merge/${base_model}`,
|
||||
url: buildModelsUrl(`merge/${base_model}`),
|
||||
method: 'PUT',
|
||||
body: body,
|
||||
};
|
||||
@ -271,37 +242,21 @@ export const modelsApi = api.injectEndpoints({
|
||||
syncModels: build.mutation<SyncModelsResponse, void>({
|
||||
query: () => {
|
||||
return {
|
||||
url: `models/sync`,
|
||||
url: buildModelsUrl('sync'),
|
||||
method: 'POST',
|
||||
};
|
||||
},
|
||||
invalidatesTags: ['Model'],
|
||||
}),
|
||||
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity, string>, void>({
|
||||
query: () => ({ url: 'models/', params: { model_type: 'lora' } }),
|
||||
providesTags: (result) => {
|
||||
const tags: ApiTagDescription[] = [{ type: 'LoRAModel', id: LIST_TAG }, 'Model'];
|
||||
|
||||
if (result) {
|
||||
tags.push(
|
||||
...result.ids.map((id) => ({
|
||||
type: 'LoRAModel' as const,
|
||||
id,
|
||||
}))
|
||||
);
|
||||
}
|
||||
|
||||
return tags;
|
||||
},
|
||||
transformResponse: (response: { models: LoRAModelConfig[] }) => {
|
||||
const entities = createModelEntities<LoRAModelConfigEntity>(response.models);
|
||||
return loraModelsAdapter.setAll(loraModelsAdapter.getInitialState(), entities);
|
||||
},
|
||||
getLoRAModels: build.query<EntityState<LoRAConfig, string>, void>({
|
||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 'lora' } }),
|
||||
providesTags: buildProvidesTags<LoRAConfig>('LoRAModel'),
|
||||
transformResponse: buildTransformResponse<LoRAConfig>(loraModelsAdapter),
|
||||
}),
|
||||
updateLoRAModels: build.mutation<UpdateLoRAModelResponse, UpdateLoRAModelArg>({
|
||||
query: ({ base_model, model_name, body }) => {
|
||||
return {
|
||||
url: `models/${base_model}/lora/${model_name}`,
|
||||
url: buildModelsUrl(`${base_model}/lora/${model_name}`),
|
||||
method: 'PATCH',
|
||||
body: body,
|
||||
};
|
||||
@ -311,129 +266,49 @@ export const modelsApi = api.injectEndpoints({
|
||||
deleteLoRAModels: build.mutation<DeleteLoRAModelResponse, DeleteLoRAModelArg>({
|
||||
query: ({ base_model, model_name }) => {
|
||||
return {
|
||||
url: `models/${base_model}/lora/${model_name}`,
|
||||
url: buildModelsUrl(`${base_model}/lora/${model_name}`),
|
||||
method: 'DELETE',
|
||||
};
|
||||
},
|
||||
invalidatesTags: [{ type: 'LoRAModel', id: LIST_TAG }],
|
||||
}),
|
||||
getControlNetModels: build.query<EntityState<ControlNetModelConfigEntity, string>, void>({
|
||||
query: () => ({ url: 'models/', params: { model_type: 'controlnet' } }),
|
||||
providesTags: (result) => {
|
||||
const tags: ApiTagDescription[] = [{ type: 'ControlNetModel', id: LIST_TAG }, 'Model'];
|
||||
|
||||
if (result) {
|
||||
tags.push(
|
||||
...result.ids.map((id) => ({
|
||||
type: 'ControlNetModel' as const,
|
||||
id,
|
||||
}))
|
||||
);
|
||||
}
|
||||
|
||||
return tags;
|
||||
},
|
||||
transformResponse: (response: { models: ControlNetModelConfig[] }) => {
|
||||
const entities = createModelEntities<ControlNetModelConfigEntity>(response.models);
|
||||
return controlNetModelsAdapter.setAll(controlNetModelsAdapter.getInitialState(), entities);
|
||||
},
|
||||
getControlNetModels: build.query<EntityState<ControlNetConfig, string>, void>({
|
||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 'controlnet' } }),
|
||||
providesTags: buildProvidesTags<ControlNetConfig>('ControlNetModel'),
|
||||
transformResponse: buildTransformResponse<ControlNetConfig>(controlNetModelsAdapter),
|
||||
}),
|
||||
getIPAdapterModels: build.query<EntityState<IPAdapterModelConfigEntity, string>, void>({
|
||||
query: () => ({ url: 'models/', params: { model_type: 'ip_adapter' } }),
|
||||
providesTags: (result) => {
|
||||
const tags: ApiTagDescription[] = [{ type: 'IPAdapterModel', id: LIST_TAG }, 'Model'];
|
||||
|
||||
if (result) {
|
||||
tags.push(
|
||||
...result.ids.map((id) => ({
|
||||
type: 'IPAdapterModel' as const,
|
||||
id,
|
||||
}))
|
||||
);
|
||||
}
|
||||
|
||||
return tags;
|
||||
},
|
||||
transformResponse: (response: { models: IPAdapterModelConfig[] }) => {
|
||||
const entities = createModelEntities<IPAdapterModelConfigEntity>(response.models);
|
||||
return ipAdapterModelsAdapter.setAll(ipAdapterModelsAdapter.getInitialState(), entities);
|
||||
},
|
||||
getIPAdapterModels: build.query<EntityState<IPAdapterConfig, string>, void>({
|
||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 'ip_adapter' } }),
|
||||
providesTags: buildProvidesTags<IPAdapterConfig>('IPAdapterModel'),
|
||||
transformResponse: buildTransformResponse<IPAdapterConfig>(ipAdapterModelsAdapter),
|
||||
}),
|
||||
getT2IAdapterModels: build.query<EntityState<T2IAdapterModelConfigEntity, string>, void>({
|
||||
query: () => ({ url: 'models/', params: { model_type: 't2i_adapter' } }),
|
||||
providesTags: (result) => {
|
||||
const tags: ApiTagDescription[] = [{ type: 'T2IAdapterModel', id: LIST_TAG }, 'Model'];
|
||||
|
||||
if (result) {
|
||||
tags.push(
|
||||
...result.ids.map((id) => ({
|
||||
type: 'T2IAdapterModel' as const,
|
||||
id,
|
||||
}))
|
||||
);
|
||||
}
|
||||
|
||||
return tags;
|
||||
},
|
||||
transformResponse: (response: { models: T2IAdapterModelConfig[] }) => {
|
||||
const entities = createModelEntities<T2IAdapterModelConfigEntity>(response.models);
|
||||
return t2iAdapterModelsAdapter.setAll(t2iAdapterModelsAdapter.getInitialState(), entities);
|
||||
},
|
||||
getT2IAdapterModels: build.query<EntityState<T2IAdapterConfig, string>, void>({
|
||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 't2i_adapter' } }),
|
||||
providesTags: buildProvidesTags<T2IAdapterConfig>('T2IAdapterModel'),
|
||||
transformResponse: buildTransformResponse<T2IAdapterConfig>(t2iAdapterModelsAdapter),
|
||||
}),
|
||||
getVaeModels: build.query<EntityState<VaeModelConfigEntity, string>, void>({
|
||||
query: () => ({ url: 'models/', params: { model_type: 'vae' } }),
|
||||
providesTags: (result) => {
|
||||
const tags: ApiTagDescription[] = [{ type: 'VaeModel', id: LIST_TAG }, 'Model'];
|
||||
|
||||
if (result) {
|
||||
tags.push(
|
||||
...result.ids.map((id) => ({
|
||||
type: 'VaeModel' as const,
|
||||
id,
|
||||
}))
|
||||
);
|
||||
}
|
||||
|
||||
return tags;
|
||||
},
|
||||
transformResponse: (response: { models: VaeModelConfig[] }) => {
|
||||
const entities = createModelEntities<VaeModelConfigEntity>(response.models);
|
||||
return vaeModelsAdapter.setAll(vaeModelsAdapter.getInitialState(), entities);
|
||||
},
|
||||
getVaeModels: build.query<EntityState<VAEConfig, string>, void>({
|
||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 'vae' } }),
|
||||
providesTags: buildProvidesTags<VAEConfig>('VaeModel'),
|
||||
transformResponse: buildTransformResponse<VAEConfig>(vaeModelsAdapter),
|
||||
}),
|
||||
getTextualInversionModels: build.query<EntityState<TextualInversionModelConfigEntity, string>, void>({
|
||||
query: () => ({ url: 'models/', params: { model_type: 'embedding' } }),
|
||||
providesTags: (result) => {
|
||||
const tags: ApiTagDescription[] = [{ type: 'TextualInversionModel', id: LIST_TAG }, 'Model'];
|
||||
|
||||
if (result) {
|
||||
tags.push(
|
||||
...result.ids.map((id) => ({
|
||||
type: 'TextualInversionModel' as const,
|
||||
id,
|
||||
}))
|
||||
);
|
||||
}
|
||||
|
||||
return tags;
|
||||
},
|
||||
transformResponse: (response: { models: TextualInversionModelConfig[] }) => {
|
||||
const entities = createModelEntities<TextualInversionModelConfigEntity>(response.models);
|
||||
return textualInversionModelsAdapter.setAll(textualInversionModelsAdapter.getInitialState(), entities);
|
||||
},
|
||||
getTextualInversionModels: build.query<EntityState<TextualInversionConfig, string>, void>({
|
||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 'embedding' } }),
|
||||
providesTags: buildProvidesTags<TextualInversionConfig>('TextualInversionModel'),
|
||||
transformResponse: buildTransformResponse<TextualInversionConfig>(textualInversionModelsAdapter),
|
||||
}),
|
||||
getModelsInFolder: build.query<SearchFolderResponse, SearchFolderArg>({
|
||||
query: (arg) => {
|
||||
const folderQueryStr = queryString.stringify(arg, {});
|
||||
return {
|
||||
url: `/models/search?${folderQueryStr}`,
|
||||
url: buildModelsUrl(`search?${folderQueryStr}`),
|
||||
};
|
||||
},
|
||||
}),
|
||||
getCheckpointConfigs: build.query<CheckpointConfigsResponse, void>({
|
||||
query: () => {
|
||||
return {
|
||||
url: `/models/ckpt_confs`,
|
||||
url: buildModelsUrl(`ckpt_confs`),
|
||||
};
|
||||
},
|
||||
}),
|
||||
|
@ -7,7 +7,15 @@ import queryString from 'query-string';
|
||||
import type { components, paths } from 'services/api/schema';
|
||||
|
||||
import type { ApiTagDescription } from '..';
|
||||
import { api } from '..';
|
||||
import { api, buildV1Url } from '..';
|
||||
|
||||
/**
|
||||
* Builds an endpoint URL for the queue router
|
||||
* @example
|
||||
* buildQueueUrl('some-path')
|
||||
* // '/api/v1/queue/queue_id/some-path'
|
||||
*/
|
||||
const buildQueueUrl = (path: string = '') => buildV1Url(`queue/${$queueId.get()}/${path}`);
|
||||
|
||||
const getListQueueItemsUrl = (queryArgs?: paths['/api/v1/queue/{queue_id}/list']['get']['parameters']['query']) => {
|
||||
const query = queryArgs
|
||||
@ -17,10 +25,10 @@ const getListQueueItemsUrl = (queryArgs?: paths['/api/v1/queue/{queue_id}/list']
|
||||
: undefined;
|
||||
|
||||
if (query) {
|
||||
return `queue/${$queueId.get()}/list?${query}`;
|
||||
return buildQueueUrl(`list?${query}`);
|
||||
}
|
||||
|
||||
return `queue/${$queueId.get()}/list`;
|
||||
return buildQueueUrl('list');
|
||||
};
|
||||
|
||||
export type SessionQueueItemStatus = NonNullable<
|
||||
@ -58,7 +66,7 @@ export const queueApi = api.injectEndpoints({
|
||||
paths['/api/v1/queue/{queue_id}/enqueue_batch']['post']['requestBody']['content']['application/json']
|
||||
>({
|
||||
query: (arg) => ({
|
||||
url: `queue/${$queueId.get()}/enqueue_batch`,
|
||||
url: buildQueueUrl('enqueue_batch'),
|
||||
body: arg,
|
||||
method: 'POST',
|
||||
}),
|
||||
@ -78,7 +86,7 @@ export const queueApi = api.injectEndpoints({
|
||||
void
|
||||
>({
|
||||
query: () => ({
|
||||
url: `queue/${$queueId.get()}/processor/resume`,
|
||||
url: buildQueueUrl('processor/resume'),
|
||||
method: 'PUT',
|
||||
}),
|
||||
invalidatesTags: ['CurrentSessionQueueItem', 'SessionQueueStatus'],
|
||||
@ -88,7 +96,7 @@ export const queueApi = api.injectEndpoints({
|
||||
void
|
||||
>({
|
||||
query: () => ({
|
||||
url: `queue/${$queueId.get()}/processor/pause`,
|
||||
url: buildQueueUrl('processor/pause'),
|
||||
method: 'PUT',
|
||||
}),
|
||||
invalidatesTags: ['CurrentSessionQueueItem', 'SessionQueueStatus'],
|
||||
@ -98,7 +106,7 @@ export const queueApi = api.injectEndpoints({
|
||||
void
|
||||
>({
|
||||
query: () => ({
|
||||
url: `queue/${$queueId.get()}/prune`,
|
||||
url: buildQueueUrl('prune'),
|
||||
method: 'PUT',
|
||||
}),
|
||||
invalidatesTags: ['SessionQueueStatus', 'BatchStatus'],
|
||||
@ -117,7 +125,7 @@ export const queueApi = api.injectEndpoints({
|
||||
void
|
||||
>({
|
||||
query: () => ({
|
||||
url: `queue/${$queueId.get()}/clear`,
|
||||
url: buildQueueUrl('clear'),
|
||||
method: 'PUT',
|
||||
}),
|
||||
invalidatesTags: [
|
||||
@ -142,7 +150,7 @@ export const queueApi = api.injectEndpoints({
|
||||
void
|
||||
>({
|
||||
query: () => ({
|
||||
url: `queue/${$queueId.get()}/current`,
|
||||
url: buildQueueUrl('current'),
|
||||
method: 'GET',
|
||||
}),
|
||||
providesTags: (result) => {
|
||||
@ -158,7 +166,7 @@ export const queueApi = api.injectEndpoints({
|
||||
void
|
||||
>({
|
||||
query: () => ({
|
||||
url: `queue/${$queueId.get()}/next`,
|
||||
url: buildQueueUrl('next'),
|
||||
method: 'GET',
|
||||
}),
|
||||
providesTags: (result) => {
|
||||
@ -174,7 +182,7 @@ export const queueApi = api.injectEndpoints({
|
||||
void
|
||||
>({
|
||||
query: () => ({
|
||||
url: `queue/${$queueId.get()}/status`,
|
||||
url: buildQueueUrl('status'),
|
||||
method: 'GET',
|
||||
}),
|
||||
providesTags: ['SessionQueueStatus', 'FetchOnReconnect'],
|
||||
@ -184,7 +192,7 @@ export const queueApi = api.injectEndpoints({
|
||||
{ batch_id: string }
|
||||
>({
|
||||
query: ({ batch_id }) => ({
|
||||
url: `queue/${$queueId.get()}/b/${batch_id}/status`,
|
||||
url: buildQueueUrl(`/b/${batch_id}/status`),
|
||||
method: 'GET',
|
||||
}),
|
||||
providesTags: (result) => {
|
||||
@ -200,7 +208,7 @@ export const queueApi = api.injectEndpoints({
|
||||
number
|
||||
>({
|
||||
query: (item_id) => ({
|
||||
url: `queue/${$queueId.get()}/i/${item_id}`,
|
||||
url: buildQueueUrl(`i/${item_id}`),
|
||||
method: 'GET',
|
||||
}),
|
||||
providesTags: (result) => {
|
||||
@ -216,7 +224,7 @@ export const queueApi = api.injectEndpoints({
|
||||
number
|
||||
>({
|
||||
query: (item_id) => ({
|
||||
url: `queue/${$queueId.get()}/i/${item_id}/cancel`,
|
||||
url: buildQueueUrl(`i/${item_id}/cancel`),
|
||||
method: 'PUT',
|
||||
}),
|
||||
onQueryStarted: async (item_id, { dispatch, queryFulfilled }) => {
|
||||
@ -253,7 +261,7 @@ export const queueApi = api.injectEndpoints({
|
||||
paths['/api/v1/queue/{queue_id}/cancel_by_batch_ids']['put']['requestBody']['content']['application/json']
|
||||
>({
|
||||
query: (body) => ({
|
||||
url: `queue/${$queueId.get()}/cancel_by_batch_ids`,
|
||||
url: buildQueueUrl('cancel_by_batch_ids'),
|
||||
method: 'PUT',
|
||||
body,
|
||||
}),
|
||||
@ -279,7 +287,7 @@ export const queueApi = api.injectEndpoints({
|
||||
method: 'GET',
|
||||
}),
|
||||
serializeQueryArgs: () => {
|
||||
return `queue/${$queueId.get()}/list`;
|
||||
return buildQueueUrl('list');
|
||||
},
|
||||
transformResponse: (response: components['schemas']['CursorPaginatedResults_SessionQueueItemDTO_']) =>
|
||||
queueItemsAdapter.addMany(
|
||||
|
@ -1,6 +1,14 @@
|
||||
import type { components } from 'services/api/schema';
|
||||
|
||||
import { api } from '..';
|
||||
import { api, buildV1Url } from '..';
|
||||
|
||||
/**
|
||||
* Builds an endpoint URL for the utilities router
|
||||
* @example
|
||||
* buildUtilitiesUrl('some-path')
|
||||
* // '/api/v1/utilities/some-path'
|
||||
*/
|
||||
const buildUtilitiesUrl = (path: string = '') => buildV1Url(`utilities/${path}`);
|
||||
|
||||
export const utilitiesApi = api.injectEndpoints({
|
||||
endpoints: (build) => ({
|
||||
@ -9,7 +17,7 @@ export const utilitiesApi = api.injectEndpoints({
|
||||
{ prompt: string; max_prompts: number }
|
||||
>({
|
||||
query: (arg) => ({
|
||||
url: 'utilities/dynamicprompts',
|
||||
url: buildUtilitiesUrl('dynamicprompts'),
|
||||
body: arg,
|
||||
method: 'POST',
|
||||
}),
|
||||
|
@ -1,6 +1,14 @@
|
||||
import type { paths } from 'services/api/schema';
|
||||
|
||||
import { api, LIST_TAG } from '..';
|
||||
import { api, buildV1Url, LIST_TAG } from '..';
|
||||
|
||||
/**
|
||||
* Builds an endpoint URL for the workflows router
|
||||
* @example
|
||||
* buildWorkflowsUrl('some-path')
|
||||
* // '/api/v1/workflows/some-path'
|
||||
*/
|
||||
const buildWorkflowsUrl = (path: string = '') => buildV1Url(`workflows/${path}`);
|
||||
|
||||
export const workflowsApi = api.injectEndpoints({
|
||||
endpoints: (build) => ({
|
||||
@ -8,7 +16,7 @@ export const workflowsApi = api.injectEndpoints({
|
||||
paths['/api/v1/workflows/i/{workflow_id}']['get']['responses']['200']['content']['application/json'],
|
||||
string
|
||||
>({
|
||||
query: (workflow_id) => `workflows/i/${workflow_id}`,
|
||||
query: (workflow_id) => buildWorkflowsUrl(`i/${workflow_id}`),
|
||||
providesTags: (result, error, workflow_id) => [{ type: 'Workflow', id: workflow_id }, 'FetchOnReconnect'],
|
||||
onQueryStarted: async (arg, api) => {
|
||||
const { dispatch, queryFulfilled } = api;
|
||||
@ -22,7 +30,7 @@ export const workflowsApi = api.injectEndpoints({
|
||||
}),
|
||||
deleteWorkflow: build.mutation<void, string>({
|
||||
query: (workflow_id) => ({
|
||||
url: `workflows/i/${workflow_id}`,
|
||||
url: buildWorkflowsUrl(`i/${workflow_id}`),
|
||||
method: 'DELETE',
|
||||
}),
|
||||
invalidatesTags: (result, error, workflow_id) => [
|
||||
@ -36,7 +44,7 @@ export const workflowsApi = api.injectEndpoints({
|
||||
paths['/api/v1/workflows/']['post']['requestBody']['content']['application/json']['workflow']
|
||||
>({
|
||||
query: (workflow) => ({
|
||||
url: 'workflows/',
|
||||
url: buildWorkflowsUrl(),
|
||||
method: 'POST',
|
||||
body: { workflow },
|
||||
}),
|
||||
@ -50,7 +58,7 @@ export const workflowsApi = api.injectEndpoints({
|
||||
paths['/api/v1/workflows/i/{workflow_id}']['patch']['requestBody']['content']['application/json']['workflow']
|
||||
>({
|
||||
query: (workflow) => ({
|
||||
url: `workflows/i/${workflow.id}`,
|
||||
url: buildWorkflowsUrl(`i/${workflow.id}`),
|
||||
method: 'PATCH',
|
||||
body: { workflow },
|
||||
}),
|
||||
@ -65,7 +73,7 @@ export const workflowsApi = api.injectEndpoints({
|
||||
NonNullable<paths['/api/v1/workflows/']['get']['parameters']['query']>
|
||||
>({
|
||||
query: (params) => ({
|
||||
url: 'workflows/',
|
||||
url: buildWorkflowsUrl(),
|
||||
params,
|
||||
}),
|
||||
providesTags: ['FetchOnReconnect', { type: 'Workflow', id: LIST_TAG }],
|
||||
|
@ -54,7 +54,7 @@ const dynamicBaseQuery: BaseQueryFn<string | FetchArgs, unknown, FetchBaseQueryE
|
||||
const projectId = $projectId.get();
|
||||
|
||||
const fetchBaseQueryArgs: FetchBaseQueryArgs = {
|
||||
baseUrl: baseUrl ? `${baseUrl}/api/v1` : `${window.location.href.replace(/\/$/, '')}/api/v1`,
|
||||
baseUrl: baseUrl || window.location.href.replace(/\/$/, ''),
|
||||
prepareHeaders: (headers) => {
|
||||
if (authToken) {
|
||||
headers.set('Authorization', `Bearer ${authToken}`);
|
||||
@ -108,3 +108,6 @@ function getCircularReplacer() {
|
||||
return value;
|
||||
};
|
||||
}
|
||||
|
||||
export const buildV1Url = (path: string): string => `api/v1/${path}`;
|
||||
export const buildV2Url = (path: string): string => `api/v2/${path}`;
|
||||
|
File diff suppressed because one or more lines are too long
@ -2,6 +2,7 @@ import type { UseToastOptions } from '@invoke-ai/ui-library';
|
||||
import type { EntityState } from '@reduxjs/toolkit';
|
||||
import type { components, paths } from 'services/api/schema';
|
||||
import type { O } from 'ts-toolbelt';
|
||||
import type { SetRequired } from 'type-fest';
|
||||
|
||||
export type S = components['schemas'];
|
||||
|
||||
@ -54,40 +55,36 @@ export type LoRAModelFormat = S['LoRAModelFormat'];
|
||||
export type ControlNetModelField = S['ControlNetModelField'];
|
||||
export type IPAdapterModelField = S['IPAdapterModelField'];
|
||||
export type T2IAdapterModelField = S['T2IAdapterModelField'];
|
||||
export type ModelsList = S['invokeai__app__api__routers__models__ModelsList'];
|
||||
export type ControlField = S['ControlField'];
|
||||
export type IPAdapterField = S['IPAdapterField'];
|
||||
|
||||
// Model Configs
|
||||
export type LoRAModelConfig = S['LoRAModelConfig'];
|
||||
export type VaeModelConfig = S['VaeModelConfig'];
|
||||
export type ControlNetModelCheckpointConfig = S['ControlNetModelCheckpointConfig'];
|
||||
export type ControlNetModelDiffusersConfig = S['ControlNetModelDiffusersConfig'];
|
||||
export type ControlNetModelConfig = ControlNetModelCheckpointConfig | ControlNetModelDiffusersConfig;
|
||||
export type IPAdapterModelInvokeAIConfig = S['IPAdapterModelInvokeAIConfig'];
|
||||
export type IPAdapterModelConfig = IPAdapterModelInvokeAIConfig;
|
||||
export type T2IAdapterModelDiffusersConfig = S['T2IAdapterModelDiffusersConfig'];
|
||||
export type T2IAdapterModelConfig = T2IAdapterModelDiffusersConfig;
|
||||
export type TextualInversionModelConfig = S['TextualInversionModelConfig'];
|
||||
export type DiffusersModelConfig =
|
||||
| S['StableDiffusion1ModelDiffusersConfig']
|
||||
| S['StableDiffusion2ModelDiffusersConfig']
|
||||
| S['StableDiffusionXLModelDiffusersConfig'];
|
||||
export type CheckpointModelConfig =
|
||||
| S['StableDiffusion1ModelCheckpointConfig']
|
||||
| S['StableDiffusion2ModelCheckpointConfig']
|
||||
| S['StableDiffusionXLModelCheckpointConfig'];
|
||||
|
||||
// TODO(MM2): Can we make key required in the pydantic model?
|
||||
type KeyRequired<T extends { key?: string }> = SetRequired<T, 'key'>;
|
||||
export type LoRAConfig = KeyRequired<S['LoRAConfig']>;
|
||||
// TODO(MM2): Can we rename this from Vae -> VAE
|
||||
export type VAEConfig = KeyRequired<S['VaeCheckpointConfig']> | KeyRequired<S['VaeDiffusersConfig']>;
|
||||
export type ControlNetConfig =
|
||||
| KeyRequired<S['ControlNetDiffusersConfig']>
|
||||
| KeyRequired<S['ControlNetCheckpointConfig']>;
|
||||
export type IPAdapterConfig = KeyRequired<S['IPAdapterConfig']>;
|
||||
// TODO(MM2): Can we rename this to T2IAdapterConfig
|
||||
export type T2IAdapterConfig = KeyRequired<S['T2IConfig']>;
|
||||
export type TextualInversionConfig = KeyRequired<S['TextualInversionConfig']>;
|
||||
export type DiffusersModelConfig = KeyRequired<S['MainDiffusersConfig']>;
|
||||
export type CheckpointModelConfig = KeyRequired<S['MainCheckpointConfig']>;
|
||||
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig;
|
||||
export type AnyModelConfig =
|
||||
| LoRAModelConfig
|
||||
| VaeModelConfig
|
||||
| ControlNetModelConfig
|
||||
| IPAdapterModelConfig
|
||||
| T2IAdapterModelConfig
|
||||
| TextualInversionModelConfig
|
||||
| LoRAConfig
|
||||
| VAEConfig
|
||||
| ControlNetConfig
|
||||
| IPAdapterConfig
|
||||
| T2IAdapterConfig
|
||||
| TextualInversionConfig
|
||||
| MainModelConfig;
|
||||
|
||||
export type MergeModelConfig = S['Body_merge_models'];
|
||||
export type MergeModelConfig = S['Body_merge'];
|
||||
export type ImportModelConfig = S['Body_import_model'];
|
||||
|
||||
// Graphs
|
||||
|
@ -3,6 +3,7 @@ import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
|
||||
import { dateComparator } from 'common/util/dateComparator';
|
||||
import { ASSETS_CATEGORIES, IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
||||
import queryString from 'query-string';
|
||||
import { buildV1Url } from 'services/api';
|
||||
|
||||
import type { ImageCache, ImageDTO, ListImagesArgs } from './types';
|
||||
|
||||
@ -79,4 +80,4 @@ export const imagesSelectors = imagesAdapter.getSelectors(undefined, getSelector
|
||||
|
||||
// Helper to create the url for the listImages endpoint. Also we use it to create the cache key.
|
||||
export const getListImagesUrl = (queryArgs: ListImagesArgs) =>
|
||||
`images/?${queryString.stringify(queryArgs, { arrayFormat: 'none' })}`;
|
||||
buildV1Url(`images/?${queryString.stringify(queryArgs, { arrayFormat: 'none' })}`);
|
||||
|
@ -76,9 +76,9 @@ export default defineConfig(({ mode }) => {
|
||||
changeOrigin: true,
|
||||
},
|
||||
// proxy nodes api
|
||||
'/api/v1': {
|
||||
target: 'http://127.0.0.1:9090/api/v1',
|
||||
rewrite: (path) => path.replace(/^\/api\/v1/, ''),
|
||||
'/api/': {
|
||||
target: 'http://127.0.0.1:9090/api/',
|
||||
rewrite: (path) => path.replace(/^\/api/, ''),
|
||||
changeOrigin: true,
|
||||
},
|
||||
},
|
||||
|
Loading…
Reference in New Issue
Block a user