feat(ui): add rest of controlnet processors

This commit is contained in:
psychedelicious 2023-06-02 17:26:05 +10:00
parent 707ed39300
commit 9cdad95f48
26 changed files with 1458 additions and 291 deletions

View File

@ -71,6 +71,7 @@ import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSa
import { addCommitStagingAreaImageListener } from './listeners/addCommitStagingAreaImageListener'; import { addCommitStagingAreaImageListener } from './listeners/addCommitStagingAreaImageListener';
import { addImageCategoriesChangedListener } from './listeners/imageCategoriesChanged'; import { addImageCategoriesChangedListener } from './listeners/imageCategoriesChanged';
import { addControlNetImageProcessedListener } from './listeners/controlNetImageProcessed'; import { addControlNetImageProcessedListener } from './listeners/controlNetImageProcessed';
import { addControlNetProcessorParamsChangedListener } from './listeners/controlNetProcessorParamsChanged';
export const listenerMiddleware = createListenerMiddleware(); export const listenerMiddleware = createListenerMiddleware();
@ -177,3 +178,4 @@ addImageCategoriesChangedListener();
// ControlNet // ControlNet
addControlNetImageProcessedListener(); addControlNetImageProcessedListener();
addControlNetProcessorParamsChangedListener();

View File

@ -8,6 +8,7 @@ import { sessionReadyToInvoke } from 'features/system/store/actions';
import { socketInvocationComplete } from 'services/events/actions'; import { socketInvocationComplete } from 'services/events/actions';
import { isImageOutput } from 'services/types/guards'; import { isImageOutput } from 'services/types/guards';
import { controlNetProcessedImageChanged } from 'features/controlNet/store/controlNetSlice'; import { controlNetProcessedImageChanged } from 'features/controlNet/store/controlNetSlice';
import { pick } from 'lodash-es';
const moduleLog = log.child({ namespace: 'controlNet' }); const moduleLog = log.child({ namespace: 'controlNet' });
@ -15,11 +16,27 @@ export const addControlNetImageProcessedListener = () => {
startAppListening({ startAppListening({
actionCreator: controlNetImageProcessed, actionCreator: controlNetImageProcessed,
effect: async (action, { dispatch, getState, take }) => { effect: async (action, { dispatch, getState, take }) => {
const { controlNetId, processorNode } = action.payload; const { controlNetId } = action.payload;
const controlNet = getState().controlNet.controlNets[controlNetId];
// ControlNet one-off procressing graph is just he processor node, no edges if (!controlNet.controlImage) {
moduleLog.error('Unable to process ControlNet image');
return;
}
// ControlNet one-off procressing graph is just the processor node, no edges.
// Also we need to grab the image.
const graph: Graph = { const graph: Graph = {
nodes: { [processorNode.id]: processorNode }, nodes: {
[controlNet.processorNode.id]: {
...controlNet.processorNode,
is_intermediate: true,
image: pick(controlNet.controlImage, [
'image_name',
'image_origin',
]),
},
},
}; };
// Create a session to run the graph & wait til it's ready to invoke // Create a session to run the graph & wait til it's ready to invoke

View File

@ -0,0 +1,27 @@
import { startAppListening } from '..';
import { log } from 'app/logging/useLogger';
import { controlNetImageProcessed } from 'features/controlNet/store/actions';
import {
controlNetProcessorParamsChanged,
controlNetProcessorTypeChanged,
} from 'features/controlNet/store/controlNetSlice';
const moduleLog = log.child({ namespace: 'controlNet' });
export const addControlNetProcessorParamsChangedListener = () => {
startAppListening({
predicate: (action) =>
controlNetProcessorParamsChanged.match(action) ||
controlNetProcessorTypeChanged.match(action),
effect: async (action, { dispatch, cancelActiveListeners, delay }) => {
const { controlNetId } = action.payload;
// Cancel any in-progress instances of this listener
cancelActiveListeners();
// Delay before starting actual work
await delay(1000);
dispatch(controlNetImageProcessed({ controlNetId }));
},
});
};

View File

@ -1,51 +1,33 @@
import { memo, useCallback, useState } from 'react'; import { memo, useCallback } from 'react';
import { ControlNetProcessorNode } from '../store/types'; import { RequiredControlNetProcessorNode } from '../store/types';
import { ImageDTO } from 'services/api'; import { ImageDTO } from 'services/api';
import CannyProcessor from './processors/CannyProcessor'; import CannyProcessor from './processors/CannyProcessor';
import { import {
CONTROLNET_PROCESSORS,
ControlNet, ControlNet,
ControlNetModel,
ControlNetProcessor,
controlNetBeginStepPctChanged,
controlNetEndStepPctChanged,
controlNetImageChanged, controlNetImageChanged,
controlNetModelChanged,
controlNetProcessedImageChanged, controlNetProcessedImageChanged,
controlNetProcessorChanged,
controlNetRemoved, controlNetRemoved,
controlNetToggled,
controlNetWeightChanged,
isControlNetImageProcessedToggled,
} from '../store/controlNetSlice'; } from '../store/controlNetSlice';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
import IAISlider from 'common/components/IAISlider';
import ParamControlNetIsEnabled from './parameters/ParamControlNetIsEnabled';
import ParamControlNetModel from './parameters/ParamControlNetModel'; import ParamControlNetModel from './parameters/ParamControlNetModel';
import ParamControlNetWeight from './parameters/ParamControlNetWeight'; import ParamControlNetWeight from './parameters/ParamControlNetWeight';
import ParamControlNetBeginStepPct from './parameters/ParamControlNetBeginStepPct'; import ParamControlNetBeginStepPct from './parameters/ParamControlNetBeginStepPct';
import ParamControlNetEndStepPct from './parameters/ParamControlNetEndStepPct'; import ParamControlNetEndStepPct from './parameters/ParamControlNetEndStepPct';
import { import {
Box,
Flex, Flex,
HStack,
Tab, Tab,
TabList, TabList,
TabPanel, TabPanel,
TabPanels, TabPanels,
Tabs, Tabs,
VStack,
useDisclosure,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import IAISelectableImage from './parameters/IAISelectableImage'; import IAISelectableImage from './parameters/IAISelectableImage';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAIIconButton from 'common/components/IAIIconButton'; import { controlNetImageProcessed } from '../store/actions';
import IAISwitch from 'common/components/IAISwitch'; import { FaUndo } from 'react-icons/fa';
import ParamControlNetIsPreprocessed from './parameters/ParamControlNetIsPreprocessed'; import HedProcessor from './processors/HedProcessor';
import IAICollapse from 'common/components/IAICollapse'; import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
import ControlNetProcessorCollapse from './ControlNetProcessorCollapse'; import ProcessorComponent from './ProcessorComponent';
import IAICustomSelect from 'common/components/IAICustomSelect';
type ControlNetProps = { type ControlNetProps = {
controlNet: ControlNet; controlNet: ControlNet;
@ -62,22 +44,10 @@ const ControlNet = (props: ControlNetProps) => {
controlImage, controlImage,
isControlImageProcessed, isControlImageProcessed,
processedControlImage, processedControlImage,
processor, processorNode,
} = props.controlNet; } = props.controlNet;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const handleProcessorTypeChanged = useCallback(
(processor: string | null | undefined) => {
dispatch(
controlNetProcessorChanged({
controlNetId,
processor: processor as ControlNetProcessor,
})
);
},
[controlNetId, dispatch]
);
const handleControlImageChanged = useCallback( const handleControlImageChanged = useCallback(
(controlImage: ImageDTO) => { (controlImage: ImageDTO) => {
dispatch(controlNetImageChanged({ controlNetId, controlImage })); dispatch(controlNetImageChanged({ controlNetId, controlImage }));
@ -85,6 +55,23 @@ const ControlNet = (props: ControlNetProps) => {
[controlNetId, dispatch] [controlNetId, dispatch]
); );
const handleProcess = useCallback(() => {
dispatch(
controlNetImageProcessed({
controlNetId,
})
);
}, [controlNetId, dispatch]);
const handleReset = useCallback(() => {
dispatch(
controlNetProcessedImageChanged({
controlNetId,
processedControlImage: null,
})
);
}, [controlNetId, dispatch]);
const handleControlImageReset = useCallback(() => { const handleControlImageReset = useCallback(() => {
dispatch(controlNetImageChanged({ controlNetId, controlImage: null })); dispatch(controlNetImageChanged({ controlNetId, controlImage: null }));
}, [controlNetId, dispatch]); }, [controlNetId, dispatch]);
@ -137,18 +124,29 @@ const ControlNet = (props: ControlNetProps) => {
/> />
</TabPanel> </TabPanel>
<TabPanel sx={{ p: 0 }}> <TabPanel sx={{ p: 0 }}>
<IAICustomSelect <ParamControlNetProcessorSelect
label="Processor" controlNetId={controlNetId}
items={CONTROLNET_PROCESSORS} processorNode={processorNode}
selectedItem={processor}
setSelectedItem={handleProcessorTypeChanged}
/> />
<ProcessorComponent <ProcessorComponent
controlNetId={controlNetId} controlNetId={controlNetId}
controlImage={controlImage} processorNode={processorNode}
processedControlImage={processedControlImage}
type={processor}
/> />
<IAIButton
size="sm"
onClick={handleProcess}
isDisabled={Boolean(!controlImage)}
>
Preprocess
</IAIButton>
<IAIButton
size="sm"
leftIcon={<FaUndo />}
onClick={handleReset}
isDisabled={Boolean(!processedControlImage)}
>
Reset Processing
</IAIButton>
</TabPanel> </TabPanel>
</TabPanels> </TabPanels>
</Tabs> </Tabs>
@ -158,18 +156,3 @@ const ControlNet = (props: ControlNetProps) => {
}; };
export default memo(ControlNet); export default memo(ControlNet);
export type ControlNetProcessorProps = {
controlNetId: string;
controlImage: ImageDTO | null;
processedControlImage: ImageDTO | null;
type: ControlNetProcessor;
};
const ProcessorComponent = (props: ControlNetProcessorProps) => {
const { type } = props;
if (type === 'canny') {
return <CannyProcessor {...props} />;
}
return null;
};

View File

@ -1,76 +0,0 @@
// import { Collapse, Flex, useDisclosure } from '@chakra-ui/react';
// import { memo, useState } from 'react';
// import CannyProcessor from './processors/CannyProcessor';
// import { ImageDTO } from 'services/api';
// import IAICustomSelect from 'common/components/IAICustomSelect';
// import {
// CONTROLNET_PROCESSORS,
// ControlNetProcessor,
// } from '../store/controlNetSlice';
// import IAISwitch from 'common/components/IAISwitch';
// export type ControlNetProcessorProps = {
// controlNetId: string;
// controlImage: ImageDTO | null;
// processedControlImage: ImageDTO | null;
// type: ControlNetProcessor;
// };
// const ProcessorComponent = (props: ControlNetProcessorProps) => {
// const { type } = props;
// if (type === 'canny') {
// return <CannyProcessor {...props} />;
// }
// return null;
// };
// type ControlNetProcessorCollapseProps = {
// isOpen: boolean;
// controlNetId: string;
// controlImage: ImageDTO | null;
// processedControlImage: ImageDTO | null;
// };
// const ControlNetProcessorCollapse = (
// props: ControlNetProcessorCollapseProps
// ) => {
// const { isOpen, controlImage, controlNetId, processedControlImage } = props;
// const [processorType, setProcessorType] =
// useState<ControlNetProcessor>('canny');
// const handleProcessorTypeChanged = (type: string | null | undefined) => {
// setProcessorType(type as ControlNetProcessor);
// };
// return (
// <Flex
// sx={{
// gap: 2,
// p: 4,
// mt: 2,
// bg: 'base.850',
// borderRadius: 'base',
// flexDirection: 'column',
// }}
// >
// <IAICustomSelect
// label="Processor"
// items={CONTROLNET_PROCESSORS}
// selectedItem={processorType}
// setSelectedItem={handleProcessorTypeChanged}
// />
// {controlImage && (
// <ProcessorComponent
// controlNetId={controlNetId}
// controlImage={controlImage}
// processedControlImage={processedControlImage}
// type={processorType}
// />
// )}
// </Flex>
// );
// };
// export default memo(ControlNetProcessorCollapse);
export default {};

View File

@ -0,0 +1,131 @@
import { memo } from 'react';
import { RequiredControlNetProcessorNode } from '../store/types';
import CannyProcessor from './processors/CannyProcessor';
import HedProcessor from './processors/HedProcessor';
import LineartProcessor from './processors/LineartProcessor';
import LineartAnimeProcessor from './processors/LineartAnimeProcessor';
import ContentShuffleProcessor from './processors/ContentShuffleProcessor';
import MediapipeFaceProcessor from './processors/MediapipeFaceProcessor';
import MidasDepthProcessor from './processors/MidasDepthProcessor';
import MlsdImageProcessor from './processors/MlsdImageProcessor';
import NormalBaeProcessor from './processors/NormalBaeProcessor';
import OpenposeProcessor from './processors/OpenposeProcessor';
import PidiProcessor from './processors/PidiProcessor';
import ZoeDepthProcessor from './processors/ZoeDepthProcessor';
export type ControlNetProcessorProps = {
controlNetId: string;
processorNode: RequiredControlNetProcessorNode;
};
const ProcessorComponent = (props: ControlNetProcessorProps) => {
const { controlNetId, processorNode } = props;
if (processorNode.type === 'canny_image_processor') {
return (
<CannyProcessor
controlNetId={controlNetId}
processorNode={processorNode}
/>
);
}
if (processorNode.type === 'hed_image_processor') {
return (
<HedProcessor controlNetId={controlNetId} processorNode={processorNode} />
);
}
if (processorNode.type === 'lineart_image_processor') {
return (
<LineartProcessor
controlNetId={controlNetId}
processorNode={processorNode}
/>
);
}
if (processorNode.type === 'content_shuffle_image_processor') {
return (
<ContentShuffleProcessor
controlNetId={controlNetId}
processorNode={processorNode}
/>
);
}
if (processorNode.type === 'lineart_anime_image_processor') {
return (
<LineartAnimeProcessor
controlNetId={controlNetId}
processorNode={processorNode}
/>
);
}
if (processorNode.type === 'mediapipe_face_processor') {
return (
<MediapipeFaceProcessor
controlNetId={controlNetId}
processorNode={processorNode}
/>
);
}
if (processorNode.type === 'midas_depth_image_processor') {
return (
<MidasDepthProcessor
controlNetId={controlNetId}
processorNode={processorNode}
/>
);
}
if (processorNode.type === 'mlsd_image_processor') {
return (
<MlsdImageProcessor
controlNetId={controlNetId}
processorNode={processorNode}
/>
);
}
if (processorNode.type === 'normalbae_image_processor') {
return (
<NormalBaeProcessor
controlNetId={controlNetId}
processorNode={processorNode}
/>
);
}
if (processorNode.type === 'openpose_image_processor') {
return (
<OpenposeProcessor
controlNetId={controlNetId}
processorNode={processorNode}
/>
);
}
if (processorNode.type === 'pidi_image_processor') {
return (
<PidiProcessor
controlNetId={controlNetId}
processorNode={processorNode}
/>
);
}
if (processorNode.type === 'zoe_depth_image_processor') {
return (
<ZoeDepthProcessor
controlNetId={controlNetId}
processorNode={processorNode}
/>
);
}
return null;
};
export default memo(ProcessorComponent);

View File

@ -0,0 +1,20 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { controlNetProcessorParamsChanged } from 'features/controlNet/store/controlNetSlice';
import { ControlNetProcessorNode } from 'features/controlNet/store/types';
import { useCallback } from 'react';
export const useProcessorNodeChanged = () => {
const dispatch = useAppDispatch();
const handleProcessorNodeChanged = useCallback(
(controlNetId: string, changes: Partial<ControlNetProcessorNode>) => {
dispatch(
controlNetProcessorParamsChanged({
controlNetId,
changes,
})
);
},
[dispatch]
);
return handleProcessorNodeChanged;
};

View File

@ -0,0 +1,46 @@
import IAICustomSelect from 'common/components/IAICustomSelect';
import { memo, useCallback } from 'react';
import {
ControlNetProcessorNode,
ControlNetProcessorType,
} from '../../store/types';
import { controlNetProcessorTypeChanged } from '../../store/controlNetSlice';
import { useAppDispatch } from 'app/store/storeHooks';
import { CONTROLNET_PROCESSORS } from '../../store/constants';
type ParamControlNetProcessorSelectProps = {
controlNetId: string;
processorNode: ControlNetProcessorNode;
};
const CONTROLNET_PROCESSOR_TYPES = Object.keys(
CONTROLNET_PROCESSORS
) as ControlNetProcessorType[];
const ParamControlNetProcessorSelect = (
props: ParamControlNetProcessorSelectProps
) => {
const { controlNetId, processorNode } = props;
const dispatch = useAppDispatch();
const handleProcessorTypeChanged = useCallback(
(v: string | null | undefined) => {
dispatch(
controlNetProcessorTypeChanged({
controlNetId,
processorType: v as ControlNetProcessorType,
})
);
},
[controlNetId, dispatch]
);
return (
<IAICustomSelect
label="Processor"
items={CONTROLNET_PROCESSOR_TYPES}
selectedItem={processorNode.type ?? 'canny_image_processor'}
setSelectedItem={handleProcessorTypeChanged}
/>
);
};
export default memo(ParamControlNetProcessorSelect);

View File

@ -1,75 +1,61 @@
import { Flex } from '@chakra-ui/react'; import { Flex } from '@chakra-ui/react';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import { useAppDispatch } from 'app/store/storeHooks'; import { memo, useCallback } from 'react';
import { controlNetImageProcessed } from 'features/controlNet/store/actions'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import { controlNetProcessedImageChanged } from 'features/controlNet/store/controlNetSlice'; import { RequiredCannyImageProcessorInvocation } from 'features/controlNet/store/types';
import ControlNetProcessorButtons from './common/ControlNetProcessorButtons';
import { memo, useCallback, useState } from 'react';
import { ControlNetProcessorProps } from '../ControlNet';
export const CANNY_PROCESSOR = 'canny_image_processor'; type CannyProcessorProps = {
controlNetId: string;
processorNode: RequiredCannyImageProcessorInvocation;
};
const CannyProcessor = (props: ControlNetProcessorProps) => { const CannyProcessor = (props: CannyProcessorProps) => {
const { controlNetId, controlImage, processedControlImage, type } = props; const { controlNetId, processorNode } = props;
const dispatch = useAppDispatch(); const { low_threshold, high_threshold } = processorNode;
const [lowThreshold, setLowThreshold] = useState(100); const processorChanged = useProcessorNodeChanged();
const [highThreshold, setHighThreshold] = useState(200);
const handleProcess = useCallback(() => { const handleLowThresholdChanged = useCallback(
if (!controlImage) { (v: number) => {
return; processorChanged(controlNetId, { low_threshold: v });
}
dispatch(
controlNetImageProcessed({
controlNetId,
processorNode: {
id: CANNY_PROCESSOR,
type: 'canny_image_processor',
image: {
image_name: controlImage.image_name,
image_origin: controlImage.image_origin,
}, },
low_threshold: lowThreshold, [controlNetId, processorChanged]
high_threshold: highThreshold,
},
})
); );
}, [controlNetId, dispatch, highThreshold, controlImage, lowThreshold]);
const handleReset = useCallback(() => { const handleLowThresholdReset = useCallback(() => {
dispatch( processorChanged(controlNetId, { low_threshold: 100 });
controlNetProcessedImageChanged({ }, [controlNetId, processorChanged]);
controlNetId,
processedControlImage: null, const handleHighThresholdChanged = useCallback(
}) (v: number) => {
processorChanged(controlNetId, { high_threshold: v });
},
[controlNetId, processorChanged]
); );
}, [controlNetId, dispatch]);
const handleHighThresholdReset = useCallback(() => {
processorChanged(controlNetId, { high_threshold: 200 });
}, [controlNetId, processorChanged]);
return ( return (
<Flex sx={{ flexDirection: 'column', gap: 2 }}> <Flex sx={{ flexDirection: 'column', gap: 2 }}>
<IAISlider <IAISlider
label="Low Threshold" label="Low Threshold"
value={lowThreshold} value={low_threshold}
onChange={setLowThreshold} onChange={handleLowThresholdChanged}
handleReset={handleLowThresholdReset}
min={0} min={0}
max={255} max={255}
withInput withInput
/> />
<IAISlider <IAISlider
label="High Threshold" label="High Threshold"
value={highThreshold} value={high_threshold}
onChange={setHighThreshold} onChange={handleHighThresholdChanged}
handleReset={handleHighThresholdReset}
min={0} min={0}
max={255} max={255}
withInput withInput
/> />
<ControlNetProcessorButtons
handleProcess={handleProcess}
isProcessDisabled={Boolean(!controlImage)}
handleReset={handleReset}
isResetDisabled={Boolean(!processedControlImage)}
/>
</Flex> </Flex>
); );
}; };

View File

@ -0,0 +1,98 @@
import { Flex } from '@chakra-ui/react';
import IAISlider from 'common/components/IAISlider';
import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import { RequiredContentShuffleImageProcessorInvocation } from 'features/controlNet/store/types';
type Props = {
controlNetId: string;
processorNode: RequiredContentShuffleImageProcessorInvocation;
};
const ContentShuffleProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { image_resolution, detect_resolution, w, h, f } = processorNode;
const processorChanged = useProcessorNodeChanged();
const handleDetectResolutionChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { detect_resolution: v });
},
[controlNetId, processorChanged]
);
const handleImageResolutionChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { image_resolution: v });
},
[controlNetId, processorChanged]
);
const handleWChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { w: v });
},
[controlNetId, processorChanged]
);
const handleHChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { h: v });
},
[controlNetId, processorChanged]
);
const handleFChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { f: v });
},
[controlNetId, processorChanged]
);
return (
<Flex sx={{ flexDirection: 'column', gap: 2 }}>
<IAISlider
label="Detect Resolution"
value={detect_resolution}
onChange={handleDetectResolutionChanged}
min={0}
max={4096}
withInput
/>
<IAISlider
label="Image Resolution"
value={image_resolution}
onChange={handleImageResolutionChanged}
min={0}
max={4096}
withInput
/>
<IAISlider
label="W"
value={w}
onChange={handleWChanged}
min={0}
max={4096}
withInput
/>
<IAISlider
label="H"
value={h}
onChange={handleHChanged}
min={0}
max={4096}
withInput
/>
<IAISlider
label="F"
value={f}
onChange={handleFChanged}
min={0}
max={4096}
withInput
/>
</Flex>
);
};
export default memo(ContentShuffleProcessor);

