From c1666a8b5a785c87c7a047f1385b7d4bddc9059b Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 30 Apr 2024 10:37:45 +1000 Subject: [PATCH] fix(ui): select default control/ip adapter models in control layers --- .../regionalControlToControlAdapterBridge.ts | 67 ++++++++++++++++--- 1 file changed, 59 insertions(+), 8 deletions(-) diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/regionalControlToControlAdapterBridge.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/regionalControlToControlAdapterBridge.ts index 2d80902550..6f53159d4b 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/regionalControlToControlAdapterBridge.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/regionalControlToControlAdapterBridge.ts @@ -1,6 +1,9 @@ import { createAction } from '@reduxjs/toolkit'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; +import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants'; import { controlAdapterAdded, controlAdapterRemoved } from 'features/controlAdapters/store/controlAdaptersSlice'; +import type { ControlNetConfig, IPAdapterConfig } from 'features/controlAdapters/store/types'; +import { isControlAdapterProcessorType } from 'features/controlAdapters/store/types'; import { controlAdapterLayerAdded, ipAdapterLayerAdded, @@ -10,6 +13,8 @@ import { maskLayerIPAdapterDeleted, } from 'features/regionalPrompts/store/regionalPromptsSlice'; import type { Layer } from 'features/regionalPrompts/store/types'; +import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models'; +import { isControlNetModelConfig, isIPAdapterModelConfig } from 'services/api/types'; import { assert } from 'tsafe'; import { v4 as uuidv4 } from 'uuid'; @@ -24,19 +29,52 @@ export const guidanceLayerIPAdapterDeleted = createAction<{ layerId: string; ipA export const addRegionalControlToControlAdapterBridge = (startAppListening: AppStartListening) => { startAppListening({ actionCreator: guidanceLayerAdded, - effect: (action, { dispatch }) => { + effect: (action, { dispatch, getState }) => { const type = action.payload; const layerId = uuidv4(); + if (type === 'masked_guidance_layer') { + dispatch(maskedGuidanceLayerAdded({ layerId })); + return; + } + + const state = getState(); + const baseModel = state.generation.model?.base; + const modelConfigs = modelsApi.endpoints.getModelConfigs.select(undefined)(state).data; + if (type === 'ip_adapter_layer') { const ipAdapterId = uuidv4(); - dispatch(controlAdapterAdded({ type: 'ip_adapter', overrides: { id: ipAdapterId } })); + const overrides: Partial = { + id: ipAdapterId, + }; + + // Find and select the first matching model + if (modelConfigs) { + const models = modelConfigsAdapterSelectors.selectAll(modelConfigs).filter(isIPAdapterModelConfig); + overrides.model = models.find((m) => m.base === baseModel) ?? null; + } + dispatch(controlAdapterAdded({ type: 'ip_adapter', overrides })); dispatch(ipAdapterLayerAdded({ layerId, ipAdapterId })); - } else if (type === 'control_adapter_layer') { + return; + } + + if (type === 'control_adapter_layer') { const controlNetId = uuidv4(); - dispatch(controlAdapterAdded({ type: 'controlnet', overrides: { id: controlNetId } })); + const overrides: Partial = { + id: controlNetId, + }; + + // Find and select the first matching model + if (modelConfigs) { + const models = modelConfigsAdapterSelectors.selectAll(modelConfigs).filter(isControlNetModelConfig); + const model = models.find((m) => m.base === baseModel) ?? null; + overrides.model = model; + const defaultPreprocessor = model?.default_settings?.preprocessor; + overrides.processorType = isControlAdapterProcessorType(defaultPreprocessor) ? defaultPreprocessor : 'none'; + overrides.processorNode = CONTROLNET_PROCESSORS[overrides.processorType].buildDefaults(baseModel); + } + dispatch(controlAdapterAdded({ type: 'controlnet', overrides })); dispatch(controlAdapterLayerAdded({ layerId, controlNetId })); - } else if (type === 'masked_guidance_layer') { - dispatch(maskedGuidanceLayerAdded({ layerId })); + return; } }, }); @@ -74,10 +112,23 @@ export const addRegionalControlToControlAdapterBridge = (startAppListening: AppS startAppListening({ actionCreator: guidanceLayerIPAdapterAdded, - effect: (action, { dispatch }) => { + effect: (action, { dispatch, getState }) => { const layerId = action.payload; const ipAdapterId = uuidv4(); - dispatch(controlAdapterAdded({ type: 'ip_adapter', overrides: { id: ipAdapterId } })); + const overrides: Partial = { + id: ipAdapterId, + }; + + // Find and select the first matching model + const state = getState(); + const baseModel = state.generation.model?.base; + const modelConfigs = modelsApi.endpoints.getModelConfigs.select(undefined)(state).data; + if (modelConfigs) { + const models = modelConfigsAdapterSelectors.selectAll(modelConfigs).filter(isIPAdapterModelConfig); + overrides.model = models.find((m) => m.base === baseModel) ?? null; + } + + dispatch(controlAdapterAdded({ type: 'ip_adapter', overrides })); dispatch(maskLayerIPAdapterAdded({ layerId, ipAdapterId })); }, });