diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx
index 5907ba0700..7eec7e1875 100644
--- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx
@@ -1,3 +1,4 @@
+import { isModelIdentifier } from 'features/nodes/types/common';
import type {
ControlNetMetadataItem,
CoreMetadata,
@@ -6,15 +7,10 @@ import type {
T2IAdapterMetadataItem,
} from 'features/nodes/types/metadata';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
-import {
- isParameterControlNetModel,
- isParameterLoRAModel,
- isParameterT2IAdapterModel,
-} from 'features/parameters/types/parameterSchemas';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
-import ImageMetadataItem from './ImageMetadataItem';
+import ImageMetadataItem, { ModelMetadataItem, VAEMetadataItem } from './ImageMetadataItem';
type Props = {
metadata?: CoreMetadata;
@@ -147,19 +143,19 @@ const ImageMetadataActions = (props: Props) => {
const validControlNets: ControlNetMetadataItem[] = useMemo(() => {
return metadata?.controlnets
- ? metadata.controlnets.filter((controlnet) => isParameterControlNetModel(controlnet.control_model))
+ ? metadata.controlnets.filter((controlnet) => isModelIdentifier(controlnet.control_model))
: [];
}, [metadata?.controlnets]);
const validIPAdapters: IPAdapterMetadataItem[] = useMemo(() => {
return metadata?.ipAdapters
- ? metadata.ipAdapters.filter((ipAdapter) => isParameterControlNetModel(ipAdapter.ip_adapter_model))
+ ? metadata.ipAdapters.filter((ipAdapter) => isModelIdentifier(ipAdapter.ip_adapter_model))
: [];
}, [metadata?.ipAdapters]);
const validT2IAdapters: T2IAdapterMetadataItem[] = useMemo(() => {
return metadata?.t2iAdapters
- ? metadata.t2iAdapters.filter((t2iAdapter) => isParameterT2IAdapterModel(t2iAdapter.t2i_adapter_model))
+ ? metadata.t2iAdapters.filter((t2iAdapter) => isModelIdentifier(t2iAdapter.t2i_adapter_model))
: [];
}, [metadata?.t2iAdapters]);
@@ -209,7 +205,7 @@ const ImageMetadataActions = (props: Props) => {
)}
{metadata.model !== undefined && metadata.model !== null && metadata.model.key && (
-
+
)}
{metadata.width && (
@@ -220,11 +216,7 @@ const ImageMetadataActions = (props: Props) => {
{metadata.scheduler && (
)}
-
+
{metadata.steps && (
)}
@@ -264,38 +256,42 @@ const ImageMetadataActions = (props: Props) => {
)}
{metadata.loras &&
metadata.loras.map((lora, index) => {
- if (isParameterLoRAModel(lora.lora)) {
+ if (isModelIdentifier(lora.lora)) {
return (
-
);
}
})}
{validControlNets.map((controlnet, index) => (
-
))}
{validIPAdapters.map((ipAdapter, index) => (
-
))}
{validT2IAdapters.map((t2iAdapter, index) => (
-
))}
diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataItem.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataItem.tsx
index c6dbd16269..7d17a2ad3d 100644
--- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataItem.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataItem.tsx
@@ -1,8 +1,10 @@
import { ExternalLink, Flex, IconButton, Text, Tooltip } from '@invoke-ai/ui-library';
-import { memo, useCallback } from 'react';
+import { skipToken } from '@reduxjs/toolkit/query';
+import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
import { PiCopyBold } from 'react-icons/pi';
+import { useGetModelConfigQuery } from 'services/api/endpoints/models';
type MetadataItemProps = {
isLink?: boolean;
@@ -18,8 +20,9 @@ type MetadataItemProps = {
*/
const ImageMetadataItem = ({ label, value, onClick, isLink, labelPosition, withCopy = false }: MetadataItemProps) => {
const { t } = useTranslation();
-
- const handleCopy = useCallback(() => navigator.clipboard.writeText(value.toString()), [value]);
+ const handleCopy = useCallback(() => {
+ navigator.clipboard.writeText(value?.toString());
+ }, [value]);
if (!value) {
return null;
@@ -68,3 +71,40 @@ const ImageMetadataItem = ({ label, value, onClick, isLink, labelPosition, withC
};
export default memo(ImageMetadataItem);
+
+type VAEMetadataItemProps = {
+ label: string;
+ modelKey?: string;
+ onClick: () => void;
+};
+
+export const VAEMetadataItem = memo(({ label, modelKey, onClick }: VAEMetadataItemProps) => {
+ const { data: modelConfig } = useGetModelConfigQuery(modelKey ?? skipToken);
+
+ return (
+
+ );
+});
+
+VAEMetadataItem.displayName = 'VAEMetadataItem';
+
+type ModelMetadataItemProps = {
+ label: string;
+ modelKey?: string;
+
+ extra?: string;
+ onClick: () => void;
+};
+
+export const ModelMetadataItem = memo(({ label, modelKey, extra, onClick }: ModelMetadataItemProps) => {
+ const { data: modelConfig } = useGetModelConfigQuery(modelKey ?? skipToken);
+ const value = useMemo(() => {
+ if (modelConfig) {
+ return `${modelConfig.name}${extra ?? ''}`;
+ }
+ return `${modelKey}${extra ?? ''}`;
+ }, [extra, modelConfig, modelKey]);
+ return ;
+});
+
+ModelMetadataItem.displayName = 'ModelMetadataItem';
diff --git a/invokeai/frontend/web/src/features/nodes/types/metadata.ts b/invokeai/frontend/web/src/features/nodes/types/metadata.ts
index 0cc30499e3..493a0464b3 100644
--- a/invokeai/frontend/web/src/features/nodes/types/metadata.ts
+++ b/invokeai/frontend/web/src/features/nodes/types/metadata.ts
@@ -3,8 +3,8 @@ import { z } from 'zod';
import {
zControlField,
zIPAdapterField,
- zLoRAModelField,
zMainModelField,
+ zModelFieldBase,
zSDXLRefinerModelField,
zT2IAdapterField,
zVAEModelField,
@@ -15,7 +15,7 @@ import {
// - https://github.com/colinhacks/zod/issues/2106
// - https://github.com/colinhacks/zod/issues/2854
export const zLoRAMetadataItem = z.object({
- lora: zLoRAModelField.deepPartial(),
+ lora: zModelFieldBase.deepPartial(),
weight: z.number(),
});
const zControlNetMetadataItem = zControlField.deepPartial();
diff --git a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts
index c8b17816bb..0d464cd9b9 100644
--- a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts
+++ b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts
@@ -11,6 +11,8 @@ import {
} from 'features/controlAdapters/util/buildControlAdapter';
import { setHrfEnabled, setHrfMethod, setHrfStrength } from 'features/hrf/store/hrfSlice';
import { loraRecalled, lorasCleared } from 'features/lora/store/loraSlice';
+import type { ModelIdentifier } from 'features/nodes/types/common';
+import { isModelIdentifier } from 'features/nodes/types/common';
import type {
ControlNetMetadataItem,
CoreMetadata,
@@ -37,13 +39,9 @@ import type { ParameterModel } from 'features/parameters/types/parameterSchemas'
import {
isParameterCFGRescaleMultiplier,
isParameterCFGScale,
- isParameterControlNetModel,
isParameterHeight,
isParameterHRFEnabled,
isParameterHRFMethod,
- isParameterIPAdapterModel,
- isParameterLoRAModel,
- isParameterModel,
isParameterNegativePrompt,
isParameterNegativeStylePromptSDXL,
isParameterPositivePrompt,
@@ -56,7 +54,6 @@ import {
isParameterSeed,
isParameterSteps,
isParameterStrength,
- isParameterVAEModel,
isParameterWidth,
} from 'features/parameters/types/parameterSchemas';
import {
@@ -73,15 +70,20 @@ import {
import { isNil } from 'lodash-es';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
+import { ALL_BASE_MODELS } from 'services/api/constants';
import {
controlNetModelsAdapterSelectors,
ipAdapterModelsAdapterSelectors,
loraModelsAdapterSelectors,
+ mainModelsAdapterSelectors,
t2iAdapterModelsAdapterSelectors,
useGetControlNetModelsQuery,
useGetIPAdapterModelsQuery,
useGetLoRAModelsQuery,
+ useGetMainModelsQuery,
useGetT2IAdapterModelsQuery,
+ useGetVaeModelsQuery,
+ vaeModelsAdapterSelectors,
} from 'services/api/endpoints/models';
import type { ImageDTO } from 'services/api/types';
import { v4 as uuidv4 } from 'uuid';
@@ -278,21 +280,6 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast]
);
- /**
- * Recall model with toast
- */
- const recallModel = useCallback(
- (model: unknown) => {
- if (!isParameterModel(model)) {
- parameterNotSetToast();
- return;
- }
- dispatch(modelSelected(model));
- parameterSetToast();
- },
- [dispatch, parameterSetToast, parameterNotSetToast]
- );
-
/**
* Recall scheduler with toast
*/
@@ -308,25 +295,6 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast]
);
- /**
- * Recall vae model
- */
- const recallVaeModel = useCallback(
- (vae: unknown) => {
- if (!isParameterVAEModel(vae) && !isNil(vae)) {
- parameterNotSetToast();
- return;
- }
- if (isNil(vae)) {
- dispatch(vaeSelected(null));
- } else {
- dispatch(vaeSelected(vae));
- }
- parameterSetToast();
- },
- [dispatch, parameterSetToast, parameterNotSetToast]
- );
-
/**
* Recall steps with toast
*/
@@ -452,6 +420,95 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast]
);
+ const { data: mainModels } = useGetMainModelsQuery(ALL_BASE_MODELS);
+
+ const prepareMainModelMetadataItem = useCallback(
+ (model: ModelIdentifier) => {
+ const matchingModel = mainModels ? mainModelsAdapterSelectors.selectById(mainModels, model.key) : undefined;
+
+ if (!matchingModel) {
+ return { model: null, error: 'Model is not installed' };
+ }
+
+ return { model: matchingModel, error: null };
+ },
+ [mainModels]
+ );
+
+ /**
+ * Recall model with toast
+ */
+ const recallModel = useCallback(
+ (model: unknown) => {
+ if (!isModelIdentifier(model)) {
+ parameterNotSetToast();
+ return;
+ }
+
+ const result = prepareMainModelMetadataItem(model);
+
+ if (!result.model) {
+ parameterNotSetToast(result.error);
+ return;
+ }
+
+ dispatch(modelSelected(result.model));
+ parameterSetToast();
+ },
+ [prepareMainModelMetadataItem, dispatch, parameterSetToast, parameterNotSetToast]
+ );
+
+ const { data: vaeModels } = useGetVaeModelsQuery();
+
+ const prepareVAEMetadataItem = useCallback(
+ (vae: ModelIdentifier, newModel?: ParameterModel) => {
+ const matchingModel = vaeModels ? vaeModelsAdapterSelectors.selectById(vaeModels, vae.key) : undefined;
+ if (!matchingModel) {
+ return { vae: null, error: 'VAE model is not installed' };
+ }
+ const isCompatibleBaseModel = matchingModel?.base === (newModel ?? model)?.base;
+
+ if (!isCompatibleBaseModel) {
+ return {
+ vae: null,
+ error: 'VAE incompatible with currently-selected model',
+ };
+ }
+
+ return { vae: matchingModel, error: null };
+ },
+ [model, vaeModels]
+ );
+
+ /**
+ * Recall vae model
+ */
+ const recallVaeModel = useCallback(
+ (vae: unknown) => {
+ if (!isModelIdentifier(vae) && !isNil(vae)) {
+ parameterNotSetToast();
+ return;
+ }
+
+ if (isNil(vae)) {
+ dispatch(vaeSelected(null));
+ parameterSetToast();
+ return;
+ }
+
+ const result = prepareVAEMetadataItem(vae);
+
+ if (!result.vae) {
+ parameterNotSetToast(result.error);
+ return;
+ }
+
+ dispatch(vaeSelected(result.vae));
+ parameterSetToast();
+ },
+ [prepareVAEMetadataItem, dispatch, parameterSetToast, parameterNotSetToast]
+ );
+
/**
* Recall LoRA with toast
*/
@@ -460,7 +517,7 @@ export const useRecallParameters = () => {
const prepareLoRAMetadataItem = useCallback(
(loraMetadataItem: LoRAMetadataItem, newModel?: ParameterModel) => {
- if (!isParameterLoRAModel(loraMetadataItem.lora)) {
+ if (!isModelIdentifier(loraMetadataItem.lora)) {
return { lora: null, error: 'Invalid LoRA model' };
}
@@ -510,7 +567,7 @@ export const useRecallParameters = () => {
const prepareControlNetMetadataItem = useCallback(
(controlnetMetadataItem: ControlNetMetadataItem, newModel?: ParameterModel) => {
- if (!isParameterControlNetModel(controlnetMetadataItem.control_model)) {
+ if (!isModelIdentifier(controlnetMetadataItem.control_model)) {
return { controlnet: null, error: 'Invalid ControlNet model' };
}
@@ -584,7 +641,7 @@ export const useRecallParameters = () => {
const prepareT2IAdapterMetadataItem = useCallback(
(t2iAdapterMetadataItem: T2IAdapterMetadataItem, newModel?: ParameterModel) => {
- if (!isParameterControlNetModel(t2iAdapterMetadataItem.t2i_adapter_model)) {
+ if (!isModelIdentifier(t2iAdapterMetadataItem.t2i_adapter_model)) {
return { controlnet: null, error: 'Invalid ControlNet model' };
}
@@ -657,7 +714,7 @@ export const useRecallParameters = () => {
const prepareIPAdapterMetadataItem = useCallback(
(ipAdapterMetadataItem: IPAdapterMetadataItem, newModel?: ParameterModel) => {
- if (!isParameterIPAdapterModel(ipAdapterMetadataItem?.ip_adapter_model)) {
+ if (!isModelIdentifier(ipAdapterMetadataItem?.ip_adapter_model)) {
return { ipAdapter: null, error: 'Invalid IP Adapter model' };
}
@@ -762,9 +819,12 @@ export const useRecallParameters = () => {
let newModel: ParameterModel | undefined = undefined;
- if (isParameterModel(model)) {
- newModel = model;
- dispatch(modelSelected(model));
+ if (isModelIdentifier(model)) {
+ const result = prepareMainModelMetadataItem(model);
+ if (result.model) {
+ dispatch(modelSelected(result.model));
+ newModel = result.model;
+ }
}
if (isParameterCFGScale(cfg_scale)) {
@@ -786,11 +846,14 @@ export const useRecallParameters = () => {
if (isParameterScheduler(scheduler)) {
dispatch(setScheduler(scheduler));
}
- if (isParameterVAEModel(vae) || isNil(vae)) {
+ if (isModelIdentifier(vae) || isNil(vae)) {
if (isNil(vae)) {
dispatch(vaeSelected(null));
} else {
- dispatch(vaeSelected(vae));
+ const result = prepareVAEMetadataItem(vae, newModel);
+ if (result.vae) {
+ dispatch(vaeSelected(result.vae));
+ }
}
}
@@ -898,6 +961,8 @@ export const useRecallParameters = () => {
dispatch,
allParameterSetToast,
allParameterNotSetToast,
+ prepareMainModelMetadataItem,
+ prepareVAEMetadataItem,
prepareLoRAMetadataItem,
prepareControlNetMetadataItem,
prepareIPAdapterMetadataItem,