fix(ui): model metadata handlers use model identifiers, not configs

Model metadata includes the main model, VAE and refiner model.

These used full model configs, as returned by the server, as their metadata type.

LoRA and control adapter metadata only use the metadata identifier.

This created a difference in handling. After parsing a model/vae/refiner, we have its name and can display it. But for LoRAs and control adapters, we only have the model key and must query for the full model config to get the name.

This change makes main model/vae/refiner metadata only have the model key, like LoRAs and control adapters.

The render function is now async so fetching can occur within it. All metadata fields with models now only contain the identifier, and fetch the model name to render their values.
This commit is contained in:
psychedelicious 2024-02-26 23:15:01 +11:00 committed by Brandon Rising
parent 4833f9c736
commit 769ddc0024
10 changed files with 128 additions and 68 deletions

View File

@ -1,4 +1,3 @@
import { Text } from '@invoke-ai/ui-library';
import type { ControlNetConfig } from 'features/controlAdapters/store/types';
import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
import type { MetadataHandlers } from 'features/metadata/types';
@ -56,11 +55,18 @@ const MetadataViewControlNet = ({
handlers.recallItem(controlNet, true);
}, [handlers, controlNet]);
const renderedValue = useMemo(() => {
if (!handlers.renderItemValue) {
return null;
}
return <Text>{handlers.renderItemValue(controlNet)}</Text>;
const [renderedValue, setRenderedValue] = useState<React.ReactNode>(null);
useEffect(() => {
const _renderValue = async () => {
if (!handlers.renderItemValue) {
setRenderedValue(null);
return;
}
const rendered = await handlers.renderItemValue(controlNet);
setRenderedValue(rendered);
};
_renderValue();
}, [handlers, controlNet]);
return <MetadataItemView label={label} isDisabled={false} onRecall={onRecall} renderedValue={renderedValue} />;

View File

@ -1,4 +1,3 @@
import { Text } from '@invoke-ai/ui-library';
import type { IPAdapterConfig } from 'features/controlAdapters/store/types';
import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
import type { MetadataHandlers } from 'features/metadata/types';
@ -51,11 +50,18 @@ const MetadataViewIPAdapter = ({
handlers.recallItem(ipAdapter, true);
}, [handlers, ipAdapter]);
const renderedValue = useMemo(() => {
if (!handlers.renderItemValue) {
return null;
}
return <Text>{handlers.renderItemValue(ipAdapter)}</Text>;
const [renderedValue, setRenderedValue] = useState<React.ReactNode>(null);
useEffect(() => {
const _renderValue = async () => {
if (!handlers.renderItemValue) {
setRenderedValue(null);
return;
}
const rendered = await handlers.renderItemValue(ipAdapter);
setRenderedValue(rendered);
};
_renderValue();
}, [handlers, ipAdapter]);
return <MetadataItemView label={label} isDisabled={false} onRecall={onRecall} renderedValue={renderedValue} />;

View File

@ -1,4 +1,3 @@
import { Text } from '@invoke-ai/ui-library';
import type { LoRA } from 'features/lora/store/loraSlice';
import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
import type { MetadataHandlers } from 'features/metadata/types';
@ -51,11 +50,18 @@ const MetadataViewLoRA = ({
handlers.recallItem(lora, true);
}, [handlers, lora]);
const renderedValue = useMemo(() => {
if (!handlers.renderItemValue) {
return null;
}
return <Text>{handlers.renderItemValue(lora)}</Text>;
const [renderedValue, setRenderedValue] = useState<React.ReactNode>(null);
useEffect(() => {
const _renderValue = async () => {
if (!handlers.renderItemValue) {
setRenderedValue(null);
return;
}
const rendered = await handlers.renderItemValue(lora);
setRenderedValue(rendered);
};
_renderValue();
}, [handlers, lora]);
return <MetadataItemView label={label} isDisabled={false} onRecall={onRecall} renderedValue={renderedValue} />;

View File

@ -1,4 +1,3 @@
import { Text } from '@invoke-ai/ui-library';
import type { T2IAdapterConfig } from 'features/controlAdapters/store/types';
import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
import type { MetadataHandlers } from 'features/metadata/types';
@ -55,12 +54,19 @@ const MetadataViewT2IAdapter = ({
}
handlers.recallItem(t2iAdapter, true);
}, [handlers, t2iAdapter]);
const [renderedValue, setRenderedValue] = useState<React.ReactNode>(null);
useEffect(() => {
const _renderValue = async () => {
if (!handlers.renderItemValue) {
setRenderedValue(null);
return;
}
const rendered = await handlers.renderItemValue(t2iAdapter);
setRenderedValue(rendered);
};
const renderedValue = useMemo(() => {
if (!handlers.renderItemValue) {
return null;
}
return <Text>{handlers.renderItemValue(t2iAdapter)}</Text>;
_renderValue();
}, [handlers, t2iAdapter]);
return <MetadataItemView label={label} isDisabled={false} onRecall={onRecall} renderedValue={renderedValue} />;

View File

@ -3,10 +3,14 @@ import type { MetadataHandlers } from 'features/metadata/types';
import { MetadataParseFailedToken, MetadataParsePendingToken } from 'features/metadata/util/parsers';
import { useCallback, useEffect, useMemo, useState } from 'react';
const pendingRenderedValue = <Text>Loading</Text>;
const failedRenderedValue = <Text>Parsing Failed</Text>;
export const useMetadataItem = <T,>(metadata: unknown, handlers: MetadataHandlers<T>) => {
const [value, setValue] = useState<T | typeof MetadataParsePendingToken | typeof MetadataParseFailedToken>(
MetadataParsePendingToken
);
const [renderedValue, setRenderedValue] = useState<React.ReactNode>(pendingRenderedValue);
useEffect(() => {
const _parse = async () => {
@ -24,20 +28,27 @@ export const useMetadataItem = <T,>(metadata: unknown, handlers: MetadataHandler
const label = useMemo(() => handlers.getLabel(), [handlers]);
const renderedValue = useMemo(() => {
if (value === MetadataParsePendingToken) {
return <Text>Loading</Text>;
}
if (value === MetadataParseFailedToken) {
return <Text>Parsing Failed</Text>;
}
useEffect(() => {
const _renderValue = async () => {
if (value === MetadataParsePendingToken) {
setRenderedValue(pendingRenderedValue);
return;
}
if (value === MetadataParseFailedToken) {
setRenderedValue(failedRenderedValue);
return;
}
const rendered = handlers.renderValue(value);
const rendered = await handlers.renderValue(value);
if (typeof rendered === 'string') {
return <Text>{rendered}</Text>;
}
return rendered;
if (typeof rendered === 'string') {
setRenderedValue(<Text>{rendered}</Text>);
return;
}
setRenderedValue(rendered);
};
_renderValue();
}, [handlers, value]);
const onRecall = useCallback(() => {

View File

@ -1,7 +1,7 @@
/**
* Renders a value of type T as a React node.
*/
export type MetadataRenderValueFunc<T> = (value: T) => React.ReactNode;
export type MetadataRenderValueFunc<T> = (value: T) => Promise<React.ReactNode>;
/**
* Gets the label of the current metadata item as a string.

View File

@ -11,18 +11,38 @@ import type {
MetadataRenderValueFunc,
MetadataValidateFunc,
} from 'features/metadata/types';
import { fetchModelConfig } from 'features/metadata/util/modelFetchingHelpers';
import { validators } from 'features/metadata/util/validators';
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
import { t } from 'i18next';
import type { AnyModelConfig } from 'services/api/types';
import { parsers } from './parsers';
import { recallers } from './recallers';
const renderModelConfigValue: MetadataRenderValueFunc<AnyModelConfig> = (value) =>
`${value.name} (${value.base.toUpperCase()}, ${value.key})`;
const renderLoRAValue: MetadataRenderValueFunc<LoRA> = (value) => `${value.model.key} (${value.weight})`;
const renderControlAdapterValue: MetadataRenderValueFunc<ControlAdapterConfig> = (value) =>
`${value.model?.key} (${value.weight})`;
const renderModelConfigValue: MetadataRenderValueFunc<ModelIdentifierWithBase> = async (value) => {
try {
const modelConfig = await fetchModelConfig(value.key);
return `${modelConfig.name} (${modelConfig.base.toUpperCase()})`;
} catch {
return `${value.key} (${value.base.toUpperCase()})`;
}
};
const renderLoRAValue: MetadataRenderValueFunc<LoRA> = async (value) => {
try {
const modelConfig = await fetchModelConfig(value.model.key);
return `${modelConfig.name} (${modelConfig.base.toUpperCase()}) - ${value.weight}`;
} catch {
return `${value.model.key} (${value.model.base.toUpperCase()}) - ${value.weight}`;
}
};
const renderControlAdapterValue: MetadataRenderValueFunc<ControlAdapterConfig> = async (value) => {
try {
const modelConfig = await fetchModelConfig(value.model?.key ?? 'none');
return `${modelConfig.name} (${modelConfig.base.toUpperCase()}) - ${value.weight}`;
} catch {
return `${value.model?.key} (${value.model?.base.toUpperCase()}) - ${value.weight}`;
}
};
const parameterSetToast = (parameter: string, description?: string) => {
toast({
@ -130,6 +150,8 @@ const buildRecallItem =
}
};
const resolveToString = (value: unknown) => new Promise<React.ReactNode>((resolve) => resolve(String(value)));
const buildHandlers: BuildMetadataHandlers = ({
getLabel,
parser,
@ -146,8 +168,8 @@ const buildHandlers: BuildMetadataHandlers = ({
recall: recaller ? buildRecall({ recaller, validator, getLabel }) : undefined,
recallItem: itemRecaller ? buildRecallItem({ itemRecaller, itemValidator, getLabel }) : undefined,
getLabel,
renderValue: renderValue ?? String,
renderItemValue: renderItemValue ?? String,
renderValue: renderValue ?? resolveToString,
renderItemValue: renderItemValue ?? resolveToString,
});
export const handlers = {

View File

@ -12,7 +12,6 @@ import type { MetadataParseFunc } from 'features/metadata/types';
import {
fetchModelConfigWithTypeGuard,
getModelKey,
getModelKeyAndBase,
} from 'features/metadata/util/modelFetchingHelpers';
import {
zControlField,
@ -26,17 +25,20 @@ import type {
ParameterHeight,
ParameterHRFEnabled,
ParameterHRFMethod,
ParameterModel,
ParameterNegativePrompt,
ParameterNegativeStylePromptSDXL,
ParameterPositivePrompt,
ParameterPositiveStylePromptSDXL,
ParameterScheduler,
ParameterSDXLRefinerModel,
ParameterSDXLRefinerNegativeAestheticScore,
ParameterSDXLRefinerPositiveAestheticScore,
ParameterSDXLRefinerStart,
ParameterSeed,
ParameterSteps,
ParameterStrength,
ParameterVAEModel,
ParameterWidth,
} from 'features/parameters/types/parameterSchemas';
import {
@ -60,7 +62,6 @@ import {
isParameterWidth,
} from 'features/parameters/types/parameterSchemas';
import { get, isArray, isString } from 'lodash-es';
import type { NonRefinerMainModelConfig, RefinerMainModelConfig, VAEModelConfig } from 'services/api/types';
import {
isControlNetModelConfig,
isIPAdapterModelConfig,
@ -163,25 +164,28 @@ const parseRefinerNegativeAestheticScore: MetadataParseFunc<ParameterSDXLRefiner
const parseRefinerStart: MetadataParseFunc<ParameterSDXLRefinerStart> = (metadata) =>
getProperty(metadata, 'refiner_start', isParameterSDXLRefinerStart);
const parseMainModel: MetadataParseFunc<NonRefinerMainModelConfig> = async (metadata) => {
const parseMainModel: MetadataParseFunc<ParameterModel> = async (metadata) => {
const model = await getProperty(metadata, 'model', undefined);
const key = await getModelKey(model, 'main');
const mainModelConfig = await fetchModelConfigWithTypeGuard(key, isNonRefinerMainModelConfig);
return mainModelConfig;
const modelIdentifier = zModelIdentifierWithBase.parse(mainModelConfig);
return modelIdentifier;
};
const parseRefinerModel: MetadataParseFunc<RefinerMainModelConfig> = async (metadata) => {
const parseRefinerModel: MetadataParseFunc<ParameterSDXLRefinerModel> = async (metadata) => {
const refiner_model = await getProperty(metadata, 'refiner_model', undefined);
const key = await getModelKey(refiner_model, 'main');
const refinerModelConfig = await fetchModelConfigWithTypeGuard(key, isRefinerMainModelModelConfig);
return refinerModelConfig;
const modelIdentifier = zModelIdentifierWithBase.parse(refinerModelConfig);
return modelIdentifier;
};
const parseVAEModel: MetadataParseFunc<VAEModelConfig> = async (metadata) => {
const parseVAEModel: MetadataParseFunc<ParameterVAEModel> = async (metadata) => {
const vae = await getProperty(metadata, 'vae', undefined);
const key = await getModelKey(vae, 'vae');
const vaeModelConfig = await fetchModelConfigWithTypeGuard(key, isVAEModelConfig);
return vaeModelConfig;
const modelIdentifier = zModelIdentifierWithBase.parse(vaeModelConfig);
return modelIdentifier;
};
const parseLoRA: MetadataParseFunc<LoRA> = async (metadataItem) => {
@ -194,7 +198,7 @@ const parseLoRA: MetadataParseFunc<LoRA> = async (metadataItem) => {
const loraModelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig);
return {
model: getModelKeyAndBase(loraModelConfig),
model: zModelIdentifierWithBase.parse(loraModelConfig),
weight: isParameterLoRAWeight(weight) ? weight : defaultLoRAConfig.weight,
isEnabled: true,
};

View File

@ -5,7 +5,6 @@ import { setHrfEnabled, setHrfMethod, setHrfStrength } from 'features/hrf/store/
import type { LoRA } from 'features/lora/store/loraSlice';
import { loraRecalled } from 'features/lora/store/loraSlice';
import type { MetadataRecallFunc } from 'features/metadata/types';
import { zModelIdentifierWithBase } from 'features/nodes/types/common';
import { modelSelected } from 'features/parameters/store/actions';
import {
heightRecalled,
@ -26,17 +25,20 @@ import type {
ParameterHeight,
ParameterHRFEnabled,
ParameterHRFMethod,
ParameterModel,
ParameterNegativePrompt,
ParameterNegativeStylePromptSDXL,
ParameterPositivePrompt,
ParameterPositiveStylePromptSDXL,
ParameterScheduler,
ParameterSDXLRefinerModel,
ParameterSDXLRefinerNegativeAestheticScore,
ParameterSDXLRefinerPositiveAestheticScore,
ParameterSDXLRefinerStart,
ParameterSeed,
ParameterSteps,
ParameterStrength,
ParameterVAEModel,
ParameterWidth,
} from 'features/parameters/types/parameterSchemas';
import {
@ -50,7 +52,6 @@ import {
setRefinerStart,
setRefinerSteps,
} from 'features/sdxl/store/sdxlSlice';
import type { NonRefinerMainModelConfig, RefinerMainModelConfig, VAEModelConfig } from 'services/api/types';
const recallPositivePrompt: MetadataRecallFunc<ParameterPositivePrompt> = (positivePrompt) => {
getStore().dispatch(setPositivePrompt(positivePrompt));
@ -140,23 +141,20 @@ const recallRefinerStart: MetadataRecallFunc<ParameterSDXLRefinerStart> = (refin
getStore().dispatch(setRefinerStart(refinerStart));
};
const recallModel: MetadataRecallFunc<NonRefinerMainModelConfig> = (model) => {
const modelIdentifier = zModelIdentifierWithBase.parse(model);
getStore().dispatch(modelSelected(modelIdentifier));
const recallModel: MetadataRecallFunc<ParameterModel> = (model) => {
getStore().dispatch(modelSelected(model));
};
const recallRefinerModel: MetadataRecallFunc<RefinerMainModelConfig> = (refinerModel) => {
const modelIdentifier = zModelIdentifierWithBase.parse(refinerModel);
getStore().dispatch(refinerModelChanged(modelIdentifier));
const recallRefinerModel: MetadataRecallFunc<ParameterSDXLRefinerModel> = (refinerModel) => {
getStore().dispatch(refinerModelChanged(refinerModel));
};
const recallVAE: MetadataRecallFunc<VAEModelConfig | null | undefined> = (vaeModel) => {
const recallVAE: MetadataRecallFunc<ParameterVAEModel | null | undefined> = (vaeModel) => {
if (!vaeModel) {
getStore().dispatch(vaeSelected(null));
return;
}
const modelIdentifier = zModelIdentifierWithBase.parse(vaeModel);
getStore().dispatch(vaeSelected(modelIdentifier));
getStore().dispatch(vaeSelected(vaeModel));
};
const recallLoRA: MetadataRecallFunc<LoRA> = (lora) => {

View File

@ -3,7 +3,8 @@ import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'featur
import type { LoRA } from 'features/lora/store/loraSlice';
import type { MetadataValidateFunc } from 'features/metadata/types';
import { InvalidModelConfigError } from 'features/metadata/util/modelFetchingHelpers';
import type { BaseModelType, RefinerMainModelConfig, VAEModelConfig } from 'services/api/types';
import type { ParameterSDXLRefinerModel, ParameterVAEModel } from 'features/parameters/types/parameterSchemas';
import type { BaseModelType } from 'services/api/types';
/**
* Checks the given base model type against the currently-selected model's base type and throws an error if they are
@ -21,12 +22,12 @@ const validateBaseCompatibility = (base?: BaseModelType, message?: string) => {
}
};
const validateRefinerModel: MetadataValidateFunc<RefinerMainModelConfig> = (refinerModel) => {
const validateRefinerModel: MetadataValidateFunc<ParameterSDXLRefinerModel> = (refinerModel) => {
validateBaseCompatibility('sdxl', 'Refiner incompatible with currently-selected model');
return new Promise((resolve) => resolve(refinerModel));
};
const validateVAEModel: MetadataValidateFunc<VAEModelConfig> = (vaeModel) => {
const validateVAEModel: MetadataValidateFunc<ParameterVAEModel> = (vaeModel) => {
validateBaseCompatibility(vaeModel.base, 'VAE incompatible with currently-selected model');
return new Promise((resolve) => resolve(vaeModel));
};