fix(ui): handle new model format for metadata

This commit is contained in:
psychedelicious 2024-02-21 19:42:49 +11:00
parent 9d9b417432
commit b59d23d608
4 changed files with 178 additions and 77 deletions

View File

@ -1,3 +1,4 @@
import { isModelIdentifier } from 'features/nodes/types/common';
import type { import type {
ControlNetMetadataItem, ControlNetMetadataItem,
CoreMetadata, CoreMetadata,
@ -6,15 +7,10 @@ import type {
T2IAdapterMetadataItem, T2IAdapterMetadataItem,
} from 'features/nodes/types/metadata'; } from 'features/nodes/types/metadata';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters'; import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import {
isParameterControlNetModel,
isParameterLoRAModel,
isParameterT2IAdapterModel,
} from 'features/parameters/types/parameterSchemas';
import { memo, useCallback, useMemo } from 'react'; import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import ImageMetadataItem from './ImageMetadataItem'; import ImageMetadataItem, { ModelMetadataItem, VAEMetadataItem } from './ImageMetadataItem';
type Props = { type Props = {
metadata?: CoreMetadata; metadata?: CoreMetadata;
@ -147,19 +143,19 @@ const ImageMetadataActions = (props: Props) => {
const validControlNets: ControlNetMetadataItem[] = useMemo(() => { const validControlNets: ControlNetMetadataItem[] = useMemo(() => {
return metadata?.controlnets return metadata?.controlnets
? metadata.controlnets.filter((controlnet) => isParameterControlNetModel(controlnet.control_model)) ? metadata.controlnets.filter((controlnet) => isModelIdentifier(controlnet.control_model))
: []; : [];
}, [metadata?.controlnets]); }, [metadata?.controlnets]);
const validIPAdapters: IPAdapterMetadataItem[] = useMemo(() => { const validIPAdapters: IPAdapterMetadataItem[] = useMemo(() => {
return metadata?.ipAdapters return metadata?.ipAdapters
? metadata.ipAdapters.filter((ipAdapter) => isParameterControlNetModel(ipAdapter.ip_adapter_model)) ? metadata.ipAdapters.filter((ipAdapter) => isModelIdentifier(ipAdapter.ip_adapter_model))
: []; : [];
}, [metadata?.ipAdapters]); }, [metadata?.ipAdapters]);
const validT2IAdapters: T2IAdapterMetadataItem[] = useMemo(() => { const validT2IAdapters: T2IAdapterMetadataItem[] = useMemo(() => {
return metadata?.t2iAdapters return metadata?.t2iAdapters
? metadata.t2iAdapters.filter((t2iAdapter) => isParameterT2IAdapterModel(t2iAdapter.t2i_adapter_model)) ? metadata.t2iAdapters.filter((t2iAdapter) => isModelIdentifier(t2iAdapter.t2i_adapter_model))
: []; : [];
}, [metadata?.t2iAdapters]); }, [metadata?.t2iAdapters]);
@ -209,7 +205,7 @@ const ImageMetadataActions = (props: Props) => {
<ImageMetadataItem label={t('metadata.seed')} value={metadata.seed} onClick={handleRecallSeed} /> <ImageMetadataItem label={t('metadata.seed')} value={metadata.seed} onClick={handleRecallSeed} />
)} )}
{metadata.model !== undefined && metadata.model !== null && metadata.model.key && ( {metadata.model !== undefined && metadata.model !== null && metadata.model.key && (
<ImageMetadataItem label={t('metadata.model')} value={metadata.model.key} onClick={handleRecallModel} /> <ModelMetadataItem label={t('metadata.model')} modelKey={metadata.model.key} onClick={handleRecallModel} />
)} )}
{metadata.width && ( {metadata.width && (
<ImageMetadataItem label={t('metadata.width')} value={metadata.width} onClick={handleRecallWidth} /> <ImageMetadataItem label={t('metadata.width')} value={metadata.width} onClick={handleRecallWidth} />
@ -220,11 +216,7 @@ const ImageMetadataActions = (props: Props) => {
{metadata.scheduler && ( {metadata.scheduler && (
<ImageMetadataItem label={t('metadata.scheduler')} value={metadata.scheduler} onClick={handleRecallScheduler} /> <ImageMetadataItem label={t('metadata.scheduler')} value={metadata.scheduler} onClick={handleRecallScheduler} />
)} )}
<ImageMetadataItem <VAEMetadataItem label={t('metadata.vae')} modelKey={metadata.vae?.key} onClick={handleRecallVaeModel} />
label={t('metadata.vae')}
value={metadata.vae?.key ?? 'Default'}
onClick={handleRecallVaeModel}
/>
{metadata.steps && ( {metadata.steps && (
<ImageMetadataItem label={t('metadata.steps')} value={metadata.steps} onClick={handleRecallSteps} /> <ImageMetadataItem label={t('metadata.steps')} value={metadata.steps} onClick={handleRecallSteps} />
)} )}
@ -264,38 +256,42 @@ const ImageMetadataActions = (props: Props) => {
)} )}
{metadata.loras && {metadata.loras &&
metadata.loras.map((lora, index) => { metadata.loras.map((lora, index) => {
if (isParameterLoRAModel(lora.lora)) { if (isModelIdentifier(lora.lora)) {
return ( return (
<ImageMetadataItem <ModelMetadataItem
key={index} key={index}
label="LoRA" label="LoRA"
value={`${lora.lora.key} - ${lora.weight}`} modelKey={lora.lora.key}
extra={` - ${lora.weight}`}
onClick={handleRecallLoRA.bind(null, lora)} onClick={handleRecallLoRA.bind(null, lora)}
/> />
); );
} }
})} })}
{validControlNets.map((controlnet, index) => ( {validControlNets.map((controlnet, index) => (
<ImageMetadataItem <ModelMetadataItem
key={index} key={index}
label="ControlNet" label="ControlNet"
value={`${controlnet.control_model?.key} - ${controlnet.control_weight}`} modelKey={controlnet.control_model?.key}
extra={` - ${controlnet.control_weight}`}
onClick={handleRecallControlNet.bind(null, controlnet)} onClick={handleRecallControlNet.bind(null, controlnet)}
/> />
))} ))}
{validIPAdapters.map((ipAdapter, index) => ( {validIPAdapters.map((ipAdapter, index) => (
<ImageMetadataItem <ModelMetadataItem
key={index} key={index}
label="IP Adapter" label="IP Adapter"
value={`${ipAdapter.ip_adapter_model?.key} - ${ipAdapter.weight}`} modelKey={ipAdapter.ip_adapter_model?.key}
extra={` - ${ipAdapter.weight}`}
onClick={handleRecallIPAdapter.bind(null, ipAdapter)} onClick={handleRecallIPAdapter.bind(null, ipAdapter)}
/> />
))} ))}
{validT2IAdapters.map((t2iAdapter, index) => ( {validT2IAdapters.map((t2iAdapter, index) => (
<ImageMetadataItem <ModelMetadataItem
key={index} key={index}
label="T2I Adapter" label="T2I Adapter"
value={`${t2iAdapter.t2i_adapter_model?.key} - ${t2iAdapter.weight}`} modelKey={t2iAdapter.t2i_adapter_model?.key}
extra={` - ${t2iAdapter.weight}`}
onClick={handleRecallT2IAdapter.bind(null, t2iAdapter)} onClick={handleRecallT2IAdapter.bind(null, t2iAdapter)}
/> />
))} ))}

View File

@ -1,8 +1,10 @@
import { ExternalLink, Flex, IconButton, Text, Tooltip } from '@invoke-ai/ui-library'; import { ExternalLink, Flex, IconButton, Text, Tooltip } from '@invoke-ai/ui-library';
import { memo, useCallback } from 'react'; import { skipToken } from '@reduxjs/toolkit/query';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { IoArrowUndoCircleOutline } from 'react-icons/io5'; import { IoArrowUndoCircleOutline } from 'react-icons/io5';
import { PiCopyBold } from 'react-icons/pi'; import { PiCopyBold } from 'react-icons/pi';
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
type MetadataItemProps = { type MetadataItemProps = {
isLink?: boolean; isLink?: boolean;
@ -18,8 +20,9 @@ type MetadataItemProps = {
*/ */
const ImageMetadataItem = ({ label, value, onClick, isLink, labelPosition, withCopy = false }: MetadataItemProps) => { const ImageMetadataItem = ({ label, value, onClick, isLink, labelPosition, withCopy = false }: MetadataItemProps) => {
const { t } = useTranslation(); const { t } = useTranslation();
const handleCopy = useCallback(() => {
const handleCopy = useCallback(() => navigator.clipboard.writeText(value.toString()), [value]); navigator.clipboard.writeText(value?.toString());
}, [value]);
if (!value) { if (!value) {
return null; return null;
@ -68,3 +71,40 @@ const ImageMetadataItem = ({ label, value, onClick, isLink, labelPosition, withC
}; };
export default memo(ImageMetadataItem); export default memo(ImageMetadataItem);
type VAEMetadataItemProps = {
label: string;
modelKey?: string;
onClick: () => void;
};
export const VAEMetadataItem = memo(({ label, modelKey, onClick }: VAEMetadataItemProps) => {
const { data: modelConfig } = useGetModelConfigQuery(modelKey ?? skipToken);
return (
<ImageMetadataItem label={label} value={modelKey ? modelConfig?.name ?? modelKey : 'Default'} onClick={onClick} />
);
});
VAEMetadataItem.displayName = 'VAEMetadataItem';
type ModelMetadataItemProps = {
label: string;
modelKey?: string;
extra?: string;
onClick: () => void;
};
export const ModelMetadataItem = memo(({ label, modelKey, extra, onClick }: ModelMetadataItemProps) => {
const { data: modelConfig } = useGetModelConfigQuery(modelKey ?? skipToken);
const value = useMemo(() => {
if (modelConfig) {
return `${modelConfig.name}${extra ?? ''}`;
}
return `${modelKey}${extra ?? ''}`;
}, [extra, modelConfig, modelKey]);
return <ImageMetadataItem label={label} value={value} onClick={onClick} />;
});
ModelMetadataItem.displayName = 'ModelMetadataItem';

View File

@ -3,8 +3,8 @@ import { z } from 'zod';
import { import {
zControlField, zControlField,
zIPAdapterField, zIPAdapterField,
zLoRAModelField,
zMainModelField, zMainModelField,
zModelFieldBase,
zSDXLRefinerModelField, zSDXLRefinerModelField,
zT2IAdapterField, zT2IAdapterField,
zVAEModelField, zVAEModelField,
@ -15,7 +15,7 @@ import {
// - https://github.com/colinhacks/zod/issues/2106 // - https://github.com/colinhacks/zod/issues/2106
// - https://github.com/colinhacks/zod/issues/2854 // - https://github.com/colinhacks/zod/issues/2854
export const zLoRAMetadataItem = z.object({ export const zLoRAMetadataItem = z.object({
lora: zLoRAModelField.deepPartial(), lora: zModelFieldBase.deepPartial(),
weight: z.number(), weight: z.number(),
}); });
const zControlNetMetadataItem = zControlField.deepPartial(); const zControlNetMetadataItem = zControlField.deepPartial();

View File

@ -11,6 +11,8 @@ import {
} from 'features/controlAdapters/util/buildControlAdapter'; } from 'features/controlAdapters/util/buildControlAdapter';
import { setHrfEnabled, setHrfMethod, setHrfStrength } from 'features/hrf/store/hrfSlice'; import { setHrfEnabled, setHrfMethod, setHrfStrength } from 'features/hrf/store/hrfSlice';
import { loraRecalled, lorasCleared } from 'features/lora/store/loraSlice'; import { loraRecalled, lorasCleared } from 'features/lora/store/loraSlice';
import type { ModelIdentifier } from 'features/nodes/types/common';
import { isModelIdentifier } from 'features/nodes/types/common';
import type { import type {
ControlNetMetadataItem, ControlNetMetadataItem,
CoreMetadata, CoreMetadata,
@ -37,13 +39,9 @@ import type { ParameterModel } from 'features/parameters/types/parameterSchemas'
import { import {
isParameterCFGRescaleMultiplier, isParameterCFGRescaleMultiplier,
isParameterCFGScale, isParameterCFGScale,
isParameterControlNetModel,
isParameterHeight, isParameterHeight,
isParameterHRFEnabled, isParameterHRFEnabled,
isParameterHRFMethod, isParameterHRFMethod,
isParameterIPAdapterModel,
isParameterLoRAModel,
isParameterModel,
isParameterNegativePrompt, isParameterNegativePrompt,
isParameterNegativeStylePromptSDXL, isParameterNegativeStylePromptSDXL,
isParameterPositivePrompt, isParameterPositivePrompt,
@ -56,7 +54,6 @@ import {
isParameterSeed, isParameterSeed,
isParameterSteps, isParameterSteps,
isParameterStrength, isParameterStrength,
isParameterVAEModel,
isParameterWidth, isParameterWidth,
} from 'features/parameters/types/parameterSchemas'; } from 'features/parameters/types/parameterSchemas';
import { import {
@ -73,15 +70,20 @@ import {
import { isNil } from 'lodash-es'; import { isNil } from 'lodash-es';
import { useCallback } from 'react'; import { useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { ALL_BASE_MODELS } from 'services/api/constants';
import { import {
controlNetModelsAdapterSelectors, controlNetModelsAdapterSelectors,
ipAdapterModelsAdapterSelectors, ipAdapterModelsAdapterSelectors,
loraModelsAdapterSelectors, loraModelsAdapterSelectors,
mainModelsAdapterSelectors,
t2iAdapterModelsAdapterSelectors, t2iAdapterModelsAdapterSelectors,
useGetControlNetModelsQuery, useGetControlNetModelsQuery,
useGetIPAdapterModelsQuery, useGetIPAdapterModelsQuery,
useGetLoRAModelsQuery, useGetLoRAModelsQuery,
useGetMainModelsQuery,
useGetT2IAdapterModelsQuery, useGetT2IAdapterModelsQuery,
useGetVaeModelsQuery,
vaeModelsAdapterSelectors,
} from 'services/api/endpoints/models'; } from 'services/api/endpoints/models';
import type { ImageDTO } from 'services/api/types'; import type { ImageDTO } from 'services/api/types';
import { v4 as uuidv4 } from 'uuid'; import { v4 as uuidv4 } from 'uuid';
@ -278,21 +280,6 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast] [dispatch, parameterSetToast, parameterNotSetToast]
); );
/**
* Recall model with toast
*/
const recallModel = useCallback(
(model: unknown) => {
if (!isParameterModel(model)) {
parameterNotSetToast();
return;
}
dispatch(modelSelected(model));
parameterSetToast();
},
[dispatch, parameterSetToast, parameterNotSetToast]
);
/** /**
* Recall scheduler with toast * Recall scheduler with toast
*/ */
@ -308,25 +295,6 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast] [dispatch, parameterSetToast, parameterNotSetToast]
); );
/**
* Recall vae model
*/
const recallVaeModel = useCallback(
(vae: unknown) => {
if (!isParameterVAEModel(vae) && !isNil(vae)) {
parameterNotSetToast();
return;
}
if (isNil(vae)) {
dispatch(vaeSelected(null));
} else {
dispatch(vaeSelected(vae));
}
parameterSetToast();
},
[dispatch, parameterSetToast, parameterNotSetToast]
);
/** /**
* Recall steps with toast * Recall steps with toast
*/ */
@ -452,6 +420,95 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast] [dispatch, parameterSetToast, parameterNotSetToast]
); );
const { data: mainModels } = useGetMainModelsQuery(ALL_BASE_MODELS);
const prepareMainModelMetadataItem = useCallback(
(model: ModelIdentifier) => {
const matchingModel = mainModels ? mainModelsAdapterSelectors.selectById(mainModels, model.key) : undefined;
if (!matchingModel) {
return { model: null, error: 'Model is not installed' };
}
return { model: matchingModel, error: null };
},
[mainModels]
);
/**
* Recall model with toast
*/
const recallModel = useCallback(
(model: unknown) => {
if (!isModelIdentifier(model)) {
parameterNotSetToast();
return;
}
const result = prepareMainModelMetadataItem(model);
if (!result.model) {
parameterNotSetToast(result.error);
return;
}
dispatch(modelSelected(result.model));
parameterSetToast();
},
[prepareMainModelMetadataItem, dispatch, parameterSetToast, parameterNotSetToast]
);
const { data: vaeModels } = useGetVaeModelsQuery();
const prepareVAEMetadataItem = useCallback(
(vae: ModelIdentifier, newModel?: ParameterModel) => {
const matchingModel = vaeModels ? vaeModelsAdapterSelectors.selectById(vaeModels, vae.key) : undefined;
if (!matchingModel) {
return { vae: null, error: 'VAE model is not installed' };
}
const isCompatibleBaseModel = matchingModel?.base === (newModel ?? model)?.base;
if (!isCompatibleBaseModel) {
return {
vae: null,
error: 'VAE incompatible with currently-selected model',
};
}
return { vae: matchingModel, error: null };
},
[model, vaeModels]
);
/**
* Recall vae model
*/
const recallVaeModel = useCallback(
(vae: unknown) => {
if (!isModelIdentifier(vae) && !isNil(vae)) {
parameterNotSetToast();
return;
}
if (isNil(vae)) {
dispatch(vaeSelected(null));
parameterSetToast();
return;
}
const result = prepareVAEMetadataItem(vae);
if (!result.vae) {
parameterNotSetToast(result.error);
return;
}
dispatch(vaeSelected(result.vae));
parameterSetToast();
},
[prepareVAEMetadataItem, dispatch, parameterSetToast, parameterNotSetToast]
);
/** /**
* Recall LoRA with toast * Recall LoRA with toast
*/ */
@ -460,7 +517,7 @@ export const useRecallParameters = () => {
const prepareLoRAMetadataItem = useCallback( const prepareLoRAMetadataItem = useCallback(
(loraMetadataItem: LoRAMetadataItem, newModel?: ParameterModel) => { (loraMetadataItem: LoRAMetadataItem, newModel?: ParameterModel) => {
if (!isParameterLoRAModel(loraMetadataItem.lora)) { if (!isModelIdentifier(loraMetadataItem.lora)) {
return { lora: null, error: 'Invalid LoRA model' }; return { lora: null, error: 'Invalid LoRA model' };
} }
@ -510,7 +567,7 @@ export const useRecallParameters = () => {
const prepareControlNetMetadataItem = useCallback( const prepareControlNetMetadataItem = useCallback(
(controlnetMetadataItem: ControlNetMetadataItem, newModel?: ParameterModel) => { (controlnetMetadataItem: ControlNetMetadataItem, newModel?: ParameterModel) => {
if (!isParameterControlNetModel(controlnetMetadataItem.control_model)) { if (!isModelIdentifier(controlnetMetadataItem.control_model)) {
return { controlnet: null, error: 'Invalid ControlNet model' }; return { controlnet: null, error: 'Invalid ControlNet model' };
} }
@ -584,7 +641,7 @@ export const useRecallParameters = () => {
const prepareT2IAdapterMetadataItem = useCallback( const prepareT2IAdapterMetadataItem = useCallback(
(t2iAdapterMetadataItem: T2IAdapterMetadataItem, newModel?: ParameterModel) => { (t2iAdapterMetadataItem: T2IAdapterMetadataItem, newModel?: ParameterModel) => {
if (!isParameterControlNetModel(t2iAdapterMetadataItem.t2i_adapter_model)) { if (!isModelIdentifier(t2iAdapterMetadataItem.t2i_adapter_model)) {
return { controlnet: null, error: 'Invalid ControlNet model' }; return { controlnet: null, error: 'Invalid ControlNet model' };
} }
@ -657,7 +714,7 @@ export const useRecallParameters = () => {
const prepareIPAdapterMetadataItem = useCallback( const prepareIPAdapterMetadataItem = useCallback(
(ipAdapterMetadataItem: IPAdapterMetadataItem, newModel?: ParameterModel) => { (ipAdapterMetadataItem: IPAdapterMetadataItem, newModel?: ParameterModel) => {
if (!isParameterIPAdapterModel(ipAdapterMetadataItem?.ip_adapter_model)) { if (!isModelIdentifier(ipAdapterMetadataItem?.ip_adapter_model)) {
return { ipAdapter: null, error: 'Invalid IP Adapter model' }; return { ipAdapter: null, error: 'Invalid IP Adapter model' };
} }
@ -762,9 +819,12 @@ export const useRecallParameters = () => {
let newModel: ParameterModel | undefined = undefined; let newModel: ParameterModel | undefined = undefined;
if (isParameterModel(model)) { if (isModelIdentifier(model)) {
newModel = model; const result = prepareMainModelMetadataItem(model);
dispatch(modelSelected(model)); if (result.model) {
dispatch(modelSelected(result.model));
newModel = result.model;
}
} }
if (isParameterCFGScale(cfg_scale)) { if (isParameterCFGScale(cfg_scale)) {
@ -786,11 +846,14 @@ export const useRecallParameters = () => {
if (isParameterScheduler(scheduler)) { if (isParameterScheduler(scheduler)) {
dispatch(setScheduler(scheduler)); dispatch(setScheduler(scheduler));
} }
if (isParameterVAEModel(vae) || isNil(vae)) { if (isModelIdentifier(vae) || isNil(vae)) {
if (isNil(vae)) { if (isNil(vae)) {
dispatch(vaeSelected(null)); dispatch(vaeSelected(null));
} else { } else {
dispatch(vaeSelected(vae)); const result = prepareVAEMetadataItem(vae, newModel);
if (result.vae) {
dispatch(vaeSelected(result.vae));
}
} }
} }
@ -898,6 +961,8 @@ export const useRecallParameters = () => {
dispatch, dispatch,
allParameterSetToast, allParameterSetToast,
allParameterNotSetToast, allParameterNotSetToast,
prepareMainModelMetadataItem,
prepareVAEMetadataItem,
prepareLoRAMetadataItem, prepareLoRAMetadataItem,
prepareControlNetMetadataItem, prepareControlNetMetadataItem,
prepareIPAdapterMetadataItem, prepareIPAdapterMetadataItem,