mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): refactor metadata handling
Refactor of metadata recall handling. This is in preparation for a backwards compatibility layer for models. - Create helpers to fetch a model outside react (e.g. not in a hook) - Created helpers to parse model metadata - Renamed a lot of types that were confusing and/or had naming collisions
This commit is contained in:
parent
79b16596b5
commit
3ed2963f43
@ -8,4 +8,26 @@ declare global {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Raised when the redux store is unable to be retrieved.
|
||||||
|
*/
|
||||||
|
export class ReduxStoreNotInitialized extends Error {
|
||||||
|
/**
|
||||||
|
* Create ReduxStoreNotInitialized
|
||||||
|
* @param {String} message
|
||||||
|
*/
|
||||||
|
constructor(message = 'Redux store not initialized') {
|
||||||
|
super(message);
|
||||||
|
this.name = this.constructor.name;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
export const $store = atom<Readonly<ReturnType<typeof createStore>> | undefined>();
|
export const $store = atom<Readonly<ReturnType<typeof createStore>> | undefined>();
|
||||||
|
|
||||||
|
export const getStore = () => {
|
||||||
|
const store = $store.get();
|
||||||
|
if (!store) {
|
||||||
|
throw new ReduxStoreNotInitialized();
|
||||||
|
}
|
||||||
|
return store;
|
||||||
|
};
|
||||||
|
@ -8,7 +8,7 @@ import { useControlAdapterType } from 'features/controlAdapters/hooks/useControl
|
|||||||
import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
|
import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||||
import { pick } from 'lodash-es';
|
import { pick } from 'lodash-es';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'services/api/types';
|
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
type ParamControlAdapterModelProps = {
|
type ParamControlAdapterModelProps = {
|
||||||
id: string;
|
id: string;
|
||||||
@ -24,7 +24,7 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
|
|||||||
const { data, isLoading } = useControlAdapterModelQuery(controlAdapterType);
|
const { data, isLoading } = useControlAdapterModelQuery(controlAdapterType);
|
||||||
|
|
||||||
const _onChange = useCallback(
|
const _onChange = useCallback(
|
||||||
(model: ControlNetConfig | IPAdapterConfig | T2IAdapterConfig | null) => {
|
(model: ControlNetModelConfig | IPAdapterModelConfig | T2IAdapterModelConfig | null) => {
|
||||||
if (!model) {
|
if (!model) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -7,7 +7,7 @@ import { t } from 'i18next';
|
|||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
|
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
|
||||||
import type { TextualInversionConfig } from 'services/api/types';
|
import type { TextualInversionModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
const noOptionsMessage = () => t('embedding.noMatchingEmbedding');
|
const noOptionsMessage = () => t('embedding.noMatchingEmbedding');
|
||||||
|
|
||||||
@ -17,7 +17,7 @@ export const EmbeddingSelect = memo(({ onSelect, onClose }: EmbeddingSelectProps
|
|||||||
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
||||||
|
|
||||||
const getIsDisabled = useCallback(
|
const getIsDisabled = useCallback(
|
||||||
(embedding: TextualInversionConfig): boolean => {
|
(embedding: TextualInversionModelConfig): boolean => {
|
||||||
const isCompatible = currentBaseModel === embedding.base;
|
const isCompatible = currentBaseModel === embedding.base;
|
||||||
const hasMainModel = Boolean(currentBaseModel);
|
const hasMainModel = Boolean(currentBaseModel);
|
||||||
return !hasMainModel || !isCompatible;
|
return !hasMainModel || !isCompatible;
|
||||||
@ -27,7 +27,7 @@ export const EmbeddingSelect = memo(({ onSelect, onClose }: EmbeddingSelectProps
|
|||||||
const { data, isLoading } = useGetTextualInversionModelsQuery();
|
const { data, isLoading } = useGetTextualInversionModelsQuery();
|
||||||
|
|
||||||
const _onChange = useCallback(
|
const _onChange = useCallback(
|
||||||
(embedding: TextualInversionConfig | null) => {
|
(embedding: TextualInversionModelConfig | null) => {
|
||||||
if (!embedding) {
|
if (!embedding) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -8,7 +8,7 @@ import { loraAdded, selectLoraSlice } from 'features/lora/store/loraSlice';
|
|||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
|
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
|
||||||
import type { LoRAConfig } from 'services/api/types';
|
import type { LoRAModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
const selectAddedLoRAs = createMemoizedSelector(selectLoraSlice, (lora) => lora.loras);
|
const selectAddedLoRAs = createMemoizedSelector(selectLoraSlice, (lora) => lora.loras);
|
||||||
|
|
||||||
@ -19,7 +19,7 @@ const LoRASelect = () => {
|
|||||||
const addedLoRAs = useAppSelector(selectAddedLoRAs);
|
const addedLoRAs = useAppSelector(selectAddedLoRAs);
|
||||||
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
||||||
|
|
||||||
const getIsDisabled = (lora: LoRAConfig): boolean => {
|
const getIsDisabled = (lora: LoRAModelConfig): boolean => {
|
||||||
const isCompatible = currentBaseModel === lora.base;
|
const isCompatible = currentBaseModel === lora.base;
|
||||||
const isAdded = Boolean(addedLoRAs[lora.key]);
|
const isAdded = Boolean(addedLoRAs[lora.key]);
|
||||||
const hasMainModel = Boolean(currentBaseModel);
|
const hasMainModel = Boolean(currentBaseModel);
|
||||||
@ -27,7 +27,7 @@ const LoRASelect = () => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const _onChange = useCallback(
|
const _onChange = useCallback(
|
||||||
(lora: LoRAConfig | null) => {
|
(lora: LoRAModelConfig | null) => {
|
||||||
if (!lora) {
|
if (!lora) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -2,7 +2,7 @@ import type { PayloadAction } from '@reduxjs/toolkit';
|
|||||||
import { createSlice } from '@reduxjs/toolkit';
|
import { createSlice } from '@reduxjs/toolkit';
|
||||||
import type { PersistConfig, RootState } from 'app/store/store';
|
import type { PersistConfig, RootState } from 'app/store/store';
|
||||||
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
|
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
|
||||||
import type { LoRAConfig } from 'services/api/types';
|
import type { LoRAModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
export type LoRA = ParameterLoRAModel & {
|
export type LoRA = ParameterLoRAModel & {
|
||||||
weight: number;
|
weight: number;
|
||||||
@ -28,13 +28,12 @@ export const loraSlice = createSlice({
|
|||||||
name: 'lora',
|
name: 'lora',
|
||||||
initialState: initialLoraState,
|
initialState: initialLoraState,
|
||||||
reducers: {
|
reducers: {
|
||||||
loraAdded: (state, action: PayloadAction<LoRAConfig>) => {
|
loraAdded: (state, action: PayloadAction<LoRAModelConfig>) => {
|
||||||
const { key, base } = action.payload;
|
const { key, base } = action.payload;
|
||||||
state.loras[key] = { key, base, ...defaultLoRAConfig };
|
state.loras[key] = { key, base, ...defaultLoRAConfig };
|
||||||
},
|
},
|
||||||
loraRecalled: (state, action: PayloadAction<LoRAConfig & { weight: number }>) => {
|
loraRecalled: (state, action: PayloadAction<LoRA>) => {
|
||||||
const { key, base, weight } = action.payload;
|
state.loras[action.payload.key] = action.payload;
|
||||||
state.loras[key] = { key, base, weight, isEnabled: true };
|
|
||||||
},
|
},
|
||||||
loraRemoved: (state, action: PayloadAction<string>) => {
|
loraRemoved: (state, action: PayloadAction<string>) => {
|
||||||
const key = action.payload;
|
const key = action.payload;
|
||||||
|
@ -8,12 +8,11 @@ import { memo, useCallback } from 'react';
|
|||||||
import type { SubmitHandler } from 'react-hook-form';
|
import type { SubmitHandler } from 'react-hook-form';
|
||||||
import { useForm } from 'react-hook-form';
|
import { useForm } from 'react-hook-form';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import type { LoRAConfig } from 'services/api/endpoints/models';
|
|
||||||
import { useUpdateLoRAModelsMutation } from 'services/api/endpoints/models';
|
import { useUpdateLoRAModelsMutation } from 'services/api/endpoints/models';
|
||||||
import type { LoRAConfig } from 'services/api/types';
|
import type { LoRAModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
type LoRAModelEditProps = {
|
type LoRAModelEditProps = {
|
||||||
model: LoRAConfig;
|
model: LoRAModelConfig;
|
||||||
};
|
};
|
||||||
|
|
||||||
const LoRAModelEdit = (props: LoRAModelEditProps) => {
|
const LoRAModelEdit = (props: LoRAModelEditProps) => {
|
||||||
@ -30,7 +29,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
|
|||||||
control,
|
control,
|
||||||
formState: { errors },
|
formState: { errors },
|
||||||
reset,
|
reset,
|
||||||
} = useForm<LoRAConfig>({
|
} = useForm<LoRAModelConfig>({
|
||||||
defaultValues: {
|
defaultValues: {
|
||||||
model_name: model.model_name ? model.model_name : '',
|
model_name: model.model_name ? model.model_name : '',
|
||||||
base_model: model.base_model,
|
base_model: model.base_model,
|
||||||
@ -42,7 +41,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
|
|||||||
mode: 'onChange',
|
mode: 'onChange',
|
||||||
});
|
});
|
||||||
|
|
||||||
const onSubmit = useCallback<SubmitHandler<LoRAConfig>>(
|
const onSubmit = useCallback<SubmitHandler<LoRAModelConfig>>(
|
||||||
(values) => {
|
(values) => {
|
||||||
const responseBody = {
|
const responseBody = {
|
||||||
base_model: model.base_model,
|
base_model: model.base_model,
|
||||||
@ -53,7 +52,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
|
|||||||
updateLoRAModel(responseBody)
|
updateLoRAModel(responseBody)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.then((payload) => {
|
.then((payload) => {
|
||||||
reset(payload as LoRAConfig, { keepDefaultValues: true });
|
reset(payload as LoRAModelConfig, { keepDefaultValues: true });
|
||||||
dispatch(
|
dispatch(
|
||||||
addToast(
|
addToast(
|
||||||
makeToast({
|
makeToast({
|
||||||
@ -106,7 +105,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
|
|||||||
<FormLabel>{t('modelManager.description')}</FormLabel>
|
<FormLabel>{t('modelManager.description')}</FormLabel>
|
||||||
<Input {...register('description')} />
|
<Input {...register('description')} />
|
||||||
</FormControl>
|
</FormControl>
|
||||||
<BaseModelSelect<LoRAConfig> control={control} name="base_model" />
|
<BaseModelSelect<LoRAModelConfig> control={control} name="base_model" />
|
||||||
|
|
||||||
<FormControl isInvalid={Boolean(errors.path)}>
|
<FormControl isInvalid={Boolean(errors.path)}>
|
||||||
<FormLabel>{t('modelManager.modelLocation')}</FormLabel>
|
<FormLabel>{t('modelManager.modelLocation')}</FormLabel>
|
||||||
|
@ -6,7 +6,7 @@ import type { ControlNetModelFieldInputInstance, ControlNetModelFieldInputTempla
|
|||||||
import { pick } from 'lodash-es';
|
import { pick } from 'lodash-es';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
|
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
|
||||||
import type { ControlNetConfig } from 'services/api/types';
|
import type { ControlNetModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
import type { FieldComponentProps } from './types';
|
import type { FieldComponentProps } from './types';
|
||||||
|
|
||||||
@ -18,7 +18,7 @@ const ControlNetModelFieldInputComponent = (props: Props) => {
|
|||||||
const { data, isLoading } = useGetControlNetModelsQuery();
|
const { data, isLoading } = useGetControlNetModelsQuery();
|
||||||
|
|
||||||
const _onChange = useCallback(
|
const _onChange = useCallback(
|
||||||
(value: ControlNetConfig | null) => {
|
(value: ControlNetModelConfig | null) => {
|
||||||
if (!value) {
|
if (!value) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -6,7 +6,7 @@ import type { IPAdapterModelFieldInputInstance, IPAdapterModelFieldInputTemplate
|
|||||||
import { pick } from 'lodash-es';
|
import { pick } from 'lodash-es';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useGetIPAdapterModelsQuery } from 'services/api/endpoints/models';
|
import { useGetIPAdapterModelsQuery } from 'services/api/endpoints/models';
|
||||||
import type { IPAdapterConfig } from 'services/api/types';
|
import type { IPAdapterModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
import type { FieldComponentProps } from './types';
|
import type { FieldComponentProps } from './types';
|
||||||
|
|
||||||
@ -18,7 +18,7 @@ const IPAdapterModelFieldInputComponent = (
|
|||||||
const { data: ipAdapterModels } = useGetIPAdapterModelsQuery();
|
const { data: ipAdapterModels } = useGetIPAdapterModelsQuery();
|
||||||
|
|
||||||
const _onChange = useCallback(
|
const _onChange = useCallback(
|
||||||
(value: IPAdapterConfig | null) => {
|
(value: IPAdapterModelConfig | null) => {
|
||||||
if (!value) {
|
if (!value) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -6,7 +6,7 @@ import type { LoRAModelFieldInputInstance, LoRAModelFieldInputTemplate } from 'f
|
|||||||
import { pick } from 'lodash-es';
|
import { pick } from 'lodash-es';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
|
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
|
||||||
import type { LoRAConfig } from 'services/api/types';
|
import type { LoRAModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
import type { FieldComponentProps } from './types';
|
import type { FieldComponentProps } from './types';
|
||||||
|
|
||||||
@ -17,7 +17,7 @@ const LoRAModelFieldInputComponent = (props: Props) => {
|
|||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { data, isLoading } = useGetLoRAModelsQuery();
|
const { data, isLoading } = useGetLoRAModelsQuery();
|
||||||
const _onChange = useCallback(
|
const _onChange = useCallback(
|
||||||
(value: LoRAConfig | null) => {
|
(value: LoRAModelConfig | null) => {
|
||||||
if (!value) {
|
if (!value) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -6,7 +6,7 @@ import type { T2IAdapterModelFieldInputInstance, T2IAdapterModelFieldInputTempla
|
|||||||
import { pick } from 'lodash-es';
|
import { pick } from 'lodash-es';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useGetT2IAdapterModelsQuery } from 'services/api/endpoints/models';
|
import { useGetT2IAdapterModelsQuery } from 'services/api/endpoints/models';
|
||||||
import type { T2IAdapterConfig } from 'services/api/types';
|
import type { T2IAdapterModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
import type { FieldComponentProps } from './types';
|
import type { FieldComponentProps } from './types';
|
||||||
|
|
||||||
@ -19,7 +19,7 @@ const T2IAdapterModelFieldInputComponent = (
|
|||||||
const { data: t2iAdapterModels } = useGetT2IAdapterModelsQuery();
|
const { data: t2iAdapterModels } = useGetT2IAdapterModelsQuery();
|
||||||
|
|
||||||
const _onChange = useCallback(
|
const _onChange = useCallback(
|
||||||
(value: T2IAdapterConfig | null) => {
|
(value: T2IAdapterModelConfig | null) => {
|
||||||
if (!value) {
|
if (!value) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -7,7 +7,7 @@ import type { VAEModelFieldInputInstance, VAEModelFieldInputTemplate } from 'fea
|
|||||||
import { pick } from 'lodash-es';
|
import { pick } from 'lodash-es';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
||||||
import type { VAEConfig } from 'services/api/types';
|
import type { VAEModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
import type { FieldComponentProps } from './types';
|
import type { FieldComponentProps } from './types';
|
||||||
|
|
||||||
@ -18,7 +18,7 @@ const VAEModelFieldInputComponent = (props: Props) => {
|
|||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { data, isLoading } = useGetVaeModelsQuery();
|
const { data, isLoading } = useGetVaeModelsQuery();
|
||||||
const _onChange = useCallback(
|
const _onChange = useCallback(
|
||||||
(value: VAEConfig | null) => {
|
(value: VAEModelConfig | null) => {
|
||||||
if (!value) {
|
if (!value) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -8,7 +8,7 @@ import { pick } from 'lodash-es';
|
|||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
||||||
import type { VAEConfig } from 'services/api/types';
|
import type { VAEModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
const selector = createMemoizedSelector(selectGenerationSlice, (generation) => {
|
const selector = createMemoizedSelector(selectGenerationSlice, (generation) => {
|
||||||
const { model, vae } = generation;
|
const { model, vae } = generation;
|
||||||
@ -21,7 +21,7 @@ const ParamVAEModelSelect = () => {
|
|||||||
const { model, vae } = useAppSelector(selector);
|
const { model, vae } = useAppSelector(selector);
|
||||||
const { data, isLoading } = useGetVaeModelsQuery();
|
const { data, isLoading } = useGetVaeModelsQuery();
|
||||||
const getIsDisabled = useCallback(
|
const getIsDisabled = useCallback(
|
||||||
(vae: VAEConfig): boolean => {
|
(vae: VAEModelConfig): boolean => {
|
||||||
const isCompatible = model?.base === vae.base;
|
const isCompatible = model?.base === vae.base;
|
||||||
const hasMainModel = Boolean(model?.base);
|
const hasMainModel = Boolean(model?.base);
|
||||||
return !hasMainModel || !isCompatible;
|
return !hasMainModel || !isCompatible;
|
||||||
@ -29,7 +29,7 @@ const ParamVAEModelSelect = () => {
|
|||||||
[model?.base]
|
[model?.base]
|
||||||
);
|
);
|
||||||
const _onChange = useCallback(
|
const _onChange = useCallback(
|
||||||
(vae: VAEConfig | null) => {
|
(vae: VAEModelConfig | null) => {
|
||||||
dispatch(vaeSelected(vae ? pick(vae, 'key', 'base') : null));
|
dispatch(vaeSelected(vae ? pick(vae, 'key', 'base') : null));
|
||||||
},
|
},
|
||||||
[dispatch]
|
[dispatch]
|
||||||
|
@ -1,17 +1,9 @@
|
|||||||
import { useAppToaster } from 'app/components/Toaster';
|
import { useAppToaster } from 'app/components/Toaster';
|
||||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
|
|
||||||
import { controlAdapterRecalled, controlAdaptersReset } from 'features/controlAdapters/store/controlAdaptersSlice';
|
import { controlAdapterRecalled, controlAdaptersReset } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||||
import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types';
|
|
||||||
import {
|
|
||||||
initialControlNet,
|
|
||||||
initialIPAdapter,
|
|
||||||
initialT2IAdapter,
|
|
||||||
} from 'features/controlAdapters/util/buildControlAdapter';
|
|
||||||
import { setHrfEnabled, setHrfMethod, setHrfStrength } from 'features/hrf/store/hrfSlice';
|
import { setHrfEnabled, setHrfMethod, setHrfStrength } from 'features/hrf/store/hrfSlice';
|
||||||
import { loraRecalled, lorasCleared } from 'features/lora/store/loraSlice';
|
import { loraRecalled, lorasCleared } from 'features/lora/store/loraSlice';
|
||||||
import type { ModelIdentifier } from 'features/nodes/types/common';
|
|
||||||
import { isModelIdentifier } from 'features/nodes/types/common';
|
import { isModelIdentifier } from 'features/nodes/types/common';
|
||||||
import type {
|
import type {
|
||||||
ControlNetMetadataItem,
|
ControlNetMetadataItem,
|
||||||
@ -56,6 +48,14 @@ import {
|
|||||||
isParameterStrength,
|
isParameterStrength,
|
||||||
isParameterWidth,
|
isParameterWidth,
|
||||||
} from 'features/parameters/types/parameterSchemas';
|
} from 'features/parameters/types/parameterSchemas';
|
||||||
|
import {
|
||||||
|
prepareControlNetMetadataItem,
|
||||||
|
prepareIPAdapterMetadataItem,
|
||||||
|
prepareLoRAMetadataItem,
|
||||||
|
prepareMainModelMetadataItem,
|
||||||
|
prepareT2IAdapterMetadataItem,
|
||||||
|
prepareVAEMetadataItem,
|
||||||
|
} from 'features/parameters/util/modelMetadataHelpers';
|
||||||
import {
|
import {
|
||||||
refinerModelChanged,
|
refinerModelChanged,
|
||||||
setNegativeStylePromptSDXL,
|
setNegativeStylePromptSDXL,
|
||||||
@ -70,23 +70,7 @@ import {
|
|||||||
import { isNil } from 'lodash-es';
|
import { isNil } from 'lodash-es';
|
||||||
import { useCallback } from 'react';
|
import { useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
|
||||||
import {
|
|
||||||
controlNetModelsAdapterSelectors,
|
|
||||||
ipAdapterModelsAdapterSelectors,
|
|
||||||
loraModelsAdapterSelectors,
|
|
||||||
mainModelsAdapterSelectors,
|
|
||||||
t2iAdapterModelsAdapterSelectors,
|
|
||||||
useGetControlNetModelsQuery,
|
|
||||||
useGetIPAdapterModelsQuery,
|
|
||||||
useGetLoRAModelsQuery,
|
|
||||||
useGetMainModelsQuery,
|
|
||||||
useGetT2IAdapterModelsQuery,
|
|
||||||
useGetVaeModelsQuery,
|
|
||||||
vaeModelsAdapterSelectors,
|
|
||||||
} from 'services/api/endpoints/models';
|
|
||||||
import type { ImageDTO } from 'services/api/types';
|
import type { ImageDTO } from 'services/api/types';
|
||||||
import { v4 as uuidv4 } from 'uuid';
|
|
||||||
|
|
||||||
const selectModel = createMemoizedSelector(selectGenerationSlice, (generation) => generation.model);
|
const selectModel = createMemoizedSelector(selectGenerationSlice, (generation) => generation.model);
|
||||||
|
|
||||||
@ -140,9 +124,6 @@ export const useRecallParameters = () => {
|
|||||||
[t, toaster]
|
[t, toaster]
|
||||||
);
|
);
|
||||||
|
|
||||||
/**
|
|
||||||
* Recall both prompts with toast
|
|
||||||
*/
|
|
||||||
const recallBothPrompts = useCallback(
|
const recallBothPrompts = useCallback(
|
||||||
(positivePrompt: unknown, negativePrompt: unknown, positiveStylePrompt: unknown, negativeStylePrompt: unknown) => {
|
(positivePrompt: unknown, negativePrompt: unknown, positiveStylePrompt: unknown, negativeStylePrompt: unknown) => {
|
||||||
if (
|
if (
|
||||||
@ -175,9 +156,6 @@ export const useRecallParameters = () => {
|
|||||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||||
);
|
);
|
||||||
|
|
||||||
/**
|
|
||||||
* Recall positive prompt with toast
|
|
||||||
*/
|
|
||||||
const recallPositivePrompt = useCallback(
|
const recallPositivePrompt = useCallback(
|
||||||
(positivePrompt: unknown) => {
|
(positivePrompt: unknown) => {
|
||||||
if (!isParameterPositivePrompt(positivePrompt)) {
|
if (!isParameterPositivePrompt(positivePrompt)) {
|
||||||
@ -190,9 +168,6 @@ export const useRecallParameters = () => {
|
|||||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||||
);
|
);
|
||||||
|
|
||||||
/**
|
|
||||||
* Recall negative prompt with toast
|
|
||||||
*/
|
|
||||||
const recallNegativePrompt = useCallback(
|
const recallNegativePrompt = useCallback(
|
||||||
(negativePrompt: unknown) => {
|
(negativePrompt: unknown) => {
|
||||||
if (!isParameterNegativePrompt(negativePrompt)) {
|
if (!isParameterNegativePrompt(negativePrompt)) {
|
||||||
@ -205,9 +180,6 @@ export const useRecallParameters = () => {
|
|||||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||||
);
|
);
|
||||||
|
|
||||||
/**
|
|
||||||
* Recall SDXL Positive Style Prompt with toast
|
|
||||||
*/
|
|
||||||
const recallSDXLPositiveStylePrompt = useCallback(
|
const recallSDXLPositiveStylePrompt = useCallback(
|
||||||
(positiveStylePrompt: unknown) => {
|
(positiveStylePrompt: unknown) => {
|
||||||
if (!isParameterPositiveStylePromptSDXL(positiveStylePrompt)) {
|
if (!isParameterPositiveStylePromptSDXL(positiveStylePrompt)) {
|
||||||
@ -220,9 +192,6 @@ export const useRecallParameters = () => {
|
|||||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||||
);
|
);
|
||||||
|
|
||||||
/**
|
|
||||||
* Recall SDXL Negative Style Prompt with toast
|
|
||||||
*/
|
|
||||||
const recallSDXLNegativeStylePrompt = useCallback(
|
const recallSDXLNegativeStylePrompt = useCallback(
|
||||||
(negativeStylePrompt: unknown) => {
|
(negativeStylePrompt: unknown) => {
|
||||||
if (!isParameterNegativeStylePromptSDXL(negativeStylePrompt)) {
|
if (!isParameterNegativeStylePromptSDXL(negativeStylePrompt)) {
|
||||||
@ -235,9 +204,6 @@ export const useRecallParameters = () => {
|
|||||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||||
);
|
);
|
||||||
|
|
||||||
/**
|
|
||||||
* Recall seed with toast
|
|
||||||
*/
|
|
||||||
const recallSeed = useCallback(
|
const recallSeed = useCallback(
|
||||||
(seed: unknown) => {
|
(seed: unknown) => {
|
||||||
if (!isParameterSeed(seed)) {
|
if (!isParameterSeed(seed)) {
|
||||||
@ -250,9 +216,6 @@ export const useRecallParameters = () => {
|
|||||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||||
);
|
);
|
||||||
|
|
||||||
/**
|
|
||||||
* Recall CFG scale with toast
|
|
||||||
*/
|
|
||||||
const recallCfgScale = useCallback(
|
const recallCfgScale = useCallback(
|
||||||
(cfgScale: unknown) => {
|
(cfgScale: unknown) => {
|
||||||
if (!isParameterCFGScale(cfgScale)) {
|
if (!isParameterCFGScale(cfgScale)) {
|
||||||
@ -265,9 +228,6 @@ export const useRecallParameters = () => {
|
|||||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||||
);
|
);
|
||||||
|
|
||||||
/**
|
|
||||||
* Recall CFG rescale multiplier with toast
|
|
||||||
*/
|
|
||||||
const recallCfgRescaleMultiplier = useCallback(
|
const recallCfgRescaleMultiplier = useCallback(
|
||||||
(cfgRescaleMultiplier: unknown) => {
|
(cfgRescaleMultiplier: unknown) => {
|
||||||
if (!isParameterCFGRescaleMultiplier(cfgRescaleMultiplier)) {
|
if (!isParameterCFGRescaleMultiplier(cfgRescaleMultiplier)) {
|
||||||
@ -280,9 +240,6 @@ export const useRecallParameters = () => {
|
|||||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||||
);
|
);
|
||||||
|
|
||||||
/**
|
|
||||||
* Recall scheduler with toast
|
|
||||||
*/
|
|
||||||
const recallScheduler = useCallback(
|
const recallScheduler = useCallback(
|
||||||
(scheduler: unknown) => {
|
(scheduler: unknown) => {
|
||||||
if (!isParameterScheduler(scheduler)) {
|
if (!isParameterScheduler(scheduler)) {
|
||||||
@ -295,9 +252,6 @@ export const useRecallParameters = () => {
|
|||||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||||
);
|
);
|
||||||
|
|
||||||
/**
|
|
||||||
* Recall steps with toast
|
|
||||||
*/
|
|
||||||
const recallSteps = useCallback(
|
const recallSteps = useCallback(
|
||||||
(steps: unknown) => {
|
(steps: unknown) => {
|
||||||
if (!isParameterSteps(steps)) {
|
if (!isParameterSteps(steps)) {
|
||||||
@ -310,9 +264,6 @@ export const useRecallParameters = () => {
|
|||||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||||
);
|
);
|
||||||
|
|
||||||
/**
|
|
||||||
* Recall width with toast
|
|
||||||
*/
|
|
||||||
const recallWidth = useCallback(
|
const recallWidth = useCallback(
|
||||||
(width: unknown) => {
|
(width: unknown) => {
|
||||||
if (!isParameterWidth(width)) {
|
if (!isParameterWidth(width)) {
|
||||||
@ -325,9 +276,6 @@ export const useRecallParameters = () => {
|
|||||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||||
);
|
);
|
||||||
|
|
||||||
/**
|
|
||||||
* Recall height with toast
|
|
||||||
*/
|
|
||||||
const recallHeight = useCallback(
|
const recallHeight = useCallback(
|
||||||
(height: unknown) => {
|
(height: unknown) => {
|
||||||
if (!isParameterHeight(height)) {
|
if (!isParameterHeight(height)) {
|
||||||
@ -340,9 +288,6 @@ export const useRecallParameters = () => {
|
|||||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||||
);
|
);
|
||||||
|
|
||||||
/**
|
|
||||||
* Recall width and height with toast
|
|
||||||
*/
|
|
||||||
const recallWidthAndHeight = useCallback(
|
const recallWidthAndHeight = useCallback(
|
||||||
(width: unknown, height: unknown) => {
|
(width: unknown, height: unknown) => {
|
||||||
if (!isParameterWidth(width)) {
|
if (!isParameterWidth(width)) {
|
||||||
@ -360,9 +305,6 @@ export const useRecallParameters = () => {
|
|||||||
[dispatch, allParameterSetToast, allParameterNotSetToast]
|
[dispatch, allParameterSetToast, allParameterNotSetToast]
|
||||||
);
|
);
|
||||||
|
|
||||||
/**
|
|
||||||
* Recall strength with toast
|
|
||||||
*/
|
|
||||||
const recallStrength = useCallback(
|
const recallStrength = useCallback(
|
||||||
(strength: unknown) => {
|
(strength: unknown) => {
|
||||||
if (!isParameterStrength(strength)) {
|
if (!isParameterStrength(strength)) {
|
||||||
@ -375,9 +317,6 @@ export const useRecallParameters = () => {
|
|||||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||||
);
|
);
|
||||||
|
|
||||||
/**
|
|
||||||
* Recall high resolution enabled with toast
|
|
||||||
*/
|
|
||||||
const recallHrfEnabled = useCallback(
|
const recallHrfEnabled = useCallback(
|
||||||
(hrfEnabled: unknown) => {
|
(hrfEnabled: unknown) => {
|
||||||
if (!isParameterHRFEnabled(hrfEnabled)) {
|
if (!isParameterHRFEnabled(hrfEnabled)) {
|
||||||
@ -390,9 +329,6 @@ export const useRecallParameters = () => {
|
|||||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||||
);
|
);
|
||||||
|
|
||||||
/**
|
|
||||||
* Recall high resolution strength with toast
|
|
||||||
*/
|
|
||||||
const recallHrfStrength = useCallback(
|
const recallHrfStrength = useCallback(
|
||||||
(hrfStrength: unknown) => {
|
(hrfStrength: unknown) => {
|
||||||
if (!isParameterStrength(hrfStrength)) {
|
if (!isParameterStrength(hrfStrength)) {
|
||||||
@ -405,9 +341,6 @@ export const useRecallParameters = () => {
|
|||||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||||
);
|
);
|
||||||
|
|
||||||
/**
|
|
||||||
* Recall high resolution method with toast
|
|
||||||
*/
|
|
||||||
const recallHrfMethod = useCallback(
|
const recallHrfMethod = useCallback(
|
||||||
(hrfMethod: unknown) => {
|
(hrfMethod: unknown) => {
|
||||||
if (!isParameterHRFMethod(hrfMethod)) {
|
if (!isParameterHRFMethod(hrfMethod)) {
|
||||||
@ -420,358 +353,95 @@ export const useRecallParameters = () => {
|
|||||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||||
);
|
);
|
||||||
|
|
||||||
const { data: mainModels } = useGetMainModelsQuery(ALL_BASE_MODELS);
|
|
||||||
|
|
||||||
const prepareMainModelMetadataItem = useCallback(
|
|
||||||
(model: ModelIdentifier) => {
|
|
||||||
const matchingModel = mainModels ? mainModelsAdapterSelectors.selectById(mainModels, model.key) : undefined;
|
|
||||||
|
|
||||||
if (!matchingModel) {
|
|
||||||
return { model: null, error: 'Model is not installed' };
|
|
||||||
}
|
|
||||||
|
|
||||||
return { model: matchingModel, error: null };
|
|
||||||
},
|
|
||||||
[mainModels]
|
|
||||||
);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Recall model with toast
|
|
||||||
*/
|
|
||||||
const recallModel = useCallback(
|
const recallModel = useCallback(
|
||||||
(model: unknown) => {
|
async (modelMetadataItem: unknown) => {
|
||||||
if (!isModelIdentifier(model)) {
|
try {
|
||||||
parameterNotSetToast();
|
const model = await prepareMainModelMetadataItem(modelMetadataItem);
|
||||||
return;
|
dispatch(modelSelected(model));
|
||||||
}
|
|
||||||
|
|
||||||
const result = prepareMainModelMetadataItem(model);
|
|
||||||
|
|
||||||
if (!result.model) {
|
|
||||||
parameterNotSetToast(result.error);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
dispatch(modelSelected(result.model));
|
|
||||||
parameterSetToast();
|
parameterSetToast();
|
||||||
},
|
} catch (e) {
|
||||||
[prepareMainModelMetadataItem, dispatch, parameterSetToast, parameterNotSetToast]
|
parameterNotSetToast((e as unknown as Error).message);
|
||||||
);
|
|
||||||
|
|
||||||
const { data: vaeModels } = useGetVaeModelsQuery();
|
|
||||||
|
|
||||||
const prepareVAEMetadataItem = useCallback(
|
|
||||||
(vae: ModelIdentifier, newModel?: ParameterModel) => {
|
|
||||||
const matchingModel = vaeModels ? vaeModelsAdapterSelectors.selectById(vaeModels, vae.key) : undefined;
|
|
||||||
if (!matchingModel) {
|
|
||||||
return { vae: null, error: 'VAE model is not installed' };
|
|
||||||
}
|
|
||||||
const isCompatibleBaseModel = matchingModel?.base === (newModel ?? model)?.base;
|
|
||||||
|
|
||||||
if (!isCompatibleBaseModel) {
|
|
||||||
return {
|
|
||||||
vae: null,
|
|
||||||
error: 'VAE incompatible with currently-selected model',
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
return { vae: matchingModel, error: null };
|
|
||||||
},
|
|
||||||
[model, vaeModels]
|
|
||||||
);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Recall vae model
|
|
||||||
*/
|
|
||||||
const recallVaeModel = useCallback(
|
|
||||||
(vae: unknown) => {
|
|
||||||
if (!isModelIdentifier(vae) && !isNil(vae)) {
|
|
||||||
parameterNotSetToast();
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||||
|
);
|
||||||
|
|
||||||
if (isNil(vae)) {
|
const recallVaeModel = useCallback(
|
||||||
|
async (vaeMetadataItem: unknown) => {
|
||||||
|
if (isNil(vaeMetadataItem)) {
|
||||||
dispatch(vaeSelected(null));
|
dispatch(vaeSelected(null));
|
||||||
parameterSetToast();
|
parameterSetToast();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
try {
|
||||||
const result = prepareVAEMetadataItem(vae);
|
const vae = await prepareVAEMetadataItem(vaeMetadataItem);
|
||||||
|
dispatch(vaeSelected(vae));
|
||||||
if (!result.vae) {
|
parameterSetToast();
|
||||||
parameterNotSetToast(result.error);
|
} catch (e) {
|
||||||
|
parameterNotSetToast((e as unknown as Error).message);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
dispatch(vaeSelected(result.vae));
|
|
||||||
parameterSetToast();
|
|
||||||
},
|
},
|
||||||
[prepareVAEMetadataItem, dispatch, parameterSetToast, parameterNotSetToast]
|
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||||
);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Recall LoRA with toast
|
|
||||||
*/
|
|
||||||
|
|
||||||
const { data: loraModels } = useGetLoRAModelsQuery(undefined);
|
|
||||||
|
|
||||||
const prepareLoRAMetadataItem = useCallback(
|
|
||||||
(loraMetadataItem: LoRAMetadataItem, newModel?: ParameterModel) => {
|
|
||||||
if (!isModelIdentifier(loraMetadataItem.lora)) {
|
|
||||||
return { lora: null, error: 'Invalid LoRA model' };
|
|
||||||
}
|
|
||||||
|
|
||||||
const { lora } = loraMetadataItem;
|
|
||||||
|
|
||||||
const matchingLoRA = loraModels ? loraModelsAdapterSelectors.selectById(loraModels, lora.key) : undefined;
|
|
||||||
|
|
||||||
if (!matchingLoRA) {
|
|
||||||
return { lora: null, error: 'LoRA model is not installed' };
|
|
||||||
}
|
|
||||||
|
|
||||||
const isCompatibleBaseModel = matchingLoRA?.base === (newModel ?? model)?.base;
|
|
||||||
|
|
||||||
if (!isCompatibleBaseModel) {
|
|
||||||
return {
|
|
||||||
lora: null,
|
|
||||||
error: 'LoRA incompatible with currently-selected model',
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
return { lora: matchingLoRA, error: null };
|
|
||||||
},
|
|
||||||
[loraModels, model]
|
|
||||||
);
|
);
|
||||||
|
|
||||||
const recallLoRA = useCallback(
|
const recallLoRA = useCallback(
|
||||||
(loraMetadataItem: LoRAMetadataItem) => {
|
async (loraMetadataItem: LoRAMetadataItem) => {
|
||||||
const result = prepareLoRAMetadataItem(loraMetadataItem);
|
try {
|
||||||
|
const lora = await prepareLoRAMetadataItem(loraMetadataItem, model?.base);
|
||||||
if (!result.lora) {
|
dispatch(loraRecalled(lora));
|
||||||
parameterNotSetToast(result.error);
|
parameterSetToast();
|
||||||
|
} catch (e) {
|
||||||
|
parameterNotSetToast((e as unknown as Error).message);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
dispatch(loraRecalled({ ...result.lora, weight: loraMetadataItem.weight }));
|
|
||||||
|
|
||||||
parameterSetToast();
|
|
||||||
},
|
},
|
||||||
[prepareLoRAMetadataItem, dispatch, parameterSetToast, parameterNotSetToast]
|
[model?.base, dispatch, parameterSetToast, parameterNotSetToast]
|
||||||
);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Recall ControlNet with toast
|
|
||||||
*/
|
|
||||||
|
|
||||||
const { data: controlNetModels } = useGetControlNetModelsQuery(undefined);
|
|
||||||
|
|
||||||
const prepareControlNetMetadataItem = useCallback(
|
|
||||||
(controlnetMetadataItem: ControlNetMetadataItem, newModel?: ParameterModel) => {
|
|
||||||
if (!isModelIdentifier(controlnetMetadataItem.control_model)) {
|
|
||||||
return { controlnet: null, error: 'Invalid ControlNet model' };
|
|
||||||
}
|
|
||||||
|
|
||||||
const { image, control_model, control_weight, begin_step_percent, end_step_percent, control_mode, resize_mode } =
|
|
||||||
controlnetMetadataItem;
|
|
||||||
|
|
||||||
const matchingControlNetModel = controlNetModels
|
|
||||||
? controlNetModelsAdapterSelectors.selectById(controlNetModels, control_model.key)
|
|
||||||
: undefined;
|
|
||||||
|
|
||||||
if (!matchingControlNetModel) {
|
|
||||||
return { controlnet: null, error: 'ControlNet model is not installed' };
|
|
||||||
}
|
|
||||||
|
|
||||||
const isCompatibleBaseModel = matchingControlNetModel?.base === (newModel ?? model)?.base;
|
|
||||||
|
|
||||||
if (!isCompatibleBaseModel) {
|
|
||||||
return {
|
|
||||||
controlnet: null,
|
|
||||||
error: 'ControlNet incompatible with currently-selected model',
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// We don't save the original image that was processed into a control image, only the processed image
|
|
||||||
const processorType = 'none';
|
|
||||||
const processorNode = CONTROLNET_PROCESSORS.none.default;
|
|
||||||
|
|
||||||
const controlnet: ControlNetConfig = {
|
|
||||||
type: 'controlnet',
|
|
||||||
isEnabled: true,
|
|
||||||
model: matchingControlNetModel,
|
|
||||||
weight: typeof control_weight === 'number' ? control_weight : initialControlNet.weight,
|
|
||||||
beginStepPct: begin_step_percent || initialControlNet.beginStepPct,
|
|
||||||
endStepPct: end_step_percent || initialControlNet.endStepPct,
|
|
||||||
controlMode: control_mode || initialControlNet.controlMode,
|
|
||||||
resizeMode: resize_mode || initialControlNet.resizeMode,
|
|
||||||
controlImage: image?.image_name || null,
|
|
||||||
processedControlImage: image?.image_name || null,
|
|
||||||
processorType,
|
|
||||||
processorNode,
|
|
||||||
shouldAutoConfig: true,
|
|
||||||
id: uuidv4(),
|
|
||||||
};
|
|
||||||
|
|
||||||
return { controlnet, error: null };
|
|
||||||
},
|
|
||||||
[controlNetModels, model]
|
|
||||||
);
|
);
|
||||||
|
|
||||||
const recallControlNet = useCallback(
|
const recallControlNet = useCallback(
|
||||||
(controlnetMetadataItem: ControlNetMetadataItem) => {
|
async (controlnetMetadataItem: ControlNetMetadataItem) => {
|
||||||
const result = prepareControlNetMetadataItem(controlnetMetadataItem);
|
try {
|
||||||
|
const controlNetConfig = await prepareControlNetMetadataItem(controlnetMetadataItem, model?.base);
|
||||||
if (!result.controlnet) {
|
dispatch(controlAdapterRecalled(controlNetConfig));
|
||||||
parameterNotSetToast(result.error);
|
parameterSetToast();
|
||||||
|
} catch (e) {
|
||||||
|
parameterNotSetToast((e as unknown as Error).message);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
dispatch(controlAdapterRecalled(result.controlnet));
|
|
||||||
|
|
||||||
parameterSetToast();
|
|
||||||
},
|
},
|
||||||
[prepareControlNetMetadataItem, dispatch, parameterSetToast, parameterNotSetToast]
|
[model?.base, dispatch, parameterSetToast, parameterNotSetToast]
|
||||||
);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Recall T2I Adapter with toast
|
|
||||||
*/
|
|
||||||
|
|
||||||
const { data: t2iAdapterModels } = useGetT2IAdapterModelsQuery(undefined);
|
|
||||||
|
|
||||||
const prepareT2IAdapterMetadataItem = useCallback(
|
|
||||||
(t2iAdapterMetadataItem: T2IAdapterMetadataItem, newModel?: ParameterModel) => {
|
|
||||||
if (!isModelIdentifier(t2iAdapterMetadataItem.t2i_adapter_model)) {
|
|
||||||
return { controlnet: null, error: 'Invalid ControlNet model' };
|
|
||||||
}
|
|
||||||
|
|
||||||
const { image, t2i_adapter_model, weight, begin_step_percent, end_step_percent, resize_mode } =
|
|
||||||
t2iAdapterMetadataItem;
|
|
||||||
|
|
||||||
const matchingT2IAdapterModel = t2iAdapterModels
|
|
||||||
? t2iAdapterModelsAdapterSelectors.selectById(t2iAdapterModels, t2i_adapter_model.key)
|
|
||||||
: undefined;
|
|
||||||
|
|
||||||
if (!matchingT2IAdapterModel) {
|
|
||||||
return { controlnet: null, error: 'ControlNet model is not installed' };
|
|
||||||
}
|
|
||||||
|
|
||||||
const isCompatibleBaseModel = matchingT2IAdapterModel?.base === (newModel ?? model)?.base;
|
|
||||||
|
|
||||||
if (!isCompatibleBaseModel) {
|
|
||||||
return {
|
|
||||||
t2iAdapter: null,
|
|
||||||
error: 'ControlNet incompatible with currently-selected model',
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// We don't save the original image that was processed into a control image, only the processed image
|
|
||||||
const processorType = 'none';
|
|
||||||
const processorNode = CONTROLNET_PROCESSORS.none.default;
|
|
||||||
|
|
||||||
const t2iAdapter: T2IAdapterConfig = {
|
|
||||||
type: 't2i_adapter',
|
|
||||||
isEnabled: true,
|
|
||||||
model: matchingT2IAdapterModel,
|
|
||||||
weight: typeof weight === 'number' ? weight : initialT2IAdapter.weight,
|
|
||||||
beginStepPct: begin_step_percent || initialT2IAdapter.beginStepPct,
|
|
||||||
endStepPct: end_step_percent || initialT2IAdapter.endStepPct,
|
|
||||||
resizeMode: resize_mode || initialT2IAdapter.resizeMode,
|
|
||||||
controlImage: image?.image_name || null,
|
|
||||||
processedControlImage: image?.image_name || null,
|
|
||||||
processorType,
|
|
||||||
processorNode,
|
|
||||||
shouldAutoConfig: true,
|
|
||||||
id: uuidv4(),
|
|
||||||
};
|
|
||||||
|
|
||||||
return { t2iAdapter, error: null };
|
|
||||||
},
|
|
||||||
[model, t2iAdapterModels]
|
|
||||||
);
|
);
|
||||||
|
|
||||||
const recallT2IAdapter = useCallback(
|
const recallT2IAdapter = useCallback(
|
||||||
(t2iAdapterMetadataItem: T2IAdapterMetadataItem) => {
|
async (t2iAdapterMetadataItem: T2IAdapterMetadataItem) => {
|
||||||
const result = prepareT2IAdapterMetadataItem(t2iAdapterMetadataItem);
|
try {
|
||||||
|
const t2iAdapterConfig = await prepareT2IAdapterMetadataItem(t2iAdapterMetadataItem, model?.base);
|
||||||
if (!result.t2iAdapter) {
|
dispatch(controlAdapterRecalled(t2iAdapterConfig));
|
||||||
parameterNotSetToast(result.error);
|
parameterSetToast();
|
||||||
|
} catch (e) {
|
||||||
|
parameterNotSetToast((e as unknown as Error).message);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
dispatch(controlAdapterRecalled(result.t2iAdapter));
|
|
||||||
|
|
||||||
parameterSetToast();
|
|
||||||
},
|
},
|
||||||
[prepareT2IAdapterMetadataItem, dispatch, parameterSetToast, parameterNotSetToast]
|
[model?.base, dispatch, parameterSetToast, parameterNotSetToast]
|
||||||
);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Recall IP Adapter with toast
|
|
||||||
*/
|
|
||||||
|
|
||||||
const { data: ipAdapterModels } = useGetIPAdapterModelsQuery(undefined);
|
|
||||||
|
|
||||||
const prepareIPAdapterMetadataItem = useCallback(
|
|
||||||
(ipAdapterMetadataItem: IPAdapterMetadataItem, newModel?: ParameterModel) => {
|
|
||||||
if (!isModelIdentifier(ipAdapterMetadataItem?.ip_adapter_model)) {
|
|
||||||
return { ipAdapter: null, error: 'Invalid IP Adapter model' };
|
|
||||||
}
|
|
||||||
|
|
||||||
const { image, ip_adapter_model, weight, begin_step_percent, end_step_percent } = ipAdapterMetadataItem;
|
|
||||||
|
|
||||||
const matchingIPAdapterModel = ipAdapterModels
|
|
||||||
? ipAdapterModelsAdapterSelectors.selectById(ipAdapterModels, ip_adapter_model.key)
|
|
||||||
: undefined;
|
|
||||||
|
|
||||||
if (!matchingIPAdapterModel) {
|
|
||||||
return { ipAdapter: null, error: 'IP Adapter model is not installed' };
|
|
||||||
}
|
|
||||||
|
|
||||||
const isCompatibleBaseModel = matchingIPAdapterModel?.base === (newModel ?? model)?.base;
|
|
||||||
|
|
||||||
if (!isCompatibleBaseModel) {
|
|
||||||
return {
|
|
||||||
ipAdapter: null,
|
|
||||||
error: 'IP Adapter incompatible with currently-selected model',
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
const ipAdapter: IPAdapterConfig = {
|
|
||||||
id: uuidv4(),
|
|
||||||
type: 'ip_adapter',
|
|
||||||
isEnabled: true,
|
|
||||||
controlImage: image?.image_name ?? null,
|
|
||||||
model: matchingIPAdapterModel,
|
|
||||||
weight: weight ?? initialIPAdapter.weight,
|
|
||||||
beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct,
|
|
||||||
endStepPct: end_step_percent ?? initialIPAdapter.endStepPct,
|
|
||||||
};
|
|
||||||
|
|
||||||
return { ipAdapter, error: null };
|
|
||||||
},
|
|
||||||
[ipAdapterModels, model]
|
|
||||||
);
|
);
|
||||||
|
|
||||||
const recallIPAdapter = useCallback(
|
const recallIPAdapter = useCallback(
|
||||||
(ipAdapterMetadataItem: IPAdapterMetadataItem) => {
|
async (ipAdapterMetadataItem: IPAdapterMetadataItem) => {
|
||||||
const result = prepareIPAdapterMetadataItem(ipAdapterMetadataItem);
|
try {
|
||||||
|
const ipAdapterConfig = await prepareIPAdapterMetadataItem(ipAdapterMetadataItem, model?.base);
|
||||||
if (!result.ipAdapter) {
|
dispatch(controlAdapterRecalled(ipAdapterConfig));
|
||||||
parameterNotSetToast(result.error);
|
parameterSetToast();
|
||||||
|
} catch (e) {
|
||||||
|
parameterNotSetToast((e as unknown as Error).message);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
dispatch(controlAdapterRecalled(result.ipAdapter));
|
|
||||||
|
|
||||||
parameterSetToast();
|
|
||||||
},
|
},
|
||||||
[prepareIPAdapterMetadataItem, dispatch, parameterSetToast, parameterNotSetToast]
|
[model?.base, dispatch, parameterSetToast, parameterNotSetToast]
|
||||||
);
|
);
|
||||||
|
|
||||||
/*
|
|
||||||
* Sets image as initial image with toast
|
|
||||||
*/
|
|
||||||
const sendToImageToImage = useCallback(
|
const sendToImageToImage = useCallback(
|
||||||
(image: ImageDTO) => {
|
(image: ImageDTO) => {
|
||||||
dispatch(initialImageSelected(image));
|
dispatch(initialImageSelected(image));
|
||||||
@ -780,7 +450,7 @@ export const useRecallParameters = () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
const recallAllParameters = useCallback(
|
const recallAllParameters = useCallback(
|
||||||
(metadata: CoreMetadata | undefined) => {
|
async (metadata: CoreMetadata | undefined) => {
|
||||||
if (!metadata) {
|
if (!metadata) {
|
||||||
allParameterNotSetToast();
|
allParameterNotSetToast();
|
||||||
return;
|
return;
|
||||||
@ -820,10 +490,12 @@ export const useRecallParameters = () => {
|
|||||||
let newModel: ParameterModel | undefined = undefined;
|
let newModel: ParameterModel | undefined = undefined;
|
||||||
|
|
||||||
if (isModelIdentifier(model)) {
|
if (isModelIdentifier(model)) {
|
||||||
const result = prepareMainModelMetadataItem(model);
|
try {
|
||||||
if (result.model) {
|
const _model = await prepareMainModelMetadataItem(model);
|
||||||
dispatch(modelSelected(result.model));
|
dispatch(modelSelected(_model));
|
||||||
newModel = result.model;
|
newModel = _model;
|
||||||
|
} catch {
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -850,9 +522,11 @@ export const useRecallParameters = () => {
|
|||||||
if (isNil(vae)) {
|
if (isNil(vae)) {
|
||||||
dispatch(vaeSelected(null));
|
dispatch(vaeSelected(null));
|
||||||
} else {
|
} else {
|
||||||
const result = prepareVAEMetadataItem(vae, newModel);
|
try {
|
||||||
if (result.vae) {
|
const _vae = await prepareVAEMetadataItem(vae, newModel?.base);
|
||||||
dispatch(vaeSelected(result.vae));
|
dispatch(vaeSelected(_vae));
|
||||||
|
} catch {
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -926,48 +600,46 @@ export const useRecallParameters = () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
dispatch(lorasCleared());
|
dispatch(lorasCleared());
|
||||||
loras?.forEach((lora) => {
|
loras?.forEach(async (loraMetadataItem) => {
|
||||||
const result = prepareLoRAMetadataItem(lora, newModel);
|
try {
|
||||||
if (result.lora) {
|
const lora = await prepareLoRAMetadataItem(loraMetadataItem, newModel?.base);
|
||||||
dispatch(loraRecalled({ ...result.lora, weight: lora.weight }));
|
dispatch(loraRecalled(lora));
|
||||||
|
} catch {
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
dispatch(controlAdaptersReset());
|
dispatch(controlAdaptersReset());
|
||||||
controlnets?.forEach((controlnet) => {
|
controlnets?.forEach(async (controlNetMetadataItem) => {
|
||||||
const result = prepareControlNetMetadataItem(controlnet, newModel);
|
try {
|
||||||
if (result.controlnet) {
|
const controlNet = await prepareControlNetMetadataItem(controlNetMetadataItem, newModel?.base);
|
||||||
dispatch(controlAdapterRecalled(result.controlnet));
|
dispatch(controlAdapterRecalled(controlNet));
|
||||||
|
} catch {
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
ipAdapters?.forEach((ipAdapter) => {
|
ipAdapters?.forEach(async (ipAdapterMetadataItem) => {
|
||||||
const result = prepareIPAdapterMetadataItem(ipAdapter, newModel);
|
try {
|
||||||
if (result.ipAdapter) {
|
const ipAdapter = await prepareIPAdapterMetadataItem(ipAdapterMetadataItem, newModel?.base);
|
||||||
dispatch(controlAdapterRecalled(result.ipAdapter));
|
dispatch(controlAdapterRecalled(ipAdapter));
|
||||||
|
} catch {
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
t2iAdapters?.forEach((t2iAdapter) => {
|
t2iAdapters?.forEach(async (t2iAdapterMetadataItem) => {
|
||||||
const result = prepareT2IAdapterMetadataItem(t2iAdapter, newModel);
|
try {
|
||||||
if (result.t2iAdapter) {
|
const t2iAdapter = await prepareT2IAdapterMetadataItem(t2iAdapterMetadataItem, newModel?.base);
|
||||||
dispatch(controlAdapterRecalled(result.t2iAdapter));
|
dispatch(controlAdapterRecalled(t2iAdapter));
|
||||||
|
} catch {
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
allParameterSetToast();
|
allParameterSetToast();
|
||||||
},
|
},
|
||||||
[
|
[dispatch, allParameterSetToast, allParameterNotSetToast]
|
||||||
dispatch,
|
|
||||||
allParameterSetToast,
|
|
||||||
allParameterNotSetToast,
|
|
||||||
prepareMainModelMetadataItem,
|
|
||||||
prepareVAEMetadataItem,
|
|
||||||
prepareLoRAMetadataItem,
|
|
||||||
prepareControlNetMetadataItem,
|
|
||||||
prepareIPAdapterMetadataItem,
|
|
||||||
prepareT2IAdapterMetadataItem,
|
|
||||||
]
|
|
||||||
);
|
);
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
@ -0,0 +1,113 @@
|
|||||||
|
import { getStore } from 'app/store/nanostores/store';
|
||||||
|
import { isModelIdentifier } from 'features/nodes/types/common';
|
||||||
|
import { modelsApi } from 'services/api/endpoints/models';
|
||||||
|
import type { AnyModelConfig, BaseModelType } from 'services/api/types';
|
||||||
|
import {
|
||||||
|
isControlNetModelConfig,
|
||||||
|
isIPAdapterModelConfig,
|
||||||
|
isLoRAModelConfig,
|
||||||
|
isNonRefinerMainModelConfig,
|
||||||
|
isRefinerMainModelModelConfig,
|
||||||
|
isT2IAdapterModelConfig,
|
||||||
|
isTextualInversionModelConfig,
|
||||||
|
isVAEModelConfig,
|
||||||
|
} from 'services/api/types';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Raised when a model config is unable to be fetched.
|
||||||
|
*/
|
||||||
|
export class ModelConfigNotFoundError extends Error {
|
||||||
|
/**
|
||||||
|
* Create ModelConfigNotFoundError
|
||||||
|
* @param {String} message
|
||||||
|
*/
|
||||||
|
constructor(message: string) {
|
||||||
|
super(message);
|
||||||
|
this.name = this.constructor.name;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Raised when a fetched model config is of an unexpected type.
|
||||||
|
*/
|
||||||
|
export class InvalidModelConfigError extends Error {
|
||||||
|
/**
|
||||||
|
* Create InvalidModelConfigError
|
||||||
|
* @param {String} message
|
||||||
|
*/
|
||||||
|
constructor(message: string) {
|
||||||
|
super(message);
|
||||||
|
this.name = this.constructor.name;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export const fetchModelConfig = async (key: string): Promise<AnyModelConfig> => {
|
||||||
|
const { dispatch } = getStore();
|
||||||
|
try {
|
||||||
|
const req = dispatch(modelsApi.endpoints.getModelConfig.initiate(key));
|
||||||
|
req.unsubscribe();
|
||||||
|
return await req.unwrap();
|
||||||
|
} catch {
|
||||||
|
throw new ModelConfigNotFoundError(`Unable to retrieve model config for key ${key}`);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
export const fetchModelConfigWithTypeGuard = async <T extends AnyModelConfig>(
|
||||||
|
key: string,
|
||||||
|
typeGuard: (config: AnyModelConfig) => config is T
|
||||||
|
) => {
|
||||||
|
const modelConfig = await fetchModelConfig(key);
|
||||||
|
if (!typeGuard(modelConfig)) {
|
||||||
|
throw new InvalidModelConfigError(`Invalid model type for key ${key}: ${modelConfig.type}`);
|
||||||
|
}
|
||||||
|
return modelConfig;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const fetchMainModel = async (key: string) => {
|
||||||
|
return fetchModelConfigWithTypeGuard(key, isNonRefinerMainModelConfig);
|
||||||
|
};
|
||||||
|
|
||||||
|
export const fetchRefinerModel = async (key: string) => {
|
||||||
|
return fetchModelConfigWithTypeGuard(key, isRefinerMainModelModelConfig);
|
||||||
|
};
|
||||||
|
|
||||||
|
export const fetchVAEModel = async (key: string) => {
|
||||||
|
return fetchModelConfigWithTypeGuard(key, isVAEModelConfig);
|
||||||
|
};
|
||||||
|
|
||||||
|
export const fetchLoRAModel = async (key: string) => {
|
||||||
|
return fetchModelConfigWithTypeGuard(key, isLoRAModelConfig);
|
||||||
|
};
|
||||||
|
|
||||||
|
export const fetchControlNetModel = async (key: string) => {
|
||||||
|
return fetchModelConfigWithTypeGuard(key, isControlNetModelConfig);
|
||||||
|
};
|
||||||
|
|
||||||
|
export const fetchIPAdapterModel = async (key: string) => {
|
||||||
|
return fetchModelConfigWithTypeGuard(key, isIPAdapterModelConfig);
|
||||||
|
};
|
||||||
|
|
||||||
|
export const fetchT2IAdapterModel = async (key: string) => {
|
||||||
|
return fetchModelConfigWithTypeGuard(key, isT2IAdapterModelConfig);
|
||||||
|
};
|
||||||
|
|
||||||
|
export const fetchTextualInversionModel = async (key: string) => {
|
||||||
|
return fetchModelConfigWithTypeGuard(key, isTextualInversionModelConfig);
|
||||||
|
};
|
||||||
|
|
||||||
|
export const isBaseCompatible = (sourceBase: BaseModelType, targetBase: BaseModelType) => {
|
||||||
|
return sourceBase === targetBase;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const raiseIfBaseIncompatible = (sourceBase: BaseModelType, targetBase?: BaseModelType, message?: string) => {
|
||||||
|
if (targetBase && !isBaseCompatible(sourceBase, targetBase)) {
|
||||||
|
throw new InvalidModelConfigError(message || `Incompatible base models: ${sourceBase} and ${targetBase}`);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
export const getModelKey = (modelIdentifier: unknown, message?: string): string => {
|
||||||
|
if (!isModelIdentifier(modelIdentifier)) {
|
||||||
|
throw new InvalidModelConfigError(message || `Invalid model identifier: ${modelIdentifier}`);
|
||||||
|
}
|
||||||
|
return modelIdentifier.key;
|
||||||
|
};
|
@ -0,0 +1,150 @@
|
|||||||
|
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
|
||||||
|
import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types';
|
||||||
|
import {
|
||||||
|
initialControlNet,
|
||||||
|
initialIPAdapter,
|
||||||
|
initialT2IAdapter,
|
||||||
|
} from 'features/controlAdapters/util/buildControlAdapter';
|
||||||
|
import type { LoRA } from 'features/lora/store/loraSlice';
|
||||||
|
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
|
||||||
|
import { zModelIdentifierWithBase } from 'features/nodes/types/common';
|
||||||
|
import type {
|
||||||
|
ControlNetMetadataItem,
|
||||||
|
IPAdapterMetadataItem,
|
||||||
|
LoRAMetadataItem,
|
||||||
|
T2IAdapterMetadataItem,
|
||||||
|
} from 'features/nodes/types/metadata';
|
||||||
|
import {
|
||||||
|
fetchControlNetModel,
|
||||||
|
fetchIPAdapterModel,
|
||||||
|
fetchLoRAModel,
|
||||||
|
fetchMainModel,
|
||||||
|
fetchRefinerModel,
|
||||||
|
fetchT2IAdapterModel,
|
||||||
|
fetchVAEModel,
|
||||||
|
getModelKey,
|
||||||
|
raiseIfBaseIncompatible,
|
||||||
|
} from 'features/parameters/util/modelFetchingHelpers';
|
||||||
|
import type { BaseModelType } from 'services/api/types';
|
||||||
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
|
|
||||||
|
export const prepareMainModelMetadataItem = async (model: unknown): Promise<ModelIdentifierWithBase> => {
|
||||||
|
const key = getModelKey(model);
|
||||||
|
const mainModel = await fetchMainModel(key);
|
||||||
|
return zModelIdentifierWithBase.parse(mainModel);
|
||||||
|
};
|
||||||
|
|
||||||
|
export const prepareRefinerMetadataItem = async (model: unknown): Promise<ModelIdentifierWithBase> => {
|
||||||
|
const key = getModelKey(model);
|
||||||
|
const refinerModel = await fetchRefinerModel(key);
|
||||||
|
return zModelIdentifierWithBase.parse(refinerModel);
|
||||||
|
};
|
||||||
|
|
||||||
|
export const prepareVAEMetadataItem = async (vae: unknown, base?: BaseModelType): Promise<ModelIdentifierWithBase> => {
|
||||||
|
const key = getModelKey(vae);
|
||||||
|
const vaeModel = await fetchVAEModel(key);
|
||||||
|
raiseIfBaseIncompatible(vaeModel.base, base, 'VAE incompatible with currently-selected model');
|
||||||
|
return zModelIdentifierWithBase.parse(vaeModel);
|
||||||
|
};
|
||||||
|
|
||||||
|
export const prepareLoRAMetadataItem = async (
|
||||||
|
loraMetadataItem: LoRAMetadataItem,
|
||||||
|
base?: BaseModelType
|
||||||
|
): Promise<LoRA> => {
|
||||||
|
const key = getModelKey(loraMetadataItem.lora);
|
||||||
|
const loraModel = await fetchLoRAModel(key);
|
||||||
|
raiseIfBaseIncompatible(loraModel.base, base, 'LoRA incompatible with currently-selected model');
|
||||||
|
return { key: loraModel.key, base: loraModel.base, weight: loraMetadataItem.weight, isEnabled: true };
|
||||||
|
};
|
||||||
|
|
||||||
|
export const prepareControlNetMetadataItem = async (
|
||||||
|
controlnetMetadataItem: ControlNetMetadataItem,
|
||||||
|
base?: BaseModelType
|
||||||
|
): Promise<ControlNetConfig> => {
|
||||||
|
const key = getModelKey(controlnetMetadataItem.control_model);
|
||||||
|
const controlNetModel = await fetchControlNetModel(key);
|
||||||
|
raiseIfBaseIncompatible(controlNetModel.base, base, 'ControlNet incompatible with currently-selected model');
|
||||||
|
|
||||||
|
const { image, control_weight, begin_step_percent, end_step_percent, control_mode, resize_mode } =
|
||||||
|
controlnetMetadataItem;
|
||||||
|
|
||||||
|
// We don't save the original image that was processed into a control image, only the processed image
|
||||||
|
const processorType = 'none';
|
||||||
|
const processorNode = CONTROLNET_PROCESSORS.none.default;
|
||||||
|
|
||||||
|
const controlnet: ControlNetConfig = {
|
||||||
|
type: 'controlnet',
|
||||||
|
isEnabled: true,
|
||||||
|
model: zModelIdentifierWithBase.parse(controlNetModel),
|
||||||
|
weight: typeof control_weight === 'number' ? control_weight : initialControlNet.weight,
|
||||||
|
beginStepPct: begin_step_percent || initialControlNet.beginStepPct,
|
||||||
|
endStepPct: end_step_percent || initialControlNet.endStepPct,
|
||||||
|
controlMode: control_mode || initialControlNet.controlMode,
|
||||||
|
resizeMode: resize_mode || initialControlNet.resizeMode,
|
||||||
|
controlImage: image?.image_name || null,
|
||||||
|
processedControlImage: image?.image_name || null,
|
||||||
|
processorType,
|
||||||
|
processorNode,
|
||||||
|
shouldAutoConfig: true,
|
||||||
|
id: uuidv4(),
|
||||||
|
};
|
||||||
|
|
||||||
|
return controlnet;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const prepareT2IAdapterMetadataItem = async (
|
||||||
|
t2iAdapterMetadataItem: T2IAdapterMetadataItem,
|
||||||
|
base?: BaseModelType
|
||||||
|
): Promise<T2IAdapterConfig> => {
|
||||||
|
const key = getModelKey(t2iAdapterMetadataItem.t2i_adapter_model);
|
||||||
|
const t2iAdapterModel = await fetchT2IAdapterModel(key);
|
||||||
|
raiseIfBaseIncompatible(t2iAdapterModel.base, base, 'T2I Adapter incompatible with currently-selected model');
|
||||||
|
|
||||||
|
const { image, weight, begin_step_percent, end_step_percent, resize_mode } = t2iAdapterMetadataItem;
|
||||||
|
|
||||||
|
// We don't save the original image that was processed into a control image, only the processed image
|
||||||
|
const processorType = 'none';
|
||||||
|
const processorNode = CONTROLNET_PROCESSORS.none.default;
|
||||||
|
|
||||||
|
const t2iAdapter: T2IAdapterConfig = {
|
||||||
|
type: 't2i_adapter',
|
||||||
|
isEnabled: true,
|
||||||
|
model: zModelIdentifierWithBase.parse(t2iAdapterModel),
|
||||||
|
weight: typeof weight === 'number' ? weight : initialT2IAdapter.weight,
|
||||||
|
beginStepPct: begin_step_percent || initialT2IAdapter.beginStepPct,
|
||||||
|
endStepPct: end_step_percent || initialT2IAdapter.endStepPct,
|
||||||
|
resizeMode: resize_mode || initialT2IAdapter.resizeMode,
|
||||||
|
controlImage: image?.image_name || null,
|
||||||
|
processedControlImage: image?.image_name || null,
|
||||||
|
processorType,
|
||||||
|
processorNode,
|
||||||
|
shouldAutoConfig: true,
|
||||||
|
id: uuidv4(),
|
||||||
|
};
|
||||||
|
|
||||||
|
return t2iAdapter;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const prepareIPAdapterMetadataItem = async (
|
||||||
|
ipAdapterMetadataItem: IPAdapterMetadataItem,
|
||||||
|
base?: BaseModelType
|
||||||
|
): Promise<IPAdapterConfig> => {
|
||||||
|
const key = getModelKey(ipAdapterMetadataItem?.ip_adapter_model);
|
||||||
|
const ipAdapterModel = await fetchIPAdapterModel(key);
|
||||||
|
raiseIfBaseIncompatible(ipAdapterModel.base, base, 'T2I Adapter incompatible with currently-selected model');
|
||||||
|
|
||||||
|
const { image, weight, begin_step_percent, end_step_percent } = ipAdapterMetadataItem;
|
||||||
|
|
||||||
|
const ipAdapter: IPAdapterConfig = {
|
||||||
|
id: uuidv4(),
|
||||||
|
type: 'ip_adapter',
|
||||||
|
isEnabled: true,
|
||||||
|
controlImage: image?.image_name ?? null,
|
||||||
|
model: zModelIdentifierWithBase.parse(ipAdapterModel),
|
||||||
|
weight: weight ?? initialIPAdapter.weight,
|
||||||
|
beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct,
|
||||||
|
endStepPct: end_step_percent ?? initialIPAdapter.endStepPct,
|
||||||
|
};
|
||||||
|
|
||||||
|
return ipAdapter;
|
||||||
|
};
|
@ -6,16 +6,16 @@ import type { operations, paths } from 'services/api/schema';
|
|||||||
import type {
|
import type {
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
ControlNetConfig,
|
ControlNetModelConfig,
|
||||||
ImportModelConfig,
|
ImportModelConfig,
|
||||||
IPAdapterConfig,
|
IPAdapterModelConfig,
|
||||||
LoRAConfig,
|
LoRAModelConfig,
|
||||||
MainModelConfig,
|
MainModelConfig,
|
||||||
MergeModelConfig,
|
MergeModelConfig,
|
||||||
ModelType,
|
ModelType,
|
||||||
T2IAdapterConfig,
|
T2IAdapterModelConfig,
|
||||||
TextualInversionConfig,
|
TextualInversionModelConfig,
|
||||||
VAEConfig,
|
VAEModelConfig,
|
||||||
} from 'services/api/types';
|
} from 'services/api/types';
|
||||||
|
|
||||||
import type { ApiTagDescription, tagTypes } from '..';
|
import type { ApiTagDescription, tagTypes } from '..';
|
||||||
@ -30,7 +30,7 @@ type UpdateMainModelArg = {
|
|||||||
type UpdateLoRAModelArg = {
|
type UpdateLoRAModelArg = {
|
||||||
base_model: BaseModelType;
|
base_model: BaseModelType;
|
||||||
model_name: string;
|
model_name: string;
|
||||||
body: LoRAConfig;
|
body: LoRAModelConfig;
|
||||||
};
|
};
|
||||||
|
|
||||||
type UpdateMainModelResponse =
|
type UpdateMainModelResponse =
|
||||||
@ -97,27 +97,27 @@ export const mainModelsAdapter = createEntityAdapter<MainModelConfig, string>({
|
|||||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||||
});
|
});
|
||||||
export const mainModelsAdapterSelectors = mainModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
export const mainModelsAdapterSelectors = mainModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||||
export const loraModelsAdapter = createEntityAdapter<LoRAConfig, string>({
|
export const loraModelsAdapter = createEntityAdapter<LoRAModelConfig, string>({
|
||||||
selectId: (entity) => entity.key,
|
selectId: (entity) => entity.key,
|
||||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||||
});
|
});
|
||||||
export const loraModelsAdapterSelectors = loraModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
export const loraModelsAdapterSelectors = loraModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||||
export const controlNetModelsAdapter = createEntityAdapter<ControlNetConfig, string>({
|
export const controlNetModelsAdapter = createEntityAdapter<ControlNetModelConfig, string>({
|
||||||
selectId: (entity) => entity.key,
|
selectId: (entity) => entity.key,
|
||||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||||
});
|
});
|
||||||
export const controlNetModelsAdapterSelectors = controlNetModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
export const controlNetModelsAdapterSelectors = controlNetModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||||
export const ipAdapterModelsAdapter = createEntityAdapter<IPAdapterConfig, string>({
|
export const ipAdapterModelsAdapter = createEntityAdapter<IPAdapterModelConfig, string>({
|
||||||
selectId: (entity) => entity.key,
|
selectId: (entity) => entity.key,
|
||||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||||
});
|
});
|
||||||
export const ipAdapterModelsAdapterSelectors = ipAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
export const ipAdapterModelsAdapterSelectors = ipAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||||
export const t2iAdapterModelsAdapter = createEntityAdapter<T2IAdapterConfig, string>({
|
export const t2iAdapterModelsAdapter = createEntityAdapter<T2IAdapterModelConfig, string>({
|
||||||
selectId: (entity) => entity.key,
|
selectId: (entity) => entity.key,
|
||||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||||
});
|
});
|
||||||
export const t2iAdapterModelsAdapterSelectors = t2iAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
export const t2iAdapterModelsAdapterSelectors = t2iAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||||
export const textualInversionModelsAdapter = createEntityAdapter<TextualInversionConfig, string>({
|
export const textualInversionModelsAdapter = createEntityAdapter<TextualInversionModelConfig, string>({
|
||||||
selectId: (entity) => entity.key,
|
selectId: (entity) => entity.key,
|
||||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||||
});
|
});
|
||||||
@ -125,7 +125,7 @@ export const textualInversionModelsAdapterSelectors = textualInversionModelsAdap
|
|||||||
undefined,
|
undefined,
|
||||||
getSelectorsOptions
|
getSelectorsOptions
|
||||||
);
|
);
|
||||||
export const vaeModelsAdapter = createEntityAdapter<VAEConfig, string>({
|
export const vaeModelsAdapter = createEntityAdapter<VAEModelConfig, string>({
|
||||||
selectId: (entity) => entity.key,
|
selectId: (entity) => entity.key,
|
||||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||||
});
|
});
|
||||||
@ -162,6 +162,8 @@ const buildTransformResponse =
|
|||||||
*/
|
*/
|
||||||
const buildModelsUrl = (path: string = '') => buildV2Url(`models/${path}`);
|
const buildModelsUrl = (path: string = '') => buildV2Url(`models/${path}`);
|
||||||
|
|
||||||
|
// TODO(psyche): Ideally we can share the cache between the `getXYZModels` queries and `getModelConfig` query
|
||||||
|
|
||||||
export const modelsApi = api.injectEndpoints({
|
export const modelsApi = api.injectEndpoints({
|
||||||
endpoints: (build) => ({
|
endpoints: (build) => ({
|
||||||
getMainModels: build.query<EntityState<MainModelConfig, string>, BaseModelType[]>({
|
getMainModels: build.query<EntityState<MainModelConfig, string>, BaseModelType[]>({
|
||||||
@ -257,10 +259,10 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
},
|
},
|
||||||
invalidatesTags: ['Model'],
|
invalidatesTags: ['Model'],
|
||||||
}),
|
}),
|
||||||
getLoRAModels: build.query<EntityState<LoRAConfig, string>, void>({
|
getLoRAModels: build.query<EntityState<LoRAModelConfig, string>, void>({
|
||||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 'lora' } }),
|
query: () => ({ url: buildModelsUrl(), params: { model_type: 'lora' } }),
|
||||||
providesTags: buildProvidesTags<LoRAConfig>('LoRAModel'),
|
providesTags: buildProvidesTags<LoRAModelConfig>('LoRAModel'),
|
||||||
transformResponse: buildTransformResponse<LoRAConfig>(loraModelsAdapter),
|
transformResponse: buildTransformResponse<LoRAModelConfig>(loraModelsAdapter),
|
||||||
}),
|
}),
|
||||||
updateLoRAModels: build.mutation<UpdateLoRAModelResponse, UpdateLoRAModelArg>({
|
updateLoRAModels: build.mutation<UpdateLoRAModelResponse, UpdateLoRAModelArg>({
|
||||||
query: ({ base_model, model_name, body }) => {
|
query: ({ base_model, model_name, body }) => {
|
||||||
@ -281,30 +283,30 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
},
|
},
|
||||||
invalidatesTags: [{ type: 'LoRAModel', id: LIST_TAG }],
|
invalidatesTags: [{ type: 'LoRAModel', id: LIST_TAG }],
|
||||||
}),
|
}),
|
||||||
getControlNetModels: build.query<EntityState<ControlNetConfig, string>, void>({
|
getControlNetModels: build.query<EntityState<ControlNetModelConfig, string>, void>({
|
||||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 'controlnet' } }),
|
query: () => ({ url: buildModelsUrl(), params: { model_type: 'controlnet' } }),
|
||||||
providesTags: buildProvidesTags<ControlNetConfig>('ControlNetModel'),
|
providesTags: buildProvidesTags<ControlNetModelConfig>('ControlNetModel'),
|
||||||
transformResponse: buildTransformResponse<ControlNetConfig>(controlNetModelsAdapter),
|
transformResponse: buildTransformResponse<ControlNetModelConfig>(controlNetModelsAdapter),
|
||||||
}),
|
}),
|
||||||
getIPAdapterModels: build.query<EntityState<IPAdapterConfig, string>, void>({
|
getIPAdapterModels: build.query<EntityState<IPAdapterModelConfig, string>, void>({
|
||||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 'ip_adapter' } }),
|
query: () => ({ url: buildModelsUrl(), params: { model_type: 'ip_adapter' } }),
|
||||||
providesTags: buildProvidesTags<IPAdapterConfig>('IPAdapterModel'),
|
providesTags: buildProvidesTags<IPAdapterModelConfig>('IPAdapterModel'),
|
||||||
transformResponse: buildTransformResponse<IPAdapterConfig>(ipAdapterModelsAdapter),
|
transformResponse: buildTransformResponse<IPAdapterModelConfig>(ipAdapterModelsAdapter),
|
||||||
}),
|
}),
|
||||||
getT2IAdapterModels: build.query<EntityState<T2IAdapterConfig, string>, void>({
|
getT2IAdapterModels: build.query<EntityState<T2IAdapterModelConfig, string>, void>({
|
||||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 't2i_adapter' } }),
|
query: () => ({ url: buildModelsUrl(), params: { model_type: 't2i_adapter' } }),
|
||||||
providesTags: buildProvidesTags<T2IAdapterConfig>('T2IAdapterModel'),
|
providesTags: buildProvidesTags<T2IAdapterModelConfig>('T2IAdapterModel'),
|
||||||
transformResponse: buildTransformResponse<T2IAdapterConfig>(t2iAdapterModelsAdapter),
|
transformResponse: buildTransformResponse<T2IAdapterModelConfig>(t2iAdapterModelsAdapter),
|
||||||
}),
|
}),
|
||||||
getVaeModels: build.query<EntityState<VAEConfig, string>, void>({
|
getVaeModels: build.query<EntityState<VAEModelConfig, string>, void>({
|
||||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 'vae' } }),
|
query: () => ({ url: buildModelsUrl(), params: { model_type: 'vae' } }),
|
||||||
providesTags: buildProvidesTags<VAEConfig>('VaeModel'),
|
providesTags: buildProvidesTags<VAEModelConfig>('VaeModel'),
|
||||||
transformResponse: buildTransformResponse<VAEConfig>(vaeModelsAdapter),
|
transformResponse: buildTransformResponse<VAEModelConfig>(vaeModelsAdapter),
|
||||||
}),
|
}),
|
||||||
getTextualInversionModels: build.query<EntityState<TextualInversionConfig, string>, void>({
|
getTextualInversionModels: build.query<EntityState<TextualInversionModelConfig, string>, void>({
|
||||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 'embedding' } }),
|
query: () => ({ url: buildModelsUrl(), params: { model_type: 'embedding' } }),
|
||||||
providesTags: buildProvidesTags<TextualInversionConfig>('TextualInversionModel'),
|
providesTags: buildProvidesTags<TextualInversionModelConfig>('TextualInversionModel'),
|
||||||
transformResponse: buildTransformResponse<TextualInversionConfig>(textualInversionModelsAdapter),
|
transformResponse: buildTransformResponse<TextualInversionModelConfig>(textualInversionModelsAdapter),
|
||||||
}),
|
}),
|
||||||
getModelsInFolder: build.query<SearchFolderResponse, SearchFolderArg>({
|
getModelsInFolder: build.query<SearchFolderResponse, SearchFolderArg>({
|
||||||
query: (arg) => {
|
query: (arg) => {
|
||||||
|
Loading…
Reference in New Issue
Block a user