From e174ce038fd9cd547f095a07cdfbebf161b21647 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 24 Feb 2024 01:25:58 +1100 Subject: [PATCH] feat(ui): refactor metadata handling (again) Add concepts for metadata handlers. Handlers include parsers, recallers and validators for different metadata types: - Parsers parse a raw metadata object of any shape to a structured object. - Recallers load the parsed metadata into state. Recallers are optional, as some metadata types don't need to be loaded into state. - Validators provide an additional layer of validation before recalling the metadata. This is needed because a metadata object may be valid, but not able to be recalled due to some other requirement, like base model compatibility. Validators are optional. Sometimes metadata is not a single object but a list of items - like LoRAs. Metadata handlers may implement an optional set of "item" handlers which operate on individual items in the list. Parsers and validators are async to allow fetching additional data, like a model config. Recallers are synchronous. The these handlers are composed into a public API, exported as a `handlers` object. Besides the handlers functions, a metadata handler set includes: - A function to get the label of the metadata type. - An optional function to render the value of the metadata type. - An optional function to render the _item_ value of the metadata type. --- invokeai/frontend/web/public/locales/en.json | 8 +- .../listeners/batchEnqueued.ts | 7 +- .../listeners/bulkDownload.tsx | 8 +- .../frontend/web/src/common/util/toast.ts | 6 + .../ImageMetadataActions.tsx | 322 +---- .../ImageMetadataViewer/ImageMetadataItem.tsx | 51 +- .../src/features/lora/components/LoRACard.tsx | 16 +- .../web/src/features/lora/store/loraSlice.ts | 12 +- .../components/MetadataControlNets.tsx | 66 + .../components/MetadataIPAdapters.tsx | 66 + .../metadata/components/MetadataItem.tsx | 33 + .../metadata/components/MetadataItemView.tsx | 29 + .../metadata/components/MetadataLoRAs.tsx | 61 + .../components/MetadataT2IAdapters.tsx | 66 + .../metadata/components/RecallButton.tsx | 27 + .../web/src/features/metadata/exceptions.ts | 27 + .../metadata/hooks/useMetadataItem.tsx | 51 + .../web/src/features/metadata/types.ts | 136 ++ .../src/features/metadata/util/handlers.tsx | 316 ++++ .../web/src/features/metadata/util/parsers.ts | 396 +++++ .../src/features/metadata/util/recallers.ts | 295 ++++ .../src/features/metadata/util/validators.ts | 117 ++ .../web/src/features/nodes/types/common.ts | 5 + .../web/src/features/nodes/types/metadata.ts | 27 +- .../parameters/hooks/useRecallParameters.ts | 1272 ++++++++++------- .../parameters/types/parameterSchemas.ts | 25 +- .../parameters/util/modelFetchingHelpers.ts | 88 +- .../parameters/util/modelMetadataHelpers.ts | 150 -- .../AdvancedSettingsAccordion.tsx | 2 +- .../web/src/services/api/endpoints/images.ts | 19 +- .../web/src/services/api/endpoints/models.ts | 15 + 31 files changed, 2627 insertions(+), 1092 deletions(-) create mode 100644 invokeai/frontend/web/src/common/util/toast.ts create mode 100644 invokeai/frontend/web/src/features/metadata/components/MetadataControlNets.tsx create mode 100644 invokeai/frontend/web/src/features/metadata/components/MetadataIPAdapters.tsx create mode 100644 invokeai/frontend/web/src/features/metadata/components/MetadataItem.tsx create mode 100644 invokeai/frontend/web/src/features/metadata/components/MetadataItemView.tsx create mode 100644 invokeai/frontend/web/src/features/metadata/components/MetadataLoRAs.tsx create mode 100644 invokeai/frontend/web/src/features/metadata/components/MetadataT2IAdapters.tsx create mode 100644 invokeai/frontend/web/src/features/metadata/components/RecallButton.tsx create mode 100644 invokeai/frontend/web/src/features/metadata/exceptions.ts create mode 100644 invokeai/frontend/web/src/features/metadata/hooks/useMetadataItem.tsx create mode 100644 invokeai/frontend/web/src/features/metadata/types.ts create mode 100644 invokeai/frontend/web/src/features/metadata/util/handlers.tsx create mode 100644 invokeai/frontend/web/src/features/metadata/util/parsers.ts create mode 100644 invokeai/frontend/web/src/features/metadata/util/recallers.ts create mode 100644 invokeai/frontend/web/src/features/metadata/util/validators.ts delete mode 100644 invokeai/frontend/web/src/features/parameters/util/modelMetadataHelpers.ts diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 21a2f6a9fd..e951a294b5 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -656,6 +656,7 @@ } }, "metadata": { + "allPrompts": "All Prompts", "cfgScale": "CFG scale", "cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)", "createdBy": "Created By", @@ -664,6 +665,7 @@ "height": "Height", "hiresFix": "High Resolution Optimization", "imageDetails": "Image Details", + "imageDimensions": "Image Dimensions", "initImage": "Initial image", "metadata": "Metadata", "model": "Model", @@ -671,9 +673,11 @@ "noImageDetails": "No image details found", "noMetaData": "No metadata found", "noRecallParameters": "No parameters to recall found", + "parameterSet": "Parameter {{parameter}} set", "perlin": "Perlin Noise", "positivePrompt": "Positive Prompt", "recallParameters": "Recall Parameters", + "recallParameter": "Recall {{label}}", "scheduler": "Scheduler", "seamless": "Seamless", "seed": "Seed", @@ -1381,8 +1385,8 @@ "nodesNotValidJSON": "Not a valid JSON", "nodesSaved": "Nodes Saved", "nodesUnrecognizedTypes": "Cannot load. Graph has unrecognized types", - "parameterNotSet": "Parameter not set", - "parameterSet": "Parameter set", + "parameterNotSet": "{{parameter}} not set", + "parameterSet": "{{parameter}} set", "parametersFailed": "Problem loading parameters", "parametersFailedDesc": "Unable to load init image.", "parametersNotSet": "Parameters Not Set", diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/batchEnqueued.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/batchEnqueued.ts index 6419f840ec..07a1039bef 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/batchEnqueued.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/batchEnqueued.ts @@ -1,6 +1,6 @@ -import { createStandaloneToast, theme, TOAST_OPTIONS } from '@invoke-ai/ui-library'; import { logger } from 'app/logging/logger'; import { parseify } from 'common/util/serialize'; +import { toast } from 'common/util/toast'; import { zPydanticValidationError } from 'features/system/store/zodSchemas'; import { t } from 'i18next'; import { truncate, upperFirst } from 'lodash-es'; @@ -8,11 +8,6 @@ import { queueApi } from 'services/api/endpoints/queue'; import { startAppListening } from '..'; -const { toast } = createStandaloneToast({ - theme: theme, - defaultOptions: TOAST_OPTIONS.defaultOptions, -}); - export const addBatchEnqueuedListener = () => { // success startAppListening({ diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/bulkDownload.tsx b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/bulkDownload.tsx index 69adfb5b67..fa5c962ff5 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/bulkDownload.tsx +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/bulkDownload.tsx @@ -1,7 +1,8 @@ import type { UseToastOptions } from '@invoke-ai/ui-library'; -import { createStandaloneToast, ExternalLink, theme, TOAST_OPTIONS } from '@invoke-ai/ui-library'; +import { ExternalLink } from '@invoke-ai/ui-library'; import { logger } from 'app/logging/logger'; import { startAppListening } from 'app/store/middleware/listenerMiddleware'; +import { toast } from 'common/util/toast'; import { t } from 'i18next'; import { imagesApi } from 'services/api/endpoints/images'; import { @@ -12,11 +13,6 @@ import { const log = logger('images'); -const { toast } = createStandaloneToast({ - theme: theme, - defaultOptions: TOAST_OPTIONS.defaultOptions, -}); - export const addBulkDownloadListeners = () => { startAppListening({ matcher: imagesApi.endpoints.bulkDownloadImages.matchFulfilled, diff --git a/invokeai/frontend/web/src/common/util/toast.ts b/invokeai/frontend/web/src/common/util/toast.ts new file mode 100644 index 0000000000..ac61a4a12d --- /dev/null +++ b/invokeai/frontend/web/src/common/util/toast.ts @@ -0,0 +1,6 @@ +import { createStandaloneToast, theme, TOAST_OPTIONS } from '@invoke-ai/ui-library'; + +export const { toast } = createStandaloneToast({ + theme: theme, + defaultOptions: TOAST_OPTIONS.defaultOptions, +}); diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx index 7eec7e1875..cfa679c805 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx @@ -1,300 +1,54 @@ -import { isModelIdentifier } from 'features/nodes/types/common'; -import type { - ControlNetMetadataItem, - CoreMetadata, - IPAdapterMetadataItem, - LoRAMetadataItem, - T2IAdapterMetadataItem, -} from 'features/nodes/types/metadata'; -import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters'; -import { memo, useCallback, useMemo } from 'react'; -import { useTranslation } from 'react-i18next'; - -import ImageMetadataItem, { ModelMetadataItem, VAEMetadataItem } from './ImageMetadataItem'; +import { MetadataControlNets } from 'features/metadata/components/MetadataControlNets'; +import { MetadataIPAdapters } from 'features/metadata/components/MetadataIPAdapters'; +import { MetadataItem } from 'features/metadata/components/MetadataItem'; +import { MetadataLoRAs } from 'features/metadata/components/MetadataLoRAs'; +import { MetadataT2IAdapters } from 'features/metadata/components/MetadataT2IAdapters'; +import { handlers } from 'features/metadata/util/handlers'; +import { memo } from 'react'; type Props = { - metadata?: CoreMetadata; + metadata?: unknown; }; const ImageMetadataActions = (props: Props) => { const { metadata } = props; - const { t } = useTranslation(); - - const { - recallPositivePrompt, - recallNegativePrompt, - recallSeed, - recallCfgScale, - recallCfgRescaleMultiplier, - recallModel, - recallScheduler, - recallVaeModel, - recallSteps, - recallWidth, - recallHeight, - recallStrength, - recallHrfEnabled, - recallHrfStrength, - recallHrfMethod, - recallLoRA, - recallControlNet, - recallIPAdapter, - recallT2IAdapter, - recallSDXLPositiveStylePrompt, - recallSDXLNegativeStylePrompt, - } = useRecallParameters(); - - const handleRecallPositivePrompt = useCallback(() => { - recallPositivePrompt(metadata?.positive_prompt); - }, [metadata?.positive_prompt, recallPositivePrompt]); - - const handleRecallNegativePrompt = useCallback(() => { - recallNegativePrompt(metadata?.negative_prompt); - }, [metadata?.negative_prompt, recallNegativePrompt]); - - const handleRecallSDXLPositiveStylePrompt = useCallback(() => { - recallSDXLPositiveStylePrompt(metadata?.positive_style_prompt); - }, [metadata?.positive_style_prompt, recallSDXLPositiveStylePrompt]); - - const handleRecallSDXLNegativeStylePrompt = useCallback(() => { - recallSDXLNegativeStylePrompt(metadata?.negative__style_prompt); - }, [metadata?.negative__style_prompt, recallSDXLNegativeStylePrompt]); - - const handleRecallSeed = useCallback(() => { - recallSeed(metadata?.seed); - }, [metadata?.seed, recallSeed]); - - const handleRecallModel = useCallback(() => { - recallModel(metadata?.model); - }, [metadata?.model, recallModel]); - - const handleRecallWidth = useCallback(() => { - recallWidth(metadata?.width); - }, [metadata?.width, recallWidth]); - - const handleRecallHeight = useCallback(() => { - recallHeight(metadata?.height); - }, [metadata?.height, recallHeight]); - - const handleRecallScheduler = useCallback(() => { - recallScheduler(metadata?.scheduler); - }, [metadata?.scheduler, recallScheduler]); - - const handleRecallVaeModel = useCallback(() => { - recallVaeModel(metadata?.vae); - }, [metadata?.vae, recallVaeModel]); - - const handleRecallSteps = useCallback(() => { - recallSteps(metadata?.steps); - }, [metadata?.steps, recallSteps]); - - const handleRecallCfgScale = useCallback(() => { - recallCfgScale(metadata?.cfg_scale); - }, [metadata?.cfg_scale, recallCfgScale]); - - const handleRecallCfgRescaleMultiplier = useCallback(() => { - recallCfgRescaleMultiplier(metadata?.cfg_rescale_multiplier); - }, [metadata?.cfg_rescale_multiplier, recallCfgRescaleMultiplier]); - - const handleRecallStrength = useCallback(() => { - recallStrength(metadata?.strength); - }, [metadata?.strength, recallStrength]); - - const handleRecallHrfEnabled = useCallback(() => { - recallHrfEnabled(metadata?.hrf_enabled); - }, [metadata?.hrf_enabled, recallHrfEnabled]); - - const handleRecallHrfStrength = useCallback(() => { - recallHrfStrength(metadata?.hrf_strength); - }, [metadata?.hrf_strength, recallHrfStrength]); - - const handleRecallHrfMethod = useCallback(() => { - recallHrfMethod(metadata?.hrf_method); - }, [metadata?.hrf_method, recallHrfMethod]); - - const handleRecallLoRA = useCallback( - (lora: LoRAMetadataItem) => { - recallLoRA(lora); - }, - [recallLoRA] - ); - - const handleRecallControlNet = useCallback( - (controlnet: ControlNetMetadataItem) => { - recallControlNet(controlnet); - }, - [recallControlNet] - ); - - const handleRecallIPAdapter = useCallback( - (ipAdapter: IPAdapterMetadataItem) => { - recallIPAdapter(ipAdapter); - }, - [recallIPAdapter] - ); - - const handleRecallT2IAdapter = useCallback( - (ipAdapter: T2IAdapterMetadataItem) => { - recallT2IAdapter(ipAdapter); - }, - [recallT2IAdapter] - ); - - const validControlNets: ControlNetMetadataItem[] = useMemo(() => { - return metadata?.controlnets - ? metadata.controlnets.filter((controlnet) => isModelIdentifier(controlnet.control_model)) - : []; - }, [metadata?.controlnets]); - - const validIPAdapters: IPAdapterMetadataItem[] = useMemo(() => { - return metadata?.ipAdapters - ? metadata.ipAdapters.filter((ipAdapter) => isModelIdentifier(ipAdapter.ip_adapter_model)) - : []; - }, [metadata?.ipAdapters]); - - const validT2IAdapters: T2IAdapterMetadataItem[] = useMemo(() => { - return metadata?.t2iAdapters - ? metadata.t2iAdapters.filter((t2iAdapter) => isModelIdentifier(t2iAdapter.t2i_adapter_model)) - : []; - }, [metadata?.t2iAdapters]); - if (!metadata || Object.keys(metadata).length === 0) { return null; } return ( <> - {metadata.created_by && } - {metadata.generation_mode && ( - - )} - {metadata.positive_prompt && ( - - )} - {metadata.negative_prompt && ( - - )} - {metadata.positive_style_prompt && ( - - )} - {metadata.negative_style_prompt && ( - - )} - {metadata.seed !== undefined && metadata.seed !== null && ( - - )} - {metadata.model !== undefined && metadata.model !== null && metadata.model.key && ( - - )} - {metadata.width && ( - - )} - {metadata.height && ( - - )} - {metadata.scheduler && ( - - )} - - {metadata.steps && ( - - )} - {metadata.cfg_scale !== undefined && metadata.cfg_scale !== null && ( - - )} - {metadata.cfg_rescale_multiplier !== undefined && metadata.cfg_rescale_multiplier !== null && ( - - )} - {metadata.strength && ( - - )} - {metadata.hrf_enabled && ( - - )} - {metadata.hrf_enabled && metadata.hrf_strength && ( - - )} - {metadata.hrf_enabled && metadata.hrf_method && ( - - )} - {metadata.loras && - metadata.loras.map((lora, index) => { - if (isModelIdentifier(lora.lora)) { - return ( - - ); - } - })} - {validControlNets.map((controlnet, index) => ( - - ))} - {validIPAdapters.map((ipAdapter, index) => ( - - ))} - {validT2IAdapters.map((t2iAdapter, index) => ( - - ))} + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + ); }; diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataItem.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataItem.tsx index ce25b54e59..146c1437d1 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataItem.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataItem.tsx @@ -1,28 +1,39 @@ import { ExternalLink, Flex, IconButton, Text, Tooltip } from '@invoke-ai/ui-library'; import { skipToken } from '@reduxjs/toolkit/query'; +import { get } from 'lodash-es'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import { IoArrowUndoCircleOutline } from 'react-icons/io5'; -import { PiCopyBold } from 'react-icons/pi'; +import { PiArrowBendUpLeftBold } from 'react-icons/pi'; import { useGetModelConfigQuery } from 'services/api/endpoints/models'; type MetadataItemProps = { isLink?: boolean; label: string; - onClick?: () => void; - value: number | string | boolean; + metadata: unknown; + propertyName: string; + onRecall?: (value: unknown) => void; labelPosition?: string; - withCopy?: boolean; }; /** * Component to display an individual metadata item or parameter. */ -const ImageMetadataItem = ({ label, value, onClick, isLink, labelPosition, withCopy = false }: MetadataItemProps) => { +const ImageMetadataItem = ({ + label, + metadata, + propertyName, + onRecall: _onRecall, + isLink, + labelPosition, +}: MetadataItemProps) => { const { t } = useTranslation(); - const handleCopy = useCallback(() => { - navigator.clipboard.writeText(value?.toString()); - }, [value]); + const value = useMemo(() => get(metadata, propertyName), [metadata, propertyName]); + const onRecall = useCallback(() => { + if (!_onRecall) { + return; + } + _onRecall(value); + }, [_onRecall, value]); if (!value) { return null; @@ -30,27 +41,15 @@ const ImageMetadataItem = ({ label, value, onClick, isLink, labelPosition, withC return ( - {onClick && ( - + {_onRecall && ( + } + aria-label={t('metadata.recallParameter', { parameter: label })} + icon={} size="xs" variant="ghost" fontSize={20} - onClick={onClick} - /> - - )} - {withCopy && ( - - } - size="xs" - variant="ghost" - fontSize={14} - onClick={handleCopy} + onClick={onRecall} /> )} diff --git a/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx b/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx index b508c9424d..fd71ce3e19 100644 --- a/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx +++ b/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx @@ -23,29 +23,29 @@ type LoRACardProps = { export const LoRACard = memo((props: LoRACardProps) => { const { lora } = props; const dispatch = useAppDispatch(); - const { data: loraConfig } = useGetModelConfigQuery(lora.key); + const { data: loraConfig } = useGetModelConfigQuery(lora.model.key); const handleChange = useCallback( (v: number) => { - dispatch(loraWeightChanged({ key: lora.key, weight: v })); + dispatch(loraWeightChanged({ key: lora.model.key, weight: v })); }, - [dispatch, lora.key] + [dispatch, lora.model.key] ); const handleSetLoraToggle = useCallback(() => { - dispatch(loraIsEnabledChanged({ key: lora.key, isEnabled: !lora.isEnabled })); - }, [dispatch, lora.key, lora.isEnabled]); + dispatch(loraIsEnabledChanged({ key: lora.model.key, isEnabled: !lora.isEnabled })); + }, [dispatch, lora.model.key, lora.isEnabled]); const handleRemoveLora = useCallback(() => { - dispatch(loraRemoved(lora.key)); - }, [dispatch, lora.key]); + dispatch(loraRemoved(lora.model.key)); + }, [dispatch, lora.model.key]); return ( - {loraConfig?.name ?? lora.key.substring(0, 8)} + {loraConfig?.name ?? lora.model.key.substring(0, 8)} diff --git a/invokeai/frontend/web/src/features/lora/store/loraSlice.ts b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts index 377406b3e5..32bef5fd9b 100644 --- a/invokeai/frontend/web/src/features/lora/store/loraSlice.ts +++ b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts @@ -2,9 +2,11 @@ import type { PayloadAction } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit'; import type { PersistConfig, RootState } from 'app/store/store'; import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas'; +import { getModelKeyAndBase } from 'features/parameters/util/modelFetchingHelpers'; import type { LoRAModelConfig } from 'services/api/types'; -export type LoRA = ParameterLoRAModel & { +export type LoRA = { + model: ParameterLoRAModel; weight: number; isEnabled?: boolean; }; @@ -29,11 +31,11 @@ export const loraSlice = createSlice({ initialState: initialLoraState, reducers: { loraAdded: (state, action: PayloadAction) => { - const { key, base } = action.payload; - state.loras[key] = { key, base, ...defaultLoRAConfig }; + const model = getModelKeyAndBase(action.payload); + state.loras[model.key] = { ...defaultLoRAConfig, model }; }, loraRecalled: (state, action: PayloadAction) => { - state.loras[action.payload.key] = action.payload; + state.loras[action.payload.model.key] = action.payload; }, loraRemoved: (state, action: PayloadAction) => { const key = action.payload; @@ -58,7 +60,7 @@ export const loraSlice = createSlice({ } lora.weight = defaultLoRAConfig.weight; }, - loraIsEnabledChanged: (state, action: PayloadAction>) => { + loraIsEnabledChanged: (state, action: PayloadAction<{ key: string; isEnabled: boolean }>) => { const { key, isEnabled } = action.payload; const lora = state.loras[key]; if (!lora) { diff --git a/invokeai/frontend/web/src/features/metadata/components/MetadataControlNets.tsx b/invokeai/frontend/web/src/features/metadata/components/MetadataControlNets.tsx new file mode 100644 index 0000000000..2e8e3b6f9a --- /dev/null +++ b/invokeai/frontend/web/src/features/metadata/components/MetadataControlNets.tsx @@ -0,0 +1,66 @@ +import type { ControlNetConfig } from 'features/controlAdapters/store/types'; +import { MetadataItemView } from 'features/metadata/components/MetadataItemView'; +import type { MetadataHandlers } from 'features/metadata/types'; +import { handlers } from 'features/metadata/util/handlers'; +import { useCallback, useEffect, useMemo, useState } from 'react'; + +type Props = { + metadata: unknown; +}; + +export const MetadataControlNets = ({ metadata }: Props) => { + const [controlNets, setControlNets] = useState([]); + + useEffect(() => { + const parse = async () => { + try { + const parsed = await handlers.controlNets.parse(metadata); + setControlNets(parsed); + } catch (e) { + setControlNets([]); + } + }; + parse(); + }, [metadata]); + + const label = useMemo(() => handlers.controlNets.getLabel(), []); + + return ( + <> + {controlNets.map((controlNet) => ( + + ))} + + ); +}; + +const MetadataViewControlNet = ({ + label, + controlNet, + handlers, +}: { + label: string; + controlNet: ControlNetConfig; + handlers: MetadataHandlers; +}) => { + const onRecall = useCallback(() => { + if (!handlers.recallItem) { + return; + } + handlers.recallItem(controlNet, true); + }, [handlers, controlNet]); + + const renderedValue = useMemo(() => { + if (!handlers.renderItemValue) { + return null; + } + return handlers.renderItemValue(controlNet); + }, [handlers, controlNet]); + + return ; +}; diff --git a/invokeai/frontend/web/src/features/metadata/components/MetadataIPAdapters.tsx b/invokeai/frontend/web/src/features/metadata/components/MetadataIPAdapters.tsx new file mode 100644 index 0000000000..fef281cd09 --- /dev/null +++ b/invokeai/frontend/web/src/features/metadata/components/MetadataIPAdapters.tsx @@ -0,0 +1,66 @@ +import type { IPAdapterConfig } from 'features/controlAdapters/store/types'; +import { MetadataItemView } from 'features/metadata/components/MetadataItemView'; +import type { MetadataHandlers } from 'features/metadata/types'; +import { handlers } from 'features/metadata/util/handlers'; +import { useCallback, useEffect, useMemo, useState } from 'react'; + +type Props = { + metadata: unknown; +}; + +export const MetadataIPAdapters = ({ metadata }: Props) => { + const [ipAdapters, setIPAdapters] = useState([]); + + useEffect(() => { + const parse = async () => { + try { + const parsed = await handlers.ipAdapters.parse(metadata); + setIPAdapters(parsed); + } catch (e) { + setIPAdapters([]); + } + }; + parse(); + }, [metadata]); + + const label = useMemo(() => handlers.ipAdapters.getLabel(), []); + + return ( + <> + {ipAdapters.map((ipAdapter) => ( + + ))} + + ); +}; + +const MetadataViewIPAdapter = ({ + label, + ipAdapter, + handlers, +}: { + label: string; + ipAdapter: IPAdapterConfig; + handlers: MetadataHandlers; +}) => { + const onRecall = useCallback(() => { + if (!handlers.recallItem) { + return; + } + handlers.recallItem(ipAdapter, true); + }, [handlers, ipAdapter]); + + const renderedValue = useMemo(() => { + if (!handlers.renderItemValue) { + return null; + } + return handlers.renderItemValue(ipAdapter); + }, [handlers, ipAdapter]); + + return ; +}; diff --git a/invokeai/frontend/web/src/features/metadata/components/MetadataItem.tsx b/invokeai/frontend/web/src/features/metadata/components/MetadataItem.tsx new file mode 100644 index 0000000000..66d101f458 --- /dev/null +++ b/invokeai/frontend/web/src/features/metadata/components/MetadataItem.tsx @@ -0,0 +1,33 @@ +import { typedMemo } from '@invoke-ai/ui-library'; +import { MetadataItemView } from 'features/metadata/components/MetadataItemView'; +import { useMetadataItem } from 'features/metadata/hooks/useMetadataItem'; +import type { MetadataHandlers } from 'features/metadata/types'; +import { MetadataParseFailedToken } from 'features/metadata/util/parsers'; + +type MetadataItemProps = { + metadata: unknown; + handlers: MetadataHandlers; + direction?: 'row' | 'column'; +}; + +const _MetadataItem = typedMemo(({ metadata, handlers, direction = 'row' }: MetadataItemProps) => { + const { label, isDisabled, value, renderedValue, onRecall } = useMetadataItem(metadata, handlers); + + if (value === MetadataParseFailedToken) { + return null; + } + + return ( + + ); +}); + +export const MetadataItem = typedMemo(_MetadataItem); + +MetadataItem.displayName = 'MetadataItem'; diff --git a/invokeai/frontend/web/src/features/metadata/components/MetadataItemView.tsx b/invokeai/frontend/web/src/features/metadata/components/MetadataItemView.tsx new file mode 100644 index 0000000000..14ccc4ae2b --- /dev/null +++ b/invokeai/frontend/web/src/features/metadata/components/MetadataItemView.tsx @@ -0,0 +1,29 @@ +import { Flex, Text } from '@invoke-ai/ui-library'; +import { RecallButton } from 'features/metadata/components/RecallButton'; +import { memo } from 'react'; + +type MetadataItemViewProps = { + onRecall: () => void; + label: string; + renderedValue: React.ReactNode; + isDisabled: boolean; + direction?: 'row' | 'column'; +}; + +export const MetadataItemView = memo( + ({ label, onRecall, isDisabled, renderedValue, direction = 'row' }: MetadataItemViewProps) => { + return ( + + {onRecall && } + + + {label}: + + {renderedValue} + + + ); + } +); + +MetadataItemView.displayName = 'MetadataItemView'; diff --git a/invokeai/frontend/web/src/features/metadata/components/MetadataLoRAs.tsx b/invokeai/frontend/web/src/features/metadata/components/MetadataLoRAs.tsx new file mode 100644 index 0000000000..3225a048d4 --- /dev/null +++ b/invokeai/frontend/web/src/features/metadata/components/MetadataLoRAs.tsx @@ -0,0 +1,61 @@ +import type { LoRA } from 'features/lora/store/loraSlice'; +import { MetadataItemView } from 'features/metadata/components/MetadataItemView'; +import type { MetadataHandlers } from 'features/metadata/types'; +import { handlers } from 'features/metadata/util/handlers'; +import { useCallback, useEffect, useMemo, useState } from 'react'; + +type Props = { + metadata: unknown; +}; + +export const MetadataLoRAs = ({ metadata }: Props) => { + const [loras, setLoRAs] = useState([]); + + useEffect(() => { + const parse = async () => { + try { + const parsed = await handlers.loras.parse(metadata); + setLoRAs(parsed); + } catch (e) { + setLoRAs([]); + } + }; + parse(); + }, [metadata]); + + const label = useMemo(() => handlers.loras.getLabel(), []); + + return ( + <> + {loras.map((lora) => ( + + ))} + + ); +}; + +const MetadataViewLoRA = ({ + label, + lora, + handlers, +}: { + label: string; + lora: LoRA; + handlers: MetadataHandlers; +}) => { + const onRecall = useCallback(() => { + if (!handlers.recallItem) { + return; + } + handlers.recallItem(lora, true); + }, [handlers, lora]); + + const renderedValue = useMemo(() => { + if (!handlers.renderItemValue) { + return null; + } + return handlers.renderItemValue(lora); + }, [handlers, lora]); + + return ; +}; diff --git a/invokeai/frontend/web/src/features/metadata/components/MetadataT2IAdapters.tsx b/invokeai/frontend/web/src/features/metadata/components/MetadataT2IAdapters.tsx new file mode 100644 index 0000000000..4a73fb3ccb --- /dev/null +++ b/invokeai/frontend/web/src/features/metadata/components/MetadataT2IAdapters.tsx @@ -0,0 +1,66 @@ +import type { T2IAdapterConfig } from 'features/controlAdapters/store/types'; +import { MetadataItemView } from 'features/metadata/components/MetadataItemView'; +import type { MetadataHandlers } from 'features/metadata/types'; +import { handlers } from 'features/metadata/util/handlers'; +import { useCallback, useEffect, useMemo, useState } from 'react'; + +type Props = { + metadata: unknown; +}; + +export const MetadataT2IAdapters = ({ metadata }: Props) => { + const [t2iAdapters, setT2IAdapters] = useState([]); + + useEffect(() => { + const parse = async () => { + try { + const parsed = await handlers.t2iAdapters.parse(metadata); + setT2IAdapters(parsed); + } catch (e) { + setT2IAdapters([]); + } + }; + parse(); + }, [metadata]); + + const label = useMemo(() => handlers.t2iAdapters.getLabel(), []); + + return ( + <> + {t2iAdapters.map((t2iAdapter) => ( + + ))} + + ); +}; + +const MetadataViewT2IAdapter = ({ + label, + t2iAdapter, + handlers, +}: { + label: string; + t2iAdapter: T2IAdapterConfig; + handlers: MetadataHandlers; +}) => { + const onRecall = useCallback(() => { + if (!handlers.recallItem) { + return; + } + handlers.recallItem(t2iAdapter, true); + }, [handlers, t2iAdapter]); + + const renderedValue = useMemo(() => { + if (!handlers.renderItemValue) { + return null; + } + return handlers.renderItemValue(t2iAdapter); + }, [handlers, t2iAdapter]); + + return ; +}; diff --git a/invokeai/frontend/web/src/features/metadata/components/RecallButton.tsx b/invokeai/frontend/web/src/features/metadata/components/RecallButton.tsx new file mode 100644 index 0000000000..8803945b06 --- /dev/null +++ b/invokeai/frontend/web/src/features/metadata/components/RecallButton.tsx @@ -0,0 +1,27 @@ +import type { IconButtonProps} from '@invoke-ai/ui-library'; +import { IconButton, Tooltip } from '@invoke-ai/ui-library'; +import { memo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { PiArrowBendUpLeftBold } from 'react-icons/pi'; + +type MetadataItemProps = Omit & { + label: string; +}; + +export const RecallButton = memo(({ label, ...rest }: MetadataItemProps) => { + const { t } = useTranslation(); + + return ( + + } + size="xs" + variant="ghost" + {...rest} + /> + + ); +}); + +RecallButton.displayName = 'RecallButton'; diff --git a/invokeai/frontend/web/src/features/metadata/exceptions.ts b/invokeai/frontend/web/src/features/metadata/exceptions.ts new file mode 100644 index 0000000000..ffe4e98c79 --- /dev/null +++ b/invokeai/frontend/web/src/features/metadata/exceptions.ts @@ -0,0 +1,27 @@ +/** + * Raised when metadata parsing fails. + */ +export class MetadataParseError extends Error { + /** + * Create MetadataParseError + * @param {String} message + */ + constructor(message: string) { + super(message); + this.name = this.constructor.name; + } +} + +/** + * Raised when metadata recall fails. + */ +export class MetadataRecallError extends Error { + /** + * Create MetadataRecallError + * @param {String} message + */ + constructor(message: string) { + super(message); + this.name = this.constructor.name; + } +} diff --git a/invokeai/frontend/web/src/features/metadata/hooks/useMetadataItem.tsx b/invokeai/frontend/web/src/features/metadata/hooks/useMetadataItem.tsx new file mode 100644 index 0000000000..178ac5155a --- /dev/null +++ b/invokeai/frontend/web/src/features/metadata/hooks/useMetadataItem.tsx @@ -0,0 +1,51 @@ +import { Text } from '@invoke-ai/ui-library'; +import type { MetadataHandlers } from 'features/metadata/types'; +import { MetadataParseFailedToken, MetadataParsePendingToken } from 'features/metadata/util/parsers'; +import { useCallback, useEffect, useMemo, useState } from 'react'; + +export const useMetadataItem = (metadata: unknown, handlers: MetadataHandlers) => { + const [value, setValue] = useState( + MetadataParsePendingToken + ); + + useEffect(() => { + const _parse = async () => { + try { + const parsed = await handlers.parse(metadata); + setValue(parsed); + } catch (e) { + setValue(MetadataParseFailedToken); + } + }; + _parse(); + }, [handlers, metadata]); + + const isDisabled = useMemo(() => value === MetadataParsePendingToken || value === MetadataParseFailedToken, [value]); + + const label = useMemo(() => handlers.getLabel(), [handlers]); + + const renderedValue = useMemo(() => { + if (value === MetadataParsePendingToken) { + return Loading; + } + if (value === MetadataParseFailedToken) { + return Parsing Failed; + } + + const rendered = handlers.renderValue(value); + + if (typeof rendered === 'string') { + return {rendered}; + } + return rendered; + }, [handlers, value]); + + const onRecall = useCallback(() => { + if (!handlers.recall || value === MetadataParsePendingToken || value === MetadataParseFailedToken) { + return null; + } + handlers.recall(value, true); + }, [handlers, value]); + + return { label, isDisabled, value, renderedValue, onRecall }; +}; diff --git a/invokeai/frontend/web/src/features/metadata/types.ts b/invokeai/frontend/web/src/features/metadata/types.ts new file mode 100644 index 0000000000..366f11701e --- /dev/null +++ b/invokeai/frontend/web/src/features/metadata/types.ts @@ -0,0 +1,136 @@ +/** + * Renders a value of type T as a React node. + */ +export type MetadataRenderValueFunc = (value: T) => React.ReactNode; + +/** + * Gets the label of the current metadata item as a string. + */ +export type MetadataGetLabelFunc = () => string; + +export type MetadataParseOptions = { + toastOnFailure?: boolean; + toastOnSuccess?: boolean; +}; + +export type MetadataRecallOptions = MetadataParseOptions; + +/** + * A function that recalls a parsed and validated metadata value. + * + * @param value The value to recall. + * @throws MetadataRecallError if the value cannot be recalled. + */ +export type MetadataRecallFunc = (value: T) => void; + +/** + * An async function that receives metadata and returns a parsed value, throwing if the value is invalid or missing. + * + * The function receives an object of unknown type. It is responsible for extracting the relevant data from the metadata + * and returning a value of type T. + * + * The function should throw a MetadataParseError if the metadata is invalid or missing. + * + * @param metadata The metadata to parse. + * @returns A promise that resolves to the parsed value. + * @throws MetadataParseError if the metadata is invalid or missing. + */ +export type MetadataParseFunc = (metadata: unknown) => Promise; + +/** + * A function that performs additional validation logic before recalling a metadata value. It is called with a parsed + * value and should throw if the validation logic fails. + * + * This function is used in cases where some additional logic is required before recalling. For example, when recalling + * a LoRA, we need to check if it is compatible with the current base model. + * + * @param value The value to validate. + * @returns A promise that resolves to the validated value. + * @throws MetadataRecallError if the value is invalid. + */ +export type MetadataValidateFunc = (value: T) => Promise; + +export type MetadataHandlers = { + /** + * Gets the label of the current metadata item as a string. + * + * @returns The label of the current metadata item. + */ + getLabel: MetadataGetLabelFunc; + /** + * An async function that receives metadata and returns a parsed metadata value. + * + * @param metadata The metadata to parse. + * @param withToast Whether to show a toast on success or failure. + * @returns A promise that resolves to the parsed value. + * @throws MetadataParseError if the metadata is invalid or missing. + */ + parse: (metadata: unknown, withToast?: boolean) => Promise; + /** + * An async function that receives a metadata item and returns a parsed metadata item value. + * + * This is only provided if the metadata value is an array. + * + * @param item The item to parse. It should be an item from the array. + * @param withToast Whether to show a toast on success or failure. + * @returns A promise that resolves to the parsed value. + * @throws MetadataParseError if the metadata is invalid or missing. + */ + parseItem?: (item: unknown, withToast?: boolean) => Promise; + /** + * An async function that recalls a parsed metadata value. + * + * This function is only provided if the metadata value can be recalled. + * + * @param value The value to recall. + * @param withToast Whether to show a toast on success or failure. + * @returns A promise that resolves when the recall operation is complete. + * @throws MetadataRecallError if the value cannot be recalled. + */ + recall?: (value: TValue, withToast?: boolean) => Promise; + /** + * An async function that recalls a parsed metadata item value. + * + * This function is only provided if the metadata value is an array and the items can be recalled. + * + * @param item The item to recall. It should be an item from the array. + * @param withToast Whether to show a toast on success or failure. + * @returns A promise that resolves when the recall operation is complete. + * @throws MetadataRecallError if the value cannot be recalled. + */ + recallItem?: (item: TItem, withToast?: boolean) => Promise; + /** + * Renders a parsed metadata value as a React node. + * + * @param value The value to render. + * @returns The rendered value. + */ + renderValue: MetadataRenderValueFunc; + /** + * Renders a parsed metadata item value as a React node. + * + * @param item The item to render. + * @returns The rendered item. + */ + renderItemValue?: MetadataRenderValueFunc; +}; + +// TODO(psyche): The types for item handlers should be able to be inferred from the type of the value: +// type MetadataHandlersInferItem = TValue extends Array ? MetadataParseFunc : never +// While this works for the types as expected, I couldn't satisfy TS in the implementations of the handlers. + +export type BuildMetadataHandlersArg = { + parser: MetadataParseFunc; + itemParser?: MetadataParseFunc; + recaller?: MetadataRecallFunc; + itemRecaller?: MetadataRecallFunc; + validator?: MetadataValidateFunc; + itemValidator?: MetadataValidateFunc; + getLabel: MetadataGetLabelFunc; + renderValue?: MetadataRenderValueFunc; + renderItemValue?: MetadataRenderValueFunc; +}; + +export type BuildMetadataHandlers = ( + arg: BuildMetadataHandlersArg +) => MetadataHandlers; diff --git a/invokeai/frontend/web/src/features/metadata/util/handlers.tsx b/invokeai/frontend/web/src/features/metadata/util/handlers.tsx new file mode 100644 index 0000000000..1f86257d42 --- /dev/null +++ b/invokeai/frontend/web/src/features/metadata/util/handlers.tsx @@ -0,0 +1,316 @@ +import { Text } from '@invoke-ai/ui-library'; +import { toast } from 'common/util/toast'; +import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types'; +import type { LoRA } from 'features/lora/store/loraSlice'; +import type { + BuildMetadataHandlers, + MetadataGetLabelFunc, + MetadataHandlers, + MetadataParseFunc, + MetadataRecallFunc, + MetadataRenderValueFunc, + MetadataValidateFunc, +} from 'features/metadata/types'; +import { validators } from 'features/metadata/util/validators'; +import { t } from 'i18next'; +import type { AnyModelConfig } from 'services/api/types'; + +import { parsers } from './parsers'; +import { recallers } from './recallers'; + +const renderModelConfigValue: MetadataRenderValueFunc = (value) => + `${value.name} (${value.base.toUpperCase()}, ${value.key})`; +const renderLoRAValue: MetadataRenderValueFunc = (value) => {`${value.model.key} (${value.weight})`}; +const renderControlAdapterValue: MetadataRenderValueFunc = ( + value +) => {`${value.model?.key} (${value.weight})`}; + +const parameterSetToast = (parameter: string, description?: string) => { + toast({ + title: t('toast.parameterSet', { parameter }), + description, + status: 'info', + duration: 2500, + isClosable: true, + }); +}; + +const parameterNotSetToast = (parameter: string, description?: string) => { + toast({ + title: t('toast.parameterNotSet', { parameter }), + description, + status: 'warning', + duration: 2500, + isClosable: true, + }); +}; + +// const allParameterSetToast = (description?: string) => { +// toast({ +// title: t('toast.parametersSet'), +// status: 'info', +// description, +// duration: 2500, +// isClosable: true, +// }); +// }; + +// const allParameterNotSetToast = (description?: string) => { +// toast({ +// title: t('toast.parametersNotSet'), +// status: 'warning', +// description, +// duration: 2500, +// isClosable: true, +// }); +// }; + +const buildParse = + (arg: { + parser: MetadataParseFunc; + getLabel: MetadataGetLabelFunc; + }): MetadataHandlers['parse'] => + async (metadata, withToast = false) => { + try { + const parsed = await arg.parser(metadata); + withToast && parameterSetToast(arg.getLabel()); + return parsed; + } catch (e) { + withToast && parameterNotSetToast(arg.getLabel(), (e as Error).message); + throw e; + } + }; + +const buildParseItem = + (arg: { + itemParser: MetadataParseFunc; + getLabel: MetadataGetLabelFunc; + }): MetadataHandlers['parseItem'] => + async (item, withToast = false) => { + try { + const parsed = await arg.itemParser(item); + withToast && parameterSetToast(arg.getLabel()); + return parsed; + } catch (e) { + withToast && parameterNotSetToast(arg.getLabel(), (e as Error).message); + throw e; + } + }; + +const buildRecall = + (arg: { + recaller: MetadataRecallFunc; + validator?: MetadataValidateFunc; + getLabel: MetadataGetLabelFunc; + }): NonNullable['recall']> => + async (value, withToast = false) => { + try { + arg.validator && (await arg.validator(value)); + await arg.recaller(value); + withToast && parameterSetToast(arg.getLabel()); + } catch (e) { + withToast && parameterNotSetToast(arg.getLabel(), (e as Error).message); + throw e; + } + }; + +const buildRecallItem = + (arg: { + itemRecaller: MetadataRecallFunc; + itemValidator?: MetadataValidateFunc; + getLabel: MetadataGetLabelFunc; + }): NonNullable['recallItem']> => + async (item, withToast = false) => { + try { + arg.itemValidator && (await arg.itemValidator(item)); + await arg.itemRecaller(item); + withToast && parameterSetToast(arg.getLabel()); + } catch (e) { + withToast && parameterNotSetToast(arg.getLabel(), (e as Error).message); + throw e; + } + }; + +const buildHandlers: BuildMetadataHandlers = ({ + getLabel, + parser, + itemParser, + recaller, + itemRecaller, + validator, + itemValidator, + renderValue, + renderItemValue, +}) => ({ + parse: buildParse({ parser, getLabel }), + parseItem: itemParser ? buildParseItem({ itemParser, getLabel }) : undefined, + recall: recaller ? buildRecall({ recaller, validator, getLabel }) : undefined, + recallItem: itemRecaller ? buildRecallItem({ itemRecaller, itemValidator, getLabel }) : undefined, + getLabel, + renderValue: renderValue ?? String, + renderItemValue: renderItemValue ?? String, +}); + +export const handlers = { + // Misc + createdBy: buildHandlers({ getLabel: () => t('metadata.createdBy'), parser: parsers.createdBy }), + generationMode: buildHandlers({ getLabel: () => t('metadata.generationMode'), parser: parsers.generationMode }), + + // Core parameters + cfgRescaleMultiplier: buildHandlers({ + getLabel: () => t('metadata.cfgRescaleMultiplier'), + parser: parsers.cfgRescaleMultiplier, + recaller: recallers.cfgRescaleMultiplier, + }), + cfgScale: buildHandlers({ + getLabel: () => t('metadata.cfgScale'), + parser: parsers.cfgScale, + recaller: recallers.cfgScale, + }), + height: buildHandlers({ getLabel: () => t('metadata.height'), parser: parsers.height, recaller: recallers.height }), + negativePrompt: buildHandlers({ + getLabel: () => t('metadata.negativePrompt'), + parser: parsers.negativePrompt, + recaller: recallers.negativePrompt, + }), + positivePrompt: buildHandlers({ + getLabel: () => t('metadata.positivePrompt'), + parser: parsers.positivePrompt, + recaller: recallers.positivePrompt, + }), + scheduler: buildHandlers({ + getLabel: () => t('metadata.scheduler'), + parser: parsers.scheduler, + recaller: recallers.scheduler, + }), + sdxlNegativeStylePrompt: buildHandlers({ + getLabel: () => t('sdxl.negStylePrompt'), + parser: parsers.sdxlNegativeStylePrompt, + recaller: recallers.sdxlNegativeStylePrompt, + }), + sdxlPositiveStylePrompt: buildHandlers({ + getLabel: () => t('sdxl.posStylePrompt'), + parser: parsers.sdxlPositiveStylePrompt, + recaller: recallers.sdxlPositiveStylePrompt, + }), + seed: buildHandlers({ getLabel: () => t('metadata.seed'), parser: parsers.seed, recaller: recallers.seed }), + steps: buildHandlers({ getLabel: () => t('metadata.steps'), parser: parsers.steps, recaller: recallers.steps }), + strength: buildHandlers({ + getLabel: () => t('metadata.strength'), + parser: parsers.strength, + recaller: recallers.strength, + }), + width: buildHandlers({ getLabel: () => t('metadata.width'), parser: parsers.width, recaller: recallers.width }), + + // HRF + hrfEnabled: buildHandlers({ + getLabel: () => t('hrf.metadata.enabled'), + parser: parsers.hrfEnabled, + recaller: recallers.hrfEnabled, + }), + hrfMethod: buildHandlers({ + getLabel: () => t('hrf.metadata.method'), + parser: parsers.hrfMethod, + recaller: recallers.hrfMethod, + }), + hrfStrength: buildHandlers({ + getLabel: () => t('hrf.metadata.strength'), + parser: parsers.hrfStrength, + recaller: recallers.hrfStrength, + }), + + // Refiner + refinerCFGScale: buildHandlers({ + getLabel: () => t('sdxl.cfgScale'), + parser: parsers.refinerCFGScale, + recaller: recallers.refinerCFGScale, + }), + refinerModel: buildHandlers({ + getLabel: () => t('sdxl.refinerModel'), + parser: parsers.refinerModel, + recaller: recallers.refinerModel, + validator: validators.refinerModel, + }), + refinerNegativeAestheticScore: buildHandlers({ + getLabel: () => t('sdxl.posAestheticScore'), + parser: parsers.refinerNegativeAestheticScore, + recaller: recallers.refinerNegativeAestheticScore, + }), + refinerPositiveAestheticScore: buildHandlers({ + getLabel: () => t('sdxl.negAestheticScore'), + parser: parsers.refinerPositiveAestheticScore, + recaller: recallers.refinerPositiveAestheticScore, + }), + refinerScheduler: buildHandlers({ + getLabel: () => t('sdxl.scheduler'), + parser: parsers.refinerScheduler, + recaller: recallers.refinerScheduler, + }), + refinerStart: buildHandlers({ + getLabel: () => t('sdxl.refiner_start'), + parser: parsers.refinerStart, + recaller: recallers.refinerStart, + }), + refinerSteps: buildHandlers({ + getLabel: () => t('sdxl.refiner_steps'), + parser: parsers.refinerSteps, + recaller: recallers.refinerSteps, + }), + + // Models + model: buildHandlers({ + getLabel: () => t('metadata.model'), + parser: parsers.mainModel, + recaller: recallers.model, + renderValue: renderModelConfigValue, + }), + vae: buildHandlers({ + getLabel: () => t('metadata.vae'), + parser: parsers.vaeModel, + recaller: recallers.vae, + renderValue: renderModelConfigValue, + validator: validators.vaeModel, + }), + + // Arrays of models + controlNets: buildHandlers({ + getLabel: () => t('common.controlNet'), + parser: parsers.controlNets, + itemParser: parsers.controlNet, + recaller: recallers.controlNets, + itemRecaller: recallers.controlNet, + validator: validators.controlNets, + itemValidator: validators.controlNet, + renderItemValue: renderControlAdapterValue, + }), + ipAdapters: buildHandlers({ + getLabel: () => t('common.ipAdapter'), + parser: parsers.ipAdapters, + itemParser: parsers.ipAdapter, + recaller: recallers.ipAdapters, + itemRecaller: recallers.ipAdapter, + validator: validators.ipAdapters, + itemValidator: validators.ipAdapter, + renderItemValue: renderControlAdapterValue, + }), + loras: buildHandlers({ + getLabel: () => t('models.lora'), + parser: parsers.loras, + itemParser: parsers.lora, + recaller: recallers.loras, + itemRecaller: recallers.lora, + validator: validators.loras, + itemValidator: validators.lora, + renderItemValue: renderLoRAValue, + }), + t2iAdapters: buildHandlers({ + getLabel: () => t('common.t2iAdapter'), + parser: parsers.t2iAdapters, + itemParser: parsers.t2iAdapter, + recaller: recallers.t2iAdapters, + itemRecaller: recallers.t2iAdapter, + validator: validators.t2iAdapters, + itemValidator: validators.t2iAdapter, + renderItemValue: renderControlAdapterValue, + }), +} as const; diff --git a/invokeai/frontend/web/src/features/metadata/util/parsers.ts b/invokeai/frontend/web/src/features/metadata/util/parsers.ts new file mode 100644 index 0000000000..c7f2a6d09f --- /dev/null +++ b/invokeai/frontend/web/src/features/metadata/util/parsers.ts @@ -0,0 +1,396 @@ +import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants'; +import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types'; +import { + initialControlNet, + initialIPAdapter, + initialT2IAdapter, +} from 'features/controlAdapters/util/buildControlAdapter'; +import type { LoRA } from 'features/lora/store/loraSlice'; +import { defaultLoRAConfig } from 'features/lora/store/loraSlice'; +import { MetadataParseError } from 'features/metadata/exceptions'; +import type { MetadataParseFunc } from 'features/metadata/types'; +import { + zControlField, + zIPAdapterField, + zModelIdentifierWithBase, + zT2IAdapterField, +} from 'features/nodes/types/common'; +import type { + ParameterCFGRescaleMultiplier, + ParameterCFGScale, + ParameterHeight, + ParameterHRFEnabled, + ParameterHRFMethod, + ParameterNegativePrompt, + ParameterNegativeStylePromptSDXL, + ParameterPositivePrompt, + ParameterPositiveStylePromptSDXL, + ParameterScheduler, + ParameterSDXLRefinerNegativeAestheticScore, + ParameterSDXLRefinerPositiveAestheticScore, + ParameterSDXLRefinerStart, + ParameterSeed, + ParameterSteps, + ParameterStrength, + ParameterWidth, +} from 'features/parameters/types/parameterSchemas'; +import { + isParameterCFGRescaleMultiplier, + isParameterCFGScale, + isParameterHeight, + isParameterHRFEnabled, + isParameterHRFMethod, + isParameterLoRAWeight, + isParameterNegativePrompt, + isParameterNegativeStylePromptSDXL, + isParameterPositivePrompt, + isParameterPositiveStylePromptSDXL, + isParameterScheduler, + isParameterSDXLRefinerNegativeAestheticScore, + isParameterSDXLRefinerPositiveAestheticScore, + isParameterSDXLRefinerStart, + isParameterSeed, + isParameterSteps, + isParameterStrength, + isParameterWidth, +} from 'features/parameters/types/parameterSchemas'; +import { + fetchModelConfigWithTypeGuard, + getModelKey, + getModelKeyAndBase, +} from 'features/parameters/util/modelFetchingHelpers'; +import { get, isArray, isString } from 'lodash-es'; +import type { NonRefinerMainModelConfig, RefinerMainModelConfig, VAEModelConfig } from 'services/api/types'; +import { + isControlNetModelConfig, + isIPAdapterModelConfig, + isLoRAModelConfig, + isNonRefinerMainModelConfig, + isRefinerMainModelModelConfig, + isT2IAdapterModelConfig, + isVAEModelConfig, +} from 'services/api/types'; +import { v4 as uuidv4 } from 'uuid'; + +export const MetadataParsePendingToken = Symbol('pending'); +export const MetadataParseFailedToken = Symbol('failed'); + +/** + * An async function that a property from an object and validates its type using a type guard. If the property is missing + * or invalid, the function should throw a MetadataParseError. + * @param obj The object to get the property from. + * @param property The property to get. + * @param typeGuard A type guard function to check the type of the property. Provide `undefined` to opt out of type + * validation and always return the property value. + * @returns A promise that resolves to the property value if it exists and is of the expected type. + * @throws MetadataParseError if a type guard is provided and the property is not of the expected type. + */ +const getProperty = ( + obj: unknown, + property: string, + typeGuard: (val: unknown) => val is T = (val: unknown): val is T => true +): Promise => { + return new Promise((resolve, reject) => { + const val = get(obj, property) as unknown; + if (typeGuard(val)) { + resolve(val); + } + reject(new MetadataParseError(`Property ${property} is not of expected type`)); + }); +}; + +const parseCreatedBy: MetadataParseFunc = (metadata) => getProperty(metadata, 'created_by', isString); + +const parseGenerationMode: MetadataParseFunc = (metadata) => getProperty(metadata, 'generation_mode', isString); + +const parsePositivePrompt: MetadataParseFunc = (metadata) => + getProperty(metadata, 'positive_prompt', isParameterPositivePrompt); + +const parseNegativePrompt: MetadataParseFunc = (metadata) => + getProperty(metadata, 'negative_prompt', isParameterNegativePrompt); + +const parseSDXLPositiveStylePrompt: MetadataParseFunc = (metadata) => + getProperty(metadata, 'positive_style_prompt', isParameterPositiveStylePromptSDXL); + +const parseSDXLNegativeStylePrompt: MetadataParseFunc = (metadata) => + getProperty(metadata, 'negative_style_prompt', isParameterNegativeStylePromptSDXL); + +const parseSeed: MetadataParseFunc = (metadata) => getProperty(metadata, 'seed', isParameterSeed); + +const parseCFGScale: MetadataParseFunc = (metadata) => + getProperty(metadata, 'cfg_scale', isParameterCFGScale); + +const parseCFGRescaleMultiplier: MetadataParseFunc = (metadata) => + getProperty(metadata, 'cfg_rescale_multiplier', isParameterCFGRescaleMultiplier); + +const parseScheduler: MetadataParseFunc = (metadata) => + getProperty(metadata, 'scheduler', isParameterScheduler); + +const parseWidth: MetadataParseFunc = (metadata) => getProperty(metadata, 'width', isParameterWidth); + +const parseHeight: MetadataParseFunc = (metadata) => + getProperty(metadata, 'height', isParameterHeight); + +const parseSteps: MetadataParseFunc = (metadata) => getProperty(metadata, 'steps', isParameterSteps); + +const parseStrength: MetadataParseFunc = (metadata) => + getProperty(metadata, 'strength', isParameterStrength); + +const parseHRFEnabled: MetadataParseFunc = (metadata) => + getProperty(metadata, 'hrf_enabled', isParameterHRFEnabled); + +const parseHRFStrength: MetadataParseFunc = (metadata) => + getProperty(metadata, 'hrf_strength', isParameterStrength); + +const parseHRFMethod: MetadataParseFunc = (metadata) => + getProperty(metadata, 'hrf_method', isParameterHRFMethod); + +const parseRefinerSteps: MetadataParseFunc = (metadata) => + getProperty(metadata, 'refiner_steps', isParameterSteps); + +const parseRefinerCFGScale: MetadataParseFunc = (metadata) => + getProperty(metadata, 'refiner_cfg_scale', isParameterCFGScale); + +const parseRefinerScheduler: MetadataParseFunc = (metadata) => + getProperty(metadata, 'refiner_scheduler', isParameterScheduler); + +const parseRefinerPositiveAestheticScore: MetadataParseFunc = (metadata) => + getProperty(metadata, 'refiner_positive_aesthetic_score', isParameterSDXLRefinerPositiveAestheticScore); + +const parseRefinerNegativeAestheticScore: MetadataParseFunc = (metadata) => + getProperty(metadata, 'refiner_negative_aesthetic_score', isParameterSDXLRefinerNegativeAestheticScore); + +const parseRefinerStart: MetadataParseFunc = (metadata) => + getProperty(metadata, 'refiner_start', isParameterSDXLRefinerStart); + +const parseMainModel: MetadataParseFunc = async (metadata) => { + const model = await getProperty(metadata, 'model', undefined); + const key = await getModelKey(model, 'main'); + const mainModelConfig = await fetchModelConfigWithTypeGuard(key, isNonRefinerMainModelConfig); + return mainModelConfig; +}; + +const parseRefinerModel: MetadataParseFunc = async (metadata) => { + const refiner_model = await getProperty(metadata, 'refiner_model', undefined); + const key = await getModelKey(refiner_model, 'main'); + const refinerModelConfig = await fetchModelConfigWithTypeGuard(key, isRefinerMainModelModelConfig); + return refinerModelConfig; +}; + +const parseVAEModel: MetadataParseFunc = async (metadata) => { + const vae = await getProperty(metadata, 'vae', undefined); + const key = await getModelKey(vae, 'vae'); + const vaeModelConfig = await fetchModelConfigWithTypeGuard(key, isVAEModelConfig); + return vaeModelConfig; +}; + +const parseLoRA: MetadataParseFunc = async (metadataItem) => { + // Previously, the LoRA model identifier parts were stored in the LoRA metadata: `{key: ..., weight: 0.75}` + const modelV1 = await getProperty(metadataItem, 'lora', undefined); + // Now, the LoRA model is stored in a `model` property of the LoRA metadata: `{model: {key: ...}, weight: 0.75}` + const modelV2 = await getProperty(metadataItem, 'model', undefined); + const weight = await getProperty(metadataItem, 'weight', undefined); + const key = await getModelKey(modelV2 ?? modelV1, 'lora'); + const loraModelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig); + + return { + model: getModelKeyAndBase(loraModelConfig), + weight: isParameterLoRAWeight(weight) ? weight : defaultLoRAConfig.weight, + isEnabled: true, + }; +}; + +const parseAllLoRAs: MetadataParseFunc = async (metadata) => { + const lorasRaw = await getProperty(metadata, 'loras', isArray); + const parseResults = await Promise.allSettled(lorasRaw.map((lora) => parseLoRA(lora))); + const loras = parseResults + .filter((result): result is PromiseFulfilledResult => result.status === 'fulfilled') + .map((result) => result.value); + return loras; +}; + +const parseControlNet: MetadataParseFunc = async (metadataItem) => { + const control_model = await getProperty(metadataItem, 'control_model'); + const key = await getModelKey(control_model, 'controlnet'); + const controlNetModel = await fetchModelConfigWithTypeGuard(key, isControlNetModelConfig); + + const image = zControlField.shape.image.nullish().catch(null).parse(getProperty(metadataItem, 'image')); + const control_weight = zControlField.shape.control_weight + .nullish() + .catch(null) + .parse(getProperty(metadataItem, 'control_weight')); + const begin_step_percent = zControlField.shape.begin_step_percent + .nullish() + .catch(null) + .parse(getProperty(metadataItem, 'begin_step_percent')); + const end_step_percent = zControlField.shape.end_step_percent + .nullish() + .catch(null) + .parse(getProperty(metadataItem, 'end_step_percent')); + const control_mode = zControlField.shape.control_mode + .nullish() + .catch(null) + .parse(getProperty(metadataItem, 'control_mode')); + const resize_mode = zControlField.shape.resize_mode + .nullish() + .catch(null) + .parse(getProperty(metadataItem, 'resize_mode')); + + const processorType = 'none'; + const processorNode = CONTROLNET_PROCESSORS.none.default; + + const controlNet: ControlNetConfig = { + type: 'controlnet', + isEnabled: true, + model: zModelIdentifierWithBase.parse(controlNetModel), + weight: typeof control_weight === 'number' ? control_weight : initialControlNet.weight, + beginStepPct: begin_step_percent ?? initialControlNet.beginStepPct, + endStepPct: end_step_percent ?? initialControlNet.endStepPct, + controlMode: control_mode ?? initialControlNet.controlMode, + resizeMode: resize_mode ?? initialControlNet.resizeMode, + controlImage: image?.image_name ?? null, + processedControlImage: image?.image_name ?? null, + processorType, + processorNode, + shouldAutoConfig: true, + id: uuidv4(), + }; + + return controlNet; +}; + +const parseAllControlNets: MetadataParseFunc = async (metadata) => { + const controlNetsRaw = await getProperty(metadata, 'controlnets', isArray); + const parseResults = await Promise.allSettled(controlNetsRaw.map((cn) => parseControlNet(cn))); + const controlNets = parseResults + .filter((result): result is PromiseFulfilledResult => result.status === 'fulfilled') + .map((result) => result.value); + return controlNets; +}; + +const parseT2IAdapter: MetadataParseFunc = async (metadataItem) => { + const t2i_adapter_model = await getProperty(metadataItem, 't2i_adapter_model'); + const key = await getModelKey(t2i_adapter_model, 't2i_adapter'); + const t2iAdapterModel = await fetchModelConfigWithTypeGuard(key, isT2IAdapterModelConfig); + + const image = zT2IAdapterField.shape.image.nullish().catch(null).parse(getProperty(metadataItem, 'image')); + const weight = zT2IAdapterField.shape.weight.nullish().catch(null).parse(getProperty(metadataItem, 'weight')); + const begin_step_percent = zT2IAdapterField.shape.begin_step_percent + .nullish() + .catch(null) + .parse(getProperty(metadataItem, 'begin_step_percent')); + const end_step_percent = zT2IAdapterField.shape.end_step_percent + .nullish() + .catch(null) + .parse(getProperty(metadataItem, 'end_step_percent')); + const resize_mode = zT2IAdapterField.shape.resize_mode + .nullish() + .catch(null) + .parse(getProperty(metadataItem, 'resize_mode')); + + const processorType = 'none'; + const processorNode = CONTROLNET_PROCESSORS.none.default; + + const t2iAdapter: T2IAdapterConfig = { + type: 't2i_adapter', + isEnabled: true, + model: zModelIdentifierWithBase.parse(t2iAdapterModel), + weight: typeof weight === 'number' ? weight : initialT2IAdapter.weight, + beginStepPct: begin_step_percent ?? initialT2IAdapter.beginStepPct, + endStepPct: end_step_percent ?? initialT2IAdapter.endStepPct, + resizeMode: resize_mode ?? initialT2IAdapter.resizeMode, + controlImage: image?.image_name ?? null, + processedControlImage: image?.image_name ?? null, + processorType, + processorNode, + shouldAutoConfig: true, + id: uuidv4(), + }; + + return t2iAdapter; +}; + +const parseAllT2IAdapters: MetadataParseFunc = async (metadata) => { + const t2iAdaptersRaw = await getProperty(metadata, 't2iAdapters', isArray); + const parseResults = await Promise.allSettled(t2iAdaptersRaw.map((t2iAdapter) => parseT2IAdapter(t2iAdapter))); + const t2iAdapters = parseResults + .filter((result): result is PromiseFulfilledResult => result.status === 'fulfilled') + .map((result) => result.value); + return t2iAdapters; +}; + +const parseIPAdapter: MetadataParseFunc = async (metadataItem) => { + const ip_adapter_model = await getProperty(metadataItem, 'ip_adapter_model'); + const key = await getModelKey(ip_adapter_model, 'ip_adapter'); + const ipAdapterModel = await fetchModelConfigWithTypeGuard(key, isIPAdapterModelConfig); + + const image = zIPAdapterField.shape.image.nullish().catch(null).parse(getProperty(metadataItem, 'image')); + const weight = zIPAdapterField.shape.weight.nullish().catch(null).parse(getProperty(metadataItem, 'weight')); + const begin_step_percent = zIPAdapterField.shape.begin_step_percent + .nullish() + .catch(null) + .parse(getProperty(metadataItem, 'begin_step_percent')); + const end_step_percent = zIPAdapterField.shape.end_step_percent + .nullish() + .catch(null) + .parse(getProperty(metadataItem, 'end_step_percent')); + + const ipAdapter: IPAdapterConfig = { + id: uuidv4(), + type: 'ip_adapter', + isEnabled: true, + model: zModelIdentifierWithBase.parse(ipAdapterModel), + controlImage: image?.image_name ?? null, + weight: weight ?? initialIPAdapter.weight, + beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct, + endStepPct: end_step_percent ?? initialIPAdapter.endStepPct, + }; + + return ipAdapter; +}; + +const parseAllIPAdapters: MetadataParseFunc = async (metadata) => { + const ipAdaptersRaw = await getProperty(metadata, 'ipAdapters', isArray); + const parseResults = await Promise.allSettled(ipAdaptersRaw.map((ipAdapter) => parseIPAdapter(ipAdapter))); + const ipAdapters = parseResults + .filter((result): result is PromiseFulfilledResult => result.status === 'fulfilled') + .map((result) => result.value); + return ipAdapters; +}; + +export const parsers = { + createdBy: parseCreatedBy, + generationMode: parseGenerationMode, + positivePrompt: parsePositivePrompt, + negativePrompt: parseNegativePrompt, + sdxlPositiveStylePrompt: parseSDXLPositiveStylePrompt, + sdxlNegativeStylePrompt: parseSDXLNegativeStylePrompt, + seed: parseSeed, + cfgScale: parseCFGScale, + cfgRescaleMultiplier: parseCFGRescaleMultiplier, + scheduler: parseScheduler, + width: parseWidth, + height: parseHeight, + steps: parseSteps, + strength: parseStrength, + hrfEnabled: parseHRFEnabled, + hrfStrength: parseHRFStrength, + hrfMethod: parseHRFMethod, + refinerSteps: parseRefinerSteps, + refinerCFGScale: parseRefinerCFGScale, + refinerScheduler: parseRefinerScheduler, + refinerPositiveAestheticScore: parseRefinerPositiveAestheticScore, + refinerNegativeAestheticScore: parseRefinerNegativeAestheticScore, + refinerStart: parseRefinerStart, + mainModel: parseMainModel, + refinerModel: parseRefinerModel, + vaeModel: parseVAEModel, + lora: parseLoRA, + loras: parseAllLoRAs, + controlNet: parseControlNet, + controlNets: parseAllControlNets, + t2iAdapter: parseT2IAdapter, + t2iAdapters: parseAllT2IAdapters, + ipAdapter: parseIPAdapter, + ipAdapters: parseAllIPAdapters, +} as const; diff --git a/invokeai/frontend/web/src/features/metadata/util/recallers.ts b/invokeai/frontend/web/src/features/metadata/util/recallers.ts new file mode 100644 index 0000000000..e33589d5ce --- /dev/null +++ b/invokeai/frontend/web/src/features/metadata/util/recallers.ts @@ -0,0 +1,295 @@ +import { getStore } from 'app/store/nanostores/store'; +import { controlAdapterRecalled } from 'features/controlAdapters/store/controlAdaptersSlice'; +import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types'; +import { setHrfEnabled, setHrfMethod, setHrfStrength } from 'features/hrf/store/hrfSlice'; +import type { LoRA } from 'features/lora/store/loraSlice'; +import { loraRecalled } from 'features/lora/store/loraSlice'; +import type { MetadataRecallFunc } from 'features/metadata/types'; +import { zModelIdentifierWithBase } from 'features/nodes/types/common'; +import { modelSelected } from 'features/parameters/store/actions'; +import { + heightRecalled, + setCfgRescaleMultiplier, + setCfgScale, + setImg2imgStrength, + setNegativePrompt, + setPositivePrompt, + setScheduler, + setSeed, + setSteps, + vaeSelected, + widthRecalled, +} from 'features/parameters/store/generationSlice'; +import type { + ParameterCFGRescaleMultiplier, + ParameterCFGScale, + ParameterHeight, + ParameterHRFEnabled, + ParameterHRFMethod, + ParameterNegativePrompt, + ParameterNegativeStylePromptSDXL, + ParameterPositivePrompt, + ParameterPositiveStylePromptSDXL, + ParameterScheduler, + ParameterSDXLRefinerNegativeAestheticScore, + ParameterSDXLRefinerPositiveAestheticScore, + ParameterSDXLRefinerStart, + ParameterSeed, + ParameterSteps, + ParameterStrength, + ParameterWidth, +} from 'features/parameters/types/parameterSchemas'; +import { + refinerModelChanged, + setNegativeStylePromptSDXL, + setPositiveStylePromptSDXL, + setRefinerCFGScale, + setRefinerNegativeAestheticScore, + setRefinerPositiveAestheticScore, + setRefinerScheduler, + setRefinerStart, + setRefinerSteps, +} from 'features/sdxl/store/sdxlSlice'; +import type { NonRefinerMainModelConfig, RefinerMainModelConfig, VAEModelConfig } from 'services/api/types'; + +const recallPositivePrompt: MetadataRecallFunc = (positivePrompt) => { + getStore().dispatch(setPositivePrompt(positivePrompt)); +}; + +const recallNegativePrompt: MetadataRecallFunc = (negativePrompt) => { + getStore().dispatch(setNegativePrompt(negativePrompt)); +}; + +const recallSDXLPositiveStylePrompt: MetadataRecallFunc = (positiveStylePrompt) => { + getStore().dispatch(setPositiveStylePromptSDXL(positiveStylePrompt)); +}; + +const recallSDXLNegativeStylePrompt: MetadataRecallFunc = (negativeStylePrompt) => { + getStore().dispatch(setNegativeStylePromptSDXL(negativeStylePrompt)); +}; + +const recallSeed: MetadataRecallFunc = (seed) => { + getStore().dispatch(setSeed(seed)); +}; + +const recallCFGScale: MetadataRecallFunc = (cfgScale) => { + getStore().dispatch(setCfgScale(cfgScale)); +}; + +const recallCFGRescaleMultiplier: MetadataRecallFunc = (cfgRescaleMultiplier) => { + getStore().dispatch(setCfgRescaleMultiplier(cfgRescaleMultiplier)); +}; + +const recallScheduler: MetadataRecallFunc = (scheduler) => { + getStore().dispatch(setScheduler(scheduler)); +}; + +const recallWidth: MetadataRecallFunc = (width) => { + getStore().dispatch(widthRecalled(width)); +}; + +const recallHeight: MetadataRecallFunc = (height) => { + getStore().dispatch(heightRecalled(height)); +}; + +const recallSteps: MetadataRecallFunc = (steps) => { + getStore().dispatch(setSteps(steps)); +}; + +const recallStrength: MetadataRecallFunc = (strength) => { + getStore().dispatch(setImg2imgStrength(strength)); +}; + +const recallHRFEnabled: MetadataRecallFunc = (hrfEnabled) => { + getStore().dispatch(setHrfEnabled(hrfEnabled)); +}; + +const recallHRFStrength: MetadataRecallFunc = (hrfStrength) => { + getStore().dispatch(setHrfStrength(hrfStrength)); +}; + +const recallHRFMethod: MetadataRecallFunc = (hrfMethod) => { + getStore().dispatch(setHrfMethod(hrfMethod)); +}; + +const recallRefinerSteps: MetadataRecallFunc = (refinerSteps) => { + getStore().dispatch(setRefinerSteps(refinerSteps)); +}; + +const recallRefinerCFGScale: MetadataRecallFunc = (refinerCFGScale) => { + getStore().dispatch(setRefinerCFGScale(refinerCFGScale)); +}; + +const recallRefinerScheduler: MetadataRecallFunc = (refinerScheduler) => { + getStore().dispatch(setRefinerScheduler(refinerScheduler)); +}; + +const recallRefinerPositiveAestheticScore: MetadataRecallFunc = ( + refinerPositiveAestheticScore +) => { + getStore().dispatch(setRefinerPositiveAestheticScore(refinerPositiveAestheticScore)); +}; + +const recallRefinerNegativeAestheticScore: MetadataRecallFunc = ( + refinerNegativeAestheticScore +) => { + getStore().dispatch(setRefinerNegativeAestheticScore(refinerNegativeAestheticScore)); +}; + +const recallRefinerStart: MetadataRecallFunc = (refinerStart) => { + getStore().dispatch(setRefinerStart(refinerStart)); +}; + +const recallModel: MetadataRecallFunc = (model) => { + const modelIdentifier = zModelIdentifierWithBase.parse(model); + getStore().dispatch(modelSelected(modelIdentifier)); +}; + +const recallRefinerModel: MetadataRecallFunc = (refinerModel) => { + const modelIdentifier = zModelIdentifierWithBase.parse(refinerModel); + getStore().dispatch(refinerModelChanged(modelIdentifier)); +}; + +const recallVAE: MetadataRecallFunc = (vaeModel) => { + if (!vaeModel) { + getStore().dispatch(vaeSelected(null)); + return; + } + const modelIdentifier = zModelIdentifierWithBase.parse(vaeModel); + getStore().dispatch(vaeSelected(modelIdentifier)); +}; + +const recallLoRA: MetadataRecallFunc = (lora) => { + getStore().dispatch(loraRecalled(lora)); +}; + +const recallAllLoRAs: MetadataRecallFunc = (loras) => { + const { dispatch } = getStore(); + loras.forEach((lora) => { + dispatch(loraRecalled(lora)); + }); +}; + +const recallControlNet: MetadataRecallFunc = (controlNet) => { + getStore().dispatch(controlAdapterRecalled(controlNet)); +}; + +const recallControlNets: MetadataRecallFunc = (controlNets) => { + const { dispatch } = getStore(); + controlNets.forEach((controlNet) => { + dispatch(controlAdapterRecalled(controlNet)); + }); +}; + +const recallT2IAdapter: MetadataRecallFunc = (t2iAdapter) => { + getStore().dispatch(controlAdapterRecalled(t2iAdapter)); +}; + +const recallT2IAdapters: MetadataRecallFunc = (t2iAdapters) => { + const { dispatch } = getStore(); + t2iAdapters.forEach((t2iAdapter) => { + dispatch(controlAdapterRecalled(t2iAdapter)); + }); +}; + +const recallIPAdapter: MetadataRecallFunc = (ipAdapter) => { + getStore().dispatch(controlAdapterRecalled(ipAdapter)); +}; + +const recallIPAdapters: MetadataRecallFunc = (ipAdapters) => { + const { dispatch } = getStore(); + ipAdapters.forEach((ipAdapter) => { + dispatch(controlAdapterRecalled(ipAdapter)); + }); +}; + +export const recallPrompts = (_metadata: unknown) => { + // recallPositivePrompt(metadata); + // recallNegativePrompt(metadata); + // recallSDXLPositiveStylePrompt(metadata); + // recallSDXLNegativeStylePrompt(metadata); + // parameterNotSetToast(t('metadata.allPrompts')); + // parameterSetToast(t('metadata.allPrompts')); +}; + +export const recallWidthAndHeight = (_metadata: unknown) => { + // recallWidth(metadata); + // recallHeight(metadata); +}; + +export const recallAll = async (_metadata: unknown) => { + // if (!metadata) { + // allParameterNotSetToast(); + // return; + // } + // // Update the main model first, as other parameters may depend on it. + // await recallModel(metadata); + // await Promise.allSettled([ + // // Core parameters + // recallCFGScale(metadata), + // recallCFGRescaleMultiplier(metadata), + // recallPositivePrompt(metadata), + // recallNegativePrompt(metadata), + // recallScheduler(metadata), + // recallSeed(metadata), + // recallSteps(metadata), + // recallWidth(metadata), + // recallHeight(metadata), + // recallStrength(metadata), + // recallHRFEnabled(metadata), + // recallHRFMethod(metadata), + // recallHRFStrength(metadata), + // // SDXL parameters + // recallSDXLPositiveStylePrompt(metadata), + // recallSDXLNegativeStylePrompt(metadata), + // recallRefinerSteps(metadata), + // recallRefinerCFGScale(metadata), + // recallRefinerScheduler(metadata), + // recallRefinerPositiveAestheticScore(metadata), + // recallRefinerNegativeAestheticScore(metadata), + // recallRefinerStart(metadata), + // // Models + // recallVAE(metadata), + // recallRefinerModel(metadata), + // recallAllLoRAs(metadata), + // recallControlNets(metadata), + // recallT2IAdapters(metadata), + // recallIPAdapters(metadata), + // ]); + // allParameterSetToast(); +}; + +export const recallers = { + positivePrompt: recallPositivePrompt, + negativePrompt: recallNegativePrompt, + sdxlPositiveStylePrompt: recallSDXLPositiveStylePrompt, + sdxlNegativeStylePrompt: recallSDXLNegativeStylePrompt, + seed: recallSeed, + cfgScale: recallCFGScale, + cfgRescaleMultiplier: recallCFGRescaleMultiplier, + scheduler: recallScheduler, + width: recallWidth, + height: recallHeight, + steps: recallSteps, + strength: recallStrength, + hrfEnabled: recallHRFEnabled, + hrfStrength: recallHRFStrength, + hrfMethod: recallHRFMethod, + refinerSteps: recallRefinerSteps, + refinerCFGScale: recallRefinerCFGScale, + refinerScheduler: recallRefinerScheduler, + refinerPositiveAestheticScore: recallRefinerPositiveAestheticScore, + refinerNegativeAestheticScore: recallRefinerNegativeAestheticScore, + refinerStart: recallRefinerStart, + model: recallModel, + refinerModel: recallRefinerModel, + vae: recallVAE, + lora: recallLoRA, + loras: recallAllLoRAs, + controlNets: recallControlNets, + controlNet: recallControlNet, + t2iAdapters: recallT2IAdapters, + t2iAdapter: recallT2IAdapter, + ipAdapters: recallIPAdapters, + ipAdapter: recallIPAdapter, +} as const; diff --git a/invokeai/frontend/web/src/features/metadata/util/validators.ts b/invokeai/frontend/web/src/features/metadata/util/validators.ts new file mode 100644 index 0000000000..e627afeda3 --- /dev/null +++ b/invokeai/frontend/web/src/features/metadata/util/validators.ts @@ -0,0 +1,117 @@ +import { getStore } from 'app/store/nanostores/store'; +import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types'; +import type { LoRA } from 'features/lora/store/loraSlice'; +import type { MetadataValidateFunc } from 'features/metadata/types'; +import { InvalidModelConfigError } from 'features/parameters/util/modelFetchingHelpers'; +import type { BaseModelType, RefinerMainModelConfig, VAEModelConfig } from 'services/api/types'; + +/** + * Checks the given base model type against the currently-selected model's base type and throws an error if they are + * incompatible. + * @param base The base model type to validate. + * @param message An optional message to use in the error if the base model is incompatible. + */ +const validateBaseCompatibility = (base?: BaseModelType, message?: string) => { + if (!base) { + throw new InvalidModelConfigError(message || 'Missing base'); + } + const currentBase = getStore().getState().generation.model?.base; + if (currentBase && base !== currentBase) { + throw new InvalidModelConfigError(message || `Incompatible base models: ${base} and ${currentBase}`); + } +}; + +const validateRefinerModel: MetadataValidateFunc = (refinerModel) => { + validateBaseCompatibility('sdxl', 'Refiner incompatible with currently-selected model'); + return new Promise((resolve) => resolve(refinerModel)); +}; + +const validateVAEModel: MetadataValidateFunc = (vaeModel) => { + validateBaseCompatibility(vaeModel.base, 'VAE incompatible with currently-selected model'); + return new Promise((resolve) => resolve(vaeModel)); +}; + +const validateLoRA: MetadataValidateFunc = (lora) => { + validateBaseCompatibility(lora.model.base, 'LoRA incompatible with currently-selected model'); + return new Promise((resolve) => resolve(lora)); +}; + +const validateLoRAs: MetadataValidateFunc = (loras) => { + const validatedLoRAs: LoRA[] = []; + loras.forEach((lora) => { + try { + validateBaseCompatibility(lora.model.base, 'LoRA incompatible with currently-selected model'); + validatedLoRAs.push(lora); + } catch { + // This is a no-op - we want to continue validating the rest of the LoRAs, and an empty list is valid. + } + }); + return new Promise((resolve) => resolve(validatedLoRAs)); +}; + +const validateControlNet: MetadataValidateFunc = (controlNet) => { + validateBaseCompatibility(controlNet.model?.base, 'ControlNet incompatible with currently-selected model'); + return new Promise((resolve) => resolve(controlNet)); +}; + +const validateControlNets: MetadataValidateFunc = (controlNets) => { + const validatedControlNets: ControlNetConfig[] = []; + controlNets.forEach((controlNet) => { + try { + validateBaseCompatibility(controlNet.model?.base, 'ControlNet incompatible with currently-selected model'); + validatedControlNets.push(controlNet); + } catch { + // This is a no-op - we want to continue validating the rest of the ControlNets, and an empty list is valid. + } + }); + return new Promise((resolve) => resolve(validatedControlNets)); +}; + +const validateT2IAdapter: MetadataValidateFunc = (t2iAdapter) => { + validateBaseCompatibility(t2iAdapter.model?.base, 'T2I Adapter incompatible with currently-selected model'); + return new Promise((resolve) => resolve(t2iAdapter)); +}; + +const validateT2IAdapters: MetadataValidateFunc = (t2iAdapters) => { + const validatedT2IAdapters: T2IAdapterConfig[] = []; + t2iAdapters.forEach((t2iAdapter) => { + try { + validateBaseCompatibility(t2iAdapter.model?.base, 'T2I Adapter incompatible with currently-selected model'); + validatedT2IAdapters.push(t2iAdapter); + } catch { + // This is a no-op - we want to continue validating the rest of the T2I Adapters, and an empty list is valid. + } + }); + return new Promise((resolve) => resolve(validatedT2IAdapters)); +}; + +const validateIPAdapter: MetadataValidateFunc = (ipAdapter) => { + validateBaseCompatibility(ipAdapter.model?.base, 'IP Adapter incompatible with currently-selected model'); + return new Promise((resolve) => resolve(ipAdapter)); +}; + +const validateIPAdapters: MetadataValidateFunc = (ipAdapters) => { + const validatedIPAdapters: IPAdapterConfig[] = []; + ipAdapters.forEach((ipAdapter) => { + try { + validateBaseCompatibility(ipAdapter.model?.base, 'IP Adapter incompatible with currently-selected model'); + validatedIPAdapters.push(ipAdapter); + } catch { + // This is a no-op - we want to continue validating the rest of the IP Adapters, and an empty list is valid. + } + }); + return new Promise((resolve) => resolve(validatedIPAdapters)); +}; + +export const validators = { + refinerModel: validateRefinerModel, + vaeModel: validateVAEModel, + lora: validateLoRA, + loras: validateLoRAs, + controlNet: validateControlNet, + controlNets: validateControlNets, + t2iAdapter: validateT2IAdapter, + t2iAdapters: validateT2IAdapters, + ipAdapter: validateIPAdapter, + ipAdapters: validateIPAdapters, +} as const; diff --git a/invokeai/frontend/web/src/features/nodes/types/common.ts b/invokeai/frontend/web/src/features/nodes/types/common.ts index b195ce4434..501286c785 100644 --- a/invokeai/frontend/web/src/features/nodes/types/common.ts +++ b/invokeai/frontend/web/src/features/nodes/types/common.ts @@ -1,5 +1,8 @@ import { z } from 'zod'; +import type { ModelIdentifier as ModelIdentifierV2 } from './v2/common'; +import { zModelIdentifier as zModelIdentifierV2 } from './v2/common'; + // #region Field data schemas export const zImageField = z.object({ image_name: z.string().trim().min(1), @@ -69,6 +72,8 @@ export const zModelIdentifier = z.object({ }); export const isModelIdentifier = (field: unknown): field is ModelIdentifier => zModelIdentifier.safeParse(field).success; +export const isModelIdentifierV2 = (field: unknown): field is ModelIdentifierV2 => + zModelIdentifierV2.safeParse(field).success; export const zModelFieldBase = zModelIdentifier; export const zModelIdentifierWithBase = zModelIdentifier.extend({ base: zBaseModel }); export type BaseModel = z.infer; diff --git a/invokeai/frontend/web/src/features/nodes/types/metadata.ts b/invokeai/frontend/web/src/features/nodes/types/metadata.ts index 493a0464b3..2789abaaca 100644 --- a/invokeai/frontend/web/src/features/nodes/types/metadata.ts +++ b/invokeai/frontend/web/src/features/nodes/types/metadata.ts @@ -1,29 +1,22 @@ import { z } from 'zod'; -import { - zControlField, - zIPAdapterField, - zMainModelField, - zModelFieldBase, - zSDXLRefinerModelField, - zT2IAdapterField, - zVAEModelField, -} from './common'; +import { zControlField, zIPAdapterField, zT2IAdapterField } from './common'; +export const zLoRAWeight = z.number().nullish(); // #region Metadata-optimized versions of schemas // TODO: It's possible that `deepPartial` will be deprecated: // - https://github.com/colinhacks/zod/issues/2106 // - https://github.com/colinhacks/zod/issues/2854 export const zLoRAMetadataItem = z.object({ - lora: zModelFieldBase.deepPartial(), - weight: z.number(), + lora: z.unknown(), + weight: zLoRAWeight, }); -const zControlNetMetadataItem = zControlField.deepPartial(); -const zIPAdapterMetadataItem = zIPAdapterField.deepPartial(); -const zT2IAdapterMetadataItem = zT2IAdapterField.deepPartial(); -const zSDXLRefinerModelMetadataItem = zSDXLRefinerModelField.deepPartial(); -const zModelMetadataItem = zMainModelField.deepPartial(); -const zVAEModelMetadataItem = zVAEModelField.deepPartial(); +const zControlNetMetadataItem = zControlField.merge(z.object({ control_model: z.unknown() })).deepPartial(); +const zIPAdapterMetadataItem = zIPAdapterField.merge(z.object({ ip_adapter_model: z.unknown() })).deepPartial(); +const zT2IAdapterMetadataItem = zT2IAdapterField.merge(z.object({ t2i_adapter_model: z.unknown() })).deepPartial(); +const zSDXLRefinerModelMetadataItem = z.unknown(); +const zModelMetadataItem = z.unknown(); +const zVAEModelMetadataItem = z.unknown(); export type LoRAMetadataItem = z.infer; export type ControlNetMetadataItem = z.infer; export type IPAdapterMetadataItem = z.infer; diff --git a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts index 0929fc1dc3..8f9a76feca 100644 --- a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts +++ b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts @@ -1,17 +1,25 @@ -import { useAppToaster } from 'app/components/Toaster'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { getStore } from 'app/store/nanostores/store'; +import { useAppDispatch } from 'app/store/storeHooks'; +import { toast } from 'common/util/toast'; +import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants'; import { controlAdapterRecalled, controlAdaptersReset } from 'features/controlAdapters/store/controlAdaptersSlice'; +import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types'; +import { + initialControlNet, + initialIPAdapter, + initialT2IAdapter, +} from 'features/controlAdapters/util/buildControlAdapter'; import { setHrfEnabled, setHrfMethod, setHrfStrength } from 'features/hrf/store/hrfSlice'; -import { loraRecalled, lorasCleared } from 'features/lora/store/loraSlice'; -import { isModelIdentifier } from 'features/nodes/types/common'; -import type { - ControlNetMetadataItem, - CoreMetadata, - IPAdapterMetadataItem, - LoRAMetadataItem, - T2IAdapterMetadataItem, -} from 'features/nodes/types/metadata'; +import type { LoRA } from 'features/lora/store/loraSlice'; +import { defaultLoRAConfig, loraRecalled, lorasCleared } from 'features/lora/store/loraSlice'; +import type { ModelIdentifierWithBase } from 'features/nodes/types/common'; +import { + zControlField, + zIPAdapterField, + zModelIdentifierWithBase, + zT2IAdapterField, +} from 'features/nodes/types/common'; import { initialImageSelected, modelSelected } from 'features/parameters/store/actions'; import { heightRecalled, @@ -27,19 +35,18 @@ import { vaeSelected, widthRecalled, } from 'features/parameters/store/generationSlice'; -import type { ParameterModel } from 'features/parameters/types/parameterSchemas'; import { isParameterCFGRescaleMultiplier, isParameterCFGScale, isParameterHeight, isParameterHRFEnabled, isParameterHRFMethod, + isParameterLoRAWeight, isParameterNegativePrompt, isParameterNegativeStylePromptSDXL, isParameterPositivePrompt, isParameterPositiveStylePromptSDXL, isParameterScheduler, - isParameterSDXLRefinerModel, isParameterSDXLRefinerNegativeAestheticScore, isParameterSDXLRefinerPositiveAestheticScore, isParameterSDXLRefinerStart, @@ -49,13 +56,16 @@ import { isParameterWidth, } from 'features/parameters/types/parameterSchemas'; import { - prepareControlNetMetadataItem, - prepareIPAdapterMetadataItem, - prepareLoRAMetadataItem, - prepareMainModelMetadataItem, - prepareT2IAdapterMetadataItem, - prepareVAEMetadataItem, -} from 'features/parameters/util/modelMetadataHelpers'; + fetchControlNetModel, + fetchIPAdapterModel, + fetchLoRAModel, + fetchMainModelConfig, + fetchRefinerModelConfig, + fetchT2IAdapterModel, + fetchVAEModelConfig, + getModelKey, + raiseIfBaseIncompatible, +} from 'features/parameters/util/modelFetchingHelpers'; import { refinerModelChanged, setNegativeStylePromptSDXL, @@ -67,229 +77,663 @@ import { setRefinerStart, setRefinerSteps, } from 'features/sdxl/store/sdxlSlice'; -import { isNil } from 'lodash-es'; +import { t } from 'i18next'; +import { get, isArray, isNil } from 'lodash-es'; import { useCallback } from 'react'; -import { useTranslation } from 'react-i18next'; -import type { ImageDTO } from 'services/api/types'; +import type { BaseModelType, ImageDTO } from 'services/api/types'; +import { v4 as uuidv4 } from 'uuid'; const selectModel = createMemoizedSelector(selectGenerationSlice, (generation) => generation.model); +/** + * A function that recalls from metadata from the full metadata object. + */ +type MetadataRecallFunc = (metadata: unknown, withToast?: boolean) => void; + +/** + * A function that recalls metadata from a specific metadata item. + */ +type MetadataItemRecallFunc = (metadataItem: unknown, withToast?: boolean) => void; + +/** + * Raised when metadata recall fails. + */ +export class MetadataRecallError extends Error { + /** + * Create MetadataRecallError + * @param {String} message + */ + constructor(message: string) { + super(message); + this.name = this.constructor.name; + } +} + +export class InvalidMetadataPropertyType extends MetadataRecallError {} + +const getProperty = ( + obj: unknown, + property: string, + typeGuard: (val: unknown) => val is T = (val: unknown): val is T => true +): T => { + const val = get(obj, property) as unknown; + if (typeGuard(val)) { + return val; + } + throw new InvalidMetadataPropertyType(`Property ${property} is not of expected type`); +}; + +const getCurrentBase = () => selectModel(getStore().getState())?.base; + +const parameterSetToast = (parameter: string, description?: string) => { + toast({ + title: t('toast.parameterSet', { parameter }), + description, + status: 'info', + duration: 2500, + isClosable: true, + }); +}; + +const parameterNotSetToast = (parameter: string, description?: string) => { + toast({ + title: t('toast.parameterNotSet', { parameter }), + description, + status: 'warning', + duration: 2500, + isClosable: true, + }); +}; + +const allParameterSetToast = (description?: string) => { + toast({ + title: t('toast.parametersSet'), + status: 'info', + description, + duration: 2500, + isClosable: true, + }); +}; + +const allParameterNotSetToast = (description?: string) => { + toast({ + title: t('toast.parametersNotSet'), + status: 'warning', + description, + duration: 2500, + isClosable: true, + }); +}; + +const recall = (callback: () => void, parameter: string, withToast = true) => { + try { + callback(); + withToast && parameterSetToast(parameter); + } catch (e) { + withToast && parameterNotSetToast(parameter, (e as Error).message); + } +}; + +const recallAsync = async (callback: () => Promise, parameter: string, withToast = true) => { + try { + await callback(); + withToast && parameterSetToast(parameter); + } catch (e) { + withToast && parameterNotSetToast(parameter, (e as Error).message); + } +}; + +export const recallPositivePrompt: MetadataRecallFunc = (metadata: unknown, withToast = true) => { + recall( + () => { + const positive_prompt = getProperty(metadata, 'positive_prompt', isParameterPositivePrompt); + getStore().dispatch(setPositivePrompt(positive_prompt)); + }, + t('metadata.positivePrompt'), + withToast + ); +}; + +export const recallNegativePrompt: MetadataRecallFunc = (metadata: unknown, withToast = true) => { + recall( + () => { + const negative_prompt = getProperty(metadata, 'negative_prompt', isParameterNegativePrompt); + getStore().dispatch(setNegativePrompt(negative_prompt)); + }, + t('metadata.negativePrompt'), + withToast + ); +}; + +export const recallSDXLPositiveStylePrompt: MetadataRecallFunc = (metadata: unknown, withToast = true) => { + recall( + () => { + const positive_style_prompt = getProperty(metadata, 'positive_style_prompt', isParameterPositiveStylePromptSDXL); + getStore().dispatch(setPositiveStylePromptSDXL(positive_style_prompt)); + }, + t('sdxl.posStylePrompt'), + withToast + ); +}; + +export const recallSDXLNegativeStylePrompt: MetadataRecallFunc = (metadata: unknown, withToast = true) => { + recall( + () => { + const negative_style_prompt = getProperty(metadata, 'negative_style_prompt', isParameterNegativeStylePromptSDXL); + getStore().dispatch(setNegativeStylePromptSDXL(negative_style_prompt)); + }, + t('sdxl.negStylePrompt'), + withToast + ); +}; + +export const recallSeed: MetadataRecallFunc = (metadata: unknown, withToast = true) => { + recall( + () => { + const seed = getProperty(metadata, 'seed', isParameterSeed); + getStore().dispatch(setSeed(seed)); + }, + t('metadata.seed'), + withToast + ); +}; + +export const recallCFGScale: MetadataRecallFunc = (metadata: unknown, withToast = true) => { + recall( + () => { + const cfg_scale = getProperty(metadata, 'cfg_scale', isParameterCFGScale); + getStore().dispatch(setCfgScale(cfg_scale)); + }, + t('metadata.cfgScale'), + withToast + ); +}; + +export const recallCFGRescaleMultiplier: MetadataRecallFunc = (metadata: unknown, withToast = true) => { + recall( + () => { + const cfg_rescale_multiplier = getProperty(metadata, 'cfg_rescale_multiplier', isParameterCFGRescaleMultiplier); + getStore().dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier)); + }, + t('metadata.cfgRescaleMultiplier'), + withToast + ); +}; + +export const recallScheduler: MetadataRecallFunc = (metadata: unknown, withToast = true) => { + recall( + () => { + const scheduler = getProperty(metadata, 'scheduler', isParameterScheduler); + getStore().dispatch(setScheduler(scheduler)); + }, + t('metadata.scheduler'), + withToast + ); +}; + +export const recallWidth: MetadataRecallFunc = (metadata: unknown, withToast = true) => { + recall( + () => { + const width = getProperty(metadata, 'width', isParameterWidth); + getStore().dispatch(widthRecalled(width)); + }, + t('metadata.width'), + withToast + ); +}; + +export const recallHeight: MetadataRecallFunc = (metadata: unknown, withToast = true) => { + recall( + () => { + const height = getProperty(metadata, 'height', isParameterHeight); + getStore().dispatch(heightRecalled(height)); + }, + t('metadata.height'), + withToast + ); +}; + +export const recallSteps: MetadataRecallFunc = (metadata: unknown, withToast = true) => { + recall( + () => { + const steps = getProperty(metadata, 'steps', isParameterSteps); + getStore().dispatch(setSteps(steps)); + }, + t('metadata.steps'), + withToast + ); +}; + +export const recallStrength: MetadataRecallFunc = (metadata: unknown, withToast = true) => { + recall( + () => { + const strength = getProperty(metadata, 'strength', isParameterStrength); + getStore().dispatch(setImg2imgStrength(strength)); + }, + t('metadata.strength'), + withToast + ); +}; + +export const recallHRFEnabled: MetadataRecallFunc = (metadata: unknown, withToast = true) => { + recall( + () => { + const hrf_enabled = getProperty(metadata, 'hrf_enabled', isParameterHRFEnabled); + getStore().dispatch(setHrfEnabled(hrf_enabled)); + }, + t('hrf.metadata.enabled'), + withToast + ); +}; + +export const recallHRFStrength: MetadataRecallFunc = (metadata: unknown, withToast = true) => { + recall( + () => { + const hrf_strength = getProperty(metadata, 'hrf_strength', isParameterStrength); + getStore().dispatch(setHrfStrength(hrf_strength)); + }, + t('hrf.metadata.strength'), + withToast + ); +}; + +export const recallHRFMethod: MetadataRecallFunc = (metadata: unknown, withToast = true) => { + recall( + () => { + const hrf_method = getProperty(metadata, 'hrf_method', isParameterHRFMethod); + getStore().dispatch(setHrfMethod(hrf_method)); + }, + t('hrf.metadata.method'), + withToast + ); +}; + +export const recallRefinerSteps: MetadataRecallFunc = (metadata: unknown, withToast = true) => { + recall( + () => { + const refiner_steps = getProperty(metadata, 'refiner_steps', isParameterSteps); + getStore().dispatch(setRefinerSteps(refiner_steps)); + }, + t('sdxl.steps'), + withToast + ); +}; + +export const recallRefinerCFGScale: MetadataRecallFunc = (metadata: unknown, withToast = true) => { + recall( + () => { + const refiner_cfg_scale = getProperty(metadata, 'refiner_cfg_scale', isParameterCFGScale); + getStore().dispatch(setRefinerCFGScale(refiner_cfg_scale)); + }, + t('sdxl.cfgScale'), + withToast + ); +}; + +export const recallRefinerScheduler: MetadataRecallFunc = (metadata: unknown, withToast = true) => { + recall( + () => { + const refiner_scheduler = getProperty(metadata, 'refiner_scheduler', isParameterScheduler); + getStore().dispatch(setRefinerScheduler(refiner_scheduler)); + }, + t('sdxl.cfgScale'), + withToast + ); +}; + +export const recallRefinerPositiveAestheticScore: MetadataRecallFunc = (metadata: unknown, withToast = true) => { + recall( + () => { + const refiner_positive_aesthetic_score = getProperty( + metadata, + 'refiner_positive_aesthetic_score', + isParameterSDXLRefinerPositiveAestheticScore + ); + getStore().dispatch(setRefinerPositiveAestheticScore(refiner_positive_aesthetic_score)); + }, + t('sdxl.posAestheticScore'), + withToast + ); +}; + +export const recallRefinerNegativeAestheticScore: MetadataRecallFunc = (metadata: unknown, withToast = true) => { + recall( + () => { + const refiner_negative_aesthetic_score = getProperty( + metadata, + 'refiner_negative_aesthetic_score', + isParameterSDXLRefinerNegativeAestheticScore + ); + getStore().dispatch(setRefinerNegativeAestheticScore(refiner_negative_aesthetic_score)); + }, + t('sdxl.negAestheticScore'), + withToast + ); +}; + +export const recallRefinerStart: MetadataRecallFunc = (metadata: unknown, withToast = true) => { + recall( + () => { + const refiner_start = getProperty(metadata, 'refiner_start', isParameterSDXLRefinerStart); + getStore().dispatch(setRefinerStart(refiner_start)); + }, + t('sdxl.refinerStart'), + withToast + ); +}; + +export const prepareMainModelMetadataItem = async (model: unknown): Promise => { + const key = await getModelKey(model, 'main'); + const mainModel = await fetchMainModelConfig(key); + return zModelIdentifierWithBase.parse(mainModel); +}; + +const recallModelAsync: MetadataRecallFunc = async (metadata: unknown, withToast = true) => { + await recallAsync( + async () => { + const modelMetadataItem = getProperty(metadata, 'model'); + const model = await prepareMainModelMetadataItem(modelMetadataItem); + getStore().dispatch(modelSelected(model)); + }, + t('metadata.model'), + withToast + ); +}; + +export const prepareRefinerMetadataItem = async ( + model: unknown, + currentBase: BaseModelType | undefined +): Promise => { + const key = await getModelKey(model, 'main'); + const refinerModel = await fetchRefinerModelConfig(key); + raiseIfBaseIncompatible('sdxl-refiner', currentBase, 'Refiner incompatible with currently-selected model'); + return zModelIdentifierWithBase.parse(refinerModel); +}; + +const recallRefinerModelAsync: MetadataRecallFunc = async (metadata: unknown, withToast = true) => { + await recallAsync( + async () => { + const refinerMetadataItem = getProperty(metadata, 'refiner_model'); + const refiner = await prepareRefinerMetadataItem(refinerMetadataItem, getCurrentBase()); + getStore().dispatch(refinerModelChanged(refiner)); + }, + t('sdxl.refinerModel'), + withToast + ); +}; + +export const prepareVAEMetadataItem = async ( + vae: unknown, + currentBase: BaseModelType | undefined +): Promise => { + const key = await getModelKey(vae, 'vae'); + const vaeModel = await fetchVAEModelConfig(key); + raiseIfBaseIncompatible(vaeModel.base, currentBase, 'VAE incompatible with currently-selected model'); + return zModelIdentifierWithBase.parse(vaeModel); +}; + +const recallVAEAsync: MetadataRecallFunc = async (metadata: unknown, withToast = true) => { + await recallAsync( + async () => { + const currentBase = getCurrentBase(); + const vaeMetadataItem = getProperty(metadata, 'vae'); + if (isNil(vaeMetadataItem)) { + getStore().dispatch(vaeSelected(null)); + } else { + const vae = await prepareVAEMetadataItem(vaeMetadataItem, currentBase); + getStore().dispatch(vaeSelected(vae)); + } + }, + t('metadata.vae'), + withToast + ); +}; + +export const prepareLoRAMetadataItem = async ( + loraMetadataItem: unknown, + currentBase: BaseModelType | undefined +): Promise => { + const lora = getProperty(loraMetadataItem, 'lora'); + const weight = getProperty(loraMetadataItem, 'weight'); + const key = await getModelKey(lora, 'lora'); + const loraModel = await fetchLoRAModel(key); + raiseIfBaseIncompatible(loraModel.base, currentBase, 'LoRA incompatible with currently-selected model'); + return { + key: loraModel.key, + base: loraModel.base, + weight: isParameterLoRAWeight(weight) ? weight : defaultLoRAConfig.weight, + isEnabled: true, + }; +}; + +const recallLoRAAsync: MetadataItemRecallFunc = async (metadataItem: unknown, withToast = true) => { + await recallAsync( + async () => { + const currentBase = getCurrentBase(); + const lora = await prepareLoRAMetadataItem(metadataItem, currentBase); + getStore().dispatch(loraRecalled(lora)); + }, + t('models.lora'), + withToast + ); +}; + +export const prepareControlNetMetadataItem = async ( + metadataItem: unknown, + currentBase: BaseModelType | undefined +): Promise => { + const control_model = getProperty(metadataItem, 'control_model'); + const key = await getModelKey(control_model, 'controlnet'); + const controlNetModel = await fetchControlNetModel(key); + raiseIfBaseIncompatible(controlNetModel.base, currentBase, 'ControlNet incompatible with currently-selected model'); + + const image = zControlField.shape.image.nullish().catch(null).parse(getProperty(metadataItem, 'image')); + const control_weight = zControlField.shape.control_weight + .nullish() + .catch(null) + .parse(getProperty(metadataItem, 'control_weight')); + const begin_step_percent = zControlField.shape.begin_step_percent + .nullish() + .catch(null) + .parse(getProperty(metadataItem, 'begin_step_percent')); + const end_step_percent = zControlField.shape.end_step_percent + .nullish() + .catch(null) + .parse(getProperty(metadataItem, 'end_step_percent')); + const control_mode = zControlField.shape.control_mode + .nullish() + .catch(null) + .parse(getProperty(metadataItem, 'control_mode')); + const resize_mode = zControlField.shape.resize_mode + .nullish() + .catch(null) + .parse(getProperty(metadataItem, 'resize_mode')); + + // We don't save the original image that was processed into a control image, only the processed image + const processorType = 'none'; + const processorNode = CONTROLNET_PROCESSORS.none.default; + + const controlnet: ControlNetConfig = { + type: 'controlnet', + isEnabled: true, + model: zModelIdentifierWithBase.parse(controlNetModel), + weight: typeof control_weight === 'number' ? control_weight : initialControlNet.weight, + beginStepPct: begin_step_percent ?? initialControlNet.beginStepPct, + endStepPct: end_step_percent ?? initialControlNet.endStepPct, + controlMode: control_mode ?? initialControlNet.controlMode, + resizeMode: resize_mode ?? initialControlNet.resizeMode, + controlImage: image?.image_name ?? null, + processedControlImage: image?.image_name ?? null, + processorType, + processorNode, + shouldAutoConfig: true, + id: uuidv4(), + }; + + return controlnet; +}; + +const recallControlNetAsync: MetadataItemRecallFunc = async (metadataItem: unknown, withToast = true) => { + await recallAsync( + async () => { + const currentBase = getCurrentBase(); + const controlNetConfig = await prepareControlNetMetadataItem(metadataItem, currentBase); + getStore().dispatch(controlAdapterRecalled(controlNetConfig)); + }, + t('common.controlNet'), + withToast + ); +}; + +export const prepareT2IAdapterMetadataItem = async ( + metadataItem: unknown, + currentBase: BaseModelType | undefined +): Promise => { + const t2i_adapter_model = getProperty(metadataItem, 't2i_adapter_model'); + const key = await getModelKey(t2i_adapter_model, 't2i_adapter'); + const t2iAdapterModel = await fetchT2IAdapterModel(key); + raiseIfBaseIncompatible(t2iAdapterModel.base, currentBase, 'T2I Adapter incompatible with currently-selected model'); + + const image = zT2IAdapterField.shape.image.nullish().catch(null).parse(getProperty(metadataItem, 'image')); + const weight = zT2IAdapterField.shape.weight.nullish().catch(null).parse(getProperty(metadataItem, 'weight')); + const begin_step_percent = zT2IAdapterField.shape.begin_step_percent + .nullish() + .catch(null) + .parse(getProperty(metadataItem, 'begin_step_percent')); + const end_step_percent = zT2IAdapterField.shape.end_step_percent + .nullish() + .catch(null) + .parse(getProperty(metadataItem, 'end_step_percent')); + const resize_mode = zT2IAdapterField.shape.resize_mode + .nullish() + .catch(null) + .parse(getProperty(metadataItem, 'resize_mode')); + + // We don't save the original image that was processed into a control image, only the processed image + const processorType = 'none'; + const processorNode = CONTROLNET_PROCESSORS.none.default; + + const t2iAdapter: T2IAdapterConfig = { + type: 't2i_adapter', + isEnabled: true, + model: zModelIdentifierWithBase.parse(t2iAdapterModel), + weight: typeof weight === 'number' ? weight : initialT2IAdapter.weight, + beginStepPct: begin_step_percent ?? initialT2IAdapter.beginStepPct, + endStepPct: end_step_percent ?? initialT2IAdapter.endStepPct, + resizeMode: resize_mode ?? initialT2IAdapter.resizeMode, + controlImage: image?.image_name ?? null, + processedControlImage: image?.image_name ?? null, + processorType, + processorNode, + shouldAutoConfig: true, + id: uuidv4(), + }; + + return t2iAdapter; +}; + +const recallT2IAdapterAsync: MetadataItemRecallFunc = async (metadataItem: unknown, withToast = true) => { + await recallAsync( + async () => { + const currentBase = getCurrentBase(); + const t2iAdapterConfig = await prepareT2IAdapterMetadataItem(metadataItem, currentBase); + getStore().dispatch(controlAdapterRecalled(t2iAdapterConfig)); + }, + t('common.t2iAdapter'), + withToast + ); +}; + +export const prepareIPAdapterMetadataItem = async ( + metadataItem: unknown, + currentBase: BaseModelType | undefined +): Promise => { + const ip_adapter_model = getProperty(metadataItem, 'ip_adapter_model'); + const key = await getModelKey(ip_adapter_model, 'ip_adapter'); + const ipAdapterModel = await fetchIPAdapterModel(key); + raiseIfBaseIncompatible(ipAdapterModel.base, currentBase, 'T2I Adapter incompatible with currently-selected model'); + + const image = zIPAdapterField.shape.image.nullish().catch(null).parse(getProperty(metadataItem, 'image')); + const weight = zIPAdapterField.shape.weight.nullish().catch(null).parse(getProperty(metadataItem, 'weight')); + const begin_step_percent = zIPAdapterField.shape.begin_step_percent + .nullish() + .catch(null) + .parse(getProperty(metadataItem, 'begin_step_percent')); + const end_step_percent = zIPAdapterField.shape.end_step_percent + .nullish() + .catch(null) + .parse(getProperty(metadataItem, 'end_step_percent')); + + const ipAdapter: IPAdapterConfig = { + id: uuidv4(), + type: 'ip_adapter', + isEnabled: true, + model: zModelIdentifierWithBase.parse(ipAdapterModel), + controlImage: image?.image_name ?? null, + weight: weight ?? initialIPAdapter.weight, + beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct, + endStepPct: end_step_percent ?? initialIPAdapter.endStepPct, + }; + + return ipAdapter; +}; + +const recallIPAdapterAsync: MetadataItemRecallFunc = async (metadataItem: unknown, withToast) => { + await recallAsync( + async () => { + const currentBase = getCurrentBase(); + const ipAdapterConfig = await prepareIPAdapterMetadataItem(metadataItem, currentBase); + getStore().dispatch(controlAdapterRecalled(ipAdapterConfig)); + }, + t('common.ipAdapter'), + withToast + ); +}; + export const useRecallParameters = () => { const dispatch = useAppDispatch(); - const toaster = useAppToaster(); - const { t } = useTranslation(); - const model = useAppSelector(selectModel); - const parameterSetToast = useCallback(() => { - toaster({ - title: t('toast.parameterSet'), - status: 'info', - duration: 2500, - isClosable: true, - }); - }, [t, toaster]); - - const parameterNotSetToast = useCallback( - (description?: string) => { - toaster({ - title: t('toast.parameterNotSet'), - description, - status: 'warning', - duration: 2500, - isClosable: true, - }); - }, - [t, toaster] - ); - - const allParameterSetToast = useCallback(() => { - toaster({ - title: t('toast.parametersSet'), - status: 'info', - duration: 2500, - isClosable: true, - }); - }, [t, toaster]); - - const allParameterNotSetToast = useCallback( - (description?: string) => { - toaster({ - title: t('toast.parametersNotSet'), - status: 'warning', - description, - duration: 2500, - isClosable: true, - }); - }, - [t, toaster] - ); - - const recallBothPrompts = useCallback( - (positivePrompt: unknown, negativePrompt: unknown, positiveStylePrompt: unknown, negativeStylePrompt: unknown) => { + const recallBothPrompts = useCallback( + (metadata: unknown) => { + const positive_prompt = getProperty(metadata, 'positive_prompt'); + const negative_prompt = getProperty(metadata, 'negative_prompt'); + const positive_style_prompt = getProperty(metadata, 'positive_style_prompt'); + const negative_style_prompt = getProperty(metadata, 'negative_style_prompt'); if ( - isParameterPositivePrompt(positivePrompt) || - isParameterNegativePrompt(negativePrompt) || - isParameterPositiveStylePromptSDXL(positiveStylePrompt) || - isParameterNegativeStylePromptSDXL(negativeStylePrompt) + isParameterPositivePrompt(positive_prompt) || + isParameterNegativePrompt(negative_prompt) || + isParameterPositiveStylePromptSDXL(positive_style_prompt) || + isParameterNegativeStylePromptSDXL(negative_style_prompt) ) { - if (isParameterPositivePrompt(positivePrompt)) { - dispatch(setPositivePrompt(positivePrompt)); + if (isParameterPositivePrompt(positive_prompt)) { + dispatch(setPositivePrompt(positive_prompt)); } - if (isParameterNegativePrompt(negativePrompt)) { - dispatch(setNegativePrompt(negativePrompt)); + if (isParameterNegativePrompt(negative_prompt)) { + dispatch(setNegativePrompt(negative_prompt)); } - if (isParameterPositiveStylePromptSDXL(positiveStylePrompt)) { - dispatch(setPositiveStylePromptSDXL(positiveStylePrompt)); + if (isParameterPositiveStylePromptSDXL(positive_style_prompt)) { + dispatch(setPositiveStylePromptSDXL(positive_style_prompt)); } - if (isParameterPositiveStylePromptSDXL(negativeStylePrompt)) { - dispatch(setNegativeStylePromptSDXL(negativeStylePrompt)); + if (isParameterPositiveStylePromptSDXL(negative_style_prompt)) { + dispatch(setNegativeStylePromptSDXL(negative_style_prompt)); } - parameterSetToast(); + parameterSetToast(t('metadata.allPrompts')); return; } - parameterNotSetToast(); + parameterNotSetToast(t('metadata.allPrompts')); }, - [dispatch, parameterSetToast, parameterNotSetToast] + [dispatch] ); - const recallPositivePrompt = useCallback( - (positivePrompt: unknown) => { - if (!isParameterPositivePrompt(positivePrompt)) { - parameterNotSetToast(); - return; - } - dispatch(setPositivePrompt(positivePrompt)); - parameterSetToast(); - }, - [dispatch, parameterSetToast, parameterNotSetToast] - ); + const recallWidthAndHeight = useCallback( + (metadata: unknown) => { + const width = getProperty(metadata, 'width'); + const height = getProperty(metadata, 'height'); - const recallNegativePrompt = useCallback( - (negativePrompt: unknown) => { - if (!isParameterNegativePrompt(negativePrompt)) { - parameterNotSetToast(); - return; - } - dispatch(setNegativePrompt(negativePrompt)); - parameterSetToast(); - }, - [dispatch, parameterSetToast, parameterNotSetToast] - ); - - const recallSDXLPositiveStylePrompt = useCallback( - (positiveStylePrompt: unknown) => { - if (!isParameterPositiveStylePromptSDXL(positiveStylePrompt)) { - parameterNotSetToast(); - return; - } - dispatch(setPositiveStylePromptSDXL(positiveStylePrompt)); - parameterSetToast(); - }, - [dispatch, parameterSetToast, parameterNotSetToast] - ); - - const recallSDXLNegativeStylePrompt = useCallback( - (negativeStylePrompt: unknown) => { - if (!isParameterNegativeStylePromptSDXL(negativeStylePrompt)) { - parameterNotSetToast(); - return; - } - dispatch(setNegativeStylePromptSDXL(negativeStylePrompt)); - parameterSetToast(); - }, - [dispatch, parameterSetToast, parameterNotSetToast] - ); - - const recallSeed = useCallback( - (seed: unknown) => { - if (!isParameterSeed(seed)) { - parameterNotSetToast(); - return; - } - dispatch(setSeed(seed)); - parameterSetToast(); - }, - [dispatch, parameterSetToast, parameterNotSetToast] - ); - - const recallCfgScale = useCallback( - (cfgScale: unknown) => { - if (!isParameterCFGScale(cfgScale)) { - parameterNotSetToast(); - return; - } - dispatch(setCfgScale(cfgScale)); - parameterSetToast(); - }, - [dispatch, parameterSetToast, parameterNotSetToast] - ); - - const recallCfgRescaleMultiplier = useCallback( - (cfgRescaleMultiplier: unknown) => { - if (!isParameterCFGRescaleMultiplier(cfgRescaleMultiplier)) { - parameterNotSetToast(); - return; - } - dispatch(setCfgRescaleMultiplier(cfgRescaleMultiplier)); - parameterSetToast(); - }, - [dispatch, parameterSetToast, parameterNotSetToast] - ); - - const recallScheduler = useCallback( - (scheduler: unknown) => { - if (!isParameterScheduler(scheduler)) { - parameterNotSetToast(); - return; - } - dispatch(setScheduler(scheduler)); - parameterSetToast(); - }, - [dispatch, parameterSetToast, parameterNotSetToast] - ); - - const recallSteps = useCallback( - (steps: unknown) => { - if (!isParameterSteps(steps)) { - parameterNotSetToast(); - return; - } - dispatch(setSteps(steps)); - parameterSetToast(); - }, - [dispatch, parameterSetToast, parameterNotSetToast] - ); - - const recallWidth = useCallback( - (width: unknown) => { - if (!isParameterWidth(width)) { - parameterNotSetToast(); - return; - } - dispatch(widthRecalled(width)); - parameterSetToast(); - }, - [dispatch, parameterSetToast, parameterNotSetToast] - ); - - const recallHeight = useCallback( - (height: unknown) => { - if (!isParameterHeight(height)) { - parameterNotSetToast(); - return; - } - dispatch(heightRecalled(height)); - parameterSetToast(); - }, - [dispatch, parameterSetToast, parameterNotSetToast] - ); - - const recallWidthAndHeight = useCallback( - (width: unknown, height: unknown) => { if (!isParameterWidth(width)) { allParameterNotSetToast(); return; @@ -302,144 +746,7 @@ export const useRecallParameters = () => { dispatch(widthRecalled(width)); allParameterSetToast(); }, - [dispatch, allParameterSetToast, allParameterNotSetToast] - ); - - const recallStrength = useCallback( - (strength: unknown) => { - if (!isParameterStrength(strength)) { - parameterNotSetToast(); - return; - } - dispatch(setImg2imgStrength(strength)); - parameterSetToast(); - }, - [dispatch, parameterSetToast, parameterNotSetToast] - ); - - const recallHrfEnabled = useCallback( - (hrfEnabled: unknown) => { - if (!isParameterHRFEnabled(hrfEnabled)) { - parameterNotSetToast(); - return; - } - dispatch(setHrfEnabled(hrfEnabled)); - parameterSetToast(); - }, - [dispatch, parameterSetToast, parameterNotSetToast] - ); - - const recallHrfStrength = useCallback( - (hrfStrength: unknown) => { - if (!isParameterStrength(hrfStrength)) { - parameterNotSetToast(); - return; - } - dispatch(setHrfStrength(hrfStrength)); - parameterSetToast(); - }, - [dispatch, parameterSetToast, parameterNotSetToast] - ); - - const recallHrfMethod = useCallback( - (hrfMethod: unknown) => { - if (!isParameterHRFMethod(hrfMethod)) { - parameterNotSetToast(); - return; - } - dispatch(setHrfMethod(hrfMethod)); - parameterSetToast(); - }, - [dispatch, parameterSetToast, parameterNotSetToast] - ); - - const recallModel = useCallback( - async (modelMetadataItem: unknown) => { - try { - const model = await prepareMainModelMetadataItem(modelMetadataItem); - dispatch(modelSelected(model)); - parameterSetToast(); - } catch (e) { - parameterNotSetToast((e as unknown as Error).message); - return; - } - }, - [dispatch, parameterSetToast, parameterNotSetToast] - ); - - const recallVaeModel = useCallback( - async (vaeMetadataItem: unknown) => { - if (isNil(vaeMetadataItem)) { - dispatch(vaeSelected(null)); - parameterSetToast(); - return; - } - try { - const vae = await prepareVAEMetadataItem(vaeMetadataItem); - dispatch(vaeSelected(vae)); - parameterSetToast(); - } catch (e) { - parameterNotSetToast((e as unknown as Error).message); - return; - } - }, - [dispatch, parameterSetToast, parameterNotSetToast] - ); - - const recallLoRA = useCallback( - async (loraMetadataItem: LoRAMetadataItem) => { - try { - const lora = await prepareLoRAMetadataItem(loraMetadataItem, model?.base); - dispatch(loraRecalled(lora)); - parameterSetToast(); - } catch (e) { - parameterNotSetToast((e as unknown as Error).message); - return; - } - }, - [model?.base, dispatch, parameterSetToast, parameterNotSetToast] - ); - - const recallControlNet = useCallback( - async (controlnetMetadataItem: ControlNetMetadataItem) => { - try { - const controlNetConfig = await prepareControlNetMetadataItem(controlnetMetadataItem, model?.base); - dispatch(controlAdapterRecalled(controlNetConfig)); - parameterSetToast(); - } catch (e) { - parameterNotSetToast((e as unknown as Error).message); - return; - } - }, - [model?.base, dispatch, parameterSetToast, parameterNotSetToast] - ); - - const recallT2IAdapter = useCallback( - async (t2iAdapterMetadataItem: T2IAdapterMetadataItem) => { - try { - const t2iAdapterConfig = await prepareT2IAdapterMetadataItem(t2iAdapterMetadataItem, model?.base); - dispatch(controlAdapterRecalled(t2iAdapterConfig)); - parameterSetToast(); - } catch (e) { - parameterNotSetToast((e as unknown as Error).message); - return; - } - }, - [model?.base, dispatch, parameterSetToast, parameterNotSetToast] - ); - - const recallIPAdapter = useCallback( - async (ipAdapterMetadataItem: IPAdapterMetadataItem) => { - try { - const ipAdapterConfig = await prepareIPAdapterMetadataItem(ipAdapterMetadataItem, model?.base); - dispatch(controlAdapterRecalled(ipAdapterConfig)); - parameterSetToast(); - } catch (e) { - parameterNotSetToast((e as unknown as Error).message); - return; - } - }, - [model?.base, dispatch, parameterSetToast, parameterNotSetToast] + [dispatch] ); const sendToImageToImage = useCallback( @@ -449,197 +756,75 @@ export const useRecallParameters = () => { [dispatch] ); - const recallAllParameters = useCallback( - async (metadata: CoreMetadata | undefined) => { + const recallAllParameters = useCallback( + async (metadata: unknown) => { if (!metadata) { allParameterNotSetToast(); return; } - const { - cfg_scale, - cfg_rescale_multiplier, - height, - model, - positive_prompt, - negative_prompt, - scheduler, - vae, - seed, - steps, - width, - strength, - hrf_enabled, - hrf_strength, - hrf_method, - positive_style_prompt, - negative_style_prompt, - refiner_model, - refiner_cfg_scale, - refiner_steps, - refiner_scheduler, - refiner_positive_aesthetic_score, - refiner_negative_aesthetic_score, - refiner_start, - loras, - controlnets, - ipAdapters, - t2iAdapters, - } = metadata; + await recallModelAsync(metadata, false); - let newModel: ParameterModel | undefined = undefined; + recallCFGScale(metadata, false); + recallCFGRescaleMultiplier(metadata, false); + recallPositivePrompt(metadata, false); + recallNegativePrompt(metadata, false); + recallScheduler(metadata, false); + recallSeed(metadata, false); + recallSteps(metadata, false); + recallWidth(metadata, false); + recallHeight(metadata, false); + recallStrength(metadata, false); + recallHRFEnabled(metadata, false); + recallHRFMethod(metadata, false); + recallHRFStrength(metadata, false); - if (isModelIdentifier(model)) { - try { - const _model = await prepareMainModelMetadataItem(model); - dispatch(modelSelected(_model)); - newModel = _model; - } catch { - return; - } - } + // SDXL + recallSDXLPositiveStylePrompt(metadata, false); + recallSDXLNegativeStylePrompt(metadata, false); + recallRefinerSteps(metadata, false); + recallRefinerCFGScale(metadata, false); + recallRefinerScheduler(metadata, false); + recallRefinerPositiveAestheticScore(metadata, false); + recallRefinerNegativeAestheticScore(metadata, false); + recallRefinerStart(metadata, false); - if (isParameterCFGScale(cfg_scale)) { - dispatch(setCfgScale(cfg_scale)); - } - - if (isParameterCFGRescaleMultiplier(cfg_rescale_multiplier)) { - dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier)); - } - - if (isParameterPositivePrompt(positive_prompt)) { - dispatch(setPositivePrompt(positive_prompt)); - } - - if (isParameterNegativePrompt(negative_prompt)) { - dispatch(setNegativePrompt(negative_prompt)); - } - - if (isParameterScheduler(scheduler)) { - dispatch(setScheduler(scheduler)); - } - if (isModelIdentifier(vae) || isNil(vae)) { - if (isNil(vae)) { - dispatch(vaeSelected(null)); - } else { - try { - const _vae = await prepareVAEMetadataItem(vae, newModel?.base); - dispatch(vaeSelected(_vae)); - } catch { - return; - } - } - } - - if (isParameterSeed(seed)) { - dispatch(setSeed(seed)); - } - - if (isParameterSteps(steps)) { - dispatch(setSteps(steps)); - } - - if (isParameterWidth(width)) { - dispatch(widthRecalled(width)); - } - - if (isParameterHeight(height)) { - dispatch(heightRecalled(height)); - } - - if (isParameterStrength(strength)) { - dispatch(setImg2imgStrength(strength)); - } - - if (isParameterHRFEnabled(hrf_enabled)) { - dispatch(setHrfEnabled(hrf_enabled)); - } - - if (isParameterStrength(hrf_strength)) { - dispatch(setHrfStrength(hrf_strength)); - } - - if (isParameterHRFMethod(hrf_method)) { - dispatch(setHrfMethod(hrf_method)); - } - - if (isParameterPositiveStylePromptSDXL(positive_style_prompt)) { - dispatch(setPositiveStylePromptSDXL(positive_style_prompt)); - } - - if (isParameterNegativeStylePromptSDXL(negative_style_prompt)) { - dispatch(setNegativeStylePromptSDXL(negative_style_prompt)); - } - - if (isParameterSDXLRefinerModel(refiner_model)) { - dispatch(refinerModelChanged(refiner_model)); - } - - if (isParameterSteps(refiner_steps)) { - dispatch(setRefinerSteps(refiner_steps)); - } - - if (isParameterCFGScale(refiner_cfg_scale)) { - dispatch(setRefinerCFGScale(refiner_cfg_scale)); - } - - if (isParameterScheduler(refiner_scheduler)) { - dispatch(setRefinerScheduler(refiner_scheduler)); - } - - if (isParameterSDXLRefinerPositiveAestheticScore(refiner_positive_aesthetic_score)) { - dispatch(setRefinerPositiveAestheticScore(refiner_positive_aesthetic_score)); - } - - if (isParameterSDXLRefinerNegativeAestheticScore(refiner_negative_aesthetic_score)) { - dispatch(setRefinerNegativeAestheticScore(refiner_negative_aesthetic_score)); - } - - if (isParameterSDXLRefinerStart(refiner_start)) { - dispatch(setRefinerStart(refiner_start)); - } + await recallVAEAsync(metadata, false); + await recallRefinerModelAsync(metadata, false); dispatch(lorasCleared()); - loras?.forEach(async (loraMetadataItem) => { - try { - const lora = await prepareLoRAMetadataItem(loraMetadataItem, newModel?.base); - dispatch(loraRecalled(lora)); - } catch { - return; - } - }); + const loraMetadataArray = getProperty(metadata, 'loras'); + if (isArray(loraMetadataArray)) { + loraMetadataArray.forEach(async (loraMetadataItem) => { + await recallLoRAAsync(loraMetadataItem, false); + }); + } dispatch(controlAdaptersReset()); - controlnets?.forEach(async (controlNetMetadataItem) => { - try { - const controlNet = await prepareControlNetMetadataItem(controlNetMetadataItem, newModel?.base); - dispatch(controlAdapterRecalled(controlNet)); - } catch { - return; - } - }); + const controlNetMetadataArray = getProperty(metadata, 'controlnets'); + if (isArray(controlNetMetadataArray)) { + controlNetMetadataArray.forEach(async (controlNetMetadataItem) => { + await recallControlNetAsync(controlNetMetadataItem, false); + }); + } - ipAdapters?.forEach(async (ipAdapterMetadataItem) => { - try { - const ipAdapter = await prepareIPAdapterMetadataItem(ipAdapterMetadataItem, newModel?.base); - dispatch(controlAdapterRecalled(ipAdapter)); - } catch { - return; - } - }); + const ipAdapterMetadataArray = getProperty(metadata, 'ipAdapters'); + if (isArray(ipAdapterMetadataArray)) { + ipAdapterMetadataArray.forEach(async (ipAdapterMetadataItem) => { + await recallIPAdapterAsync(ipAdapterMetadataItem, false); + }); + } - t2iAdapters?.forEach(async (t2iAdapterMetadataItem) => { - try { - const t2iAdapter = await prepareT2IAdapterMetadataItem(t2iAdapterMetadataItem, newModel?.base); - dispatch(controlAdapterRecalled(t2iAdapter)); - } catch { - return; - } - }); + const t2iAdapterMetadataArray = getProperty(metadata, 't2iAdapters'); + if (isArray(t2iAdapterMetadataArray)) { + t2iAdapterMetadataArray.forEach(async (t2iAdapterMetadataItem) => { + await recallT2IAdapterAsync(t2iAdapterMetadataItem, false); + }); + } allParameterSetToast(); }, - [dispatch, allParameterSetToast, allParameterNotSetToast] + [dispatch] ); return { @@ -649,24 +834,25 @@ export const useRecallParameters = () => { recallSDXLPositiveStylePrompt, recallSDXLNegativeStylePrompt, recallSeed, - recallCfgScale, - recallCfgRescaleMultiplier, - recallModel, + recallCFGScale, + recallCFGRescaleMultiplier, + recallModel: recallModelAsync, recallScheduler, - recallVaeModel, + recallVaeModel: recallVAEAsync, recallSteps, recallWidth, recallHeight, recallWidthAndHeight, recallStrength, - recallHrfEnabled, - recallHrfStrength, - recallHrfMethod, - recallLoRA, - recallControlNet, - recallIPAdapter, - recallT2IAdapter, + recallHRFEnabled, + recallHRFStrength, + recallHRFMethod, + recallLoRA: recallLoRAAsync, + recallControlNet: recallControlNetAsync, + recallIPAdapter: recallIPAdapterAsync, + recallT2IAdapter: recallT2IAdapterAsync, recallAllParameters, + recallRefinerModel: recallRefinerModelAsync, sendToImageToImage, }; }; diff --git a/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts b/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts index 20953ec266..8d46add8c8 100644 --- a/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts +++ b/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts @@ -1,14 +1,7 @@ import { NUMPY_RAND_MAX } from 'app/constants'; import { - zBaseModel, - zControlNetModelField, - zIPAdapterModelField, - zLoRAModelField, zModelIdentifierWithBase, zSchedulerField, - zSDXLRefinerModelField, - zT2IAdapterModelField, - zVAEModelField, } from 'features/nodes/types/common'; import { z } from 'zod'; @@ -111,42 +104,42 @@ export const isParameterModel = (val: unknown): val is ParameterModel => zParame // #endregion // #region SDXL Refiner Model -export const zParameterSDXLRefinerModel = zSDXLRefinerModelField.extend({ base: zBaseModel }); +export const zParameterSDXLRefinerModel = zModelIdentifierWithBase; export type ParameterSDXLRefinerModel = z.infer; export const isParameterSDXLRefinerModel = (val: unknown): val is ParameterSDXLRefinerModel => zParameterSDXLRefinerModel.safeParse(val).success; // #endregion // #region VAE Model -export const zParameterVAEModel = zVAEModelField.extend({ base: zBaseModel }); +export const zParameterVAEModel = zModelIdentifierWithBase; export type ParameterVAEModel = z.infer; export const isParameterVAEModel = (val: unknown): val is ParameterVAEModel => zParameterVAEModel.safeParse(val).success; // #endregion // #region LoRA Model -export const zParameterLoRAModel = zLoRAModelField.extend({ base: zBaseModel }); +export const zParameterLoRAModel = zModelIdentifierWithBase; export type ParameterLoRAModel = z.infer; export const isParameterLoRAModel = (val: unknown): val is ParameterLoRAModel => zParameterLoRAModel.safeParse(val).success; // #endregion // #region ControlNet Model -export const zParameterControlNetModel = zControlNetModelField.extend({ base: zBaseModel }); +export const zParameterControlNetModel = zModelIdentifierWithBase; export type ParameterControlNetModel = z.infer; export const isParameterControlNetModel = (val: unknown): val is ParameterControlNetModel => zParameterControlNetModel.safeParse(val).success; // #endregion // #region IP Adapter Model -export const zParameterIPAdapterModel = zIPAdapterModelField.extend({ base: zBaseModel }); +export const zParameterIPAdapterModel = zModelIdentifierWithBase; export type ParameterIPAdapterModel = z.infer; export const isParameterIPAdapterModel = (val: unknown): val is ParameterIPAdapterModel => zParameterIPAdapterModel.safeParse(val).success; // #endregion // #region T2I Adapter Model -export const zParameterT2IAdapterModel = zT2IAdapterModelField.extend({ base: zBaseModel }); +export const zParameterT2IAdapterModel = zModelIdentifierWithBase; export type ParameterT2IAdapterModel = z.infer; export const isParameterT2IAdapterModel = (val: unknown): val is ParameterT2IAdapterModel => zParameterT2IAdapterModel.safeParse(val).success; @@ -218,3 +211,9 @@ export type ParameterCanvasCoherenceMode = z.infer zParameterCanvasCoherenceMode.safeParse(val).success; // #endregion + +// #region LoRA weight +export const zLoRAWeight = z.number(); +export type ParameterLoRAWeight = z.infer; +export const isParameterLoRAWeight = (val: unknown): val is ParameterLoRAWeight => zLoRAWeight.safeParse(val).success; +// #endregion diff --git a/invokeai/frontend/web/src/features/parameters/util/modelFetchingHelpers.ts b/invokeai/frontend/web/src/features/parameters/util/modelFetchingHelpers.ts index c7d25fed8b..bd3e42a314 100644 --- a/invokeai/frontend/web/src/features/parameters/util/modelFetchingHelpers.ts +++ b/invokeai/frontend/web/src/features/parameters/util/modelFetchingHelpers.ts @@ -1,7 +1,8 @@ import { getStore } from 'app/store/nanostores/store'; -import { isModelIdentifier } from 'features/nodes/types/common'; +import type { ModelIdentifierWithBase } from 'features/nodes/types/common'; +import { isModelIdentifier, isModelIdentifierV2 } from 'features/nodes/types/common'; import { modelsApi } from 'services/api/endpoints/models'; -import type { AnyModelConfig, BaseModelType } from 'services/api/types'; +import type { AnyModelConfig, BaseModelType, ModelType } from 'services/api/types'; import { isControlNetModelConfig, isIPAdapterModelConfig, @@ -41,6 +42,12 @@ export class InvalidModelConfigError extends Error { } } +/** + * Fetches the model config for a given model key. + * @param key The model key. + * @returns A promise that resolves to the model config. + * @throws {ModelConfigNotFoundError} If the model config is unable to be fetched. + */ export const fetchModelConfig = async (key: string): Promise => { const { dispatch } = getStore(); try { @@ -52,6 +59,37 @@ export const fetchModelConfig = async (key: string): Promise => } }; +/** + * Fetches the model config for a given model name, base model, and model type. This provides backwards compatibility + * for MM1 model identifiers. + * @param name The model name. + * @param base The base model. + * @param type The model type. + * @returns A promise that resolves to the model config. + * @throws {ModelConfigNotFoundError} If the model config is unable to be fetched. + */ +export const fetchModelConfigByAttrs = async ( + name: string, + base: BaseModelType, + type: ModelType +): Promise => { + const { dispatch } = getStore(); + try { + const req = dispatch(modelsApi.endpoints.getModelConfigByAttrs.initiate({ name, base, type })); + req.unsubscribe(); + return await req.unwrap(); + } catch { + throw new ModelConfigNotFoundError(`Unable to retrieve model config for name/base/type ${name}/${base}/${type}`); + } +}; + +/** + * Fetches the model config for a given model key and type, and ensures that the model config is of a specific type. + * @param key The model key. + * @param typeGuard A type guard function that checks if the model config is of the expected type. + * @returns A promise that resolves to the model config. The model config is guaranteed to be of the expected type. + * @throws {InvalidModelConfigError} If the model config is unable to be fetched or is of an unexpected type. + */ export const fetchModelConfigWithTypeGuard = async ( key: string, typeGuard: (config: AnyModelConfig) => config is T @@ -63,15 +101,17 @@ export const fetchModelConfigWithTypeGuard = async ( return modelConfig; }; -export const fetchMainModel = async (key: string) => { +// TODO(psyche): Remove these helpers once `useRecallParameters` is removed + +export const fetchMainModelConfig = async (key: string) => { return fetchModelConfigWithTypeGuard(key, isNonRefinerMainModelConfig); }; -export const fetchRefinerModel = async (key: string) => { +export const fetchRefinerModelConfig = async (key: string) => { return fetchModelConfigWithTypeGuard(key, isRefinerMainModelModelConfig); }; -export const fetchVAEModel = async (key: string) => { +export const fetchVAEModelConfig = async (key: string) => { return fetchModelConfigWithTypeGuard(key, isVAEModelConfig); }; @@ -95,19 +135,39 @@ export const fetchTextualInversionModel = async (key: string) => { return fetchModelConfigWithTypeGuard(key, isTextualInversionModelConfig); }; -export const isBaseCompatible = (sourceBase: BaseModelType, targetBase: BaseModelType) => { - return sourceBase === targetBase; -}; - +/** + * Raises an error if the source base model is incompatible with the target base model. + * @param sourceBase The source base model. + * @param targetBase The target base model. + * @param message An optional custom message to include in the error. + * @throws {InvalidModelConfigError} If the source base model is incompatible with the target base model. + */ export const raiseIfBaseIncompatible = (sourceBase: BaseModelType, targetBase?: BaseModelType, message?: string) => { - if (targetBase && !isBaseCompatible(sourceBase, targetBase)) { + if (targetBase && sourceBase !== targetBase) { throw new InvalidModelConfigError(message || `Incompatible base models: ${sourceBase} and ${targetBase}`); } }; -export const getModelKey = (modelIdentifier: unknown, message?: string): string => { - if (!isModelIdentifier(modelIdentifier)) { - throw new InvalidModelConfigError(message || `Invalid model identifier: ${modelIdentifier}`); +/** + * Fetches the model key from a model identifier. This includes fetching the key for MM1 format model identifiers. + * @param modelIdentifier The model identifier. The MM2 format `{key: string}` simply extracts the key. The MM1 format + * `{model_name: string, base_model: BaseModelType}` must do a network request to fetch the key. + * @param type The type of model to fetch. This is used to fetch the key for MM1 format model identifiers. + * @param message An optional custom message to include in the error if the model identifier is invalid. + * @returns A promise that resolves to the model key. + * @throws {InvalidModelConfigError} If the model identifier is invalid. + */ +export const getModelKey = async (modelIdentifier: unknown, type: ModelType, message?: string): Promise => { + if (isModelIdentifier(modelIdentifier)) { + return modelIdentifier.key; } - return modelIdentifier.key; + if (isModelIdentifierV2(modelIdentifier)) { + return (await fetchModelConfigByAttrs(modelIdentifier.model_name, modelIdentifier.base_model, type)).key; + } + throw new InvalidModelConfigError(message || `Invalid model identifier: ${modelIdentifier}`); }; + +export const getModelKeyAndBase = (modelConfig: AnyModelConfig): ModelIdentifierWithBase => ({ + key: modelConfig.key, + base: modelConfig.base, +}); diff --git a/invokeai/frontend/web/src/features/parameters/util/modelMetadataHelpers.ts b/invokeai/frontend/web/src/features/parameters/util/modelMetadataHelpers.ts deleted file mode 100644 index 722073366f..0000000000 --- a/invokeai/frontend/web/src/features/parameters/util/modelMetadataHelpers.ts +++ /dev/null @@ -1,150 +0,0 @@ -import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants'; -import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types'; -import { - initialControlNet, - initialIPAdapter, - initialT2IAdapter, -} from 'features/controlAdapters/util/buildControlAdapter'; -import type { LoRA } from 'features/lora/store/loraSlice'; -import type { ModelIdentifierWithBase } from 'features/nodes/types/common'; -import { zModelIdentifierWithBase } from 'features/nodes/types/common'; -import type { - ControlNetMetadataItem, - IPAdapterMetadataItem, - LoRAMetadataItem, - T2IAdapterMetadataItem, -} from 'features/nodes/types/metadata'; -import { - fetchControlNetModel, - fetchIPAdapterModel, - fetchLoRAModel, - fetchMainModel, - fetchRefinerModel, - fetchT2IAdapterModel, - fetchVAEModel, - getModelKey, - raiseIfBaseIncompatible, -} from 'features/parameters/util/modelFetchingHelpers'; -import type { BaseModelType } from 'services/api/types'; -import { v4 as uuidv4 } from 'uuid'; - -export const prepareMainModelMetadataItem = async (model: unknown): Promise => { - const key = getModelKey(model); - const mainModel = await fetchMainModel(key); - return zModelIdentifierWithBase.parse(mainModel); -}; - -export const prepareRefinerMetadataItem = async (model: unknown): Promise => { - const key = getModelKey(model); - const refinerModel = await fetchRefinerModel(key); - return zModelIdentifierWithBase.parse(refinerModel); -}; - -export const prepareVAEMetadataItem = async (vae: unknown, base?: BaseModelType): Promise => { - const key = getModelKey(vae); - const vaeModel = await fetchVAEModel(key); - raiseIfBaseIncompatible(vaeModel.base, base, 'VAE incompatible with currently-selected model'); - return zModelIdentifierWithBase.parse(vaeModel); -}; - -export const prepareLoRAMetadataItem = async ( - loraMetadataItem: LoRAMetadataItem, - base?: BaseModelType -): Promise => { - const key = getModelKey(loraMetadataItem.lora); - const loraModel = await fetchLoRAModel(key); - raiseIfBaseIncompatible(loraModel.base, base, 'LoRA incompatible with currently-selected model'); - return { key: loraModel.key, base: loraModel.base, weight: loraMetadataItem.weight, isEnabled: true }; -}; - -export const prepareControlNetMetadataItem = async ( - controlnetMetadataItem: ControlNetMetadataItem, - base?: BaseModelType -): Promise => { - const key = getModelKey(controlnetMetadataItem.control_model); - const controlNetModel = await fetchControlNetModel(key); - raiseIfBaseIncompatible(controlNetModel.base, base, 'ControlNet incompatible with currently-selected model'); - - const { image, control_weight, begin_step_percent, end_step_percent, control_mode, resize_mode } = - controlnetMetadataItem; - - // We don't save the original image that was processed into a control image, only the processed image - const processorType = 'none'; - const processorNode = CONTROLNET_PROCESSORS.none.default; - - const controlnet: ControlNetConfig = { - type: 'controlnet', - isEnabled: true, - model: zModelIdentifierWithBase.parse(controlNetModel), - weight: typeof control_weight === 'number' ? control_weight : initialControlNet.weight, - beginStepPct: begin_step_percent || initialControlNet.beginStepPct, - endStepPct: end_step_percent || initialControlNet.endStepPct, - controlMode: control_mode || initialControlNet.controlMode, - resizeMode: resize_mode || initialControlNet.resizeMode, - controlImage: image?.image_name || null, - processedControlImage: image?.image_name || null, - processorType, - processorNode, - shouldAutoConfig: true, - id: uuidv4(), - }; - - return controlnet; -}; - -export const prepareT2IAdapterMetadataItem = async ( - t2iAdapterMetadataItem: T2IAdapterMetadataItem, - base?: BaseModelType -): Promise => { - const key = getModelKey(t2iAdapterMetadataItem.t2i_adapter_model); - const t2iAdapterModel = await fetchT2IAdapterModel(key); - raiseIfBaseIncompatible(t2iAdapterModel.base, base, 'T2I Adapter incompatible with currently-selected model'); - - const { image, weight, begin_step_percent, end_step_percent, resize_mode } = t2iAdapterMetadataItem; - - // We don't save the original image that was processed into a control image, only the processed image - const processorType = 'none'; - const processorNode = CONTROLNET_PROCESSORS.none.default; - - const t2iAdapter: T2IAdapterConfig = { - type: 't2i_adapter', - isEnabled: true, - model: zModelIdentifierWithBase.parse(t2iAdapterModel), - weight: typeof weight === 'number' ? weight : initialT2IAdapter.weight, - beginStepPct: begin_step_percent || initialT2IAdapter.beginStepPct, - endStepPct: end_step_percent || initialT2IAdapter.endStepPct, - resizeMode: resize_mode || initialT2IAdapter.resizeMode, - controlImage: image?.image_name || null, - processedControlImage: image?.image_name || null, - processorType, - processorNode, - shouldAutoConfig: true, - id: uuidv4(), - }; - - return t2iAdapter; -}; - -export const prepareIPAdapterMetadataItem = async ( - ipAdapterMetadataItem: IPAdapterMetadataItem, - base?: BaseModelType -): Promise => { - const key = getModelKey(ipAdapterMetadataItem?.ip_adapter_model); - const ipAdapterModel = await fetchIPAdapterModel(key); - raiseIfBaseIncompatible(ipAdapterModel.base, base, 'T2I Adapter incompatible with currently-selected model'); - - const { image, weight, begin_step_percent, end_step_percent } = ipAdapterMetadataItem; - - const ipAdapter: IPAdapterConfig = { - id: uuidv4(), - type: 'ip_adapter', - isEnabled: true, - controlImage: image?.image_name ?? null, - model: zModelIdentifierWithBase.parse(ipAdapterModel), - weight: weight ?? initialIPAdapter.weight, - beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct, - endStepPct: end_step_percent ?? initialIPAdapter.endStepPct, - }; - - return ipAdapter; -}; diff --git a/invokeai/frontend/web/src/features/settingsAccordions/components/AdvancedSettingsAccordion/AdvancedSettingsAccordion.tsx b/invokeai/frontend/web/src/features/settingsAccordions/components/AdvancedSettingsAccordion/AdvancedSettingsAccordion.tsx index 8b10d9bddd..05ca73927c 100644 --- a/invokeai/frontend/web/src/features/settingsAccordions/components/AdvancedSettingsAccordion/AdvancedSettingsAccordion.tsx +++ b/invokeai/frontend/web/src/features/settingsAccordions/components/AdvancedSettingsAccordion/AdvancedSettingsAccordion.tsx @@ -25,7 +25,7 @@ const formLabelProps2: FormLabelProps = { export const AdvancedSettingsAccordion = memo(() => { const vaeKey = useAppSelector((state) => state.generation.vae?.key); - const { data: vaeConfig } = useGetModelConfigQuery(vaeKey ?? skipToken); + const { currentData: vaeConfig } = useGetModelConfigQuery(vaeKey ?? skipToken); const selectBadges = useMemo( () => createMemoizedSelector(selectGenerationSlice, (generation) => { diff --git a/invokeai/frontend/web/src/services/api/endpoints/images.ts b/invokeai/frontend/web/src/services/api/endpoints/images.ts index 49eb28390f..1849e2fde4 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/images.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/images.ts @@ -1,10 +1,8 @@ import type { EntityState, Update } from '@reduxjs/toolkit'; import type { PatchCollection } from '@reduxjs/toolkit/dist/query/core/buildThunks'; -import { logger } from 'app/logging/logger'; +import type { JSONObject } from 'common/types'; import type { BoardId } from 'features/gallery/store/types'; import { ASSETS_CATEGORIES, IMAGE_CATEGORIES, IMAGE_LIMIT } from 'features/gallery/store/types'; -import type { CoreMetadata } from 'features/nodes/types/metadata'; -import { zCoreMetadata } from 'features/nodes/types/metadata'; import { addToast } from 'features/system/store/systemSlice'; import { t } from 'i18next'; import { keyBy } from 'lodash-es'; @@ -118,22 +116,9 @@ export const imagesApi = api.injectEndpoints({ providesTags: (result, error, image_name) => [{ type: 'Image', id: image_name }], keepUnusedDataFor: 86400, // 24 hours }), - getImageMetadata: build.query({ + getImageMetadata: build.query({ query: (image_name) => ({ url: buildImagesUrl(`i/${image_name}/metadata`) }), providesTags: (result, error, image_name) => [{ type: 'ImageMetadata', id: image_name }], - transformResponse: ( - response: paths['/api/v1/images/i/{image_name}/metadata']['get']['responses']['200']['content']['application/json'] - ) => { - if (response) { - const result = zCoreMetadata.safeParse(response); - if (result.success) { - return result.data; - } else { - logger('images').warn('Problem parsing metadata'); - } - } - return; - }, keepUnusedDataFor: 86400, // 24 hours }), getImageWorkflow: build.query< diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index 6618fda58e..3ccccf62e1 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -70,6 +70,8 @@ export type ScanFolderResponse = paths['/api/v2/models/scan_folder']['get']['responses']['200']['content']['application/json']; type ScanFolderArg = operations['scan_for_models']['parameters']['query']; +export type GetByAttrsArg = operations['get_model_records_by_attrs']['parameters']['query']; + export const mainModelsAdapter = createEntityAdapter({ selectId: (entity) => entity.key, sortComparer: (a, b) => a.name.localeCompare(b.name), @@ -223,6 +225,18 @@ export const modelsApi = api.injectEndpoints({ return tags; }, }), + getModelConfigByAttrs: build.query({ + query: (arg) => buildModelsUrl(`get_by_attrs?${queryString.stringify(arg)}`), + providesTags: (result) => { + const tags: ApiTagDescription[] = ['Model']; + + if (result) { + tags.push({ type: 'ModelConfig', id: result.key }); + } + + return tags; + }, + }), syncModels: build.mutation({ query: () => { return { @@ -300,6 +314,7 @@ export const modelsApi = api.injectEndpoints({ }); export const { + useGetModelConfigByAttrsQuery, useGetModelConfigQuery, useGetMainModelsQuery, useGetControlNetModelsQuery,