diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx index 5907ba0700..7eec7e1875 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx @@ -1,3 +1,4 @@ +import { isModelIdentifier } from 'features/nodes/types/common'; import type { ControlNetMetadataItem, CoreMetadata, @@ -6,15 +7,10 @@ import type { T2IAdapterMetadataItem, } from 'features/nodes/types/metadata'; import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters'; -import { - isParameterControlNetModel, - isParameterLoRAModel, - isParameterT2IAdapterModel, -} from 'features/parameters/types/parameterSchemas'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import ImageMetadataItem from './ImageMetadataItem'; +import ImageMetadataItem, { ModelMetadataItem, VAEMetadataItem } from './ImageMetadataItem'; type Props = { metadata?: CoreMetadata; @@ -147,19 +143,19 @@ const ImageMetadataActions = (props: Props) => { const validControlNets: ControlNetMetadataItem[] = useMemo(() => { return metadata?.controlnets - ? metadata.controlnets.filter((controlnet) => isParameterControlNetModel(controlnet.control_model)) + ? metadata.controlnets.filter((controlnet) => isModelIdentifier(controlnet.control_model)) : []; }, [metadata?.controlnets]); const validIPAdapters: IPAdapterMetadataItem[] = useMemo(() => { return metadata?.ipAdapters - ? metadata.ipAdapters.filter((ipAdapter) => isParameterControlNetModel(ipAdapter.ip_adapter_model)) + ? metadata.ipAdapters.filter((ipAdapter) => isModelIdentifier(ipAdapter.ip_adapter_model)) : []; }, [metadata?.ipAdapters]); const validT2IAdapters: T2IAdapterMetadataItem[] = useMemo(() => { return metadata?.t2iAdapters - ? metadata.t2iAdapters.filter((t2iAdapter) => isParameterT2IAdapterModel(t2iAdapter.t2i_adapter_model)) + ? metadata.t2iAdapters.filter((t2iAdapter) => isModelIdentifier(t2iAdapter.t2i_adapter_model)) : []; }, [metadata?.t2iAdapters]); @@ -209,7 +205,7 @@ const ImageMetadataActions = (props: Props) => { )} {metadata.model !== undefined && metadata.model !== null && metadata.model.key && ( - + )} {metadata.width && ( @@ -220,11 +216,7 @@ const ImageMetadataActions = (props: Props) => { {metadata.scheduler && ( )} - + {metadata.steps && ( )} @@ -264,38 +256,42 @@ const ImageMetadataActions = (props: Props) => { )} {metadata.loras && metadata.loras.map((lora, index) => { - if (isParameterLoRAModel(lora.lora)) { + if (isModelIdentifier(lora.lora)) { return ( - ); } })} {validControlNets.map((controlnet, index) => ( - ))} {validIPAdapters.map((ipAdapter, index) => ( - ))} {validT2IAdapters.map((t2iAdapter, index) => ( - ))} diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataItem.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataItem.tsx index c6dbd16269..7d17a2ad3d 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataItem.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataItem.tsx @@ -1,8 +1,10 @@ import { ExternalLink, Flex, IconButton, Text, Tooltip } from '@invoke-ai/ui-library'; -import { memo, useCallback } from 'react'; +import { skipToken } from '@reduxjs/toolkit/query'; +import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { IoArrowUndoCircleOutline } from 'react-icons/io5'; import { PiCopyBold } from 'react-icons/pi'; +import { useGetModelConfigQuery } from 'services/api/endpoints/models'; type MetadataItemProps = { isLink?: boolean; @@ -18,8 +20,9 @@ type MetadataItemProps = { */ const ImageMetadataItem = ({ label, value, onClick, isLink, labelPosition, withCopy = false }: MetadataItemProps) => { const { t } = useTranslation(); - - const handleCopy = useCallback(() => navigator.clipboard.writeText(value.toString()), [value]); + const handleCopy = useCallback(() => { + navigator.clipboard.writeText(value?.toString()); + }, [value]); if (!value) { return null; @@ -68,3 +71,40 @@ const ImageMetadataItem = ({ label, value, onClick, isLink, labelPosition, withC }; export default memo(ImageMetadataItem); + +type VAEMetadataItemProps = { + label: string; + modelKey?: string; + onClick: () => void; +}; + +export const VAEMetadataItem = memo(({ label, modelKey, onClick }: VAEMetadataItemProps) => { + const { data: modelConfig } = useGetModelConfigQuery(modelKey ?? skipToken); + + return ( + + ); +}); + +VAEMetadataItem.displayName = 'VAEMetadataItem'; + +type ModelMetadataItemProps = { + label: string; + modelKey?: string; + + extra?: string; + onClick: () => void; +}; + +export const ModelMetadataItem = memo(({ label, modelKey, extra, onClick }: ModelMetadataItemProps) => { + const { data: modelConfig } = useGetModelConfigQuery(modelKey ?? skipToken); + const value = useMemo(() => { + if (modelConfig) { + return `${modelConfig.name}${extra ?? ''}`; + } + return `${modelKey}${extra ?? ''}`; + }, [extra, modelConfig, modelKey]); + return ; +}); + +ModelMetadataItem.displayName = 'ModelMetadataItem'; diff --git a/invokeai/frontend/web/src/features/nodes/types/metadata.ts b/invokeai/frontend/web/src/features/nodes/types/metadata.ts index 0cc30499e3..493a0464b3 100644 --- a/invokeai/frontend/web/src/features/nodes/types/metadata.ts +++ b/invokeai/frontend/web/src/features/nodes/types/metadata.ts @@ -3,8 +3,8 @@ import { z } from 'zod'; import { zControlField, zIPAdapterField, - zLoRAModelField, zMainModelField, + zModelFieldBase, zSDXLRefinerModelField, zT2IAdapterField, zVAEModelField, @@ -15,7 +15,7 @@ import { // - https://github.com/colinhacks/zod/issues/2106 // - https://github.com/colinhacks/zod/issues/2854 export const zLoRAMetadataItem = z.object({ - lora: zLoRAModelField.deepPartial(), + lora: zModelFieldBase.deepPartial(), weight: z.number(), }); const zControlNetMetadataItem = zControlField.deepPartial(); diff --git a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts index c8b17816bb..0d464cd9b9 100644 --- a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts +++ b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts @@ -11,6 +11,8 @@ import { } from 'features/controlAdapters/util/buildControlAdapter'; import { setHrfEnabled, setHrfMethod, setHrfStrength } from 'features/hrf/store/hrfSlice'; import { loraRecalled, lorasCleared } from 'features/lora/store/loraSlice'; +import type { ModelIdentifier } from 'features/nodes/types/common'; +import { isModelIdentifier } from 'features/nodes/types/common'; import type { ControlNetMetadataItem, CoreMetadata, @@ -37,13 +39,9 @@ import type { ParameterModel } from 'features/parameters/types/parameterSchemas' import { isParameterCFGRescaleMultiplier, isParameterCFGScale, - isParameterControlNetModel, isParameterHeight, isParameterHRFEnabled, isParameterHRFMethod, - isParameterIPAdapterModel, - isParameterLoRAModel, - isParameterModel, isParameterNegativePrompt, isParameterNegativeStylePromptSDXL, isParameterPositivePrompt, @@ -56,7 +54,6 @@ import { isParameterSeed, isParameterSteps, isParameterStrength, - isParameterVAEModel, isParameterWidth, } from 'features/parameters/types/parameterSchemas'; import { @@ -73,15 +70,20 @@ import { import { isNil } from 'lodash-es'; import { useCallback } from 'react'; import { useTranslation } from 'react-i18next'; +import { ALL_BASE_MODELS } from 'services/api/constants'; import { controlNetModelsAdapterSelectors, ipAdapterModelsAdapterSelectors, loraModelsAdapterSelectors, + mainModelsAdapterSelectors, t2iAdapterModelsAdapterSelectors, useGetControlNetModelsQuery, useGetIPAdapterModelsQuery, useGetLoRAModelsQuery, + useGetMainModelsQuery, useGetT2IAdapterModelsQuery, + useGetVaeModelsQuery, + vaeModelsAdapterSelectors, } from 'services/api/endpoints/models'; import type { ImageDTO } from 'services/api/types'; import { v4 as uuidv4 } from 'uuid'; @@ -278,21 +280,6 @@ export const useRecallParameters = () => { [dispatch, parameterSetToast, parameterNotSetToast] ); - /** - * Recall model with toast - */ - const recallModel = useCallback( - (model: unknown) => { - if (!isParameterModel(model)) { - parameterNotSetToast(); - return; - } - dispatch(modelSelected(model)); - parameterSetToast(); - }, - [dispatch, parameterSetToast, parameterNotSetToast] - ); - /** * Recall scheduler with toast */ @@ -308,25 +295,6 @@ export const useRecallParameters = () => { [dispatch, parameterSetToast, parameterNotSetToast] ); - /** - * Recall vae model - */ - const recallVaeModel = useCallback( - (vae: unknown) => { - if (!isParameterVAEModel(vae) && !isNil(vae)) { - parameterNotSetToast(); - return; - } - if (isNil(vae)) { - dispatch(vaeSelected(null)); - } else { - dispatch(vaeSelected(vae)); - } - parameterSetToast(); - }, - [dispatch, parameterSetToast, parameterNotSetToast] - ); - /** * Recall steps with toast */ @@ -452,6 +420,95 @@ export const useRecallParameters = () => { [dispatch, parameterSetToast, parameterNotSetToast] ); + const { data: mainModels } = useGetMainModelsQuery(ALL_BASE_MODELS); + + const prepareMainModelMetadataItem = useCallback( + (model: ModelIdentifier) => { + const matchingModel = mainModels ? mainModelsAdapterSelectors.selectById(mainModels, model.key) : undefined; + + if (!matchingModel) { + return { model: null, error: 'Model is not installed' }; + } + + return { model: matchingModel, error: null }; + }, + [mainModels] + ); + + /** + * Recall model with toast + */ + const recallModel = useCallback( + (model: unknown) => { + if (!isModelIdentifier(model)) { + parameterNotSetToast(); + return; + } + + const result = prepareMainModelMetadataItem(model); + + if (!result.model) { + parameterNotSetToast(result.error); + return; + } + + dispatch(modelSelected(result.model)); + parameterSetToast(); + }, + [prepareMainModelMetadataItem, dispatch, parameterSetToast, parameterNotSetToast] + ); + + const { data: vaeModels } = useGetVaeModelsQuery(); + + const prepareVAEMetadataItem = useCallback( + (vae: ModelIdentifier, newModel?: ParameterModel) => { + const matchingModel = vaeModels ? vaeModelsAdapterSelectors.selectById(vaeModels, vae.key) : undefined; + if (!matchingModel) { + return { vae: null, error: 'VAE model is not installed' }; + } + const isCompatibleBaseModel = matchingModel?.base === (newModel ?? model)?.base; + + if (!isCompatibleBaseModel) { + return { + vae: null, + error: 'VAE incompatible with currently-selected model', + }; + } + + return { vae: matchingModel, error: null }; + }, + [model, vaeModels] + ); + + /** + * Recall vae model + */ + const recallVaeModel = useCallback( + (vae: unknown) => { + if (!isModelIdentifier(vae) && !isNil(vae)) { + parameterNotSetToast(); + return; + } + + if (isNil(vae)) { + dispatch(vaeSelected(null)); + parameterSetToast(); + return; + } + + const result = prepareVAEMetadataItem(vae); + + if (!result.vae) { + parameterNotSetToast(result.error); + return; + } + + dispatch(vaeSelected(result.vae)); + parameterSetToast(); + }, + [prepareVAEMetadataItem, dispatch, parameterSetToast, parameterNotSetToast] + ); + /** * Recall LoRA with toast */ @@ -460,7 +517,7 @@ export const useRecallParameters = () => { const prepareLoRAMetadataItem = useCallback( (loraMetadataItem: LoRAMetadataItem, newModel?: ParameterModel) => { - if (!isParameterLoRAModel(loraMetadataItem.lora)) { + if (!isModelIdentifier(loraMetadataItem.lora)) { return { lora: null, error: 'Invalid LoRA model' }; } @@ -510,7 +567,7 @@ export const useRecallParameters = () => { const prepareControlNetMetadataItem = useCallback( (controlnetMetadataItem: ControlNetMetadataItem, newModel?: ParameterModel) => { - if (!isParameterControlNetModel(controlnetMetadataItem.control_model)) { + if (!isModelIdentifier(controlnetMetadataItem.control_model)) { return { controlnet: null, error: 'Invalid ControlNet model' }; } @@ -584,7 +641,7 @@ export const useRecallParameters = () => { const prepareT2IAdapterMetadataItem = useCallback( (t2iAdapterMetadataItem: T2IAdapterMetadataItem, newModel?: ParameterModel) => { - if (!isParameterControlNetModel(t2iAdapterMetadataItem.t2i_adapter_model)) { + if (!isModelIdentifier(t2iAdapterMetadataItem.t2i_adapter_model)) { return { controlnet: null, error: 'Invalid ControlNet model' }; } @@ -657,7 +714,7 @@ export const useRecallParameters = () => { const prepareIPAdapterMetadataItem = useCallback( (ipAdapterMetadataItem: IPAdapterMetadataItem, newModel?: ParameterModel) => { - if (!isParameterIPAdapterModel(ipAdapterMetadataItem?.ip_adapter_model)) { + if (!isModelIdentifier(ipAdapterMetadataItem?.ip_adapter_model)) { return { ipAdapter: null, error: 'Invalid IP Adapter model' }; } @@ -762,9 +819,12 @@ export const useRecallParameters = () => { let newModel: ParameterModel | undefined = undefined; - if (isParameterModel(model)) { - newModel = model; - dispatch(modelSelected(model)); + if (isModelIdentifier(model)) { + const result = prepareMainModelMetadataItem(model); + if (result.model) { + dispatch(modelSelected(result.model)); + newModel = result.model; + } } if (isParameterCFGScale(cfg_scale)) { @@ -786,11 +846,14 @@ export const useRecallParameters = () => { if (isParameterScheduler(scheduler)) { dispatch(setScheduler(scheduler)); } - if (isParameterVAEModel(vae) || isNil(vae)) { + if (isModelIdentifier(vae) || isNil(vae)) { if (isNil(vae)) { dispatch(vaeSelected(null)); } else { - dispatch(vaeSelected(vae)); + const result = prepareVAEMetadataItem(vae, newModel); + if (result.vae) { + dispatch(vaeSelected(result.vae)); + } } } @@ -898,6 +961,8 @@ export const useRecallParameters = () => { dispatch, allParameterSetToast, allParameterNotSetToast, + prepareMainModelMetadataItem, + prepareVAEMetadataItem, prepareLoRAMetadataItem, prepareControlNetMetadataItem, prepareIPAdapterMetadataItem,