From 769ddc0024bc7cb62efc99c85aacdd03d094378d Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 26 Feb 2024 23:15:01 +1100 Subject: [PATCH] fix(ui): model metadata handlers use model identifiers, not configs Model metadata includes the main model, VAE and refiner model. These used full model configs, as returned by the server, as their metadata type. LoRA and control adapter metadata only use the metadata identifier. This created a difference in handling. After parsing a model/vae/refiner, we have its name and can display it. But for LoRAs and control adapters, we only have the model key and must query for the full model config to get the name. This change makes main model/vae/refiner metadata only have the model key, like LoRAs and control adapters. The render function is now async so fetching can occur within it. All metadata fields with models now only contain the identifier, and fetch the model name to render their values. --- .../components/MetadataControlNets.tsx | 18 ++++++--- .../components/MetadataIPAdapters.tsx | 18 ++++++--- .../metadata/components/MetadataLoRAs.tsx | 18 ++++++--- .../components/MetadataT2IAdapters.tsx | 18 ++++++--- .../metadata/hooks/useMetadataItem.tsx | 35 +++++++++++------ .../web/src/features/metadata/types.ts | 2 +- .../src/features/metadata/util/handlers.ts | 38 +++++++++++++++---- .../web/src/features/metadata/util/parsers.ts | 22 ++++++----- .../src/features/metadata/util/recallers.ts | 20 +++++----- .../src/features/metadata/util/validators.ts | 7 ++-- 10 files changed, 128 insertions(+), 68 deletions(-) diff --git a/invokeai/frontend/web/src/features/metadata/components/MetadataControlNets.tsx b/invokeai/frontend/web/src/features/metadata/components/MetadataControlNets.tsx index 1255e8e7cb..1e1c45632f 100644 --- a/invokeai/frontend/web/src/features/metadata/components/MetadataControlNets.tsx +++ b/invokeai/frontend/web/src/features/metadata/components/MetadataControlNets.tsx @@ -1,4 +1,3 @@ -import { Text } from '@invoke-ai/ui-library'; import type { ControlNetConfig } from 'features/controlAdapters/store/types'; import { MetadataItemView } from 'features/metadata/components/MetadataItemView'; import type { MetadataHandlers } from 'features/metadata/types'; @@ -56,11 +55,18 @@ const MetadataViewControlNet = ({ handlers.recallItem(controlNet, true); }, [handlers, controlNet]); - const renderedValue = useMemo(() => { - if (!handlers.renderItemValue) { - return null; - } - return {handlers.renderItemValue(controlNet)}; + const [renderedValue, setRenderedValue] = useState(null); + useEffect(() => { + const _renderValue = async () => { + if (!handlers.renderItemValue) { + setRenderedValue(null); + return; + } + const rendered = await handlers.renderItemValue(controlNet); + setRenderedValue(rendered); + }; + + _renderValue(); }, [handlers, controlNet]); return ; diff --git a/invokeai/frontend/web/src/features/metadata/components/MetadataIPAdapters.tsx b/invokeai/frontend/web/src/features/metadata/components/MetadataIPAdapters.tsx index 69f9ec986d..8c3705d698 100644 --- a/invokeai/frontend/web/src/features/metadata/components/MetadataIPAdapters.tsx +++ b/invokeai/frontend/web/src/features/metadata/components/MetadataIPAdapters.tsx @@ -1,4 +1,3 @@ -import { Text } from '@invoke-ai/ui-library'; import type { IPAdapterConfig } from 'features/controlAdapters/store/types'; import { MetadataItemView } from 'features/metadata/components/MetadataItemView'; import type { MetadataHandlers } from 'features/metadata/types'; @@ -51,11 +50,18 @@ const MetadataViewIPAdapter = ({ handlers.recallItem(ipAdapter, true); }, [handlers, ipAdapter]); - const renderedValue = useMemo(() => { - if (!handlers.renderItemValue) { - return null; - } - return {handlers.renderItemValue(ipAdapter)}; + const [renderedValue, setRenderedValue] = useState(null); + useEffect(() => { + const _renderValue = async () => { + if (!handlers.renderItemValue) { + setRenderedValue(null); + return; + } + const rendered = await handlers.renderItemValue(ipAdapter); + setRenderedValue(rendered); + }; + + _renderValue(); }, [handlers, ipAdapter]); return ; diff --git a/invokeai/frontend/web/src/features/metadata/components/MetadataLoRAs.tsx b/invokeai/frontend/web/src/features/metadata/components/MetadataLoRAs.tsx index 40f6bc427b..7e78985c49 100644 --- a/invokeai/frontend/web/src/features/metadata/components/MetadataLoRAs.tsx +++ b/invokeai/frontend/web/src/features/metadata/components/MetadataLoRAs.tsx @@ -1,4 +1,3 @@ -import { Text } from '@invoke-ai/ui-library'; import type { LoRA } from 'features/lora/store/loraSlice'; import { MetadataItemView } from 'features/metadata/components/MetadataItemView'; import type { MetadataHandlers } from 'features/metadata/types'; @@ -51,11 +50,18 @@ const MetadataViewLoRA = ({ handlers.recallItem(lora, true); }, [handlers, lora]); - const renderedValue = useMemo(() => { - if (!handlers.renderItemValue) { - return null; - } - return {handlers.renderItemValue(lora)}; + const [renderedValue, setRenderedValue] = useState(null); + useEffect(() => { + const _renderValue = async () => { + if (!handlers.renderItemValue) { + setRenderedValue(null); + return; + } + const rendered = await handlers.renderItemValue(lora); + setRenderedValue(rendered); + }; + + _renderValue(); }, [handlers, lora]); return ; diff --git a/invokeai/frontend/web/src/features/metadata/components/MetadataT2IAdapters.tsx b/invokeai/frontend/web/src/features/metadata/components/MetadataT2IAdapters.tsx index 209ab4f2ed..4464fbbd8f 100644 --- a/invokeai/frontend/web/src/features/metadata/components/MetadataT2IAdapters.tsx +++ b/invokeai/frontend/web/src/features/metadata/components/MetadataT2IAdapters.tsx @@ -1,4 +1,3 @@ -import { Text } from '@invoke-ai/ui-library'; import type { T2IAdapterConfig } from 'features/controlAdapters/store/types'; import { MetadataItemView } from 'features/metadata/components/MetadataItemView'; import type { MetadataHandlers } from 'features/metadata/types'; @@ -55,12 +54,19 @@ const MetadataViewT2IAdapter = ({ } handlers.recallItem(t2iAdapter, true); }, [handlers, t2iAdapter]); + + const [renderedValue, setRenderedValue] = useState(null); + useEffect(() => { + const _renderValue = async () => { + if (!handlers.renderItemValue) { + setRenderedValue(null); + return; + } + const rendered = await handlers.renderItemValue(t2iAdapter); + setRenderedValue(rendered); + }; - const renderedValue = useMemo(() => { - if (!handlers.renderItemValue) { - return null; - } - return {handlers.renderItemValue(t2iAdapter)}; + _renderValue(); }, [handlers, t2iAdapter]); return ; diff --git a/invokeai/frontend/web/src/features/metadata/hooks/useMetadataItem.tsx b/invokeai/frontend/web/src/features/metadata/hooks/useMetadataItem.tsx index 178ac5155a..63d987569b 100644 --- a/invokeai/frontend/web/src/features/metadata/hooks/useMetadataItem.tsx +++ b/invokeai/frontend/web/src/features/metadata/hooks/useMetadataItem.tsx @@ -3,10 +3,14 @@ import type { MetadataHandlers } from 'features/metadata/types'; import { MetadataParseFailedToken, MetadataParsePendingToken } from 'features/metadata/util/parsers'; import { useCallback, useEffect, useMemo, useState } from 'react'; +const pendingRenderedValue = Loading; +const failedRenderedValue = Parsing Failed; + export const useMetadataItem = (metadata: unknown, handlers: MetadataHandlers) => { const [value, setValue] = useState( MetadataParsePendingToken ); + const [renderedValue, setRenderedValue] = useState(pendingRenderedValue); useEffect(() => { const _parse = async () => { @@ -24,20 +28,27 @@ export const useMetadataItem = (metadata: unknown, handlers: MetadataHandler const label = useMemo(() => handlers.getLabel(), [handlers]); - const renderedValue = useMemo(() => { - if (value === MetadataParsePendingToken) { - return Loading; - } - if (value === MetadataParseFailedToken) { - return Parsing Failed; - } + useEffect(() => { + const _renderValue = async () => { + if (value === MetadataParsePendingToken) { + setRenderedValue(pendingRenderedValue); + return; + } + if (value === MetadataParseFailedToken) { + setRenderedValue(failedRenderedValue); + return; + } - const rendered = handlers.renderValue(value); + const rendered = await handlers.renderValue(value); - if (typeof rendered === 'string') { - return {rendered}; - } - return rendered; + if (typeof rendered === 'string') { + setRenderedValue({rendered}); + return; + } + setRenderedValue(rendered); + }; + + _renderValue(); }, [handlers, value]); const onRecall = useCallback(() => { diff --git a/invokeai/frontend/web/src/features/metadata/types.ts b/invokeai/frontend/web/src/features/metadata/types.ts index 366f11701e..4390a525b7 100644 --- a/invokeai/frontend/web/src/features/metadata/types.ts +++ b/invokeai/frontend/web/src/features/metadata/types.ts @@ -1,7 +1,7 @@ /** * Renders a value of type T as a React node. */ -export type MetadataRenderValueFunc = (value: T) => React.ReactNode; +export type MetadataRenderValueFunc = (value: T) => Promise; /** * Gets the label of the current metadata item as a string. diff --git a/invokeai/frontend/web/src/features/metadata/util/handlers.ts b/invokeai/frontend/web/src/features/metadata/util/handlers.ts index 49d85f4eb3..8f520a6b4e 100644 --- a/invokeai/frontend/web/src/features/metadata/util/handlers.ts +++ b/invokeai/frontend/web/src/features/metadata/util/handlers.ts @@ -11,18 +11,38 @@ import type { MetadataRenderValueFunc, MetadataValidateFunc, } from 'features/metadata/types'; +import { fetchModelConfig } from 'features/metadata/util/modelFetchingHelpers'; import { validators } from 'features/metadata/util/validators'; +import type { ModelIdentifierWithBase } from 'features/nodes/types/common'; import { t } from 'i18next'; -import type { AnyModelConfig } from 'services/api/types'; import { parsers } from './parsers'; import { recallers } from './recallers'; -const renderModelConfigValue: MetadataRenderValueFunc = (value) => - `${value.name} (${value.base.toUpperCase()}, ${value.key})`; -const renderLoRAValue: MetadataRenderValueFunc = (value) => `${value.model.key} (${value.weight})`; -const renderControlAdapterValue: MetadataRenderValueFunc = (value) => - `${value.model?.key} (${value.weight})`; +const renderModelConfigValue: MetadataRenderValueFunc = async (value) => { + try { + const modelConfig = await fetchModelConfig(value.key); + return `${modelConfig.name} (${modelConfig.base.toUpperCase()})`; + } catch { + return `${value.key} (${value.base.toUpperCase()})`; + } +}; +const renderLoRAValue: MetadataRenderValueFunc = async (value) => { + try { + const modelConfig = await fetchModelConfig(value.model.key); + return `${modelConfig.name} (${modelConfig.base.toUpperCase()}) - ${value.weight}`; + } catch { + return `${value.model.key} (${value.model.base.toUpperCase()}) - ${value.weight}`; + } +}; +const renderControlAdapterValue: MetadataRenderValueFunc = async (value) => { + try { + const modelConfig = await fetchModelConfig(value.model?.key ?? 'none'); + return `${modelConfig.name} (${modelConfig.base.toUpperCase()}) - ${value.weight}`; + } catch { + return `${value.model?.key} (${value.model?.base.toUpperCase()}) - ${value.weight}`; + } +}; const parameterSetToast = (parameter: string, description?: string) => { toast({ @@ -130,6 +150,8 @@ const buildRecallItem = } }; +const resolveToString = (value: unknown) => new Promise((resolve) => resolve(String(value))); + const buildHandlers: BuildMetadataHandlers = ({ getLabel, parser, @@ -146,8 +168,8 @@ const buildHandlers: BuildMetadataHandlers = ({ recall: recaller ? buildRecall({ recaller, validator, getLabel }) : undefined, recallItem: itemRecaller ? buildRecallItem({ itemRecaller, itemValidator, getLabel }) : undefined, getLabel, - renderValue: renderValue ?? String, - renderItemValue: renderItemValue ?? String, + renderValue: renderValue ?? resolveToString, + renderItemValue: renderItemValue ?? resolveToString, }); export const handlers = { diff --git a/invokeai/frontend/web/src/features/metadata/util/parsers.ts b/invokeai/frontend/web/src/features/metadata/util/parsers.ts index df0485a0c4..619b8cc26c 100644 --- a/invokeai/frontend/web/src/features/metadata/util/parsers.ts +++ b/invokeai/frontend/web/src/features/metadata/util/parsers.ts @@ -12,7 +12,6 @@ import type { MetadataParseFunc } from 'features/metadata/types'; import { fetchModelConfigWithTypeGuard, getModelKey, - getModelKeyAndBase, } from 'features/metadata/util/modelFetchingHelpers'; import { zControlField, @@ -26,17 +25,20 @@ import type { ParameterHeight, ParameterHRFEnabled, ParameterHRFMethod, + ParameterModel, ParameterNegativePrompt, ParameterNegativeStylePromptSDXL, ParameterPositivePrompt, ParameterPositiveStylePromptSDXL, ParameterScheduler, + ParameterSDXLRefinerModel, ParameterSDXLRefinerNegativeAestheticScore, ParameterSDXLRefinerPositiveAestheticScore, ParameterSDXLRefinerStart, ParameterSeed, ParameterSteps, ParameterStrength, + ParameterVAEModel, ParameterWidth, } from 'features/parameters/types/parameterSchemas'; import { @@ -60,7 +62,6 @@ import { isParameterWidth, } from 'features/parameters/types/parameterSchemas'; import { get, isArray, isString } from 'lodash-es'; -import type { NonRefinerMainModelConfig, RefinerMainModelConfig, VAEModelConfig } from 'services/api/types'; import { isControlNetModelConfig, isIPAdapterModelConfig, @@ -163,25 +164,28 @@ const parseRefinerNegativeAestheticScore: MetadataParseFunc = (metadata) => getProperty(metadata, 'refiner_start', isParameterSDXLRefinerStart); -const parseMainModel: MetadataParseFunc = async (metadata) => { +const parseMainModel: MetadataParseFunc = async (metadata) => { const model = await getProperty(metadata, 'model', undefined); const key = await getModelKey(model, 'main'); const mainModelConfig = await fetchModelConfigWithTypeGuard(key, isNonRefinerMainModelConfig); - return mainModelConfig; + const modelIdentifier = zModelIdentifierWithBase.parse(mainModelConfig); + return modelIdentifier; }; -const parseRefinerModel: MetadataParseFunc = async (metadata) => { +const parseRefinerModel: MetadataParseFunc = async (metadata) => { const refiner_model = await getProperty(metadata, 'refiner_model', undefined); const key = await getModelKey(refiner_model, 'main'); const refinerModelConfig = await fetchModelConfigWithTypeGuard(key, isRefinerMainModelModelConfig); - return refinerModelConfig; + const modelIdentifier = zModelIdentifierWithBase.parse(refinerModelConfig); + return modelIdentifier; }; -const parseVAEModel: MetadataParseFunc = async (metadata) => { +const parseVAEModel: MetadataParseFunc = async (metadata) => { const vae = await getProperty(metadata, 'vae', undefined); const key = await getModelKey(vae, 'vae'); const vaeModelConfig = await fetchModelConfigWithTypeGuard(key, isVAEModelConfig); - return vaeModelConfig; + const modelIdentifier = zModelIdentifierWithBase.parse(vaeModelConfig); + return modelIdentifier; }; const parseLoRA: MetadataParseFunc = async (metadataItem) => { @@ -194,7 +198,7 @@ const parseLoRA: MetadataParseFunc = async (metadataItem) => { const loraModelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig); return { - model: getModelKeyAndBase(loraModelConfig), + model: zModelIdentifierWithBase.parse(loraModelConfig), weight: isParameterLoRAWeight(weight) ? weight : defaultLoRAConfig.weight, isEnabled: true, }; diff --git a/invokeai/frontend/web/src/features/metadata/util/recallers.ts b/invokeai/frontend/web/src/features/metadata/util/recallers.ts index 5e156dc035..cc867933b1 100644 --- a/invokeai/frontend/web/src/features/metadata/util/recallers.ts +++ b/invokeai/frontend/web/src/features/metadata/util/recallers.ts @@ -5,7 +5,6 @@ import { setHrfEnabled, setHrfMethod, setHrfStrength } from 'features/hrf/store/ import type { LoRA } from 'features/lora/store/loraSlice'; import { loraRecalled } from 'features/lora/store/loraSlice'; import type { MetadataRecallFunc } from 'features/metadata/types'; -import { zModelIdentifierWithBase } from 'features/nodes/types/common'; import { modelSelected } from 'features/parameters/store/actions'; import { heightRecalled, @@ -26,17 +25,20 @@ import type { ParameterHeight, ParameterHRFEnabled, ParameterHRFMethod, + ParameterModel, ParameterNegativePrompt, ParameterNegativeStylePromptSDXL, ParameterPositivePrompt, ParameterPositiveStylePromptSDXL, ParameterScheduler, + ParameterSDXLRefinerModel, ParameterSDXLRefinerNegativeAestheticScore, ParameterSDXLRefinerPositiveAestheticScore, ParameterSDXLRefinerStart, ParameterSeed, ParameterSteps, ParameterStrength, + ParameterVAEModel, ParameterWidth, } from 'features/parameters/types/parameterSchemas'; import { @@ -50,7 +52,6 @@ import { setRefinerStart, setRefinerSteps, } from 'features/sdxl/store/sdxlSlice'; -import type { NonRefinerMainModelConfig, RefinerMainModelConfig, VAEModelConfig } from 'services/api/types'; const recallPositivePrompt: MetadataRecallFunc = (positivePrompt) => { getStore().dispatch(setPositivePrompt(positivePrompt)); @@ -140,23 +141,20 @@ const recallRefinerStart: MetadataRecallFunc = (refin getStore().dispatch(setRefinerStart(refinerStart)); }; -const recallModel: MetadataRecallFunc = (model) => { - const modelIdentifier = zModelIdentifierWithBase.parse(model); - getStore().dispatch(modelSelected(modelIdentifier)); +const recallModel: MetadataRecallFunc = (model) => { + getStore().dispatch(modelSelected(model)); }; -const recallRefinerModel: MetadataRecallFunc = (refinerModel) => { - const modelIdentifier = zModelIdentifierWithBase.parse(refinerModel); - getStore().dispatch(refinerModelChanged(modelIdentifier)); +const recallRefinerModel: MetadataRecallFunc = (refinerModel) => { + getStore().dispatch(refinerModelChanged(refinerModel)); }; -const recallVAE: MetadataRecallFunc = (vaeModel) => { +const recallVAE: MetadataRecallFunc = (vaeModel) => { if (!vaeModel) { getStore().dispatch(vaeSelected(null)); return; } - const modelIdentifier = zModelIdentifierWithBase.parse(vaeModel); - getStore().dispatch(vaeSelected(modelIdentifier)); + getStore().dispatch(vaeSelected(vaeModel)); }; const recallLoRA: MetadataRecallFunc = (lora) => { diff --git a/invokeai/frontend/web/src/features/metadata/util/validators.ts b/invokeai/frontend/web/src/features/metadata/util/validators.ts index 1f34e38651..9a9ffae723 100644 --- a/invokeai/frontend/web/src/features/metadata/util/validators.ts +++ b/invokeai/frontend/web/src/features/metadata/util/validators.ts @@ -3,7 +3,8 @@ import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'featur import type { LoRA } from 'features/lora/store/loraSlice'; import type { MetadataValidateFunc } from 'features/metadata/types'; import { InvalidModelConfigError } from 'features/metadata/util/modelFetchingHelpers'; -import type { BaseModelType, RefinerMainModelConfig, VAEModelConfig } from 'services/api/types'; +import type { ParameterSDXLRefinerModel, ParameterVAEModel } from 'features/parameters/types/parameterSchemas'; +import type { BaseModelType } from 'services/api/types'; /** * Checks the given base model type against the currently-selected model's base type and throws an error if they are @@ -21,12 +22,12 @@ const validateBaseCompatibility = (base?: BaseModelType, message?: string) => { } }; -const validateRefinerModel: MetadataValidateFunc = (refinerModel) => { +const validateRefinerModel: MetadataValidateFunc = (refinerModel) => { validateBaseCompatibility('sdxl', 'Refiner incompatible with currently-selected model'); return new Promise((resolve) => resolve(refinerModel)); }; -const validateVAEModel: MetadataValidateFunc = (vaeModel) => { +const validateVAEModel: MetadataValidateFunc = (vaeModel) => { validateBaseCompatibility(vaeModel.base, 'VAE incompatible with currently-selected model'); return new Promise((resolve) => resolve(vaeModel)); };