fix(ui): handle new model format for metadata

This commit is contained in:
psychedelicious 2024-02-21 19:42:49 +11:00 committed by Brandon Rising
parent 1ced80d492
commit 64c1ce895c
4 changed files with 178 additions and 77 deletions

View File

@ -1,3 +1,4 @@
import { isModelIdentifier } from 'features/nodes/types/common';
import type {
ControlNetMetadataItem,
CoreMetadata,
@ -6,15 +7,10 @@ import type {
T2IAdapterMetadataItem,
} from 'features/nodes/types/metadata';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import {
isParameterControlNetModel,
isParameterLoRAModel,
isParameterT2IAdapterModel,
} from 'features/parameters/types/parameterSchemas';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import ImageMetadataItem from './ImageMetadataItem';
import ImageMetadataItem, { ModelMetadataItem, VAEMetadataItem } from './ImageMetadataItem';
type Props = {
metadata?: CoreMetadata;
@ -147,19 +143,19 @@ const ImageMetadataActions = (props: Props) => {
const validControlNets: ControlNetMetadataItem[] = useMemo(() => {
return metadata?.controlnets
? metadata.controlnets.filter((controlnet) => isParameterControlNetModel(controlnet.control_model))
? metadata.controlnets.filter((controlnet) => isModelIdentifier(controlnet.control_model))
: [];
}, [metadata?.controlnets]);
const validIPAdapters: IPAdapterMetadataItem[] = useMemo(() => {
return metadata?.ipAdapters
? metadata.ipAdapters.filter((ipAdapter) => isParameterControlNetModel(ipAdapter.ip_adapter_model))
? metadata.ipAdapters.filter((ipAdapter) => isModelIdentifier(ipAdapter.ip_adapter_model))
: [];
}, [metadata?.ipAdapters]);
const validT2IAdapters: T2IAdapterMetadataItem[] = useMemo(() => {
return metadata?.t2iAdapters
? metadata.t2iAdapters.filter((t2iAdapter) => isParameterT2IAdapterModel(t2iAdapter.t2i_adapter_model))
? metadata.t2iAdapters.filter((t2iAdapter) => isModelIdentifier(t2iAdapter.t2i_adapter_model))
: [];
}, [metadata?.t2iAdapters]);
@ -209,7 +205,7 @@ const ImageMetadataActions = (props: Props) => {
<ImageMetadataItem label={t('metadata.seed')} value={metadata.seed} onClick={handleRecallSeed} />
)}
{metadata.model !== undefined && metadata.model !== null && metadata.model.key && (
<ImageMetadataItem label={t('metadata.model')} value={metadata.model.key} onClick={handleRecallModel} />
<ModelMetadataItem label={t('metadata.model')} modelKey={metadata.model.key} onClick={handleRecallModel} />
)}
{metadata.width && (
<ImageMetadataItem label={t('metadata.width')} value={metadata.width} onClick={handleRecallWidth} />
@ -220,11 +216,7 @@ const ImageMetadataActions = (props: Props) => {
{metadata.scheduler && (
<ImageMetadataItem label={t('metadata.scheduler')} value={metadata.scheduler} onClick={handleRecallScheduler} />
)}
<ImageMetadataItem
label={t('metadata.vae')}
value={metadata.vae?.key ?? 'Default'}
onClick={handleRecallVaeModel}
/>
<VAEMetadataItem label={t('metadata.vae')} modelKey={metadata.vae?.key} onClick={handleRecallVaeModel} />
{metadata.steps && (
<ImageMetadataItem label={t('metadata.steps')} value={metadata.steps} onClick={handleRecallSteps} />
)}
@ -264,38 +256,42 @@ const ImageMetadataActions = (props: Props) => {
)}
{metadata.loras &&
metadata.loras.map((lora, index) => {
if (isParameterLoRAModel(lora.lora)) {
if (isModelIdentifier(lora.lora)) {
return (
<ImageMetadataItem
<ModelMetadataItem
key={index}
label="LoRA"
value={`${lora.lora.key} - ${lora.weight}`}
modelKey={lora.lora.key}
extra={` - ${lora.weight}`}
onClick={handleRecallLoRA.bind(null, lora)}
/>
);
}
})}
{validControlNets.map((controlnet, index) => (
<ImageMetadataItem
<ModelMetadataItem
key={index}
label="ControlNet"
value={`${controlnet.control_model?.key} - ${controlnet.control_weight}`}
modelKey={controlnet.control_model?.key}
extra={` - ${controlnet.control_weight}`}
onClick={handleRecallControlNet.bind(null, controlnet)}
/>
))}
{validIPAdapters.map((ipAdapter, index) => (
<ImageMetadataItem
<ModelMetadataItem
key={index}
label="IP Adapter"
value={`${ipAdapter.ip_adapter_model?.key} - ${ipAdapter.weight}`}
modelKey={ipAdapter.ip_adapter_model?.key}
extra={` - ${ipAdapter.weight}`}
onClick={handleRecallIPAdapter.bind(null, ipAdapter)}
/>
))}
{validT2IAdapters.map((t2iAdapter, index) => (
<ImageMetadataItem
<ModelMetadataItem
key={index}
label="T2I Adapter"
value={`${t2iAdapter.t2i_adapter_model?.key} - ${t2iAdapter.weight}`}
modelKey={t2iAdapter.t2i_adapter_model?.key}
extra={` - ${t2iAdapter.weight}`}
onClick={handleRecallT2IAdapter.bind(null, t2iAdapter)}
/>
))}

View File

@ -1,8 +1,10 @@
import { ExternalLink, Flex, IconButton, Text, Tooltip } from '@invoke-ai/ui-library';
import { memo, useCallback } from 'react';
import { skipToken } from '@reduxjs/toolkit/query';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
import { PiCopyBold } from 'react-icons/pi';
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
type MetadataItemProps = {
isLink?: boolean;
@ -18,8 +20,9 @@ type MetadataItemProps = {
*/
const ImageMetadataItem = ({ label, value, onClick, isLink, labelPosition, withCopy = false }: MetadataItemProps) => {
const { t } = useTranslation();
const handleCopy = useCallback(() => navigator.clipboard.writeText(value.toString()), [value]);
const handleCopy = useCallback(() => {
navigator.clipboard.writeText(value?.toString());
}, [value]);
if (!value) {
return null;
@ -68,3 +71,40 @@ const ImageMetadataItem = ({ label, value, onClick, isLink, labelPosition, withC
};
export default memo(ImageMetadataItem);
type VAEMetadataItemProps = {
label: string;
modelKey?: string;
onClick: () => void;
};
export const VAEMetadataItem = memo(({ label, modelKey, onClick }: VAEMetadataItemProps) => {
const { data: modelConfig } = useGetModelConfigQuery(modelKey ?? skipToken);
return (
<ImageMetadataItem label={label} value={modelKey ? modelConfig?.name ?? modelKey : 'Default'} onClick={onClick} />
);
});
VAEMetadataItem.displayName = 'VAEMetadataItem';
type ModelMetadataItemProps = {
label: string;
modelKey?: string;
extra?: string;
onClick: () => void;
};
export const ModelMetadataItem = memo(({ label, modelKey, extra, onClick }: ModelMetadataItemProps) => {
const { data: modelConfig } = useGetModelConfigQuery(modelKey ?? skipToken);
const value = useMemo(() => {
if (modelConfig) {
return `${modelConfig.name}${extra ?? ''}`;
}
return `${modelKey}${extra ?? ''}`;
}, [extra, modelConfig, modelKey]);
return <ImageMetadataItem label={label} value={value} onClick={onClick} />;
});
ModelMetadataItem.displayName = 'ModelMetadataItem';

View File

@ -3,8 +3,8 @@ import { z } from 'zod';
import {
zControlField,
zIPAdapterField,
zLoRAModelField,
zMainModelField,
zModelFieldBase,
zSDXLRefinerModelField,
zT2IAdapterField,
zVAEModelField,
@ -15,7 +15,7 @@ import {
// - https://github.com/colinhacks/zod/issues/2106
// - https://github.com/colinhacks/zod/issues/2854
export const zLoRAMetadataItem = z.object({
lora: zLoRAModelField.deepPartial(),
lora: zModelFieldBase.deepPartial(),
weight: z.number(),
});
const zControlNetMetadataItem = zControlField.deepPartial();

View File

@ -11,6 +11,8 @@ import {
} from 'features/controlAdapters/util/buildControlAdapter';
import { setHrfEnabled, setHrfMethod, setHrfStrength } from 'features/hrf/store/hrfSlice';
import { loraRecalled, lorasCleared } from 'features/lora/store/loraSlice';
import type { ModelIdentifier } from 'features/nodes/types/common';
import { isModelIdentifier } from 'features/nodes/types/common';
import type {
ControlNetMetadataItem,
CoreMetadata,
@ -37,13 +39,9 @@ import type { ParameterModel } from 'features/parameters/types/parameterSchemas'
import {
isParameterCFGRescaleMultiplier,
isParameterCFGScale,
isParameterControlNetModel,
isParameterHeight,
isParameterHRFEnabled,
isParameterHRFMethod,
isParameterIPAdapterModel,
isParameterLoRAModel,
isParameterModel,
isParameterNegativePrompt,
isParameterNegativeStylePromptSDXL,
isParameterPositivePrompt,
@ -56,7 +54,6 @@ import {
isParameterSeed,
isParameterSteps,
isParameterStrength,
isParameterVAEModel,
isParameterWidth,
} from 'features/parameters/types/parameterSchemas';
import {
@ -73,15 +70,20 @@ import {
import { isNil } from 'lodash-es';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { ALL_BASE_MODELS } from 'services/api/constants';
import {
controlNetModelsAdapterSelectors,
ipAdapterModelsAdapterSelectors,
loraModelsAdapterSelectors,
mainModelsAdapterSelectors,
t2iAdapterModelsAdapterSelectors,
useGetControlNetModelsQuery,
useGetIPAdapterModelsQuery,
useGetLoRAModelsQuery,
useGetMainModelsQuery,
useGetT2IAdapterModelsQuery,
useGetVaeModelsQuery,
vaeModelsAdapterSelectors,
} from 'services/api/endpoints/models';
import type { ImageDTO } from 'services/api/types';
import { v4 as uuidv4 } from 'uuid';
@ -278,21 +280,6 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall model with toast
*/
const recallModel = useCallback(
(model: unknown) => {
if (!isParameterModel(model)) {
parameterNotSetToast();
return;
}
dispatch(modelSelected(model));
parameterSetToast();
},
[dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall scheduler with toast
*/
@ -308,25 +295,6 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall vae model
*/
const recallVaeModel = useCallback(
(vae: unknown) => {
if (!isParameterVAEModel(vae) && !isNil(vae)) {
parameterNotSetToast();
return;
}
if (isNil(vae)) {
dispatch(vaeSelected(null));
} else {
dispatch(vaeSelected(vae));
}
parameterSetToast();
},
[dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall steps with toast
*/
@ -452,6 +420,95 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast]
);
const { data: mainModels } = useGetMainModelsQuery(ALL_BASE_MODELS);
const prepareMainModelMetadataItem = useCallback(
(model: ModelIdentifier) => {
const matchingModel = mainModels ? mainModelsAdapterSelectors.selectById(mainModels, model.key) : undefined;
if (!matchingModel) {
return { model: null, error: 'Model is not installed' };
}
return { model: matchingModel, error: null };
},
[mainModels]
);
/**
* Recall model with toast
*/
const recallModel = useCallback(
(model: unknown) => {
if (!isModelIdentifier(model)) {
parameterNotSetToast();
return;
}
const result = prepareMainModelMetadataItem(model);
if (!result.model) {
parameterNotSetToast(result.error);
return;
}
dispatch(modelSelected(result.model));
parameterSetToast();
},
[prepareMainModelMetadataItem, dispatch, parameterSetToast, parameterNotSetToast]
);
const { data: vaeModels } = useGetVaeModelsQuery();
const prepareVAEMetadataItem = useCallback(
(vae: ModelIdentifier, newModel?: ParameterModel) => {
const matchingModel = vaeModels ? vaeModelsAdapterSelectors.selectById(vaeModels, vae.key) : undefined;
if (!matchingModel) {
return { vae: null, error: 'VAE model is not installed' };
}
const isCompatibleBaseModel = matchingModel?.base === (newModel ?? model)?.base;
if (!isCompatibleBaseModel) {
return {
vae: null,
error: 'VAE incompatible with currently-selected model',
};
}
return { vae: matchingModel, error: null };
},
[model, vaeModels]
);
/**
* Recall vae model
*/
const recallVaeModel = useCallback(
(vae: unknown) => {
if (!isModelIdentifier(vae) && !isNil(vae)) {
parameterNotSetToast();
return;
}
if (isNil(vae)) {
dispatch(vaeSelected(null));
parameterSetToast();
return;
}
const result = prepareVAEMetadataItem(vae);
if (!result.vae) {
parameterNotSetToast(result.error);
return;
}
dispatch(vaeSelected(result.vae));
parameterSetToast();
},
[prepareVAEMetadataItem, dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall LoRA with toast
*/
@ -460,7 +517,7 @@ export const useRecallParameters = () => {
const prepareLoRAMetadataItem = useCallback(
(loraMetadataItem: LoRAMetadataItem, newModel?: ParameterModel) => {
if (!isParameterLoRAModel(loraMetadataItem.lora)) {
if (!isModelIdentifier(loraMetadataItem.lora)) {
return { lora: null, error: 'Invalid LoRA model' };
}
@ -510,7 +567,7 @@ export const useRecallParameters = () => {
const prepareControlNetMetadataItem = useCallback(
(controlnetMetadataItem: ControlNetMetadataItem, newModel?: ParameterModel) => {
if (!isParameterControlNetModel(controlnetMetadataItem.control_model)) {
if (!isModelIdentifier(controlnetMetadataItem.control_model)) {
return { controlnet: null, error: 'Invalid ControlNet model' };
}
@ -584,7 +641,7 @@ export const useRecallParameters = () => {
const prepareT2IAdapterMetadataItem = useCallback(
(t2iAdapterMetadataItem: T2IAdapterMetadataItem, newModel?: ParameterModel) => {
if (!isParameterControlNetModel(t2iAdapterMetadataItem.t2i_adapter_model)) {
if (!isModelIdentifier(t2iAdapterMetadataItem.t2i_adapter_model)) {
return { controlnet: null, error: 'Invalid ControlNet model' };
}
@ -657,7 +714,7 @@ export const useRecallParameters = () => {
const prepareIPAdapterMetadataItem = useCallback(
(ipAdapterMetadataItem: IPAdapterMetadataItem, newModel?: ParameterModel) => {
if (!isParameterIPAdapterModel(ipAdapterMetadataItem?.ip_adapter_model)) {
if (!isModelIdentifier(ipAdapterMetadataItem?.ip_adapter_model)) {
return { ipAdapter: null, error: 'Invalid IP Adapter model' };
}
@ -762,9 +819,12 @@ export const useRecallParameters = () => {
let newModel: ParameterModel | undefined = undefined;
if (isParameterModel(model)) {
newModel = model;
dispatch(modelSelected(model));
if (isModelIdentifier(model)) {
const result = prepareMainModelMetadataItem(model);
if (result.model) {
dispatch(modelSelected(result.model));
newModel = result.model;
}
}
if (isParameterCFGScale(cfg_scale)) {
@ -786,11 +846,14 @@ export const useRecallParameters = () => {
if (isParameterScheduler(scheduler)) {
dispatch(setScheduler(scheduler));
}
if (isParameterVAEModel(vae) || isNil(vae)) {
if (isModelIdentifier(vae) || isNil(vae)) {
if (isNil(vae)) {
dispatch(vaeSelected(null));
} else {
dispatch(vaeSelected(vae));
const result = prepareVAEMetadataItem(vae, newModel);
if (result.vae) {
dispatch(vaeSelected(result.vae));
}
}
}
@ -898,6 +961,8 @@ export const useRecallParameters = () => {
dispatch,
allParameterSetToast,
allParameterNotSetToast,
prepareMainModelMetadataItem,
prepareVAEMetadataItem,
prepareLoRAMetadataItem,
prepareControlNetMetadataItem,
prepareIPAdapterMetadataItem,