Merge branch 'main' of github.com:invoke-ai/InvokeAI into feat/controlnet-control-modes

Only "real" conflicts were in:
     invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx
     invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts
This commit is contained in:
user1
2023-06-24 17:05:57 -07:00
341 changed files with 16419 additions and 11561 deletions

View File

@ -189,7 +189,7 @@ const ControlNet = (props: ControlNetProps) => {
<Box mt={2}>
<ControlNetImagePreview
controlNet={props.controlNet}
imageSx={expandedControlImageSx}
height={96}
/>
</Box>
<ParamControlNetProcessorSelect

View File

@ -1,19 +1,20 @@
import { memo, useCallback, useState } from 'react';
import { ImageDTO } from 'services/api';
import { ImageDTO } from 'services/api/types';
import {
ControlNetConfig,
controlNetImageChanged,
controlNetSelector,
} from '../store/controlNetSlice';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { Box, ChakraProps, Flex } from '@chakra-ui/react';
import { Box, Flex, SystemStyleObject } from '@chakra-ui/react';
import IAIDndImage from 'common/components/IAIDndImage';
import { createSelector } from '@reduxjs/toolkit';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { AnimatePresence, motion } from 'framer-motion';
import { IAIImageFallback } from 'common/components/IAIImageFallback';
import { IAIImageLoadingFallback } from 'common/components/IAIImageFallback';
import IAIIconButton from 'common/components/IAIIconButton';
import { FaUndo } from 'react-icons/fa';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { skipToken } from '@reduxjs/toolkit/dist/query';
const selector = createSelector(
controlNetSelector,
@ -26,29 +27,50 @@ const selector = createSelector(
type Props = {
controlNet: ControlNetConfig;
imageSx?: ChakraProps['sx'];
height: SystemStyleObject['h'];
};
const ControlNetImagePreview = (props: Props) => {
const { imageSx } = props;
const { controlNetId, controlImage, processedControlImage, processorType } =
props.controlNet;
const { height } = props;
const {
controlNetId,
controlImage: controlImageName,
processedControlImage: processedControlImageName,
processorType,
} = props.controlNet;
const dispatch = useAppDispatch();
const { pendingControlImages } = useAppSelector(selector);
const [isMouseOverImage, setIsMouseOverImage] = useState(false);
const {
currentData: controlImage,
isLoading: isLoadingControlImage,
isError: isErrorControlImage,
isSuccess: isSuccessControlImage,
} = useGetImageDTOQuery(controlImageName ?? skipToken);
const {
currentData: processedControlImage,
isLoading: isLoadingProcessedControlImage,
isError: isErrorProcessedControlImage,
isSuccess: isSuccessProcessedControlImage,
} = useGetImageDTOQuery(processedControlImageName ?? skipToken);
const handleDrop = useCallback(
(droppedImage: ImageDTO) => {
if (controlImage?.image_name === droppedImage.image_name) {
if (controlImageName === droppedImage.image_name) {
return;
}
setIsMouseOverImage(false);
dispatch(
controlNetImageChanged({ controlNetId, controlImage: droppedImage })
controlNetImageChanged({
controlNetId,
controlImage: droppedImage.image_name,
})
);
},
[controlImage, controlNetId, dispatch]
[controlImageName, controlNetId, dispatch]
);
const handleResetControlImage = useCallback(() => {
@ -62,10 +84,6 @@ const ControlNetImagePreview = (props: Props) => {
setIsMouseOverImage(false);
}, []);
const shouldShowProcessedImageBackdrop =
Number(controlImage?.width) > Number(processedControlImage?.width) ||
Number(controlImage?.height) > Number(processedControlImage?.height);
const shouldShowProcessedImage =
controlImage &&
processedControlImage &&
@ -74,72 +92,51 @@ const ControlNetImagePreview = (props: Props) => {
processorType !== 'none';
return (
<Box
<Flex
onMouseEnter={handleMouseEnter}
onMouseLeave={handleMouseLeave}
sx={{ position: 'relative', w: 'full', h: 'full' }}
sx={{
position: 'relative',
w: 'full',
h: height,
alignItems: 'center',
justifyContent: 'center',
}}
>
<IAIDndImage
image={controlImage}
onDrop={handleDrop}
isDropDisabled={Boolean(
processedControlImage && processorType !== 'none'
)}
isUploadDisabled={Boolean(controlImage)}
isDropDisabled={shouldShowProcessedImage}
postUploadAction={{ type: 'SET_CONTROLNET_IMAGE', controlNetId }}
imageSx={imageSx}
imageSx={{
w: 'full',
h: 'full',
}}
/>
<AnimatePresence>
{shouldShowProcessedImage && (
<motion.div
style={{ width: '100%' }}
initial={{
opacity: 0,
}}
animate={{
opacity: 1,
transition: { duration: 0.1 },
}}
exit={{
opacity: 0,
transition: { duration: 0.1 },
}}
>
<>
{shouldShowProcessedImageBackdrop && (
<Box
sx={{
position: 'absolute',
top: 0,
insetInlineStart: 0,
w: 'full',
h: 'full',
bg: 'base.900',
opacity: 0.7,
}}
/>
)}
<Box
sx={{
position: 'absolute',
top: 0,
insetInlineStart: 0,
w: 'full',
h: 'full',
}}
>
<IAIDndImage
image={processedControlImage}
onDrop={handleDrop}
payloadImage={controlImage}
isUploadDisabled={true}
imageSx={imageSx}
/>
</Box>
</>
</motion.div>
)}
</AnimatePresence>
<Box
sx={{
position: 'absolute',
top: 0,
insetInlineStart: 0,
w: 'full',
h: 'full',
opacity: shouldShowProcessedImage ? 1 : 0,
transitionProperty: 'common',
transitionDuration: 'normal',
pointerEvents: 'none',
}}
>
<IAIDndImage
image={processedControlImage}
onDrop={handleDrop}
payloadImage={controlImage}
isUploadDisabled={true}
imageSx={{
w: 'full',
h: 'full',
}}
/>
</Box>
{pendingControlImages.includes(controlNetId) && (
<Box
sx={{
@ -150,7 +147,7 @@ const ControlNetImagePreview = (props: Props) => {
h: 'full',
}}
>
<IAIImageFallback />
<IAIImageLoadingFallback />
</Box>
)}
{controlImage && (
@ -169,7 +166,7 @@ const ControlNetImagePreview = (props: Props) => {
/>
</Flex>
)}
</Box>
</Flex>
);
};

View File

@ -1,21 +1,22 @@
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import { forEach } from 'lodash-es';
import { ImageDTO } from 'services/api';
import { appSocketInvocationError } from 'services/events/actions';
import { imageDeleted, imageUrlsReceived } from 'services/thunks/image';
import { isAnySessionRejected } from 'services/thunks/session';
import { controlNetImageProcessed } from './actions';
import {
CONTROLNET_MODELS,
CONTROLNET_PROCESSORS,
ControlNetModelName,
} from './constants';
import { ImageDTO } from 'services/api/types';
import {
ControlNetProcessorType,
RequiredCannyImageProcessorInvocation,
RequiredControlNetProcessorNode,
} from './types';
import {
CONTROLNET_MODELS,
CONTROLNET_PROCESSORS,
ControlNetModelName,
} from './constants';
import { controlNetImageProcessed } from './actions';
import { imageDeleted, imageUrlsReceived } from 'services/api/thunks/image';
import { forEach } from 'lodash-es';
import { isAnySessionRejected } from 'services/api/thunks/session';
import { appSocketInvocationError } from 'services/events/actions';
export type ControlModes =
| 'balanced'
@ -23,21 +24,6 @@ export type ControlModes =
| 'more_control'
| 'unbalanced';
export type ControlNetConfig = {
controlNetId: string;
isEnabled: boolean;
model: ControlNetModelName;
weight: number;
beginStepPct: number;
endStepPct: number;
controlMode: ControlModes;
controlImage: ImageDTO | null;
processedControlImage: ImageDTO | null;
processorType: ControlNetProcessorType;
processorNode: RequiredControlNetProcessorNode;
shouldAutoConfig: boolean;
};
export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
isEnabled: true,
model: CONTROLNET_MODELS['lllyasviel/control_v11p_sd15_canny'].type,
@ -53,6 +39,21 @@ export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
shouldAutoConfig: true,
};
export type ControlNetConfig = {
controlNetId: string;
isEnabled: boolean;
model: ControlNetModelName;
weight: number;
beginStepPct: number;
endStepPct: number;
controlMode: ControlModes;
controlImage: string | null;
processedControlImage: string | null;
processorType: ControlNetProcessorType;
processorNode: RequiredControlNetProcessorNode;
shouldAutoConfig: boolean;
};
export type ControlNetState = {
controlNets: Record<string, ControlNetConfig>;
isEnabled: boolean;
@ -87,7 +88,7 @@ export const controlNetSlice = createSlice({
},
controlNetAddedFromImage: (
state,
action: PayloadAction<{ controlNetId: string; controlImage: ImageDTO }>
action: PayloadAction<{ controlNetId: string; controlImage: string }>
) => {
const { controlNetId, controlImage } = action.payload;
state.controlNets[controlNetId] = {
@ -115,7 +116,7 @@ export const controlNetSlice = createSlice({
state,
action: PayloadAction<{
controlNetId: string;
controlImage: ImageDTO | null;
controlImage: string | null;
}>
) => {
const { controlNetId, controlImage } = action.payload;
@ -132,7 +133,7 @@ export const controlNetSlice = createSlice({
state,
action: PayloadAction<{
controlNetId: string;
processedControlImage: ImageDTO | null;
processedControlImage: string | null;
}>
) => {
const { controlNetId, processedControlImage } = action.payload;
@ -154,13 +155,11 @@ export const controlNetSlice = createSlice({
state.controlNets[controlNetId].processedControlImage = null;
if (state.controlNets[controlNetId].shouldAutoConfig) {
const processorType =
CONTROLNET_MODELS[model as keyof typeof CONTROLNET_MODELS]
.defaultProcessor;
const processorType = CONTROLNET_MODELS[model].defaultProcessor;
if (processorType) {
state.controlNets[controlNetId].processorType = processorType;
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
processorType as keyof typeof CONTROLNET_PROCESSORS
processorType
].default as RequiredControlNetProcessorNode;
} else {
state.controlNets[controlNetId].processorType = 'none';
@ -226,7 +225,7 @@ export const controlNetSlice = createSlice({
state.controlNets[controlNetId].processedControlImage = null;
state.controlNets[controlNetId].processorType = processorType;
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
processorType as keyof typeof CONTROLNET_PROCESSORS
processorType
].default as RequiredControlNetProcessorNode;
state.controlNets[controlNetId].shouldAutoConfig = false;
},
@ -243,14 +242,12 @@ export const controlNetSlice = createSlice({
if (newShouldAutoConfig) {
// manage the processor for the user
const processorType =
CONTROLNET_MODELS[
state.controlNets[controlNetId]
.model as keyof typeof CONTROLNET_MODELS
].defaultProcessor;
CONTROLNET_MODELS[state.controlNets[controlNetId].model]
.defaultProcessor;
if (processorType) {
state.controlNets[controlNetId].processorType = processorType;
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
processorType as keyof typeof CONTROLNET_PROCESSORS
processorType
].default as RequiredControlNetProcessorNode;
} else {
state.controlNets[controlNetId].processorType = 'none';
@ -276,38 +273,38 @@ export const controlNetSlice = createSlice({
builder.addCase(imageDeleted.pending, (state, action) => {
// Preemptively remove the image from the gallery
const { imageName } = action.meta.arg;
const { image_name } = action.meta.arg;
forEach(state.controlNets, (c) => {
if (c.controlImage?.image_name === imageName) {
if (c.controlImage === image_name) {
c.controlImage = null;
c.processedControlImage = null;
}
if (c.processedControlImage?.image_name === imageName) {
if (c.processedControlImage === image_name) {
c.processedControlImage = null;
}
});
});
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_url, thumbnail_url } = action.payload;
// builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
// const { image_name, image_url, thumbnail_url } = action.payload;
forEach(state.controlNets, (c) => {
if (c.controlImage?.image_name === image_name) {
c.controlImage.image_url = image_url;
c.controlImage.thumbnail_url = thumbnail_url;
}
if (c.processedControlImage?.image_name === image_name) {
c.processedControlImage.image_url = image_url;
c.processedControlImage.thumbnail_url = thumbnail_url;
}
});
});
// forEach(state.controlNets, (c) => {
// if (c.controlImage?.image_name === image_name) {
// c.controlImage.image_url = image_url;
// c.controlImage.thumbnail_url = thumbnail_url;
// }
// if (c.processedControlImage?.image_name === image_name) {
// c.processedControlImage.image_url = image_url;
// c.processedControlImage.thumbnail_url = thumbnail_url;
// }
// });
// });
builder.addCase(appSocketInvocationError, (state) => {
builder.addCase(appSocketInvocationError, (state, action) => {
state.pendingControlImages = [];
});
builder.addMatcher(isAnySessionRejected, (state) => {
builder.addMatcher(isAnySessionRejected, (state, action) => {
state.pendingControlImages = [];
});
},

View File

@ -12,7 +12,7 @@ import {
OpenposeImageProcessorInvocation,
PidiImageProcessorInvocation,
ZoeDepthImageProcessorInvocation,
} from 'services/api';
} from 'services/api/types';
import { O } from 'ts-toolbelt';
/**