mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(ui): handle new model format for metadata
This commit is contained in:
parent
9d9b417432
commit
b59d23d608
@ -1,3 +1,4 @@
|
||||
import { isModelIdentifier } from 'features/nodes/types/common';
|
||||
import type {
|
||||
ControlNetMetadataItem,
|
||||
CoreMetadata,
|
||||
@ -6,15 +7,10 @@ import type {
|
||||
T2IAdapterMetadataItem,
|
||||
} from 'features/nodes/types/metadata';
|
||||
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
||||
import {
|
||||
isParameterControlNetModel,
|
||||
isParameterLoRAModel,
|
||||
isParameterT2IAdapterModel,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import ImageMetadataItem from './ImageMetadataItem';
|
||||
import ImageMetadataItem, { ModelMetadataItem, VAEMetadataItem } from './ImageMetadataItem';
|
||||
|
||||
type Props = {
|
||||
metadata?: CoreMetadata;
|
||||
@ -147,19 +143,19 @@ const ImageMetadataActions = (props: Props) => {
|
||||
|
||||
const validControlNets: ControlNetMetadataItem[] = useMemo(() => {
|
||||
return metadata?.controlnets
|
||||
? metadata.controlnets.filter((controlnet) => isParameterControlNetModel(controlnet.control_model))
|
||||
? metadata.controlnets.filter((controlnet) => isModelIdentifier(controlnet.control_model))
|
||||
: [];
|
||||
}, [metadata?.controlnets]);
|
||||
|
||||
const validIPAdapters: IPAdapterMetadataItem[] = useMemo(() => {
|
||||
return metadata?.ipAdapters
|
||||
? metadata.ipAdapters.filter((ipAdapter) => isParameterControlNetModel(ipAdapter.ip_adapter_model))
|
||||
? metadata.ipAdapters.filter((ipAdapter) => isModelIdentifier(ipAdapter.ip_adapter_model))
|
||||
: [];
|
||||
}, [metadata?.ipAdapters]);
|
||||
|
||||
const validT2IAdapters: T2IAdapterMetadataItem[] = useMemo(() => {
|
||||
return metadata?.t2iAdapters
|
||||
? metadata.t2iAdapters.filter((t2iAdapter) => isParameterT2IAdapterModel(t2iAdapter.t2i_adapter_model))
|
||||
? metadata.t2iAdapters.filter((t2iAdapter) => isModelIdentifier(t2iAdapter.t2i_adapter_model))
|
||||
: [];
|
||||
}, [metadata?.t2iAdapters]);
|
||||
|
||||
@ -209,7 +205,7 @@ const ImageMetadataActions = (props: Props) => {
|
||||
<ImageMetadataItem label={t('metadata.seed')} value={metadata.seed} onClick={handleRecallSeed} />
|
||||
)}
|
||||
{metadata.model !== undefined && metadata.model !== null && metadata.model.key && (
|
||||
<ImageMetadataItem label={t('metadata.model')} value={metadata.model.key} onClick={handleRecallModel} />
|
||||
<ModelMetadataItem label={t('metadata.model')} modelKey={metadata.model.key} onClick={handleRecallModel} />
|
||||
)}
|
||||
{metadata.width && (
|
||||
<ImageMetadataItem label={t('metadata.width')} value={metadata.width} onClick={handleRecallWidth} />
|
||||
@ -220,11 +216,7 @@ const ImageMetadataActions = (props: Props) => {
|
||||
{metadata.scheduler && (
|
||||
<ImageMetadataItem label={t('metadata.scheduler')} value={metadata.scheduler} onClick={handleRecallScheduler} />
|
||||
)}
|
||||
<ImageMetadataItem
|
||||
label={t('metadata.vae')}
|
||||
value={metadata.vae?.key ?? 'Default'}
|
||||
onClick={handleRecallVaeModel}
|
||||
/>
|
||||
<VAEMetadataItem label={t('metadata.vae')} modelKey={metadata.vae?.key} onClick={handleRecallVaeModel} />
|
||||
{metadata.steps && (
|
||||
<ImageMetadataItem label={t('metadata.steps')} value={metadata.steps} onClick={handleRecallSteps} />
|
||||
)}
|
||||
@ -264,38 +256,42 @@ const ImageMetadataActions = (props: Props) => {
|
||||
)}
|
||||
{metadata.loras &&
|
||||
metadata.loras.map((lora, index) => {
|
||||
if (isParameterLoRAModel(lora.lora)) {
|
||||
if (isModelIdentifier(lora.lora)) {
|
||||
return (
|
||||
<ImageMetadataItem
|
||||
<ModelMetadataItem
|
||||
key={index}
|
||||
label="LoRA"
|
||||
value={`${lora.lora.key} - ${lora.weight}`}
|
||||
modelKey={lora.lora.key}
|
||||
extra={` - ${lora.weight}`}
|
||||
onClick={handleRecallLoRA.bind(null, lora)}
|
||||
/>
|
||||
);
|
||||
}
|
||||
})}
|
||||
{validControlNets.map((controlnet, index) => (
|
||||
<ImageMetadataItem
|
||||
<ModelMetadataItem
|
||||
key={index}
|
||||
label="ControlNet"
|
||||
value={`${controlnet.control_model?.key} - ${controlnet.control_weight}`}
|
||||
modelKey={controlnet.control_model?.key}
|
||||
extra={` - ${controlnet.control_weight}`}
|
||||
onClick={handleRecallControlNet.bind(null, controlnet)}
|
||||
/>
|
||||
))}
|
||||
{validIPAdapters.map((ipAdapter, index) => (
|
||||
<ImageMetadataItem
|
||||
<ModelMetadataItem
|
||||
key={index}
|
||||
label="IP Adapter"
|
||||
value={`${ipAdapter.ip_adapter_model?.key} - ${ipAdapter.weight}`}
|
||||
modelKey={ipAdapter.ip_adapter_model?.key}
|
||||
extra={` - ${ipAdapter.weight}`}
|
||||
onClick={handleRecallIPAdapter.bind(null, ipAdapter)}
|
||||
/>
|
||||
))}
|
||||
{validT2IAdapters.map((t2iAdapter, index) => (
|
||||
<ImageMetadataItem
|
||||
<ModelMetadataItem
|
||||
key={index}
|
||||
label="T2I Adapter"
|
||||
value={`${t2iAdapter.t2i_adapter_model?.key} - ${t2iAdapter.weight}`}
|
||||
modelKey={t2iAdapter.t2i_adapter_model?.key}
|
||||
extra={` - ${t2iAdapter.weight}`}
|
||||
onClick={handleRecallT2IAdapter.bind(null, t2iAdapter)}
|
||||
/>
|
||||
))}
|
||||
|
@ -1,8 +1,10 @@
|
||||
import { ExternalLink, Flex, IconButton, Text, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
|
||||
import { PiCopyBold } from 'react-icons/pi';
|
||||
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||
|
||||
type MetadataItemProps = {
|
||||
isLink?: boolean;
|
||||
@ -18,8 +20,9 @@ type MetadataItemProps = {
|
||||
*/
|
||||
const ImageMetadataItem = ({ label, value, onClick, isLink, labelPosition, withCopy = false }: MetadataItemProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleCopy = useCallback(() => navigator.clipboard.writeText(value.toString()), [value]);
|
||||
const handleCopy = useCallback(() => {
|
||||
navigator.clipboard.writeText(value?.toString());
|
||||
}, [value]);
|
||||
|
||||
if (!value) {
|
||||
return null;
|
||||
@ -68,3 +71,40 @@ const ImageMetadataItem = ({ label, value, onClick, isLink, labelPosition, withC
|
||||
};
|
||||
|
||||
export default memo(ImageMetadataItem);
|
||||
|
||||
type VAEMetadataItemProps = {
|
||||
label: string;
|
||||
modelKey?: string;
|
||||
onClick: () => void;
|
||||
};
|
||||
|
||||
export const VAEMetadataItem = memo(({ label, modelKey, onClick }: VAEMetadataItemProps) => {
|
||||
const { data: modelConfig } = useGetModelConfigQuery(modelKey ?? skipToken);
|
||||
|
||||
return (
|
||||
<ImageMetadataItem label={label} value={modelKey ? modelConfig?.name ?? modelKey : 'Default'} onClick={onClick} />
|
||||
);
|
||||
});
|
||||
|
||||
VAEMetadataItem.displayName = 'VAEMetadataItem';
|
||||
|
||||
type ModelMetadataItemProps = {
|
||||
label: string;
|
||||
modelKey?: string;
|
||||
|
||||
extra?: string;
|
||||
onClick: () => void;
|
||||
};
|
||||
|
||||
export const ModelMetadataItem = memo(({ label, modelKey, extra, onClick }: ModelMetadataItemProps) => {
|
||||
const { data: modelConfig } = useGetModelConfigQuery(modelKey ?? skipToken);
|
||||
const value = useMemo(() => {
|
||||
if (modelConfig) {
|
||||
return `${modelConfig.name}${extra ?? ''}`;
|
||||
}
|
||||
return `${modelKey}${extra ?? ''}`;
|
||||
}, [extra, modelConfig, modelKey]);
|
||||
return <ImageMetadataItem label={label} value={value} onClick={onClick} />;
|
||||
});
|
||||
|
||||
ModelMetadataItem.displayName = 'ModelMetadataItem';
|
||||
|
@ -3,8 +3,8 @@ import { z } from 'zod';
|
||||
import {
|
||||
zControlField,
|
||||
zIPAdapterField,
|
||||
zLoRAModelField,
|
||||
zMainModelField,
|
||||
zModelFieldBase,
|
||||
zSDXLRefinerModelField,
|
||||
zT2IAdapterField,
|
||||
zVAEModelField,
|
||||
@ -15,7 +15,7 @@ import {
|
||||
// - https://github.com/colinhacks/zod/issues/2106
|
||||
// - https://github.com/colinhacks/zod/issues/2854
|
||||
export const zLoRAMetadataItem = z.object({
|
||||
lora: zLoRAModelField.deepPartial(),
|
||||
lora: zModelFieldBase.deepPartial(),
|
||||
weight: z.number(),
|
||||
});
|
||||
const zControlNetMetadataItem = zControlField.deepPartial();
|
||||
|
@ -11,6 +11,8 @@ import {
|
||||
} from 'features/controlAdapters/util/buildControlAdapter';
|
||||
import { setHrfEnabled, setHrfMethod, setHrfStrength } from 'features/hrf/store/hrfSlice';
|
||||
import { loraRecalled, lorasCleared } from 'features/lora/store/loraSlice';
|
||||
import type { ModelIdentifier } from 'features/nodes/types/common';
|
||||
import { isModelIdentifier } from 'features/nodes/types/common';
|
||||
import type {
|
||||
ControlNetMetadataItem,
|
||||
CoreMetadata,
|
||||
@ -37,13 +39,9 @@ import type { ParameterModel } from 'features/parameters/types/parameterSchemas'
|
||||
import {
|
||||
isParameterCFGRescaleMultiplier,
|
||||
isParameterCFGScale,
|
||||
isParameterControlNetModel,
|
||||
isParameterHeight,
|
||||
isParameterHRFEnabled,
|
||||
isParameterHRFMethod,
|
||||
isParameterIPAdapterModel,
|
||||
isParameterLoRAModel,
|
||||
isParameterModel,
|
||||
isParameterNegativePrompt,
|
||||
isParameterNegativeStylePromptSDXL,
|
||||
isParameterPositivePrompt,
|
||||
@ -56,7 +54,6 @@ import {
|
||||
isParameterSeed,
|
||||
isParameterSteps,
|
||||
isParameterStrength,
|
||||
isParameterVAEModel,
|
||||
isParameterWidth,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import {
|
||||
@ -73,15 +70,20 @@ import {
|
||||
import { isNil } from 'lodash-es';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||
import {
|
||||
controlNetModelsAdapterSelectors,
|
||||
ipAdapterModelsAdapterSelectors,
|
||||
loraModelsAdapterSelectors,
|
||||
mainModelsAdapterSelectors,
|
||||
t2iAdapterModelsAdapterSelectors,
|
||||
useGetControlNetModelsQuery,
|
||||
useGetIPAdapterModelsQuery,
|
||||
useGetLoRAModelsQuery,
|
||||
useGetMainModelsQuery,
|
||||
useGetT2IAdapterModelsQuery,
|
||||
useGetVaeModelsQuery,
|
||||
vaeModelsAdapterSelectors,
|
||||
} from 'services/api/endpoints/models';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
@ -278,21 +280,6 @@ export const useRecallParameters = () => {
|
||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall model with toast
|
||||
*/
|
||||
const recallModel = useCallback(
|
||||
(model: unknown) => {
|
||||
if (!isParameterModel(model)) {
|
||||
parameterNotSetToast();
|
||||
return;
|
||||
}
|
||||
dispatch(modelSelected(model));
|
||||
parameterSetToast();
|
||||
},
|
||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall scheduler with toast
|
||||
*/
|
||||
@ -308,25 +295,6 @@ export const useRecallParameters = () => {
|
||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall vae model
|
||||
*/
|
||||
const recallVaeModel = useCallback(
|
||||
(vae: unknown) => {
|
||||
if (!isParameterVAEModel(vae) && !isNil(vae)) {
|
||||
parameterNotSetToast();
|
||||
return;
|
||||
}
|
||||
if (isNil(vae)) {
|
||||
dispatch(vaeSelected(null));
|
||||
} else {
|
||||
dispatch(vaeSelected(vae));
|
||||
}
|
||||
parameterSetToast();
|
||||
},
|
||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall steps with toast
|
||||
*/
|
||||
@ -452,6 +420,95 @@ export const useRecallParameters = () => {
|
||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
const { data: mainModels } = useGetMainModelsQuery(ALL_BASE_MODELS);
|
||||
|
||||
const prepareMainModelMetadataItem = useCallback(
|
||||
(model: ModelIdentifier) => {
|
||||
const matchingModel = mainModels ? mainModelsAdapterSelectors.selectById(mainModels, model.key) : undefined;
|
||||
|
||||
if (!matchingModel) {
|
||||
return { model: null, error: 'Model is not installed' };
|
||||
}
|
||||
|
||||
return { model: matchingModel, error: null };
|
||||
},
|
||||
[mainModels]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall model with toast
|
||||
*/
|
||||
const recallModel = useCallback(
|
||||
(model: unknown) => {
|
||||
if (!isModelIdentifier(model)) {
|
||||
parameterNotSetToast();
|
||||
return;
|
||||
}
|
||||
|
||||
const result = prepareMainModelMetadataItem(model);
|
||||
|
||||
if (!result.model) {
|
||||
parameterNotSetToast(result.error);
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(modelSelected(result.model));
|
||||
parameterSetToast();
|
||||
},
|
||||
[prepareMainModelMetadataItem, dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
const { data: vaeModels } = useGetVaeModelsQuery();
|
||||
|
||||
const prepareVAEMetadataItem = useCallback(
|
||||
(vae: ModelIdentifier, newModel?: ParameterModel) => {
|
||||
const matchingModel = vaeModels ? vaeModelsAdapterSelectors.selectById(vaeModels, vae.key) : undefined;
|
||||
if (!matchingModel) {
|
||||
return { vae: null, error: 'VAE model is not installed' };
|
||||
}
|
||||
const isCompatibleBaseModel = matchingModel?.base === (newModel ?? model)?.base;
|
||||
|
||||
if (!isCompatibleBaseModel) {
|
||||
return {
|
||||
vae: null,
|
||||
error: 'VAE incompatible with currently-selected model',
|
||||
};
|
||||
}
|
||||
|
||||
return { vae: matchingModel, error: null };
|
||||
},
|
||||
[model, vaeModels]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall vae model
|
||||
*/
|
||||
const recallVaeModel = useCallback(
|
||||
(vae: unknown) => {
|
||||
if (!isModelIdentifier(vae) && !isNil(vae)) {
|
||||
parameterNotSetToast();
|
||||
return;
|
||||
}
|
||||
|
||||
if (isNil(vae)) {
|
||||
dispatch(vaeSelected(null));
|
||||
parameterSetToast();
|
||||
return;
|
||||
}
|
||||
|
||||
const result = prepareVAEMetadataItem(vae);
|
||||
|
||||
if (!result.vae) {
|
||||
parameterNotSetToast(result.error);
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(vaeSelected(result.vae));
|
||||
parameterSetToast();
|
||||
},
|
||||
[prepareVAEMetadataItem, dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall LoRA with toast
|
||||
*/
|
||||
@ -460,7 +517,7 @@ export const useRecallParameters = () => {
|
||||
|
||||
const prepareLoRAMetadataItem = useCallback(
|
||||
(loraMetadataItem: LoRAMetadataItem, newModel?: ParameterModel) => {
|
||||
if (!isParameterLoRAModel(loraMetadataItem.lora)) {
|
||||
if (!isModelIdentifier(loraMetadataItem.lora)) {
|
||||
return { lora: null, error: 'Invalid LoRA model' };
|
||||
}
|
||||
|
||||
@ -510,7 +567,7 @@ export const useRecallParameters = () => {
|
||||
|
||||
const prepareControlNetMetadataItem = useCallback(
|
||||
(controlnetMetadataItem: ControlNetMetadataItem, newModel?: ParameterModel) => {
|
||||
if (!isParameterControlNetModel(controlnetMetadataItem.control_model)) {
|
||||
if (!isModelIdentifier(controlnetMetadataItem.control_model)) {
|
||||
return { controlnet: null, error: 'Invalid ControlNet model' };
|
||||
}
|
||||
|
||||
@ -584,7 +641,7 @@ export const useRecallParameters = () => {
|
||||
|
||||
const prepareT2IAdapterMetadataItem = useCallback(
|
||||
(t2iAdapterMetadataItem: T2IAdapterMetadataItem, newModel?: ParameterModel) => {
|
||||
if (!isParameterControlNetModel(t2iAdapterMetadataItem.t2i_adapter_model)) {
|
||||
if (!isModelIdentifier(t2iAdapterMetadataItem.t2i_adapter_model)) {
|
||||
return { controlnet: null, error: 'Invalid ControlNet model' };
|
||||
}
|
||||
|
||||
@ -657,7 +714,7 @@ export const useRecallParameters = () => {
|
||||
|
||||
const prepareIPAdapterMetadataItem = useCallback(
|
||||
(ipAdapterMetadataItem: IPAdapterMetadataItem, newModel?: ParameterModel) => {
|
||||
if (!isParameterIPAdapterModel(ipAdapterMetadataItem?.ip_adapter_model)) {
|
||||
if (!isModelIdentifier(ipAdapterMetadataItem?.ip_adapter_model)) {
|
||||
return { ipAdapter: null, error: 'Invalid IP Adapter model' };
|
||||
}
|
||||
|
||||
@ -762,9 +819,12 @@ export const useRecallParameters = () => {
|
||||
|
||||
let newModel: ParameterModel | undefined = undefined;
|
||||
|
||||
if (isParameterModel(model)) {
|
||||
newModel = model;
|
||||
dispatch(modelSelected(model));
|
||||
if (isModelIdentifier(model)) {
|
||||
const result = prepareMainModelMetadataItem(model);
|
||||
if (result.model) {
|
||||
dispatch(modelSelected(result.model));
|
||||
newModel = result.model;
|
||||
}
|
||||
}
|
||||
|
||||
if (isParameterCFGScale(cfg_scale)) {
|
||||
@ -786,11 +846,14 @@ export const useRecallParameters = () => {
|
||||
if (isParameterScheduler(scheduler)) {
|
||||
dispatch(setScheduler(scheduler));
|
||||
}
|
||||
if (isParameterVAEModel(vae) || isNil(vae)) {
|
||||
if (isModelIdentifier(vae) || isNil(vae)) {
|
||||
if (isNil(vae)) {
|
||||
dispatch(vaeSelected(null));
|
||||
} else {
|
||||
dispatch(vaeSelected(vae));
|
||||
const result = prepareVAEMetadataItem(vae, newModel);
|
||||
if (result.vae) {
|
||||
dispatch(vaeSelected(result.vae));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -898,6 +961,8 @@ export const useRecallParameters = () => {
|
||||
dispatch,
|
||||
allParameterSetToast,
|
||||
allParameterNotSetToast,
|
||||
prepareMainModelMetadataItem,
|
||||
prepareVAEMetadataItem,
|
||||
prepareLoRAMetadataItem,
|
||||
prepareControlNetMetadataItem,
|
||||
prepareIPAdapterMetadataItem,
|
||||
|
Loading…
Reference in New Issue
Block a user