fix(ui): update all components and logic to use enriched ModelIdentifierField

This commit is contained in:
psychedelicious
2024-03-09 19:51:15 +11:00
parent 4433b78e59
commit 133c90e116
19 changed files with 85 additions and 94 deletions

View File

@ -13,13 +13,13 @@ import type {
} from 'features/metadata/types';
import { fetchModelConfig } from 'features/metadata/util/modelFetchingHelpers';
import { validators } from 'features/metadata/util/validators';
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { t } from 'i18next';
import { parsers } from './parsers';
import { recallers } from './recallers';
const renderModelConfigValue: MetadataRenderValueFunc<ModelIdentifierWithBase> = async (value) => {
const renderModelConfigValue: MetadataRenderValueFunc<ModelIdentifierField> = async (value) => {
try {
const modelConfig = await fetchModelConfig(value.key);
return `${modelConfig.name} (${modelConfig.base.toUpperCase()})`;

View File

@ -1,5 +1,4 @@
import { getStore } from 'app/store/nanostores/store';
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
import { isModelIdentifier, isModelIdentifierV2 } from 'features/nodes/types/common';
import { modelsApi } from 'services/api/endpoints/models';
import type { AnyModelConfig, BaseModelType, ModelType } from 'services/api/types';
@ -105,8 +104,3 @@ export const getModelKey = async (modelIdentifier: unknown, type: ModelType, mes
}
throw new InvalidModelConfigError(message || `Invalid model identifier: ${modelIdentifier}`);
};
export const getModelKeyAndBase = (modelConfig: AnyModelConfig): ModelIdentifierWithBase => ({
key: modelConfig.key,
base: modelConfig.base,
});

View File

@ -13,12 +13,7 @@ import type {
T2IAdapterConfigMetadata,
} from 'features/metadata/types';
import { fetchModelConfigWithTypeGuard, getModelKey } from 'features/metadata/util/modelFetchingHelpers';
import {
zControlField,
zIPAdapterField,
zModelIdentifierWithBase,
zT2IAdapterField,
} from 'features/nodes/types/common';
import { zControlField, zIPAdapterField, zModelIdentifierField, zT2IAdapterField } from 'features/nodes/types/common';
import type {
ParameterCFGRescaleMultiplier,
ParameterCFGScale,
@ -181,7 +176,7 @@ const parseMainModel: MetadataParseFunc<ParameterModel> = async (metadata) => {
const model = await getProperty(metadata, 'model', undefined);
const key = await getModelKey(model, 'main');
const mainModelConfig = await fetchModelConfigWithTypeGuard(key, isNonRefinerMainModelConfig);
const modelIdentifier = zModelIdentifierWithBase.parse(mainModelConfig);
const modelIdentifier = zModelIdentifierField.parse(mainModelConfig);
return modelIdentifier;
};
@ -189,7 +184,7 @@ const parseRefinerModel: MetadataParseFunc<ParameterSDXLRefinerModel> = async (m
const refiner_model = await getProperty(metadata, 'refiner_model', undefined);
const key = await getModelKey(refiner_model, 'main');
const refinerModelConfig = await fetchModelConfigWithTypeGuard(key, isRefinerMainModelModelConfig);
const modelIdentifier = zModelIdentifierWithBase.parse(refinerModelConfig);
const modelIdentifier = zModelIdentifierField.parse(refinerModelConfig);
return modelIdentifier;
};
@ -197,7 +192,7 @@ const parseVAEModel: MetadataParseFunc<ParameterVAEModel> = async (metadata) =>
const vae = await getProperty(metadata, 'vae', undefined);
const key = await getModelKey(vae, 'vae');
const vaeModelConfig = await fetchModelConfigWithTypeGuard(key, isVAEModelConfig);
const modelIdentifier = zModelIdentifierWithBase.parse(vaeModelConfig);
const modelIdentifier = zModelIdentifierField.parse(vaeModelConfig);
return modelIdentifier;
};
@ -211,7 +206,7 @@ const parseLoRA: MetadataParseFunc<LoRA> = async (metadataItem) => {
const loraModelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig);
return {
model: zModelIdentifierWithBase.parse(loraModelConfig),
model: zModelIdentifierField.parse(loraModelConfig),
weight: isParameterLoRAWeight(weight) ? weight : defaultLoRAConfig.weight,
isEnabled: true,
};
@ -258,7 +253,7 @@ const parseControlNet: MetadataParseFunc<ControlNetConfigMetadata> = async (meta
const controlNet: ControlNetConfigMetadata = {
type: 'controlnet',
isEnabled: true,
model: zModelIdentifierWithBase.parse(controlNetModel),
model: zModelIdentifierField.parse(controlNetModel),
weight: typeof control_weight === 'number' ? control_weight : initialControlNet.weight,
beginStepPct: begin_step_percent ?? initialControlNet.beginStepPct,
endStepPct: end_step_percent ?? initialControlNet.endStepPct,
@ -309,7 +304,7 @@ const parseT2IAdapter: MetadataParseFunc<T2IAdapterConfigMetadata> = async (meta
const t2iAdapter: T2IAdapterConfigMetadata = {
type: 't2i_adapter',
isEnabled: true,
model: zModelIdentifierWithBase.parse(t2iAdapterModel),
model: zModelIdentifierField.parse(t2iAdapterModel),
weight: typeof weight === 'number' ? weight : initialT2IAdapter.weight,
beginStepPct: begin_step_percent ?? initialT2IAdapter.beginStepPct,
endStepPct: end_step_percent ?? initialT2IAdapter.endStepPct,
@ -354,7 +349,7 @@ const parseIPAdapter: MetadataParseFunc<IPAdapterConfigMetadata> = async (metada
id: uuidv4(),
type: 'ip_adapter',
isEnabled: true,
model: zModelIdentifierWithBase.parse(ipAdapterModel),
model: zModelIdentifierField.parse(ipAdapterModel),
controlImage: image?.image_name ?? null,
weight: weight ?? initialIPAdapter.weight,
beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct,