feat(ui): refactor metadata handling

Refactor of metadata recall handling. This is in preparation for a backwards compatibility layer for models.

- Create helpers to fetch a model outside react (e.g. not in a hook)
- Created helpers to parse model metadata
- Renamed a lot of types that were confusing and/or had naming collisions
This commit is contained in:
psychedelicious 2024-02-22 17:33:20 +11:00
parent 79b16596b5
commit 3ed2963f43
16 changed files with 443 additions and 486 deletions

View File

@ -8,4 +8,26 @@ declare global {
} }
} }
/**
* Raised when the redux store is unable to be retrieved.
*/
export class ReduxStoreNotInitialized extends Error {
/**
* Create ReduxStoreNotInitialized
* @param {String} message
*/
constructor(message = 'Redux store not initialized') {
super(message);
this.name = this.constructor.name;
}
}
export const $store = atom<Readonly<ReturnType<typeof createStore>> | undefined>(); export const $store = atom<Readonly<ReturnType<typeof createStore>> | undefined>();
export const getStore = () => {
const store = $store.get();
if (!store) {
throw new ReduxStoreNotInitialized();
}
return store;
};

View File

@ -8,7 +8,7 @@ import { useControlAdapterType } from 'features/controlAdapters/hooks/useControl
import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice'; import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import { pick } from 'lodash-es'; import { pick } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react'; import { memo, useCallback, useMemo } from 'react';
import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'services/api/types'; import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
type ParamControlAdapterModelProps = { type ParamControlAdapterModelProps = {
id: string; id: string;
@ -24,7 +24,7 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
const { data, isLoading } = useControlAdapterModelQuery(controlAdapterType); const { data, isLoading } = useControlAdapterModelQuery(controlAdapterType);
const _onChange = useCallback( const _onChange = useCallback(
(model: ControlNetConfig | IPAdapterConfig | T2IAdapterConfig | null) => { (model: ControlNetModelConfig | IPAdapterModelConfig | T2IAdapterModelConfig | null) => {
if (!model) { if (!model) {
return; return;
} }

View File

@ -7,7 +7,7 @@ import { t } from 'i18next';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models'; import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
import type { TextualInversionConfig } from 'services/api/types'; import type { TextualInversionModelConfig } from 'services/api/types';
const noOptionsMessage = () => t('embedding.noMatchingEmbedding'); const noOptionsMessage = () => t('embedding.noMatchingEmbedding');
@ -17,7 +17,7 @@ export const EmbeddingSelect = memo(({ onSelect, onClose }: EmbeddingSelectProps
const currentBaseModel = useAppSelector((s) => s.generation.model?.base); const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
const getIsDisabled = useCallback( const getIsDisabled = useCallback(
(embedding: TextualInversionConfig): boolean => { (embedding: TextualInversionModelConfig): boolean => {
const isCompatible = currentBaseModel === embedding.base; const isCompatible = currentBaseModel === embedding.base;
const hasMainModel = Boolean(currentBaseModel); const hasMainModel = Boolean(currentBaseModel);
return !hasMainModel || !isCompatible; return !hasMainModel || !isCompatible;
@ -27,7 +27,7 @@ export const EmbeddingSelect = memo(({ onSelect, onClose }: EmbeddingSelectProps
const { data, isLoading } = useGetTextualInversionModelsQuery(); const { data, isLoading } = useGetTextualInversionModelsQuery();
const _onChange = useCallback( const _onChange = useCallback(
(embedding: TextualInversionConfig | null) => { (embedding: TextualInversionModelConfig | null) => {
if (!embedding) { if (!embedding) {
return; return;
} }

View File

@ -8,7 +8,7 @@ import { loraAdded, selectLoraSlice } from 'features/lora/store/loraSlice';
import { memo, useCallback, useMemo } from 'react'; import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models'; import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
import type { LoRAConfig } from 'services/api/types'; import type { LoRAModelConfig } from 'services/api/types';
const selectAddedLoRAs = createMemoizedSelector(selectLoraSlice, (lora) => lora.loras); const selectAddedLoRAs = createMemoizedSelector(selectLoraSlice, (lora) => lora.loras);
@ -19,7 +19,7 @@ const LoRASelect = () => {
const addedLoRAs = useAppSelector(selectAddedLoRAs); const addedLoRAs = useAppSelector(selectAddedLoRAs);
const currentBaseModel = useAppSelector((s) => s.generation.model?.base); const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
const getIsDisabled = (lora: LoRAConfig): boolean => { const getIsDisabled = (lora: LoRAModelConfig): boolean => {
const isCompatible = currentBaseModel === lora.base; const isCompatible = currentBaseModel === lora.base;
const isAdded = Boolean(addedLoRAs[lora.key]); const isAdded = Boolean(addedLoRAs[lora.key]);
const hasMainModel = Boolean(currentBaseModel); const hasMainModel = Boolean(currentBaseModel);
@ -27,7 +27,7 @@ const LoRASelect = () => {
}; };
const _onChange = useCallback( const _onChange = useCallback(
(lora: LoRAConfig | null) => { (lora: LoRAModelConfig | null) => {
if (!lora) { if (!lora) {
return; return;
} }

View File

@ -2,7 +2,7 @@ import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store'; import type { PersistConfig, RootState } from 'app/store/store';
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas'; import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
import type { LoRAConfig } from 'services/api/types'; import type { LoRAModelConfig } from 'services/api/types';
export type LoRA = ParameterLoRAModel & { export type LoRA = ParameterLoRAModel & {
weight: number; weight: number;
@ -28,13 +28,12 @@ export const loraSlice = createSlice({
name: 'lora', name: 'lora',
initialState: initialLoraState, initialState: initialLoraState,
reducers: { reducers: {
loraAdded: (state, action: PayloadAction<LoRAConfig>) => { loraAdded: (state, action: PayloadAction<LoRAModelConfig>) => {
const { key, base } = action.payload; const { key, base } = action.payload;
state.loras[key] = { key, base, ...defaultLoRAConfig }; state.loras[key] = { key, base, ...defaultLoRAConfig };
}, },
loraRecalled: (state, action: PayloadAction<LoRAConfig & { weight: number }>) => { loraRecalled: (state, action: PayloadAction<LoRA>) => {
const { key, base, weight } = action.payload; state.loras[action.payload.key] = action.payload;
state.loras[key] = { key, base, weight, isEnabled: true };
}, },
loraRemoved: (state, action: PayloadAction<string>) => { loraRemoved: (state, action: PayloadAction<string>) => {
const key = action.payload; const key = action.payload;

View File

@ -8,12 +8,11 @@ import { memo, useCallback } from 'react';
import type { SubmitHandler } from 'react-hook-form'; import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form'; import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import type { LoRAConfig } from 'services/api/endpoints/models';
import { useUpdateLoRAModelsMutation } from 'services/api/endpoints/models'; import { useUpdateLoRAModelsMutation } from 'services/api/endpoints/models';
import type { LoRAConfig } from 'services/api/types'; import type { LoRAModelConfig } from 'services/api/types';
type LoRAModelEditProps = { type LoRAModelEditProps = {
model: LoRAConfig; model: LoRAModelConfig;
}; };
const LoRAModelEdit = (props: LoRAModelEditProps) => { const LoRAModelEdit = (props: LoRAModelEditProps) => {
@ -30,7 +29,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
control, control,
formState: { errors }, formState: { errors },
reset, reset,
} = useForm<LoRAConfig>({ } = useForm<LoRAModelConfig>({
defaultValues: { defaultValues: {
model_name: model.model_name ? model.model_name : '', model_name: model.model_name ? model.model_name : '',
base_model: model.base_model, base_model: model.base_model,
@ -42,7 +41,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
mode: 'onChange', mode: 'onChange',
}); });
const onSubmit = useCallback<SubmitHandler<LoRAConfig>>( const onSubmit = useCallback<SubmitHandler<LoRAModelConfig>>(
(values) => { (values) => {
const responseBody = { const responseBody = {
base_model: model.base_model, base_model: model.base_model,
@ -53,7 +52,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
updateLoRAModel(responseBody) updateLoRAModel(responseBody)
.unwrap() .unwrap()
.then((payload) => { .then((payload) => {
reset(payload as LoRAConfig, { keepDefaultValues: true }); reset(payload as LoRAModelConfig, { keepDefaultValues: true });
dispatch( dispatch(
addToast( addToast(
makeToast({ makeToast({
@ -106,7 +105,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
<FormLabel>{t('modelManager.description')}</FormLabel> <FormLabel>{t('modelManager.description')}</FormLabel>
<Input {...register('description')} /> <Input {...register('description')} />
</FormControl> </FormControl>
<BaseModelSelect<LoRAConfig> control={control} name="base_model" /> <BaseModelSelect<LoRAModelConfig> control={control} name="base_model" />
<FormControl isInvalid={Boolean(errors.path)}> <FormControl isInvalid={Boolean(errors.path)}>
<FormLabel>{t('modelManager.modelLocation')}</FormLabel> <FormLabel>{t('modelManager.modelLocation')}</FormLabel>

View File

@ -6,7 +6,7 @@ import type { ControlNetModelFieldInputInstance, ControlNetModelFieldInputTempla
import { pick } from 'lodash-es'; import { pick } from 'lodash-es';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models'; import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
import type { ControlNetConfig } from 'services/api/types'; import type { ControlNetModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types'; import type { FieldComponentProps } from './types';
@ -18,7 +18,7 @@ const ControlNetModelFieldInputComponent = (props: Props) => {
const { data, isLoading } = useGetControlNetModelsQuery(); const { data, isLoading } = useGetControlNetModelsQuery();
const _onChange = useCallback( const _onChange = useCallback(
(value: ControlNetConfig | null) => { (value: ControlNetModelConfig | null) => {
if (!value) { if (!value) {
return; return;
} }

View File

@ -6,7 +6,7 @@ import type { IPAdapterModelFieldInputInstance, IPAdapterModelFieldInputTemplate
import { pick } from 'lodash-es'; import { pick } from 'lodash-es';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useGetIPAdapterModelsQuery } from 'services/api/endpoints/models'; import { useGetIPAdapterModelsQuery } from 'services/api/endpoints/models';
import type { IPAdapterConfig } from 'services/api/types'; import type { IPAdapterModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types'; import type { FieldComponentProps } from './types';
@ -18,7 +18,7 @@ const IPAdapterModelFieldInputComponent = (
const { data: ipAdapterModels } = useGetIPAdapterModelsQuery(); const { data: ipAdapterModels } = useGetIPAdapterModelsQuery();
const _onChange = useCallback( const _onChange = useCallback(
(value: IPAdapterConfig | null) => { (value: IPAdapterModelConfig | null) => {
if (!value) { if (!value) {
return; return;
} }

View File

@ -6,7 +6,7 @@ import type { LoRAModelFieldInputInstance, LoRAModelFieldInputTemplate } from 'f
import { pick } from 'lodash-es'; import { pick } from 'lodash-es';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models'; import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
import type { LoRAConfig } from 'services/api/types'; import type { LoRAModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types'; import type { FieldComponentProps } from './types';
@ -17,7 +17,7 @@ const LoRAModelFieldInputComponent = (props: Props) => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { data, isLoading } = useGetLoRAModelsQuery(); const { data, isLoading } = useGetLoRAModelsQuery();
const _onChange = useCallback( const _onChange = useCallback(
(value: LoRAConfig | null) => { (value: LoRAModelConfig | null) => {
if (!value) { if (!value) {
return; return;
} }

View File

@ -6,7 +6,7 @@ import type { T2IAdapterModelFieldInputInstance, T2IAdapterModelFieldInputTempla
import { pick } from 'lodash-es'; import { pick } from 'lodash-es';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useGetT2IAdapterModelsQuery } from 'services/api/endpoints/models'; import { useGetT2IAdapterModelsQuery } from 'services/api/endpoints/models';
import type { T2IAdapterConfig } from 'services/api/types'; import type { T2IAdapterModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types'; import type { FieldComponentProps } from './types';
@ -19,7 +19,7 @@ const T2IAdapterModelFieldInputComponent = (
const { data: t2iAdapterModels } = useGetT2IAdapterModelsQuery(); const { data: t2iAdapterModels } = useGetT2IAdapterModelsQuery();
const _onChange = useCallback( const _onChange = useCallback(
(value: T2IAdapterConfig | null) => { (value: T2IAdapterModelConfig | null) => {
if (!value) { if (!value) {
return; return;
} }

View File

@ -7,7 +7,7 @@ import type { VAEModelFieldInputInstance, VAEModelFieldInputTemplate } from 'fea
import { pick } from 'lodash-es'; import { pick } from 'lodash-es';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useGetVaeModelsQuery } from 'services/api/endpoints/models'; import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
import type { VAEConfig } from 'services/api/types'; import type { VAEModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types'; import type { FieldComponentProps } from './types';
@ -18,7 +18,7 @@ const VAEModelFieldInputComponent = (props: Props) => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { data, isLoading } = useGetVaeModelsQuery(); const { data, isLoading } = useGetVaeModelsQuery();
const _onChange = useCallback( const _onChange = useCallback(
(value: VAEConfig | null) => { (value: VAEModelConfig | null) => {
if (!value) { if (!value) {
return; return;
} }

View File

@ -8,7 +8,7 @@ import { pick } from 'lodash-es';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useGetVaeModelsQuery } from 'services/api/endpoints/models'; import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
import type { VAEConfig } from 'services/api/types'; import type { VAEModelConfig } from 'services/api/types';
const selector = createMemoizedSelector(selectGenerationSlice, (generation) => { const selector = createMemoizedSelector(selectGenerationSlice, (generation) => {
const { model, vae } = generation; const { model, vae } = generation;
@ -21,7 +21,7 @@ const ParamVAEModelSelect = () => {
const { model, vae } = useAppSelector(selector); const { model, vae } = useAppSelector(selector);
const { data, isLoading } = useGetVaeModelsQuery(); const { data, isLoading } = useGetVaeModelsQuery();
const getIsDisabled = useCallback( const getIsDisabled = useCallback(
(vae: VAEConfig): boolean => { (vae: VAEModelConfig): boolean => {
const isCompatible = model?.base === vae.base; const isCompatible = model?.base === vae.base;
const hasMainModel = Boolean(model?.base); const hasMainModel = Boolean(model?.base);
return !hasMainModel || !isCompatible; return !hasMainModel || !isCompatible;
@ -29,7 +29,7 @@ const ParamVAEModelSelect = () => {
[model?.base] [model?.base]
); );
const _onChange = useCallback( const _onChange = useCallback(
(vae: VAEConfig | null) => { (vae: VAEModelConfig | null) => {
dispatch(vaeSelected(vae ? pick(vae, 'key', 'base') : null)); dispatch(vaeSelected(vae ? pick(vae, 'key', 'base') : null));
}, },
[dispatch] [dispatch]

View File

@ -1,17 +1,9 @@
import { useAppToaster } from 'app/components/Toaster'; import { useAppToaster } from 'app/components/Toaster';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
import { controlAdapterRecalled, controlAdaptersReset } from 'features/controlAdapters/store/controlAdaptersSlice'; import { controlAdapterRecalled, controlAdaptersReset } from 'features/controlAdapters/store/controlAdaptersSlice';
import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types';
import {
initialControlNet,
initialIPAdapter,
initialT2IAdapter,
} from 'features/controlAdapters/util/buildControlAdapter';
import { setHrfEnabled, setHrfMethod, setHrfStrength } from 'features/hrf/store/hrfSlice'; import { setHrfEnabled, setHrfMethod, setHrfStrength } from 'features/hrf/store/hrfSlice';
import { loraRecalled, lorasCleared } from 'features/lora/store/loraSlice'; import { loraRecalled, lorasCleared } from 'features/lora/store/loraSlice';
import type { ModelIdentifier } from 'features/nodes/types/common';
import { isModelIdentifier } from 'features/nodes/types/common'; import { isModelIdentifier } from 'features/nodes/types/common';
import type { import type {
ControlNetMetadataItem, ControlNetMetadataItem,
@ -56,6 +48,14 @@ import {
isParameterStrength, isParameterStrength,
isParameterWidth, isParameterWidth,
} from 'features/parameters/types/parameterSchemas'; } from 'features/parameters/types/parameterSchemas';
import {
prepareControlNetMetadataItem,
prepareIPAdapterMetadataItem,
prepareLoRAMetadataItem,
prepareMainModelMetadataItem,
prepareT2IAdapterMetadataItem,
prepareVAEMetadataItem,
} from 'features/parameters/util/modelMetadataHelpers';
import { import {
refinerModelChanged, refinerModelChanged,
setNegativeStylePromptSDXL, setNegativeStylePromptSDXL,
@ -70,23 +70,7 @@ import {
import { isNil } from 'lodash-es'; import { isNil } from 'lodash-es';
import { useCallback } from 'react'; import { useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { ALL_BASE_MODELS } from 'services/api/constants';
import {
controlNetModelsAdapterSelectors,
ipAdapterModelsAdapterSelectors,
loraModelsAdapterSelectors,
mainModelsAdapterSelectors,
t2iAdapterModelsAdapterSelectors,
useGetControlNetModelsQuery,
useGetIPAdapterModelsQuery,
useGetLoRAModelsQuery,
useGetMainModelsQuery,
useGetT2IAdapterModelsQuery,
useGetVaeModelsQuery,
vaeModelsAdapterSelectors,
} from 'services/api/endpoints/models';
import type { ImageDTO } from 'services/api/types'; import type { ImageDTO } from 'services/api/types';
import { v4 as uuidv4 } from 'uuid';
const selectModel = createMemoizedSelector(selectGenerationSlice, (generation) => generation.model); const selectModel = createMemoizedSelector(selectGenerationSlice, (generation) => generation.model);
@ -140,9 +124,6 @@ export const useRecallParameters = () => {
[t, toaster] [t, toaster]
); );
/**
* Recall both prompts with toast
*/
const recallBothPrompts = useCallback( const recallBothPrompts = useCallback(
(positivePrompt: unknown, negativePrompt: unknown, positiveStylePrompt: unknown, negativeStylePrompt: unknown) => { (positivePrompt: unknown, negativePrompt: unknown, positiveStylePrompt: unknown, negativeStylePrompt: unknown) => {
if ( if (
@ -175,9 +156,6 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast] [dispatch, parameterSetToast, parameterNotSetToast]
); );
/**
* Recall positive prompt with toast
*/
const recallPositivePrompt = useCallback( const recallPositivePrompt = useCallback(
(positivePrompt: unknown) => { (positivePrompt: unknown) => {
if (!isParameterPositivePrompt(positivePrompt)) { if (!isParameterPositivePrompt(positivePrompt)) {
@ -190,9 +168,6 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast] [dispatch, parameterSetToast, parameterNotSetToast]
); );
/**
* Recall negative prompt with toast
*/
const recallNegativePrompt = useCallback( const recallNegativePrompt = useCallback(
(negativePrompt: unknown) => { (negativePrompt: unknown) => {
if (!isParameterNegativePrompt(negativePrompt)) { if (!isParameterNegativePrompt(negativePrompt)) {
@ -205,9 +180,6 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast] [dispatch, parameterSetToast, parameterNotSetToast]
); );
/**
* Recall SDXL Positive Style Prompt with toast
*/
const recallSDXLPositiveStylePrompt = useCallback( const recallSDXLPositiveStylePrompt = useCallback(
(positiveStylePrompt: unknown) => { (positiveStylePrompt: unknown) => {
if (!isParameterPositiveStylePromptSDXL(positiveStylePrompt)) { if (!isParameterPositiveStylePromptSDXL(positiveStylePrompt)) {
@ -220,9 +192,6 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast] [dispatch, parameterSetToast, parameterNotSetToast]
); );
/**
* Recall SDXL Negative Style Prompt with toast
*/
const recallSDXLNegativeStylePrompt = useCallback( const recallSDXLNegativeStylePrompt = useCallback(
(negativeStylePrompt: unknown) => { (negativeStylePrompt: unknown) => {
if (!isParameterNegativeStylePromptSDXL(negativeStylePrompt)) { if (!isParameterNegativeStylePromptSDXL(negativeStylePrompt)) {
@ -235,9 +204,6 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast] [dispatch, parameterSetToast, parameterNotSetToast]
); );
/**
* Recall seed with toast
*/
const recallSeed = useCallback( const recallSeed = useCallback(
(seed: unknown) => { (seed: unknown) => {
if (!isParameterSeed(seed)) { if (!isParameterSeed(seed)) {
@ -250,9 +216,6 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast] [dispatch, parameterSetToast, parameterNotSetToast]
); );
/**
* Recall CFG scale with toast
*/
const recallCfgScale = useCallback( const recallCfgScale = useCallback(
(cfgScale: unknown) => { (cfgScale: unknown) => {
if (!isParameterCFGScale(cfgScale)) { if (!isParameterCFGScale(cfgScale)) {
@ -265,9 +228,6 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast] [dispatch, parameterSetToast, parameterNotSetToast]
); );
/**
* Recall CFG rescale multiplier with toast
*/
const recallCfgRescaleMultiplier = useCallback( const recallCfgRescaleMultiplier = useCallback(
(cfgRescaleMultiplier: unknown) => { (cfgRescaleMultiplier: unknown) => {
if (!isParameterCFGRescaleMultiplier(cfgRescaleMultiplier)) { if (!isParameterCFGRescaleMultiplier(cfgRescaleMultiplier)) {
@ -280,9 +240,6 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast] [dispatch, parameterSetToast, parameterNotSetToast]
); );
/**
* Recall scheduler with toast
*/
const recallScheduler = useCallback( const recallScheduler = useCallback(
(scheduler: unknown) => { (scheduler: unknown) => {
if (!isParameterScheduler(scheduler)) { if (!isParameterScheduler(scheduler)) {
@ -295,9 +252,6 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast] [dispatch, parameterSetToast, parameterNotSetToast]
); );
/**
* Recall steps with toast
*/
const recallSteps = useCallback( const recallSteps = useCallback(
(steps: unknown) => { (steps: unknown) => {
if (!isParameterSteps(steps)) { if (!isParameterSteps(steps)) {
@ -310,9 +264,6 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast] [dispatch, parameterSetToast, parameterNotSetToast]
); );
/**
* Recall width with toast
*/
const recallWidth = useCallback( const recallWidth = useCallback(
(width: unknown) => { (width: unknown) => {
if (!isParameterWidth(width)) { if (!isParameterWidth(width)) {
@ -325,9 +276,6 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast] [dispatch, parameterSetToast, parameterNotSetToast]
); );
/**
* Recall height with toast
*/
const recallHeight = useCallback( const recallHeight = useCallback(
(height: unknown) => { (height: unknown) => {
if (!isParameterHeight(height)) { if (!isParameterHeight(height)) {
@ -340,9 +288,6 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast] [dispatch, parameterSetToast, parameterNotSetToast]
); );
/**
* Recall width and height with toast
*/
const recallWidthAndHeight = useCallback( const recallWidthAndHeight = useCallback(
(width: unknown, height: unknown) => { (width: unknown, height: unknown) => {
if (!isParameterWidth(width)) { if (!isParameterWidth(width)) {
@ -360,9 +305,6 @@ export const useRecallParameters = () => {
[dispatch, allParameterSetToast, allParameterNotSetToast] [dispatch, allParameterSetToast, allParameterNotSetToast]
); );
/**
* Recall strength with toast
*/
const recallStrength = useCallback( const recallStrength = useCallback(
(strength: unknown) => { (strength: unknown) => {
if (!isParameterStrength(strength)) { if (!isParameterStrength(strength)) {
@ -375,9 +317,6 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast] [dispatch, parameterSetToast, parameterNotSetToast]
); );
/**
* Recall high resolution enabled with toast
*/
const recallHrfEnabled = useCallback( const recallHrfEnabled = useCallback(
(hrfEnabled: unknown) => { (hrfEnabled: unknown) => {
if (!isParameterHRFEnabled(hrfEnabled)) { if (!isParameterHRFEnabled(hrfEnabled)) {
@ -390,9 +329,6 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast] [dispatch, parameterSetToast, parameterNotSetToast]
); );
/**
* Recall high resolution strength with toast
*/
const recallHrfStrength = useCallback( const recallHrfStrength = useCallback(
(hrfStrength: unknown) => { (hrfStrength: unknown) => {
if (!isParameterStrength(hrfStrength)) { if (!isParameterStrength(hrfStrength)) {
@ -405,9 +341,6 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast] [dispatch, parameterSetToast, parameterNotSetToast]
); );
/**
* Recall high resolution method with toast
*/
const recallHrfMethod = useCallback( const recallHrfMethod = useCallback(
(hrfMethod: unknown) => { (hrfMethod: unknown) => {
if (!isParameterHRFMethod(hrfMethod)) { if (!isParameterHRFMethod(hrfMethod)) {
@ -420,358 +353,95 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast] [dispatch, parameterSetToast, parameterNotSetToast]
); );
const { data: mainModels } = useGetMainModelsQuery(ALL_BASE_MODELS);
const prepareMainModelMetadataItem = useCallback(
(model: ModelIdentifier) => {
const matchingModel = mainModels ? mainModelsAdapterSelectors.selectById(mainModels, model.key) : undefined;
if (!matchingModel) {
return { model: null, error: 'Model is not installed' };
}
return { model: matchingModel, error: null };
},
[mainModels]
);
/**
* Recall model with toast
*/
const recallModel = useCallback( const recallModel = useCallback(
(model: unknown) => { async (modelMetadataItem: unknown) => {
if (!isModelIdentifier(model)) { try {
parameterNotSetToast(); const model = await prepareMainModelMetadataItem(modelMetadataItem);
return; dispatch(modelSelected(model));
}
const result = prepareMainModelMetadataItem(model);
if (!result.model) {
parameterNotSetToast(result.error);
return;
}
dispatch(modelSelected(result.model));
parameterSetToast(); parameterSetToast();
}, } catch (e) {
[prepareMainModelMetadataItem, dispatch, parameterSetToast, parameterNotSetToast] parameterNotSetToast((e as unknown as Error).message);
);
const { data: vaeModels } = useGetVaeModelsQuery();
const prepareVAEMetadataItem = useCallback(
(vae: ModelIdentifier, newModel?: ParameterModel) => {
const matchingModel = vaeModels ? vaeModelsAdapterSelectors.selectById(vaeModels, vae.key) : undefined;
if (!matchingModel) {
return { vae: null, error: 'VAE model is not installed' };
}
const isCompatibleBaseModel = matchingModel?.base === (newModel ?? model)?.base;
if (!isCompatibleBaseModel) {
return {
vae: null,
error: 'VAE incompatible with currently-selected model',
};
}
return { vae: matchingModel, error: null };
},
[model, vaeModels]
);
/**
* Recall vae model
*/
const recallVaeModel = useCallback(
(vae: unknown) => {
if (!isModelIdentifier(vae) && !isNil(vae)) {
parameterNotSetToast();
return; return;
} }
},
[dispatch, parameterSetToast, parameterNotSetToast]
);
if (isNil(vae)) { const recallVaeModel = useCallback(
async (vaeMetadataItem: unknown) => {
if (isNil(vaeMetadataItem)) {
dispatch(vaeSelected(null)); dispatch(vaeSelected(null));
parameterSetToast(); parameterSetToast();
return; return;
} }
try {
const result = prepareVAEMetadataItem(vae); const vae = await prepareVAEMetadataItem(vaeMetadataItem);
dispatch(vaeSelected(vae));
if (!result.vae) { parameterSetToast();
parameterNotSetToast(result.error); } catch (e) {
parameterNotSetToast((e as unknown as Error).message);
return; return;
} }
dispatch(vaeSelected(result.vae));
parameterSetToast();
}, },
[prepareVAEMetadataItem, dispatch, parameterSetToast, parameterNotSetToast] [dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall LoRA with toast
*/
const { data: loraModels } = useGetLoRAModelsQuery(undefined);
const prepareLoRAMetadataItem = useCallback(
(loraMetadataItem: LoRAMetadataItem, newModel?: ParameterModel) => {
if (!isModelIdentifier(loraMetadataItem.lora)) {
return { lora: null, error: 'Invalid LoRA model' };
}
const { lora } = loraMetadataItem;
const matchingLoRA = loraModels ? loraModelsAdapterSelectors.selectById(loraModels, lora.key) : undefined;
if (!matchingLoRA) {
return { lora: null, error: 'LoRA model is not installed' };
}
const isCompatibleBaseModel = matchingLoRA?.base === (newModel ?? model)?.base;
if (!isCompatibleBaseModel) {
return {
lora: null,
error: 'LoRA incompatible with currently-selected model',
};
}
return { lora: matchingLoRA, error: null };
},
[loraModels, model]
); );
const recallLoRA = useCallback( const recallLoRA = useCallback(
(loraMetadataItem: LoRAMetadataItem) => { async (loraMetadataItem: LoRAMetadataItem) => {
const result = prepareLoRAMetadataItem(loraMetadataItem); try {
const lora = await prepareLoRAMetadataItem(loraMetadataItem, model?.base);
if (!result.lora) { dispatch(loraRecalled(lora));
parameterNotSetToast(result.error); parameterSetToast();
} catch (e) {
parameterNotSetToast((e as unknown as Error).message);
return; return;
} }
dispatch(loraRecalled({ ...result.lora, weight: loraMetadataItem.weight }));
parameterSetToast();
}, },
[prepareLoRAMetadataItem, dispatch, parameterSetToast, parameterNotSetToast] [model?.base, dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall ControlNet with toast
*/
const { data: controlNetModels } = useGetControlNetModelsQuery(undefined);
const prepareControlNetMetadataItem = useCallback(
(controlnetMetadataItem: ControlNetMetadataItem, newModel?: ParameterModel) => {
if (!isModelIdentifier(controlnetMetadataItem.control_model)) {
return { controlnet: null, error: 'Invalid ControlNet model' };
}
const { image, control_model, control_weight, begin_step_percent, end_step_percent, control_mode, resize_mode } =
controlnetMetadataItem;
const matchingControlNetModel = controlNetModels
? controlNetModelsAdapterSelectors.selectById(controlNetModels, control_model.key)
: undefined;
if (!matchingControlNetModel) {
return { controlnet: null, error: 'ControlNet model is not installed' };
}
const isCompatibleBaseModel = matchingControlNetModel?.base === (newModel ?? model)?.base;
if (!isCompatibleBaseModel) {
return {
controlnet: null,
error: 'ControlNet incompatible with currently-selected model',
};
}
// We don't save the original image that was processed into a control image, only the processed image
const processorType = 'none';
const processorNode = CONTROLNET_PROCESSORS.none.default;
const controlnet: ControlNetConfig = {
type: 'controlnet',
isEnabled: true,
model: matchingControlNetModel,
weight: typeof control_weight === 'number' ? control_weight : initialControlNet.weight,
beginStepPct: begin_step_percent || initialControlNet.beginStepPct,
endStepPct: end_step_percent || initialControlNet.endStepPct,
controlMode: control_mode || initialControlNet.controlMode,
resizeMode: resize_mode || initialControlNet.resizeMode,
controlImage: image?.image_name || null,
processedControlImage: image?.image_name || null,
processorType,
processorNode,
shouldAutoConfig: true,
id: uuidv4(),
};
return { controlnet, error: null };
},
[controlNetModels, model]
); );
const recallControlNet = useCallback( const recallControlNet = useCallback(
(controlnetMetadataItem: ControlNetMetadataItem) => { async (controlnetMetadataItem: ControlNetMetadataItem) => {
const result = prepareControlNetMetadataItem(controlnetMetadataItem); try {
const controlNetConfig = await prepareControlNetMetadataItem(controlnetMetadataItem, model?.base);
if (!result.controlnet) { dispatch(controlAdapterRecalled(controlNetConfig));
parameterNotSetToast(result.error); parameterSetToast();
} catch (e) {
parameterNotSetToast((e as unknown as Error).message);
return; return;
} }
dispatch(controlAdapterRecalled(result.controlnet));
parameterSetToast();
}, },
[prepareControlNetMetadataItem, dispatch, parameterSetToast, parameterNotSetToast] [model?.base, dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall T2I Adapter with toast
*/
const { data: t2iAdapterModels } = useGetT2IAdapterModelsQuery(undefined);
const prepareT2IAdapterMetadataItem = useCallback(
(t2iAdapterMetadataItem: T2IAdapterMetadataItem, newModel?: ParameterModel) => {
if (!isModelIdentifier(t2iAdapterMetadataItem.t2i_adapter_model)) {
return { controlnet: null, error: 'Invalid ControlNet model' };
}
const { image, t2i_adapter_model, weight, begin_step_percent, end_step_percent, resize_mode } =
t2iAdapterMetadataItem;
const matchingT2IAdapterModel = t2iAdapterModels
? t2iAdapterModelsAdapterSelectors.selectById(t2iAdapterModels, t2i_adapter_model.key)
: undefined;
if (!matchingT2IAdapterModel) {
return { controlnet: null, error: 'ControlNet model is not installed' };
}
const isCompatibleBaseModel = matchingT2IAdapterModel?.base === (newModel ?? model)?.base;
if (!isCompatibleBaseModel) {
return {
t2iAdapter: null,
error: 'ControlNet incompatible with currently-selected model',
};
}
// We don't save the original image that was processed into a control image, only the processed image
const processorType = 'none';
const processorNode = CONTROLNET_PROCESSORS.none.default;
const t2iAdapter: T2IAdapterConfig = {
type: 't2i_adapter',
isEnabled: true,
model: matchingT2IAdapterModel,
weight: typeof weight === 'number' ? weight : initialT2IAdapter.weight,
beginStepPct: begin_step_percent || initialT2IAdapter.beginStepPct,
endStepPct: end_step_percent || initialT2IAdapter.endStepPct,
resizeMode: resize_mode || initialT2IAdapter.resizeMode,
controlImage: image?.image_name || null,
processedControlImage: image?.image_name || null,
processorType,
processorNode,
shouldAutoConfig: true,
id: uuidv4(),
};
return { t2iAdapter, error: null };
},
[model, t2iAdapterModels]
); );
const recallT2IAdapter = useCallback( const recallT2IAdapter = useCallback(
(t2iAdapterMetadataItem: T2IAdapterMetadataItem) => { async (t2iAdapterMetadataItem: T2IAdapterMetadataItem) => {
const result = prepareT2IAdapterMetadataItem(t2iAdapterMetadataItem); try {
const t2iAdapterConfig = await prepareT2IAdapterMetadataItem(t2iAdapterMetadataItem, model?.base);
if (!result.t2iAdapter) { dispatch(controlAdapterRecalled(t2iAdapterConfig));
parameterNotSetToast(result.error); parameterSetToast();
} catch (e) {
parameterNotSetToast((e as unknown as Error).message);
return; return;
} }
dispatch(controlAdapterRecalled(result.t2iAdapter));
parameterSetToast();
}, },
[prepareT2IAdapterMetadataItem, dispatch, parameterSetToast, parameterNotSetToast] [model?.base, dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall IP Adapter with toast
*/
const { data: ipAdapterModels } = useGetIPAdapterModelsQuery(undefined);
const prepareIPAdapterMetadataItem = useCallback(
(ipAdapterMetadataItem: IPAdapterMetadataItem, newModel?: ParameterModel) => {
if (!isModelIdentifier(ipAdapterMetadataItem?.ip_adapter_model)) {
return { ipAdapter: null, error: 'Invalid IP Adapter model' };
}
const { image, ip_adapter_model, weight, begin_step_percent, end_step_percent } = ipAdapterMetadataItem;
const matchingIPAdapterModel = ipAdapterModels
? ipAdapterModelsAdapterSelectors.selectById(ipAdapterModels, ip_adapter_model.key)
: undefined;
if (!matchingIPAdapterModel) {
return { ipAdapter: null, error: 'IP Adapter model is not installed' };
}
const isCompatibleBaseModel = matchingIPAdapterModel?.base === (newModel ?? model)?.base;
if (!isCompatibleBaseModel) {
return {
ipAdapter: null,
error: 'IP Adapter incompatible with currently-selected model',
};
}
const ipAdapter: IPAdapterConfig = {
id: uuidv4(),
type: 'ip_adapter',
isEnabled: true,
controlImage: image?.image_name ?? null,
model: matchingIPAdapterModel,
weight: weight ?? initialIPAdapter.weight,
beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct,
endStepPct: end_step_percent ?? initialIPAdapter.endStepPct,
};
return { ipAdapter, error: null };
},
[ipAdapterModels, model]
); );
const recallIPAdapter = useCallback( const recallIPAdapter = useCallback(
(ipAdapterMetadataItem: IPAdapterMetadataItem) => { async (ipAdapterMetadataItem: IPAdapterMetadataItem) => {
const result = prepareIPAdapterMetadataItem(ipAdapterMetadataItem); try {
const ipAdapterConfig = await prepareIPAdapterMetadataItem(ipAdapterMetadataItem, model?.base);
if (!result.ipAdapter) { dispatch(controlAdapterRecalled(ipAdapterConfig));
parameterNotSetToast(result.error); parameterSetToast();
} catch (e) {
parameterNotSetToast((e as unknown as Error).message);
return; return;
} }
dispatch(controlAdapterRecalled(result.ipAdapter));
parameterSetToast();
}, },
[prepareIPAdapterMetadataItem, dispatch, parameterSetToast, parameterNotSetToast] [model?.base, dispatch, parameterSetToast, parameterNotSetToast]
); );
/*
* Sets image as initial image with toast
*/
const sendToImageToImage = useCallback( const sendToImageToImage = useCallback(
(image: ImageDTO) => { (image: ImageDTO) => {
dispatch(initialImageSelected(image)); dispatch(initialImageSelected(image));
@ -780,7 +450,7 @@ export const useRecallParameters = () => {
); );
const recallAllParameters = useCallback( const recallAllParameters = useCallback(
(metadata: CoreMetadata | undefined) => { async (metadata: CoreMetadata | undefined) => {
if (!metadata) { if (!metadata) {
allParameterNotSetToast(); allParameterNotSetToast();
return; return;
@ -820,10 +490,12 @@ export const useRecallParameters = () => {
let newModel: ParameterModel | undefined = undefined; let newModel: ParameterModel | undefined = undefined;
if (isModelIdentifier(model)) { if (isModelIdentifier(model)) {
const result = prepareMainModelMetadataItem(model); try {
if (result.model) { const _model = await prepareMainModelMetadataItem(model);
dispatch(modelSelected(result.model)); dispatch(modelSelected(_model));
newModel = result.model; newModel = _model;
} catch {
return;
} }
} }
@ -850,9 +522,11 @@ export const useRecallParameters = () => {
if (isNil(vae)) { if (isNil(vae)) {
dispatch(vaeSelected(null)); dispatch(vaeSelected(null));
} else { } else {
const result = prepareVAEMetadataItem(vae, newModel); try {
if (result.vae) { const _vae = await prepareVAEMetadataItem(vae, newModel?.base);
dispatch(vaeSelected(result.vae)); dispatch(vaeSelected(_vae));
} catch {
return;
} }
} }
} }
@ -926,48 +600,46 @@ export const useRecallParameters = () => {
} }
dispatch(lorasCleared()); dispatch(lorasCleared());
loras?.forEach((lora) => { loras?.forEach(async (loraMetadataItem) => {
const result = prepareLoRAMetadataItem(lora, newModel); try {
if (result.lora) { const lora = await prepareLoRAMetadataItem(loraMetadataItem, newModel?.base);
dispatch(loraRecalled({ ...result.lora, weight: lora.weight })); dispatch(loraRecalled(lora));
} catch {
return;
} }
}); });
dispatch(controlAdaptersReset()); dispatch(controlAdaptersReset());
controlnets?.forEach((controlnet) => { controlnets?.forEach(async (controlNetMetadataItem) => {
const result = prepareControlNetMetadataItem(controlnet, newModel); try {
if (result.controlnet) { const controlNet = await prepareControlNetMetadataItem(controlNetMetadataItem, newModel?.base);
dispatch(controlAdapterRecalled(result.controlnet)); dispatch(controlAdapterRecalled(controlNet));
} catch {
return;
} }
}); });
ipAdapters?.forEach((ipAdapter) => { ipAdapters?.forEach(async (ipAdapterMetadataItem) => {
const result = prepareIPAdapterMetadataItem(ipAdapter, newModel); try {
if (result.ipAdapter) { const ipAdapter = await prepareIPAdapterMetadataItem(ipAdapterMetadataItem, newModel?.base);
dispatch(controlAdapterRecalled(result.ipAdapter)); dispatch(controlAdapterRecalled(ipAdapter));
} catch {
return;
} }
}); });
t2iAdapters?.forEach((t2iAdapter) => { t2iAdapters?.forEach(async (t2iAdapterMetadataItem) => {
const result = prepareT2IAdapterMetadataItem(t2iAdapter, newModel); try {
if (result.t2iAdapter) { const t2iAdapter = await prepareT2IAdapterMetadataItem(t2iAdapterMetadataItem, newModel?.base);
dispatch(controlAdapterRecalled(result.t2iAdapter)); dispatch(controlAdapterRecalled(t2iAdapter));
} catch {
return;
} }
}); });
allParameterSetToast(); allParameterSetToast();
}, },
[ [dispatch, allParameterSetToast, allParameterNotSetToast]
dispatch,
allParameterSetToast,
allParameterNotSetToast,
prepareMainModelMetadataItem,
prepareVAEMetadataItem,
prepareLoRAMetadataItem,
prepareControlNetMetadataItem,
prepareIPAdapterMetadataItem,
prepareT2IAdapterMetadataItem,
]
); );
return { return {

View File

@ -0,0 +1,113 @@
import { getStore } from 'app/store/nanostores/store';
import { isModelIdentifier } from 'features/nodes/types/common';
import { modelsApi } from 'services/api/endpoints/models';
import type { AnyModelConfig, BaseModelType } from 'services/api/types';
import {
isControlNetModelConfig,
isIPAdapterModelConfig,
isLoRAModelConfig,
isNonRefinerMainModelConfig,
isRefinerMainModelModelConfig,
isT2IAdapterModelConfig,
isTextualInversionModelConfig,
isVAEModelConfig,
} from 'services/api/types';
/**
* Raised when a model config is unable to be fetched.
*/
export class ModelConfigNotFoundError extends Error {
/**
* Create ModelConfigNotFoundError
* @param {String} message
*/
constructor(message: string) {
super(message);
this.name = this.constructor.name;
}
}
/**
* Raised when a fetched model config is of an unexpected type.
*/
export class InvalidModelConfigError extends Error {
/**
* Create InvalidModelConfigError
* @param {String} message
*/
constructor(message: string) {
super(message);
this.name = this.constructor.name;
}
}
export const fetchModelConfig = async (key: string): Promise<AnyModelConfig> => {
const { dispatch } = getStore();
try {
const req = dispatch(modelsApi.endpoints.getModelConfig.initiate(key));
req.unsubscribe();
return await req.unwrap();
} catch {
throw new ModelConfigNotFoundError(`Unable to retrieve model config for key ${key}`);
}
};
export const fetchModelConfigWithTypeGuard = async <T extends AnyModelConfig>(
key: string,
typeGuard: (config: AnyModelConfig) => config is T
) => {
const modelConfig = await fetchModelConfig(key);
if (!typeGuard(modelConfig)) {
throw new InvalidModelConfigError(`Invalid model type for key ${key}: ${modelConfig.type}`);
}
return modelConfig;
};
export const fetchMainModel = async (key: string) => {
return fetchModelConfigWithTypeGuard(key, isNonRefinerMainModelConfig);
};
export const fetchRefinerModel = async (key: string) => {
return fetchModelConfigWithTypeGuard(key, isRefinerMainModelModelConfig);
};
export const fetchVAEModel = async (key: string) => {
return fetchModelConfigWithTypeGuard(key, isVAEModelConfig);
};
export const fetchLoRAModel = async (key: string) => {
return fetchModelConfigWithTypeGuard(key, isLoRAModelConfig);
};
export const fetchControlNetModel = async (key: string) => {
return fetchModelConfigWithTypeGuard(key, isControlNetModelConfig);
};
export const fetchIPAdapterModel = async (key: string) => {
return fetchModelConfigWithTypeGuard(key, isIPAdapterModelConfig);
};
export const fetchT2IAdapterModel = async (key: string) => {
return fetchModelConfigWithTypeGuard(key, isT2IAdapterModelConfig);
};
export const fetchTextualInversionModel = async (key: string) => {
return fetchModelConfigWithTypeGuard(key, isTextualInversionModelConfig);
};
export const isBaseCompatible = (sourceBase: BaseModelType, targetBase: BaseModelType) => {
return sourceBase === targetBase;
};
export const raiseIfBaseIncompatible = (sourceBase: BaseModelType, targetBase?: BaseModelType, message?: string) => {
if (targetBase && !isBaseCompatible(sourceBase, targetBase)) {
throw new InvalidModelConfigError(message || `Incompatible base models: ${sourceBase} and ${targetBase}`);
}
};
export const getModelKey = (modelIdentifier: unknown, message?: string): string => {
if (!isModelIdentifier(modelIdentifier)) {
throw new InvalidModelConfigError(message || `Invalid model identifier: ${modelIdentifier}`);
}
return modelIdentifier.key;
};

View File

@ -0,0 +1,150 @@
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types';
import {
initialControlNet,
initialIPAdapter,
initialT2IAdapter,
} from 'features/controlAdapters/util/buildControlAdapter';
import type { LoRA } from 'features/lora/store/loraSlice';
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
import { zModelIdentifierWithBase } from 'features/nodes/types/common';
import type {
ControlNetMetadataItem,
IPAdapterMetadataItem,
LoRAMetadataItem,
T2IAdapterMetadataItem,
} from 'features/nodes/types/metadata';
import {
fetchControlNetModel,
fetchIPAdapterModel,
fetchLoRAModel,
fetchMainModel,
fetchRefinerModel,
fetchT2IAdapterModel,
fetchVAEModel,
getModelKey,
raiseIfBaseIncompatible,
} from 'features/parameters/util/modelFetchingHelpers';
import type { BaseModelType } from 'services/api/types';
import { v4 as uuidv4 } from 'uuid';
export const prepareMainModelMetadataItem = async (model: unknown): Promise<ModelIdentifierWithBase> => {
const key = getModelKey(model);
const mainModel = await fetchMainModel(key);
return zModelIdentifierWithBase.parse(mainModel);
};
export const prepareRefinerMetadataItem = async (model: unknown): Promise<ModelIdentifierWithBase> => {
const key = getModelKey(model);
const refinerModel = await fetchRefinerModel(key);
return zModelIdentifierWithBase.parse(refinerModel);
};
export const prepareVAEMetadataItem = async (vae: unknown, base?: BaseModelType): Promise<ModelIdentifierWithBase> => {
const key = getModelKey(vae);
const vaeModel = await fetchVAEModel(key);
raiseIfBaseIncompatible(vaeModel.base, base, 'VAE incompatible with currently-selected model');
return zModelIdentifierWithBase.parse(vaeModel);
};
export const prepareLoRAMetadataItem = async (
loraMetadataItem: LoRAMetadataItem,
base?: BaseModelType
): Promise<LoRA> => {
const key = getModelKey(loraMetadataItem.lora);
const loraModel = await fetchLoRAModel(key);
raiseIfBaseIncompatible(loraModel.base, base, 'LoRA incompatible with currently-selected model');
return { key: loraModel.key, base: loraModel.base, weight: loraMetadataItem.weight, isEnabled: true };
};
export const prepareControlNetMetadataItem = async (
controlnetMetadataItem: ControlNetMetadataItem,
base?: BaseModelType
): Promise<ControlNetConfig> => {
const key = getModelKey(controlnetMetadataItem.control_model);
const controlNetModel = await fetchControlNetModel(key);
raiseIfBaseIncompatible(controlNetModel.base, base, 'ControlNet incompatible with currently-selected model');
const { image, control_weight, begin_step_percent, end_step_percent, control_mode, resize_mode } =
controlnetMetadataItem;
// We don't save the original image that was processed into a control image, only the processed image
const processorType = 'none';
const processorNode = CONTROLNET_PROCESSORS.none.default;
const controlnet: ControlNetConfig = {
type: 'controlnet',
isEnabled: true,
model: zModelIdentifierWithBase.parse(controlNetModel),
weight: typeof control_weight === 'number' ? control_weight : initialControlNet.weight,
beginStepPct: begin_step_percent || initialControlNet.beginStepPct,
endStepPct: end_step_percent || initialControlNet.endStepPct,
controlMode: control_mode || initialControlNet.controlMode,
resizeMode: resize_mode || initialControlNet.resizeMode,
controlImage: image?.image_name || null,
processedControlImage: image?.image_name || null,
processorType,
processorNode,
shouldAutoConfig: true,
id: uuidv4(),
};
return controlnet;
};
export const prepareT2IAdapterMetadataItem = async (
t2iAdapterMetadataItem: T2IAdapterMetadataItem,
base?: BaseModelType
): Promise<T2IAdapterConfig> => {
const key = getModelKey(t2iAdapterMetadataItem.t2i_adapter_model);
const t2iAdapterModel = await fetchT2IAdapterModel(key);
raiseIfBaseIncompatible(t2iAdapterModel.base, base, 'T2I Adapter incompatible with currently-selected model');
const { image, weight, begin_step_percent, end_step_percent, resize_mode } = t2iAdapterMetadataItem;
// We don't save the original image that was processed into a control image, only the processed image
const processorType = 'none';
const processorNode = CONTROLNET_PROCESSORS.none.default;
const t2iAdapter: T2IAdapterConfig = {
type: 't2i_adapter',
isEnabled: true,
model: zModelIdentifierWithBase.parse(t2iAdapterModel),
weight: typeof weight === 'number' ? weight : initialT2IAdapter.weight,
beginStepPct: begin_step_percent || initialT2IAdapter.beginStepPct,
endStepPct: end_step_percent || initialT2IAdapter.endStepPct,
resizeMode: resize_mode || initialT2IAdapter.resizeMode,
controlImage: image?.image_name || null,
processedControlImage: image?.image_name || null,
processorType,
processorNode,
shouldAutoConfig: true,
id: uuidv4(),
};
return t2iAdapter;
};
export const prepareIPAdapterMetadataItem = async (
ipAdapterMetadataItem: IPAdapterMetadataItem,
base?: BaseModelType
): Promise<IPAdapterConfig> => {
const key = getModelKey(ipAdapterMetadataItem?.ip_adapter_model);
const ipAdapterModel = await fetchIPAdapterModel(key);
raiseIfBaseIncompatible(ipAdapterModel.base, base, 'T2I Adapter incompatible with currently-selected model');
const { image, weight, begin_step_percent, end_step_percent } = ipAdapterMetadataItem;
const ipAdapter: IPAdapterConfig = {
id: uuidv4(),
type: 'ip_adapter',
isEnabled: true,
controlImage: image?.image_name ?? null,
model: zModelIdentifierWithBase.parse(ipAdapterModel),
weight: weight ?? initialIPAdapter.weight,
beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct,
endStepPct: end_step_percent ?? initialIPAdapter.endStepPct,
};
return ipAdapter;
};

View File

@ -6,16 +6,16 @@ import type { operations, paths } from 'services/api/schema';
import type { import type {
AnyModelConfig, AnyModelConfig,
BaseModelType, BaseModelType,
ControlNetConfig, ControlNetModelConfig,
ImportModelConfig, ImportModelConfig,
IPAdapterConfig, IPAdapterModelConfig,
LoRAConfig, LoRAModelConfig,
MainModelConfig, MainModelConfig,
MergeModelConfig, MergeModelConfig,
ModelType, ModelType,
T2IAdapterConfig, T2IAdapterModelConfig,
TextualInversionConfig, TextualInversionModelConfig,
VAEConfig, VAEModelConfig,
} from 'services/api/types'; } from 'services/api/types';
import type { ApiTagDescription, tagTypes } from '..'; import type { ApiTagDescription, tagTypes } from '..';
@ -30,7 +30,7 @@ type UpdateMainModelArg = {
type UpdateLoRAModelArg = { type UpdateLoRAModelArg = {
base_model: BaseModelType; base_model: BaseModelType;
model_name: string; model_name: string;
body: LoRAConfig; body: LoRAModelConfig;
}; };
type UpdateMainModelResponse = type UpdateMainModelResponse =
@ -97,27 +97,27 @@ export const mainModelsAdapter = createEntityAdapter<MainModelConfig, string>({
sortComparer: (a, b) => a.name.localeCompare(b.name), sortComparer: (a, b) => a.name.localeCompare(b.name),
}); });
export const mainModelsAdapterSelectors = mainModelsAdapter.getSelectors(undefined, getSelectorsOptions); export const mainModelsAdapterSelectors = mainModelsAdapter.getSelectors(undefined, getSelectorsOptions);
export const loraModelsAdapter = createEntityAdapter<LoRAConfig, string>({ export const loraModelsAdapter = createEntityAdapter<LoRAModelConfig, string>({
selectId: (entity) => entity.key, selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name), sortComparer: (a, b) => a.name.localeCompare(b.name),
}); });
export const loraModelsAdapterSelectors = loraModelsAdapter.getSelectors(undefined, getSelectorsOptions); export const loraModelsAdapterSelectors = loraModelsAdapter.getSelectors(undefined, getSelectorsOptions);
export const controlNetModelsAdapter = createEntityAdapter<ControlNetConfig, string>({ export const controlNetModelsAdapter = createEntityAdapter<ControlNetModelConfig, string>({
selectId: (entity) => entity.key, selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name), sortComparer: (a, b) => a.name.localeCompare(b.name),
}); });
export const controlNetModelsAdapterSelectors = controlNetModelsAdapter.getSelectors(undefined, getSelectorsOptions); export const controlNetModelsAdapterSelectors = controlNetModelsAdapter.getSelectors(undefined, getSelectorsOptions);
export const ipAdapterModelsAdapter = createEntityAdapter<IPAdapterConfig, string>({ export const ipAdapterModelsAdapter = createEntityAdapter<IPAdapterModelConfig, string>({
selectId: (entity) => entity.key, selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name), sortComparer: (a, b) => a.name.localeCompare(b.name),
}); });
export const ipAdapterModelsAdapterSelectors = ipAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions); export const ipAdapterModelsAdapterSelectors = ipAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions);
export const t2iAdapterModelsAdapter = createEntityAdapter<T2IAdapterConfig, string>({ export const t2iAdapterModelsAdapter = createEntityAdapter<T2IAdapterModelConfig, string>({
selectId: (entity) => entity.key, selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name), sortComparer: (a, b) => a.name.localeCompare(b.name),
}); });
export const t2iAdapterModelsAdapterSelectors = t2iAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions); export const t2iAdapterModelsAdapterSelectors = t2iAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions);
export const textualInversionModelsAdapter = createEntityAdapter<TextualInversionConfig, string>({ export const textualInversionModelsAdapter = createEntityAdapter<TextualInversionModelConfig, string>({
selectId: (entity) => entity.key, selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name), sortComparer: (a, b) => a.name.localeCompare(b.name),
}); });
@ -125,7 +125,7 @@ export const textualInversionModelsAdapterSelectors = textualInversionModelsAdap
undefined, undefined,
getSelectorsOptions getSelectorsOptions
); );
export const vaeModelsAdapter = createEntityAdapter<VAEConfig, string>({ export const vaeModelsAdapter = createEntityAdapter<VAEModelConfig, string>({
selectId: (entity) => entity.key, selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name), sortComparer: (a, b) => a.name.localeCompare(b.name),
}); });
@ -162,6 +162,8 @@ const buildTransformResponse =
*/ */
const buildModelsUrl = (path: string = '') => buildV2Url(`models/${path}`); const buildModelsUrl = (path: string = '') => buildV2Url(`models/${path}`);
// TODO(psyche): Ideally we can share the cache between the `getXYZModels` queries and `getModelConfig` query
export const modelsApi = api.injectEndpoints({ export const modelsApi = api.injectEndpoints({
endpoints: (build) => ({ endpoints: (build) => ({
getMainModels: build.query<EntityState<MainModelConfig, string>, BaseModelType[]>({ getMainModels: build.query<EntityState<MainModelConfig, string>, BaseModelType[]>({
@ -257,10 +259,10 @@ export const modelsApi = api.injectEndpoints({
}, },
invalidatesTags: ['Model'], invalidatesTags: ['Model'],
}), }),
getLoRAModels: build.query<EntityState<LoRAConfig, string>, void>({ getLoRAModels: build.query<EntityState<LoRAModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'lora' } }), query: () => ({ url: buildModelsUrl(), params: { model_type: 'lora' } }),
providesTags: buildProvidesTags<LoRAConfig>('LoRAModel'), providesTags: buildProvidesTags<LoRAModelConfig>('LoRAModel'),
transformResponse: buildTransformResponse<LoRAConfig>(loraModelsAdapter), transformResponse: buildTransformResponse<LoRAModelConfig>(loraModelsAdapter),
}), }),
updateLoRAModels: build.mutation<UpdateLoRAModelResponse, UpdateLoRAModelArg>({ updateLoRAModels: build.mutation<UpdateLoRAModelResponse, UpdateLoRAModelArg>({
query: ({ base_model, model_name, body }) => { query: ({ base_model, model_name, body }) => {
@ -281,30 +283,30 @@ export const modelsApi = api.injectEndpoints({
}, },
invalidatesTags: [{ type: 'LoRAModel', id: LIST_TAG }], invalidatesTags: [{ type: 'LoRAModel', id: LIST_TAG }],
}), }),
getControlNetModels: build.query<EntityState<ControlNetConfig, string>, void>({ getControlNetModels: build.query<EntityState<ControlNetModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'controlnet' } }), query: () => ({ url: buildModelsUrl(), params: { model_type: 'controlnet' } }),
providesTags: buildProvidesTags<ControlNetConfig>('ControlNetModel'), providesTags: buildProvidesTags<ControlNetModelConfig>('ControlNetModel'),
transformResponse: buildTransformResponse<ControlNetConfig>(controlNetModelsAdapter), transformResponse: buildTransformResponse<ControlNetModelConfig>(controlNetModelsAdapter),
}), }),
getIPAdapterModels: build.query<EntityState<IPAdapterConfig, string>, void>({ getIPAdapterModels: build.query<EntityState<IPAdapterModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'ip_adapter' } }), query: () => ({ url: buildModelsUrl(), params: { model_type: 'ip_adapter' } }),
providesTags: buildProvidesTags<IPAdapterConfig>('IPAdapterModel'), providesTags: buildProvidesTags<IPAdapterModelConfig>('IPAdapterModel'),
transformResponse: buildTransformResponse<IPAdapterConfig>(ipAdapterModelsAdapter), transformResponse: buildTransformResponse<IPAdapterModelConfig>(ipAdapterModelsAdapter),
}), }),
getT2IAdapterModels: build.query<EntityState<T2IAdapterConfig, string>, void>({ getT2IAdapterModels: build.query<EntityState<T2IAdapterModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 't2i_adapter' } }), query: () => ({ url: buildModelsUrl(), params: { model_type: 't2i_adapter' } }),
providesTags: buildProvidesTags<T2IAdapterConfig>('T2IAdapterModel'), providesTags: buildProvidesTags<T2IAdapterModelConfig>('T2IAdapterModel'),
transformResponse: buildTransformResponse<T2IAdapterConfig>(t2iAdapterModelsAdapter), transformResponse: buildTransformResponse<T2IAdapterModelConfig>(t2iAdapterModelsAdapter),
}), }),
getVaeModels: build.query<EntityState<VAEConfig, string>, void>({ getVaeModels: build.query<EntityState<VAEModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'vae' } }), query: () => ({ url: buildModelsUrl(), params: { model_type: 'vae' } }),
providesTags: buildProvidesTags<VAEConfig>('VaeModel'), providesTags: buildProvidesTags<VAEModelConfig>('VaeModel'),
transformResponse: buildTransformResponse<VAEConfig>(vaeModelsAdapter), transformResponse: buildTransformResponse<VAEModelConfig>(vaeModelsAdapter),
}), }),
getTextualInversionModels: build.query<EntityState<TextualInversionConfig, string>, void>({ getTextualInversionModels: build.query<EntityState<TextualInversionModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'embedding' } }), query: () => ({ url: buildModelsUrl(), params: { model_type: 'embedding' } }),
providesTags: buildProvidesTags<TextualInversionConfig>('TextualInversionModel'), providesTags: buildProvidesTags<TextualInversionModelConfig>('TextualInversionModel'),
transformResponse: buildTransformResponse<TextualInversionConfig>(textualInversionModelsAdapter), transformResponse: buildTransformResponse<TextualInversionModelConfig>(textualInversionModelsAdapter),
}), }),
getModelsInFolder: build.query<SearchFolderResponse, SearchFolderArg>({ getModelsInFolder: build.query<SearchFolderResponse, SearchFolderArg>({
query: (arg) => { query: (arg) => {