View File

@ -1,39 +1,66 @@
import { Flex } from '@chakra-ui/react'; import { Flex } from '@chakra-ui/react';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import IAISwitch from 'common/components/IAISwitch'; import IAISwitch from 'common/components/IAISwitch';
import { ChangeEvent, memo, useState } from 'react'; import { ChangeEvent, memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import { RequiredHedImageProcessorInvocation } from 'features/controlNet/store/types';
const HedPreprocessor = () => { type HedProcessorProps = {
const [detectResolution, setDetectResolution] = useState(512); controlNetId: string;
const [imageResolution, setImageResolution] = useState(512); processorNode: RequiredHedImageProcessorInvocation;
const [isScribbleEnabled, setIsScribbleEnabled] = useState(false);
const handleChangeScribble = (e: ChangeEvent<HTMLInputElement>) => {
setIsScribbleEnabled(e.target.checked);
}; };
const HedPreprocessor = (props: HedProcessorProps) => {
const {
controlNetId,
processorNode: { detect_resolution, image_resolution, scribble },
} = props;
const processorChanged = useProcessorNodeChanged();
const handleDetectResolutionChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { detect_resolution: v });
},
[controlNetId, processorChanged]
);
const handleImageResolutionChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { image_resolution: v });
},
[controlNetId, processorChanged]
);
const handleScribbleChanged = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
processorChanged(controlNetId, { scribble: e.target.checked });
},
[controlNetId, processorChanged]
);
return ( return (
<Flex sx={{ flexDirection: 'column', gap: 2 }}> <Flex sx={{ flexDirection: 'column', gap: 2 }}>
<IAISlider <IAISlider
label="Detect Resolution" label="Detect Resolution"
value={detectResolution} value={detect_resolution}
onChange={setDetectResolution} onChange={handleDetectResolutionChanged}
min={0} min={0}
max={4096} max={4096}
withInput withInput
/> />
<IAISlider <IAISlider
label="Image Resolution" label="Image Resolution"
value={imageResolution} value={image_resolution}
onChange={setImageResolution} onChange={handleImageResolutionChanged}
min={0} min={0}
max={4096} max={4096}
withInput withInput
/> />
<IAISwitch <IAISwitch
label="Scribble" label="Scribble"
isChecked={isScribbleEnabled} isChecked={scribble}
onChange={handleChangeScribble} onChange={handleScribbleChanged}
/> />
</Flex> </Flex>
); );

