mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(ui): update all components and logic to use enriched ModelIdentifierField
This commit is contained in:
@ -10,7 +10,7 @@ export const ReduxInit = memo((props: PropsWithChildren) => {
|
|||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
useGlobalModifiersInit();
|
useGlobalModifiersInit();
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
dispatch(modelChanged({ key: 'test_model', base: 'sd-1' }));
|
dispatch(modelChanged({ key: 'test_model', hash: 'some_hash', name: 'some name', base: 'sd-1', type: 'main' }));
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
return props.children;
|
return props.children;
|
||||||
|
@ -2,7 +2,7 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
|||||||
import type { EntityState } from '@reduxjs/toolkit';
|
import type { EntityState } from '@reduxjs/toolkit';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import type { GroupBase } from 'chakra-react-select';
|
import type { GroupBase } from 'chakra-react-select';
|
||||||
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
|
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
||||||
import { groupBy, map, reduce } from 'lodash-es';
|
import { groupBy, map, reduce } from 'lodash-es';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
@ -10,7 +10,7 @@ import type { AnyModelConfig } from 'services/api/types';
|
|||||||
|
|
||||||
type UseGroupedModelComboboxArg<T extends AnyModelConfig> = {
|
type UseGroupedModelComboboxArg<T extends AnyModelConfig> = {
|
||||||
modelEntities: EntityState<T, string> | undefined;
|
modelEntities: EntityState<T, string> | undefined;
|
||||||
selectedModel?: ModelIdentifierWithBase | null;
|
selectedModel?: ModelIdentifierField | null;
|
||||||
onChange: (value: T | null) => void;
|
onChange: (value: T | null) => void;
|
||||||
getIsDisabled?: (model: T) => boolean;
|
getIsDisabled?: (model: T) => boolean;
|
||||||
isLoading?: boolean;
|
isLoading?: boolean;
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||||
import type { EntityState } from '@reduxjs/toolkit';
|
import type { EntityState } from '@reduxjs/toolkit';
|
||||||
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
|
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
||||||
import { map } from 'lodash-es';
|
import { map } from 'lodash-es';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
@ -8,7 +8,7 @@ import type { AnyModelConfig } from 'services/api/types';
|
|||||||
|
|
||||||
type UseModelComboboxArg<T extends AnyModelConfig> = {
|
type UseModelComboboxArg<T extends AnyModelConfig> = {
|
||||||
modelEntities: EntityState<T, string> | undefined;
|
modelEntities: EntityState<T, string> | undefined;
|
||||||
selectedModel?: ModelIdentifierWithBase | null;
|
selectedModel?: ModelIdentifierField | null;
|
||||||
onChange: (value: T | null) => void;
|
onChange: (value: T | null) => void;
|
||||||
getIsDisabled?: (model: T) => boolean;
|
getIsDisabled?: (model: T) => boolean;
|
||||||
optionsFilter?: (model: T) => boolean;
|
optionsFilter?: (model: T) => boolean;
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import type { Item } from '@invoke-ai/ui-library';
|
import type { Item } from '@invoke-ai/ui-library';
|
||||||
import type { EntityState } from '@reduxjs/toolkit';
|
import type { EntityState } from '@reduxjs/toolkit';
|
||||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||||
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
|
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
||||||
import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants';
|
import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants';
|
||||||
import { filter } from 'lodash-es';
|
import { filter } from 'lodash-es';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
@ -11,7 +11,7 @@ import type { AnyModelConfig } from 'services/api/types';
|
|||||||
type UseModelCustomSelectArg<T extends AnyModelConfig> = {
|
type UseModelCustomSelectArg<T extends AnyModelConfig> = {
|
||||||
data: EntityState<T, string> | undefined;
|
data: EntityState<T, string> | undefined;
|
||||||
isLoading: boolean;
|
isLoading: boolean;
|
||||||
selectedModel?: ModelIdentifierWithBase | null;
|
selectedModel?: ModelIdentifierField | null;
|
||||||
onChange: (value: T | null) => void;
|
onChange: (value: T | null) => void;
|
||||||
modelFilter?: (model: T) => boolean;
|
modelFilter?: (model: T) => boolean;
|
||||||
isModelDisabled?: (model: T) => boolean;
|
isModelDisabled?: (model: T) => boolean;
|
||||||
|
@ -4,6 +4,7 @@ import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
|
|||||||
import type { PersistConfig, RootState } from 'app/store/store';
|
import type { PersistConfig, RootState } from 'app/store/store';
|
||||||
import { buildControlAdapter } from 'features/controlAdapters/util/buildControlAdapter';
|
import { buildControlAdapter } from 'features/controlAdapters/util/buildControlAdapter';
|
||||||
import { buildControlAdapterProcessor } from 'features/controlAdapters/util/buildControlAdapterProcessor';
|
import { buildControlAdapterProcessor } from 'features/controlAdapters/util/buildControlAdapterProcessor';
|
||||||
|
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||||
import { cloneDeep, merge, uniq } from 'lodash-es';
|
import { cloneDeep, merge, uniq } from 'lodash-es';
|
||||||
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||||
import { socketInvocationError } from 'services/events/actions';
|
import { socketInvocationError } from 'services/events/actions';
|
||||||
@ -197,7 +198,7 @@ export const controlAdaptersSlice = createSlice({
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const model = { key: modelConfig.key, base: modelConfig.base };
|
const model = zModelIdentifierField.parse(modelConfig);
|
||||||
|
|
||||||
if (!isControlNetOrT2IAdapter(cn)) {
|
if (!isControlNetOrT2IAdapter(cn)) {
|
||||||
caAdapter.updateOne(state, { id, changes: { model } });
|
caAdapter.updateOne(state, { id, changes: { model } });
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||||
import { createSlice } from '@reduxjs/toolkit';
|
import { createSlice } from '@reduxjs/toolkit';
|
||||||
import type { PersistConfig, RootState } from 'app/store/store';
|
import type { PersistConfig, RootState } from 'app/store/store';
|
||||||
import { getModelKeyAndBase } from 'features/metadata/util/modelFetchingHelpers';
|
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||||
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
|
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
|
||||||
import type { LoRAModelConfig } from 'services/api/types';
|
import type { LoRAModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
@ -31,7 +31,7 @@ export const loraSlice = createSlice({
|
|||||||
initialState: initialLoraState,
|
initialState: initialLoraState,
|
||||||
reducers: {
|
reducers: {
|
||||||
loraAdded: (state, action: PayloadAction<LoRAModelConfig>) => {
|
loraAdded: (state, action: PayloadAction<LoRAModelConfig>) => {
|
||||||
const model = getModelKeyAndBase(action.payload);
|
const model = zModelIdentifierField.parse(action.payload);
|
||||||
state.loras[model.key] = { ...defaultLoRAConfig, model };
|
state.loras[model.key] = { ...defaultLoRAConfig, model };
|
||||||
},
|
},
|
||||||
loraRecalled: (state, action: PayloadAction<LoRA>) => {
|
loraRecalled: (state, action: PayloadAction<LoRA>) => {
|
||||||
|
@ -13,13 +13,13 @@ import type {
|
|||||||
} from 'features/metadata/types';
|
} from 'features/metadata/types';
|
||||||
import { fetchModelConfig } from 'features/metadata/util/modelFetchingHelpers';
|
import { fetchModelConfig } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
import { validators } from 'features/metadata/util/validators';
|
import { validators } from 'features/metadata/util/validators';
|
||||||
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
|
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
||||||
import { t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
|
|
||||||
import { parsers } from './parsers';
|
import { parsers } from './parsers';
|
||||||
import { recallers } from './recallers';
|
import { recallers } from './recallers';
|
||||||
|
|
||||||
const renderModelConfigValue: MetadataRenderValueFunc<ModelIdentifierWithBase> = async (value) => {
|
const renderModelConfigValue: MetadataRenderValueFunc<ModelIdentifierField> = async (value) => {
|
||||||
try {
|
try {
|
||||||
const modelConfig = await fetchModelConfig(value.key);
|
const modelConfig = await fetchModelConfig(value.key);
|
||||||
return `${modelConfig.name} (${modelConfig.base.toUpperCase()})`;
|
return `${modelConfig.name} (${modelConfig.base.toUpperCase()})`;
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import { getStore } from 'app/store/nanostores/store';
|
import { getStore } from 'app/store/nanostores/store';
|
||||||
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
|
|
||||||
import { isModelIdentifier, isModelIdentifierV2 } from 'features/nodes/types/common';
|
import { isModelIdentifier, isModelIdentifierV2 } from 'features/nodes/types/common';
|
||||||
import { modelsApi } from 'services/api/endpoints/models';
|
import { modelsApi } from 'services/api/endpoints/models';
|
||||||
import type { AnyModelConfig, BaseModelType, ModelType } from 'services/api/types';
|
import type { AnyModelConfig, BaseModelType, ModelType } from 'services/api/types';
|
||||||
@ -105,8 +104,3 @@ export const getModelKey = async (modelIdentifier: unknown, type: ModelType, mes
|
|||||||
}
|
}
|
||||||
throw new InvalidModelConfigError(message || `Invalid model identifier: ${modelIdentifier}`);
|
throw new InvalidModelConfigError(message || `Invalid model identifier: ${modelIdentifier}`);
|
||||||
};
|
};
|
||||||
|
|
||||||
export const getModelKeyAndBase = (modelConfig: AnyModelConfig): ModelIdentifierWithBase => ({
|
|
||||||
key: modelConfig.key,
|
|
||||||
base: modelConfig.base,
|
|
||||||
});
|
|
||||||
|
@ -13,12 +13,7 @@ import type {
|
|||||||
T2IAdapterConfigMetadata,
|
T2IAdapterConfigMetadata,
|
||||||
} from 'features/metadata/types';
|
} from 'features/metadata/types';
|
||||||
import { fetchModelConfigWithTypeGuard, getModelKey } from 'features/metadata/util/modelFetchingHelpers';
|
import { fetchModelConfigWithTypeGuard, getModelKey } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
import {
|
import { zControlField, zIPAdapterField, zModelIdentifierField, zT2IAdapterField } from 'features/nodes/types/common';
|
||||||
zControlField,
|
|
||||||
zIPAdapterField,
|
|
||||||
zModelIdentifierWithBase,
|
|
||||||
zT2IAdapterField,
|
|
||||||
} from 'features/nodes/types/common';
|
|
||||||
import type {
|
import type {
|
||||||
ParameterCFGRescaleMultiplier,
|
ParameterCFGRescaleMultiplier,
|
||||||
ParameterCFGScale,
|
ParameterCFGScale,
|
||||||
@ -181,7 +176,7 @@ const parseMainModel: MetadataParseFunc<ParameterModel> = async (metadata) => {
|
|||||||
const model = await getProperty(metadata, 'model', undefined);
|
const model = await getProperty(metadata, 'model', undefined);
|
||||||
const key = await getModelKey(model, 'main');
|
const key = await getModelKey(model, 'main');
|
||||||
const mainModelConfig = await fetchModelConfigWithTypeGuard(key, isNonRefinerMainModelConfig);
|
const mainModelConfig = await fetchModelConfigWithTypeGuard(key, isNonRefinerMainModelConfig);
|
||||||
const modelIdentifier = zModelIdentifierWithBase.parse(mainModelConfig);
|
const modelIdentifier = zModelIdentifierField.parse(mainModelConfig);
|
||||||
return modelIdentifier;
|
return modelIdentifier;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -189,7 +184,7 @@ const parseRefinerModel: MetadataParseFunc<ParameterSDXLRefinerModel> = async (m
|
|||||||
const refiner_model = await getProperty(metadata, 'refiner_model', undefined);
|
const refiner_model = await getProperty(metadata, 'refiner_model', undefined);
|
||||||
const key = await getModelKey(refiner_model, 'main');
|
const key = await getModelKey(refiner_model, 'main');
|
||||||
const refinerModelConfig = await fetchModelConfigWithTypeGuard(key, isRefinerMainModelModelConfig);
|
const refinerModelConfig = await fetchModelConfigWithTypeGuard(key, isRefinerMainModelModelConfig);
|
||||||
const modelIdentifier = zModelIdentifierWithBase.parse(refinerModelConfig);
|
const modelIdentifier = zModelIdentifierField.parse(refinerModelConfig);
|
||||||
return modelIdentifier;
|
return modelIdentifier;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -197,7 +192,7 @@ const parseVAEModel: MetadataParseFunc<ParameterVAEModel> = async (metadata) =>
|
|||||||
const vae = await getProperty(metadata, 'vae', undefined);
|
const vae = await getProperty(metadata, 'vae', undefined);
|
||||||
const key = await getModelKey(vae, 'vae');
|
const key = await getModelKey(vae, 'vae');
|
||||||
const vaeModelConfig = await fetchModelConfigWithTypeGuard(key, isVAEModelConfig);
|
const vaeModelConfig = await fetchModelConfigWithTypeGuard(key, isVAEModelConfig);
|
||||||
const modelIdentifier = zModelIdentifierWithBase.parse(vaeModelConfig);
|
const modelIdentifier = zModelIdentifierField.parse(vaeModelConfig);
|
||||||
return modelIdentifier;
|
return modelIdentifier;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -211,7 +206,7 @@ const parseLoRA: MetadataParseFunc<LoRA> = async (metadataItem) => {
|
|||||||
const loraModelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig);
|
const loraModelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig);
|
||||||
|
|
||||||
return {
|
return {
|
||||||
model: zModelIdentifierWithBase.parse(loraModelConfig),
|
model: zModelIdentifierField.parse(loraModelConfig),
|
||||||
weight: isParameterLoRAWeight(weight) ? weight : defaultLoRAConfig.weight,
|
weight: isParameterLoRAWeight(weight) ? weight : defaultLoRAConfig.weight,
|
||||||
isEnabled: true,
|
isEnabled: true,
|
||||||
};
|
};
|
||||||
@ -258,7 +253,7 @@ const parseControlNet: MetadataParseFunc<ControlNetConfigMetadata> = async (meta
|
|||||||
const controlNet: ControlNetConfigMetadata = {
|
const controlNet: ControlNetConfigMetadata = {
|
||||||
type: 'controlnet',
|
type: 'controlnet',
|
||||||
isEnabled: true,
|
isEnabled: true,
|
||||||
model: zModelIdentifierWithBase.parse(controlNetModel),
|
model: zModelIdentifierField.parse(controlNetModel),
|
||||||
weight: typeof control_weight === 'number' ? control_weight : initialControlNet.weight,
|
weight: typeof control_weight === 'number' ? control_weight : initialControlNet.weight,
|
||||||
beginStepPct: begin_step_percent ?? initialControlNet.beginStepPct,
|
beginStepPct: begin_step_percent ?? initialControlNet.beginStepPct,
|
||||||
endStepPct: end_step_percent ?? initialControlNet.endStepPct,
|
endStepPct: end_step_percent ?? initialControlNet.endStepPct,
|
||||||
@ -309,7 +304,7 @@ const parseT2IAdapter: MetadataParseFunc<T2IAdapterConfigMetadata> = async (meta
|
|||||||
const t2iAdapter: T2IAdapterConfigMetadata = {
|
const t2iAdapter: T2IAdapterConfigMetadata = {
|
||||||
type: 't2i_adapter',
|
type: 't2i_adapter',
|
||||||
isEnabled: true,
|
isEnabled: true,
|
||||||
model: zModelIdentifierWithBase.parse(t2iAdapterModel),
|
model: zModelIdentifierField.parse(t2iAdapterModel),
|
||||||
weight: typeof weight === 'number' ? weight : initialT2IAdapter.weight,
|
weight: typeof weight === 'number' ? weight : initialT2IAdapter.weight,
|
||||||
beginStepPct: begin_step_percent ?? initialT2IAdapter.beginStepPct,
|
beginStepPct: begin_step_percent ?? initialT2IAdapter.beginStepPct,
|
||||||
endStepPct: end_step_percent ?? initialT2IAdapter.endStepPct,
|
endStepPct: end_step_percent ?? initialT2IAdapter.endStepPct,
|
||||||
@ -354,7 +349,7 @@ const parseIPAdapter: MetadataParseFunc<IPAdapterConfigMetadata> = async (metada
|
|||||||
id: uuidv4(),
|
id: uuidv4(),
|
||||||
type: 'ip_adapter',
|
type: 'ip_adapter',
|
||||||
isEnabled: true,
|
isEnabled: true,
|
||||||
model: zModelIdentifierWithBase.parse(ipAdapterModel),
|
model: zModelIdentifierField.parse(ipAdapterModel),
|
||||||
controlImage: image?.image_name ?? null,
|
controlImage: image?.image_name ?? null,
|
||||||
weight: weight ?? initialIPAdapter.weight,
|
weight: weight ?? initialIPAdapter.weight,
|
||||||
beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct,
|
beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct,
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import type {
|
import type {
|
||||||
BaseModel,
|
|
||||||
BoardField,
|
BoardField,
|
||||||
Classification,
|
Classification,
|
||||||
ColorField,
|
ColorField,
|
||||||
@ -7,6 +6,7 @@ import type {
|
|||||||
ImageField,
|
ImageField,
|
||||||
ImageOutput,
|
ImageOutput,
|
||||||
IPAdapterField,
|
IPAdapterField,
|
||||||
|
ModelIdentifierField,
|
||||||
ProgressImage,
|
ProgressImage,
|
||||||
SchedulerField,
|
SchedulerField,
|
||||||
T2IAdapterField,
|
T2IAdapterField,
|
||||||
@ -33,10 +33,9 @@ describe('Common types', () => {
|
|||||||
test('T2IAdapterField', () => assert<Equals<T2IAdapterField, S['T2IAdapterField']>>());
|
test('T2IAdapterField', () => assert<Equals<T2IAdapterField, S['T2IAdapterField']>>());
|
||||||
|
|
||||||
// Model component types
|
// Model component types
|
||||||
test('BaseModel', () => assert<Equals<BaseModel, S['BaseModelType']>>());
|
test('ModelIdentifier', () => assert<Equals<ModelIdentifierField, S['ModelIdentifierField']>>());
|
||||||
|
|
||||||
// Misc types
|
// Misc types
|
||||||
// @ts-expect-error TODO(psyche): There is no `ProgressImage` in the server types yet
|
|
||||||
test('ProgressImage', () => assert<Equals<ProgressImage, S['ProgressImage']>>());
|
test('ProgressImage', () => assert<Equals<ProgressImage, S['ProgressImage']>>());
|
||||||
test('ImageOutput', () => assert<Equals<ImageOutput, S['ImageOutput']>>());
|
test('ImageOutput', () => assert<Equals<ImageOutput, S['ImageOutput']>>());
|
||||||
test('Classification', () => assert<Equals<Classification, S['Classification']>>());
|
test('Classification', () => assert<Equals<Classification, S['Classification']>>());
|
||||||
|
@ -55,6 +55,17 @@ export type SchedulerField = z.infer<typeof zSchedulerField>;
|
|||||||
|
|
||||||
// #region Model-related schemas
|
// #region Model-related schemas
|
||||||
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
|
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
|
||||||
|
const zModelType = z.enum([
|
||||||
|
'main',
|
||||||
|
'vae',
|
||||||
|
'lora',
|
||||||
|
'controlnet',
|
||||||
|
't2i_adapter',
|
||||||
|
'ip_adapter',
|
||||||
|
'embedding',
|
||||||
|
'onnx',
|
||||||
|
'clip_vision',
|
||||||
|
]);
|
||||||
const zSubModelType = z.enum([
|
const zSubModelType = z.enum([
|
||||||
'unet',
|
'unet',
|
||||||
'text_encoder',
|
'text_encoder',
|
||||||
@ -67,26 +78,25 @@ const zSubModelType = z.enum([
|
|||||||
'scheduler',
|
'scheduler',
|
||||||
'safety_checker',
|
'safety_checker',
|
||||||
]);
|
]);
|
||||||
|
export const zModelIdentifierField = z.object({
|
||||||
const zModelIdentifier = z.object({
|
|
||||||
key: z.string().min(1),
|
key: z.string().min(1),
|
||||||
|
hash: z.string().min(1),
|
||||||
|
name: z.string().min(1),
|
||||||
|
base: zBaseModel,
|
||||||
|
type: zModelType,
|
||||||
submodel_type: zSubModelType.nullish(),
|
submodel_type: zSubModelType.nullish(),
|
||||||
});
|
});
|
||||||
export const isModelIdentifier = (field: unknown): field is ModelIdentifier =>
|
export const isModelIdentifier = (field: unknown): field is ModelIdentifierField =>
|
||||||
zModelIdentifier.safeParse(field).success;
|
zModelIdentifierField.safeParse(field).success;
|
||||||
export const isModelIdentifierV2 = (field: unknown): field is ModelIdentifierV2 =>
|
export const isModelIdentifierV2 = (field: unknown): field is ModelIdentifierV2 =>
|
||||||
zModelIdentifierV2.safeParse(field).success;
|
zModelIdentifierV2.safeParse(field).success;
|
||||||
const zModelFieldBase = zModelIdentifier;
|
export type ModelIdentifierField = z.infer<typeof zModelIdentifierField>;
|
||||||
export const zModelIdentifierWithBase = zModelIdentifier.extend({ base: zBaseModel });
|
|
||||||
export type BaseModel = z.infer<typeof zBaseModel>;
|
|
||||||
type ModelIdentifier = z.infer<typeof zModelIdentifier>;
|
|
||||||
export type ModelIdentifierWithBase = z.infer<typeof zModelIdentifierWithBase>;
|
|
||||||
// #endregion
|
// #endregion
|
||||||
|
|
||||||
// #region Control Adapters
|
// #region Control Adapters
|
||||||
export const zControlField = z.object({
|
export const zControlField = z.object({
|
||||||
image: zImageField,
|
image: zImageField,
|
||||||
control_model: zModelFieldBase,
|
control_model: zModelIdentifierField,
|
||||||
control_weight: z.union([z.number(), z.array(z.number())]).optional(),
|
control_weight: z.union([z.number(), z.array(z.number())]).optional(),
|
||||||
begin_step_percent: z.number().optional(),
|
begin_step_percent: z.number().optional(),
|
||||||
end_step_percent: z.number().optional(),
|
end_step_percent: z.number().optional(),
|
||||||
@ -97,7 +107,7 @@ export type ControlField = z.infer<typeof zControlField>;
|
|||||||
|
|
||||||
export const zIPAdapterField = z.object({
|
export const zIPAdapterField = z.object({
|
||||||
image: zImageField,
|
image: zImageField,
|
||||||
ip_adapter_model: zModelFieldBase,
|
ip_adapter_model: zModelIdentifierField,
|
||||||
weight: z.number(),
|
weight: z.number(),
|
||||||
begin_step_percent: z.number().optional(),
|
begin_step_percent: z.number().optional(),
|
||||||
end_step_percent: z.number().optional(),
|
end_step_percent: z.number().optional(),
|
||||||
@ -106,7 +116,7 @@ export type IPAdapterField = z.infer<typeof zIPAdapterField>;
|
|||||||
|
|
||||||
export const zT2IAdapterField = z.object({
|
export const zT2IAdapterField = z.object({
|
||||||
image: zImageField,
|
image: zImageField,
|
||||||
t2i_adapter_model: zModelFieldBase,
|
t2i_adapter_model: zModelIdentifierField,
|
||||||
weight: z.union([z.number(), z.array(z.number())]).optional(),
|
weight: z.union([z.number(), z.array(z.number())]).optional(),
|
||||||
begin_step_percent: z.number().optional(),
|
begin_step_percent: z.number().optional(),
|
||||||
end_step_percent: z.number().optional(),
|
end_step_percent: z.number().optional(),
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import { z } from 'zod';
|
import { z } from 'zod';
|
||||||
|
|
||||||
import { zBoardField, zColorField, zImageField, zModelIdentifierWithBase, zSchedulerField } from './common';
|
import { zBoardField, zColorField, zImageField, zModelIdentifierField, zSchedulerField } from './common';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* zod schemas & inferred types for fields.
|
* zod schemas & inferred types for fields.
|
||||||
@ -277,7 +277,7 @@ export const isColorFieldInputTemplate = (val: unknown): val is ColorFieldInputT
|
|||||||
const zMainModelFieldType = zFieldTypeBase.extend({
|
const zMainModelFieldType = zFieldTypeBase.extend({
|
||||||
name: z.literal('MainModelField'),
|
name: z.literal('MainModelField'),
|
||||||
});
|
});
|
||||||
export const zMainModelFieldValue = zModelIdentifierWithBase.optional();
|
export const zMainModelFieldValue = zModelIdentifierField.optional();
|
||||||
const zMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
const zMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||||
value: zMainModelFieldValue,
|
value: zMainModelFieldValue,
|
||||||
});
|
});
|
||||||
@ -348,7 +348,7 @@ export const isSDXLRefinerModelFieldInputTemplate = (val: unknown): val is SDXLR
|
|||||||
const zVAEModelFieldType = zFieldTypeBase.extend({
|
const zVAEModelFieldType = zFieldTypeBase.extend({
|
||||||
name: z.literal('VAEModelField'),
|
name: z.literal('VAEModelField'),
|
||||||
});
|
});
|
||||||
export const zVAEModelFieldValue = zModelIdentifierWithBase.optional();
|
export const zVAEModelFieldValue = zModelIdentifierField.optional();
|
||||||
const zVAEModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
const zVAEModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||||
value: zVAEModelFieldValue,
|
value: zVAEModelFieldValue,
|
||||||
});
|
});
|
||||||
@ -372,7 +372,7 @@ export const isVAEModelFieldInputTemplate = (val: unknown): val is VAEModelField
|
|||||||
const zLoRAModelFieldType = zFieldTypeBase.extend({
|
const zLoRAModelFieldType = zFieldTypeBase.extend({
|
||||||
name: z.literal('LoRAModelField'),
|
name: z.literal('LoRAModelField'),
|
||||||
});
|
});
|
||||||
export const zLoRAModelFieldValue = zModelIdentifierWithBase.optional();
|
export const zLoRAModelFieldValue = zModelIdentifierField.optional();
|
||||||
const zLoRAModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
const zLoRAModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||||
value: zLoRAModelFieldValue,
|
value: zLoRAModelFieldValue,
|
||||||
});
|
});
|
||||||
@ -396,7 +396,7 @@ export const isLoRAModelFieldInputTemplate = (val: unknown): val is LoRAModelFie
|
|||||||
const zControlNetModelFieldType = zFieldTypeBase.extend({
|
const zControlNetModelFieldType = zFieldTypeBase.extend({
|
||||||
name: z.literal('ControlNetModelField'),
|
name: z.literal('ControlNetModelField'),
|
||||||
});
|
});
|
||||||
export const zControlNetModelFieldValue = zModelIdentifierWithBase.optional();
|
export const zControlNetModelFieldValue = zModelIdentifierField.optional();
|
||||||
const zControlNetModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
const zControlNetModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||||
value: zControlNetModelFieldValue,
|
value: zControlNetModelFieldValue,
|
||||||
});
|
});
|
||||||
@ -420,7 +420,7 @@ export const isControlNetModelFieldInputTemplate = (val: unknown): val is Contro
|
|||||||
const zIPAdapterModelFieldType = zFieldTypeBase.extend({
|
const zIPAdapterModelFieldType = zFieldTypeBase.extend({
|
||||||
name: z.literal('IPAdapterModelField'),
|
name: z.literal('IPAdapterModelField'),
|
||||||
});
|
});
|
||||||
export const zIPAdapterModelFieldValue = zModelIdentifierWithBase.optional();
|
export const zIPAdapterModelFieldValue = zModelIdentifierField.optional();
|
||||||
const zIPAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
const zIPAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||||
value: zIPAdapterModelFieldValue,
|
value: zIPAdapterModelFieldValue,
|
||||||
});
|
});
|
||||||
@ -444,7 +444,7 @@ export const isIPAdapterModelFieldInputTemplate = (val: unknown): val is IPAdapt
|
|||||||
const zT2IAdapterModelFieldType = zFieldTypeBase.extend({
|
const zT2IAdapterModelFieldType = zFieldTypeBase.extend({
|
||||||
name: z.literal('T2IAdapterModelField'),
|
name: z.literal('T2IAdapterModelField'),
|
||||||
});
|
});
|
||||||
export const zT2IAdapterModelFieldValue = zModelIdentifierWithBase.optional();
|
export const zT2IAdapterModelFieldValue = zModelIdentifierField.optional();
|
||||||
const zT2IAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
const zT2IAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||||
value: zT2IAdapterModelFieldValue,
|
value: zT2IAdapterModelFieldValue,
|
||||||
});
|
});
|
||||||
|
@ -1,15 +1,10 @@
|
|||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||||
import { filter, size } from 'lodash-es';
|
import { filter, size } from 'lodash-es';
|
||||||
import {
|
import type { CoreMetadataInvocation, LoRALoaderInvocation, NonNullableGraph } from 'services/api/types';
|
||||||
type CoreMetadataInvocation,
|
|
||||||
isLoRAModelConfig,
|
|
||||||
type LoRALoaderInvocation,
|
|
||||||
type NonNullableGraph,
|
|
||||||
} from 'services/api/types';
|
|
||||||
|
|
||||||
import { CLIP_SKIP, LORA_LOADER, MAIN_MODEL_LOADER, NEGATIVE_CONDITIONING, POSITIVE_CONDITIONING } from './constants';
|
import { CLIP_SKIP, LORA_LOADER, MAIN_MODEL_LOADER, NEGATIVE_CONDITIONING, POSITIVE_CONDITIONING } from './constants';
|
||||||
import { getModelMetadataField, upsertMetadata } from './metadata';
|
import { upsertMetadata } from './metadata';
|
||||||
|
|
||||||
export const addLoRAsToGraph = async (
|
export const addLoRAsToGraph = async (
|
||||||
state: RootState,
|
state: RootState,
|
||||||
@ -49,19 +44,18 @@ export const addLoRAsToGraph = async (
|
|||||||
const { weight } = lora;
|
const { weight } = lora;
|
||||||
const { key } = lora.model;
|
const { key } = lora.model;
|
||||||
const currentLoraNodeId = `${LORA_LOADER}_${key}`;
|
const currentLoraNodeId = `${LORA_LOADER}_${key}`;
|
||||||
|
const parsedModel = zModelIdentifierField.parse(lora.model);
|
||||||
|
|
||||||
const loraLoaderNode: LoRALoaderInvocation = {
|
const loraLoaderNode: LoRALoaderInvocation = {
|
||||||
type: 'lora_loader',
|
type: 'lora_loader',
|
||||||
id: currentLoraNodeId,
|
id: currentLoraNodeId,
|
||||||
is_intermediate: true,
|
is_intermediate: true,
|
||||||
lora: { key },
|
lora: parsedModel,
|
||||||
weight,
|
weight,
|
||||||
};
|
};
|
||||||
|
|
||||||
const modelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig);
|
|
||||||
|
|
||||||
loraMetadata.push({
|
loraMetadata.push({
|
||||||
model: getModelMetadataField(modelConfig),
|
model: parsedModel,
|
||||||
weight,
|
weight,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -1,12 +1,7 @@
|
|||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||||
import { filter, size } from 'lodash-es';
|
import { filter, size } from 'lodash-es';
|
||||||
import {
|
import type { CoreMetadataInvocation, NonNullableGraph, SDXLLoRALoaderInvocation } from 'services/api/types';
|
||||||
type CoreMetadataInvocation,
|
|
||||||
isLoRAModelConfig,
|
|
||||||
type NonNullableGraph,
|
|
||||||
type SDXLLoRALoaderInvocation,
|
|
||||||
} from 'services/api/types';
|
|
||||||
|
|
||||||
import {
|
import {
|
||||||
LORA_LOADER,
|
LORA_LOADER,
|
||||||
@ -16,7 +11,7 @@ import {
|
|||||||
SDXL_REFINER_INPAINT_CREATE_MASK,
|
SDXL_REFINER_INPAINT_CREATE_MASK,
|
||||||
SEAMLESS,
|
SEAMLESS,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { getModelMetadataField, upsertMetadata } from './metadata';
|
import { upsertMetadata } from './metadata';
|
||||||
|
|
||||||
export const addSDXLLoRAsToGraph = async (
|
export const addSDXLLoRAsToGraph = async (
|
||||||
state: RootState,
|
state: RootState,
|
||||||
@ -63,20 +58,18 @@ export const addSDXLLoRAsToGraph = async (
|
|||||||
|
|
||||||
enabledLoRAs.forEach(async (lora) => {
|
enabledLoRAs.forEach(async (lora) => {
|
||||||
const { weight } = lora;
|
const { weight } = lora;
|
||||||
const { key } = lora.model;
|
const currentLoraNodeId = `${LORA_LOADER}_${lora.model.key}`;
|
||||||
const currentLoraNodeId = `${LORA_LOADER}_${key}`;
|
const parsedModel = zModelIdentifierField.parse(lora.model);
|
||||||
|
|
||||||
const loraLoaderNode: SDXLLoRALoaderInvocation = {
|
const loraLoaderNode: SDXLLoRALoaderInvocation = {
|
||||||
type: 'sdxl_lora_loader',
|
type: 'sdxl_lora_loader',
|
||||||
id: currentLoraNodeId,
|
id: currentLoraNodeId,
|
||||||
is_intermediate: true,
|
is_intermediate: true,
|
||||||
lora: { key },
|
lora: parsedModel,
|
||||||
weight,
|
weight,
|
||||||
};
|
};
|
||||||
|
|
||||||
const modelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig);
|
loraMetadata.push({ model: parsedModel, weight });
|
||||||
|
|
||||||
loraMetadata.push({ model: getModelMetadataField(modelConfig), weight });
|
|
||||||
|
|
||||||
// add to graph
|
// add to graph
|
||||||
graph.nodes[currentLoraNodeId] = loraLoaderNode;
|
graph.nodes[currentLoraNodeId] = loraLoaderNode;
|
||||||
|
@ -3,6 +3,7 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
|||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||||
import { useModelCustomSelect } from 'common/hooks/useModelCustomSelect';
|
import { useModelCustomSelect } from 'common/hooks/useModelCustomSelect';
|
||||||
|
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||||
import { modelSelected } from 'features/parameters/store/actions';
|
import { modelSelected } from 'features/parameters/store/actions';
|
||||||
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
|
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
@ -24,7 +25,11 @@ const ParamMainModelSelect = () => {
|
|||||||
if (!model) {
|
if (!model) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
dispatch(modelSelected({ key: model.key, base: model.base }));
|
try {
|
||||||
|
dispatch(modelSelected(zModelIdentifierField.parse(model)));
|
||||||
|
} catch {
|
||||||
|
// no-op
|
||||||
|
}
|
||||||
},
|
},
|
||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
|
@ -3,7 +3,7 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
|||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||||
import { getModelKeyAndBase } from 'features/metadata/util/modelFetchingHelpers';
|
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||||
import { selectGenerationSlice, vaeSelected } from 'features/parameters/store/generationSlice';
|
import { selectGenerationSlice, vaeSelected } from 'features/parameters/store/generationSlice';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
@ -30,7 +30,7 @@ const ParamVAEModelSelect = () => {
|
|||||||
);
|
);
|
||||||
const _onChange = useCallback(
|
const _onChange = useCallback(
|
||||||
(vae: VAEModelConfig | null) => {
|
(vae: VAEModelConfig | null) => {
|
||||||
dispatch(vaeSelected(vae ? getModelKeyAndBase(vae) : null));
|
dispatch(vaeSelected(vae ? zModelIdentifierField.parse(vae) : null));
|
||||||
},
|
},
|
||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import { NUMPY_RAND_MAX } from 'app/constants';
|
import { NUMPY_RAND_MAX } from 'app/constants';
|
||||||
import { roundToMultiple } from 'common/util/roundDownToMultiple';
|
import { roundToMultiple } from 'common/util/roundDownToMultiple';
|
||||||
import { zModelIdentifierWithBase, zSchedulerField } from 'features/nodes/types/common';
|
import { zModelIdentifierField, zSchedulerField } from 'features/nodes/types/common';
|
||||||
import { z } from 'zod';
|
import { z } from 'zod';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -92,37 +92,37 @@ export const isParameterHeight = (val: unknown): val is ParameterHeight => zPara
|
|||||||
// #endregion
|
// #endregion
|
||||||
|
|
||||||
// #region Model
|
// #region Model
|
||||||
export const zParameterModel = zModelIdentifierWithBase;
|
export const zParameterModel = zModelIdentifierField;
|
||||||
export type ParameterModel = z.infer<typeof zParameterModel>;
|
export type ParameterModel = z.infer<typeof zParameterModel>;
|
||||||
// #endregion
|
// #endregion
|
||||||
|
|
||||||
// #region SDXL Refiner Model
|
// #region SDXL Refiner Model
|
||||||
const zParameterSDXLRefinerModel = zModelIdentifierWithBase;
|
const zParameterSDXLRefinerModel = zModelIdentifierField;
|
||||||
export type ParameterSDXLRefinerModel = z.infer<typeof zParameterSDXLRefinerModel>;
|
export type ParameterSDXLRefinerModel = z.infer<typeof zParameterSDXLRefinerModel>;
|
||||||
// #endregion
|
// #endregion
|
||||||
|
|
||||||
// #region VAE Model
|
// #region VAE Model
|
||||||
export const zParameterVAEModel = zModelIdentifierWithBase;
|
export const zParameterVAEModel = zModelIdentifierField;
|
||||||
export type ParameterVAEModel = z.infer<typeof zParameterVAEModel>;
|
export type ParameterVAEModel = z.infer<typeof zParameterVAEModel>;
|
||||||
// #endregion
|
// #endregion
|
||||||
|
|
||||||
// #region LoRA Model
|
// #region LoRA Model
|
||||||
const zParameterLoRAModel = zModelIdentifierWithBase;
|
const zParameterLoRAModel = zModelIdentifierField;
|
||||||
export type ParameterLoRAModel = z.infer<typeof zParameterLoRAModel>;
|
export type ParameterLoRAModel = z.infer<typeof zParameterLoRAModel>;
|
||||||
// #endregion
|
// #endregion
|
||||||
|
|
||||||
// #region ControlNet Model
|
// #region ControlNet Model
|
||||||
const zParameterControlNetModel = zModelIdentifierWithBase;
|
const zParameterControlNetModel = zModelIdentifierField;
|
||||||
export type ParameterControlNetModel = z.infer<typeof zParameterControlNetModel>;
|
export type ParameterControlNetModel = z.infer<typeof zParameterControlNetModel>;
|
||||||
// #endregion
|
// #endregion
|
||||||
|
|
||||||
// #region IP Adapter Model
|
// #region IP Adapter Model
|
||||||
const zParameterIPAdapterModel = zModelIdentifierWithBase;
|
const zParameterIPAdapterModel = zModelIdentifierField;
|
||||||
export type ParameterIPAdapterModel = z.infer<typeof zParameterIPAdapterModel>;
|
export type ParameterIPAdapterModel = z.infer<typeof zParameterIPAdapterModel>;
|
||||||
// #endregion
|
// #endregion
|
||||||
|
|
||||||
// #region T2I Adapter Model
|
// #region T2I Adapter Model
|
||||||
const zParameterT2IAdapterModel = zModelIdentifierWithBase;
|
const zParameterT2IAdapterModel = zModelIdentifierField;
|
||||||
export type ParameterT2IAdapterModel = z.infer<typeof zParameterT2IAdapterModel>;
|
export type ParameterT2IAdapterModel = z.infer<typeof zParameterT2IAdapterModel>;
|
||||||
// #endregion
|
// #endregion
|
||||||
|
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
|
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Gets the optimal dimension for a givel model, based on the model's base_model
|
* Gets the optimal dimension for a givel model, based on the model's base_model
|
||||||
* @param model The model identifier
|
* @param model The model identifier
|
||||||
* @returns The optimal dimension for the model
|
* @returns The optimal dimension for the model
|
||||||
*/
|
*/
|
||||||
export const getOptimalDimension = (model?: ModelIdentifierWithBase | null): number =>
|
export const getOptimalDimension = (model?: ModelIdentifierField | null): number =>
|
||||||
model?.base === 'sdxl' ? 1024 : 512;
|
model?.base === 'sdxl' ? 1024 : 512;
|
||||||
|
|
||||||
const MIN_AREA_FACTOR = 0.8;
|
const MIN_AREA_FACTOR = 0.8;
|
||||||
|
@ -3,7 +3,7 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
|||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||||
import { useModelCombobox } from 'common/hooks/useModelCombobox';
|
import { useModelCombobox } from 'common/hooks/useModelCombobox';
|
||||||
import { getModelKeyAndBase } from 'features/metadata/util/modelFetchingHelpers';
|
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||||
import { refinerModelChanged, selectSdxlSlice } from 'features/sdxl/store/sdxlSlice';
|
import { refinerModelChanged, selectSdxlSlice } from 'features/sdxl/store/sdxlSlice';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
@ -26,7 +26,7 @@ const ParamSDXLRefinerModelSelect = () => {
|
|||||||
dispatch(refinerModelChanged(null));
|
dispatch(refinerModelChanged(null));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
dispatch(refinerModelChanged(getModelKeyAndBase(model)));
|
dispatch(refinerModelChanged(zModelIdentifierField.parse(model)));
|
||||||
},
|
},
|
||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
|
Reference in New Issue
Block a user