diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json
index b31fa5d93c..12d06ee224 100644
--- a/invokeai/frontend/web/public/locales/en.json
+++ b/invokeai/frontend/web/public/locales/en.json
@@ -656,6 +656,7 @@
}
},
"metadata": {
+ "allPrompts": "All Prompts",
"cfgScale": "CFG scale",
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)",
"createdBy": "Created By",
@@ -664,6 +665,7 @@
"height": "Height",
"hiresFix": "High Resolution Optimization",
"imageDetails": "Image Details",
+ "imageDimensions": "Image Dimensions",
"initImage": "Initial image",
"metadata": "Metadata",
"model": "Model",
@@ -671,9 +673,11 @@
"noImageDetails": "No image details found",
"noMetaData": "No metadata found",
"noRecallParameters": "No parameters to recall found",
+ "parameterSet": "Parameter {{parameter}} set",
"perlin": "Perlin Noise",
"positivePrompt": "Positive Prompt",
"recallParameters": "Recall Parameters",
+ "recallParameter": "Recall {{label}}",
"scheduler": "Scheduler",
"seamless": "Seamless",
"seed": "Seed",
@@ -1381,8 +1385,8 @@
"nodesNotValidJSON": "Not a valid JSON",
"nodesSaved": "Nodes Saved",
"nodesUnrecognizedTypes": "Cannot load. Graph has unrecognized types",
- "parameterNotSet": "Parameter not set",
- "parameterSet": "Parameter set",
+ "parameterNotSet": "{{parameter}} not set",
+ "parameterSet": "{{parameter}} set",
"parametersFailed": "Problem loading parameters",
"parametersFailedDesc": "Unable to load init image.",
"parametersNotSet": "Parameters Not Set",
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/batchEnqueued.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/batchEnqueued.ts
index 6419f840ec..07a1039bef 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/batchEnqueued.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/batchEnqueued.ts
@@ -1,6 +1,6 @@
-import { createStandaloneToast, theme, TOAST_OPTIONS } from '@invoke-ai/ui-library';
import { logger } from 'app/logging/logger';
import { parseify } from 'common/util/serialize';
+import { toast } from 'common/util/toast';
import { zPydanticValidationError } from 'features/system/store/zodSchemas';
import { t } from 'i18next';
import { truncate, upperFirst } from 'lodash-es';
@@ -8,11 +8,6 @@ import { queueApi } from 'services/api/endpoints/queue';
import { startAppListening } from '..';
-const { toast } = createStandaloneToast({
- theme: theme,
- defaultOptions: TOAST_OPTIONS.defaultOptions,
-});
-
export const addBatchEnqueuedListener = () => {
// success
startAppListening({
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/bulkDownload.tsx b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/bulkDownload.tsx
index 69adfb5b67..fa5c962ff5 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/bulkDownload.tsx
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/bulkDownload.tsx
@@ -1,7 +1,8 @@
import type { UseToastOptions } from '@invoke-ai/ui-library';
-import { createStandaloneToast, ExternalLink, theme, TOAST_OPTIONS } from '@invoke-ai/ui-library';
+import { ExternalLink } from '@invoke-ai/ui-library';
import { logger } from 'app/logging/logger';
import { startAppListening } from 'app/store/middleware/listenerMiddleware';
+import { toast } from 'common/util/toast';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
import {
@@ -12,11 +13,6 @@ import {
const log = logger('images');
-const { toast } = createStandaloneToast({
- theme: theme,
- defaultOptions: TOAST_OPTIONS.defaultOptions,
-});
-
export const addBulkDownloadListeners = () => {
startAppListening({
matcher: imagesApi.endpoints.bulkDownloadImages.matchFulfilled,
diff --git a/invokeai/frontend/web/src/common/util/toast.ts b/invokeai/frontend/web/src/common/util/toast.ts
new file mode 100644
index 0000000000..ac61a4a12d
--- /dev/null
+++ b/invokeai/frontend/web/src/common/util/toast.ts
@@ -0,0 +1,6 @@
+import { createStandaloneToast, theme, TOAST_OPTIONS } from '@invoke-ai/ui-library';
+
+export const { toast } = createStandaloneToast({
+ theme: theme,
+ defaultOptions: TOAST_OPTIONS.defaultOptions,
+});
diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx
index 7eec7e1875..cfa679c805 100644
--- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx
@@ -1,300 +1,54 @@
-import { isModelIdentifier } from 'features/nodes/types/common';
-import type {
- ControlNetMetadataItem,
- CoreMetadata,
- IPAdapterMetadataItem,
- LoRAMetadataItem,
- T2IAdapterMetadataItem,
-} from 'features/nodes/types/metadata';
-import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
-import { memo, useCallback, useMemo } from 'react';
-import { useTranslation } from 'react-i18next';
-
-import ImageMetadataItem, { ModelMetadataItem, VAEMetadataItem } from './ImageMetadataItem';
+import { MetadataControlNets } from 'features/metadata/components/MetadataControlNets';
+import { MetadataIPAdapters } from 'features/metadata/components/MetadataIPAdapters';
+import { MetadataItem } from 'features/metadata/components/MetadataItem';
+import { MetadataLoRAs } from 'features/metadata/components/MetadataLoRAs';
+import { MetadataT2IAdapters } from 'features/metadata/components/MetadataT2IAdapters';
+import { handlers } from 'features/metadata/util/handlers';
+import { memo } from 'react';
type Props = {
- metadata?: CoreMetadata;
+ metadata?: unknown;
};
const ImageMetadataActions = (props: Props) => {
const { metadata } = props;
- const { t } = useTranslation();
-
- const {
- recallPositivePrompt,
- recallNegativePrompt,
- recallSeed,
- recallCfgScale,
- recallCfgRescaleMultiplier,
- recallModel,
- recallScheduler,
- recallVaeModel,
- recallSteps,
- recallWidth,
- recallHeight,
- recallStrength,
- recallHrfEnabled,
- recallHrfStrength,
- recallHrfMethod,
- recallLoRA,
- recallControlNet,
- recallIPAdapter,
- recallT2IAdapter,
- recallSDXLPositiveStylePrompt,
- recallSDXLNegativeStylePrompt,
- } = useRecallParameters();
-
- const handleRecallPositivePrompt = useCallback(() => {
- recallPositivePrompt(metadata?.positive_prompt);
- }, [metadata?.positive_prompt, recallPositivePrompt]);
-
- const handleRecallNegativePrompt = useCallback(() => {
- recallNegativePrompt(metadata?.negative_prompt);
- }, [metadata?.negative_prompt, recallNegativePrompt]);
-
- const handleRecallSDXLPositiveStylePrompt = useCallback(() => {
- recallSDXLPositiveStylePrompt(metadata?.positive_style_prompt);
- }, [metadata?.positive_style_prompt, recallSDXLPositiveStylePrompt]);
-
- const handleRecallSDXLNegativeStylePrompt = useCallback(() => {
- recallSDXLNegativeStylePrompt(metadata?.negative__style_prompt);
- }, [metadata?.negative__style_prompt, recallSDXLNegativeStylePrompt]);
-
- const handleRecallSeed = useCallback(() => {
- recallSeed(metadata?.seed);
- }, [metadata?.seed, recallSeed]);
-
- const handleRecallModel = useCallback(() => {
- recallModel(metadata?.model);
- }, [metadata?.model, recallModel]);
-
- const handleRecallWidth = useCallback(() => {
- recallWidth(metadata?.width);
- }, [metadata?.width, recallWidth]);
-
- const handleRecallHeight = useCallback(() => {
- recallHeight(metadata?.height);
- }, [metadata?.height, recallHeight]);
-
- const handleRecallScheduler = useCallback(() => {
- recallScheduler(metadata?.scheduler);
- }, [metadata?.scheduler, recallScheduler]);
-
- const handleRecallVaeModel = useCallback(() => {
- recallVaeModel(metadata?.vae);
- }, [metadata?.vae, recallVaeModel]);
-
- const handleRecallSteps = useCallback(() => {
- recallSteps(metadata?.steps);
- }, [metadata?.steps, recallSteps]);
-
- const handleRecallCfgScale = useCallback(() => {
- recallCfgScale(metadata?.cfg_scale);
- }, [metadata?.cfg_scale, recallCfgScale]);
-
- const handleRecallCfgRescaleMultiplier = useCallback(() => {
- recallCfgRescaleMultiplier(metadata?.cfg_rescale_multiplier);
- }, [metadata?.cfg_rescale_multiplier, recallCfgRescaleMultiplier]);
-
- const handleRecallStrength = useCallback(() => {
- recallStrength(metadata?.strength);
- }, [metadata?.strength, recallStrength]);
-
- const handleRecallHrfEnabled = useCallback(() => {
- recallHrfEnabled(metadata?.hrf_enabled);
- }, [metadata?.hrf_enabled, recallHrfEnabled]);
-
- const handleRecallHrfStrength = useCallback(() => {
- recallHrfStrength(metadata?.hrf_strength);
- }, [metadata?.hrf_strength, recallHrfStrength]);
-
- const handleRecallHrfMethod = useCallback(() => {
- recallHrfMethod(metadata?.hrf_method);
- }, [metadata?.hrf_method, recallHrfMethod]);
-
- const handleRecallLoRA = useCallback(
- (lora: LoRAMetadataItem) => {
- recallLoRA(lora);
- },
- [recallLoRA]
- );
-
- const handleRecallControlNet = useCallback(
- (controlnet: ControlNetMetadataItem) => {
- recallControlNet(controlnet);
- },
- [recallControlNet]
- );
-
- const handleRecallIPAdapter = useCallback(
- (ipAdapter: IPAdapterMetadataItem) => {
- recallIPAdapter(ipAdapter);
- },
- [recallIPAdapter]
- );
-
- const handleRecallT2IAdapter = useCallback(
- (ipAdapter: T2IAdapterMetadataItem) => {
- recallT2IAdapter(ipAdapter);
- },
- [recallT2IAdapter]
- );
-
- const validControlNets: ControlNetMetadataItem[] = useMemo(() => {
- return metadata?.controlnets
- ? metadata.controlnets.filter((controlnet) => isModelIdentifier(controlnet.control_model))
- : [];
- }, [metadata?.controlnets]);
-
- const validIPAdapters: IPAdapterMetadataItem[] = useMemo(() => {
- return metadata?.ipAdapters
- ? metadata.ipAdapters.filter((ipAdapter) => isModelIdentifier(ipAdapter.ip_adapter_model))
- : [];
- }, [metadata?.ipAdapters]);
-
- const validT2IAdapters: T2IAdapterMetadataItem[] = useMemo(() => {
- return metadata?.t2iAdapters
- ? metadata.t2iAdapters.filter((t2iAdapter) => isModelIdentifier(t2iAdapter.t2i_adapter_model))
- : [];
- }, [metadata?.t2iAdapters]);
-
if (!metadata || Object.keys(metadata).length === 0) {
return null;
}
return (
<>
- {metadata.created_by && }
- {metadata.generation_mode && (
-
- )}
- {metadata.positive_prompt && (
-
- )}
- {metadata.negative_prompt && (
-
- )}
- {metadata.positive_style_prompt && (
-
- )}
- {metadata.negative_style_prompt && (
-
- )}
- {metadata.seed !== undefined && metadata.seed !== null && (
-
- )}
- {metadata.model !== undefined && metadata.model !== null && metadata.model.key && (
-
- )}
- {metadata.width && (
-
- )}
- {metadata.height && (
-
- )}
- {metadata.scheduler && (
-
- )}
-
- {metadata.steps && (
-
- )}
- {metadata.cfg_scale !== undefined && metadata.cfg_scale !== null && (
-
- )}
- {metadata.cfg_rescale_multiplier !== undefined && metadata.cfg_rescale_multiplier !== null && (
-
- )}
- {metadata.strength && (
-
- )}
- {metadata.hrf_enabled && (
-
- )}
- {metadata.hrf_enabled && metadata.hrf_strength && (
-
- )}
- {metadata.hrf_enabled && metadata.hrf_method && (
-
- )}
- {metadata.loras &&
- metadata.loras.map((lora, index) => {
- if (isModelIdentifier(lora.lora)) {
- return (
-
- );
- }
- })}
- {validControlNets.map((controlnet, index) => (
-
- ))}
- {validIPAdapters.map((ipAdapter, index) => (
-
- ))}
- {validT2IAdapters.map((t2iAdapter, index) => (
-
- ))}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
>
);
};
diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataItem.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataItem.tsx
index ce25b54e59..146c1437d1 100644
--- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataItem.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataItem.tsx
@@ -1,28 +1,39 @@
import { ExternalLink, Flex, IconButton, Text, Tooltip } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
+import { get } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
-import { IoArrowUndoCircleOutline } from 'react-icons/io5';
-import { PiCopyBold } from 'react-icons/pi';
+import { PiArrowBendUpLeftBold } from 'react-icons/pi';
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
type MetadataItemProps = {
isLink?: boolean;
label: string;
- onClick?: () => void;
- value: number | string | boolean;
+ metadata: unknown;
+ propertyName: string;
+ onRecall?: (value: unknown) => void;
labelPosition?: string;
- withCopy?: boolean;
};
/**
* Component to display an individual metadata item or parameter.
*/
-const ImageMetadataItem = ({ label, value, onClick, isLink, labelPosition, withCopy = false }: MetadataItemProps) => {
+const ImageMetadataItem = ({
+ label,
+ metadata,
+ propertyName,
+ onRecall: _onRecall,
+ isLink,
+ labelPosition,
+}: MetadataItemProps) => {
const { t } = useTranslation();
- const handleCopy = useCallback(() => {
- navigator.clipboard.writeText(value?.toString());
- }, [value]);
+ const value = useMemo(() => get(metadata, propertyName), [metadata, propertyName]);
+ const onRecall = useCallback(() => {
+ if (!_onRecall) {
+ return;
+ }
+ _onRecall(value);
+ }, [_onRecall, value]);
if (!value) {
return null;
@@ -30,27 +41,15 @@ const ImageMetadataItem = ({ label, value, onClick, isLink, labelPosition, withC
return (
- {onClick && (
-
+ {_onRecall && (
+
}
+ aria-label={t('metadata.recallParameter', { parameter: label })}
+ icon={}
size="xs"
variant="ghost"
fontSize={20}
- onClick={onClick}
- />
-
- )}
- {withCopy && (
-
- }
- size="xs"
- variant="ghost"
- fontSize={14}
- onClick={handleCopy}
+ onClick={onRecall}
/>
)}
diff --git a/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx b/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx
index 579d45054b..ddcdd58e75 100644
--- a/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx
+++ b/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx
@@ -24,29 +24,29 @@ type LoRACardProps = {
export const LoRACard = memo((props: LoRACardProps) => {
const { lora } = props;
const dispatch = useAppDispatch();
- const { data: loraConfig } = useGetModelConfigQuery(lora.key);
+ const { data: loraConfig } = useGetModelConfigQuery(lora.model.key);
const handleChange = useCallback(
(v: number) => {
- dispatch(loraWeightChanged({ key: lora.key, weight: v }));
+ dispatch(loraWeightChanged({ key: lora.model.key, weight: v }));
},
- [dispatch, lora.key]
+ [dispatch, lora.model.key]
);
const handleSetLoraToggle = useCallback(() => {
- dispatch(loraIsEnabledChanged({ key: lora.key, isEnabled: !lora.isEnabled }));
- }, [dispatch, lora.key, lora.isEnabled]);
+ dispatch(loraIsEnabledChanged({ key: lora.model.key, isEnabled: !lora.isEnabled }));
+ }, [dispatch, lora.model.key, lora.isEnabled]);
const handleRemoveLora = useCallback(() => {
- dispatch(loraRemoved(lora.key));
- }, [dispatch, lora.key]);
+ dispatch(loraRemoved(lora.model.key));
+ }, [dispatch, lora.model.key]);
return (
- {loraConfig?.name ?? lora.key.substring(0, 8)}
+ {loraConfig?.name ?? lora.model.key.substring(0, 8)}
diff --git a/invokeai/frontend/web/src/features/lora/store/loraSlice.ts b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts
index 377406b3e5..32bef5fd9b 100644
--- a/invokeai/frontend/web/src/features/lora/store/loraSlice.ts
+++ b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts
@@ -2,9 +2,11 @@ import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
+import { getModelKeyAndBase } from 'features/parameters/util/modelFetchingHelpers';
import type { LoRAModelConfig } from 'services/api/types';
-export type LoRA = ParameterLoRAModel & {
+export type LoRA = {
+ model: ParameterLoRAModel;
weight: number;
isEnabled?: boolean;
};
@@ -29,11 +31,11 @@ export const loraSlice = createSlice({
initialState: initialLoraState,
reducers: {
loraAdded: (state, action: PayloadAction) => {
- const { key, base } = action.payload;
- state.loras[key] = { key, base, ...defaultLoRAConfig };
+ const model = getModelKeyAndBase(action.payload);
+ state.loras[model.key] = { ...defaultLoRAConfig, model };
},
loraRecalled: (state, action: PayloadAction) => {
- state.loras[action.payload.key] = action.payload;
+ state.loras[action.payload.model.key] = action.payload;
},
loraRemoved: (state, action: PayloadAction) => {
const key = action.payload;
@@ -58,7 +60,7 @@ export const loraSlice = createSlice({
}
lora.weight = defaultLoRAConfig.weight;
},
- loraIsEnabledChanged: (state, action: PayloadAction>) => {
+ loraIsEnabledChanged: (state, action: PayloadAction<{ key: string; isEnabled: boolean }>) => {
const { key, isEnabled } = action.payload;
const lora = state.loras[key];
if (!lora) {
diff --git a/invokeai/frontend/web/src/features/metadata/components/MetadataControlNets.tsx b/invokeai/frontend/web/src/features/metadata/components/MetadataControlNets.tsx
new file mode 100644
index 0000000000..2e8e3b6f9a
--- /dev/null
+++ b/invokeai/frontend/web/src/features/metadata/components/MetadataControlNets.tsx
@@ -0,0 +1,66 @@
+import type { ControlNetConfig } from 'features/controlAdapters/store/types';
+import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
+import type { MetadataHandlers } from 'features/metadata/types';
+import { handlers } from 'features/metadata/util/handlers';
+import { useCallback, useEffect, useMemo, useState } from 'react';
+
+type Props = {
+ metadata: unknown;
+};
+
+export const MetadataControlNets = ({ metadata }: Props) => {
+ const [controlNets, setControlNets] = useState([]);
+
+ useEffect(() => {
+ const parse = async () => {
+ try {
+ const parsed = await handlers.controlNets.parse(metadata);
+ setControlNets(parsed);
+ } catch (e) {
+ setControlNets([]);
+ }
+ };
+ parse();
+ }, [metadata]);
+
+ const label = useMemo(() => handlers.controlNets.getLabel(), []);
+
+ return (
+ <>
+ {controlNets.map((controlNet) => (
+
+ ))}
+ >
+ );
+};
+
+const MetadataViewControlNet = ({
+ label,
+ controlNet,
+ handlers,
+}: {
+ label: string;
+ controlNet: ControlNetConfig;
+ handlers: MetadataHandlers;
+}) => {
+ const onRecall = useCallback(() => {
+ if (!handlers.recallItem) {
+ return;
+ }
+ handlers.recallItem(controlNet, true);
+ }, [handlers, controlNet]);
+
+ const renderedValue = useMemo(() => {
+ if (!handlers.renderItemValue) {
+ return null;
+ }
+ return handlers.renderItemValue(controlNet);
+ }, [handlers, controlNet]);
+
+ return ;
+};
diff --git a/invokeai/frontend/web/src/features/metadata/components/MetadataIPAdapters.tsx b/invokeai/frontend/web/src/features/metadata/components/MetadataIPAdapters.tsx
new file mode 100644
index 0000000000..fef281cd09
--- /dev/null
+++ b/invokeai/frontend/web/src/features/metadata/components/MetadataIPAdapters.tsx
@@ -0,0 +1,66 @@
+import type { IPAdapterConfig } from 'features/controlAdapters/store/types';
+import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
+import type { MetadataHandlers } from 'features/metadata/types';
+import { handlers } from 'features/metadata/util/handlers';
+import { useCallback, useEffect, useMemo, useState } from 'react';
+
+type Props = {
+ metadata: unknown;
+};
+
+export const MetadataIPAdapters = ({ metadata }: Props) => {
+ const [ipAdapters, setIPAdapters] = useState([]);
+
+ useEffect(() => {
+ const parse = async () => {
+ try {
+ const parsed = await handlers.ipAdapters.parse(metadata);
+ setIPAdapters(parsed);
+ } catch (e) {
+ setIPAdapters([]);
+ }
+ };
+ parse();
+ }, [metadata]);
+
+ const label = useMemo(() => handlers.ipAdapters.getLabel(), []);
+
+ return (
+ <>
+ {ipAdapters.map((ipAdapter) => (
+
+ ))}
+ >
+ );
+};
+
+const MetadataViewIPAdapter = ({
+ label,
+ ipAdapter,
+ handlers,
+}: {
+ label: string;
+ ipAdapter: IPAdapterConfig;
+ handlers: MetadataHandlers;
+}) => {
+ const onRecall = useCallback(() => {
+ if (!handlers.recallItem) {
+ return;
+ }
+ handlers.recallItem(ipAdapter, true);
+ }, [handlers, ipAdapter]);
+
+ const renderedValue = useMemo(() => {
+ if (!handlers.renderItemValue) {
+ return null;
+ }
+ return handlers.renderItemValue(ipAdapter);
+ }, [handlers, ipAdapter]);
+
+ return ;
+};
diff --git a/invokeai/frontend/web/src/features/metadata/components/MetadataItem.tsx b/invokeai/frontend/web/src/features/metadata/components/MetadataItem.tsx
new file mode 100644
index 0000000000..66d101f458
--- /dev/null
+++ b/invokeai/frontend/web/src/features/metadata/components/MetadataItem.tsx
@@ -0,0 +1,33 @@
+import { typedMemo } from '@invoke-ai/ui-library';
+import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
+import { useMetadataItem } from 'features/metadata/hooks/useMetadataItem';
+import type { MetadataHandlers } from 'features/metadata/types';
+import { MetadataParseFailedToken } from 'features/metadata/util/parsers';
+
+type MetadataItemProps = {
+ metadata: unknown;
+ handlers: MetadataHandlers;
+ direction?: 'row' | 'column';
+};
+
+const _MetadataItem = typedMemo(({ metadata, handlers, direction = 'row' }: MetadataItemProps) => {
+ const { label, isDisabled, value, renderedValue, onRecall } = useMetadataItem(metadata, handlers);
+
+ if (value === MetadataParseFailedToken) {
+ return null;
+ }
+
+ return (
+
+ );
+});
+
+export const MetadataItem = typedMemo(_MetadataItem);
+
+MetadataItem.displayName = 'MetadataItem';
diff --git a/invokeai/frontend/web/src/features/metadata/components/MetadataItemView.tsx b/invokeai/frontend/web/src/features/metadata/components/MetadataItemView.tsx
new file mode 100644
index 0000000000..14ccc4ae2b
--- /dev/null
+++ b/invokeai/frontend/web/src/features/metadata/components/MetadataItemView.tsx
@@ -0,0 +1,29 @@
+import { Flex, Text } from '@invoke-ai/ui-library';
+import { RecallButton } from 'features/metadata/components/RecallButton';
+import { memo } from 'react';
+
+type MetadataItemViewProps = {
+ onRecall: () => void;
+ label: string;
+ renderedValue: React.ReactNode;
+ isDisabled: boolean;
+ direction?: 'row' | 'column';
+};
+
+export const MetadataItemView = memo(
+ ({ label, onRecall, isDisabled, renderedValue, direction = 'row' }: MetadataItemViewProps) => {
+ return (
+
+ {onRecall && }
+
+
+ {label}:
+
+ {renderedValue}
+
+
+ );
+ }
+);
+
+MetadataItemView.displayName = 'MetadataItemView';
diff --git a/invokeai/frontend/web/src/features/metadata/components/MetadataLoRAs.tsx b/invokeai/frontend/web/src/features/metadata/components/MetadataLoRAs.tsx
new file mode 100644
index 0000000000..3225a048d4
--- /dev/null
+++ b/invokeai/frontend/web/src/features/metadata/components/MetadataLoRAs.tsx
@@ -0,0 +1,61 @@
+import type { LoRA } from 'features/lora/store/loraSlice';
+import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
+import type { MetadataHandlers } from 'features/metadata/types';
+import { handlers } from 'features/metadata/util/handlers';
+import { useCallback, useEffect, useMemo, useState } from 'react';
+
+type Props = {
+ metadata: unknown;
+};
+
+export const MetadataLoRAs = ({ metadata }: Props) => {
+ const [loras, setLoRAs] = useState([]);
+
+ useEffect(() => {
+ const parse = async () => {
+ try {
+ const parsed = await handlers.loras.parse(metadata);
+ setLoRAs(parsed);
+ } catch (e) {
+ setLoRAs([]);
+ }
+ };
+ parse();
+ }, [metadata]);
+
+ const label = useMemo(() => handlers.loras.getLabel(), []);
+
+ return (
+ <>
+ {loras.map((lora) => (
+
+ ))}
+ >
+ );
+};
+
+const MetadataViewLoRA = ({
+ label,
+ lora,
+ handlers,
+}: {
+ label: string;
+ lora: LoRA;
+ handlers: MetadataHandlers;
+}) => {
+ const onRecall = useCallback(() => {
+ if (!handlers.recallItem) {
+ return;
+ }
+ handlers.recallItem(lora, true);
+ }, [handlers, lora]);
+
+ const renderedValue = useMemo(() => {
+ if (!handlers.renderItemValue) {
+ return null;
+ }
+ return handlers.renderItemValue(lora);
+ }, [handlers, lora]);
+
+ return ;
+};
diff --git a/invokeai/frontend/web/src/features/metadata/components/MetadataT2IAdapters.tsx b/invokeai/frontend/web/src/features/metadata/components/MetadataT2IAdapters.tsx
new file mode 100644
index 0000000000..4a73fb3ccb
--- /dev/null
+++ b/invokeai/frontend/web/src/features/metadata/components/MetadataT2IAdapters.tsx
@@ -0,0 +1,66 @@
+import type { T2IAdapterConfig } from 'features/controlAdapters/store/types';
+import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
+import type { MetadataHandlers } from 'features/metadata/types';
+import { handlers } from 'features/metadata/util/handlers';
+import { useCallback, useEffect, useMemo, useState } from 'react';
+
+type Props = {
+ metadata: unknown;
+};
+
+export const MetadataT2IAdapters = ({ metadata }: Props) => {
+ const [t2iAdapters, setT2IAdapters] = useState([]);
+
+ useEffect(() => {
+ const parse = async () => {
+ try {
+ const parsed = await handlers.t2iAdapters.parse(metadata);
+ setT2IAdapters(parsed);
+ } catch (e) {
+ setT2IAdapters([]);
+ }
+ };
+ parse();
+ }, [metadata]);
+
+ const label = useMemo(() => handlers.t2iAdapters.getLabel(), []);
+
+ return (
+ <>
+ {t2iAdapters.map((t2iAdapter) => (
+
+ ))}
+ >
+ );
+};
+
+const MetadataViewT2IAdapter = ({
+ label,
+ t2iAdapter,
+ handlers,
+}: {
+ label: string;
+ t2iAdapter: T2IAdapterConfig;
+ handlers: MetadataHandlers;
+}) => {
+ const onRecall = useCallback(() => {
+ if (!handlers.recallItem) {
+ return;
+ }
+ handlers.recallItem(t2iAdapter, true);
+ }, [handlers, t2iAdapter]);
+
+ const renderedValue = useMemo(() => {
+ if (!handlers.renderItemValue) {
+ return null;
+ }
+ return handlers.renderItemValue(t2iAdapter);
+ }, [handlers, t2iAdapter]);
+
+ return ;
+};
diff --git a/invokeai/frontend/web/src/features/metadata/components/RecallButton.tsx b/invokeai/frontend/web/src/features/metadata/components/RecallButton.tsx
new file mode 100644
index 0000000000..8803945b06
--- /dev/null
+++ b/invokeai/frontend/web/src/features/metadata/components/RecallButton.tsx
@@ -0,0 +1,27 @@
+import type { IconButtonProps} from '@invoke-ai/ui-library';
+import { IconButton, Tooltip } from '@invoke-ai/ui-library';
+import { memo } from 'react';
+import { useTranslation } from 'react-i18next';
+import { PiArrowBendUpLeftBold } from 'react-icons/pi';
+
+type MetadataItemProps = Omit & {
+ label: string;
+};
+
+export const RecallButton = memo(({ label, ...rest }: MetadataItemProps) => {
+ const { t } = useTranslation();
+
+ return (
+
+ }
+ size="xs"
+ variant="ghost"
+ {...rest}
+ />
+
+ );
+});
+
+RecallButton.displayName = 'RecallButton';
diff --git a/invokeai/frontend/web/src/features/metadata/exceptions.ts b/invokeai/frontend/web/src/features/metadata/exceptions.ts
new file mode 100644
index 0000000000..ffe4e98c79
--- /dev/null
+++ b/invokeai/frontend/web/src/features/metadata/exceptions.ts
@@ -0,0 +1,27 @@
+/**
+ * Raised when metadata parsing fails.
+ */
+export class MetadataParseError extends Error {
+ /**
+ * Create MetadataParseError
+ * @param {String} message
+ */
+ constructor(message: string) {
+ super(message);
+ this.name = this.constructor.name;
+ }
+}
+
+/**
+ * Raised when metadata recall fails.
+ */
+export class MetadataRecallError extends Error {
+ /**
+ * Create MetadataRecallError
+ * @param {String} message
+ */
+ constructor(message: string) {
+ super(message);
+ this.name = this.constructor.name;
+ }
+}
diff --git a/invokeai/frontend/web/src/features/metadata/hooks/useMetadataItem.tsx b/invokeai/frontend/web/src/features/metadata/hooks/useMetadataItem.tsx
new file mode 100644
index 0000000000..178ac5155a
--- /dev/null
+++ b/invokeai/frontend/web/src/features/metadata/hooks/useMetadataItem.tsx
@@ -0,0 +1,51 @@
+import { Text } from '@invoke-ai/ui-library';
+import type { MetadataHandlers } from 'features/metadata/types';
+import { MetadataParseFailedToken, MetadataParsePendingToken } from 'features/metadata/util/parsers';
+import { useCallback, useEffect, useMemo, useState } from 'react';
+
+export const useMetadataItem = (metadata: unknown, handlers: MetadataHandlers) => {
+ const [value, setValue] = useState(
+ MetadataParsePendingToken
+ );
+
+ useEffect(() => {
+ const _parse = async () => {
+ try {
+ const parsed = await handlers.parse(metadata);
+ setValue(parsed);
+ } catch (e) {
+ setValue(MetadataParseFailedToken);
+ }
+ };
+ _parse();
+ }, [handlers, metadata]);
+
+ const isDisabled = useMemo(() => value === MetadataParsePendingToken || value === MetadataParseFailedToken, [value]);
+
+ const label = useMemo(() => handlers.getLabel(), [handlers]);
+
+ const renderedValue = useMemo(() => {
+ if (value === MetadataParsePendingToken) {
+ return Loading;
+ }
+ if (value === MetadataParseFailedToken) {
+ return Parsing Failed;
+ }
+
+ const rendered = handlers.renderValue(value);
+
+ if (typeof rendered === 'string') {
+ return {rendered};
+ }
+ return rendered;
+ }, [handlers, value]);
+
+ const onRecall = useCallback(() => {
+ if (!handlers.recall || value === MetadataParsePendingToken || value === MetadataParseFailedToken) {
+ return null;
+ }
+ handlers.recall(value, true);
+ }, [handlers, value]);
+
+ return { label, isDisabled, value, renderedValue, onRecall };
+};
diff --git a/invokeai/frontend/web/src/features/metadata/types.ts b/invokeai/frontend/web/src/features/metadata/types.ts
new file mode 100644
index 0000000000..366f11701e
--- /dev/null
+++ b/invokeai/frontend/web/src/features/metadata/types.ts
@@ -0,0 +1,136 @@
+/**
+ * Renders a value of type T as a React node.
+ */
+export type MetadataRenderValueFunc = (value: T) => React.ReactNode;
+
+/**
+ * Gets the label of the current metadata item as a string.
+ */
+export type MetadataGetLabelFunc = () => string;
+
+export type MetadataParseOptions = {
+ toastOnFailure?: boolean;
+ toastOnSuccess?: boolean;
+};
+
+export type MetadataRecallOptions = MetadataParseOptions;
+
+/**
+ * A function that recalls a parsed and validated metadata value.
+ *
+ * @param value The value to recall.
+ * @throws MetadataRecallError if the value cannot be recalled.
+ */
+export type MetadataRecallFunc = (value: T) => void;
+
+/**
+ * An async function that receives metadata and returns a parsed value, throwing if the value is invalid or missing.
+ *
+ * The function receives an object of unknown type. It is responsible for extracting the relevant data from the metadata
+ * and returning a value of type T.
+ *
+ * The function should throw a MetadataParseError if the metadata is invalid or missing.
+ *
+ * @param metadata The metadata to parse.
+ * @returns A promise that resolves to the parsed value.
+ * @throws MetadataParseError if the metadata is invalid or missing.
+ */
+export type MetadataParseFunc = (metadata: unknown) => Promise;
+
+/**
+ * A function that performs additional validation logic before recalling a metadata value. It is called with a parsed
+ * value and should throw if the validation logic fails.
+ *
+ * This function is used in cases where some additional logic is required before recalling. For example, when recalling
+ * a LoRA, we need to check if it is compatible with the current base model.
+ *
+ * @param value The value to validate.
+ * @returns A promise that resolves to the validated value.
+ * @throws MetadataRecallError if the value is invalid.
+ */
+export type MetadataValidateFunc = (value: T) => Promise;
+
+export type MetadataHandlers = {
+ /**
+ * Gets the label of the current metadata item as a string.
+ *
+ * @returns The label of the current metadata item.
+ */
+ getLabel: MetadataGetLabelFunc;
+ /**
+ * An async function that receives metadata and returns a parsed metadata value.
+ *
+ * @param metadata The metadata to parse.
+ * @param withToast Whether to show a toast on success or failure.
+ * @returns A promise that resolves to the parsed value.
+ * @throws MetadataParseError if the metadata is invalid or missing.
+ */
+ parse: (metadata: unknown, withToast?: boolean) => Promise;
+ /**
+ * An async function that receives a metadata item and returns a parsed metadata item value.
+ *
+ * This is only provided if the metadata value is an array.
+ *
+ * @param item The item to parse. It should be an item from the array.
+ * @param withToast Whether to show a toast on success or failure.
+ * @returns A promise that resolves to the parsed value.
+ * @throws MetadataParseError if the metadata is invalid or missing.
+ */
+ parseItem?: (item: unknown, withToast?: boolean) => Promise;
+ /**
+ * An async function that recalls a parsed metadata value.
+ *
+ * This function is only provided if the metadata value can be recalled.
+ *
+ * @param value The value to recall.
+ * @param withToast Whether to show a toast on success or failure.
+ * @returns A promise that resolves when the recall operation is complete.
+ * @throws MetadataRecallError if the value cannot be recalled.
+ */
+ recall?: (value: TValue, withToast?: boolean) => Promise;
+ /**
+ * An async function that recalls a parsed metadata item value.
+ *
+ * This function is only provided if the metadata value is an array and the items can be recalled.
+ *
+ * @param item The item to recall. It should be an item from the array.
+ * @param withToast Whether to show a toast on success or failure.
+ * @returns A promise that resolves when the recall operation is complete.
+ * @throws MetadataRecallError if the value cannot be recalled.
+ */
+ recallItem?: (item: TItem, withToast?: boolean) => Promise;
+ /**
+ * Renders a parsed metadata value as a React node.
+ *
+ * @param value The value to render.
+ * @returns The rendered value.
+ */
+ renderValue: MetadataRenderValueFunc;
+ /**
+ * Renders a parsed metadata item value as a React node.
+ *
+ * @param item The item to render.
+ * @returns The rendered item.
+ */
+ renderItemValue?: MetadataRenderValueFunc;
+};
+
+// TODO(psyche): The types for item handlers should be able to be inferred from the type of the value:
+// type MetadataHandlersInferItem = TValue extends Array ? MetadataParseFunc : never
+// While this works for the types as expected, I couldn't satisfy TS in the implementations of the handlers.
+
+export type BuildMetadataHandlersArg = {
+ parser: MetadataParseFunc;
+ itemParser?: MetadataParseFunc;
+ recaller?: MetadataRecallFunc;
+ itemRecaller?: MetadataRecallFunc;
+ validator?: MetadataValidateFunc;
+ itemValidator?: MetadataValidateFunc;
+ getLabel: MetadataGetLabelFunc;
+ renderValue?: MetadataRenderValueFunc;
+ renderItemValue?: MetadataRenderValueFunc;
+};
+
+export type BuildMetadataHandlers = (
+ arg: BuildMetadataHandlersArg
+) => MetadataHandlers;
diff --git a/invokeai/frontend/web/src/features/metadata/util/handlers.tsx b/invokeai/frontend/web/src/features/metadata/util/handlers.tsx
new file mode 100644
index 0000000000..1f86257d42
--- /dev/null
+++ b/invokeai/frontend/web/src/features/metadata/util/handlers.tsx
@@ -0,0 +1,316 @@
+import { Text } from '@invoke-ai/ui-library';
+import { toast } from 'common/util/toast';
+import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types';
+import type { LoRA } from 'features/lora/store/loraSlice';
+import type {
+ BuildMetadataHandlers,
+ MetadataGetLabelFunc,
+ MetadataHandlers,
+ MetadataParseFunc,
+ MetadataRecallFunc,
+ MetadataRenderValueFunc,
+ MetadataValidateFunc,
+} from 'features/metadata/types';
+import { validators } from 'features/metadata/util/validators';
+import { t } from 'i18next';
+import type { AnyModelConfig } from 'services/api/types';
+
+import { parsers } from './parsers';
+import { recallers } from './recallers';
+
+const renderModelConfigValue: MetadataRenderValueFunc = (value) =>
+ `${value.name} (${value.base.toUpperCase()}, ${value.key})`;
+const renderLoRAValue: MetadataRenderValueFunc = (value) => {`${value.model.key} (${value.weight})`};
+const renderControlAdapterValue: MetadataRenderValueFunc = (
+ value
+) => {`${value.model?.key} (${value.weight})`};
+
+const parameterSetToast = (parameter: string, description?: string) => {
+ toast({
+ title: t('toast.parameterSet', { parameter }),
+ description,
+ status: 'info',
+ duration: 2500,
+ isClosable: true,
+ });
+};
+
+const parameterNotSetToast = (parameter: string, description?: string) => {
+ toast({
+ title: t('toast.parameterNotSet', { parameter }),
+ description,
+ status: 'warning',
+ duration: 2500,
+ isClosable: true,
+ });
+};
+
+// const allParameterSetToast = (description?: string) => {
+// toast({
+// title: t('toast.parametersSet'),
+// status: 'info',
+// description,
+// duration: 2500,
+// isClosable: true,
+// });
+// };
+
+// const allParameterNotSetToast = (description?: string) => {
+// toast({
+// title: t('toast.parametersNotSet'),
+// status: 'warning',
+// description,
+// duration: 2500,
+// isClosable: true,
+// });
+// };
+
+const buildParse =
+ (arg: {
+ parser: MetadataParseFunc;
+ getLabel: MetadataGetLabelFunc;
+ }): MetadataHandlers['parse'] =>
+ async (metadata, withToast = false) => {
+ try {
+ const parsed = await arg.parser(metadata);
+ withToast && parameterSetToast(arg.getLabel());
+ return parsed;
+ } catch (e) {
+ withToast && parameterNotSetToast(arg.getLabel(), (e as Error).message);
+ throw e;
+ }
+ };
+
+const buildParseItem =
+ (arg: {
+ itemParser: MetadataParseFunc;
+ getLabel: MetadataGetLabelFunc;
+ }): MetadataHandlers['parseItem'] =>
+ async (item, withToast = false) => {
+ try {
+ const parsed = await arg.itemParser(item);
+ withToast && parameterSetToast(arg.getLabel());
+ return parsed;
+ } catch (e) {
+ withToast && parameterNotSetToast(arg.getLabel(), (e as Error).message);
+ throw e;
+ }
+ };
+
+const buildRecall =
+ (arg: {
+ recaller: MetadataRecallFunc;
+ validator?: MetadataValidateFunc;
+ getLabel: MetadataGetLabelFunc;
+ }): NonNullable['recall']> =>
+ async (value, withToast = false) => {
+ try {
+ arg.validator && (await arg.validator(value));
+ await arg.recaller(value);
+ withToast && parameterSetToast(arg.getLabel());
+ } catch (e) {
+ withToast && parameterNotSetToast(arg.getLabel(), (e as Error).message);
+ throw e;
+ }
+ };
+
+const buildRecallItem =
+ (arg: {
+ itemRecaller: MetadataRecallFunc;
+ itemValidator?: MetadataValidateFunc;
+ getLabel: MetadataGetLabelFunc;
+ }): NonNullable['recallItem']> =>
+ async (item, withToast = false) => {
+ try {
+ arg.itemValidator && (await arg.itemValidator(item));
+ await arg.itemRecaller(item);
+ withToast && parameterSetToast(arg.getLabel());
+ } catch (e) {
+ withToast && parameterNotSetToast(arg.getLabel(), (e as Error).message);
+ throw e;
+ }
+ };
+
+const buildHandlers: BuildMetadataHandlers = ({
+ getLabel,
+ parser,
+ itemParser,
+ recaller,
+ itemRecaller,
+ validator,
+ itemValidator,
+ renderValue,
+ renderItemValue,
+}) => ({
+ parse: buildParse({ parser, getLabel }),
+ parseItem: itemParser ? buildParseItem({ itemParser, getLabel }) : undefined,
+ recall: recaller ? buildRecall({ recaller, validator, getLabel }) : undefined,
+ recallItem: itemRecaller ? buildRecallItem({ itemRecaller, itemValidator, getLabel }) : undefined,
+ getLabel,
+ renderValue: renderValue ?? String,
+ renderItemValue: renderItemValue ?? String,
+});
+
+export const handlers = {
+ // Misc
+ createdBy: buildHandlers({ getLabel: () => t('metadata.createdBy'), parser: parsers.createdBy }),
+ generationMode: buildHandlers({ getLabel: () => t('metadata.generationMode'), parser: parsers.generationMode }),
+
+ // Core parameters
+ cfgRescaleMultiplier: buildHandlers({
+ getLabel: () => t('metadata.cfgRescaleMultiplier'),
+ parser: parsers.cfgRescaleMultiplier,
+ recaller: recallers.cfgRescaleMultiplier,
+ }),
+ cfgScale: buildHandlers({
+ getLabel: () => t('metadata.cfgScale'),
+ parser: parsers.cfgScale,
+ recaller: recallers.cfgScale,
+ }),
+ height: buildHandlers({ getLabel: () => t('metadata.height'), parser: parsers.height, recaller: recallers.height }),
+ negativePrompt: buildHandlers({
+ getLabel: () => t('metadata.negativePrompt'),
+ parser: parsers.negativePrompt,
+ recaller: recallers.negativePrompt,
+ }),
+ positivePrompt: buildHandlers({
+ getLabel: () => t('metadata.positivePrompt'),
+ parser: parsers.positivePrompt,
+ recaller: recallers.positivePrompt,
+ }),
+ scheduler: buildHandlers({
+ getLabel: () => t('metadata.scheduler'),
+ parser: parsers.scheduler,
+ recaller: recallers.scheduler,
+ }),
+ sdxlNegativeStylePrompt: buildHandlers({
+ getLabel: () => t('sdxl.negStylePrompt'),
+ parser: parsers.sdxlNegativeStylePrompt,
+ recaller: recallers.sdxlNegativeStylePrompt,
+ }),
+ sdxlPositiveStylePrompt: buildHandlers({
+ getLabel: () => t('sdxl.posStylePrompt'),
+ parser: parsers.sdxlPositiveStylePrompt,
+ recaller: recallers.sdxlPositiveStylePrompt,
+ }),
+ seed: buildHandlers({ getLabel: () => t('metadata.seed'), parser: parsers.seed, recaller: recallers.seed }),
+ steps: buildHandlers({ getLabel: () => t('metadata.steps'), parser: parsers.steps, recaller: recallers.steps }),
+ strength: buildHandlers({
+ getLabel: () => t('metadata.strength'),
+ parser: parsers.strength,
+ recaller: recallers.strength,
+ }),
+ width: buildHandlers({ getLabel: () => t('metadata.width'), parser: parsers.width, recaller: recallers.width }),
+
+ // HRF
+ hrfEnabled: buildHandlers({
+ getLabel: () => t('hrf.metadata.enabled'),
+ parser: parsers.hrfEnabled,
+ recaller: recallers.hrfEnabled,
+ }),
+ hrfMethod: buildHandlers({
+ getLabel: () => t('hrf.metadata.method'),
+ parser: parsers.hrfMethod,
+ recaller: recallers.hrfMethod,
+ }),
+ hrfStrength: buildHandlers({
+ getLabel: () => t('hrf.metadata.strength'),
+ parser: parsers.hrfStrength,
+ recaller: recallers.hrfStrength,
+ }),
+
+ // Refiner
+ refinerCFGScale: buildHandlers({
+ getLabel: () => t('sdxl.cfgScale'),
+ parser: parsers.refinerCFGScale,
+ recaller: recallers.refinerCFGScale,
+ }),
+ refinerModel: buildHandlers({
+ getLabel: () => t('sdxl.refinerModel'),
+ parser: parsers.refinerModel,
+ recaller: recallers.refinerModel,
+ validator: validators.refinerModel,
+ }),
+ refinerNegativeAestheticScore: buildHandlers({
+ getLabel: () => t('sdxl.posAestheticScore'),
+ parser: parsers.refinerNegativeAestheticScore,
+ recaller: recallers.refinerNegativeAestheticScore,
+ }),
+ refinerPositiveAestheticScore: buildHandlers({
+ getLabel: () => t('sdxl.negAestheticScore'),
+ parser: parsers.refinerPositiveAestheticScore,
+ recaller: recallers.refinerPositiveAestheticScore,
+ }),
+ refinerScheduler: buildHandlers({
+ getLabel: () => t('sdxl.scheduler'),
+ parser: parsers.refinerScheduler,
+ recaller: recallers.refinerScheduler,
+ }),
+ refinerStart: buildHandlers({
+ getLabel: () => t('sdxl.refiner_start'),
+ parser: parsers.refinerStart,
+ recaller: recallers.refinerStart,
+ }),
+ refinerSteps: buildHandlers({
+ getLabel: () => t('sdxl.refiner_steps'),
+ parser: parsers.refinerSteps,
+ recaller: recallers.refinerSteps,
+ }),
+
+ // Models
+ model: buildHandlers({
+ getLabel: () => t('metadata.model'),
+ parser: parsers.mainModel,
+ recaller: recallers.model,
+ renderValue: renderModelConfigValue,
+ }),
+ vae: buildHandlers({
+ getLabel: () => t('metadata.vae'),
+ parser: parsers.vaeModel,
+ recaller: recallers.vae,
+ renderValue: renderModelConfigValue,
+ validator: validators.vaeModel,
+ }),
+
+ // Arrays of models
+ controlNets: buildHandlers({
+ getLabel: () => t('common.controlNet'),
+ parser: parsers.controlNets,
+ itemParser: parsers.controlNet,
+ recaller: recallers.controlNets,
+ itemRecaller: recallers.controlNet,
+ validator: validators.controlNets,
+ itemValidator: validators.controlNet,
+ renderItemValue: renderControlAdapterValue,
+ }),
+ ipAdapters: buildHandlers({
+ getLabel: () => t('common.ipAdapter'),
+ parser: parsers.ipAdapters,
+ itemParser: parsers.ipAdapter,
+ recaller: recallers.ipAdapters,
+ itemRecaller: recallers.ipAdapter,
+ validator: validators.ipAdapters,
+ itemValidator: validators.ipAdapter,
+ renderItemValue: renderControlAdapterValue,
+ }),
+ loras: buildHandlers({
+ getLabel: () => t('models.lora'),
+ parser: parsers.loras,
+ itemParser: parsers.lora,
+ recaller: recallers.loras,
+ itemRecaller: recallers.lora,
+ validator: validators.loras,
+ itemValidator: validators.lora,
+ renderItemValue: renderLoRAValue,
+ }),
+ t2iAdapters: buildHandlers({
+ getLabel: () => t('common.t2iAdapter'),
+ parser: parsers.t2iAdapters,
+ itemParser: parsers.t2iAdapter,
+ recaller: recallers.t2iAdapters,
+ itemRecaller: recallers.t2iAdapter,
+ validator: validators.t2iAdapters,
+ itemValidator: validators.t2iAdapter,
+ renderItemValue: renderControlAdapterValue,
+ }),
+} as const;
diff --git a/invokeai/frontend/web/src/features/metadata/util/parsers.ts b/invokeai/frontend/web/src/features/metadata/util/parsers.ts
new file mode 100644
index 0000000000..c7f2a6d09f
--- /dev/null
+++ b/invokeai/frontend/web/src/features/metadata/util/parsers.ts
@@ -0,0 +1,396 @@
+import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
+import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types';
+import {
+ initialControlNet,
+ initialIPAdapter,
+ initialT2IAdapter,
+} from 'features/controlAdapters/util/buildControlAdapter';
+import type { LoRA } from 'features/lora/store/loraSlice';
+import { defaultLoRAConfig } from 'features/lora/store/loraSlice';
+import { MetadataParseError } from 'features/metadata/exceptions';
+import type { MetadataParseFunc } from 'features/metadata/types';
+import {
+ zControlField,
+ zIPAdapterField,
+ zModelIdentifierWithBase,
+ zT2IAdapterField,
+} from 'features/nodes/types/common';
+import type {
+ ParameterCFGRescaleMultiplier,
+ ParameterCFGScale,
+ ParameterHeight,
+ ParameterHRFEnabled,
+ ParameterHRFMethod,
+ ParameterNegativePrompt,
+ ParameterNegativeStylePromptSDXL,
+ ParameterPositivePrompt,
+ ParameterPositiveStylePromptSDXL,
+ ParameterScheduler,
+ ParameterSDXLRefinerNegativeAestheticScore,
+ ParameterSDXLRefinerPositiveAestheticScore,
+ ParameterSDXLRefinerStart,
+ ParameterSeed,
+ ParameterSteps,
+ ParameterStrength,
+ ParameterWidth,
+} from 'features/parameters/types/parameterSchemas';
+import {
+ isParameterCFGRescaleMultiplier,
+ isParameterCFGScale,
+ isParameterHeight,
+ isParameterHRFEnabled,
+ isParameterHRFMethod,
+ isParameterLoRAWeight,
+ isParameterNegativePrompt,
+ isParameterNegativeStylePromptSDXL,
+ isParameterPositivePrompt,
+ isParameterPositiveStylePromptSDXL,
+ isParameterScheduler,
+ isParameterSDXLRefinerNegativeAestheticScore,
+ isParameterSDXLRefinerPositiveAestheticScore,
+ isParameterSDXLRefinerStart,
+ isParameterSeed,
+ isParameterSteps,
+ isParameterStrength,
+ isParameterWidth,
+} from 'features/parameters/types/parameterSchemas';
+import {
+ fetchModelConfigWithTypeGuard,
+ getModelKey,
+ getModelKeyAndBase,
+} from 'features/parameters/util/modelFetchingHelpers';
+import { get, isArray, isString } from 'lodash-es';
+import type { NonRefinerMainModelConfig, RefinerMainModelConfig, VAEModelConfig } from 'services/api/types';
+import {
+ isControlNetModelConfig,
+ isIPAdapterModelConfig,
+ isLoRAModelConfig,
+ isNonRefinerMainModelConfig,
+ isRefinerMainModelModelConfig,
+ isT2IAdapterModelConfig,
+ isVAEModelConfig,
+} from 'services/api/types';
+import { v4 as uuidv4 } from 'uuid';
+
+export const MetadataParsePendingToken = Symbol('pending');
+export const MetadataParseFailedToken = Symbol('failed');
+
+/**
+ * An async function that a property from an object and validates its type using a type guard. If the property is missing
+ * or invalid, the function should throw a MetadataParseError.
+ * @param obj The object to get the property from.
+ * @param property The property to get.
+ * @param typeGuard A type guard function to check the type of the property. Provide `undefined` to opt out of type
+ * validation and always return the property value.
+ * @returns A promise that resolves to the property value if it exists and is of the expected type.
+ * @throws MetadataParseError if a type guard is provided and the property is not of the expected type.
+ */
+const getProperty = (
+ obj: unknown,
+ property: string,
+ typeGuard: (val: unknown) => val is T = (val: unknown): val is T => true
+): Promise => {
+ return new Promise((resolve, reject) => {
+ const val = get(obj, property) as unknown;
+ if (typeGuard(val)) {
+ resolve(val);
+ }
+ reject(new MetadataParseError(`Property ${property} is not of expected type`));
+ });
+};
+
+const parseCreatedBy: MetadataParseFunc = (metadata) => getProperty(metadata, 'created_by', isString);
+
+const parseGenerationMode: MetadataParseFunc = (metadata) => getProperty(metadata, 'generation_mode', isString);
+
+const parsePositivePrompt: MetadataParseFunc = (metadata) =>
+ getProperty(metadata, 'positive_prompt', isParameterPositivePrompt);
+
+const parseNegativePrompt: MetadataParseFunc = (metadata) =>
+ getProperty(metadata, 'negative_prompt', isParameterNegativePrompt);
+
+const parseSDXLPositiveStylePrompt: MetadataParseFunc = (metadata) =>
+ getProperty(metadata, 'positive_style_prompt', isParameterPositiveStylePromptSDXL);
+
+const parseSDXLNegativeStylePrompt: MetadataParseFunc = (metadata) =>
+ getProperty(metadata, 'negative_style_prompt', isParameterNegativeStylePromptSDXL);
+
+const parseSeed: MetadataParseFunc = (metadata) => getProperty(metadata, 'seed', isParameterSeed);
+
+const parseCFGScale: MetadataParseFunc = (metadata) =>
+ getProperty(metadata, 'cfg_scale', isParameterCFGScale);
+
+const parseCFGRescaleMultiplier: MetadataParseFunc = (metadata) =>
+ getProperty(metadata, 'cfg_rescale_multiplier', isParameterCFGRescaleMultiplier);
+
+const parseScheduler: MetadataParseFunc = (metadata) =>
+ getProperty(metadata, 'scheduler', isParameterScheduler);
+
+const parseWidth: MetadataParseFunc = (metadata) => getProperty(metadata, 'width', isParameterWidth);
+
+const parseHeight: MetadataParseFunc = (metadata) =>
+ getProperty(metadata, 'height', isParameterHeight);
+
+const parseSteps: MetadataParseFunc = (metadata) => getProperty(metadata, 'steps', isParameterSteps);
+
+const parseStrength: MetadataParseFunc = (metadata) =>
+ getProperty(metadata, 'strength', isParameterStrength);
+
+const parseHRFEnabled: MetadataParseFunc = (metadata) =>
+ getProperty(metadata, 'hrf_enabled', isParameterHRFEnabled);
+
+const parseHRFStrength: MetadataParseFunc = (metadata) =>
+ getProperty(metadata, 'hrf_strength', isParameterStrength);
+
+const parseHRFMethod: MetadataParseFunc = (metadata) =>
+ getProperty(metadata, 'hrf_method', isParameterHRFMethod);
+
+const parseRefinerSteps: MetadataParseFunc = (metadata) =>
+ getProperty(metadata, 'refiner_steps', isParameterSteps);
+
+const parseRefinerCFGScale: MetadataParseFunc = (metadata) =>
+ getProperty(metadata, 'refiner_cfg_scale', isParameterCFGScale);
+
+const parseRefinerScheduler: MetadataParseFunc = (metadata) =>
+ getProperty(metadata, 'refiner_scheduler', isParameterScheduler);
+
+const parseRefinerPositiveAestheticScore: MetadataParseFunc = (metadata) =>
+ getProperty(metadata, 'refiner_positive_aesthetic_score', isParameterSDXLRefinerPositiveAestheticScore);
+
+const parseRefinerNegativeAestheticScore: MetadataParseFunc = (metadata) =>
+ getProperty(metadata, 'refiner_negative_aesthetic_score', isParameterSDXLRefinerNegativeAestheticScore);
+
+const parseRefinerStart: MetadataParseFunc = (metadata) =>
+ getProperty(metadata, 'refiner_start', isParameterSDXLRefinerStart);
+
+const parseMainModel: MetadataParseFunc = async (metadata) => {
+ const model = await getProperty(metadata, 'model', undefined);
+ const key = await getModelKey(model, 'main');
+ const mainModelConfig = await fetchModelConfigWithTypeGuard(key, isNonRefinerMainModelConfig);
+ return mainModelConfig;
+};
+
+const parseRefinerModel: MetadataParseFunc = async (metadata) => {
+ const refiner_model = await getProperty(metadata, 'refiner_model', undefined);
+ const key = await getModelKey(refiner_model, 'main');
+ const refinerModelConfig = await fetchModelConfigWithTypeGuard(key, isRefinerMainModelModelConfig);
+ return refinerModelConfig;
+};
+
+const parseVAEModel: MetadataParseFunc = async (metadata) => {
+ const vae = await getProperty(metadata, 'vae', undefined);
+ const key = await getModelKey(vae, 'vae');
+ const vaeModelConfig = await fetchModelConfigWithTypeGuard(key, isVAEModelConfig);
+ return vaeModelConfig;
+};
+
+const parseLoRA: MetadataParseFunc = async (metadataItem) => {
+ // Previously, the LoRA model identifier parts were stored in the LoRA metadata: `{key: ..., weight: 0.75}`
+ const modelV1 = await getProperty(metadataItem, 'lora', undefined);
+ // Now, the LoRA model is stored in a `model` property of the LoRA metadata: `{model: {key: ...}, weight: 0.75}`
+ const modelV2 = await getProperty(metadataItem, 'model', undefined);
+ const weight = await getProperty(metadataItem, 'weight', undefined);
+ const key = await getModelKey(modelV2 ?? modelV1, 'lora');
+ const loraModelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig);
+
+ return {
+ model: getModelKeyAndBase(loraModelConfig),
+ weight: isParameterLoRAWeight(weight) ? weight : defaultLoRAConfig.weight,
+ isEnabled: true,
+ };
+};
+
+const parseAllLoRAs: MetadataParseFunc = async (metadata) => {
+ const lorasRaw = await getProperty(metadata, 'loras', isArray);
+ const parseResults = await Promise.allSettled(lorasRaw.map((lora) => parseLoRA(lora)));
+ const loras = parseResults
+ .filter((result): result is PromiseFulfilledResult => result.status === 'fulfilled')
+ .map((result) => result.value);
+ return loras;
+};
+
+const parseControlNet: MetadataParseFunc = async (metadataItem) => {
+ const control_model = await getProperty(metadataItem, 'control_model');
+ const key = await getModelKey(control_model, 'controlnet');
+ const controlNetModel = await fetchModelConfigWithTypeGuard(key, isControlNetModelConfig);
+
+ const image = zControlField.shape.image.nullish().catch(null).parse(getProperty(metadataItem, 'image'));
+ const control_weight = zControlField.shape.control_weight
+ .nullish()
+ .catch(null)
+ .parse(getProperty(metadataItem, 'control_weight'));
+ const begin_step_percent = zControlField.shape.begin_step_percent
+ .nullish()
+ .catch(null)
+ .parse(getProperty(metadataItem, 'begin_step_percent'));
+ const end_step_percent = zControlField.shape.end_step_percent
+ .nullish()
+ .catch(null)
+ .parse(getProperty(metadataItem, 'end_step_percent'));
+ const control_mode = zControlField.shape.control_mode
+ .nullish()
+ .catch(null)
+ .parse(getProperty(metadataItem, 'control_mode'));
+ const resize_mode = zControlField.shape.resize_mode
+ .nullish()
+ .catch(null)
+ .parse(getProperty(metadataItem, 'resize_mode'));
+
+ const processorType = 'none';
+ const processorNode = CONTROLNET_PROCESSORS.none.default;
+
+ const controlNet: ControlNetConfig = {
+ type: 'controlnet',
+ isEnabled: true,
+ model: zModelIdentifierWithBase.parse(controlNetModel),
+ weight: typeof control_weight === 'number' ? control_weight : initialControlNet.weight,
+ beginStepPct: begin_step_percent ?? initialControlNet.beginStepPct,
+ endStepPct: end_step_percent ?? initialControlNet.endStepPct,
+ controlMode: control_mode ?? initialControlNet.controlMode,
+ resizeMode: resize_mode ?? initialControlNet.resizeMode,
+ controlImage: image?.image_name ?? null,
+ processedControlImage: image?.image_name ?? null,
+ processorType,
+ processorNode,
+ shouldAutoConfig: true,
+ id: uuidv4(),
+ };
+
+ return controlNet;
+};
+
+const parseAllControlNets: MetadataParseFunc = async (metadata) => {
+ const controlNetsRaw = await getProperty(metadata, 'controlnets', isArray);
+ const parseResults = await Promise.allSettled(controlNetsRaw.map((cn) => parseControlNet(cn)));
+ const controlNets = parseResults
+ .filter((result): result is PromiseFulfilledResult => result.status === 'fulfilled')
+ .map((result) => result.value);
+ return controlNets;
+};
+
+const parseT2IAdapter: MetadataParseFunc = async (metadataItem) => {
+ const t2i_adapter_model = await getProperty(metadataItem, 't2i_adapter_model');
+ const key = await getModelKey(t2i_adapter_model, 't2i_adapter');
+ const t2iAdapterModel = await fetchModelConfigWithTypeGuard(key, isT2IAdapterModelConfig);
+
+ const image = zT2IAdapterField.shape.image.nullish().catch(null).parse(getProperty(metadataItem, 'image'));
+ const weight = zT2IAdapterField.shape.weight.nullish().catch(null).parse(getProperty(metadataItem, 'weight'));
+ const begin_step_percent = zT2IAdapterField.shape.begin_step_percent
+ .nullish()
+ .catch(null)
+ .parse(getProperty(metadataItem, 'begin_step_percent'));
+ const end_step_percent = zT2IAdapterField.shape.end_step_percent
+ .nullish()
+ .catch(null)
+ .parse(getProperty(metadataItem, 'end_step_percent'));
+ const resize_mode = zT2IAdapterField.shape.resize_mode
+ .nullish()
+ .catch(null)
+ .parse(getProperty(metadataItem, 'resize_mode'));
+
+ const processorType = 'none';
+ const processorNode = CONTROLNET_PROCESSORS.none.default;
+
+ const t2iAdapter: T2IAdapterConfig = {
+ type: 't2i_adapter',
+ isEnabled: true,
+ model: zModelIdentifierWithBase.parse(t2iAdapterModel),
+ weight: typeof weight === 'number' ? weight : initialT2IAdapter.weight,
+ beginStepPct: begin_step_percent ?? initialT2IAdapter.beginStepPct,
+ endStepPct: end_step_percent ?? initialT2IAdapter.endStepPct,
+ resizeMode: resize_mode ?? initialT2IAdapter.resizeMode,
+ controlImage: image?.image_name ?? null,
+ processedControlImage: image?.image_name ?? null,
+ processorType,
+ processorNode,
+ shouldAutoConfig: true,
+ id: uuidv4(),
+ };
+
+ return t2iAdapter;
+};
+
+const parseAllT2IAdapters: MetadataParseFunc = async (metadata) => {
+ const t2iAdaptersRaw = await getProperty(metadata, 't2iAdapters', isArray);
+ const parseResults = await Promise.allSettled(t2iAdaptersRaw.map((t2iAdapter) => parseT2IAdapter(t2iAdapter)));
+ const t2iAdapters = parseResults
+ .filter((result): result is PromiseFulfilledResult => result.status === 'fulfilled')
+ .map((result) => result.value);
+ return t2iAdapters;
+};
+
+const parseIPAdapter: MetadataParseFunc = async (metadataItem) => {
+ const ip_adapter_model = await getProperty(metadataItem, 'ip_adapter_model');
+ const key = await getModelKey(ip_adapter_model, 'ip_adapter');
+ const ipAdapterModel = await fetchModelConfigWithTypeGuard(key, isIPAdapterModelConfig);
+
+ const image = zIPAdapterField.shape.image.nullish().catch(null).parse(getProperty(metadataItem, 'image'));
+ const weight = zIPAdapterField.shape.weight.nullish().catch(null).parse(getProperty(metadataItem, 'weight'));
+ const begin_step_percent = zIPAdapterField.shape.begin_step_percent
+ .nullish()
+ .catch(null)
+ .parse(getProperty(metadataItem, 'begin_step_percent'));
+ const end_step_percent = zIPAdapterField.shape.end_step_percent
+ .nullish()
+ .catch(null)
+ .parse(getProperty(metadataItem, 'end_step_percent'));
+
+ const ipAdapter: IPAdapterConfig = {
+ id: uuidv4(),
+ type: 'ip_adapter',
+ isEnabled: true,
+ model: zModelIdentifierWithBase.parse(ipAdapterModel),
+ controlImage: image?.image_name ?? null,
+ weight: weight ?? initialIPAdapter.weight,
+ beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct,
+ endStepPct: end_step_percent ?? initialIPAdapter.endStepPct,
+ };
+
+ return ipAdapter;
+};
+
+const parseAllIPAdapters: MetadataParseFunc = async (metadata) => {
+ const ipAdaptersRaw = await getProperty(metadata, 'ipAdapters', isArray);
+ const parseResults = await Promise.allSettled(ipAdaptersRaw.map((ipAdapter) => parseIPAdapter(ipAdapter)));
+ const ipAdapters = parseResults
+ .filter((result): result is PromiseFulfilledResult => result.status === 'fulfilled')
+ .map((result) => result.value);
+ return ipAdapters;
+};
+
+export const parsers = {
+ createdBy: parseCreatedBy,
+ generationMode: parseGenerationMode,
+ positivePrompt: parsePositivePrompt,
+ negativePrompt: parseNegativePrompt,
+ sdxlPositiveStylePrompt: parseSDXLPositiveStylePrompt,
+ sdxlNegativeStylePrompt: parseSDXLNegativeStylePrompt,
+ seed: parseSeed,
+ cfgScale: parseCFGScale,
+ cfgRescaleMultiplier: parseCFGRescaleMultiplier,
+ scheduler: parseScheduler,
+ width: parseWidth,
+ height: parseHeight,
+ steps: parseSteps,
+ strength: parseStrength,
+ hrfEnabled: parseHRFEnabled,
+ hrfStrength: parseHRFStrength,
+ hrfMethod: parseHRFMethod,
+ refinerSteps: parseRefinerSteps,
+ refinerCFGScale: parseRefinerCFGScale,
+ refinerScheduler: parseRefinerScheduler,
+ refinerPositiveAestheticScore: parseRefinerPositiveAestheticScore,
+ refinerNegativeAestheticScore: parseRefinerNegativeAestheticScore,
+ refinerStart: parseRefinerStart,
+ mainModel: parseMainModel,
+ refinerModel: parseRefinerModel,
+ vaeModel: parseVAEModel,
+ lora: parseLoRA,
+ loras: parseAllLoRAs,
+ controlNet: parseControlNet,
+ controlNets: parseAllControlNets,
+ t2iAdapter: parseT2IAdapter,
+ t2iAdapters: parseAllT2IAdapters,
+ ipAdapter: parseIPAdapter,
+ ipAdapters: parseAllIPAdapters,
+} as const;
diff --git a/invokeai/frontend/web/src/features/metadata/util/recallers.ts b/invokeai/frontend/web/src/features/metadata/util/recallers.ts
new file mode 100644
index 0000000000..e33589d5ce
--- /dev/null
+++ b/invokeai/frontend/web/src/features/metadata/util/recallers.ts
@@ -0,0 +1,295 @@
+import { getStore } from 'app/store/nanostores/store';
+import { controlAdapterRecalled } from 'features/controlAdapters/store/controlAdaptersSlice';
+import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types';
+import { setHrfEnabled, setHrfMethod, setHrfStrength } from 'features/hrf/store/hrfSlice';
+import type { LoRA } from 'features/lora/store/loraSlice';
+import { loraRecalled } from 'features/lora/store/loraSlice';
+import type { MetadataRecallFunc } from 'features/metadata/types';
+import { zModelIdentifierWithBase } from 'features/nodes/types/common';
+import { modelSelected } from 'features/parameters/store/actions';
+import {
+ heightRecalled,
+ setCfgRescaleMultiplier,
+ setCfgScale,
+ setImg2imgStrength,
+ setNegativePrompt,
+ setPositivePrompt,
+ setScheduler,
+ setSeed,
+ setSteps,
+ vaeSelected,
+ widthRecalled,
+} from 'features/parameters/store/generationSlice';
+import type {
+ ParameterCFGRescaleMultiplier,
+ ParameterCFGScale,
+ ParameterHeight,
+ ParameterHRFEnabled,
+ ParameterHRFMethod,
+ ParameterNegativePrompt,
+ ParameterNegativeStylePromptSDXL,
+ ParameterPositivePrompt,
+ ParameterPositiveStylePromptSDXL,
+ ParameterScheduler,
+ ParameterSDXLRefinerNegativeAestheticScore,
+ ParameterSDXLRefinerPositiveAestheticScore,
+ ParameterSDXLRefinerStart,
+ ParameterSeed,
+ ParameterSteps,
+ ParameterStrength,
+ ParameterWidth,
+} from 'features/parameters/types/parameterSchemas';
+import {
+ refinerModelChanged,
+ setNegativeStylePromptSDXL,
+ setPositiveStylePromptSDXL,
+ setRefinerCFGScale,
+ setRefinerNegativeAestheticScore,
+ setRefinerPositiveAestheticScore,
+ setRefinerScheduler,
+ setRefinerStart,
+ setRefinerSteps,
+} from 'features/sdxl/store/sdxlSlice';
+import type { NonRefinerMainModelConfig, RefinerMainModelConfig, VAEModelConfig } from 'services/api/types';
+
+const recallPositivePrompt: MetadataRecallFunc = (positivePrompt) => {
+ getStore().dispatch(setPositivePrompt(positivePrompt));
+};
+
+const recallNegativePrompt: MetadataRecallFunc = (negativePrompt) => {
+ getStore().dispatch(setNegativePrompt(negativePrompt));
+};
+
+const recallSDXLPositiveStylePrompt: MetadataRecallFunc = (positiveStylePrompt) => {
+ getStore().dispatch(setPositiveStylePromptSDXL(positiveStylePrompt));
+};
+
+const recallSDXLNegativeStylePrompt: MetadataRecallFunc = (negativeStylePrompt) => {
+ getStore().dispatch(setNegativeStylePromptSDXL(negativeStylePrompt));
+};
+
+const recallSeed: MetadataRecallFunc = (seed) => {
+ getStore().dispatch(setSeed(seed));
+};
+
+const recallCFGScale: MetadataRecallFunc = (cfgScale) => {
+ getStore().dispatch(setCfgScale(cfgScale));
+};
+
+const recallCFGRescaleMultiplier: MetadataRecallFunc = (cfgRescaleMultiplier) => {
+ getStore().dispatch(setCfgRescaleMultiplier(cfgRescaleMultiplier));
+};
+
+const recallScheduler: MetadataRecallFunc = (scheduler) => {
+ getStore().dispatch(setScheduler(scheduler));
+};
+
+const recallWidth: MetadataRecallFunc = (width) => {
+ getStore().dispatch(widthRecalled(width));
+};
+
+const recallHeight: MetadataRecallFunc = (height) => {
+ getStore().dispatch(heightRecalled(height));
+};
+
+const recallSteps: MetadataRecallFunc = (steps) => {
+ getStore().dispatch(setSteps(steps));
+};
+
+const recallStrength: MetadataRecallFunc = (strength) => {
+ getStore().dispatch(setImg2imgStrength(strength));
+};
+
+const recallHRFEnabled: MetadataRecallFunc = (hrfEnabled) => {
+ getStore().dispatch(setHrfEnabled(hrfEnabled));
+};
+
+const recallHRFStrength: MetadataRecallFunc = (hrfStrength) => {
+ getStore().dispatch(setHrfStrength(hrfStrength));
+};
+
+const recallHRFMethod: MetadataRecallFunc = (hrfMethod) => {
+ getStore().dispatch(setHrfMethod(hrfMethod));
+};
+
+const recallRefinerSteps: MetadataRecallFunc = (refinerSteps) => {
+ getStore().dispatch(setRefinerSteps(refinerSteps));
+};
+
+const recallRefinerCFGScale: MetadataRecallFunc = (refinerCFGScale) => {
+ getStore().dispatch(setRefinerCFGScale(refinerCFGScale));
+};
+
+const recallRefinerScheduler: MetadataRecallFunc = (refinerScheduler) => {
+ getStore().dispatch(setRefinerScheduler(refinerScheduler));
+};
+
+const recallRefinerPositiveAestheticScore: MetadataRecallFunc = (
+ refinerPositiveAestheticScore
+) => {
+ getStore().dispatch(setRefinerPositiveAestheticScore(refinerPositiveAestheticScore));
+};
+
+const recallRefinerNegativeAestheticScore: MetadataRecallFunc = (
+ refinerNegativeAestheticScore
+) => {
+ getStore().dispatch(setRefinerNegativeAestheticScore(refinerNegativeAestheticScore));
+};
+
+const recallRefinerStart: MetadataRecallFunc = (refinerStart) => {
+ getStore().dispatch(setRefinerStart(refinerStart));
+};
+
+const recallModel: MetadataRecallFunc = (model) => {
+ const modelIdentifier = zModelIdentifierWithBase.parse(model);
+ getStore().dispatch(modelSelected(modelIdentifier));
+};
+
+const recallRefinerModel: MetadataRecallFunc = (refinerModel) => {
+ const modelIdentifier = zModelIdentifierWithBase.parse(refinerModel);
+ getStore().dispatch(refinerModelChanged(modelIdentifier));
+};
+
+const recallVAE: MetadataRecallFunc = (vaeModel) => {
+ if (!vaeModel) {
+ getStore().dispatch(vaeSelected(null));
+ return;
+ }
+ const modelIdentifier = zModelIdentifierWithBase.parse(vaeModel);
+ getStore().dispatch(vaeSelected(modelIdentifier));
+};
+
+const recallLoRA: MetadataRecallFunc = (lora) => {
+ getStore().dispatch(loraRecalled(lora));
+};
+
+const recallAllLoRAs: MetadataRecallFunc = (loras) => {
+ const { dispatch } = getStore();
+ loras.forEach((lora) => {
+ dispatch(loraRecalled(lora));
+ });
+};
+
+const recallControlNet: MetadataRecallFunc = (controlNet) => {
+ getStore().dispatch(controlAdapterRecalled(controlNet));
+};
+
+const recallControlNets: MetadataRecallFunc = (controlNets) => {
+ const { dispatch } = getStore();
+ controlNets.forEach((controlNet) => {
+ dispatch(controlAdapterRecalled(controlNet));
+ });
+};
+
+const recallT2IAdapter: MetadataRecallFunc = (t2iAdapter) => {
+ getStore().dispatch(controlAdapterRecalled(t2iAdapter));
+};
+
+const recallT2IAdapters: MetadataRecallFunc = (t2iAdapters) => {
+ const { dispatch } = getStore();
+ t2iAdapters.forEach((t2iAdapter) => {
+ dispatch(controlAdapterRecalled(t2iAdapter));
+ });
+};
+
+const recallIPAdapter: MetadataRecallFunc = (ipAdapter) => {
+ getStore().dispatch(controlAdapterRecalled(ipAdapter));
+};
+
+const recallIPAdapters: MetadataRecallFunc = (ipAdapters) => {
+ const { dispatch } = getStore();
+ ipAdapters.forEach((ipAdapter) => {
+ dispatch(controlAdapterRecalled(ipAdapter));
+ });
+};
+
+export const recallPrompts = (_metadata: unknown) => {
+ // recallPositivePrompt(metadata);
+ // recallNegativePrompt(metadata);
+ // recallSDXLPositiveStylePrompt(metadata);
+ // recallSDXLNegativeStylePrompt(metadata);
+ // parameterNotSetToast(t('metadata.allPrompts'));
+ // parameterSetToast(t('metadata.allPrompts'));
+};
+
+export const recallWidthAndHeight = (_metadata: unknown) => {
+ // recallWidth(metadata);
+ // recallHeight(metadata);
+};
+
+export const recallAll = async (_metadata: unknown) => {
+ // if (!metadata) {
+ // allParameterNotSetToast();
+ // return;
+ // }
+ // // Update the main model first, as other parameters may depend on it.
+ // await recallModel(metadata);
+ // await Promise.allSettled([
+ // // Core parameters
+ // recallCFGScale(metadata),
+ // recallCFGRescaleMultiplier(metadata),
+ // recallPositivePrompt(metadata),
+ // recallNegativePrompt(metadata),
+ // recallScheduler(metadata),
+ // recallSeed(metadata),
+ // recallSteps(metadata),
+ // recallWidth(metadata),
+ // recallHeight(metadata),
+ // recallStrength(metadata),
+ // recallHRFEnabled(metadata),
+ // recallHRFMethod(metadata),
+ // recallHRFStrength(metadata),
+ // // SDXL parameters
+ // recallSDXLPositiveStylePrompt(metadata),
+ // recallSDXLNegativeStylePrompt(metadata),
+ // recallRefinerSteps(metadata),
+ // recallRefinerCFGScale(metadata),
+ // recallRefinerScheduler(metadata),
+ // recallRefinerPositiveAestheticScore(metadata),
+ // recallRefinerNegativeAestheticScore(metadata),
+ // recallRefinerStart(metadata),
+ // // Models
+ // recallVAE(metadata),
+ // recallRefinerModel(metadata),
+ // recallAllLoRAs(metadata),
+ // recallControlNets(metadata),
+ // recallT2IAdapters(metadata),
+ // recallIPAdapters(metadata),
+ // ]);
+ // allParameterSetToast();
+};
+
+export const recallers = {
+ positivePrompt: recallPositivePrompt,
+ negativePrompt: recallNegativePrompt,
+ sdxlPositiveStylePrompt: recallSDXLPositiveStylePrompt,
+ sdxlNegativeStylePrompt: recallSDXLNegativeStylePrompt,
+ seed: recallSeed,
+ cfgScale: recallCFGScale,
+ cfgRescaleMultiplier: recallCFGRescaleMultiplier,
+ scheduler: recallScheduler,
+ width: recallWidth,
+ height: recallHeight,
+ steps: recallSteps,
+ strength: recallStrength,
+ hrfEnabled: recallHRFEnabled,
+ hrfStrength: recallHRFStrength,
+ hrfMethod: recallHRFMethod,
+ refinerSteps: recallRefinerSteps,
+ refinerCFGScale: recallRefinerCFGScale,
+ refinerScheduler: recallRefinerScheduler,
+ refinerPositiveAestheticScore: recallRefinerPositiveAestheticScore,
+ refinerNegativeAestheticScore: recallRefinerNegativeAestheticScore,
+ refinerStart: recallRefinerStart,
+ model: recallModel,
+ refinerModel: recallRefinerModel,
+ vae: recallVAE,
+ lora: recallLoRA,
+ loras: recallAllLoRAs,
+ controlNets: recallControlNets,
+ controlNet: recallControlNet,
+ t2iAdapters: recallT2IAdapters,
+ t2iAdapter: recallT2IAdapter,
+ ipAdapters: recallIPAdapters,
+ ipAdapter: recallIPAdapter,
+} as const;
diff --git a/invokeai/frontend/web/src/features/metadata/util/validators.ts b/invokeai/frontend/web/src/features/metadata/util/validators.ts
new file mode 100644
index 0000000000..e627afeda3
--- /dev/null
+++ b/invokeai/frontend/web/src/features/metadata/util/validators.ts
@@ -0,0 +1,117 @@
+import { getStore } from 'app/store/nanostores/store';
+import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types';
+import type { LoRA } from 'features/lora/store/loraSlice';
+import type { MetadataValidateFunc } from 'features/metadata/types';
+import { InvalidModelConfigError } from 'features/parameters/util/modelFetchingHelpers';
+import type { BaseModelType, RefinerMainModelConfig, VAEModelConfig } from 'services/api/types';
+
+/**
+ * Checks the given base model type against the currently-selected model's base type and throws an error if they are
+ * incompatible.
+ * @param base The base model type to validate.
+ * @param message An optional message to use in the error if the base model is incompatible.
+ */
+const validateBaseCompatibility = (base?: BaseModelType, message?: string) => {
+ if (!base) {
+ throw new InvalidModelConfigError(message || 'Missing base');
+ }
+ const currentBase = getStore().getState().generation.model?.base;
+ if (currentBase && base !== currentBase) {
+ throw new InvalidModelConfigError(message || `Incompatible base models: ${base} and ${currentBase}`);
+ }
+};
+
+const validateRefinerModel: MetadataValidateFunc = (refinerModel) => {
+ validateBaseCompatibility('sdxl', 'Refiner incompatible with currently-selected model');
+ return new Promise((resolve) => resolve(refinerModel));
+};
+
+const validateVAEModel: MetadataValidateFunc = (vaeModel) => {
+ validateBaseCompatibility(vaeModel.base, 'VAE incompatible with currently-selected model');
+ return new Promise((resolve) => resolve(vaeModel));
+};
+
+const validateLoRA: MetadataValidateFunc = (lora) => {
+ validateBaseCompatibility(lora.model.base, 'LoRA incompatible with currently-selected model');
+ return new Promise((resolve) => resolve(lora));
+};
+
+const validateLoRAs: MetadataValidateFunc = (loras) => {
+ const validatedLoRAs: LoRA[] = [];
+ loras.forEach((lora) => {
+ try {
+ validateBaseCompatibility(lora.model.base, 'LoRA incompatible with currently-selected model');
+ validatedLoRAs.push(lora);
+ } catch {
+ // This is a no-op - we want to continue validating the rest of the LoRAs, and an empty list is valid.
+ }
+ });
+ return new Promise((resolve) => resolve(validatedLoRAs));
+};
+
+const validateControlNet: MetadataValidateFunc = (controlNet) => {
+ validateBaseCompatibility(controlNet.model?.base, 'ControlNet incompatible with currently-selected model');
+ return new Promise((resolve) => resolve(controlNet));
+};
+
+const validateControlNets: MetadataValidateFunc = (controlNets) => {
+ const validatedControlNets: ControlNetConfig[] = [];
+ controlNets.forEach((controlNet) => {
+ try {
+ validateBaseCompatibility(controlNet.model?.base, 'ControlNet incompatible with currently-selected model');
+ validatedControlNets.push(controlNet);
+ } catch {
+ // This is a no-op - we want to continue validating the rest of the ControlNets, and an empty list is valid.
+ }
+ });
+ return new Promise((resolve) => resolve(validatedControlNets));
+};
+
+const validateT2IAdapter: MetadataValidateFunc = (t2iAdapter) => {
+ validateBaseCompatibility(t2iAdapter.model?.base, 'T2I Adapter incompatible with currently-selected model');
+ return new Promise((resolve) => resolve(t2iAdapter));
+};
+
+const validateT2IAdapters: MetadataValidateFunc = (t2iAdapters) => {
+ const validatedT2IAdapters: T2IAdapterConfig[] = [];
+ t2iAdapters.forEach((t2iAdapter) => {
+ try {
+ validateBaseCompatibility(t2iAdapter.model?.base, 'T2I Adapter incompatible with currently-selected model');
+ validatedT2IAdapters.push(t2iAdapter);
+ } catch {
+ // This is a no-op - we want to continue validating the rest of the T2I Adapters, and an empty list is valid.
+ }
+ });
+ return new Promise((resolve) => resolve(validatedT2IAdapters));
+};
+
+const validateIPAdapter: MetadataValidateFunc = (ipAdapter) => {
+ validateBaseCompatibility(ipAdapter.model?.base, 'IP Adapter incompatible with currently-selected model');
+ return new Promise((resolve) => resolve(ipAdapter));
+};
+
+const validateIPAdapters: MetadataValidateFunc = (ipAdapters) => {
+ const validatedIPAdapters: IPAdapterConfig[] = [];
+ ipAdapters.forEach((ipAdapter) => {
+ try {
+ validateBaseCompatibility(ipAdapter.model?.base, 'IP Adapter incompatible with currently-selected model');
+ validatedIPAdapters.push(ipAdapter);
+ } catch {
+ // This is a no-op - we want to continue validating the rest of the IP Adapters, and an empty list is valid.
+ }
+ });
+ return new Promise((resolve) => resolve(validatedIPAdapters));
+};
+
+export const validators = {
+ refinerModel: validateRefinerModel,
+ vaeModel: validateVAEModel,
+ lora: validateLoRA,
+ loras: validateLoRAs,
+ controlNet: validateControlNet,
+ controlNets: validateControlNets,
+ t2iAdapter: validateT2IAdapter,
+ t2iAdapters: validateT2IAdapters,
+ ipAdapter: validateIPAdapter,
+ ipAdapters: validateIPAdapters,
+} as const;
diff --git a/invokeai/frontend/web/src/features/nodes/types/common.ts b/invokeai/frontend/web/src/features/nodes/types/common.ts
index b195ce4434..501286c785 100644
--- a/invokeai/frontend/web/src/features/nodes/types/common.ts
+++ b/invokeai/frontend/web/src/features/nodes/types/common.ts
@@ -1,5 +1,8 @@
import { z } from 'zod';
+import type { ModelIdentifier as ModelIdentifierV2 } from './v2/common';
+import { zModelIdentifier as zModelIdentifierV2 } from './v2/common';
+
// #region Field data schemas
export const zImageField = z.object({
image_name: z.string().trim().min(1),
@@ -69,6 +72,8 @@ export const zModelIdentifier = z.object({
});
export const isModelIdentifier = (field: unknown): field is ModelIdentifier =>
zModelIdentifier.safeParse(field).success;
+export const isModelIdentifierV2 = (field: unknown): field is ModelIdentifierV2 =>
+ zModelIdentifierV2.safeParse(field).success;
export const zModelFieldBase = zModelIdentifier;
export const zModelIdentifierWithBase = zModelIdentifier.extend({ base: zBaseModel });
export type BaseModel = z.infer;
diff --git a/invokeai/frontend/web/src/features/nodes/types/metadata.ts b/invokeai/frontend/web/src/features/nodes/types/metadata.ts
index 493a0464b3..2789abaaca 100644
--- a/invokeai/frontend/web/src/features/nodes/types/metadata.ts
+++ b/invokeai/frontend/web/src/features/nodes/types/metadata.ts
@@ -1,29 +1,22 @@
import { z } from 'zod';
-import {
- zControlField,
- zIPAdapterField,
- zMainModelField,
- zModelFieldBase,
- zSDXLRefinerModelField,
- zT2IAdapterField,
- zVAEModelField,
-} from './common';
+import { zControlField, zIPAdapterField, zT2IAdapterField } from './common';
+export const zLoRAWeight = z.number().nullish();
// #region Metadata-optimized versions of schemas
// TODO: It's possible that `deepPartial` will be deprecated:
// - https://github.com/colinhacks/zod/issues/2106
// - https://github.com/colinhacks/zod/issues/2854
export const zLoRAMetadataItem = z.object({
- lora: zModelFieldBase.deepPartial(),
- weight: z.number(),
+ lora: z.unknown(),
+ weight: zLoRAWeight,
});
-const zControlNetMetadataItem = zControlField.deepPartial();
-const zIPAdapterMetadataItem = zIPAdapterField.deepPartial();
-const zT2IAdapterMetadataItem = zT2IAdapterField.deepPartial();
-const zSDXLRefinerModelMetadataItem = zSDXLRefinerModelField.deepPartial();
-const zModelMetadataItem = zMainModelField.deepPartial();
-const zVAEModelMetadataItem = zVAEModelField.deepPartial();
+const zControlNetMetadataItem = zControlField.merge(z.object({ control_model: z.unknown() })).deepPartial();
+const zIPAdapterMetadataItem = zIPAdapterField.merge(z.object({ ip_adapter_model: z.unknown() })).deepPartial();
+const zT2IAdapterMetadataItem = zT2IAdapterField.merge(z.object({ t2i_adapter_model: z.unknown() })).deepPartial();
+const zSDXLRefinerModelMetadataItem = z.unknown();
+const zModelMetadataItem = z.unknown();
+const zVAEModelMetadataItem = z.unknown();
export type LoRAMetadataItem = z.infer;
export type ControlNetMetadataItem = z.infer;
export type IPAdapterMetadataItem = z.infer;
diff --git a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts
index 0929fc1dc3..8f9a76feca 100644
--- a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts
+++ b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts
@@ -1,17 +1,25 @@
-import { useAppToaster } from 'app/components/Toaster';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
-import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
+import { getStore } from 'app/store/nanostores/store';
+import { useAppDispatch } from 'app/store/storeHooks';
+import { toast } from 'common/util/toast';
+import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
import { controlAdapterRecalled, controlAdaptersReset } from 'features/controlAdapters/store/controlAdaptersSlice';
+import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types';
+import {
+ initialControlNet,
+ initialIPAdapter,
+ initialT2IAdapter,
+} from 'features/controlAdapters/util/buildControlAdapter';
import { setHrfEnabled, setHrfMethod, setHrfStrength } from 'features/hrf/store/hrfSlice';
-import { loraRecalled, lorasCleared } from 'features/lora/store/loraSlice';
-import { isModelIdentifier } from 'features/nodes/types/common';
-import type {
- ControlNetMetadataItem,
- CoreMetadata,
- IPAdapterMetadataItem,
- LoRAMetadataItem,
- T2IAdapterMetadataItem,
-} from 'features/nodes/types/metadata';
+import type { LoRA } from 'features/lora/store/loraSlice';
+import { defaultLoRAConfig, loraRecalled, lorasCleared } from 'features/lora/store/loraSlice';
+import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
+import {
+ zControlField,
+ zIPAdapterField,
+ zModelIdentifierWithBase,
+ zT2IAdapterField,
+} from 'features/nodes/types/common';
import { initialImageSelected, modelSelected } from 'features/parameters/store/actions';
import {
heightRecalled,
@@ -27,19 +35,18 @@ import {
vaeSelected,
widthRecalled,
} from 'features/parameters/store/generationSlice';
-import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
import {
isParameterCFGRescaleMultiplier,
isParameterCFGScale,
isParameterHeight,
isParameterHRFEnabled,
isParameterHRFMethod,
+ isParameterLoRAWeight,
isParameterNegativePrompt,
isParameterNegativeStylePromptSDXL,
isParameterPositivePrompt,
isParameterPositiveStylePromptSDXL,
isParameterScheduler,
- isParameterSDXLRefinerModel,
isParameterSDXLRefinerNegativeAestheticScore,
isParameterSDXLRefinerPositiveAestheticScore,
isParameterSDXLRefinerStart,
@@ -49,13 +56,16 @@ import {
isParameterWidth,
} from 'features/parameters/types/parameterSchemas';
import {
- prepareControlNetMetadataItem,
- prepareIPAdapterMetadataItem,
- prepareLoRAMetadataItem,
- prepareMainModelMetadataItem,
- prepareT2IAdapterMetadataItem,
- prepareVAEMetadataItem,
-} from 'features/parameters/util/modelMetadataHelpers';
+ fetchControlNetModel,
+ fetchIPAdapterModel,
+ fetchLoRAModel,
+ fetchMainModelConfig,
+ fetchRefinerModelConfig,
+ fetchT2IAdapterModel,
+ fetchVAEModelConfig,
+ getModelKey,
+ raiseIfBaseIncompatible,
+} from 'features/parameters/util/modelFetchingHelpers';
import {
refinerModelChanged,
setNegativeStylePromptSDXL,
@@ -67,229 +77,663 @@ import {
setRefinerStart,
setRefinerSteps,
} from 'features/sdxl/store/sdxlSlice';
-import { isNil } from 'lodash-es';
+import { t } from 'i18next';
+import { get, isArray, isNil } from 'lodash-es';
import { useCallback } from 'react';
-import { useTranslation } from 'react-i18next';
-import type { ImageDTO } from 'services/api/types';
+import type { BaseModelType, ImageDTO } from 'services/api/types';
+import { v4 as uuidv4 } from 'uuid';
const selectModel = createMemoizedSelector(selectGenerationSlice, (generation) => generation.model);
+/**
+ * A function that recalls from metadata from the full metadata object.
+ */
+type MetadataRecallFunc = (metadata: unknown, withToast?: boolean) => void;
+
+/**
+ * A function that recalls metadata from a specific metadata item.
+ */
+type MetadataItemRecallFunc = (metadataItem: unknown, withToast?: boolean) => void;
+
+/**
+ * Raised when metadata recall fails.
+ */
+export class MetadataRecallError extends Error {
+ /**
+ * Create MetadataRecallError
+ * @param {String} message
+ */
+ constructor(message: string) {
+ super(message);
+ this.name = this.constructor.name;
+ }
+}
+
+export class InvalidMetadataPropertyType extends MetadataRecallError {}
+
+const getProperty = (
+ obj: unknown,
+ property: string,
+ typeGuard: (val: unknown) => val is T = (val: unknown): val is T => true
+): T => {
+ const val = get(obj, property) as unknown;
+ if (typeGuard(val)) {
+ return val;
+ }
+ throw new InvalidMetadataPropertyType(`Property ${property} is not of expected type`);
+};
+
+const getCurrentBase = () => selectModel(getStore().getState())?.base;
+
+const parameterSetToast = (parameter: string, description?: string) => {
+ toast({
+ title: t('toast.parameterSet', { parameter }),
+ description,
+ status: 'info',
+ duration: 2500,
+ isClosable: true,
+ });
+};
+
+const parameterNotSetToast = (parameter: string, description?: string) => {
+ toast({
+ title: t('toast.parameterNotSet', { parameter }),
+ description,
+ status: 'warning',
+ duration: 2500,
+ isClosable: true,
+ });
+};
+
+const allParameterSetToast = (description?: string) => {
+ toast({
+ title: t('toast.parametersSet'),
+ status: 'info',
+ description,
+ duration: 2500,
+ isClosable: true,
+ });
+};
+
+const allParameterNotSetToast = (description?: string) => {
+ toast({
+ title: t('toast.parametersNotSet'),
+ status: 'warning',
+ description,
+ duration: 2500,
+ isClosable: true,
+ });
+};
+
+const recall = (callback: () => void, parameter: string, withToast = true) => {
+ try {
+ callback();
+ withToast && parameterSetToast(parameter);
+ } catch (e) {
+ withToast && parameterNotSetToast(parameter, (e as Error).message);
+ }
+};
+
+const recallAsync = async (callback: () => Promise, parameter: string, withToast = true) => {
+ try {
+ await callback();
+ withToast && parameterSetToast(parameter);
+ } catch (e) {
+ withToast && parameterNotSetToast(parameter, (e as Error).message);
+ }
+};
+
+export const recallPositivePrompt: MetadataRecallFunc = (metadata: unknown, withToast = true) => {
+ recall(
+ () => {
+ const positive_prompt = getProperty(metadata, 'positive_prompt', isParameterPositivePrompt);
+ getStore().dispatch(setPositivePrompt(positive_prompt));
+ },
+ t('metadata.positivePrompt'),
+ withToast
+ );
+};
+
+export const recallNegativePrompt: MetadataRecallFunc = (metadata: unknown, withToast = true) => {
+ recall(
+ () => {
+ const negative_prompt = getProperty(metadata, 'negative_prompt', isParameterNegativePrompt);
+ getStore().dispatch(setNegativePrompt(negative_prompt));
+ },
+ t('metadata.negativePrompt'),
+ withToast
+ );
+};
+
+export const recallSDXLPositiveStylePrompt: MetadataRecallFunc = (metadata: unknown, withToast = true) => {
+ recall(
+ () => {
+ const positive_style_prompt = getProperty(metadata, 'positive_style_prompt', isParameterPositiveStylePromptSDXL);
+ getStore().dispatch(setPositiveStylePromptSDXL(positive_style_prompt));
+ },
+ t('sdxl.posStylePrompt'),
+ withToast
+ );
+};
+
+export const recallSDXLNegativeStylePrompt: MetadataRecallFunc = (metadata: unknown, withToast = true) => {
+ recall(
+ () => {
+ const negative_style_prompt = getProperty(metadata, 'negative_style_prompt', isParameterNegativeStylePromptSDXL);
+ getStore().dispatch(setNegativeStylePromptSDXL(negative_style_prompt));
+ },
+ t('sdxl.negStylePrompt'),
+ withToast
+ );
+};
+
+export const recallSeed: MetadataRecallFunc = (metadata: unknown, withToast = true) => {
+ recall(
+ () => {
+ const seed = getProperty(metadata, 'seed', isParameterSeed);
+ getStore().dispatch(setSeed(seed));
+ },
+ t('metadata.seed'),
+ withToast
+ );
+};
+
+export const recallCFGScale: MetadataRecallFunc = (metadata: unknown, withToast = true) => {
+ recall(
+ () => {
+ const cfg_scale = getProperty(metadata, 'cfg_scale', isParameterCFGScale);
+ getStore().dispatch(setCfgScale(cfg_scale));
+ },
+ t('metadata.cfgScale'),
+ withToast
+ );
+};
+
+export const recallCFGRescaleMultiplier: MetadataRecallFunc = (metadata: unknown, withToast = true) => {
+ recall(
+ () => {
+ const cfg_rescale_multiplier = getProperty(metadata, 'cfg_rescale_multiplier', isParameterCFGRescaleMultiplier);
+ getStore().dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier));
+ },
+ t('metadata.cfgRescaleMultiplier'),
+ withToast
+ );
+};
+
+export const recallScheduler: MetadataRecallFunc = (metadata: unknown, withToast = true) => {
+ recall(
+ () => {
+ const scheduler = getProperty(metadata, 'scheduler', isParameterScheduler);
+ getStore().dispatch(setScheduler(scheduler));
+ },
+ t('metadata.scheduler'),
+ withToast
+ );
+};
+
+export const recallWidth: MetadataRecallFunc = (metadata: unknown, withToast = true) => {
+ recall(
+ () => {
+ const width = getProperty(metadata, 'width', isParameterWidth);
+ getStore().dispatch(widthRecalled(width));
+ },
+ t('metadata.width'),
+ withToast
+ );
+};
+
+export const recallHeight: MetadataRecallFunc = (metadata: unknown, withToast = true) => {
+ recall(
+ () => {
+ const height = getProperty(metadata, 'height', isParameterHeight);
+ getStore().dispatch(heightRecalled(height));
+ },
+ t('metadata.height'),
+ withToast
+ );
+};
+
+export const recallSteps: MetadataRecallFunc = (metadata: unknown, withToast = true) => {
+ recall(
+ () => {
+ const steps = getProperty(metadata, 'steps', isParameterSteps);
+ getStore().dispatch(setSteps(steps));
+ },
+ t('metadata.steps'),
+ withToast
+ );
+};
+
+export const recallStrength: MetadataRecallFunc = (metadata: unknown, withToast = true) => {
+ recall(
+ () => {
+ const strength = getProperty(metadata, 'strength', isParameterStrength);
+ getStore().dispatch(setImg2imgStrength(strength));
+ },
+ t('metadata.strength'),
+ withToast
+ );
+};
+
+export const recallHRFEnabled: MetadataRecallFunc = (metadata: unknown, withToast = true) => {
+ recall(
+ () => {
+ const hrf_enabled = getProperty(metadata, 'hrf_enabled', isParameterHRFEnabled);
+ getStore().dispatch(setHrfEnabled(hrf_enabled));
+ },
+ t('hrf.metadata.enabled'),
+ withToast
+ );
+};
+
+export const recallHRFStrength: MetadataRecallFunc = (metadata: unknown, withToast = true) => {
+ recall(
+ () => {
+ const hrf_strength = getProperty(metadata, 'hrf_strength', isParameterStrength);
+ getStore().dispatch(setHrfStrength(hrf_strength));
+ },
+ t('hrf.metadata.strength'),
+ withToast
+ );
+};
+
+export const recallHRFMethod: MetadataRecallFunc = (metadata: unknown, withToast = true) => {
+ recall(
+ () => {
+ const hrf_method = getProperty(metadata, 'hrf_method', isParameterHRFMethod);
+ getStore().dispatch(setHrfMethod(hrf_method));
+ },
+ t('hrf.metadata.method'),
+ withToast
+ );
+};
+
+export const recallRefinerSteps: MetadataRecallFunc = (metadata: unknown, withToast = true) => {
+ recall(
+ () => {
+ const refiner_steps = getProperty(metadata, 'refiner_steps', isParameterSteps);
+ getStore().dispatch(setRefinerSteps(refiner_steps));
+ },
+ t('sdxl.steps'),
+ withToast
+ );
+};
+
+export const recallRefinerCFGScale: MetadataRecallFunc = (metadata: unknown, withToast = true) => {
+ recall(
+ () => {
+ const refiner_cfg_scale = getProperty(metadata, 'refiner_cfg_scale', isParameterCFGScale);
+ getStore().dispatch(setRefinerCFGScale(refiner_cfg_scale));
+ },
+ t('sdxl.cfgScale'),
+ withToast
+ );
+};
+
+export const recallRefinerScheduler: MetadataRecallFunc = (metadata: unknown, withToast = true) => {
+ recall(
+ () => {
+ const refiner_scheduler = getProperty(metadata, 'refiner_scheduler', isParameterScheduler);
+ getStore().dispatch(setRefinerScheduler(refiner_scheduler));
+ },
+ t('sdxl.cfgScale'),
+ withToast
+ );
+};
+
+export const recallRefinerPositiveAestheticScore: MetadataRecallFunc = (metadata: unknown, withToast = true) => {
+ recall(
+ () => {
+ const refiner_positive_aesthetic_score = getProperty(
+ metadata,
+ 'refiner_positive_aesthetic_score',
+ isParameterSDXLRefinerPositiveAestheticScore
+ );
+ getStore().dispatch(setRefinerPositiveAestheticScore(refiner_positive_aesthetic_score));
+ },
+ t('sdxl.posAestheticScore'),
+ withToast
+ );
+};
+
+export const recallRefinerNegativeAestheticScore: MetadataRecallFunc = (metadata: unknown, withToast = true) => {
+ recall(
+ () => {
+ const refiner_negative_aesthetic_score = getProperty(
+ metadata,
+ 'refiner_negative_aesthetic_score',
+ isParameterSDXLRefinerNegativeAestheticScore
+ );
+ getStore().dispatch(setRefinerNegativeAestheticScore(refiner_negative_aesthetic_score));
+ },
+ t('sdxl.negAestheticScore'),
+ withToast
+ );
+};
+
+export const recallRefinerStart: MetadataRecallFunc = (metadata: unknown, withToast = true) => {
+ recall(
+ () => {
+ const refiner_start = getProperty(metadata, 'refiner_start', isParameterSDXLRefinerStart);
+ getStore().dispatch(setRefinerStart(refiner_start));
+ },
+ t('sdxl.refinerStart'),
+ withToast
+ );
+};
+
+export const prepareMainModelMetadataItem = async (model: unknown): Promise => {
+ const key = await getModelKey(model, 'main');
+ const mainModel = await fetchMainModelConfig(key);
+ return zModelIdentifierWithBase.parse(mainModel);
+};
+
+const recallModelAsync: MetadataRecallFunc = async (metadata: unknown, withToast = true) => {
+ await recallAsync(
+ async () => {
+ const modelMetadataItem = getProperty(metadata, 'model');
+ const model = await prepareMainModelMetadataItem(modelMetadataItem);
+ getStore().dispatch(modelSelected(model));
+ },
+ t('metadata.model'),
+ withToast
+ );
+};
+
+export const prepareRefinerMetadataItem = async (
+ model: unknown,
+ currentBase: BaseModelType | undefined
+): Promise => {
+ const key = await getModelKey(model, 'main');
+ const refinerModel = await fetchRefinerModelConfig(key);
+ raiseIfBaseIncompatible('sdxl-refiner', currentBase, 'Refiner incompatible with currently-selected model');
+ return zModelIdentifierWithBase.parse(refinerModel);
+};
+
+const recallRefinerModelAsync: MetadataRecallFunc = async (metadata: unknown, withToast = true) => {
+ await recallAsync(
+ async () => {
+ const refinerMetadataItem = getProperty(metadata, 'refiner_model');
+ const refiner = await prepareRefinerMetadataItem(refinerMetadataItem, getCurrentBase());
+ getStore().dispatch(refinerModelChanged(refiner));
+ },
+ t('sdxl.refinerModel'),
+ withToast
+ );
+};
+
+export const prepareVAEMetadataItem = async (
+ vae: unknown,
+ currentBase: BaseModelType | undefined
+): Promise => {
+ const key = await getModelKey(vae, 'vae');
+ const vaeModel = await fetchVAEModelConfig(key);
+ raiseIfBaseIncompatible(vaeModel.base, currentBase, 'VAE incompatible with currently-selected model');
+ return zModelIdentifierWithBase.parse(vaeModel);
+};
+
+const recallVAEAsync: MetadataRecallFunc = async (metadata: unknown, withToast = true) => {
+ await recallAsync(
+ async () => {
+ const currentBase = getCurrentBase();
+ const vaeMetadataItem = getProperty(metadata, 'vae');
+ if (isNil(vaeMetadataItem)) {
+ getStore().dispatch(vaeSelected(null));
+ } else {
+ const vae = await prepareVAEMetadataItem(vaeMetadataItem, currentBase);
+ getStore().dispatch(vaeSelected(vae));
+ }
+ },
+ t('metadata.vae'),
+ withToast
+ );
+};
+
+export const prepareLoRAMetadataItem = async (
+ loraMetadataItem: unknown,
+ currentBase: BaseModelType | undefined
+): Promise => {
+ const lora = getProperty(loraMetadataItem, 'lora');
+ const weight = getProperty(loraMetadataItem, 'weight');
+ const key = await getModelKey(lora, 'lora');
+ const loraModel = await fetchLoRAModel(key);
+ raiseIfBaseIncompatible(loraModel.base, currentBase, 'LoRA incompatible with currently-selected model');
+ return {
+ key: loraModel.key,
+ base: loraModel.base,
+ weight: isParameterLoRAWeight(weight) ? weight : defaultLoRAConfig.weight,
+ isEnabled: true,
+ };
+};
+
+const recallLoRAAsync: MetadataItemRecallFunc = async (metadataItem: unknown, withToast = true) => {
+ await recallAsync(
+ async () => {
+ const currentBase = getCurrentBase();
+ const lora = await prepareLoRAMetadataItem(metadataItem, currentBase);
+ getStore().dispatch(loraRecalled(lora));
+ },
+ t('models.lora'),
+ withToast
+ );
+};
+
+export const prepareControlNetMetadataItem = async (
+ metadataItem: unknown,
+ currentBase: BaseModelType | undefined
+): Promise => {
+ const control_model = getProperty(metadataItem, 'control_model');
+ const key = await getModelKey(control_model, 'controlnet');
+ const controlNetModel = await fetchControlNetModel(key);
+ raiseIfBaseIncompatible(controlNetModel.base, currentBase, 'ControlNet incompatible with currently-selected model');
+
+ const image = zControlField.shape.image.nullish().catch(null).parse(getProperty(metadataItem, 'image'));
+ const control_weight = zControlField.shape.control_weight
+ .nullish()
+ .catch(null)
+ .parse(getProperty(metadataItem, 'control_weight'));
+ const begin_step_percent = zControlField.shape.begin_step_percent
+ .nullish()
+ .catch(null)
+ .parse(getProperty(metadataItem, 'begin_step_percent'));
+ const end_step_percent = zControlField.shape.end_step_percent
+ .nullish()
+ .catch(null)
+ .parse(getProperty(metadataItem, 'end_step_percent'));
+ const control_mode = zControlField.shape.control_mode
+ .nullish()
+ .catch(null)
+ .parse(getProperty(metadataItem, 'control_mode'));
+ const resize_mode = zControlField.shape.resize_mode
+ .nullish()
+ .catch(null)
+ .parse(getProperty(metadataItem, 'resize_mode'));
+
+ // We don't save the original image that was processed into a control image, only the processed image
+ const processorType = 'none';
+ const processorNode = CONTROLNET_PROCESSORS.none.default;
+
+ const controlnet: ControlNetConfig = {
+ type: 'controlnet',
+ isEnabled: true,
+ model: zModelIdentifierWithBase.parse(controlNetModel),
+ weight: typeof control_weight === 'number' ? control_weight : initialControlNet.weight,
+ beginStepPct: begin_step_percent ?? initialControlNet.beginStepPct,
+ endStepPct: end_step_percent ?? initialControlNet.endStepPct,
+ controlMode: control_mode ?? initialControlNet.controlMode,
+ resizeMode: resize_mode ?? initialControlNet.resizeMode,
+ controlImage: image?.image_name ?? null,
+ processedControlImage: image?.image_name ?? null,
+ processorType,
+ processorNode,
+ shouldAutoConfig: true,
+ id: uuidv4(),
+ };
+
+ return controlnet;
+};
+
+const recallControlNetAsync: MetadataItemRecallFunc = async (metadataItem: unknown, withToast = true) => {
+ await recallAsync(
+ async () => {
+ const currentBase = getCurrentBase();
+ const controlNetConfig = await prepareControlNetMetadataItem(metadataItem, currentBase);
+ getStore().dispatch(controlAdapterRecalled(controlNetConfig));
+ },
+ t('common.controlNet'),
+ withToast
+ );
+};
+
+export const prepareT2IAdapterMetadataItem = async (
+ metadataItem: unknown,
+ currentBase: BaseModelType | undefined
+): Promise => {
+ const t2i_adapter_model = getProperty(metadataItem, 't2i_adapter_model');
+ const key = await getModelKey(t2i_adapter_model, 't2i_adapter');
+ const t2iAdapterModel = await fetchT2IAdapterModel(key);
+ raiseIfBaseIncompatible(t2iAdapterModel.base, currentBase, 'T2I Adapter incompatible with currently-selected model');
+
+ const image = zT2IAdapterField.shape.image.nullish().catch(null).parse(getProperty(metadataItem, 'image'));
+ const weight = zT2IAdapterField.shape.weight.nullish().catch(null).parse(getProperty(metadataItem, 'weight'));
+ const begin_step_percent = zT2IAdapterField.shape.begin_step_percent
+ .nullish()
+ .catch(null)
+ .parse(getProperty(metadataItem, 'begin_step_percent'));
+ const end_step_percent = zT2IAdapterField.shape.end_step_percent
+ .nullish()
+ .catch(null)
+ .parse(getProperty(metadataItem, 'end_step_percent'));
+ const resize_mode = zT2IAdapterField.shape.resize_mode
+ .nullish()
+ .catch(null)
+ .parse(getProperty(metadataItem, 'resize_mode'));
+
+ // We don't save the original image that was processed into a control image, only the processed image
+ const processorType = 'none';
+ const processorNode = CONTROLNET_PROCESSORS.none.default;
+
+ const t2iAdapter: T2IAdapterConfig = {
+ type: 't2i_adapter',
+ isEnabled: true,
+ model: zModelIdentifierWithBase.parse(t2iAdapterModel),
+ weight: typeof weight === 'number' ? weight : initialT2IAdapter.weight,
+ beginStepPct: begin_step_percent ?? initialT2IAdapter.beginStepPct,
+ endStepPct: end_step_percent ?? initialT2IAdapter.endStepPct,
+ resizeMode: resize_mode ?? initialT2IAdapter.resizeMode,
+ controlImage: image?.image_name ?? null,
+ processedControlImage: image?.image_name ?? null,
+ processorType,
+ processorNode,
+ shouldAutoConfig: true,
+ id: uuidv4(),
+ };
+
+ return t2iAdapter;
+};
+
+const recallT2IAdapterAsync: MetadataItemRecallFunc = async (metadataItem: unknown, withToast = true) => {
+ await recallAsync(
+ async () => {
+ const currentBase = getCurrentBase();
+ const t2iAdapterConfig = await prepareT2IAdapterMetadataItem(metadataItem, currentBase);
+ getStore().dispatch(controlAdapterRecalled(t2iAdapterConfig));
+ },
+ t('common.t2iAdapter'),
+ withToast
+ );
+};
+
+export const prepareIPAdapterMetadataItem = async (
+ metadataItem: unknown,
+ currentBase: BaseModelType | undefined
+): Promise => {
+ const ip_adapter_model = getProperty(metadataItem, 'ip_adapter_model');
+ const key = await getModelKey(ip_adapter_model, 'ip_adapter');
+ const ipAdapterModel = await fetchIPAdapterModel(key);
+ raiseIfBaseIncompatible(ipAdapterModel.base, currentBase, 'T2I Adapter incompatible with currently-selected model');
+
+ const image = zIPAdapterField.shape.image.nullish().catch(null).parse(getProperty(metadataItem, 'image'));
+ const weight = zIPAdapterField.shape.weight.nullish().catch(null).parse(getProperty(metadataItem, 'weight'));
+ const begin_step_percent = zIPAdapterField.shape.begin_step_percent
+ .nullish()
+ .catch(null)
+ .parse(getProperty(metadataItem, 'begin_step_percent'));
+ const end_step_percent = zIPAdapterField.shape.end_step_percent
+ .nullish()
+ .catch(null)
+ .parse(getProperty(metadataItem, 'end_step_percent'));
+
+ const ipAdapter: IPAdapterConfig = {
+ id: uuidv4(),
+ type: 'ip_adapter',
+ isEnabled: true,
+ model: zModelIdentifierWithBase.parse(ipAdapterModel),
+ controlImage: image?.image_name ?? null,
+ weight: weight ?? initialIPAdapter.weight,
+ beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct,
+ endStepPct: end_step_percent ?? initialIPAdapter.endStepPct,
+ };
+
+ return ipAdapter;
+};
+
+const recallIPAdapterAsync: MetadataItemRecallFunc = async (metadataItem: unknown, withToast) => {
+ await recallAsync(
+ async () => {
+ const currentBase = getCurrentBase();
+ const ipAdapterConfig = await prepareIPAdapterMetadataItem(metadataItem, currentBase);
+ getStore().dispatch(controlAdapterRecalled(ipAdapterConfig));
+ },
+ t('common.ipAdapter'),
+ withToast
+ );
+};
+
export const useRecallParameters = () => {
const dispatch = useAppDispatch();
- const toaster = useAppToaster();
- const { t } = useTranslation();
- const model = useAppSelector(selectModel);
- const parameterSetToast = useCallback(() => {
- toaster({
- title: t('toast.parameterSet'),
- status: 'info',
- duration: 2500,
- isClosable: true,
- });
- }, [t, toaster]);
-
- const parameterNotSetToast = useCallback(
- (description?: string) => {
- toaster({
- title: t('toast.parameterNotSet'),
- description,
- status: 'warning',
- duration: 2500,
- isClosable: true,
- });
- },
- [t, toaster]
- );
-
- const allParameterSetToast = useCallback(() => {
- toaster({
- title: t('toast.parametersSet'),
- status: 'info',
- duration: 2500,
- isClosable: true,
- });
- }, [t, toaster]);
-
- const allParameterNotSetToast = useCallback(
- (description?: string) => {
- toaster({
- title: t('toast.parametersNotSet'),
- status: 'warning',
- description,
- duration: 2500,
- isClosable: true,
- });
- },
- [t, toaster]
- );
-
- const recallBothPrompts = useCallback(
- (positivePrompt: unknown, negativePrompt: unknown, positiveStylePrompt: unknown, negativeStylePrompt: unknown) => {
+ const recallBothPrompts = useCallback(
+ (metadata: unknown) => {
+ const positive_prompt = getProperty(metadata, 'positive_prompt');
+ const negative_prompt = getProperty(metadata, 'negative_prompt');
+ const positive_style_prompt = getProperty(metadata, 'positive_style_prompt');
+ const negative_style_prompt = getProperty(metadata, 'negative_style_prompt');
if (
- isParameterPositivePrompt(positivePrompt) ||
- isParameterNegativePrompt(negativePrompt) ||
- isParameterPositiveStylePromptSDXL(positiveStylePrompt) ||
- isParameterNegativeStylePromptSDXL(negativeStylePrompt)
+ isParameterPositivePrompt(positive_prompt) ||
+ isParameterNegativePrompt(negative_prompt) ||
+ isParameterPositiveStylePromptSDXL(positive_style_prompt) ||
+ isParameterNegativeStylePromptSDXL(negative_style_prompt)
) {
- if (isParameterPositivePrompt(positivePrompt)) {
- dispatch(setPositivePrompt(positivePrompt));
+ if (isParameterPositivePrompt(positive_prompt)) {
+ dispatch(setPositivePrompt(positive_prompt));
}
- if (isParameterNegativePrompt(negativePrompt)) {
- dispatch(setNegativePrompt(negativePrompt));
+ if (isParameterNegativePrompt(negative_prompt)) {
+ dispatch(setNegativePrompt(negative_prompt));
}
- if (isParameterPositiveStylePromptSDXL(positiveStylePrompt)) {
- dispatch(setPositiveStylePromptSDXL(positiveStylePrompt));
+ if (isParameterPositiveStylePromptSDXL(positive_style_prompt)) {
+ dispatch(setPositiveStylePromptSDXL(positive_style_prompt));
}
- if (isParameterPositiveStylePromptSDXL(negativeStylePrompt)) {
- dispatch(setNegativeStylePromptSDXL(negativeStylePrompt));
+ if (isParameterPositiveStylePromptSDXL(negative_style_prompt)) {
+ dispatch(setNegativeStylePromptSDXL(negative_style_prompt));
}
- parameterSetToast();
+ parameterSetToast(t('metadata.allPrompts'));
return;
}
- parameterNotSetToast();
+ parameterNotSetToast(t('metadata.allPrompts'));
},
- [dispatch, parameterSetToast, parameterNotSetToast]
+ [dispatch]
);
- const recallPositivePrompt = useCallback(
- (positivePrompt: unknown) => {
- if (!isParameterPositivePrompt(positivePrompt)) {
- parameterNotSetToast();
- return;
- }
- dispatch(setPositivePrompt(positivePrompt));
- parameterSetToast();
- },
- [dispatch, parameterSetToast, parameterNotSetToast]
- );
+ const recallWidthAndHeight = useCallback(
+ (metadata: unknown) => {
+ const width = getProperty(metadata, 'width');
+ const height = getProperty(metadata, 'height');
- const recallNegativePrompt = useCallback(
- (negativePrompt: unknown) => {
- if (!isParameterNegativePrompt(negativePrompt)) {
- parameterNotSetToast();
- return;
- }
- dispatch(setNegativePrompt(negativePrompt));
- parameterSetToast();
- },
- [dispatch, parameterSetToast, parameterNotSetToast]
- );
-
- const recallSDXLPositiveStylePrompt = useCallback(
- (positiveStylePrompt: unknown) => {
- if (!isParameterPositiveStylePromptSDXL(positiveStylePrompt)) {
- parameterNotSetToast();
- return;
- }
- dispatch(setPositiveStylePromptSDXL(positiveStylePrompt));
- parameterSetToast();
- },
- [dispatch, parameterSetToast, parameterNotSetToast]
- );
-
- const recallSDXLNegativeStylePrompt = useCallback(
- (negativeStylePrompt: unknown) => {
- if (!isParameterNegativeStylePromptSDXL(negativeStylePrompt)) {
- parameterNotSetToast();
- return;
- }
- dispatch(setNegativeStylePromptSDXL(negativeStylePrompt));
- parameterSetToast();
- },
- [dispatch, parameterSetToast, parameterNotSetToast]
- );
-
- const recallSeed = useCallback(
- (seed: unknown) => {
- if (!isParameterSeed(seed)) {
- parameterNotSetToast();
- return;
- }
- dispatch(setSeed(seed));
- parameterSetToast();
- },
- [dispatch, parameterSetToast, parameterNotSetToast]
- );
-
- const recallCfgScale = useCallback(
- (cfgScale: unknown) => {
- if (!isParameterCFGScale(cfgScale)) {
- parameterNotSetToast();
- return;
- }
- dispatch(setCfgScale(cfgScale));
- parameterSetToast();
- },
- [dispatch, parameterSetToast, parameterNotSetToast]
- );
-
- const recallCfgRescaleMultiplier = useCallback(
- (cfgRescaleMultiplier: unknown) => {
- if (!isParameterCFGRescaleMultiplier(cfgRescaleMultiplier)) {
- parameterNotSetToast();
- return;
- }
- dispatch(setCfgRescaleMultiplier(cfgRescaleMultiplier));
- parameterSetToast();
- },
- [dispatch, parameterSetToast, parameterNotSetToast]
- );
-
- const recallScheduler = useCallback(
- (scheduler: unknown) => {
- if (!isParameterScheduler(scheduler)) {
- parameterNotSetToast();
- return;
- }
- dispatch(setScheduler(scheduler));
- parameterSetToast();
- },
- [dispatch, parameterSetToast, parameterNotSetToast]
- );
-
- const recallSteps = useCallback(
- (steps: unknown) => {
- if (!isParameterSteps(steps)) {
- parameterNotSetToast();
- return;
- }
- dispatch(setSteps(steps));
- parameterSetToast();
- },
- [dispatch, parameterSetToast, parameterNotSetToast]
- );
-
- const recallWidth = useCallback(
- (width: unknown) => {
- if (!isParameterWidth(width)) {
- parameterNotSetToast();
- return;
- }
- dispatch(widthRecalled(width));
- parameterSetToast();
- },
- [dispatch, parameterSetToast, parameterNotSetToast]
- );
-
- const recallHeight = useCallback(
- (height: unknown) => {
- if (!isParameterHeight(height)) {
- parameterNotSetToast();
- return;
- }
- dispatch(heightRecalled(height));
- parameterSetToast();
- },
- [dispatch, parameterSetToast, parameterNotSetToast]
- );
-
- const recallWidthAndHeight = useCallback(
- (width: unknown, height: unknown) => {
if (!isParameterWidth(width)) {
allParameterNotSetToast();
return;
@@ -302,144 +746,7 @@ export const useRecallParameters = () => {
dispatch(widthRecalled(width));
allParameterSetToast();
},
- [dispatch, allParameterSetToast, allParameterNotSetToast]
- );
-
- const recallStrength = useCallback(
- (strength: unknown) => {
- if (!isParameterStrength(strength)) {
- parameterNotSetToast();
- return;
- }
- dispatch(setImg2imgStrength(strength));
- parameterSetToast();
- },
- [dispatch, parameterSetToast, parameterNotSetToast]
- );
-
- const recallHrfEnabled = useCallback(
- (hrfEnabled: unknown) => {
- if (!isParameterHRFEnabled(hrfEnabled)) {
- parameterNotSetToast();
- return;
- }
- dispatch(setHrfEnabled(hrfEnabled));
- parameterSetToast();
- },
- [dispatch, parameterSetToast, parameterNotSetToast]
- );
-
- const recallHrfStrength = useCallback(
- (hrfStrength: unknown) => {
- if (!isParameterStrength(hrfStrength)) {
- parameterNotSetToast();
- return;
- }
- dispatch(setHrfStrength(hrfStrength));
- parameterSetToast();
- },
- [dispatch, parameterSetToast, parameterNotSetToast]
- );
-
- const recallHrfMethod = useCallback(
- (hrfMethod: unknown) => {
- if (!isParameterHRFMethod(hrfMethod)) {
- parameterNotSetToast();
- return;
- }
- dispatch(setHrfMethod(hrfMethod));
- parameterSetToast();
- },
- [dispatch, parameterSetToast, parameterNotSetToast]
- );
-
- const recallModel = useCallback(
- async (modelMetadataItem: unknown) => {
- try {
- const model = await prepareMainModelMetadataItem(modelMetadataItem);
- dispatch(modelSelected(model));
- parameterSetToast();
- } catch (e) {
- parameterNotSetToast((e as unknown as Error).message);
- return;
- }
- },
- [dispatch, parameterSetToast, parameterNotSetToast]
- );
-
- const recallVaeModel = useCallback(
- async (vaeMetadataItem: unknown) => {
- if (isNil(vaeMetadataItem)) {
- dispatch(vaeSelected(null));
- parameterSetToast();
- return;
- }
- try {
- const vae = await prepareVAEMetadataItem(vaeMetadataItem);
- dispatch(vaeSelected(vae));
- parameterSetToast();
- } catch (e) {
- parameterNotSetToast((e as unknown as Error).message);
- return;
- }
- },
- [dispatch, parameterSetToast, parameterNotSetToast]
- );
-
- const recallLoRA = useCallback(
- async (loraMetadataItem: LoRAMetadataItem) => {
- try {
- const lora = await prepareLoRAMetadataItem(loraMetadataItem, model?.base);
- dispatch(loraRecalled(lora));
- parameterSetToast();
- } catch (e) {
- parameterNotSetToast((e as unknown as Error).message);
- return;
- }
- },
- [model?.base, dispatch, parameterSetToast, parameterNotSetToast]
- );
-
- const recallControlNet = useCallback(
- async (controlnetMetadataItem: ControlNetMetadataItem) => {
- try {
- const controlNetConfig = await prepareControlNetMetadataItem(controlnetMetadataItem, model?.base);
- dispatch(controlAdapterRecalled(controlNetConfig));
- parameterSetToast();
- } catch (e) {
- parameterNotSetToast((e as unknown as Error).message);
- return;
- }
- },
- [model?.base, dispatch, parameterSetToast, parameterNotSetToast]
- );
-
- const recallT2IAdapter = useCallback(
- async (t2iAdapterMetadataItem: T2IAdapterMetadataItem) => {
- try {
- const t2iAdapterConfig = await prepareT2IAdapterMetadataItem(t2iAdapterMetadataItem, model?.base);
- dispatch(controlAdapterRecalled(t2iAdapterConfig));
- parameterSetToast();
- } catch (e) {
- parameterNotSetToast((e as unknown as Error).message);
- return;
- }
- },
- [model?.base, dispatch, parameterSetToast, parameterNotSetToast]
- );
-
- const recallIPAdapter = useCallback(
- async (ipAdapterMetadataItem: IPAdapterMetadataItem) => {
- try {
- const ipAdapterConfig = await prepareIPAdapterMetadataItem(ipAdapterMetadataItem, model?.base);
- dispatch(controlAdapterRecalled(ipAdapterConfig));
- parameterSetToast();
- } catch (e) {
- parameterNotSetToast((e as unknown as Error).message);
- return;
- }
- },
- [model?.base, dispatch, parameterSetToast, parameterNotSetToast]
+ [dispatch]
);
const sendToImageToImage = useCallback(
@@ -449,197 +756,75 @@ export const useRecallParameters = () => {
[dispatch]
);
- const recallAllParameters = useCallback(
- async (metadata: CoreMetadata | undefined) => {
+ const recallAllParameters = useCallback(
+ async (metadata: unknown) => {
if (!metadata) {
allParameterNotSetToast();
return;
}
- const {
- cfg_scale,
- cfg_rescale_multiplier,
- height,
- model,
- positive_prompt,
- negative_prompt,
- scheduler,
- vae,
- seed,
- steps,
- width,
- strength,
- hrf_enabled,
- hrf_strength,
- hrf_method,
- positive_style_prompt,
- negative_style_prompt,
- refiner_model,
- refiner_cfg_scale,
- refiner_steps,
- refiner_scheduler,
- refiner_positive_aesthetic_score,
- refiner_negative_aesthetic_score,
- refiner_start,
- loras,
- controlnets,
- ipAdapters,
- t2iAdapters,
- } = metadata;
+ await recallModelAsync(metadata, false);
- let newModel: ParameterModel | undefined = undefined;
+ recallCFGScale(metadata, false);
+ recallCFGRescaleMultiplier(metadata, false);
+ recallPositivePrompt(metadata, false);
+ recallNegativePrompt(metadata, false);
+ recallScheduler(metadata, false);
+ recallSeed(metadata, false);
+ recallSteps(metadata, false);
+ recallWidth(metadata, false);
+ recallHeight(metadata, false);
+ recallStrength(metadata, false);
+ recallHRFEnabled(metadata, false);
+ recallHRFMethod(metadata, false);
+ recallHRFStrength(metadata, false);
- if (isModelIdentifier(model)) {
- try {
- const _model = await prepareMainModelMetadataItem(model);
- dispatch(modelSelected(_model));
- newModel = _model;
- } catch {
- return;
- }
- }
+ // SDXL
+ recallSDXLPositiveStylePrompt(metadata, false);
+ recallSDXLNegativeStylePrompt(metadata, false);
+ recallRefinerSteps(metadata, false);
+ recallRefinerCFGScale(metadata, false);
+ recallRefinerScheduler(metadata, false);
+ recallRefinerPositiveAestheticScore(metadata, false);
+ recallRefinerNegativeAestheticScore(metadata, false);
+ recallRefinerStart(metadata, false);
- if (isParameterCFGScale(cfg_scale)) {
- dispatch(setCfgScale(cfg_scale));
- }
-
- if (isParameterCFGRescaleMultiplier(cfg_rescale_multiplier)) {
- dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier));
- }
-
- if (isParameterPositivePrompt(positive_prompt)) {
- dispatch(setPositivePrompt(positive_prompt));
- }
-
- if (isParameterNegativePrompt(negative_prompt)) {
- dispatch(setNegativePrompt(negative_prompt));
- }
-
- if (isParameterScheduler(scheduler)) {
- dispatch(setScheduler(scheduler));
- }
- if (isModelIdentifier(vae) || isNil(vae)) {
- if (isNil(vae)) {
- dispatch(vaeSelected(null));
- } else {
- try {
- const _vae = await prepareVAEMetadataItem(vae, newModel?.base);
- dispatch(vaeSelected(_vae));
- } catch {
- return;
- }
- }
- }
-
- if (isParameterSeed(seed)) {
- dispatch(setSeed(seed));
- }
-
- if (isParameterSteps(steps)) {
- dispatch(setSteps(steps));
- }
-
- if (isParameterWidth(width)) {
- dispatch(widthRecalled(width));
- }
-
- if (isParameterHeight(height)) {
- dispatch(heightRecalled(height));
- }
-
- if (isParameterStrength(strength)) {
- dispatch(setImg2imgStrength(strength));
- }
-
- if (isParameterHRFEnabled(hrf_enabled)) {
- dispatch(setHrfEnabled(hrf_enabled));
- }
-
- if (isParameterStrength(hrf_strength)) {
- dispatch(setHrfStrength(hrf_strength));
- }
-
- if (isParameterHRFMethod(hrf_method)) {
- dispatch(setHrfMethod(hrf_method));
- }
-
- if (isParameterPositiveStylePromptSDXL(positive_style_prompt)) {
- dispatch(setPositiveStylePromptSDXL(positive_style_prompt));
- }
-
- if (isParameterNegativeStylePromptSDXL(negative_style_prompt)) {
- dispatch(setNegativeStylePromptSDXL(negative_style_prompt));
- }
-
- if (isParameterSDXLRefinerModel(refiner_model)) {
- dispatch(refinerModelChanged(refiner_model));
- }
-
- if (isParameterSteps(refiner_steps)) {
- dispatch(setRefinerSteps(refiner_steps));
- }
-
- if (isParameterCFGScale(refiner_cfg_scale)) {
- dispatch(setRefinerCFGScale(refiner_cfg_scale));
- }
-
- if (isParameterScheduler(refiner_scheduler)) {
- dispatch(setRefinerScheduler(refiner_scheduler));
- }
-
- if (isParameterSDXLRefinerPositiveAestheticScore(refiner_positive_aesthetic_score)) {
- dispatch(setRefinerPositiveAestheticScore(refiner_positive_aesthetic_score));
- }
-
- if (isParameterSDXLRefinerNegativeAestheticScore(refiner_negative_aesthetic_score)) {
- dispatch(setRefinerNegativeAestheticScore(refiner_negative_aesthetic_score));
- }
-
- if (isParameterSDXLRefinerStart(refiner_start)) {
- dispatch(setRefinerStart(refiner_start));
- }
+ await recallVAEAsync(metadata, false);
+ await recallRefinerModelAsync(metadata, false);
dispatch(lorasCleared());
- loras?.forEach(async (loraMetadataItem) => {
- try {
- const lora = await prepareLoRAMetadataItem(loraMetadataItem, newModel?.base);
- dispatch(loraRecalled(lora));
- } catch {
- return;
- }
- });
+ const loraMetadataArray = getProperty(metadata, 'loras');
+ if (isArray(loraMetadataArray)) {
+ loraMetadataArray.forEach(async (loraMetadataItem) => {
+ await recallLoRAAsync(loraMetadataItem, false);
+ });
+ }
dispatch(controlAdaptersReset());
- controlnets?.forEach(async (controlNetMetadataItem) => {
- try {
- const controlNet = await prepareControlNetMetadataItem(controlNetMetadataItem, newModel?.base);
- dispatch(controlAdapterRecalled(controlNet));
- } catch {
- return;
- }
- });
+ const controlNetMetadataArray = getProperty(metadata, 'controlnets');
+ if (isArray(controlNetMetadataArray)) {
+ controlNetMetadataArray.forEach(async (controlNetMetadataItem) => {
+ await recallControlNetAsync(controlNetMetadataItem, false);
+ });
+ }
- ipAdapters?.forEach(async (ipAdapterMetadataItem) => {
- try {
- const ipAdapter = await prepareIPAdapterMetadataItem(ipAdapterMetadataItem, newModel?.base);
- dispatch(controlAdapterRecalled(ipAdapter));
- } catch {
- return;
- }
- });
+ const ipAdapterMetadataArray = getProperty(metadata, 'ipAdapters');
+ if (isArray(ipAdapterMetadataArray)) {
+ ipAdapterMetadataArray.forEach(async (ipAdapterMetadataItem) => {
+ await recallIPAdapterAsync(ipAdapterMetadataItem, false);
+ });
+ }
- t2iAdapters?.forEach(async (t2iAdapterMetadataItem) => {
- try {
- const t2iAdapter = await prepareT2IAdapterMetadataItem(t2iAdapterMetadataItem, newModel?.base);
- dispatch(controlAdapterRecalled(t2iAdapter));
- } catch {
- return;
- }
- });
+ const t2iAdapterMetadataArray = getProperty(metadata, 't2iAdapters');
+ if (isArray(t2iAdapterMetadataArray)) {
+ t2iAdapterMetadataArray.forEach(async (t2iAdapterMetadataItem) => {
+ await recallT2IAdapterAsync(t2iAdapterMetadataItem, false);
+ });
+ }
allParameterSetToast();
},
- [dispatch, allParameterSetToast, allParameterNotSetToast]
+ [dispatch]
);
return {
@@ -649,24 +834,25 @@ export const useRecallParameters = () => {
recallSDXLPositiveStylePrompt,
recallSDXLNegativeStylePrompt,
recallSeed,
- recallCfgScale,
- recallCfgRescaleMultiplier,
- recallModel,
+ recallCFGScale,
+ recallCFGRescaleMultiplier,
+ recallModel: recallModelAsync,
recallScheduler,
- recallVaeModel,
+ recallVaeModel: recallVAEAsync,
recallSteps,
recallWidth,
recallHeight,
recallWidthAndHeight,
recallStrength,
- recallHrfEnabled,
- recallHrfStrength,
- recallHrfMethod,
- recallLoRA,
- recallControlNet,
- recallIPAdapter,
- recallT2IAdapter,
+ recallHRFEnabled,
+ recallHRFStrength,
+ recallHRFMethod,
+ recallLoRA: recallLoRAAsync,
+ recallControlNet: recallControlNetAsync,
+ recallIPAdapter: recallIPAdapterAsync,
+ recallT2IAdapter: recallT2IAdapterAsync,
recallAllParameters,
+ recallRefinerModel: recallRefinerModelAsync,
sendToImageToImage,
};
};
diff --git a/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts b/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts
index 20953ec266..8d46add8c8 100644
--- a/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts
+++ b/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts
@@ -1,14 +1,7 @@
import { NUMPY_RAND_MAX } from 'app/constants';
import {
- zBaseModel,
- zControlNetModelField,
- zIPAdapterModelField,
- zLoRAModelField,
zModelIdentifierWithBase,
zSchedulerField,
- zSDXLRefinerModelField,
- zT2IAdapterModelField,
- zVAEModelField,
} from 'features/nodes/types/common';
import { z } from 'zod';
@@ -111,42 +104,42 @@ export const isParameterModel = (val: unknown): val is ParameterModel => zParame
// #endregion
// #region SDXL Refiner Model
-export const zParameterSDXLRefinerModel = zSDXLRefinerModelField.extend({ base: zBaseModel });
+export const zParameterSDXLRefinerModel = zModelIdentifierWithBase;
export type ParameterSDXLRefinerModel = z.infer;
export const isParameterSDXLRefinerModel = (val: unknown): val is ParameterSDXLRefinerModel =>
zParameterSDXLRefinerModel.safeParse(val).success;
// #endregion
// #region VAE Model
-export const zParameterVAEModel = zVAEModelField.extend({ base: zBaseModel });
+export const zParameterVAEModel = zModelIdentifierWithBase;
export type ParameterVAEModel = z.infer;
export const isParameterVAEModel = (val: unknown): val is ParameterVAEModel =>
zParameterVAEModel.safeParse(val).success;
// #endregion
// #region LoRA Model
-export const zParameterLoRAModel = zLoRAModelField.extend({ base: zBaseModel });
+export const zParameterLoRAModel = zModelIdentifierWithBase;
export type ParameterLoRAModel = z.infer;
export const isParameterLoRAModel = (val: unknown): val is ParameterLoRAModel =>
zParameterLoRAModel.safeParse(val).success;
// #endregion
// #region ControlNet Model
-export const zParameterControlNetModel = zControlNetModelField.extend({ base: zBaseModel });
+export const zParameterControlNetModel = zModelIdentifierWithBase;
export type ParameterControlNetModel = z.infer;
export const isParameterControlNetModel = (val: unknown): val is ParameterControlNetModel =>
zParameterControlNetModel.safeParse(val).success;
// #endregion
// #region IP Adapter Model
-export const zParameterIPAdapterModel = zIPAdapterModelField.extend({ base: zBaseModel });
+export const zParameterIPAdapterModel = zModelIdentifierWithBase;
export type ParameterIPAdapterModel = z.infer;
export const isParameterIPAdapterModel = (val: unknown): val is ParameterIPAdapterModel =>
zParameterIPAdapterModel.safeParse(val).success;
// #endregion
// #region T2I Adapter Model
-export const zParameterT2IAdapterModel = zT2IAdapterModelField.extend({ base: zBaseModel });
+export const zParameterT2IAdapterModel = zModelIdentifierWithBase;
export type ParameterT2IAdapterModel = z.infer;
export const isParameterT2IAdapterModel = (val: unknown): val is ParameterT2IAdapterModel =>
zParameterT2IAdapterModel.safeParse(val).success;
@@ -218,3 +211,9 @@ export type ParameterCanvasCoherenceMode = z.infer
zParameterCanvasCoherenceMode.safeParse(val).success;
// #endregion
+
+// #region LoRA weight
+export const zLoRAWeight = z.number();
+export type ParameterLoRAWeight = z.infer;
+export const isParameterLoRAWeight = (val: unknown): val is ParameterLoRAWeight => zLoRAWeight.safeParse(val).success;
+// #endregion
diff --git a/invokeai/frontend/web/src/features/parameters/util/modelFetchingHelpers.ts b/invokeai/frontend/web/src/features/parameters/util/modelFetchingHelpers.ts
index c7d25fed8b..bd3e42a314 100644
--- a/invokeai/frontend/web/src/features/parameters/util/modelFetchingHelpers.ts
+++ b/invokeai/frontend/web/src/features/parameters/util/modelFetchingHelpers.ts
@@ -1,7 +1,8 @@
import { getStore } from 'app/store/nanostores/store';
-import { isModelIdentifier } from 'features/nodes/types/common';
+import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
+import { isModelIdentifier, isModelIdentifierV2 } from 'features/nodes/types/common';
import { modelsApi } from 'services/api/endpoints/models';
-import type { AnyModelConfig, BaseModelType } from 'services/api/types';
+import type { AnyModelConfig, BaseModelType, ModelType } from 'services/api/types';
import {
isControlNetModelConfig,
isIPAdapterModelConfig,
@@ -41,6 +42,12 @@ export class InvalidModelConfigError extends Error {
}
}
+/**
+ * Fetches the model config for a given model key.
+ * @param key The model key.
+ * @returns A promise that resolves to the model config.
+ * @throws {ModelConfigNotFoundError} If the model config is unable to be fetched.
+ */
export const fetchModelConfig = async (key: string): Promise => {
const { dispatch } = getStore();
try {
@@ -52,6 +59,37 @@ export const fetchModelConfig = async (key: string): Promise =>
}
};
+/**
+ * Fetches the model config for a given model name, base model, and model type. This provides backwards compatibility
+ * for MM1 model identifiers.
+ * @param name The model name.
+ * @param base The base model.
+ * @param type The model type.
+ * @returns A promise that resolves to the model config.
+ * @throws {ModelConfigNotFoundError} If the model config is unable to be fetched.
+ */
+export const fetchModelConfigByAttrs = async (
+ name: string,
+ base: BaseModelType,
+ type: ModelType
+): Promise => {
+ const { dispatch } = getStore();
+ try {
+ const req = dispatch(modelsApi.endpoints.getModelConfigByAttrs.initiate({ name, base, type }));
+ req.unsubscribe();
+ return await req.unwrap();
+ } catch {
+ throw new ModelConfigNotFoundError(`Unable to retrieve model config for name/base/type ${name}/${base}/${type}`);
+ }
+};
+
+/**
+ * Fetches the model config for a given model key and type, and ensures that the model config is of a specific type.
+ * @param key The model key.
+ * @param typeGuard A type guard function that checks if the model config is of the expected type.
+ * @returns A promise that resolves to the model config. The model config is guaranteed to be of the expected type.
+ * @throws {InvalidModelConfigError} If the model config is unable to be fetched or is of an unexpected type.
+ */
export const fetchModelConfigWithTypeGuard = async (
key: string,
typeGuard: (config: AnyModelConfig) => config is T
@@ -63,15 +101,17 @@ export const fetchModelConfigWithTypeGuard = async (
return modelConfig;
};
-export const fetchMainModel = async (key: string) => {
+// TODO(psyche): Remove these helpers once `useRecallParameters` is removed
+
+export const fetchMainModelConfig = async (key: string) => {
return fetchModelConfigWithTypeGuard(key, isNonRefinerMainModelConfig);
};
-export const fetchRefinerModel = async (key: string) => {
+export const fetchRefinerModelConfig = async (key: string) => {
return fetchModelConfigWithTypeGuard(key, isRefinerMainModelModelConfig);
};
-export const fetchVAEModel = async (key: string) => {
+export const fetchVAEModelConfig = async (key: string) => {
return fetchModelConfigWithTypeGuard(key, isVAEModelConfig);
};
@@ -95,19 +135,39 @@ export const fetchTextualInversionModel = async (key: string) => {
return fetchModelConfigWithTypeGuard(key, isTextualInversionModelConfig);
};
-export const isBaseCompatible = (sourceBase: BaseModelType, targetBase: BaseModelType) => {
- return sourceBase === targetBase;
-};
-
+/**
+ * Raises an error if the source base model is incompatible with the target base model.
+ * @param sourceBase The source base model.
+ * @param targetBase The target base model.
+ * @param message An optional custom message to include in the error.
+ * @throws {InvalidModelConfigError} If the source base model is incompatible with the target base model.
+ */
export const raiseIfBaseIncompatible = (sourceBase: BaseModelType, targetBase?: BaseModelType, message?: string) => {
- if (targetBase && !isBaseCompatible(sourceBase, targetBase)) {
+ if (targetBase && sourceBase !== targetBase) {
throw new InvalidModelConfigError(message || `Incompatible base models: ${sourceBase} and ${targetBase}`);
}
};
-export const getModelKey = (modelIdentifier: unknown, message?: string): string => {
- if (!isModelIdentifier(modelIdentifier)) {
- throw new InvalidModelConfigError(message || `Invalid model identifier: ${modelIdentifier}`);
+/**
+ * Fetches the model key from a model identifier. This includes fetching the key for MM1 format model identifiers.
+ * @param modelIdentifier The model identifier. The MM2 format `{key: string}` simply extracts the key. The MM1 format
+ * `{model_name: string, base_model: BaseModelType}` must do a network request to fetch the key.
+ * @param type The type of model to fetch. This is used to fetch the key for MM1 format model identifiers.
+ * @param message An optional custom message to include in the error if the model identifier is invalid.
+ * @returns A promise that resolves to the model key.
+ * @throws {InvalidModelConfigError} If the model identifier is invalid.
+ */
+export const getModelKey = async (modelIdentifier: unknown, type: ModelType, message?: string): Promise => {
+ if (isModelIdentifier(modelIdentifier)) {
+ return modelIdentifier.key;
}
- return modelIdentifier.key;
+ if (isModelIdentifierV2(modelIdentifier)) {
+ return (await fetchModelConfigByAttrs(modelIdentifier.model_name, modelIdentifier.base_model, type)).key;
+ }
+ throw new InvalidModelConfigError(message || `Invalid model identifier: ${modelIdentifier}`);
};
+
+export const getModelKeyAndBase = (modelConfig: AnyModelConfig): ModelIdentifierWithBase => ({
+ key: modelConfig.key,
+ base: modelConfig.base,
+});
diff --git a/invokeai/frontend/web/src/features/parameters/util/modelMetadataHelpers.ts b/invokeai/frontend/web/src/features/parameters/util/modelMetadataHelpers.ts
deleted file mode 100644
index 722073366f..0000000000
--- a/invokeai/frontend/web/src/features/parameters/util/modelMetadataHelpers.ts
+++ /dev/null
@@ -1,150 +0,0 @@
-import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
-import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types';
-import {
- initialControlNet,
- initialIPAdapter,
- initialT2IAdapter,
-} from 'features/controlAdapters/util/buildControlAdapter';
-import type { LoRA } from 'features/lora/store/loraSlice';
-import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
-import { zModelIdentifierWithBase } from 'features/nodes/types/common';
-import type {
- ControlNetMetadataItem,
- IPAdapterMetadataItem,
- LoRAMetadataItem,
- T2IAdapterMetadataItem,
-} from 'features/nodes/types/metadata';
-import {
- fetchControlNetModel,
- fetchIPAdapterModel,
- fetchLoRAModel,
- fetchMainModel,
- fetchRefinerModel,
- fetchT2IAdapterModel,
- fetchVAEModel,
- getModelKey,
- raiseIfBaseIncompatible,
-} from 'features/parameters/util/modelFetchingHelpers';
-import type { BaseModelType } from 'services/api/types';
-import { v4 as uuidv4 } from 'uuid';
-
-export const prepareMainModelMetadataItem = async (model: unknown): Promise => {
- const key = getModelKey(model);
- const mainModel = await fetchMainModel(key);
- return zModelIdentifierWithBase.parse(mainModel);
-};
-
-export const prepareRefinerMetadataItem = async (model: unknown): Promise => {
- const key = getModelKey(model);
- const refinerModel = await fetchRefinerModel(key);
- return zModelIdentifierWithBase.parse(refinerModel);
-};
-
-export const prepareVAEMetadataItem = async (vae: unknown, base?: BaseModelType): Promise => {
- const key = getModelKey(vae);
- const vaeModel = await fetchVAEModel(key);
- raiseIfBaseIncompatible(vaeModel.base, base, 'VAE incompatible with currently-selected model');
- return zModelIdentifierWithBase.parse(vaeModel);
-};
-
-export const prepareLoRAMetadataItem = async (
- loraMetadataItem: LoRAMetadataItem,
- base?: BaseModelType
-): Promise => {
- const key = getModelKey(loraMetadataItem.lora);
- const loraModel = await fetchLoRAModel(key);
- raiseIfBaseIncompatible(loraModel.base, base, 'LoRA incompatible with currently-selected model');
- return { key: loraModel.key, base: loraModel.base, weight: loraMetadataItem.weight, isEnabled: true };
-};
-
-export const prepareControlNetMetadataItem = async (
- controlnetMetadataItem: ControlNetMetadataItem,
- base?: BaseModelType
-): Promise => {
- const key = getModelKey(controlnetMetadataItem.control_model);
- const controlNetModel = await fetchControlNetModel(key);
- raiseIfBaseIncompatible(controlNetModel.base, base, 'ControlNet incompatible with currently-selected model');
-
- const { image, control_weight, begin_step_percent, end_step_percent, control_mode, resize_mode } =
- controlnetMetadataItem;
-
- // We don't save the original image that was processed into a control image, only the processed image
- const processorType = 'none';
- const processorNode = CONTROLNET_PROCESSORS.none.default;
-
- const controlnet: ControlNetConfig = {
- type: 'controlnet',
- isEnabled: true,
- model: zModelIdentifierWithBase.parse(controlNetModel),
- weight: typeof control_weight === 'number' ? control_weight : initialControlNet.weight,
- beginStepPct: begin_step_percent || initialControlNet.beginStepPct,
- endStepPct: end_step_percent || initialControlNet.endStepPct,
- controlMode: control_mode || initialControlNet.controlMode,
- resizeMode: resize_mode || initialControlNet.resizeMode,
- controlImage: image?.image_name || null,
- processedControlImage: image?.image_name || null,
- processorType,
- processorNode,
- shouldAutoConfig: true,
- id: uuidv4(),
- };
-
- return controlnet;
-};
-
-export const prepareT2IAdapterMetadataItem = async (
- t2iAdapterMetadataItem: T2IAdapterMetadataItem,
- base?: BaseModelType
-): Promise => {
- const key = getModelKey(t2iAdapterMetadataItem.t2i_adapter_model);
- const t2iAdapterModel = await fetchT2IAdapterModel(key);
- raiseIfBaseIncompatible(t2iAdapterModel.base, base, 'T2I Adapter incompatible with currently-selected model');
-
- const { image, weight, begin_step_percent, end_step_percent, resize_mode } = t2iAdapterMetadataItem;
-
- // We don't save the original image that was processed into a control image, only the processed image
- const processorType = 'none';
- const processorNode = CONTROLNET_PROCESSORS.none.default;
-
- const t2iAdapter: T2IAdapterConfig = {
- type: 't2i_adapter',
- isEnabled: true,
- model: zModelIdentifierWithBase.parse(t2iAdapterModel),
- weight: typeof weight === 'number' ? weight : initialT2IAdapter.weight,
- beginStepPct: begin_step_percent || initialT2IAdapter.beginStepPct,
- endStepPct: end_step_percent || initialT2IAdapter.endStepPct,
- resizeMode: resize_mode || initialT2IAdapter.resizeMode,
- controlImage: image?.image_name || null,
- processedControlImage: image?.image_name || null,
- processorType,
- processorNode,
- shouldAutoConfig: true,
- id: uuidv4(),
- };
-
- return t2iAdapter;
-};
-
-export const prepareIPAdapterMetadataItem = async (
- ipAdapterMetadataItem: IPAdapterMetadataItem,
- base?: BaseModelType
-): Promise => {
- const key = getModelKey(ipAdapterMetadataItem?.ip_adapter_model);
- const ipAdapterModel = await fetchIPAdapterModel(key);
- raiseIfBaseIncompatible(ipAdapterModel.base, base, 'T2I Adapter incompatible with currently-selected model');
-
- const { image, weight, begin_step_percent, end_step_percent } = ipAdapterMetadataItem;
-
- const ipAdapter: IPAdapterConfig = {
- id: uuidv4(),
- type: 'ip_adapter',
- isEnabled: true,
- controlImage: image?.image_name ?? null,
- model: zModelIdentifierWithBase.parse(ipAdapterModel),
- weight: weight ?? initialIPAdapter.weight,
- beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct,
- endStepPct: end_step_percent ?? initialIPAdapter.endStepPct,
- };
-
- return ipAdapter;
-};
diff --git a/invokeai/frontend/web/src/features/settingsAccordions/components/AdvancedSettingsAccordion/AdvancedSettingsAccordion.tsx b/invokeai/frontend/web/src/features/settingsAccordions/components/AdvancedSettingsAccordion/AdvancedSettingsAccordion.tsx
index 8b10d9bddd..05ca73927c 100644
--- a/invokeai/frontend/web/src/features/settingsAccordions/components/AdvancedSettingsAccordion/AdvancedSettingsAccordion.tsx
+++ b/invokeai/frontend/web/src/features/settingsAccordions/components/AdvancedSettingsAccordion/AdvancedSettingsAccordion.tsx
@@ -25,7 +25,7 @@ const formLabelProps2: FormLabelProps = {
export const AdvancedSettingsAccordion = memo(() => {
const vaeKey = useAppSelector((state) => state.generation.vae?.key);
- const { data: vaeConfig } = useGetModelConfigQuery(vaeKey ?? skipToken);
+ const { currentData: vaeConfig } = useGetModelConfigQuery(vaeKey ?? skipToken);
const selectBadges = useMemo(
() =>
createMemoizedSelector(selectGenerationSlice, (generation) => {
diff --git a/invokeai/frontend/web/src/services/api/endpoints/images.ts b/invokeai/frontend/web/src/services/api/endpoints/images.ts
index 49eb28390f..1849e2fde4 100644
--- a/invokeai/frontend/web/src/services/api/endpoints/images.ts
+++ b/invokeai/frontend/web/src/services/api/endpoints/images.ts
@@ -1,10 +1,8 @@
import type { EntityState, Update } from '@reduxjs/toolkit';
import type { PatchCollection } from '@reduxjs/toolkit/dist/query/core/buildThunks';
-import { logger } from 'app/logging/logger';
+import type { JSONObject } from 'common/types';
import type { BoardId } from 'features/gallery/store/types';
import { ASSETS_CATEGORIES, IMAGE_CATEGORIES, IMAGE_LIMIT } from 'features/gallery/store/types';
-import type { CoreMetadata } from 'features/nodes/types/metadata';
-import { zCoreMetadata } from 'features/nodes/types/metadata';
import { addToast } from 'features/system/store/systemSlice';
import { t } from 'i18next';
import { keyBy } from 'lodash-es';
@@ -118,22 +116,9 @@ export const imagesApi = api.injectEndpoints({
providesTags: (result, error, image_name) => [{ type: 'Image', id: image_name }],
keepUnusedDataFor: 86400, // 24 hours
}),
- getImageMetadata: build.query({
+ getImageMetadata: build.query({
query: (image_name) => ({ url: buildImagesUrl(`i/${image_name}/metadata`) }),
providesTags: (result, error, image_name) => [{ type: 'ImageMetadata', id: image_name }],
- transformResponse: (
- response: paths['/api/v1/images/i/{image_name}/metadata']['get']['responses']['200']['content']['application/json']
- ) => {
- if (response) {
- const result = zCoreMetadata.safeParse(response);
- if (result.success) {
- return result.data;
- } else {
- logger('images').warn('Problem parsing metadata');
- }
- }
- return;
- },
keepUnusedDataFor: 86400, // 24 hours
}),
getImageWorkflow: build.query<
diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts
index 6618fda58e..3ccccf62e1 100644
--- a/invokeai/frontend/web/src/services/api/endpoints/models.ts
+++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts
@@ -70,6 +70,8 @@ export type ScanFolderResponse =
paths['/api/v2/models/scan_folder']['get']['responses']['200']['content']['application/json'];
type ScanFolderArg = operations['scan_for_models']['parameters']['query'];
+export type GetByAttrsArg = operations['get_model_records_by_attrs']['parameters']['query'];
+
export const mainModelsAdapter = createEntityAdapter({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
@@ -223,6 +225,18 @@ export const modelsApi = api.injectEndpoints({
return tags;
},
}),
+ getModelConfigByAttrs: build.query({
+ query: (arg) => buildModelsUrl(`get_by_attrs?${queryString.stringify(arg)}`),
+ providesTags: (result) => {
+ const tags: ApiTagDescription[] = ['Model'];
+
+ if (result) {
+ tags.push({ type: 'ModelConfig', id: result.key });
+ }
+
+ return tags;
+ },
+ }),
syncModels: build.mutation({
query: () => {
return {
@@ -300,6 +314,7 @@ export const modelsApi = api.injectEndpoints({
});
export const {
+ useGetModelConfigByAttrsQuery,
useGetModelConfigQuery,
useGetMainModelsQuery,
useGetControlNetModelsQuery,