View File

@ -1,25 +1,47 @@
import { Flex } from '@chakra-ui/react'; import { Flex } from '@chakra-ui/react';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import { memo, useState } from 'react'; import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import { RequiredLineartAnimeImageProcessorInvocation } from 'features/controlNet/store/types';
const LineartPreprocessor = () => { type Props = {
const [detectResolution, setDetectResolution] = useState(512); controlNetId: string;
const [imageResolution, setImageResolution] = useState(512); processorNode: RequiredLineartAnimeImageProcessorInvocation;
};
const LineartAnimeProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { image_resolution, detect_resolution } = processorNode;
const processorChanged = useProcessorNodeChanged();
const handleDetectResolutionChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { detect_resolution: v });
},
[controlNetId, processorChanged]
);
const handleImageResolutionChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { image_resolution: v });
},
[controlNetId, processorChanged]
);
return ( return (
<Flex sx={{ flexDirection: 'column', gap: 2 }}> <Flex sx={{ flexDirection: 'column', gap: 2 }}>
<IAISlider <IAISlider
label="Detect Resolution" label="Detect Resolution"
value={detectResolution} value={detect_resolution}
onChange={setDetectResolution} onChange={handleDetectResolutionChanged}
min={0} min={0}
max={4096} max={4096}
withInput withInput
/> />
<IAISlider <IAISlider
label="Image Resolution" label="Image Resolution"
value={imageResolution} value={image_resolution}
onChange={setImageResolution} onChange={handleImageResolutionChanged}
min={0} min={0}
max={4096} max={4096}
withInput withInput
@ -28,4 +50,4 @@ const LineartPreprocessor = () => {
); );
}; };
export default memo(LineartPreprocessor); export default memo(LineartAnimeProcessor);

