diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index b31fa5d93c..12d06ee224 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -656,6 +656,7 @@ } }, "metadata": { + "allPrompts": "All Prompts", "cfgScale": "CFG scale", "cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)", "createdBy": "Created By", @@ -664,6 +665,7 @@ "height": "Height", "hiresFix": "High Resolution Optimization", "imageDetails": "Image Details", + "imageDimensions": "Image Dimensions", "initImage": "Initial image", "metadata": "Metadata", "model": "Model", @@ -671,9 +673,11 @@ "noImageDetails": "No image details found", "noMetaData": "No metadata found", "noRecallParameters": "No parameters to recall found", + "parameterSet": "Parameter {{parameter}} set", "perlin": "Perlin Noise", "positivePrompt": "Positive Prompt", "recallParameters": "Recall Parameters", + "recallParameter": "Recall {{label}}", "scheduler": "Scheduler", "seamless": "Seamless", "seed": "Seed", @@ -1381,8 +1385,8 @@ "nodesNotValidJSON": "Not a valid JSON", "nodesSaved": "Nodes Saved", "nodesUnrecognizedTypes": "Cannot load. Graph has unrecognized types", - "parameterNotSet": "Parameter not set", - "parameterSet": "Parameter set", + "parameterNotSet": "{{parameter}} not set", + "parameterSet": "{{parameter}} set", "parametersFailed": "Problem loading parameters", "parametersFailedDesc": "Unable to load init image.", "parametersNotSet": "Parameters Not Set", diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/batchEnqueued.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/batchEnqueued.ts index 6419f840ec..07a1039bef 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/batchEnqueued.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/batchEnqueued.ts @@ -1,6 +1,6 @@ -import { createStandaloneToast, theme, TOAST_OPTIONS } from '@invoke-ai/ui-library'; import { logger } from 'app/logging/logger'; import { parseify } from 'common/util/serialize'; +import { toast } from 'common/util/toast'; import { zPydanticValidationError } from 'features/system/store/zodSchemas'; import { t } from 'i18next'; import { truncate, upperFirst } from 'lodash-es'; @@ -8,11 +8,6 @@ import { queueApi } from 'services/api/endpoints/queue'; import { startAppListening } from '..'; -const { toast } = createStandaloneToast({ - theme: theme, - defaultOptions: TOAST_OPTIONS.defaultOptions, -}); - export const addBatchEnqueuedListener = () => { // success startAppListening({ diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/bulkDownload.tsx b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/bulkDownload.tsx index 69adfb5b67..fa5c962ff5 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/bulkDownload.tsx +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/bulkDownload.tsx @@ -1,7 +1,8 @@ import type { UseToastOptions } from '@invoke-ai/ui-library'; -import { createStandaloneToast, ExternalLink, theme, TOAST_OPTIONS } from '@invoke-ai/ui-library'; +import { ExternalLink } from '@invoke-ai/ui-library'; import { logger } from 'app/logging/logger'; import { startAppListening } from 'app/store/middleware/listenerMiddleware'; +import { toast } from 'common/util/toast'; import { t } from 'i18next'; import { imagesApi } from 'services/api/endpoints/images'; import { @@ -12,11 +13,6 @@ import { const log = logger('images'); -const { toast } = createStandaloneToast({ - theme: theme, - defaultOptions: TOAST_OPTIONS.defaultOptions, -}); - export const addBulkDownloadListeners = () => { startAppListening({ matcher: imagesApi.endpoints.bulkDownloadImages.matchFulfilled, diff --git a/invokeai/frontend/web/src/common/util/toast.ts b/invokeai/frontend/web/src/common/util/toast.ts new file mode 100644 index 0000000000..ac61a4a12d --- /dev/null +++ b/invokeai/frontend/web/src/common/util/toast.ts @@ -0,0 +1,6 @@ +import { createStandaloneToast, theme, TOAST_OPTIONS } from '@invoke-ai/ui-library'; + +export const { toast } = createStandaloneToast({ + theme: theme, + defaultOptions: TOAST_OPTIONS.defaultOptions, +}); diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx index 7eec7e1875..cfa679c805 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx @@ -1,300 +1,54 @@ -import { isModelIdentifier } from 'features/nodes/types/common'; -import type { - ControlNetMetadataItem, - CoreMetadata, - IPAdapterMetadataItem, - LoRAMetadataItem, - T2IAdapterMetadataItem, -} from 'features/nodes/types/metadata'; -import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters'; -import { memo, useCallback, useMemo } from 'react'; -import { useTranslation } from 'react-i18next'; - -import ImageMetadataItem, { ModelMetadataItem, VAEMetadataItem } from './ImageMetadataItem'; +import { MetadataControlNets } from 'features/metadata/components/MetadataControlNets'; +import { MetadataIPAdapters } from 'features/metadata/components/MetadataIPAdapters'; +import { MetadataItem } from 'features/metadata/components/MetadataItem'; +import { MetadataLoRAs } from 'features/metadata/components/MetadataLoRAs'; +import { MetadataT2IAdapters } from 'features/metadata/components/MetadataT2IAdapters'; +import { handlers } from 'features/metadata/util/handlers'; +import { memo } from 'react'; type Props = { - metadata?: CoreMetadata; + metadata?: unknown; }; const ImageMetadataActions = (props: Props) => { const { metadata } = props; - const { t } = useTranslation(); - - const { - recallPositivePrompt, - recallNegativePrompt, - recallSeed, - recallCfgScale, - recallCfgRescaleMultiplier, - recallModel, - recallScheduler, - recallVaeModel, - recallSteps, - recallWidth, - recallHeight, - recallStrength, - recallHrfEnabled, - recallHrfStrength, - recallHrfMethod, - recallLoRA, - recallControlNet, - recallIPAdapter, - recallT2IAdapter, - recallSDXLPositiveStylePrompt, - recallSDXLNegativeStylePrompt, - } = useRecallParameters(); - - const handleRecallPositivePrompt = useCallback(() => { - recallPositivePrompt(metadata?.positive_prompt); - }, [metadata?.positive_prompt, recallPositivePrompt]); - - const handleRecallNegativePrompt = useCallback(() => { - recallNegativePrompt(metadata?.negative_prompt); - }, [metadata?.negative_prompt, recallNegativePrompt]); - - const handleRecallSDXLPositiveStylePrompt = useCallback(() => { - recallSDXLPositiveStylePrompt(metadata?.positive_style_prompt); - }, [metadata?.positive_style_prompt, recallSDXLPositiveStylePrompt]); - - const handleRecallSDXLNegativeStylePrompt = useCallback(() => { - recallSDXLNegativeStylePrompt(metadata?.negative__style_prompt); - }, [metadata?.negative__style_prompt, recallSDXLNegativeStylePrompt]); - - const handleRecallSeed = useCallback(() => { - recallSeed(metadata?.seed); - }, [metadata?.seed, recallSeed]); - - const handleRecallModel = useCallback(() => { - recallModel(metadata?.model); - }, [metadata?.model, recallModel]); - - const handleRecallWidth = useCallback(() => { - recallWidth(metadata?.width); - }, [metadata?.width, recallWidth]); - - const handleRecallHeight = useCallback(() => { - recallHeight(metadata?.height); - }, [metadata?.height, recallHeight]); - - const handleRecallScheduler = useCallback(() => { - recallScheduler(metadata?.scheduler); - }, [metadata?.scheduler, recallScheduler]); - - const handleRecallVaeModel = useCallback(() => { - recallVaeModel(metadata?.vae); - }, [metadata?.vae, recallVaeModel]); - - const handleRecallSteps = useCallback(() => { - recallSteps(metadata?.steps); - }, [metadata?.steps, recallSteps]); - - const handleRecallCfgScale = useCallback(() => { - recallCfgScale(metadata?.cfg_scale); - }, [metadata?.cfg_scale, recallCfgScale]); - - const handleRecallCfgRescaleMultiplier = useCallback(() => { - recallCfgRescaleMultiplier(metadata?.cfg_rescale_multiplier); - }, [metadata?.cfg_rescale_multiplier, recallCfgRescaleMultiplier]); - - const handleRecallStrength = useCallback(() => { - recallStrength(metadata?.strength); - }, [metadata?.strength, recallStrength]); - - const handleRecallHrfEnabled = useCallback(() => { - recallHrfEnabled(metadata?.hrf_enabled); - }, [metadata?.hrf_enabled, recallHrfEnabled]); - - const handleRecallHrfStrength = useCallback(() => { - recallHrfStrength(metadata?.hrf_strength); - }, [metadata?.hrf_strength, recallHrfStrength]); - - const handleRecallHrfMethod = useCallback(() => { - recallHrfMethod(metadata?.hrf_method); - }, [metadata?.hrf_method, recallHrfMethod]); - - const handleRecallLoRA = useCallback( - (lora: LoRAMetadataItem) => { - recallLoRA(lora); - }, - [recallLoRA] - ); - - const handleRecallControlNet = useCallback( - (controlnet: ControlNetMetadataItem) => { - recallControlNet(controlnet); - }, - [recallControlNet] - ); - - const handleRecallIPAdapter = useCallback( - (ipAdapter: IPAdapterMetadataItem) => { - recallIPAdapter(ipAdapter); - }, - [recallIPAdapter] - ); - - const handleRecallT2IAdapter = useCallback( - (ipAdapter: T2IAdapterMetadataItem) => { - recallT2IAdapter(ipAdapter); - }, - [recallT2IAdapter] - ); - - const validControlNets: ControlNetMetadataItem[] = useMemo(() => { - return metadata?.controlnets - ? metadata.controlnets.filter((controlnet) => isModelIdentifier(controlnet.control_model)) - : []; - }, [metadata?.controlnets]); - - const validIPAdapters: IPAdapterMetadataItem[] = useMemo(() => { - return metadata?.ipAdapters - ? metadata.ipAdapters.filter((ipAdapter) => isModelIdentifier(ipAdapter.ip_adapter_model)) - : []; - }, [metadata?.ipAdapters]); - - const validT2IAdapters: T2IAdapterMetadataItem[] = useMemo(() => { - return metadata?.t2iAdapters - ? metadata.t2iAdapters.filter((t2iAdapter) => isModelIdentifier(t2iAdapter.t2i_adapter_model)) - : []; - }, [metadata?.t2iAdapters]); - if (!metadata || Object.keys(metadata).length === 0) { return null; } return ( <> - {metadata.created_by && } - {metadata.generation_mode && ( - - )} - {metadata.positive_prompt && ( - - )} - {metadata.negative_prompt && ( - - )} - {metadata.positive_style_prompt && ( - - )} - {metadata.negative_style_prompt && ( - - )} - {metadata.seed !== undefined && metadata.seed !== null && ( - - )} - {metadata.model !== undefined && metadata.model !== null && metadata.model.key && ( - - )} - {metadata.width && ( - - )} - {metadata.height && ( - - )} - {metadata.scheduler && ( - - )} - - {metadata.steps && ( - - )} - {metadata.cfg_scale !== undefined && metadata.cfg_scale !== null && ( - - )} - {metadata.cfg_rescale_multiplier !== undefined && metadata.cfg_rescale_multiplier !== null && ( - - )} - {metadata.strength && ( - - )} - {metadata.hrf_enabled && ( - - )} - {metadata.hrf_enabled && metadata.hrf_strength && ( - - )} - {metadata.hrf_enabled && metadata.hrf_method && ( - - )} - {metadata.loras && - metadata.loras.map((lora, index) => { - if (isModelIdentifier(lora.lora)) { - return ( - - ); - } - })} - {validControlNets.map((controlnet, index) => ( - - ))} - {validIPAdapters.map((ipAdapter, index) => ( - - ))} - {validT2IAdapters.map((t2iAdapter, index) => ( - - ))} + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + ); }; diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataItem.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataItem.tsx index ce25b54e59..146c1437d1 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataItem.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataItem.tsx @@ -1,28 +1,39 @@ import { ExternalLink, Flex, IconButton, Text, Tooltip } from '@invoke-ai/ui-library'; import { skipToken } from '@reduxjs/toolkit/query'; +import { get } from 'lodash-es'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import { IoArrowUndoCircleOutline } from 'react-icons/io5'; -import { PiCopyBold } from 'react-icons/pi'; +import { PiArrowBendUpLeftBold } from 'react-icons/pi'; import { useGetModelConfigQuery } from 'services/api/endpoints/models'; type MetadataItemProps = { isLink?: boolean; label: string; - onClick?: () => void; - value: number | string | boolean; + metadata: unknown; + propertyName: string; + onRecall?: (value: unknown) => void; labelPosition?: string; - withCopy?: boolean; }; /** * Component to display an individual metadata item or parameter. */ -const ImageMetadataItem = ({ label, value, onClick, isLink, labelPosition, withCopy = false }: MetadataItemProps) => { +const ImageMetadataItem = ({ + label, + metadata, + propertyName, + onRecall: _onRecall, + isLink, + labelPosition, +}: MetadataItemProps) => { const { t } = useTranslation(); - const handleCopy = useCallback(() => { - navigator.clipboard.writeText(value?.toString()); - }, [value]); + const value = useMemo(() => get(metadata, propertyName), [metadata, propertyName]); + const onRecall = useCallback(() => { + if (!_onRecall) { + return; + } + _onRecall(value); + }, [_onRecall, value]); if (!value) { return null; @@ -30,27 +41,15 @@ const ImageMetadataItem = ({ label, value, onClick, isLink, labelPosition, withC return ( - {onClick && ( - + {_onRecall && ( + } + aria-label={t('metadata.recallParameter', { parameter: label })} + icon={} size="xs" variant="ghost" fontSize={20} - onClick={onClick} - /> - - )} - {withCopy && ( - - } - size="xs" - variant="ghost" - fontSize={14} - onClick={handleCopy} + onClick={onRecall} /> )} diff --git a/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx b/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx index 579d45054b..ddcdd58e75 100644 --- a/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx +++ b/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx @@ -24,29 +24,29 @@ type LoRACardProps = { export const LoRACard = memo((props: LoRACardProps) => { const { lora } = props; const dispatch = useAppDispatch(); - const { data: loraConfig } = useGetModelConfigQuery(lora.key); + const { data: loraConfig } = useGetModelConfigQuery(lora.model.key); const handleChange = useCallback( (v: number) => { - dispatch(loraWeightChanged({ key: lora.key, weight: v })); + dispatch(loraWeightChanged({ key: lora.model.key, weight: v })); }, - [dispatch, lora.key] + [dispatch, lora.model.key] ); const handleSetLoraToggle = useCallback(() => { - dispatch(loraIsEnabledChanged({ key: lora.key, isEnabled: !lora.isEnabled })); - }, [dispatch, lora.key, lora.isEnabled]); + dispatch(loraIsEnabledChanged({ key: lora.model.key, isEnabled: !lora.isEnabled })); + }, [dispatch, lora.model.key, lora.isEnabled]); const handleRemoveLora = useCallback(() => { - dispatch(loraRemoved(lora.key)); - }, [dispatch, lora.key]); + dispatch(loraRemoved(lora.model.key)); + }, [dispatch, lora.model.key]); return ( - {loraConfig?.name ?? lora.key.substring(0, 8)} + {loraConfig?.name ?? lora.model.key.substring(0, 8)} diff --git a/invokeai/frontend/web/src/features/lora/store/loraSlice.ts b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts index 377406b3e5..32bef5fd9b 100644 --- a/invokeai/frontend/web/src/features/lora/store/loraSlice.ts +++ b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts @@ -2,9 +2,11 @@ import type { PayloadAction } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit'; import type { PersistConfig, RootState } from 'app/store/store'; import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas'; +import { getModelKeyAndBase } from 'features/parameters/util/modelFetchingHelpers'; import type { LoRAModelConfig } from 'services/api/types'; -export type LoRA = ParameterLoRAModel & { +export type LoRA = { + model: ParameterLoRAModel; weight: number; isEnabled?: boolean; }; @@ -29,11 +31,11 @@ export const loraSlice = createSlice({ initialState: initialLoraState, reducers: { loraAdded: (state, action: PayloadAction) => { - const { key, base } = action.payload; - state.loras[key] = { key, base, ...defaultLoRAConfig }; + const model = getModelKeyAndBase(action.payload); + state.loras[model.key] = { ...defaultLoRAConfig, model }; }, loraRecalled: (state, action: PayloadAction) => { - state.loras[action.payload.key] = action.payload; + state.loras[action.payload.model.key] = action.payload; }, loraRemoved: (state, action: PayloadAction) => { const key = action.payload; @@ -58,7 +60,7 @@ export const loraSlice = createSlice({ } lora.weight = defaultLoRAConfig.weight; }, - loraIsEnabledChanged: (state, action: PayloadAction>) => { + loraIsEnabledChanged: (state, action: PayloadAction<{ key: string; isEnabled: boolean }>) => { const { key, isEnabled } = action.payload; const lora = state.loras[key]; if (!lora) { diff --git a/invokeai/frontend/web/src/features/metadata/components/MetadataControlNets.tsx b/invokeai/frontend/web/src/features/metadata/components/MetadataControlNets.tsx new file mode 100644 index 0000000000..2e8e3b6f9a --- /dev/null +++ b/invokeai/frontend/web/src/features/metadata/components/MetadataControlNets.tsx @@ -0,0 +1,66 @@ +import type { ControlNetConfig } from 'features/controlAdapters/store/types'; +import { MetadataItemView } from 'features/metadata/components/MetadataItemView'; +import type { MetadataHandlers } from 'features/metadata/types'; +import { handlers } from 'features/metadata/util/handlers'; +import { useCallback, useEffect, useMemo, useState } from 'react'; + +type Props = { + metadata: unknown; +}; + +export const MetadataControlNets = ({ metadata }: Props) => { + const [controlNets, setControlNets] = useState([]); + + useEffect(() => { + const parse = async () => { + try { + const parsed = await handlers.controlNets.parse(metadata); + setControlNets(parsed); + } catch (e) { + setControlNets([]); + } + }; + parse(); + }, [metadata]); + + const label = useMemo(() => handlers.controlNets.getLabel(), []); + + return ( + <> + {controlNets.map((controlNet) => ( + + ))} + + ); +}; + +const MetadataViewControlNet = ({ + label, + controlNet, + handlers, +}: { + label: string; + controlNet: ControlNetConfig; + handlers: MetadataHandlers; +}) => { + const onRecall = useCallback(() => { + if (!handlers.recallItem) { + return; + } + handlers.recallItem(controlNet, true); + }, [handlers, controlNet]); + + const renderedValue = useMemo(() => { + if (!handlers.renderItemValue) { + return null; + } + return handlers.renderItemValue(controlNet); + }, [handlers, controlNet]); + + return ; +}; diff --git a/invokeai/frontend/web/src/features/metadata/components/MetadataIPAdapters.tsx b/invokeai/frontend/web/src/features/metadata/components/MetadataIPAdapters.tsx new file mode 100644 index 0000000000..fef281cd09 --- /dev/null +++ b/invokeai/frontend/web/src/features/metadata/components/MetadataIPAdapters.tsx @@ -0,0 +1,66 @@ +import type { IPAdapterConfig } from 'features/controlAdapters/store/types'; +import { MetadataItemView } from 'features/metadata/components/MetadataItemView'; +import type { MetadataHandlers } from 'features/metadata/types'; +import { handlers } from 'features/metadata/util/handlers'; +import { useCallback, useEffect, useMemo, useState } from 'react'; + +type Props = { + metadata: unknown; +}; + +export const MetadataIPAdapters = ({ metadata }: Props) => { + const [ipAdapters, setIPAdapters] = useState([]); + + useEffect(() => { + const parse = async () => { + try { + const parsed = await handlers.ipAdapters.parse(metadata); + setIPAdapters(parsed); + } catch (e) { + setIPAdapters([]); + } + }; + parse(); + }, [metadata]); + + const label = useMemo(() => handlers.ipAdapters.getLabel(), []); + + return ( + <> + {ipAdapters.map((ipAdapter) => ( + + ))} + + ); +}; + +const MetadataViewIPAdapter = ({ + label, + ipAdapter, + handlers, +}: { + label: string; + ipAdapter: IPAdapterConfig; + handlers: MetadataHandlers; +}) => { + const onRecall = useCallback(() => { + if (!handlers.recallItem) { + return; + } + handlers.recallItem(ipAdapter, true); + }, [handlers, ipAdapter]); + + const renderedValue = useMemo(() => { + if (!handlers.renderItemValue) { + return null; + } + return handlers.renderItemValue(ipAdapter); + }, [handlers, ipAdapter]); + + return ; +}; diff --git a/invokeai/frontend/web/src/features/metadata/components/MetadataItem.tsx b/invokeai/frontend/web/src/features/metadata/components/MetadataItem.tsx new file mode 100644 index 0000000000..66d101f458 --- /dev/null +++ b/invokeai/frontend/web/src/features/metadata/components/MetadataItem.tsx @@ -0,0 +1,33 @@ +import { typedMemo } from '@invoke-ai/ui-library'; +import { MetadataItemView } from 'features/metadata/components/MetadataItemView'; +import { useMetadataItem } from 'features/metadata/hooks/useMetadataItem'; +import type { MetadataHandlers } from 'features/metadata/types'; +import { MetadataParseFailedToken } from 'features/metadata/util/parsers'; + +type MetadataItemProps = { + metadata: unknown; + handlers: MetadataHandlers; + direction?: 'row' | 'column'; +}; + +const _MetadataItem = typedMemo(({ metadata, handlers, direction = 'row' }: MetadataItemProps) => { + const { label, isDisabled, value, renderedValue, onRecall } = useMetadataItem(metadata, handlers); + + if (value === MetadataParseFailedToken) { + return null; + } + + return ( + + ); +}); + +export const MetadataItem = typedMemo(_MetadataItem); + +MetadataItem.displayName = 'MetadataItem'; diff --git a/invokeai/frontend/web/src/features/metadata/components/MetadataItemView.tsx b/invokeai/frontend/web/src/features/metadata/components/MetadataItemView.tsx new file mode 100644 index 0000000000..14ccc4ae2b --- /dev/null +++ b/invokeai/frontend/web/src/features/metadata/components/MetadataItemView.tsx @@ -0,0 +1,29 @@ +import { Flex, Text } from '@invoke-ai/ui-library'; +import { RecallButton } from 'features/metadata/components/RecallButton'; +import { memo } from 'react'; + +type MetadataItemViewProps = { + onRecall: () => void; + label: string; + renderedValue: React.ReactNode; + isDisabled: boolean; + direction?: 'row' | 'column'; +}; + +export const MetadataItemView = memo( + ({ label, onRecall, isDisabled, renderedValue, direction = 'row' }: MetadataItemViewProps) => { + return ( + + {onRecall && } + + + {label}: + + {renderedValue} + + + ); + } +); + +MetadataItemView.displayName = 'MetadataItemView'; diff --git a/invokeai/frontend/web/src/features/metadata/components/MetadataLoRAs.tsx b/invokeai/frontend/web/src/features/metadata/components/MetadataLoRAs.tsx new file mode 100644 index 0000000000..3225a048d4 --- /dev/null +++ b/invokeai/frontend/web/src/features/metadata/components/MetadataLoRAs.tsx @@ -0,0 +1,61 @@ +import type { LoRA } from 'features/lora/store/loraSlice'; +import { MetadataItemView } from 'features/metadata/components/MetadataItemView'; +import type { MetadataHandlers } from 'features/metadata/types'; +import { handlers } from 'features/metadata/util/handlers'; +import { useCallback, useEffect, useMemo, useState } from 'react'; + +type Props = { + metadata: unknown; +}; + +export const MetadataLoRAs = ({ metadata }: Props) => { + const [loras, setLoRAs] = useState([]); + + useEffect(() => { + const parse = async () => { + try { + const parsed = await handlers.loras.parse(metadata); + setLoRAs(parsed); + } catch (e) { + setLoRAs([]); + } + }; + parse(); + }, [metadata]); + + const label = useMemo(() => handlers.loras.getLabel(), []); + + return ( + <> + {loras.map((lora) => ( + + ))} + + ); +}; + +const MetadataViewLoRA = ({ + label, + lora, + handlers, +}: { + label: string; + lora: LoRA; + handlers: MetadataHandlers; +}) => { + const onRecall = useCallback(() => { + if (!handlers.recallItem) { + return; + } + handlers.recallItem(lora, true); + }, [handlers, lora]); + + const renderedValue = useMemo(() => { + if (!handlers.renderItemValue) { + return null; + } + return handlers.renderItemValue(lora); + }, [handlers, lora]); + + return ; +}; diff --git a/invokeai/frontend/web/src/features/metadata/components/MetadataT2IAdapters.tsx b/invokeai/frontend/web/src/features/metadata/components/MetadataT2IAdapters.tsx new file mode 100644 index 0000000000..4a73fb3ccb --- /dev/null +++ b/invokeai/frontend/web/src/features/metadata/components/MetadataT2IAdapters.tsx @@ -0,0 +1,66 @@ +import type { T2IAdapterConfig } from 'features/controlAdapters/store/types'; +import { MetadataItemView } from 'features/metadata/components/MetadataItemView'; +import type { MetadataHandlers } from 'features/metadata/types'; +import { handlers } from 'features/metadata/util/handlers'; +import { useCallback, useEffect, useMemo, useState } from 'react'; + +type Props = { + metadata: unknown; +}; + +export const MetadataT2IAdapters = ({ metadata }: Props) => { + const [t2iAdapters, setT2IAdapters] = useState([]); + + useEffect(() => { + const parse = async () => { + try { + const parsed = await handlers.t2iAdapters.parse(metadata); + setT2IAdapters(parsed); + } catch (e) { + setT2IAdapters([]); + } + }; + parse(); + }, [metadata]); + + const label = useMemo(() => handlers.t2iAdapters.getLabel(), []); + + return ( + <> + {t2iAdapters.map((t2iAdapter) => ( + + ))} + + ); +}; + +const MetadataViewT2IAdapter = ({ + label, + t2iAdapter, + handlers, +}: { + label: string; + t2iAdapter: T2IAdapterConfig; + handlers: MetadataHandlers; +}) => { + const onRecall = useCallback(() => { + if (!handlers.recallItem) { + return; + } + handlers.recallItem(t2iAdapter, true); + }, [handlers, t2iAdapter]); + + const renderedValue = useMemo(() => { + if (!handlers.renderItemValue) { + return null; + } + return handlers.renderItemValue(t2iAdapter); + }, [handlers, t2iAdapter]); + + return ; +}; diff --git a/invokeai/frontend/web/src/features/metadata/components/RecallButton.tsx b/invokeai/frontend/web/src/features/metadata/components/RecallButton.tsx new file mode 100644 index 0000000000..8803945b06 --- /dev/null +++ b/invokeai/frontend/web/src/features/metadata/components/RecallButton.tsx @@ -0,0 +1,27 @@ +import type { IconButtonProps} from '@invoke-ai/ui-library'; +import { IconButton, Tooltip } from '@invoke-ai/ui-library'; +import { memo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { PiArrowBendUpLeftBold } from 'react-icons/pi'; + +type MetadataItemProps = Omit & { + label: string; +}; + +export const RecallButton = memo(({ label, ...rest }: MetadataItemProps) => { + const { t } = useTranslation(); + + return ( + + } + size="xs" + variant="ghost" + {...rest} + /> + + ); +}); + +RecallButton.displayName = 'RecallButton'; diff --git a/invokeai/frontend/web/src/features/metadata/exceptions.ts b/invokeai/frontend/web/src/features/metadata/exceptions.ts new file mode 100644 index 0000000000..ffe4e98c79 --- /dev/null +++ b/invokeai/frontend/web/src/features/metadata/exceptions.ts @@ -0,0 +1,27 @@ +/** + * Raised when metadata parsing fails. + */ +export class MetadataParseError extends Error { + /** + * Create MetadataParseError + * @param {String} message + */ + constructor(message: string) { + super(message); + this.name = this.constructor.name; + } +} + +/** + * Raised when metadata recall fails. + */ +export class MetadataRecallError extends Error { + /** + * Create MetadataRecallError + * @param {String} message + */ + constructor(message: string) { + super(message); + this.name = this.constructor.name; + } +} diff --git a/invokeai/frontend/web/src/features/metadata/hooks/useMetadataItem.tsx b/invokeai/frontend/web/src/features/metadata/hooks/useMetadataItem.tsx new file mode 100644 index 0000000000..178ac5155a --- /dev/null +++ b/invokeai/frontend/web/src/features/metadata/hooks/useMetadataItem.tsx @@ -0,0 +1,51 @@ +import { Text } from '@invoke-ai/ui-library'; +import type { MetadataHandlers } from 'features/metadata/types'; +import { MetadataParseFailedToken, MetadataParsePendingToken } from 'features/metadata/util/parsers'; +import { useCallback, useEffect, useMemo, useState } from 'react'; + +export const useMetadataItem = (metadata: unknown, handlers: MetadataHandlers) => { + const [value, setValue] = useState( + MetadataParsePendingToken + ); + + useEffect(() => { + const _parse = async () => { + try { + const parsed = await handlers.parse(metadata); + setValue(parsed); + } catch (e) { + setValue(MetadataParseFailedToken); + } + }; + _parse(); + }, [handlers, metadata]); + + const isDisabled = useMemo(() => value === MetadataParsePendingToken || value === MetadataParseFailedToken, [value]); + + const label = useMemo(() => handlers.getLabel(), [handlers]); + + const renderedValue = useMemo(() => { + if (value === MetadataParsePendingToken) { + return Loading; + } + if (value === MetadataParseFailedToken) { + return Parsing Failed; + } + + const rendered = handlers.renderValue(value); + + if (typeof rendered === 'string') { + return {rendered}; + } + return rendered; + }, [handlers, value]); + + const onRecall = useCallback(() => { + if (!handlers.recall || value === MetadataParsePendingToken || value === MetadataParseFailedToken) { + return null; + } + handlers.recall(value, true); + }, [handlers, value]); + + return { label, isDisabled, value, renderedValue, onRecall }; +}; diff --git a/invokeai/frontend/web/src/features/metadata/types.ts b/invokeai/frontend/web/src/features/metadata/types.ts new file mode 100644 index 0000000000..366f11701e --- /dev/null +++ b/invokeai/frontend/web/src/features/metadata/types.ts @@ -0,0 +1,136 @@ +/** + * Renders a value of type T as a React node. + */ +export type MetadataRenderValueFunc = (value: T) => React.ReactNode; + +/** + * Gets the label of the current metadata item as a string. + */ +export type MetadataGetLabelFunc = () => string; + +export type MetadataParseOptions = { + toastOnFailure?: boolean; + toastOnSuccess?: boolean; +}; + +export type MetadataRecallOptions = MetadataParseOptions; + +/** + * A function that recalls a parsed and validated metadata value. + * + * @param value The value to recall. + * @throws MetadataRecallError if the value cannot be recalled. + */ +export type MetadataRecallFunc = (value: T) => void; + +/** + * An async function that receives metadata and returns a parsed value, throwing if the value is invalid or missing. + * + * The function receives an object of unknown type. It is responsible for extracting the relevant data from the metadata + * and returning a value of type T. + * + * The function should throw a MetadataParseError if the metadata is invalid or missing. + * + * @param metadata The metadata to parse. + * @returns A promise that resolves to the parsed value. + * @throws MetadataParseError if the metadata is invalid or missing. + */ +export type MetadataParseFunc = (metadata: unknown) => Promise; + +/** + * A function that performs additional validation logic before recalling a metadata value. It is called with a parsed + * value and should throw if the validation logic fails. + * + * This function is used in cases where some additional logic is required before recalling. For example, when recalling + * a LoRA, we need to check if it is compatible with the current base model. + * + * @param value The value to validate. + * @returns A promise that resolves to the validated value. + * @throws MetadataRecallError if the value is invalid. + */ +export type MetadataValidateFunc = (value: T) => Promise; + +export type MetadataHandlers = { + /** + * Gets the label of the current metadata item as a string. + * + * @returns The label of the current metadata item. + */ + getLabel: MetadataGetLabelFunc; + /** + * An async function that receives metadata and returns a parsed metadata value. + * + * @param metadata The metadata to parse. + * @param withToast Whether to show a toast on success or failure. + * @returns A promise that resolves to the parsed value. + * @throws MetadataParseError if the metadata is invalid or missing. + */ + parse: (metadata: unknown, withToast?: boolean) => Promise; + /** + * An async function that receives a metadata item and returns a parsed metadata item value. + * + * This is only provided if the metadata value is an array. + * + * @param item The item to parse. It should be an item from the array. + * @param withToast Whether to show a toast on success or failure. + * @returns A promise that resolves to the parsed value. + * @throws MetadataParseError if the metadata is invalid or missing. + */ + parseItem?: (item: unknown, withToast?: boolean) => Promise; + /** + * An async function that recalls a parsed metadata value. + * + * This function is only provided if the metadata value can be recalled. + * + * @param value The value to recall. + * @param withToast Whether to show a toast on success or failure. + * @returns A promise that resolves when the recall operation is complete. + * @throws MetadataRecallError if the value cannot be recalled. + */ + recall?: (value: TValue, withToast?: boolean) => Promise; + /** + * An async function that recalls a parsed metadata item value. + * + * This function is only provided if the metadata value is an array and the items can be recalled. + * + * @param item The item to recall. It should be an item from the array. + * @param withToast Whether to show a toast on success or failure. + * @returns A promise that resolves when the recall operation is complete. + * @throws MetadataRecallError if the value cannot be recalled. + */ + recallItem?: (item: TItem, withToast?: boolean) => Promise; + /** + * Renders a parsed metadata value as a React node. + * + * @param value The value to render. + * @returns The rendered value. + */ + renderValue: MetadataRenderValueFunc; + /** + * Renders a parsed metadata item value as a React node. + * + * @param item The item to render. + * @returns The rendered item. + */ + renderItemValue?: MetadataRenderValueFunc; +}; + +// TODO(psyche): The types for item handlers should be able to be inferred from the type of the value: +// type MetadataHandlersInferItem = TValue extends Array ? MetadataParseFunc : never +// While this works for the types as expected, I couldn't satisfy TS in the implementations of the handlers. + +export type BuildMetadataHandlersArg = { + parser: MetadataParseFunc; + itemParser?: MetadataParseFunc; + recaller?: MetadataRecallFunc; + itemRecaller?: MetadataRecallFunc; + validator?: MetadataValidateFunc; + itemValidator?: MetadataValidateFunc; + getLabel: MetadataGetLabelFunc; + renderValue?: MetadataRenderValueFunc; + renderItemValue?: MetadataRenderValueFunc; +}; + +export type BuildMetadataHandlers = ( + arg: BuildMetadataHandlersArg +) => MetadataHandlers; diff --git a/invokeai/frontend/web/src/features/metadata/util/handlers.tsx b/invokeai/frontend/web/src/features/metadata/util/handlers.tsx new file mode 100644 index 0000000000..1f86257d42 --- /dev/null +++ b/invokeai/frontend/web/src/features/metadata/util/handlers.tsx @@ -0,0 +1,316 @@ +import { Text } from '@invoke-ai/ui-library'; +import { toast } from 'common/util/toast'; +import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types'; +import type { LoRA } from 'features/lora/store/loraSlice'; +import type { + BuildMetadataHandlers, + MetadataGetLabelFunc, + MetadataHandlers, + MetadataParseFunc, + MetadataRecallFunc, + MetadataRenderValueFunc, + MetadataValidateFunc, +} from 'features/metadata/types'; +import { validators } from 'features/metadata/util/validators'; +import { t } from 'i18next'; +import type { AnyModelConfig } from 'services/api/types'; + +import { parsers } from './parsers'; +import { recallers } from './recallers'; + +const renderModelConfigValue: MetadataRenderValueFunc = (value) => + `${value.name} (${value.base.toUpperCase()}, ${value.key})`; +const renderLoRAValue: MetadataRenderValueFunc = (value) => {`${value.model.key} (${value.weight})`}; +const renderControlAdapterValue: MetadataRenderValueFunc = ( + value +) => {`${value.model?.key} (${value.weight})`}; + +const parameterSetToast = (parameter: string, description?: string) => { + toast({ + title: t('toast.parameterSet', { parameter }), + description, + status: 'info', + duration: 2500, + isClosable: true, + }); +}; + +const parameterNotSetToast = (parameter: string, description?: string) => { + toast({ + title: t('toast.parameterNotSet', { parameter }), + description, + status: 'warning', + duration: 2500, + isClosable: true, + }); +}; + +// const allParameterSetToast = (description?: string) => { +// toast({ +// title: t('toast.parametersSet'), +// status: 'info', +// description, +// duration: 2500, +// isClosable: true, +// }); +// }; + +// const allParameterNotSetToast = (description?: string) => { +// toast({ +// title: t('toast.parametersNotSet'), +// status: 'warning', +// description, +// duration: 2500, +// isClosable: true, +// }); +// }; + +const buildParse = + (arg: { + parser: MetadataParseFunc; + getLabel: MetadataGetLabelFunc; + }): MetadataHandlers['parse'] => + async (metadata, withToast = false) => { + try { + const parsed = await arg.parser(metadata); + withToast && parameterSetToast(arg.getLabel()); + return parsed; + } catch (e) { + withToast && parameterNotSetToast(arg.getLabel(), (e as Error).message); + throw e; + } + }; + +const buildParseItem = + (arg: { + itemParser: MetadataParseFunc; + getLabel: MetadataGetLabelFunc; + }): MetadataHandlers['parseItem'] => + async (item, withToast = false) => { + try { + const parsed = await arg.itemParser(item); + withToast && parameterSetToast(arg.getLabel()); + return parsed; + } catch (e) { + withToast && parameterNotSetToast(arg.getLabel(), (e as Error).message); + throw e; + } + }; + +const buildRecall = + (arg: { + recaller: MetadataRecallFunc; + validator?: MetadataValidateFunc; + getLabel: MetadataGetLabelFunc; + }): NonNullable['recall']> => + async (value, withToast = false) => { + try { + arg.validator && (await arg.validator(value)); + await arg.recaller(value); + withToast && parameterSetToast(arg.getLabel()); + } catch (e) { + withToast && parameterNotSetToast(arg.getLabel(), (e as Error).message); + throw e; + } + }; + +const buildRecallItem = + (arg: { + itemRecaller: MetadataRecallFunc; + itemValidator?: MetadataValidateFunc; + getLabel: MetadataGetLabelFunc; + }): NonNullable['recallItem']> => + async (item, withToast = false) => { + try { + arg.itemValidator && (await arg.itemValidator(item)); + await arg.itemRecaller(item); + withToast && parameterSetToast(arg.getLabel()); + } catch (e) { + withToast && parameterNotSetToast(arg.getLabel(), (e as Error).message); + throw e; + } + }; + +const buildHandlers: BuildMetadataHandlers = ({ + getLabel, + parser, + itemParser, + recaller, + itemRecaller, + validator, + itemValidator, + renderValue, + renderItemValue, +}) => ({ + parse: buildParse({ parser, getLabel }), + parseItem: itemParser ? buildParseItem({ itemParser, getLabel }) : undefined, + recall: recaller ? buildRecall({ recaller, validator, getLabel }) : undefined, + recallItem: itemRecaller ? buildRecallItem({ itemRecaller, itemValidator, getLabel }) : undefined, + getLabel, + renderValue: renderValue ?? String, + renderItemValue: renderItemValue ?? String, +}); + +export const handlers = { + // Misc + createdBy: buildHandlers({ getLabel: () => t('metadata.createdBy'), parser: parsers.createdBy }), + generationMode: buildHandlers({ getLabel: () => t('metadata.generationMode'), parser: parsers.generationMode }), + + // Core parameters + cfgRescaleMultiplier: buildHandlers({ + getLabel: () => t('metadata.cfgRescaleMultiplier'), + parser: parsers.cfgRescaleMultiplier, + recaller: recallers.cfgRescaleMultiplier, + }), + cfgScale: buildHandlers({ + getLabel: () => t('metadata.cfgScale'), + parser: parsers.cfgScale, + recaller: recallers.cfgScale, + }), + height: buildHandlers({ getLabel: () => t('metadata.height'), parser: parsers.height, recaller: recallers.height }), + negativePrompt: buildHandlers({ + getLabel: () => t('metadata.negativePrompt'), + parser: parsers.negativePrompt, + recaller: recallers.negativePrompt, + }), + positivePrompt: buildHandlers({ + getLabel: () => t('metadata.positivePrompt'), + parser: parsers.positivePrompt, + recaller: recallers.positivePrompt, + }), + scheduler: buildHandlers({ + getLabel: () => t('metadata.scheduler'), + parser: parsers.scheduler, + recaller: recallers.scheduler, + }), + sdxlNegativeStylePrompt: buildHandlers({ + getLabel: () => t('sdxl.negStylePrompt'), + parser: parsers.sdxlNegativeStylePrompt, + recaller: recallers.sdxlNegativeStylePrompt, + }), + sdxlPositiveStylePrompt: buildHandlers({ + getLabel: () => t('sdxl.posStylePrompt'), + parser: parsers.sdxlPositiveStylePrompt, + recaller: recallers.sdxlPositiveStylePrompt, + }), + seed: buildHandlers({ getLabel: () => t('metadata.seed'), parser: parsers.seed, recaller: recallers.seed }), + steps: buildHandlers({ getLabel: () => t('metadata.steps'), parser: parsers.steps, recaller: recallers.steps }), + strength: buildHandlers({ + getLabel: () => t('metadata.strength'), + parser: parsers.strength, + recaller: recallers.strength, + }), + width: buildHandlers({ getLabel: () => t('metadata.width'), parser: parsers.width, recaller: recallers.width }), + + // HRF + hrfEnabled: buildHandlers({ + getLabel: () => t('hrf.metadata.enabled'), + parser: parsers.hrfEnabled, + recaller: recallers.hrfEnabled, + }), + hrfMethod: buildHandlers({ + getLabel: () => t('hrf.metadata.method'), + parser: parsers.hrfMethod, + recaller: recallers.hrfMethod, + }), + hrfStrength: buildHandlers({ + getLabel: () => t('hrf.metadata.strength'), + parser: parsers.hrfStrength, + recaller: recallers.hrfStrength, + }), + + // Refiner + refinerCFGScale: buildHandlers({ + getLabel: () => t('sdxl.cfgScale'), + parser: parsers.refinerCFGScale, + recaller: recallers.refinerCFGScale, + }), + refinerModel: buildHandlers({ + getLabel: () => t('sdxl.refinerModel'), + parser: parsers.refinerModel, + recaller: recallers.refinerModel, + validator: validators.refinerModel, + }), + refinerNegativeAestheticScore: buildHandlers({ + getLabel: () => t('sdxl.posAestheticScore'), + parser: parsers.refinerNegativeAestheticScore, + recaller: recallers.refinerNegativeAestheticScore, + }), + refinerPositiveAestheticScore: buildHandlers({ + getLabel: () => t('sdxl.negAestheticScore'), + parser: parsers.refinerPositiveAestheticScore, + recaller: recallers.refinerPositiveAestheticScore, + }), + refinerScheduler: buildHandlers({ + getLabel: () => t('sdxl.scheduler'), + parser: parsers.refinerScheduler, + recaller: recallers.refinerScheduler, + }), + refinerStart: buildHandlers({ + getLabel: () => t('sdxl.refiner_start'), + parser: parsers.refinerStart, + recaller: recallers.refinerStart, + }), + refinerSteps: buildHandlers({ + getLabel: () => t('sdxl.refiner_steps'), + parser: parsers.refinerSteps, + recaller: recallers.refinerSteps, + }), + + // Models + model: buildHandlers({ + getLabel: () => t('metadata.model'), + parser: parsers.mainModel, + recaller: recallers.model, + renderValue: renderModelConfigValue, + }), + vae: buildHandlers({ + getLabel: () => t('metadata.vae'), + parser: parsers.vaeModel, + recaller: recallers.vae, + renderValue: renderModelConfigValue, + validator: validators.vaeModel, + }), + + // Arrays of models + controlNets: buildHandlers({ + getLabel: () => t('common.controlNet'), + parser: parsers.controlNets, + itemParser: parsers.controlNet, + recaller: recallers.controlNets, + itemRecaller: recallers.controlNet, + validator: validators.controlNets, + itemValidator: validators.controlNet, + renderItemValue: renderControlAdapterValue, + }), + ipAdapters: buildHandlers({ + getLabel: () => t('common.ipAdapter'), + parser: parsers.ipAdapters, + itemParser: parsers.ipAdapter, + recaller: recallers.ipAdapters, + itemRecaller: recallers.ipAdapter, + validator: validators.ipAdapters, + itemValidator: validators.ipAdapter, + renderItemValue: renderControlAdapterValue, + }), + loras: buildHandlers({ + getLabel: () => t('models.lora'), + parser: parsers.loras, + itemParser: parsers.lora, + recaller: recallers.loras, + itemRecaller: recallers.lora, + validator: validators.loras, + itemValidator: validators.lora, + renderItemValue: renderLoRAValue, + }), + t2iAdapters: buildHandlers({ + getLabel: () => t('common.t2iAdapter'), + parser: parsers.t2iAdapters, + itemParser: parsers.t2iAdapter, + recaller: recallers.t2iAdapters, + itemRecaller: recallers.t2iAdapter, + validator: validators.t2iAdapters, + itemValidator: validators.t2iAdapter, + renderItemValue: renderControlAdapterValue, + }), +} as const; diff --git a/invokeai/frontend/web/src/features/metadata/util/parsers.ts b/invokeai/frontend/web/src/features/metadata/util/parsers.ts new file mode 100644 index 0000000000..c7f2a6d09f --- /dev/null +++ b/invokeai/frontend/web/src/features/metadata/util/parsers.ts @@ -0,0 +1,396 @@ +import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants'; +import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types'; +import { + initialControlNet, + initialIPAdapter, + initialT2IAdapter, +} from 'features/controlAdapters/util/buildControlAdapter'; +import type { LoRA } from 'features/lora/store/loraSlice'; +import { defaultLoRAConfig } from 'features/lora/store/loraSlice'; +import { MetadataParseError } from 'features/metadata/exceptions'; +import type { MetadataParseFunc } from 'features/metadata/types'; +import { + zControlField, + zIPAdapterField, + zModelIdentifierWithBase, + zT2IAdapterField, +} from 'features/nodes/types/common'; +import type { + ParameterCFGRescaleMultiplier, + ParameterCFGScale, + ParameterHeight, + ParameterHRFEnabled, + ParameterHRFMethod, + ParameterNegativePrompt, + ParameterNegativeStylePromptSDXL, + ParameterPositivePrompt, + ParameterPositiveStylePromptSDXL, + ParameterScheduler, + ParameterSDXLRefinerNegativeAestheticScore, + ParameterSDXLRefinerPositiveAestheticScore, + ParameterSDXLRefinerStart, + ParameterSeed, + ParameterSteps, + ParameterStrength, + ParameterWidth, +} from 'features/parameters/types/parameterSchemas'; +import { + isParameterCFGRescaleMultiplier, + isParameterCFGScale, + isParameterHeight, + isParameterHRFEnabled, + isParameterHRFMethod, + isParameterLoRAWeight, + isParameterNegativePrompt, + isParameterNegativeStylePromptSDXL, + isParameterPositivePrompt, + isParameterPositiveStylePromptSDXL, + isParameterScheduler, + isParameterSDXLRefinerNegativeAestheticScore, + isParameterSDXLRefinerPositiveAestheticScore, + isParameterSDXLRefinerStart, + isParameterSeed, + isParameterSteps, + isParameterStrength, + isParameterWidth, +} from 'features/parameters/types/parameterSchemas'; +import { + fetchModelConfigWithTypeGuard, + getModelKey, + getModelKeyAndBase, +} from 'features/parameters/util/modelFetchingHelpers'; +import { get, isArray, isString } from 'lodash-es'; +import type { NonRefinerMainModelConfig, RefinerMainModelConfig, VAEModelConfig } from 'services/api/types'; +import { + isControlNetModelConfig, + isIPAdapterModelConfig, + isLoRAModelConfig, + isNonRefinerMainModelConfig, + isRefinerMainModelModelConfig, + isT2IAdapterModelConfig, + isVAEModelConfig, +} from 'services/api/types'; +import { v4 as uuidv4 } from 'uuid'; + +export const MetadataParsePendingToken = Symbol('pending'); +export const MetadataParseFailedToken = Symbol('failed'); + +/** + * An async function that a property from an object and validates its type using a type guard. If the property is missing + * or invalid, the function should throw a MetadataParseError. + * @param obj The object to get the property from. + * @param property The property to get. + * @param typeGuard A type guard function to check the type of the property. Provide `undefined` to opt out of type + * validation and always return the property value. + * @returns A promise that resolves to the property value if it exists and is of the expected type. + * @throws MetadataParseError if a type guard is provided and the property is not of the expected type. + */ +const getProperty = ( + obj: unknown, + property: string, + typeGuard: (val: unknown) => val is T = (val: unknown): val is T => true +): Promise => { + return new Promise((resolve, reject) => { + const val = get(obj, property) as unknown; + if (typeGuard(val)) { + resolve(val); + } + reject(new MetadataParseError(`Property ${property} is not of expected type`)); + }); +}; + +const parseCreatedBy: MetadataParseFunc = (metadata) => getProperty(metadata, 'created_by', isString); + +const parseGenerationMode: MetadataParseFunc = (metadata) => getProperty(metadata, 'generation_mode', isString); + +const parsePositivePrompt: MetadataParseFunc = (metadata) => + getProperty(metadata, 'positive_prompt', isParameterPositivePrompt); + +const parseNegativePrompt: MetadataParseFunc = (metadata) => + getProperty(metadata, 'negative_prompt', isParameterNegativePrompt); + +const parseSDXLPositiveStylePrompt: MetadataParseFunc = (metadata) => + getProperty(metadata, 'positive_style_prompt', isParameterPositiveStylePromptSDXL); + +const parseSDXLNegativeStylePrompt: MetadataParseFunc = (metadata) => + getProperty(metadata, 'negative_style_prompt', isParameterNegativeStylePromptSDXL); + +const parseSeed: MetadataParseFunc = (metadata) => getProperty(metadata, 'seed', isParameterSeed); + +const parseCFGScale: MetadataParseFunc = (metadata) => + getProperty(metadata, 'cfg_scale', isParameterCFGScale); + +const parseCFGRescaleMultiplier: MetadataParseFunc = (metadata) => + getProperty(metadata, 'cfg_rescale_multiplier', isParameterCFGRescaleMultiplier); + +const parseScheduler: MetadataParseFunc = (metadata) => + getProperty(metadata, 'scheduler', isParameterScheduler); + +const parseWidth: MetadataParseFunc = (metadata) => getProperty(metadata, 'width', isParameterWidth); + +const parseHeight: MetadataParseFunc = (metadata) => + getProperty(metadata, 'height', isParameterHeight); + +const parseSteps: MetadataParseFunc = (metadata) => getProperty(metadata, 'steps', isParameterSteps); + +const parseStrength: MetadataParseFunc = (metadata) => + getProperty(metadata, 'strength', isParameterStrength); + +const parseHRFEnabled: MetadataParseFunc = (metadata) => + getProperty(metadata, 'hrf_enabled', isParameterHRFEnabled); + +const parseHRFStrength: MetadataParseFunc = (metadata) => + getProperty(metadata, 'hrf_strength', isParameterStrength); + +const parseHRFMethod: MetadataParseFunc = (metadata) => + getProperty(metadata, 'hrf_method', isParameterHRFMethod); + +const parseRefinerSteps: MetadataParseFunc = (metadata) => + getProperty(metadata, 'refiner_steps', isParameterSteps); + +const parseRefinerCFGScale: MetadataParseFunc = (metadata) => + getProperty(metadata, 'refiner_cfg_scale', isParameterCFGScale); + +const parseRefinerScheduler: MetadataParseFunc = (metadata) => + getProperty(metadata, 'refiner_scheduler', isParameterScheduler); + +const parseRefinerPositiveAestheticScore: MetadataParseFunc = (metadata) => + getProperty(metadata, 'refiner_positive_aesthetic_score', isParameterSDXLRefinerPositiveAestheticScore); + +const parseRefinerNegativeAestheticScore: MetadataParseFunc = (metadata) => + getProperty(metadata, 'refiner_negative_aesthetic_score', isParameterSDXLRefinerNegativeAestheticScore); + +const parseRefinerStart: MetadataParseFunc = (metadata) => + getProperty(metadata, 'refiner_start', isParameterSDXLRefinerStart); + +const parseMainModel: MetadataParseFunc = async (metadata) => { + const model = await getProperty(metadata, 'model', undefined); + const key = await getModelKey(model, 'main'); + const mainModelConfig = await fetchModelConfigWithTypeGuard(key, isNonRefinerMainModelConfig); + return mainModelConfig; +}; + +const parseRefinerModel: MetadataParseFunc = async (metadata) => { + const refiner_model = await getProperty(metadata, 'refiner_model', undefined); + const key = await getModelKey(refiner_model, 'main'); + const refinerModelConfig = await fetchModelConfigWithTypeGuard(key, isRefinerMainModelModelConfig); + return refinerModelConfig; +}; + +const parseVAEModel: MetadataParseFunc = async (metadata) => { + const vae = await getProperty(metadata, 'vae', undefined); + const key = await getModelKey(vae, 'vae'); + const vaeModelConfig = await fetchModelConfigWithTypeGuard(key, isVAEModelConfig); + return vaeModelConfig; +}; + +const parseLoRA: MetadataParseFunc = async (metadataItem) => { + // Previously, the LoRA model identifier parts were stored in the LoRA metadata: `{key: ..., weight: 0.75}` + const modelV1 = await getProperty(metadataItem, 'lora', undefined); + // Now, the LoRA model is stored in a `model` property of the LoRA metadata: `{model: {key: ...}, weight: 0.75}` + const modelV2 = await getProperty(metadataItem, 'model', undefined); + const weight = await getProperty(metadataItem, 'weight', undefined); + const key = await getModelKey(modelV2 ?? modelV1, 'lora'); + const loraModelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig); + + return { + model: getModelKeyAndBase(loraModelConfig), + weight: isParameterLoRAWeight(weight) ? weight : defaultLoRAConfig.weight, + isEnabled: true, + }; +}; + +const parseAllLoRAs: MetadataParseFunc = async (metadata) => { + const lorasRaw = await getProperty(metadata, 'loras', isArray); + const parseResults = await Promise.allSettled(lorasRaw.map((lora) => parseLoRA(lora))); + const loras = parseResults + .filter((result): result is PromiseFulfilledResult => result.status === 'fulfilled') + .map((result) => result.value); + return loras; +}; + +const parseControlNet: MetadataParseFunc = async (metadataItem) => { + const control_model = await getProperty(metadataItem, 'control_model'); + const key = await getModelKey(control_model, 'controlnet'); + const controlNetModel = await fetchModelConfigWithTypeGuard(key, isControlNetModelConfig); + + const image = zControlField.shape.image.nullish().catch(null).parse(getProperty(metadataItem, 'image')); + const control_weight = zControlField.shape.control_weight + .nullish() + .catch(null) + .parse(getProperty(metadataItem, 'control_weight')); + const begin_step_percent = zControlField.shape.begin_step_percent + .nullish() + .catch(null) + .parse(getProperty(metadataItem, 'begin_step_percent')); + const end_step_percent = zControlField.shape.end_step_percent + .nullish() + .catch(null) + .parse(getProperty(metadataItem, 'end_step_percent')); + const control_mode = zControlField.shape.control_mode + .nullish() + .catch(null) + .parse(getProperty(metadataItem, 'control_mode')); + const resize_mode = zControlField.shape.resize_mode + .nullish() + .catch(null) + .parse(getProperty(metadataItem, 'resize_mode')); + + const processorType = 'none'; + const processorNode = CONTROLNET_PROCESSORS.none.default; + + const controlNet: ControlNetConfig = { + type: 'controlnet', + isEnabled: true, + model: zModelIdentifierWithBase.parse(controlNetModel), + weight: typeof control_weight === 'number' ? control_weight : initialControlNet.weight, + beginStepPct: begin_step_percent ?? initialControlNet.beginStepPct, + endStepPct: end_step_percent ?? initialControlNet.endStepPct, + controlMode: control_mode ?? initialControlNet.controlMode, + resizeMode: resize_mode ?? initialControlNet.resizeMode, + controlImage: image?.image_name ?? null, + processedControlImage: image?.image_name ?? null, + processorType, + processorNode, + shouldAutoConfig: true, + id: uuidv4(), + }; + + return controlNet; +}; + +const parseAllControlNets: MetadataParseFunc = async (metadata) => { + const controlNetsRaw = await getProperty(metadata, 'controlnets', isArray); + const parseResults = await Promise.allSettled(controlNetsRaw.map((cn) => parseControlNet(cn))); + const controlNets = parseResults + .filter((result): result is PromiseFulfilledResult => result.status === 'fulfilled') + .map((result) => result.value); + return controlNets; +}; + +const parseT2IAdapter: MetadataParseFunc = async (metadataItem) => { + const t2i_adapter_model = await getProperty(metadataItem, 't2i_adapter_model'); + const key = await getModelKey(t2i_adapter_model, 't2i_adapter'); + const t2iAdapterModel = await fetchModelConfigWithTypeGuard(key, isT2IAdapterModelConfig); + + const image = zT2IAdapterField.shape.image.nullish().catch(null).parse(getProperty(metadataItem, 'image')); + const weight = zT2IAdapterField.shape.weight.nullish().catch(null).parse(getProperty(metadataItem, 'weight')); + const begin_step_percent = zT2IAdapterField.shape.begin_step_percent + .nullish() + .catch(null) + .parse(getProperty(metadataItem, 'begin_step_percent')); + const end_step_percent = zT2IAdapterField.shape.end_step_percent + .nullish() + .catch(null) + .parse(getProperty(metadataItem, 'end_step_percent')); + const resize_mode = zT2IAdapterField.shape.resize_mode + .nullish() + .catch(null) + .parse(getProperty(metadataItem, 'resize_mode')); + + const processorType = 'none'; + const processorNode = CONTROLNET_PROCESSORS.none.default; + + const t2iAdapter: T2IAdapterConfig = { + type: 't2i_adapter', + isEnabled: true, + model: zModelIdentifierWithBase.parse(t2iAdapterModel), + weight: typeof weight === 'number' ? weight : initialT2IAdapter.weight, + beginStepPct: begin_step_percent ?? initialT2IAdapter.beginStepPct, + endStepPct: end_step_percent ?? initialT2IAdapter.endStepPct, + resizeMode: resize_mode ?? initialT2IAdapter.resizeMode, + controlImage: image?.image_name ?? null, + processedControlImage: image?.image_name ?? null, + processorType, + processorNode, + shouldAutoConfig: true, + id: uuidv4(), + }; + + return t2iAdapter; +}; + +const parseAllT2IAdapters: MetadataParseFunc = async (metadata) => { + const t2iAdaptersRaw = await getProperty(metadata, 't2iAdapters', isArray); + const parseResults = await Promise.allSettled(t2iAdaptersRaw.map((t2iAdapter) => parseT2IAdapter(t2iAdapter))); + const t2iAdapters = parseResults + .filter((result): result is PromiseFulfilledResult => result.status === 'fulfilled') + .map((result) => result.value); + return t2iAdapters; +}; + +const parseIPAdapter: MetadataParseFunc = async (metadataItem) => { + const ip_adapter_model = await getProperty(metadataItem, 'ip_adapter_model'); + const key = await getModelKey(ip_adapter_model, 'ip_adapter'); + const ipAdapterModel = await fetchModelConfigWithTypeGuard(key, isIPAdapterModelConfig); + + const image = zIPAdapterField.shape.image.nullish().catch(null).parse(getProperty(metadataItem, 'image')); + const weight = zIPAdapterField.shape.weight.nullish().catch(null).parse(getProperty(metadataItem, 'weight')); + const begin_step_percent = zIPAdapterField.shape.begin_step_percent + .nullish() + .catch(null) + .parse(getProperty(metadataItem, 'begin_step_percent')); + const end_step_percent = zIPAdapterField.shape.end_step_percent + .nullish() + .catch(null) + .parse(getProperty(metadataItem, 'end_step_percent')); + + const ipAdapter: IPAdapterConfig = { + id: uuidv4(), + type: 'ip_adapter', + isEnabled: true, + model: zModelIdentifierWithBase.parse(ipAdapterModel), + controlImage: image?.image_name ?? null, + weight: weight ?? initialIPAdapter.weight, + beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct, + endStepPct: end_step_percent ?? initialIPAdapter.endStepPct, + }; + + return ipAdapter; +}; + +const parseAllIPAdapters: MetadataParseFunc = async (metadata) => { + const ipAdaptersRaw = await getProperty(metadata, 'ipAdapters', isArray); + const parseResults = await Promise.allSettled(ipAdaptersRaw.map((ipAdapter) => parseIPAdapter(ipAdapter))); + const ipAdapters = parseResults + .filter((result): result is PromiseFulfilledResult => result.status === 'fulfilled') + .map((result) => result.value); + return ipAdapters; +}; + +export const parsers = { + createdBy: parseCreatedBy, + generationMode: parseGenerationMode, + positivePrompt: parsePositivePrompt, + negativePrompt: parseNegativePrompt, + sdxlPositiveStylePrompt: parseSDXLPositiveStylePrompt, + sdxlNegativeStylePrompt: parseSDXLNegativeStylePrompt, + seed: parseSeed, + cfgScale: parseCFGScale, + cfgRescaleMultiplier: parseCFGRescaleMultiplier, + scheduler: parseScheduler, + width: parseWidth, + height: parseHeight, + steps: parseSteps, + strength: parseStrength, + hrfEnabled: parseHRFEnabled, + hrfStrength: parseHRFStrength, + hrfMethod: parseHRFMethod, + refinerSteps: parseRefinerSteps, + refinerCFGScale: parseRefinerCFGScale, + refinerScheduler: parseRefinerScheduler, + refinerPositiveAestheticScore: parseRefinerPositiveAestheticScore, + refinerNegativeAestheticScore: parseRefinerNegativeAestheticScore, + refinerStart: parseRefinerStart, + mainModel: parseMainModel, + refinerModel: parseRefinerModel, + vaeModel: parseVAEModel, + lora: parseLoRA, + loras: parseAllLoRAs, + controlNet: parseControlNet, + controlNets: parseAllControlNets, + t2iAdapter: parseT2IAdapter, + t2iAdapters: parseAllT2IAdapters, + ipAdapter: parseIPAdapter, + ipAdapters: parseAllIPAdapters, +} as const; diff --git a/invokeai/frontend/web/src/features/metadata/util/recallers.ts b/invokeai/frontend/web/src/features/metadata/util/recallers.ts new file mode 100644 index 0000000000..e33589d5ce --- /dev/null +++ b/invokeai/frontend/web/src/features/metadata/util/recallers.ts @@ -0,0 +1,295 @@ +import { getStore } from 'app/store/nanostores/store'; +import { controlAdapterRecalled } from 'features/controlAdapters/store/controlAdaptersSlice'; +import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types'; +import { setHrfEnabled, setHrfMethod, setHrfStrength } from 'features/hrf/store/hrfSlice'; +import type { LoRA } from 'features/lora/store/loraSlice'; +import { loraRecalled } from 'features/lora/store/loraSlice'; +import type { MetadataRecallFunc } from 'features/metadata/types'; +import { zModelIdentifierWithBase } from 'features/nodes/types/common'; +import { modelSelected } from 'features/parameters/store/actions'; +import { + heightRecalled, + setCfgRescaleMultiplier, + setCfgScale, + setImg2imgStrength, + setNegativePrompt, + setPositivePrompt, + setScheduler, + setSeed, + setSteps, + vaeSelected, + widthRecalled, +} from 'features/parameters/store/generationSlice'; +import type { + ParameterCFGRescaleMultiplier, + ParameterCFGScale, + ParameterHeight, + ParameterHRFEnabled, + ParameterHRFMethod, + ParameterNegativePrompt, + ParameterNegativeStylePromptSDXL, + ParameterPositivePrompt, + ParameterPositiveStylePromptSDXL, + ParameterScheduler, + ParameterSDXLRefinerNegativeAestheticScore, + ParameterSDXLRefinerPositiveAestheticScore, + ParameterSDXLRefinerStart, + ParameterSeed, + ParameterSteps, + ParameterStrength, + ParameterWidth, +} from 'features/parameters/types/parameterSchemas'; +import { + refinerModelChanged, + setNegativeStylePromptSDXL, + setPositiveStylePromptSDXL, + setRefinerCFGScale, + setRefinerNegativeAestheticScore, + setRefinerPositiveAestheticScore, + setRefinerScheduler, + setRefinerStart, + setRefinerSteps, +} from 'features/sdxl/store/sdxlSlice'; +import type { NonRefinerMainModelConfig, RefinerMainModelConfig, VAEModelConfig } from 'services/api/types'; + +const recallPositivePrompt: MetadataRecallFunc = (positivePrompt) => { + getStore().dispatch(setPositivePrompt(positivePrompt)); +}; + +const recallNegativePrompt: MetadataRecallFunc = (negativePrompt) => { + getStore().dispatch(setNegativePrompt(negativePrompt)); +}; + +const recallSDXLPositiveStylePrompt: MetadataRecallFunc = (positiveStylePrompt) => { + getStore().dispatch(setPositiveStylePromptSDXL(positiveStylePrompt)); +}; + +const recallSDXLNegativeStylePrompt: MetadataRecallFunc = (negativeStylePrompt) => { + getStore().dispatch(setNegativeStylePromptSDXL(negativeStylePrompt)); +}; + +const recallSeed: MetadataRecallFunc = (seed) => { + getStore().dispatch(setSeed(seed)); +}; + +const recallCFGScale: MetadataRecallFunc = (cfgScale) => { + getStore().dispatch(setCfgScale(cfgScale)); +}; + +const recallCFGRescaleMultiplier: MetadataRecallFunc = (cfgRescaleMultiplier) => { + getStore().dispatch(setCfgRescaleMultiplier(cfgRescaleMultiplier)); +}; + +const recallScheduler: MetadataRecallFunc = (scheduler) => { + getStore().dispatch(setScheduler(scheduler)); +}; + +const recallWidth: MetadataRecallFunc = (width) => { + getStore().dispatch(widthRecalled(width)); +}; + +const recallHeight: MetadataRecallFunc = (height) => { + getStore().dispatch(heightRecalled(height)); +}; + +const recallSteps: MetadataRecallFunc = (steps) => { + getStore().dispatch(setSteps(steps)); +}; + +const recallStrength: MetadataRecallFunc = (strength) => { + getStore().dispatch(setImg2imgStrength(strength)); +}; + +const recallHRFEnabled: MetadataRecallFunc = (hrfEnabled) => { + getStore().dispatch(setHrfEnabled(hrfEnabled)); +}; + +const recallHRFStrength: MetadataRecallFunc = (hrfStrength) => { + getStore().dispatch(setHrfStrength(hrfStrength)); +}; + +const recallHRFMethod: MetadataRecallFunc = (hrfMethod) => { + getStore().dispatch(setHrfMethod(hrfMethod)); +}; + +const recallRefinerSteps: MetadataRecallFunc = (refinerSteps) => { + getStore().dispatch(setRefinerSteps(refinerSteps)); +}; + +const recallRefinerCFGScale: MetadataRecallFunc = (refinerCFGScale) => { + getStore().dispatch(setRefinerCFGScale(refinerCFGScale)); +}; + +const recallRefinerScheduler: MetadataRecallFunc = (refinerScheduler) => { + getStore().dispatch(setRefinerScheduler(refinerScheduler)); +}; + +const recallRefinerPositiveAestheticScore: MetadataRecallFunc = ( + refinerPositiveAestheticScore +) => { + getStore().dispatch(setRefinerPositiveAestheticScore(refinerPositiveAestheticScore)); +}; + +const recallRefinerNegativeAestheticScore: MetadataRecallFunc = ( + refinerNegativeAestheticScore +) => { + getStore().dispatch(setRefinerNegativeAestheticScore(refinerNegativeAestheticScore)); +}; + +const recallRefinerStart: MetadataRecallFunc = (refinerStart) => { + getStore().dispatch(setRefinerStart(refinerStart)); +}; + +const recallModel: MetadataRecallFunc = (model) => { + const modelIdentifier = zModelIdentifierWithBase.parse(model); + getStore().dispatch(modelSelected(modelIdentifier)); +}; + +const recallRefinerModel: MetadataRecallFunc = (refinerModel) => { + const modelIdentifier = zModelIdentifierWithBase.parse(refinerModel); + getStore().dispatch(refinerModelChanged(modelIdentifier)); +}; + +const recallVAE: MetadataRecallFunc = (vaeModel) => { + if (!vaeModel) { + getStore().dispatch(vaeSelected(null)); + return; + } + const modelIdentifier = zModelIdentifierWithBase.parse(vaeModel); + getStore().dispatch(vaeSelected(modelIdentifier)); +}; + +const recallLoRA: MetadataRecallFunc = (lora) => { + getStore().dispatch(loraRecalled(lora)); +}; + +const recallAllLoRAs: MetadataRecallFunc = (loras) => { + const { dispatch } = getStore(); + loras.forEach((lora) => { + dispatch(loraRecalled(lora)); + }); +}; + +const recallControlNet: MetadataRecallFunc = (controlNet) => { + getStore().dispatch(controlAdapterRecalled(controlNet)); +}; + +const recallControlNets: MetadataRecallFunc = (controlNets) => { + const { dispatch } = getStore(); + controlNets.forEach((controlNet) => { + dispatch(controlAdapterRecalled(controlNet)); + }); +}; + +const recallT2IAdapter: MetadataRecallFunc = (t2iAdapter) => { + getStore().dispatch(controlAdapterRecalled(t2iAdapter)); +}; + +const recallT2IAdapters: MetadataRecallFunc = (t2iAdapters) => { + const { dispatch } = getStore(); + t2iAdapters.forEach((t2iAdapter) => { + dispatch(controlAdapterRecalled(t2iAdapter)); + }); +}; + +const recallIPAdapter: MetadataRecallFunc = (ipAdapter) => { + getStore().dispatch(controlAdapterRecalled(ipAdapter)); +}; + +const recallIPAdapters: MetadataRecallFunc = (ipAdapters) => { + const { dispatch } = getStore(); + ipAdapters.forEach((ipAdapter) => { + dispatch(controlAdapterRecalled(ipAdapter)); + }); +}; + +export const recallPrompts = (_metadata: unknown) => { + // recallPositivePrompt(metadata); + // recallNegativePrompt(metadata); + // recallSDXLPositiveStylePrompt(metadata); + // recallSDXLNegativeStylePrompt(metadata); + // parameterNotSetToast(t('metadata.allPrompts')); + // parameterSetToast(t('metadata.allPrompts')); +}; + +export const recallWidthAndHeight = (_metadata: unknown) => { + // recallWidth(metadata); + // recallHeight(metadata); +}; + +export const recallAll = async (_metadata: unknown) => { + // if (!metadata) { + // allParameterNotSetToast(); + // return; + // } + // // Update the main model first, as other parameters may depend on it. + // await recallModel(metadata); + // await Promise.allSettled([ + // // Core parameters + // recallCFGScale(metadata), + // recallCFGRescaleMultiplier(metadata), + // recallPositivePrompt(metadata), + // recallNegativePrompt(metadata), + // recallScheduler(metadata), + // recallSeed(metadata), + // recallSteps(metadata), + // recallWidth(metadata), + // recallHeight(metadata), + // recallStrength(metadata), + // recallHRFEnabled(metadata), + // recallHRFMethod(metadata), + // recallHRFStrength(metadata), + // // SDXL parameters + // recallSDXLPositiveStylePrompt(metadata), + // recallSDXLNegativeStylePrompt(metadata), + // recallRefinerSteps(metadata), + // recallRefinerCFGScale(metadata), + // recallRefinerScheduler(metadata), + // recallRefinerPositiveAestheticScore(metadata), + // recallRefinerNegativeAestheticScore(metadata), + // recallRefinerStart(metadata), + // // Models + // recallVAE(metadata), + // recallRefinerModel(metadata), + // recallAllLoRAs(metadata), + // recallControlNets(metadata), + // recallT2IAdapters(metadata), + // recallIPAdapters(metadata), + // ]); + // allParameterSetToast(); +}; + +export const recallers = { + positivePrompt: recallPositivePrompt, + negativePrompt: recallNegativePrompt, + sdxlPositiveStylePrompt: recallSDXLPositiveStylePrompt, + sdxlNegativeStylePrompt: recallSDXLNegativeStylePrompt, + seed: recallSeed, + cfgScale: recallCFGScale, + cfgRescaleMultiplier: recallCFGRescaleMultiplier, + scheduler: recallScheduler, + width: recallWidth, + height: recallHeight, + steps: recallSteps, + strength: recallStrength, + hrfEnabled: recallHRFEnabled, + hrfStrength: recallHRFStrength, + hrfMethod: recallHRFMethod, + refinerSteps: recallRefinerSteps, + refinerCFGScale: recallRefinerCFGScale, + refinerScheduler: recallRefinerScheduler, + refinerPositiveAestheticScore: recallRefinerPositiveAestheticScore, + refinerNegativeAestheticScore: recallRefinerNegativeAestheticScore, + refinerStart: recallRefinerStart, + model: recallModel, + refinerModel: recallRefinerModel, + vae: recallVAE, + lora: recallLoRA, + loras: recallAllLoRAs, + controlNets: recallControlNets, + controlNet: recallControlNet, + t2iAdapters: recallT2IAdapters, + t2iAdapter: recallT2IAdapter, + ipAdapters: recallIPAdapters, + ipAdapter: recallIPAdapter, +} as const; diff --git a/invokeai/frontend/web/src/features/metadata/util/validators.ts b/invokeai/frontend/web/src/features/metadata/util/validators.ts new file mode 100644 index 0000000000..e627afeda3 --- /dev/null +++ b/invokeai/frontend/web/src/features/metadata/util/validators.ts @@ -0,0 +1,117 @@ +import { getStore } from 'app/store/nanostores/store'; +import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types'; +import type { LoRA } from 'features/lora/store/loraSlice'; +import type { MetadataValidateFunc } from 'features/metadata/types'; +import { InvalidModelConfigError } from 'features/parameters/util/modelFetchingHelpers'; +import type { BaseModelType, RefinerMainModelConfig, VAEModelConfig } from 'services/api/types'; + +/** + * Checks the given base model type against the currently-selected model's base type and throws an error if they are + * incompatible. + * @param base The base model type to validate. + * @param message An optional message to use in the error if the base model is incompatible. + */ +const validateBaseCompatibility = (base?: BaseModelType, message?: string) => { + if (!base) { + throw new InvalidModelConfigError(message || 'Missing base'); + } + const currentBase = getStore().getState().generation.model?.base; + if (currentBase && base !== currentBase) { + throw new InvalidModelConfigError(message || `Incompatible base models: ${base} and ${currentBase}`); + } +}; + +const validateRefinerModel: MetadataValidateFunc = (refinerModel) => { + validateBaseCompatibility('sdxl', 'Refiner incompatible with currently-selected model'); + return new Promise((resolve) => resolve(refinerModel)); +}; + +const validateVAEModel: MetadataValidateFunc = (vaeModel) => { + validateBaseCompatibility(vaeModel.base, 'VAE incompatible with currently-selected model'); + return new Promise((resolve) => resolve(vaeModel)); +}; + +const validateLoRA: MetadataValidateFunc = (lora) => { + validateBaseCompatibility(lora.model.base, 'LoRA incompatible with currently-selected model'); + return new Promise((resolve) => resolve(lora)); +}; + +const validateLoRAs: MetadataValidateFunc = (loras) => { + const validatedLoRAs: LoRA[] = []; + loras.forEach((lora) => { + try { + validateBaseCompatibility(lora.model.base, 'LoRA incompatible with currently-selected model'); + validatedLoRAs.push(lora); + } catch { + // This is a no-op - we want to continue validating the rest of the LoRAs, and an empty list is valid. + } + }); + return new Promise((resolve) => resolve(validatedLoRAs)); +}; + +const validateControlNet: MetadataValidateFunc = (controlNet) => { + validateBaseCompatibility(controlNet.model?.base, 'ControlNet incompatible with currently-selected model'); + return new Promise((resolve) => resolve(controlNet)); +}; + +const validateControlNets: MetadataValidateFunc = (controlNets) => { + const validatedControlNets: ControlNetConfig[] = []; + controlNets.forEach((controlNet) => { + try { + validateBaseCompatibility(controlNet.model?.base, 'ControlNet incompatible with currently-selected model'); + validatedControlNets.push(controlNet); + } catch { + // This is a no-op - we want to continue validating the rest of the ControlNets, and an empty list is valid. + } + }); + return new Promise((resolve) => resolve(validatedControlNets)); +}; + +const validateT2IAdapter: MetadataValidateFunc = (t2iAdapter) => { + validateBaseCompatibility(t2iAdapter.model?.base, 'T2I Adapter incompatible with currently-selected model'); + return new Promise((resolve) => resolve(t2iAdapter)); +}; + +const validateT2IAdapters: MetadataValidateFunc = (t2iAdapters) => { + const validatedT2IAdapters: T2IAdapterConfig[] = []; + t2iAdapters.forEach((t2iAdapter) => { + try { + validateBaseCompatibility(t2iAdapter.model?.base, 'T2I Adapter incompatible with currently-selected model'); + validatedT2IAdapters.push(t2iAdapter); + } catch { + // This is a no-op - we want to continue validating the rest of the T2I Adapters, and an empty list is valid. + } + }); + return new Promise((resolve) => resolve(validatedT2IAdapters)); +}; + +const validateIPAdapter: MetadataValidateFunc = (ipAdapter) => { + validateBaseCompatibility(ipAdapter.model?.base, 'IP Adapter incompatible with currently-selected model'); + return new Promise((resolve) => resolve(ipAdapter)); +}; + +const validateIPAdapters: MetadataValidateFunc = (ipAdapters) => { + const validatedIPAdapters: IPAdapterConfig[] = []; + ipAdapters.forEach((ipAdapter) => { + try { + validateBaseCompatibility(ipAdapter.model?.base, 'IP Adapter incompatible with currently-selected model'); + validatedIPAdapters.push(ipAdapter); + } catch { + // This is a no-op - we want to continue validating the rest of the IP Adapters, and an empty list is valid. + } + }); + return new Promise((resolve) => resolve(validatedIPAdapters)); +}; + +export const validators = { + refinerModel: validateRefinerModel, + vaeModel: validateVAEModel, + lora: validateLoRA, + loras: validateLoRAs, + controlNet: validateControlNet, + controlNets: validateControlNets, + t2iAdapter: validateT2IAdapter, + t2iAdapters: validateT2IAdapters, + ipAdapter: validateIPAdapter, + ipAdapters: validateIPAdapters, +} as const; diff --git a/invokeai/frontend/web/src/features/nodes/types/common.ts b/invokeai/frontend/web/src/features/nodes/types/common.ts index b195ce4434..501286c785 100644 --- a/invokeai/frontend/web/src/features/nodes/types/common.ts +++ b/invokeai/frontend/web/src/features/nodes/types/common.ts @@ -1,5 +1,8 @@ import { z } from 'zod'; +import type { ModelIdentifier as ModelIdentifierV2 } from './v2/common'; +import { zModelIdentifier as zModelIdentifierV2 } from './v2/common'; + // #region Field data schemas export const zImageField = z.object({ image_name: z.string().trim().min(1), @@ -69,6 +72,8 @@ export const zModelIdentifier = z.object({ }); export const isModelIdentifier = (field: unknown): field is ModelIdentifier => zModelIdentifier.safeParse(field).success; +export const isModelIdentifierV2 = (field: unknown): field is ModelIdentifierV2 => + zModelIdentifierV2.safeParse(field).success; export const zModelFieldBase = zModelIdentifier; export const zModelIdentifierWithBase = zModelIdentifier.extend({ base: zBaseModel }); export type BaseModel = z.infer; diff --git a/invokeai/frontend/web/src/features/nodes/types/metadata.ts b/invokeai/frontend/web/src/features/nodes/types/metadata.ts index 493a0464b3..2789abaaca 100644 --- a/invokeai/frontend/web/src/features/nodes/types/metadata.ts +++ b/invokeai/frontend/web/src/features/nodes/types/metadata.ts @@ -1,29 +1,22 @@ import { z } from 'zod'; -import { - zControlField, - zIPAdapterField, - zMainModelField, - zModelFieldBase, - zSDXLRefinerModelField, - zT2IAdapterField, - zVAEModelField, -} from './common'; +import { zControlField, zIPAdapterField, zT2IAdapterField } from './common'; +export const zLoRAWeight = z.number().nullish(); // #region Metadata-optimized versions of schemas // TODO: It's possible that `deepPartial` will be deprecated: // - https://github.com/colinhacks/zod/issues/2106 // - https://github.com/colinhacks/zod/issues/2854 export const zLoRAMetadataItem = z.object({ - lora: zModelFieldBase.deepPartial(), - weight: z.number(), + lora: z.unknown(), + weight: zLoRAWeight, }); -const zControlNetMetadataItem = zControlField.deepPartial(); -const zIPAdapterMetadataItem = zIPAdapterField.deepPartial(); -const zT2IAdapterMetadataItem = zT2IAdapterField.deepPartial(); -const zSDXLRefinerModelMetadataItem = zSDXLRefinerModelField.deepPartial(); -const zModelMetadataItem = zMainModelField.deepPartial(); -const zVAEModelMetadataItem = zVAEModelField.deepPartial(); +const zControlNetMetadataItem = zControlField.merge(z.object({ control_model: z.unknown() })).deepPartial(); +const zIPAdapterMetadataItem = zIPAdapterField.merge(z.object({ ip_adapter_model: z.unknown() })).deepPartial(); +const zT2IAdapterMetadataItem = zT2IAdapterField.merge(z.object({ t2i_adapter_model: z.unknown() })).deepPartial(); +const zSDXLRefinerModelMetadataItem = z.unknown(); +const zModelMetadataItem = z.unknown(); +const zVAEModelMetadataItem = z.unknown(); export type LoRAMetadataItem = z.infer; export type ControlNetMetadataItem = z.infer; export type IPAdapterMetadataItem = z.infer; diff --git a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts index 0929fc1dc3..8f9a76feca 100644 --- a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts +++ b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts @@ -1,17 +1,25 @@ -import { useAppToaster } from 'app/components/Toaster'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { getStore } from 'app/store/nanostores/store'; +import { useAppDispatch } from 'app/store/storeHooks'; +import { toast } from 'common/util/toast'; +import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants'; import { controlAdapterRecalled, controlAdaptersReset } from 'features/controlAdapters/store/controlAdaptersSlice'; +import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types'; +import { + initialControlNet, + initialIPAdapter, + initialT2IAdapter, +} from 'features/controlAdapters/util/buildControlAdapter'; import { setHrfEnabled, setHrfMethod, setHrfStrength } from 'features/hrf/store/hrfSlice'; -import { loraRecalled, lorasCleared } from 'features/lora/store/loraSlice'; -import { isModelIdentifier } from 'features/nodes/types/common'; -import type { - ControlNetMetadataItem, - CoreMetadata, - IPAdapterMetadataItem, - LoRAMetadataItem, - T2IAdapterMetadataItem, -} from 'features/nodes/types/metadata'; +import type { LoRA } from 'features/lora/store/loraSlice'; +import { defaultLoRAConfig, loraRecalled, lorasCleared } from 'features/lora/store/loraSlice'; +import type { ModelIdentifierWithBase } from 'features/nodes/types/common'; +import { + zControlField, + zIPAdapterField, + zModelIdentifierWithBase, + zT2IAdapterField, +} from 'features/nodes/types/common'; import { initialImageSelected, modelSelected } from 'features/parameters/store/actions'; import { heightRecalled, @@ -27,19 +35,18 @@ import { vaeSelected, widthRecalled, } from 'features/parameters/store/generationSlice'; -import type { ParameterModel } from 'features/parameters/types/parameterSchemas'; import { isParameterCFGRescaleMultiplier, isParameterCFGScale, isParameterHeight, isParameterHRFEnabled, isParameterHRFMethod, + isParameterLoRAWeight, isParameterNegativePrompt, isParameterNegativeStylePromptSDXL, isParameterPositivePrompt, isParameterPositiveStylePromptSDXL, isParameterScheduler, - isParameterSDXLRefinerModel, isParameterSDXLRefinerNegativeAestheticScore, isParameterSDXLRefinerPositiveAestheticScore, isParameterSDXLRefinerStart, @@ -49,13 +56,16 @@ import { isParameterWidth, } from 'features/parameters/types/parameterSchemas'; import { - prepareControlNetMetadataItem, - prepareIPAdapterMetadataItem, - prepareLoRAMetadataItem, - prepareMainModelMetadataItem, - prepareT2IAdapterMetadataItem, - prepareVAEMetadataItem, -} from 'features/parameters/util/modelMetadataHelpers'; + fetchControlNetModel, + fetchIPAdapterModel, + fetchLoRAModel, + fetchMainModelConfig, + fetchRefinerModelConfig, + fetchT2IAdapterModel, + fetchVAEModelConfig, + getModelKey, + raiseIfBaseIncompatible, +} from 'features/parameters/util/modelFetchingHelpers'; import { refinerModelChanged, setNegativeStylePromptSDXL, @@ -67,229 +77,663 @@ import { setRefinerStart, setRefinerSteps, } from 'features/sdxl/store/sdxlSlice'; -import { isNil } from 'lodash-es'; +import { t } from 'i18next'; +import { get, isArray, isNil } from 'lodash-es'; import { useCallback } from 'react'; -import { useTranslation } from 'react-i18next'; -import type { ImageDTO } from 'services/api/types'; +import type { BaseModelType, ImageDTO } from 'services/api/types'; +import { v4 as uuidv4 } from 'uuid'; const selectModel = createMemoizedSelector(selectGenerationSlice, (generation) => generation.model); +/** + * A function that recalls from metadata from the full metadata object. + */ +type MetadataRecallFunc = (metadata: unknown, withToast?: boolean) => void; + +/** + * A function that recalls metadata from a specific metadata item. + */ +type MetadataItemRecallFunc = (metadataItem: unknown, withToast?: boolean) => void; + +/** + * Raised when metadata recall fails. + */ +export class MetadataRecallError extends Error { + /** + * Create MetadataRecallError + * @param {String} message + */ + constructor(message: string) { + super(message); + this.name = this.constructor.name; + } +} + +export class InvalidMetadataPropertyType extends MetadataRecallError {} + +const getProperty = ( + obj: unknown, + property: string, + typeGuard: (val: unknown) => val is T = (val: unknown): val is T => true +): T => { + const val = get(obj, property) as unknown; + if (typeGuard(val)) { + return val; + } + throw new InvalidMetadataPropertyType(`Property ${property} is not of expected type`); +}; + +const getCurrentBase = () => selectModel(getStore().getState())?.base; + +const parameterSetToast = (parameter: string, description?: string) => { + toast({ + title: t('toast.parameterSet', { parameter }), + description, + status: 'info', + duration: 2500, + isClosable: true, + }); +}; + +const parameterNotSetToast = (parameter: string, description?: string) => { + toast({ + title: t('toast.parameterNotSet', { parameter }), + description, + status: 'warning', + duration: 2500, + isClosable: true, + }); +}; + +const allParameterSetToast = (description?: string) => { + toast({ + title: t('toast.parametersSet'), + status: 'info', + description, + duration: 2500, + isClosable: true, + }); +}; + +const allParameterNotSetToast = (description?: string) => { + toast({ + title: t('toast.parametersNotSet'), + status: 'warning', + description, + duration: 2500, + isClosable: true, + }); +}; + +const recall = (callback: () => void, parameter: string, withToast = true) => { + try { + callback(); + withToast && parameterSetToast(parameter); + } catch (e) { + withToast && parameterNotSetToast(parameter, (e as Error).message); + } +}; + +const recallAsync = async (callback: () => Promise, parameter: string, withToast = true) => { + try { + await callback(); + withToast && parameterSetToast(parameter); + } catch (e) { + withToast && parameterNotSetToast(parameter, (e as Error).message); + } +}; + +export const recallPositivePrompt: MetadataRecallFunc = (metadata: unknown, withToast = true) => { + recall( + () => { + const positive_prompt = getProperty(metadata, 'positive_prompt', isParameterPositivePrompt); + getStore().dispatch(setPositivePrompt(positive_prompt)); + }, + t('metadata.positivePrompt'), + withToast + ); +}; + +export const recallNegativePrompt: MetadataRecallFunc = (metadata: unknown, withToast = true) => { + recall( + () => { + const negative_prompt = getProperty(metadata, 'negative_prompt', isParameterNegativePrompt); + getStore().dispatch(setNegativePrompt(negative_prompt)); + }, + t('metadata.negativePrompt'), + withToast + ); +}; + +export const recallSDXLPositiveStylePrompt: MetadataRecallFunc = (metadata: unknown, withToast = true) => { + recall( + () => { + const positive_style_prompt = getProperty(metadata, 'positive_style_prompt', isParameterPositiveStylePromptSDXL); + getStore().dispatch(setPositiveStylePromptSDXL(positive_style_prompt)); + }, + t('sdxl.posStylePrompt'), + withToast + ); +}; + +export const recallSDXLNegativeStylePrompt: MetadataRecallFunc = (metadata: unknown, withToast = true) => { + recall( + () => { + const negative_style_prompt = getProperty(metadata, 'negative_style_prompt', isParameterNegativeStylePromptSDXL); + getStore().dispatch(setNegativeStylePromptSDXL(negative_style_prompt)); + }, + t('sdxl.negStylePrompt'), + withToast + ); +}; + +export const recallSeed: MetadataRecallFunc = (metadata: unknown, withToast = true) => { + recall( + () => { + const seed = getProperty(metadata, 'seed', isParameterSeed); + getStore().dispatch(setSeed(seed)); + }, + t('metadata.seed'), + withToast + ); +}; + +export const recallCFGScale: MetadataRecallFunc = (metadata: unknown, withToast = true) => { + recall( + () => { + const cfg_scale = getProperty(metadata, 'cfg_scale', isParameterCFGScale); + getStore().dispatch(setCfgScale(cfg_scale)); + }, + t('metadata.cfgScale'), + withToast + ); +}; + +export const recallCFGRescaleMultiplier: MetadataRecallFunc = (metadata: unknown, withToast = true) => { + recall( + () => { + const cfg_rescale_multiplier = getProperty(metadata, 'cfg_rescale_multiplier', isParameterCFGRescaleMultiplier); + getStore().dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier)); + }, + t('metadata.cfgRescaleMultiplier'), + withToast + ); +}; + +export const recallScheduler: MetadataRecallFunc = (metadata: unknown, withToast = true) => { + recall( + () => { + const scheduler = getProperty(metadata, 'scheduler', isParameterScheduler); + getStore().dispatch(setScheduler(scheduler)); + }, + t('metadata.scheduler'), + withToast + ); +}; + +export const recallWidth: MetadataRecallFunc = (metadata: unknown, withToast = true) => { + recall( + () => { + const width = getProperty(metadata, 'width', isParameterWidth); + getStore().dispatch(widthRecalled(width)); + }, + t('metadata.width'), + withToast + ); +}; + +export const recallHeight: MetadataRecallFunc = (metadata: unknown, withToast = true) => { + recall( + () => { + const height = getProperty(metadata, 'height', isParameterHeight); + getStore().dispatch(heightRecalled(height)); + }, + t('metadata.height'), + withToast + ); +}; + +export const recallSteps: MetadataRecallFunc = (metadata: unknown, withToast = true) => { + recall( + () => { + const steps = getProperty(metadata, 'steps', isParameterSteps); + getStore().dispatch(setSteps(steps)); + }, + t('metadata.steps'), + withToast + ); +}; + +export const recallStrength: MetadataRecallFunc = (metadata: unknown, withToast = true) => { + recall( + () => { + const strength = getProperty(metadata, 'strength', isParameterStrength); + getStore().dispatch(setImg2imgStrength(strength)); + }, + t('metadata.strength'), + withToast + ); +}; + +export const recallHRFEnabled: MetadataRecallFunc = (metadata: unknown, withToast = true) => { + recall( + () => { + const hrf_enabled = getProperty(metadata, 'hrf_enabled', isParameterHRFEnabled); + getStore().dispatch(setHrfEnabled(hrf_enabled)); + }, + t('hrf.metadata.enabled'), + withToast + ); +}; + +export const recallHRFStrength: MetadataRecallFunc = (metadata: unknown, withToast = true) => { + recall( + () => { + const hrf_strength = getProperty(metadata, 'hrf_strength', isParameterStrength); + getStore().dispatch(setHrfStrength(hrf_strength)); + }, + t('hrf.metadata.strength'), + withToast + ); +}; + +export const recallHRFMethod: MetadataRecallFunc = (metadata: unknown, withToast = true) => { + recall( + () => { + const hrf_method = getProperty(metadata, 'hrf_method', isParameterHRFMethod); + getStore().dispatch(setHrfMethod(hrf_method)); + }, + t('hrf.metadata.method'), + withToast + ); +}; + +export const recallRefinerSteps: MetadataRecallFunc = (metadata: unknown, withToast = true) => { + recall( + () => { + const refiner_steps = getProperty(metadata, 'refiner_steps', isParameterSteps); + getStore().dispatch(setRefinerSteps(refiner_steps)); + }, + t('sdxl.steps'), + withToast + ); +}; + +export const recallRefinerCFGScale: MetadataRecallFunc = (metadata: unknown, withToast = true) => { + recall( + () => { + const refiner_cfg_scale = getProperty(metadata, 'refiner_cfg_scale', isParameterCFGScale); + getStore().dispatch(setRefinerCFGScale(refiner_cfg_scale)); + }, + t('sdxl.cfgScale'), + withToast + ); +}; + +export const recallRefinerScheduler: MetadataRecallFunc = (metadata: unknown, withToast = true) => { + recall( + () => { + const refiner_scheduler = getProperty(metadata, 'refiner_scheduler', isParameterScheduler); + getStore().dispatch(setRefinerScheduler(refiner_scheduler)); + }, + t('sdxl.cfgScale'), + withToast + ); +}; + +export const recallRefinerPositiveAestheticScore: MetadataRecallFunc = (metadata: unknown, withToast = true) => { + recall( + () => { + const refiner_positive_aesthetic_score = getProperty( + metadata, + 'refiner_positive_aesthetic_score', + isParameterSDXLRefinerPositiveAestheticScore + ); + getStore().dispatch(setRefinerPositiveAestheticScore(refiner_positive_aesthetic_score)); + }, + t('sdxl.posAestheticScore'), + withToast + ); +}; + +export const recallRefinerNegativeAestheticScore: MetadataRecallFunc = (metadata: unknown, withToast = true) => { + recall( + () => { + const refiner_negative_aesthetic_score = getProperty( + metadata, + 'refiner_negative_aesthetic_score', + isParameterSDXLRefinerNegativeAestheticScore + ); + getStore().dispatch(setRefinerNegativeAestheticScore(refiner_negative_aesthetic_score)); + }, + t('sdxl.negAestheticScore'), + withToast + ); +}; + +export const recallRefinerStart: MetadataRecallFunc = (metadata: unknown, withToast = true) => { + recall( + () => { + const refiner_start = getProperty(metadata, 'refiner_start', isParameterSDXLRefinerStart); + getStore().dispatch(setRefinerStart(refiner_start)); + }, + t('sdxl.refinerStart'), + withToast + ); +}; + +export const prepareMainModelMetadataItem = async (model: unknown): Promise => { + const key = await getModelKey(model, 'main'); + const mainModel = await fetchMainModelConfig(key); + return zModelIdentifierWithBase.parse(mainModel); +}; + +const recallModelAsync: MetadataRecallFunc = async (metadata: unknown, withToast = true) => { + await recallAsync( + async () => { + const modelMetadataItem = getProperty(metadata, 'model'); + const model = await prepareMainModelMetadataItem(modelMetadataItem); + getStore().dispatch(modelSelected(model)); + }, + t('metadata.model'), + withToast + ); +}; + +export const prepareRefinerMetadataItem = async ( + model: unknown, + currentBase: BaseModelType | undefined +): Promise => { + const key = await getModelKey(model, 'main'); + const refinerModel = await fetchRefinerModelConfig(key); + raiseIfBaseIncompatible('sdxl-refiner', currentBase, 'Refiner incompatible with currently-selected model'); + return zModelIdentifierWithBase.parse(refinerModel); +}; + +const recallRefinerModelAsync: MetadataRecallFunc = async (metadata: unknown, withToast = true) => { + await recallAsync( + async () => { + const refinerMetadataItem = getProperty(metadata, 'refiner_model'); + const refiner = await prepareRefinerMetadataItem(refinerMetadataItem, getCurrentBase()); + getStore().dispatch(refinerModelChanged(refiner)); + }, + t('sdxl.refinerModel'), + withToast + ); +}; + +export const prepareVAEMetadataItem = async ( + vae: unknown, + currentBase: BaseModelType | undefined +): Promise => { + const key = await getModelKey(vae, 'vae'); + const vaeModel = await fetchVAEModelConfig(key); + raiseIfBaseIncompatible(vaeModel.base, currentBase, 'VAE incompatible with currently-selected model'); + return zModelIdentifierWithBase.parse(vaeModel); +}; + +const recallVAEAsync: MetadataRecallFunc = async (metadata: unknown, withToast = true) => { + await recallAsync( + async () => { + const currentBase = getCurrentBase(); + const vaeMetadataItem = getProperty(metadata, 'vae'); + if (isNil(vaeMetadataItem)) { + getStore().dispatch(vaeSelected(null)); + } else { + const vae = await prepareVAEMetadataItem(vaeMetadataItem, currentBase); + getStore().dispatch(vaeSelected(vae)); + } + }, + t('metadata.vae'), + withToast + ); +}; + +export const prepareLoRAMetadataItem = async ( + loraMetadataItem: unknown, + currentBase: BaseModelType | undefined +): Promise => { + const lora = getProperty(loraMetadataItem, 'lora'); + const weight = getProperty(loraMetadataItem, 'weight'); + const key = await getModelKey(lora, 'lora'); + const loraModel = await fetchLoRAModel(key); + raiseIfBaseIncompatible(loraModel.base, currentBase, 'LoRA incompatible with currently-selected model'); + return { + key: loraModel.key, + base: loraModel.base, + weight: isParameterLoRAWeight(weight) ? weight : defaultLoRAConfig.weight, + isEnabled: true, + }; +}; + +const recallLoRAAsync: MetadataItemRecallFunc = async (metadataItem: unknown, withToast = true) => { + await recallAsync( + async () => { + const currentBase = getCurrentBase(); + const lora = await prepareLoRAMetadataItem(metadataItem, currentBase); + getStore().dispatch(loraRecalled(lora)); + }, + t('models.lora'), + withToast + ); +}; + +export const prepareControlNetMetadataItem = async ( + metadataItem: unknown, + currentBase: BaseModelType | undefined +): Promise => { + const control_model = getProperty(metadataItem, 'control_model'); + const key = await getModelKey(control_model, 'controlnet'); + const controlNetModel = await fetchControlNetModel(key); + raiseIfBaseIncompatible(controlNetModel.base, currentBase, 'ControlNet incompatible with currently-selected model'); + + const image = zControlField.shape.image.nullish().catch(null).parse(getProperty(metadataItem, 'image')); + const control_weight = zControlField.shape.control_weight + .nullish() + .catch(null) + .parse(getProperty(metadataItem, 'control_weight')); + const begin_step_percent = zControlField.shape.begin_step_percent + .nullish() + .catch(null) + .parse(getProperty(metadataItem, 'begin_step_percent')); + const end_step_percent = zControlField.shape.end_step_percent + .nullish() + .catch(null) + .parse(getProperty(metadataItem, 'end_step_percent')); + const control_mode = zControlField.shape.control_mode + .nullish() + .catch(null) + .parse(getProperty(metadataItem, 'control_mode')); + const resize_mode = zControlField.shape.resize_mode + .nullish() + .catch(null) + .parse(getProperty(metadataItem, 'resize_mode')); + + // We don't save the original image that was processed into a control image, only the processed image + const processorType = 'none'; + const processorNode = CONTROLNET_PROCESSORS.none.default; + + const controlnet: ControlNetConfig = { + type: 'controlnet', + isEnabled: true, + model: zModelIdentifierWithBase.parse(controlNetModel), + weight: typeof control_weight === 'number' ? control_weight : initialControlNet.weight, + beginStepPct: begin_step_percent ?? initialControlNet.beginStepPct, + endStepPct: end_step_percent ?? initialControlNet.endStepPct, + controlMode: control_mode ?? initialControlNet.controlMode, + resizeMode: resize_mode ?? initialControlNet.resizeMode, + controlImage: image?.image_name ?? null, + processedControlImage: image?.image_name ?? null, + processorType, + processorNode, + shouldAutoConfig: true, + id: uuidv4(), + }; + + return controlnet; +}; + +const recallControlNetAsync: MetadataItemRecallFunc = async (metadataItem: unknown, withToast = true) => { + await recallAsync( + async () => { + const currentBase = getCurrentBase(); + const controlNetConfig = await prepareControlNetMetadataItem(metadataItem, currentBase); + getStore().dispatch(controlAdapterRecalled(controlNetConfig)); + }, + t('common.controlNet'), + withToast + ); +}; + +export const prepareT2IAdapterMetadataItem = async ( + metadataItem: unknown, + currentBase: BaseModelType | undefined +): Promise => { + const t2i_adapter_model = getProperty(metadataItem, 't2i_adapter_model'); + const key = await getModelKey(t2i_adapter_model, 't2i_adapter'); + const t2iAdapterModel = await fetchT2IAdapterModel(key); + raiseIfBaseIncompatible(t2iAdapterModel.base, currentBase, 'T2I Adapter incompatible with currently-selected model'); + + const image = zT2IAdapterField.shape.image.nullish().catch(null).parse(getProperty(metadataItem, 'image')); + const weight = zT2IAdapterField.shape.weight.nullish().catch(null).parse(getProperty(metadataItem, 'weight')); + const begin_step_percent = zT2IAdapterField.shape.begin_step_percent + .nullish() + .catch(null) + .parse(getProperty(metadataItem, 'begin_step_percent')); + const end_step_percent = zT2IAdapterField.shape.end_step_percent + .nullish() + .catch(null) + .parse(getProperty(metadataItem, 'end_step_percent')); + const resize_mode = zT2IAdapterField.shape.resize_mode + .nullish() + .catch(null) + .parse(getProperty(metadataItem, 'resize_mode')); + + // We don't save the original image that was processed into a control image, only the processed image + const processorType = 'none'; + const processorNode = CONTROLNET_PROCESSORS.none.default; + + const t2iAdapter: T2IAdapterConfig = { + type: 't2i_adapter', + isEnabled: true, + model: zModelIdentifierWithBase.parse(t2iAdapterModel), + weight: typeof weight === 'number' ? weight : initialT2IAdapter.weight, + beginStepPct: begin_step_percent ?? initialT2IAdapter.beginStepPct, + endStepPct: end_step_percent ?? initialT2IAdapter.endStepPct, + resizeMode: resize_mode ?? initialT2IAdapter.resizeMode, + controlImage: image?.image_name ?? null, + processedControlImage: image?.image_name ?? null, + processorType, + processorNode, + shouldAutoConfig: true, + id: uuidv4(), + }; + + return t2iAdapter; +}; + +const recallT2IAdapterAsync: MetadataItemRecallFunc = async (metadataItem: unknown, withToast = true) => { + await recallAsync( + async () => { + const currentBase = getCurrentBase(); + const t2iAdapterConfig = await prepareT2IAdapterMetadataItem(metadataItem, currentBase); + getStore().dispatch(controlAdapterRecalled(t2iAdapterConfig)); + }, + t('common.t2iAdapter'), + withToast + ); +}; + +export const prepareIPAdapterMetadataItem = async ( + metadataItem: unknown, + currentBase: BaseModelType | undefined +): Promise => { + const ip_adapter_model = getProperty(metadataItem, 'ip_adapter_model'); + const key = await getModelKey(ip_adapter_model, 'ip_adapter'); + const ipAdapterModel = await fetchIPAdapterModel(key); + raiseIfBaseIncompatible(ipAdapterModel.base, currentBase, 'T2I Adapter incompatible with currently-selected model'); + + const image = zIPAdapterField.shape.image.nullish().catch(null).parse(getProperty(metadataItem, 'image')); + const weight = zIPAdapterField.shape.weight.nullish().catch(null).parse(getProperty(metadataItem, 'weight')); + const begin_step_percent = zIPAdapterField.shape.begin_step_percent + .nullish() + .catch(null) + .parse(getProperty(metadataItem, 'begin_step_percent')); + const end_step_percent = zIPAdapterField.shape.end_step_percent + .nullish() + .catch(null) + .parse(getProperty(metadataItem, 'end_step_percent')); + + const ipAdapter: IPAdapterConfig = { + id: uuidv4(), + type: 'ip_adapter', + isEnabled: true, + model: zModelIdentifierWithBase.parse(ipAdapterModel), + controlImage: image?.image_name ?? null, + weight: weight ?? initialIPAdapter.weight, + beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct, + endStepPct: end_step_percent ?? initialIPAdapter.endStepPct, + }; + + return ipAdapter; +}; + +const recallIPAdapterAsync: MetadataItemRecallFunc = async (metadataItem: unknown, withToast) => { + await recallAsync( + async () => { + const currentBase = getCurrentBase(); + const ipAdapterConfig = await prepareIPAdapterMetadataItem(metadataItem, currentBase); + getStore().dispatch(controlAdapterRecalled(ipAdapterConfig)); + }, + t('common.ipAdapter'), + withToast + ); +}; + export const useRecallParameters = () => { const dispatch = useAppDispatch(); - const toaster = useAppToaster(); - const { t } = useTranslation(); - const model = useAppSelector(selectModel); - const parameterSetToast = useCallback(() => { - toaster({ - title: t('toast.parameterSet'), - status: 'info', - duration: 2500, - isClosable: true, - }); - }, [t, toaster]); - - const parameterNotSetToast = useCallback( - (description?: string) => { - toaster({ - title: t('toast.parameterNotSet'), - description, - status: 'warning', - duration: 2500, - isClosable: true, - }); - }, - [t, toaster] - ); - - const allParameterSetToast = useCallback(() => { - toaster({ - title: t('toast.parametersSet'), - status: 'info', - duration: 2500, - isClosable: true, - }); - }, [t, toaster]); - - const allParameterNotSetToast = useCallback( - (description?: string) => { - toaster({ - title: t('toast.parametersNotSet'), - status: 'warning', - description, - duration: 2500, - isClosable: true, - }); - }, - [t, toaster] - ); - - const recallBothPrompts = useCallback( - (positivePrompt: unknown, negativePrompt: unknown, positiveStylePrompt: unknown, negativeStylePrompt: unknown) => { + const recallBothPrompts = useCallback( + (metadata: unknown) => { + const positive_prompt = getProperty(metadata, 'positive_prompt'); + const negative_prompt = getProperty(metadata, 'negative_prompt'); + const positive_style_prompt = getProperty(metadata, 'positive_style_prompt'); + const negative_style_prompt = getProperty(metadata, 'negative_style_prompt'); if ( - isParameterPositivePrompt(positivePrompt) || - isParameterNegativePrompt(negativePrompt) || - isParameterPositiveStylePromptSDXL(positiveStylePrompt) || - isParameterNegativeStylePromptSDXL(negativeStylePrompt) + isParameterPositivePrompt(positive_prompt) || + isParameterNegativePrompt(negative_prompt) || + isParameterPositiveStylePromptSDXL(positive_style_prompt) || + isParameterNegativeStylePromptSDXL(negative_style_prompt) ) { - if (isParameterPositivePrompt(positivePrompt)) { - dispatch(setPositivePrompt(positivePrompt)); + if (isParameterPositivePrompt(positive_prompt)) { + dispatch(setPositivePrompt(positive_prompt)); } - if (isParameterNegativePrompt(negativePrompt)) { - dispatch(setNegativePrompt(negativePrompt)); + if (isParameterNegativePrompt(negative_prompt)) { + dispatch(setNegativePrompt(negative_prompt)); } - if (isParameterPositiveStylePromptSDXL(positiveStylePrompt)) { - dispatch(setPositiveStylePromptSDXL(positiveStylePrompt)); + if (isParameterPositiveStylePromptSDXL(positive_style_prompt)) { + dispatch(setPositiveStylePromptSDXL(positive_style_prompt)); } - if (isParameterPositiveStylePromptSDXL(negativeStylePrompt)) { - dispatch(setNegativeStylePromptSDXL(negativeStylePrompt)); + if (isParameterPositiveStylePromptSDXL(negative_style_prompt)) { + dispatch(setNegativeStylePromptSDXL(negative_style_prompt)); } - parameterSetToast(); + parameterSetToast(t('metadata.allPrompts')); return; } - parameterNotSetToast(); + parameterNotSetToast(t('metadata.allPrompts')); }, - [dispatch, parameterSetToast, parameterNotSetToast] + [dispatch] ); - const recallPositivePrompt = useCallback( - (positivePrompt: unknown) => { - if (!isParameterPositivePrompt(positivePrompt)) { - parameterNotSetToast(); - return; - } - dispatch(setPositivePrompt(positivePrompt)); - parameterSetToast(); - }, - [dispatch, parameterSetToast, parameterNotSetToast] - ); + const recallWidthAndHeight = useCallback( + (metadata: unknown) => { + const width = getProperty(metadata, 'width'); + const height = getProperty(metadata, 'height'); - const recallNegativePrompt = useCallback( - (negativePrompt: unknown) => { - if (!isParameterNegativePrompt(negativePrompt)) { - parameterNotSetToast(); - return; - } - dispatch(setNegativePrompt(negativePrompt)); - parameterSetToast(); - }, - [dispatch, parameterSetToast, parameterNotSetToast] - ); - - const recallSDXLPositiveStylePrompt = useCallback( - (positiveStylePrompt: unknown) => { - if (!isParameterPositiveStylePromptSDXL(positiveStylePrompt)) { - parameterNotSetToast(); - return; - } - dispatch(setPositiveStylePromptSDXL(positiveStylePrompt)); - parameterSetToast(); - }, - [dispatch, parameterSetToast, parameterNotSetToast] - ); - - const recallSDXLNegativeStylePrompt = useCallback( - (negativeStylePrompt: unknown) => { - if (!isParameterNegativeStylePromptSDXL(negativeStylePrompt)) { - parameterNotSetToast(); - return; - } - dispatch(setNegativeStylePromptSDXL(negativeStylePrompt)); - parameterSetToast(); - }, - [dispatch, parameterSetToast, parameterNotSetToast] - ); - - const recallSeed = useCallback( - (seed: unknown) => { - if (!isParameterSeed(seed)) { - parameterNotSetToast(); - return; - } - dispatch(setSeed(seed)); - parameterSetToast(); - }, - [dispatch, parameterSetToast, parameterNotSetToast] - ); - - const recallCfgScale = useCallback( - (cfgScale: unknown) => { - if (!isParameterCFGScale(cfgScale)) { - parameterNotSetToast(); - return; - } - dispatch(setCfgScale(cfgScale)); - parameterSetToast(); - }, - [dispatch, parameterSetToast, parameterNotSetToast] - ); - - const recallCfgRescaleMultiplier = useCallback( - (cfgRescaleMultiplier: unknown) => { - if (!isParameterCFGRescaleMultiplier(cfgRescaleMultiplier)) { - parameterNotSetToast(); - return; - } - dispatch(setCfgRescaleMultiplier(cfgRescaleMultiplier)); - parameterSetToast(); - }, - [dispatch, parameterSetToast, parameterNotSetToast] - ); - - const recallScheduler = useCallback( - (scheduler: unknown) => { - if (!isParameterScheduler(scheduler)) { - parameterNotSetToast(); - return; - } - dispatch(setScheduler(scheduler)); - parameterSetToast(); - }, - [dispatch, parameterSetToast, parameterNotSetToast] - ); - - const recallSteps = useCallback( - (steps: unknown) => { - if (!isParameterSteps(steps)) { - parameterNotSetToast(); - return; - } - dispatch(setSteps(steps)); - parameterSetToast(); - }, - [dispatch, parameterSetToast, parameterNotSetToast] - ); - - const recallWidth = useCallback( - (width: unknown) => { - if (!isParameterWidth(width)) { - parameterNotSetToast(); - return; - } - dispatch(widthRecalled(width)); - parameterSetToast(); - }, - [dispatch, parameterSetToast, parameterNotSetToast] - ); - - const recallHeight = useCallback( - (height: unknown) => { - if (!isParameterHeight(height)) { - parameterNotSetToast(); - return; - } - dispatch(heightRecalled(height)); - parameterSetToast(); - }, - [dispatch, parameterSetToast, parameterNotSetToast] - ); - - const recallWidthAndHeight = useCallback( - (width: unknown, height: unknown) => { if (!isParameterWidth(width)) { allParameterNotSetToast(); return; @@ -302,144 +746,7 @@ export const useRecallParameters = () => { dispatch(widthRecalled(width)); allParameterSetToast(); }, - [dispatch, allParameterSetToast, allParameterNotSetToast] - ); - - const recallStrength = useCallback( - (strength: unknown) => { - if (!isParameterStrength(strength)) { - parameterNotSetToast(); - return; - } - dispatch(setImg2imgStrength(strength)); - parameterSetToast(); - }, - [dispatch, parameterSetToast, parameterNotSetToast] - ); - - const recallHrfEnabled = useCallback( - (hrfEnabled: unknown) => { - if (!isParameterHRFEnabled(hrfEnabled)) { - parameterNotSetToast(); - return; - } - dispatch(setHrfEnabled(hrfEnabled)); - parameterSetToast(); - }, - [dispatch, parameterSetToast, parameterNotSetToast] - ); - - const recallHrfStrength = useCallback( - (hrfStrength: unknown) => { - if (!isParameterStrength(hrfStrength)) { - parameterNotSetToast(); - return; - } - dispatch(setHrfStrength(hrfStrength)); - parameterSetToast(); - }, - [dispatch, parameterSetToast, parameterNotSetToast] - ); - - const recallHrfMethod = useCallback( - (hrfMethod: unknown) => { - if (!isParameterHRFMethod(hrfMethod)) { - parameterNotSetToast(); - return; - } - dispatch(setHrfMethod(hrfMethod)); - parameterSetToast(); - }, - [dispatch, parameterSetToast, parameterNotSetToast] - ); - - const recallModel = useCallback( - async (modelMetadataItem: unknown) => { - try { - const model = await prepareMainModelMetadataItem(modelMetadataItem); - dispatch(modelSelected(model)); - parameterSetToast(); - } catch (e) { - parameterNotSetToast((e as unknown as Error).message); - return; - } - }, - [dispatch, parameterSetToast, parameterNotSetToast] - ); - - const recallVaeModel = useCallback( - async (vaeMetadataItem: unknown) => { - if (isNil(vaeMetadataItem)) { - dispatch(vaeSelected(null)); - parameterSetToast(); - return; - } - try { - const vae = await prepareVAEMetadataItem(vaeMetadataItem); - dispatch(vaeSelected(vae)); - parameterSetToast(); - } catch (e) { - parameterNotSetToast((e as unknown as Error).message); - return; - } - }, - [dispatch, parameterSetToast, parameterNotSetToast] - ); - - const recallLoRA = useCallback( - async (loraMetadataItem: LoRAMetadataItem) => { - try { - const lora = await prepareLoRAMetadataItem(loraMetadataItem, model?.base); - dispatch(loraRecalled(lora)); - parameterSetToast(); - } catch (e) { - parameterNotSetToast((e as unknown as Error).message); - return; - } - }, - [model?.base, dispatch, parameterSetToast, parameterNotSetToast] - ); - - const recallControlNet = useCallback( - async (controlnetMetadataItem: ControlNetMetadataItem) => { - try { - const controlNetConfig = await prepareControlNetMetadataItem(controlnetMetadataItem, model?.base); - dispatch(controlAdapterRecalled(controlNetConfig)); - parameterSetToast(); - } catch (e) { - parameterNotSetToast((e as unknown as Error).message); - return; - } - }, - [model?.base, dispatch, parameterSetToast, parameterNotSetToast] - ); - - const recallT2IAdapter = useCallback( - async (t2iAdapterMetadataItem: T2IAdapterMetadataItem) => { - try { - const t2iAdapterConfig = await prepareT2IAdapterMetadataItem(t2iAdapterMetadataItem, model?.base); - dispatch(controlAdapterRecalled(t2iAdapterConfig)); - parameterSetToast(); - } catch (e) { - parameterNotSetToast((e as unknown as Error).message); - return; - } - }, - [model?.base, dispatch, parameterSetToast, parameterNotSetToast] - ); - - const recallIPAdapter = useCallback( - async (ipAdapterMetadataItem: IPAdapterMetadataItem) => { - try { - const ipAdapterConfig = await prepareIPAdapterMetadataItem(ipAdapterMetadataItem, model?.base); - dispatch(controlAdapterRecalled(ipAdapterConfig)); - parameterSetToast(); - } catch (e) { - parameterNotSetToast((e as unknown as Error).message); - return; - } - }, - [model?.base, dispatch, parameterSetToast, parameterNotSetToast] + [dispatch] ); const sendToImageToImage = useCallback( @@ -449,197 +756,75 @@ export const useRecallParameters = () => { [dispatch] ); - const recallAllParameters = useCallback( - async (metadata: CoreMetadata | undefined) => { + const recallAllParameters = useCallback( + async (metadata: unknown) => { if (!metadata) { allParameterNotSetToast(); return; } - const { - cfg_scale, - cfg_rescale_multiplier, - height, - model, - positive_prompt, - negative_prompt, - scheduler, - vae, - seed, - steps, - width, - strength, - hrf_enabled, - hrf_strength, - hrf_method, - positive_style_prompt, - negative_style_prompt, - refiner_model, - refiner_cfg_scale, - refiner_steps, - refiner_scheduler, - refiner_positive_aesthetic_score, - refiner_negative_aesthetic_score, - refiner_start, - loras, - controlnets, - ipAdapters, - t2iAdapters, - } = metadata; + await recallModelAsync(metadata, false); - let newModel: ParameterModel | undefined = undefined; + recallCFGScale(metadata, false); + recallCFGRescaleMultiplier(metadata, false); + recallPositivePrompt(metadata, false); + recallNegativePrompt(metadata, false); + recallScheduler(metadata, false); + recallSeed(metadata, false); + recallSteps(metadata, false); + recallWidth(metadata, false); + recallHeight(metadata, false); + recallStrength(metadata, false); + recallHRFEnabled(metadata, false); + recallHRFMethod(metadata, false); + recallHRFStrength(metadata, false); - if (isModelIdentifier(model)) { - try { - const _model = await prepareMainModelMetadataItem(model); - dispatch(modelSelected(_model)); - newModel = _model; - } catch { - return; - } - } + // SDXL + recallSDXLPositiveStylePrompt(metadata, false); + recallSDXLNegativeStylePrompt(metadata, false); + recallRefinerSteps(metadata, false); + recallRefinerCFGScale(metadata, false); + recallRefinerScheduler(metadata, false); + recallRefinerPositiveAestheticScore(metadata, false); + recallRefinerNegativeAestheticScore(metadata, false); + recallRefinerStart(metadata, false); - if (isParameterCFGScale(cfg_scale)) { - dispatch(setCfgScale(cfg_scale)); - } - - if (isParameterCFGRescaleMultiplier(cfg_rescale_multiplier)) { - dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier)); - } - - if (isParameterPositivePrompt(positive_prompt)) { - dispatch(setPositivePrompt(positive_prompt)); - } - - if (isParameterNegativePrompt(negative_prompt)) { - dispatch(setNegativePrompt(negative_prompt)); - } - - if (isParameterScheduler(scheduler)) { - dispatch(setScheduler(scheduler)); - } - if (isModelIdentifier(vae) || isNil(vae)) { - if (isNil(vae)) { - dispatch(vaeSelected(null)); - } else { - try { - const _vae = await prepareVAEMetadataItem(vae, newModel?.base); - dispatch(vaeSelected(_vae)); - } catch { - return; - } - } - } - - if (isParameterSeed(seed)) { - dispatch(setSeed(seed)); - } - - if (isParameterSteps(steps)) { - dispatch(setSteps(steps)); - } - - if (isParameterWidth(width)) { - dispatch(widthRecalled(width)); - } - - if (isParameterHeight(height)) { - dispatch(heightRecalled(height)); - } - - if (isParameterStrength(strength)) { - dispatch(setImg2imgStrength(strength)); - } - - if (isParameterHRFEnabled(hrf_enabled)) { - dispatch(setHrfEnabled(hrf_enabled)); - } - - if (isParameterStrength(hrf_strength)) { - dispatch(setHrfStrength(hrf_strength)); - } - - if (isParameterHRFMethod(hrf_method)) { - dispatch(setHrfMethod(hrf_method)); - } - - if (isParameterPositiveStylePromptSDXL(positive_style_prompt)) { - dispatch(setPositiveStylePromptSDXL(positive_style_prompt)); - } - - if (isParameterNegativeStylePromptSDXL(negative_style_prompt)) { - dispatch(setNegativeStylePromptSDXL(negative_style_prompt)); - } - - if (isParameterSDXLRefinerModel(refiner_model)) { - dispatch(refinerModelChanged(refiner_model)); - } - - if (isParameterSteps(refiner_steps)) { - dispatch(setRefinerSteps(refiner_steps)); - } - - if (isParameterCFGScale(refiner_cfg_scale)) { - dispatch(setRefinerCFGScale(refiner_cfg_scale)); - } - - if (isParameterScheduler(refiner_scheduler)) { - dispatch(setRefinerScheduler(refiner_scheduler)); - } - - if (isParameterSDXLRefinerPositiveAestheticScore(refiner_positive_aesthetic_score)) { - dispatch(setRefinerPositiveAestheticScore(refiner_positive_aesthetic_score)); - } - - if (isParameterSDXLRefinerNegativeAestheticScore(refiner_negative_aesthetic_score)) { - dispatch(setRefinerNegativeAestheticScore(refiner_negative_aesthetic_score)); - } - - if (isParameterSDXLRefinerStart(refiner_start)) { - dispatch(setRefinerStart(refiner_start)); - } + await recallVAEAsync(metadata, false); + await recallRefinerModelAsync(metadata, false); dispatch(lorasCleared()); - loras?.forEach(async (loraMetadataItem) => { - try { - const lora = await prepareLoRAMetadataItem(loraMetadataItem, newModel?.base); - dispatch(loraRecalled(lora)); - } catch { - return; - } - }); + const loraMetadataArray = getProperty(metadata, 'loras'); + if (isArray(loraMetadataArray)) { + loraMetadataArray.forEach(async (loraMetadataItem) => { + await recallLoRAAsync(loraMetadataItem, false); + }); + } dispatch(controlAdaptersReset()); - controlnets?.forEach(async (controlNetMetadataItem) => { - try { - const controlNet = await prepareControlNetMetadataItem(controlNetMetadataItem, newModel?.base); - dispatch(controlAdapterRecalled(controlNet)); - } catch { - return; - } - }); + const controlNetMetadataArray = getProperty(metadata, 'controlnets'); + if (isArray(controlNetMetadataArray)) { + controlNetMetadataArray.forEach(async (controlNetMetadataItem) => { + await recallControlNetAsync(controlNetMetadataItem, false); + }); + } - ipAdapters?.forEach(async (ipAdapterMetadataItem) => { - try { - const ipAdapter = await prepareIPAdapterMetadataItem(ipAdapterMetadataItem, newModel?.base); - dispatch(controlAdapterRecalled(ipAdapter)); - } catch { - return; - } - }); + const ipAdapterMetadataArray = getProperty(metadata, 'ipAdapters'); + if (isArray(ipAdapterMetadataArray)) { + ipAdapterMetadataArray.forEach(async (ipAdapterMetadataItem) => { + await recallIPAdapterAsync(ipAdapterMetadataItem, false); + }); + } - t2iAdapters?.forEach(async (t2iAdapterMetadataItem) => { - try { - const t2iAdapter = await prepareT2IAdapterMetadataItem(t2iAdapterMetadataItem, newModel?.base); - dispatch(controlAdapterRecalled(t2iAdapter)); - } catch { - return; - } - }); + const t2iAdapterMetadataArray = getProperty(metadata, 't2iAdapters'); + if (isArray(t2iAdapterMetadataArray)) { + t2iAdapterMetadataArray.forEach(async (t2iAdapterMetadataItem) => { + await recallT2IAdapterAsync(t2iAdapterMetadataItem, false); + }); + } allParameterSetToast(); }, - [dispatch, allParameterSetToast, allParameterNotSetToast] + [dispatch] ); return { @@ -649,24 +834,25 @@ export const useRecallParameters = () => { recallSDXLPositiveStylePrompt, recallSDXLNegativeStylePrompt, recallSeed, - recallCfgScale, - recallCfgRescaleMultiplier, - recallModel, + recallCFGScale, + recallCFGRescaleMultiplier, + recallModel: recallModelAsync, recallScheduler, - recallVaeModel, + recallVaeModel: recallVAEAsync, recallSteps, recallWidth, recallHeight, recallWidthAndHeight, recallStrength, - recallHrfEnabled, - recallHrfStrength, - recallHrfMethod, - recallLoRA, - recallControlNet, - recallIPAdapter, - recallT2IAdapter, + recallHRFEnabled, + recallHRFStrength, + recallHRFMethod, + recallLoRA: recallLoRAAsync, + recallControlNet: recallControlNetAsync, + recallIPAdapter: recallIPAdapterAsync, + recallT2IAdapter: recallT2IAdapterAsync, recallAllParameters, + recallRefinerModel: recallRefinerModelAsync, sendToImageToImage, }; }; diff --git a/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts b/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts index 20953ec266..8d46add8c8 100644 --- a/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts +++ b/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts @@ -1,14 +1,7 @@ import { NUMPY_RAND_MAX } from 'app/constants'; import { - zBaseModel, - zControlNetModelField, - zIPAdapterModelField, - zLoRAModelField, zModelIdentifierWithBase, zSchedulerField, - zSDXLRefinerModelField, - zT2IAdapterModelField, - zVAEModelField, } from 'features/nodes/types/common'; import { z } from 'zod'; @@ -111,42 +104,42 @@ export const isParameterModel = (val: unknown): val is ParameterModel => zParame // #endregion // #region SDXL Refiner Model -export const zParameterSDXLRefinerModel = zSDXLRefinerModelField.extend({ base: zBaseModel }); +export const zParameterSDXLRefinerModel = zModelIdentifierWithBase; export type ParameterSDXLRefinerModel = z.infer; export const isParameterSDXLRefinerModel = (val: unknown): val is ParameterSDXLRefinerModel => zParameterSDXLRefinerModel.safeParse(val).success; // #endregion // #region VAE Model -export const zParameterVAEModel = zVAEModelField.extend({ base: zBaseModel }); +export const zParameterVAEModel = zModelIdentifierWithBase; export type ParameterVAEModel = z.infer; export const isParameterVAEModel = (val: unknown): val is ParameterVAEModel => zParameterVAEModel.safeParse(val).success; // #endregion // #region LoRA Model -export const zParameterLoRAModel = zLoRAModelField.extend({ base: zBaseModel }); +export const zParameterLoRAModel = zModelIdentifierWithBase; export type ParameterLoRAModel = z.infer; export const isParameterLoRAModel = (val: unknown): val is ParameterLoRAModel => zParameterLoRAModel.safeParse(val).success; // #endregion // #region ControlNet Model -export const zParameterControlNetModel = zControlNetModelField.extend({ base: zBaseModel }); +export const zParameterControlNetModel = zModelIdentifierWithBase; export type ParameterControlNetModel = z.infer; export const isParameterControlNetModel = (val: unknown): val is ParameterControlNetModel => zParameterControlNetModel.safeParse(val).success; // #endregion // #region IP Adapter Model -export const zParameterIPAdapterModel = zIPAdapterModelField.extend({ base: zBaseModel }); +export const zParameterIPAdapterModel = zModelIdentifierWithBase; export type ParameterIPAdapterModel = z.infer; export const isParameterIPAdapterModel = (val: unknown): val is ParameterIPAdapterModel => zParameterIPAdapterModel.safeParse(val).success; // #endregion // #region T2I Adapter Model -export const zParameterT2IAdapterModel = zT2IAdapterModelField.extend({ base: zBaseModel }); +export const zParameterT2IAdapterModel = zModelIdentifierWithBase; export type ParameterT2IAdapterModel = z.infer; export const isParameterT2IAdapterModel = (val: unknown): val is ParameterT2IAdapterModel => zParameterT2IAdapterModel.safeParse(val).success; @@ -218,3 +211,9 @@ export type ParameterCanvasCoherenceMode = z.infer zParameterCanvasCoherenceMode.safeParse(val).success; // #endregion + +// #region LoRA weight +export const zLoRAWeight = z.number(); +export type ParameterLoRAWeight = z.infer; +export const isParameterLoRAWeight = (val: unknown): val is ParameterLoRAWeight => zLoRAWeight.safeParse(val).success; +// #endregion diff --git a/invokeai/frontend/web/src/features/parameters/util/modelFetchingHelpers.ts b/invokeai/frontend/web/src/features/parameters/util/modelFetchingHelpers.ts index c7d25fed8b..bd3e42a314 100644 --- a/invokeai/frontend/web/src/features/parameters/util/modelFetchingHelpers.ts +++ b/invokeai/frontend/web/src/features/parameters/util/modelFetchingHelpers.ts @@ -1,7 +1,8 @@ import { getStore } from 'app/store/nanostores/store'; -import { isModelIdentifier } from 'features/nodes/types/common'; +import type { ModelIdentifierWithBase } from 'features/nodes/types/common'; +import { isModelIdentifier, isModelIdentifierV2 } from 'features/nodes/types/common'; import { modelsApi } from 'services/api/endpoints/models'; -import type { AnyModelConfig, BaseModelType } from 'services/api/types'; +import type { AnyModelConfig, BaseModelType, ModelType } from 'services/api/types'; import { isControlNetModelConfig, isIPAdapterModelConfig, @@ -41,6 +42,12 @@ export class InvalidModelConfigError extends Error { } } +/** + * Fetches the model config for a given model key. + * @param key The model key. + * @returns A promise that resolves to the model config. + * @throws {ModelConfigNotFoundError} If the model config is unable to be fetched. + */ export const fetchModelConfig = async (key: string): Promise => { const { dispatch } = getStore(); try { @@ -52,6 +59,37 @@ export const fetchModelConfig = async (key: string): Promise => } }; +/** + * Fetches the model config for a given model name, base model, and model type. This provides backwards compatibility + * for MM1 model identifiers. + * @param name The model name. + * @param base The base model. + * @param type The model type. + * @returns A promise that resolves to the model config. + * @throws {ModelConfigNotFoundError} If the model config is unable to be fetched. + */ +export const fetchModelConfigByAttrs = async ( + name: string, + base: BaseModelType, + type: ModelType +): Promise => { + const { dispatch } = getStore(); + try { + const req = dispatch(modelsApi.endpoints.getModelConfigByAttrs.initiate({ name, base, type })); + req.unsubscribe(); + return await req.unwrap(); + } catch { + throw new ModelConfigNotFoundError(`Unable to retrieve model config for name/base/type ${name}/${base}/${type}`); + } +}; + +/** + * Fetches the model config for a given model key and type, and ensures that the model config is of a specific type. + * @param key The model key. + * @param typeGuard A type guard function that checks if the model config is of the expected type. + * @returns A promise that resolves to the model config. The model config is guaranteed to be of the expected type. + * @throws {InvalidModelConfigError} If the model config is unable to be fetched or is of an unexpected type. + */ export const fetchModelConfigWithTypeGuard = async ( key: string, typeGuard: (config: AnyModelConfig) => config is T @@ -63,15 +101,17 @@ export const fetchModelConfigWithTypeGuard = async ( return modelConfig; }; -export const fetchMainModel = async (key: string) => { +// TODO(psyche): Remove these helpers once `useRecallParameters` is removed + +export const fetchMainModelConfig = async (key: string) => { return fetchModelConfigWithTypeGuard(key, isNonRefinerMainModelConfig); }; -export const fetchRefinerModel = async (key: string) => { +export const fetchRefinerModelConfig = async (key: string) => { return fetchModelConfigWithTypeGuard(key, isRefinerMainModelModelConfig); }; -export const fetchVAEModel = async (key: string) => { +export const fetchVAEModelConfig = async (key: string) => { return fetchModelConfigWithTypeGuard(key, isVAEModelConfig); }; @@ -95,19 +135,39 @@ export const fetchTextualInversionModel = async (key: string) => { return fetchModelConfigWithTypeGuard(key, isTextualInversionModelConfig); }; -export const isBaseCompatible = (sourceBase: BaseModelType, targetBase: BaseModelType) => { - return sourceBase === targetBase; -}; - +/** + * Raises an error if the source base model is incompatible with the target base model. + * @param sourceBase The source base model. + * @param targetBase The target base model. + * @param message An optional custom message to include in the error. + * @throws {InvalidModelConfigError} If the source base model is incompatible with the target base model. + */ export const raiseIfBaseIncompatible = (sourceBase: BaseModelType, targetBase?: BaseModelType, message?: string) => { - if (targetBase && !isBaseCompatible(sourceBase, targetBase)) { + if (targetBase && sourceBase !== targetBase) { throw new InvalidModelConfigError(message || `Incompatible base models: ${sourceBase} and ${targetBase}`); } }; -export const getModelKey = (modelIdentifier: unknown, message?: string): string => { - if (!isModelIdentifier(modelIdentifier)) { - throw new InvalidModelConfigError(message || `Invalid model identifier: ${modelIdentifier}`); +/** + * Fetches the model key from a model identifier. This includes fetching the key for MM1 format model identifiers. + * @param modelIdentifier The model identifier. The MM2 format `{key: string}` simply extracts the key. The MM1 format + * `{model_name: string, base_model: BaseModelType}` must do a network request to fetch the key. + * @param type The type of model to fetch. This is used to fetch the key for MM1 format model identifiers. + * @param message An optional custom message to include in the error if the model identifier is invalid. + * @returns A promise that resolves to the model key. + * @throws {InvalidModelConfigError} If the model identifier is invalid. + */ +export const getModelKey = async (modelIdentifier: unknown, type: ModelType, message?: string): Promise => { + if (isModelIdentifier(modelIdentifier)) { + return modelIdentifier.key; } - return modelIdentifier.key; + if (isModelIdentifierV2(modelIdentifier)) { + return (await fetchModelConfigByAttrs(modelIdentifier.model_name, modelIdentifier.base_model, type)).key; + } + throw new InvalidModelConfigError(message || `Invalid model identifier: ${modelIdentifier}`); }; + +export const getModelKeyAndBase = (modelConfig: AnyModelConfig): ModelIdentifierWithBase => ({ + key: modelConfig.key, + base: modelConfig.base, +}); diff --git a/invokeai/frontend/web/src/features/parameters/util/modelMetadataHelpers.ts b/invokeai/frontend/web/src/features/parameters/util/modelMetadataHelpers.ts deleted file mode 100644 index 722073366f..0000000000 --- a/invokeai/frontend/web/src/features/parameters/util/modelMetadataHelpers.ts +++ /dev/null @@ -1,150 +0,0 @@ -import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants'; -import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types'; -import { - initialControlNet, - initialIPAdapter, - initialT2IAdapter, -} from 'features/controlAdapters/util/buildControlAdapter'; -import type { LoRA } from 'features/lora/store/loraSlice'; -import type { ModelIdentifierWithBase } from 'features/nodes/types/common'; -import { zModelIdentifierWithBase } from 'features/nodes/types/common'; -import type { - ControlNetMetadataItem, - IPAdapterMetadataItem, - LoRAMetadataItem, - T2IAdapterMetadataItem, -} from 'features/nodes/types/metadata'; -import { - fetchControlNetModel, - fetchIPAdapterModel, - fetchLoRAModel, - fetchMainModel, - fetchRefinerModel, - fetchT2IAdapterModel, - fetchVAEModel, - getModelKey, - raiseIfBaseIncompatible, -} from 'features/parameters/util/modelFetchingHelpers'; -import type { BaseModelType } from 'services/api/types'; -import { v4 as uuidv4 } from 'uuid'; - -export const prepareMainModelMetadataItem = async (model: unknown): Promise => { - const key = getModelKey(model); - const mainModel = await fetchMainModel(key); - return zModelIdentifierWithBase.parse(mainModel); -}; - -export const prepareRefinerMetadataItem = async (model: unknown): Promise => { - const key = getModelKey(model); - const refinerModel = await fetchRefinerModel(key); - return zModelIdentifierWithBase.parse(refinerModel); -}; - -export const prepareVAEMetadataItem = async (vae: unknown, base?: BaseModelType): Promise => { - const key = getModelKey(vae); - const vaeModel = await fetchVAEModel(key); - raiseIfBaseIncompatible(vaeModel.base, base, 'VAE incompatible with currently-selected model'); - return zModelIdentifierWithBase.parse(vaeModel); -}; - -export const prepareLoRAMetadataItem = async ( - loraMetadataItem: LoRAMetadataItem, - base?: BaseModelType -): Promise => { - const key = getModelKey(loraMetadataItem.lora); - const loraModel = await fetchLoRAModel(key); - raiseIfBaseIncompatible(loraModel.base, base, 'LoRA incompatible with currently-selected model'); - return { key: loraModel.key, base: loraModel.base, weight: loraMetadataItem.weight, isEnabled: true }; -}; - -export const prepareControlNetMetadataItem = async ( - controlnetMetadataItem: ControlNetMetadataItem, - base?: BaseModelType -): Promise => { - const key = getModelKey(controlnetMetadataItem.control_model); - const controlNetModel = await fetchControlNetModel(key); - raiseIfBaseIncompatible(controlNetModel.base, base, 'ControlNet incompatible with currently-selected model'); - - const { image, control_weight, begin_step_percent, end_step_percent, control_mode, resize_mode } = - controlnetMetadataItem; - - // We don't save the original image that was processed into a control image, only the processed image - const processorType = 'none'; - const processorNode = CONTROLNET_PROCESSORS.none.default; - - const controlnet: ControlNetConfig = { - type: 'controlnet', - isEnabled: true, - model: zModelIdentifierWithBase.parse(controlNetModel), - weight: typeof control_weight === 'number' ? control_weight : initialControlNet.weight, - beginStepPct: begin_step_percent || initialControlNet.beginStepPct, - endStepPct: end_step_percent || initialControlNet.endStepPct, - controlMode: control_mode || initialControlNet.controlMode, - resizeMode: resize_mode || initialControlNet.resizeMode, - controlImage: image?.image_name || null, - processedControlImage: image?.image_name || null, - processorType, - processorNode, - shouldAutoConfig: true, - id: uuidv4(), - }; - - return controlnet; -}; - -export const prepareT2IAdapterMetadataItem = async ( - t2iAdapterMetadataItem: T2IAdapterMetadataItem, - base?: BaseModelType -): Promise => { - const key = getModelKey(t2iAdapterMetadataItem.t2i_adapter_model); - const t2iAdapterModel = await fetchT2IAdapterModel(key); - raiseIfBaseIncompatible(t2iAdapterModel.base, base, 'T2I Adapter incompatible with currently-selected model'); - - const { image, weight, begin_step_percent, end_step_percent, resize_mode } = t2iAdapterMetadataItem; - - // We don't save the original image that was processed into a control image, only the processed image - const processorType = 'none'; - const processorNode = CONTROLNET_PROCESSORS.none.default; - - const t2iAdapter: T2IAdapterConfig = { - type: 't2i_adapter', - isEnabled: true, - model: zModelIdentifierWithBase.parse(t2iAdapterModel), - weight: typeof weight === 'number' ? weight : initialT2IAdapter.weight, - beginStepPct: begin_step_percent || initialT2IAdapter.beginStepPct, - endStepPct: end_step_percent || initialT2IAdapter.endStepPct, - resizeMode: resize_mode || initialT2IAdapter.resizeMode, - controlImage: image?.image_name || null, - processedControlImage: image?.image_name || null, - processorType, - processorNode, - shouldAutoConfig: true, - id: uuidv4(), - }; - - return t2iAdapter; -}; - -export const prepareIPAdapterMetadataItem = async ( - ipAdapterMetadataItem: IPAdapterMetadataItem, - base?: BaseModelType -): Promise => { - const key = getModelKey(ipAdapterMetadataItem?.ip_adapter_model); - const ipAdapterModel = await fetchIPAdapterModel(key); - raiseIfBaseIncompatible(ipAdapterModel.base, base, 'T2I Adapter incompatible with currently-selected model'); - - const { image, weight, begin_step_percent, end_step_percent } = ipAdapterMetadataItem; - - const ipAdapter: IPAdapterConfig = { - id: uuidv4(), - type: 'ip_adapter', - isEnabled: true, - controlImage: image?.image_name ?? null, - model: zModelIdentifierWithBase.parse(ipAdapterModel), - weight: weight ?? initialIPAdapter.weight, - beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct, - endStepPct: end_step_percent ?? initialIPAdapter.endStepPct, - }; - - return ipAdapter; -}; diff --git a/invokeai/frontend/web/src/features/settingsAccordions/components/AdvancedSettingsAccordion/AdvancedSettingsAccordion.tsx b/invokeai/frontend/web/src/features/settingsAccordions/components/AdvancedSettingsAccordion/AdvancedSettingsAccordion.tsx index 8b10d9bddd..05ca73927c 100644 --- a/invokeai/frontend/web/src/features/settingsAccordions/components/AdvancedSettingsAccordion/AdvancedSettingsAccordion.tsx +++ b/invokeai/frontend/web/src/features/settingsAccordions/components/AdvancedSettingsAccordion/AdvancedSettingsAccordion.tsx @@ -25,7 +25,7 @@ const formLabelProps2: FormLabelProps = { export const AdvancedSettingsAccordion = memo(() => { const vaeKey = useAppSelector((state) => state.generation.vae?.key); - const { data: vaeConfig } = useGetModelConfigQuery(vaeKey ?? skipToken); + const { currentData: vaeConfig } = useGetModelConfigQuery(vaeKey ?? skipToken); const selectBadges = useMemo( () => createMemoizedSelector(selectGenerationSlice, (generation) => { diff --git a/invokeai/frontend/web/src/services/api/endpoints/images.ts b/invokeai/frontend/web/src/services/api/endpoints/images.ts index 49eb28390f..1849e2fde4 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/images.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/images.ts @@ -1,10 +1,8 @@ import type { EntityState, Update } from '@reduxjs/toolkit'; import type { PatchCollection } from '@reduxjs/toolkit/dist/query/core/buildThunks'; -import { logger } from 'app/logging/logger'; +import type { JSONObject } from 'common/types'; import type { BoardId } from 'features/gallery/store/types'; import { ASSETS_CATEGORIES, IMAGE_CATEGORIES, IMAGE_LIMIT } from 'features/gallery/store/types'; -import type { CoreMetadata } from 'features/nodes/types/metadata'; -import { zCoreMetadata } from 'features/nodes/types/metadata'; import { addToast } from 'features/system/store/systemSlice'; import { t } from 'i18next'; import { keyBy } from 'lodash-es'; @@ -118,22 +116,9 @@ export const imagesApi = api.injectEndpoints({ providesTags: (result, error, image_name) => [{ type: 'Image', id: image_name }], keepUnusedDataFor: 86400, // 24 hours }), - getImageMetadata: build.query({ + getImageMetadata: build.query({ query: (image_name) => ({ url: buildImagesUrl(`i/${image_name}/metadata`) }), providesTags: (result, error, image_name) => [{ type: 'ImageMetadata', id: image_name }], - transformResponse: ( - response: paths['/api/v1/images/i/{image_name}/metadata']['get']['responses']['200']['content']['application/json'] - ) => { - if (response) { - const result = zCoreMetadata.safeParse(response); - if (result.success) { - return result.data; - } else { - logger('images').warn('Problem parsing metadata'); - } - } - return; - }, keepUnusedDataFor: 86400, // 24 hours }), getImageWorkflow: build.query< diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index 6618fda58e..3ccccf62e1 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -70,6 +70,8 @@ export type ScanFolderResponse = paths['/api/v2/models/scan_folder']['get']['responses']['200']['content']['application/json']; type ScanFolderArg = operations['scan_for_models']['parameters']['query']; +export type GetByAttrsArg = operations['get_model_records_by_attrs']['parameters']['query']; + export const mainModelsAdapter = createEntityAdapter({ selectId: (entity) => entity.key, sortComparer: (a, b) => a.name.localeCompare(b.name), @@ -223,6 +225,18 @@ export const modelsApi = api.injectEndpoints({ return tags; }, }), + getModelConfigByAttrs: build.query({ + query: (arg) => buildModelsUrl(`get_by_attrs?${queryString.stringify(arg)}`), + providesTags: (result) => { + const tags: ApiTagDescription[] = ['Model']; + + if (result) { + tags.push({ type: 'ModelConfig', id: result.key }); + } + + return tags; + }, + }), syncModels: build.mutation({ query: () => { return { @@ -300,6 +314,7 @@ export const modelsApi = api.injectEndpoints({ }); export const { + useGetModelConfigByAttrsQuery, useGetModelConfigQuery, useGetMainModelsQuery, useGetControlNetModelsQuery,