From 76dc47e88df9c817346b1a4409e2847eddef3dd3 Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Tue, 11 Jul 2023 16:18:38 -0400 Subject: [PATCH] remove frontend constants, use backend response for controlnet models. add disabled state if base model is not compatible. clear control net model if main base model changes. add logic to guess processor and move it up in UI --- .../frontend/web/src/app/types/invokeai.ts | 4 +- .../controlNet/components/ControlNet.tsx | 19 ++-- .../parameters/ParamControlNetModel.tsx | 78 ++++++++------ .../features/controlNet/store/constants.ts | 100 +++--------------- .../controlNet/store/controlNetSlice.ts | 33 ++++-- 5 files changed, 97 insertions(+), 137 deletions(-) diff --git a/invokeai/frontend/web/src/app/types/invokeai.ts b/invokeai/frontend/web/src/app/types/invokeai.ts index 40b8c1c73a..be642a6435 100644 --- a/invokeai/frontend/web/src/app/types/invokeai.ts +++ b/invokeai/frontend/web/src/app/types/invokeai.ts @@ -1,5 +1,5 @@ import { - CONTROLNET_MODELS, + // CONTROLNET_MODELS, CONTROLNET_PROCESSORS, } from 'features/controlNet/store/constants'; import { InvokeTabName } from 'features/ui/store/tabMap'; @@ -128,7 +128,7 @@ export type AppConfig = { canRestoreDeletedImagesFromBin: boolean; sd: { defaultModel?: string; - disabledControlNetModels: (keyof typeof CONTROLNET_MODELS)[]; + disabledControlNetModels: string[]; disabledControlNetProcessors: (keyof typeof CONTROLNET_PROCESSORS)[]; iterations: { initial: number; diff --git a/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx b/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx index bb01416e1d..e25c320cd6 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx @@ -124,6 +124,7 @@ const ControlNet = (props: ControlNetProps) => { /> } /> + {!shouldAutoConfig && ( { /> )} + + + + {isEnabled && ( <> @@ -196,18 +207,10 @@ const ControlNet = (props: ControlNetProps) => { height={96} /> - - )} diff --git a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetModel.tsx b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetModel.tsx index ddf266ccfc..eda3cde5d2 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetModel.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetModel.tsx @@ -1,55 +1,71 @@ -import { createSelector } from '@reduxjs/toolkit'; +import { SelectItem } from '@mantine/core'; +import { RootState } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import IAIMantineSearchableSelect, { - IAISelectDataType, -} from 'common/components/IAIMantineSearchableSelect'; +import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect'; +import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip'; import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke'; -import { - CONTROLNET_MODELS, - ControlNetModelName, -} from 'features/controlNet/store/constants'; import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice'; -import { configSelector } from 'features/system/store/configSelectors'; -import { map } from 'lodash-es'; -import { memo, useCallback } from 'react'; +import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; +import { forEach } from 'lodash-es'; +import { memo, useCallback, useMemo } from 'react'; +import { useGetControlNetModelsQuery } from 'services/api/endpoints/models'; type ParamControlNetModelProps = { controlNetId: string; - model: ControlNetModelName; + model: string; }; -const selector = createSelector(configSelector, (config) => { - const controlNetModels: IAISelectDataType[] = map(CONTROLNET_MODELS, (m) => ({ - label: m.label, - value: m.type, - })).filter( - (d) => - !config.sd.disabledControlNetModels.includes( - d.value as ControlNetModelName - ) - ); - - return controlNetModels; -}); - const ParamControlNetModel = (props: ParamControlNetModelProps) => { const { controlNetId, model } = props; - const controlNetModels = useAppSelector(selector); const dispatch = useAppDispatch(); const isReady = useIsReadyToInvoke(); + const currentMainModel = useAppSelector( + (state: RootState) => state.generation.model + ); + + const { data: controlNetModels } = useGetControlNetModelsQuery(); + const handleModelChanged = useCallback( (val: string | null) => { - // TODO: do not cast - const model = val as ControlNetModelName; - dispatch(controlNetModelChanged({ controlNetId, model })); + if (!val) return; + dispatch(controlNetModelChanged({ controlNetId, model: val })); }, [controlNetId, dispatch] ); + const data = useMemo(() => { + if (!controlNetModels) { + return []; + } + + const data: SelectItem[] = []; + + forEach(controlNetModels.entities, (model, id) => { + if (!model) { + return; + } + + const disabled = currentMainModel?.base_model !== model.base_model; + + data.push({ + value: id, + label: model.model_name, + group: MODEL_TYPE_MAP[model.base_model], + disabled, + tooltip: disabled + ? `Incompatible base model: ${model.base_model}` + : undefined, + }); + }); + + return data; + }, [controlNetModels, currentMainModel?.base_model]); + return ( ; - -type ControlNetModel = { - type: string; - label: string; - description?: string; - defaultProcessor?: ControlNetProcessorType; +export const CONTROLNET_MODEL_DEFAULT_PROCESSORS: { + [key: string]: ControlNetProcessorType; +} = { + canny: 'canny_image_processor', + mlsd: 'mlsd_image_processor', + depth: 'midas_depth_image_processor', + bae: 'normalbae_image_processor', + lineart: 'lineart_image_processor', + lineart_anime: 'lineart_anime_image_processor', + softedge: 'hed_image_processor', + shuffle: 'content_shuffle_image_processor', + openpose: 'openpose_image_processor', + mediapipe: 'mediapipe_face_processor', }; - -export const CONTROLNET_MODELS: ControlNetModelsDict = { - 'lllyasviel/control_v11p_sd15_canny': { - type: 'lllyasviel/control_v11p_sd15_canny', - label: 'Canny', - defaultProcessor: 'canny_image_processor', - }, - 'lllyasviel/control_v11p_sd15_inpaint': { - type: 'lllyasviel/control_v11p_sd15_inpaint', - label: 'Inpaint', - defaultProcessor: 'none', - }, - 'lllyasviel/control_v11p_sd15_mlsd': { - type: 'lllyasviel/control_v11p_sd15_mlsd', - label: 'M-LSD', - defaultProcessor: 'mlsd_image_processor', - }, - 'lllyasviel/control_v11f1p_sd15_depth': { - type: 'lllyasviel/control_v11f1p_sd15_depth', - label: 'Depth', - defaultProcessor: 'midas_depth_image_processor', - }, - 'lllyasviel/control_v11p_sd15_normalbae': { - type: 'lllyasviel/control_v11p_sd15_normalbae', - label: 'Normal Map (BAE)', - defaultProcessor: 'normalbae_image_processor', - }, - 'lllyasviel/control_v11p_sd15_seg': { - type: 'lllyasviel/control_v11p_sd15_seg', - label: 'Segmentation', - defaultProcessor: 'none', - }, - 'lllyasviel/control_v11p_sd15_lineart': { - type: 'lllyasviel/control_v11p_sd15_lineart', - label: 'Lineart', - defaultProcessor: 'lineart_image_processor', - }, - 'lllyasviel/control_v11p_sd15s2_lineart_anime': { - type: 'lllyasviel/control_v11p_sd15s2_lineart_anime', - label: 'Lineart Anime', - defaultProcessor: 'lineart_anime_image_processor', - }, - 'lllyasviel/control_v11p_sd15_scribble': { - type: 'lllyasviel/control_v11p_sd15_scribble', - label: 'Scribble', - defaultProcessor: 'none', - }, - 'lllyasviel/control_v11p_sd15_softedge': { - type: 'lllyasviel/control_v11p_sd15_softedge', - label: 'Soft Edge', - defaultProcessor: 'hed_image_processor', - }, - 'lllyasviel/control_v11e_sd15_shuffle': { - type: 'lllyasviel/control_v11e_sd15_shuffle', - label: 'Content Shuffle', - defaultProcessor: 'content_shuffle_image_processor', - }, - 'lllyasviel/control_v11p_sd15_openpose': { - type: 'lllyasviel/control_v11p_sd15_openpose', - label: 'Openpose', - defaultProcessor: 'openpose_image_processor', - }, - 'lllyasviel/control_v11f1e_sd15_tile': { - type: 'lllyasviel/control_v11f1e_sd15_tile', - label: 'Tile (experimental)', - defaultProcessor: 'none', - }, - 'lllyasviel/control_v11e_sd15_ip2p': { - type: 'lllyasviel/control_v11e_sd15_ip2p', - label: 'Pix2Pix (experimental)', - defaultProcessor: 'none', - }, - 'CrucibleAI/ControlNetMediaPipeFace': { - type: 'CrucibleAI/ControlNetMediaPipeFace', - label: 'Mediapipe Face', - defaultProcessor: 'mediapipe_face_processor', - }, -}; - -export type ControlNetModelName = keyof typeof CONTROLNET_MODELS; diff --git a/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts b/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts index d1c69566e9..39a321b282 100644 --- a/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts +++ b/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts @@ -8,9 +8,10 @@ import { RequiredControlNetProcessorNode, } from './types'; import { - CONTROLNET_MODELS, + CONTROLNET_MODEL_DEFAULT_PROCESSORS, + // CONTROLNET_MODELS, CONTROLNET_PROCESSORS, - ControlNetModelName, + // ControlNetModelName, } from './constants'; import { controlNetImageProcessed } from './actions'; import { imageDeleted, imageUrlsReceived } from 'services/api/thunks/image'; @@ -26,7 +27,7 @@ export type ControlModes = export const initialControlNet: Omit = { isEnabled: true, - model: CONTROLNET_MODELS['lllyasviel/control_v11p_sd15_canny'].type, + model: '', weight: 1, beginStepPct: 0, endStepPct: 1, @@ -42,7 +43,7 @@ export const initialControlNet: Omit = { export type ControlNetConfig = { controlNetId: string; isEnabled: boolean; - model: ControlNetModelName; + model: string; weight: number; beginStepPct: number; endStepPct: number; @@ -147,7 +148,7 @@ export const controlNetSlice = createSlice({ state, action: PayloadAction<{ controlNetId: string; - model: ControlNetModelName; + model: string; }> ) => { const { controlNetId, model } = action.payload; @@ -155,7 +156,15 @@ export const controlNetSlice = createSlice({ state.controlNets[controlNetId].processedControlImage = null; if (state.controlNets[controlNetId].shouldAutoConfig) { - const processorType = CONTROLNET_MODELS[model].defaultProcessor; + let processorType: ControlNetProcessorType | undefined = undefined; + + for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) { + if (model.includes(modelSubstring)) { + processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring]; + break; + } + } + if (processorType) { state.controlNets[controlNetId].processorType = processorType; state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[ @@ -241,9 +250,15 @@ export const controlNetSlice = createSlice({ if (newShouldAutoConfig) { // manage the processor for the user - const processorType = - CONTROLNET_MODELS[state.controlNets[controlNetId].model] - .defaultProcessor; + let processorType: ControlNetProcessorType | undefined = undefined; + + for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) { + if (state.controlNets[controlNetId].model.includes(modelSubstring)) { + processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring]; + break; + } + } + if (processorType) { state.controlNets[controlNetId].processorType = processorType; state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[