View File

@ -1,42 +1,66 @@
import { Flex } from '@chakra-ui/react'; import { Flex } from '@chakra-ui/react';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import { ChangeEvent, memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import { RequiredLineartImageProcessorInvocation } from 'features/controlNet/store/types';
import IAISwitch from 'common/components/IAISwitch'; import IAISwitch from 'common/components/IAISwitch';
import { ChangeEvent, memo, useState } from 'react';
const LineartPreprocessor = () => { type LineartProcessorProps = {
const [detectResolution, setDetectResolution] = useState(512); controlNetId: string;
const [imageResolution, setImageResolution] = useState(512); processorNode: RequiredLineartImageProcessorInvocation;
const [isCoarseEnabled, setIsCoarseEnabled] = useState(false);
const handleChangeScribble = (e: ChangeEvent<HTMLInputElement>) => {
setIsCoarseEnabled(e.target.checked);
}; };
const LineartProcessor = (props: LineartProcessorProps) => {
const { controlNetId, processorNode } = props;
const { image_resolution, detect_resolution, coarse } = processorNode;
const processorChanged = useProcessorNodeChanged();
const handleDetectResolutionChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { detect_resolution: v });
},
[controlNetId, processorChanged]
);
const handleImageResolutionChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { image_resolution: v });
},
[controlNetId, processorChanged]
);
const handleCoarseChanged = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
processorChanged(controlNetId, { coarse: e.target.checked });
},
[controlNetId, processorChanged]
);
return ( return (
<Flex sx={{ flexDirection: 'column', gap: 2 }}> <Flex sx={{ flexDirection: 'column', gap: 2 }}>
<IAISlider <IAISlider
label="Detect Resolution" label="Detect Resolution"
value={detectResolution} value={detect_resolution}
onChange={setDetectResolution} onChange={handleDetectResolutionChanged}
min={0} min={0}
max={4096} max={4096}
withInput withInput
/> />
<IAISlider <IAISlider
label="Image Resolution" label="Image Resolution"
value={imageResolution} value={image_resolution}
onChange={setImageResolution} onChange={handleImageResolutionChanged}
min={0} min={0}
max={4096} max={4096}
withInput withInput
/> />
<IAISwitch <IAISwitch
label="Coarse" label="Coarse"
isChecked={isCoarseEnabled} isChecked={coarse}
onChange={handleChangeScribble} onChange={handleCoarseChanged}
/> />
</Flex> </Flex>
); );
}; };
export default memo(LineartPreprocessor); export default memo(LineartProcessor);

