mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): refactor metadata handling (again)
Add concepts for metadata handlers. Handlers include parsers, recallers and validators for different metadata types: - Parsers parse a raw metadata object of any shape to a structured object. - Recallers load the parsed metadata into state. Recallers are optional, as some metadata types don't need to be loaded into state. - Validators provide an additional layer of validation before recalling the metadata. This is needed because a metadata object may be valid, but not able to be recalled due to some other requirement, like base model compatibility. Validators are optional. Sometimes metadata is not a single object but a list of items - like LoRAs. Metadata handlers may implement an optional set of "item" handlers which operate on individual items in the list. Parsers and validators are async to allow fetching additional data, like a model config. Recallers are synchronous. The these handlers are composed into a public API, exported as a `handlers` object. Besides the handlers functions, a metadata handler set includes: - A function to get the label of the metadata type. - An optional function to render the value of the metadata type. - An optional function to render the _item_ value of the metadata type.
This commit is contained in:
parent
90327cb521
commit
d1f4cde8c7
@ -656,6 +656,7 @@
|
||||
}
|
||||
},
|
||||
"metadata": {
|
||||
"allPrompts": "All Prompts",
|
||||
"cfgScale": "CFG scale",
|
||||
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)",
|
||||
"createdBy": "Created By",
|
||||
@ -664,6 +665,7 @@
|
||||
"height": "Height",
|
||||
"hiresFix": "High Resolution Optimization",
|
||||
"imageDetails": "Image Details",
|
||||
"imageDimensions": "Image Dimensions",
|
||||
"initImage": "Initial image",
|
||||
"metadata": "Metadata",
|
||||
"model": "Model",
|
||||
@ -671,9 +673,11 @@
|
||||
"noImageDetails": "No image details found",
|
||||
"noMetaData": "No metadata found",
|
||||
"noRecallParameters": "No parameters to recall found",
|
||||
"parameterSet": "Parameter {{parameter}} set",
|
||||
"perlin": "Perlin Noise",
|
||||
"positivePrompt": "Positive Prompt",
|
||||
"recallParameters": "Recall Parameters",
|
||||
"recallParameter": "Recall {{label}}",
|
||||
"scheduler": "Scheduler",
|
||||
"seamless": "Seamless",
|
||||
"seed": "Seed",
|
||||
@ -1381,8 +1385,8 @@
|
||||
"nodesNotValidJSON": "Not a valid JSON",
|
||||
"nodesSaved": "Nodes Saved",
|
||||
"nodesUnrecognizedTypes": "Cannot load. Graph has unrecognized types",
|
||||
"parameterNotSet": "Parameter not set",
|
||||
"parameterSet": "Parameter set",
|
||||
"parameterNotSet": "{{parameter}} not set",
|
||||
"parameterSet": "{{parameter}} set",
|
||||
"parametersFailed": "Problem loading parameters",
|
||||
"parametersFailedDesc": "Unable to load init image.",
|
||||
"parametersNotSet": "Parameters Not Set",
|
||||
|
@ -1,6 +1,6 @@
|
||||
import { createStandaloneToast, theme, TOAST_OPTIONS } from '@invoke-ai/ui-library';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { toast } from 'common/util/toast';
|
||||
import { zPydanticValidationError } from 'features/system/store/zodSchemas';
|
||||
import { t } from 'i18next';
|
||||
import { truncate, upperFirst } from 'lodash-es';
|
||||
@ -8,11 +8,6 @@ import { queueApi } from 'services/api/endpoints/queue';
|
||||
|
||||
import { startAppListening } from '..';
|
||||
|
||||
const { toast } = createStandaloneToast({
|
||||
theme: theme,
|
||||
defaultOptions: TOAST_OPTIONS.defaultOptions,
|
||||
});
|
||||
|
||||
export const addBatchEnqueuedListener = () => {
|
||||
// success
|
||||
startAppListening({
|
||||
|
@ -1,7 +1,8 @@
|
||||
import type { UseToastOptions } from '@invoke-ai/ui-library';
|
||||
import { createStandaloneToast, ExternalLink, theme, TOAST_OPTIONS } from '@invoke-ai/ui-library';
|
||||
import { ExternalLink } from '@invoke-ai/ui-library';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { startAppListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { toast } from 'common/util/toast';
|
||||
import { t } from 'i18next';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import {
|
||||
@ -12,11 +13,6 @@ import {
|
||||
|
||||
const log = logger('images');
|
||||
|
||||
const { toast } = createStandaloneToast({
|
||||
theme: theme,
|
||||
defaultOptions: TOAST_OPTIONS.defaultOptions,
|
||||
});
|
||||
|
||||
export const addBulkDownloadListeners = () => {
|
||||
startAppListening({
|
||||
matcher: imagesApi.endpoints.bulkDownloadImages.matchFulfilled,
|
||||
|
6
invokeai/frontend/web/src/common/util/toast.ts
Normal file
6
invokeai/frontend/web/src/common/util/toast.ts
Normal file
@ -0,0 +1,6 @@
|
||||
import { createStandaloneToast, theme, TOAST_OPTIONS } from '@invoke-ai/ui-library';
|
||||
|
||||
export const { toast } = createStandaloneToast({
|
||||
theme: theme,
|
||||
defaultOptions: TOAST_OPTIONS.defaultOptions,
|
||||
});
|
@ -1,300 +1,54 @@
|
||||
import { isModelIdentifier } from 'features/nodes/types/common';
|
||||
import type {
|
||||
ControlNetMetadataItem,
|
||||
CoreMetadata,
|
||||
IPAdapterMetadataItem,
|
||||
LoRAMetadataItem,
|
||||
T2IAdapterMetadataItem,
|
||||
} from 'features/nodes/types/metadata';
|
||||
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import ImageMetadataItem, { ModelMetadataItem, VAEMetadataItem } from './ImageMetadataItem';
|
||||
import { MetadataControlNets } from 'features/metadata/components/MetadataControlNets';
|
||||
import { MetadataIPAdapters } from 'features/metadata/components/MetadataIPAdapters';
|
||||
import { MetadataItem } from 'features/metadata/components/MetadataItem';
|
||||
import { MetadataLoRAs } from 'features/metadata/components/MetadataLoRAs';
|
||||
import { MetadataT2IAdapters } from 'features/metadata/components/MetadataT2IAdapters';
|
||||
import { handlers } from 'features/metadata/util/handlers';
|
||||
import { memo } from 'react';
|
||||
|
||||
type Props = {
|
||||
metadata?: CoreMetadata;
|
||||
metadata?: unknown;
|
||||
};
|
||||
|
||||
const ImageMetadataActions = (props: Props) => {
|
||||
const { metadata } = props;
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
const {
|
||||
recallPositivePrompt,
|
||||
recallNegativePrompt,
|
||||
recallSeed,
|
||||
recallCfgScale,
|
||||
recallCfgRescaleMultiplier,
|
||||
recallModel,
|
||||
recallScheduler,
|
||||
recallVaeModel,
|
||||
recallSteps,
|
||||
recallWidth,
|
||||
recallHeight,
|
||||
recallStrength,
|
||||
recallHrfEnabled,
|
||||
recallHrfStrength,
|
||||
recallHrfMethod,
|
||||
recallLoRA,
|
||||
recallControlNet,
|
||||
recallIPAdapter,
|
||||
recallT2IAdapter,
|
||||
recallSDXLPositiveStylePrompt,
|
||||
recallSDXLNegativeStylePrompt,
|
||||
} = useRecallParameters();
|
||||
|
||||
const handleRecallPositivePrompt = useCallback(() => {
|
||||
recallPositivePrompt(metadata?.positive_prompt);
|
||||
}, [metadata?.positive_prompt, recallPositivePrompt]);
|
||||
|
||||
const handleRecallNegativePrompt = useCallback(() => {
|
||||
recallNegativePrompt(metadata?.negative_prompt);
|
||||
}, [metadata?.negative_prompt, recallNegativePrompt]);
|
||||
|
||||
const handleRecallSDXLPositiveStylePrompt = useCallback(() => {
|
||||
recallSDXLPositiveStylePrompt(metadata?.positive_style_prompt);
|
||||
}, [metadata?.positive_style_prompt, recallSDXLPositiveStylePrompt]);
|
||||
|
||||
const handleRecallSDXLNegativeStylePrompt = useCallback(() => {
|
||||
recallSDXLNegativeStylePrompt(metadata?.negative__style_prompt);
|
||||
}, [metadata?.negative__style_prompt, recallSDXLNegativeStylePrompt]);
|
||||
|
||||
const handleRecallSeed = useCallback(() => {
|
||||
recallSeed(metadata?.seed);
|
||||
}, [metadata?.seed, recallSeed]);
|
||||
|
||||
const handleRecallModel = useCallback(() => {
|
||||
recallModel(metadata?.model);
|
||||
}, [metadata?.model, recallModel]);
|
||||
|
||||
const handleRecallWidth = useCallback(() => {
|
||||
recallWidth(metadata?.width);
|
||||
}, [metadata?.width, recallWidth]);
|
||||
|
||||
const handleRecallHeight = useCallback(() => {
|
||||
recallHeight(metadata?.height);
|
||||
}, [metadata?.height, recallHeight]);
|
||||
|
||||
const handleRecallScheduler = useCallback(() => {
|
||||
recallScheduler(metadata?.scheduler);
|
||||
}, [metadata?.scheduler, recallScheduler]);
|
||||
|
||||
const handleRecallVaeModel = useCallback(() => {
|
||||
recallVaeModel(metadata?.vae);
|
||||
}, [metadata?.vae, recallVaeModel]);
|
||||
|
||||
const handleRecallSteps = useCallback(() => {
|
||||
recallSteps(metadata?.steps);
|
||||
}, [metadata?.steps, recallSteps]);
|
||||
|
||||
const handleRecallCfgScale = useCallback(() => {
|
||||
recallCfgScale(metadata?.cfg_scale);
|
||||
}, [metadata?.cfg_scale, recallCfgScale]);
|
||||
|
||||
const handleRecallCfgRescaleMultiplier = useCallback(() => {
|
||||
recallCfgRescaleMultiplier(metadata?.cfg_rescale_multiplier);
|
||||
}, [metadata?.cfg_rescale_multiplier, recallCfgRescaleMultiplier]);
|
||||
|
||||
const handleRecallStrength = useCallback(() => {
|
||||
recallStrength(metadata?.strength);
|
||||
}, [metadata?.strength, recallStrength]);
|
||||
|
||||
const handleRecallHrfEnabled = useCallback(() => {
|
||||
recallHrfEnabled(metadata?.hrf_enabled);
|
||||
}, [metadata?.hrf_enabled, recallHrfEnabled]);
|
||||
|
||||
const handleRecallHrfStrength = useCallback(() => {
|
||||
recallHrfStrength(metadata?.hrf_strength);
|
||||
}, [metadata?.hrf_strength, recallHrfStrength]);
|
||||
|
||||
const handleRecallHrfMethod = useCallback(() => {
|
||||
recallHrfMethod(metadata?.hrf_method);
|
||||
}, [metadata?.hrf_method, recallHrfMethod]);
|
||||
|
||||
const handleRecallLoRA = useCallback(
|
||||
(lora: LoRAMetadataItem) => {
|
||||
recallLoRA(lora);
|
||||
},
|
||||
[recallLoRA]
|
||||
);
|
||||
|
||||
const handleRecallControlNet = useCallback(
|
||||
(controlnet: ControlNetMetadataItem) => {
|
||||
recallControlNet(controlnet);
|
||||
},
|
||||
[recallControlNet]
|
||||
);
|
||||
|
||||
const handleRecallIPAdapter = useCallback(
|
||||
(ipAdapter: IPAdapterMetadataItem) => {
|
||||
recallIPAdapter(ipAdapter);
|
||||
},
|
||||
[recallIPAdapter]
|
||||
);
|
||||
|
||||
const handleRecallT2IAdapter = useCallback(
|
||||
(ipAdapter: T2IAdapterMetadataItem) => {
|
||||
recallT2IAdapter(ipAdapter);
|
||||
},
|
||||
[recallT2IAdapter]
|
||||
);
|
||||
|
||||
const validControlNets: ControlNetMetadataItem[] = useMemo(() => {
|
||||
return metadata?.controlnets
|
||||
? metadata.controlnets.filter((controlnet) => isModelIdentifier(controlnet.control_model))
|
||||
: [];
|
||||
}, [metadata?.controlnets]);
|
||||
|
||||
const validIPAdapters: IPAdapterMetadataItem[] = useMemo(() => {
|
||||
return metadata?.ipAdapters
|
||||
? metadata.ipAdapters.filter((ipAdapter) => isModelIdentifier(ipAdapter.ip_adapter_model))
|
||||
: [];
|
||||
}, [metadata?.ipAdapters]);
|
||||
|
||||
const validT2IAdapters: T2IAdapterMetadataItem[] = useMemo(() => {
|
||||
return metadata?.t2iAdapters
|
||||
? metadata.t2iAdapters.filter((t2iAdapter) => isModelIdentifier(t2iAdapter.t2i_adapter_model))
|
||||
: [];
|
||||
}, [metadata?.t2iAdapters]);
|
||||
|
||||
if (!metadata || Object.keys(metadata).length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
{metadata.created_by && <ImageMetadataItem label={t('metadata.createdBy')} value={metadata.created_by} />}
|
||||
{metadata.generation_mode && (
|
||||
<ImageMetadataItem label={t('metadata.generationMode')} value={metadata.generation_mode} />
|
||||
)}
|
||||
{metadata.positive_prompt && (
|
||||
<ImageMetadataItem
|
||||
label={t('metadata.positivePrompt')}
|
||||
labelPosition="top"
|
||||
value={metadata.positive_prompt}
|
||||
onClick={handleRecallPositivePrompt}
|
||||
/>
|
||||
)}
|
||||
{metadata.negative_prompt && (
|
||||
<ImageMetadataItem
|
||||
label={t('metadata.negativePrompt')}
|
||||
labelPosition="top"
|
||||
value={metadata.negative_prompt}
|
||||
onClick={handleRecallNegativePrompt}
|
||||
/>
|
||||
)}
|
||||
{metadata.positive_style_prompt && (
|
||||
<ImageMetadataItem
|
||||
label={t('sdxl.posStylePrompt')}
|
||||
labelPosition="top"
|
||||
value={metadata.positive_style_prompt}
|
||||
onClick={handleRecallSDXLPositiveStylePrompt}
|
||||
/>
|
||||
)}
|
||||
{metadata.negative_style_prompt && (
|
||||
<ImageMetadataItem
|
||||
label={t('sdxl.negStylePrompt')}
|
||||
labelPosition="top"
|
||||
value={metadata.negative_style_prompt}
|
||||
onClick={handleRecallSDXLNegativeStylePrompt}
|
||||
/>
|
||||
)}
|
||||
{metadata.seed !== undefined && metadata.seed !== null && (
|
||||
<ImageMetadataItem label={t('metadata.seed')} value={metadata.seed} onClick={handleRecallSeed} />
|
||||
)}
|
||||
{metadata.model !== undefined && metadata.model !== null && metadata.model.key && (
|
||||
<ModelMetadataItem label={t('metadata.model')} modelKey={metadata.model.key} onClick={handleRecallModel} />
|
||||
)}
|
||||
{metadata.width && (
|
||||
<ImageMetadataItem label={t('metadata.width')} value={metadata.width} onClick={handleRecallWidth} />
|
||||
)}
|
||||
{metadata.height && (
|
||||
<ImageMetadataItem label={t('metadata.height')} value={metadata.height} onClick={handleRecallHeight} />
|
||||
)}
|
||||
{metadata.scheduler && (
|
||||
<ImageMetadataItem label={t('metadata.scheduler')} value={metadata.scheduler} onClick={handleRecallScheduler} />
|
||||
)}
|
||||
<VAEMetadataItem label={t('metadata.vae')} modelKey={metadata.vae?.key} onClick={handleRecallVaeModel} />
|
||||
{metadata.steps && (
|
||||
<ImageMetadataItem label={t('metadata.steps')} value={metadata.steps} onClick={handleRecallSteps} />
|
||||
)}
|
||||
{metadata.cfg_scale !== undefined && metadata.cfg_scale !== null && (
|
||||
<ImageMetadataItem label={t('metadata.cfgScale')} value={metadata.cfg_scale} onClick={handleRecallCfgScale} />
|
||||
)}
|
||||
{metadata.cfg_rescale_multiplier !== undefined && metadata.cfg_rescale_multiplier !== null && (
|
||||
<ImageMetadataItem
|
||||
label={t('metadata.cfgRescaleMultiplier')}
|
||||
value={metadata.cfg_rescale_multiplier}
|
||||
onClick={handleRecallCfgRescaleMultiplier}
|
||||
/>
|
||||
)}
|
||||
{metadata.strength && (
|
||||
<ImageMetadataItem label={t('metadata.strength')} value={metadata.strength} onClick={handleRecallStrength} />
|
||||
)}
|
||||
{metadata.hrf_enabled && (
|
||||
<ImageMetadataItem
|
||||
label={t('hrf.metadata.enabled')}
|
||||
value={metadata.hrf_enabled}
|
||||
onClick={handleRecallHrfEnabled}
|
||||
/>
|
||||
)}
|
||||
{metadata.hrf_enabled && metadata.hrf_strength && (
|
||||
<ImageMetadataItem
|
||||
label={t('hrf.metadata.strength')}
|
||||
value={metadata.hrf_strength}
|
||||
onClick={handleRecallHrfStrength}
|
||||
/>
|
||||
)}
|
||||
{metadata.hrf_enabled && metadata.hrf_method && (
|
||||
<ImageMetadataItem
|
||||
label={t('hrf.metadata.method')}
|
||||
value={metadata.hrf_method}
|
||||
onClick={handleRecallHrfMethod}
|
||||
/>
|
||||
)}
|
||||
{metadata.loras &&
|
||||
metadata.loras.map((lora, index) => {
|
||||
if (isModelIdentifier(lora.lora)) {
|
||||
return (
|
||||
<ModelMetadataItem
|
||||
key={index}
|
||||
label="LoRA"
|
||||
modelKey={lora.lora.key}
|
||||
extra={` - ${lora.weight}`}
|
||||
onClick={handleRecallLoRA.bind(null, lora)}
|
||||
/>
|
||||
);
|
||||
}
|
||||
})}
|
||||
{validControlNets.map((controlnet, index) => (
|
||||
<ModelMetadataItem
|
||||
key={index}
|
||||
label="ControlNet"
|
||||
modelKey={controlnet.control_model?.key}
|
||||
extra={` - ${controlnet.control_weight}`}
|
||||
onClick={handleRecallControlNet.bind(null, controlnet)}
|
||||
/>
|
||||
))}
|
||||
{validIPAdapters.map((ipAdapter, index) => (
|
||||
<ModelMetadataItem
|
||||
key={index}
|
||||
label="IP Adapter"
|
||||
modelKey={ipAdapter.ip_adapter_model?.key}
|
||||
extra={` - ${ipAdapter.weight}`}
|
||||
onClick={handleRecallIPAdapter.bind(null, ipAdapter)}
|
||||
/>
|
||||
))}
|
||||
{validT2IAdapters.map((t2iAdapter, index) => (
|
||||
<ModelMetadataItem
|
||||
key={index}
|
||||
label="T2I Adapter"
|
||||
modelKey={t2iAdapter.t2i_adapter_model?.key}
|
||||
extra={` - ${t2iAdapter.weight}`}
|
||||
onClick={handleRecallT2IAdapter.bind(null, t2iAdapter)}
|
||||
/>
|
||||
))}
|
||||
<MetadataItem metadata={metadata} handlers={handlers.createdBy} />
|
||||
<MetadataItem metadata={metadata} handlers={handlers.generationMode} />
|
||||
<MetadataItem metadata={metadata} handlers={handlers.positivePrompt} direction="column" />
|
||||
<MetadataItem metadata={metadata} handlers={handlers.negativePrompt} direction="column" />
|
||||
<MetadataItem metadata={metadata} handlers={handlers.sdxlPositiveStylePrompt} direction="column" />
|
||||
<MetadataItem metadata={metadata} handlers={handlers.sdxlNegativeStylePrompt} direction="column" />
|
||||
<MetadataItem metadata={metadata} handlers={handlers.model} />
|
||||
<MetadataItem metadata={metadata} handlers={handlers.vae} />
|
||||
<MetadataItem metadata={metadata} handlers={handlers.width} />
|
||||
<MetadataItem metadata={metadata} handlers={handlers.height} />
|
||||
<MetadataItem metadata={metadata} handlers={handlers.seed} />
|
||||
<MetadataItem metadata={metadata} handlers={handlers.steps} />
|
||||
<MetadataItem metadata={metadata} handlers={handlers.scheduler} />
|
||||
<MetadataItem metadata={metadata} handlers={handlers.cfgScale} />
|
||||
<MetadataItem metadata={metadata} handlers={handlers.cfgRescaleMultiplier} />
|
||||
<MetadataItem metadata={metadata} handlers={handlers.strength} />
|
||||
<MetadataItem metadata={metadata} handlers={handlers.hrfEnabled} />
|
||||
<MetadataItem metadata={metadata} handlers={handlers.hrfMethod} />
|
||||
<MetadataItem metadata={metadata} handlers={handlers.hrfStrength} />
|
||||
<MetadataItem metadata={metadata} handlers={handlers.refinerCFGScale} />
|
||||
<MetadataItem metadata={metadata} handlers={handlers.refinerModel} />
|
||||
<MetadataItem metadata={metadata} handlers={handlers.refinerNegativeAestheticScore} />
|
||||
<MetadataItem metadata={metadata} handlers={handlers.refinerPositiveAestheticScore} />
|
||||
<MetadataItem metadata={metadata} handlers={handlers.refinerScheduler} />
|
||||
<MetadataItem metadata={metadata} handlers={handlers.refinerStart} />
|
||||
<MetadataItem metadata={metadata} handlers={handlers.refinerSteps} />
|
||||
<MetadataLoRAs metadata={metadata} />
|
||||
<MetadataControlNets metadata={metadata} />
|
||||
<MetadataT2IAdapters metadata={metadata} />
|
||||
<MetadataIPAdapters metadata={metadata} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
@ -1,28 +1,39 @@
|
||||
import { ExternalLink, Flex, IconButton, Text, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { get } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
|
||||
import { PiCopyBold } from 'react-icons/pi';
|
||||
import { PiArrowBendUpLeftBold } from 'react-icons/pi';
|
||||
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||
|
||||
type MetadataItemProps = {
|
||||
isLink?: boolean;
|
||||
label: string;
|
||||
onClick?: () => void;
|
||||
value: number | string | boolean;
|
||||
metadata: unknown;
|
||||
propertyName: string;
|
||||
onRecall?: (value: unknown) => void;
|
||||
labelPosition?: string;
|
||||
withCopy?: boolean;
|
||||
};
|
||||
|
||||
/**
|
||||
* Component to display an individual metadata item or parameter.
|
||||
*/
|
||||
const ImageMetadataItem = ({ label, value, onClick, isLink, labelPosition, withCopy = false }: MetadataItemProps) => {
|
||||
const ImageMetadataItem = ({
|
||||
label,
|
||||
metadata,
|
||||
propertyName,
|
||||
onRecall: _onRecall,
|
||||
isLink,
|
||||
labelPosition,
|
||||
}: MetadataItemProps) => {
|
||||
const { t } = useTranslation();
|
||||
const handleCopy = useCallback(() => {
|
||||
navigator.clipboard.writeText(value?.toString());
|
||||
}, [value]);
|
||||
const value = useMemo(() => get(metadata, propertyName), [metadata, propertyName]);
|
||||
const onRecall = useCallback(() => {
|
||||
if (!_onRecall) {
|
||||
return;
|
||||
}
|
||||
_onRecall(value);
|
||||
}, [_onRecall, value]);
|
||||
|
||||
if (!value) {
|
||||
return null;
|
||||
@ -30,27 +41,15 @@ const ImageMetadataItem = ({ label, value, onClick, isLink, labelPosition, withC
|
||||
|
||||
return (
|
||||
<Flex gap={2}>
|
||||
{onClick && (
|
||||
<Tooltip label={`Recall ${label}`}>
|
||||
{_onRecall && (
|
||||
<Tooltip label={t('metadata.recallParameter', { parameter: label })}>
|
||||
<IconButton
|
||||
aria-label={t('accessibility.useThisParameter')}
|
||||
icon={<IoArrowUndoCircleOutline />}
|
||||
aria-label={t('metadata.recallParameter', { parameter: label })}
|
||||
icon={<PiArrowBendUpLeftBold />}
|
||||
size="xs"
|
||||
variant="ghost"
|
||||
fontSize={20}
|
||||
onClick={onClick}
|
||||
/>
|
||||
</Tooltip>
|
||||
)}
|
||||
{withCopy && (
|
||||
<Tooltip label={`Copy ${label}`}>
|
||||
<IconButton
|
||||
aria-label={`Copy ${label}`}
|
||||
icon={<PiCopyBold />}
|
||||
size="xs"
|
||||
variant="ghost"
|
||||
fontSize={14}
|
||||
onClick={handleCopy}
|
||||
onClick={onRecall}
|
||||
/>
|
||||
</Tooltip>
|
||||
)}
|
||||
|
@ -24,29 +24,29 @@ type LoRACardProps = {
|
||||
export const LoRACard = memo((props: LoRACardProps) => {
|
||||
const { lora } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const { data: loraConfig } = useGetModelConfigQuery(lora.key);
|
||||
const { data: loraConfig } = useGetModelConfigQuery(lora.model.key);
|
||||
|
||||
const handleChange = useCallback(
|
||||
(v: number) => {
|
||||
dispatch(loraWeightChanged({ key: lora.key, weight: v }));
|
||||
dispatch(loraWeightChanged({ key: lora.model.key, weight: v }));
|
||||
},
|
||||
[dispatch, lora.key]
|
||||
[dispatch, lora.model.key]
|
||||
);
|
||||
|
||||
const handleSetLoraToggle = useCallback(() => {
|
||||
dispatch(loraIsEnabledChanged({ key: lora.key, isEnabled: !lora.isEnabled }));
|
||||
}, [dispatch, lora.key, lora.isEnabled]);
|
||||
dispatch(loraIsEnabledChanged({ key: lora.model.key, isEnabled: !lora.isEnabled }));
|
||||
}, [dispatch, lora.model.key, lora.isEnabled]);
|
||||
|
||||
const handleRemoveLora = useCallback(() => {
|
||||
dispatch(loraRemoved(lora.key));
|
||||
}, [dispatch, lora.key]);
|
||||
dispatch(loraRemoved(lora.model.key));
|
||||
}, [dispatch, lora.model.key]);
|
||||
|
||||
return (
|
||||
<Card variant="lora">
|
||||
<CardHeader>
|
||||
<Flex alignItems="center" justifyContent="space-between" width="100%" gap={2}>
|
||||
<Text noOfLines={1} wordBreak="break-all" color={lora.isEnabled ? 'base.200' : 'base.500'}>
|
||||
{loraConfig?.name ?? lora.key.substring(0, 8)}
|
||||
{loraConfig?.name ?? lora.model.key.substring(0, 8)}
|
||||
</Text>
|
||||
<Flex alignItems="center" gap={2}>
|
||||
<Switch size="sm" onChange={handleSetLoraToggle} isChecked={lora.isEnabled} />
|
||||
|
@ -2,9 +2,11 @@ import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { getModelKeyAndBase } from 'features/parameters/util/modelFetchingHelpers';
|
||||
import type { LoRAModelConfig } from 'services/api/types';
|
||||
|
||||
export type LoRA = ParameterLoRAModel & {
|
||||
export type LoRA = {
|
||||
model: ParameterLoRAModel;
|
||||
weight: number;
|
||||
isEnabled?: boolean;
|
||||
};
|
||||
@ -29,11 +31,11 @@ export const loraSlice = createSlice({
|
||||
initialState: initialLoraState,
|
||||
reducers: {
|
||||
loraAdded: (state, action: PayloadAction<LoRAModelConfig>) => {
|
||||
const { key, base } = action.payload;
|
||||
state.loras[key] = { key, base, ...defaultLoRAConfig };
|
||||
const model = getModelKeyAndBase(action.payload);
|
||||
state.loras[model.key] = { ...defaultLoRAConfig, model };
|
||||
},
|
||||
loraRecalled: (state, action: PayloadAction<LoRA>) => {
|
||||
state.loras[action.payload.key] = action.payload;
|
||||
state.loras[action.payload.model.key] = action.payload;
|
||||
},
|
||||
loraRemoved: (state, action: PayloadAction<string>) => {
|
||||
const key = action.payload;
|
||||
@ -58,7 +60,7 @@ export const loraSlice = createSlice({
|
||||
}
|
||||
lora.weight = defaultLoRAConfig.weight;
|
||||
},
|
||||
loraIsEnabledChanged: (state, action: PayloadAction<Pick<LoRA, 'key' | 'isEnabled'>>) => {
|
||||
loraIsEnabledChanged: (state, action: PayloadAction<{ key: string; isEnabled: boolean }>) => {
|
||||
const { key, isEnabled } = action.payload;
|
||||
const lora = state.loras[key];
|
||||
if (!lora) {
|
||||
|
@ -0,0 +1,66 @@
|
||||
import type { ControlNetConfig } from 'features/controlAdapters/store/types';
|
||||
import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
|
||||
import type { MetadataHandlers } from 'features/metadata/types';
|
||||
import { handlers } from 'features/metadata/util/handlers';
|
||||
import { useCallback, useEffect, useMemo, useState } from 'react';
|
||||
|
||||
type Props = {
|
||||
metadata: unknown;
|
||||
};
|
||||
|
||||
export const MetadataControlNets = ({ metadata }: Props) => {
|
||||
const [controlNets, setControlNets] = useState<ControlNetConfig[]>([]);
|
||||
|
||||
useEffect(() => {
|
||||
const parse = async () => {
|
||||
try {
|
||||
const parsed = await handlers.controlNets.parse(metadata);
|
||||
setControlNets(parsed);
|
||||
} catch (e) {
|
||||
setControlNets([]);
|
||||
}
|
||||
};
|
||||
parse();
|
||||
}, [metadata]);
|
||||
|
||||
const label = useMemo(() => handlers.controlNets.getLabel(), []);
|
||||
|
||||
return (
|
||||
<>
|
||||
{controlNets.map((controlNet) => (
|
||||
<MetadataViewControlNet
|
||||
key={controlNet.model.key}
|
||||
label={label}
|
||||
controlNet={controlNet}
|
||||
handlers={handlers.controlNets}
|
||||
/>
|
||||
))}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
const MetadataViewControlNet = ({
|
||||
label,
|
||||
controlNet,
|
||||
handlers,
|
||||
}: {
|
||||
label: string;
|
||||
controlNet: ControlNetConfig;
|
||||
handlers: MetadataHandlers<ControlNetConfig[], ControlNetConfig>;
|
||||
}) => {
|
||||
const onRecall = useCallback(() => {
|
||||
if (!handlers.recallItem) {
|
||||
return;
|
||||
}
|
||||
handlers.recallItem(controlNet, true);
|
||||
}, [handlers, controlNet]);
|
||||
|
||||
const renderedValue = useMemo(() => {
|
||||
if (!handlers.renderItemValue) {
|
||||
return null;
|
||||
}
|
||||
return handlers.renderItemValue(controlNet);
|
||||
}, [handlers, controlNet]);
|
||||
|
||||
return <MetadataItemView label={label} isDisabled={false} onRecall={onRecall} renderedValue={renderedValue} />;
|
||||
};
|
@ -0,0 +1,66 @@
|
||||
import type { IPAdapterConfig } from 'features/controlAdapters/store/types';
|
||||
import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
|
||||
import type { MetadataHandlers } from 'features/metadata/types';
|
||||
import { handlers } from 'features/metadata/util/handlers';
|
||||
import { useCallback, useEffect, useMemo, useState } from 'react';
|
||||
|
||||
type Props = {
|
||||
metadata: unknown;
|
||||
};
|
||||
|
||||
export const MetadataIPAdapters = ({ metadata }: Props) => {
|
||||
const [ipAdapters, setIPAdapters] = useState<IPAdapterConfig[]>([]);
|
||||
|
||||
useEffect(() => {
|
||||
const parse = async () => {
|
||||
try {
|
||||
const parsed = await handlers.ipAdapters.parse(metadata);
|
||||
setIPAdapters(parsed);
|
||||
} catch (e) {
|
||||
setIPAdapters([]);
|
||||
}
|
||||
};
|
||||
parse();
|
||||
}, [metadata]);
|
||||
|
||||
const label = useMemo(() => handlers.ipAdapters.getLabel(), []);
|
||||
|
||||
return (
|
||||
<>
|
||||
{ipAdapters.map((ipAdapter) => (
|
||||
<MetadataViewIPAdapter
|
||||
key={ipAdapter.model.key}
|
||||
label={label}
|
||||
ipAdapter={ipAdapter}
|
||||
handlers={handlers.ipAdapters}
|
||||
/>
|
||||
))}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
const MetadataViewIPAdapter = ({
|
||||
label,
|
||||
ipAdapter,
|
||||
handlers,
|
||||
}: {
|
||||
label: string;
|
||||
ipAdapter: IPAdapterConfig;
|
||||
handlers: MetadataHandlers<IPAdapterConfig[], IPAdapterConfig>;
|
||||
}) => {
|
||||
const onRecall = useCallback(() => {
|
||||
if (!handlers.recallItem) {
|
||||
return;
|
||||
}
|
||||
handlers.recallItem(ipAdapter, true);
|
||||
}, [handlers, ipAdapter]);
|
||||
|
||||
const renderedValue = useMemo(() => {
|
||||
if (!handlers.renderItemValue) {
|
||||
return null;
|
||||
}
|
||||
return handlers.renderItemValue(ipAdapter);
|
||||
}, [handlers, ipAdapter]);
|
||||
|
||||
return <MetadataItemView label={label} isDisabled={false} onRecall={onRecall} renderedValue={renderedValue} />;
|
||||
};
|
@ -0,0 +1,33 @@
|
||||
import { typedMemo } from '@invoke-ai/ui-library';
|
||||
import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
|
||||
import { useMetadataItem } from 'features/metadata/hooks/useMetadataItem';
|
||||
import type { MetadataHandlers } from 'features/metadata/types';
|
||||
import { MetadataParseFailedToken } from 'features/metadata/util/parsers';
|
||||
|
||||
type MetadataItemProps<T> = {
|
||||
metadata: unknown;
|
||||
handlers: MetadataHandlers<T>;
|
||||
direction?: 'row' | 'column';
|
||||
};
|
||||
|
||||
const _MetadataItem = typedMemo(<T,>({ metadata, handlers, direction = 'row' }: MetadataItemProps<T>) => {
|
||||
const { label, isDisabled, value, renderedValue, onRecall } = useMetadataItem(metadata, handlers);
|
||||
|
||||
if (value === MetadataParseFailedToken) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<MetadataItemView
|
||||
label={label}
|
||||
onRecall={onRecall}
|
||||
isDisabled={isDisabled}
|
||||
renderedValue={renderedValue}
|
||||
direction={direction}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
export const MetadataItem = typedMemo(_MetadataItem);
|
||||
|
||||
MetadataItem.displayName = 'MetadataItem';
|
@ -0,0 +1,29 @@
|
||||
import { Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { RecallButton } from 'features/metadata/components/RecallButton';
|
||||
import { memo } from 'react';
|
||||
|
||||
type MetadataItemViewProps = {
|
||||
onRecall: () => void;
|
||||
label: string;
|
||||
renderedValue: React.ReactNode;
|
||||
isDisabled: boolean;
|
||||
direction?: 'row' | 'column';
|
||||
};
|
||||
|
||||
export const MetadataItemView = memo(
|
||||
({ label, onRecall, isDisabled, renderedValue, direction = 'row' }: MetadataItemViewProps) => {
|
||||
return (
|
||||
<Flex gap={2}>
|
||||
{onRecall && <RecallButton label={label} onClick={onRecall} isDisabled={isDisabled} />}
|
||||
<Flex direction={direction}>
|
||||
<Text fontWeight="semibold" whiteSpace="pre-wrap" pr={2}>
|
||||
{label}:
|
||||
</Text>
|
||||
{renderedValue}
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
MetadataItemView.displayName = 'MetadataItemView';
|
@ -0,0 +1,61 @@
|
||||
import type { LoRA } from 'features/lora/store/loraSlice';
|
||||
import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
|
||||
import type { MetadataHandlers } from 'features/metadata/types';
|
||||
import { handlers } from 'features/metadata/util/handlers';
|
||||
import { useCallback, useEffect, useMemo, useState } from 'react';
|
||||
|
||||
type Props = {
|
||||
metadata: unknown;
|
||||
};
|
||||
|
||||
export const MetadataLoRAs = ({ metadata }: Props) => {
|
||||
const [loras, setLoRAs] = useState<LoRA[]>([]);
|
||||
|
||||
useEffect(() => {
|
||||
const parse = async () => {
|
||||
try {
|
||||
const parsed = await handlers.loras.parse(metadata);
|
||||
setLoRAs(parsed);
|
||||
} catch (e) {
|
||||
setLoRAs([]);
|
||||
}
|
||||
};
|
||||
parse();
|
||||
}, [metadata]);
|
||||
|
||||
const label = useMemo(() => handlers.loras.getLabel(), []);
|
||||
|
||||
return (
|
||||
<>
|
||||
{loras.map((lora) => (
|
||||
<MetadataViewLoRA key={lora.model.key} label={label} lora={lora} handlers={handlers.loras} />
|
||||
))}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
const MetadataViewLoRA = ({
|
||||
label,
|
||||
lora,
|
||||
handlers,
|
||||
}: {
|
||||
label: string;
|
||||
lora: LoRA;
|
||||
handlers: MetadataHandlers<LoRA[], LoRA>;
|
||||
}) => {
|
||||
const onRecall = useCallback(() => {
|
||||
if (!handlers.recallItem) {
|
||||
return;
|
||||
}
|
||||
handlers.recallItem(lora, true);
|
||||
}, [handlers, lora]);
|
||||
|
||||
const renderedValue = useMemo(() => {
|
||||
if (!handlers.renderItemValue) {
|
||||
return null;
|
||||
}
|
||||
return handlers.renderItemValue(lora);
|
||||
}, [handlers, lora]);
|
||||
|
||||
return <MetadataItemView label={label} isDisabled={false} onRecall={onRecall} renderedValue={renderedValue} />;
|
||||
};
|
@ -0,0 +1,66 @@
|
||||
import type { T2IAdapterConfig } from 'features/controlAdapters/store/types';
|
||||
import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
|
||||
import type { MetadataHandlers } from 'features/metadata/types';
|
||||
import { handlers } from 'features/metadata/util/handlers';
|
||||
import { useCallback, useEffect, useMemo, useState } from 'react';
|
||||
|
||||
type Props = {
|
||||
metadata: unknown;
|
||||
};
|
||||
|
||||
export const MetadataT2IAdapters = ({ metadata }: Props) => {
|
||||
const [t2iAdapters, setT2IAdapters] = useState<T2IAdapterConfig[]>([]);
|
||||
|
||||
useEffect(() => {
|
||||
const parse = async () => {
|
||||
try {
|
||||
const parsed = await handlers.t2iAdapters.parse(metadata);
|
||||
setT2IAdapters(parsed);
|
||||
} catch (e) {
|
||||
setT2IAdapters([]);
|
||||
}
|
||||
};
|
||||
parse();
|
||||
}, [metadata]);
|
||||
|
||||
const label = useMemo(() => handlers.t2iAdapters.getLabel(), []);
|
||||
|
||||
return (
|
||||
<>
|
||||
{t2iAdapters.map((t2iAdapter) => (
|
||||
<MetadataViewT2IAdapter
|
||||
key={t2iAdapter.model.key}
|
||||
label={label}
|
||||
t2iAdapter={t2iAdapter}
|
||||
handlers={handlers.t2iAdapters}
|
||||
/>
|
||||
))}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
const MetadataViewT2IAdapter = ({
|
||||
label,
|
||||
t2iAdapter,
|
||||
handlers,
|
||||
}: {
|
||||
label: string;
|
||||
t2iAdapter: T2IAdapterConfig;
|
||||
handlers: MetadataHandlers<T2IAdapterConfig[], T2IAdapterConfig>;
|
||||
}) => {
|
||||
const onRecall = useCallback(() => {
|
||||
if (!handlers.recallItem) {
|
||||
return;
|
||||
}
|
||||
handlers.recallItem(t2iAdapter, true);
|
||||
}, [handlers, t2iAdapter]);
|
||||
|
||||
const renderedValue = useMemo(() => {
|
||||
if (!handlers.renderItemValue) {
|
||||
return null;
|
||||
}
|
||||
return handlers.renderItemValue(t2iAdapter);
|
||||
}, [handlers, t2iAdapter]);
|
||||
|
||||
return <MetadataItemView label={label} isDisabled={false} onRecall={onRecall} renderedValue={renderedValue} />;
|
||||
};
|
@ -0,0 +1,27 @@
|
||||
import type { IconButtonProps} from '@invoke-ai/ui-library';
|
||||
import { IconButton, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiArrowBendUpLeftBold } from 'react-icons/pi';
|
||||
|
||||
type MetadataItemProps = Omit<IconButtonProps, 'aria-label'> & {
|
||||
label: string;
|
||||
};
|
||||
|
||||
export const RecallButton = memo(({ label, ...rest }: MetadataItemProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<Tooltip label={t('metadata.recallParameter', { label })}>
|
||||
<IconButton
|
||||
aria-label={t('metadata.recallParameter', { label })}
|
||||
icon={<PiArrowBendUpLeftBold />}
|
||||
size="xs"
|
||||
variant="ghost"
|
||||
{...rest}
|
||||
/>
|
||||
</Tooltip>
|
||||
);
|
||||
});
|
||||
|
||||
RecallButton.displayName = 'RecallButton';
|
27
invokeai/frontend/web/src/features/metadata/exceptions.ts
Normal file
27
invokeai/frontend/web/src/features/metadata/exceptions.ts
Normal file
@ -0,0 +1,27 @@
|
||||
/**
|
||||
* Raised when metadata parsing fails.
|
||||
*/
|
||||
export class MetadataParseError extends Error {
|
||||
/**
|
||||
* Create MetadataParseError
|
||||
* @param {String} message
|
||||
*/
|
||||
constructor(message: string) {
|
||||
super(message);
|
||||
this.name = this.constructor.name;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Raised when metadata recall fails.
|
||||
*/
|
||||
export class MetadataRecallError extends Error {
|
||||
/**
|
||||
* Create MetadataRecallError
|
||||
* @param {String} message
|
||||
*/
|
||||
constructor(message: string) {
|
||||
super(message);
|
||||
this.name = this.constructor.name;
|
||||
}
|
||||
}
|
@ -0,0 +1,51 @@
|
||||
import { Text } from '@invoke-ai/ui-library';
|
||||
import type { MetadataHandlers } from 'features/metadata/types';
|
||||
import { MetadataParseFailedToken, MetadataParsePendingToken } from 'features/metadata/util/parsers';
|
||||
import { useCallback, useEffect, useMemo, useState } from 'react';
|
||||
|
||||
export const useMetadataItem = <T,>(metadata: unknown, handlers: MetadataHandlers<T>) => {
|
||||
const [value, setValue] = useState<T | typeof MetadataParsePendingToken | typeof MetadataParseFailedToken>(
|
||||
MetadataParsePendingToken
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
const _parse = async () => {
|
||||
try {
|
||||
const parsed = await handlers.parse(metadata);
|
||||
setValue(parsed);
|
||||
} catch (e) {
|
||||
setValue(MetadataParseFailedToken);
|
||||
}
|
||||
};
|
||||
_parse();
|
||||
}, [handlers, metadata]);
|
||||
|
||||
const isDisabled = useMemo(() => value === MetadataParsePendingToken || value === MetadataParseFailedToken, [value]);
|
||||
|
||||
const label = useMemo(() => handlers.getLabel(), [handlers]);
|
||||
|
||||
const renderedValue = useMemo(() => {
|
||||
if (value === MetadataParsePendingToken) {
|
||||
return <Text>Loading</Text>;
|
||||
}
|
||||
if (value === MetadataParseFailedToken) {
|
||||
return <Text>Parsing Failed</Text>;
|
||||
}
|
||||
|
||||
const rendered = handlers.renderValue(value);
|
||||
|
||||
if (typeof rendered === 'string') {
|
||||
return <Text>{rendered}</Text>;
|
||||
}
|
||||
return rendered;
|
||||
}, [handlers, value]);
|
||||
|
||||
const onRecall = useCallback(() => {
|
||||
if (!handlers.recall || value === MetadataParsePendingToken || value === MetadataParseFailedToken) {
|
||||
return null;
|
||||
}
|
||||
handlers.recall(value, true);
|
||||
}, [handlers, value]);
|
||||
|
||||
return { label, isDisabled, value, renderedValue, onRecall };
|
||||
};
|
136
invokeai/frontend/web/src/features/metadata/types.ts
Normal file
136
invokeai/frontend/web/src/features/metadata/types.ts
Normal file
@ -0,0 +1,136 @@
|
||||
/**
|
||||
* Renders a value of type T as a React node.
|
||||
*/
|
||||
export type MetadataRenderValueFunc<T> = (value: T) => React.ReactNode;
|
||||
|
||||
/**
|
||||
* Gets the label of the current metadata item as a string.
|
||||
*/
|
||||
export type MetadataGetLabelFunc = () => string;
|
||||
|
||||
export type MetadataParseOptions = {
|
||||
toastOnFailure?: boolean;
|
||||
toastOnSuccess?: boolean;
|
||||
};
|
||||
|
||||
export type MetadataRecallOptions = MetadataParseOptions;
|
||||
|
||||
/**
|
||||
* A function that recalls a parsed and validated metadata value.
|
||||
*
|
||||
* @param value The value to recall.
|
||||
* @throws MetadataRecallError if the value cannot be recalled.
|
||||
*/
|
||||
export type MetadataRecallFunc<T> = (value: T) => void;
|
||||
|
||||
/**
|
||||
* An async function that receives metadata and returns a parsed value, throwing if the value is invalid or missing.
|
||||
*
|
||||
* The function receives an object of unknown type. It is responsible for extracting the relevant data from the metadata
|
||||
* and returning a value of type T.
|
||||
*
|
||||
* The function should throw a MetadataParseError if the metadata is invalid or missing.
|
||||
*
|
||||
* @param metadata The metadata to parse.
|
||||
* @returns A promise that resolves to the parsed value.
|
||||
* @throws MetadataParseError if the metadata is invalid or missing.
|
||||
*/
|
||||
export type MetadataParseFunc<T = unknown> = (metadata: unknown) => Promise<T>;
|
||||
|
||||
/**
|
||||
* A function that performs additional validation logic before recalling a metadata value. It is called with a parsed
|
||||
* value and should throw if the validation logic fails.
|
||||
*
|
||||
* This function is used in cases where some additional logic is required before recalling. For example, when recalling
|
||||
* a LoRA, we need to check if it is compatible with the current base model.
|
||||
*
|
||||
* @param value The value to validate.
|
||||
* @returns A promise that resolves to the validated value.
|
||||
* @throws MetadataRecallError if the value is invalid.
|
||||
*/
|
||||
export type MetadataValidateFunc<T> = (value: T) => Promise<T>;
|
||||
|
||||
export type MetadataHandlers<TValue = unknown, TItem = unknown> = {
|
||||
/**
|
||||
* Gets the label of the current metadata item as a string.
|
||||
*
|
||||
* @returns The label of the current metadata item.
|
||||
*/
|
||||
getLabel: MetadataGetLabelFunc;
|
||||
/**
|
||||
* An async function that receives metadata and returns a parsed metadata value.
|
||||
*
|
||||
* @param metadata The metadata to parse.
|
||||
* @param withToast Whether to show a toast on success or failure.
|
||||
* @returns A promise that resolves to the parsed value.
|
||||
* @throws MetadataParseError if the metadata is invalid or missing.
|
||||
*/
|
||||
parse: (metadata: unknown, withToast?: boolean) => Promise<TValue>;
|
||||
/**
|
||||
* An async function that receives a metadata item and returns a parsed metadata item value.
|
||||
*
|
||||
* This is only provided if the metadata value is an array.
|
||||
*
|
||||
* @param item The item to parse. It should be an item from the array.
|
||||
* @param withToast Whether to show a toast on success or failure.
|
||||
* @returns A promise that resolves to the parsed value.
|
||||
* @throws MetadataParseError if the metadata is invalid or missing.
|
||||
*/
|
||||
parseItem?: (item: unknown, withToast?: boolean) => Promise<TItem>;
|
||||
/**
|
||||
* An async function that recalls a parsed metadata value.
|
||||
*
|
||||
* This function is only provided if the metadata value can be recalled.
|
||||
*
|
||||
* @param value The value to recall.
|
||||
* @param withToast Whether to show a toast on success or failure.
|
||||
* @returns A promise that resolves when the recall operation is complete.
|
||||
* @throws MetadataRecallError if the value cannot be recalled.
|
||||
*/
|
||||
recall?: (value: TValue, withToast?: boolean) => Promise<void>;
|
||||
/**
|
||||
* An async function that recalls a parsed metadata item value.
|
||||
*
|
||||
* This function is only provided if the metadata value is an array and the items can be recalled.
|
||||
*
|
||||
* @param item The item to recall. It should be an item from the array.
|
||||
* @param withToast Whether to show a toast on success or failure.
|
||||
* @returns A promise that resolves when the recall operation is complete.
|
||||
* @throws MetadataRecallError if the value cannot be recalled.
|
||||
*/
|
||||
recallItem?: (item: TItem, withToast?: boolean) => Promise<void>;
|
||||
/**
|
||||
* Renders a parsed metadata value as a React node.
|
||||
*
|
||||
* @param value The value to render.
|
||||
* @returns The rendered value.
|
||||
*/
|
||||
renderValue: MetadataRenderValueFunc<TValue>;
|
||||
/**
|
||||
* Renders a parsed metadata item value as a React node.
|
||||
*
|
||||
* @param item The item to render.
|
||||
* @returns The rendered item.
|
||||
*/
|
||||
renderItemValue?: MetadataRenderValueFunc<TItem>;
|
||||
};
|
||||
|
||||
// TODO(psyche): The types for item handlers should be able to be inferred from the type of the value:
|
||||
// type MetadataHandlersInferItem<TValue> = TValue extends Array<infer TItem> ? MetadataParseFunc<TItem> : never
|
||||
// While this works for the types as expected, I couldn't satisfy TS in the implementations of the handlers.
|
||||
|
||||
export type BuildMetadataHandlersArg<TValue, TItem> = {
|
||||
parser: MetadataParseFunc<TValue>;
|
||||
itemParser?: MetadataParseFunc<TItem>;
|
||||
recaller?: MetadataRecallFunc<TValue>;
|
||||
itemRecaller?: MetadataRecallFunc<TItem>;
|
||||
validator?: MetadataValidateFunc<TValue>;
|
||||
itemValidator?: MetadataValidateFunc<TItem>;
|
||||
getLabel: MetadataGetLabelFunc;
|
||||
renderValue?: MetadataRenderValueFunc<TValue>;
|
||||
renderItemValue?: MetadataRenderValueFunc<TItem>;
|
||||
};
|
||||
|
||||
export type BuildMetadataHandlers = <TValue, TItem>(
|
||||
arg: BuildMetadataHandlersArg<TValue, TItem>
|
||||
) => MetadataHandlers<TValue, TItem>;
|
316
invokeai/frontend/web/src/features/metadata/util/handlers.tsx
Normal file
316
invokeai/frontend/web/src/features/metadata/util/handlers.tsx
Normal file
@ -0,0 +1,316 @@
|
||||
import { Text } from '@invoke-ai/ui-library';
|
||||
import { toast } from 'common/util/toast';
|
||||
import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types';
|
||||
import type { LoRA } from 'features/lora/store/loraSlice';
|
||||
import type {
|
||||
BuildMetadataHandlers,
|
||||
MetadataGetLabelFunc,
|
||||
MetadataHandlers,
|
||||
MetadataParseFunc,
|
||||
MetadataRecallFunc,
|
||||
MetadataRenderValueFunc,
|
||||
MetadataValidateFunc,
|
||||
} from 'features/metadata/types';
|
||||
import { validators } from 'features/metadata/util/validators';
|
||||
import { t } from 'i18next';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
import { parsers } from './parsers';
|
||||
import { recallers } from './recallers';
|
||||
|
||||
const renderModelConfigValue: MetadataRenderValueFunc<AnyModelConfig> = (value) =>
|
||||
`${value.name} (${value.base.toUpperCase()}, ${value.key})`;
|
||||
const renderLoRAValue: MetadataRenderValueFunc<LoRA> = (value) => <Text>{`${value.model.key} (${value.weight})`}</Text>;
|
||||
const renderControlAdapterValue: MetadataRenderValueFunc<ControlNetConfig | T2IAdapterConfig | IPAdapterConfig> = (
|
||||
value
|
||||
) => <Text>{`${value.model?.key} (${value.weight})`}</Text>;
|
||||
|
||||
const parameterSetToast = (parameter: string, description?: string) => {
|
||||
toast({
|
||||
title: t('toast.parameterSet', { parameter }),
|
||||
description,
|
||||
status: 'info',
|
||||
duration: 2500,
|
||||
isClosable: true,
|
||||
});
|
||||
};
|
||||
|
||||
const parameterNotSetToast = (parameter: string, description?: string) => {
|
||||
toast({
|
||||
title: t('toast.parameterNotSet', { parameter }),
|
||||
description,
|
||||
status: 'warning',
|
||||
duration: 2500,
|
||||
isClosable: true,
|
||||
});
|
||||
};
|
||||
|
||||
// const allParameterSetToast = (description?: string) => {
|
||||
// toast({
|
||||
// title: t('toast.parametersSet'),
|
||||
// status: 'info',
|
||||
// description,
|
||||
// duration: 2500,
|
||||
// isClosable: true,
|
||||
// });
|
||||
// };
|
||||
|
||||
// const allParameterNotSetToast = (description?: string) => {
|
||||
// toast({
|
||||
// title: t('toast.parametersNotSet'),
|
||||
// status: 'warning',
|
||||
// description,
|
||||
// duration: 2500,
|
||||
// isClosable: true,
|
||||
// });
|
||||
// };
|
||||
|
||||
const buildParse =
|
||||
<TValue, TItem>(arg: {
|
||||
parser: MetadataParseFunc<TValue>;
|
||||
getLabel: MetadataGetLabelFunc;
|
||||
}): MetadataHandlers<TValue, TItem>['parse'] =>
|
||||
async (metadata, withToast = false) => {
|
||||
try {
|
||||
const parsed = await arg.parser(metadata);
|
||||
withToast && parameterSetToast(arg.getLabel());
|
||||
return parsed;
|
||||
} catch (e) {
|
||||
withToast && parameterNotSetToast(arg.getLabel(), (e as Error).message);
|
||||
throw e;
|
||||
}
|
||||
};
|
||||
|
||||
const buildParseItem =
|
||||
<TValue, TItem>(arg: {
|
||||
itemParser: MetadataParseFunc<TItem>;
|
||||
getLabel: MetadataGetLabelFunc;
|
||||
}): MetadataHandlers<TValue, TItem>['parseItem'] =>
|
||||
async (item, withToast = false) => {
|
||||
try {
|
||||
const parsed = await arg.itemParser(item);
|
||||
withToast && parameterSetToast(arg.getLabel());
|
||||
return parsed;
|
||||
} catch (e) {
|
||||
withToast && parameterNotSetToast(arg.getLabel(), (e as Error).message);
|
||||
throw e;
|
||||
}
|
||||
};
|
||||
|
||||
const buildRecall =
|
||||
<TValue, TItem>(arg: {
|
||||
recaller: MetadataRecallFunc<TValue>;
|
||||
validator?: MetadataValidateFunc<TValue>;
|
||||
getLabel: MetadataGetLabelFunc;
|
||||
}): NonNullable<MetadataHandlers<TValue, TItem>['recall']> =>
|
||||
async (value, withToast = false) => {
|
||||
try {
|
||||
arg.validator && (await arg.validator(value));
|
||||
await arg.recaller(value);
|
||||
withToast && parameterSetToast(arg.getLabel());
|
||||
} catch (e) {
|
||||
withToast && parameterNotSetToast(arg.getLabel(), (e as Error).message);
|
||||
throw e;
|
||||
}
|
||||
};
|
||||
|
||||
const buildRecallItem =
|
||||
<TValue, TItem>(arg: {
|
||||
itemRecaller: MetadataRecallFunc<TItem>;
|
||||
itemValidator?: MetadataValidateFunc<TItem>;
|
||||
getLabel: MetadataGetLabelFunc;
|
||||
}): NonNullable<MetadataHandlers<TValue, TItem>['recallItem']> =>
|
||||
async (item, withToast = false) => {
|
||||
try {
|
||||
arg.itemValidator && (await arg.itemValidator(item));
|
||||
await arg.itemRecaller(item);
|
||||
withToast && parameterSetToast(arg.getLabel());
|
||||
} catch (e) {
|
||||
withToast && parameterNotSetToast(arg.getLabel(), (e as Error).message);
|
||||
throw e;
|
||||
}
|
||||
};
|
||||
|
||||
const buildHandlers: BuildMetadataHandlers = ({
|
||||
getLabel,
|
||||
parser,
|
||||
itemParser,
|
||||
recaller,
|
||||
itemRecaller,
|
||||
validator,
|
||||
itemValidator,
|
||||
renderValue,
|
||||
renderItemValue,
|
||||
}) => ({
|
||||
parse: buildParse({ parser, getLabel }),
|
||||
parseItem: itemParser ? buildParseItem({ itemParser, getLabel }) : undefined,
|
||||
recall: recaller ? buildRecall({ recaller, validator, getLabel }) : undefined,
|
||||
recallItem: itemRecaller ? buildRecallItem({ itemRecaller, itemValidator, getLabel }) : undefined,
|
||||
getLabel,
|
||||
renderValue: renderValue ?? String,
|
||||
renderItemValue: renderItemValue ?? String,
|
||||
});
|
||||
|
||||
export const handlers = {
|
||||
// Misc
|
||||
createdBy: buildHandlers({ getLabel: () => t('metadata.createdBy'), parser: parsers.createdBy }),
|
||||
generationMode: buildHandlers({ getLabel: () => t('metadata.generationMode'), parser: parsers.generationMode }),
|
||||
|
||||
// Core parameters
|
||||
cfgRescaleMultiplier: buildHandlers({
|
||||
getLabel: () => t('metadata.cfgRescaleMultiplier'),
|
||||
parser: parsers.cfgRescaleMultiplier,
|
||||
recaller: recallers.cfgRescaleMultiplier,
|
||||
}),
|
||||
cfgScale: buildHandlers({
|
||||
getLabel: () => t('metadata.cfgScale'),
|
||||
parser: parsers.cfgScale,
|
||||
recaller: recallers.cfgScale,
|
||||
}),
|
||||
height: buildHandlers({ getLabel: () => t('metadata.height'), parser: parsers.height, recaller: recallers.height }),
|
||||
negativePrompt: buildHandlers({
|
||||
getLabel: () => t('metadata.negativePrompt'),
|
||||
parser: parsers.negativePrompt,
|
||||
recaller: recallers.negativePrompt,
|
||||
}),
|
||||
positivePrompt: buildHandlers({
|
||||
getLabel: () => t('metadata.positivePrompt'),
|
||||
parser: parsers.positivePrompt,
|
||||
recaller: recallers.positivePrompt,
|
||||
}),
|
||||
scheduler: buildHandlers({
|
||||
getLabel: () => t('metadata.scheduler'),
|
||||
parser: parsers.scheduler,
|
||||
recaller: recallers.scheduler,
|
||||
}),
|
||||
sdxlNegativeStylePrompt: buildHandlers({
|
||||
getLabel: () => t('sdxl.negStylePrompt'),
|
||||
parser: parsers.sdxlNegativeStylePrompt,
|
||||
recaller: recallers.sdxlNegativeStylePrompt,
|
||||
}),
|
||||
sdxlPositiveStylePrompt: buildHandlers({
|
||||
getLabel: () => t('sdxl.posStylePrompt'),
|
||||
parser: parsers.sdxlPositiveStylePrompt,
|
||||
recaller: recallers.sdxlPositiveStylePrompt,
|
||||
}),
|
||||
seed: buildHandlers({ getLabel: () => t('metadata.seed'), parser: parsers.seed, recaller: recallers.seed }),
|
||||
steps: buildHandlers({ getLabel: () => t('metadata.steps'), parser: parsers.steps, recaller: recallers.steps }),
|
||||
strength: buildHandlers({
|
||||
getLabel: () => t('metadata.strength'),
|
||||
parser: parsers.strength,
|
||||
recaller: recallers.strength,
|
||||
}),
|
||||
width: buildHandlers({ getLabel: () => t('metadata.width'), parser: parsers.width, recaller: recallers.width }),
|
||||
|
||||
// HRF
|
||||
hrfEnabled: buildHandlers({
|
||||
getLabel: () => t('hrf.metadata.enabled'),
|
||||
parser: parsers.hrfEnabled,
|
||||
recaller: recallers.hrfEnabled,
|
||||
}),
|
||||
hrfMethod: buildHandlers({
|
||||
getLabel: () => t('hrf.metadata.method'),
|
||||
parser: parsers.hrfMethod,
|
||||
recaller: recallers.hrfMethod,
|
||||
}),
|
||||
hrfStrength: buildHandlers({
|
||||
getLabel: () => t('hrf.metadata.strength'),
|
||||
parser: parsers.hrfStrength,
|
||||
recaller: recallers.hrfStrength,
|
||||
}),
|
||||
|
||||
// Refiner
|
||||
refinerCFGScale: buildHandlers({
|
||||
getLabel: () => t('sdxl.cfgScale'),
|
||||
parser: parsers.refinerCFGScale,
|
||||
recaller: recallers.refinerCFGScale,
|
||||
}),
|
||||
refinerModel: buildHandlers({
|
||||
getLabel: () => t('sdxl.refinerModel'),
|
||||
parser: parsers.refinerModel,
|
||||
recaller: recallers.refinerModel,
|
||||
validator: validators.refinerModel,
|
||||
}),
|
||||
refinerNegativeAestheticScore: buildHandlers({
|
||||
getLabel: () => t('sdxl.posAestheticScore'),
|
||||
parser: parsers.refinerNegativeAestheticScore,
|
||||
recaller: recallers.refinerNegativeAestheticScore,
|
||||
}),
|
||||
refinerPositiveAestheticScore: buildHandlers({
|
||||
getLabel: () => t('sdxl.negAestheticScore'),
|
||||
parser: parsers.refinerPositiveAestheticScore,
|
||||
recaller: recallers.refinerPositiveAestheticScore,
|
||||
}),
|
||||
refinerScheduler: buildHandlers({
|
||||
getLabel: () => t('sdxl.scheduler'),
|
||||
parser: parsers.refinerScheduler,
|
||||
recaller: recallers.refinerScheduler,
|
||||
}),
|
||||
refinerStart: buildHandlers({
|
||||
getLabel: () => t('sdxl.refiner_start'),
|
||||
parser: parsers.refinerStart,
|
||||
recaller: recallers.refinerStart,
|
||||
}),
|
||||
refinerSteps: buildHandlers({
|
||||
getLabel: () => t('sdxl.refiner_steps'),
|
||||
parser: parsers.refinerSteps,
|
||||
recaller: recallers.refinerSteps,
|
||||
}),
|
||||
|
||||
// Models
|
||||
model: buildHandlers({
|
||||
getLabel: () => t('metadata.model'),
|
||||
parser: parsers.mainModel,
|
||||
recaller: recallers.model,
|
||||
renderValue: renderModelConfigValue,
|
||||
}),
|
||||
vae: buildHandlers({
|
||||
getLabel: () => t('metadata.vae'),
|
||||
parser: parsers.vaeModel,
|
||||
recaller: recallers.vae,
|
||||
renderValue: renderModelConfigValue,
|
||||
validator: validators.vaeModel,
|
||||
}),
|
||||
|
||||
// Arrays of models
|
||||
controlNets: buildHandlers({
|
||||
getLabel: () => t('common.controlNet'),
|
||||
parser: parsers.controlNets,
|
||||
itemParser: parsers.controlNet,
|
||||
recaller: recallers.controlNets,
|
||||
itemRecaller: recallers.controlNet,
|
||||
validator: validators.controlNets,
|
||||
itemValidator: validators.controlNet,
|
||||
renderItemValue: renderControlAdapterValue,
|
||||
}),
|
||||
ipAdapters: buildHandlers({
|
||||
getLabel: () => t('common.ipAdapter'),
|
||||
parser: parsers.ipAdapters,
|
||||
itemParser: parsers.ipAdapter,
|
||||
recaller: recallers.ipAdapters,
|
||||
itemRecaller: recallers.ipAdapter,
|
||||
validator: validators.ipAdapters,
|
||||
itemValidator: validators.ipAdapter,
|
||||
renderItemValue: renderControlAdapterValue,
|
||||
}),
|
||||
loras: buildHandlers({
|
||||
getLabel: () => t('models.lora'),
|
||||
parser: parsers.loras,
|
||||
itemParser: parsers.lora,
|
||||
recaller: recallers.loras,
|
||||
itemRecaller: recallers.lora,
|
||||
validator: validators.loras,
|
||||
itemValidator: validators.lora,
|
||||
renderItemValue: renderLoRAValue,
|
||||
}),
|
||||
t2iAdapters: buildHandlers({
|
||||
getLabel: () => t('common.t2iAdapter'),
|
||||
parser: parsers.t2iAdapters,
|
||||
itemParser: parsers.t2iAdapter,
|
||||
recaller: recallers.t2iAdapters,
|
||||
itemRecaller: recallers.t2iAdapter,
|
||||
validator: validators.t2iAdapters,
|
||||
itemValidator: validators.t2iAdapter,
|
||||
renderItemValue: renderControlAdapterValue,
|
||||
}),
|
||||
} as const;
|
396
invokeai/frontend/web/src/features/metadata/util/parsers.ts
Normal file
396
invokeai/frontend/web/src/features/metadata/util/parsers.ts
Normal file
@ -0,0 +1,396 @@
|
||||
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
|
||||
import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types';
|
||||
import {
|
||||
initialControlNet,
|
||||
initialIPAdapter,
|
||||
initialT2IAdapter,
|
||||
} from 'features/controlAdapters/util/buildControlAdapter';
|
||||
import type { LoRA } from 'features/lora/store/loraSlice';
|
||||
import { defaultLoRAConfig } from 'features/lora/store/loraSlice';
|
||||
import { MetadataParseError } from 'features/metadata/exceptions';
|
||||
import type { MetadataParseFunc } from 'features/metadata/types';
|
||||
import {
|
||||
zControlField,
|
||||
zIPAdapterField,
|
||||
zModelIdentifierWithBase,
|
||||
zT2IAdapterField,
|
||||
} from 'features/nodes/types/common';
|
||||
import type {
|
||||
ParameterCFGRescaleMultiplier,
|
||||
ParameterCFGScale,
|
||||
ParameterHeight,
|
||||
ParameterHRFEnabled,
|
||||
ParameterHRFMethod,
|
||||
ParameterNegativePrompt,
|
||||
ParameterNegativeStylePromptSDXL,
|
||||
ParameterPositivePrompt,
|
||||
ParameterPositiveStylePromptSDXL,
|
||||
ParameterScheduler,
|
||||
ParameterSDXLRefinerNegativeAestheticScore,
|
||||
ParameterSDXLRefinerPositiveAestheticScore,
|
||||
ParameterSDXLRefinerStart,
|
||||
ParameterSeed,
|
||||
ParameterSteps,
|
||||
ParameterStrength,
|
||||
ParameterWidth,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import {
|
||||
isParameterCFGRescaleMultiplier,
|
||||
isParameterCFGScale,
|
||||
isParameterHeight,
|
||||
isParameterHRFEnabled,
|
||||
isParameterHRFMethod,
|
||||
isParameterLoRAWeight,
|
||||
isParameterNegativePrompt,
|
||||
isParameterNegativeStylePromptSDXL,
|
||||
isParameterPositivePrompt,
|
||||
isParameterPositiveStylePromptSDXL,
|
||||
isParameterScheduler,
|
||||
isParameterSDXLRefinerNegativeAestheticScore,
|
||||
isParameterSDXLRefinerPositiveAestheticScore,
|
||||
isParameterSDXLRefinerStart,
|
||||
isParameterSeed,
|
||||
isParameterSteps,
|
||||
isParameterStrength,
|
||||
isParameterWidth,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import {
|
||||
fetchModelConfigWithTypeGuard,
|
||||
getModelKey,
|
||||
getModelKeyAndBase,
|
||||
} from 'features/parameters/util/modelFetchingHelpers';
|
||||
import { get, isArray, isString } from 'lodash-es';
|
||||
import type { NonRefinerMainModelConfig, RefinerMainModelConfig, VAEModelConfig } from 'services/api/types';
|
||||
import {
|
||||
isControlNetModelConfig,
|
||||
isIPAdapterModelConfig,
|
||||
isLoRAModelConfig,
|
||||
isNonRefinerMainModelConfig,
|
||||
isRefinerMainModelModelConfig,
|
||||
isT2IAdapterModelConfig,
|
||||
isVAEModelConfig,
|
||||
} from 'services/api/types';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
export const MetadataParsePendingToken = Symbol('pending');
|
||||
export const MetadataParseFailedToken = Symbol('failed');
|
||||
|
||||
/**
|
||||
* An async function that a property from an object and validates its type using a type guard. If the property is missing
|
||||
* or invalid, the function should throw a MetadataParseError.
|
||||
* @param obj The object to get the property from.
|
||||
* @param property The property to get.
|
||||
* @param typeGuard A type guard function to check the type of the property. Provide `undefined` to opt out of type
|
||||
* validation and always return the property value.
|
||||
* @returns A promise that resolves to the property value if it exists and is of the expected type.
|
||||
* @throws MetadataParseError if a type guard is provided and the property is not of the expected type.
|
||||
*/
|
||||
const getProperty = <T = unknown>(
|
||||
obj: unknown,
|
||||
property: string,
|
||||
typeGuard: (val: unknown) => val is T = (val: unknown): val is T => true
|
||||
): Promise<T> => {
|
||||
return new Promise<T>((resolve, reject) => {
|
||||
const val = get(obj, property) as unknown;
|
||||
if (typeGuard(val)) {
|
||||
resolve(val);
|
||||
}
|
||||
reject(new MetadataParseError(`Property ${property} is not of expected type`));
|
||||
});
|
||||
};
|
||||
|
||||
const parseCreatedBy: MetadataParseFunc<string> = (metadata) => getProperty(metadata, 'created_by', isString);
|
||||
|
||||
const parseGenerationMode: MetadataParseFunc<string> = (metadata) => getProperty(metadata, 'generation_mode', isString);
|
||||
|
||||
const parsePositivePrompt: MetadataParseFunc<ParameterPositivePrompt> = (metadata) =>
|
||||
getProperty(metadata, 'positive_prompt', isParameterPositivePrompt);
|
||||
|
||||
const parseNegativePrompt: MetadataParseFunc<ParameterNegativePrompt> = (metadata) =>
|
||||
getProperty(metadata, 'negative_prompt', isParameterNegativePrompt);
|
||||
|
||||
const parseSDXLPositiveStylePrompt: MetadataParseFunc<ParameterPositiveStylePromptSDXL> = (metadata) =>
|
||||
getProperty(metadata, 'positive_style_prompt', isParameterPositiveStylePromptSDXL);
|
||||
|
||||
const parseSDXLNegativeStylePrompt: MetadataParseFunc<ParameterNegativeStylePromptSDXL> = (metadata) =>
|
||||
getProperty(metadata, 'negative_style_prompt', isParameterNegativeStylePromptSDXL);
|
||||
|
||||
const parseSeed: MetadataParseFunc<ParameterSeed> = (metadata) => getProperty(metadata, 'seed', isParameterSeed);
|
||||
|
||||
const parseCFGScale: MetadataParseFunc<ParameterCFGScale> = (metadata) =>
|
||||
getProperty(metadata, 'cfg_scale', isParameterCFGScale);
|
||||
|
||||
const parseCFGRescaleMultiplier: MetadataParseFunc<ParameterCFGRescaleMultiplier> = (metadata) =>
|
||||
getProperty(metadata, 'cfg_rescale_multiplier', isParameterCFGRescaleMultiplier);
|
||||
|
||||
const parseScheduler: MetadataParseFunc<ParameterScheduler> = (metadata) =>
|
||||
getProperty(metadata, 'scheduler', isParameterScheduler);
|
||||
|
||||
const parseWidth: MetadataParseFunc<ParameterWidth> = (metadata) => getProperty(metadata, 'width', isParameterWidth);
|
||||
|
||||
const parseHeight: MetadataParseFunc<ParameterHeight> = (metadata) =>
|
||||
getProperty(metadata, 'height', isParameterHeight);
|
||||
|
||||
const parseSteps: MetadataParseFunc<ParameterSteps> = (metadata) => getProperty(metadata, 'steps', isParameterSteps);
|
||||
|
||||
const parseStrength: MetadataParseFunc<ParameterStrength> = (metadata) =>
|
||||
getProperty(metadata, 'strength', isParameterStrength);
|
||||
|
||||
const parseHRFEnabled: MetadataParseFunc<ParameterHRFEnabled> = (metadata) =>
|
||||
getProperty(metadata, 'hrf_enabled', isParameterHRFEnabled);
|
||||
|
||||
const parseHRFStrength: MetadataParseFunc<ParameterStrength> = (metadata) =>
|
||||
getProperty(metadata, 'hrf_strength', isParameterStrength);
|
||||
|
||||
const parseHRFMethod: MetadataParseFunc<ParameterHRFMethod> = (metadata) =>
|
||||
getProperty(metadata, 'hrf_method', isParameterHRFMethod);
|
||||
|
||||
const parseRefinerSteps: MetadataParseFunc<ParameterSteps> = (metadata) =>
|
||||
getProperty(metadata, 'refiner_steps', isParameterSteps);
|
||||
|
||||
const parseRefinerCFGScale: MetadataParseFunc<ParameterCFGScale> = (metadata) =>
|
||||
getProperty(metadata, 'refiner_cfg_scale', isParameterCFGScale);
|
||||
|
||||
const parseRefinerScheduler: MetadataParseFunc<ParameterScheduler> = (metadata) =>
|
||||
getProperty(metadata, 'refiner_scheduler', isParameterScheduler);
|
||||
|
||||
const parseRefinerPositiveAestheticScore: MetadataParseFunc<ParameterSDXLRefinerPositiveAestheticScore> = (metadata) =>
|
||||
getProperty(metadata, 'refiner_positive_aesthetic_score', isParameterSDXLRefinerPositiveAestheticScore);
|
||||
|
||||
const parseRefinerNegativeAestheticScore: MetadataParseFunc<ParameterSDXLRefinerNegativeAestheticScore> = (metadata) =>
|
||||
getProperty(metadata, 'refiner_negative_aesthetic_score', isParameterSDXLRefinerNegativeAestheticScore);
|
||||
|
||||
const parseRefinerStart: MetadataParseFunc<ParameterSDXLRefinerStart> = (metadata) =>
|
||||
getProperty(metadata, 'refiner_start', isParameterSDXLRefinerStart);
|
||||
|
||||
const parseMainModel: MetadataParseFunc<NonRefinerMainModelConfig> = async (metadata) => {
|
||||
const model = await getProperty(metadata, 'model', undefined);
|
||||
const key = await getModelKey(model, 'main');
|
||||
const mainModelConfig = await fetchModelConfigWithTypeGuard(key, isNonRefinerMainModelConfig);
|
||||
return mainModelConfig;
|
||||
};
|
||||
|
||||
const parseRefinerModel: MetadataParseFunc<RefinerMainModelConfig> = async (metadata) => {
|
||||
const refiner_model = await getProperty(metadata, 'refiner_model', undefined);
|
||||
const key = await getModelKey(refiner_model, 'main');
|
||||
const refinerModelConfig = await fetchModelConfigWithTypeGuard(key, isRefinerMainModelModelConfig);
|
||||
return refinerModelConfig;
|
||||
};
|
||||
|
||||
const parseVAEModel: MetadataParseFunc<VAEModelConfig> = async (metadata) => {
|
||||
const vae = await getProperty(metadata, 'vae', undefined);
|
||||
const key = await getModelKey(vae, 'vae');
|
||||
const vaeModelConfig = await fetchModelConfigWithTypeGuard(key, isVAEModelConfig);
|
||||
return vaeModelConfig;
|
||||
};
|
||||
|
||||
const parseLoRA: MetadataParseFunc<LoRA> = async (metadataItem) => {
|
||||
// Previously, the LoRA model identifier parts were stored in the LoRA metadata: `{key: ..., weight: 0.75}`
|
||||
const modelV1 = await getProperty(metadataItem, 'lora', undefined);
|
||||
// Now, the LoRA model is stored in a `model` property of the LoRA metadata: `{model: {key: ...}, weight: 0.75}`
|
||||
const modelV2 = await getProperty(metadataItem, 'model', undefined);
|
||||
const weight = await getProperty(metadataItem, 'weight', undefined);
|
||||
const key = await getModelKey(modelV2 ?? modelV1, 'lora');
|
||||
const loraModelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig);
|
||||
|
||||
return {
|
||||
model: getModelKeyAndBase(loraModelConfig),
|
||||
weight: isParameterLoRAWeight(weight) ? weight : defaultLoRAConfig.weight,
|
||||
isEnabled: true,
|
||||
};
|
||||
};
|
||||
|
||||
const parseAllLoRAs: MetadataParseFunc<LoRA[]> = async (metadata) => {
|
||||
const lorasRaw = await getProperty(metadata, 'loras', isArray);
|
||||
const parseResults = await Promise.allSettled(lorasRaw.map((lora) => parseLoRA(lora)));
|
||||
const loras = parseResults
|
||||
.filter((result): result is PromiseFulfilledResult<LoRA> => result.status === 'fulfilled')
|
||||
.map((result) => result.value);
|
||||
return loras;
|
||||
};
|
||||
|
||||
const parseControlNet: MetadataParseFunc<ControlNetConfig> = async (metadataItem) => {
|
||||
const control_model = await getProperty(metadataItem, 'control_model');
|
||||
const key = await getModelKey(control_model, 'controlnet');
|
||||
const controlNetModel = await fetchModelConfigWithTypeGuard(key, isControlNetModelConfig);
|
||||
|
||||
const image = zControlField.shape.image.nullish().catch(null).parse(getProperty(metadataItem, 'image'));
|
||||
const control_weight = zControlField.shape.control_weight
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(getProperty(metadataItem, 'control_weight'));
|
||||
const begin_step_percent = zControlField.shape.begin_step_percent
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(getProperty(metadataItem, 'begin_step_percent'));
|
||||
const end_step_percent = zControlField.shape.end_step_percent
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(getProperty(metadataItem, 'end_step_percent'));
|
||||
const control_mode = zControlField.shape.control_mode
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(getProperty(metadataItem, 'control_mode'));
|
||||
const resize_mode = zControlField.shape.resize_mode
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(getProperty(metadataItem, 'resize_mode'));
|
||||
|
||||
const processorType = 'none';
|
||||
const processorNode = CONTROLNET_PROCESSORS.none.default;
|
||||
|
||||
const controlNet: ControlNetConfig = {
|
||||
type: 'controlnet',
|
||||
isEnabled: true,
|
||||
model: zModelIdentifierWithBase.parse(controlNetModel),
|
||||
weight: typeof control_weight === 'number' ? control_weight : initialControlNet.weight,
|
||||
beginStepPct: begin_step_percent ?? initialControlNet.beginStepPct,
|
||||
endStepPct: end_step_percent ?? initialControlNet.endStepPct,
|
||||
controlMode: control_mode ?? initialControlNet.controlMode,
|
||||
resizeMode: resize_mode ?? initialControlNet.resizeMode,
|
||||
controlImage: image?.image_name ?? null,
|
||||
processedControlImage: image?.image_name ?? null,
|
||||
processorType,
|
||||
processorNode,
|
||||
shouldAutoConfig: true,
|
||||
id: uuidv4(),
|
||||
};
|
||||
|
||||
return controlNet;
|
||||
};
|
||||
|
||||
const parseAllControlNets: MetadataParseFunc<ControlNetConfig[]> = async (metadata) => {
|
||||
const controlNetsRaw = await getProperty(metadata, 'controlnets', isArray);
|
||||
const parseResults = await Promise.allSettled(controlNetsRaw.map((cn) => parseControlNet(cn)));
|
||||
const controlNets = parseResults
|
||||
.filter((result): result is PromiseFulfilledResult<ControlNetConfig> => result.status === 'fulfilled')
|
||||
.map((result) => result.value);
|
||||
return controlNets;
|
||||
};
|
||||
|
||||
const parseT2IAdapter: MetadataParseFunc<T2IAdapterConfig> = async (metadataItem) => {
|
||||
const t2i_adapter_model = await getProperty(metadataItem, 't2i_adapter_model');
|
||||
const key = await getModelKey(t2i_adapter_model, 't2i_adapter');
|
||||
const t2iAdapterModel = await fetchModelConfigWithTypeGuard(key, isT2IAdapterModelConfig);
|
||||
|
||||
const image = zT2IAdapterField.shape.image.nullish().catch(null).parse(getProperty(metadataItem, 'image'));
|
||||
const weight = zT2IAdapterField.shape.weight.nullish().catch(null).parse(getProperty(metadataItem, 'weight'));
|
||||
const begin_step_percent = zT2IAdapterField.shape.begin_step_percent
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(getProperty(metadataItem, 'begin_step_percent'));
|
||||
const end_step_percent = zT2IAdapterField.shape.end_step_percent
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(getProperty(metadataItem, 'end_step_percent'));
|
||||
const resize_mode = zT2IAdapterField.shape.resize_mode
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(getProperty(metadataItem, 'resize_mode'));
|
||||
|
||||
const processorType = 'none';
|
||||
const processorNode = CONTROLNET_PROCESSORS.none.default;
|
||||
|
||||
const t2iAdapter: T2IAdapterConfig = {
|
||||
type: 't2i_adapter',
|
||||
isEnabled: true,
|
||||
model: zModelIdentifierWithBase.parse(t2iAdapterModel),
|
||||
weight: typeof weight === 'number' ? weight : initialT2IAdapter.weight,
|
||||
beginStepPct: begin_step_percent ?? initialT2IAdapter.beginStepPct,
|
||||
endStepPct: end_step_percent ?? initialT2IAdapter.endStepPct,
|
||||
resizeMode: resize_mode ?? initialT2IAdapter.resizeMode,
|
||||
controlImage: image?.image_name ?? null,
|
||||
processedControlImage: image?.image_name ?? null,
|
||||
processorType,
|
||||
processorNode,
|
||||
shouldAutoConfig: true,
|
||||
id: uuidv4(),
|
||||
};
|
||||
|
||||
return t2iAdapter;
|
||||
};
|
||||
|
||||
const parseAllT2IAdapters: MetadataParseFunc<T2IAdapterConfig[]> = async (metadata) => {
|
||||
const t2iAdaptersRaw = await getProperty(metadata, 't2iAdapters', isArray);
|
||||
const parseResults = await Promise.allSettled(t2iAdaptersRaw.map((t2iAdapter) => parseT2IAdapter(t2iAdapter)));
|
||||
const t2iAdapters = parseResults
|
||||
.filter((result): result is PromiseFulfilledResult<T2IAdapterConfig> => result.status === 'fulfilled')
|
||||
.map((result) => result.value);
|
||||
return t2iAdapters;
|
||||
};
|
||||
|
||||
const parseIPAdapter: MetadataParseFunc<IPAdapterConfig> = async (metadataItem) => {
|
||||
const ip_adapter_model = await getProperty(metadataItem, 'ip_adapter_model');
|
||||
const key = await getModelKey(ip_adapter_model, 'ip_adapter');
|
||||
const ipAdapterModel = await fetchModelConfigWithTypeGuard(key, isIPAdapterModelConfig);
|
||||
|
||||
const image = zIPAdapterField.shape.image.nullish().catch(null).parse(getProperty(metadataItem, 'image'));
|
||||
const weight = zIPAdapterField.shape.weight.nullish().catch(null).parse(getProperty(metadataItem, 'weight'));
|
||||
const begin_step_percent = zIPAdapterField.shape.begin_step_percent
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(getProperty(metadataItem, 'begin_step_percent'));
|
||||
const end_step_percent = zIPAdapterField.shape.end_step_percent
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(getProperty(metadataItem, 'end_step_percent'));
|
||||
|
||||
const ipAdapter: IPAdapterConfig = {
|
||||
id: uuidv4(),
|
||||
type: 'ip_adapter',
|
||||
isEnabled: true,
|
||||
model: zModelIdentifierWithBase.parse(ipAdapterModel),
|
||||
controlImage: image?.image_name ?? null,
|
||||
weight: weight ?? initialIPAdapter.weight,
|
||||
beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct,
|
||||
endStepPct: end_step_percent ?? initialIPAdapter.endStepPct,
|
||||
};
|
||||
|
||||
return ipAdapter;
|
||||
};
|
||||
|
||||
const parseAllIPAdapters: MetadataParseFunc<IPAdapterConfig[]> = async (metadata) => {
|
||||
const ipAdaptersRaw = await getProperty(metadata, 'ipAdapters', isArray);
|
||||
const parseResults = await Promise.allSettled(ipAdaptersRaw.map((ipAdapter) => parseIPAdapter(ipAdapter)));
|
||||
const ipAdapters = parseResults
|
||||
.filter((result): result is PromiseFulfilledResult<IPAdapterConfig> => result.status === 'fulfilled')
|
||||
.map((result) => result.value);
|
||||
return ipAdapters;
|
||||
};
|
||||
|
||||
export const parsers = {
|
||||
createdBy: parseCreatedBy,
|
||||
generationMode: parseGenerationMode,
|
||||
positivePrompt: parsePositivePrompt,
|
||||
negativePrompt: parseNegativePrompt,
|
||||
sdxlPositiveStylePrompt: parseSDXLPositiveStylePrompt,
|
||||
sdxlNegativeStylePrompt: parseSDXLNegativeStylePrompt,
|
||||
seed: parseSeed,
|
||||
cfgScale: parseCFGScale,
|
||||
cfgRescaleMultiplier: parseCFGRescaleMultiplier,
|
||||
scheduler: parseScheduler,
|
||||
width: parseWidth,
|
||||
height: parseHeight,
|
||||
steps: parseSteps,
|
||||
strength: parseStrength,
|
||||
hrfEnabled: parseHRFEnabled,
|
||||
hrfStrength: parseHRFStrength,
|
||||
hrfMethod: parseHRFMethod,
|
||||
refinerSteps: parseRefinerSteps,
|
||||
refinerCFGScale: parseRefinerCFGScale,
|
||||
refinerScheduler: parseRefinerScheduler,
|
||||
refinerPositiveAestheticScore: parseRefinerPositiveAestheticScore,
|
||||
refinerNegativeAestheticScore: parseRefinerNegativeAestheticScore,
|
||||
refinerStart: parseRefinerStart,
|
||||
mainModel: parseMainModel,
|
||||
refinerModel: parseRefinerModel,
|
||||
vaeModel: parseVAEModel,
|
||||
lora: parseLoRA,
|
||||
loras: parseAllLoRAs,
|
||||
controlNet: parseControlNet,
|
||||
controlNets: parseAllControlNets,
|
||||
t2iAdapter: parseT2IAdapter,
|
||||
t2iAdapters: parseAllT2IAdapters,
|
||||
ipAdapter: parseIPAdapter,
|
||||
ipAdapters: parseAllIPAdapters,
|
||||
} as const;
|
295
invokeai/frontend/web/src/features/metadata/util/recallers.ts
Normal file
295
invokeai/frontend/web/src/features/metadata/util/recallers.ts
Normal file
@ -0,0 +1,295 @@
|
||||
import { getStore } from 'app/store/nanostores/store';
|
||||
import { controlAdapterRecalled } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types';
|
||||
import { setHrfEnabled, setHrfMethod, setHrfStrength } from 'features/hrf/store/hrfSlice';
|
||||
import type { LoRA } from 'features/lora/store/loraSlice';
|
||||
import { loraRecalled } from 'features/lora/store/loraSlice';
|
||||
import type { MetadataRecallFunc } from 'features/metadata/types';
|
||||
import { zModelIdentifierWithBase } from 'features/nodes/types/common';
|
||||
import { modelSelected } from 'features/parameters/store/actions';
|
||||
import {
|
||||
heightRecalled,
|
||||
setCfgRescaleMultiplier,
|
||||
setCfgScale,
|
||||
setImg2imgStrength,
|
||||
setNegativePrompt,
|
||||
setPositivePrompt,
|
||||
setScheduler,
|
||||
setSeed,
|
||||
setSteps,
|
||||
vaeSelected,
|
||||
widthRecalled,
|
||||
} from 'features/parameters/store/generationSlice';
|
||||
import type {
|
||||
ParameterCFGRescaleMultiplier,
|
||||
ParameterCFGScale,
|
||||
ParameterHeight,
|
||||
ParameterHRFEnabled,
|
||||
ParameterHRFMethod,
|
||||
ParameterNegativePrompt,
|
||||
ParameterNegativeStylePromptSDXL,
|
||||
ParameterPositivePrompt,
|
||||
ParameterPositiveStylePromptSDXL,
|
||||
ParameterScheduler,
|
||||
ParameterSDXLRefinerNegativeAestheticScore,
|
||||
ParameterSDXLRefinerPositiveAestheticScore,
|
||||
ParameterSDXLRefinerStart,
|
||||
ParameterSeed,
|
||||
ParameterSteps,
|
||||
ParameterStrength,
|
||||
ParameterWidth,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import {
|
||||
refinerModelChanged,
|
||||
setNegativeStylePromptSDXL,
|
||||
setPositiveStylePromptSDXL,
|
||||
setRefinerCFGScale,
|
||||
setRefinerNegativeAestheticScore,
|
||||
setRefinerPositiveAestheticScore,
|
||||
setRefinerScheduler,
|
||||
setRefinerStart,
|
||||
setRefinerSteps,
|
||||
} from 'features/sdxl/store/sdxlSlice';
|
||||
import type { NonRefinerMainModelConfig, RefinerMainModelConfig, VAEModelConfig } from 'services/api/types';
|
||||
|
||||
const recallPositivePrompt: MetadataRecallFunc<ParameterPositivePrompt> = (positivePrompt) => {
|
||||
getStore().dispatch(setPositivePrompt(positivePrompt));
|
||||
};
|
||||
|
||||
const recallNegativePrompt: MetadataRecallFunc<ParameterNegativePrompt> = (negativePrompt) => {
|
||||
getStore().dispatch(setNegativePrompt(negativePrompt));
|
||||
};
|
||||
|
||||
const recallSDXLPositiveStylePrompt: MetadataRecallFunc<ParameterPositiveStylePromptSDXL> = (positiveStylePrompt) => {
|
||||
getStore().dispatch(setPositiveStylePromptSDXL(positiveStylePrompt));
|
||||
};
|
||||
|
||||
const recallSDXLNegativeStylePrompt: MetadataRecallFunc<ParameterNegativeStylePromptSDXL> = (negativeStylePrompt) => {
|
||||
getStore().dispatch(setNegativeStylePromptSDXL(negativeStylePrompt));
|
||||
};
|
||||
|
||||
const recallSeed: MetadataRecallFunc<ParameterSeed> = (seed) => {
|
||||
getStore().dispatch(setSeed(seed));
|
||||
};
|
||||
|
||||
const recallCFGScale: MetadataRecallFunc<ParameterCFGScale> = (cfgScale) => {
|
||||
getStore().dispatch(setCfgScale(cfgScale));
|
||||
};
|
||||
|
||||
const recallCFGRescaleMultiplier: MetadataRecallFunc<ParameterCFGRescaleMultiplier> = (cfgRescaleMultiplier) => {
|
||||
getStore().dispatch(setCfgRescaleMultiplier(cfgRescaleMultiplier));
|
||||
};
|
||||
|
||||
const recallScheduler: MetadataRecallFunc<ParameterScheduler> = (scheduler) => {
|
||||
getStore().dispatch(setScheduler(scheduler));
|
||||
};
|
||||
|
||||
const recallWidth: MetadataRecallFunc<ParameterWidth> = (width) => {
|
||||
getStore().dispatch(widthRecalled(width));
|
||||
};
|
||||
|
||||
const recallHeight: MetadataRecallFunc<ParameterHeight> = (height) => {
|
||||
getStore().dispatch(heightRecalled(height));
|
||||
};
|
||||
|
||||
const recallSteps: MetadataRecallFunc<ParameterSteps> = (steps) => {
|
||||
getStore().dispatch(setSteps(steps));
|
||||
};
|
||||
|
||||
const recallStrength: MetadataRecallFunc<ParameterStrength> = (strength) => {
|
||||
getStore().dispatch(setImg2imgStrength(strength));
|
||||
};
|
||||
|
||||
const recallHRFEnabled: MetadataRecallFunc<ParameterHRFEnabled> = (hrfEnabled) => {
|
||||
getStore().dispatch(setHrfEnabled(hrfEnabled));
|
||||
};
|
||||
|
||||
const recallHRFStrength: MetadataRecallFunc<ParameterStrength> = (hrfStrength) => {
|
||||
getStore().dispatch(setHrfStrength(hrfStrength));
|
||||
};
|
||||
|
||||
const recallHRFMethod: MetadataRecallFunc<ParameterHRFMethod> = (hrfMethod) => {
|
||||
getStore().dispatch(setHrfMethod(hrfMethod));
|
||||
};
|
||||
|
||||
const recallRefinerSteps: MetadataRecallFunc<ParameterSteps> = (refinerSteps) => {
|
||||
getStore().dispatch(setRefinerSteps(refinerSteps));
|
||||
};
|
||||
|
||||
const recallRefinerCFGScale: MetadataRecallFunc<ParameterCFGScale> = (refinerCFGScale) => {
|
||||
getStore().dispatch(setRefinerCFGScale(refinerCFGScale));
|
||||
};
|
||||
|
||||
const recallRefinerScheduler: MetadataRecallFunc<ParameterScheduler> = (refinerScheduler) => {
|
||||
getStore().dispatch(setRefinerScheduler(refinerScheduler));
|
||||
};
|
||||
|
||||
const recallRefinerPositiveAestheticScore: MetadataRecallFunc<ParameterSDXLRefinerPositiveAestheticScore> = (
|
||||
refinerPositiveAestheticScore
|
||||
) => {
|
||||
getStore().dispatch(setRefinerPositiveAestheticScore(refinerPositiveAestheticScore));
|
||||
};
|
||||
|
||||
const recallRefinerNegativeAestheticScore: MetadataRecallFunc<ParameterSDXLRefinerNegativeAestheticScore> = (
|
||||
refinerNegativeAestheticScore
|
||||
) => {
|
||||
getStore().dispatch(setRefinerNegativeAestheticScore(refinerNegativeAestheticScore));
|
||||
};
|
||||
|
||||
const recallRefinerStart: MetadataRecallFunc<ParameterSDXLRefinerStart> = (refinerStart) => {
|
||||
getStore().dispatch(setRefinerStart(refinerStart));
|
||||
};
|
||||
|
||||
const recallModel: MetadataRecallFunc<NonRefinerMainModelConfig> = (model) => {
|
||||
const modelIdentifier = zModelIdentifierWithBase.parse(model);
|
||||
getStore().dispatch(modelSelected(modelIdentifier));
|
||||
};
|
||||
|
||||
const recallRefinerModel: MetadataRecallFunc<RefinerMainModelConfig> = (refinerModel) => {
|
||||
const modelIdentifier = zModelIdentifierWithBase.parse(refinerModel);
|
||||
getStore().dispatch(refinerModelChanged(modelIdentifier));
|
||||
};
|
||||
|
||||
const recallVAE: MetadataRecallFunc<VAEModelConfig | null | undefined> = (vaeModel) => {
|
||||
if (!vaeModel) {
|
||||
getStore().dispatch(vaeSelected(null));
|
||||
return;
|
||||
}
|
||||
const modelIdentifier = zModelIdentifierWithBase.parse(vaeModel);
|
||||
getStore().dispatch(vaeSelected(modelIdentifier));
|
||||
};
|
||||
|
||||
const recallLoRA: MetadataRecallFunc<LoRA> = (lora) => {
|
||||
getStore().dispatch(loraRecalled(lora));
|
||||
};
|
||||
|
||||
const recallAllLoRAs: MetadataRecallFunc<LoRA[]> = (loras) => {
|
||||
const { dispatch } = getStore();
|
||||
loras.forEach((lora) => {
|
||||
dispatch(loraRecalled(lora));
|
||||
});
|
||||
};
|
||||
|
||||
const recallControlNet: MetadataRecallFunc<ControlNetConfig> = (controlNet) => {
|
||||
getStore().dispatch(controlAdapterRecalled(controlNet));
|
||||
};
|
||||
|
||||
const recallControlNets: MetadataRecallFunc<ControlNetConfig[]> = (controlNets) => {
|
||||
const { dispatch } = getStore();
|
||||
controlNets.forEach((controlNet) => {
|
||||
dispatch(controlAdapterRecalled(controlNet));
|
||||
});
|
||||
};
|
||||
|
||||
const recallT2IAdapter: MetadataRecallFunc<T2IAdapterConfig> = (t2iAdapter) => {
|
||||
getStore().dispatch(controlAdapterRecalled(t2iAdapter));
|
||||
};
|
||||
|
||||
const recallT2IAdapters: MetadataRecallFunc<T2IAdapterConfig[]> = (t2iAdapters) => {
|
||||
const { dispatch } = getStore();
|
||||
t2iAdapters.forEach((t2iAdapter) => {
|
||||
dispatch(controlAdapterRecalled(t2iAdapter));
|
||||
});
|
||||
};
|
||||
|
||||
const recallIPAdapter: MetadataRecallFunc<IPAdapterConfig> = (ipAdapter) => {
|
||||
getStore().dispatch(controlAdapterRecalled(ipAdapter));
|
||||
};
|
||||
|
||||
const recallIPAdapters: MetadataRecallFunc<IPAdapterConfig[]> = (ipAdapters) => {
|
||||
const { dispatch } = getStore();
|
||||
ipAdapters.forEach((ipAdapter) => {
|
||||
dispatch(controlAdapterRecalled(ipAdapter));
|
||||
});
|
||||
};
|
||||
|
||||
export const recallPrompts = (_metadata: unknown) => {
|
||||
// recallPositivePrompt(metadata);
|
||||
// recallNegativePrompt(metadata);
|
||||
// recallSDXLPositiveStylePrompt(metadata);
|
||||
// recallSDXLNegativeStylePrompt(metadata);
|
||||
// parameterNotSetToast(t('metadata.allPrompts'));
|
||||
// parameterSetToast(t('metadata.allPrompts'));
|
||||
};
|
||||
|
||||
export const recallWidthAndHeight = (_metadata: unknown) => {
|
||||
// recallWidth(metadata);
|
||||
// recallHeight(metadata);
|
||||
};
|
||||
|
||||
export const recallAll = async (_metadata: unknown) => {
|
||||
// if (!metadata) {
|
||||
// allParameterNotSetToast();
|
||||
// return;
|
||||
// }
|
||||
// // Update the main model first, as other parameters may depend on it.
|
||||
// await recallModel(metadata);
|
||||
// await Promise.allSettled([
|
||||
// // Core parameters
|
||||
// recallCFGScale(metadata),
|
||||
// recallCFGRescaleMultiplier(metadata),
|
||||
// recallPositivePrompt(metadata),
|
||||
// recallNegativePrompt(metadata),
|
||||
// recallScheduler(metadata),
|
||||
// recallSeed(metadata),
|
||||
// recallSteps(metadata),
|
||||
// recallWidth(metadata),
|
||||
// recallHeight(metadata),
|
||||
// recallStrength(metadata),
|
||||
// recallHRFEnabled(metadata),
|
||||
// recallHRFMethod(metadata),
|
||||
// recallHRFStrength(metadata),
|
||||
// // SDXL parameters
|
||||
// recallSDXLPositiveStylePrompt(metadata),
|
||||
// recallSDXLNegativeStylePrompt(metadata),
|
||||
// recallRefinerSteps(metadata),
|
||||
// recallRefinerCFGScale(metadata),
|
||||
// recallRefinerScheduler(metadata),
|
||||
// recallRefinerPositiveAestheticScore(metadata),
|
||||
// recallRefinerNegativeAestheticScore(metadata),
|
||||
// recallRefinerStart(metadata),
|
||||
// // Models
|
||||
// recallVAE(metadata),
|
||||
// recallRefinerModel(metadata),
|
||||
// recallAllLoRAs(metadata),
|
||||
// recallControlNets(metadata),
|
||||
// recallT2IAdapters(metadata),
|
||||
// recallIPAdapters(metadata),
|
||||
// ]);
|
||||
// allParameterSetToast();
|
||||
};
|
||||
|
||||
export const recallers = {
|
||||
positivePrompt: recallPositivePrompt,
|
||||
negativePrompt: recallNegativePrompt,
|
||||
sdxlPositiveStylePrompt: recallSDXLPositiveStylePrompt,
|
||||
sdxlNegativeStylePrompt: recallSDXLNegativeStylePrompt,
|
||||
seed: recallSeed,
|
||||
cfgScale: recallCFGScale,
|
||||
cfgRescaleMultiplier: recallCFGRescaleMultiplier,
|
||||
scheduler: recallScheduler,
|
||||
width: recallWidth,
|
||||
height: recallHeight,
|
||||
steps: recallSteps,
|
||||
strength: recallStrength,
|
||||
hrfEnabled: recallHRFEnabled,
|
||||
hrfStrength: recallHRFStrength,
|
||||
hrfMethod: recallHRFMethod,
|
||||
refinerSteps: recallRefinerSteps,
|
||||
refinerCFGScale: recallRefinerCFGScale,
|
||||
refinerScheduler: recallRefinerScheduler,
|
||||
refinerPositiveAestheticScore: recallRefinerPositiveAestheticScore,
|
||||
refinerNegativeAestheticScore: recallRefinerNegativeAestheticScore,
|
||||
refinerStart: recallRefinerStart,
|
||||
model: recallModel,
|
||||
refinerModel: recallRefinerModel,
|
||||
vae: recallVAE,
|
||||
lora: recallLoRA,
|
||||
loras: recallAllLoRAs,
|
||||
controlNets: recallControlNets,
|
||||
controlNet: recallControlNet,
|
||||
t2iAdapters: recallT2IAdapters,
|
||||
t2iAdapter: recallT2IAdapter,
|
||||
ipAdapters: recallIPAdapters,
|
||||
ipAdapter: recallIPAdapter,
|
||||
} as const;
|
117
invokeai/frontend/web/src/features/metadata/util/validators.ts
Normal file
117
invokeai/frontend/web/src/features/metadata/util/validators.ts
Normal file
@ -0,0 +1,117 @@
|
||||
import { getStore } from 'app/store/nanostores/store';
|
||||
import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types';
|
||||
import type { LoRA } from 'features/lora/store/loraSlice';
|
||||
import type { MetadataValidateFunc } from 'features/metadata/types';
|
||||
import { InvalidModelConfigError } from 'features/parameters/util/modelFetchingHelpers';
|
||||
import type { BaseModelType, RefinerMainModelConfig, VAEModelConfig } from 'services/api/types';
|
||||
|
||||
/**
|
||||
* Checks the given base model type against the currently-selected model's base type and throws an error if they are
|
||||
* incompatible.
|
||||
* @param base The base model type to validate.
|
||||
* @param message An optional message to use in the error if the base model is incompatible.
|
||||
*/
|
||||
const validateBaseCompatibility = (base?: BaseModelType, message?: string) => {
|
||||
if (!base) {
|
||||
throw new InvalidModelConfigError(message || 'Missing base');
|
||||
}
|
||||
const currentBase = getStore().getState().generation.model?.base;
|
||||
if (currentBase && base !== currentBase) {
|
||||
throw new InvalidModelConfigError(message || `Incompatible base models: ${base} and ${currentBase}`);
|
||||
}
|
||||
};
|
||||
|
||||
const validateRefinerModel: MetadataValidateFunc<RefinerMainModelConfig> = (refinerModel) => {
|
||||
validateBaseCompatibility('sdxl', 'Refiner incompatible with currently-selected model');
|
||||
return new Promise((resolve) => resolve(refinerModel));
|
||||
};
|
||||
|
||||
const validateVAEModel: MetadataValidateFunc<VAEModelConfig> = (vaeModel) => {
|
||||
validateBaseCompatibility(vaeModel.base, 'VAE incompatible with currently-selected model');
|
||||
return new Promise((resolve) => resolve(vaeModel));
|
||||
};
|
||||
|
||||
const validateLoRA: MetadataValidateFunc<LoRA> = (lora) => {
|
||||
validateBaseCompatibility(lora.model.base, 'LoRA incompatible with currently-selected model');
|
||||
return new Promise((resolve) => resolve(lora));
|
||||
};
|
||||
|
||||
const validateLoRAs: MetadataValidateFunc<LoRA[]> = (loras) => {
|
||||
const validatedLoRAs: LoRA[] = [];
|
||||
loras.forEach((lora) => {
|
||||
try {
|
||||
validateBaseCompatibility(lora.model.base, 'LoRA incompatible with currently-selected model');
|
||||
validatedLoRAs.push(lora);
|
||||
} catch {
|
||||
// This is a no-op - we want to continue validating the rest of the LoRAs, and an empty list is valid.
|
||||
}
|
||||
});
|
||||
return new Promise((resolve) => resolve(validatedLoRAs));
|
||||
};
|
||||
|
||||
const validateControlNet: MetadataValidateFunc<ControlNetConfig> = (controlNet) => {
|
||||
validateBaseCompatibility(controlNet.model?.base, 'ControlNet incompatible with currently-selected model');
|
||||
return new Promise((resolve) => resolve(controlNet));
|
||||
};
|
||||
|
||||
const validateControlNets: MetadataValidateFunc<ControlNetConfig[]> = (controlNets) => {
|
||||
const validatedControlNets: ControlNetConfig[] = [];
|
||||
controlNets.forEach((controlNet) => {
|
||||
try {
|
||||
validateBaseCompatibility(controlNet.model?.base, 'ControlNet incompatible with currently-selected model');
|
||||
validatedControlNets.push(controlNet);
|
||||
} catch {
|
||||
// This is a no-op - we want to continue validating the rest of the ControlNets, and an empty list is valid.
|
||||
}
|
||||
});
|
||||
return new Promise((resolve) => resolve(validatedControlNets));
|
||||
};
|
||||
|
||||
const validateT2IAdapter: MetadataValidateFunc<T2IAdapterConfig> = (t2iAdapter) => {
|
||||
validateBaseCompatibility(t2iAdapter.model?.base, 'T2I Adapter incompatible with currently-selected model');
|
||||
return new Promise((resolve) => resolve(t2iAdapter));
|
||||
};
|
||||
|
||||
const validateT2IAdapters: MetadataValidateFunc<T2IAdapterConfig[]> = (t2iAdapters) => {
|
||||
const validatedT2IAdapters: T2IAdapterConfig[] = [];
|
||||
t2iAdapters.forEach((t2iAdapter) => {
|
||||
try {
|
||||
validateBaseCompatibility(t2iAdapter.model?.base, 'T2I Adapter incompatible with currently-selected model');
|
||||
validatedT2IAdapters.push(t2iAdapter);
|
||||
} catch {
|
||||
// This is a no-op - we want to continue validating the rest of the T2I Adapters, and an empty list is valid.
|
||||
}
|
||||
});
|
||||
return new Promise((resolve) => resolve(validatedT2IAdapters));
|
||||
};
|
||||
|
||||
const validateIPAdapter: MetadataValidateFunc<IPAdapterConfig> = (ipAdapter) => {
|
||||
validateBaseCompatibility(ipAdapter.model?.base, 'IP Adapter incompatible with currently-selected model');
|
||||
return new Promise((resolve) => resolve(ipAdapter));
|
||||
};
|
||||
|
||||
const validateIPAdapters: MetadataValidateFunc<IPAdapterConfig[]> = (ipAdapters) => {
|
||||
const validatedIPAdapters: IPAdapterConfig[] = [];
|
||||
ipAdapters.forEach((ipAdapter) => {
|
||||
try {
|
||||
validateBaseCompatibility(ipAdapter.model?.base, 'IP Adapter incompatible with currently-selected model');
|
||||
validatedIPAdapters.push(ipAdapter);
|
||||
} catch {
|
||||
// This is a no-op - we want to continue validating the rest of the IP Adapters, and an empty list is valid.
|
||||
}
|
||||
});
|
||||
return new Promise((resolve) => resolve(validatedIPAdapters));
|
||||
};
|
||||
|
||||
export const validators = {
|
||||
refinerModel: validateRefinerModel,
|
||||
vaeModel: validateVAEModel,
|
||||
lora: validateLoRA,
|
||||
loras: validateLoRAs,
|
||||
controlNet: validateControlNet,
|
||||
controlNets: validateControlNets,
|
||||
t2iAdapter: validateT2IAdapter,
|
||||
t2iAdapters: validateT2IAdapters,
|
||||
ipAdapter: validateIPAdapter,
|
||||
ipAdapters: validateIPAdapters,
|
||||
} as const;
|
@ -1,5 +1,8 @@
|
||||
import { z } from 'zod';
|
||||
|
||||
import type { ModelIdentifier as ModelIdentifierV2 } from './v2/common';
|
||||
import { zModelIdentifier as zModelIdentifierV2 } from './v2/common';
|
||||
|
||||
// #region Field data schemas
|
||||
export const zImageField = z.object({
|
||||
image_name: z.string().trim().min(1),
|
||||
@ -69,6 +72,8 @@ export const zModelIdentifier = z.object({
|
||||
});
|
||||
export const isModelIdentifier = (field: unknown): field is ModelIdentifier =>
|
||||
zModelIdentifier.safeParse(field).success;
|
||||
export const isModelIdentifierV2 = (field: unknown): field is ModelIdentifierV2 =>
|
||||
zModelIdentifierV2.safeParse(field).success;
|
||||
export const zModelFieldBase = zModelIdentifier;
|
||||
export const zModelIdentifierWithBase = zModelIdentifier.extend({ base: zBaseModel });
|
||||
export type BaseModel = z.infer<typeof zBaseModel>;
|
||||
|
@ -1,29 +1,22 @@
|
||||
import { z } from 'zod';
|
||||
|
||||
import {
|
||||
zControlField,
|
||||
zIPAdapterField,
|
||||
zMainModelField,
|
||||
zModelFieldBase,
|
||||
zSDXLRefinerModelField,
|
||||
zT2IAdapterField,
|
||||
zVAEModelField,
|
||||
} from './common';
|
||||
import { zControlField, zIPAdapterField, zT2IAdapterField } from './common';
|
||||
|
||||
export const zLoRAWeight = z.number().nullish();
|
||||
// #region Metadata-optimized versions of schemas
|
||||
// TODO: It's possible that `deepPartial` will be deprecated:
|
||||
// - https://github.com/colinhacks/zod/issues/2106
|
||||
// - https://github.com/colinhacks/zod/issues/2854
|
||||
export const zLoRAMetadataItem = z.object({
|
||||
lora: zModelFieldBase.deepPartial(),
|
||||
weight: z.number(),
|
||||
lora: z.unknown(),
|
||||
weight: zLoRAWeight,
|
||||
});
|
||||
const zControlNetMetadataItem = zControlField.deepPartial();
|
||||
const zIPAdapterMetadataItem = zIPAdapterField.deepPartial();
|
||||
const zT2IAdapterMetadataItem = zT2IAdapterField.deepPartial();
|
||||
const zSDXLRefinerModelMetadataItem = zSDXLRefinerModelField.deepPartial();
|
||||
const zModelMetadataItem = zMainModelField.deepPartial();
|
||||
const zVAEModelMetadataItem = zVAEModelField.deepPartial();
|
||||
const zControlNetMetadataItem = zControlField.merge(z.object({ control_model: z.unknown() })).deepPartial();
|
||||
const zIPAdapterMetadataItem = zIPAdapterField.merge(z.object({ ip_adapter_model: z.unknown() })).deepPartial();
|
||||
const zT2IAdapterMetadataItem = zT2IAdapterField.merge(z.object({ t2i_adapter_model: z.unknown() })).deepPartial();
|
||||
const zSDXLRefinerModelMetadataItem = z.unknown();
|
||||
const zModelMetadataItem = z.unknown();
|
||||
const zVAEModelMetadataItem = z.unknown();
|
||||
export type LoRAMetadataItem = z.infer<typeof zLoRAMetadataItem>;
|
||||
export type ControlNetMetadataItem = z.infer<typeof zControlNetMetadataItem>;
|
||||
export type IPAdapterMetadataItem = z.infer<typeof zIPAdapterMetadataItem>;
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,14 +1,7 @@
|
||||
import { NUMPY_RAND_MAX } from 'app/constants';
|
||||
import {
|
||||
zBaseModel,
|
||||
zControlNetModelField,
|
||||
zIPAdapterModelField,
|
||||
zLoRAModelField,
|
||||
zModelIdentifierWithBase,
|
||||
zSchedulerField,
|
||||
zSDXLRefinerModelField,
|
||||
zT2IAdapterModelField,
|
||||
zVAEModelField,
|
||||
} from 'features/nodes/types/common';
|
||||
import { z } from 'zod';
|
||||
|
||||
@ -111,42 +104,42 @@ export const isParameterModel = (val: unknown): val is ParameterModel => zParame
|
||||
// #endregion
|
||||
|
||||
// #region SDXL Refiner Model
|
||||
export const zParameterSDXLRefinerModel = zSDXLRefinerModelField.extend({ base: zBaseModel });
|
||||
export const zParameterSDXLRefinerModel = zModelIdentifierWithBase;
|
||||
export type ParameterSDXLRefinerModel = z.infer<typeof zParameterSDXLRefinerModel>;
|
||||
export const isParameterSDXLRefinerModel = (val: unknown): val is ParameterSDXLRefinerModel =>
|
||||
zParameterSDXLRefinerModel.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region VAE Model
|
||||
export const zParameterVAEModel = zVAEModelField.extend({ base: zBaseModel });
|
||||
export const zParameterVAEModel = zModelIdentifierWithBase;
|
||||
export type ParameterVAEModel = z.infer<typeof zParameterVAEModel>;
|
||||
export const isParameterVAEModel = (val: unknown): val is ParameterVAEModel =>
|
||||
zParameterVAEModel.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region LoRA Model
|
||||
export const zParameterLoRAModel = zLoRAModelField.extend({ base: zBaseModel });
|
||||
export const zParameterLoRAModel = zModelIdentifierWithBase;
|
||||
export type ParameterLoRAModel = z.infer<typeof zParameterLoRAModel>;
|
||||
export const isParameterLoRAModel = (val: unknown): val is ParameterLoRAModel =>
|
||||
zParameterLoRAModel.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region ControlNet Model
|
||||
export const zParameterControlNetModel = zControlNetModelField.extend({ base: zBaseModel });
|
||||
export const zParameterControlNetModel = zModelIdentifierWithBase;
|
||||
export type ParameterControlNetModel = z.infer<typeof zParameterLoRAModel>;
|
||||
export const isParameterControlNetModel = (val: unknown): val is ParameterControlNetModel =>
|
||||
zParameterControlNetModel.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region IP Adapter Model
|
||||
export const zParameterIPAdapterModel = zIPAdapterModelField.extend({ base: zBaseModel });
|
||||
export const zParameterIPAdapterModel = zModelIdentifierWithBase;
|
||||
export type ParameterIPAdapterModel = z.infer<typeof zParameterIPAdapterModel>;
|
||||
export const isParameterIPAdapterModel = (val: unknown): val is ParameterIPAdapterModel =>
|
||||
zParameterIPAdapterModel.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region T2I Adapter Model
|
||||
export const zParameterT2IAdapterModel = zT2IAdapterModelField.extend({ base: zBaseModel });
|
||||
export const zParameterT2IAdapterModel = zModelIdentifierWithBase;
|
||||
export type ParameterT2IAdapterModel = z.infer<typeof zParameterT2IAdapterModel>;
|
||||
export const isParameterT2IAdapterModel = (val: unknown): val is ParameterT2IAdapterModel =>
|
||||
zParameterT2IAdapterModel.safeParse(val).success;
|
||||
@ -218,3 +211,9 @@ export type ParameterCanvasCoherenceMode = z.infer<typeof zParameterCanvasCohere
|
||||
export const isParameterCanvasCoherenceMode = (val: unknown): val is ParameterCanvasCoherenceMode =>
|
||||
zParameterCanvasCoherenceMode.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region LoRA weight
|
||||
export const zLoRAWeight = z.number();
|
||||
export type ParameterLoRAWeight = z.infer<typeof zLoRAWeight>;
|
||||
export const isParameterLoRAWeight = (val: unknown): val is ParameterLoRAWeight => zLoRAWeight.safeParse(val).success;
|
||||
// #endregion
|
||||
|
@ -1,7 +1,8 @@
|
||||
import { getStore } from 'app/store/nanostores/store';
|
||||
import { isModelIdentifier } from 'features/nodes/types/common';
|
||||
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
|
||||
import { isModelIdentifier, isModelIdentifierV2 } from 'features/nodes/types/common';
|
||||
import { modelsApi } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig, BaseModelType } from 'services/api/types';
|
||||
import type { AnyModelConfig, BaseModelType, ModelType } from 'services/api/types';
|
||||
import {
|
||||
isControlNetModelConfig,
|
||||
isIPAdapterModelConfig,
|
||||
@ -41,6 +42,12 @@ export class InvalidModelConfigError extends Error {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetches the model config for a given model key.
|
||||
* @param key The model key.
|
||||
* @returns A promise that resolves to the model config.
|
||||
* @throws {ModelConfigNotFoundError} If the model config is unable to be fetched.
|
||||
*/
|
||||
export const fetchModelConfig = async (key: string): Promise<AnyModelConfig> => {
|
||||
const { dispatch } = getStore();
|
||||
try {
|
||||
@ -52,6 +59,37 @@ export const fetchModelConfig = async (key: string): Promise<AnyModelConfig> =>
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Fetches the model config for a given model name, base model, and model type. This provides backwards compatibility
|
||||
* for MM1 model identifiers.
|
||||
* @param name The model name.
|
||||
* @param base The base model.
|
||||
* @param type The model type.
|
||||
* @returns A promise that resolves to the model config.
|
||||
* @throws {ModelConfigNotFoundError} If the model config is unable to be fetched.
|
||||
*/
|
||||
export const fetchModelConfigByAttrs = async (
|
||||
name: string,
|
||||
base: BaseModelType,
|
||||
type: ModelType
|
||||
): Promise<AnyModelConfig> => {
|
||||
const { dispatch } = getStore();
|
||||
try {
|
||||
const req = dispatch(modelsApi.endpoints.getModelConfigByAttrs.initiate({ name, base, type }));
|
||||
req.unsubscribe();
|
||||
return await req.unwrap();
|
||||
} catch {
|
||||
throw new ModelConfigNotFoundError(`Unable to retrieve model config for name/base/type ${name}/${base}/${type}`);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Fetches the model config for a given model key and type, and ensures that the model config is of a specific type.
|
||||
* @param key The model key.
|
||||
* @param typeGuard A type guard function that checks if the model config is of the expected type.
|
||||
* @returns A promise that resolves to the model config. The model config is guaranteed to be of the expected type.
|
||||
* @throws {InvalidModelConfigError} If the model config is unable to be fetched or is of an unexpected type.
|
||||
*/
|
||||
export const fetchModelConfigWithTypeGuard = async <T extends AnyModelConfig>(
|
||||
key: string,
|
||||
typeGuard: (config: AnyModelConfig) => config is T
|
||||
@ -63,15 +101,17 @@ export const fetchModelConfigWithTypeGuard = async <T extends AnyModelConfig>(
|
||||
return modelConfig;
|
||||
};
|
||||
|
||||
export const fetchMainModel = async (key: string) => {
|
||||
// TODO(psyche): Remove these helpers once `useRecallParameters` is removed
|
||||
|
||||
export const fetchMainModelConfig = async (key: string) => {
|
||||
return fetchModelConfigWithTypeGuard(key, isNonRefinerMainModelConfig);
|
||||
};
|
||||
|
||||
export const fetchRefinerModel = async (key: string) => {
|
||||
export const fetchRefinerModelConfig = async (key: string) => {
|
||||
return fetchModelConfigWithTypeGuard(key, isRefinerMainModelModelConfig);
|
||||
};
|
||||
|
||||
export const fetchVAEModel = async (key: string) => {
|
||||
export const fetchVAEModelConfig = async (key: string) => {
|
||||
return fetchModelConfigWithTypeGuard(key, isVAEModelConfig);
|
||||
};
|
||||
|
||||
@ -95,19 +135,39 @@ export const fetchTextualInversionModel = async (key: string) => {
|
||||
return fetchModelConfigWithTypeGuard(key, isTextualInversionModelConfig);
|
||||
};
|
||||
|
||||
export const isBaseCompatible = (sourceBase: BaseModelType, targetBase: BaseModelType) => {
|
||||
return sourceBase === targetBase;
|
||||
};
|
||||
|
||||
/**
|
||||
* Raises an error if the source base model is incompatible with the target base model.
|
||||
* @param sourceBase The source base model.
|
||||
* @param targetBase The target base model.
|
||||
* @param message An optional custom message to include in the error.
|
||||
* @throws {InvalidModelConfigError} If the source base model is incompatible with the target base model.
|
||||
*/
|
||||
export const raiseIfBaseIncompatible = (sourceBase: BaseModelType, targetBase?: BaseModelType, message?: string) => {
|
||||
if (targetBase && !isBaseCompatible(sourceBase, targetBase)) {
|
||||
if (targetBase && sourceBase !== targetBase) {
|
||||
throw new InvalidModelConfigError(message || `Incompatible base models: ${sourceBase} and ${targetBase}`);
|
||||
}
|
||||
};
|
||||
|
||||
export const getModelKey = (modelIdentifier: unknown, message?: string): string => {
|
||||
if (!isModelIdentifier(modelIdentifier)) {
|
||||
throw new InvalidModelConfigError(message || `Invalid model identifier: ${modelIdentifier}`);
|
||||
/**
|
||||
* Fetches the model key from a model identifier. This includes fetching the key for MM1 format model identifiers.
|
||||
* @param modelIdentifier The model identifier. The MM2 format `{key: string}` simply extracts the key. The MM1 format
|
||||
* `{model_name: string, base_model: BaseModelType}` must do a network request to fetch the key.
|
||||
* @param type The type of model to fetch. This is used to fetch the key for MM1 format model identifiers.
|
||||
* @param message An optional custom message to include in the error if the model identifier is invalid.
|
||||
* @returns A promise that resolves to the model key.
|
||||
* @throws {InvalidModelConfigError} If the model identifier is invalid.
|
||||
*/
|
||||
export const getModelKey = async (modelIdentifier: unknown, type: ModelType, message?: string): Promise<string> => {
|
||||
if (isModelIdentifier(modelIdentifier)) {
|
||||
return modelIdentifier.key;
|
||||
}
|
||||
return modelIdentifier.key;
|
||||
if (isModelIdentifierV2(modelIdentifier)) {
|
||||
return (await fetchModelConfigByAttrs(modelIdentifier.model_name, modelIdentifier.base_model, type)).key;
|
||||
}
|
||||
throw new InvalidModelConfigError(message || `Invalid model identifier: ${modelIdentifier}`);
|
||||
};
|
||||
|
||||
export const getModelKeyAndBase = (modelConfig: AnyModelConfig): ModelIdentifierWithBase => ({
|
||||
key: modelConfig.key,
|
||||
base: modelConfig.base,
|
||||
});
|
||||
|
@ -1,150 +0,0 @@
|
||||
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
|
||||
import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types';
|
||||
import {
|
||||
initialControlNet,
|
||||
initialIPAdapter,
|
||||
initialT2IAdapter,
|
||||
} from 'features/controlAdapters/util/buildControlAdapter';
|
||||
import type { LoRA } from 'features/lora/store/loraSlice';
|
||||
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
|
||||
import { zModelIdentifierWithBase } from 'features/nodes/types/common';
|
||||
import type {
|
||||
ControlNetMetadataItem,
|
||||
IPAdapterMetadataItem,
|
||||
LoRAMetadataItem,
|
||||
T2IAdapterMetadataItem,
|
||||
} from 'features/nodes/types/metadata';
|
||||
import {
|
||||
fetchControlNetModel,
|
||||
fetchIPAdapterModel,
|
||||
fetchLoRAModel,
|
||||
fetchMainModel,
|
||||
fetchRefinerModel,
|
||||
fetchT2IAdapterModel,
|
||||
fetchVAEModel,
|
||||
getModelKey,
|
||||
raiseIfBaseIncompatible,
|
||||
} from 'features/parameters/util/modelFetchingHelpers';
|
||||
import type { BaseModelType } from 'services/api/types';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
export const prepareMainModelMetadataItem = async (model: unknown): Promise<ModelIdentifierWithBase> => {
|
||||
const key = getModelKey(model);
|
||||
const mainModel = await fetchMainModel(key);
|
||||
return zModelIdentifierWithBase.parse(mainModel);
|
||||
};
|
||||
|
||||
export const prepareRefinerMetadataItem = async (model: unknown): Promise<ModelIdentifierWithBase> => {
|
||||
const key = getModelKey(model);
|
||||
const refinerModel = await fetchRefinerModel(key);
|
||||
return zModelIdentifierWithBase.parse(refinerModel);
|
||||
};
|
||||
|
||||
export const prepareVAEMetadataItem = async (vae: unknown, base?: BaseModelType): Promise<ModelIdentifierWithBase> => {
|
||||
const key = getModelKey(vae);
|
||||
const vaeModel = await fetchVAEModel(key);
|
||||
raiseIfBaseIncompatible(vaeModel.base, base, 'VAE incompatible with currently-selected model');
|
||||
return zModelIdentifierWithBase.parse(vaeModel);
|
||||
};
|
||||
|
||||
export const prepareLoRAMetadataItem = async (
|
||||
loraMetadataItem: LoRAMetadataItem,
|
||||
base?: BaseModelType
|
||||
): Promise<LoRA> => {
|
||||
const key = getModelKey(loraMetadataItem.lora);
|
||||
const loraModel = await fetchLoRAModel(key);
|
||||
raiseIfBaseIncompatible(loraModel.base, base, 'LoRA incompatible with currently-selected model');
|
||||
return { key: loraModel.key, base: loraModel.base, weight: loraMetadataItem.weight, isEnabled: true };
|
||||
};
|
||||
|
||||
export const prepareControlNetMetadataItem = async (
|
||||
controlnetMetadataItem: ControlNetMetadataItem,
|
||||
base?: BaseModelType
|
||||
): Promise<ControlNetConfig> => {
|
||||
const key = getModelKey(controlnetMetadataItem.control_model);
|
||||
const controlNetModel = await fetchControlNetModel(key);
|
||||
raiseIfBaseIncompatible(controlNetModel.base, base, 'ControlNet incompatible with currently-selected model');
|
||||
|
||||
const { image, control_weight, begin_step_percent, end_step_percent, control_mode, resize_mode } =
|
||||
controlnetMetadataItem;
|
||||
|
||||
// We don't save the original image that was processed into a control image, only the processed image
|
||||
const processorType = 'none';
|
||||
const processorNode = CONTROLNET_PROCESSORS.none.default;
|
||||
|
||||
const controlnet: ControlNetConfig = {
|
||||
type: 'controlnet',
|
||||
isEnabled: true,
|
||||
model: zModelIdentifierWithBase.parse(controlNetModel),
|
||||
weight: typeof control_weight === 'number' ? control_weight : initialControlNet.weight,
|
||||
beginStepPct: begin_step_percent || initialControlNet.beginStepPct,
|
||||
endStepPct: end_step_percent || initialControlNet.endStepPct,
|
||||
controlMode: control_mode || initialControlNet.controlMode,
|
||||
resizeMode: resize_mode || initialControlNet.resizeMode,
|
||||
controlImage: image?.image_name || null,
|
||||
processedControlImage: image?.image_name || null,
|
||||
processorType,
|
||||
processorNode,
|
||||
shouldAutoConfig: true,
|
||||
id: uuidv4(),
|
||||
};
|
||||
|
||||
return controlnet;
|
||||
};
|
||||
|
||||
export const prepareT2IAdapterMetadataItem = async (
|
||||
t2iAdapterMetadataItem: T2IAdapterMetadataItem,
|
||||
base?: BaseModelType
|
||||
): Promise<T2IAdapterConfig> => {
|
||||
const key = getModelKey(t2iAdapterMetadataItem.t2i_adapter_model);
|
||||
const t2iAdapterModel = await fetchT2IAdapterModel(key);
|
||||
raiseIfBaseIncompatible(t2iAdapterModel.base, base, 'T2I Adapter incompatible with currently-selected model');
|
||||
|
||||
const { image, weight, begin_step_percent, end_step_percent, resize_mode } = t2iAdapterMetadataItem;
|
||||
|
||||
// We don't save the original image that was processed into a control image, only the processed image
|
||||
const processorType = 'none';
|
||||
const processorNode = CONTROLNET_PROCESSORS.none.default;
|
||||
|
||||
const t2iAdapter: T2IAdapterConfig = {
|
||||
type: 't2i_adapter',
|
||||
isEnabled: true,
|
||||
model: zModelIdentifierWithBase.parse(t2iAdapterModel),
|
||||
weight: typeof weight === 'number' ? weight : initialT2IAdapter.weight,
|
||||
beginStepPct: begin_step_percent || initialT2IAdapter.beginStepPct,
|
||||
endStepPct: end_step_percent || initialT2IAdapter.endStepPct,
|
||||
resizeMode: resize_mode || initialT2IAdapter.resizeMode,
|
||||
controlImage: image?.image_name || null,
|
||||
processedControlImage: image?.image_name || null,
|
||||
processorType,
|
||||
processorNode,
|
||||
shouldAutoConfig: true,
|
||||
id: uuidv4(),
|
||||
};
|
||||
|
||||
return t2iAdapter;
|
||||
};
|
||||
|
||||
export const prepareIPAdapterMetadataItem = async (
|
||||
ipAdapterMetadataItem: IPAdapterMetadataItem,
|
||||
base?: BaseModelType
|
||||
): Promise<IPAdapterConfig> => {
|
||||
const key = getModelKey(ipAdapterMetadataItem?.ip_adapter_model);
|
||||
const ipAdapterModel = await fetchIPAdapterModel(key);
|
||||
raiseIfBaseIncompatible(ipAdapterModel.base, base, 'T2I Adapter incompatible with currently-selected model');
|
||||
|
||||
const { image, weight, begin_step_percent, end_step_percent } = ipAdapterMetadataItem;
|
||||
|
||||
const ipAdapter: IPAdapterConfig = {
|
||||
id: uuidv4(),
|
||||
type: 'ip_adapter',
|
||||
isEnabled: true,
|
||||
controlImage: image?.image_name ?? null,
|
||||
model: zModelIdentifierWithBase.parse(ipAdapterModel),
|
||||
weight: weight ?? initialIPAdapter.weight,
|
||||
beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct,
|
||||
endStepPct: end_step_percent ?? initialIPAdapter.endStepPct,
|
||||
};
|
||||
|
||||
return ipAdapter;
|
||||
};
|
@ -25,7 +25,7 @@ const formLabelProps2: FormLabelProps = {
|
||||
|
||||
export const AdvancedSettingsAccordion = memo(() => {
|
||||
const vaeKey = useAppSelector((state) => state.generation.vae?.key);
|
||||
const { data: vaeConfig } = useGetModelConfigQuery(vaeKey ?? skipToken);
|
||||
const { currentData: vaeConfig } = useGetModelConfigQuery(vaeKey ?? skipToken);
|
||||
const selectBadges = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectGenerationSlice, (generation) => {
|
||||
|
@ -1,10 +1,8 @@
|
||||
import type { EntityState, Update } from '@reduxjs/toolkit';
|
||||
import type { PatchCollection } from '@reduxjs/toolkit/dist/query/core/buildThunks';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { JSONObject } from 'common/types';
|
||||
import type { BoardId } from 'features/gallery/store/types';
|
||||
import { ASSETS_CATEGORIES, IMAGE_CATEGORIES, IMAGE_LIMIT } from 'features/gallery/store/types';
|
||||
import type { CoreMetadata } from 'features/nodes/types/metadata';
|
||||
import { zCoreMetadata } from 'features/nodes/types/metadata';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { t } from 'i18next';
|
||||
import { keyBy } from 'lodash-es';
|
||||
@ -118,22 +116,9 @@ export const imagesApi = api.injectEndpoints({
|
||||
providesTags: (result, error, image_name) => [{ type: 'Image', id: image_name }],
|
||||
keepUnusedDataFor: 86400, // 24 hours
|
||||
}),
|
||||
getImageMetadata: build.query<CoreMetadata | undefined, string>({
|
||||
getImageMetadata: build.query<JSONObject | undefined, string>({
|
||||
query: (image_name) => ({ url: buildImagesUrl(`i/${image_name}/metadata`) }),
|
||||
providesTags: (result, error, image_name) => [{ type: 'ImageMetadata', id: image_name }],
|
||||
transformResponse: (
|
||||
response: paths['/api/v1/images/i/{image_name}/metadata']['get']['responses']['200']['content']['application/json']
|
||||
) => {
|
||||
if (response) {
|
||||
const result = zCoreMetadata.safeParse(response);
|
||||
if (result.success) {
|
||||
return result.data;
|
||||
} else {
|
||||
logger('images').warn('Problem parsing metadata');
|
||||
}
|
||||
}
|
||||
return;
|
||||
},
|
||||
keepUnusedDataFor: 86400, // 24 hours
|
||||
}),
|
||||
getImageWorkflow: build.query<
|
||||
|
@ -70,6 +70,8 @@ export type ScanFolderResponse =
|
||||
paths['/api/v2/models/scan_folder']['get']['responses']['200']['content']['application/json'];
|
||||
type ScanFolderArg = operations['scan_for_models']['parameters']['query'];
|
||||
|
||||
export type GetByAttrsArg = operations['get_model_records_by_attrs']['parameters']['query'];
|
||||
|
||||
export const mainModelsAdapter = createEntityAdapter<MainModelConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
@ -223,6 +225,18 @@ export const modelsApi = api.injectEndpoints({
|
||||
return tags;
|
||||
},
|
||||
}),
|
||||
getModelConfigByAttrs: build.query<AnyModelConfig, GetByAttrsArg>({
|
||||
query: (arg) => buildModelsUrl(`get_by_attrs?${queryString.stringify(arg)}`),
|
||||
providesTags: (result) => {
|
||||
const tags: ApiTagDescription[] = ['Model'];
|
||||
|
||||
if (result) {
|
||||
tags.push({ type: 'ModelConfig', id: result.key });
|
||||
}
|
||||
|
||||
return tags;
|
||||
},
|
||||
}),
|
||||
syncModels: build.mutation<void, void>({
|
||||
query: () => {
|
||||
return {
|
||||
@ -300,6 +314,7 @@ export const modelsApi = api.injectEndpoints({
|
||||
});
|
||||
|
||||
export const {
|
||||
useGetModelConfigByAttrsQuery,
|
||||
useGetModelConfigQuery,
|
||||
useGetMainModelsQuery,
|
||||
useGetControlNetModelsQuery,
|
||||
|
Loading…
Reference in New Issue
Block a user