diff --git a/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx b/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx index 89d0488b35..cedfebac6e 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx @@ -6,20 +6,18 @@ import { controlAdapterDuplicated, controlAdapterIsEnabledChanged, controlAdapterRemoved, - selectControlAdapterById, } from '../store/controlAdaptersSlice'; import ParamControlNetModel from './parameters/ParamControlNetModel'; import ParamControlNetWeight from './parameters/ParamControlNetWeight'; import { ChevronUpIcon } from '@chakra-ui/icons'; -import { createSelector } from '@reduxjs/toolkit'; -import { stateSelector } from 'app/store/store'; -import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAIIconButton from 'common/components/IAIIconButton'; import IAISwitch from 'common/components/IAISwitch'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { useTranslation } from 'react-i18next'; -import { isControlNetOrT2IAdapter } from '../store/types'; +import { useToggle } from 'react-use'; +import { useControlAdapterIsEnabled } from '../hooks/useControlAdapterIsEnabled'; +import { useControlAdapterType } from '../hooks/useControlAdapterType'; import ControlNetImagePreview from './ControlNetImagePreview'; import ControlNetProcessorComponent from './ControlNetProcessorComponent'; import ParamControlNetShouldAutoConfig from './ParamControlNetShouldAutoConfig'; @@ -28,8 +26,6 @@ import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd'; import ParamControlNetControlMode from './parameters/ParamControlNetControlMode'; import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect'; import ParamControlNetResizeMode from './parameters/ParamControlNetResizeMode'; -import { useToggle } from 'react-use'; -import { useControlAdapterType } from '../hooks/useControlAdapterType'; const ControlNet = (props: { id: string }) => { const { id } = props; @@ -38,33 +34,7 @@ const ControlNet = (props: { id: string }) => { const { t } = useTranslation(); const activeTabName = useAppSelector(activeTabNameSelector); - - const selector = createSelector( - stateSelector, - ({ controlAdapters }) => { - const cn = selectControlAdapterById(controlAdapters, id); - - if (!cn) { - return { - isEnabled: false, - shouldAutoConfig: false, - }; - } - - const isEnabled = cn.isEnabled; - const shouldAutoConfig = isControlNetOrT2IAdapter(cn) - ? cn.shouldAutoConfig - : false; - - return { - isEnabled, - shouldAutoConfig, - }; - }, - defaultSelectorOptions - ); - - const { isEnabled, shouldAutoConfig } = useAppSelector(selector); + const isEnabled = useControlAdapterIsEnabled(id); const [isExpanded, toggleIsExpanded] = useToggle(false); const handleDelete = useCallback(() => { @@ -116,8 +86,6 @@ const ControlNet = (props: { id: string }) => { sx={{ w: 'full', minW: 0, - // opacity: isEnabled ? 1 : 0.5, - // pointerEvents: isEnabled ? 'auto' : 'none', transitionProperty: 'common', transitionDuration: '0.1s', }} @@ -176,23 +144,6 @@ const ControlNet = (props: { id: string }) => { /> } /> - - {!shouldAutoConfig && ( - - )} diff --git a/invokeai/frontend/web/src/features/controlNet/components/ParamControlNetShouldAutoConfig.tsx b/invokeai/frontend/web/src/features/controlNet/components/ParamControlNetShouldAutoConfig.tsx index 46d665a843..14f496abd0 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/ParamControlNetShouldAutoConfig.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/ParamControlNetShouldAutoConfig.tsx @@ -5,6 +5,7 @@ import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { useControlAdapterIsEnabled } from '../hooks/useControlAdapterIsEnabled'; import { useControlAdapterShouldAutoConfig } from '../hooks/useControlAdapterShouldAutoConfig'; +import { isNil } from 'lodash-es'; type Props = { id: string; @@ -20,6 +21,10 @@ const ParamControlNetShouldAutoConfig = ({ id }: Props) => { dispatch(controlAdapterAutoConfigToggled({ id })); }, [id, dispatch]); + if (isNil(shouldAutoConfig)) { + return null; + } + return ( { [dispatch, id] ); - console.log(model, selectedModel); - return ( (ca.processorType === 'none' && Boolean(ca.controlImage))) ); +const disableAllIPAdapters = ( + state: ControlAdaptersState, + exclude?: string +) => { + const updates: Update[] = selectAllIPAdapters(state) + .filter((ca) => ca.id !== exclude) + .map((ca) => ({ + id: ca.id, + changes: { isEnabled: false }, + })); + caAdapter.updateMany(state, updates); +}; + +const disableAllControlNets = ( + state: ControlAdaptersState, + exclude?: string +) => { + const updates: Update[] = selectAllControlNets(state) + .filter((ca) => ca.id !== exclude) + .map((ca) => ({ + id: ca.id, + changes: { isEnabled: false }, + })); + caAdapter.updateMany(state, updates); +}; + +const disableAllT2IAdapters = ( + state: ControlAdaptersState, + exclude?: string +) => { + const updates: Update[] = selectAllT2IAdapters(state) + .filter((ca) => ca.id !== exclude) + .map((ca) => ({ + id: ca.id, + changes: { isEnabled: false }, + })); + caAdapter.updateMany(state, updates); +}; + +const disableIncompatibleControlAdapters = ( + state: ControlAdaptersState, + type: ControlAdapterType, + exclude?: string +) => { + if (type === 'ip_adapter') { + // we can only have a single active IP Adapter, if we are enabling this one, disable others + disableAllIPAdapters(state, exclude); + } + if (type === 'controlnet') { + // we cannot do controlnet + t2i adapter, if we are enabled a controlnet, disable all t2is + disableAllT2IAdapters(state, exclude); + } + if (type === 't2i_adapter') { + // we cannot do controlnet + t2i adapter, if we are enabled a t2i, disable controlnets + disableAllControlNets(state, exclude); + } +}; + export const controlAdaptersSlice = createSlice({ name: 'controlAdapters', initialState: initialControlAdapterState, @@ -102,6 +160,7 @@ export const controlAdaptersSlice = createSlice({ ) => { const { id, type, overrides } = action.payload; caAdapter.addOne(state, buildControlAdapter(id, type, overrides)); + disableIncompatibleControlAdapters(state, type, id); }, prepare: ({ type, @@ -117,8 +176,9 @@ export const controlAdaptersSlice = createSlice({ state, action: PayloadAction ) => { - const config = action.payload; - caAdapter.addOne(state, config); + caAdapter.addOne(state, action.payload); + const { type, id } = action.payload; + disableIncompatibleControlAdapters(state, type, id); }, controlAdapterDuplicated: { reducer: ( @@ -137,6 +197,8 @@ export const controlAdaptersSlice = createSlice({ id: newId, }); caAdapter.addOne(state, newControlAdapter); + const { type } = newControlAdapter; + disableIncompatibleControlAdapters(state, type, newId); }, prepare: (id: string) => { return { payload: { id, newId: uuidv4() } }; @@ -156,6 +218,7 @@ export const controlAdaptersSlice = createSlice({ state, buildControlAdapter(id, type, { controlImage }) ); + disableIncompatibleControlAdapters(state, type, id); }, prepare: (payload: { type: ControlAdapterType; @@ -173,6 +236,12 @@ export const controlAdaptersSlice = createSlice({ ) => { const { id, isEnabled } = action.payload; caAdapter.updateOne(state, { id, changes: { isEnabled } }); + if (isEnabled) { + // we are enabling a control adapter. due to limitations in the current system, we may need to disable other adapters + // TODO: disable when multiple IP adapters are supported + const ca = selectControlAdapterById(state, id); + ca && disableIncompatibleControlAdapters(state, ca.type, id); + } }, controlAdapterImageChanged: ( state, @@ -182,8 +251,8 @@ export const controlAdaptersSlice = createSlice({ }> ) => { const { id, controlImage } = action.payload; - const cn = selectControlAdapterById(state, id); - if (!cn) { + const ca = selectControlAdapterById(state, id); + if (!ca) { return; } @@ -194,8 +263,8 @@ export const controlAdaptersSlice = createSlice({ if ( controlImage !== null && - isControlNetOrT2IAdapter(cn) && - cn.processorType !== 'none' + isControlNetOrT2IAdapter(ca) && + ca.processorType !== 'none' ) { state.pendingControlImages.push(id); } diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse.tsx index b0019b0fee..949f8c7708 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse.tsx @@ -13,6 +13,7 @@ import { selectAllControlNets, selectAllIPAdapters, selectAllT2IAdapters, + selectControlAdapterIds, } from 'features/controlNet/store/controlAdaptersSlice'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import { Fragment, memo } from 'react'; @@ -23,27 +24,23 @@ const selector = createSelector( ({ controlAdapters }) => { const activeLabel: string[] = []; - const ipAdapters = selectAllIPAdapters(controlAdapters); - const ipAdapterCount = ipAdapters.length; + const ipAdapterCount = selectAllIPAdapters(controlAdapters).length; if (ipAdapterCount > 0) { activeLabel.push(`${ipAdapterCount} IP`); } - const controlNets = selectAllControlNets(controlAdapters); - const controlNetCount = controlNets.length; + const controlNetCount = selectAllControlNets(controlAdapters).length; if (controlNetCount > 0) { activeLabel.push(`${controlNetCount} ControlNet`); } - const t2iAdapters = selectAllT2IAdapters(controlAdapters); - const t2iAdapterCount = t2iAdapters.length; + const t2iAdapterCount = selectAllT2IAdapters(controlAdapters).length; if (t2iAdapterCount > 0) { activeLabel.push(`${t2iAdapterCount} T2I`); } - const controlAdapterIds = [ipAdapters, controlNets, t2iAdapters] - .flat() - .map((ca) => ca.id); + const controlAdapterIds = + selectControlAdapterIds(controlAdapters).map(String); return { controlAdapterIds, @@ -72,6 +69,7 @@ const ParamControlNetCollapse = () => { leftIcon={} onClick={addControlNet} data-testid="add controlnet" + flexGrow={1} > ControlNet @@ -79,6 +77,7 @@ const ParamControlNetCollapse = () => { leftIcon={} onClick={addIPAdapter} data-testid="add ip adapter" + flexGrow={1} > IP Adapter @@ -86,10 +85,12 @@ const ParamControlNetCollapse = () => { leftIcon={} onClick={addT2IAdapter} data-testid="add t2i adapter" + flexGrow={1} > T2I Adapter + {controlAdapterIds.map((id, i) => ( {i > 0 && }