feat(ui): add limits to enabled control adapters

- only 1 ip adapter at a time
- controlnet and t2i cannot both be active at once
This commit is contained in:
psychedelicious 2023-10-06 18:13:14 +11:00
parent dcfbd49e1b
commit ba4616ff89
5 changed files with 94 additions and 70 deletions

View File

@ -6,20 +6,18 @@ import {
controlAdapterDuplicated,
controlAdapterIsEnabledChanged,
controlAdapterRemoved,
selectControlAdapterById,
} from '../store/controlAdaptersSlice';
import ParamControlNetModel from './parameters/ParamControlNetModel';
import ParamControlNetWeight from './parameters/ParamControlNetWeight';
import { ChevronUpIcon } from '@chakra-ui/icons';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIIconButton from 'common/components/IAIIconButton';
import IAISwitch from 'common/components/IAISwitch';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { useTranslation } from 'react-i18next';
import { isControlNetOrT2IAdapter } from '../store/types';
import { useToggle } from 'react-use';
import { useControlAdapterIsEnabled } from '../hooks/useControlAdapterIsEnabled';
import { useControlAdapterType } from '../hooks/useControlAdapterType';
import ControlNetImagePreview from './ControlNetImagePreview';
import ControlNetProcessorComponent from './ControlNetProcessorComponent';
import ParamControlNetShouldAutoConfig from './ParamControlNetShouldAutoConfig';
@ -28,8 +26,6 @@ import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
import ParamControlNetControlMode from './parameters/ParamControlNetControlMode';
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
import ParamControlNetResizeMode from './parameters/ParamControlNetResizeMode';
import { useToggle } from 'react-use';
import { useControlAdapterType } from '../hooks/useControlAdapterType';
const ControlNet = (props: { id: string }) => {
const { id } = props;
@ -38,33 +34,7 @@ const ControlNet = (props: { id: string }) => {
const { t } = useTranslation();
const activeTabName = useAppSelector(activeTabNameSelector);
const selector = createSelector(
stateSelector,
({ controlAdapters }) => {
const cn = selectControlAdapterById(controlAdapters, id);
if (!cn) {
return {
isEnabled: false,
shouldAutoConfig: false,
};
}
const isEnabled = cn.isEnabled;
const shouldAutoConfig = isControlNetOrT2IAdapter(cn)
? cn.shouldAutoConfig
: false;
return {
isEnabled,
shouldAutoConfig,
};
},
defaultSelectorOptions
);
const { isEnabled, shouldAutoConfig } = useAppSelector(selector);
const isEnabled = useControlAdapterIsEnabled(id);
const [isExpanded, toggleIsExpanded] = useToggle(false);
const handleDelete = useCallback(() => {
@ -116,8 +86,6 @@ const ControlNet = (props: { id: string }) => {
sx={{
w: 'full',
minW: 0,
// opacity: isEnabled ? 1 : 0.5,
// pointerEvents: isEnabled ? 'auto' : 'none',
transitionProperty: 'common',
transitionDuration: '0.1s',
}}
@ -176,23 +144,6 @@ const ControlNet = (props: { id: string }) => {
/>
}
/>
{!shouldAutoConfig && (
<Box
sx={{
position: 'absolute',
w: 1.5,
h: 1.5,
borderRadius: 'full',
top: 4,
insetInlineEnd: 4,
bg: 'accent.700',
_dark: {
bg: 'accent.400',
},
}}
/>
)}
</Flex>
<Flex sx={{ w: 'full', flexDirection: 'column', gap: 3 }}>

View File

@ -5,6 +5,7 @@ import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useControlAdapterIsEnabled } from '../hooks/useControlAdapterIsEnabled';
import { useControlAdapterShouldAutoConfig } from '../hooks/useControlAdapterShouldAutoConfig';
import { isNil } from 'lodash-es';
type Props = {
id: string;
@ -20,6 +21,10 @@ const ParamControlNetShouldAutoConfig = ({ id }: Props) => {
dispatch(controlAdapterAutoConfigToggled({ id }));
}, [id, dispatch]);
if (isNil(shouldAutoConfig)) {
return null;
}
return (
<IAISwitch
label={t('controlnet.autoConfigure')}

View File

@ -106,8 +106,6 @@ const ParamControlNetModel = ({ id }: ParamControlNetModelProps) => {
[dispatch, id]
);
console.log(model, selectedModel);
return (
<IAIMantineSearchableSelect
itemComponent={IAIMantineSelectItemWithTooltip}

View File

@ -87,6 +87,64 @@ export const selectValidT2IAdapters = (controlAdapters: ControlAdaptersState) =>
(ca.processorType === 'none' && Boolean(ca.controlImage)))
);
const disableAllIPAdapters = (
state: ControlAdaptersState,
exclude?: string
) => {
const updates: Update<ControlAdapterConfig>[] = selectAllIPAdapters(state)
.filter((ca) => ca.id !== exclude)
.map((ca) => ({
id: ca.id,
changes: { isEnabled: false },
}));
caAdapter.updateMany(state, updates);
};
const disableAllControlNets = (
state: ControlAdaptersState,
exclude?: string
) => {
const updates: Update<ControlAdapterConfig>[] = selectAllControlNets(state)
.filter((ca) => ca.id !== exclude)
.map((ca) => ({
id: ca.id,
changes: { isEnabled: false },
}));
caAdapter.updateMany(state, updates);
};
const disableAllT2IAdapters = (
state: ControlAdaptersState,
exclude?: string
) => {
const updates: Update<ControlAdapterConfig>[] = selectAllT2IAdapters(state)
.filter((ca) => ca.id !== exclude)
.map((ca) => ({
id: ca.id,
changes: { isEnabled: false },
}));
caAdapter.updateMany(state, updates);
};
const disableIncompatibleControlAdapters = (
state: ControlAdaptersState,
type: ControlAdapterType,
exclude?: string
) => {
if (type === 'ip_adapter') {
// we can only have a single active IP Adapter, if we are enabling this one, disable others
disableAllIPAdapters(state, exclude);
}
if (type === 'controlnet') {
// we cannot do controlnet + t2i adapter, if we are enabled a controlnet, disable all t2is
disableAllT2IAdapters(state, exclude);
}
if (type === 't2i_adapter') {
// we cannot do controlnet + t2i adapter, if we are enabled a t2i, disable controlnets
disableAllControlNets(state, exclude);
}
};
export const controlAdaptersSlice = createSlice({
name: 'controlAdapters',
initialState: initialControlAdapterState,
@ -102,6 +160,7 @@ export const controlAdaptersSlice = createSlice({
) => {
const { id, type, overrides } = action.payload;
caAdapter.addOne(state, buildControlAdapter(id, type, overrides));
disableIncompatibleControlAdapters(state, type, id);
},
prepare: ({
type,
@ -117,8 +176,9 @@ export const controlAdaptersSlice = createSlice({
state,
action: PayloadAction<ControlAdapterConfig>
) => {
const config = action.payload;
caAdapter.addOne(state, config);
caAdapter.addOne(state, action.payload);
const { type, id } = action.payload;
disableIncompatibleControlAdapters(state, type, id);
},
controlAdapterDuplicated: {
reducer: (
@ -137,6 +197,8 @@ export const controlAdaptersSlice = createSlice({
id: newId,
});
caAdapter.addOne(state, newControlAdapter);
const { type } = newControlAdapter;
disableIncompatibleControlAdapters(state, type, newId);
},
prepare: (id: string) => {
return { payload: { id, newId: uuidv4() } };
@ -156,6 +218,7 @@ export const controlAdaptersSlice = createSlice({
state,
buildControlAdapter(id, type, { controlImage })
);
disableIncompatibleControlAdapters(state, type, id);
},
prepare: (payload: {
type: ControlAdapterType;
@ -173,6 +236,12 @@ export const controlAdaptersSlice = createSlice({
) => {
const { id, isEnabled } = action.payload;
caAdapter.updateOne(state, { id, changes: { isEnabled } });
if (isEnabled) {
// we are enabling a control adapter. due to limitations in the current system, we may need to disable other adapters
// TODO: disable when multiple IP adapters are supported
const ca = selectControlAdapterById(state, id);
ca && disableIncompatibleControlAdapters(state, ca.type, id);
}
},
controlAdapterImageChanged: (
state,
@ -182,8 +251,8 @@ export const controlAdaptersSlice = createSlice({
}>
) => {
const { id, controlImage } = action.payload;
const cn = selectControlAdapterById(state, id);
if (!cn) {
const ca = selectControlAdapterById(state, id);
if (!ca) {
return;
}
@ -194,8 +263,8 @@ export const controlAdaptersSlice = createSlice({
if (
controlImage !== null &&
isControlNetOrT2IAdapter(cn) &&
cn.processorType !== 'none'
isControlNetOrT2IAdapter(ca) &&
ca.processorType !== 'none'
) {
state.pendingControlImages.push(id);
}

View File

@ -13,6 +13,7 @@ import {
selectAllControlNets,
selectAllIPAdapters,
selectAllT2IAdapters,
selectControlAdapterIds,
} from 'features/controlNet/store/controlAdaptersSlice';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { Fragment, memo } from 'react';
@ -23,27 +24,23 @@ const selector = createSelector(
({ controlAdapters }) => {
const activeLabel: string[] = [];
const ipAdapters = selectAllIPAdapters(controlAdapters);
const ipAdapterCount = ipAdapters.length;
const ipAdapterCount = selectAllIPAdapters(controlAdapters).length;
if (ipAdapterCount > 0) {
activeLabel.push(`${ipAdapterCount} IP`);
}
const controlNets = selectAllControlNets(controlAdapters);
const controlNetCount = controlNets.length;
const controlNetCount = selectAllControlNets(controlAdapters).length;
if (controlNetCount > 0) {
activeLabel.push(`${controlNetCount} ControlNet`);
}
const t2iAdapters = selectAllT2IAdapters(controlAdapters);
const t2iAdapterCount = t2iAdapters.length;
const t2iAdapterCount = selectAllT2IAdapters(controlAdapters).length;
if (t2iAdapterCount > 0) {
activeLabel.push(`${t2iAdapterCount} T2I`);
}
const controlAdapterIds = [ipAdapters, controlNets, t2iAdapters]
.flat()
.map((ca) => ca.id);
const controlAdapterIds =
selectControlAdapterIds(controlAdapters).map(String);
return {
controlAdapterIds,
@ -72,6 +69,7 @@ const ParamControlNetCollapse = () => {
leftIcon={<FaPlus />}
onClick={addControlNet}
data-testid="add controlnet"
flexGrow={1}
>
ControlNet
</IAIButton>
@ -79,6 +77,7 @@ const ParamControlNetCollapse = () => {
leftIcon={<FaPlus />}
onClick={addIPAdapter}
data-testid="add ip adapter"
flexGrow={1}
>
IP Adapter
</IAIButton>
@ -86,10 +85,12 @@ const ParamControlNetCollapse = () => {
leftIcon={<FaPlus />}
onClick={addT2IAdapter}
data-testid="add t2i adapter"
flexGrow={1}
>
T2I Adapter
</IAIButton>
</ButtonGroup>
<Divider />
{controlAdapterIds.map((id, i) => (
<Fragment key={id}>
{i > 0 && <Divider />}