feat(ui): refactor metadata handling

Refactor of metadata recall handling. This is in preparation for a backwards compatibility layer for models.

- Create helpers to fetch a model outside react (e.g. not in a hook)
- Created helpers to parse model metadata
- Renamed a lot of types that were confusing and/or had naming collisions
This commit is contained in:
psychedelicious 2024-02-22 17:33:20 +11:00
parent 79b16596b5
commit 3ed2963f43
16 changed files with 443 additions and 486 deletions

View File

@ -8,4 +8,26 @@ declare global {
}
}
/**
* Raised when the redux store is unable to be retrieved.
*/
export class ReduxStoreNotInitialized extends Error {
/**
* Create ReduxStoreNotInitialized
* @param {String} message
*/
constructor(message = 'Redux store not initialized') {
super(message);
this.name = this.constructor.name;
}
}
export const $store = atom<Readonly<ReturnType<typeof createStore>> | undefined>();
export const getStore = () => {
const store = $store.get();
if (!store) {
throw new ReduxStoreNotInitialized();
}
return store;
};

View File

@ -8,7 +8,7 @@ import { useControlAdapterType } from 'features/controlAdapters/hooks/useControl
import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import { pick } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'services/api/types';
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
type ParamControlAdapterModelProps = {
id: string;
@ -24,7 +24,7 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
const { data, isLoading } = useControlAdapterModelQuery(controlAdapterType);
const _onChange = useCallback(
(model: ControlNetConfig | IPAdapterConfig | T2IAdapterConfig | null) => {
(model: ControlNetModelConfig | IPAdapterModelConfig | T2IAdapterModelConfig | null) => {
if (!model) {
return;
}

View File

@ -7,7 +7,7 @@ import { t } from 'i18next';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
import type { TextualInversionConfig } from 'services/api/types';
import type { TextualInversionModelConfig } from 'services/api/types';
const noOptionsMessage = () => t('embedding.noMatchingEmbedding');
@ -17,7 +17,7 @@ export const EmbeddingSelect = memo(({ onSelect, onClose }: EmbeddingSelectProps
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
const getIsDisabled = useCallback(
(embedding: TextualInversionConfig): boolean => {
(embedding: TextualInversionModelConfig): boolean => {
const isCompatible = currentBaseModel === embedding.base;
const hasMainModel = Boolean(currentBaseModel);
return !hasMainModel || !isCompatible;
@ -27,7 +27,7 @@ export const EmbeddingSelect = memo(({ onSelect, onClose }: EmbeddingSelectProps
const { data, isLoading } = useGetTextualInversionModelsQuery();
const _onChange = useCallback(
(embedding: TextualInversionConfig | null) => {
(embedding: TextualInversionModelConfig | null) => {
if (!embedding) {
return;
}

View File

@ -8,7 +8,7 @@ import { loraAdded, selectLoraSlice } from 'features/lora/store/loraSlice';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
import type { LoRAConfig } from 'services/api/types';
import type { LoRAModelConfig } from 'services/api/types';
const selectAddedLoRAs = createMemoizedSelector(selectLoraSlice, (lora) => lora.loras);
@ -19,7 +19,7 @@ const LoRASelect = () => {
const addedLoRAs = useAppSelector(selectAddedLoRAs);
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
const getIsDisabled = (lora: LoRAConfig): boolean => {
const getIsDisabled = (lora: LoRAModelConfig): boolean => {
const isCompatible = currentBaseModel === lora.base;
const isAdded = Boolean(addedLoRAs[lora.key]);
const hasMainModel = Boolean(currentBaseModel);
@ -27,7 +27,7 @@ const LoRASelect = () => {
};
const _onChange = useCallback(
(lora: LoRAConfig | null) => {
(lora: LoRAModelConfig | null) => {
if (!lora) {
return;
}

View File

@ -2,7 +2,7 @@ import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
import type { LoRAConfig } from 'services/api/types';
import type { LoRAModelConfig } from 'services/api/types';
export type LoRA = ParameterLoRAModel & {
weight: number;
@ -28,13 +28,12 @@ export const loraSlice = createSlice({
name: 'lora',
initialState: initialLoraState,
reducers: {
loraAdded: (state, action: PayloadAction<LoRAConfig>) => {
loraAdded: (state, action: PayloadAction<LoRAModelConfig>) => {
const { key, base } = action.payload;
state.loras[key] = { key, base, ...defaultLoRAConfig };
},
loraRecalled: (state, action: PayloadAction<LoRAConfig & { weight: number }>) => {
const { key, base, weight } = action.payload;
state.loras[key] = { key, base, weight, isEnabled: true };
loraRecalled: (state, action: PayloadAction<LoRA>) => {
state.loras[action.payload.key] = action.payload;
},
loraRemoved: (state, action: PayloadAction<string>) => {
const key = action.payload;

View File

@ -8,12 +8,11 @@ import { memo, useCallback } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import type { LoRAConfig } from 'services/api/endpoints/models';
import { useUpdateLoRAModelsMutation } from 'services/api/endpoints/models';
import type { LoRAConfig } from 'services/api/types';
import type { LoRAModelConfig } from 'services/api/types';
type LoRAModelEditProps = {
model: LoRAConfig;
model: LoRAModelConfig;
};
const LoRAModelEdit = (props: LoRAModelEditProps) => {
@ -30,7 +29,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
control,
formState: { errors },
reset,
} = useForm<LoRAConfig>({
} = useForm<LoRAModelConfig>({
defaultValues: {
model_name: model.model_name ? model.model_name : '',
base_model: model.base_model,
@ -42,7 +41,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
mode: 'onChange',
});
const onSubmit = useCallback<SubmitHandler<LoRAConfig>>(
const onSubmit = useCallback<SubmitHandler<LoRAModelConfig>>(
(values) => {
const responseBody = {
base_model: model.base_model,
@ -53,7 +52,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
updateLoRAModel(responseBody)
.unwrap()
.then((payload) => {
reset(payload as LoRAConfig, { keepDefaultValues: true });
reset(payload as LoRAModelConfig, { keepDefaultValues: true });
dispatch(
addToast(
makeToast({
@ -106,7 +105,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
<FormLabel>{t('modelManager.description')}</FormLabel>
<Input {...register('description')} />
</FormControl>
<BaseModelSelect<LoRAConfig> control={control} name="base_model" />
<BaseModelSelect<LoRAModelConfig> control={control} name="base_model" />
<FormControl isInvalid={Boolean(errors.path)}>
<FormLabel>{t('modelManager.modelLocation')}</FormLabel>

View File

@ -6,7 +6,7 @@ import type { ControlNetModelFieldInputInstance, ControlNetModelFieldInputTempla
import { pick } from 'lodash-es';
import { memo, useCallback } from 'react';
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
import type { ControlNetConfig } from 'services/api/types';
import type { ControlNetModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
@ -18,7 +18,7 @@ const ControlNetModelFieldInputComponent = (props: Props) => {
const { data, isLoading } = useGetControlNetModelsQuery();
const _onChange = useCallback(
(value: ControlNetConfig | null) => {
(value: ControlNetModelConfig | null) => {
if (!value) {
return;
}

View File

@ -6,7 +6,7 @@ import type { IPAdapterModelFieldInputInstance, IPAdapterModelFieldInputTemplate
import { pick } from 'lodash-es';
import { memo, useCallback } from 'react';
import { useGetIPAdapterModelsQuery } from 'services/api/endpoints/models';
import type { IPAdapterConfig } from 'services/api/types';
import type { IPAdapterModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
@ -18,7 +18,7 @@ const IPAdapterModelFieldInputComponent = (
const { data: ipAdapterModels } = useGetIPAdapterModelsQuery();
const _onChange = useCallback(
(value: IPAdapterConfig | null) => {
(value: IPAdapterModelConfig | null) => {
if (!value) {
return;
}

View File

@ -6,7 +6,7 @@ import type { LoRAModelFieldInputInstance, LoRAModelFieldInputTemplate } from 'f
import { pick } from 'lodash-es';
import { memo, useCallback } from 'react';
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
import type { LoRAConfig } from 'services/api/types';
import type { LoRAModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
@ -17,7 +17,7 @@ const LoRAModelFieldInputComponent = (props: Props) => {
const dispatch = useAppDispatch();
const { data, isLoading } = useGetLoRAModelsQuery();
const _onChange = useCallback(
(value: LoRAConfig | null) => {
(value: LoRAModelConfig | null) => {
if (!value) {
return;
}

View File

@ -6,7 +6,7 @@ import type { T2IAdapterModelFieldInputInstance, T2IAdapterModelFieldInputTempla
import { pick } from 'lodash-es';
import { memo, useCallback } from 'react';
import { useGetT2IAdapterModelsQuery } from 'services/api/endpoints/models';
import type { T2IAdapterConfig } from 'services/api/types';
import type { T2IAdapterModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
@ -19,7 +19,7 @@ const T2IAdapterModelFieldInputComponent = (
const { data: t2iAdapterModels } = useGetT2IAdapterModelsQuery();
const _onChange = useCallback(
(value: T2IAdapterConfig | null) => {
(value: T2IAdapterModelConfig | null) => {
if (!value) {
return;
}

View File

@ -7,7 +7,7 @@ import type { VAEModelFieldInputInstance, VAEModelFieldInputTemplate } from 'fea
import { pick } from 'lodash-es';
import { memo, useCallback } from 'react';
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
import type { VAEConfig } from 'services/api/types';
import type { VAEModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
@ -18,7 +18,7 @@ const VAEModelFieldInputComponent = (props: Props) => {
const dispatch = useAppDispatch();
const { data, isLoading } = useGetVaeModelsQuery();
const _onChange = useCallback(
(value: VAEConfig | null) => {
(value: VAEModelConfig | null) => {
if (!value) {
return;
}

View File

@ -8,7 +8,7 @@ import { pick } from 'lodash-es';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
import type { VAEConfig } from 'services/api/types';
import type { VAEModelConfig } from 'services/api/types';
const selector = createMemoizedSelector(selectGenerationSlice, (generation) => {
const { model, vae } = generation;
@ -21,7 +21,7 @@ const ParamVAEModelSelect = () => {
const { model, vae } = useAppSelector(selector);
const { data, isLoading } = useGetVaeModelsQuery();
const getIsDisabled = useCallback(
(vae: VAEConfig): boolean => {
(vae: VAEModelConfig): boolean => {
const isCompatible = model?.base === vae.base;
const hasMainModel = Boolean(model?.base);
return !hasMainModel || !isCompatible;
@ -29,7 +29,7 @@ const ParamVAEModelSelect = () => {
[model?.base]
);
const _onChange = useCallback(
(vae: VAEConfig | null) => {
(vae: VAEModelConfig | null) => {
dispatch(vaeSelected(vae ? pick(vae, 'key', 'base') : null));
},
[dispatch]

View File

@ -1,17 +1,9 @@
import { useAppToaster } from 'app/components/Toaster';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
import { controlAdapterRecalled, controlAdaptersReset } from 'features/controlAdapters/store/controlAdaptersSlice';
import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types';
import {
initialControlNet,
initialIPAdapter,
initialT2IAdapter,
} from 'features/controlAdapters/util/buildControlAdapter';
import { setHrfEnabled, setHrfMethod, setHrfStrength } from 'features/hrf/store/hrfSlice';
import { loraRecalled, lorasCleared } from 'features/lora/store/loraSlice';
import type { ModelIdentifier } from 'features/nodes/types/common';
import { isModelIdentifier } from 'features/nodes/types/common';
import type {
ControlNetMetadataItem,
@ -56,6 +48,14 @@ import {
isParameterStrength,
isParameterWidth,
} from 'features/parameters/types/parameterSchemas';
import {
prepareControlNetMetadataItem,
prepareIPAdapterMetadataItem,
prepareLoRAMetadataItem,
prepareMainModelMetadataItem,
prepareT2IAdapterMetadataItem,
prepareVAEMetadataItem,
} from 'features/parameters/util/modelMetadataHelpers';
import {
refinerModelChanged,
setNegativeStylePromptSDXL,
@ -70,23 +70,7 @@ import {
import { isNil } from 'lodash-es';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { ALL_BASE_MODELS } from 'services/api/constants';
import {
controlNetModelsAdapterSelectors,
ipAdapterModelsAdapterSelectors,
loraModelsAdapterSelectors,
mainModelsAdapterSelectors,
t2iAdapterModelsAdapterSelectors,
useGetControlNetModelsQuery,
useGetIPAdapterModelsQuery,
useGetLoRAModelsQuery,
useGetMainModelsQuery,
useGetT2IAdapterModelsQuery,
useGetVaeModelsQuery,
vaeModelsAdapterSelectors,
} from 'services/api/endpoints/models';
import type { ImageDTO } from 'services/api/types';
import { v4 as uuidv4 } from 'uuid';
const selectModel = createMemoizedSelector(selectGenerationSlice, (generation) => generation.model);
@ -140,9 +124,6 @@ export const useRecallParameters = () => {
[t, toaster]
);
/**
* Recall both prompts with toast
*/
const recallBothPrompts = useCallback(
(positivePrompt: unknown, negativePrompt: unknown, positiveStylePrompt: unknown, negativeStylePrompt: unknown) => {
if (
@ -175,9 +156,6 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall positive prompt with toast
*/
const recallPositivePrompt = useCallback(
(positivePrompt: unknown) => {
if (!isParameterPositivePrompt(positivePrompt)) {
@ -190,9 +168,6 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall negative prompt with toast
*/
const recallNegativePrompt = useCallback(
(negativePrompt: unknown) => {
if (!isParameterNegativePrompt(negativePrompt)) {
@ -205,9 +180,6 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall SDXL Positive Style Prompt with toast
*/
const recallSDXLPositiveStylePrompt = useCallback(
(positiveStylePrompt: unknown) => {
if (!isParameterPositiveStylePromptSDXL(positiveStylePrompt)) {
@ -220,9 +192,6 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall SDXL Negative Style Prompt with toast
*/
const recallSDXLNegativeStylePrompt = useCallback(
(negativeStylePrompt: unknown) => {
if (!isParameterNegativeStylePromptSDXL(negativeStylePrompt)) {
@ -235,9 +204,6 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall seed with toast
*/
const recallSeed = useCallback(
(seed: unknown) => {
if (!isParameterSeed(seed)) {
@ -250,9 +216,6 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall CFG scale with toast
*/
const recallCfgScale = useCallback(
(cfgScale: unknown) => {
if (!isParameterCFGScale(cfgScale)) {
@ -265,9 +228,6 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall CFG rescale multiplier with toast
*/
const recallCfgRescaleMultiplier = useCallback(
(cfgRescaleMultiplier: unknown) => {
if (!isParameterCFGRescaleMultiplier(cfgRescaleMultiplier)) {
@ -280,9 +240,6 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall scheduler with toast
*/
const recallScheduler = useCallback(
(scheduler: unknown) => {
if (!isParameterScheduler(scheduler)) {
@ -295,9 +252,6 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall steps with toast
*/
const recallSteps = useCallback(
(steps: unknown) => {
if (!isParameterSteps(steps)) {
@ -310,9 +264,6 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall width with toast
*/
const recallWidth = useCallback(
(width: unknown) => {
if (!isParameterWidth(width)) {
@ -325,9 +276,6 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall height with toast
*/
const recallHeight = useCallback(
(height: unknown) => {
if (!isParameterHeight(height)) {
@ -340,9 +288,6 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall width and height with toast
*/
const recallWidthAndHeight = useCallback(
(width: unknown, height: unknown) => {
if (!isParameterWidth(width)) {
@ -360,9 +305,6 @@ export const useRecallParameters = () => {
[dispatch, allParameterSetToast, allParameterNotSetToast]
);
/**
* Recall strength with toast
*/
const recallStrength = useCallback(
(strength: unknown) => {
if (!isParameterStrength(strength)) {
@ -375,9 +317,6 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall high resolution enabled with toast
*/
const recallHrfEnabled = useCallback(
(hrfEnabled: unknown) => {
if (!isParameterHRFEnabled(hrfEnabled)) {
@ -390,9 +329,6 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall high resolution strength with toast
*/
const recallHrfStrength = useCallback(
(hrfStrength: unknown) => {
if (!isParameterStrength(hrfStrength)) {
@ -405,9 +341,6 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall high resolution method with toast
*/
const recallHrfMethod = useCallback(
(hrfMethod: unknown) => {
if (!isParameterHRFMethod(hrfMethod)) {
@ -420,358 +353,95 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast]
);
const { data: mainModels } = useGetMainModelsQuery(ALL_BASE_MODELS);
const prepareMainModelMetadataItem = useCallback(
(model: ModelIdentifier) => {
const matchingModel = mainModels ? mainModelsAdapterSelectors.selectById(mainModels, model.key) : undefined;
if (!matchingModel) {
return { model: null, error: 'Model is not installed' };
}
return { model: matchingModel, error: null };
},
[mainModels]
);
/**
* Recall model with toast
*/
const recallModel = useCallback(
(model: unknown) => {
if (!isModelIdentifier(model)) {
parameterNotSetToast();
async (modelMetadataItem: unknown) => {
try {
const model = await prepareMainModelMetadataItem(modelMetadataItem);
dispatch(modelSelected(model));
parameterSetToast();
} catch (e) {
parameterNotSetToast((e as unknown as Error).message);
return;
}
const result = prepareMainModelMetadataItem(model);
if (!result.model) {
parameterNotSetToast(result.error);
return;
}
dispatch(modelSelected(result.model));
parameterSetToast();
},
[prepareMainModelMetadataItem, dispatch, parameterSetToast, parameterNotSetToast]
[dispatch, parameterSetToast, parameterNotSetToast]
);
const { data: vaeModels } = useGetVaeModelsQuery();
const prepareVAEMetadataItem = useCallback(
(vae: ModelIdentifier, newModel?: ParameterModel) => {
const matchingModel = vaeModels ? vaeModelsAdapterSelectors.selectById(vaeModels, vae.key) : undefined;
if (!matchingModel) {
return { vae: null, error: 'VAE model is not installed' };
}
const isCompatibleBaseModel = matchingModel?.base === (newModel ?? model)?.base;
if (!isCompatibleBaseModel) {
return {
vae: null,
error: 'VAE incompatible with currently-selected model',
};
}
return { vae: matchingModel, error: null };
},
[model, vaeModels]
);
/**
* Recall vae model
*/
const recallVaeModel = useCallback(
(vae: unknown) => {
if (!isModelIdentifier(vae) && !isNil(vae)) {
parameterNotSetToast();
return;
}
if (isNil(vae)) {
async (vaeMetadataItem: unknown) => {
if (isNil(vaeMetadataItem)) {
dispatch(vaeSelected(null));
parameterSetToast();
return;
}
const result = prepareVAEMetadataItem(vae);
if (!result.vae) {
parameterNotSetToast(result.error);
try {
const vae = await prepareVAEMetadataItem(vaeMetadataItem);
dispatch(vaeSelected(vae));
parameterSetToast();
} catch (e) {
parameterNotSetToast((e as unknown as Error).message);
return;
}
dispatch(vaeSelected(result.vae));
parameterSetToast();
},
[prepareVAEMetadataItem, dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall LoRA with toast
*/
const { data: loraModels } = useGetLoRAModelsQuery(undefined);
const prepareLoRAMetadataItem = useCallback(
(loraMetadataItem: LoRAMetadataItem, newModel?: ParameterModel) => {
if (!isModelIdentifier(loraMetadataItem.lora)) {
return { lora: null, error: 'Invalid LoRA model' };
}
const { lora } = loraMetadataItem;
const matchingLoRA = loraModels ? loraModelsAdapterSelectors.selectById(loraModels, lora.key) : undefined;
if (!matchingLoRA) {
return { lora: null, error: 'LoRA model is not installed' };
}
const isCompatibleBaseModel = matchingLoRA?.base === (newModel ?? model)?.base;
if (!isCompatibleBaseModel) {
return {
lora: null,
error: 'LoRA incompatible with currently-selected model',
};
}
return { lora: matchingLoRA, error: null };
},
[loraModels, model]
[dispatch, parameterSetToast, parameterNotSetToast]
);
const recallLoRA = useCallback(
(loraMetadataItem: LoRAMetadataItem) => {
const result = prepareLoRAMetadataItem(loraMetadataItem);
if (!result.lora) {
parameterNotSetToast(result.error);
async (loraMetadataItem: LoRAMetadataItem) => {
try {
const lora = await prepareLoRAMetadataItem(loraMetadataItem, model?.base);
dispatch(loraRecalled(lora));
parameterSetToast();
} catch (e) {
parameterNotSetToast((e as unknown as Error).message);
return;
}
dispatch(loraRecalled({ ...result.lora, weight: loraMetadataItem.weight }));
parameterSetToast();
},
[prepareLoRAMetadataItem, dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall ControlNet with toast
*/
const { data: controlNetModels } = useGetControlNetModelsQuery(undefined);
const prepareControlNetMetadataItem = useCallback(
(controlnetMetadataItem: ControlNetMetadataItem, newModel?: ParameterModel) => {
if (!isModelIdentifier(controlnetMetadataItem.control_model)) {
return { controlnet: null, error: 'Invalid ControlNet model' };
}
const { image, control_model, control_weight, begin_step_percent, end_step_percent, control_mode, resize_mode } =
controlnetMetadataItem;
const matchingControlNetModel = controlNetModels
? controlNetModelsAdapterSelectors.selectById(controlNetModels, control_model.key)
: undefined;
if (!matchingControlNetModel) {
return { controlnet: null, error: 'ControlNet model is not installed' };
}
const isCompatibleBaseModel = matchingControlNetModel?.base === (newModel ?? model)?.base;
if (!isCompatibleBaseModel) {
return {
controlnet: null,
error: 'ControlNet incompatible with currently-selected model',
};
}
// We don't save the original image that was processed into a control image, only the processed image
const processorType = 'none';
const processorNode = CONTROLNET_PROCESSORS.none.default;
const controlnet: ControlNetConfig = {
type: 'controlnet',
isEnabled: true,
model: matchingControlNetModel,
weight: typeof control_weight === 'number' ? control_weight : initialControlNet.weight,
beginStepPct: begin_step_percent || initialControlNet.beginStepPct,
endStepPct: end_step_percent || initialControlNet.endStepPct,
controlMode: control_mode || initialControlNet.controlMode,
resizeMode: resize_mode || initialControlNet.resizeMode,
controlImage: image?.image_name || null,
processedControlImage: image?.image_name || null,
processorType,
processorNode,
shouldAutoConfig: true,
id: uuidv4(),
};
return { controlnet, error: null };
},
[controlNetModels, model]
[model?.base, dispatch, parameterSetToast, parameterNotSetToast]
);
const recallControlNet = useCallback(
(controlnetMetadataItem: ControlNetMetadataItem) => {
const result = prepareControlNetMetadataItem(controlnetMetadataItem);
if (!result.controlnet) {
parameterNotSetToast(result.error);
async (controlnetMetadataItem: ControlNetMetadataItem) => {
try {
const controlNetConfig = await prepareControlNetMetadataItem(controlnetMetadataItem, model?.base);
dispatch(controlAdapterRecalled(controlNetConfig));
parameterSetToast();
} catch (e) {
parameterNotSetToast((e as unknown as Error).message);
return;
}
dispatch(controlAdapterRecalled(result.controlnet));
parameterSetToast();
},
[prepareControlNetMetadataItem, dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall T2I Adapter with toast
*/
const { data: t2iAdapterModels } = useGetT2IAdapterModelsQuery(undefined);
const prepareT2IAdapterMetadataItem = useCallback(
(t2iAdapterMetadataItem: T2IAdapterMetadataItem, newModel?: ParameterModel) => {
if (!isModelIdentifier(t2iAdapterMetadataItem.t2i_adapter_model)) {
return { controlnet: null, error: 'Invalid ControlNet model' };
}
const { image, t2i_adapter_model, weight, begin_step_percent, end_step_percent, resize_mode } =
t2iAdapterMetadataItem;
const matchingT2IAdapterModel = t2iAdapterModels
? t2iAdapterModelsAdapterSelectors.selectById(t2iAdapterModels, t2i_adapter_model.key)
: undefined;
if (!matchingT2IAdapterModel) {
return { controlnet: null, error: 'ControlNet model is not installed' };
}
const isCompatibleBaseModel = matchingT2IAdapterModel?.base === (newModel ?? model)?.base;
if (!isCompatibleBaseModel) {
return {
t2iAdapter: null,
error: 'ControlNet incompatible with currently-selected model',
};
}
// We don't save the original image that was processed into a control image, only the processed image
const processorType = 'none';
const processorNode = CONTROLNET_PROCESSORS.none.default;
const t2iAdapter: T2IAdapterConfig = {
type: 't2i_adapter',
isEnabled: true,
model: matchingT2IAdapterModel,
weight: typeof weight === 'number' ? weight : initialT2IAdapter.weight,
beginStepPct: begin_step_percent || initialT2IAdapter.beginStepPct,
endStepPct: end_step_percent || initialT2IAdapter.endStepPct,
resizeMode: resize_mode || initialT2IAdapter.resizeMode,
controlImage: image?.image_name || null,
processedControlImage: image?.image_name || null,
processorType,
processorNode,
shouldAutoConfig: true,
id: uuidv4(),
};
return { t2iAdapter, error: null };
},
[model, t2iAdapterModels]
[model?.base, dispatch, parameterSetToast, parameterNotSetToast]
);
const recallT2IAdapter = useCallback(
(t2iAdapterMetadataItem: T2IAdapterMetadataItem) => {
const result = prepareT2IAdapterMetadataItem(t2iAdapterMetadataItem);
if (!result.t2iAdapter) {
parameterNotSetToast(result.error);
async (t2iAdapterMetadataItem: T2IAdapterMetadataItem) => {
try {
const t2iAdapterConfig = await prepareT2IAdapterMetadataItem(t2iAdapterMetadataItem, model?.base);
dispatch(controlAdapterRecalled(t2iAdapterConfig));
parameterSetToast();
} catch (e) {
parameterNotSetToast((e as unknown as Error).message);
return;
}
dispatch(controlAdapterRecalled(result.t2iAdapter));
parameterSetToast();
},
[prepareT2IAdapterMetadataItem, dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall IP Adapter with toast
*/
const { data: ipAdapterModels } = useGetIPAdapterModelsQuery(undefined);
const prepareIPAdapterMetadataItem = useCallback(
(ipAdapterMetadataItem: IPAdapterMetadataItem, newModel?: ParameterModel) => {
if (!isModelIdentifier(ipAdapterMetadataItem?.ip_adapter_model)) {
return { ipAdapter: null, error: 'Invalid IP Adapter model' };
}
const { image, ip_adapter_model, weight, begin_step_percent, end_step_percent } = ipAdapterMetadataItem;
const matchingIPAdapterModel = ipAdapterModels
? ipAdapterModelsAdapterSelectors.selectById(ipAdapterModels, ip_adapter_model.key)
: undefined;
if (!matchingIPAdapterModel) {
return { ipAdapter: null, error: 'IP Adapter model is not installed' };
}
const isCompatibleBaseModel = matchingIPAdapterModel?.base === (newModel ?? model)?.base;
if (!isCompatibleBaseModel) {
return {
ipAdapter: null,
error: 'IP Adapter incompatible with currently-selected model',
};
}
const ipAdapter: IPAdapterConfig = {
id: uuidv4(),
type: 'ip_adapter',
isEnabled: true,
controlImage: image?.image_name ?? null,
model: matchingIPAdapterModel,
weight: weight ?? initialIPAdapter.weight,
beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct,
endStepPct: end_step_percent ?? initialIPAdapter.endStepPct,
};
return { ipAdapter, error: null };
},
[ipAdapterModels, model]
[model?.base, dispatch, parameterSetToast, parameterNotSetToast]
);
const recallIPAdapter = useCallback(
(ipAdapterMetadataItem: IPAdapterMetadataItem) => {
const result = prepareIPAdapterMetadataItem(ipAdapterMetadataItem);
if (!result.ipAdapter) {
parameterNotSetToast(result.error);
async (ipAdapterMetadataItem: IPAdapterMetadataItem) => {
try {
const ipAdapterConfig = await prepareIPAdapterMetadataItem(ipAdapterMetadataItem, model?.base);
dispatch(controlAdapterRecalled(ipAdapterConfig));
parameterSetToast();
} catch (e) {
parameterNotSetToast((e as unknown as Error).message);
return;
}
dispatch(controlAdapterRecalled(result.ipAdapter));
parameterSetToast();
},
[prepareIPAdapterMetadataItem, dispatch, parameterSetToast, parameterNotSetToast]
[model?.base, dispatch, parameterSetToast, parameterNotSetToast]
);
/*
* Sets image as initial image with toast
*/
const sendToImageToImage = useCallback(
(image: ImageDTO) => {
dispatch(initialImageSelected(image));
@ -780,7 +450,7 @@ export const useRecallParameters = () => {
);
const recallAllParameters = useCallback(
(metadata: CoreMetadata | undefined) => {
async (metadata: CoreMetadata | undefined) => {
if (!metadata) {
allParameterNotSetToast();
return;
@ -820,10 +490,12 @@ export const useRecallParameters = () => {
let newModel: ParameterModel | undefined = undefined;
if (isModelIdentifier(model)) {
const result = prepareMainModelMetadataItem(model);
if (result.model) {
dispatch(modelSelected(result.model));
newModel = result.model;
try {
const _model = await prepareMainModelMetadataItem(model);
dispatch(modelSelected(_model));
newModel = _model;
} catch {
return;
}
}
@ -850,9 +522,11 @@ export const useRecallParameters = () => {
if (isNil(vae)) {
dispatch(vaeSelected(null));
} else {
const result = prepareVAEMetadataItem(vae, newModel);
if (result.vae) {
dispatch(vaeSelected(result.vae));
try {
const _vae = await prepareVAEMetadataItem(vae, newModel?.base);
dispatch(vaeSelected(_vae));
} catch {
return;
}
}
}
@ -926,48 +600,46 @@ export const useRecallParameters = () => {
}
dispatch(lorasCleared());
loras?.forEach((lora) => {
const result = prepareLoRAMetadataItem(lora, newModel);
if (result.lora) {
dispatch(loraRecalled({ ...result.lora, weight: lora.weight }));
loras?.forEach(async (loraMetadataItem) => {
try {
const lora = await prepareLoRAMetadataItem(loraMetadataItem, newModel?.base);
dispatch(loraRecalled(lora));
} catch {
return;
}
});
dispatch(controlAdaptersReset());
controlnets?.forEach((controlnet) => {
const result = prepareControlNetMetadataItem(controlnet, newModel);
if (result.controlnet) {
dispatch(controlAdapterRecalled(result.controlnet));
controlnets?.forEach(async (controlNetMetadataItem) => {
try {
const controlNet = await prepareControlNetMetadataItem(controlNetMetadataItem, newModel?.base);
dispatch(controlAdapterRecalled(controlNet));
} catch {
return;
}
});
ipAdapters?.forEach((ipAdapter) => {
const result = prepareIPAdapterMetadataItem(ipAdapter, newModel);
if (result.ipAdapter) {
dispatch(controlAdapterRecalled(result.ipAdapter));
ipAdapters?.forEach(async (ipAdapterMetadataItem) => {
try {
const ipAdapter = await prepareIPAdapterMetadataItem(ipAdapterMetadataItem, newModel?.base);
dispatch(controlAdapterRecalled(ipAdapter));
} catch {
return;
}
});
t2iAdapters?.forEach((t2iAdapter) => {
const result = prepareT2IAdapterMetadataItem(t2iAdapter, newModel);
if (result.t2iAdapter) {
dispatch(controlAdapterRecalled(result.t2iAdapter));
t2iAdapters?.forEach(async (t2iAdapterMetadataItem) => {
try {
const t2iAdapter = await prepareT2IAdapterMetadataItem(t2iAdapterMetadataItem, newModel?.base);
dispatch(controlAdapterRecalled(t2iAdapter));
} catch {
return;
}
});
allParameterSetToast();
},
[
dispatch,
allParameterSetToast,
allParameterNotSetToast,
prepareMainModelMetadataItem,
prepareVAEMetadataItem,
prepareLoRAMetadataItem,
prepareControlNetMetadataItem,
prepareIPAdapterMetadataItem,
prepareT2IAdapterMetadataItem,
]
[dispatch, allParameterSetToast, allParameterNotSetToast]
);
return {

View File

@ -0,0 +1,113 @@
import { getStore } from 'app/store/nanostores/store';
import { isModelIdentifier } from 'features/nodes/types/common';
import { modelsApi } from 'services/api/endpoints/models';
import type { AnyModelConfig, BaseModelType } from 'services/api/types';
import {
isControlNetModelConfig,
isIPAdapterModelConfig,
isLoRAModelConfig,
isNonRefinerMainModelConfig,
isRefinerMainModelModelConfig,
isT2IAdapterModelConfig,
isTextualInversionModelConfig,
isVAEModelConfig,
} from 'services/api/types';
/**
* Raised when a model config is unable to be fetched.
*/
export class ModelConfigNotFoundError extends Error {
/**
* Create ModelConfigNotFoundError
* @param {String} message
*/
constructor(message: string) {
super(message);
this.name = this.constructor.name;
}
}
/**
* Raised when a fetched model config is of an unexpected type.
*/
export class InvalidModelConfigError extends Error {
/**
* Create InvalidModelConfigError
* @param {String} message
*/
constructor(message: string) {
super(message);
this.name = this.constructor.name;
}
}
export const fetchModelConfig = async (key: string): Promise<AnyModelConfig> => {
const { dispatch } = getStore();
try {
const req = dispatch(modelsApi.endpoints.getModelConfig.initiate(key));
req.unsubscribe();
return await req.unwrap();
} catch {
throw new ModelConfigNotFoundError(`Unable to retrieve model config for key ${key}`);
}
};
export const fetchModelConfigWithTypeGuard = async <T extends AnyModelConfig>(
key: string,
typeGuard: (config: AnyModelConfig) => config is T
) => {
const modelConfig = await fetchModelConfig(key);
if (!typeGuard(modelConfig)) {
throw new InvalidModelConfigError(`Invalid model type for key ${key}: ${modelConfig.type}`);
}
return modelConfig;
};
export const fetchMainModel = async (key: string) => {
return fetchModelConfigWithTypeGuard(key, isNonRefinerMainModelConfig);
};
export const fetchRefinerModel = async (key: string) => {
return fetchModelConfigWithTypeGuard(key, isRefinerMainModelModelConfig);
};
export const fetchVAEModel = async (key: string) => {
return fetchModelConfigWithTypeGuard(key, isVAEModelConfig);
};
export const fetchLoRAModel = async (key: string) => {
return fetchModelConfigWithTypeGuard(key, isLoRAModelConfig);
};
export const fetchControlNetModel = async (key: string) => {
return fetchModelConfigWithTypeGuard(key, isControlNetModelConfig);
};
export const fetchIPAdapterModel = async (key: string) => {
return fetchModelConfigWithTypeGuard(key, isIPAdapterModelConfig);
};
export const fetchT2IAdapterModel = async (key: string) => {
return fetchModelConfigWithTypeGuard(key, isT2IAdapterModelConfig);
};
export const fetchTextualInversionModel = async (key: string) => {
return fetchModelConfigWithTypeGuard(key, isTextualInversionModelConfig);
};
export const isBaseCompatible = (sourceBase: BaseModelType, targetBase: BaseModelType) => {
return sourceBase === targetBase;
};
export const raiseIfBaseIncompatible = (sourceBase: BaseModelType, targetBase?: BaseModelType, message?: string) => {
if (targetBase && !isBaseCompatible(sourceBase, targetBase)) {
throw new InvalidModelConfigError(message || `Incompatible base models: ${sourceBase} and ${targetBase}`);
}
};
export const getModelKey = (modelIdentifier: unknown, message?: string): string => {
if (!isModelIdentifier(modelIdentifier)) {
throw new InvalidModelConfigError(message || `Invalid model identifier: ${modelIdentifier}`);
}
return modelIdentifier.key;
};

View File

@ -0,0 +1,150 @@
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types';
import {
initialControlNet,
initialIPAdapter,
initialT2IAdapter,
} from 'features/controlAdapters/util/buildControlAdapter';
import type { LoRA } from 'features/lora/store/loraSlice';
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
import { zModelIdentifierWithBase } from 'features/nodes/types/common';
import type {
ControlNetMetadataItem,
IPAdapterMetadataItem,
LoRAMetadataItem,
T2IAdapterMetadataItem,
} from 'features/nodes/types/metadata';
import {
fetchControlNetModel,
fetchIPAdapterModel,
fetchLoRAModel,
fetchMainModel,
fetchRefinerModel,
fetchT2IAdapterModel,
fetchVAEModel,
getModelKey,
raiseIfBaseIncompatible,
} from 'features/parameters/util/modelFetchingHelpers';
import type { BaseModelType } from 'services/api/types';
import { v4 as uuidv4 } from 'uuid';
export const prepareMainModelMetadataItem = async (model: unknown): Promise<ModelIdentifierWithBase> => {
const key = getModelKey(model);
const mainModel = await fetchMainModel(key);
return zModelIdentifierWithBase.parse(mainModel);
};
export const prepareRefinerMetadataItem = async (model: unknown): Promise<ModelIdentifierWithBase> => {
const key = getModelKey(model);
const refinerModel = await fetchRefinerModel(key);
return zModelIdentifierWithBase.parse(refinerModel);
};
export const prepareVAEMetadataItem = async (vae: unknown, base?: BaseModelType): Promise<ModelIdentifierWithBase> => {
const key = getModelKey(vae);
const vaeModel = await fetchVAEModel(key);
raiseIfBaseIncompatible(vaeModel.base, base, 'VAE incompatible with currently-selected model');
return zModelIdentifierWithBase.parse(vaeModel);
};
export const prepareLoRAMetadataItem = async (
loraMetadataItem: LoRAMetadataItem,
base?: BaseModelType
): Promise<LoRA> => {
const key = getModelKey(loraMetadataItem.lora);
const loraModel = await fetchLoRAModel(key);
raiseIfBaseIncompatible(loraModel.base, base, 'LoRA incompatible with currently-selected model');
return { key: loraModel.key, base: loraModel.base, weight: loraMetadataItem.weight, isEnabled: true };
};
export const prepareControlNetMetadataItem = async (
controlnetMetadataItem: ControlNetMetadataItem,
base?: BaseModelType
): Promise<ControlNetConfig> => {
const key = getModelKey(controlnetMetadataItem.control_model);
const controlNetModel = await fetchControlNetModel(key);
raiseIfBaseIncompatible(controlNetModel.base, base, 'ControlNet incompatible with currently-selected model');
const { image, control_weight, begin_step_percent, end_step_percent, control_mode, resize_mode } =
controlnetMetadataItem;
// We don't save the original image that was processed into a control image, only the processed image
const processorType = 'none';
const processorNode = CONTROLNET_PROCESSORS.none.default;
const controlnet: ControlNetConfig = {
type: 'controlnet',
isEnabled: true,
model: zModelIdentifierWithBase.parse(controlNetModel),
weight: typeof control_weight === 'number' ? control_weight : initialControlNet.weight,
beginStepPct: begin_step_percent || initialControlNet.beginStepPct,
endStepPct: end_step_percent || initialControlNet.endStepPct,
controlMode: control_mode || initialControlNet.controlMode,
resizeMode: resize_mode || initialControlNet.resizeMode,
controlImage: image?.image_name || null,
processedControlImage: image?.image_name || null,
processorType,
processorNode,
shouldAutoConfig: true,
id: uuidv4(),
};
return controlnet;
};
export const prepareT2IAdapterMetadataItem = async (
t2iAdapterMetadataItem: T2IAdapterMetadataItem,
base?: BaseModelType
): Promise<T2IAdapterConfig> => {
const key = getModelKey(t2iAdapterMetadataItem.t2i_adapter_model);
const t2iAdapterModel = await fetchT2IAdapterModel(key);
raiseIfBaseIncompatible(t2iAdapterModel.base, base, 'T2I Adapter incompatible with currently-selected model');
const { image, weight, begin_step_percent, end_step_percent, resize_mode } = t2iAdapterMetadataItem;
// We don't save the original image that was processed into a control image, only the processed image
const processorType = 'none';
const processorNode = CONTROLNET_PROCESSORS.none.default;
const t2iAdapter: T2IAdapterConfig = {
type: 't2i_adapter',
isEnabled: true,
model: zModelIdentifierWithBase.parse(t2iAdapterModel),
weight: typeof weight === 'number' ? weight : initialT2IAdapter.weight,
beginStepPct: begin_step_percent || initialT2IAdapter.beginStepPct,
endStepPct: end_step_percent || initialT2IAdapter.endStepPct,
resizeMode: resize_mode || initialT2IAdapter.resizeMode,
controlImage: image?.image_name || null,
processedControlImage: image?.image_name || null,
processorType,
processorNode,
shouldAutoConfig: true,
id: uuidv4(),
};
return t2iAdapter;
};
export const prepareIPAdapterMetadataItem = async (
ipAdapterMetadataItem: IPAdapterMetadataItem,
base?: BaseModelType
): Promise<IPAdapterConfig> => {
const key = getModelKey(ipAdapterMetadataItem?.ip_adapter_model);
const ipAdapterModel = await fetchIPAdapterModel(key);
raiseIfBaseIncompatible(ipAdapterModel.base, base, 'T2I Adapter incompatible with currently-selected model');
const { image, weight, begin_step_percent, end_step_percent } = ipAdapterMetadataItem;
const ipAdapter: IPAdapterConfig = {
id: uuidv4(),
type: 'ip_adapter',
isEnabled: true,
controlImage: image?.image_name ?? null,
model: zModelIdentifierWithBase.parse(ipAdapterModel),
weight: weight ?? initialIPAdapter.weight,
beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct,
endStepPct: end_step_percent ?? initialIPAdapter.endStepPct,
};
return ipAdapter;
};

View File

@ -6,16 +6,16 @@ import type { operations, paths } from 'services/api/schema';
import type {
AnyModelConfig,
BaseModelType,
ControlNetConfig,
ControlNetModelConfig,
ImportModelConfig,
IPAdapterConfig,
LoRAConfig,
IPAdapterModelConfig,
LoRAModelConfig,
MainModelConfig,
MergeModelConfig,
ModelType,
T2IAdapterConfig,
TextualInversionConfig,
VAEConfig,
T2IAdapterModelConfig,
TextualInversionModelConfig,
VAEModelConfig,
} from 'services/api/types';
import type { ApiTagDescription, tagTypes } from '..';
@ -30,7 +30,7 @@ type UpdateMainModelArg = {
type UpdateLoRAModelArg = {
base_model: BaseModelType;
model_name: string;
body: LoRAConfig;
body: LoRAModelConfig;
};
type UpdateMainModelResponse =
@ -97,27 +97,27 @@ export const mainModelsAdapter = createEntityAdapter<MainModelConfig, string>({
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const mainModelsAdapterSelectors = mainModelsAdapter.getSelectors(undefined, getSelectorsOptions);
export const loraModelsAdapter = createEntityAdapter<LoRAConfig, string>({
export const loraModelsAdapter = createEntityAdapter<LoRAModelConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const loraModelsAdapterSelectors = loraModelsAdapter.getSelectors(undefined, getSelectorsOptions);
export const controlNetModelsAdapter = createEntityAdapter<ControlNetConfig, string>({
export const controlNetModelsAdapter = createEntityAdapter<ControlNetModelConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const controlNetModelsAdapterSelectors = controlNetModelsAdapter.getSelectors(undefined, getSelectorsOptions);
export const ipAdapterModelsAdapter = createEntityAdapter<IPAdapterConfig, string>({
export const ipAdapterModelsAdapter = createEntityAdapter<IPAdapterModelConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const ipAdapterModelsAdapterSelectors = ipAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions);
export const t2iAdapterModelsAdapter = createEntityAdapter<T2IAdapterConfig, string>({
export const t2iAdapterModelsAdapter = createEntityAdapter<T2IAdapterModelConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const t2iAdapterModelsAdapterSelectors = t2iAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions);
export const textualInversionModelsAdapter = createEntityAdapter<TextualInversionConfig, string>({
export const textualInversionModelsAdapter = createEntityAdapter<TextualInversionModelConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
@ -125,7 +125,7 @@ export const textualInversionModelsAdapterSelectors = textualInversionModelsAdap
undefined,
getSelectorsOptions
);
export const vaeModelsAdapter = createEntityAdapter<VAEConfig, string>({
export const vaeModelsAdapter = createEntityAdapter<VAEModelConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
@ -162,6 +162,8 @@ const buildTransformResponse =
*/
const buildModelsUrl = (path: string = '') => buildV2Url(`models/${path}`);
// TODO(psyche): Ideally we can share the cache between the `getXYZModels` queries and `getModelConfig` query
export const modelsApi = api.injectEndpoints({
endpoints: (build) => ({
getMainModels: build.query<EntityState<MainModelConfig, string>, BaseModelType[]>({
@ -257,10 +259,10 @@ export const modelsApi = api.injectEndpoints({
},
invalidatesTags: ['Model'],
}),
getLoRAModels: build.query<EntityState<LoRAConfig, string>, void>({
getLoRAModels: build.query<EntityState<LoRAModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'lora' } }),
providesTags: buildProvidesTags<LoRAConfig>('LoRAModel'),
transformResponse: buildTransformResponse<LoRAConfig>(loraModelsAdapter),
providesTags: buildProvidesTags<LoRAModelConfig>('LoRAModel'),
transformResponse: buildTransformResponse<LoRAModelConfig>(loraModelsAdapter),
}),
updateLoRAModels: build.mutation<UpdateLoRAModelResponse, UpdateLoRAModelArg>({
query: ({ base_model, model_name, body }) => {
@ -281,30 +283,30 @@ export const modelsApi = api.injectEndpoints({
},
invalidatesTags: [{ type: 'LoRAModel', id: LIST_TAG }],
}),
getControlNetModels: build.query<EntityState<ControlNetConfig, string>, void>({
getControlNetModels: build.query<EntityState<ControlNetModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'controlnet' } }),
providesTags: buildProvidesTags<ControlNetConfig>('ControlNetModel'),
transformResponse: buildTransformResponse<ControlNetConfig>(controlNetModelsAdapter),
providesTags: buildProvidesTags<ControlNetModelConfig>('ControlNetModel'),
transformResponse: buildTransformResponse<ControlNetModelConfig>(controlNetModelsAdapter),
}),
getIPAdapterModels: build.query<EntityState<IPAdapterConfig, string>, void>({
getIPAdapterModels: build.query<EntityState<IPAdapterModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'ip_adapter' } }),
providesTags: buildProvidesTags<IPAdapterConfig>('IPAdapterModel'),
transformResponse: buildTransformResponse<IPAdapterConfig>(ipAdapterModelsAdapter),
providesTags: buildProvidesTags<IPAdapterModelConfig>('IPAdapterModel'),
transformResponse: buildTransformResponse<IPAdapterModelConfig>(ipAdapterModelsAdapter),
}),
getT2IAdapterModels: build.query<EntityState<T2IAdapterConfig, string>, void>({
getT2IAdapterModels: build.query<EntityState<T2IAdapterModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 't2i_adapter' } }),
providesTags: buildProvidesTags<T2IAdapterConfig>('T2IAdapterModel'),
transformResponse: buildTransformResponse<T2IAdapterConfig>(t2iAdapterModelsAdapter),
providesTags: buildProvidesTags<T2IAdapterModelConfig>('T2IAdapterModel'),
transformResponse: buildTransformResponse<T2IAdapterModelConfig>(t2iAdapterModelsAdapter),
}),
getVaeModels: build.query<EntityState<VAEConfig, string>, void>({
getVaeModels: build.query<EntityState<VAEModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'vae' } }),
providesTags: buildProvidesTags<VAEConfig>('VaeModel'),
transformResponse: buildTransformResponse<VAEConfig>(vaeModelsAdapter),
providesTags: buildProvidesTags<VAEModelConfig>('VaeModel'),
transformResponse: buildTransformResponse<VAEModelConfig>(vaeModelsAdapter),
}),
getTextualInversionModels: build.query<EntityState<TextualInversionConfig, string>, void>({
getTextualInversionModels: build.query<EntityState<TextualInversionModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'embedding' } }),
providesTags: buildProvidesTags<TextualInversionConfig>('TextualInversionModel'),
transformResponse: buildTransformResponse<TextualInversionConfig>(textualInversionModelsAdapter),
providesTags: buildProvidesTags<TextualInversionModelConfig>('TextualInversionModel'),
transformResponse: buildTransformResponse<TextualInversionModelConfig>(textualInversionModelsAdapter),
}),
getModelsInFolder: build.query<SearchFolderResponse, SearchFolderArg>({
query: (arg) => {