View File

@ -0,0 +1,57 @@
import { Flex } from '@chakra-ui/react';
import IAISlider from 'common/components/IAISlider';
import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import {
RequiredContentShuffleImageProcessorInvocation,
RequiredMediapipeFaceProcessorInvocation,
} from 'features/controlNet/store/types';
type Props = {
controlNetId: string;
processorNode: RequiredMediapipeFaceProcessorInvocation;
};
const MediapipeFaceProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { max_faces, min_confidence } = processorNode;
const processorChanged = useProcessorNodeChanged();
const handleMaxFacesChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { max_faces: v });
},
[controlNetId, processorChanged]
);
const handleMinConfidenceChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { min_confidence: v });
},
[controlNetId, processorChanged]
);
return (
<Flex sx={{ flexDirection: 'column', gap: 2 }}>
<IAISlider
label="Max Faces"
value={max_faces}
onChange={handleMaxFacesChanged}
min={1}
max={20}
withInput
/>
<IAISlider
label="Min Confidence"
value={min_confidence}
onChange={handleMinConfidenceChanged}
min={0}
max={1}
step={0.01}
withInput
/>
</Flex>
);
};
export default memo(MediapipeFaceProcessor);

View File

@ -0,0 +1,55 @@
import { Flex } from '@chakra-ui/react';
import IAISlider from 'common/components/IAISlider';
import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import { RequiredMidasDepthImageProcessorInvocation } from 'features/controlNet/store/types';
type Props = {
controlNetId: string;
processorNode: RequiredMidasDepthImageProcessorInvocation;
};
const MidasDepthProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { a_mult, bg_th } = processorNode;
const processorChanged = useProcessorNodeChanged();
const handleAMultChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { a_mult: v });
},
[controlNetId, processorChanged]
);
const handleBgThChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { bg_th: v });
},
[controlNetId, processorChanged]
);
return (
<Flex sx={{ flexDirection: 'column', gap: 2 }}>
<IAISlider
label="a_mult"
value={a_mult}
onChange={handleAMultChanged}
min={0}
max={20}
step={0.01}
withInput
/>
<IAISlider
label="bg_th"
value={bg_th}
onChange={handleBgThChanged}
min={0}
max={20}
step={0.01}
withInput
/>
</Flex>
);
};
export default memo(MidasDepthProcessor);

View File

@ -0,0 +1,85 @@
import { Flex } from '@chakra-ui/react';
import IAISlider from 'common/components/IAISlider';
import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import { RequiredMlsdImageProcessorInvocation } from 'features/controlNet/store/types';
type Props = {
controlNetId: string;
processorNode: RequiredMlsdImageProcessorInvocation;
};
const MlsdImageProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { image_resolution, detect_resolution, thr_d, thr_v } = processorNode;
const processorChanged = useProcessorNodeChanged();
const handleDetectResolutionChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { detect_resolution: v });
},
[controlNetId, processorChanged]
);
const handleImageResolutionChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { image_resolution: v });
},
[controlNetId, processorChanged]
);
const handleThrDChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { thr_d: v });
},
[controlNetId, processorChanged]
);
const handleThrVChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { thr_v: v });
},
[controlNetId, processorChanged]
);
return (
<Flex sx={{ flexDirection: 'column', gap: 2 }}>
<IAISlider
label="Detect Resolution"
value={detect_resolution}
onChange={handleDetectResolutionChanged}
min={0}
max={4096}
withInput
/>
<IAISlider
label="Image Resolution"
value={image_resolution}
onChange={handleImageResolutionChanged}
min={0}
max={4096}
withInput
/>
<IAISlider
label="W"
value={thr_d}
onChange={handleThrDChanged}
min={0}
max={1}
step={0.01}
withInput
/>
<IAISlider
label="H"
value={thr_v}
onChange={handleThrVChanged}
min={0}
max={1}
step={0.01}
withInput
/>
</Flex>
);
};
export default memo(MlsdImageProcessor);

View File

@ -0,0 +1,53 @@
import { Flex } from '@chakra-ui/react';
import IAISlider from 'common/components/IAISlider';
import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import { RequiredNormalbaeImageProcessorInvocation } from 'features/controlNet/store/types';
type Props = {
controlNetId: string;
processorNode: RequiredNormalbaeImageProcessorInvocation;
};
const NormalBaeProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { image_resolution, detect_resolution } = processorNode;
const processorChanged = useProcessorNodeChanged();
const handleDetectResolutionChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { detect_resolution: v });
},
[controlNetId, processorChanged]
);
const handleImageResolutionChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { image_resolution: v });
},
[controlNetId, processorChanged]
);
return (
<Flex sx={{ flexDirection: 'column', gap: 2 }}>
<IAISlider
label="Detect Resolution"
value={detect_resolution}
onChange={handleDetectResolutionChanged}
min={0}
max={4096}
withInput
/>
<IAISlider
label="Image Resolution"
value={image_resolution}
onChange={handleImageResolutionChanged}
min={0}
max={4096}
withInput
/>
</Flex>
);
};
export default memo(NormalBaeProcessor);

View File

@ -0,0 +1,66 @@
import { Flex } from '@chakra-ui/react';
import IAISlider from 'common/components/IAISlider';
import { ChangeEvent, memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import { RequiredOpenposeImageProcessorInvocation } from 'features/controlNet/store/types';
import IAISwitch from 'common/components/IAISwitch';
type Props = {
controlNetId: string;
processorNode: RequiredOpenposeImageProcessorInvocation;
};
const OpenposeProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { image_resolution, detect_resolution, hand_and_face } = processorNode;
const processorChanged = useProcessorNodeChanged();
const handleDetectResolutionChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { detect_resolution: v });
},
[controlNetId, processorChanged]
);
const handleImageResolutionChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { image_resolution: v });
},
[controlNetId, processorChanged]
);
const handleHandAndFaceChanged = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
processorChanged(controlNetId, { hand_and_face: e.target.checked });
},
[controlNetId, processorChanged]
);
return (
<Flex sx={{ flexDirection: 'column', gap: 2 }}>
<IAISlider
label="Detect Resolution"
value={detect_resolution}
onChange={handleDetectResolutionChanged}
min={0}
max={4096}
withInput
/>
<IAISlider
label="Image Resolution"
value={image_resolution}
onChange={handleImageResolutionChanged}
min={0}
max={4096}
withInput
/>
<IAISwitch
label="Hand and Face"
isChecked={hand_and_face}
onChange={handleHandAndFaceChanged}
/>
</Flex>
);
};
export default memo(OpenposeProcessor);

View File

