fix(ui): update all components and logic to use enriched ModelIdentifierField

This commit is contained in:
psychedelicious
2024-03-09 19:51:15 +11:00
parent 4433b78e59
commit 133c90e116
19 changed files with 85 additions and 94 deletions

View File

@ -10,7 +10,7 @@ export const ReduxInit = memo((props: PropsWithChildren) => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
useGlobalModifiersInit(); useGlobalModifiersInit();
useEffect(() => { useEffect(() => {
dispatch(modelChanged({ key: 'test_model', base: 'sd-1' })); dispatch(modelChanged({ key: 'test_model', hash: 'some_hash', name: 'some name', base: 'sd-1', type: 'main' }));
}, []); }, []);
return props.children; return props.children;

View File

@ -2,7 +2,7 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import type { EntityState } from '@reduxjs/toolkit'; import type { EntityState } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import type { GroupBase } from 'chakra-react-select'; import type { GroupBase } from 'chakra-react-select';
import type { ModelIdentifierWithBase } from 'features/nodes/types/common'; import type { ModelIdentifierField } from 'features/nodes/types/common';
import { groupBy, map, reduce } from 'lodash-es'; import { groupBy, map, reduce } from 'lodash-es';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
@ -10,7 +10,7 @@ import type { AnyModelConfig } from 'services/api/types';
type UseGroupedModelComboboxArg<T extends AnyModelConfig> = { type UseGroupedModelComboboxArg<T extends AnyModelConfig> = {
modelEntities: EntityState<T, string> | undefined; modelEntities: EntityState<T, string> | undefined;
selectedModel?: ModelIdentifierWithBase | null; selectedModel?: ModelIdentifierField | null;
onChange: (value: T | null) => void; onChange: (value: T | null) => void;
getIsDisabled?: (model: T) => boolean; getIsDisabled?: (model: T) => boolean;
isLoading?: boolean; isLoading?: boolean;

View File

@ -1,6 +1,6 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library'; import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import type { EntityState } from '@reduxjs/toolkit'; import type { EntityState } from '@reduxjs/toolkit';
import type { ModelIdentifierWithBase } from 'features/nodes/types/common'; import type { ModelIdentifierField } from 'features/nodes/types/common';
import { map } from 'lodash-es'; import { map } from 'lodash-es';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
@ -8,7 +8,7 @@ import type { AnyModelConfig } from 'services/api/types';
type UseModelComboboxArg<T extends AnyModelConfig> = { type UseModelComboboxArg<T extends AnyModelConfig> = {
modelEntities: EntityState<T, string> | undefined; modelEntities: EntityState<T, string> | undefined;
selectedModel?: ModelIdentifierWithBase | null; selectedModel?: ModelIdentifierField | null;
onChange: (value: T | null) => void; onChange: (value: T | null) => void;
getIsDisabled?: (model: T) => boolean; getIsDisabled?: (model: T) => boolean;
optionsFilter?: (model: T) => boolean; optionsFilter?: (model: T) => boolean;

View File

@ -1,7 +1,7 @@
import type { Item } from '@invoke-ai/ui-library'; import type { Item } from '@invoke-ai/ui-library';
import type { EntityState } from '@reduxjs/toolkit'; import type { EntityState } from '@reduxjs/toolkit';
import { EMPTY_ARRAY } from 'app/store/constants'; import { EMPTY_ARRAY } from 'app/store/constants';
import type { ModelIdentifierWithBase } from 'features/nodes/types/common'; import type { ModelIdentifierField } from 'features/nodes/types/common';
import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants'; import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants';
import { filter } from 'lodash-es'; import { filter } from 'lodash-es';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
@ -11,7 +11,7 @@ import type { AnyModelConfig } from 'services/api/types';
type UseModelCustomSelectArg<T extends AnyModelConfig> = { type UseModelCustomSelectArg<T extends AnyModelConfig> = {
data: EntityState<T, string> | undefined; data: EntityState<T, string> | undefined;
isLoading: boolean; isLoading: boolean;
selectedModel?: ModelIdentifierWithBase | null; selectedModel?: ModelIdentifierField | null;
onChange: (value: T | null) => void; onChange: (value: T | null) => void;
modelFilter?: (model: T) => boolean; modelFilter?: (model: T) => boolean;
isModelDisabled?: (model: T) => boolean; isModelDisabled?: (model: T) => boolean;

View File

@ -4,6 +4,7 @@ import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
import type { PersistConfig, RootState } from 'app/store/store'; import type { PersistConfig, RootState } from 'app/store/store';
import { buildControlAdapter } from 'features/controlAdapters/util/buildControlAdapter'; import { buildControlAdapter } from 'features/controlAdapters/util/buildControlAdapter';
import { buildControlAdapterProcessor } from 'features/controlAdapters/util/buildControlAdapterProcessor'; import { buildControlAdapterProcessor } from 'features/controlAdapters/util/buildControlAdapterProcessor';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { cloneDeep, merge, uniq } from 'lodash-es'; import { cloneDeep, merge, uniq } from 'lodash-es';
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types'; import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import { socketInvocationError } from 'services/events/actions'; import { socketInvocationError } from 'services/events/actions';
@ -197,7 +198,7 @@ export const controlAdaptersSlice = createSlice({
return; return;
} }
const model = { key: modelConfig.key, base: modelConfig.base }; const model = zModelIdentifierField.parse(modelConfig);
if (!isControlNetOrT2IAdapter(cn)) { if (!isControlNetOrT2IAdapter(cn)) {
caAdapter.updateOne(state, { id, changes: { model } }); caAdapter.updateOne(state, { id, changes: { model } });

View File

@ -1,7 +1,7 @@
import type { PayloadAction } from '@reduxjs/toolkit'; import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store'; import type { PersistConfig, RootState } from 'app/store/store';
import { getModelKeyAndBase } from 'features/metadata/util/modelFetchingHelpers'; import { zModelIdentifierField } from 'features/nodes/types/common';
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas'; import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
import type { LoRAModelConfig } from 'services/api/types'; import type { LoRAModelConfig } from 'services/api/types';
@ -31,7 +31,7 @@ export const loraSlice = createSlice({
initialState: initialLoraState, initialState: initialLoraState,
reducers: { reducers: {
loraAdded: (state, action: PayloadAction<LoRAModelConfig>) => { loraAdded: (state, action: PayloadAction<LoRAModelConfig>) => {
const model = getModelKeyAndBase(action.payload); const model = zModelIdentifierField.parse(action.payload);
state.loras[model.key] = { ...defaultLoRAConfig, model }; state.loras[model.key] = { ...defaultLoRAConfig, model };
}, },
loraRecalled: (state, action: PayloadAction<LoRA>) => { loraRecalled: (state, action: PayloadAction<LoRA>) => {

View File

@ -13,13 +13,13 @@ import type {
} from 'features/metadata/types'; } from 'features/metadata/types';
import { fetchModelConfig } from 'features/metadata/util/modelFetchingHelpers'; import { fetchModelConfig } from 'features/metadata/util/modelFetchingHelpers';
import { validators } from 'features/metadata/util/validators'; import { validators } from 'features/metadata/util/validators';
import type { ModelIdentifierWithBase } from 'features/nodes/types/common'; import type { ModelIdentifierField } from 'features/nodes/types/common';
import { t } from 'i18next'; import { t } from 'i18next';
import { parsers } from './parsers'; import { parsers } from './parsers';
import { recallers } from './recallers'; import { recallers } from './recallers';
const renderModelConfigValue: MetadataRenderValueFunc<ModelIdentifierWithBase> = async (value) => { const renderModelConfigValue: MetadataRenderValueFunc<ModelIdentifierField> = async (value) => {
try { try {
const modelConfig = await fetchModelConfig(value.key); const modelConfig = await fetchModelConfig(value.key);
return `${modelConfig.name} (${modelConfig.base.toUpperCase()})`; return `${modelConfig.name} (${modelConfig.base.toUpperCase()})`;

View File

@ -1,5 +1,4 @@
import { getStore } from 'app/store/nanostores/store'; import { getStore } from 'app/store/nanostores/store';
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
import { isModelIdentifier, isModelIdentifierV2 } from 'features/nodes/types/common'; import { isModelIdentifier, isModelIdentifierV2 } from 'features/nodes/types/common';
import { modelsApi } from 'services/api/endpoints/models'; import { modelsApi } from 'services/api/endpoints/models';
import type { AnyModelConfig, BaseModelType, ModelType } from 'services/api/types'; import type { AnyModelConfig, BaseModelType, ModelType } from 'services/api/types';
@ -105,8 +104,3 @@ export const getModelKey = async (modelIdentifier: unknown, type: ModelType, mes
} }
throw new InvalidModelConfigError(message || `Invalid model identifier: ${modelIdentifier}`); throw new InvalidModelConfigError(message || `Invalid model identifier: ${modelIdentifier}`);
}; };
export const getModelKeyAndBase = (modelConfig: AnyModelConfig): ModelIdentifierWithBase => ({
key: modelConfig.key,
base: modelConfig.base,
});

View File

@ -13,12 +13,7 @@ import type {
T2IAdapterConfigMetadata, T2IAdapterConfigMetadata,
} from 'features/metadata/types'; } from 'features/metadata/types';
import { fetchModelConfigWithTypeGuard, getModelKey } from 'features/metadata/util/modelFetchingHelpers'; import { fetchModelConfigWithTypeGuard, getModelKey } from 'features/metadata/util/modelFetchingHelpers';
import { import { zControlField, zIPAdapterField, zModelIdentifierField, zT2IAdapterField } from 'features/nodes/types/common';
zControlField,
zIPAdapterField,
zModelIdentifierWithBase,
zT2IAdapterField,
} from 'features/nodes/types/common';
import type { import type {
ParameterCFGRescaleMultiplier, ParameterCFGRescaleMultiplier,
ParameterCFGScale, ParameterCFGScale,
@ -181,7 +176,7 @@ const parseMainModel: MetadataParseFunc<ParameterModel> = async (metadata) => {
const model = await getProperty(metadata, 'model', undefined); const model = await getProperty(metadata, 'model', undefined);
const key = await getModelKey(model, 'main'); const key = await getModelKey(model, 'main');
const mainModelConfig = await fetchModelConfigWithTypeGuard(key, isNonRefinerMainModelConfig); const mainModelConfig = await fetchModelConfigWithTypeGuard(key, isNonRefinerMainModelConfig);
const modelIdentifier = zModelIdentifierWithBase.parse(mainModelConfig); const modelIdentifier = zModelIdentifierField.parse(mainModelConfig);
return modelIdentifier; return modelIdentifier;
}; };
@ -189,7 +184,7 @@ const parseRefinerModel: MetadataParseFunc<ParameterSDXLRefinerModel> = async (m
const refiner_model = await getProperty(metadata, 'refiner_model', undefined); const refiner_model = await getProperty(metadata, 'refiner_model', undefined);
const key = await getModelKey(refiner_model, 'main'); const key = await getModelKey(refiner_model, 'main');
const refinerModelConfig = await fetchModelConfigWithTypeGuard(key, isRefinerMainModelModelConfig); const refinerModelConfig = await fetchModelConfigWithTypeGuard(key, isRefinerMainModelModelConfig);
const modelIdentifier = zModelIdentifierWithBase.parse(refinerModelConfig); const modelIdentifier = zModelIdentifierField.parse(refinerModelConfig);
return modelIdentifier; return modelIdentifier;
}; };
@ -197,7 +192,7 @@ const parseVAEModel: MetadataParseFunc<ParameterVAEModel> = async (metadata) =>
const vae = await getProperty(metadata, 'vae', undefined); const vae = await getProperty(metadata, 'vae', undefined);
const key = await getModelKey(vae, 'vae'); const key = await getModelKey(vae, 'vae');
const vaeModelConfig = await fetchModelConfigWithTypeGuard(key, isVAEModelConfig); const vaeModelConfig = await fetchModelConfigWithTypeGuard(key, isVAEModelConfig);
const modelIdentifier = zModelIdentifierWithBase.parse(vaeModelConfig); const modelIdentifier = zModelIdentifierField.parse(vaeModelConfig);
return modelIdentifier; return modelIdentifier;
}; };
@ -211,7 +206,7 @@ const parseLoRA: MetadataParseFunc<LoRA> = async (metadataItem) => {
const loraModelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig); const loraModelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig);
return { return {
model: zModelIdentifierWithBase.parse(loraModelConfig), model: zModelIdentifierField.parse(loraModelConfig),
weight: isParameterLoRAWeight(weight) ? weight : defaultLoRAConfig.weight, weight: isParameterLoRAWeight(weight) ? weight : defaultLoRAConfig.weight,
isEnabled: true, isEnabled: true,
}; };
@ -258,7 +253,7 @@ const parseControlNet: MetadataParseFunc<ControlNetConfigMetadata> = async (meta
const controlNet: ControlNetConfigMetadata = { const controlNet: ControlNetConfigMetadata = {
type: 'controlnet', type: 'controlnet',
isEnabled: true, isEnabled: true,
model: zModelIdentifierWithBase.parse(controlNetModel), model: zModelIdentifierField.parse(controlNetModel),
weight: typeof control_weight === 'number' ? control_weight : initialControlNet.weight, weight: typeof control_weight === 'number' ? control_weight : initialControlNet.weight,
beginStepPct: begin_step_percent ?? initialControlNet.beginStepPct, beginStepPct: begin_step_percent ?? initialControlNet.beginStepPct,
endStepPct: end_step_percent ?? initialControlNet.endStepPct, endStepPct: end_step_percent ?? initialControlNet.endStepPct,
@ -309,7 +304,7 @@ const parseT2IAdapter: MetadataParseFunc<T2IAdapterConfigMetadata> = async (meta
const t2iAdapter: T2IAdapterConfigMetadata = { const t2iAdapter: T2IAdapterConfigMetadata = {
type: 't2i_adapter', type: 't2i_adapter',
isEnabled: true, isEnabled: true,
model: zModelIdentifierWithBase.parse(t2iAdapterModel), model: zModelIdentifierField.parse(t2iAdapterModel),
weight: typeof weight === 'number' ? weight : initialT2IAdapter.weight, weight: typeof weight === 'number' ? weight : initialT2IAdapter.weight,
beginStepPct: begin_step_percent ?? initialT2IAdapter.beginStepPct, beginStepPct: begin_step_percent ?? initialT2IAdapter.beginStepPct,
endStepPct: end_step_percent ?? initialT2IAdapter.endStepPct, endStepPct: end_step_percent ?? initialT2IAdapter.endStepPct,
@ -354,7 +349,7 @@ const parseIPAdapter: MetadataParseFunc<IPAdapterConfigMetadata> = async (metada
id: uuidv4(), id: uuidv4(),
type: 'ip_adapter', type: 'ip_adapter',
isEnabled: true, isEnabled: true,
model: zModelIdentifierWithBase.parse(ipAdapterModel), model: zModelIdentifierField.parse(ipAdapterModel),
controlImage: image?.image_name ?? null, controlImage: image?.image_name ?? null,
weight: weight ?? initialIPAdapter.weight, weight: weight ?? initialIPAdapter.weight,
beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct, beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct,

View File

@ -1,5 +1,4 @@
import type { import type {
BaseModel,
BoardField, BoardField,
Classification, Classification,
ColorField, ColorField,
@ -7,6 +6,7 @@ import type {
ImageField, ImageField,
ImageOutput, ImageOutput,
IPAdapterField, IPAdapterField,
ModelIdentifierField,
ProgressImage, ProgressImage,
SchedulerField, SchedulerField,
T2IAdapterField, T2IAdapterField,
@ -33,10 +33,9 @@ describe('Common types', () => {
test('T2IAdapterField', () => assert<Equals<T2IAdapterField, S['T2IAdapterField']>>()); test('T2IAdapterField', () => assert<Equals<T2IAdapterField, S['T2IAdapterField']>>());
// Model component types // Model component types
test('BaseModel', () => assert<Equals<BaseModel, S['BaseModelType']>>()); test('ModelIdentifier', () => assert<Equals<ModelIdentifierField, S['ModelIdentifierField']>>());
// Misc types // Misc types
// @ts-expect-error TODO(psyche): There is no `ProgressImage` in the server types yet
test('ProgressImage', () => assert<Equals<ProgressImage, S['ProgressImage']>>()); test('ProgressImage', () => assert<Equals<ProgressImage, S['ProgressImage']>>());
test('ImageOutput', () => assert<Equals<ImageOutput, S['ImageOutput']>>()); test('ImageOutput', () => assert<Equals<ImageOutput, S['ImageOutput']>>());
test('Classification', () => assert<Equals<Classification, S['Classification']>>()); test('Classification', () => assert<Equals<Classification, S['Classification']>>());

View File

@ -55,6 +55,17 @@ export type SchedulerField = z.infer<typeof zSchedulerField>;
// #region Model-related schemas // #region Model-related schemas
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']); const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
const zModelType = z.enum([
'main',
'vae',
'lora',
'controlnet',
't2i_adapter',
'ip_adapter',
'embedding',
'onnx',
'clip_vision',
]);
const zSubModelType = z.enum([ const zSubModelType = z.enum([
'unet', 'unet',
'text_encoder', 'text_encoder',
@ -67,26 +78,25 @@ const zSubModelType = z.enum([
'scheduler', 'scheduler',
'safety_checker', 'safety_checker',
]); ]);
export const zModelIdentifierField = z.object({
const zModelIdentifier = z.object({
key: z.string().min(1), key: z.string().min(1),
hash: z.string().min(1),
name: z.string().min(1),
base: zBaseModel,
type: zModelType,
submodel_type: zSubModelType.nullish(), submodel_type: zSubModelType.nullish(),
}); });
export const isModelIdentifier = (field: unknown): field is ModelIdentifier => export const isModelIdentifier = (field: unknown): field is ModelIdentifierField =>
zModelIdentifier.safeParse(field).success; zModelIdentifierField.safeParse(field).success;
export const isModelIdentifierV2 = (field: unknown): field is ModelIdentifierV2 => export const isModelIdentifierV2 = (field: unknown): field is ModelIdentifierV2 =>
zModelIdentifierV2.safeParse(field).success; zModelIdentifierV2.safeParse(field).success;
const zModelFieldBase = zModelIdentifier; export type ModelIdentifierField = z.infer<typeof zModelIdentifierField>;
export const zModelIdentifierWithBase = zModelIdentifier.extend({ base: zBaseModel });
export type BaseModel = z.infer<typeof zBaseModel>;
type ModelIdentifier = z.infer<typeof zModelIdentifier>;
export type ModelIdentifierWithBase = z.infer<typeof zModelIdentifierWithBase>;
// #endregion // #endregion
// #region Control Adapters // #region Control Adapters
export const zControlField = z.object({ export const zControlField = z.object({
image: zImageField, image: zImageField,
control_model: zModelFieldBase, control_model: zModelIdentifierField,
control_weight: z.union([z.number(), z.array(z.number())]).optional(), control_weight: z.union([z.number(), z.array(z.number())]).optional(),
begin_step_percent: z.number().optional(), begin_step_percent: z.number().optional(),
end_step_percent: z.number().optional(), end_step_percent: z.number().optional(),
@ -97,7 +107,7 @@ export type ControlField = z.infer<typeof zControlField>;
export const zIPAdapterField = z.object({ export const zIPAdapterField = z.object({
image: zImageField, image: zImageField,
ip_adapter_model: zModelFieldBase, ip_adapter_model: zModelIdentifierField,
weight: z.number(), weight: z.number(),
begin_step_percent: z.number().optional(), begin_step_percent: z.number().optional(),
end_step_percent: z.number().optional(), end_step_percent: z.number().optional(),
@ -106,7 +116,7 @@ export type IPAdapterField = z.infer<typeof zIPAdapterField>;
export const zT2IAdapterField = z.object({ export const zT2IAdapterField = z.object({
image: zImageField, image: zImageField,
t2i_adapter_model: zModelFieldBase, t2i_adapter_model: zModelIdentifierField,
weight: z.union([z.number(), z.array(z.number())]).optional(), weight: z.union([z.number(), z.array(z.number())]).optional(),
begin_step_percent: z.number().optional(), begin_step_percent: z.number().optional(),
end_step_percent: z.number().optional(), end_step_percent: z.number().optional(),

View File

@ -1,6 +1,6 @@
import { z } from 'zod'; import { z } from 'zod';
import { zBoardField, zColorField, zImageField, zModelIdentifierWithBase, zSchedulerField } from './common'; import { zBoardField, zColorField, zImageField, zModelIdentifierField, zSchedulerField } from './common';
/** /**
* zod schemas & inferred types for fields. * zod schemas & inferred types for fields.
@ -277,7 +277,7 @@ export const isColorFieldInputTemplate = (val: unknown): val is ColorFieldInputT
const zMainModelFieldType = zFieldTypeBase.extend({ const zMainModelFieldType = zFieldTypeBase.extend({
name: z.literal('MainModelField'), name: z.literal('MainModelField'),
}); });
export const zMainModelFieldValue = zModelIdentifierWithBase.optional(); export const zMainModelFieldValue = zModelIdentifierField.optional();
const zMainModelFieldInputInstance = zFieldInputInstanceBase.extend({ const zMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zMainModelFieldValue, value: zMainModelFieldValue,
}); });
@ -348,7 +348,7 @@ export const isSDXLRefinerModelFieldInputTemplate = (val: unknown): val is SDXLR
const zVAEModelFieldType = zFieldTypeBase.extend({ const zVAEModelFieldType = zFieldTypeBase.extend({
name: z.literal('VAEModelField'), name: z.literal('VAEModelField'),
}); });
export const zVAEModelFieldValue = zModelIdentifierWithBase.optional(); export const zVAEModelFieldValue = zModelIdentifierField.optional();
const zVAEModelFieldInputInstance = zFieldInputInstanceBase.extend({ const zVAEModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zVAEModelFieldValue, value: zVAEModelFieldValue,
}); });
@ -372,7 +372,7 @@ export const isVAEModelFieldInputTemplate = (val: unknown): val is VAEModelField
const zLoRAModelFieldType = zFieldTypeBase.extend({ const zLoRAModelFieldType = zFieldTypeBase.extend({
name: z.literal('LoRAModelField'), name: z.literal('LoRAModelField'),
}); });
export const zLoRAModelFieldValue = zModelIdentifierWithBase.optional(); export const zLoRAModelFieldValue = zModelIdentifierField.optional();
const zLoRAModelFieldInputInstance = zFieldInputInstanceBase.extend({ const zLoRAModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zLoRAModelFieldValue, value: zLoRAModelFieldValue,
}); });
@ -396,7 +396,7 @@ export const isLoRAModelFieldInputTemplate = (val: unknown): val is LoRAModelFie
const zControlNetModelFieldType = zFieldTypeBase.extend({ const zControlNetModelFieldType = zFieldTypeBase.extend({
name: z.literal('ControlNetModelField'), name: z.literal('ControlNetModelField'),
}); });
export const zControlNetModelFieldValue = zModelIdentifierWithBase.optional(); export const zControlNetModelFieldValue = zModelIdentifierField.optional();
const zControlNetModelFieldInputInstance = zFieldInputInstanceBase.extend({ const zControlNetModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zControlNetModelFieldValue, value: zControlNetModelFieldValue,
}); });
@ -420,7 +420,7 @@ export const isControlNetModelFieldInputTemplate = (val: unknown): val is Contro
const zIPAdapterModelFieldType = zFieldTypeBase.extend({ const zIPAdapterModelFieldType = zFieldTypeBase.extend({
name: z.literal('IPAdapterModelField'), name: z.literal('IPAdapterModelField'),
}); });
export const zIPAdapterModelFieldValue = zModelIdentifierWithBase.optional(); export const zIPAdapterModelFieldValue = zModelIdentifierField.optional();
const zIPAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({ const zIPAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zIPAdapterModelFieldValue, value: zIPAdapterModelFieldValue,
}); });
@ -444,7 +444,7 @@ export const isIPAdapterModelFieldInputTemplate = (val: unknown): val is IPAdapt
const zT2IAdapterModelFieldType = zFieldTypeBase.extend({ const zT2IAdapterModelFieldType = zFieldTypeBase.extend({
name: z.literal('T2IAdapterModelField'), name: z.literal('T2IAdapterModelField'),
}); });
export const zT2IAdapterModelFieldValue = zModelIdentifierWithBase.optional(); export const zT2IAdapterModelFieldValue = zModelIdentifierField.optional();
const zT2IAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({ const zT2IAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zT2IAdapterModelFieldValue, value: zT2IAdapterModelFieldValue,
}); });

View File

@ -1,15 +1,10 @@
import type { RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; import { zModelIdentifierField } from 'features/nodes/types/common';
import { filter, size } from 'lodash-es'; import { filter, size } from 'lodash-es';
import { import type { CoreMetadataInvocation, LoRALoaderInvocation, NonNullableGraph } from 'services/api/types';
type CoreMetadataInvocation,
isLoRAModelConfig,
type LoRALoaderInvocation,
type NonNullableGraph,
} from 'services/api/types';
import { CLIP_SKIP, LORA_LOADER, MAIN_MODEL_LOADER, NEGATIVE_CONDITIONING, POSITIVE_CONDITIONING } from './constants'; import { CLIP_SKIP, LORA_LOADER, MAIN_MODEL_LOADER, NEGATIVE_CONDITIONING, POSITIVE_CONDITIONING } from './constants';
import { getModelMetadataField, upsertMetadata } from './metadata'; import { upsertMetadata } from './metadata';
export const addLoRAsToGraph = async ( export const addLoRAsToGraph = async (
state: RootState, state: RootState,
@ -49,19 +44,18 @@ export const addLoRAsToGraph = async (
const { weight } = lora; const { weight } = lora;
const { key } = lora.model; const { key } = lora.model;
const currentLoraNodeId = `${LORA_LOADER}_${key}`; const currentLoraNodeId = `${LORA_LOADER}_${key}`;
const parsedModel = zModelIdentifierField.parse(lora.model);
const loraLoaderNode: LoRALoaderInvocation = { const loraLoaderNode: LoRALoaderInvocation = {
type: 'lora_loader', type: 'lora_loader',
id: currentLoraNodeId, id: currentLoraNodeId,
is_intermediate: true, is_intermediate: true,
lora: { key }, lora: parsedModel,
weight, weight,
}; };
const modelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig);
loraMetadata.push({ loraMetadata.push({
model: getModelMetadataField(modelConfig), model: parsedModel,
weight, weight,
}); });

View File

@ -1,12 +1,7 @@
import type { RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; import { zModelIdentifierField } from 'features/nodes/types/common';
import { filter, size } from 'lodash-es'; import { filter, size } from 'lodash-es';
import { import type { CoreMetadataInvocation, NonNullableGraph, SDXLLoRALoaderInvocation } from 'services/api/types';
type CoreMetadataInvocation,
isLoRAModelConfig,
type NonNullableGraph,
type SDXLLoRALoaderInvocation,
} from 'services/api/types';
import { import {
LORA_LOADER, LORA_LOADER,
@ -16,7 +11,7 @@ import {
SDXL_REFINER_INPAINT_CREATE_MASK, SDXL_REFINER_INPAINT_CREATE_MASK,
SEAMLESS, SEAMLESS,
} from './constants'; } from './constants';
import { getModelMetadataField, upsertMetadata } from './metadata'; import { upsertMetadata } from './metadata';
export const addSDXLLoRAsToGraph = async ( export const addSDXLLoRAsToGraph = async (
state: RootState, state: RootState,
@ -63,20 +58,18 @@ export const addSDXLLoRAsToGraph = async (
enabledLoRAs.forEach(async (lora) => { enabledLoRAs.forEach(async (lora) => {
const { weight } = lora; const { weight } = lora;
const { key } = lora.model; const currentLoraNodeId = `${LORA_LOADER}_${lora.model.key}`;
const currentLoraNodeId = `${LORA_LOADER}_${key}`; const parsedModel = zModelIdentifierField.parse(lora.model);
const loraLoaderNode: SDXLLoRALoaderInvocation = { const loraLoaderNode: SDXLLoRALoaderInvocation = {
type: 'sdxl_lora_loader', type: 'sdxl_lora_loader',
id: currentLoraNodeId, id: currentLoraNodeId,
is_intermediate: true, is_intermediate: true,
lora: { key }, lora: parsedModel,
weight, weight,
}; };
const modelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig); loraMetadata.push({ model: parsedModel, weight });
loraMetadata.push({ model: getModelMetadataField(modelConfig), weight });
// add to graph // add to graph
graph.nodes[currentLoraNodeId] = loraLoaderNode; graph.nodes[currentLoraNodeId] = loraLoaderNode;

View File

@ -3,6 +3,7 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { useModelCustomSelect } from 'common/hooks/useModelCustomSelect'; import { useModelCustomSelect } from 'common/hooks/useModelCustomSelect';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { modelSelected } from 'features/parameters/store/actions'; import { modelSelected } from 'features/parameters/store/actions';
import { selectGenerationSlice } from 'features/parameters/store/generationSlice'; import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
@ -24,7 +25,11 @@ const ParamMainModelSelect = () => {
if (!model) { if (!model) {
return; return;
} }
dispatch(modelSelected({ key: model.key, base: model.base })); try {
dispatch(modelSelected(zModelIdentifierField.parse(model)));
} catch {
// no-op
}
}, },
[dispatch] [dispatch]
); );

View File

@ -3,7 +3,7 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { getModelKeyAndBase } from 'features/metadata/util/modelFetchingHelpers'; import { zModelIdentifierField } from 'features/nodes/types/common';
import { selectGenerationSlice, vaeSelected } from 'features/parameters/store/generationSlice'; import { selectGenerationSlice, vaeSelected } from 'features/parameters/store/generationSlice';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
@ -30,7 +30,7 @@ const ParamVAEModelSelect = () => {
); );
const _onChange = useCallback( const _onChange = useCallback(
(vae: VAEModelConfig | null) => { (vae: VAEModelConfig | null) => {
dispatch(vaeSelected(vae ? getModelKeyAndBase(vae) : null)); dispatch(vaeSelected(vae ? zModelIdentifierField.parse(vae) : null));
}, },
[dispatch] [dispatch]
); );

View File

@ -1,6 +1,6 @@
import { NUMPY_RAND_MAX } from 'app/constants'; import { NUMPY_RAND_MAX } from 'app/constants';
import { roundToMultiple } from 'common/util/roundDownToMultiple'; import { roundToMultiple } from 'common/util/roundDownToMultiple';
import { zModelIdentifierWithBase, zSchedulerField } from 'features/nodes/types/common'; import { zModelIdentifierField, zSchedulerField } from 'features/nodes/types/common';
import { z } from 'zod'; import { z } from 'zod';
/** /**
@ -92,37 +92,37 @@ export const isParameterHeight = (val: unknown): val is ParameterHeight => zPara
// #endregion // #endregion
// #region Model // #region Model
export const zParameterModel = zModelIdentifierWithBase; export const zParameterModel = zModelIdentifierField;
export type ParameterModel = z.infer<typeof zParameterModel>; export type ParameterModel = z.infer<typeof zParameterModel>;
// #endregion // #endregion
// #region SDXL Refiner Model // #region SDXL Refiner Model
const zParameterSDXLRefinerModel = zModelIdentifierWithBase; const zParameterSDXLRefinerModel = zModelIdentifierField;
export type ParameterSDXLRefinerModel = z.infer<typeof zParameterSDXLRefinerModel>; export type ParameterSDXLRefinerModel = z.infer<typeof zParameterSDXLRefinerModel>;
// #endregion // #endregion
// #region VAE Model // #region VAE Model
export const zParameterVAEModel = zModelIdentifierWithBase; export const zParameterVAEModel = zModelIdentifierField;
export type ParameterVAEModel = z.infer<typeof zParameterVAEModel>; export type ParameterVAEModel = z.infer<typeof zParameterVAEModel>;
// #endregion // #endregion
// #region LoRA Model // #region LoRA Model
const zParameterLoRAModel = zModelIdentifierWithBase; const zParameterLoRAModel = zModelIdentifierField;
export type ParameterLoRAModel = z.infer<typeof zParameterLoRAModel>; export type ParameterLoRAModel = z.infer<typeof zParameterLoRAModel>;
// #endregion // #endregion
// #region ControlNet Model // #region ControlNet Model
const zParameterControlNetModel = zModelIdentifierWithBase; const zParameterControlNetModel = zModelIdentifierField;
export type ParameterControlNetModel = z.infer<typeof zParameterControlNetModel>; export type ParameterControlNetModel = z.infer<typeof zParameterControlNetModel>;
// #endregion // #endregion
// #region IP Adapter Model // #region IP Adapter Model
const zParameterIPAdapterModel = zModelIdentifierWithBase; const zParameterIPAdapterModel = zModelIdentifierField;
export type ParameterIPAdapterModel = z.infer<typeof zParameterIPAdapterModel>; export type ParameterIPAdapterModel = z.infer<typeof zParameterIPAdapterModel>;
// #endregion // #endregion
// #region T2I Adapter Model // #region T2I Adapter Model
const zParameterT2IAdapterModel = zModelIdentifierWithBase; const zParameterT2IAdapterModel = zModelIdentifierField;
export type ParameterT2IAdapterModel = z.infer<typeof zParameterT2IAdapterModel>; export type ParameterT2IAdapterModel = z.infer<typeof zParameterT2IAdapterModel>;
// #endregion // #endregion

View File

@ -1,11 +1,11 @@
import type { ModelIdentifierWithBase } from 'features/nodes/types/common'; import type { ModelIdentifierField } from 'features/nodes/types/common';
/** /**
* Gets the optimal dimension for a givel model, based on the model's base_model * Gets the optimal dimension for a givel model, based on the model's base_model
* @param model The model identifier * @param model The model identifier
* @returns The optimal dimension for the model * @returns The optimal dimension for the model
*/ */
export const getOptimalDimension = (model?: ModelIdentifierWithBase | null): number => export const getOptimalDimension = (model?: ModelIdentifierField | null): number =>
model?.base === 'sdxl' ? 1024 : 512; model?.base === 'sdxl' ? 1024 : 512;
const MIN_AREA_FACTOR = 0.8; const MIN_AREA_FACTOR = 0.8;

View File

@ -3,7 +3,7 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { useModelCombobox } from 'common/hooks/useModelCombobox'; import { useModelCombobox } from 'common/hooks/useModelCombobox';
import { getModelKeyAndBase } from 'features/metadata/util/modelFetchingHelpers'; import { zModelIdentifierField } from 'features/nodes/types/common';
import { refinerModelChanged, selectSdxlSlice } from 'features/sdxl/store/sdxlSlice'; import { refinerModelChanged, selectSdxlSlice } from 'features/sdxl/store/sdxlSlice';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
@ -26,7 +26,7 @@ const ParamSDXLRefinerModelSelect = () => {
dispatch(refinerModelChanged(null)); dispatch(refinerModelChanged(null));
return; return;
} }
dispatch(refinerModelChanged(getModelKeyAndBase(model))); dispatch(refinerModelChanged(zModelIdentifierField.parse(model)));
}, },
[dispatch] [dispatch]
); );