feat(ui): add mini/advanced controlnet ui

This commit is contained in:
psychedelicious 2023-06-03 15:05:49 +10:00
parent 69f0ba65f1
commit d6c08ba469
18 changed files with 430 additions and 548 deletions

View File

@ -71,7 +71,7 @@ import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSa
import { addCommitStagingAreaImageListener } from './listeners/addCommitStagingAreaImageListener'; import { addCommitStagingAreaImageListener } from './listeners/addCommitStagingAreaImageListener';
import { addImageCategoriesChangedListener } from './listeners/imageCategoriesChanged'; import { addImageCategoriesChangedListener } from './listeners/imageCategoriesChanged';
import { addControlNetImageProcessedListener } from './listeners/controlNetImageProcessed'; import { addControlNetImageProcessedListener } from './listeners/controlNetImageProcessed';
import { addControlNetAutoProcessListener } from './listeners/controlNetProcessorParamsChanged'; import { addControlNetAutoProcessListener } from './listeners/controlNetAutoProcess';
export const listenerMiddleware = createListenerMiddleware(); export const listenerMiddleware = createListenerMiddleware();

View File

@ -1,3 +1,4 @@
import { AnyAction } from '@reduxjs/toolkit';
import { startAppListening } from '..'; import { startAppListening } from '..';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { controlNetImageProcessed } from 'features/controlNet/store/actions'; import { controlNetImageProcessed } from 'features/controlNet/store/actions';
@ -5,10 +6,37 @@ import {
controlNetImageChanged, controlNetImageChanged,
controlNetProcessorParamsChanged, controlNetProcessorParamsChanged,
controlNetProcessorTypeChanged, controlNetProcessorTypeChanged,
isControlNetImagePreprocessedToggled,
} from 'features/controlNet/store/controlNetSlice'; } from 'features/controlNet/store/controlNetSlice';
import { RootState } from 'app/store/store';
const moduleLog = log.child({ namespace: 'controlNet' }); const moduleLog = log.child({ namespace: 'controlNet' });
const predicate = (action: AnyAction, state: RootState) => {
const isActionMatched =
controlNetProcessorParamsChanged.match(action) ||
controlNetImageChanged.match(action) ||
controlNetProcessorTypeChanged.match(action) ||
isControlNetImagePreprocessedToggled.match(action);
if (!isActionMatched) {
return false;
}
const { controlNetId } = action.payload;
const shouldAutoProcess =
!state.controlNet.controlNets[controlNetId].isPreprocessed;
const isBusy = state.system.isProcessing;
const hasControlImage = Boolean(
state.controlNet.controlNets[controlNetId].controlImage
);
return shouldAutoProcess && !isBusy && hasControlImage;
};
/** /**
* Listener that automatically processes a ControlNet image when its processor parameters are changed. * Listener that automatically processes a ControlNet image when its processor parameters are changed.
* *
@ -16,35 +44,13 @@ const moduleLog = log.child({ namespace: 'controlNet' });
*/ */
export const addControlNetAutoProcessListener = () => { export const addControlNetAutoProcessListener = () => {
startAppListening({ startAppListening({
predicate: (action) => predicate,
controlNetProcessorParamsChanged.match(action) ||
controlNetImageChanged.match(action) ||
controlNetProcessorTypeChanged.match(action),
effect: async ( effect: async (
action, action,
{ dispatch, getState, cancelActiveListeners, delay } { dispatch, getState, cancelActiveListeners, delay }
) => { ) => {
const state = getState();
if (!state.controlNet.shouldAutoProcess) {
// silently skip
return;
}
if (state.system.isProcessing) {
moduleLog.trace('System busy, skipping ControlNet auto-processing');
return;
}
const { controlNetId } = action.payload; const { controlNetId } = action.payload;
if (!state.controlNet.controlNets[controlNetId].controlImage) {
moduleLog.trace(
{ data: { controlNetId } },
'No ControlNet image to auto-process'
);
return;
}
// Cancel any in-progress instances of this listener // Cancel any in-progress instances of this listener
cancelActiveListeners(); cancelActiveListeners();

View File

@ -36,9 +36,11 @@ const IAISwitch = (props: Props) => {
alignItems="center" alignItems="center"
{...formControlProps} {...formControlProps}
> >
<FormLabel my={1} flexGrow={1} {...formLabelProps}> {label && (
{label} <FormLabel my={1} flexGrow={1} {...formLabelProps}>
</FormLabel> {label}
</FormLabel>
)}
<Switch {...rest} /> <Switch {...rest} />
</FormControl> </FormControl>
); );

View File

@ -1,32 +1,42 @@
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { import {
ControlNet, ControlNetConfig,
controlNetProcessedImageChanged, controlNetAdded,
controlNetRemoved, controlNetRemoved,
controlNetToggled,
isControlNetImagePreprocessedToggled,
} from '../store/controlNetSlice'; } from '../store/controlNetSlice';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import ParamControlNetModel from './parameters/ParamControlNetModel'; import ParamControlNetModel from './parameters/ParamControlNetModel';
import ParamControlNetWeight from './parameters/ParamControlNetWeight'; import ParamControlNetWeight from './parameters/ParamControlNetWeight';
import { import {
Box, Checkbox,
Flex, Flex,
Tab, FormControl,
FormLabel,
HStack,
TabList, TabList,
TabPanel,
TabPanels, TabPanels,
Tabs, Tabs,
Text, Tab,
TabPanel,
Box,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import IAIButton from 'common/components/IAIButton'; import { FaCopy, FaTrash } from 'react-icons/fa';
import { FaUndo } from 'react-icons/fa';
import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
import ControlNetImagePreview from './ControlNetImagePreview';
import IAIIconButton from 'common/components/IAIIconButton';
import { v4 as uuidv4 } from 'uuid';
import { useToggle } from 'react-use';
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect'; import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
import ControlNetProcessorComponent from './ControlNetProcessorComponent'; import ControlNetProcessorComponent from './ControlNetProcessorComponent';
import ControlNetPreprocessButton from './ControlNetPreprocessButton'; import ControlNetPreprocessButton from './ControlNetPreprocessButton';
import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd'; import IAIButton from 'common/components/IAIButton';
import ControlNetImagePreview from './ControlNetImagePreview'; import IAISwitch from 'common/components/IAISwitch';
type ControlNetProps = { type ControlNetProps = {
controlNet: ControlNet; controlNet: ControlNetConfig;
}; };
const ControlNet = (props: ControlNetProps) => { const ControlNet = (props: ControlNetProps) => {
@ -38,24 +48,160 @@ const ControlNet = (props: ControlNetProps) => {
beginStepPct, beginStepPct,
endStepPct, endStepPct,
controlImage, controlImage,
isControlImageProcessed, isPreprocessed: isControlImageProcessed,
processedControlImage, processedControlImage,
processorNode, processorNode,
} = props.controlNet; } = props.controlNet;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const handleReset = useCallback(() => { const [shouldShowAdvanced, onToggleAdvanced] = useToggle(true);
dispatch(
controlNetProcessedImageChanged({ const handleDelete = useCallback(() => {
controlNetId, dispatch(controlNetRemoved({ controlNetId }));
processedControlImage: null,
})
);
}, [controlNetId, dispatch]); }, [controlNetId, dispatch]);
const handleControlNetRemoved = useCallback(() => { const handleDuplicate = useCallback(() => {
dispatch(controlNetRemoved(controlNetId)); dispatch(
controlNetAdded({ controlNetId: uuidv4(), controlNet: props.controlNet })
);
}, [dispatch, props.controlNet]);
const handleToggleIsEnabled = useCallback(() => {
dispatch(controlNetToggled({ controlNetId }));
}, [controlNetId, dispatch]); }, [controlNetId, dispatch]);
const handleToggleIsPreprocessed = useCallback(() => {
dispatch(isControlNetImagePreprocessedToggled({ controlNetId }));
}, [controlNetId, dispatch]);
return (
<Flex
sx={{
flexDir: 'column',
gap: 2,
p: 2,
bg: 'base.850',
borderRadius: 'base',
}}
>
<HStack>
<IAISwitch
aria-label="Toggle ControlNet"
isChecked={isEnabled}
onChange={handleToggleIsEnabled}
/>
<Box
w="full"
opacity={isEnabled ? 1 : 0.5}
pointerEvents={isEnabled ? 'auto' : 'none'}
transitionProperty="common"
transitionDuration="0.1s"
>
<ParamControlNetModel controlNetId={controlNetId} model={model} />
</Box>
<IAIIconButton
size="sm"
tooltip="Duplicate ControlNet"
aria-label="Duplicate ControlNet"
onClick={handleDuplicate}
icon={<FaCopy />}
/>
<IAIIconButton
size="sm"
tooltip="Delete ControlNet"
aria-label="Delete ControlNet"
colorScheme="error"
onClick={handleDelete}
icon={<FaTrash />}
/>
</HStack>
{isEnabled && (
<>
<Flex sx={{ gap: 4 }}>
{!shouldShowAdvanced && (
<Flex
sx={{
alignItems: 'center',
justifyContent: 'center',
h: 32,
w: 32,
aspectRatio: '1/1',
}}
>
<ControlNetImagePreview controlNet={props.controlNet} />
</Flex>
)}
<Flex
sx={{
flexDir: 'column',
gap: 2,
w: 'full',
paddingInlineEnd: 2,
pb: shouldShowAdvanced ? 0 : 2,
justifyContent: 'space-between',
}}
>
<Flex
sx={{
justifyContent: 'space-between',
w: 'full',
}}
>
<FormControl>
<HStack>
<Checkbox
isChecked={isControlImageProcessed}
onChange={handleToggleIsPreprocessed}
/>
<FormLabel>Preprocessed</FormLabel>
</HStack>
</FormControl>
<FormControl>
<HStack>
<Checkbox
isChecked={shouldShowAdvanced}
onChange={onToggleAdvanced}
/>
<FormLabel>Advanced</FormLabel>
</HStack>
</FormControl>
</Flex>
<ParamControlNetWeight
controlNetId={controlNetId}
weight={weight}
mini
/>
<ParamControlNetBeginEnd
controlNetId={controlNetId}
beginStepPct={beginStepPct}
endStepPct={endStepPct}
mini
/>
</Flex>
</Flex>
{shouldShowAdvanced && (
<>
<Box pt={2}>
<ControlNetImagePreview controlNet={props.controlNet} />
</Box>
{!isControlImageProcessed && (
<>
<ParamControlNetProcessorSelect
controlNetId={controlNetId}
processorNode={processorNode}
/>
<ControlNetProcessorComponent
controlNetId={controlNetId}
processorNode={processorNode}
/>
</>
)}
</>
)}
</>
)}
</Flex>
);
return ( return (
<Flex sx={{ flexDir: 'column', gap: 3 }}> <Flex sx={{ flexDir: 'column', gap: 3 }}>
<ControlNetImagePreview controlNet={props.controlNet} /> <ControlNetImagePreview controlNet={props.controlNet} />
@ -101,18 +247,18 @@ const ControlNet = (props: ControlNetProps) => {
processorNode={processorNode} processorNode={processorNode}
/> />
<ControlNetPreprocessButton controlNet={props.controlNet} /> <ControlNetPreprocessButton controlNet={props.controlNet} />
<IAIButton {/* <IAIButton
size="sm" size="sm"
leftIcon={<FaUndo />} leftIcon={<FaUndo />}
onClick={handleReset} onClick={handleReset}
isDisabled={Boolean(!processedControlImage)} isDisabled={Boolean(!processedControlImage)}
> >
Reset Processing Reset Processing
</IAIButton> </IAIButton> */}
</TabPanel> </TabPanel>
</TabPanels> </TabPanels>
</Tabs> </Tabs>
<IAIButton onClick={handleControlNetRemoved}>Remove ControlNet</IAIButton> <IAIButton onClick={handleDelete}>Remove ControlNet</IAIButton>
</Flex> </Flex>
); );
}; };

View File

@ -1,7 +1,7 @@
import { memo, useCallback, useRef, useState } from 'react'; import { memo, useCallback, useRef, useState } from 'react';
import { ImageDTO } from 'services/api'; import { ImageDTO } from 'services/api';
import { import {
ControlNet, ControlNetConfig,
controlNetImageChanged, controlNetImageChanged,
controlNetSelector, controlNetSelector,
} from '../store/controlNetSlice'; } from '../store/controlNetSlice';
@ -24,7 +24,7 @@ const selector = createSelector(
); );
type Props = { type Props = {
controlNet: ControlNet; controlNet: ControlNetConfig;
}; };
const ControlNetImagePreview = (props: Props) => { const ControlNetImagePreview = (props: Props) => {
@ -32,7 +32,7 @@ const ControlNetImagePreview = (props: Props) => {
controlNetId, controlNetId,
controlImage, controlImage,
processedControlImage, processedControlImage,
isControlImageProcessed, isPreprocessed: isControlImageProcessed,
} = props.controlNet; } = props.controlNet;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { isProcessingControlImage } = useAppSelector(selector); const { isProcessingControlImage } = useAppSelector(selector);
@ -63,63 +63,62 @@ const ControlNetImagePreview = (props: Props) => {
<IAIDndImage <IAIDndImage
image={controlImage} image={controlImage}
onDrop={handleControlImageChanged} onDrop={handleControlImageChanged}
isDropDisabled={Boolean(processedControlImage)} isDropDisabled={Boolean(
processedControlImage && !isControlImageProcessed
)}
/> />
<AnimatePresence> <AnimatePresence>
{controlImage && {shouldShowProcessedImage && (
processedControlImage && <motion.div
shouldShowProcessedImage && initial={{
!isProcessingControlImage && ( opacity: 0,
<motion.div }}
initial={{ animate={{
opacity: 0, opacity: 1,
}} transition: { duration: 0.1 },
animate={{ }}
opacity: 1, exit={{
transition: { duration: 0.1 }, opacity: 0,
}} transition: { duration: 0.1 },
exit={{ }}
opacity: 0, >
transition: { duration: 0.1 }, <Box
sx={{
position: 'absolute',
w: 'full',
h: 'full',
top: 0,
insetInlineStart: 0,
}} }}
> >
{shouldShowProcessedImageBackdrop && (
<Box
sx={{
w: 'full',
h: 'full',
bg: 'base.900',
opacity: 0.7,
}}
/>
)}
<Box <Box
sx={{ sx={{
position: 'absolute', position: 'absolute',
w: 'full',
h: 'full',
top: 0, top: 0,
insetInlineStart: 0, insetInlineStart: 0,
w: 'full',
h: 'full',
}} }}
> >
{shouldShowProcessedImageBackdrop && ( <IAIDndImage
<Box image={processedControlImage}
sx={{ onDrop={handleControlImageChanged}
w: 'full', payloadImage={controlImage}
h: 'full', />
bg: 'base.900',
opacity: 0.7,
}}
/>
)}
<Box
sx={{
position: 'absolute',
top: 0,
insetInlineStart: 0,
w: 'full',
h: 'full',
}}
>
<IAIDndImage
image={processedControlImage}
onDrop={handleControlImageChanged}
payloadImage={controlImage}
/>
</Box>
</Box> </Box>
</motion.div> </Box>
)} </motion.div>
)}
</AnimatePresence> </AnimatePresence>
{isProcessingControlImage && ( {isProcessingControlImage && (
<Box <Box

View File

@ -1,153 +0,0 @@
import { memo, useCallback } from 'react';
import {
ControlNet,
controlNetAdded,
controlNetRemoved,
controlNetToggled,
isControlNetImageProcessedToggled,
} from '../store/controlNetSlice';
import { useAppDispatch } from 'app/store/storeHooks';
import ParamControlNetModel from './parameters/ParamControlNetModel';
import ParamControlNetWeight from './parameters/ParamControlNetWeight';
import {
Checkbox,
Flex,
FormControl,
FormLabel,
HStack,
} from '@chakra-ui/react';
import { FaCopy, FaTrash } from 'react-icons/fa';
import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
import ControlNetImagePreview from './ControlNetImagePreview';
import IAIIconButton from 'common/components/IAIIconButton';
import { v4 as uuidv4 } from 'uuid';
type ControlNetProps = {
controlNet: ControlNet;
};
const ControlNet = (props: ControlNetProps) => {
const {
controlNetId,
isEnabled,
model,
weight,
beginStepPct,
endStepPct,
controlImage,
isControlImageProcessed,
processedControlImage,
processorNode,
} = props.controlNet;
const dispatch = useAppDispatch();
const handleDelete = useCallback(() => {
dispatch(controlNetRemoved(controlNetId));
}, [controlNetId, dispatch]);
const handleDuplicate = useCallback(() => {
dispatch(
controlNetAdded({ controlNetId: uuidv4(), controlNet: props.controlNet })
);
}, [dispatch, props.controlNet]);
const handleToggleIsEnabled = useCallback(() => {
dispatch(controlNetToggled(controlNetId));
}, [controlNetId, dispatch]);
const handleToggleIsPreprocessed = useCallback(() => {
dispatch(isControlNetImageProcessedToggled(controlNetId));
}, [controlNetId, dispatch]);
return (
<Flex
sx={{
flexDir: 'column',
gap: 2,
}}
>
<HStack>
<ParamControlNetModel controlNetId={controlNetId} model={model} />
<IAIIconButton
size="sm"
tooltip="Duplicate ControlNet"
aria-label="Duplicate ControlNet"
onClick={handleDuplicate}
icon={<FaCopy />}
/>
<IAIIconButton
size="sm"
tooltip="Delete ControlNet"
aria-label="Delete ControlNet"
colorScheme="error"
onClick={handleDelete}
icon={<FaTrash />}
/>
</HStack>
<Flex
sx={{
gap: 4,
paddingInlineEnd: 2,
}}
>
<Flex
sx={{
alignItems: 'center',
justifyContent: 'center',
h: 32,
w: 32,
aspectRatio: '1/1',
}}
>
<ControlNetImagePreview controlNet={props.controlNet} />
</Flex>
<Flex
sx={{
flexDir: 'column',
gap: 2,
w: 'full',
justifyContent: 'space-between',
}}
>
<ParamControlNetWeight
controlNetId={controlNetId}
weight={weight}
mini
/>
<ParamControlNetBeginEnd
controlNetId={controlNetId}
beginStepPct={beginStepPct}
endStepPct={endStepPct}
mini
/>
<Flex
sx={{
justifyContent: 'space-between',
}}
>
<FormControl>
<HStack>
<Checkbox
isChecked={isEnabled}
onChange={handleToggleIsEnabled}
/>
<FormLabel>Enabled</FormLabel>
</HStack>
</FormControl>
<FormControl>
<HStack>
<Checkbox
isChecked={isControlImageProcessed}
onChange={handleToggleIsPreprocessed}
/>
<FormLabel>Preprocessed</FormLabel>
</HStack>
</FormControl>
</Flex>
</Flex>
</Flex>
</Flex>
);
};
export default memo(ControlNet);

View File

@ -1,12 +1,12 @@
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { ControlNet } from '../store/controlNetSlice'; import { ControlNetConfig } from '../store/controlNetSlice';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import { controlNetImageProcessed } from '../store/actions'; import { controlNetImageProcessed } from '../store/actions';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke'; import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
type Props = { type Props = {
controlNet: ControlNet; controlNet: ControlNetConfig;
}; };
const ControlNetPreprocessButton = (props: Props) => { const ControlNetPreprocessButton = (props: Props) => {

View File

@ -1,58 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import {
controlNetBeginStepPctChanged,
controlNetEndStepPctChanged,
} from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react';
type ParamControlNetBeginStepPctProps = {
controlNetId: string;
beginStepPct: number;
};
const ParamControlNetBeginStepPct = (
props: ParamControlNetBeginStepPctProps
) => {
const { controlNetId, beginStepPct } = props;
const dispatch = useAppDispatch();
const handleBeginStepPctChanged = useCallback(
(beginStepPct: number) => {
dispatch(controlNetBeginStepPctChanged({ controlNetId, beginStepPct }));
},
[controlNetId, dispatch]
);
const handleBeginStepPctReset = useCallback(() => {
dispatch(controlNetBeginStepPctChanged({ controlNetId, beginStepPct: 0 }));
}, [controlNetId, dispatch]);
const handleEndStepPctChanged = useCallback(
(endStepPct: number) => {
dispatch(controlNetEndStepPctChanged({ controlNetId, endStepPct }));
},
[controlNetId, dispatch]
);
const handleEndStepPctReset = useCallback(() => {
dispatch(controlNetEndStepPctChanged({ controlNetId, endStepPct: 0 }));
}, [controlNetId, dispatch]);
return (
<IAISlider
label="Begin Step %"
value={beginStepPct}
onChange={handleBeginStepPctChanged}
withInput
withReset
handleReset={handleBeginStepPctReset}
withSliderMarks
min={0}
max={1}
step={0.01}
/>
);
};
export default memo(ParamControlNetBeginStepPct);

View File

@ -1,42 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { controlNetEndStepPctChanged } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react';
type ParamControlNetEndStepPctProps = {
controlNetId: string;
endStepPct: number;
};
const ParamControlNetEndStepPct = (props: ParamControlNetEndStepPctProps) => {
const { controlNetId, endStepPct } = props;
const dispatch = useAppDispatch();
const handleEndStepPctChanged = useCallback(
(endStepPct: number) => {
dispatch(controlNetEndStepPctChanged({ controlNetId, endStepPct }));
},
[controlNetId, dispatch]
);
const handleEndStepPctReset = () => {
dispatch(controlNetEndStepPctChanged({ controlNetId, endStepPct: 0 }));
};
return (
<IAISlider
label="End Step %"
value={endStepPct}
onChange={handleEndStepPctChanged}
withInput
withReset
handleReset={handleEndStepPctReset}
withSliderMarks
min={0}
max={1}
step={0.01}
/>
);
};
export default memo(ParamControlNetEndStepPct);

View File

@ -13,7 +13,7 @@ const ParamControlNetIsEnabled = (props: ParamControlNetIsEnabledProps) => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const handleIsEnabledChanged = useCallback(() => { const handleIsEnabledChanged = useCallback(() => {
dispatch(controlNetToggled(controlNetId)); dispatch(controlNetToggled({ controlNetId }));
}, [dispatch, controlNetId]); }, [dispatch, controlNetId]);
return ( return (

View File

@ -3,7 +3,7 @@ import IAIFullCheckbox from 'common/components/IAIFullCheckbox';
import IAISwitch from 'common/components/IAISwitch'; import IAISwitch from 'common/components/IAISwitch';
import { import {
controlNetToggled, controlNetToggled,
isControlNetImageProcessedToggled, isControlNetImagePreprocessedToggled,
} from 'features/controlNet/store/controlNetSlice'; } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
@ -18,7 +18,7 @@ const ParamControlNetIsEnabled = (props: ParamControlNetIsEnabledProps) => {
const handleIsControlImageProcessedToggled = useCallback(() => { const handleIsControlImageProcessedToggled = useCallback(() => {
dispatch( dispatch(
isControlNetImageProcessedToggled({ isControlNetImagePreprocessedToggled({
controlNetId, controlNetId,
}) })
); );

View File

@ -3,8 +3,8 @@ import IAICustomSelect from 'common/components/IAICustomSelect';
import { import {
CONTROLNET_MODELS, CONTROLNET_MODELS,
ControlNetModel, ControlNetModel,
controlNetModelChanged, } from 'features/controlNet/store/constants';
} from 'features/controlNet/store/controlNetSlice'; import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
type ParamIsControlNetModelProps = { type ParamIsControlNetModelProps = {

View File

@ -22,7 +22,7 @@ type ControlNetProcessorsDict = Record<
* *
* TODO: Generate from the OpenAPI schema * TODO: Generate from the OpenAPI schema
*/ */
export const CONTROLNET_PROCESSORS = { export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = {
canny_image_processor: { canny_image_processor: {
type: 'canny_image_processor', type: 'canny_image_processor',
label: 'Canny', label: 'Canny',
@ -164,3 +164,27 @@ export const CONTROLNET_PROCESSORS = {
}, },
}, },
}; };
export const CONTROLNET_MODELS = [
'lllyasviel/sd-controlnet-canny',
'lllyasviel/sd-controlnet-depth',
'lllyasviel/sd-controlnet-hed',
'lllyasviel/sd-controlnet-seg',
'lllyasviel/sd-controlnet-openpose',
'lllyasviel/sd-controlnet-scribble',
'lllyasviel/sd-controlnet-normal',
'lllyasviel/sd-controlnet-mlsd',
];
export type ControlNetModel = (typeof CONTROLNET_MODELS)[number];
export const CONTROLNET_MODEL_MAP: Record<
ControlNetModel,
ControlNetProcessorType
> = {
'lllyasviel/sd-controlnet-canny': 'canny_image_processor',
'lllyasviel/sd-controlnet-depth': 'midas_depth_image_processor',
'lllyasviel/sd-controlnet-hed': 'hed_image_processor',
'lllyasviel/sd-controlnet-openpose': 'openpose_image_processor',
'lllyasviel/sd-controlnet-mlsd': 'mlsd_image_processor',
};

View File

@ -7,36 +7,27 @@ import {
RequiredCannyImageProcessorInvocation, RequiredCannyImageProcessorInvocation,
RequiredControlNetProcessorNode, RequiredControlNetProcessorNode,
} from './types'; } from './types';
import { CONTROLNET_PROCESSORS } from './constants'; import {
CONTROLNET_MODELS,
CONTROLNET_PROCESSORS,
ControlNetModel,
} from './constants';
import { controlNetImageProcessed } from './actions'; import { controlNetImageProcessed } from './actions';
export const CONTROLNET_MODELS = [ export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
'lllyasviel/sd-controlnet-canny',
'lllyasviel/sd-controlnet-depth',
'lllyasviel/sd-controlnet-hed',
'lllyasviel/sd-controlnet-seg',
'lllyasviel/sd-controlnet-openpose',
'lllyasviel/sd-controlnet-scribble',
'lllyasviel/sd-controlnet-normal',
'lllyasviel/sd-controlnet-mlsd',
];
export type ControlNetModel = (typeof CONTROLNET_MODELS)[number];
export const initialControlNet: Omit<ControlNet, 'controlNetId'> = {
isEnabled: true, isEnabled: true,
model: CONTROLNET_MODELS[0], model: CONTROLNET_MODELS[0],
weight: 1, weight: 1,
beginStepPct: 0, beginStepPct: 0,
endStepPct: 1, endStepPct: 1,
controlImage: null, controlImage: null,
isControlImageProcessed: false, isPreprocessed: false,
processedControlImage: null, processedControlImage: null,
processorNode: CONTROLNET_PROCESSORS.canny_image_processor processorNode: CONTROLNET_PROCESSORS.canny_image_processor
.default as RequiredCannyImageProcessorInvocation, .default as RequiredCannyImageProcessorInvocation,
}; };
export type ControlNet = { export type ControlNetConfig = {
controlNetId: string; controlNetId: string;
isEnabled: boolean; isEnabled: boolean;
model: ControlNetModel; model: ControlNetModel;
@ -44,22 +35,20 @@ export type ControlNet = {
beginStepPct: number; beginStepPct: number;
endStepPct: number; endStepPct: number;
controlImage: ImageDTO | null; controlImage: ImageDTO | null;
isControlImageProcessed: boolean; isPreprocessed: boolean;
processedControlImage: ImageDTO | null; processedControlImage: ImageDTO | null;
processorNode: RequiredControlNetProcessorNode; processorNode: RequiredControlNetProcessorNode;
}; };
export type ControlNetState = { export type ControlNetState = {
controlNets: Record<string, ControlNet>; controlNets: Record<string, ControlNetConfig>;
isEnabled: boolean; isEnabled: boolean;
shouldAutoProcess: boolean;
isProcessingControlImage: boolean; isProcessingControlImage: boolean;
}; };
export const initialControlNetState: ControlNetState = { export const initialControlNetState: ControlNetState = {
controlNets: {}, controlNets: {},
isEnabled: false, isEnabled: false,
shouldAutoProcess: true,
isProcessingControlImage: false, isProcessingControlImage: false,
}; };
@ -72,7 +61,10 @@ export const controlNetSlice = createSlice({
}, },
controlNetAdded: ( controlNetAdded: (
state, state,
action: PayloadAction<{ controlNetId: string; controlNet?: ControlNet }> action: PayloadAction<{
controlNetId: string;
controlNet?: ControlNetConfig;
}>
) => { ) => {
const { controlNetId, controlNet } = action.payload; const { controlNetId, controlNet } = action.payload;
state.controlNets[controlNetId] = { state.controlNets[controlNetId] = {
@ -91,12 +83,18 @@ export const controlNetSlice = createSlice({
controlImage, controlImage,
}; };
}, },
controlNetRemoved: (state, action: PayloadAction<string>) => { controlNetRemoved: (
const controlNetId = action.payload; state,
action: PayloadAction<{ controlNetId: string }>
) => {
const { controlNetId } = action.payload;
delete state.controlNets[controlNetId]; delete state.controlNets[controlNetId];
}, },
controlNetToggled: (state, action: PayloadAction<string>) => { controlNetToggled: (
const controlNetId = action.payload; state,
action: PayloadAction<{ controlNetId: string }>
) => {
const { controlNetId } = action.payload;
state.controlNets[controlNetId].isEnabled = state.controlNets[controlNetId].isEnabled =
!state.controlNets[controlNetId].isEnabled; !state.controlNets[controlNetId].isEnabled;
}, },
@ -110,17 +108,20 @@ export const controlNetSlice = createSlice({
const { controlNetId, controlImage } = action.payload; const { controlNetId, controlImage } = action.payload;
state.controlNets[controlNetId].controlImage = controlImage; state.controlNets[controlNetId].controlImage = controlImage;
state.controlNets[controlNetId].processedControlImage = null; state.controlNets[controlNetId].processedControlImage = null;
if (state.shouldAutoProcess && controlImage !== null) { if (
controlImage !== null &&
!state.controlNets[controlNetId].isPreprocessed
) {
state.isProcessingControlImage = true; state.isProcessingControlImage = true;
} }
}, },
isControlNetImageProcessedToggled: ( isControlNetImagePreprocessedToggled: (
state, state,
action: PayloadAction<string> action: PayloadAction<{ controlNetId: string }>
) => { ) => {
const controlNetId = action.payload; const { controlNetId } = action.payload;
state.controlNets[controlNetId].isControlImageProcessed = state.controlNets[controlNetId].isPreprocessed =
!state.controlNets[controlNetId].isControlImageProcessed; !state.controlNets[controlNetId].isPreprocessed;
}, },
controlNetProcessedImageChanged: ( controlNetProcessedImageChanged: (
state, state,
@ -191,9 +192,6 @@ export const controlNetSlice = createSlice({
processorType processorType
].default as RequiredControlNetProcessorNode; ].default as RequiredControlNetProcessorNode;
}, },
shouldAutoProcessToggled: (state) => {
state.shouldAutoProcess = !state.shouldAutoProcess;
},
}, },
extraReducers: (builder) => { extraReducers: (builder) => {
builder.addCase(controlNetImageProcessed, (state, action) => { builder.addCase(controlNetImageProcessed, (state, action) => {
@ -212,7 +210,7 @@ export const {
controlNetAddedFromImage, controlNetAddedFromImage,
controlNetRemoved, controlNetRemoved,
controlNetImageChanged, controlNetImageChanged,
isControlNetImageProcessedToggled, isControlNetImagePreprocessedToggled,
controlNetProcessedImageChanged, controlNetProcessedImageChanged,
controlNetToggled, controlNetToggled,
controlNetModelChanged, controlNetModelChanged,
@ -221,7 +219,6 @@ export const {
controlNetEndStepPctChanged, controlNetEndStepPctChanged,
controlNetProcessorParamsChanged, controlNetProcessorParamsChanged,
controlNetProcessorTypeChanged, controlNetProcessorTypeChanged,
shouldAutoProcessToggled,
} = controlNetSlice.actions; } = controlNetSlice.actions;
export default controlNetSlice.reducer; export default controlNetSlice.reducer;

View File

@ -0,0 +1,100 @@
import { RootState } from 'app/store/store';
import { forEach, size } from 'lodash-es';
import { CollectInvocation, ControlNetInvocation } from 'services/api';
import { NonNullableGraph } from '../types/types';
const CONTROL_NET_COLLECT = 'control_net_collect';
export const addControlNetToLinearGraph = (
graph: NonNullableGraph,
baseNodeId: string,
state: RootState
): void => {
const { isEnabled: isControlNetEnabled, controlNets } = state.controlNet;
// Add ControlNet
if (isControlNetEnabled) {
if (size(controlNets) > 1) {
const controlNetIterateNode: CollectInvocation = {
id: CONTROL_NET_COLLECT,
type: 'collect',
};
graph.nodes[controlNetIterateNode.id] = controlNetIterateNode;
graph.edges.push({
source: { node_id: controlNetIterateNode.id, field: 'collection' },
destination: {
node_id: baseNodeId,
field: 'control',
},
});
}
forEach(controlNets, (controlNet, index) => {
const {
controlNetId,
isEnabled,
isPreprocessed: isControlImageProcessed,
controlImage,
processedControlImage,
beginStepPct,
endStepPct,
model,
processorNode,
weight,
} = controlNet;
if (!isEnabled) {
// Skip disabled ControlNets
return;
}
const controlNetNode: ControlNetInvocation = {
id: `control_net_${controlNetId}`,
type: 'controlnet',
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
control_model: model as ControlNetInvocation['control_model'],
control_weight: weight,
};
if (processedControlImage && !isControlImageProcessed) {
// We've already processed the image in the app, so we can just use the processed image
const { image_name, image_origin } = processedControlImage;
controlNetNode.image = {
image_name,
image_origin,
};
} else if (controlImage && isControlImageProcessed) {
// The control image is preprocessed
const { image_name, image_origin } = controlImage;
controlNetNode.image = {
image_name,
image_origin,
};
} else {
// Skip ControlNets without an unprocessed image - should never happen if everything is working correctly
return;
}
graph.nodes[controlNetNode.id] = controlNetNode;
if (size(controlNets) > 1) {
graph.edges.push({
source: { node_id: controlNetNode.id, field: 'control' },
destination: {
node_id: CONTROL_NET_COLLECT,
field: 'item',
},
});
} else {
graph.edges.push({
source: { node_id: controlNetNode.id, field: 'control' },
destination: {
node_id: baseNodeId,
field: 'control',
},
});
}
});
}
};

View File

@ -14,6 +14,7 @@ import {
import { NonNullableGraph } from 'features/nodes/types/types'; import { NonNullableGraph } from 'features/nodes/types/types';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { set } from 'lodash-es'; import { set } from 'lodash-es';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
const moduleLog = log.child({ namespace: 'nodes' }); const moduleLog = log.child({ namespace: 'nodes' });
@ -408,5 +409,7 @@ export const buildImageToImageGraph = (state: RootState): Graph => {
}); });
} }
addControlNetToLinearGraph(graph, LATENTS_TO_LATENTS, state);
return graph; return graph;
}; };

View File

@ -1,8 +1,6 @@
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { import {
CollectInvocation,
CompelInvocation, CompelInvocation,
ControlNetInvocation,
Graph, Graph,
IterateInvocation, IterateInvocation,
LatentsToImageInvocation, LatentsToImageInvocation,
@ -12,7 +10,7 @@ import {
TextToLatentsInvocation, TextToLatentsInvocation,
} from 'services/api'; } from 'services/api';
import { NonNullableGraph } from 'features/nodes/types/types'; import { NonNullableGraph } from 'features/nodes/types/types';
import { forEach, size } from 'lodash-es'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
const POSITIVE_CONDITIONING = 'positive_conditioning'; const POSITIVE_CONDITIONING = 'positive_conditioning';
const NEGATIVE_CONDITIONING = 'negative_conditioning'; const NEGATIVE_CONDITIONING = 'negative_conditioning';
@ -22,7 +20,6 @@ const NOISE = 'noise';
const RANDOM_INT = 'rand_int'; const RANDOM_INT = 'rand_int';
const RANGE_OF_SIZE = 'range_of_size'; const RANGE_OF_SIZE = 'range_of_size';
const ITERATE = 'iterate'; const ITERATE = 'iterate';
const CONTROL_NET_COLLECT = 'control_net_collect';
/** /**
* Builds the Text to Image tab graph. * Builds the Text to Image tab graph.
@ -42,8 +39,6 @@ export const buildTextToImageGraph = (state: RootState): Graph => {
shouldRandomizeSeed, shouldRandomizeSeed,
} = state.generation; } = state.generation;
const { isEnabled: isControlNetEnabled, controlNets } = state.controlNet;
const graph: NonNullableGraph = { const graph: NonNullableGraph = {
nodes: {}, nodes: {},
edges: [], edges: [],
@ -315,91 +310,7 @@ export const buildTextToImageGraph = (state: RootState): Graph => {
}); });
} }
// Add ControlNet addControlNetToLinearGraph(graph, TEXT_TO_LATENTS, state);
if (isControlNetEnabled) {
if (size(controlNets) > 1) {
const controlNetIterateNode: CollectInvocation = {
id: CONTROL_NET_COLLECT,
type: 'collect',
};
graph.nodes[controlNetIterateNode.id] = controlNetIterateNode;
graph.edges.push({
source: { node_id: controlNetIterateNode.id, field: 'collection' },
destination: {
node_id: TEXT_TO_LATENTS,
field: 'control',
},
});
}
forEach(controlNets, (controlNet, index) => {
const {
controlNetId,
isEnabled,
isControlImageProcessed,
controlImage,
processedControlImage,
beginStepPct,
endStepPct,
model,
processorNode,
weight,
} = controlNet;
if (!isEnabled) {
// Skip disabled ControlNets
return;
}
const controlNetNode: ControlNetInvocation = {
id: `control_net_${controlNetId}`,
type: 'controlnet',
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
control_model: model as ControlNetInvocation['control_model'],
control_weight: weight,
};
if (processedControlImage && !isControlImageProcessed) {
// We've already processed the image in the app, so we can just use the processed image
const { image_name, image_origin } = processedControlImage;
controlNetNode.image = {
image_name,
image_origin,
};
} else if (controlImage && isControlImageProcessed) {
// The control image is preprocessed
const { image_name, image_origin } = controlImage;
controlNetNode.image = {
image_name,
image_origin,
};
} else {
// Skip ControlNets without an unprocessed image - should never happen if everything is working correctly
return;
}
graph.nodes[controlNetNode.id] = controlNetNode;
if (size(controlNets) > 1) {
graph.edges.push({
source: { node_id: controlNetNode.id, field: 'control' },
destination: {
node_id: CONTROL_NET_COLLECT,
field: 'item',
},
});
} else {
graph.edges.push({
source: { node_id: controlNetNode.id, field: 'control' },
destination: {
node_id: TEXT_TO_LATENTS,
field: 'control',
},
});
}
});
}
return graph; return graph;
}; };

View File

@ -1,18 +1,7 @@
import { import { Divider, Flex } from '@chakra-ui/react';
Divider,
Flex,
Tab,
TabList,
TabPanel,
TabPanels,
Tabs,
} from '@chakra-ui/react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import IAICollapse from 'common/components/IAICollapse'; import IAICollapse from 'common/components/IAICollapse';
import { Fragment, memo, useCallback } from 'react'; import { Fragment, memo, useCallback } from 'react';
import IAIIconButton from 'common/components/IAIIconButton';
import { FaPlus } from 'react-icons/fa';
import ControlNet from 'features/controlNet/components/ControlNet';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { import {
@ -23,9 +12,9 @@ import {
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { map } from 'lodash-es'; import { map } from 'lodash-es';
import { v4 as uuidv4 } from 'uuid'; import { v4 as uuidv4 } from 'uuid';
import ControlNetMini from 'features/controlNet/components/ControlNetMini';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import ControlNet from 'features/controlNet/components/ControlNet';
const selector = createSelector( const selector = createSelector(
controlNetSelector, controlNetSelector,
@ -62,61 +51,19 @@ const ParamControlNetCollapse = () => {
onToggle={handleClickControlNetToggle} onToggle={handleClickControlNetToggle}
withSwitch withSwitch
> >
{controlNetsArray.length === 0 && ( <Flex sx={{ flexDir: 'column', gap: 3 }}>
<IAIButton onClick={handleClickedAddControlNet}>
Add ControlNet
</IAIButton>
)}
<Flex sx={{ flexDir: 'column', gap: 4 }}>
{controlNetsArray.map((c, i) => ( {controlNetsArray.map((c, i) => (
<Fragment key={c.controlNetId}> <Fragment key={c.controlNetId}>
{i > 0 && <Divider />} {i > 0 && <Divider />}
<ControlNetMini controlNet={c} /> <ControlNet controlNet={c} />
</Fragment> </Fragment>
))} ))}
<IAIButton flexGrow={1} onClick={handleClickedAddControlNet}>
Add ControlNet
</IAIButton>
</Flex> </Flex>
</IAICollapse> </IAICollapse>
); );
return (
<IAICollapse
label={'ControlNet'}
isOpen={isEnabled}
onToggle={handleClickControlNetToggle}
withSwitch
>
<Tabs
isFitted
orientation="horizontal"
variant="line"
size="sm"
colorScheme="accent"
>
<TabList alignItems="center" borderBottomColor="base.800" pb={4}>
{controlNetsArray.map((c, i) => (
<Tab key={`tab_${c.controlNetId}`} borderTopRadius="base">
{i + 1}
</Tab>
))}
<IAIIconButton
marginInlineStart={2}
size="sm"
aria-label="Add ControlNet"
onClick={handleClickedAddControlNet}
icon={<FaPlus />}
/>
</TabList>
<TabPanels>
{controlNetsArray.map((c) => (
<TabPanel key={`tabPanel_${c.controlNetId}`} sx={{ p: 0 }}>
<ControlNet controlNet={c} />
{/* <ControlNetMini controlNet={c} /> */}
</TabPanel>
))}
</TabPanels>
</Tabs>
</IAICollapse>
);
}; };
export default memo(ParamControlNetCollapse); export default memo(ParamControlNetCollapse);