feat(ui): wip controlnet ui

This commit is contained in:
psychedelicious 2023-06-01 14:17:32 +10:00
parent d6a959b000
commit e2e07696fc
16 changed files with 579 additions and 15 deletions

View File

@ -70,6 +70,7 @@ import {
import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSaved';
import { addCommitStagingAreaImageListener } from './listeners/addCommitStagingAreaImageListener';
import { addImageCategoriesChangedListener } from './listeners/imageCategoriesChanged';
import { addControlNetImageProcessedListener } from './listeners/controlNetImageProcessed';
export const listenerMiddleware = createListenerMiddleware();
@ -173,3 +174,6 @@ addReceivedPageOfImagesRejectedListener();
// Gallery
addImageCategoriesChangedListener();
// ControlNet
addControlNetImageProcessedListener();

View File

@ -0,0 +1,60 @@
import { startAppListening } from '..';
import { imageMetadataReceived, imageUploaded } from 'services/thunks/image';
import { addToast } from 'features/system/store/systemSlice';
import { log } from 'app/logging/useLogger';
import { controlNetImageProcessed } from 'features/controlNet/store/actions';
import { Graph } from 'services/api';
import { sessionCreated } from 'services/thunks/session';
import { sessionReadyToInvoke } from 'features/system/store/actions';
import { appSocketInvocationComplete } from 'services/events/actions';
import { isImageOutput } from 'services/types/guards';
import { controlNetProcessedImageChanged } from 'features/controlNet/store/controlNetSlice';
import { selectImagesById } from 'features/gallery/store/imagesSlice';
const moduleLog = log.child({ namespace: 'controlNet' });
export const addControlNetImageProcessedListener = () => {
startAppListening({
actionCreator: controlNetImageProcessed,
effect: async (action, { dispatch, getState, take }) => {
const { controlNetId, processorNode } = action.payload;
const { id } = processorNode;
const graph: Graph = {
nodes: { [id]: processorNode },
};
const sessionCreatedAction = dispatch(sessionCreated({ graph }));
const [sessionCreatedFulfilledAction] = await take(
(action): action is ReturnType<typeof sessionCreated.fulfilled> =>
sessionCreated.fulfilled.match(action) &&
action.meta.requestId === sessionCreatedAction.requestId
);
const sessionId = sessionCreatedFulfilledAction.payload.id;
dispatch(sessionReadyToInvoke());
const [processorAction] = await take(
(action): action is ReturnType<typeof appSocketInvocationComplete> =>
appSocketInvocationComplete.match(action) &&
action.payload.data.graph_execution_state_id === sessionId
);
if (isImageOutput(processorAction.payload.data.result)) {
const { image_name } = processorAction.payload.data.result.image;
const [imageMetadataReceivedAction] = await take(
(
action
): action is ReturnType<typeof imageMetadataReceived.fulfilled> =>
imageMetadataReceived.fulfilled.match(action) &&
action.payload.image_name === image_name
);
const processedControlImage = imageMetadataReceivedAction.payload;
dispatch(
controlNetProcessedImageChanged({
controlNetId,
processedControlImage,
})
);
}
},
});
};

View File

@ -49,7 +49,7 @@ const IAICollapse = (props: IAIToggleCollapseProps) => {
/>
)}
</Flex>
<Collapse in={isOpen} animateOpacity>
<Collapse in={isOpen} animateOpacity style={{ overflow: 'unset' }}>
<Box sx={{ p: 4, borderBottomRadius: 'base', bg: 'base.800' }}>
{children}
</Box>

View File