@ -0,0 +1,74 @@
import { Flex } from '@chakra-ui/react';
import IAISlider from 'common/components/IAISlider';
import { ChangeEvent, memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import { RequiredPidiImageProcessorInvocation } from 'features/controlNet/store/types';
import IAISwitch from 'common/components/IAISwitch';
type Props = {
controlNetId: string;
processorNode: RequiredPidiImageProcessorInvocation;
};
const PidiProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { image_resolution, detect_resolution, scribble, safe } = processorNode;
const processorChanged = useProcessorNodeChanged();
const handleDetectResolutionChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { detect_resolution: v });
},
[controlNetId, processorChanged]
);
const handleImageResolutionChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { image_resolution: v });
},
[controlNetId, processorChanged]
);
const handleScribbleChanged = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
processorChanged(controlNetId, { scribble: e.target.checked });
},
[controlNetId, processorChanged]
);
const handleSafeChanged = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
processorChanged(controlNetId, { safe: e.target.checked });
},
[controlNetId, processorChanged]
);
return (
<Flex sx={{ flexDirection: 'column', gap: 2 }}>
<IAISlider
label="Detect Resolution"
value={detect_resolution}
onChange={handleDetectResolutionChanged}
min={0}
max={4096}
withInput
/>
<IAISlider
label="Image Resolution"
value={image_resolution}
onChange={handleImageResolutionChanged}
min={0}
max={4096}
withInput
/>
<IAISwitch
label="Scribble"
isChecked={scribble}
onChange={handleScribbleChanged}
/>
<IAISwitch label="Safe" isChecked={safe} onChange={handleSafeChanged} />
</Flex>
);
};
export default memo(PidiProcessor);

View File

@ -0,0 +1,14 @@
import { memo } from 'react';
import { RequiredZoeDepthImageProcessorInvocation } from 'features/controlNet/store/types';
type Props = {
controlNetId: string;
processorNode: RequiredZoeDepthImageProcessorInvocation;
};
const ZoeDepthProcessor = (props: Props) => {
// Has no parameters?
return null;
};
export default memo(ZoeDepthProcessor);

View File

@ -21,23 +21,7 @@ const ControlNetProcessorButtons = (props: ControlNetProcessorButtonsProps) => {
alignItems: 'center', alignItems: 'center',
justifyContent: 'stretch', justifyContent: 'stretch',
}} }}
> ></Flex>
<IAIButton
size="sm"
onClick={handleProcess}
isDisabled={isProcessDisabled}
>
Preprocess
</IAIButton>
<IAIButton
size="sm"
leftIcon={<FaUndo />}
onClick={handleReset}
isDisabled={isResetDisabled}
>
Reset Processing
</IAIButton>
</Flex>
); );
}; };

View File

@ -1,7 +1,5 @@
import { createAction } from '@reduxjs/toolkit'; import { createAction } from '@reduxjs/toolkit';
import { ControlNetProcessorNode } from './types';
export const controlNetImageProcessed = createAction<{ export const controlNetImageProcessed = createAction<{
controlNetId: string; controlNetId: string;
processorNode: ControlNetProcessorNode;
}>('controlNet/imageProcessed'); }>('controlNet/imageProcessed');

View File

@ -0,0 +1,166 @@
import {
ControlNetProcessorType,
RequiredControlNetProcessorNode,
} from './types';
type ControlNetProcessorsDict = Record<
ControlNetProcessorType,
{
type: ControlNetProcessorType;
label: string;
description: string;
default: RequiredControlNetProcessorNode;
}
>;
/**
* A dict of ControlNet processors, including:
* - type
* - label
* - description
* - default values
*
* TODO: Generate from the OpenAPI schema
*/
export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = {
canny_image_processor: {
type: 'canny_image_processor',
label: 'Canny',
description: '',
default: {
id: 'canny_image_processor',
type: 'canny_image_processor',
low_threshold: 100,
high_threshold: 200,
},
},
content_shuffle_image_processor: {
type: 'content_shuffle_image_processor',
label: 'Content Shuffle',
description: '',
default: {
id: 'content_shuffle_image_processor',
type: 'content_shuffle_image_processor',
detect_resolution: 512,
image_resolution: 512,
h: 512,
w: 512,
f: 256,
},
},
hed_image_processor: {
type: 'hed_image_processor',
label: 'HED',
description: '',
default: {
id: 'hed_image_processor',
type: 'hed_image_processor',
detect_resolution: 512,
image_resolution: 512,
scribble: false,
},
},
lineart_anime_image_processor: {
type: 'lineart_anime_image_processor',
label: 'Lineart Anime',
description: '',
default: {
id: 'lineart_anime_image_processor',
type: 'lineart_anime_image_processor',
detect_resolution: 512,
image_resolution: 512,
},
},
lineart_image_processor: {
type: 'lineart_image_processor',
label: 'Lineart',
description: '',
default: {
id: 'lineart_image_processor',
type: 'lineart_image_processor',
detect_resolution: 512,
image_resolution: 512,
coarse: false,
},
},
mediapipe_face_processor: {
type: 'mediapipe_face_processor',
label: 'Mediapipe Face',
description: '',
default: {
id: 'mediapipe_face_processor',
type: 'mediapipe_face_processor',
max_faces: 1,
min_confidence: 0.5,
},
},
midas_depth_image_processor: {
type: 'midas_depth_image_processor',
label: 'Depth (Midas)',
description: '',
default: {
id: 'midas_depth_image_processor',
type: 'midas_depth_image_processor',
a_mult: 2,
bg_th: 0.1,
},
},
mlsd_image_processor: {
type: 'mlsd_image_processor',
label: 'MLSD',
description: '',
default: {
id: 'mlsd_image_processor',
type: 'mlsd_image_processor',
detect_resolution: 512,
image_resolution: 512,
thr_d: 0.1,
thr_v: 0.1,
},
},
normalbae_image_processor: {
type: 'normalbae_image_processor',
label: 'NormalBae',
description: '',
default: {
id: 'normalbae_image_processor',
type: 'normalbae_image_processor',
detect_resolution: 512,
image_resolution: 512,
},
},
openpose_image_processor: {
type: 'openpose_image_processor',
label: 'Openpose',
description: '',
default: {
id: 'openpose_image_processor',
type: 'openpose_image_processor',
detect_resolution: 512,
image_resolution: 512,
hand_and_face: false,
},
},
pidi_image_processor: {
type: 'pidi_image_processor',
label: 'PIDI',
description: '',
default: {
id: 'pidi_image_processor',
type: 'pidi_image_processor',
detect_resolution: 512,
image_resolution: 512,
scribble: false,
safe: false,
},
},
zoe_depth_image_processor: {
type: 'zoe_depth_image_processor',
label: 'Depth (Zoe)',
description: '',
default: {
id: 'zoe_depth_image_processor',
type: 'zoe_depth_image_processor',
},
},
};

View File

