feat(ui): refactor metadata handling (again)

Add concepts for metadata handlers. Handlers include parsers, recallers and validators for different metadata types:
- Parsers parse a raw metadata object of any shape to a structured object.
- Recallers load the parsed metadata into state. Recallers are optional, as some metadata types don't need to be loaded into state.
- Validators provide an additional layer of validation before recalling the metadata. This is needed because a metadata object may be valid, but not able to be recalled due to some other requirement, like base model compatibility. Validators are optional.

Sometimes metadata is not a single object but a list of items - like LoRAs. Metadata handlers may implement an optional set of "item" handlers which operate on individual items in the list.

Parsers and validators are async to allow fetching additional data, like a model config. Recallers are synchronous.

The these handlers are composed into a public API, exported as a `handlers` object. Besides the handlers functions, a metadata handler set includes:
- A function to get the label of the metadata type.
- An optional function to render the value of the metadata type.
- An optional function to render the _item_ value of the metadata type.
This commit is contained in:
psychedelicious 2024-02-24 01:25:58 +11:00 committed by Kent Keirsey
parent 0f10faf0d4
commit e174ce038f
31 changed files with 2627 additions and 1092 deletions

View File

@ -656,6 +656,7 @@
} }
}, },
"metadata": { "metadata": {
"allPrompts": "All Prompts",
"cfgScale": "CFG scale", "cfgScale": "CFG scale",
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)", "cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)",
"createdBy": "Created By", "createdBy": "Created By",
@ -664,6 +665,7 @@
"height": "Height", "height": "Height",
"hiresFix": "High Resolution Optimization", "hiresFix": "High Resolution Optimization",
"imageDetails": "Image Details", "imageDetails": "Image Details",
"imageDimensions": "Image Dimensions",
"initImage": "Initial image", "initImage": "Initial image",
"metadata": "Metadata", "metadata": "Metadata",
"model": "Model", "model": "Model",
@ -671,9 +673,11 @@
"noImageDetails": "No image details found", "noImageDetails": "No image details found",
"noMetaData": "No metadata found", "noMetaData": "No metadata found",
"noRecallParameters": "No parameters to recall found", "noRecallParameters": "No parameters to recall found",
"parameterSet": "Parameter {{parameter}} set",
"perlin": "Perlin Noise", "perlin": "Perlin Noise",
"positivePrompt": "Positive Prompt", "positivePrompt": "Positive Prompt",
"recallParameters": "Recall Parameters", "recallParameters": "Recall Parameters",
"recallParameter": "Recall {{label}}",
"scheduler": "Scheduler", "scheduler": "Scheduler",
"seamless": "Seamless", "seamless": "Seamless",
"seed": "Seed", "seed": "Seed",
@ -1381,8 +1385,8 @@
"nodesNotValidJSON": "Not a valid JSON", "nodesNotValidJSON": "Not a valid JSON",
"nodesSaved": "Nodes Saved", "nodesSaved": "Nodes Saved",
"nodesUnrecognizedTypes": "Cannot load. Graph has unrecognized types", "nodesUnrecognizedTypes": "Cannot load. Graph has unrecognized types",
"parameterNotSet": "Parameter not set", "parameterNotSet": "{{parameter}} not set",
"parameterSet": "Parameter set", "parameterSet": "{{parameter}} set",
"parametersFailed": "Problem loading parameters", "parametersFailed": "Problem loading parameters",
"parametersFailedDesc": "Unable to load init image.", "parametersFailedDesc": "Unable to load init image.",
"parametersNotSet": "Parameters Not Set", "parametersNotSet": "Parameters Not Set",

View File

