mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): refactor metadata handling
Refactor of metadata recall handling. This is in preparation for a backwards compatibility layer for models. - Create helpers to fetch a model outside react (e.g. not in a hook) - Created helpers to parse model metadata - Renamed a lot of types that were confusing and/or had naming collisions
This commit is contained in:
parent
79b16596b5
commit
3ed2963f43
@ -8,4 +8,26 @@ declare global {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Raised when the redux store is unable to be retrieved.
|
||||
*/
|
||||
export class ReduxStoreNotInitialized extends Error {
|
||||
/**
|
||||
* Create ReduxStoreNotInitialized
|
||||
* @param {String} message
|
||||
*/
|
||||
constructor(message = 'Redux store not initialized') {
|
||||
super(message);
|
||||
this.name = this.constructor.name;
|
||||
}
|
||||
}
|
||||
|
||||
export const $store = atom<Readonly<ReturnType<typeof createStore>> | undefined>();
|
||||
|
||||
export const getStore = () => {
|
||||
const store = $store.get();
|
||||
if (!store) {
|
||||
throw new ReduxStoreNotInitialized();
|
||||
}
|
||||
return store;
|
||||
};
|
||||
|
@ -8,7 +8,7 @@ import { useControlAdapterType } from 'features/controlAdapters/hooks/useControl
|
||||
import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { pick } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'services/api/types';
|
||||
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
type ParamControlAdapterModelProps = {
|
||||
id: string;
|
||||
@ -24,7 +24,7 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
|
||||
const { data, isLoading } = useControlAdapterModelQuery(controlAdapterType);
|
||||
|
||||
const _onChange = useCallback(
|
||||
(model: ControlNetConfig | IPAdapterConfig | T2IAdapterConfig | null) => {
|
||||
(model: ControlNetModelConfig | IPAdapterModelConfig | T2IAdapterModelConfig | null) => {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
|
@ -7,7 +7,7 @@ import { t } from 'i18next';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
|
||||
import type { TextualInversionConfig } from 'services/api/types';
|
||||
import type { TextualInversionModelConfig } from 'services/api/types';
|
||||
|
||||
const noOptionsMessage = () => t('embedding.noMatchingEmbedding');
|
||||
|
||||
@ -17,7 +17,7 @@ export const EmbeddingSelect = memo(({ onSelect, onClose }: EmbeddingSelectProps
|
||||
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
||||
|
||||
const getIsDisabled = useCallback(
|
||||
(embedding: TextualInversionConfig): boolean => {
|
||||
(embedding: TextualInversionModelConfig): boolean => {
|
||||
const isCompatible = currentBaseModel === embedding.base;
|
||||
const hasMainModel = Boolean(currentBaseModel);
|
||||
return !hasMainModel || !isCompatible;
|
||||
@ -27,7 +27,7 @@ export const EmbeddingSelect = memo(({ onSelect, onClose }: EmbeddingSelectProps
|
||||
const { data, isLoading } = useGetTextualInversionModelsQuery();
|
||||
|
||||
const _onChange = useCallback(
|
||||
(embedding: TextualInversionConfig | null) => {
|
||||
(embedding: TextualInversionModelConfig | null) => {
|
||||
if (!embedding) {
|
||||
return;
|
||||
}
|
||||
|
@ -8,7 +8,7 @@ import { loraAdded, selectLoraSlice } from 'features/lora/store/loraSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
|
||||
import type { LoRAConfig } from 'services/api/types';
|
||||
import type { LoRAModelConfig } from 'services/api/types';
|
||||
|
||||
const selectAddedLoRAs = createMemoizedSelector(selectLoraSlice, (lora) => lora.loras);
|
||||
|
||||
@ -19,7 +19,7 @@ const LoRASelect = () => {
|
||||
const addedLoRAs = useAppSelector(selectAddedLoRAs);
|
||||
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
||||
|
||||
const getIsDisabled = (lora: LoRAConfig): boolean => {
|
||||
const getIsDisabled = (lora: LoRAModelConfig): boolean => {
|
||||
const isCompatible = currentBaseModel === lora.base;
|
||||
const isAdded = Boolean(addedLoRAs[lora.key]);
|
||||
const hasMainModel = Boolean(currentBaseModel);
|
||||
@ -27,7 +27,7 @@ const LoRASelect = () => {
|
||||
};
|
||||
|
||||
const _onChange = useCallback(
|
||||
(lora: LoRAConfig | null) => {
|
||||
(lora: LoRAModelConfig | null) => {
|
||||
if (!lora) {
|
||||
return;
|
||||
}
|
||||
|
@ -2,7 +2,7 @@ import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
|
||||
import type { LoRAConfig } from 'services/api/types';
|
||||
import type { LoRAModelConfig } from 'services/api/types';
|
||||
|
||||
export type LoRA = ParameterLoRAModel & {
|
||||
weight: number;
|
||||
@ -28,13 +28,12 @@ export const loraSlice = createSlice({
|
||||
name: 'lora',
|
||||
initialState: initialLoraState,
|
||||
reducers: {
|
||||
loraAdded: (state, action: PayloadAction<LoRAConfig>) => {
|
||||
loraAdded: (state, action: PayloadAction<LoRAModelConfig>) => {
|
||||
const { key, base } = action.payload;
|
||||
state.loras[key] = { key, base, ...defaultLoRAConfig };
|
||||
},
|
||||
loraRecalled: (state, action: PayloadAction<LoRAConfig & { weight: number }>) => {
|
||||
const { key, base, weight } = action.payload;
|
||||
state.loras[key] = { key, base, weight, isEnabled: true };
|
||||
loraRecalled: (state, action: PayloadAction<LoRA>) => {
|
||||
state.loras[action.payload.key] = action.payload;
|
||||
},
|
||||
loraRemoved: (state, action: PayloadAction<string>) => {
|
||||
const key = action.payload;
|
||||
|
@ -8,12 +8,11 @@ import { memo, useCallback } from 'react';
|
||||
import type { SubmitHandler } from 'react-hook-form';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { LoRAConfig } from 'services/api/endpoints/models';
|
||||
import { useUpdateLoRAModelsMutation } from 'services/api/endpoints/models';
|
||||
import type { LoRAConfig } from 'services/api/types';
|
||||
import type { LoRAModelConfig } from 'services/api/types';
|
||||
|
||||
type LoRAModelEditProps = {
|
||||
model: LoRAConfig;
|
||||
model: LoRAModelConfig;
|
||||
};
|
||||
|
||||
const LoRAModelEdit = (props: LoRAModelEditProps) => {
|
||||
@ -30,7 +29,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
|
||||
control,
|
||||
formState: { errors },
|
||||
reset,
|
||||
} = useForm<LoRAConfig>({
|
||||
} = useForm<LoRAModelConfig>({
|
||||
defaultValues: {
|
||||
model_name: model.model_name ? model.model_name : '',
|
||||
base_model: model.base_model,
|
||||
@ -42,7 +41,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
|
||||
mode: 'onChange',
|
||||
});
|
||||
|
||||
const onSubmit = useCallback<SubmitHandler<LoRAConfig>>(
|
||||
const onSubmit = useCallback<SubmitHandler<LoRAModelConfig>>(
|
||||
(values) => {
|
||||
const responseBody = {
|
||||
base_model: model.base_model,
|
||||
@ -53,7 +52,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
|
||||
updateLoRAModel(responseBody)
|
||||
.unwrap()
|
||||
.then((payload) => {
|
||||
reset(payload as LoRAConfig, { keepDefaultValues: true });
|
||||
reset(payload as LoRAModelConfig, { keepDefaultValues: true });
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
@ -106,7 +105,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
|
||||
<FormLabel>{t('modelManager.description')}</FormLabel>
|
||||
<Input {...register('description')} />
|
||||
</FormControl>
|
||||
<BaseModelSelect<LoRAConfig> control={control} name="base_model" />
|
||||
<BaseModelSelect<LoRAModelConfig> control={control} name="base_model" />
|
||||
|
||||
<FormControl isInvalid={Boolean(errors.path)}>
|
||||
<FormLabel>{t('modelManager.modelLocation')}</FormLabel>
|
||||
|
@ -6,7 +6,7 @@ import type { ControlNetModelFieldInputInstance, ControlNetModelFieldInputTempla
|
||||
import { pick } from 'lodash-es';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
|
||||
import type { ControlNetConfig } from 'services/api/types';
|
||||
import type { ControlNetModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
@ -18,7 +18,7 @@ const ControlNetModelFieldInputComponent = (props: Props) => {
|
||||
const { data, isLoading } = useGetControlNetModelsQuery();
|
||||
|
||||
const _onChange = useCallback(
|
||||
(value: ControlNetConfig | null) => {
|
||||
(value: ControlNetModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
|
@ -6,7 +6,7 @@ import type { IPAdapterModelFieldInputInstance, IPAdapterModelFieldInputTemplate
|
||||
import { pick } from 'lodash-es';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useGetIPAdapterModelsQuery } from 'services/api/endpoints/models';
|
||||
import type { IPAdapterConfig } from 'services/api/types';
|
||||
import type { IPAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
@ -18,7 +18,7 @@ const IPAdapterModelFieldInputComponent = (
|
||||
const { data: ipAdapterModels } = useGetIPAdapterModelsQuery();
|
||||
|
||||
const _onChange = useCallback(
|
||||
(value: IPAdapterConfig | null) => {
|
||||
(value: IPAdapterModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
|
@ -6,7 +6,7 @@ import type { LoRAModelFieldInputInstance, LoRAModelFieldInputTemplate } from 'f
|
||||
import { pick } from 'lodash-es';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
|
||||
import type { LoRAConfig } from 'services/api/types';
|
||||
import type { LoRAModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
@ -17,7 +17,7 @@ const LoRAModelFieldInputComponent = (props: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { data, isLoading } = useGetLoRAModelsQuery();
|
||||
const _onChange = useCallback(
|
||||
(value: LoRAConfig | null) => {
|
||||
(value: LoRAModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
|
@ -6,7 +6,7 @@ import type { T2IAdapterModelFieldInputInstance, T2IAdapterModelFieldInputTempla
|
||||
import { pick } from 'lodash-es';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useGetT2IAdapterModelsQuery } from 'services/api/endpoints/models';
|
||||
import type { T2IAdapterConfig } from 'services/api/types';
|
||||
import type { T2IAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
@ -19,7 +19,7 @@ const T2IAdapterModelFieldInputComponent = (
|
||||
const { data: t2iAdapterModels } = useGetT2IAdapterModelsQuery();
|
||||
|
||||
const _onChange = useCallback(
|
||||
(value: T2IAdapterConfig | null) => {
|
||||
(value: T2IAdapterModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
|
@ -7,7 +7,7 @@ import type { VAEModelFieldInputInstance, VAEModelFieldInputTemplate } from 'fea
|
||||
import { pick } from 'lodash-es';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
||||
import type { VAEConfig } from 'services/api/types';
|
||||
import type { VAEModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
@ -18,7 +18,7 @@ const VAEModelFieldInputComponent = (props: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { data, isLoading } = useGetVaeModelsQuery();
|
||||
const _onChange = useCallback(
|
||||
(value: VAEConfig | null) => {
|
||||
(value: VAEModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
|
@ -8,7 +8,7 @@ import { pick } from 'lodash-es';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
||||
import type { VAEConfig } from 'services/api/types';
|
||||
import type { VAEModelConfig } from 'services/api/types';
|
||||
|
||||
const selector = createMemoizedSelector(selectGenerationSlice, (generation) => {
|
||||
const { model, vae } = generation;
|
||||
@ -21,7 +21,7 @@ const ParamVAEModelSelect = () => {
|
||||
const { model, vae } = useAppSelector(selector);
|
||||
const { data, isLoading } = useGetVaeModelsQuery();
|
||||
const getIsDisabled = useCallback(
|
||||
(vae: VAEConfig): boolean => {
|
||||
(vae: VAEModelConfig): boolean => {
|
||||
const isCompatible = model?.base === vae.base;
|
||||
const hasMainModel = Boolean(model?.base);
|
||||
return !hasMainModel || !isCompatible;
|
||||
@ -29,7 +29,7 @@ const ParamVAEModelSelect = () => {
|
||||
[model?.base]
|
||||
);
|
||||
const _onChange = useCallback(
|
||||
(vae: VAEConfig | null) => {
|
||||
(vae: VAEModelConfig | null) => {
|
||||
dispatch(vaeSelected(vae ? pick(vae, 'key', 'base') : null));
|
||||
},
|
||||
[dispatch]
|
||||
|
@ -1,17 +1,9 @@
|
||||
import { useAppToaster } from 'app/components/Toaster';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
|
||||
import { controlAdapterRecalled, controlAdaptersReset } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types';
|
||||
import {
|
||||
initialControlNet,
|
||||
initialIPAdapter,
|
||||
initialT2IAdapter,
|
||||
} from 'features/controlAdapters/util/buildControlAdapter';
|
||||
import { setHrfEnabled, setHrfMethod, setHrfStrength } from 'features/hrf/store/hrfSlice';
|
||||
import { loraRecalled, lorasCleared } from 'features/lora/store/loraSlice';
|
||||
import type { ModelIdentifier } from 'features/nodes/types/common';
|
||||
import { isModelIdentifier } from 'features/nodes/types/common';
|
||||
import type {
|
||||
ControlNetMetadataItem,
|
||||
@ -56,6 +48,14 @@ import {
|
||||
isParameterStrength,
|
||||
isParameterWidth,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import {
|
||||
prepareControlNetMetadataItem,
|
||||
prepareIPAdapterMetadataItem,
|
||||
prepareLoRAMetadataItem,
|
||||
prepareMainModelMetadataItem,
|
||||
prepareT2IAdapterMetadataItem,
|
||||
prepareVAEMetadataItem,
|
||||
} from 'features/parameters/util/modelMetadataHelpers';
|
||||
import {
|
||||
refinerModelChanged,
|
||||
setNegativeStylePromptSDXL,
|
||||
@ -70,23 +70,7 @@ import {
|
||||
import { isNil } from 'lodash-es';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||
import {
|
||||
controlNetModelsAdapterSelectors,
|
||||
ipAdapterModelsAdapterSelectors,
|
||||
loraModelsAdapterSelectors,
|
||||
mainModelsAdapterSelectors,
|
||||
t2iAdapterModelsAdapterSelectors,
|
||||
useGetControlNetModelsQuery,
|
||||
useGetIPAdapterModelsQuery,
|
||||
useGetLoRAModelsQuery,
|
||||
useGetMainModelsQuery,
|
||||
useGetT2IAdapterModelsQuery,
|
||||
useGetVaeModelsQuery,
|
||||
vaeModelsAdapterSelectors,
|
||||
} from 'services/api/endpoints/models';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
const selectModel = createMemoizedSelector(selectGenerationSlice, (generation) => generation.model);
|
||||
|
||||
@ -140,9 +124,6 @@ export const useRecallParameters = () => {
|
||||
[t, toaster]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall both prompts with toast
|
||||
*/
|
||||
const recallBothPrompts = useCallback(
|
||||
(positivePrompt: unknown, negativePrompt: unknown, positiveStylePrompt: unknown, negativeStylePrompt: unknown) => {
|
||||
if (
|
||||
@ -175,9 +156,6 @@ export const useRecallParameters = () => {
|
||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall positive prompt with toast
|
||||
*/
|
||||
const recallPositivePrompt = useCallback(
|
||||
(positivePrompt: unknown) => {
|
||||
if (!isParameterPositivePrompt(positivePrompt)) {
|
||||
@ -190,9 +168,6 @@ export const useRecallParameters = () => {
|
||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall negative prompt with toast
|
||||
*/
|
||||
const recallNegativePrompt = useCallback(
|
||||
(negativePrompt: unknown) => {
|
||||
if (!isParameterNegativePrompt(negativePrompt)) {
|
||||
@ -205,9 +180,6 @@ export const useRecallParameters = () => {
|
||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall SDXL Positive Style Prompt with toast
|
||||
*/
|
||||
const recallSDXLPositiveStylePrompt = useCallback(
|
||||
(positiveStylePrompt: unknown) => {
|
||||
if (!isParameterPositiveStylePromptSDXL(positiveStylePrompt)) {
|
||||
@ -220,9 +192,6 @@ export const useRecallParameters = () => {
|
||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall SDXL Negative Style Prompt with toast
|
||||
*/
|
||||
const recallSDXLNegativeStylePrompt = useCallback(
|
||||
(negativeStylePrompt: unknown) => {
|
||||
if (!isParameterNegativeStylePromptSDXL(negativeStylePrompt)) {
|
||||
@ -235,9 +204,6 @@ export const useRecallParameters = () => {
|
||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall seed with toast
|
||||
*/
|
||||
const recallSeed = useCallback(
|
||||
(seed: unknown) => {
|
||||
if (!isParameterSeed(seed)) {
|
||||
@ -250,9 +216,6 @@ export const useRecallParameters = () => {
|
||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall CFG scale with toast
|
||||
*/
|
||||
const recallCfgScale = useCallback(
|
||||
(cfgScale: unknown) => {
|
||||
if (!isParameterCFGScale(cfgScale)) {
|
||||
@ -265,9 +228,6 @@ export const useRecallParameters = () => {
|
||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall CFG rescale multiplier with toast
|
||||
*/
|
||||
const recallCfgRescaleMultiplier = useCallback(
|
||||
(cfgRescaleMultiplier: unknown) => {
|
||||
if (!isParameterCFGRescaleMultiplier(cfgRescaleMultiplier)) {
|
||||
@ -280,9 +240,6 @@ export const useRecallParameters = () => {
|
||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall scheduler with toast
|
||||
*/
|
||||
const recallScheduler = useCallback(
|
||||
(scheduler: unknown) => {
|
||||
if (!isParameterScheduler(scheduler)) {
|
||||
@ -295,9 +252,6 @@ export const useRecallParameters = () => {
|
||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall steps with toast
|
||||
*/
|
||||
const recallSteps = useCallback(
|
||||
(steps: unknown) => {
|
||||
if (!isParameterSteps(steps)) {
|
||||
@ -310,9 +264,6 @@ export const useRecallParameters = () => {
|
||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall width with toast
|
||||
*/
|
||||
const recallWidth = useCallback(
|
||||
(width: unknown) => {
|
||||
if (!isParameterWidth(width)) {
|
||||
@ -325,9 +276,6 @@ export const useRecallParameters = () => {
|
||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall height with toast
|
||||
*/
|
||||
const recallHeight = useCallback(
|
||||
(height: unknown) => {
|
||||
if (!isParameterHeight(height)) {
|
||||
@ -340,9 +288,6 @@ export const useRecallParameters = () => {
|
||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall width and height with toast
|
||||
*/
|
||||
const recallWidthAndHeight = useCallback(
|
||||
(width: unknown, height: unknown) => {
|
||||
if (!isParameterWidth(width)) {
|
||||
@ -360,9 +305,6 @@ export const useRecallParameters = () => {
|
||||
[dispatch, allParameterSetToast, allParameterNotSetToast]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall strength with toast
|
||||
*/
|
||||
const recallStrength = useCallback(
|
||||
(strength: unknown) => {
|
||||
if (!isParameterStrength(strength)) {
|
||||
@ -375,9 +317,6 @@ export const useRecallParameters = () => {
|
||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall high resolution enabled with toast
|
||||
*/
|
||||
const recallHrfEnabled = useCallback(
|
||||
(hrfEnabled: unknown) => {
|
||||
if (!isParameterHRFEnabled(hrfEnabled)) {
|
||||
@ -390,9 +329,6 @@ export const useRecallParameters = () => {
|
||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall high resolution strength with toast
|
||||
*/
|
||||
const recallHrfStrength = useCallback(
|
||||
(hrfStrength: unknown) => {
|
||||
if (!isParameterStrength(hrfStrength)) {
|
||||
@ -405,9 +341,6 @@ export const useRecallParameters = () => {
|
||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall high resolution method with toast
|
||||
*/
|
||||
const recallHrfMethod = useCallback(
|
||||
(hrfMethod: unknown) => {
|
||||
if (!isParameterHRFMethod(hrfMethod)) {
|
||||
@ -420,358 +353,95 @@ export const useRecallParameters = () => {
|
||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
const { data: mainModels } = useGetMainModelsQuery(ALL_BASE_MODELS);
|
||||
|
||||
const prepareMainModelMetadataItem = useCallback(
|
||||
(model: ModelIdentifier) => {
|
||||
const matchingModel = mainModels ? mainModelsAdapterSelectors.selectById(mainModels, model.key) : undefined;
|
||||
|
||||
if (!matchingModel) {
|
||||
return { model: null, error: 'Model is not installed' };
|
||||
}
|
||||
|
||||
return { model: matchingModel, error: null };
|
||||
},
|
||||
[mainModels]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall model with toast
|
||||
*/
|
||||
const recallModel = useCallback(
|
||||
(model: unknown) => {
|
||||
if (!isModelIdentifier(model)) {
|
||||
parameterNotSetToast();
|
||||
async (modelMetadataItem: unknown) => {
|
||||
try {
|
||||
const model = await prepareMainModelMetadataItem(modelMetadataItem);
|
||||
dispatch(modelSelected(model));
|
||||
parameterSetToast();
|
||||
} catch (e) {
|
||||
parameterNotSetToast((e as unknown as Error).message);
|
||||
return;
|
||||
}
|
||||
|
||||
const result = prepareMainModelMetadataItem(model);
|
||||
|
||||
if (!result.model) {
|
||||
parameterNotSetToast(result.error);
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(modelSelected(result.model));
|
||||
parameterSetToast();
|
||||
},
|
||||
[prepareMainModelMetadataItem, dispatch, parameterSetToast, parameterNotSetToast]
|
||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
const { data: vaeModels } = useGetVaeModelsQuery();
|
||||
|
||||
const prepareVAEMetadataItem = useCallback(
|
||||
(vae: ModelIdentifier, newModel?: ParameterModel) => {
|
||||
const matchingModel = vaeModels ? vaeModelsAdapterSelectors.selectById(vaeModels, vae.key) : undefined;
|
||||
if (!matchingModel) {
|
||||
return { vae: null, error: 'VAE model is not installed' };
|
||||
}
|
||||
const isCompatibleBaseModel = matchingModel?.base === (newModel ?? model)?.base;
|
||||
|
||||
if (!isCompatibleBaseModel) {
|
||||
return {
|
||||
vae: null,
|
||||
error: 'VAE incompatible with currently-selected model',
|
||||
};
|
||||
}
|
||||
|
||||
return { vae: matchingModel, error: null };
|
||||
},
|
||||
[model, vaeModels]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall vae model
|
||||
*/
|
||||
const recallVaeModel = useCallback(
|
||||
(vae: unknown) => {
|
||||
if (!isModelIdentifier(vae) && !isNil(vae)) {
|
||||
parameterNotSetToast();
|
||||
return;
|
||||
}
|
||||
|
||||
if (isNil(vae)) {
|
||||
async (vaeMetadataItem: unknown) => {
|
||||
if (isNil(vaeMetadataItem)) {
|
||||
dispatch(vaeSelected(null));
|
||||
parameterSetToast();
|
||||
return;
|
||||
}
|
||||
|
||||
const result = prepareVAEMetadataItem(vae);
|
||||
|
||||
if (!result.vae) {
|
||||
parameterNotSetToast(result.error);
|
||||
try {
|
||||
const vae = await prepareVAEMetadataItem(vaeMetadataItem);
|
||||
dispatch(vaeSelected(vae));
|
||||
parameterSetToast();
|
||||
} catch (e) {
|
||||
parameterNotSetToast((e as unknown as Error).message);
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(vaeSelected(result.vae));
|
||||
parameterSetToast();
|
||||
},
|
||||
[prepareVAEMetadataItem, dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall LoRA with toast
|
||||
*/
|
||||
|
||||
const { data: loraModels } = useGetLoRAModelsQuery(undefined);
|
||||
|
||||
const prepareLoRAMetadataItem = useCallback(
|
||||
(loraMetadataItem: LoRAMetadataItem, newModel?: ParameterModel) => {
|
||||
if (!isModelIdentifier(loraMetadataItem.lora)) {
|
||||
return { lora: null, error: 'Invalid LoRA model' };
|
||||
}
|
||||
|
||||
const { lora } = loraMetadataItem;
|
||||
|
||||
const matchingLoRA = loraModels ? loraModelsAdapterSelectors.selectById(loraModels, lora.key) : undefined;
|
||||
|
||||
if (!matchingLoRA) {
|
||||
return { lora: null, error: 'LoRA model is not installed' };
|
||||
}
|
||||
|
||||
const isCompatibleBaseModel = matchingLoRA?.base === (newModel ?? model)?.base;
|
||||
|
||||
if (!isCompatibleBaseModel) {
|
||||
return {
|
||||
lora: null,
|
||||
error: 'LoRA incompatible with currently-selected model',
|
||||
};
|
||||
}
|
||||
|
||||
return { lora: matchingLoRA, error: null };
|
||||
},
|
||||
[loraModels, model]
|
||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
const recallLoRA = useCallback(
|
||||
(loraMetadataItem: LoRAMetadataItem) => {
|
||||
const result = prepareLoRAMetadataItem(loraMetadataItem);
|
||||
|
||||
if (!result.lora) {
|
||||
parameterNotSetToast(result.error);
|
||||
async (loraMetadataItem: LoRAMetadataItem) => {
|
||||
try {
|
||||
const lora = await prepareLoRAMetadataItem(loraMetadataItem, model?.base);
|
||||
dispatch(loraRecalled(lora));
|
||||
parameterSetToast();
|
||||
} catch (e) {
|
||||
parameterNotSetToast((e as unknown as Error).message);
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(loraRecalled({ ...result.lora, weight: loraMetadataItem.weight }));
|
||||
|
||||
parameterSetToast();
|
||||
},
|
||||
[prepareLoRAMetadataItem, dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall ControlNet with toast
|
||||
*/
|
||||
|
||||
const { data: controlNetModels } = useGetControlNetModelsQuery(undefined);
|
||||
|
||||
const prepareControlNetMetadataItem = useCallback(
|
||||
(controlnetMetadataItem: ControlNetMetadataItem, newModel?: ParameterModel) => {
|
||||
if (!isModelIdentifier(controlnetMetadataItem.control_model)) {
|
||||
return { controlnet: null, error: 'Invalid ControlNet model' };
|
||||
}
|
||||
|
||||
const { image, control_model, control_weight, begin_step_percent, end_step_percent, control_mode, resize_mode } =
|
||||
controlnetMetadataItem;
|
||||
|
||||
const matchingControlNetModel = controlNetModels
|
||||
? controlNetModelsAdapterSelectors.selectById(controlNetModels, control_model.key)
|
||||
: undefined;
|
||||
|
||||
if (!matchingControlNetModel) {
|
||||
return { controlnet: null, error: 'ControlNet model is not installed' };
|
||||
}
|
||||
|
||||
const isCompatibleBaseModel = matchingControlNetModel?.base === (newModel ?? model)?.base;
|
||||
|
||||
if (!isCompatibleBaseModel) {
|
||||
return {
|
||||
controlnet: null,
|
||||
error: 'ControlNet incompatible with currently-selected model',
|
||||
};
|
||||
}
|
||||
|
||||
// We don't save the original image that was processed into a control image, only the processed image
|
||||
const processorType = 'none';
|
||||
const processorNode = CONTROLNET_PROCESSORS.none.default;
|
||||
|
||||
const controlnet: ControlNetConfig = {
|
||||
type: 'controlnet',
|
||||
isEnabled: true,
|
||||
model: matchingControlNetModel,
|
||||
weight: typeof control_weight === 'number' ? control_weight : initialControlNet.weight,
|
||||
beginStepPct: begin_step_percent || initialControlNet.beginStepPct,
|
||||
endStepPct: end_step_percent || initialControlNet.endStepPct,
|
||||
controlMode: control_mode || initialControlNet.controlMode,
|
||||
resizeMode: resize_mode || initialControlNet.resizeMode,
|
||||
controlImage: image?.image_name || null,
|
||||
processedControlImage: image?.image_name || null,
|
||||
processorType,
|
||||
processorNode,
|
||||
shouldAutoConfig: true,
|
||||
id: uuidv4(),
|
||||
};
|
||||
|
||||
return { controlnet, error: null };
|
||||
},
|
||||
[controlNetModels, model]
|
||||
[model?.base, dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
const recallControlNet = useCallback(
|
||||
(controlnetMetadataItem: ControlNetMetadataItem) => {
|
||||
const result = prepareControlNetMetadataItem(controlnetMetadataItem);
|
||||
|
||||
if (!result.controlnet) {
|
||||
parameterNotSetToast(result.error);
|
||||
async (controlnetMetadataItem: ControlNetMetadataItem) => {
|
||||
try {
|
||||
const controlNetConfig = await prepareControlNetMetadataItem(controlnetMetadataItem, model?.base);
|
||||
dispatch(controlAdapterRecalled(controlNetConfig));
|
||||
parameterSetToast();
|
||||
} catch (e) {
|
||||
parameterNotSetToast((e as unknown as Error).message);
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(controlAdapterRecalled(result.controlnet));
|
||||
|
||||
parameterSetToast();
|
||||
},
|
||||
[prepareControlNetMetadataItem, dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall T2I Adapter with toast
|
||||
*/
|
||||
|
||||
const { data: t2iAdapterModels } = useGetT2IAdapterModelsQuery(undefined);
|
||||
|
||||
const prepareT2IAdapterMetadataItem = useCallback(
|
||||
(t2iAdapterMetadataItem: T2IAdapterMetadataItem, newModel?: ParameterModel) => {
|
||||
if (!isModelIdentifier(t2iAdapterMetadataItem.t2i_adapter_model)) {
|
||||
return { controlnet: null, error: 'Invalid ControlNet model' };
|
||||
}
|
||||
|
||||
const { image, t2i_adapter_model, weight, begin_step_percent, end_step_percent, resize_mode } =
|
||||
t2iAdapterMetadataItem;
|
||||
|
||||
const matchingT2IAdapterModel = t2iAdapterModels
|
||||
? t2iAdapterModelsAdapterSelectors.selectById(t2iAdapterModels, t2i_adapter_model.key)
|
||||
: undefined;
|
||||
|
||||
if (!matchingT2IAdapterModel) {
|
||||
return { controlnet: null, error: 'ControlNet model is not installed' };
|
||||
}
|
||||
|
||||
const isCompatibleBaseModel = matchingT2IAdapterModel?.base === (newModel ?? model)?.base;
|
||||
|
||||
if (!isCompatibleBaseModel) {
|
||||
return {
|
||||
t2iAdapter: null,
|
||||
error: 'ControlNet incompatible with currently-selected model',
|
||||
};
|
||||
}
|
||||
|
||||
// We don't save the original image that was processed into a control image, only the processed image
|
||||
const processorType = 'none';
|
||||
const processorNode = CONTROLNET_PROCESSORS.none.default;
|
||||
|
||||
const t2iAdapter: T2IAdapterConfig = {
|
||||
type: 't2i_adapter',
|
||||
isEnabled: true,
|
||||
model: matchingT2IAdapterModel,
|
||||
weight: typeof weight === 'number' ? weight : initialT2IAdapter.weight,
|
||||
beginStepPct: begin_step_percent || initialT2IAdapter.beginStepPct,
|
||||
endStepPct: end_step_percent || initialT2IAdapter.endStepPct,
|
||||
resizeMode: resize_mode || initialT2IAdapter.resizeMode,
|
||||
controlImage: image?.image_name || null,
|
||||
processedControlImage: image?.image_name || null,
|
||||
processorType,
|
||||
processorNode,
|
||||
shouldAutoConfig: true,
|
||||
id: uuidv4(),
|
||||
};
|
||||
|
||||
return { t2iAdapter, error: null };
|
||||
},
|
||||
[model, t2iAdapterModels]
|
||||
[model?.base, dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
const recallT2IAdapter = useCallback(
|
||||
(t2iAdapterMetadataItem: T2IAdapterMetadataItem) => {
|
||||
const result = prepareT2IAdapterMetadataItem(t2iAdapterMetadataItem);
|
||||
|
||||
if (!result.t2iAdapter) {
|
||||
parameterNotSetToast(result.error);
|
||||
async (t2iAdapterMetadataItem: T2IAdapterMetadataItem) => {
|
||||
try {
|
||||
const t2iAdapterConfig = await prepareT2IAdapterMetadataItem(t2iAdapterMetadataItem, model?.base);
|
||||
dispatch(controlAdapterRecalled(t2iAdapterConfig));
|
||||
parameterSetToast();
|
||||
} catch (e) {
|
||||
parameterNotSetToast((e as unknown as Error).message);
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(controlAdapterRecalled(result.t2iAdapter));
|
||||
|
||||
parameterSetToast();
|
||||
},
|
||||
[prepareT2IAdapterMetadataItem, dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall IP Adapter with toast
|
||||
*/
|
||||
|
||||
const { data: ipAdapterModels } = useGetIPAdapterModelsQuery(undefined);
|
||||
|
||||
const prepareIPAdapterMetadataItem = useCallback(
|
||||
(ipAdapterMetadataItem: IPAdapterMetadataItem, newModel?: ParameterModel) => {
|
||||
if (!isModelIdentifier(ipAdapterMetadataItem?.ip_adapter_model)) {
|
||||
return { ipAdapter: null, error: 'Invalid IP Adapter model' };
|
||||
}
|
||||
|
||||
const { image, ip_adapter_model, weight, begin_step_percent, end_step_percent } = ipAdapterMetadataItem;
|
||||
|
||||
const matchingIPAdapterModel = ipAdapterModels
|
||||
? ipAdapterModelsAdapterSelectors.selectById(ipAdapterModels, ip_adapter_model.key)
|
||||
: undefined;
|
||||
|
||||
if (!matchingIPAdapterModel) {
|
||||
return { ipAdapter: null, error: 'IP Adapter model is not installed' };
|
||||
}
|
||||
|
||||
const isCompatibleBaseModel = matchingIPAdapterModel?.base === (newModel ?? model)?.base;
|
||||
|
||||
if (!isCompatibleBaseModel) {
|
||||
return {
|
||||
ipAdapter: null,
|
||||
error: 'IP Adapter incompatible with currently-selected model',
|
||||
};
|
||||
}
|
||||
|
||||
const ipAdapter: IPAdapterConfig = {
|
||||
id: uuidv4(),
|
||||
type: 'ip_adapter',
|
||||
isEnabled: true,
|
||||
controlImage: image?.image_name ?? null,
|
||||
model: matchingIPAdapterModel,
|
||||
weight: weight ?? initialIPAdapter.weight,
|
||||
beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct,
|
||||
endStepPct: end_step_percent ?? initialIPAdapter.endStepPct,
|
||||
};
|
||||
|
||||
return { ipAdapter, error: null };
|
||||
},
|
||||
[ipAdapterModels, model]
|
||||
[model?.base, dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
const recallIPAdapter = useCallback(
|
||||
(ipAdapterMetadataItem: IPAdapterMetadataItem) => {
|
||||
const result = prepareIPAdapterMetadataItem(ipAdapterMetadataItem);
|
||||
|
||||
if (!result.ipAdapter) {
|
||||
parameterNotSetToast(result.error);
|
||||
async (ipAdapterMetadataItem: IPAdapterMetadataItem) => {
|
||||
try {
|
||||
const ipAdapterConfig = await prepareIPAdapterMetadataItem(ipAdapterMetadataItem, model?.base);
|
||||
dispatch(controlAdapterRecalled(ipAdapterConfig));
|
||||
parameterSetToast();
|
||||
} catch (e) {
|
||||
parameterNotSetToast((e as unknown as Error).message);
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(controlAdapterRecalled(result.ipAdapter));
|
||||
|
||||
parameterSetToast();
|
||||
},
|
||||
[prepareIPAdapterMetadataItem, dispatch, parameterSetToast, parameterNotSetToast]
|
||||
[model?.base, dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
/*
|
||||
* Sets image as initial image with toast
|
||||
*/
|
||||
const sendToImageToImage = useCallback(
|
||||
(image: ImageDTO) => {
|
||||
dispatch(initialImageSelected(image));
|
||||
@ -780,7 +450,7 @@ export const useRecallParameters = () => {
|
||||
);
|
||||
|
||||
const recallAllParameters = useCallback(
|
||||
(metadata: CoreMetadata | undefined) => {
|
||||
async (metadata: CoreMetadata | undefined) => {
|
||||
if (!metadata) {
|
||||
allParameterNotSetToast();
|
||||
return;
|
||||
@ -820,10 +490,12 @@ export const useRecallParameters = () => {
|
||||
let newModel: ParameterModel | undefined = undefined;
|
||||
|
||||
if (isModelIdentifier(model)) {
|
||||
const result = prepareMainModelMetadataItem(model);
|
||||
if (result.model) {
|
||||
dispatch(modelSelected(result.model));
|
||||
newModel = result.model;
|
||||
try {
|
||||
const _model = await prepareMainModelMetadataItem(model);
|
||||
dispatch(modelSelected(_model));
|
||||
newModel = _model;
|
||||
} catch {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
@ -850,9 +522,11 @@ export const useRecallParameters = () => {
|
||||
if (isNil(vae)) {
|
||||
dispatch(vaeSelected(null));
|
||||
} else {
|
||||
const result = prepareVAEMetadataItem(vae, newModel);
|
||||
if (result.vae) {
|
||||
dispatch(vaeSelected(result.vae));
|
||||
try {
|
||||
const _vae = await prepareVAEMetadataItem(vae, newModel?.base);
|
||||
dispatch(vaeSelected(_vae));
|
||||
} catch {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -926,48 +600,46 @@ export const useRecallParameters = () => {
|
||||
}
|
||||
|
||||
dispatch(lorasCleared());
|
||||
loras?.forEach((lora) => {
|
||||
const result = prepareLoRAMetadataItem(lora, newModel);
|
||||
if (result.lora) {
|
||||
dispatch(loraRecalled({ ...result.lora, weight: lora.weight }));
|
||||
loras?.forEach(async (loraMetadataItem) => {
|
||||
try {
|
||||
const lora = await prepareLoRAMetadataItem(loraMetadataItem, newModel?.base);
|
||||
dispatch(loraRecalled(lora));
|
||||
} catch {
|
||||
return;
|
||||
}
|
||||
});
|
||||
|
||||
dispatch(controlAdaptersReset());
|
||||
controlnets?.forEach((controlnet) => {
|
||||
const result = prepareControlNetMetadataItem(controlnet, newModel);
|
||||
if (result.controlnet) {
|
||||
dispatch(controlAdapterRecalled(result.controlnet));
|
||||
controlnets?.forEach(async (controlNetMetadataItem) => {
|
||||
try {
|
||||
const controlNet = await prepareControlNetMetadataItem(controlNetMetadataItem, newModel?.base);
|
||||
dispatch(controlAdapterRecalled(controlNet));
|
||||
} catch {
|
||||
return;
|
||||
}
|
||||
});
|
||||
|
||||
ipAdapters?.forEach((ipAdapter) => {
|
||||
const result = prepareIPAdapterMetadataItem(ipAdapter, newModel);
|
||||
if (result.ipAdapter) {
|
||||
dispatch(controlAdapterRecalled(result.ipAdapter));
|
||||
ipAdapters?.forEach(async (ipAdapterMetadataItem) => {
|
||||
try {
|
||||
const ipAdapter = await prepareIPAdapterMetadataItem(ipAdapterMetadataItem, newModel?.base);
|
||||
dispatch(controlAdapterRecalled(ipAdapter));
|
||||
} catch {
|
||||
return;
|
||||
}
|
||||
});
|
||||
|
||||
t2iAdapters?.forEach((t2iAdapter) => {
|
||||
const result = prepareT2IAdapterMetadataItem(t2iAdapter, newModel);
|
||||
if (result.t2iAdapter) {
|
||||
dispatch(controlAdapterRecalled(result.t2iAdapter));
|
||||
t2iAdapters?.forEach(async (t2iAdapterMetadataItem) => {
|
||||
try {
|
||||
const t2iAdapter = await prepareT2IAdapterMetadataItem(t2iAdapterMetadataItem, newModel?.base);
|
||||
dispatch(controlAdapterRecalled(t2iAdapter));
|
||||
} catch {
|
||||
return;
|
||||
}
|
||||
});
|
||||
|
||||
allParameterSetToast();
|
||||
},
|
||||
[
|
||||
dispatch,
|
||||
allParameterSetToast,
|
||||
allParameterNotSetToast,
|
||||
prepareMainModelMetadataItem,
|
||||
prepareVAEMetadataItem,
|
||||
prepareLoRAMetadataItem,
|
||||
prepareControlNetMetadataItem,
|
||||
prepareIPAdapterMetadataItem,
|
||||
prepareT2IAdapterMetadataItem,
|
||||
]
|
||||
[dispatch, allParameterSetToast, allParameterNotSetToast]
|
||||
);
|
||||
|
||||
return {
|
||||
|
@ -0,0 +1,113 @@
|
||||
import { getStore } from 'app/store/nanostores/store';
|
||||
import { isModelIdentifier } from 'features/nodes/types/common';
|
||||
import { modelsApi } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig, BaseModelType } from 'services/api/types';
|
||||
import {
|
||||
isControlNetModelConfig,
|
||||
isIPAdapterModelConfig,
|
||||
isLoRAModelConfig,
|
||||
isNonRefinerMainModelConfig,
|
||||
isRefinerMainModelModelConfig,
|
||||
isT2IAdapterModelConfig,
|
||||
isTextualInversionModelConfig,
|
||||
isVAEModelConfig,
|
||||
} from 'services/api/types';
|
||||
|
||||
/**
|
||||
* Raised when a model config is unable to be fetched.
|
||||
*/
|
||||
export class ModelConfigNotFoundError extends Error {
|
||||
/**
|
||||
* Create ModelConfigNotFoundError
|
||||
* @param {String} message
|
||||
*/
|
||||
constructor(message: string) {
|
||||
super(message);
|
||||
this.name = this.constructor.name;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Raised when a fetched model config is of an unexpected type.
|
||||
*/
|
||||
export class InvalidModelConfigError extends Error {
|
||||
/**
|
||||
* Create InvalidModelConfigError
|
||||
* @param {String} message
|
||||
*/
|
||||
constructor(message: string) {
|
||||
super(message);
|
||||
this.name = this.constructor.name;
|
||||
}
|
||||
}
|
||||
|
||||
export const fetchModelConfig = async (key: string): Promise<AnyModelConfig> => {
|
||||
const { dispatch } = getStore();
|
||||
try {
|
||||
const req = dispatch(modelsApi.endpoints.getModelConfig.initiate(key));
|
||||
req.unsubscribe();
|
||||
return await req.unwrap();
|
||||
} catch {
|
||||
throw new ModelConfigNotFoundError(`Unable to retrieve model config for key ${key}`);
|
||||
}
|
||||
};
|
||||
|
||||
export const fetchModelConfigWithTypeGuard = async <T extends AnyModelConfig>(
|
||||
key: string,
|
||||
typeGuard: (config: AnyModelConfig) => config is T
|
||||
) => {
|
||||
const modelConfig = await fetchModelConfig(key);
|
||||
if (!typeGuard(modelConfig)) {
|
||||
throw new InvalidModelConfigError(`Invalid model type for key ${key}: ${modelConfig.type}`);
|
||||
}
|
||||
return modelConfig;
|
||||
};
|
||||
|
||||
export const fetchMainModel = async (key: string) => {
|
||||
return fetchModelConfigWithTypeGuard(key, isNonRefinerMainModelConfig);
|
||||
};
|
||||
|
||||
export const fetchRefinerModel = async (key: string) => {
|
||||
return fetchModelConfigWithTypeGuard(key, isRefinerMainModelModelConfig);
|
||||
};
|
||||
|
||||
export const fetchVAEModel = async (key: string) => {
|
||||
return fetchModelConfigWithTypeGuard(key, isVAEModelConfig);
|
||||
};
|
||||
|
||||
export const fetchLoRAModel = async (key: string) => {
|
||||
return fetchModelConfigWithTypeGuard(key, isLoRAModelConfig);
|
||||
};
|
||||
|
||||
export const fetchControlNetModel = async (key: string) => {
|
||||
return fetchModelConfigWithTypeGuard(key, isControlNetModelConfig);
|
||||
};
|
||||
|
||||
export const fetchIPAdapterModel = async (key: string) => {
|
||||
return fetchModelConfigWithTypeGuard(key, isIPAdapterModelConfig);
|
||||
};
|
||||
|
||||
export const fetchT2IAdapterModel = async (key: string) => {
|
||||
return fetchModelConfigWithTypeGuard(key, isT2IAdapterModelConfig);
|
||||
};
|
||||
|
||||
export const fetchTextualInversionModel = async (key: string) => {
|
||||
return fetchModelConfigWithTypeGuard(key, isTextualInversionModelConfig);
|
||||
};
|
||||
|
||||
export const isBaseCompatible = (sourceBase: BaseModelType, targetBase: BaseModelType) => {
|
||||
return sourceBase === targetBase;
|
||||
};
|
||||
|
||||
export const raiseIfBaseIncompatible = (sourceBase: BaseModelType, targetBase?: BaseModelType, message?: string) => {
|
||||
if (targetBase && !isBaseCompatible(sourceBase, targetBase)) {
|
||||
throw new InvalidModelConfigError(message || `Incompatible base models: ${sourceBase} and ${targetBase}`);
|
||||
}
|
||||
};
|
||||
|
||||
export const getModelKey = (modelIdentifier: unknown, message?: string): string => {
|
||||
if (!isModelIdentifier(modelIdentifier)) {
|
||||
throw new InvalidModelConfigError(message || `Invalid model identifier: ${modelIdentifier}`);
|
||||
}
|
||||
return modelIdentifier.key;
|
||||
};
|
@ -0,0 +1,150 @@
|
||||
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
|
||||
import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types';
|
||||
import {
|
||||
initialControlNet,
|
||||
initialIPAdapter,
|
||||
initialT2IAdapter,
|
||||
} from 'features/controlAdapters/util/buildControlAdapter';
|
||||
import type { LoRA } from 'features/lora/store/loraSlice';
|
||||
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
|
||||
import { zModelIdentifierWithBase } from 'features/nodes/types/common';
|
||||
import type {
|
||||
ControlNetMetadataItem,
|
||||
IPAdapterMetadataItem,
|
||||
LoRAMetadataItem,
|
||||
T2IAdapterMetadataItem,
|
||||
} from 'features/nodes/types/metadata';
|
||||
import {
|
||||
fetchControlNetModel,
|
||||
fetchIPAdapterModel,
|
||||
fetchLoRAModel,
|
||||
fetchMainModel,
|
||||
fetchRefinerModel,
|
||||
fetchT2IAdapterModel,
|
||||
fetchVAEModel,
|
||||
getModelKey,
|
||||
raiseIfBaseIncompatible,
|
||||
} from 'features/parameters/util/modelFetchingHelpers';
|
||||
import type { BaseModelType } from 'services/api/types';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
export const prepareMainModelMetadataItem = async (model: unknown): Promise<ModelIdentifierWithBase> => {
|
||||
const key = getModelKey(model);
|
||||
const mainModel = await fetchMainModel(key);
|
||||
return zModelIdentifierWithBase.parse(mainModel);
|
||||
};
|
||||
|
||||
export const prepareRefinerMetadataItem = async (model: unknown): Promise<ModelIdentifierWithBase> => {
|
||||
const key = getModelKey(model);
|
||||
const refinerModel = await fetchRefinerModel(key);
|
||||
return zModelIdentifierWithBase.parse(refinerModel);
|
||||
};
|
||||
|
||||
export const prepareVAEMetadataItem = async (vae: unknown, base?: BaseModelType): Promise<ModelIdentifierWithBase> => {
|
||||
const key = getModelKey(vae);
|
||||
const vaeModel = await fetchVAEModel(key);
|
||||
raiseIfBaseIncompatible(vaeModel.base, base, 'VAE incompatible with currently-selected model');
|
||||
return zModelIdentifierWithBase.parse(vaeModel);
|
||||
};
|
||||
|
||||
export const prepareLoRAMetadataItem = async (
|
||||
loraMetadataItem: LoRAMetadataItem,
|
||||
base?: BaseModelType
|
||||
): Promise<LoRA> => {
|
||||
const key = getModelKey(loraMetadataItem.lora);
|
||||
const loraModel = await fetchLoRAModel(key);
|
||||
raiseIfBaseIncompatible(loraModel.base, base, 'LoRA incompatible with currently-selected model');
|
||||
return { key: loraModel.key, base: loraModel.base, weight: loraMetadataItem.weight, isEnabled: true };
|
||||
};
|
||||
|
||||
export const prepareControlNetMetadataItem = async (
|
||||
controlnetMetadataItem: ControlNetMetadataItem,
|
||||
base?: BaseModelType
|
||||
): Promise<ControlNetConfig> => {
|
||||
const key = getModelKey(controlnetMetadataItem.control_model);
|
||||
const controlNetModel = await fetchControlNetModel(key);
|
||||
raiseIfBaseIncompatible(controlNetModel.base, base, 'ControlNet incompatible with currently-selected model');
|
||||
|
||||
const { image, control_weight, begin_step_percent, end_step_percent, control_mode, resize_mode } =
|
||||
controlnetMetadataItem;
|
||||
|
||||
// We don't save the original image that was processed into a control image, only the processed image
|
||||
const processorType = 'none';
|
||||
const processorNode = CONTROLNET_PROCESSORS.none.default;
|
||||
|
||||
const controlnet: ControlNetConfig = {
|
||||
type: 'controlnet',
|
||||
isEnabled: true,
|
||||
model: zModelIdentifierWithBase.parse(controlNetModel),
|
||||
weight: typeof control_weight === 'number' ? control_weight : initialControlNet.weight,
|
||||
beginStepPct: begin_step_percent || initialControlNet.beginStepPct,
|
||||
endStepPct: end_step_percent || initialControlNet.endStepPct,
|
||||
controlMode: control_mode || initialControlNet.controlMode,
|
||||
resizeMode: resize_mode || initialControlNet.resizeMode,
|
||||
controlImage: image?.image_name || null,
|
||||
processedControlImage: image?.image_name || null,
|
||||
processorType,
|
||||
processorNode,
|
||||
shouldAutoConfig: true,
|
||||
id: uuidv4(),
|
||||
};
|
||||
|
||||
return controlnet;
|
||||
};
|
||||
|
||||
export const prepareT2IAdapterMetadataItem = async (
|
||||
t2iAdapterMetadataItem: T2IAdapterMetadataItem,
|
||||
base?: BaseModelType
|
||||
): Promise<T2IAdapterConfig> => {
|
||||
const key = getModelKey(t2iAdapterMetadataItem.t2i_adapter_model);
|
||||
const t2iAdapterModel = await fetchT2IAdapterModel(key);
|
||||
raiseIfBaseIncompatible(t2iAdapterModel.base, base, 'T2I Adapter incompatible with currently-selected model');
|
||||
|
||||
const { image, weight, begin_step_percent, end_step_percent, resize_mode } = t2iAdapterMetadataItem;
|
||||
|
||||
// We don't save the original image that was processed into a control image, only the processed image
|
||||
const processorType = 'none';
|
||||
const processorNode = CONTROLNET_PROCESSORS.none.default;
|
||||
|
||||
const t2iAdapter: T2IAdapterConfig = {
|
||||
type: 't2i_adapter',
|
||||
isEnabled: true,
|
||||
model: zModelIdentifierWithBase.parse(t2iAdapterModel),
|
||||
weight: typeof weight === 'number' ? weight : initialT2IAdapter.weight,
|
||||
beginStepPct: begin_step_percent || initialT2IAdapter.beginStepPct,
|
||||
endStepPct: end_step_percent || initialT2IAdapter.endStepPct,
|
||||
resizeMode: resize_mode || initialT2IAdapter.resizeMode,
|
||||
controlImage: image?.image_name || null,
|
||||
processedControlImage: image?.image_name || null,
|
||||
processorType,
|
||||
processorNode,
|
||||
shouldAutoConfig: true,
|
||||
id: uuidv4(),
|
||||
};
|
||||
|
||||
return t2iAdapter;
|
||||
};
|
||||
|
||||
export const prepareIPAdapterMetadataItem = async (
|
||||
ipAdapterMetadataItem: IPAdapterMetadataItem,
|
||||
base?: BaseModelType
|
||||
): Promise<IPAdapterConfig> => {
|
||||
const key = getModelKey(ipAdapterMetadataItem?.ip_adapter_model);
|
||||
const ipAdapterModel = await fetchIPAdapterModel(key);
|
||||
raiseIfBaseIncompatible(ipAdapterModel.base, base, 'T2I Adapter incompatible with currently-selected model');
|
||||
|
||||
const { image, weight, begin_step_percent, end_step_percent } = ipAdapterMetadataItem;
|
||||
|
||||
const ipAdapter: IPAdapterConfig = {
|
||||
id: uuidv4(),
|
||||
type: 'ip_adapter',
|
||||
isEnabled: true,
|
||||
controlImage: image?.image_name ?? null,
|
||||
model: zModelIdentifierWithBase.parse(ipAdapterModel),
|
||||
weight: weight ?? initialIPAdapter.weight,
|
||||
beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct,
|
||||
endStepPct: end_step_percent ?? initialIPAdapter.endStepPct,
|
||||
};
|
||||
|
||||
return ipAdapter;
|
||||
};
|
@ -6,16 +6,16 @@ import type { operations, paths } from 'services/api/schema';
|
||||
import type {
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ControlNetConfig,
|
||||
ControlNetModelConfig,
|
||||
ImportModelConfig,
|
||||
IPAdapterConfig,
|
||||
LoRAConfig,
|
||||
IPAdapterModelConfig,
|
||||
LoRAModelConfig,
|
||||
MainModelConfig,
|
||||
MergeModelConfig,
|
||||
ModelType,
|
||||
T2IAdapterConfig,
|
||||
TextualInversionConfig,
|
||||
VAEConfig,
|
||||
T2IAdapterModelConfig,
|
||||
TextualInversionModelConfig,
|
||||
VAEModelConfig,
|
||||
} from 'services/api/types';
|
||||
|
||||
import type { ApiTagDescription, tagTypes } from '..';
|
||||
@ -30,7 +30,7 @@ type UpdateMainModelArg = {
|
||||
type UpdateLoRAModelArg = {
|
||||
base_model: BaseModelType;
|
||||
model_name: string;
|
||||
body: LoRAConfig;
|
||||
body: LoRAModelConfig;
|
||||
};
|
||||
|
||||
type UpdateMainModelResponse =
|
||||
@ -97,27 +97,27 @@ export const mainModelsAdapter = createEntityAdapter<MainModelConfig, string>({
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
export const mainModelsAdapterSelectors = mainModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||
export const loraModelsAdapter = createEntityAdapter<LoRAConfig, string>({
|
||||
export const loraModelsAdapter = createEntityAdapter<LoRAModelConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
export const loraModelsAdapterSelectors = loraModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||
export const controlNetModelsAdapter = createEntityAdapter<ControlNetConfig, string>({
|
||||
export const controlNetModelsAdapter = createEntityAdapter<ControlNetModelConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
export const controlNetModelsAdapterSelectors = controlNetModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||
export const ipAdapterModelsAdapter = createEntityAdapter<IPAdapterConfig, string>({
|
||||
export const ipAdapterModelsAdapter = createEntityAdapter<IPAdapterModelConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
export const ipAdapterModelsAdapterSelectors = ipAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||
export const t2iAdapterModelsAdapter = createEntityAdapter<T2IAdapterConfig, string>({
|
||||
export const t2iAdapterModelsAdapter = createEntityAdapter<T2IAdapterModelConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
export const t2iAdapterModelsAdapterSelectors = t2iAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||
export const textualInversionModelsAdapter = createEntityAdapter<TextualInversionConfig, string>({
|
||||
export const textualInversionModelsAdapter = createEntityAdapter<TextualInversionModelConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
@ -125,7 +125,7 @@ export const textualInversionModelsAdapterSelectors = textualInversionModelsAdap
|
||||
undefined,
|
||||
getSelectorsOptions
|
||||
);
|
||||
export const vaeModelsAdapter = createEntityAdapter<VAEConfig, string>({
|
||||
export const vaeModelsAdapter = createEntityAdapter<VAEModelConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
@ -162,6 +162,8 @@ const buildTransformResponse =
|
||||
*/
|
||||
const buildModelsUrl = (path: string = '') => buildV2Url(`models/${path}`);
|
||||
|
||||
// TODO(psyche): Ideally we can share the cache between the `getXYZModels` queries and `getModelConfig` query
|
||||
|
||||
export const modelsApi = api.injectEndpoints({
|
||||
endpoints: (build) => ({
|
||||
getMainModels: build.query<EntityState<MainModelConfig, string>, BaseModelType[]>({
|
||||
@ -257,10 +259,10 @@ export const modelsApi = api.injectEndpoints({
|
||||
},
|
||||
invalidatesTags: ['Model'],
|
||||
}),
|
||||
getLoRAModels: build.query<EntityState<LoRAConfig, string>, void>({
|
||||
getLoRAModels: build.query<EntityState<LoRAModelConfig, string>, void>({
|
||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 'lora' } }),
|
||||
providesTags: buildProvidesTags<LoRAConfig>('LoRAModel'),
|
||||
transformResponse: buildTransformResponse<LoRAConfig>(loraModelsAdapter),
|
||||
providesTags: buildProvidesTags<LoRAModelConfig>('LoRAModel'),
|
||||
transformResponse: buildTransformResponse<LoRAModelConfig>(loraModelsAdapter),
|
||||
}),
|
||||
updateLoRAModels: build.mutation<UpdateLoRAModelResponse, UpdateLoRAModelArg>({
|
||||
query: ({ base_model, model_name, body }) => {
|
||||
@ -281,30 +283,30 @@ export const modelsApi = api.injectEndpoints({
|
||||
},
|
||||
invalidatesTags: [{ type: 'LoRAModel', id: LIST_TAG }],
|
||||
}),
|
||||
getControlNetModels: build.query<EntityState<ControlNetConfig, string>, void>({
|
||||
getControlNetModels: build.query<EntityState<ControlNetModelConfig, string>, void>({
|
||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 'controlnet' } }),
|
||||
providesTags: buildProvidesTags<ControlNetConfig>('ControlNetModel'),
|
||||
transformResponse: buildTransformResponse<ControlNetConfig>(controlNetModelsAdapter),
|
||||
providesTags: buildProvidesTags<ControlNetModelConfig>('ControlNetModel'),
|
||||
transformResponse: buildTransformResponse<ControlNetModelConfig>(controlNetModelsAdapter),
|
||||
}),
|
||||
getIPAdapterModels: build.query<EntityState<IPAdapterConfig, string>, void>({
|
||||
getIPAdapterModels: build.query<EntityState<IPAdapterModelConfig, string>, void>({
|
||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 'ip_adapter' } }),
|
||||
providesTags: buildProvidesTags<IPAdapterConfig>('IPAdapterModel'),
|
||||
transformResponse: buildTransformResponse<IPAdapterConfig>(ipAdapterModelsAdapter),
|
||||
providesTags: buildProvidesTags<IPAdapterModelConfig>('IPAdapterModel'),
|
||||
transformResponse: buildTransformResponse<IPAdapterModelConfig>(ipAdapterModelsAdapter),
|
||||
}),
|
||||
getT2IAdapterModels: build.query<EntityState<T2IAdapterConfig, string>, void>({
|
||||
getT2IAdapterModels: build.query<EntityState<T2IAdapterModelConfig, string>, void>({
|
||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 't2i_adapter' } }),
|
||||
providesTags: buildProvidesTags<T2IAdapterConfig>('T2IAdapterModel'),
|
||||
transformResponse: buildTransformResponse<T2IAdapterConfig>(t2iAdapterModelsAdapter),
|
||||
providesTags: buildProvidesTags<T2IAdapterModelConfig>('T2IAdapterModel'),
|
||||
transformResponse: buildTransformResponse<T2IAdapterModelConfig>(t2iAdapterModelsAdapter),
|
||||
}),
|
||||
getVaeModels: build.query<EntityState<VAEConfig, string>, void>({
|
||||
getVaeModels: build.query<EntityState<VAEModelConfig, string>, void>({
|
||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 'vae' } }),
|
||||
providesTags: buildProvidesTags<VAEConfig>('VaeModel'),
|
||||
transformResponse: buildTransformResponse<VAEConfig>(vaeModelsAdapter),
|
||||
providesTags: buildProvidesTags<VAEModelConfig>('VaeModel'),
|
||||
transformResponse: buildTransformResponse<VAEModelConfig>(vaeModelsAdapter),
|
||||
}),
|
||||
getTextualInversionModels: build.query<EntityState<TextualInversionConfig, string>, void>({
|
||||
getTextualInversionModels: build.query<EntityState<TextualInversionModelConfig, string>, void>({
|
||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 'embedding' } }),
|
||||
providesTags: buildProvidesTags<TextualInversionConfig>('TextualInversionModel'),
|
||||
transformResponse: buildTransformResponse<TextualInversionConfig>(textualInversionModelsAdapter),
|
||||
providesTags: buildProvidesTags<TextualInversionModelConfig>('TextualInversionModel'),
|
||||
transformResponse: buildTransformResponse<TextualInversionModelConfig>(textualInversionModelsAdapter),
|
||||
}),
|
||||
getModelsInFolder: build.query<SearchFolderResponse, SearchFolderArg>({
|
||||
query: (arg) => {
|
||||
|
Loading…
Reference in New Issue
Block a user