fix(ui): update all components and logic to use enriched ModelIdentifierField

This commit is contained in:
psychedelicious 2024-03-09 19:51:15 +11:00
parent 4433b78e59
commit 133c90e116
19 changed files with 85 additions and 94 deletions

View File

@ -10,7 +10,7 @@ export const ReduxInit = memo((props: PropsWithChildren) => {
const dispatch = useAppDispatch();
useGlobalModifiersInit();
useEffect(() => {
dispatch(modelChanged({ key: 'test_model', base: 'sd-1' }));
dispatch(modelChanged({ key: 'test_model', hash: 'some_hash', name: 'some name', base: 'sd-1', type: 'main' }));
}, []);
return props.children;

View File

@ -2,7 +2,7 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import type { EntityState } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import type { GroupBase } from 'chakra-react-select';
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { groupBy, map, reduce } from 'lodash-es';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@ -10,7 +10,7 @@ import type { AnyModelConfig } from 'services/api/types';
type UseGroupedModelComboboxArg<T extends AnyModelConfig> = {
modelEntities: EntityState<T, string> | undefined;
selectedModel?: ModelIdentifierWithBase | null;
selectedModel?: ModelIdentifierField | null;
onChange: (value: T | null) => void;
getIsDisabled?: (model: T) => boolean;
isLoading?: boolean;

View File

@ -1,6 +1,6 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import type { EntityState } from '@reduxjs/toolkit';
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { map } from 'lodash-es';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@ -8,7 +8,7 @@ import type { AnyModelConfig } from 'services/api/types';
type UseModelComboboxArg<T extends AnyModelConfig> = {
modelEntities: EntityState<T, string> | undefined;
selectedModel?: ModelIdentifierWithBase | null;
selectedModel?: ModelIdentifierField | null;
onChange: (value: T | null) => void;
getIsDisabled?: (model: T) => boolean;
optionsFilter?: (model: T) => boolean;

View File

@ -1,7 +1,7 @@
import type { Item } from '@invoke-ai/ui-library';
import type { EntityState } from '@reduxjs/toolkit';
import { EMPTY_ARRAY } from 'app/store/constants';
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants';
import { filter } from 'lodash-es';
import { useCallback, useMemo } from 'react';
@ -11,7 +11,7 @@ import type { AnyModelConfig } from 'services/api/types';
type UseModelCustomSelectArg<T extends AnyModelConfig> = {
data: EntityState<T, string> | undefined;
isLoading: boolean;
selectedModel?: ModelIdentifierWithBase | null;
selectedModel?: ModelIdentifierField | null;
onChange: (value: T | null) => void;
modelFilter?: (model: T) => boolean;
isModelDisabled?: (model: T) => boolean;

View File

@ -4,6 +4,7 @@ import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
import type { PersistConfig, RootState } from 'app/store/store';
import { buildControlAdapter } from 'features/controlAdapters/util/buildControlAdapter';
import { buildControlAdapterProcessor } from 'features/controlAdapters/util/buildControlAdapterProcessor';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { cloneDeep, merge, uniq } from 'lodash-es';
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import { socketInvocationError } from 'services/events/actions';
@ -197,7 +198,7 @@ export const controlAdaptersSlice = createSlice({
return;
}
const model = { key: modelConfig.key, base: modelConfig.base };
const model = zModelIdentifierField.parse(modelConfig);
if (!isControlNetOrT2IAdapter(cn)) {
caAdapter.updateOne(state, { id, changes: { model } });

View File

@ -1,7 +1,7 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import { getModelKeyAndBase } from 'features/metadata/util/modelFetchingHelpers';
import { zModelIdentifierField } from 'features/nodes/types/common';
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
import type { LoRAModelConfig } from 'services/api/types';
@ -31,7 +31,7 @@ export const loraSlice = createSlice({
initialState: initialLoraState,
reducers: {
loraAdded: (state, action: PayloadAction<LoRAModelConfig>) => {
const model = getModelKeyAndBase(action.payload);
const model = zModelIdentifierField.parse(action.payload);
state.loras[model.key] = { ...defaultLoRAConfig, model };
},
loraRecalled: (state, action: PayloadAction<LoRA>) => {

View File

@ -13,13 +13,13 @@ import type {
} from 'features/metadata/types';
import { fetchModelConfig } from 'features/metadata/util/modelFetchingHelpers';
import { validators } from 'features/metadata/util/validators';
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { t } from 'i18next';
import { parsers } from './parsers';
import { recallers } from './recallers';
const renderModelConfigValue: MetadataRenderValueFunc<ModelIdentifierWithBase> = async (value) => {
const renderModelConfigValue: MetadataRenderValueFunc<ModelIdentifierField> = async (value) => {
try {
const modelConfig = await fetchModelConfig(value.key);
return `${modelConfig.name} (${modelConfig.base.toUpperCase()})`;

View File

@ -1,5 +1,4 @@
import { getStore } from 'app/store/nanostores/store';
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
import { isModelIdentifier, isModelIdentifierV2 } from 'features/nodes/types/common';
import { modelsApi } from 'services/api/endpoints/models';
import type { AnyModelConfig, BaseModelType, ModelType } from 'services/api/types';
@ -105,8 +104,3 @@ export const getModelKey = async (modelIdentifier: unknown, type: ModelType, mes
}
throw new InvalidModelConfigError(message || `Invalid model identifier: ${modelIdentifier}`);
};
export const getModelKeyAndBase = (modelConfig: AnyModelConfig): ModelIdentifierWithBase => ({
key: modelConfig.key,
base: modelConfig.base,
});

View File

@ -13,12 +13,7 @@ import type {
T2IAdapterConfigMetadata,
} from 'features/metadata/types';
import { fetchModelConfigWithTypeGuard, getModelKey } from 'features/metadata/util/modelFetchingHelpers';
import {
zControlField,
zIPAdapterField,
zModelIdentifierWithBase,
zT2IAdapterField,
} from 'features/nodes/types/common';
import { zControlField, zIPAdapterField, zModelIdentifierField, zT2IAdapterField } from 'features/nodes/types/common';
import type {
ParameterCFGRescaleMultiplier,
ParameterCFGScale,
@ -181,7 +176,7 @@ const parseMainModel: MetadataParseFunc<ParameterModel> = async (metadata) => {
const model = await getProperty(metadata, 'model', undefined);
const key = await getModelKey(model, 'main');
const mainModelConfig = await fetchModelConfigWithTypeGuard(key, isNonRefinerMainModelConfig);
const modelIdentifier = zModelIdentifierWithBase.parse(mainModelConfig);
const modelIdentifier = zModelIdentifierField.parse(mainModelConfig);
return modelIdentifier;
};
@ -189,7 +184,7 @@ const parseRefinerModel: MetadataParseFunc<ParameterSDXLRefinerModel> = async (m
const refiner_model = await getProperty(metadata, 'refiner_model', undefined);
const key = await getModelKey(refiner_model, 'main');
const refinerModelConfig = await fetchModelConfigWithTypeGuard(key, isRefinerMainModelModelConfig);
const modelIdentifier = zModelIdentifierWithBase.parse(refinerModelConfig);
const modelIdentifier = zModelIdentifierField.parse(refinerModelConfig);
return modelIdentifier;
};
@ -197,7 +192,7 @@ const parseVAEModel: MetadataParseFunc<ParameterVAEModel> = async (metadata) =>
const vae = await getProperty(metadata, 'vae', undefined);
const key = await getModelKey(vae, 'vae');
const vaeModelConfig = await fetchModelConfigWithTypeGuard(key, isVAEModelConfig);
const modelIdentifier = zModelIdentifierWithBase.parse(vaeModelConfig);
const modelIdentifier = zModelIdentifierField.parse(vaeModelConfig);
return modelIdentifier;
};
@ -211,7 +206,7 @@ const parseLoRA: MetadataParseFunc<LoRA> = async (metadataItem) => {
const loraModelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig);
return {
model: zModelIdentifierWithBase.parse(loraModelConfig),
model: zModelIdentifierField.parse(loraModelConfig),
weight: isParameterLoRAWeight(weight) ? weight : defaultLoRAConfig.weight,
isEnabled: true,
};
@ -258,7 +253,7 @@ const parseControlNet: MetadataParseFunc<ControlNetConfigMetadata> = async (meta
const controlNet: ControlNetConfigMetadata = {
type: 'controlnet',
isEnabled: true,
model: zModelIdentifierWithBase.parse(controlNetModel),
model: zModelIdentifierField.parse(controlNetModel),
weight: typeof control_weight === 'number' ? control_weight : initialControlNet.weight,
beginStepPct: begin_step_percent ?? initialControlNet.beginStepPct,
endStepPct: end_step_percent ?? initialControlNet.endStepPct,
@ -309,7 +304,7 @@ const parseT2IAdapter: MetadataParseFunc<T2IAdapterConfigMetadata> = async (meta
const t2iAdapter: T2IAdapterConfigMetadata = {
type: 't2i_adapter',
isEnabled: true,
model: zModelIdentifierWithBase.parse(t2iAdapterModel),
model: zModelIdentifierField.parse(t2iAdapterModel),
weight: typeof weight === 'number' ? weight : initialT2IAdapter.weight,
beginStepPct: begin_step_percent ?? initialT2IAdapter.beginStepPct,
endStepPct: end_step_percent ?? initialT2IAdapter.endStepPct,
@ -354,7 +349,7 @@ const parseIPAdapter: MetadataParseFunc<IPAdapterConfigMetadata> = async (metada
id: uuidv4(),
type: 'ip_adapter',
isEnabled: true,
model: zModelIdentifierWithBase.parse(ipAdapterModel),
model: zModelIdentifierField.parse(ipAdapterModel),
controlImage: image?.image_name ?? null,
weight: weight ?? initialIPAdapter.weight,
beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct,

View File

@ -1,5 +1,4 @@
import type {
BaseModel,
BoardField,
Classification,
ColorField,
@ -7,6 +6,7 @@ import type {
ImageField,
ImageOutput,
IPAdapterField,
ModelIdentifierField,
ProgressImage,
SchedulerField,
T2IAdapterField,
@ -33,10 +33,9 @@ describe('Common types', () => {
test('T2IAdapterField', () => assert<Equals<T2IAdapterField, S['T2IAdapterField']>>());
// Model component types
test('BaseModel', () => assert<Equals<BaseModel, S['BaseModelType']>>());
test('ModelIdentifier', () => assert<Equals<ModelIdentifierField, S['ModelIdentifierField']>>());
// Misc types
// @ts-expect-error TODO(psyche): There is no `ProgressImage` in the server types yet
test('ProgressImage', () => assert<Equals<ProgressImage, S['ProgressImage']>>());
test('ImageOutput', () => assert<Equals<ImageOutput, S['ImageOutput']>>());
test('Classification', () => assert<Equals<Classification, S['Classification']>>());

View File

@ -55,6 +55,17 @@ export type SchedulerField = z.infer<typeof zSchedulerField>;
// #region Model-related schemas
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
const zModelType = z.enum([
'main',
'vae',
'lora',
'controlnet',
't2i_adapter',
'ip_adapter',
'embedding',
'onnx',
'clip_vision',
]);
const zSubModelType = z.enum([
'unet',
'text_encoder',
@ -67,26 +78,25 @@ const zSubModelType = z.enum([
'scheduler',
'safety_checker',
]);
const zModelIdentifier = z.object({
export const zModelIdentifierField = z.object({
key: z.string().min(1),
hash: z.string().min(1),
name: z.string().min(1),
base: zBaseModel,
type: zModelType,
submodel_type: zSubModelType.nullish(),
});
export const isModelIdentifier = (field: unknown): field is ModelIdentifier =>
zModelIdentifier.safeParse(field).success;
export const isModelIdentifier = (field: unknown): field is ModelIdentifierField =>
zModelIdentifierField.safeParse(field).success;
export const isModelIdentifierV2 = (field: unknown): field is ModelIdentifierV2 =>
zModelIdentifierV2.safeParse(field).success;
const zModelFieldBase = zModelIdentifier;
export const zModelIdentifierWithBase = zModelIdentifier.extend({ base: zBaseModel });
export type BaseModel = z.infer<typeof zBaseModel>;
type ModelIdentifier = z.infer<typeof zModelIdentifier>;
export type ModelIdentifierWithBase = z.infer<typeof zModelIdentifierWithBase>;
export type ModelIdentifierField = z.infer<typeof zModelIdentifierField>;
// #endregion
// #region Control Adapters
export const zControlField = z.object({
image: zImageField,
control_model: zModelFieldBase,
control_model: zModelIdentifierField,
control_weight: z.union([z.number(), z.array(z.number())]).optional(),
begin_step_percent: z.number().optional(),
end_step_percent: z.number().optional(),
@ -97,7 +107,7 @@ export type ControlField = z.infer<typeof zControlField>;
export const zIPAdapterField = z.object({
image: zImageField,
ip_adapter_model: zModelFieldBase,
ip_adapter_model: zModelIdentifierField,
weight: z.number(),
begin_step_percent: z.number().optional(),
end_step_percent: z.number().optional(),
@ -106,7 +116,7 @@ export type IPAdapterField = z.infer<typeof zIPAdapterField>;
export const zT2IAdapterField = z.object({
image: zImageField,
t2i_adapter_model: zModelFieldBase,
t2i_adapter_model: zModelIdentifierField,
weight: z.union([z.number(), z.array(z.number())]).optional(),
begin_step_percent: z.number().optional(),
end_step_percent: z.number().optional(),

View File

@ -1,6 +1,6 @@
import { z } from 'zod';
import { zBoardField, zColorField, zImageField, zModelIdentifierWithBase, zSchedulerField } from './common';
import { zBoardField, zColorField, zImageField, zModelIdentifierField, zSchedulerField } from './common';
/**
* zod schemas & inferred types for fields.
@ -277,7 +277,7 @@ export const isColorFieldInputTemplate = (val: unknown): val is ColorFieldInputT
const zMainModelFieldType = zFieldTypeBase.extend({
name: z.literal('MainModelField'),
});
export const zMainModelFieldValue = zModelIdentifierWithBase.optional();
export const zMainModelFieldValue = zModelIdentifierField.optional();
const zMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zMainModelFieldValue,
});
@ -348,7 +348,7 @@ export const isSDXLRefinerModelFieldInputTemplate = (val: unknown): val is SDXLR
const zVAEModelFieldType = zFieldTypeBase.extend({
name: z.literal('VAEModelField'),
});
export const zVAEModelFieldValue = zModelIdentifierWithBase.optional();
export const zVAEModelFieldValue = zModelIdentifierField.optional();
const zVAEModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zVAEModelFieldValue,
});
@ -372,7 +372,7 @@ export const isVAEModelFieldInputTemplate = (val: unknown): val is VAEModelField
const zLoRAModelFieldType = zFieldTypeBase.extend({
name: z.literal('LoRAModelField'),
});
export const zLoRAModelFieldValue = zModelIdentifierWithBase.optional();
export const zLoRAModelFieldValue = zModelIdentifierField.optional();
const zLoRAModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zLoRAModelFieldValue,
});
@ -396,7 +396,7 @@ export const isLoRAModelFieldInputTemplate = (val: unknown): val is LoRAModelFie
const zControlNetModelFieldType = zFieldTypeBase.extend({
name: z.literal('ControlNetModelField'),
});
export const zControlNetModelFieldValue = zModelIdentifierWithBase.optional();
export const zControlNetModelFieldValue = zModelIdentifierField.optional();
const zControlNetModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zControlNetModelFieldValue,
});
@ -420,7 +420,7 @@ export const isControlNetModelFieldInputTemplate = (val: unknown): val is Contro
const zIPAdapterModelFieldType = zFieldTypeBase.extend({
name: z.literal('IPAdapterModelField'),
});
export const zIPAdapterModelFieldValue = zModelIdentifierWithBase.optional();
export const zIPAdapterModelFieldValue = zModelIdentifierField.optional();
const zIPAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zIPAdapterModelFieldValue,
});
@ -444,7 +444,7 @@ export const isIPAdapterModelFieldInputTemplate = (val: unknown): val is IPAdapt
const zT2IAdapterModelFieldType = zFieldTypeBase.extend({
name: z.literal('T2IAdapterModelField'),
});
export const zT2IAdapterModelFieldValue = zModelIdentifierWithBase.optional();
export const zT2IAdapterModelFieldValue = zModelIdentifierField.optional();
const zT2IAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zT2IAdapterModelFieldValue,
});

View File

@ -1,15 +1,10 @@
import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { filter, size } from 'lodash-es';
import {
type CoreMetadataInvocation,
isLoRAModelConfig,
type LoRALoaderInvocation,
type NonNullableGraph,
} from 'services/api/types';
import type { CoreMetadataInvocation, LoRALoaderInvocation, NonNullableGraph } from 'services/api/types';
import { CLIP_SKIP, LORA_LOADER, MAIN_MODEL_LOADER, NEGATIVE_CONDITIONING, POSITIVE_CONDITIONING } from './constants';
import { getModelMetadataField, upsertMetadata } from './metadata';
import { upsertMetadata } from './metadata';
export const addLoRAsToGraph = async (
state: RootState,
@ -49,19 +44,18 @@ export const addLoRAsToGraph = async (
const { weight } = lora;
const { key } = lora.model;
const currentLoraNodeId = `${LORA_LOADER}_${key}`;
const parsedModel = zModelIdentifierField.parse(lora.model);
const loraLoaderNode: LoRALoaderInvocation = {
type: 'lora_loader',
id: currentLoraNodeId,
is_intermediate: true,
lora: { key },
lora: parsedModel,
weight,
};
const modelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig);
loraMetadata.push({
model: getModelMetadataField(modelConfig),
model: parsedModel,
weight,
});

View File

@ -1,12 +1,7 @@
import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { filter, size } from 'lodash-es';
import {
type CoreMetadataInvocation,
isLoRAModelConfig,
type NonNullableGraph,
type SDXLLoRALoaderInvocation,
} from 'services/api/types';
import type { CoreMetadataInvocation, NonNullableGraph, SDXLLoRALoaderInvocation } from 'services/api/types';
import {
LORA_LOADER,
@ -16,7 +11,7 @@ import {
SDXL_REFINER_INPAINT_CREATE_MASK,
SEAMLESS,
} from './constants';
import { getModelMetadataField, upsertMetadata } from './metadata';
import { upsertMetadata } from './metadata';
export const addSDXLLoRAsToGraph = async (
state: RootState,
@ -63,20 +58,18 @@ export const addSDXLLoRAsToGraph = async (
enabledLoRAs.forEach(async (lora) => {
const { weight } = lora;
const { key } = lora.model;
const currentLoraNodeId = `${LORA_LOADER}_${key}`;
const currentLoraNodeId = `${LORA_LOADER}_${lora.model.key}`;
const parsedModel = zModelIdentifierField.parse(lora.model);
const loraLoaderNode: SDXLLoRALoaderInvocation = {
type: 'sdxl_lora_loader',
id: currentLoraNodeId,
is_intermediate: true,
lora: { key },
lora: parsedModel,
weight,
};
const modelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig);
loraMetadata.push({ model: getModelMetadataField(modelConfig), weight });
loraMetadata.push({ model: parsedModel, weight });
// add to graph
graph.nodes[currentLoraNodeId] = loraLoaderNode;

View File

@ -3,6 +3,7 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { useModelCustomSelect } from 'common/hooks/useModelCustomSelect';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { modelSelected } from 'features/parameters/store/actions';
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
import { memo, useCallback } from 'react';
@ -24,7 +25,11 @@ const ParamMainModelSelect = () => {
if (!model) {
return;
}
dispatch(modelSelected({ key: model.key, base: model.base }));
try {
dispatch(modelSelected(zModelIdentifierField.parse(model)));
} catch {
// no-op
}
},
[dispatch]
);

View File

@ -3,7 +3,7 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { getModelKeyAndBase } from 'features/metadata/util/modelFetchingHelpers';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { selectGenerationSlice, vaeSelected } from 'features/parameters/store/generationSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
@ -30,7 +30,7 @@ const ParamVAEModelSelect = () => {
);
const _onChange = useCallback(
(vae: VAEModelConfig | null) => {
dispatch(vaeSelected(vae ? getModelKeyAndBase(vae) : null));
dispatch(vaeSelected(vae ? zModelIdentifierField.parse(vae) : null));
},
[dispatch]
);

View File

@ -1,6 +1,6 @@
import { NUMPY_RAND_MAX } from 'app/constants';
import { roundToMultiple } from 'common/util/roundDownToMultiple';
import { zModelIdentifierWithBase, zSchedulerField } from 'features/nodes/types/common';
import { zModelIdentifierField, zSchedulerField } from 'features/nodes/types/common';
import { z } from 'zod';
/**
@ -92,37 +92,37 @@ export const isParameterHeight = (val: unknown): val is ParameterHeight => zPara
// #endregion
// #region Model
export const zParameterModel = zModelIdentifierWithBase;
export const zParameterModel = zModelIdentifierField;
export type ParameterModel = z.infer<typeof zParameterModel>;
// #endregion
// #region SDXL Refiner Model
const zParameterSDXLRefinerModel = zModelIdentifierWithBase;
const zParameterSDXLRefinerModel = zModelIdentifierField;
export type ParameterSDXLRefinerModel = z.infer<typeof zParameterSDXLRefinerModel>;
// #endregion
// #region VAE Model
export const zParameterVAEModel = zModelIdentifierWithBase;
export const zParameterVAEModel = zModelIdentifierField;
export type ParameterVAEModel = z.infer<typeof zParameterVAEModel>;
// #endregion
// #region LoRA Model
const zParameterLoRAModel = zModelIdentifierWithBase;
const zParameterLoRAModel = zModelIdentifierField;
export type ParameterLoRAModel = z.infer<typeof zParameterLoRAModel>;
// #endregion
// #region ControlNet Model
const zParameterControlNetModel = zModelIdentifierWithBase;
const zParameterControlNetModel = zModelIdentifierField;
export type ParameterControlNetModel = z.infer<typeof zParameterControlNetModel>;
// #endregion
// #region IP Adapter Model
const zParameterIPAdapterModel = zModelIdentifierWithBase;
const zParameterIPAdapterModel = zModelIdentifierField;
export type ParameterIPAdapterModel = z.infer<typeof zParameterIPAdapterModel>;
// #endregion
// #region T2I Adapter Model
const zParameterT2IAdapterModel = zModelIdentifierWithBase;
const zParameterT2IAdapterModel = zModelIdentifierField;
export type ParameterT2IAdapterModel = z.infer<typeof zParameterT2IAdapterModel>;
// #endregion

View File

@ -1,11 +1,11 @@
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
import type { ModelIdentifierField } from 'features/nodes/types/common';
/**
* Gets the optimal dimension for a givel model, based on the model's base_model
* @param model The model identifier
* @returns The optimal dimension for the model
*/
export const getOptimalDimension = (model?: ModelIdentifierWithBase | null): number =>
export const getOptimalDimension = (model?: ModelIdentifierField | null): number =>
model?.base === 'sdxl' ? 1024 : 512;
const MIN_AREA_FACTOR = 0.8;

View File

@ -3,7 +3,7 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { useModelCombobox } from 'common/hooks/useModelCombobox';
import { getModelKeyAndBase } from 'features/metadata/util/modelFetchingHelpers';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { refinerModelChanged, selectSdxlSlice } from 'features/sdxl/store/sdxlSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
@ -26,7 +26,7 @@ const ParamSDXLRefinerModelSelect = () => {
dispatch(refinerModelChanged(null));
return;
}
dispatch(refinerModelChanged(getModelKeyAndBase(model)));
dispatch(refinerModelChanged(zModelIdentifierField.parse(model)));
},
[dispatch]
);