feat(ui): updated controlnet logic/ui

This commit is contained in:
psychedelicious 2023-06-03 22:48:16 +10:00
parent 2270c270ef
commit 03f3ad435a
9 changed files with 73 additions and 94 deletions

View File

@ -6,7 +6,6 @@ import {
controlNetImageChanged, controlNetImageChanged,
controlNetProcessorParamsChanged, controlNetProcessorParamsChanged,
controlNetProcessorTypeChanged, controlNetProcessorTypeChanged,
isControlNetImagePreprocessedToggled,
} from 'features/controlNet/store/controlNetSlice'; } from 'features/controlNet/store/controlNetSlice';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
@ -16,25 +15,22 @@ const predicate = (action: AnyAction, state: RootState) => {
const isActionMatched = const isActionMatched =
controlNetProcessorParamsChanged.match(action) || controlNetProcessorParamsChanged.match(action) ||
controlNetImageChanged.match(action) || controlNetImageChanged.match(action) ||
controlNetProcessorTypeChanged.match(action) || controlNetProcessorTypeChanged.match(action);
isControlNetImagePreprocessedToggled.match(action);
if (!isActionMatched) { if (!isActionMatched) {
return false; return false;
} }
const { controlNetId } = action.payload; const { controlImage, processorType } =
state.controlNet.controlNets[action.payload.controlNetId];
const shouldAutoProcess = const isProcessorSelected = processorType !== 'none';
!state.controlNet.controlNets[controlNetId].isPreprocessed;
const isBusy = state.system.isProcessing; const isBusy = state.system.isProcessing;
const hasControlImage = Boolean( const hasControlImage = Boolean(controlImage);
state.controlNet.controlNets[controlNetId].controlImage
);
return shouldAutoProcess && !isBusy && hasControlImage; return isProcessorSelected && !isBusy && hasControlImage;
}; };
/** /**

View File

@ -4,7 +4,6 @@ import {
controlNetAdded, controlNetAdded,
controlNetRemoved, controlNetRemoved,
controlNetToggled, controlNetToggled,
isControlNetImagePreprocessedToggled,
} from '../store/controlNetSlice'; } from '../store/controlNetSlice';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import ParamControlNetModel from './parameters/ParamControlNetModel'; import ParamControlNetModel from './parameters/ParamControlNetModel';
@ -22,7 +21,7 @@ import {
TabPanel, TabPanel,
Box, Box,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { FaCopy, FaTrash } from 'react-icons/fa'; import { FaCopy, FaPlus, FaTrash, FaWrench } from 'react-icons/fa';
import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd'; import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
import ControlNetImagePreview from './ControlNetImagePreview'; import ControlNetImagePreview from './ControlNetImagePreview';
@ -34,6 +33,7 @@ import ControlNetProcessorComponent from './ControlNetProcessorComponent';
import ControlNetPreprocessButton from './ControlNetPreprocessButton'; import ControlNetPreprocessButton from './ControlNetPreprocessButton';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAISwitch from 'common/components/IAISwitch'; import IAISwitch from 'common/components/IAISwitch';
import { ChevronDownIcon, ChevronUpIcon } from '@chakra-ui/icons';
type ControlNetProps = { type ControlNetProps = {
controlNet: ControlNetConfig; controlNet: ControlNetConfig;
@ -48,12 +48,12 @@ const ControlNet = (props: ControlNetProps) => {
beginStepPct, beginStepPct,
endStepPct, endStepPct,
controlImage, controlImage,
isPreprocessed,
processedControlImage, processedControlImage,
processorNode, processorNode,
processorType,
} = props.controlNet; } = props.controlNet;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const [shouldShowAdvanced, onToggleAdvanced] = useToggle(true); const [shouldShowAdvanced, onToggleAdvanced] = useToggle(false);
const handleDelete = useCallback(() => { const handleDelete = useCallback(() => {
dispatch(controlNetRemoved({ controlNetId })); dispatch(controlNetRemoved({ controlNetId }));
@ -69,23 +69,20 @@ const ControlNet = (props: ControlNetProps) => {
dispatch(controlNetToggled({ controlNetId })); dispatch(controlNetToggled({ controlNetId }));
}, [controlNetId, dispatch]); }, [controlNetId, dispatch]);
const handleToggleIsPreprocessed = useCallback(() => {
dispatch(isControlNetImagePreprocessedToggled({ controlNetId }));
}, [controlNetId, dispatch]);
return ( return (
<Flex <Flex
sx={{ sx={{
flexDir: 'column', flexDir: 'column',
gap: 2, gap: 2,
p: 2, p: 3,
bg: 'base.850', bg: 'base.850',
borderRadius: 'base', borderRadius: 'base',
}} }}
> >
<Flex sx={{ gap: 2 }}> <Flex sx={{ gap: 2 }}>
<IAISwitch <IAISwitch
aria-label="Toggle ControlNet" tooltip="Toggle"
aria-label="Toggle"
isChecked={isEnabled} isChecked={isEnabled}
onChange={handleToggleIsEnabled} onChange={handleToggleIsEnabled}
/> />
@ -103,19 +100,38 @@ const ControlNet = (props: ControlNetProps) => {
</Box> </Box>
<IAIIconButton <IAIIconButton
size="sm" size="sm"
tooltip="Duplicate ControlNet" tooltip="Duplicate"
aria-label="Duplicate ControlNet" aria-label="Duplicate"
onClick={handleDuplicate} onClick={handleDuplicate}
icon={<FaCopy />} icon={<FaCopy />}
/> />
<IAIIconButton <IAIIconButton
size="sm" size="sm"
tooltip="Delete ControlNet" tooltip="Delete"
aria-label="Delete ControlNet" aria-label="Delete"
colorScheme="error" colorScheme="error"
onClick={handleDelete} onClick={handleDelete}
icon={<FaTrash />} icon={<FaTrash />}
/> />
<IAIIconButton
size="sm"
aria-label="Expand"
onClick={onToggleAdvanced}
variant="link"
icon={
<ChevronUpIcon
sx={{
boxSize: 4,
color: 'base.300',
transform: shouldShowAdvanced
? 'rotate(0deg)'
: 'rotate(180deg)',
transitionProperty: 'common',
transitionDuration: 'normal',
}}
/>
}
/>
</Flex> </Flex>
{isEnabled && ( {isEnabled && (
<> <>
@ -125,38 +141,13 @@ const ControlNet = (props: ControlNetProps) => {
flexDir: 'column', flexDir: 'column',
gap: 2, gap: 2,
w: 'full', w: 'full',
h: 32, h: 24,
paddingInlineStart: 2, paddingInlineStart: 1,
paddingInlineEnd: shouldShowAdvanced ? 2 : 0, paddingInlineEnd: shouldShowAdvanced ? 1 : 0,
pb: 2, pb: 2,
justifyContent: 'space-between', justifyContent: 'space-between',
}} }}
> >
<Flex
sx={{
justifyContent: 'space-between',
w: 'full',
}}
>
<FormControl>
<HStack>
<Checkbox
isChecked={isPreprocessed}
onChange={handleToggleIsPreprocessed}
/>
<FormLabel>Preprocessed</FormLabel>
</HStack>
</FormControl>
<FormControl>
<HStack>
<Checkbox
isChecked={shouldShowAdvanced}
onChange={onToggleAdvanced}
/>
<FormLabel>Advanced</FormLabel>
</HStack>
</FormControl>
</Flex>
<ParamControlNetWeight <ParamControlNetWeight
controlNetId={controlNetId} controlNetId={controlNetId}
weight={weight} weight={weight}
@ -174,8 +165,8 @@ const ControlNet = (props: ControlNetProps) => {
sx={{ sx={{
alignItems: 'center', alignItems: 'center',
justifyContent: 'center', justifyContent: 'center',
h: 32, h: 24,
w: 32, w: 24,
aspectRatio: '1/1', aspectRatio: '1/1',
}} }}
> >
@ -188,8 +179,6 @@ const ControlNet = (props: ControlNetProps) => {
<Box pt={2}> <Box pt={2}>
<ControlNetImagePreview controlNet={props.controlNet} /> <ControlNetImagePreview controlNet={props.controlNet} />
</Box> </Box>
{!isPreprocessed && (
<>
<ParamControlNetProcessorSelect <ParamControlNetProcessorSelect
controlNetId={controlNetId} controlNetId={controlNetId}
processorNode={processorNode} processorNode={processorNode}
@ -202,8 +191,6 @@ const ControlNet = (props: ControlNetProps) => {
)} )}
</> </>
)} )}
</>
)}
</Flex> </Flex>
); );

View File

@ -28,12 +28,8 @@ type Props = {
}; };
const ControlNetImagePreview = (props: Props) => { const ControlNetImagePreview = (props: Props) => {
const { const { controlNetId, controlImage, processedControlImage, processorType } =
controlNetId, props.controlNet;
controlImage,
processedControlImage,
isPreprocessed: isControlImageProcessed,
} = props.controlNet;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { isProcessingControlImage } = useAppSelector(selector); const { isProcessingControlImage } = useAppSelector(selector);
const containerRef = useRef<HTMLDivElement>(null); const containerRef = useRef<HTMLDivElement>(null);
@ -56,7 +52,7 @@ const ControlNetImagePreview = (props: Props) => {
processedControlImage && processedControlImage &&
!isMouseOverImage && !isMouseOverImage &&
!isProcessingControlImage && !isProcessingControlImage &&
!isControlImageProcessed; processorType !== 'none';
return ( return (
<Box ref={containerRef} sx={{ position: 'relative', w: 'full', h: 'full' }}> <Box ref={containerRef} sx={{ position: 'relative', w: 'full', h: 'full' }}>
@ -64,7 +60,7 @@ const ControlNetImagePreview = (props: Props) => {
image={controlImage} image={controlImage}
onDrop={handleControlImageChanged} onDrop={handleControlImageChanged}
isDropDisabled={Boolean( isDropDisabled={Boolean(
processedControlImage && !isControlImageProcessed processedControlImage && processorType !== 'none'
)} )}
/> />
<AnimatePresence> <AnimatePresence>

View File

@ -27,9 +27,12 @@ const ParamIsControlNetModel = (props: ParamIsControlNetModelProps) => {
return ( return (
<IAICustomSelect <IAICustomSelect
tooltip={model}
tooltipProps={{ placement: 'top', hasArrow: true }}
items={CONTROLNET_MODELS} items={CONTROLNET_MODELS}
selectedItem={model} selectedItem={model}
setSelectedItem={handleModelChanged} setSelectedItem={handleModelChanged}
ellipsisPosition="start"
withCheckIcon withCheckIcon
/> />
); );

View File

@ -4,7 +4,5 @@ import { PropsWithChildren } from 'react';
type Props = PropsWithChildren; type Props = PropsWithChildren;
export default function ProcessorWrapper(props: Props) { export default function ProcessorWrapper(props: Props) {
return ( return <Flex sx={{ flexDirection: 'column', gap: 2 }}>{props.children}</Flex>;
<Flex sx={{ flexDirection: 'column', gap: 2, p: 2 }}>{props.children}</Flex>
);
} }

View File

@ -24,6 +24,14 @@ type ControlNetProcessorsDict = Record<
* TODO: Generate from the OpenAPI schema * TODO: Generate from the OpenAPI schema
*/ */
export const CONTROLNET_PROCESSORS = { export const CONTROLNET_PROCESSORS = {
none: {
type: 'none',
label: 'None',
description: '',
default: {
type: 'none',
},
},
canny_image_processor: { canny_image_processor: {
type: 'canny_image_processor', type: 'canny_image_processor',
label: 'Canny', label: 'Canny',

View File

@ -21,8 +21,8 @@ export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
beginStepPct: 0, beginStepPct: 0,
endStepPct: 1, endStepPct: 1,
controlImage: null, controlImage: null,
isPreprocessed: false,
processedControlImage: null, processedControlImage: null,
processorType: 'canny_image_processor',
processorNode: CONTROLNET_PROCESSORS.canny_image_processor processorNode: CONTROLNET_PROCESSORS.canny_image_processor
.default as RequiredCannyImageProcessorInvocation, .default as RequiredCannyImageProcessorInvocation,
}; };
@ -35,8 +35,8 @@ export type ControlNetConfig = {
beginStepPct: number; beginStepPct: number;
endStepPct: number; endStepPct: number;
controlImage: ImageDTO | null; controlImage: ImageDTO | null;
isPreprocessed: boolean;
processedControlImage: ImageDTO | null; processedControlImage: ImageDTO | null;
processorType: ControlNetProcessorType;
processorNode: RequiredControlNetProcessorNode; processorNode: RequiredControlNetProcessorNode;
}; };
@ -110,19 +110,11 @@ export const controlNetSlice = createSlice({
state.controlNets[controlNetId].processedControlImage = null; state.controlNets[controlNetId].processedControlImage = null;
if ( if (
controlImage !== null && controlImage !== null &&
!state.controlNets[controlNetId].isPreprocessed state.controlNets[controlNetId].processorType !== 'none'
) { ) {
state.isProcessingControlImage = true; state.isProcessingControlImage = true;
} }
}, },
isControlNetImagePreprocessedToggled: (
state,
action: PayloadAction<{ controlNetId: string }>
) => {
const { controlNetId } = action.payload;
state.controlNets[controlNetId].isPreprocessed =
!state.controlNets[controlNetId].isPreprocessed;
},
controlNetProcessedImageChanged: ( controlNetProcessedImageChanged: (
state, state,
action: PayloadAction<{ action: PayloadAction<{
@ -188,6 +180,7 @@ export const controlNetSlice = createSlice({
}> }>
) => { ) => {
const { controlNetId, processorType } = action.payload; const { controlNetId, processorType } = action.payload;
state.controlNets[controlNetId].processorType = processorType;
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[ state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
processorType processorType
].default as RequiredControlNetProcessorNode; ].default as RequiredControlNetProcessorNode;
@ -210,7 +203,6 @@ export const {
controlNetAddedFromImage, controlNetAddedFromImage,
controlNetRemoved, controlNetRemoved,
controlNetImageChanged, controlNetImageChanged,
isControlNetImagePreprocessedToggled,
controlNetProcessedImageChanged, controlNetProcessedImageChanged,
controlNetToggled, controlNetToggled,
controlNetModelChanged, controlNetModelChanged,

View File

@ -36,7 +36,7 @@ export type ControlNetProcessorNode =
* Any ControlNet processor type * Any ControlNet processor type
*/ */
export type ControlNetProcessorType = NonNullable< export type ControlNetProcessorType = NonNullable<
ControlNetProcessorNode['type'] ControlNetProcessorNode['type'] | 'none'
>; >;
/** /**

View File

@ -33,13 +33,12 @@ export const addControlNetToLinearGraph = (
const { const {
controlNetId, controlNetId,
isEnabled, isEnabled,
isPreprocessed: isControlImageProcessed,
controlImage, controlImage,
processedControlImage, processedControlImage,
beginStepPct, beginStepPct,
endStepPct, endStepPct,
model, model,
processorNode, processorType,
weight, weight,
} = controlNet; } = controlNet;
@ -57,14 +56,14 @@ export const addControlNetToLinearGraph = (
control_weight: weight, control_weight: weight,
}; };
if (processedControlImage && !isControlImageProcessed) { if (processedControlImage && processorType !== 'none') {
// We've already processed the image in the app, so we can just use the processed image // We've already processed the image in the app, so we can just use the processed image
const { image_name, image_origin } = processedControlImage; const { image_name, image_origin } = processedControlImage;
controlNetNode.image = { controlNetNode.image = {
image_name, image_name,
image_origin, image_origin,
}; };
} else if (controlImage && isControlImageProcessed) { } else if (controlImage && processorType !== 'none') {
// The control image is preprocessed // The control image is preprocessed
const { image_name, image_origin } = controlImage; const { image_name, image_origin } = controlImage;
controlNetNode.image = { controlNetNode.image = {