From 8d3bec57d52cadf1c4037d788d2a1bdb2e13d982 Mon Sep 17 00:00:00 2001
From: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
Date: Wed, 21 Jun 2023 13:48:59 +1000
Subject: [PATCH] feat(ui): store only image name in parameters
Images that are used as parameters (e.g. init image, canvas images) are stored as full `ImageDTO` objects in state, separate from and duplicating any object representing those same objects in the `imagesSlice`.
We cannot store only image names as parameters, then pull the full `ImageDTO` from `imagesSlice`, because if an image is not on a loaded page, it doesn't exist in `imagesSlice`. For example, if you scroll down a few pages in the gallery and send that image to canvas, on reloading the app, the canvas will be unable to load that image.
We solved this temporarily by storing the full `ImageDTO` object wherever it was needed, but this is both inefficient and allows for stale `ImageDTO`s across the app.
One other possible solution was to just fetch the `ImageDTO` for all images at startup, and insert them into the `imagesSlice`, but then we run into an issue where we are displaying images in the gallery totally out of context.
For example, if an image from several pages into the gallery was sent to canvas, and the user refreshes, we'd display the first 20 images in gallery. Then to populate the canvas, we'd fetch that image we sent to canvas and add it to `imagesSlice`. Now we'd have 21 images in the gallery: 1 to 20 and whichever image we sent to canvas. Weird.
Using `rtk-query` solves this by allowing us to very easily fetch individual images in the components that need them, and not directly interact with `imagesSlice`.
This commit changes all references to images-as-parameters to store only the name of the image, and not the full `ImageDTO` object. Then, we use an `rtk-query` generated `useGetImageDTOQuery()` hook in each of those components to fetch the image.
We can use cache invalidation when we mutate any image to trigger automated re-running of the query and all the images are automatically kept up to date.
This also obviates the need for the convoluted URL fetching scheme for images that are used as parameters. The `imagesSlice` still need this handling unfortunately.
---
.../src/app/contexts/DeleteImageContext.tsx | 10 +-
.../listeners/controlNetImageProcessed.ts | 4 +-
.../socketio/socketInvocationComplete.ts | 9 +-
.../listeners/updateImageUrlsOnConnect.ts | 14 +-
.../canvas/components/IAICanvasImage.tsx | 19 +-
.../components/IAICanvasObjectRenderer.tsx | 9 +-
.../components/IAICanvasStagingArea.tsx | 6 +-
.../src/features/canvas/store/canvasSlice.ts | 38 +-
.../src/features/canvas/store/canvasTypes.ts | 2 +-
.../components/ControlNetImagePreview.tsx | 33 +-
.../controlNet/store/controlNetSlice.ts | 40 +-
.../gallery/components/Boards/BoardsList.tsx | 2 +-
.../components/CurrentImagePreview.tsx | 15 +-
.../fields/ImageInputFieldComponent.tsx | 17 +-
.../web/src/features/nodes/types/types.ts | 2 +-
.../nodes/util/addControlNetToLinearGraph.ts | 6 +-
.../graphBuilders/buildImageToImageGraph.ts | 416 ++++++++++++++++++
.../nodeBuilders/buildImageToImageNode.ts | 2 +-
.../ImageToImage/InitialImagePreview.tsx | 19 +-
.../parameters/store/generationSlice.ts | 18 +-
.../frontend/web/src/services/apiSlice.ts | 89 ++--
21 files changed, 630 insertions(+), 140 deletions(-)
create mode 100644 invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildImageToImageGraph.ts
diff --git a/invokeai/frontend/web/src/app/contexts/DeleteImageContext.tsx b/invokeai/frontend/web/src/app/contexts/DeleteImageContext.tsx
index 8263b48114..50d80dcf28 100644
--- a/invokeai/frontend/web/src/app/contexts/DeleteImageContext.tsx
+++ b/invokeai/frontend/web/src/app/contexts/DeleteImageContext.tsx
@@ -35,25 +35,23 @@ export const selectImageUsage = createSelector(
(state: RootState, image_name?: string) => image_name,
],
(generation, canvas, nodes, controlNet, image_name) => {
- const isInitialImage = generation.initialImage?.image_name === image_name;
+ const isInitialImage = generation.initialImage === image_name;
const isCanvasImage = canvas.layerState.objects.some(
- (obj) => obj.kind === 'image' && obj.image.image_name === image_name
+ (obj) => obj.kind === 'image' && obj.imageName === image_name
);
const isNodesImage = nodes.nodes.some((node) => {
return some(
node.data.inputs,
- (input) =>
- input.type === 'image' && input.value?.image_name === image_name
+ (input) => input.type === 'image' && input.value === image_name
);
});
const isControlNetImage = some(
controlNet.controlNets,
(c) =>
- c.controlImage?.image_name === image_name ||
- c.processedControlImage?.image_name === image_name
+ c.controlImage === image_name || c.processedControlImage === image_name
);
const imageUsage: ImageUsage = {
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts
index ce1b515b84..7ff9a5118c 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts
@@ -34,7 +34,7 @@ export const addControlNetImageProcessedListener = () => {
[controlNet.processorNode.id]: {
...controlNet.processorNode,
is_intermediate: true,
- image: pick(controlNet.controlImage, ['image_name']),
+ image: { image_name: controlNet.controlImage },
},
},
};
@@ -81,7 +81,7 @@ export const addControlNetImageProcessedListener = () => {
dispatch(
controlNetProcessedImageChanged({
controlNetId,
- processedControlImage,
+ processedControlImage: processedControlImage.image_name,
})
);
}
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts
index 24e8eb312f..680f9c7041 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts
@@ -10,6 +10,7 @@ import { sessionCanceled } from 'services/thunks/session';
import { isImageOutput } from 'services/types/guards';
import { progressImageSet } from 'features/system/store/systemSlice';
import { imageAddedToBoard } from '../../../../../../services/thunks/board';
+import { api } from 'services/apiSlice';
const moduleLog = log.child({ namespace: 'socketio' });
const nodeDenylist = ['dataURL_image'];
@@ -42,11 +43,9 @@ export const addInvocationCompleteEventListener = () => {
if (boardIdToAddTo) {
dispatch(
- imageAddedToBoard({
- requestBody: {
- board_id: boardIdToAddTo,
- image_name,
- },
+ api.endpoints.addImageToBoard.initiate({
+ board_id: boardIdToAddTo,
+ image_name,
})
);
}
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateImageUrlsOnConnect.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateImageUrlsOnConnect.ts
index 7cb8012848..22182833b0 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateImageUrlsOnConnect.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateImageUrlsOnConnect.ts
@@ -22,7 +22,7 @@ const selectAllUsedImages = createSelector(
selectImagesEntities,
],
(generation, canvas, nodes, controlNet, imageEntities) => {
- const allUsedImages: ImageDTO[] = [];
+ const allUsedImages: string[] = [];
if (generation.initialImage) {
allUsedImages.push(generation.initialImage);
@@ -30,30 +30,30 @@ const selectAllUsedImages = createSelector(
canvas.layerState.objects.forEach((obj) => {
if (obj.kind === 'image') {
- allUsedImages.push(obj.image);
+ allUsedImages.push(obj.image.image_name);
}
});
nodes.nodes.forEach((node) => {
forEach(node.data.inputs, (input) => {
if (input.type === 'image' && input.value) {
- allUsedImages.push(input.value);
+ allUsedImages.push(input.value.image_name);
}
});
});
forEach(controlNet.controlNets, (c) => {
if (c.controlImage) {
- allUsedImages.push(c.controlImage);
+ allUsedImages.push(c.controlImage.image_name);
}
if (c.processedControlImage) {
- allUsedImages.push(c.processedControlImage);
+ allUsedImages.push(c.processedControlImage.image_name);
}
});
forEach(imageEntities, (image) => {
if (image) {
- allUsedImages.push(image);
+ allUsedImages.push(image.image_name);
}
});
@@ -80,7 +80,7 @@ export const addUpdateImageUrlsOnConnectListener = () => {
`Fetching new image URLs for ${allUsedImages.length} images`
);
- allUsedImages.forEach(({ image_name }) => {
+ allUsedImages.forEach((image_name) => {
dispatch(
imageUrlsReceived({
imageName: image_name,
diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvasImage.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvasImage.tsx
index b8757eff0c..c3132f0285 100644
--- a/invokeai/frontend/web/src/features/canvas/components/IAICanvasImage.tsx
+++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvasImage.tsx
@@ -1,14 +1,21 @@
-import { Image } from 'react-konva';
+import { skipToken } from '@reduxjs/toolkit/dist/query';
+import { Image, Rect } from 'react-konva';
+import { useGetImageDTOQuery } from 'services/apiSlice';
import useImage from 'use-image';
+import { CanvasImage } from '../store/canvasTypes';
type IAICanvasImageProps = {
- url: string;
- x: number;
- y: number;
+ canvasImage: CanvasImage;
};
const IAICanvasImage = (props: IAICanvasImageProps) => {
- const { url, x, y } = props;
- const [image] = useImage(url, 'anonymous');
+ const { width, height, x, y, imageName } = props.canvasImage;
+ const { data: imageDTO } = useGetImageDTOQuery(imageName ?? skipToken);
+ const [image] = useImage(imageDTO?.image_url ?? '', 'anonymous');
+
+ if (!imageDTO) {
+ return ;
+ }
+
return ;
};
diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvasObjectRenderer.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvasObjectRenderer.tsx
index ea04aa95c8..ec1e87cca7 100644
--- a/invokeai/frontend/web/src/features/canvas/components/IAICanvasObjectRenderer.tsx
+++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvasObjectRenderer.tsx
@@ -39,14 +39,7 @@ const IAICanvasObjectRenderer = () => {
{objects.map((obj, i) => {
if (isCanvasBaseImage(obj)) {
- return (
-
- );
+ return ;
} else if (isCanvasBaseLine(obj)) {
const line = (
{
return (
{shouldShowStagingImage && currentStagingAreaImage && (
-
+
)}
{shouldShowStagingOutline && (
diff --git a/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts b/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts
index b7092bf7e0..3e40c1211d 100644
--- a/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts
+++ b/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts
@@ -203,7 +203,7 @@ export const canvasSlice = createSlice({
y: 0,
width: width,
height: height,
- image: image,
+ imageName: image.image_name,
},
],
};
@@ -325,7 +325,7 @@ export const canvasSlice = createSlice({
kind: 'image',
layer: 'base',
...state.layerState.stagingArea.boundingBox,
- image,
+ imageName: image.image_name,
});
state.layerState.stagingArea.selectedImageIndex =
@@ -865,25 +865,25 @@ export const canvasSlice = createSlice({
state.doesCanvasNeedScaling = true;
});
- builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
- const { image_name, image_url, thumbnail_url } = action.payload;
+ // builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
+ // const { image_name, image_url, thumbnail_url } = action.payload;
- state.layerState.objects.forEach((object) => {
- if (object.kind === 'image') {
- if (object.image.image_name === image_name) {
- object.image.image_url = image_url;
- object.image.thumbnail_url = thumbnail_url;
- }
- }
- });
+ // state.layerState.objects.forEach((object) => {
+ // if (object.kind === 'image') {
+ // if (object.image.image_name === image_name) {
+ // object.image.image_url = image_url;
+ // object.image.thumbnail_url = thumbnail_url;
+ // }
+ // }
+ // });
- state.layerState.stagingArea.images.forEach((stagedImage) => {
- if (stagedImage.image.image_name === image_name) {
- stagedImage.image.image_url = image_url;
- stagedImage.image.thumbnail_url = thumbnail_url;
- }
- });
- });
+ // state.layerState.stagingArea.images.forEach((stagedImage) => {
+ // if (stagedImage.image.image_name === image_name) {
+ // stagedImage.image.image_url = image_url;
+ // stagedImage.image.thumbnail_url = thumbnail_url;
+ // }
+ // });
+ // });
},
});
diff --git a/invokeai/frontend/web/src/features/canvas/store/canvasTypes.ts b/invokeai/frontend/web/src/features/canvas/store/canvasTypes.ts
index ae78287a7b..9294e10d32 100644
--- a/invokeai/frontend/web/src/features/canvas/store/canvasTypes.ts
+++ b/invokeai/frontend/web/src/features/canvas/store/canvasTypes.ts
@@ -38,7 +38,7 @@ export type CanvasImage = {
y: number;
width: number;
height: number;
- image: ImageDTO;
+ imageName: string;
};
export type CanvasMaskLine = {
diff --git a/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx b/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx
index b8d8896dad..a121875f59 100644
--- a/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx
+++ b/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx
@@ -14,6 +14,8 @@ import { AnimatePresence, motion } from 'framer-motion';
import { IAIImageFallback } from 'common/components/IAIImageFallback';
import IAIIconButton from 'common/components/IAIIconButton';
import { FaUndo } from 'react-icons/fa';
+import { useGetImageDTOQuery } from 'services/apiSlice';
+import { skipToken } from '@reduxjs/toolkit/dist/query';
const selector = createSelector(
controlNetSelector,
@@ -31,24 +33,45 @@ type Props = {
const ControlNetImagePreview = (props: Props) => {
const { imageSx } = props;
- const { controlNetId, controlImage, processedControlImage, processorType } =
- props.controlNet;
+ const {
+ controlNetId,
+ controlImage: controlImageName,
+ processedControlImage: processedControlImageName,
+ processorType,
+ } = props.controlNet;
const dispatch = useAppDispatch();
const { pendingControlImages } = useAppSelector(selector);
const [isMouseOverImage, setIsMouseOverImage] = useState(false);
+ const {
+ data: controlImage,
+ isLoading: isLoadingControlImage,
+ isError: isErrorControlImage,
+ isSuccess: isSuccessControlImage,
+ } = useGetImageDTOQuery(controlImageName ?? skipToken);
+
+ const {
+ data: processedControlImage,
+ isLoading: isLoadingProcessedControlImage,
+ isError: isErrorProcessedControlImage,
+ isSuccess: isSuccessProcessedControlImage,
+ } = useGetImageDTOQuery(processedControlImageName ?? skipToken);
+
const handleDrop = useCallback(
(droppedImage: ImageDTO) => {
- if (controlImage?.image_name === droppedImage.image_name) {
+ if (controlImageName === droppedImage.image_name) {
return;
}
setIsMouseOverImage(false);
dispatch(
- controlNetImageChanged({ controlNetId, controlImage: droppedImage })
+ controlNetImageChanged({
+ controlNetId,
+ controlImage: droppedImage.image_name,
+ })
);
},
- [controlImage, controlNetId, dispatch]
+ [controlImageName, controlNetId, dispatch]
);
const handleResetControlImage = useCallback(() => {
diff --git a/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts b/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts
index f1b62cd997..5a54bdcd74 100644
--- a/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts
+++ b/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts
@@ -39,8 +39,8 @@ export type ControlNetConfig = {
weight: number;
beginStepPct: number;
endStepPct: number;
- controlImage: ImageDTO | null;
- processedControlImage: ImageDTO | null;
+ controlImage: string | null;
+ processedControlImage: string | null;
processorType: ControlNetProcessorType;
processorNode: RequiredControlNetProcessorNode;
shouldAutoConfig: boolean;
@@ -80,7 +80,7 @@ export const controlNetSlice = createSlice({
},
controlNetAddedFromImage: (
state,
- action: PayloadAction<{ controlNetId: string; controlImage: ImageDTO }>
+ action: PayloadAction<{ controlNetId: string; controlImage: string }>
) => {
const { controlNetId, controlImage } = action.payload;
state.controlNets[controlNetId] = {
@@ -108,7 +108,7 @@ export const controlNetSlice = createSlice({
state,
action: PayloadAction<{
controlNetId: string;
- controlImage: ImageDTO | null;
+ controlImage: string | null;
}>
) => {
const { controlNetId, controlImage } = action.payload;
@@ -125,7 +125,7 @@ export const controlNetSlice = createSlice({
state,
action: PayloadAction<{
controlNetId: string;
- processedControlImage: ImageDTO | null;
+ processedControlImage: string | null;
}>
) => {
const { controlNetId, processedControlImage } = action.payload;
@@ -260,30 +260,30 @@ export const controlNetSlice = createSlice({
// Preemptively remove the image from the gallery
const { imageName } = action.meta.arg;
forEach(state.controlNets, (c) => {
- if (c.controlImage?.image_name === imageName) {
+ if (c.controlImage === imageName) {
c.controlImage = null;
c.processedControlImage = null;
}
- if (c.processedControlImage?.image_name === imageName) {
+ if (c.processedControlImage === imageName) {
c.processedControlImage = null;
}
});
});
- builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
- const { image_name, image_url, thumbnail_url } = action.payload;
+ // builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
+ // const { image_name, image_url, thumbnail_url } = action.payload;
- forEach(state.controlNets, (c) => {
- if (c.controlImage?.image_name === image_name) {
- c.controlImage.image_url = image_url;
- c.controlImage.thumbnail_url = thumbnail_url;
- }
- if (c.processedControlImage?.image_name === image_name) {
- c.processedControlImage.image_url = image_url;
- c.processedControlImage.thumbnail_url = thumbnail_url;
- }
- });
- });
+ // forEach(state.controlNets, (c) => {
+ // if (c.controlImage?.image_name === image_name) {
+ // c.controlImage.image_url = image_url;
+ // c.controlImage.thumbnail_url = thumbnail_url;
+ // }
+ // if (c.processedControlImage?.image_name === image_name) {
+ // c.processedControlImage.image_url = image_url;
+ // c.processedControlImage.thumbnail_url = thumbnail_url;
+ // }
+ // });
+ // });
builder.addCase(appSocketInvocationError, (state, action) => {
state.pendingControlImages = [];
diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList.tsx
index be849e625e..5854c3fe7c 100644
--- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList.tsx
@@ -50,7 +50,7 @@ const BoardsList = () => {
? data?.items.filter((board) =>
board.board_name.toLowerCase().includes(searchText.toLowerCase())
)
- : data.items;
+ : data?.items;
const [searchMode, setSearchMode] = useState(false);
diff --git a/invokeai/frontend/web/src/features/gallery/components/CurrentImagePreview.tsx b/invokeai/frontend/web/src/features/gallery/components/CurrentImagePreview.tsx
index 649cae7682..bff32f1d78 100644
--- a/invokeai/frontend/web/src/features/gallery/components/CurrentImagePreview.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/CurrentImagePreview.tsx
@@ -17,6 +17,8 @@ import { ImageDTO } from 'services/api';
import { IAIImageFallback } from 'common/components/IAIImageFallback';
import { RootState } from 'app/store/store';
import { selectImagesById } from '../store/imagesSlice';
+import { useGetImageDTOQuery } from 'services/apiSlice';
+import { skipToken } from '@reduxjs/toolkit/dist/query';
export const imagesSelector = createSelector(
[uiSelector, gallerySelector, systemSelector],
@@ -53,9 +55,16 @@ const CurrentImagePreview = () => {
shouldAntialiasProgressImage,
} = useAppSelector(imagesSelector);
- const image = useAppSelector((state: RootState) =>
- selectImagesById(state, selectedImage ?? '')
- );
+ // const image = useAppSelector((state: RootState) =>
+ // selectImagesById(state, selectedImage ?? '')
+ // );
+
+ const {
+ data: image,
+ isLoading,
+ isError,
+ isSuccess,
+ } = useGetImageDTOQuery(selectedImage ?? skipToken);
const dispatch = useAppDispatch();
diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx
index dc4590e6ca..c5a3a1970b 100644
--- a/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx
@@ -11,6 +11,8 @@ import { FieldComponentProps } from './types';
import IAIDndImage from 'common/components/IAIDndImage';
import { ImageDTO } from 'services/api';
import { Flex } from '@chakra-ui/react';
+import { useGetImageDTOQuery } from 'services/apiSlice';
+import { skipToken } from '@reduxjs/toolkit/dist/query';
const ImageInputFieldComponent = (
props: FieldComponentProps
@@ -19,9 +21,16 @@ const ImageInputFieldComponent = (
const dispatch = useAppDispatch();
+ const {
+ data: image,
+ isLoading,
+ isError,
+ isSuccess,
+ } = useGetImageDTOQuery(field.value ?? skipToken);
+
const handleDrop = useCallback(
(droppedImage: ImageDTO) => {
- if (field.value?.image_name === droppedImage.image_name) {
+ if (field.value === droppedImage.image_name) {
return;
}
@@ -29,11 +38,11 @@ const ImageInputFieldComponent = (
fieldValueChanged({
nodeId,
fieldName: field.name,
- value: droppedImage,
+ value: droppedImage.image_name,
})
);
},
- [dispatch, field.name, field.value?.image_name, nodeId]
+ [dispatch, field.name, field.value, nodeId]
);
const handleReset = useCallback(() => {
@@ -56,7 +65,7 @@ const ImageInputFieldComponent = (
}}
>
{
+ const {
+ positivePrompt,
+ negativePrompt,
+ model,
+ cfgScale: cfg_scale,
+ scheduler,
+ steps,
+ initialImage,
+ img2imgStrength: strength,
+ shouldFitToWidthHeight,
+ width,
+ height,
+ iterations,
+ seed,
+ shouldRandomizeSeed,
+ } = state.generation;
+
+ if (!initialImage) {
+ moduleLog.error('No initial image found in state');
+ throw new Error('No initial image found in state');
+ }
+
+ const graph: NonNullableGraph = {
+ nodes: {},
+ edges: [],
+ };
+
+ // Create the positive conditioning (prompt) node
+ const positiveConditioningNode: CompelInvocation = {
+ id: POSITIVE_CONDITIONING,
+ type: 'compel',
+ prompt: positivePrompt,
+ model,
+ };
+
+ // Negative conditioning
+ const negativeConditioningNode: CompelInvocation = {
+ id: NEGATIVE_CONDITIONING,
+ type: 'compel',
+ prompt: negativePrompt,
+ model,
+ };
+
+ // This will encode the raster image to latents - but it may get its `image` from a resize node,
+ // so we do not set its `image` property yet
+ const imageToLatentsNode: ImageToLatentsInvocation = {
+ id: IMAGE_TO_LATENTS,
+ type: 'i2l',
+ model,
+ };
+
+ // This does the actual img2img inference
+ const latentsToLatentsNode: LatentsToLatentsInvocation = {
+ id: LATENTS_TO_LATENTS,
+ type: 'l2l',
+ cfg_scale,
+ model,
+ scheduler,
+ steps,
+ strength,
+ };
+
+ // Finally we decode the latents back to an image
+ const latentsToImageNode: LatentsToImageInvocation = {
+ id: LATENTS_TO_IMAGE,
+ type: 'l2i',
+ model,
+ };
+
+ // Add all those nodes to the graph
+ graph.nodes[POSITIVE_CONDITIONING] = positiveConditioningNode;
+ graph.nodes[NEGATIVE_CONDITIONING] = negativeConditioningNode;
+ graph.nodes[IMAGE_TO_LATENTS] = imageToLatentsNode;
+ graph.nodes[LATENTS_TO_LATENTS] = latentsToLatentsNode;
+ graph.nodes[LATENTS_TO_IMAGE] = latentsToImageNode;
+
+ // Connect the prompt nodes to the imageToLatents node
+ graph.edges.push({
+ source: { node_id: POSITIVE_CONDITIONING, field: 'conditioning' },
+ destination: {
+ node_id: LATENTS_TO_LATENTS,
+ field: 'positive_conditioning',
+ },
+ });
+ graph.edges.push({
+ source: { node_id: NEGATIVE_CONDITIONING, field: 'conditioning' },
+ destination: {
+ node_id: LATENTS_TO_LATENTS,
+ field: 'negative_conditioning',
+ },
+ });
+
+ // Connect the image-encoding node
+ graph.edges.push({
+ source: { node_id: IMAGE_TO_LATENTS, field: 'latents' },
+ destination: {
+ node_id: LATENTS_TO_LATENTS,
+ field: 'latents',
+ },
+ });
+
+ // Connect the image-decoding node
+ graph.edges.push({
+ source: { node_id: LATENTS_TO_LATENTS, field: 'latents' },
+ destination: {
+ node_id: LATENTS_TO_IMAGE,
+ field: 'latents',
+ },
+ });
+
+ /**
+ * Now we need to handle iterations and random seeds. There are four possible scenarios:
+ * - Single iteration, explicit seed
+ * - Single iteration, random seed
+ * - Multiple iterations, explicit seed
+ * - Multiple iterations, random seed
+ *
+ * They all have different graphs and connections.
+ */
+
+ // Single iteration, explicit seed
+ if (!shouldRandomizeSeed && iterations === 1) {
+ // Noise node using the explicit seed
+ const noiseNode: NoiseInvocation = {
+ id: NOISE,
+ type: 'noise',
+ seed: seed,
+ };
+
+ graph.nodes[NOISE] = noiseNode;
+
+ // Connect noise to l2l
+ graph.edges.push({
+ source: { node_id: NOISE, field: 'noise' },
+ destination: {
+ node_id: LATENTS_TO_LATENTS,
+ field: 'noise',
+ },
+ });
+ }
+
+ // Single iteration, random seed
+ if (shouldRandomizeSeed && iterations === 1) {
+ // Random int node to generate the seed
+ const randomIntNode: RandomIntInvocation = {
+ id: RANDOM_INT,
+ type: 'rand_int',
+ };
+
+ // Noise node without any seed
+ const noiseNode: NoiseInvocation = {
+ id: NOISE,
+ type: 'noise',
+ };
+
+ graph.nodes[RANDOM_INT] = randomIntNode;
+ graph.nodes[NOISE] = noiseNode;
+
+ // Connect random int to the seed of the noise node
+ graph.edges.push({
+ source: { node_id: RANDOM_INT, field: 'a' },
+ destination: {
+ node_id: NOISE,
+ field: 'seed',
+ },
+ });
+
+ // Connect noise to l2l
+ graph.edges.push({
+ source: { node_id: NOISE, field: 'noise' },
+ destination: {
+ node_id: LATENTS_TO_LATENTS,
+ field: 'noise',
+ },
+ });
+ }
+
+ // Multiple iterations, explicit seed
+ if (!shouldRandomizeSeed && iterations > 1) {
+ // Range of size node to generate `iterations` count of seeds - range of size generates a collection
+ // of ints from `start` to `start + size`. The `start` is the seed, and the `size` is the number of
+ // iterations.
+ const rangeOfSizeNode: RangeOfSizeInvocation = {
+ id: RANGE_OF_SIZE,
+ type: 'range_of_size',
+ start: seed,
+ size: iterations,
+ };
+
+ // Iterate node to iterate over the seeds generated by the range of size node
+ const iterateNode: IterateInvocation = {
+ id: ITERATE,
+ type: 'iterate',
+ };
+
+ // Noise node without any seed
+ const noiseNode: NoiseInvocation = {
+ id: NOISE,
+ type: 'noise',
+ };
+
+ // Adding to the graph
+ graph.nodes[RANGE_OF_SIZE] = rangeOfSizeNode;
+ graph.nodes[ITERATE] = iterateNode;
+ graph.nodes[NOISE] = noiseNode;
+
+ // Connect range of size to iterate
+ graph.edges.push({
+ source: { node_id: RANGE_OF_SIZE, field: 'collection' },
+ destination: {
+ node_id: ITERATE,
+ field: 'collection',
+ },
+ });
+
+ // Connect iterate to noise
+ graph.edges.push({
+ source: {
+ node_id: ITERATE,
+ field: 'item',
+ },
+ destination: {
+ node_id: NOISE,
+ field: 'seed',
+ },
+ });
+
+ // Connect noise to l2l
+ graph.edges.push({
+ source: { node_id: NOISE, field: 'noise' },
+ destination: {
+ node_id: LATENTS_TO_LATENTS,
+ field: 'noise',
+ },
+ });
+ }
+
+ // Multiple iterations, random seed
+ if (shouldRandomizeSeed && iterations > 1) {
+ // Random int node to generate the seed
+ const randomIntNode: RandomIntInvocation = {
+ id: RANDOM_INT,
+ type: 'rand_int',
+ };
+
+ // Range of size node to generate `iterations` count of seeds - range of size generates a collection
+ const rangeOfSizeNode: RangeOfSizeInvocation = {
+ id: RANGE_OF_SIZE,
+ type: 'range_of_size',
+ size: iterations,
+ };
+
+ // Iterate node to iterate over the seeds generated by the range of size node
+ const iterateNode: IterateInvocation = {
+ id: ITERATE,
+ type: 'iterate',
+ };
+
+ // Noise node without any seed
+ const noiseNode: NoiseInvocation = {
+ id: NOISE,
+ type: 'noise',
+ width,
+ height,
+ };
+
+ // Adding to the graph
+ graph.nodes[RANDOM_INT] = randomIntNode;
+ graph.nodes[RANGE_OF_SIZE] = rangeOfSizeNode;
+ graph.nodes[ITERATE] = iterateNode;
+ graph.nodes[NOISE] = noiseNode;
+
+ // Connect random int to the start of the range of size so the range starts on the random first seed
+ graph.edges.push({
+ source: { node_id: RANDOM_INT, field: 'a' },
+ destination: { node_id: RANGE_OF_SIZE, field: 'start' },
+ });
+
+ // Connect range of size to iterate
+ graph.edges.push({
+ source: { node_id: RANGE_OF_SIZE, field: 'collection' },
+ destination: {
+ node_id: ITERATE,
+ field: 'collection',
+ },
+ });
+
+ // Connect iterate to noise
+ graph.edges.push({
+ source: {
+ node_id: ITERATE,
+ field: 'item',
+ },
+ destination: {
+ node_id: NOISE,
+ field: 'seed',
+ },
+ });
+
+ // Connect noise to l2l
+ graph.edges.push({
+ source: { node_id: NOISE, field: 'noise' },
+ destination: {
+ node_id: LATENTS_TO_LATENTS,
+ field: 'noise',
+ },
+ });
+ }
+
+ if (
+ shouldFitToWidthHeight &&
+ (initialImage.width !== width || initialImage.height !== height)
+ ) {
+ // The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS`
+
+ // Create a resize node, explicitly setting its image
+ const resizeNode: ImageResizeInvocation = {
+ id: RESIZE,
+ type: 'img_resize',
+ image: {
+ image_name: initialImage,
+ },
+ is_intermediate: true,
+ height,
+ width,
+ };
+
+ graph.nodes[RESIZE] = resizeNode;
+
+ // The `RESIZE` node then passes its image to `IMAGE_TO_LATENTS`
+ graph.edges.push({
+ source: { node_id: RESIZE, field: 'image' },
+ destination: {
+ node_id: IMAGE_TO_LATENTS,
+ field: 'image',
+ },
+ });
+
+ // The `RESIZE` node also passes its width and height to `NOISE`
+ graph.edges.push({
+ source: { node_id: RESIZE, field: 'width' },
+ destination: {
+ node_id: NOISE,
+ field: 'width',
+ },
+ });
+
+ graph.edges.push({
+ source: { node_id: RESIZE, field: 'height' },
+ destination: {
+ node_id: NOISE,
+ field: 'height',
+ },
+ });
+ } else {
+ // We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
+ set(graph.nodes[IMAGE_TO_LATENTS], 'image', {
+ image_name: initialImage,
+ });
+
+ // Pass the image's dimensions to the `NOISE` node
+ graph.edges.push({
+ source: { node_id: IMAGE_TO_LATENTS, field: 'width' },
+ destination: {
+ node_id: NOISE,
+ field: 'width',
+ },
+ });
+ graph.edges.push({
+ source: { node_id: IMAGE_TO_LATENTS, field: 'height' },
+ destination: {
+ node_id: NOISE,
+ field: 'height',
+ },
+ });
+ }
+
+ addControlNetToLinearGraph(graph, LATENTS_TO_LATENTS, state);
+
+ return graph;
+};
diff --git a/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/buildImageToImageNode.ts b/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/buildImageToImageNode.ts
index e29b46af70..cc88328729 100644
--- a/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/buildImageToImageNode.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/buildImageToImageNode.ts
@@ -57,7 +57,7 @@ export const buildImg2ImgNode = (
}
imageToImageNode.image = {
- image_name: initialImage.image_name,
+ image_name: initialImage,
};
}
diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImagePreview.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImagePreview.tsx
index fa415074e6..d1f473b833 100644
--- a/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImagePreview.tsx
+++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImagePreview.tsx
@@ -11,6 +11,8 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIDndImage from 'common/components/IAIDndImage';
import { ImageDTO } from 'services/api';
import { IAIImageFallback } from 'common/components/IAIImageFallback';
+import { useGetImageDTOQuery } from 'services/apiSlice';
+import { skipToken } from '@reduxjs/toolkit/dist/query';
const selector = createSelector(
[generationSelector],
@@ -27,14 +29,21 @@ const InitialImagePreview = () => {
const { initialImage } = useAppSelector(selector);
const dispatch = useAppDispatch();
+ const {
+ data: image,
+ isLoading,
+ isError,
+ isSuccess,
+ } = useGetImageDTOQuery(initialImage ?? skipToken);
+
const handleDrop = useCallback(
- (droppedImage: ImageDTO) => {
- if (droppedImage.image_name === initialImage?.image_name) {
+ ({ image_name }: ImageDTO) => {
+ if (image_name === initialImage) {
return;
}
- dispatch(initialImageChanged(droppedImage));
+ dispatch(initialImageChanged(image_name));
},
- [dispatch, initialImage?.image_name]
+ [dispatch, initialImage]
);
const handleReset = useCallback(() => {
@@ -53,7 +62,7 @@ const InitialImagePreview = () => {
}}
>
}
diff --git a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts
index 961ea1b8af..001fc35138 100644
--- a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts
+++ b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts
@@ -24,7 +24,7 @@ export interface GenerationState {
height: HeightParam;
img2imgStrength: StrengthParam;
infillMethod: string;
- initialImage?: ImageDTO;
+ initialImage?: string;
iterations: number;
perlin: number;
positivePrompt: PositivePromptParam;
@@ -211,7 +211,7 @@ export const generationSlice = createSlice({
setShouldUseNoiseSettings: (state, action: PayloadAction) => {
state.shouldUseNoiseSettings = action.payload;
},
- initialImageChanged: (state, action: PayloadAction) => {
+ initialImageChanged: (state, action: PayloadAction) => {
state.initialImage = action.payload;
},
modelSelected: (state, action: PayloadAction) => {
@@ -233,14 +233,14 @@ export const generationSlice = createSlice({
}
});
- builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
- const { image_name, image_url, thumbnail_url } = action.payload;
+ // builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
+ // const { image_name, image_url, thumbnail_url } = action.payload;
- if (state.initialImage?.image_name === image_name) {
- state.initialImage.image_url = image_url;
- state.initialImage.thumbnail_url = thumbnail_url;
- }
- });
+ // if (state.initialImage?.image_name === image_name) {
+ // state.initialImage.image_url = image_url;
+ // state.initialImage.thumbnail_url = thumbnail_url;
+ // }
+ // });
},
});
diff --git a/invokeai/frontend/web/src/services/apiSlice.ts b/invokeai/frontend/web/src/services/apiSlice.ts
index 09eb061e29..9a1521ce5a 100644
--- a/invokeai/frontend/web/src/services/apiSlice.ts
+++ b/invokeai/frontend/web/src/services/apiSlice.ts
@@ -1,8 +1,18 @@
-import { createApi, fetchBaseQuery } from '@reduxjs/toolkit/query/react';
+import {
+ TagDescription,
+ createApi,
+ fetchBaseQuery,
+} from '@reduxjs/toolkit/query/react';
import { BoardDTO } from './api/models/BoardDTO';
import { OffsetPaginatedResults_BoardDTO_ } from './api/models/OffsetPaginatedResults_BoardDTO_';
import { BoardChanges } from './api/models/BoardChanges';
import { OffsetPaginatedResults_ImageDTO_ } from './api/models/OffsetPaginatedResults_ImageDTO_';
+import { ImageDTO } from './api/models/ImageDTO';
+import {
+ FullTagDescription,
+ TagTypesFrom,
+ TagTypesFromApi,
+} from '@reduxjs/toolkit/dist/query/endpointDefinitions';
type ListBoardsArg = { offset: number; limit: number };
type UpdateBoardArg = { board_id: string; changes: BoardChanges };
@@ -10,10 +20,15 @@ type AddImageToBoardArg = { board_id: string; image_name: string };
type RemoveImageFromBoardArg = { board_id: string; image_name: string };
type ListBoardImagesArg = { board_id: string; offset: number; limit: number };
+const tagTypes = ['Board', 'Image'];
+type ApiFullTagDescription = FullTagDescription<(typeof tagTypes)[number]>;
+
+const LIST = 'LIST';
+
export const api = createApi({
baseQuery: fetchBaseQuery({ baseUrl: 'http://localhost:5173/api/v1/' }),
reducerPath: 'api',
- tagTypes: ['Board'],
+ tagTypes,
endpoints: (build) => ({
/**
* Boards Queries
@@ -21,19 +36,20 @@ export const api = createApi({
listBoards: build.query({
query: (arg) => ({ url: 'boards/', params: arg }),
providesTags: (result, error, arg) => {
- if (!result) {
- // Provide the broad 'Board' tag until there is a response
- return ['Board'];
+ // any list of boards
+ const tags: ApiFullTagDescription[] = [{ id: 'Board', type: LIST }];
+
+ if (result) {
+ // and individual tags for each board
+ tags.push(
+ ...result.items.map(({ board_id }) => ({
+ type: 'Board' as const,
+ id: board_id,
+ }))
+ );
}
- // Provide the broad 'Board' tab, and individual tags for each board
- return [
- ...result.items.map(({ board_id }) => ({
- type: 'Board' as const,
- id: board_id,
- })),
- 'Board',
- ];
+ return tags;
},
}),
@@ -43,19 +59,20 @@ export const api = createApi({
params: { all: true },
}),
providesTags: (result, error, arg) => {
- if (!result) {
- // Provide the broad 'Board' tag until there is a response
- return ['Board'];
+ // any list of boards
+ const tags: ApiFullTagDescription[] = [{ id: 'Board', type: LIST }];
+
+ if (result) {
+ // and individual tags for each board
+ tags.push(
+ ...result.map(({ board_id }) => ({
+ type: 'Board' as const,
+ id: board_id,
+ }))
+ );
}
- // Provide the broad 'Board' tab, and individual tags for each board
- return [
- ...result.map(({ board_id }) => ({
- type: 'Board' as const,
- id: board_id,
- })),
- 'Board',
- ];
+ return tags;
},
}),
@@ -113,10 +130,10 @@ export const api = createApi({
method: 'POST',
body: { board_id, image_name },
}),
- invalidatesTags: ['Board'],
- // invalidatesTags: (result, error, arg) => [
- // { type: 'Board', id: arg.board_id },
- // ],
+ invalidatesTags: (result, error, arg) => [
+ { type: 'Board', id: arg.board_id },
+ { type: 'Image', id: arg.image_name },
+ ],
}),
removeImageFromBoard: build.mutation({
@@ -127,8 +144,23 @@ export const api = createApi({
}),
invalidatesTags: (result, error, arg) => [
{ type: 'Board', id: arg.board_id },
+ { type: 'Image', id: arg.image_name },
],
}),
+
+ /**
+ * Image Queries
+ */
+ getImageDTO: build.query({
+ query: (image_name) => ({ url: `images/${image_name}/metadata` }),
+ providesTags: (result, error, arg) => {
+ const tags: ApiFullTagDescription[] = [{ type: 'Image', id: arg }];
+ if (result?.board_id) {
+ tags.push({ type: 'Board', id: result.board_id });
+ }
+ return tags;
+ },
+ }),
}),
});
@@ -141,4 +173,5 @@ export const {
useAddImageToBoardMutation,
useRemoveImageFromBoardMutation,
useListBoardImagesQuery,
+ useGetImageDTOQuery,
} = api;