@ -1,5 +1,5 @@
import { Badge, Flex } from '@chakra-ui/react';
import { isNumber, isString } from 'lodash-es';
import { isString } from 'lodash-es';
import { useMemo } from 'react';
import { ImageDTO } from 'services/api';
@ -8,14 +8,6 @@ type ImageMetadataOverlayProps = {
};
const ImageMetadataOverlay = ({ image }: ImageMetadataOverlayProps) => {
const dimensions = useMemo(() => {
if (!isNumber(image.metadata?.width) || isNumber(!image.metadata?.height)) {
return;
}
return `${image.metadata?.width} × ${image.metadata?.height}`;
}, [image.metadata]);
const model = useMemo(() => {
if (!isString(image.metadata?.model)) {
return;
@ -37,11 +29,9 @@ const ImageMetadataOverlay = ({ image }: ImageMetadataOverlayProps) => {
gap: 2,
}}
>
{dimensions && (
<Badge variant="solid" colorScheme="base">
{dimensions}
</Badge>
)}
<Badge variant="solid" colorScheme="base">
{image.width} × {image.height}
</Badge>
{model && (
<Badge variant="solid" colorScheme="base">
{model}

View File

@ -0,0 +1,27 @@
import { memo } from 'react';
import { ControlNetProcessorNode } from '../store/types';
import { ImageDTO } from 'services/api';
import CannyProcessor from './processors/CannyProcessor';
export type ControlNetProcessorProps = {
controlNetId: string;
image: ImageDTO;
type: ControlNetProcessorNode['type'];
};
const renderProcessorComponent = (props: ControlNetProcessorProps) => {
const { type } = props;
if (type === 'canny_image_processor') {
return <CannyProcessor {...props} />;
}
};
const ControlNet = () => {
return (
<div>
<h1>ControlNet</h1>
</div>
);
};
export default memo(ControlNet);

View File

@ -0,0 +1,64 @@
import { Flex } from '@chakra-ui/react';
import IAISlider from 'common/components/IAISlider';
import { memo, useCallback, useState } from 'react';
import ControlNetProcessButton from './common/ControlNetProcessButton';
import { useAppDispatch } from 'app/store/storeHooks';
import { controlNetImageProcessed } from 'features/controlNet/store/actions';
import { ImageDTO } from 'services/api';
import ControlNetProcessorImage from './common/ControlNetProcessorImage';
import { ControlNetProcessorProps } from '../ControlNet';
export const CANNY_PROCESSOR = 'canny_processor';
const CannyProcessor = (props: ControlNetProcessorProps) => {
const { controlNetId, image, type } = props;
const dispatch = useAppDispatch();
const [lowThreshold, setLowThreshold] = useState(100);
const [highThreshold, setHighThreshold] = useState(200);
const handleProcess = useCallback(() => {
if (!image) {
return;
}
dispatch(
controlNetImageProcessed({
controlNetId,
processorNode: {
id: CANNY_PROCESSOR,
type: 'canny_image_processor',
image: {
image_name: image.image_name,
image_origin: image.image_origin,
},
low_threshold: lowThreshold,
high_threshold: highThreshold,
},
})
);
}, [controlNetId, dispatch, highThreshold, image, lowThreshold]);
return (
<Flex sx={{ flexDirection: 'column', gap: 2 }}>
<IAISlider
label="Low Threshold"
value={lowThreshold}
onChange={setLowThreshold}
min={0}
max={255}
withInput
/>
<IAISlider
label="High Threshold"
value={highThreshold}
onChange={setHighThreshold}
min={0}
max={255}
withInput
/>
<ControlNetProcessButton onClick={handleProcess} />
</Flex>
);
};
export default memo(CannyProcessor);

View File

@ -0,0 +1,42 @@
import { Flex } from '@chakra-ui/react';
import IAISlider from 'common/components/IAISlider';
import IAISwitch from 'common/components/IAISwitch';
import { ChangeEvent, memo, useState } from 'react';
const HedPreprocessor = () => {
const [detectResolution, setDetectResolution] = useState(512);
const [imageResolution, setImageResolution] = useState(512);
const [isScribbleEnabled, setIsScribbleEnabled] = useState(false);
const handleChangeScribble = (e: ChangeEvent<HTMLInputElement>) => {
setIsScribbleEnabled(e.target.checked);
};
return (
<Flex sx={{ flexDirection: 'column', gap: 2 }}>
<IAISlider
label="Detect Resolution"
value={detectResolution}
onChange={setDetectResolution}
min={0}
max={4096}
withInput
/>
<IAISlider
label="Image Resolution"
value={imageResolution}
onChange={setImageResolution}
min={0}
max={4096}
withInput
/>
<IAISwitch
label="Scribble"
isChecked={isScribbleEnabled}
onChange={handleChangeScribble}
/>
</Flex>
);
};
export default memo(HedPreprocessor);

View File

@ -0,0 +1,31 @@
import { Flex } from '@chakra-ui/react';
import IAISlider from 'common/components/IAISlider';
import { memo, useState } from 'react';
const LineartPreprocessor = () => {
const [detectResolution, setDetectResolution] = useState(512);
const [imageResolution, setImageResolution] = useState(512);
return (
<Flex sx={{ flexDirection: 'column', gap: 2 }}>
<IAISlider
label="Detect Resolution"
value={detectResolution}
onChange={setDetectResolution}
min={0}
max={4096}
withInput
/>
<IAISlider
label="Image Resolution"
value={imageResolution}
onChange={setImageResolution}
min={0}
max={4096}
withInput
/>
</Flex>
);
};
export default memo(LineartPreprocessor);

View File

@ -0,0 +1,42 @@
import { Flex } from '@chakra-ui/react';
import IAISlider from 'common/components/IAISlider';
import IAISwitch from 'common/components/IAISwitch';
import { ChangeEvent, memo, useState } from 'react';
const LineartPreprocessor = () => {
const [detectResolution, setDetectResolution] = useState(512);
const [imageResolution, setImageResolution] = useState(512);
const [isCoarseEnabled, setIsCoarseEnabled] = useState(false);
const handleChangeScribble = (e: ChangeEvent<HTMLInputElement>) => {
setIsCoarseEnabled(e.target.checked);
};
return (
<Flex sx={{ flexDirection: 'column', gap: 2 }}>
<IAISlider
label="Detect Resolution"
value={detectResolution}
onChange={setDetectResolution}
min={0}
max={4096}
withInput
/>
<IAISlider
label="Image Resolution"
value={imageResolution}
onChange={setImageResolution}
min={0}
max={4096}
withInput
/>
<IAISwitch
label="Coarse"
isChecked={isCoarseEnabled}
onChange={handleChangeScribble}
/>
</Flex>
);
};
export default memo(LineartPreprocessor);

View File

@ -0,0 +1,13 @@
import IAIButton from 'common/components/IAIButton';
import { memo } from 'react';
type ControlNetProcessButtonProps = {
onClick: () => void;
};
const ControlNetProcessButton = (props: ControlNetProcessButtonProps) => {
const { onClick } = props;
return <IAIButton onClick={onClick}>Process Control Image</IAIButton>;
};
export default memo(ControlNetProcessButton);

View File

@ -0,0 +1,33 @@
import { Flex, Image } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { selectImagesById } from 'features/gallery/store/imagesSlice';
import { DragEvent, memo, useCallback } from 'react';
import { ImageDTO } from 'services/api';
type ControlNetProcessorImageProps = {
image: ImageDTO | undefined;
setImage: (image: ImageDTO) => void;
};
const ControlNetProcessorImage = (props: ControlNetProcessorImageProps) => {
const { image, setImage } = props;
const state = useAppSelector((state) => state);
const handleDrop = useCallback(
(e: DragEvent<HTMLDivElement>) => {
const name = e.dataTransfer.getData('invokeai/imageName');
const droppedImage = selectImagesById(state, name);
if (droppedImage) {
setImage(droppedImage);
}
},
[setImage, state]
);
if (!image) {
return <Flex onDrop={handleDrop}>Upload Image</Flex>;
}
return <Image src={image.image_url} />;
};
export default memo(ControlNetProcessorImage);

View File

@ -0,0 +1,7 @@
import { createAction } from '@reduxjs/toolkit';
import { ControlNetProcessorNode } from './types';
export const controlNetImageProcessed = createAction<{
controlNetId: string;
processorNode: ControlNetProcessorNode;
}>('controlNet/imageProcessed');

View File

@ -0,0 +1,159 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import { ImageDTO } from 'services/api';
export const CONTROLNET_MODELS = [
'lllyasviel/sd-controlnet-canny',
'lllyasviel/sd-controlnet-depth',
'lllyasviel/sd-controlnet-hed',
'lllyasviel/sd-controlnet-seg',
'lllyasviel/sd-controlnet-openpose',
'lllyasviel/sd-controlnet-scribble',
'lllyasviel/sd-controlnet-normal',
'lllyasviel/sd-controlnet-mlsd',
] as const;
export const CONTROLNET_PROCESSORS = [
'canny',
'contentShuffle',
'hed',
'lineart',
'lineartAnime',
'mediapipeFace',
'midasDepth',
'mlsd',
'normalBae',
'openpose',
'pidi',
'zoeDepth',
] as const;
export type ControlNetModel = (typeof CONTROLNET_MODELS)[number];
export const initialControlNet: Omit<ControlNet, 'controlNetId'> = {
isEnabled: true,
model: CONTROLNET_MODELS[0],
weight: 1,
beginStepPct: 0,
endStepPct: 1,
controlImage: null,
processedControlImage: null,
};
export type ControlNet = {
controlNetId: string;
isEnabled: boolean;
model: string;
weight: number;
beginStepPct: number;
endStepPct: number;
controlImage: ImageDTO | null;
processedControlImage: ImageDTO | null;
};
export type ControlNetState = {
controlNets: Record<string, ControlNet>;
};
export const initialControlNetState: ControlNetState = {
controlNets: {},
};
export const controlNetSlice = createSlice({
name: 'controlNet',
initialState: initialControlNetState,
reducers: {
controlNetAddedFromModel: (
state,
action: PayloadAction<{ controlNetId: string; model: ControlNetModel }>
) => {
const { controlNetId, model } = action.payload;
state.controlNets[controlNetId] = {
...initialControlNet,
controlNetId,
model,
};
},
controlNetAddedFromImage: (
state,
action: PayloadAction<{ controlNetId: string; controlImage: ImageDTO }>
) => {
const { controlNetId, controlImage } = action.payload;
state.controlNets[controlNetId] = {
...initialControlNet,
controlNetId,
controlImage,
};
},
controlNetRemoved: (state, action: PayloadAction<string>) => {
const controlNetId = action.payload;
delete state.controlNets[controlNetId];
},
controlNetToggled: (state, action: PayloadAction<string>) => {
const controlNetId = action.payload;
state.controlNets[controlNetId].isEnabled =
!state.controlNets[controlNetId].isEnabled;
},
controlNetImageChanged: (
state,
action: PayloadAction<{ controlNetId: string; controlImage: ImageDTO }>
) => {
const { controlNetId, controlImage } = action.payload;
state.controlNets[controlNetId].controlImage = controlImage;
},
controlNetProcessedImageChanged: (
state,
action: PayloadAction<{
controlNetId: string;
processedControlImage: ImageDTO | null;
}>
) => {
const { controlNetId, processedControlImage } = action.payload;
state.controlNets[controlNetId].processedControlImage =
processedControlImage;
},
controlNetModelChanged: (
state,
action: PayloadAction<{ controlNetId: string; model: ControlNetModel }>
) => {
const { controlNetId, model } = action.payload;
state.controlNets[controlNetId].model = model;
},
controlNetWeightChanged: (
state,
action: PayloadAction<{ controlNetId: string; weight: number }>
) => {
const { controlNetId, weight } = action.payload;
state.controlNets[controlNetId].weight = weight;
},
controlNetBeginStepPctChanged: (
state,
action: PayloadAction<{ controlNetId: string; beginStepPct: number }>
) => {
const { controlNetId, beginStepPct } = action.payload;
state.controlNets[controlNetId].beginStepPct = beginStepPct;
},
controlNetEndStepPctChanged: (
state,
action: PayloadAction<{ controlNetId: string; endStepPct: number }>
) => {
const { controlNetId, endStepPct } = action.payload;
state.controlNets[controlNetId].endStepPct = endStepPct;
},
},
});
export const {
controlNetAddedFromModel,
controlNetAddedFromImage,
controlNetRemoved,
controlNetImageChanged,
controlNetProcessedImageChanged,
controlNetToggled,
controlNetModelChanged,
controlNetWeightChanged,
controlNetBeginStepPctChanged,
controlNetEndStepPctChanged,
} = controlNetSlice.actions;
export default controlNetSlice.reducer;

View File

@ -0,0 +1,28 @@
import {
CannyImageProcessorInvocation,
ContentShuffleImageProcessorInvocation,
HedImageprocessorInvocation,
LineartAnimeImageProcessorInvocation,
LineartImageProcessorInvocation,
MediapipeFaceProcessorInvocation,
MidasDepthImageProcessorInvocation,
MlsdImageProcessorInvocation,
NormalbaeImageProcessorInvocation,
OpenposeImageProcessorInvocation,
PidiImageProcessorInvocation,
ZoeDepthImageProcessorInvocation,
} from 'services/api';
export type ControlNetProcessorNode =
| CannyImageProcessorInvocation
| HedImageprocessorInvocation
| LineartImageProcessorInvocation
| LineartAnimeImageProcessorInvocation
| OpenposeImageProcessorInvocation
| MidasDepthImageProcessorInvocation
| NormalbaeImageProcessorInvocation
| MlsdImageProcessorInvocation
| PidiImageProcessorInvocation
| ContentShuffleImageProcessorInvocation
| ZoeDepthImageProcessorInvocation
| MediapipeFaceProcessorInvocation;

View File

@ -0,0 +1,62 @@
import { Flex, Text, useDisclosure } from '@chakra-ui/react';
import { useTranslation } from 'react-i18next';
import IAICollapse from 'common/components/IAICollapse';
import { memo, useCallback, useState } from 'react';
import IAICustomSelect from 'common/components/IAICustomSelect';
import IAIIconButton from 'common/components/IAIIconButton';
import { FaPlus } from 'react-icons/fa';
import CannyProcessor from 'features/controlNet/components/processors/CannyProcessor';
import ControlNet from 'features/controlNet/components/ControlNet';
const CONTROLNET_MODELS = [
'lllyasviel/sd-controlnet-canny',
'lllyasviel/sd-controlnet-depth',
'lllyasviel/sd-controlnet-hed',
'lllyasviel/sd-controlnet-seg',
'lllyasviel/sd-controlnet-openpose',
'lllyasviel/sd-controlnet-scribble',
'lllyasviel/sd-controlnet-normal',
'lllyasviel/sd-controlnet-mlsd',
];
const ParamControlNetCollapse = () => {
const { t } = useTranslation();
const { isOpen, onToggle } = useDisclosure();
const [model, setModel] = useState<string>(CONTROLNET_MODELS[0]);
const handleSetControlNet = useCallback(
(model: string | null | undefined) => {
if (model) {
setModel(model);
}
},
[]
);
return (
<ControlNet />
// <IAICollapse
// label={'ControlNet'}
// // label={t('parameters.seamCorrectionHeader')}
// isOpen={isOpen}
// onToggle={onToggle}
// >
// <Flex sx={{ alignItems: 'flex-end' }}>
// <IAICustomSelect
// label="ControlNet Model"
// items={CONTROLNET_MODELS}
// selectedItem={model}
// setSelectedItem={handleSetControlNet}
// />
// <IAIIconButton
// size="sm"
// aria-label="Add ControlNet"
// icon={<FaPlus />}
// />
// </Flex>
// <CannyProcessor />
// </IAICollapse>
);
};
export default memo(ParamControlNetCollapse);

View File

@ -9,6 +9,7 @@ import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Sym
import ParamHiresCollapse from 'features/parameters/components/Parameters/Hires/ParamHiresCollapse';
import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse';
import TextToImageTabCoreParameters from './TextToImageTabCoreParameters';
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
const TextToImageTabParameters = () => {
return (
@ -18,6 +19,7 @@ const TextToImageTabParameters = () => {
<ProcessButtons />
<TextToImageTabCoreParameters />
<ParamSeedCollapse />
<ParamControlNetCollapse />
<ParamVariationCollapse />
<ParamNoiseCollapse />
<ParamSymmetryCollapse />