@ -1,6 +1,6 @@
import { createStandaloneToast, theme, TOAST_OPTIONS } from '@invoke-ai/ui-library';
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import { parseify } from 'common/util/serialize'; import { parseify } from 'common/util/serialize';
import { toast } from 'common/util/toast';
import { zPydanticValidationError } from 'features/system/store/zodSchemas'; import { zPydanticValidationError } from 'features/system/store/zodSchemas';
import { t } from 'i18next'; import { t } from 'i18next';
import { truncate, upperFirst } from 'lodash-es'; import { truncate, upperFirst } from 'lodash-es';
@ -8,11 +8,6 @@ import { queueApi } from 'services/api/endpoints/queue';
import { startAppListening } from '..'; import { startAppListening } from '..';
const { toast } = createStandaloneToast({
theme: theme,
defaultOptions: TOAST_OPTIONS.defaultOptions,
});
export const addBatchEnqueuedListener = () => { export const addBatchEnqueuedListener = () => {
// success // success
startAppListening({ startAppListening({

View File

@ -1,7 +1,8 @@
import type { UseToastOptions } from '@invoke-ai/ui-library'; import type { UseToastOptions } from '@invoke-ai/ui-library';
import { createStandaloneToast, ExternalLink, theme, TOAST_OPTIONS } from '@invoke-ai/ui-library'; import { ExternalLink } from '@invoke-ai/ui-library';
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import { startAppListening } from 'app/store/middleware/listenerMiddleware'; import { startAppListening } from 'app/store/middleware/listenerMiddleware';
import { toast } from 'common/util/toast';
import { t } from 'i18next'; import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images'; import { imagesApi } from 'services/api/endpoints/images';
import { import {
@ -12,11 +13,6 @@ import {
const log = logger('images'); const log = logger('images');
const { toast } = createStandaloneToast({
theme: theme,
defaultOptions: TOAST_OPTIONS.defaultOptions,
});
export const addBulkDownloadListeners = () => { export const addBulkDownloadListeners = () => {
startAppListening({ startAppListening({
matcher: imagesApi.endpoints.bulkDownloadImages.matchFulfilled, matcher: imagesApi.endpoints.bulkDownloadImages.matchFulfilled,

View File

@ -0,0 +1,6 @@
import { createStandaloneToast, theme, TOAST_OPTIONS } from '@invoke-ai/ui-library';
export const { toast } = createStandaloneToast({
theme: theme,
defaultOptions: TOAST_OPTIONS.defaultOptions,
});

View File

@ -1,300 +1,54 @@
import { isModelIdentifier } from 'features/nodes/types/common'; import { MetadataControlNets } from 'features/metadata/components/MetadataControlNets';
import type { import { MetadataIPAdapters } from 'features/metadata/components/MetadataIPAdapters';
ControlNetMetadataItem, import { MetadataItem } from 'features/metadata/components/MetadataItem';
CoreMetadata, import { MetadataLoRAs } from 'features/metadata/components/MetadataLoRAs';
IPAdapterMetadataItem, import { MetadataT2IAdapters } from 'features/metadata/components/MetadataT2IAdapters';
LoRAMetadataItem, import { handlers } from 'features/metadata/util/handlers';
T2IAdapterMetadataItem, import { memo } from 'react';
} from 'features/nodes/types/metadata';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import ImageMetadataItem, { ModelMetadataItem, VAEMetadataItem } from './ImageMetadataItem';
type Props = { type Props = {
metadata?: CoreMetadata; metadata?: unknown;
}; };
const ImageMetadataActions = (props: Props) => { const ImageMetadataActions = (props: Props) => {
const { metadata } = props; const { metadata } = props;
const { t } = useTranslation();
const {
recallPositivePrompt,
recallNegativePrompt,
recallSeed,
recallCfgScale,
recallCfgRescaleMultiplier,
recallModel,
recallScheduler,
recallVaeModel,
recallSteps,
recallWidth,
recallHeight,
recallStrength,
recallHrfEnabled,
recallHrfStrength,
recallHrfMethod,
recallLoRA,
recallControlNet,
recallIPAdapter,
recallT2IAdapter,
recallSDXLPositiveStylePrompt,
recallSDXLNegativeStylePrompt,
} = useRecallParameters();
const handleRecallPositivePrompt = useCallback(() => {
recallPositivePrompt(metadata?.positive_prompt);
}, [metadata?.positive_prompt, recallPositivePrompt]);
const handleRecallNegativePrompt = useCallback(() => {
recallNegativePrompt(metadata?.negative_prompt);
}, [metadata?.negative_prompt, recallNegativePrompt]);
const handleRecallSDXLPositiveStylePrompt = useCallback(() => {
recallSDXLPositiveStylePrompt(metadata?.positive_style_prompt);
}, [metadata?.positive_style_prompt, recallSDXLPositiveStylePrompt]);
const handleRecallSDXLNegativeStylePrompt = useCallback(() => {
recallSDXLNegativeStylePrompt(metadata?.negative__style_prompt);
}, [metadata?.negative__style_prompt, recallSDXLNegativeStylePrompt]);
const handleRecallSeed = useCallback(() => {
recallSeed(metadata?.seed);
}, [metadata?.seed, recallSeed]);
const handleRecallModel = useCallback(() => {
recallModel(metadata?.model);
}, [metadata?.model, recallModel]);
const handleRecallWidth = useCallback(() => {
recallWidth(metadata?.width);
}, [metadata?.width, recallWidth]);
const handleRecallHeight = useCallback(() => {
recallHeight(metadata?.height);
}, [metadata?.height, recallHeight]);
const handleRecallScheduler = useCallback(() => {
recallScheduler(metadata?.scheduler);
}, [metadata?.scheduler, recallScheduler]);
const handleRecallVaeModel = useCallback(() => {
recallVaeModel(metadata?.vae);
}, [metadata?.vae, recallVaeModel]);
const handleRecallSteps = useCallback(() => {
recallSteps(metadata?.steps);
}, [metadata?.steps, recallSteps]);
const handleRecallCfgScale = useCallback(() => {
recallCfgScale(metadata?.cfg_scale);
}, [metadata?.cfg_scale, recallCfgScale]);
const handleRecallCfgRescaleMultiplier = useCallback(() => {
recallCfgRescaleMultiplier(metadata?.cfg_rescale_multiplier);
}, [metadata?.cfg_rescale_multiplier, recallCfgRescaleMultiplier]);
const handleRecallStrength = useCallback(() => {
recallStrength(metadata?.strength);
}, [metadata?.strength, recallStrength]);
const handleRecallHrfEnabled = useCallback(() => {
recallHrfEnabled(metadata?.hrf_enabled);
}, [metadata?.hrf_enabled, recallHrfEnabled]);
const handleRecallHrfStrength = useCallback(() => {
recallHrfStrength(metadata?.hrf_strength);
}, [metadata?.hrf_strength, recallHrfStrength]);
const handleRecallHrfMethod = useCallback(() => {
recallHrfMethod(metadata?.hrf_method);
}, [metadata?.hrf_method, recallHrfMethod]);
const handleRecallLoRA = useCallback(
(lora: LoRAMetadataItem) => {
recallLoRA(lora);
},
[recallLoRA]
);
const handleRecallControlNet = useCallback(
(controlnet: ControlNetMetadataItem) => {
recallControlNet(controlnet);
},
[recallControlNet]
);
const handleRecallIPAdapter = useCallback(
(ipAdapter: IPAdapterMetadataItem) => {
recallIPAdapter(ipAdapter);
},
[recallIPAdapter]
);
const handleRecallT2IAdapter = useCallback(
(ipAdapter: T2IAdapterMetadataItem) => {
recallT2IAdapter(ipAdapter);
},
[recallT2IAdapter]
);
const validControlNets: ControlNetMetadataItem[] = useMemo(() => {
return metadata?.controlnets
? metadata.controlnets.filter((controlnet) => isModelIdentifier(controlnet.control_model))
: [];
}, [metadata?.controlnets]);
const validIPAdapters: IPAdapterMetadataItem[] = useMemo(() => {
return metadata?.ipAdapters
? metadata.ipAdapters.filter((ipAdapter) => isModelIdentifier(ipAdapter.ip_adapter_model))
: [];
}, [metadata?.ipAdapters]);
const validT2IAdapters: T2IAdapterMetadataItem[] = useMemo(() => {
return metadata?.t2iAdapters
? metadata.t2iAdapters.filter((t2iAdapter) => isModelIdentifier(t2iAdapter.t2i_adapter_model))
: [];
}, [metadata?.t2iAdapters]);
if (!metadata || Object.keys(metadata).length === 0) { if (!metadata || Object.keys(metadata).length === 0) {
return null; return null;
} }
return ( return (
<> <>
{metadata.created_by && <ImageMetadataItem label={t('metadata.createdBy')} value={metadata.created_by} />} <MetadataItem metadata={metadata} handlers={handlers.createdBy} />
{metadata.generation_mode && ( <MetadataItem metadata={metadata} handlers={handlers.generationMode} />
<ImageMetadataItem label={t('metadata.generationMode')} value={metadata.generation_mode} /> <MetadataItem metadata={metadata} handlers={handlers.positivePrompt} direction="column" />
)} <MetadataItem metadata={metadata} handlers={handlers.negativePrompt} direction="column" />
{metadata.positive_prompt && ( <MetadataItem metadata={metadata} handlers={handlers.sdxlPositiveStylePrompt} direction="column" />
<ImageMetadataItem <MetadataItem metadata={metadata} handlers={handlers.sdxlNegativeStylePrompt} direction="column" />
label={t('metadata.positivePrompt')} <MetadataItem metadata={metadata} handlers={handlers.model} />
labelPosition="top" <MetadataItem metadata={metadata} handlers={handlers.vae} />
value={metadata.positive_prompt} <MetadataItem metadata={metadata} handlers={handlers.width} />
onClick={handleRecallPositivePrompt} <MetadataItem metadata={metadata} handlers={handlers.height} />
/> <MetadataItem metadata={metadata} handlers={handlers.seed} />
)} <MetadataItem metadata={metadata} handlers={handlers.steps} />
{metadata.negative_prompt && ( <MetadataItem metadata={metadata} handlers={handlers.scheduler} />
<ImageMetadataItem <MetadataItem metadata={metadata} handlers={handlers.cfgScale} />
label={t('metadata.negativePrompt')} <MetadataItem metadata={metadata} handlers={handlers.cfgRescaleMultiplier} />
labelPosition="top" <MetadataItem metadata={metadata} handlers={handlers.strength} />
value={metadata.negative_prompt} <MetadataItem metadata={metadata} handlers={handlers.hrfEnabled} />
onClick={handleRecallNegativePrompt} <MetadataItem metadata={metadata} handlers={handlers.hrfMethod} />
/> <MetadataItem metadata={metadata} handlers={handlers.hrfStrength} />
)} <MetadataItem metadata={metadata} handlers={handlers.refinerCFGScale} />
{metadata.positive_style_prompt && ( <MetadataItem metadata={metadata} handlers={handlers.refinerModel} />
<ImageMetadataItem <MetadataItem metadata={metadata} handlers={handlers.refinerNegativeAestheticScore} />
label={t('sdxl.posStylePrompt')} <MetadataItem metadata={metadata} handlers={handlers.refinerPositiveAestheticScore} />
labelPosition="top" <MetadataItem metadata={metadata} handlers={handlers.refinerScheduler} />
value={metadata.positive_style_prompt} <MetadataItem metadata={metadata} handlers={handlers.refinerStart} />
onClick={handleRecallSDXLPositiveStylePrompt} <MetadataItem metadata={metadata} handlers={handlers.refinerSteps} />
/> <MetadataLoRAs metadata={metadata} />
)} <MetadataControlNets metadata={metadata} />
{metadata.negative_style_prompt && ( <MetadataT2IAdapters metadata={metadata} />
<ImageMetadataItem <MetadataIPAdapters metadata={metadata} />
label={t('sdxl.negStylePrompt')}
labelPosition="top"
value={metadata.negative_style_prompt}
onClick={handleRecallSDXLNegativeStylePrompt}
/>
)}
{metadata.seed !== undefined && metadata.seed !== null && (
<ImageMetadataItem label={t('metadata.seed')} value={metadata.seed} onClick={handleRecallSeed} />
)}
{metadata.model !== undefined && metadata.model !== null && metadata.model.key && (
<ModelMetadataItem label={t('metadata.model')} modelKey={metadata.model.key} onClick={handleRecallModel} />
)}
{metadata.width && (
<ImageMetadataItem label={t('metadata.width')} value={metadata.width} onClick={handleRecallWidth} />
)}
{metadata.height && (
<ImageMetadataItem label={t('metadata.height')} value={metadata.height} onClick={handleRecallHeight} />
)}
{metadata.scheduler && (
<ImageMetadataItem label={t('metadata.scheduler')} value={metadata.scheduler} onClick={handleRecallScheduler} />
)}
<VAEMetadataItem label={t('metadata.vae')} modelKey={metadata.vae?.key} onClick={handleRecallVaeModel} />
{metadata.steps && (
<ImageMetadataItem label={t('metadata.steps')} value={metadata.steps} onClick={handleRecallSteps} />
)}
{metadata.cfg_scale !== undefined && metadata.cfg_scale !== null && (
<ImageMetadataItem label={t('metadata.cfgScale')} value={metadata.cfg_scale} onClick={handleRecallCfgScale} />
)}
{metadata.cfg_rescale_multiplier !== undefined && metadata.cfg_rescale_multiplier !== null && (
<ImageMetadataItem
label={t('metadata.cfgRescaleMultiplier')}
value={metadata.cfg_rescale_multiplier}
onClick={handleRecallCfgRescaleMultiplier}
/>
)}
{metadata.strength && (
<ImageMetadataItem label={t('metadata.strength')} value={metadata.strength} onClick={handleRecallStrength} />
)}
{metadata.hrf_enabled && (
<ImageMetadataItem
label={t('hrf.metadata.enabled')}
value={metadata.hrf_enabled}
onClick={handleRecallHrfEnabled}
/>
)}
{metadata.hrf_enabled && metadata.hrf_strength && (
<ImageMetadataItem
label={t('hrf.metadata.strength')}
value={metadata.hrf_strength}
onClick={handleRecallHrfStrength}
/>
)}
{metadata.hrf_enabled && metadata.hrf_method && (
<ImageMetadataItem
label={t('hrf.metadata.method')}
value={metadata.hrf_method}
onClick={handleRecallHrfMethod}
/>
)}
{metadata.loras &&
metadata.loras.map((lora, index) => {
if (isModelIdentifier(lora.lora)) {
return (
<ModelMetadataItem
key={index}
label="LoRA"
modelKey={lora.lora.key}
extra={` - ${lora.weight}`}
onClick={handleRecallLoRA.bind(null, lora)}
/>
);
}
})}
{validControlNets.map((controlnet, index) => (
<ModelMetadataItem
key={index}
label="ControlNet"
modelKey={controlnet.control_model?.key}
extra={` - ${controlnet.control_weight}`}
onClick={handleRecallControlNet.bind(null, controlnet)}
/>
))}
{validIPAdapters.map((ipAdapter, index) => (
<ModelMetadataItem
key={index}
label="IP Adapter"
modelKey={ipAdapter.ip_adapter_model?.key}
extra={` - ${ipAdapter.weight}`}
onClick={handleRecallIPAdapter.bind(null, ipAdapter)}
/>
))}
{validT2IAdapters.map((t2iAdapter, index) => (
<ModelMetadataItem
key={index}
label="T2I Adapter"
modelKey={t2iAdapter.t2i_adapter_model?.key}
extra={` - ${t2iAdapter.weight}`}
onClick={handleRecallT2IAdapter.bind(null, t2iAdapter)}
/>
))}
</> </>
); );
}; };

View File

@ -1,28 +1,39 @@
import { ExternalLink, Flex, IconButton, Text, Tooltip } from '@invoke-ai/ui-library'; import { ExternalLink, Flex, IconButton, Text, Tooltip } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query'; import { skipToken } from '@reduxjs/toolkit/query';
import { get } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react'; import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { IoArrowUndoCircleOutline } from 'react-icons/io5'; import { PiArrowBendUpLeftBold } from 'react-icons/pi';
import { PiCopyBold } from 'react-icons/pi';
import { useGetModelConfigQuery } from 'services/api/endpoints/models'; import { useGetModelConfigQuery } from 'services/api/endpoints/models';
type MetadataItemProps = { type MetadataItemProps = {
isLink?: boolean; isLink?: boolean;
label: string; label: string;
onClick?: () => void; metadata: unknown;
value: number | string | boolean; propertyName: string;
onRecall?: (value: unknown) => void;
labelPosition?: string; labelPosition?: string;
withCopy?: boolean;
}; };
/** /**
* Component to display an individual metadata item or parameter. * Component to display an individual metadata item or parameter.
*/ */
const ImageMetadataItem = ({ label, value, onClick, isLink, labelPosition, withCopy = false }: MetadataItemProps) => { const ImageMetadataItem = ({
label,
metadata,
propertyName,
onRecall: _onRecall,
isLink,
labelPosition,
}: MetadataItemProps) => {
const { t } = useTranslation(); const { t } = useTranslation();
const handleCopy = useCallback(() => { const value = useMemo(() => get(metadata, propertyName), [metadata, propertyName]);
navigator.clipboard.writeText(value?.toString()); const onRecall = useCallback(() => {
}, [value]); if (!_onRecall) {
return;
}
_onRecall(value);
}, [_onRecall, value]);
if (!value) { if (!value) {
return null; return null;
@ -30,27 +41,15 @@ const ImageMetadataItem = ({ label, value, onClick, isLink, labelPosition, withC
return ( return (
<Flex gap={2}> <Flex gap={2}>
{onClick && ( {_onRecall && (
<Tooltip label={`Recall ${label}`}> <Tooltip label={t('metadata.recallParameter', { parameter: label })}>
<IconButton <IconButton
aria-label={t('accessibility.useThisParameter')} aria-label={t('metadata.recallParameter', { parameter: label })}
icon={<IoArrowUndoCircleOutline />} icon={<PiArrowBendUpLeftBold />}
size="xs" size="xs"
variant="ghost" variant="ghost"
fontSize={20} fontSize={20}
onClick={onClick} onClick={onRecall}
/>
</Tooltip>
)}
{withCopy && (
<Tooltip label={`Copy ${label}`}>
<IconButton
aria-label={`Copy ${label}`}
icon={<PiCopyBold />}
size="xs"
variant="ghost"
fontSize={14}
onClick={handleCopy}
/> />
</Tooltip> </Tooltip>
)} )}

View File

@ -23,29 +23,29 @@ type LoRACardProps = {
export const LoRACard = memo((props: LoRACardProps) => { export const LoRACard = memo((props: LoRACardProps) => {
const { lora } = props; const { lora } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { data: loraConfig } = useGetModelConfigQuery(lora.key); const { data: loraConfig } = useGetModelConfigQuery(lora.model.key);
const handleChange = useCallback( const handleChange = useCallback(
(v: number) => { (v: number) => {
dispatch(loraWeightChanged({ key: lora.key, weight: v })); dispatch(loraWeightChanged({ key: lora.model.key, weight: v }));
}, },
[dispatch, lora.key] [dispatch, lora.model.key]
); );
const handleSetLoraToggle = useCallback(() => { const handleSetLoraToggle = useCallback(() => {
dispatch(loraIsEnabledChanged({ key: lora.key, isEnabled: !lora.isEnabled })); dispatch(loraIsEnabledChanged({ key: lora.model.key, isEnabled: !lora.isEnabled }));
}, [dispatch, lora.key, lora.isEnabled]); }, [dispatch, lora.model.key, lora.isEnabled]);
const handleRemoveLora = useCallback(() => { const handleRemoveLora = useCallback(() => {
dispatch(loraRemoved(lora.key)); dispatch(loraRemoved(lora.model.key));
}, [dispatch, lora.key]); }, [dispatch, lora.model.key]);
return ( return (
<Card variant="lora"> <Card variant="lora">
<CardHeader> <CardHeader>
<Flex alignItems="center" justifyContent="space-between" width="100%" gap={2}> <Flex alignItems="center" justifyContent="space-between" width="100%" gap={2}>
<Text noOfLines={1} wordBreak="break-all" color={lora.isEnabled ? 'base.200' : 'base.500'}> <Text noOfLines={1} wordBreak="break-all" color={lora.isEnabled ? 'base.200' : 'base.500'}>
{loraConfig?.name ?? lora.key.substring(0, 8)} {loraConfig?.name ?? lora.model.key.substring(0, 8)}
</Text> </Text>
<Flex alignItems="center" gap={2}> <Flex alignItems="center" gap={2}>
<Switch size="sm" onChange={handleSetLoraToggle} isChecked={lora.isEnabled} /> <Switch size="sm" onChange={handleSetLoraToggle} isChecked={lora.isEnabled} />

View File

@ -2,9 +2,11 @@ import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store'; import type { PersistConfig, RootState } from 'app/store/store';
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas'; import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
import { getModelKeyAndBase } from 'features/parameters/util/modelFetchingHelpers';
import type { LoRAModelConfig } from 'services/api/types'; import type { LoRAModelConfig } from 'services/api/types';
export type LoRA = ParameterLoRAModel & { export type LoRA = {
model: ParameterLoRAModel;
weight: number; weight: number;
isEnabled?: boolean; isEnabled?: boolean;
}; };
@ -29,11 +31,11 @@ export const loraSlice = createSlice({
initialState: initialLoraState, initialState: initialLoraState,
reducers: { reducers: {
loraAdded: (state, action: PayloadAction<LoRAModelConfig>) => { loraAdded: (state, action: PayloadAction<LoRAModelConfig>) => {
const { key, base } = action.payload; const model = getModelKeyAndBase(action.payload);
state.loras[key] = { key, base, ...defaultLoRAConfig }; state.loras[model.key] = { ...defaultLoRAConfig, model };
}, },
loraRecalled: (state, action: PayloadAction<LoRA>) => { loraRecalled: (state, action: PayloadAction<LoRA>) => {
state.loras[action.payload.key] = action.payload; state.loras[action.payload.model.key] = action.payload;
}, },
loraRemoved: (state, action: PayloadAction<string>) => { loraRemoved: (state, action: PayloadAction<string>) => {
const key = action.payload; const key = action.payload;
@ -58,7 +60,7 @@ export const loraSlice = createSlice({
} }
lora.weight = defaultLoRAConfig.weight; lora.weight = defaultLoRAConfig.weight;
}, },
loraIsEnabledChanged: (state, action: PayloadAction<Pick<LoRA, 'key' | 'isEnabled'>>) => { loraIsEnabledChanged: (state, action: PayloadAction<{ key: string; isEnabled: boolean }>) => {
const { key, isEnabled } = action.payload; const { key, isEnabled } = action.payload;
const lora = state.loras[key]; const lora = state.loras[key];
if (!lora) { if (!lora) {

View File

@ -0,0 +1,66 @@
import type { ControlNetConfig } from 'features/controlAdapters/store/types';
import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
import type { MetadataHandlers } from 'features/metadata/types';
import { handlers } from 'features/metadata/util/handlers';
import { useCallback, useEffect, useMemo, useState } from 'react';
type Props = {
metadata: unknown;
};
export const MetadataControlNets = ({ metadata }: Props) => {
const [controlNets, setControlNets] = useState<ControlNetConfig[]>([]);
useEffect(() => {
const parse = async () => {
try {
const parsed = await handlers.controlNets.parse(metadata);
setControlNets(parsed);
} catch (e) {
setControlNets([]);
}
};
parse();
}, [metadata]);
const label = useMemo(() => handlers.controlNets.getLabel(), []);
return (
<>
{controlNets.map((controlNet) => (
<MetadataViewControlNet
key={controlNet.model.key}
label={label}
controlNet={controlNet}
handlers={handlers.controlNets}
/>
))}
</>
);
};
const MetadataViewControlNet = ({
label,
controlNet,
handlers,
}: {
label: string;
controlNet: ControlNetConfig;
handlers: MetadataHandlers<ControlNetConfig[], ControlNetConfig>;
}) => {
const onRecall = useCallback(() => {
if (!handlers.recallItem) {
return;
}
handlers.recallItem(controlNet, true);
}, [handlers, controlNet]);
const renderedValue = useMemo(() => {
if (!handlers.renderItemValue) {
return null;
}
return handlers.renderItemValue(controlNet);
}, [handlers, controlNet]);
return <MetadataItemView label={label} isDisabled={false} onRecall={onRecall} renderedValue={renderedValue} />;
};

View File

@ -0,0 +1,66 @@
import type { IPAdapterConfig } from 'features/controlAdapters/store/types';
import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
import type { MetadataHandlers } from 'features/metadata/types';
import { handlers } from 'features/metadata/util/handlers';
import { useCallback, useEffect, useMemo, useState } from 'react';
type Props = {
metadata: unknown;
};
export const MetadataIPAdapters = ({ metadata }: Props) => {
const [ipAdapters, setIPAdapters] = useState<IPAdapterConfig[]>([]);
useEffect(() => {
const parse = async () => {
try {
const parsed = await handlers.ipAdapters.parse(metadata);
setIPAdapters(parsed);
} catch (e) {
setIPAdapters([]);
}
};
parse();
}, [metadata]);
const label = useMemo(() => handlers.ipAdapters.getLabel(), []);
return (
<>
{ipAdapters.map((ipAdapter) => (
<MetadataViewIPAdapter
key={ipAdapter.model.key}
label={label}
ipAdapter={ipAdapter}
handlers={handlers.ipAdapters}
/>
))}
</>
);
};
const MetadataViewIPAdapter = ({
label,
ipAdapter,
handlers,
}: {
label: string;
ipAdapter: IPAdapterConfig;
handlers: MetadataHandlers<IPAdapterConfig[], IPAdapterConfig>;
}) => {
const onRecall = useCallback(() => {
if (!handlers.recallItem) {
return;
}
handlers.recallItem(ipAdapter, true);
}, [handlers, ipAdapter]);
const renderedValue = useMemo(() => {
if (!handlers.renderItemValue) {
return null;
}
return handlers.renderItemValue(ipAdapter);
}, [handlers, ipAdapter]);
return <MetadataItemView label={label} isDisabled={false} onRecall={onRecall} renderedValue={renderedValue} />;
};

View File

@ -0,0 +1,33 @@
import { typedMemo } from '@invoke-ai/ui-library';
import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
import { useMetadataItem } from 'features/metadata/hooks/useMetadataItem';
import type { MetadataHandlers } from 'features/metadata/types';
import { MetadataParseFailedToken } from 'features/metadata/util/parsers';
type MetadataItemProps<T> = {
metadata: unknown;
handlers: MetadataHandlers<T>;
direction?: 'row' | 'column';
};
const _MetadataItem = typedMemo(<T,>({ metadata, handlers, direction = 'row' }: MetadataItemProps<T>) => {
const { label, isDisabled, value, renderedValue, onRecall } = useMetadataItem(metadata, handlers);
if (value === MetadataParseFailedToken) {
return null;
}
return (
<MetadataItemView
label={label}
onRecall={onRecall}
isDisabled={isDisabled}
renderedValue={renderedValue}
direction={direction}
/>
);
});
export const MetadataItem = typedMemo(_MetadataItem);
MetadataItem.displayName = 'MetadataItem';

View File

@ -0,0 +1,29 @@
import { Flex, Text } from '@invoke-ai/ui-library';
import { RecallButton } from 'features/metadata/components/RecallButton';
import { memo } from 'react';
type MetadataItemViewProps = {
onRecall: () => void;
label: string;
renderedValue: React.ReactNode;
isDisabled: boolean;
direction?: 'row' | 'column';
};
export const MetadataItemView = memo(
({ label, onRecall, isDisabled, renderedValue, direction = 'row' }: MetadataItemViewProps) => {
return (
<Flex gap={2}>
{onRecall && <RecallButton label={label} onClick={onRecall} isDisabled={isDisabled} />}
<Flex direction={direction}>
<Text fontWeight="semibold" whiteSpace="pre-wrap" pr={2}>
{label}:
</Text>
{renderedValue}
</Flex>
</Flex>
);
}
);
MetadataItemView.displayName = 'MetadataItemView';

View File

@ -0,0 +1,61 @@
import type { LoRA } from 'features/lora/store/loraSlice';
import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
import type { MetadataHandlers } from 'features/metadata/types';
import { handlers } from 'features/metadata/util/handlers';
import { useCallback, useEffect, useMemo, useState } from 'react';
type Props = {
metadata: unknown;
};
export const MetadataLoRAs = ({ metadata }: Props) => {
const [loras, setLoRAs] = useState<LoRA[]>([]);
useEffect(() => {
const parse = async () => {
try {
const parsed = await handlers.loras.parse(metadata);
setLoRAs(parsed);
} catch (e) {
setLoRAs([]);
}
};
parse();
}, [metadata]);
const label = useMemo(() => handlers.loras.getLabel(), []);
return (
<>
{loras.map((lora) => (
<MetadataViewLoRA key={lora.model.key} label={label} lora={lora} handlers={handlers.loras} />
))}
</>
);
};
const MetadataViewLoRA = ({
label,
lora,
handlers,
}: {
label: string;
lora: LoRA;
handlers: MetadataHandlers<LoRA[], LoRA>;
}) => {
const onRecall = useCallback(() => {
if (!handlers.recallItem) {
return;
}
handlers.recallItem(lora, true);
}, [handlers, lora]);
const renderedValue = useMemo(() => {
if (!handlers.renderItemValue) {
return null;
}
return handlers.renderItemValue(lora);
}, [handlers, lora]);
return <MetadataItemView label={label} isDisabled={false} onRecall={onRecall} renderedValue={renderedValue} />;
};

View File

@ -0,0 +1,66 @@
import type { T2IAdapterConfig } from 'features/controlAdapters/store/types';
import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
import type { MetadataHandlers } from 'features/metadata/types';
import { handlers } from 'features/metadata/util/handlers';
import { useCallback, useEffect, useMemo, useState } from 'react';
type Props = {
metadata: unknown;
};
export const MetadataT2IAdapters = ({ metadata }: Props) => {
const [t2iAdapters, setT2IAdapters] = useState<T2IAdapterConfig[]>([]);
useEffect(() => {
const parse = async () => {
try {
const parsed = await handlers.t2iAdapters.parse(metadata);
setT2IAdapters(parsed);
} catch (e) {
setT2IAdapters([]);
}
};
parse();
}, [metadata]);
const label = useMemo(() => handlers.t2iAdapters.getLabel(), []);
return (
<>
{t2iAdapters.map((t2iAdapter) => (
<MetadataViewT2IAdapter
key={t2iAdapter.model.key}
label={label}
t2iAdapter={t2iAdapter}
handlers={handlers.t2iAdapters}
/>
))}
</>
);
};
const MetadataViewT2IAdapter = ({
label,
t2iAdapter,
handlers,
}: {
label: string;
t2iAdapter: T2IAdapterConfig;
handlers: MetadataHandlers<T2IAdapterConfig[], T2IAdapterConfig>;
}) => {
const onRecall = useCallback(() => {
if (!handlers.recallItem) {
return;
}
handlers.recallItem(t2iAdapter, true);
}, [handlers, t2iAdapter]);
const renderedValue = useMemo(() => {
if (!handlers.renderItemValue) {
return null;
}
return handlers.renderItemValue(t2iAdapter);
}, [handlers, t2iAdapter]);
return <MetadataItemView label={label} isDisabled={false} onRecall={onRecall} renderedValue={renderedValue} />;
};

View File

@ -0,0 +1,27 @@
import type { IconButtonProps} from '@invoke-ai/ui-library';
import { IconButton, Tooltip } from '@invoke-ai/ui-library';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowBendUpLeftBold } from 'react-icons/pi';
type MetadataItemProps = Omit<IconButtonProps, 'aria-label'> & {
label: string;
};
export const RecallButton = memo(({ label, ...rest }: MetadataItemProps) => {
const { t } = useTranslation();
return (
<Tooltip label={t('metadata.recallParameter', { label })}>
<IconButton
aria-label={t('metadata.recallParameter', { label })}
icon={<PiArrowBendUpLeftBold />}
size="xs"
variant="ghost"
{...rest}
/>
</Tooltip>
);
});
RecallButton.displayName = 'RecallButton';

View File

@ -0,0 +1,27 @@
/**
* Raised when metadata parsing fails.
*/
export class MetadataParseError extends Error {
/**
* Create MetadataParseError
* @param {String} message
*/
constructor(message: string) {
super(message);
this.name = this.constructor.name;
}
}
/**
* Raised when metadata recall fails.
*/
export class MetadataRecallError extends Error {
/**
* Create MetadataRecallError
* @param {String} message
*/
constructor(message: string) {
super(message);
this.name = this.constructor.name;
}
}

View File

@ -0,0 +1,51 @@
import { Text } from '@invoke-ai/ui-library';
import type { MetadataHandlers } from 'features/metadata/types';
import { MetadataParseFailedToken, MetadataParsePendingToken } from 'features/metadata/util/parsers';
import { useCallback, useEffect, useMemo, useState } from 'react';
export const useMetadataItem = <T,>(metadata: unknown, handlers: MetadataHandlers<T>) => {
const [value, setValue] = useState<T | typeof MetadataParsePendingToken | typeof MetadataParseFailedToken>(
MetadataParsePendingToken
);
useEffect(() => {
const _parse = async () => {
try {
const parsed = await handlers.parse(metadata);
setValue(parsed);
} catch (e) {
setValue(MetadataParseFailedToken);
}
};
_parse();
}, [handlers, metadata]);
const isDisabled = useMemo(() => value === MetadataParsePendingToken || value === MetadataParseFailedToken, [value]);
const label = useMemo(() => handlers.getLabel(), [handlers]);
const renderedValue = useMemo(() => {
if (value === MetadataParsePendingToken) {
return <Text>Loading</Text>;
}
if (value === MetadataParseFailedToken) {
return <Text>Parsing Failed</Text>;
}
const rendered = handlers.renderValue(value);
if (typeof rendered === 'string') {
return <Text>{rendered}</Text>;
}
return rendered;
}, [handlers, value]);
const onRecall = useCallback(() => {
if (!handlers.recall || value === MetadataParsePendingToken || value === MetadataParseFailedToken) {
return null;
}
handlers.recall(value, true);
}, [handlers, value]);
return { label, isDisabled, value, renderedValue, onRecall };
};

View File

@ -0,0 +1,136 @@
/**
* Renders a value of type T as a React node.
*/
export type MetadataRenderValueFunc<T> = (value: T) => React.ReactNode;
/**
* Gets the label of the current metadata item as a string.
*/
export type MetadataGetLabelFunc = () => string;
export type MetadataParseOptions = {
toastOnFailure?: boolean;
toastOnSuccess?: boolean;
};
export type MetadataRecallOptions = MetadataParseOptions;
/**
* A function that recalls a parsed and validated metadata value.
*
* @param value The value to recall.
* @throws MetadataRecallError if the value cannot be recalled.
*/
export type MetadataRecallFunc<T> = (value: T) => void;
/**
* An async function that receives metadata and returns a parsed value, throwing if the value is invalid or missing.
*
* The function receives an object of unknown type. It is responsible for extracting the relevant data from the metadata
* and returning a value of type T.
*
* The function should throw a MetadataParseError if the metadata is invalid or missing.
*
* @param metadata The metadata to parse.
* @returns A promise that resolves to the parsed value.
* @throws MetadataParseError if the metadata is invalid or missing.
*/
export type MetadataParseFunc<T = unknown> = (metadata: unknown) => Promise<T>;
/**
* A function that performs additional validation logic before recalling a metadata value. It is called with a parsed
* value and should throw if the validation logic fails.
*
* This function is used in cases where some additional logic is required before recalling. For example, when recalling
* a LoRA, we need to check if it is compatible with the current base model.
*
* @param value The value to validate.
* @returns A promise that resolves to the validated value.
* @throws MetadataRecallError if the value is invalid.
*/
export type MetadataValidateFunc<T> = (value: T) => Promise<T>;
export type MetadataHandlers<TValue = unknown, TItem = unknown> = {
/**
* Gets the label of the current metadata item as a string.
*
* @returns The label of the current metadata item.
*/
getLabel: MetadataGetLabelFunc;
/**
* An async function that receives metadata and returns a parsed metadata value.
*
* @param metadata The metadata to parse.
* @param withToast Whether to show a toast on success or failure.
* @returns A promise that resolves to the parsed value.
* @throws MetadataParseError if the metadata is invalid or missing.
*/
parse: (metadata: unknown, withToast?: boolean) => Promise<TValue>;
/**
* An async function that receives a metadata item and returns a parsed metadata item value.
*
* This is only provided if the metadata value is an array.
*
* @param item The item to parse. It should be an item from the array.
* @param withToast Whether to show a toast on success or failure.
* @returns A promise that resolves to the parsed value.
* @throws MetadataParseError if the metadata is invalid or missing.
*/
parseItem?: (item: unknown, withToast?: boolean) => Promise<TItem>;
/**
* An async function that recalls a parsed metadata value.
*
* This function is only provided if the metadata value can be recalled.
*
* @param value The value to recall.
* @param withToast Whether to show a toast on success or failure.
* @returns A promise that resolves when the recall operation is complete.
* @throws MetadataRecallError if the value cannot be recalled.
*/
recall?: (value: TValue, withToast?: boolean) => Promise<void>;
/**
* An async function that recalls a parsed metadata item value.
*
* This function is only provided if the metadata value is an array and the items can be recalled.
*
* @param item The item to recall. It should be an item from the array.
* @param withToast Whether to show a toast on success or failure.
* @returns A promise that resolves when the recall operation is complete.
* @throws MetadataRecallError if the value cannot be recalled.
*/
recallItem?: (item: TItem, withToast?: boolean) => Promise<void>;
/**
* Renders a parsed metadata value as a React node.
*
* @param value The value to render.
* @returns The rendered value.
*/
renderValue: MetadataRenderValueFunc<TValue>;
/**
* Renders a parsed metadata item value as a React node.
*
* @param item The item to render.
* @returns The rendered item.
*/
renderItemValue?: MetadataRenderValueFunc<TItem>;
};
// TODO(psyche): The types for item handlers should be able to be inferred from the type of the value:
// type MetadataHandlersInferItem<TValue> = TValue extends Array<infer TItem> ? MetadataParseFunc<TItem> : never
// While this works for the types as expected, I couldn't satisfy TS in the implementations of the handlers.
export type BuildMetadataHandlersArg<TValue, TItem> = {
parser: MetadataParseFunc<TValue>;
itemParser?: MetadataParseFunc<TItem>;
recaller?: MetadataRecallFunc<TValue>;
itemRecaller?: MetadataRecallFunc<TItem>;
validator?: MetadataValidateFunc<TValue>;
itemValidator?: MetadataValidateFunc<TItem>;
getLabel: MetadataGetLabelFunc;
renderValue?: MetadataRenderValueFunc<TValue>;
renderItemValue?: MetadataRenderValueFunc<TItem>;
};
export type BuildMetadataHandlers = <TValue, TItem>(
arg: BuildMetadataHandlersArg<TValue, TItem>
) => MetadataHandlers<TValue, TItem>;

View File

@ -0,0 +1,316 @@
import { Text } from '@invoke-ai/ui-library';
import { toast } from 'common/util/toast';
import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types';
import type { LoRA } from 'features/lora/store/loraSlice';
import type {
BuildMetadataHandlers,
MetadataGetLabelFunc,
MetadataHandlers,
MetadataParseFunc,
MetadataRecallFunc,
MetadataRenderValueFunc,
MetadataValidateFunc,
} from 'features/metadata/types';
import { validators } from 'features/metadata/util/validators';
import { t } from 'i18next';
import type { AnyModelConfig } from 'services/api/types';
import { parsers } from './parsers';
import { recallers } from './recallers';
const renderModelConfigValue: MetadataRenderValueFunc<AnyModelConfig> = (value) =>
`${value.name} (${value.base.toUpperCase()}, ${value.key})`;
const renderLoRAValue: MetadataRenderValueFunc<LoRA> = (value) => <Text>{`${value.model.key} (${value.weight})`}</Text>;
const renderControlAdapterValue: MetadataRenderValueFunc<ControlNetConfig | T2IAdapterConfig | IPAdapterConfig> = (
value
) => <Text>{`${value.model?.key} (${value.weight})`}</Text>;
const parameterSetToast = (parameter: string, description?: string) => {
toast({
title: t('toast.parameterSet', { parameter }),
description,
status: 'info',
duration: 2500,
isClosable: true,
});
};
const parameterNotSetToast = (parameter: string, description?: string) => {
toast({
title: t('toast.parameterNotSet', { parameter }),
description,
status: 'warning',
duration: 2500,
isClosable: true,
});
};
// const allParameterSetToast = (description?: string) => {
// toast({
// title: t('toast.parametersSet'),
// status: 'info',
// description,
// duration: 2500,
// isClosable: true,
// });
// };
// const allParameterNotSetToast = (description?: string) => {
// toast({
// title: t('toast.parametersNotSet'),
// status: 'warning',
// description,
// duration: 2500,
// isClosable: true,
// });
// };
const buildParse =
<TValue, TItem>(arg: {
parser: MetadataParseFunc<TValue>;
getLabel: MetadataGetLabelFunc;
}): MetadataHandlers<TValue, TItem>['parse'] =>
async (metadata, withToast = false) => {
try {
const parsed = await arg.parser(metadata);
withToast && parameterSetToast(arg.getLabel());
return parsed;
} catch (e) {
withToast && parameterNotSetToast(arg.getLabel(), (e as Error).message);
throw e;
}
};
const buildParseItem =
<TValue, TItem>(arg: {
itemParser: MetadataParseFunc<TItem>;
getLabel: MetadataGetLabelFunc;
}): MetadataHandlers<TValue, TItem>['parseItem'] =>
async (item, withToast = false) => {
try {
const parsed = await arg.itemParser(item);
withToast && parameterSetToast(arg.getLabel());
return parsed;
} catch (e) {
withToast && parameterNotSetToast(arg.getLabel(), (e as Error).message);
throw e;
}
};
const buildRecall =
<TValue, TItem>(arg: {
recaller: MetadataRecallFunc<TValue>;
validator?: MetadataValidateFunc<TValue>;
getLabel: MetadataGetLabelFunc;
}): NonNullable<MetadataHandlers<TValue, TItem>['recall']> =>
async (value, withToast = false) => {
try {
arg.validator && (await arg.validator(value));
await arg.recaller(value);
withToast && parameterSetToast(arg.getLabel());
} catch (e) {
withToast && parameterNotSetToast(arg.getLabel(), (e as Error).message);
throw e;
}
};
const buildRecallItem =
<TValue, TItem>(arg: {
itemRecaller: MetadataRecallFunc<TItem>;
itemValidator?: MetadataValidateFunc<TItem>;
getLabel: MetadataGetLabelFunc;
}): NonNullable<MetadataHandlers<TValue, TItem>['recallItem']> =>
async (item, withToast = false) => {
try {
arg.itemValidator && (await arg.itemValidator(item));
await arg.itemRecaller(item);
withToast && parameterSetToast(arg.getLabel());
} catch (e) {
withToast && parameterNotSetToast(arg.getLabel(), (e as Error).message);
throw e;
}
};
const buildHandlers: BuildMetadataHandlers = ({
getLabel,
parser,
itemParser,
recaller,
itemRecaller,
validator,
itemValidator,
renderValue,
renderItemValue,
}) => ({
parse: buildParse({ parser, getLabel }),
parseItem: itemParser ? buildParseItem({ itemParser, getLabel }) : undefined,
recall: recaller ? buildRecall({ recaller, validator, getLabel }) : undefined,
recallItem: itemRecaller ? buildRecallItem({ itemRecaller, itemValidator, getLabel }) : undefined,
getLabel,
renderValue: renderValue ?? String,
renderItemValue: renderItemValue ?? String,
});
export const handlers = {
// Misc
createdBy: buildHandlers({ getLabel: () => t('metadata.createdBy'), parser: parsers.createdBy }),
generationMode: buildHandlers({ getLabel: () => t('metadata.generationMode'), parser: parsers.generationMode }),
// Core parameters
cfgRescaleMultiplier: buildHandlers({
getLabel: () => t('metadata.cfgRescaleMultiplier'),
parser: parsers.cfgRescaleMultiplier,
recaller: recallers.cfgRescaleMultiplier,
}),
cfgScale: buildHandlers({
getLabel: () => t('metadata.cfgScale'),
parser: parsers.cfgScale,
recaller: recallers.cfgScale,
}),
height: buildHandlers({ getLabel: () => t('metadata.height'), parser: parsers.height, recaller: recallers.height }),
negativePrompt: buildHandlers({
getLabel: () => t('metadata.negativePrompt'),
parser: parsers.negativePrompt,
recaller: recallers.negativePrompt,
}),
positivePrompt: buildHandlers({
getLabel: () => t('metadata.positivePrompt'),
parser: parsers.positivePrompt,
recaller: recallers.positivePrompt,
}),
scheduler: buildHandlers({
getLabel: () => t('metadata.scheduler'),
parser: parsers.scheduler,
recaller: recallers.scheduler,
}),
sdxlNegativeStylePrompt: buildHandlers({
getLabel: () => t('sdxl.negStylePrompt'),
parser: parsers.sdxlNegativeStylePrompt,
recaller: recallers.sdxlNegativeStylePrompt,
}),
sdxlPositiveStylePrompt: buildHandlers({
getLabel: () => t('sdxl.posStylePrompt'),
parser: parsers.sdxlPositiveStylePrompt,
recaller: recallers.sdxlPositiveStylePrompt,
}),
seed: buildHandlers({ getLabel: () => t('metadata.seed'), parser: parsers.seed, recaller: recallers.seed }),
steps: buildHandlers({ getLabel: () => t('metadata.steps'), parser: parsers.steps, recaller: recallers.steps }),
strength: buildHandlers({
getLabel: () => t('metadata.strength'),
parser: parsers.strength,
recaller: recallers.strength,
}),
width: buildHandlers({ getLabel: () => t('metadata.width'), parser: parsers.width, recaller: recallers.width }),
// HRF
hrfEnabled: buildHandlers({
getLabel: () => t('hrf.metadata.enabled'),
parser: parsers.hrfEnabled,
recaller: recallers.hrfEnabled,
}),
hrfMethod: buildHandlers({
getLabel: () => t('hrf.metadata.method'),
parser: parsers.hrfMethod,
recaller: recallers.hrfMethod,
}),
hrfStrength: buildHandlers({
getLabel: () => t('hrf.metadata.strength'),
parser: parsers.hrfStrength,
recaller: recallers.hrfStrength,
}),
// Refiner
refinerCFGScale: buildHandlers({
getLabel: () => t('sdxl.cfgScale'),
parser: parsers.refinerCFGScale,
recaller: recallers.refinerCFGScale,
}),
refinerModel: buildHandlers({
getLabel: () => t('sdxl.refinerModel'),
parser: parsers.refinerModel,
recaller: recallers.refinerModel,
validator: validators.refinerModel,
}),
refinerNegativeAestheticScore: buildHandlers({
getLabel: () => t('sdxl.posAestheticScore'),
parser: parsers.refinerNegativeAestheticScore,
recaller: recallers.refinerNegativeAestheticScore,
}),
refinerPositiveAestheticScore: buildHandlers({
getLabel: () => t('sdxl.negAestheticScore'),
parser: parsers.refinerPositiveAestheticScore,
recaller: recallers.refinerPositiveAestheticScore,
}),
refinerScheduler: buildHandlers({
getLabel: () => t('sdxl.scheduler'),
parser: parsers.refinerScheduler,
recaller: recallers.refinerScheduler,
}),
refinerStart: buildHandlers({
getLabel: () => t('sdxl.refiner_start'),
parser: parsers.refinerStart,
recaller: recallers.refinerStart,
}),
refinerSteps: buildHandlers({
getLabel: () => t('sdxl.refiner_steps'),
parser: parsers.refinerSteps,
recaller: recallers.refinerSteps,
}),
// Models
model: buildHandlers({
getLabel: () => t('metadata.model'),
parser: parsers.mainModel,
recaller: recallers.model,
renderValue: renderModelConfigValue,
}),
vae: buildHandlers({
getLabel: () => t('metadata.vae'),
parser: parsers.vaeModel,
recaller: recallers.vae,
renderValue: renderModelConfigValue,
validator: validators.vaeModel,
}),
// Arrays of models
controlNets: buildHandlers({
getLabel: () => t('common.controlNet'),
parser: parsers.controlNets,
itemParser: parsers.controlNet,
recaller: recallers.controlNets,
itemRecaller: recallers.controlNet,
validator: validators.controlNets,
itemValidator: validators.controlNet,
renderItemValue: renderControlAdapterValue,
}),
ipAdapters: buildHandlers({
getLabel: () => t('common.ipAdapter'),
parser: parsers.ipAdapters,
itemParser: parsers.ipAdapter,
recaller: recallers.ipAdapters,
itemRecaller: recallers.ipAdapter,
validator: validators.ipAdapters,
itemValidator: validators.ipAdapter,
renderItemValue: renderControlAdapterValue,
}),
loras: buildHandlers({
getLabel: () => t('models.lora'),
parser: parsers.loras,
itemParser: parsers.lora,
recaller: recallers.loras,
itemRecaller: recallers.lora,
validator: validators.loras,
itemValidator: validators.lora,
renderItemValue: renderLoRAValue,
}),
t2iAdapters: buildHandlers({
getLabel: () => t('common.t2iAdapter'),
parser: parsers.t2iAdapters,
itemParser: parsers.t2iAdapter,
recaller: recallers.t2iAdapters,
itemRecaller: recallers.t2iAdapter,
validator: validators.t2iAdapters,
itemValidator: validators.t2iAdapter,
renderItemValue: renderControlAdapterValue,
}),
} as const;

View File

@ -0,0 +1,396 @@
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types';
import {
initialControlNet,
initialIPAdapter,
initialT2IAdapter,
} from 'features/controlAdapters/util/buildControlAdapter';
import type { LoRA } from 'features/lora/store/loraSlice';
import { defaultLoRAConfig } from 'features/lora/store/loraSlice';
import { MetadataParseError } from 'features/metadata/exceptions';
import type { MetadataParseFunc } from 'features/metadata/types';
import {
zControlField,
zIPAdapterField,
zModelIdentifierWithBase,
zT2IAdapterField,
} from 'features/nodes/types/common';
import type {
ParameterCFGRescaleMultiplier,
ParameterCFGScale,
ParameterHeight,
ParameterHRFEnabled,
ParameterHRFMethod,
ParameterNegativePrompt,
ParameterNegativeStylePromptSDXL,
ParameterPositivePrompt,
ParameterPositiveStylePromptSDXL,
ParameterScheduler,
ParameterSDXLRefinerNegativeAestheticScore,
ParameterSDXLRefinerPositiveAestheticScore,
ParameterSDXLRefinerStart,
ParameterSeed,
ParameterSteps,
ParameterStrength,
ParameterWidth,
} from 'features/parameters/types/parameterSchemas';
import {
isParameterCFGRescaleMultiplier,
isParameterCFGScale,
isParameterHeight,
isParameterHRFEnabled,
isParameterHRFMethod,
isParameterLoRAWeight,
isParameterNegativePrompt,
isParameterNegativeStylePromptSDXL,
isParameterPositivePrompt,
isParameterPositiveStylePromptSDXL,
isParameterScheduler,
isParameterSDXLRefinerNegativeAestheticScore,
isParameterSDXLRefinerPositiveAestheticScore,
isParameterSDXLRefinerStart,
isParameterSeed,
isParameterSteps,
isParameterStrength,
isParameterWidth,
} from 'features/parameters/types/parameterSchemas';
import {
fetchModelConfigWithTypeGuard,
getModelKey,
getModelKeyAndBase,
} from 'features/parameters/util/modelFetchingHelpers';
import { get, isArray, isString } from 'lodash-es';
import type { NonRefinerMainModelConfig, RefinerMainModelConfig, VAEModelConfig } from 'services/api/types';
import {
isControlNetModelConfig,
isIPAdapterModelConfig,
isLoRAModelConfig,
isNonRefinerMainModelConfig,
isRefinerMainModelModelConfig,
isT2IAdapterModelConfig,
isVAEModelConfig,
} from 'services/api/types';
import { v4 as uuidv4 } from 'uuid';
export const MetadataParsePendingToken = Symbol('pending');
export const MetadataParseFailedToken = Symbol('failed');
/**
* An async function that a property from an object and validates its type using a type guard. If the property is missing
* or invalid, the function should throw a MetadataParseError.
* @param obj The object to get the property from.
* @param property The property to get.
* @param typeGuard A type guard function to check the type of the property. Provide `undefined` to opt out of type
* validation and always return the property value.
* @returns A promise that resolves to the property value if it exists and is of the expected type.
* @throws MetadataParseError if a type guard is provided and the property is not of the expected type.
*/
const getProperty = <T = unknown>(
obj: unknown,
property: string,
typeGuard: (val: unknown) => val is T = (val: unknown): val is T => true
): Promise<T> => {
return new Promise<T>((resolve, reject) => {
const val = get(obj, property) as unknown;
if (typeGuard(val)) {
resolve(val);
}
reject(new MetadataParseError(`Property ${property} is not of expected type`));
});
};
const parseCreatedBy: MetadataParseFunc<string> = (metadata) => getProperty(metadata, 'created_by', isString);
const parseGenerationMode: MetadataParseFunc<string> = (metadata) => getProperty(metadata, 'generation_mode', isString);
const parsePositivePrompt: MetadataParseFunc<ParameterPositivePrompt> = (metadata) =>
getProperty(metadata, 'positive_prompt', isParameterPositivePrompt);
const parseNegativePrompt: MetadataParseFunc<ParameterNegativePrompt> = (metadata) =>
getProperty(metadata, 'negative_prompt', isParameterNegativePrompt);
const parseSDXLPositiveStylePrompt: MetadataParseFunc<ParameterPositiveStylePromptSDXL> = (metadata) =>
getProperty(metadata, 'positive_style_prompt', isParameterPositiveStylePromptSDXL);
const parseSDXLNegativeStylePrompt: MetadataParseFunc<ParameterNegativeStylePromptSDXL> = (metadata) =>
getProperty(metadata, 'negative_style_prompt', isParameterNegativeStylePromptSDXL);
const parseSeed: MetadataParseFunc<ParameterSeed> = (metadata) => getProperty(metadata, 'seed', isParameterSeed);
const parseCFGScale: MetadataParseFunc<ParameterCFGScale> = (metadata) =>
getProperty(metadata, 'cfg_scale', isParameterCFGScale);
const parseCFGRescaleMultiplier: MetadataParseFunc<ParameterCFGRescaleMultiplier> = (metadata) =>
getProperty(metadata, 'cfg_rescale_multiplier', isParameterCFGRescaleMultiplier);
const parseScheduler: MetadataParseFunc<ParameterScheduler> = (metadata) =>
getProperty(metadata, 'scheduler', isParameterScheduler);
const parseWidth: MetadataParseFunc<ParameterWidth> = (metadata) => getProperty(metadata, 'width', isParameterWidth);
const parseHeight: MetadataParseFunc<ParameterHeight> = (metadata) =>
getProperty(metadata, 'height', isParameterHeight);
const parseSteps: MetadataParseFunc<ParameterSteps> = (metadata) => getProperty(metadata, 'steps', isParameterSteps);
const parseStrength: MetadataParseFunc<ParameterStrength> = (metadata) =>
getProperty(metadata, 'strength', isParameterStrength);
const parseHRFEnabled: MetadataParseFunc<ParameterHRFEnabled> = (metadata) =>
getProperty(metadata, 'hrf_enabled', isParameterHRFEnabled);
const parseHRFStrength: MetadataParseFunc<ParameterStrength> = (metadata) =>
getProperty(metadata, 'hrf_strength', isParameterStrength);
const parseHRFMethod: MetadataParseFunc<ParameterHRFMethod> = (metadata) =>
getProperty(metadata, 'hrf_method', isParameterHRFMethod);
const parseRefinerSteps: MetadataParseFunc<ParameterSteps> = (metadata) =>
getProperty(metadata, 'refiner_steps', isParameterSteps);
const parseRefinerCFGScale: MetadataParseFunc<ParameterCFGScale> = (metadata) =>
getProperty(metadata, 'refiner_cfg_scale', isParameterCFGScale);
const parseRefinerScheduler: MetadataParseFunc<ParameterScheduler> = (metadata) =>
getProperty(metadata, 'refiner_scheduler', isParameterScheduler);
const parseRefinerPositiveAestheticScore: MetadataParseFunc<ParameterSDXLRefinerPositiveAestheticScore> = (metadata) =>
getProperty(metadata, 'refiner_positive_aesthetic_score', isParameterSDXLRefinerPositiveAestheticScore);
const parseRefinerNegativeAestheticScore: MetadataParseFunc<ParameterSDXLRefinerNegativeAestheticScore> = (metadata) =>
getProperty(metadata, 'refiner_negative_aesthetic_score', isParameterSDXLRefinerNegativeAestheticScore);
const parseRefinerStart: MetadataParseFunc<ParameterSDXLRefinerStart> = (metadata) =>
getProperty(metadata, 'refiner_start', isParameterSDXLRefinerStart);
const parseMainModel: MetadataParseFunc<NonRefinerMainModelConfig> = async (metadata) => {
const model = await getProperty(metadata, 'model', undefined);
const key = await getModelKey(model, 'main');
const mainModelConfig = await fetchModelConfigWithTypeGuard(key, isNonRefinerMainModelConfig);
return mainModelConfig;
};
const parseRefinerModel: MetadataParseFunc<RefinerMainModelConfig> = async (metadata) => {
const refiner_model = await getProperty(metadata, 'refiner_model', undefined);
const key = await getModelKey(refiner_model, 'main');
const refinerModelConfig = await fetchModelConfigWithTypeGuard(key, isRefinerMainModelModelConfig);
return refinerModelConfig;
};
const parseVAEModel: MetadataParseFunc<VAEModelConfig> = async (metadata) => {
const vae = await getProperty(metadata, 'vae', undefined);
const key = await getModelKey(vae, 'vae');
const vaeModelConfig = await fetchModelConfigWithTypeGuard(key, isVAEModelConfig);
return vaeModelConfig;
};
const parseLoRA: MetadataParseFunc<LoRA> = async (metadataItem) => {
// Previously, the LoRA model identifier parts were stored in the LoRA metadata: `{key: ..., weight: 0.75}`
const modelV1 = await getProperty(metadataItem, 'lora', undefined);
// Now, the LoRA model is stored in a `model` property of the LoRA metadata: `{model: {key: ...}, weight: 0.75}`
const modelV2 = await getProperty(metadataItem, 'model', undefined);
const weight = await getProperty(metadataItem, 'weight', undefined);
const key = await getModelKey(modelV2 ?? modelV1, 'lora');
const loraModelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig);
return {
model: getModelKeyAndBase(loraModelConfig),
weight: isParameterLoRAWeight(weight) ? weight : defaultLoRAConfig.weight,
isEnabled: true,
};
};
const parseAllLoRAs: MetadataParseFunc<LoRA[]> = async (metadata) => {
const lorasRaw = await getProperty(metadata, 'loras', isArray);
const parseResults = await Promise.allSettled(lorasRaw.map((lora) => parseLoRA(lora)));
const loras = parseResults
.filter((result): result is PromiseFulfilledResult<LoRA> => result.status === 'fulfilled')
.map((result) => result.value);
return loras;
};
const parseControlNet: MetadataParseFunc<ControlNetConfig> = async (metadataItem) => {
const control_model = await getProperty(metadataItem, 'control_model');
const key = await getModelKey(control_model, 'controlnet');
const controlNetModel = await fetchModelConfigWithTypeGuard(key, isControlNetModelConfig);
const image = zControlField.shape.image.nullish().catch(null).parse(getProperty(metadataItem, 'image'));
const control_weight = zControlField.shape.control_weight
.nullish()
.catch(null)
.parse(getProperty(metadataItem, 'control_weight'));
const begin_step_percent = zControlField.shape.begin_step_percent
.nullish()
.catch(null)
.parse(getProperty(metadataItem, 'begin_step_percent'));
const end_step_percent = zControlField.shape.end_step_percent
.nullish()
.catch(null)
.parse(getProperty(metadataItem, 'end_step_percent'));
const control_mode = zControlField.shape.control_mode
.nullish()
.catch(null)
.parse(getProperty(metadataItem, 'control_mode'));
const resize_mode = zControlField.shape.resize_mode
.nullish()
.catch(null)
.parse(getProperty(metadataItem, 'resize_mode'));
const processorType = 'none';
const processorNode = CONTROLNET_PROCESSORS.none.default;
const controlNet: ControlNetConfig = {
type: 'controlnet',
isEnabled: true,
model: zModelIdentifierWithBase.parse(controlNetModel),
weight: typeof control_weight === 'number' ? control_weight : initialControlNet.weight,
beginStepPct: begin_step_percent ?? initialControlNet.beginStepPct,
endStepPct: end_step_percent ?? initialControlNet.endStepPct,
controlMode: control_mode ?? initialControlNet.controlMode,
resizeMode: resize_mode ?? initialControlNet.resizeMode,
controlImage: image?.image_name ?? null,
processedControlImage: image?.image_name ?? null,
processorType,
processorNode,
shouldAutoConfig: true,
id: uuidv4(),
};
return controlNet;
};
const parseAllControlNets: MetadataParseFunc<ControlNetConfig[]> = async (metadata) => {
const controlNetsRaw = await getProperty(metadata, 'controlnets', isArray);
const parseResults = await Promise.allSettled(controlNetsRaw.map((cn) => parseControlNet(cn)));
const controlNets = parseResults
.filter((result): result is PromiseFulfilledResult<ControlNetConfig> => result.status === 'fulfilled')
.map((result) => result.value);
return controlNets;
};
const parseT2IAdapter: MetadataParseFunc<T2IAdapterConfig> = async (metadataItem) => {
const t2i_adapter_model = await getProperty(metadataItem, 't2i_adapter_model');
const key = await getModelKey(t2i_adapter_model, 't2i_adapter');
const t2iAdapterModel = await fetchModelConfigWithTypeGuard(key, isT2IAdapterModelConfig);
const image = zT2IAdapterField.shape.image.nullish().catch(null).parse(getProperty(metadataItem, 'image'));
const weight = zT2IAdapterField.shape.weight.nullish().catch(null).parse(getProperty(metadataItem, 'weight'));
const begin_step_percent = zT2IAdapterField.shape.begin_step_percent
.nullish()
.catch(null)
.parse(getProperty(metadataItem, 'begin_step_percent'));
const end_step_percent = zT2IAdapterField.shape.end_step_percent
.nullish()
.catch(null)
.parse(getProperty(metadataItem, 'end_step_percent'));
const resize_mode = zT2IAdapterField.shape.resize_mode
.nullish()
.catch(null)
.parse(getProperty(metadataItem, 'resize_mode'));
const processorType = 'none';
const processorNode = CONTROLNET_PROCESSORS.none.default;
const t2iAdapter: T2IAdapterConfig = {
type: 't2i_adapter',
isEnabled: true,
model: zModelIdentifierWithBase.parse(t2iAdapterModel),
weight: typeof weight === 'number' ? weight : initialT2IAdapter.weight,
beginStepPct: begin_step_percent ?? initialT2IAdapter.beginStepPct,
endStepPct: end_step_percent ?? initialT2IAdapter.endStepPct,
resizeMode: resize_mode ?? initialT2IAdapter.resizeMode,
controlImage: image?.image_name ?? null,
processedControlImage: image?.image_name ?? null,
processorType,
processorNode,
shouldAutoConfig: true,
id: uuidv4(),
};
return t2iAdapter;
};
const parseAllT2IAdapters: MetadataParseFunc<T2IAdapterConfig[]> = async (metadata) => {
const t2iAdaptersRaw = await getProperty(metadata, 't2iAdapters', isArray);
const parseResults = await Promise.allSettled(t2iAdaptersRaw.map((t2iAdapter) => parseT2IAdapter(t2iAdapter)));
const t2iAdapters = parseResults
.filter((result): result is PromiseFulfilledResult<T2IAdapterConfig> => result.status === 'fulfilled')
.map((result) => result.value);
return t2iAdapters;
};
const parseIPAdapter: MetadataParseFunc<IPAdapterConfig> = async (metadataItem) => {
const ip_adapter_model = await getProperty(metadataItem, 'ip_adapter_model');
const key = await getModelKey(ip_adapter_model, 'ip_adapter');
const ipAdapterModel = await fetchModelConfigWithTypeGuard(key, isIPAdapterModelConfig);
const image = zIPAdapterField.shape.image.nullish().catch(null).parse(getProperty(metadataItem, 'image'));
const weight = zIPAdapterField.shape.weight.nullish().catch(null).parse(getProperty(metadataItem, 'weight'));
const begin_step_percent = zIPAdapterField.shape.begin_step_percent
.nullish()
.catch(null)
.parse(getProperty(metadataItem, 'begin_step_percent'));
const end_step_percent = zIPAdapterField.shape.end_step_percent
.nullish()
.catch(null)
.parse(getProperty(metadataItem, 'end_step_percent'));
const ipAdapter: IPAdapterConfig = {
id: uuidv4(),
type: 'ip_adapter',
isEnabled: true,
model: zModelIdentifierWithBase.parse(ipAdapterModel),
controlImage: image?.image_name ?? null,
weight: weight ?? initialIPAdapter.weight,
beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct,
endStepPct: end_step_percent ?? initialIPAdapter.endStepPct,
};
return ipAdapter;
};
const parseAllIPAdapters: MetadataParseFunc<IPAdapterConfig[]> = async (metadata) => {
const ipAdaptersRaw = await getProperty(metadata, 'ipAdapters', isArray);
const parseResults = await Promise.allSettled(ipAdaptersRaw.map((ipAdapter) => parseIPAdapter(ipAdapter)));
const ipAdapters = parseResults
.filter((result): result is PromiseFulfilledResult<IPAdapterConfig> => result.status === 'fulfilled')
.map((result) => result.value);
return ipAdapters;
};
export const parsers = {
createdBy: parseCreatedBy,
generationMode: parseGenerationMode,
positivePrompt: parsePositivePrompt,
negativePrompt: parseNegativePrompt,
sdxlPositiveStylePrompt: parseSDXLPositiveStylePrompt,
sdxlNegativeStylePrompt: parseSDXLNegativeStylePrompt,
seed: parseSeed,
cfgScale: parseCFGScale,
cfgRescaleMultiplier: parseCFGRescaleMultiplier,
scheduler: parseScheduler,
width: parseWidth,
height: parseHeight,
steps: parseSteps,
strength: parseStrength,
hrfEnabled: parseHRFEnabled,
hrfStrength: parseHRFStrength,
hrfMethod: parseHRFMethod,
refinerSteps: parseRefinerSteps,
refinerCFGScale: parseRefinerCFGScale,
refinerScheduler: parseRefinerScheduler,
refinerPositiveAestheticScore: parseRefinerPositiveAestheticScore,
refinerNegativeAestheticScore: parseRefinerNegativeAestheticScore,
refinerStart: parseRefinerStart,
mainModel: parseMainModel,
refinerModel: parseRefinerModel,
vaeModel: parseVAEModel,
lora: parseLoRA,
loras: parseAllLoRAs,
controlNet: parseControlNet,
controlNets: parseAllControlNets,
t2iAdapter: parseT2IAdapter,
t2iAdapters: parseAllT2IAdapters,
ipAdapter: parseIPAdapter,
ipAdapters: parseAllIPAdapters,
} as const;

View File

@ -0,0 +1,295 @@
import { getStore } from 'app/store/nanostores/store';
import { controlAdapterRecalled } from 'features/controlAdapters/store/controlAdaptersSlice';
import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types';
import { setHrfEnabled, setHrfMethod, setHrfStrength } from 'features/hrf/store/hrfSlice';
import type { LoRA } from 'features/lora/store/loraSlice';
import { loraRecalled } from 'features/lora/store/loraSlice';
import type { MetadataRecallFunc } from 'features/metadata/types';
import { zModelIdentifierWithBase } from 'features/nodes/types/common';
import { modelSelected } from 'features/parameters/store/actions';
import {
heightRecalled,
setCfgRescaleMultiplier,
setCfgScale,
setImg2imgStrength,
setNegativePrompt,
setPositivePrompt,
setScheduler,
setSeed,
setSteps,
vaeSelected,
widthRecalled,
} from 'features/parameters/store/generationSlice';
import type {
ParameterCFGRescaleMultiplier,
ParameterCFGScale,
ParameterHeight,
ParameterHRFEnabled,
ParameterHRFMethod,
ParameterNegativePrompt,
ParameterNegativeStylePromptSDXL,
ParameterPositivePrompt,
ParameterPositiveStylePromptSDXL,
ParameterScheduler,
ParameterSDXLRefinerNegativeAestheticScore,
ParameterSDXLRefinerPositiveAestheticScore,
ParameterSDXLRefinerStart,
ParameterSeed,
ParameterSteps,
ParameterStrength,
ParameterWidth,
} from 'features/parameters/types/parameterSchemas';
import {
refinerModelChanged,
setNegativeStylePromptSDXL,
setPositiveStylePromptSDXL,
setRefinerCFGScale,
setRefinerNegativeAestheticScore,
setRefinerPositiveAestheticScore,
setRefinerScheduler,
setRefinerStart,
setRefinerSteps,
} from 'features/sdxl/store/sdxlSlice';
import type { NonRefinerMainModelConfig, RefinerMainModelConfig, VAEModelConfig } from 'services/api/types';
const recallPositivePrompt: MetadataRecallFunc<ParameterPositivePrompt> = (positivePrompt) => {
getStore().dispatch(setPositivePrompt(positivePrompt));
};
const recallNegativePrompt: MetadataRecallFunc<ParameterNegativePrompt> = (negativePrompt) => {
getStore().dispatch(setNegativePrompt(negativePrompt));
};
const recallSDXLPositiveStylePrompt: MetadataRecallFunc<ParameterPositiveStylePromptSDXL> = (positiveStylePrompt) => {
getStore().dispatch(setPositiveStylePromptSDXL(positiveStylePrompt));
};
const recallSDXLNegativeStylePrompt: MetadataRecallFunc<ParameterNegativeStylePromptSDXL> = (negativeStylePrompt) => {
getStore().dispatch(setNegativeStylePromptSDXL(negativeStylePrompt));
};
const recallSeed: MetadataRecallFunc<ParameterSeed> = (seed) => {
getStore().dispatch(setSeed(seed));
};
const recallCFGScale: MetadataRecallFunc<ParameterCFGScale> = (cfgScale) => {
getStore().dispatch(setCfgScale(cfgScale));
};
const recallCFGRescaleMultiplier: MetadataRecallFunc<ParameterCFGRescaleMultiplier> = (cfgRescaleMultiplier) => {
getStore().dispatch(setCfgRescaleMultiplier(cfgRescaleMultiplier));
};
const recallScheduler: MetadataRecallFunc<ParameterScheduler> = (scheduler) => {
getStore().dispatch(setScheduler(scheduler));
};
const recallWidth: MetadataRecallFunc<ParameterWidth> = (width) => {
getStore().dispatch(widthRecalled(width));
};
const recallHeight: MetadataRecallFunc<ParameterHeight> = (height) => {
getStore().dispatch(heightRecalled(height));
};
const recallSteps: MetadataRecallFunc<ParameterSteps> = (steps) => {
getStore().dispatch(setSteps(steps));
};
const recallStrength: MetadataRecallFunc<ParameterStrength> = (strength) => {
getStore().dispatch(setImg2imgStrength(strength));
};
const recallHRFEnabled: MetadataRecallFunc<ParameterHRFEnabled> = (hrfEnabled) => {
getStore().dispatch(setHrfEnabled(hrfEnabled));
};
const recallHRFStrength: MetadataRecallFunc<ParameterStrength> = (hrfStrength) => {
getStore().dispatch(setHrfStrength(hrfStrength));
};
const recallHRFMethod: MetadataRecallFunc<ParameterHRFMethod> = (hrfMethod) => {
getStore().dispatch(setHrfMethod(hrfMethod));
};
const recallRefinerSteps: MetadataRecallFunc<ParameterSteps> = (refinerSteps) => {
getStore().dispatch(setRefinerSteps(refinerSteps));
};
const recallRefinerCFGScale: MetadataRecallFunc<ParameterCFGScale> = (refinerCFGScale) => {
getStore().dispatch(setRefinerCFGScale(refinerCFGScale));
};
const recallRefinerScheduler: MetadataRecallFunc<ParameterScheduler> = (refinerScheduler) => {
getStore().dispatch(setRefinerScheduler(refinerScheduler));
};
const recallRefinerPositiveAestheticScore: MetadataRecallFunc<ParameterSDXLRefinerPositiveAestheticScore> = (
refinerPositiveAestheticScore
) => {
getStore().dispatch(setRefinerPositiveAestheticScore(refinerPositiveAestheticScore));
};
const recallRefinerNegativeAestheticScore: MetadataRecallFunc<ParameterSDXLRefinerNegativeAestheticScore> = (
refinerNegativeAestheticScore
) => {
getStore().dispatch(setRefinerNegativeAestheticScore(refinerNegativeAestheticScore));
};
const recallRefinerStart: MetadataRecallFunc<ParameterSDXLRefinerStart> = (refinerStart) => {
getStore().dispatch(setRefinerStart(refinerStart));
};
const recallModel: MetadataRecallFunc<NonRefinerMainModelConfig> = (model) => {
const modelIdentifier = zModelIdentifierWithBase.parse(model);
getStore().dispatch(modelSelected(modelIdentifier));
};
const recallRefinerModel: MetadataRecallFunc<RefinerMainModelConfig> = (refinerModel) => {
const modelIdentifier = zModelIdentifierWithBase.parse(refinerModel);
getStore().dispatch(refinerModelChanged(modelIdentifier));
};
const recallVAE: MetadataRecallFunc<VAEModelConfig | null | undefined> = (vaeModel) => {
if (!vaeModel) {
getStore().dispatch(vaeSelected(null));
return;
}
const modelIdentifier = zModelIdentifierWithBase.parse(vaeModel);
getStore().dispatch(vaeSelected(modelIdentifier));
};
const recallLoRA: MetadataRecallFunc<LoRA> = (lora) => {
getStore().dispatch(loraRecalled(lora));
};
const recallAllLoRAs: MetadataRecallFunc<LoRA[]> = (loras) => {
const { dispatch } = getStore();
loras.forEach((lora) => {
dispatch(loraRecalled(lora));
});
};
const recallControlNet: MetadataRecallFunc<ControlNetConfig> = (controlNet) => {
getStore().dispatch(controlAdapterRecalled(controlNet));
};
const recallControlNets: MetadataRecallFunc<ControlNetConfig[]> = (controlNets) => {
const { dispatch } = getStore();
controlNets.forEach((controlNet) => {
dispatch(controlAdapterRecalled(controlNet));
});
};
const recallT2IAdapter: MetadataRecallFunc<T2IAdapterConfig> = (t2iAdapter) => {
getStore().dispatch(controlAdapterRecalled(t2iAdapter));
};
const recallT2IAdapters: MetadataRecallFunc<T2IAdapterConfig[]> = (t2iAdapters) => {
const { dispatch } = getStore();
t2iAdapters.forEach((t2iAdapter) => {
dispatch(controlAdapterRecalled(t2iAdapter));
});
};
const recallIPAdapter: MetadataRecallFunc<IPAdapterConfig> = (ipAdapter) => {
getStore().dispatch(controlAdapterRecalled(ipAdapter));
};
const recallIPAdapters: MetadataRecallFunc<IPAdapterConfig[]> = (ipAdapters) => {
const { dispatch } = getStore();
ipAdapters.forEach((ipAdapter) => {
dispatch(controlAdapterRecalled(ipAdapter));
});
};
export const recallPrompts = (_metadata: unknown) => {
// recallPositivePrompt(metadata);
// recallNegativePrompt(metadata);
// recallSDXLPositiveStylePrompt(metadata);
// recallSDXLNegativeStylePrompt(metadata);
// parameterNotSetToast(t('metadata.allPrompts'));
// parameterSetToast(t('metadata.allPrompts'));
};
export const recallWidthAndHeight = (_metadata: unknown) => {
// recallWidth(metadata);
// recallHeight(metadata);
};
export const recallAll = async (_metadata: unknown) => {
// if (!metadata) {
// allParameterNotSetToast();
// return;
// }
// // Update the main model first, as other parameters may depend on it.
// await recallModel(metadata);
// await Promise.allSettled([
// // Core parameters
// recallCFGScale(metadata),
// recallCFGRescaleMultiplier(metadata),
// recallPositivePrompt(metadata),
// recallNegativePrompt(metadata),
// recallScheduler(metadata),
// recallSeed(metadata),
// recallSteps(metadata),
// recallWidth(metadata),
// recallHeight(metadata),
// recallStrength(metadata),
// recallHRFEnabled(metadata),
// recallHRFMethod(metadata),
// recallHRFStrength(metadata),
// // SDXL parameters
// recallSDXLPositiveStylePrompt(metadata),
// recallSDXLNegativeStylePrompt(metadata),
// recallRefinerSteps(metadata),
// recallRefinerCFGScale(metadata),
// recallRefinerScheduler(metadata),
// recallRefinerPositiveAestheticScore(metadata),
// recallRefinerNegativeAestheticScore(metadata),
// recallRefinerStart(metadata),
// // Models
// recallVAE(metadata),
// recallRefinerModel(metadata),
// recallAllLoRAs(metadata),
// recallControlNets(metadata),
// recallT2IAdapters(metadata),
// recallIPAdapters(metadata),
// ]);
// allParameterSetToast();
};
export const recallers = {
positivePrompt: recallPositivePrompt,
negativePrompt: recallNegativePrompt,
sdxlPositiveStylePrompt: recallSDXLPositiveStylePrompt,
sdxlNegativeStylePrompt: recallSDXLNegativeStylePrompt,
seed: recallSeed,
cfgScale: recallCFGScale,
cfgRescaleMultiplier: recallCFGRescaleMultiplier,
scheduler: recallScheduler,
width: recallWidth,
height: recallHeight,
steps: recallSteps,
strength: recallStrength,
hrfEnabled: recallHRFEnabled,
hrfStrength: recallHRFStrength,
hrfMethod: recallHRFMethod,
refinerSteps: recallRefinerSteps,
refinerCFGScale: recallRefinerCFGScale,
refinerScheduler: recallRefinerScheduler,
refinerPositiveAestheticScore: recallRefinerPositiveAestheticScore,
refinerNegativeAestheticScore: recallRefinerNegativeAestheticScore,
refinerStart: recallRefinerStart,
model: recallModel,
refinerModel: recallRefinerModel,
vae: recallVAE,
lora: recallLoRA,
loras: recallAllLoRAs,
controlNets: recallControlNets,
controlNet: recallControlNet,
t2iAdapters: recallT2IAdapters,
t2iAdapter: recallT2IAdapter,
ipAdapters: recallIPAdapters,
ipAdapter: recallIPAdapter,
} as const;

View File

@ -0,0 +1,117 @@
import { getStore } from 'app/store/nanostores/store';
import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types';
import type { LoRA } from 'features/lora/store/loraSlice';
import type { MetadataValidateFunc } from 'features/metadata/types';
import { InvalidModelConfigError } from 'features/parameters/util/modelFetchingHelpers';
import type { BaseModelType, RefinerMainModelConfig, VAEModelConfig } from 'services/api/types';
/**
* Checks the given base model type against the currently-selected model's base type and throws an error if they are
* incompatible.
* @param base The base model type to validate.
* @param message An optional message to use in the error if the base model is incompatible.
*/
const validateBaseCompatibility = (base?: BaseModelType, message?: string) => {
if (!base) {
throw new InvalidModelConfigError(message || 'Missing base');
}
const currentBase = getStore().getState().generation.model?.base;
if (currentBase && base !== currentBase) {
throw new InvalidModelConfigError(message || `Incompatible base models: ${base} and ${currentBase}`);
}
};
const validateRefinerModel: MetadataValidateFunc<RefinerMainModelConfig> = (refinerModel) => {
validateBaseCompatibility('sdxl', 'Refiner incompatible with currently-selected model');
return new Promise((resolve) => resolve(refinerModel));
};
const validateVAEModel: MetadataValidateFunc<VAEModelConfig> = (vaeModel) => {
validateBaseCompatibility(vaeModel.base, 'VAE incompatible with currently-selected model');
return new Promise((resolve) => resolve(vaeModel));
};
const validateLoRA: MetadataValidateFunc<LoRA> = (lora) => {
validateBaseCompatibility(lora.model.base, 'LoRA incompatible with currently-selected model');
return new Promise((resolve) => resolve(lora));
};
const validateLoRAs: MetadataValidateFunc<LoRA[]> = (loras) => {
const validatedLoRAs: LoRA[] = [];
loras.forEach((lora) => {
try {
validateBaseCompatibility(lora.model.base, 'LoRA incompatible with currently-selected model');
validatedLoRAs.push(lora);
} catch {
// This is a no-op - we want to continue validating the rest of the LoRAs, and an empty list is valid.
}
});
return new Promise((resolve) => resolve(validatedLoRAs));
};
const validateControlNet: MetadataValidateFunc<ControlNetConfig> = (controlNet) => {
validateBaseCompatibility(controlNet.model?.base, 'ControlNet incompatible with currently-selected model');
return new Promise((resolve) => resolve(controlNet));
};
const validateControlNets: MetadataValidateFunc<ControlNetConfig[]> = (controlNets) => {
const validatedControlNets: ControlNetConfig[] = [];
controlNets.forEach((controlNet) => {
try {
validateBaseCompatibility(controlNet.model?.base, 'ControlNet incompatible with currently-selected model');
validatedControlNets.push(controlNet);
} catch {
// This is a no-op - we want to continue validating the rest of the ControlNets, and an empty list is valid.
}
});
return new Promise((resolve) => resolve(validatedControlNets));
};
const validateT2IAdapter: MetadataValidateFunc<T2IAdapterConfig> = (t2iAdapter) => {
validateBaseCompatibility(t2iAdapter.model?.base, 'T2I Adapter incompatible with currently-selected model');
return new Promise((resolve) => resolve(t2iAdapter));
};
const validateT2IAdapters: MetadataValidateFunc<T2IAdapterConfig[]> = (t2iAdapters) => {
const validatedT2IAdapters: T2IAdapterConfig[] = [];
t2iAdapters.forEach((t2iAdapter) => {
try {
validateBaseCompatibility(t2iAdapter.model?.base, 'T2I Adapter incompatible with currently-selected model');
validatedT2IAdapters.push(t2iAdapter);
} catch {
// This is a no-op - we want to continue validating the rest of the T2I Adapters, and an empty list is valid.
}
});
return new Promise((resolve) => resolve(validatedT2IAdapters));
};
const validateIPAdapter: MetadataValidateFunc<IPAdapterConfig> = (ipAdapter) => {
validateBaseCompatibility(ipAdapter.model?.base, 'IP Adapter incompatible with currently-selected model');
return new Promise((resolve) => resolve(ipAdapter));
};
const validateIPAdapters: MetadataValidateFunc<IPAdapterConfig[]> = (ipAdapters) => {
const validatedIPAdapters: IPAdapterConfig[] = [];
ipAdapters.forEach((ipAdapter) => {
try {
validateBaseCompatibility(ipAdapter.model?.base, 'IP Adapter incompatible with currently-selected model');
validatedIPAdapters.push(ipAdapter);
} catch {
// This is a no-op - we want to continue validating the rest of the IP Adapters, and an empty list is valid.
}
});
return new Promise((resolve) => resolve(validatedIPAdapters));
};
export const validators = {
refinerModel: validateRefinerModel,
vaeModel: validateVAEModel,
lora: validateLoRA,
loras: validateLoRAs,
controlNet: validateControlNet,
controlNets: validateControlNets,
t2iAdapter: validateT2IAdapter,
t2iAdapters: validateT2IAdapters,
ipAdapter: validateIPAdapter,
ipAdapters: validateIPAdapters,
} as const;

View File

@ -1,5 +1,8 @@
import { z } from 'zod'; import { z } from 'zod';
import type { ModelIdentifier as ModelIdentifierV2 } from './v2/common';
import { zModelIdentifier as zModelIdentifierV2 } from './v2/common';
// #region Field data schemas // #region Field data schemas
export const zImageField = z.object({ export const zImageField = z.object({
image_name: z.string().trim().min(1), image_name: z.string().trim().min(1),
@ -69,6 +72,8 @@ export const zModelIdentifier = z.object({
}); });
export const isModelIdentifier = (field: unknown): field is ModelIdentifier => export const isModelIdentifier = (field: unknown): field is ModelIdentifier =>
zModelIdentifier.safeParse(field).success; zModelIdentifier.safeParse(field).success;
export const isModelIdentifierV2 = (field: unknown): field is ModelIdentifierV2 =>
zModelIdentifierV2.safeParse(field).success;
export const zModelFieldBase = zModelIdentifier; export const zModelFieldBase = zModelIdentifier;
export const zModelIdentifierWithBase = zModelIdentifier.extend({ base: zBaseModel }); export const zModelIdentifierWithBase = zModelIdentifier.extend({ base: zBaseModel });
export type BaseModel = z.infer<typeof zBaseModel>; export type BaseModel = z.infer<typeof zBaseModel>;

View File

@ -1,29 +1,22 @@
import { z } from 'zod'; import { z } from 'zod';
import { import { zControlField, zIPAdapterField, zT2IAdapterField } from './common';
zControlField,
zIPAdapterField,
zMainModelField,
zModelFieldBase,
zSDXLRefinerModelField,
zT2IAdapterField,
zVAEModelField,
} from './common';
export const zLoRAWeight = z.number().nullish();
// #region Metadata-optimized versions of schemas // #region Metadata-optimized versions of schemas
// TODO: It's possible that `deepPartial` will be deprecated: // TODO: It's possible that `deepPartial` will be deprecated:
// - https://github.com/colinhacks/zod/issues/2106 // - https://github.com/colinhacks/zod/issues/2106
// - https://github.com/colinhacks/zod/issues/2854 // - https://github.com/colinhacks/zod/issues/2854
export const zLoRAMetadataItem = z.object({ export const zLoRAMetadataItem = z.object({
lora: zModelFieldBase.deepPartial(), lora: z.unknown(),
weight: z.number(), weight: zLoRAWeight,
}); });
const zControlNetMetadataItem = zControlField.deepPartial(); const zControlNetMetadataItem = zControlField.merge(z.object({ control_model: z.unknown() })).deepPartial();
const zIPAdapterMetadataItem = zIPAdapterField.deepPartial(); const zIPAdapterMetadataItem = zIPAdapterField.merge(z.object({ ip_adapter_model: z.unknown() })).deepPartial();
const zT2IAdapterMetadataItem = zT2IAdapterField.deepPartial(); const zT2IAdapterMetadataItem = zT2IAdapterField.merge(z.object({ t2i_adapter_model: z.unknown() })).deepPartial();
const zSDXLRefinerModelMetadataItem = zSDXLRefinerModelField.deepPartial(); const zSDXLRefinerModelMetadataItem = z.unknown();
const zModelMetadataItem = zMainModelField.deepPartial(); const zModelMetadataItem = z.unknown();
const zVAEModelMetadataItem = zVAEModelField.deepPartial(); const zVAEModelMetadataItem = z.unknown();
export type LoRAMetadataItem = z.infer<typeof zLoRAMetadataItem>; export type LoRAMetadataItem = z.infer<typeof zLoRAMetadataItem>;
export type ControlNetMetadataItem = z.infer<typeof zControlNetMetadataItem>; export type ControlNetMetadataItem = z.infer<typeof zControlNetMetadataItem>;
export type IPAdapterMetadataItem = z.infer<typeof zIPAdapterMetadataItem>; export type IPAdapterMetadataItem = z.infer<typeof zIPAdapterMetadataItem>;

View File

@ -1,14 +1,7 @@
import { NUMPY_RAND_MAX } from 'app/constants'; import { NUMPY_RAND_MAX } from 'app/constants';
import { import {
zBaseModel,
zControlNetModelField,
zIPAdapterModelField,
zLoRAModelField,
zModelIdentifierWithBase, zModelIdentifierWithBase,
zSchedulerField, zSchedulerField,
zSDXLRefinerModelField,
zT2IAdapterModelField,
zVAEModelField,
} from 'features/nodes/types/common'; } from 'features/nodes/types/common';
import { z } from 'zod'; import { z } from 'zod';
@ -111,42 +104,42 @@ export const isParameterModel = (val: unknown): val is ParameterModel => zParame
// #endregion // #endregion
// #region SDXL Refiner Model // #region SDXL Refiner Model
export const zParameterSDXLRefinerModel = zSDXLRefinerModelField.extend({ base: zBaseModel }); export const zParameterSDXLRefinerModel = zModelIdentifierWithBase;
export type ParameterSDXLRefinerModel = z.infer<typeof zParameterSDXLRefinerModel>; export type ParameterSDXLRefinerModel = z.infer<typeof zParameterSDXLRefinerModel>;
export const isParameterSDXLRefinerModel = (val: unknown): val is ParameterSDXLRefinerModel => export const isParameterSDXLRefinerModel = (val: unknown): val is ParameterSDXLRefinerModel =>
zParameterSDXLRefinerModel.safeParse(val).success; zParameterSDXLRefinerModel.safeParse(val).success;
// #endregion // #endregion
// #region VAE Model // #region VAE Model
export const zParameterVAEModel = zVAEModelField.extend({ base: zBaseModel }); export const zParameterVAEModel = zModelIdentifierWithBase;
export type ParameterVAEModel = z.infer<typeof zParameterVAEModel>; export type ParameterVAEModel = z.infer<typeof zParameterVAEModel>;
export const isParameterVAEModel = (val: unknown): val is ParameterVAEModel => export const isParameterVAEModel = (val: unknown): val is ParameterVAEModel =>
zParameterVAEModel.safeParse(val).success; zParameterVAEModel.safeParse(val).success;
// #endregion // #endregion
// #region LoRA Model // #region LoRA Model
export const zParameterLoRAModel = zLoRAModelField.extend({ base: zBaseModel }); export const zParameterLoRAModel = zModelIdentifierWithBase;
export type ParameterLoRAModel = z.infer<typeof zParameterLoRAModel>; export type ParameterLoRAModel = z.infer<typeof zParameterLoRAModel>;
export const isParameterLoRAModel = (val: unknown): val is ParameterLoRAModel => export const isParameterLoRAModel = (val: unknown): val is ParameterLoRAModel =>
zParameterLoRAModel.safeParse(val).success; zParameterLoRAModel.safeParse(val).success;
// #endregion // #endregion
// #region ControlNet Model // #region ControlNet Model
export const zParameterControlNetModel = zControlNetModelField.extend({ base: zBaseModel }); export const zParameterControlNetModel = zModelIdentifierWithBase;
export type ParameterControlNetModel = z.infer<typeof zParameterLoRAModel>; export type ParameterControlNetModel = z.infer<typeof zParameterLoRAModel>;
export const isParameterControlNetModel = (val: unknown): val is ParameterControlNetModel => export const isParameterControlNetModel = (val: unknown): val is ParameterControlNetModel =>
zParameterControlNetModel.safeParse(val).success; zParameterControlNetModel.safeParse(val).success;
// #endregion // #endregion
// #region IP Adapter Model // #region IP Adapter Model
export const zParameterIPAdapterModel = zIPAdapterModelField.extend({ base: zBaseModel }); export const zParameterIPAdapterModel = zModelIdentifierWithBase;
export type ParameterIPAdapterModel = z.infer<typeof zParameterIPAdapterModel>; export type ParameterIPAdapterModel = z.infer<typeof zParameterIPAdapterModel>;
export const isParameterIPAdapterModel = (val: unknown): val is ParameterIPAdapterModel => export const isParameterIPAdapterModel = (val: unknown): val is ParameterIPAdapterModel =>
zParameterIPAdapterModel.safeParse(val).success; zParameterIPAdapterModel.safeParse(val).success;
// #endregion // #endregion
// #region T2I Adapter Model // #region T2I Adapter Model
export const zParameterT2IAdapterModel = zT2IAdapterModelField.extend({ base: zBaseModel }); export const zParameterT2IAdapterModel = zModelIdentifierWithBase;
export type ParameterT2IAdapterModel = z.infer<typeof zParameterT2IAdapterModel>; export type ParameterT2IAdapterModel = z.infer<typeof zParameterT2IAdapterModel>;
export const isParameterT2IAdapterModel = (val: unknown): val is ParameterT2IAdapterModel => export const isParameterT2IAdapterModel = (val: unknown): val is ParameterT2IAdapterModel =>
zParameterT2IAdapterModel.safeParse(val).success; zParameterT2IAdapterModel.safeParse(val).success;
@ -218,3 +211,9 @@ export type ParameterCanvasCoherenceMode = z.infer<typeof zParameterCanvasCohere
export const isParameterCanvasCoherenceMode = (val: unknown): val is ParameterCanvasCoherenceMode => export const isParameterCanvasCoherenceMode = (val: unknown): val is ParameterCanvasCoherenceMode =>
zParameterCanvasCoherenceMode.safeParse(val).success; zParameterCanvasCoherenceMode.safeParse(val).success;
// #endregion // #endregion
// #region LoRA weight
export const zLoRAWeight = z.number();
export type ParameterLoRAWeight = z.infer<typeof zLoRAWeight>;
export const isParameterLoRAWeight = (val: unknown): val is ParameterLoRAWeight => zLoRAWeight.safeParse(val).success;
// #endregion

View File

@ -1,7 +1,8 @@
import { getStore } from 'app/store/nanostores/store'; import { getStore } from 'app/store/nanostores/store';
import { isModelIdentifier } from 'features/nodes/types/common'; import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
import { isModelIdentifier, isModelIdentifierV2 } from 'features/nodes/types/common';
import { modelsApi } from 'services/api/endpoints/models'; import { modelsApi } from 'services/api/endpoints/models';
import type { AnyModelConfig, BaseModelType } from 'services/api/types'; import type { AnyModelConfig, BaseModelType, ModelType } from 'services/api/types';
import { import {
isControlNetModelConfig, isControlNetModelConfig,
isIPAdapterModelConfig, isIPAdapterModelConfig,
@ -41,6 +42,12 @@ export class InvalidModelConfigError extends Error {
} }
} }
/**
* Fetches the model config for a given model key.
* @param key The model key.
* @returns A promise that resolves to the model config.
* @throws {ModelConfigNotFoundError} If the model config is unable to be fetched.
*/
export const fetchModelConfig = async (key: string): Promise<AnyModelConfig> => { export const fetchModelConfig = async (key: string): Promise<AnyModelConfig> => {
const { dispatch } = getStore(); const { dispatch } = getStore();
try { try {
@ -52,6 +59,37 @@ export const fetchModelConfig = async (key: string): Promise<AnyModelConfig> =>
} }
}; };
/**
* Fetches the model config for a given model name, base model, and model type. This provides backwards compatibility
* for MM1 model identifiers.
* @param name The model name.
* @param base The base model.
* @param type The model type.
* @returns A promise that resolves to the model config.
* @throws {ModelConfigNotFoundError} If the model config is unable to be fetched.
*/
export const fetchModelConfigByAttrs = async (
name: string,
base: BaseModelType,
type: ModelType
): Promise<AnyModelConfig> => {
const { dispatch } = getStore();
try {
const req = dispatch(modelsApi.endpoints.getModelConfigByAttrs.initiate({ name, base, type }));
req.unsubscribe();
return await req.unwrap();
} catch {
throw new ModelConfigNotFoundError(`Unable to retrieve model config for name/base/type ${name}/${base}/${type}`);
}
};
/**
* Fetches the model config for a given model key and type, and ensures that the model config is of a specific type.
* @param key The model key.
* @param typeGuard A type guard function that checks if the model config is of the expected type.
* @returns A promise that resolves to the model config. The model config is guaranteed to be of the expected type.
* @throws {InvalidModelConfigError} If the model config is unable to be fetched or is of an unexpected type.
*/
export const fetchModelConfigWithTypeGuard = async <T extends AnyModelConfig>( export const fetchModelConfigWithTypeGuard = async <T extends AnyModelConfig>(
key: string, key: string,
typeGuard: (config: AnyModelConfig) => config is T typeGuard: (config: AnyModelConfig) => config is T
@ -63,15 +101,17 @@ export const fetchModelConfigWithTypeGuard = async <T extends AnyModelConfig>(
return modelConfig; return modelConfig;
}; };
export const fetchMainModel = async (key: string) => { // TODO(psyche): Remove these helpers once `useRecallParameters` is removed
export const fetchMainModelConfig = async (key: string) => {
return fetchModelConfigWithTypeGuard(key, isNonRefinerMainModelConfig); return fetchModelConfigWithTypeGuard(key, isNonRefinerMainModelConfig);
}; };
export const fetchRefinerModel = async (key: string) => { export const fetchRefinerModelConfig = async (key: string) => {
return fetchModelConfigWithTypeGuard(key, isRefinerMainModelModelConfig); return fetchModelConfigWithTypeGuard(key, isRefinerMainModelModelConfig);
}; };
export const fetchVAEModel = async (key: string) => { export const fetchVAEModelConfig = async (key: string) => {
return fetchModelConfigWithTypeGuard(key, isVAEModelConfig); return fetchModelConfigWithTypeGuard(key, isVAEModelConfig);
}; };
@ -95,19 +135,39 @@ export const fetchTextualInversionModel = async (key: string) => {
return fetchModelConfigWithTypeGuard(key, isTextualInversionModelConfig); return fetchModelConfigWithTypeGuard(key, isTextualInversionModelConfig);
}; };
export const isBaseCompatible = (sourceBase: BaseModelType, targetBase: BaseModelType) => { /**
return sourceBase === targetBase; * Raises an error if the source base model is incompatible with the target base model.
}; * @param sourceBase The source base model.
* @param targetBase The target base model.
* @param message An optional custom message to include in the error.
* @throws {InvalidModelConfigError} If the source base model is incompatible with the target base model.
*/
export const raiseIfBaseIncompatible = (sourceBase: BaseModelType, targetBase?: BaseModelType, message?: string) => { export const raiseIfBaseIncompatible = (sourceBase: BaseModelType, targetBase?: BaseModelType, message?: string) => {
if (targetBase && !isBaseCompatible(sourceBase, targetBase)) { if (targetBase && sourceBase !== targetBase) {
throw new InvalidModelConfigError(message || `Incompatible base models: ${sourceBase} and ${targetBase}`); throw new InvalidModelConfigError(message || `Incompatible base models: ${sourceBase} and ${targetBase}`);
} }
}; };
export const getModelKey = (modelIdentifier: unknown, message?: string): string => { /**
if (!isModelIdentifier(modelIdentifier)) { * Fetches the model key from a model identifier. This includes fetching the key for MM1 format model identifiers.
throw new InvalidModelConfigError(message || `Invalid model identifier: ${modelIdentifier}`); * @param modelIdentifier The model identifier. The MM2 format `{key: string}` simply extracts the key. The MM1 format
* `{model_name: string, base_model: BaseModelType}` must do a network request to fetch the key.
* @param type The type of model to fetch. This is used to fetch the key for MM1 format model identifiers.
* @param message An optional custom message to include in the error if the model identifier is invalid.
* @returns A promise that resolves to the model key.
* @throws {InvalidModelConfigError} If the model identifier is invalid.
*/
export const getModelKey = async (modelIdentifier: unknown, type: ModelType, message?: string): Promise<string> => {
if (isModelIdentifier(modelIdentifier)) {
return modelIdentifier.key;
} }
return modelIdentifier.key; if (isModelIdentifierV2(modelIdentifier)) {
return (await fetchModelConfigByAttrs(modelIdentifier.model_name, modelIdentifier.base_model, type)).key;
}
throw new InvalidModelConfigError(message || `Invalid model identifier: ${modelIdentifier}`);
}; };
export const getModelKeyAndBase = (modelConfig: AnyModelConfig): ModelIdentifierWithBase => ({
key: modelConfig.key,
base: modelConfig.base,
});

View File

@ -1,150 +0,0 @@
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types';
import {
initialControlNet,
initialIPAdapter,
initialT2IAdapter,
} from 'features/controlAdapters/util/buildControlAdapter';
import type { LoRA } from 'features/lora/store/loraSlice';
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
import { zModelIdentifierWithBase } from 'features/nodes/types/common';
import type {
ControlNetMetadataItem,
IPAdapterMetadataItem,
LoRAMetadataItem,
T2IAdapterMetadataItem,
} from 'features/nodes/types/metadata';
import {
fetchControlNetModel,
fetchIPAdapterModel,
fetchLoRAModel,
fetchMainModel,
fetchRefinerModel,
fetchT2IAdapterModel,
fetchVAEModel,
getModelKey,
raiseIfBaseIncompatible,
} from 'features/parameters/util/modelFetchingHelpers';
import type { BaseModelType } from 'services/api/types';
import { v4 as uuidv4 } from 'uuid';
export const prepareMainModelMetadataItem = async (model: unknown): Promise<ModelIdentifierWithBase> => {
const key = getModelKey(model);
const mainModel = await fetchMainModel(key);
return zModelIdentifierWithBase.parse(mainModel);
};
export const prepareRefinerMetadataItem = async (model: unknown): Promise<ModelIdentifierWithBase> => {
const key = getModelKey(model);
const refinerModel = await fetchRefinerModel(key);
return zModelIdentifierWithBase.parse(refinerModel);
};
export const prepareVAEMetadataItem = async (vae: unknown, base?: BaseModelType): Promise<ModelIdentifierWithBase> => {
const key = getModelKey(vae);
const vaeModel = await fetchVAEModel(key);
raiseIfBaseIncompatible(vaeModel.base, base, 'VAE incompatible with currently-selected model');
return zModelIdentifierWithBase.parse(vaeModel);
};
export const prepareLoRAMetadataItem = async (
loraMetadataItem: LoRAMetadataItem,
base?: BaseModelType
): Promise<LoRA> => {
const key = getModelKey(loraMetadataItem.lora);
const loraModel = await fetchLoRAModel(key);
raiseIfBaseIncompatible(loraModel.base, base, 'LoRA incompatible with currently-selected model');
return { key: loraModel.key, base: loraModel.base, weight: loraMetadataItem.weight, isEnabled: true };
};
export const prepareControlNetMetadataItem = async (
controlnetMetadataItem: ControlNetMetadataItem,
base?: BaseModelType
): Promise<ControlNetConfig> => {
const key = getModelKey(controlnetMetadataItem.control_model);
const controlNetModel = await fetchControlNetModel(key);
raiseIfBaseIncompatible(controlNetModel.base, base, 'ControlNet incompatible with currently-selected model');
const { image, control_weight, begin_step_percent, end_step_percent, control_mode, resize_mode } =
controlnetMetadataItem;
// We don't save the original image that was processed into a control image, only the processed image
const processorType = 'none';
const processorNode = CONTROLNET_PROCESSORS.none.default;
const controlnet: ControlNetConfig = {
type: 'controlnet',
isEnabled: true,
model: zModelIdentifierWithBase.parse(controlNetModel),
weight: typeof control_weight === 'number' ? control_weight : initialControlNet.weight,
beginStepPct: begin_step_percent || initialControlNet.beginStepPct,
endStepPct: end_step_percent || initialControlNet.endStepPct,
controlMode: control_mode || initialControlNet.controlMode,
resizeMode: resize_mode || initialControlNet.resizeMode,
controlImage: image?.image_name || null,
processedControlImage: image?.image_name || null,
processorType,
processorNode,
shouldAutoConfig: true,
id: uuidv4(),
};
return controlnet;
};
export const prepareT2IAdapterMetadataItem = async (
t2iAdapterMetadataItem: T2IAdapterMetadataItem,
base?: BaseModelType
): Promise<T2IAdapterConfig> => {
const key = getModelKey(t2iAdapterMetadataItem.t2i_adapter_model);
const t2iAdapterModel = await fetchT2IAdapterModel(key);
raiseIfBaseIncompatible(t2iAdapterModel.base, base, 'T2I Adapter incompatible with currently-selected model');
const { image, weight, begin_step_percent, end_step_percent, resize_mode } = t2iAdapterMetadataItem;
// We don't save the original image that was processed into a control image, only the processed image
const processorType = 'none';
const processorNode = CONTROLNET_PROCESSORS.none.default;
const t2iAdapter: T2IAdapterConfig = {
type: 't2i_adapter',
isEnabled: true,
model: zModelIdentifierWithBase.parse(t2iAdapterModel),
weight: typeof weight === 'number' ? weight : initialT2IAdapter.weight,
beginStepPct: begin_step_percent || initialT2IAdapter.beginStepPct,
endStepPct: end_step_percent || initialT2IAdapter.endStepPct,
resizeMode: resize_mode || initialT2IAdapter.resizeMode,
controlImage: image?.image_name || null,
processedControlImage: image?.image_name || null,
processorType,
processorNode,
shouldAutoConfig: true,
id: uuidv4(),
};
return t2iAdapter;
};
export const prepareIPAdapterMetadataItem = async (
ipAdapterMetadataItem: IPAdapterMetadataItem,
base?: BaseModelType
): Promise<IPAdapterConfig> => {
const key = getModelKey(ipAdapterMetadataItem?.ip_adapter_model);
const ipAdapterModel = await fetchIPAdapterModel(key);
raiseIfBaseIncompatible(ipAdapterModel.base, base, 'T2I Adapter incompatible with currently-selected model');
const { image, weight, begin_step_percent, end_step_percent } = ipAdapterMetadataItem;
const ipAdapter: IPAdapterConfig = {
id: uuidv4(),
type: 'ip_adapter',
isEnabled: true,
controlImage: image?.image_name ?? null,
model: zModelIdentifierWithBase.parse(ipAdapterModel),
weight: weight ?? initialIPAdapter.weight,
beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct,
endStepPct: end_step_percent ?? initialIPAdapter.endStepPct,
};
return ipAdapter;
};

View File

@ -25,7 +25,7 @@ const formLabelProps2: FormLabelProps = {
export const AdvancedSettingsAccordion = memo(() => { export const AdvancedSettingsAccordion = memo(() => {
const vaeKey = useAppSelector((state) => state.generation.vae?.key); const vaeKey = useAppSelector((state) => state.generation.vae?.key);
const { data: vaeConfig } = useGetModelConfigQuery(vaeKey ?? skipToken); const { currentData: vaeConfig } = useGetModelConfigQuery(vaeKey ?? skipToken);
const selectBadges = useMemo( const selectBadges = useMemo(
() => () =>
createMemoizedSelector(selectGenerationSlice, (generation) => { createMemoizedSelector(selectGenerationSlice, (generation) => {

View File

@ -1,10 +1,8 @@
import type { EntityState, Update } from '@reduxjs/toolkit'; import type { EntityState, Update } from '@reduxjs/toolkit';
import type { PatchCollection } from '@reduxjs/toolkit/dist/query/core/buildThunks'; import type { PatchCollection } from '@reduxjs/toolkit/dist/query/core/buildThunks';
import { logger } from 'app/logging/logger'; import type { JSONObject } from 'common/types';
import type { BoardId } from 'features/gallery/store/types'; import type { BoardId } from 'features/gallery/store/types';
import { ASSETS_CATEGORIES, IMAGE_CATEGORIES, IMAGE_LIMIT } from 'features/gallery/store/types'; import { ASSETS_CATEGORIES, IMAGE_CATEGORIES, IMAGE_LIMIT } from 'features/gallery/store/types';
import type { CoreMetadata } from 'features/nodes/types/metadata';
import { zCoreMetadata } from 'features/nodes/types/metadata';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { t } from 'i18next'; import { t } from 'i18next';
import { keyBy } from 'lodash-es'; import { keyBy } from 'lodash-es';
@ -118,22 +116,9 @@ export const imagesApi = api.injectEndpoints({
providesTags: (result, error, image_name) => [{ type: 'Image', id: image_name }], providesTags: (result, error, image_name) => [{ type: 'Image', id: image_name }],
keepUnusedDataFor: 86400, // 24 hours keepUnusedDataFor: 86400, // 24 hours
}), }),
getImageMetadata: build.query<CoreMetadata | undefined, string>({ getImageMetadata: build.query<JSONObject | undefined, string>({
query: (image_name) => ({ url: buildImagesUrl(`i/${image_name}/metadata`) }), query: (image_name) => ({ url: buildImagesUrl(`i/${image_name}/metadata`) }),
providesTags: (result, error, image_name) => [{ type: 'ImageMetadata', id: image_name }], providesTags: (result, error, image_name) => [{ type: 'ImageMetadata', id: image_name }],
transformResponse: (
response: paths['/api/v1/images/i/{image_name}/metadata']['get']['responses']['200']['content']['application/json']
) => {
if (response) {
const result = zCoreMetadata.safeParse(response);
if (result.success) {
return result.data;
} else {
logger('images').warn('Problem parsing metadata');
}
}
return;
},
keepUnusedDataFor: 86400, // 24 hours keepUnusedDataFor: 86400, // 24 hours
}), }),
getImageWorkflow: build.query< getImageWorkflow: build.query<

View File

@ -70,6 +70,8 @@ export type ScanFolderResponse =
paths['/api/v2/models/scan_folder']['get']['responses']['200']['content']['application/json']; paths['/api/v2/models/scan_folder']['get']['responses']['200']['content']['application/json'];
type ScanFolderArg = operations['scan_for_models']['parameters']['query']; type ScanFolderArg = operations['scan_for_models']['parameters']['query'];
export type GetByAttrsArg = operations['get_model_records_by_attrs']['parameters']['query'];
export const mainModelsAdapter = createEntityAdapter<MainModelConfig, string>({ export const mainModelsAdapter = createEntityAdapter<MainModelConfig, string>({
selectId: (entity) => entity.key, selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name), sortComparer: (a, b) => a.name.localeCompare(b.name),
@ -223,6 +225,18 @@ export const modelsApi = api.injectEndpoints({
return tags; return tags;
}, },
}), }),
getModelConfigByAttrs: build.query<AnyModelConfig, GetByAttrsArg>({
query: (arg) => buildModelsUrl(`get_by_attrs?${queryString.stringify(arg)}`),
providesTags: (result) => {
const tags: ApiTagDescription[] = ['Model'];
if (result) {
tags.push({ type: 'ModelConfig', id: result.key });
}
return tags;
},
}),
syncModels: build.mutation<void, void>({ syncModels: build.mutation<void, void>({
query: () => { query: () => {
return { return {
@ -300,6 +314,7 @@ export const modelsApi = api.injectEndpoints({
}); });
export const { export const {
useGetModelConfigByAttrsQuery,
useGetModelConfigQuery, useGetModelConfigQuery,
useGetMainModelsQuery, useGetMainModelsQuery,
useGetControlNetModelsQuery, useGetControlNetModelsQuery,