@ -1,11 +1,12 @@
import { import { PayloadAction } from '@reduxjs/toolkit';
$CombinedState,
PayloadAction,
createSelector,
} from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { ImageDTO } from 'services/api'; import { ImageDTO } from 'services/api';
import {
ControlNetProcessorType,
RequiredControlNetProcessorNode,
} from './types';
import { CONTROLNET_PROCESSORS } from './constants';
export const CONTROLNET_MODELS = [ export const CONTROLNET_MODELS = [
'lllyasviel/sd-controlnet-canny', 'lllyasviel/sd-controlnet-canny',
@ -18,23 +19,6 @@ export const CONTROLNET_MODELS = [
'lllyasviel/sd-controlnet-mlsd', 'lllyasviel/sd-controlnet-mlsd',
]; ];
export const CONTROLNET_PROCESSORS = [
'canny',
'contentShuffle',
'hed',
'lineart',
'lineartAnime',
'mediapipeFace',
'midasDepth',
'mlsd',
'normalBae',
'openpose',
'pidi',
'zoeDepth',
];
export type ControlNetProcessor = (typeof CONTROLNET_PROCESSORS)[number];
export type ControlNetModel = (typeof CONTROLNET_MODELS)[number]; export type ControlNetModel = (typeof CONTROLNET_MODELS)[number];
export const initialControlNet: Omit<ControlNet, 'controlNetId'> = { export const initialControlNet: Omit<ControlNet, 'controlNetId'> = {
@ -46,7 +30,7 @@ export const initialControlNet: Omit<ControlNet, 'controlNetId'> = {
controlImage: null, controlImage: null,
isControlImageProcessed: false, isControlImageProcessed: false,
processedControlImage: null, processedControlImage: null,
processor: 'canny', processorNode: CONTROLNET_PROCESSORS.canny_image_processor.default,
}; };
export type ControlNet = { export type ControlNet = {
@ -59,17 +43,19 @@ export type ControlNet = {
controlImage: ImageDTO | null; controlImage: ImageDTO | null;
isControlImageProcessed: boolean; isControlImageProcessed: boolean;
processedControlImage: ImageDTO | null; processedControlImage: ImageDTO | null;
processor: ControlNetProcessor; processorNode: RequiredControlNetProcessorNode;
}; };
export type ControlNetState = { export type ControlNetState = {
controlNets: Record<string, ControlNet>; controlNets: Record<string, ControlNet>;
isEnabled: boolean; isEnabled: boolean;
shouldAutoProcess: boolean;
}; };
export const initialControlNetState: ControlNetState = { export const initialControlNetState: ControlNetState = {
controlNets: {}, controlNets: {},
isEnabled: false, isEnabled: false,
shouldAutoProcess: true,
}; };
export const controlNetSlice = createSlice({ export const controlNetSlice = createSlice({
@ -169,15 +155,36 @@ export const controlNetSlice = createSlice({
const { controlNetId, endStepPct } = action.payload; const { controlNetId, endStepPct } = action.payload;
state.controlNets[controlNetId].endStepPct = endStepPct; state.controlNets[controlNetId].endStepPct = endStepPct;
}, },
controlNetProcessorChanged: ( controlNetProcessorParamsChanged: (
state, state,
action: PayloadAction<{ action: PayloadAction<{
controlNetId: string; controlNetId: string;
processor: ControlNetProcessor; changes: Omit<
Partial<RequiredControlNetProcessorNode>,
'id' | 'type' | 'is_intermediate'
>;
}> }>
) => { ) => {
const { controlNetId, processor } = action.payload; const { controlNetId, changes } = action.payload;
state.controlNets[controlNetId].processor = processor; const processorNode = state.controlNets[controlNetId].processorNode;
state.controlNets[controlNetId].processorNode = {
...processorNode,
...changes,
};
},
controlNetProcessorTypeChanged: (
state,
action: PayloadAction<{
controlNetId: string;
processorType: ControlNetProcessorType;
}>
) => {
const { controlNetId, processorType } = action.payload;
state.controlNets[controlNetId].processorNode =
CONTROLNET_PROCESSORS[processorType].default;
},
shouldAutoProcessToggled: (state) => {
state.shouldAutoProcess = !state.shouldAutoProcess;
}, },
}, },
}); });
@ -195,7 +202,9 @@ export const {
controlNetWeightChanged, controlNetWeightChanged,
controlNetBeginStepPctChanged, controlNetBeginStepPctChanged,
controlNetEndStepPctChanged, controlNetEndStepPctChanged,
controlNetProcessorChanged, controlNetProcessorParamsChanged,
controlNetProcessorTypeChanged,
shouldAutoProcessToggled,
} = controlNetSlice.actions; } = controlNetSlice.actions;
export default controlNetSlice.reducer; export default controlNetSlice.reducer;

View File

