From 4bbe6f3548edeb46643d833c78e5ce5b34a00468 Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Thu, 7 Mar 2024 11:40:07 -0500 Subject: [PATCH] fix image processors to work with new cnet/t2i model format --- .../middleware/listenerMiddleware/index.ts | 2 + ...ntrolAdapterAutoProcessorUpdateListener.ts | 77 ++++++++++++++++ .../listeners/controlNetAutoProcess.ts | 3 +- .../parameters/ParamControlAdapterModel.tsx | 3 +- .../hooks/useAddControlAdapter.ts | 35 ++++++- .../features/controlAdapters/store/actions.ts | 10 ++ .../store/controlAdaptersSlice.ts | 91 ++----------------- 7 files changed, 132 insertions(+), 89 deletions(-) create mode 100644 invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlAdapterAutoProcessorUpdateListener.ts diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts index cd0c1290e9..48115199e4 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts @@ -55,6 +55,7 @@ import { addUpscaleRequestedListener } from 'app/store/middleware/listenerMiddle import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested'; import type { AppDispatch, RootState } from 'app/store/store'; +import { addControlAdapterAutoProcessorUpdateListener } from './listeners/controlAdapterAutoProcessorUpdateListener'; import { addSetDefaultSettingsListener } from './listeners/setDefaultSettings'; export const listenerMiddleware = createListenerMiddleware(); @@ -125,6 +126,7 @@ addBulkDownloadListeners(startAppListening); // ControlNet addControlNetImageProcessedListener(startAppListening); addControlNetAutoProcessListener(startAppListening); +addControlAdapterAutoProcessorUpdateListener(startAppListening); // Boards addImageAddedToBoardFulfilledListener(startAppListening); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlAdapterAutoProcessorUpdateListener.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlAdapterAutoProcessorUpdateListener.ts new file mode 100644 index 0000000000..3c8d87e4f7 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlAdapterAutoProcessorUpdateListener.ts @@ -0,0 +1,77 @@ +import { isAnyOf } from '@reduxjs/toolkit'; +import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; +import { controlAdapterModelChanged } from 'features/controlAdapters/store/actions'; +import { CONTROLNET_MODEL_DEFAULT_PROCESSORS } from 'features/controlAdapters/store/constants'; +import { + controlAdapterAutoConfigToggled, + controlAdapterProcessortTypeChanged, + selectControlAdapterById, +} from 'features/controlAdapters/store/controlAdaptersSlice'; +import type { ControlAdapterProcessorType } from 'features/controlAdapters/store/types'; +import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types'; +import { modelsApi } from 'services/api/endpoints/models'; + +export const addControlAdapterAutoProcessorUpdateListener = (startAppListening: AppStartListening) => { + startAppListening({ + matcher: isAnyOf(controlAdapterModelChanged, controlAdapterAutoConfigToggled), + effect: async (action, { getState, dispatch }) => { + let id; + let model; + + const state = getState(); + const { controlAdapters: controlAdaptersState } = state; + + if (controlAdapterModelChanged.match(action)) { + model = action.payload.model; + id = action.payload.id; + + const cn = selectControlAdapterById(controlAdaptersState, id); + if (!cn) { + return; + } + + if (!isControlNetOrT2IAdapter(cn)) { + return; + } + } + + if (controlAdapterAutoConfigToggled.match(action)) { + id = action.payload.id; + + const cn = selectControlAdapterById(controlAdaptersState, id); + if (!cn) { + return; + } + + if (!isControlNetOrT2IAdapter(cn)) { + return; + } + + // if they turned off autoconfig, return + if (!cn.shouldAutoConfig) { + return; + } + + model = cn.model; + } + + if (!model || !id) { + return; + } + + let processorType: ControlAdapterProcessorType | undefined = undefined; + const { data: modelConfig } = modelsApi.endpoints.getModelConfig.select(model.key)(state); + + for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) { + if (modelConfig?.name.includes(modelSubstring)) { + processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring]; + break; + } + } + + dispatch( + controlAdapterProcessortTypeChanged({ id, processorType: processorType || 'none', shouldAutoConfig: true }) + ); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetAutoProcess.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetAutoProcess.ts index e52df30681..88cb72dd86 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetAutoProcess.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetAutoProcess.ts @@ -2,11 +2,10 @@ import type { AnyListenerPredicate } from '@reduxjs/toolkit'; import { logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { RootState } from 'app/store/store'; -import { controlAdapterImageProcessed } from 'features/controlAdapters/store/actions'; +import { controlAdapterImageProcessed, controlAdapterModelChanged } from 'features/controlAdapters/store/actions'; import { controlAdapterAutoConfigToggled, controlAdapterImageChanged, - controlAdapterModelChanged, controlAdapterProcessorParamsChanged, controlAdapterProcessortTypeChanged, selectControlAdapterById, diff --git a/invokeai/frontend/web/src/features/controlAdapters/components/parameters/ParamControlAdapterModel.tsx b/invokeai/frontend/web/src/features/controlAdapters/components/parameters/ParamControlAdapterModel.tsx index d7cf2e8452..b0ebaf1985 100644 --- a/invokeai/frontend/web/src/features/controlAdapters/components/parameters/ParamControlAdapterModel.tsx +++ b/invokeai/frontend/web/src/features/controlAdapters/components/parameters/ParamControlAdapterModel.tsx @@ -5,7 +5,7 @@ import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useCo import { useControlAdapterModel } from 'features/controlAdapters/hooks/useControlAdapterModel'; import { useControlAdapterModelQuery } from 'features/controlAdapters/hooks/useControlAdapterModelQuery'; import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType'; -import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice'; +import { controlAdapterModelChanged } from 'features/controlAdapters/store/actions'; import { getModelKeyAndBase } from 'features/metadata/util/modelFetchingHelpers'; import { memo, useCallback, useMemo } from 'react'; import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types'; @@ -25,6 +25,7 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => { const _onChange = useCallback( (model: ControlNetModelConfig | IPAdapterModelConfig | T2IAdapterModelConfig | null) => { + console.log('on change'); if (!model) { return; } diff --git a/invokeai/frontend/web/src/features/controlAdapters/hooks/useAddControlAdapter.ts b/invokeai/frontend/web/src/features/controlAdapters/hooks/useAddControlAdapter.ts index 7fd1088767..743a79e281 100644 --- a/invokeai/frontend/web/src/features/controlAdapters/hooks/useAddControlAdapter.ts +++ b/invokeai/frontend/web/src/features/controlAdapters/hooks/useAddControlAdapter.ts @@ -1,6 +1,12 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { CONTROLNET_MODEL_DEFAULT_PROCESSORS, CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants'; import { controlAdapterAdded } from 'features/controlAdapters/store/controlAdaptersSlice'; -import type { ControlAdapterType } from 'features/controlAdapters/store/types'; +import type { + ControlAdapterProcessorType, + ControlAdapterType, + RequiredControlAdapterProcessorNode, +} from 'features/controlAdapters/store/types'; +import { cloneDeep } from 'lodash-es'; import { useCallback, useMemo } from 'react'; import { useControlAdapterModels } from './useControlAdapterModels'; @@ -22,6 +28,29 @@ export const useAddControlAdapter = (type: ControlAdapterType) => { return models[0]; }, [baseModel, models]); + const processor = useMemo(() => { + let processorType; + for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) { + if (firstModel?.name.includes(modelSubstring)) { + processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring]; + break; + } + } + + if (!processorType) { + processorType = 'none'; + } + + const processorNode = + processorType === 'none' + ? (cloneDeep(CONTROLNET_PROCESSORS.none.default) as RequiredControlAdapterProcessorNode) + : (cloneDeep( + CONTROLNET_PROCESSORS[processorType as ControlAdapterProcessorType].default + ) as RequiredControlAdapterProcessorNode); + + return { processorType, processorNode }; + }, [firstModel]); + const isDisabled = useMemo(() => !firstModel, [firstModel]); const addControlAdapter = useCallback(() => { @@ -31,10 +60,10 @@ export const useAddControlAdapter = (type: ControlAdapterType) => { dispatch( controlAdapterAdded({ type, - overrides: { model: firstModel }, + overrides: { model: firstModel, ...processor }, }) ); - }, [dispatch, firstModel, isDisabled, type]); + }, [dispatch, firstModel, isDisabled, type, processor]); return [addControlAdapter, isDisabled] as const; }; diff --git a/invokeai/frontend/web/src/features/controlAdapters/store/actions.ts b/invokeai/frontend/web/src/features/controlAdapters/store/actions.ts index 99ea84ed13..979a980572 100644 --- a/invokeai/frontend/web/src/features/controlAdapters/store/actions.ts +++ b/invokeai/frontend/web/src/features/controlAdapters/store/actions.ts @@ -1,5 +1,15 @@ import { createAction } from '@reduxjs/toolkit'; +import type { + ParameterControlNetModel, + ParameterIPAdapterModel, + ParameterT2IAdapterModel, +} from 'features/parameters/types/parameterSchemas'; export const controlAdapterImageProcessed = createAction<{ id: string; }>('controlAdapters/imageProcessed'); + +export const controlAdapterModelChanged = createAction<{ + id: string; + model: ParameterControlNetModel | ParameterT2IAdapterModel | ParameterIPAdapterModel; +}>('controlAdapters/controlAdapterModelChanged'); diff --git a/invokeai/frontend/web/src/features/controlAdapters/store/controlAdaptersSlice.ts b/invokeai/frontend/web/src/features/controlAdapters/store/controlAdaptersSlice.ts index a20e287011..edf3beba32 100644 --- a/invokeai/frontend/web/src/features/controlAdapters/store/controlAdaptersSlice.ts +++ b/invokeai/frontend/web/src/features/controlAdapters/store/controlAdaptersSlice.ts @@ -3,20 +3,12 @@ import { createEntityAdapter, createSlice, isAnyOf } from '@reduxjs/toolkit'; import { getSelectorsOptions } from 'app/store/createMemoizedSelector'; import type { PersistConfig, RootState } from 'app/store/store'; import { buildControlAdapter } from 'features/controlAdapters/util/buildControlAdapter'; -import type { - ParameterControlNetModel, - ParameterIPAdapterModel, - ParameterT2IAdapterModel, -} from 'features/parameters/types/parameterSchemas'; import { cloneDeep, merge, uniq } from 'lodash-es'; import { socketInvocationError } from 'services/events/actions'; import { v4 as uuidv4 } from 'uuid'; import { controlAdapterImageProcessed } from './actions'; -import { - CONTROLNET_MODEL_DEFAULT_PROCESSORS as CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS, - CONTROLNET_PROCESSORS, -} from './constants'; +import { CONTROLNET_PROCESSORS } from './constants'; import type { ControlAdapterConfig, ControlAdapterProcessorType, @@ -190,52 +182,6 @@ export const controlAdaptersSlice = createSlice({ changes: { model: null }, }); }, - controlAdapterModelChanged: ( - state, - action: PayloadAction<{ - id: string; - model: ParameterControlNetModel | ParameterT2IAdapterModel | ParameterIPAdapterModel; - }> - ) => { - const { id, model } = action.payload; - const cn = selectControlAdapterById(state, id); - if (!cn) { - return; - } - - if (!isControlNetOrT2IAdapter(cn)) { - caAdapter.updateOne(state, { id, changes: { model } }); - return; - } - - const update: Update = { - id, - changes: { model, shouldAutoConfig: true }, - }; - - update.changes.processedControlImage = null; - - let processorType: ControlAdapterProcessorType | undefined = undefined; - - for (const modelSubstring in CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS) { - // TODO(MM2): matching modelSubstring to the model key is no longer a valid way to figure out the default processorType - if (model.key.includes(modelSubstring)) { - processorType = CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS[modelSubstring]; - break; - } - } - - if (processorType) { - update.changes.processorType = processorType; - update.changes.processorNode = CONTROLNET_PROCESSORS[processorType] - .default as RequiredControlAdapterProcessorNode; - } else { - update.changes.processorType = 'none'; - update.changes.processorNode = CONTROLNET_PROCESSORS.none.default as RequiredControlAdapterProcessorNode; - } - - caAdapter.updateOne(state, update); - }, controlAdapterWeightChanged: (state, action: PayloadAction<{ id: string; weight: number }>) => { const { id, weight } = action.payload; caAdapter.updateOne(state, { id, changes: { weight } }); @@ -298,17 +244,19 @@ export const controlAdaptersSlice = createSlice({ action: PayloadAction<{ id: string; processorType: ControlAdapterProcessorType; + shouldAutoConfig?: boolean; }> ) => { - const { id, processorType } = action.payload; + const { id, processorType, shouldAutoConfig = false } = action.payload; const cn = selectControlAdapterById(state, id); if (!cn || !isControlNetOrT2IAdapter(cn)) { return; } - const processorNode = cloneDeep( - CONTROLNET_PROCESSORS[processorType].default - ) as RequiredControlAdapterProcessorNode; + const processorNode = + processorType === 'none' + ? (cloneDeep(CONTROLNET_PROCESSORS.none.default) as RequiredControlAdapterProcessorNode) + : (cloneDeep(CONTROLNET_PROCESSORS[processorType].default) as RequiredControlAdapterProcessorNode); caAdapter.updateOne(state, { id, @@ -316,7 +264,7 @@ export const controlAdaptersSlice = createSlice({ processorType, processedControlImage: null, processorNode, - shouldAutoConfig: false, + shouldAutoConfig, }, }); }, @@ -337,28 +285,6 @@ export const controlAdaptersSlice = createSlice({ changes: { shouldAutoConfig: !cn.shouldAutoConfig }, }; - if (update.changes.shouldAutoConfig) { - // manage the processor for the user - let processorType: ControlAdapterProcessorType | undefined = undefined; - - for (const modelSubstring in CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS) { - // TODO(MM2): matching modelSubstring to the model key is no longer a valid way to figure out the default processorType - if (cn.model?.key.includes(modelSubstring)) { - processorType = CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS[modelSubstring]; - break; - } - } - - if (processorType) { - update.changes.processorType = processorType; - update.changes.processorNode = CONTROLNET_PROCESSORS[processorType] - .default as RequiredControlAdapterProcessorNode; - } else { - update.changes.processorType = 'none'; - update.changes.processorNode = CONTROLNET_PROCESSORS.none.default as RequiredControlAdapterProcessorNode; - } - } - caAdapter.updateOne(state, update); }, controlAdaptersReset: () => { @@ -393,7 +319,6 @@ export const { controlAdapterImageChanged, controlAdapterProcessedImageChanged, controlAdapterIsEnabledChanged, - controlAdapterModelChanged, controlAdapterWeightChanged, controlAdapterBeginStepPctChanged, controlAdapterEndStepPctChanged,