diff --git a/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts b/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts index ae3bdd7112..f0745eae2b 100644 --- a/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts +++ b/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts @@ -98,6 +98,9 @@ export const controlNetSlice = createSlice({ isControlNetEnabledToggled: (state) => { state.isEnabled = !state.isEnabled; }, + controlNetEnabled: (state) => { + state.isEnabled = true; + }, controlNetAdded: ( state, action: PayloadAction<{ @@ -111,6 +114,12 @@ export const controlNetSlice = createSlice({ controlNetId, }; }, + controlNetRecalled: (state, action: PayloadAction) => { + const controlNet = action.payload; + state.controlNets[controlNet.controlNetId] = { + ...controlNet, + }; + }, controlNetDuplicated: ( state, action: PayloadAction<{ @@ -439,7 +448,9 @@ export const controlNetSlice = createSlice({ export const { isControlNetEnabledToggled, + controlNetEnabled, controlNetAdded, + controlNetRecalled, controlNetDuplicated, controlNetAddedFromImage, controlNetRemoved, diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx index 955e8a5a3a..25d8e1e5ac 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx @@ -1,8 +1,15 @@ -import { CoreMetadata, LoRAMetadataItem } from 'features/nodes/types/types'; +import { + ControlNetMetadataItem, + CoreMetadata, + LoRAMetadataItem, +} from 'features/nodes/types/types'; import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters'; -import { memo, useCallback } from 'react'; +import { memo, useMemo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; -import { isValidLoRAModel } from '../../../parameters/types/parameterSchemas'; +import { + isValidControlNetModel, + isValidLoRAModel, +} from '../../../parameters/types/parameterSchemas'; import ImageMetadataItem from './ImageMetadataItem'; type Props = { @@ -26,6 +33,7 @@ const ImageMetadataActions = (props: Props) => { recallHeight, recallStrength, recallLoRA, + recallControlNet, } = useRecallParameters(); const handleRecallPositivePrompt = useCallback(() => { @@ -75,6 +83,21 @@ const ImageMetadataActions = (props: Props) => { [recallLoRA] ); + const handleRecallControlNet = useCallback( + (controlnet: ControlNetMetadataItem) => { + recallControlNet(controlnet); + }, + [recallControlNet] + ); + + const validControlNets: ControlNetMetadataItem[] = useMemo(() => { + return metadata?.controlnets + ? metadata.controlnets.filter((controlnet) => + isValidControlNetModel(controlnet.control_model) + ) + : []; + }, [metadata?.controlnets]); + if (!metadata || Object.keys(metadata).length === 0) { return null; } @@ -180,6 +203,14 @@ const ImageMetadataActions = (props: Props) => { ); } })} + {validControlNets.map((controlnet, index) => ( + handleRecallControlNet(controlnet)} + /> + ))} ); }; diff --git a/invokeai/frontend/web/src/features/nodes/types/types.ts b/invokeai/frontend/web/src/features/nodes/types/types.ts index 0033e462cb..eb8baf513e 100644 --- a/invokeai/frontend/web/src/features/nodes/types/types.ts +++ b/invokeai/frontend/web/src/features/nodes/types/types.ts @@ -1141,6 +1141,10 @@ const zLoRAMetadataItem = z.object({ export type LoRAMetadataItem = z.infer; +const zControlNetMetadataItem = zControlField.deepPartial(); + +export type ControlNetMetadataItem = z.infer; + export const zCoreMetadata = z .object({ app_version: z.string().nullish().catch(null), diff --git a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts index 4fb9a0ce2c..d8561ab122 100644 --- a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts +++ b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts @@ -2,7 +2,11 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppToaster } from 'app/components/Toaster'; import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { CoreMetadata, LoRAMetadataItem } from 'features/nodes/types/types'; +import { + CoreMetadata, + LoRAMetadataItem, + ControlNetMetadataItem, +} from 'features/nodes/types/types'; import { refinerModelChanged, setNegativeStylePromptSDXL, @@ -18,9 +22,18 @@ import { useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { ImageDTO } from 'services/api/types'; import { + controlNetModelsAdapter, loraModelsAdapter, + useGetControlNetModelsQuery, useGetLoRAModelsQuery, } from '../../../services/api/endpoints/models'; +import { + ControlNetConfig, + controlNetEnabled, + controlNetRecalled, + controlNetReset, + initialControlNet, +} from '../../controlNet/store/controlNetSlice'; import { loraRecalled, lorasCleared } from '../../lora/store/loraSlice'; import { initialImageSelected, modelSelected } from '../store/actions'; import { @@ -38,6 +51,7 @@ import { isValidCfgScale, isValidHeight, isValidLoRAModel, + isValidControlNetModel, isValidMainModel, isValidNegativePrompt, isValidPositivePrompt, @@ -53,6 +67,11 @@ import { isValidStrength, isValidWidth, } from '../types/parameterSchemas'; +import { v4 as uuidv4 } from 'uuid'; +import { + CONTROLNET_PROCESSORS, + CONTROLNET_MODEL_DEFAULT_PROCESSORS, +} from 'features/controlNet/store/constants'; const selector = createSelector(stateSelector, ({ generation }) => { const { model } = generation; @@ -390,6 +409,121 @@ export const useRecallParameters = () => { [prepareLoRAMetadataItem, dispatch, parameterSetToast, parameterNotSetToast] ); + /** + * Recall ControlNet with toast + */ + + const { controlnets } = useGetControlNetModelsQuery(undefined, { + selectFromResult: (result) => ({ + controlnets: result.data + ? controlNetModelsAdapter.getSelectors().selectAll(result.data) + : [], + }), + }); + + const prepareControlNetMetadataItem = useCallback( + (controlnetMetadataItem: ControlNetMetadataItem) => { + if (!isValidControlNetModel(controlnetMetadataItem.control_model)) { + return { controlnet: null, error: 'Invalid ControlNet model' }; + } + + const { + image, + control_model, + control_weight, + begin_step_percent, + end_step_percent, + control_mode, + resize_mode, + } = controlnetMetadataItem; + + const matchingControlNetModel = controlnets.find( + (c) => + c.base_model === control_model.base_model && + c.model_name === control_model.model_name + ); + + if (!matchingControlNetModel) { + return { controlnet: null, error: 'ControlNet model is not installed' }; + } + + const isCompatibleBaseModel = + matchingControlNetModel?.base_model === model?.base_model; + + if (!isCompatibleBaseModel) { + return { + controlnet: null, + error: 'ControlNet incompatible with currently-selected model', + }; + } + + const controlNetId = uuidv4(); + + let processorType = initialControlNet.processorType; + for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) { + if (matchingControlNetModel.model_name.includes(modelSubstring)) { + processorType = + CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring] || + initialControlNet.processorType; + break; + } + } + const processorNode = CONTROLNET_PROCESSORS[processorType].default; + + const controlnet: ControlNetConfig = { + isEnabled: true, + model: matchingControlNetModel, + weight: + typeof control_weight === 'number' + ? control_weight + : initialControlNet.weight, + beginStepPct: begin_step_percent || initialControlNet.beginStepPct, + endStepPct: end_step_percent || initialControlNet.endStepPct, + controlMode: control_mode || initialControlNet.controlMode, + resizeMode: resize_mode || initialControlNet.resizeMode, + controlImage: image?.image_name || null, + processedControlImage: image?.image_name || null, + processorType, + processorNode: + processorNode.type !== 'none' + ? processorNode + : initialControlNet.processorNode, + shouldAutoConfig: true, + controlNetId, + }; + + return { controlnet, error: null }; + }, + [controlnets, model?.base_model] + ); + + const recallControlNet = useCallback( + (controlnetMetadataItem: ControlNetMetadataItem) => { + const result = prepareControlNetMetadataItem(controlnetMetadataItem); + + if (!result.controlnet) { + parameterNotSetToast(result.error); + return; + } + + dispatch( + controlNetRecalled({ + ...result.controlnet, + }) + ); + + dispatch(controlNetEnabled()); + + parameterSetToast(); + }, + [ + prepareControlNetMetadataItem, + dispatch, + parameterSetToast, + parameterNotSetToast, + ] + ); + /* * Sets image as initial image with toast */ @@ -428,6 +562,7 @@ export const useRecallParameters = () => { refiner_negative_aesthetic_score, refiner_start, loras, + controlnets, } = metadata; if (isValidCfgScale(cfg_scale)) { @@ -517,6 +652,15 @@ export const useRecallParameters = () => { } }); + dispatch(controlNetReset()); + dispatch(controlNetEnabled()); + controlnets?.forEach((controlnet) => { + const result = prepareControlNetMetadataItem(controlnet); + if (result.controlnet) { + dispatch(controlNetRecalled(result.controlnet)); + } + }); + allParameterSetToast(); }, [ @@ -524,6 +668,7 @@ export const useRecallParameters = () => { allParameterSetToast, dispatch, prepareLoRAMetadataItem, + prepareControlNetMetadataItem, ] ); @@ -542,6 +687,7 @@ export const useRecallParameters = () => { recallHeight, recallStrength, recallLoRA, + recallControlNet, recallAllParameters, sendToImageToImage, };