@ -1,7 +1,8 @@
import { isObject } from 'lodash-es';
import { import {
CannyImageProcessorInvocation, CannyImageProcessorInvocation,
ContentShuffleImageProcessorInvocation, ContentShuffleImageProcessorInvocation,
HedImageprocessorInvocation, HedImageProcessorInvocation,
LineartAnimeImageProcessorInvocation, LineartAnimeImageProcessorInvocation,
LineartImageProcessorInvocation, LineartImageProcessorInvocation,
MediapipeFaceProcessorInvocation, MediapipeFaceProcessorInvocation,
@ -12,17 +13,317 @@ import {
PidiImageProcessorInvocation, PidiImageProcessorInvocation,
ZoeDepthImageProcessorInvocation, ZoeDepthImageProcessorInvocation,
} from 'services/api'; } from 'services/api';
import { O } from 'ts-toolbelt';
/**
* Any ControlNet processor node
*/
export type ControlNetProcessorNode = export type ControlNetProcessorNode =
| CannyImageProcessorInvocation | CannyImageProcessorInvocation
| HedImageprocessorInvocation
| LineartImageProcessorInvocation
| LineartAnimeImageProcessorInvocation
| OpenposeImageProcessorInvocation
| MidasDepthImageProcessorInvocation
| NormalbaeImageProcessorInvocation
| MlsdImageProcessorInvocation
| PidiImageProcessorInvocation
| ContentShuffleImageProcessorInvocation | ContentShuffleImageProcessorInvocation
| ZoeDepthImageProcessorInvocation | HedImageProcessorInvocation
| MediapipeFaceProcessorInvocation; | LineartAnimeImageProcessorInvocation
| LineartImageProcessorInvocation
| MediapipeFaceProcessorInvocation
| MidasDepthImageProcessorInvocation
| MlsdImageProcessorInvocation
| NormalbaeImageProcessorInvocation
| OpenposeImageProcessorInvocation
| PidiImageProcessorInvocation
| ZoeDepthImageProcessorInvocation;
/**
* Any ControlNet processor type
*/
export type ControlNetProcessorType = NonNullable<
ControlNetProcessorNode['type']
>;
/**
* The Canny processor node, with parameters flagged as required
*/
export type RequiredCannyImageProcessorInvocation = O.Required<
CannyImageProcessorInvocation,
'type' | 'low_threshold' | 'high_threshold'
>;
/**
* The ContentShuffle processor node, with parameters flagged as required
*/
export type RequiredContentShuffleImageProcessorInvocation = O.Required<
ContentShuffleImageProcessorInvocation,
'type' | 'detect_resolution' | 'image_resolution' | 'w' | 'h' | 'f'
>;
/**
* The HED processor node, with parameters flagged as required
*/
export type RequiredHedImageProcessorInvocation = O.Required<
HedImageProcessorInvocation,
'type' | 'detect_resolution' | 'image_resolution' | 'scribble'
>;
/**
* The Lineart Anime processor node, with parameters flagged as required
*/
export type RequiredLineartAnimeImageProcessorInvocation = O.Required<
LineartAnimeImageProcessorInvocation,
'type' | 'detect_resolution' | 'image_resolution'
>;
/**
* The Lineart processor node, with parameters flagged as required
*/
export type RequiredLineartImageProcessorInvocation = O.Required<
LineartImageProcessorInvocation,
'type' | 'detect_resolution' | 'image_resolution' | 'coarse'
>;
/**
* The MediapipeFace processor node, with parameters flagged as required
*/
export type RequiredMediapipeFaceProcessorInvocation = O.Required<
MediapipeFaceProcessorInvocation,
'type' | 'max_faces' | 'min_confidence'
>;
/**
* The MidasDepth processor node, with parameters flagged as required
*/
export type RequiredMidasDepthImageProcessorInvocation = O.Required<
MidasDepthImageProcessorInvocation,
'type' | 'a_mult' | 'bg_th'
>;
/**
* The MLSD processor node, with parameters flagged as required
*/
export type RequiredMlsdImageProcessorInvocation = O.Required<
MlsdImageProcessorInvocation,
'type' | 'detect_resolution' | 'image_resolution' | 'thr_v' | 'thr_d'
>;
/**
* The NormalBae processor node, with parameters flagged as required
*/
export type RequiredNormalbaeImageProcessorInvocation = O.Required<
NormalbaeImageProcessorInvocation,
'type' | 'detect_resolution' | 'image_resolution'
>;
/**
* The Openpose processor node, with parameters flagged as required
*/
export type RequiredOpenposeImageProcessorInvocation = O.Required<
OpenposeImageProcessorInvocation,
'type' | 'detect_resolution' | 'image_resolution' | 'hand_and_face'
>;
/**
* The Pidi processor node, with parameters flagged as required
*/
export type RequiredPidiImageProcessorInvocation = O.Required<
PidiImageProcessorInvocation,
'type' | 'detect_resolution' | 'image_resolution' | 'safe' | 'scribble'
>;
/**
* The ZoeDepth processor node, with parameters flagged as required
*/
export type RequiredZoeDepthImageProcessorInvocation = O.Required<
ZoeDepthImageProcessorInvocation,
'type'
>;
/**
* Any ControlNet Processor node, with its parameters flagged as required
*/
export type RequiredControlNetProcessorNode =
| RequiredCannyImageProcessorInvocation
| RequiredContentShuffleImageProcessorInvocation
| RequiredHedImageProcessorInvocation
| RequiredLineartAnimeImageProcessorInvocation
| RequiredLineartImageProcessorInvocation
| RequiredMediapipeFaceProcessorInvocation
| RequiredMidasDepthImageProcessorInvocation
| RequiredMlsdImageProcessorInvocation
| RequiredNormalbaeImageProcessorInvocation
| RequiredOpenposeImageProcessorInvocation
| RequiredPidiImageProcessorInvocation
| RequiredZoeDepthImageProcessorInvocation;
/**
* Type guard for CannyImageProcessorInvocation
*/
export const isCannyImageProcessorInvocation = (
obj: unknown
): obj is CannyImageProcessorInvocation => {
if (isObject(obj) && 'type' in obj && obj.type === 'canny_image_processor') {
return true;
}
return false;
};
/**
* Type guard for ContentShuffleImageProcessorInvocation
*/
export const isContentShuffleImageProcessorInvocation = (
obj: unknown
): obj is ContentShuffleImageProcessorInvocation => {
if (
isObject(obj) &&
'type' in obj &&
obj.type === 'content_shuffle_image_processor'
) {
return true;
}
return false;
};
/**
* Type guard for HedImageprocessorInvocation
*/
export const isHedImageprocessorInvocation = (
obj: unknown
): obj is HedImageProcessorInvocation => {
if (isObject(obj) && 'type' in obj && obj.type === 'hed_image_processor') {
return true;
}
return false;
};
/**
* Type guard for LineartAnimeImageProcessorInvocation
*/
export const isLineartAnimeImageProcessorInvocation = (
obj: unknown
): obj is LineartAnimeImageProcessorInvocation => {
if (
isObject(obj) &&
'type' in obj &&
obj.type === 'lineart_anime_image_processor'
) {
return true;
}
return false;
};
/**
* Type guard for LineartImageProcessorInvocation
*/
export const isLineartImageProcessorInvocation = (
obj: unknown
): obj is LineartImageProcessorInvocation => {
if (
isObject(obj) &&
'type' in obj &&
obj.type === 'lineart_image_processor'
) {
return true;
}
return false;
};
/**
* Type guard for MediapipeFaceProcessorInvocation
*/
export const isMediapipeFaceProcessorInvocation = (
obj: unknown
): obj is MediapipeFaceProcessorInvocation => {
if (
isObject(obj) &&
'type' in obj &&
obj.type === 'mediapipe_face_processor'
) {
return true;
}
return false;
};
/**
* Type guard for MidasDepthImageProcessorInvocation
*/
export const isMidasDepthImageProcessorInvocation = (
obj: unknown
): obj is MidasDepthImageProcessorInvocation => {
if (
isObject(obj) &&
'type' in obj &&
obj.type === 'midas_depth_image_processor'
) {
return true;
}
return false;
};
/**
* Type guard for MlsdImageProcessorInvocation
*/
export const isMlsdImageProcessorInvocation = (
obj: unknown
): obj is MlsdImageProcessorInvocation => {
if (isObject(obj) && 'type' in obj && obj.type === 'mlsd_image_processor') {
return true;
}
return false;
};
/**
* Type guard for NormalbaeImageProcessorInvocation
*/
export const isNormalbaeImageProcessorInvocation = (
obj: unknown
): obj is NormalbaeImageProcessorInvocation => {
if (
isObject(obj) &&
'type' in obj &&
obj.type === 'normalbae_image_processor'
) {
return true;
}
return false;
};
/**
* Type guard for OpenposeImageProcessorInvocation
*/
export const isOpenposeImageProcessorInvocation = (
obj: unknown
): obj is OpenposeImageProcessorInvocation => {
if (
isObject(obj) &&
'type' in obj &&
obj.type === 'openpose_image_processor'
) {
return true;
}
return false;
};
/**
* Type guard for PidiImageProcessorInvocation
*/
export const isPidiImageProcessorInvocation = (
obj: unknown
): obj is PidiImageProcessorInvocation => {
if (isObject(obj) && 'type' in obj && obj.type === 'pidi_image_processor') {
return true;
}
return false;
};
/**
* Type guard for ZoeDepthImageProcessorInvocation
*/
export const isZoeDepthImageProcessorInvocation = (
obj: unknown
): obj is ZoeDepthImageProcessorInvocation => {
if (
isObject(obj) &&
'type' in obj &&
obj.type === 'zoe_depth_image_processor'
) {
return true;
}
return false;
};

View File

@ -12,9 +12,7 @@ import {
TextToLatentsInvocation, TextToLatentsInvocation,
} from 'services/api'; } from 'services/api';
import { NonNullableGraph } from 'features/nodes/types/types'; import { NonNullableGraph } from 'features/nodes/types/types';
import { forEach, map, size } from 'lodash-es'; import { forEach, size } from 'lodash-es';
import { ControlNetProcessorNode } from 'features/controlNet/store/types';
import { ControlNetModel } from 'features/controlNet/store/controlNetSlice';
const POSITIVE_CONDITIONING = 'positive_conditioning'; const POSITIVE_CONDITIONING = 'positive_conditioning';
const NEGATIVE_CONDITIONING = 'negative_conditioning'; const NEGATIVE_CONDITIONING = 'negative_conditioning';
@ -344,7 +342,7 @@ export const buildTextToImageGraph = (state: RootState): Graph => {
beginStepPct, beginStepPct,
endStepPct, endStepPct,
model, model,
processor, processorNode,
weight, weight,
} = controlNet; } = controlNet;