mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(ui): model metadata handlers use model identifiers, not configs
Model metadata includes the main model, VAE and refiner model. These used full model configs, as returned by the server, as their metadata type. LoRA and control adapter metadata only use the metadata identifier. This created a difference in handling. After parsing a model/vae/refiner, we have its name and can display it. But for LoRAs and control adapters, we only have the model key and must query for the full model config to get the name. This change makes main model/vae/refiner metadata only have the model key, like LoRAs and control adapters. The render function is now async so fetching can occur within it. All metadata fields with models now only contain the identifier, and fetch the model name to render their values.
This commit is contained in:
parent
7b4ef5926d
commit
9abfb02bf0
@ -1,4 +1,3 @@
|
||||
import { Text } from '@invoke-ai/ui-library';
|
||||
import type { ControlNetConfig } from 'features/controlAdapters/store/types';
|
||||
import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
|
||||
import type { MetadataHandlers } from 'features/metadata/types';
|
||||
@ -56,11 +55,18 @@ const MetadataViewControlNet = ({
|
||||
handlers.recallItem(controlNet, true);
|
||||
}, [handlers, controlNet]);
|
||||
|
||||
const renderedValue = useMemo(() => {
|
||||
const [renderedValue, setRenderedValue] = useState<React.ReactNode>(null);
|
||||
useEffect(() => {
|
||||
const _renderValue = async () => {
|
||||
if (!handlers.renderItemValue) {
|
||||
return null;
|
||||
setRenderedValue(null);
|
||||
return;
|
||||
}
|
||||
return <Text>{handlers.renderItemValue(controlNet)}</Text>;
|
||||
const rendered = await handlers.renderItemValue(controlNet);
|
||||
setRenderedValue(rendered);
|
||||
};
|
||||
|
||||
_renderValue();
|
||||
}, [handlers, controlNet]);
|
||||
|
||||
return <MetadataItemView label={label} isDisabled={false} onRecall={onRecall} renderedValue={renderedValue} />;
|
||||
|
@ -1,4 +1,3 @@
|
||||
import { Text } from '@invoke-ai/ui-library';
|
||||
import type { IPAdapterConfig } from 'features/controlAdapters/store/types';
|
||||
import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
|
||||
import type { MetadataHandlers } from 'features/metadata/types';
|
||||
@ -51,11 +50,18 @@ const MetadataViewIPAdapter = ({
|
||||
handlers.recallItem(ipAdapter, true);
|
||||
}, [handlers, ipAdapter]);
|
||||
|
||||
const renderedValue = useMemo(() => {
|
||||
const [renderedValue, setRenderedValue] = useState<React.ReactNode>(null);
|
||||
useEffect(() => {
|
||||
const _renderValue = async () => {
|
||||
if (!handlers.renderItemValue) {
|
||||
return null;
|
||||
setRenderedValue(null);
|
||||
return;
|
||||
}
|
||||
return <Text>{handlers.renderItemValue(ipAdapter)}</Text>;
|
||||
const rendered = await handlers.renderItemValue(ipAdapter);
|
||||
setRenderedValue(rendered);
|
||||
};
|
||||
|
||||
_renderValue();
|
||||
}, [handlers, ipAdapter]);
|
||||
|
||||
return <MetadataItemView label={label} isDisabled={false} onRecall={onRecall} renderedValue={renderedValue} />;
|
||||
|
@ -1,4 +1,3 @@
|
||||
import { Text } from '@invoke-ai/ui-library';
|
||||
import type { LoRA } from 'features/lora/store/loraSlice';
|
||||
import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
|
||||
import type { MetadataHandlers } from 'features/metadata/types';
|
||||
@ -51,11 +50,18 @@ const MetadataViewLoRA = ({
|
||||
handlers.recallItem(lora, true);
|
||||
}, [handlers, lora]);
|
||||
|
||||
const renderedValue = useMemo(() => {
|
||||
const [renderedValue, setRenderedValue] = useState<React.ReactNode>(null);
|
||||
useEffect(() => {
|
||||
const _renderValue = async () => {
|
||||
if (!handlers.renderItemValue) {
|
||||
return null;
|
||||
setRenderedValue(null);
|
||||
return;
|
||||
}
|
||||
return <Text>{handlers.renderItemValue(lora)}</Text>;
|
||||
const rendered = await handlers.renderItemValue(lora);
|
||||
setRenderedValue(rendered);
|
||||
};
|
||||
|
||||
_renderValue();
|
||||
}, [handlers, lora]);
|
||||
|
||||
return <MetadataItemView label={label} isDisabled={false} onRecall={onRecall} renderedValue={renderedValue} />;
|
||||
|
@ -1,4 +1,3 @@
|
||||
import { Text } from '@invoke-ai/ui-library';
|
||||
import type { T2IAdapterConfig } from 'features/controlAdapters/store/types';
|
||||
import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
|
||||
import type { MetadataHandlers } from 'features/metadata/types';
|
||||
@ -56,11 +55,18 @@ const MetadataViewT2IAdapter = ({
|
||||
handlers.recallItem(t2iAdapter, true);
|
||||
}, [handlers, t2iAdapter]);
|
||||
|
||||
const renderedValue = useMemo(() => {
|
||||
const [renderedValue, setRenderedValue] = useState<React.ReactNode>(null);
|
||||
useEffect(() => {
|
||||
const _renderValue = async () => {
|
||||
if (!handlers.renderItemValue) {
|
||||
return null;
|
||||
setRenderedValue(null);
|
||||
return;
|
||||
}
|
||||
return <Text>{handlers.renderItemValue(t2iAdapter)}</Text>;
|
||||
const rendered = await handlers.renderItemValue(t2iAdapter);
|
||||
setRenderedValue(rendered);
|
||||
};
|
||||
|
||||
_renderValue();
|
||||
}, [handlers, t2iAdapter]);
|
||||
|
||||
return <MetadataItemView label={label} isDisabled={false} onRecall={onRecall} renderedValue={renderedValue} />;
|
||||
|
@ -3,10 +3,14 @@ import type { MetadataHandlers } from 'features/metadata/types';
|
||||
import { MetadataParseFailedToken, MetadataParsePendingToken } from 'features/metadata/util/parsers';
|
||||
import { useCallback, useEffect, useMemo, useState } from 'react';
|
||||
|
||||
const pendingRenderedValue = <Text>Loading</Text>;
|
||||
const failedRenderedValue = <Text>Parsing Failed</Text>;
|
||||
|
||||
export const useMetadataItem = <T,>(metadata: unknown, handlers: MetadataHandlers<T>) => {
|
||||
const [value, setValue] = useState<T | typeof MetadataParsePendingToken | typeof MetadataParseFailedToken>(
|
||||
MetadataParsePendingToken
|
||||
);
|
||||
const [renderedValue, setRenderedValue] = useState<React.ReactNode>(pendingRenderedValue);
|
||||
|
||||
useEffect(() => {
|
||||
const _parse = async () => {
|
||||
@ -24,20 +28,27 @@ export const useMetadataItem = <T,>(metadata: unknown, handlers: MetadataHandler
|
||||
|
||||
const label = useMemo(() => handlers.getLabel(), [handlers]);
|
||||
|
||||
const renderedValue = useMemo(() => {
|
||||
useEffect(() => {
|
||||
const _renderValue = async () => {
|
||||
if (value === MetadataParsePendingToken) {
|
||||
return <Text>Loading</Text>;
|
||||
setRenderedValue(pendingRenderedValue);
|
||||
return;
|
||||
}
|
||||
if (value === MetadataParseFailedToken) {
|
||||
return <Text>Parsing Failed</Text>;
|
||||
setRenderedValue(failedRenderedValue);
|
||||
return;
|
||||
}
|
||||
|
||||
const rendered = handlers.renderValue(value);
|
||||
const rendered = await handlers.renderValue(value);
|
||||
|
||||
if (typeof rendered === 'string') {
|
||||
return <Text>{rendered}</Text>;
|
||||
setRenderedValue(<Text>{rendered}</Text>);
|
||||
return;
|
||||
}
|
||||
return rendered;
|
||||
setRenderedValue(rendered);
|
||||
};
|
||||
|
||||
_renderValue();
|
||||
}, [handlers, value]);
|
||||
|
||||
const onRecall = useCallback(() => {
|
||||
|
@ -1,7 +1,7 @@
|
||||
/**
|
||||
* Renders a value of type T as a React node.
|
||||
*/
|
||||
export type MetadataRenderValueFunc<T> = (value: T) => React.ReactNode;
|
||||
export type MetadataRenderValueFunc<T> = (value: T) => Promise<React.ReactNode>;
|
||||
|
||||
/**
|
||||
* Gets the label of the current metadata item as a string.
|
||||
|
@ -11,18 +11,38 @@ import type {
|
||||
MetadataRenderValueFunc,
|
||||
MetadataValidateFunc,
|
||||
} from 'features/metadata/types';
|
||||
import { fetchModelConfig } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { validators } from 'features/metadata/util/validators';
|
||||
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
|
||||
import { t } from 'i18next';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
import { parsers } from './parsers';
|
||||
import { recallers } from './recallers';
|
||||
|
||||
const renderModelConfigValue: MetadataRenderValueFunc<AnyModelConfig> = (value) =>
|
||||
`${value.name} (${value.base.toUpperCase()}, ${value.key})`;
|
||||
const renderLoRAValue: MetadataRenderValueFunc<LoRA> = (value) => `${value.model.key} (${value.weight})`;
|
||||
const renderControlAdapterValue: MetadataRenderValueFunc<ControlAdapterConfig> = (value) =>
|
||||
`${value.model?.key} (${value.weight})`;
|
||||
const renderModelConfigValue: MetadataRenderValueFunc<ModelIdentifierWithBase> = async (value) => {
|
||||
try {
|
||||
const modelConfig = await fetchModelConfig(value.key);
|
||||
return `${modelConfig.name} (${modelConfig.base.toUpperCase()})`;
|
||||
} catch {
|
||||
return `${value.key} (${value.base.toUpperCase()})`;
|
||||
}
|
||||
};
|
||||
const renderLoRAValue: MetadataRenderValueFunc<LoRA> = async (value) => {
|
||||
try {
|
||||
const modelConfig = await fetchModelConfig(value.model.key);
|
||||
return `${modelConfig.name} (${modelConfig.base.toUpperCase()}) - ${value.weight}`;
|
||||
} catch {
|
||||
return `${value.model.key} (${value.model.base.toUpperCase()}) - ${value.weight}`;
|
||||
}
|
||||
};
|
||||
const renderControlAdapterValue: MetadataRenderValueFunc<ControlAdapterConfig> = async (value) => {
|
||||
try {
|
||||
const modelConfig = await fetchModelConfig(value.model?.key ?? 'none');
|
||||
return `${modelConfig.name} (${modelConfig.base.toUpperCase()}) - ${value.weight}`;
|
||||
} catch {
|
||||
return `${value.model?.key} (${value.model?.base.toUpperCase()}) - ${value.weight}`;
|
||||
}
|
||||
};
|
||||
|
||||
const parameterSetToast = (parameter: string, description?: string) => {
|
||||
toast({
|
||||
@ -130,6 +150,8 @@ const buildRecallItem =
|
||||
}
|
||||
};
|
||||
|
||||
const resolveToString = (value: unknown) => new Promise<React.ReactNode>((resolve) => resolve(String(value)));
|
||||
|
||||
const buildHandlers: BuildMetadataHandlers = ({
|
||||
getLabel,
|
||||
parser,
|
||||
@ -146,8 +168,8 @@ const buildHandlers: BuildMetadataHandlers = ({
|
||||
recall: recaller ? buildRecall({ recaller, validator, getLabel }) : undefined,
|
||||
recallItem: itemRecaller ? buildRecallItem({ itemRecaller, itemValidator, getLabel }) : undefined,
|
||||
getLabel,
|
||||
renderValue: renderValue ?? String,
|
||||
renderItemValue: renderItemValue ?? String,
|
||||
renderValue: renderValue ?? resolveToString,
|
||||
renderItemValue: renderItemValue ?? resolveToString,
|
||||
});
|
||||
|
||||
export const handlers = {
|
||||
|
@ -12,7 +12,6 @@ import type { MetadataParseFunc } from 'features/metadata/types';
|
||||
import {
|
||||
fetchModelConfigWithTypeGuard,
|
||||
getModelKey,
|
||||
getModelKeyAndBase,
|
||||
} from 'features/metadata/util/modelFetchingHelpers';
|
||||
import {
|
||||
zControlField,
|
||||
@ -26,17 +25,20 @@ import type {
|
||||
ParameterHeight,
|
||||
ParameterHRFEnabled,
|
||||
ParameterHRFMethod,
|
||||
ParameterModel,
|
||||
ParameterNegativePrompt,
|
||||
ParameterNegativeStylePromptSDXL,
|
||||
ParameterPositivePrompt,
|
||||
ParameterPositiveStylePromptSDXL,
|
||||
ParameterScheduler,
|
||||
ParameterSDXLRefinerModel,
|
||||
ParameterSDXLRefinerNegativeAestheticScore,
|
||||
ParameterSDXLRefinerPositiveAestheticScore,
|
||||
ParameterSDXLRefinerStart,
|
||||
ParameterSeed,
|
||||
ParameterSteps,
|
||||
ParameterStrength,
|
||||
ParameterVAEModel,
|
||||
ParameterWidth,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import {
|
||||
@ -60,7 +62,6 @@ import {
|
||||
isParameterWidth,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import { get, isArray, isString } from 'lodash-es';
|
||||
import type { NonRefinerMainModelConfig, RefinerMainModelConfig, VAEModelConfig } from 'services/api/types';
|
||||
import {
|
||||
isControlNetModelConfig,
|
||||
isIPAdapterModelConfig,
|
||||
@ -163,25 +164,28 @@ const parseRefinerNegativeAestheticScore: MetadataParseFunc<ParameterSDXLRefiner
|
||||
const parseRefinerStart: MetadataParseFunc<ParameterSDXLRefinerStart> = (metadata) =>
|
||||
getProperty(metadata, 'refiner_start', isParameterSDXLRefinerStart);
|
||||
|
||||
const parseMainModel: MetadataParseFunc<NonRefinerMainModelConfig> = async (metadata) => {
|
||||
const parseMainModel: MetadataParseFunc<ParameterModel> = async (metadata) => {
|
||||
const model = await getProperty(metadata, 'model', undefined);
|
||||
const key = await getModelKey(model, 'main');
|
||||
const mainModelConfig = await fetchModelConfigWithTypeGuard(key, isNonRefinerMainModelConfig);
|
||||
return mainModelConfig;
|
||||
const modelIdentifier = zModelIdentifierWithBase.parse(mainModelConfig);
|
||||
return modelIdentifier;
|
||||
};
|
||||
|
||||
const parseRefinerModel: MetadataParseFunc<RefinerMainModelConfig> = async (metadata) => {
|
||||
const parseRefinerModel: MetadataParseFunc<ParameterSDXLRefinerModel> = async (metadata) => {
|
||||
const refiner_model = await getProperty(metadata, 'refiner_model', undefined);
|
||||
const key = await getModelKey(refiner_model, 'main');
|
||||
const refinerModelConfig = await fetchModelConfigWithTypeGuard(key, isRefinerMainModelModelConfig);
|
||||
return refinerModelConfig;
|
||||
const modelIdentifier = zModelIdentifierWithBase.parse(refinerModelConfig);
|
||||
return modelIdentifier;
|
||||
};
|
||||
|
||||
const parseVAEModel: MetadataParseFunc<VAEModelConfig> = async (metadata) => {
|
||||
const parseVAEModel: MetadataParseFunc<ParameterVAEModel> = async (metadata) => {
|
||||
const vae = await getProperty(metadata, 'vae', undefined);
|
||||
const key = await getModelKey(vae, 'vae');
|
||||
const vaeModelConfig = await fetchModelConfigWithTypeGuard(key, isVAEModelConfig);
|
||||
return vaeModelConfig;
|
||||
const modelIdentifier = zModelIdentifierWithBase.parse(vaeModelConfig);
|
||||
return modelIdentifier;
|
||||
};
|
||||
|
||||
const parseLoRA: MetadataParseFunc<LoRA> = async (metadataItem) => {
|
||||
@ -194,7 +198,7 @@ const parseLoRA: MetadataParseFunc<LoRA> = async (metadataItem) => {
|
||||
const loraModelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig);
|
||||
|
||||
return {
|
||||
model: getModelKeyAndBase(loraModelConfig),
|
||||
model: zModelIdentifierWithBase.parse(loraModelConfig),
|
||||
weight: isParameterLoRAWeight(weight) ? weight : defaultLoRAConfig.weight,
|
||||
isEnabled: true,
|
||||
};
|
||||
|
@ -5,7 +5,6 @@ import { setHrfEnabled, setHrfMethod, setHrfStrength } from 'features/hrf/store/
|
||||
import type { LoRA } from 'features/lora/store/loraSlice';
|
||||
import { loraRecalled } from 'features/lora/store/loraSlice';
|
||||
import type { MetadataRecallFunc } from 'features/metadata/types';
|
||||
import { zModelIdentifierWithBase } from 'features/nodes/types/common';
|
||||
import { modelSelected } from 'features/parameters/store/actions';
|
||||
import {
|
||||
heightRecalled,
|
||||
@ -26,17 +25,20 @@ import type {
|
||||
ParameterHeight,
|
||||
ParameterHRFEnabled,
|
||||
ParameterHRFMethod,
|
||||
ParameterModel,
|
||||
ParameterNegativePrompt,
|
||||
ParameterNegativeStylePromptSDXL,
|
||||
ParameterPositivePrompt,
|
||||
ParameterPositiveStylePromptSDXL,
|
||||
ParameterScheduler,
|
||||
ParameterSDXLRefinerModel,
|
||||
ParameterSDXLRefinerNegativeAestheticScore,
|
||||
ParameterSDXLRefinerPositiveAestheticScore,
|
||||
ParameterSDXLRefinerStart,
|
||||
ParameterSeed,
|
||||
ParameterSteps,
|
||||
ParameterStrength,
|
||||
ParameterVAEModel,
|
||||
ParameterWidth,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import {
|
||||
@ -50,7 +52,6 @@ import {
|
||||
setRefinerStart,
|
||||
setRefinerSteps,
|
||||
} from 'features/sdxl/store/sdxlSlice';
|
||||
import type { NonRefinerMainModelConfig, RefinerMainModelConfig, VAEModelConfig } from 'services/api/types';
|
||||
|
||||
const recallPositivePrompt: MetadataRecallFunc<ParameterPositivePrompt> = (positivePrompt) => {
|
||||
getStore().dispatch(setPositivePrompt(positivePrompt));
|
||||
@ -140,23 +141,20 @@ const recallRefinerStart: MetadataRecallFunc<ParameterSDXLRefinerStart> = (refin
|
||||
getStore().dispatch(setRefinerStart(refinerStart));
|
||||
};
|
||||
|
||||
const recallModel: MetadataRecallFunc<NonRefinerMainModelConfig> = (model) => {
|
||||
const modelIdentifier = zModelIdentifierWithBase.parse(model);
|
||||
getStore().dispatch(modelSelected(modelIdentifier));
|
||||
const recallModel: MetadataRecallFunc<ParameterModel> = (model) => {
|
||||
getStore().dispatch(modelSelected(model));
|
||||
};
|
||||
|
||||
const recallRefinerModel: MetadataRecallFunc<RefinerMainModelConfig> = (refinerModel) => {
|
||||
const modelIdentifier = zModelIdentifierWithBase.parse(refinerModel);
|
||||
getStore().dispatch(refinerModelChanged(modelIdentifier));
|
||||
const recallRefinerModel: MetadataRecallFunc<ParameterSDXLRefinerModel> = (refinerModel) => {
|
||||
getStore().dispatch(refinerModelChanged(refinerModel));
|
||||
};
|
||||
|
||||
const recallVAE: MetadataRecallFunc<VAEModelConfig | null | undefined> = (vaeModel) => {
|
||||
const recallVAE: MetadataRecallFunc<ParameterVAEModel | null | undefined> = (vaeModel) => {
|
||||
if (!vaeModel) {
|
||||
getStore().dispatch(vaeSelected(null));
|
||||
return;
|
||||
}
|
||||
const modelIdentifier = zModelIdentifierWithBase.parse(vaeModel);
|
||||
getStore().dispatch(vaeSelected(modelIdentifier));
|
||||
getStore().dispatch(vaeSelected(vaeModel));
|
||||
};
|
||||
|
||||
const recallLoRA: MetadataRecallFunc<LoRA> = (lora) => {
|
||||
|
@ -3,7 +3,8 @@ import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'featur
|
||||
import type { LoRA } from 'features/lora/store/loraSlice';
|
||||
import type { MetadataValidateFunc } from 'features/metadata/types';
|
||||
import { InvalidModelConfigError } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import type { BaseModelType, RefinerMainModelConfig, VAEModelConfig } from 'services/api/types';
|
||||
import type { ParameterSDXLRefinerModel, ParameterVAEModel } from 'features/parameters/types/parameterSchemas';
|
||||
import type { BaseModelType } from 'services/api/types';
|
||||
|
||||
/**
|
||||
* Checks the given base model type against the currently-selected model's base type and throws an error if they are
|
||||
@ -21,12 +22,12 @@ const validateBaseCompatibility = (base?: BaseModelType, message?: string) => {
|
||||
}
|
||||
};
|
||||
|
||||
const validateRefinerModel: MetadataValidateFunc<RefinerMainModelConfig> = (refinerModel) => {
|
||||
const validateRefinerModel: MetadataValidateFunc<ParameterSDXLRefinerModel> = (refinerModel) => {
|
||||
validateBaseCompatibility('sdxl', 'Refiner incompatible with currently-selected model');
|
||||
return new Promise((resolve) => resolve(refinerModel));
|
||||
};
|
||||
|
||||
const validateVAEModel: MetadataValidateFunc<VAEModelConfig> = (vaeModel) => {
|
||||
const validateVAEModel: MetadataValidateFunc<ParameterVAEModel> = (vaeModel) => {
|
||||
validateBaseCompatibility(vaeModel.base, 'VAE incompatible with currently-selected model');
|
||||
return new Promise((resolve) => resolve(vaeModel));
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user