feat(ui): store only image name in parameters

Images that are used as parameters (e.g. init image, canvas images) are stored as full `ImageDTO` objects in state, separate from and duplicating any object representing those same objects in the `imagesSlice`.

We cannot store only image names as parameters, then pull the full `ImageDTO` from `imagesSlice`, because if an image is not on a loaded page, it doesn't exist in `imagesSlice`. For example, if you scroll down a few pages in the gallery and send that image to canvas, on reloading the app, the canvas will be unable to load that image.

We solved this temporarily by storing the full `ImageDTO` object wherever it was needed, but this is both inefficient and allows for stale `ImageDTO`s across the app.

One other possible solution was to just fetch the `ImageDTO` for all images at startup, and insert them into the `imagesSlice`, but then we run into an issue where we are displaying images in the gallery totally out of context.

For example, if an image from several pages into the gallery was sent to canvas, and the user refreshes, we'd display the first 20 images in gallery. Then to populate the canvas, we'd fetch that image we sent to canvas and add it to `imagesSlice`. Now we'd have 21 images in the gallery: 1 to 20 and whichever image we sent to canvas. Weird.

Using `rtk-query` solves this by allowing us to very easily fetch individual images in the components that need them, and not directly interact with `imagesSlice`.

This commit changes all references to images-as-parameters to store only the name of the image, and not the full `ImageDTO` object. Then, we use an `rtk-query` generated `useGetImageDTOQuery()` hook in each of those components to fetch the image.

We can use cache invalidation when we mutate any image to trigger automated re-running of the query and all the images are automatically kept up to date.

This also obviates the need for the convoluted URL fetching scheme for images that are used as parameters. The `imagesSlice` still need this handling unfortunately.
This commit is contained in:
psychedelicious 2023-06-21 13:48:59 +10:00
parent cfda128e06
commit 8d3bec57d5
21 changed files with 630 additions and 140 deletions

View File

@ -35,25 +35,23 @@ export const selectImageUsage = createSelector(
(state: RootState, image_name?: string) => image_name,
],
(generation, canvas, nodes, controlNet, image_name) => {
const isInitialImage = generation.initialImage?.image_name === image_name;
const isInitialImage = generation.initialImage === image_name;
const isCanvasImage = canvas.layerState.objects.some(
(obj) => obj.kind === 'image' && obj.image.image_name === image_name
(obj) => obj.kind === 'image' && obj.imageName === image_name
);
const isNodesImage = nodes.nodes.some((node) => {
return some(
node.data.inputs,
(input) =>
input.type === 'image' && input.value?.image_name === image_name
(input) => input.type === 'image' && input.value === image_name
);
});
const isControlNetImage = some(
controlNet.controlNets,
(c) =>
c.controlImage?.image_name === image_name ||
c.processedControlImage?.image_name === image_name
c.controlImage === image_name || c.processedControlImage === image_name
);
const imageUsage: ImageUsage = {

View File

@ -34,7 +34,7 @@ export const addControlNetImageProcessedListener = () => {
[controlNet.processorNode.id]: {
...controlNet.processorNode,
is_intermediate: true,
image: pick(controlNet.controlImage, ['image_name']),
image: { image_name: controlNet.controlImage },
},
},
};
@ -81,7 +81,7 @@ export const addControlNetImageProcessedListener = () => {
dispatch(
controlNetProcessedImageChanged({
controlNetId,
processedControlImage,
processedControlImage: processedControlImage.image_name,
})
);
}

View File

@ -10,6 +10,7 @@ import { sessionCanceled } from 'services/thunks/session';
import { isImageOutput } from 'services/types/guards';
import { progressImageSet } from 'features/system/store/systemSlice';
import { imageAddedToBoard } from '../../../../../../services/thunks/board';
import { api } from 'services/apiSlice';
const moduleLog = log.child({ namespace: 'socketio' });
const nodeDenylist = ['dataURL_image'];
@ -42,11 +43,9 @@ export const addInvocationCompleteEventListener = () => {
if (boardIdToAddTo) {
dispatch(
imageAddedToBoard({
requestBody: {
board_id: boardIdToAddTo,
image_name,
},
api.endpoints.addImageToBoard.initiate({
board_id: boardIdToAddTo,
image_name,
})
);
}

View File

@ -22,7 +22,7 @@ const selectAllUsedImages = createSelector(
selectImagesEntities,
],
(generation, canvas, nodes, controlNet, imageEntities) => {
const allUsedImages: ImageDTO[] = [];
const allUsedImages: string[] = [];
if (generation.initialImage) {
allUsedImages.push(generation.initialImage);
@ -30,30 +30,30 @@ const selectAllUsedImages = createSelector(
canvas.layerState.objects.forEach((obj) => {
if (obj.kind === 'image') {
allUsedImages.push(obj.image);
allUsedImages.push(obj.image.image_name);
}
});
nodes.nodes.forEach((node) => {
forEach(node.data.inputs, (input) => {
if (input.type === 'image' && input.value) {
allUsedImages.push(input.value);
allUsedImages.push(input.value.image_name);
}
});
});
forEach(controlNet.controlNets, (c) => {
if (c.controlImage) {
allUsedImages.push(c.controlImage);
allUsedImages.push(c.controlImage.image_name);
}
if (c.processedControlImage) {
allUsedImages.push(c.processedControlImage);
allUsedImages.push(c.processedControlImage.image_name);
}
});
forEach(imageEntities, (image) => {
if (image) {
allUsedImages.push(image);
allUsedImages.push(image.image_name);
}
});
@ -80,7 +80,7 @@ export const addUpdateImageUrlsOnConnectListener = () => {
`Fetching new image URLs for ${allUsedImages.length} images`
);
allUsedImages.forEach(({ image_name }) => {
allUsedImages.forEach((image_name) => {
dispatch(
imageUrlsReceived({
imageName: image_name,

View File

@ -1,14 +1,21 @@
import { Image } from 'react-konva';
import { skipToken } from '@reduxjs/toolkit/dist/query';
import { Image, Rect } from 'react-konva';
import { useGetImageDTOQuery } from 'services/apiSlice';
import useImage from 'use-image';
import { CanvasImage } from '../store/canvasTypes';
type IAICanvasImageProps = {
url: string;
x: number;
y: number;
canvasImage: CanvasImage;
};
const IAICanvasImage = (props: IAICanvasImageProps) => {
const { url, x, y } = props;
const [image] = useImage(url, 'anonymous');
const { width, height, x, y, imageName } = props.canvasImage;
const { data: imageDTO } = useGetImageDTOQuery(imageName ?? skipToken);
const [image] = useImage(imageDTO?.image_url ?? '', 'anonymous');
if (!imageDTO) {
return <Rect x={x} y={y} width={width} height={height} fill="red" />;
}
return <Image x={x} y={y} image={image} listening={false} />;
};

View File

@ -39,14 +39,7 @@ const IAICanvasObjectRenderer = () => {
<Group name="outpainting-objects" listening={false}>
{objects.map((obj, i) => {
if (isCanvasBaseImage(obj)) {
return (
<IAICanvasImage
key={i}
x={obj.x}
y={obj.y}
url={obj.image.image_url}
/>
);
return <IAICanvasImage key={i} canvasImage={obj} />;
} else if (isCanvasBaseLine(obj)) {
const line = (
<Line

View File

@ -59,11 +59,7 @@ const IAICanvasStagingArea = (props: Props) => {
return (
<Group {...rest}>
{shouldShowStagingImage && currentStagingAreaImage && (
<IAICanvasImage
url={currentStagingAreaImage.image.image_url}
x={x}
y={y}
/>
<IAICanvasImage canvasImage={currentStagingAreaImage} />
)}
{shouldShowStagingOutline && (
<Group>

View File

@ -203,7 +203,7 @@ export const canvasSlice = createSlice({
y: 0,
width: width,
height: height,
image: image,
imageName: image.image_name,
},
],
};
@ -325,7 +325,7 @@ export const canvasSlice = createSlice({
kind: 'image',
layer: 'base',
...state.layerState.stagingArea.boundingBox,
image,
imageName: image.image_name,
});
state.layerState.stagingArea.selectedImageIndex =
@ -865,25 +865,25 @@ export const canvasSlice = createSlice({
state.doesCanvasNeedScaling = true;
});
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_url, thumbnail_url } = action.payload;
// builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
// const { image_name, image_url, thumbnail_url } = action.payload;
state.layerState.objects.forEach((object) => {
if (object.kind === 'image') {
if (object.image.image_name === image_name) {
object.image.image_url = image_url;
object.image.thumbnail_url = thumbnail_url;
}
}
});
// state.layerState.objects.forEach((object) => {
// if (object.kind === 'image') {
// if (object.image.image_name === image_name) {
// object.image.image_url = image_url;
// object.image.thumbnail_url = thumbnail_url;
// }
// }
// });
state.layerState.stagingArea.images.forEach((stagedImage) => {
if (stagedImage.image.image_name === image_name) {
stagedImage.image.image_url = image_url;
stagedImage.image.thumbnail_url = thumbnail_url;
}
});
});
// state.layerState.stagingArea.images.forEach((stagedImage) => {
// if (stagedImage.image.image_name === image_name) {
// stagedImage.image.image_url = image_url;
// stagedImage.image.thumbnail_url = thumbnail_url;
// }
// });
// });
},
});

View File

@ -38,7 +38,7 @@ export type CanvasImage = {
y: number;
width: number;
height: number;
image: ImageDTO;
imageName: string;
};
export type CanvasMaskLine = {

View File

@ -14,6 +14,8 @@ import { AnimatePresence, motion } from 'framer-motion';
import { IAIImageFallback } from 'common/components/IAIImageFallback';
import IAIIconButton from 'common/components/IAIIconButton';
import { FaUndo } from 'react-icons/fa';
import { useGetImageDTOQuery } from 'services/apiSlice';
import { skipToken } from '@reduxjs/toolkit/dist/query';
const selector = createSelector(
controlNetSelector,
@ -31,24 +33,45 @@ type Props = {
const ControlNetImagePreview = (props: Props) => {
const { imageSx } = props;
const { controlNetId, controlImage, processedControlImage, processorType } =
props.controlNet;
const {
controlNetId,
controlImage: controlImageName,
processedControlImage: processedControlImageName,
processorType,
} = props.controlNet;
const dispatch = useAppDispatch();
const { pendingControlImages } = useAppSelector(selector);
const [isMouseOverImage, setIsMouseOverImage] = useState(false);
const {
data: controlImage,
isLoading: isLoadingControlImage,
isError: isErrorControlImage,
isSuccess: isSuccessControlImage,
} = useGetImageDTOQuery(controlImageName ?? skipToken);
const {
data: processedControlImage,
isLoading: isLoadingProcessedControlImage,
isError: isErrorProcessedControlImage,
isSuccess: isSuccessProcessedControlImage,
} = useGetImageDTOQuery(processedControlImageName ?? skipToken);
const handleDrop = useCallback(
(droppedImage: ImageDTO) => {
if (controlImage?.image_name === droppedImage.image_name) {
if (controlImageName === droppedImage.image_name) {
return;
}
setIsMouseOverImage(false);
dispatch(
controlNetImageChanged({ controlNetId, controlImage: droppedImage })
controlNetImageChanged({
controlNetId,
controlImage: droppedImage.image_name,
})
);
},
[controlImage, controlNetId, dispatch]
[controlImageName, controlNetId, dispatch]
);
const handleResetControlImage = useCallback(() => {

View File

@ -39,8 +39,8 @@ export type ControlNetConfig = {
weight: number;
beginStepPct: number;
endStepPct: number;
controlImage: ImageDTO | null;
processedControlImage: ImageDTO | null;
controlImage: string | null;
processedControlImage: string | null;
processorType: ControlNetProcessorType;
processorNode: RequiredControlNetProcessorNode;
shouldAutoConfig: boolean;
@ -80,7 +80,7 @@ export const controlNetSlice = createSlice({
},
controlNetAddedFromImage: (
state,
action: PayloadAction<{ controlNetId: string; controlImage: ImageDTO }>
action: PayloadAction<{ controlNetId: string; controlImage: string }>
) => {
const { controlNetId, controlImage } = action.payload;
state.controlNets[controlNetId] = {
@ -108,7 +108,7 @@ export const controlNetSlice = createSlice({
state,
action: PayloadAction<{
controlNetId: string;
controlImage: ImageDTO | null;
controlImage: string | null;
}>
) => {
const { controlNetId, controlImage } = action.payload;
@ -125,7 +125,7 @@ export const controlNetSlice = createSlice({
state,
action: PayloadAction<{
controlNetId: string;
processedControlImage: ImageDTO | null;
processedControlImage: string | null;
}>
) => {
const { controlNetId, processedControlImage } = action.payload;
@ -260,30 +260,30 @@ export const controlNetSlice = createSlice({
// Preemptively remove the image from the gallery
const { imageName } = action.meta.arg;
forEach(state.controlNets, (c) => {
if (c.controlImage?.image_name === imageName) {
if (c.controlImage === imageName) {
c.controlImage = null;
c.processedControlImage = null;
}
if (c.processedControlImage?.image_name === imageName) {
if (c.processedControlImage === imageName) {
c.processedControlImage = null;
}
});
});
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_url, thumbnail_url } = action.payload;
// builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
// const { image_name, image_url, thumbnail_url } = action.payload;
forEach(state.controlNets, (c) => {
if (c.controlImage?.image_name === image_name) {
c.controlImage.image_url = image_url;
c.controlImage.thumbnail_url = thumbnail_url;
}
if (c.processedControlImage?.image_name === image_name) {
c.processedControlImage.image_url = image_url;
c.processedControlImage.thumbnail_url = thumbnail_url;
}
});
});
// forEach(state.controlNets, (c) => {
// if (c.controlImage?.image_name === image_name) {
// c.controlImage.image_url = image_url;
// c.controlImage.thumbnail_url = thumbnail_url;
// }
// if (c.processedControlImage?.image_name === image_name) {
// c.processedControlImage.image_url = image_url;
// c.processedControlImage.thumbnail_url = thumbnail_url;
// }
// });
// });
builder.addCase(appSocketInvocationError, (state, action) => {
state.pendingControlImages = [];

View File

@ -50,7 +50,7 @@ const BoardsList = () => {
? data?.items.filter((board) =>
board.board_name.toLowerCase().includes(searchText.toLowerCase())
)
: data.items;
: data?.items;
const [searchMode, setSearchMode] = useState(false);

View File

@ -17,6 +17,8 @@ import { ImageDTO } from 'services/api';
import { IAIImageFallback } from 'common/components/IAIImageFallback';
import { RootState } from 'app/store/store';
import { selectImagesById } from '../store/imagesSlice';
import { useGetImageDTOQuery } from 'services/apiSlice';
import { skipToken } from '@reduxjs/toolkit/dist/query';
export const imagesSelector = createSelector(
[uiSelector, gallerySelector, systemSelector],
@ -53,9 +55,16 @@ const CurrentImagePreview = () => {
shouldAntialiasProgressImage,
} = useAppSelector(imagesSelector);
const image = useAppSelector((state: RootState) =>
selectImagesById(state, selectedImage ?? '')
);
// const image = useAppSelector((state: RootState) =>
// selectImagesById(state, selectedImage ?? '')
// );
const {
data: image,
isLoading,
isError,
isSuccess,
} = useGetImageDTOQuery(selectedImage ?? skipToken);
const dispatch = useAppDispatch();

View File

@ -11,6 +11,8 @@ import { FieldComponentProps } from './types';
import IAIDndImage from 'common/components/IAIDndImage';
import { ImageDTO } from 'services/api';
import { Flex } from '@chakra-ui/react';
import { useGetImageDTOQuery } from 'services/apiSlice';
import { skipToken } from '@reduxjs/toolkit/dist/query';
const ImageInputFieldComponent = (
props: FieldComponentProps<ImageInputFieldValue, ImageInputFieldTemplate>
@ -19,9 +21,16 @@ const ImageInputFieldComponent = (
const dispatch = useAppDispatch();
const {
data: image,
isLoading,
isError,
isSuccess,
} = useGetImageDTOQuery(field.value ?? skipToken);
const handleDrop = useCallback(
(droppedImage: ImageDTO) => {
if (field.value?.image_name === droppedImage.image_name) {
if (field.value === droppedImage.image_name) {
return;
}
@ -29,11 +38,11 @@ const ImageInputFieldComponent = (
fieldValueChanged({
nodeId,
fieldName: field.name,
value: droppedImage,
value: droppedImage.image_name,
})
);
},
[dispatch, field.name, field.value?.image_name, nodeId]
[dispatch, field.name, field.value, nodeId]
);
const handleReset = useCallback(() => {
@ -56,7 +65,7 @@ const ImageInputFieldComponent = (
}}
>
<IAIDndImage
image={field.value}
image={image}
onDrop={handleDrop}
onReset={handleReset}
resetIconSize="sm"

View File

@ -214,7 +214,7 @@ export type VaeInputFieldValue = FieldValueBase & {
export type ImageInputFieldValue = FieldValueBase & {
type: 'image';
value?: ImageDTO;
value?: string;
};
export type ModelInputFieldValue = FieldValueBase & {

View File

@ -65,15 +65,13 @@ export const addControlNetToLinearGraph = (
if (processedControlImage && processorType !== 'none') {
// We've already processed the image in the app, so we can just use the processed image
const { image_name } = processedControlImage;
controlNetNode.image = {
image_name,
image_name: processedControlImage,
};
} else if (controlImage) {
// The control image is preprocessed
const { image_name } = controlImage;
controlNetNode.image = {
image_name,
image_name: controlImage,
};
} else {
// Skip ControlNets without an unprocessed image - should never happen if everything is working correctly

View File

@ -0,0 +1,416 @@
import { RootState } from 'app/store/store';
import {
CompelInvocation,
Graph,
ImageResizeInvocation,
ImageToLatentsInvocation,
IterateInvocation,
LatentsToImageInvocation,
LatentsToLatentsInvocation,
NoiseInvocation,
RandomIntInvocation,
RangeOfSizeInvocation,
} from 'services/api';
import { NonNullableGraph } from 'features/nodes/types/types';
import { log } from 'app/logging/useLogger';
import { set } from 'lodash-es';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
const moduleLog = log.child({ namespace: 'nodes' });
const POSITIVE_CONDITIONING = 'positive_conditioning';
const NEGATIVE_CONDITIONING = 'negative_conditioning';
const IMAGE_TO_LATENTS = 'image_to_latents';
const LATENTS_TO_LATENTS = 'latents_to_latents';
const LATENTS_TO_IMAGE = 'latents_to_image';
const RESIZE = 'resize_image';
const NOISE = 'noise';
const RANDOM_INT = 'rand_int';
const RANGE_OF_SIZE = 'range_of_size';
const ITERATE = 'iterate';
/**
* Builds the Image to Image tab graph.
*/
export const buildImageToImageGraph = (state: RootState): Graph => {
const {
positivePrompt,
negativePrompt,
model,
cfgScale: cfg_scale,
scheduler,
steps,
initialImage,
img2imgStrength: strength,
shouldFitToWidthHeight,
width,
height,
iterations,
seed,
shouldRandomizeSeed,
} = state.generation;
if (!initialImage) {
moduleLog.error('No initial image found in state');
throw new Error('No initial image found in state');
}
const graph: NonNullableGraph = {
nodes: {},
edges: [],
};
// Create the positive conditioning (prompt) node
const positiveConditioningNode: CompelInvocation = {
id: POSITIVE_CONDITIONING,
type: 'compel',
prompt: positivePrompt,
model,
};
// Negative conditioning
const negativeConditioningNode: CompelInvocation = {
id: NEGATIVE_CONDITIONING,
type: 'compel',
prompt: negativePrompt,
model,
};
// This will encode the raster image to latents - but it may get its `image` from a resize node,
// so we do not set its `image` property yet
const imageToLatentsNode: ImageToLatentsInvocation = {
id: IMAGE_TO_LATENTS,
type: 'i2l',
model,
};
// This does the actual img2img inference
const latentsToLatentsNode: LatentsToLatentsInvocation = {
id: LATENTS_TO_LATENTS,
type: 'l2l',
cfg_scale,
model,
scheduler,
steps,
strength,
};
// Finally we decode the latents back to an image
const latentsToImageNode: LatentsToImageInvocation = {
id: LATENTS_TO_IMAGE,
type: 'l2i',
model,
};
// Add all those nodes to the graph
graph.nodes[POSITIVE_CONDITIONING] = positiveConditioningNode;
graph.nodes[NEGATIVE_CONDITIONING] = negativeConditioningNode;
graph.nodes[IMAGE_TO_LATENTS] = imageToLatentsNode;
graph.nodes[LATENTS_TO_LATENTS] = latentsToLatentsNode;
graph.nodes[LATENTS_TO_IMAGE] = latentsToImageNode;
// Connect the prompt nodes to the imageToLatents node
graph.edges.push({
source: { node_id: POSITIVE_CONDITIONING, field: 'conditioning' },
destination: {
node_id: LATENTS_TO_LATENTS,
field: 'positive_conditioning',
},
});
graph.edges.push({
source: { node_id: NEGATIVE_CONDITIONING, field: 'conditioning' },
destination: {
node_id: LATENTS_TO_LATENTS,
field: 'negative_conditioning',
},
});
// Connect the image-encoding node
graph.edges.push({
source: { node_id: IMAGE_TO_LATENTS, field: 'latents' },
destination: {
node_id: LATENTS_TO_LATENTS,
field: 'latents',
},
});
// Connect the image-decoding node
graph.edges.push({
source: { node_id: LATENTS_TO_LATENTS, field: 'latents' },
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'latents',
},
});
/**
* Now we need to handle iterations and random seeds. There are four possible scenarios:
* - Single iteration, explicit seed
* - Single iteration, random seed
* - Multiple iterations, explicit seed
* - Multiple iterations, random seed
*
* They all have different graphs and connections.
*/
// Single iteration, explicit seed
if (!shouldRandomizeSeed && iterations === 1) {
// Noise node using the explicit seed
const noiseNode: NoiseInvocation = {
id: NOISE,
type: 'noise',
seed: seed,
};
graph.nodes[NOISE] = noiseNode;
// Connect noise to l2l
graph.edges.push({
source: { node_id: NOISE, field: 'noise' },
destination: {
node_id: LATENTS_TO_LATENTS,
field: 'noise',
},
});
}
// Single iteration, random seed
if (shouldRandomizeSeed && iterations === 1) {
// Random int node to generate the seed
const randomIntNode: RandomIntInvocation = {
id: RANDOM_INT,
type: 'rand_int',
};
// Noise node without any seed
const noiseNode: NoiseInvocation = {
id: NOISE,
type: 'noise',
};
graph.nodes[RANDOM_INT] = randomIntNode;
graph.nodes[NOISE] = noiseNode;
// Connect random int to the seed of the noise node
graph.edges.push({
source: { node_id: RANDOM_INT, field: 'a' },
destination: {
node_id: NOISE,
field: 'seed',
},
});
// Connect noise to l2l
graph.edges.push({
source: { node_id: NOISE, field: 'noise' },
destination: {
node_id: LATENTS_TO_LATENTS,
field: 'noise',
},
});
}
// Multiple iterations, explicit seed
if (!shouldRandomizeSeed && iterations > 1) {
// Range of size node to generate `iterations` count of seeds - range of size generates a collection
// of ints from `start` to `start + size`. The `start` is the seed, and the `size` is the number of
// iterations.
const rangeOfSizeNode: RangeOfSizeInvocation = {
id: RANGE_OF_SIZE,
type: 'range_of_size',
start: seed,
size: iterations,
};
// Iterate node to iterate over the seeds generated by the range of size node
const iterateNode: IterateInvocation = {
id: ITERATE,
type: 'iterate',
};
// Noise node without any seed
const noiseNode: NoiseInvocation = {
id: NOISE,
type: 'noise',
};
// Adding to the graph
graph.nodes[RANGE_OF_SIZE] = rangeOfSizeNode;
graph.nodes[ITERATE] = iterateNode;
graph.nodes[NOISE] = noiseNode;
// Connect range of size to iterate
graph.edges.push({
source: { node_id: RANGE_OF_SIZE, field: 'collection' },
destination: {
node_id: ITERATE,
field: 'collection',
},
});
// Connect iterate to noise
graph.edges.push({
source: {
node_id: ITERATE,
field: 'item',
},
destination: {
node_id: NOISE,
field: 'seed',
},
});
// Connect noise to l2l
graph.edges.push({
source: { node_id: NOISE, field: 'noise' },
destination: {
node_id: LATENTS_TO_LATENTS,
field: 'noise',
},
});
}
// Multiple iterations, random seed
if (shouldRandomizeSeed && iterations > 1) {
// Random int node to generate the seed
const randomIntNode: RandomIntInvocation = {
id: RANDOM_INT,
type: 'rand_int',
};
// Range of size node to generate `iterations` count of seeds - range of size generates a collection
const rangeOfSizeNode: RangeOfSizeInvocation = {
id: RANGE_OF_SIZE,
type: 'range_of_size',
size: iterations,
};
// Iterate node to iterate over the seeds generated by the range of size node
const iterateNode: IterateInvocation = {
id: ITERATE,
type: 'iterate',
};
// Noise node without any seed
const noiseNode: NoiseInvocation = {
id: NOISE,
type: 'noise',
width,
height,
};
// Adding to the graph
graph.nodes[RANDOM_INT] = randomIntNode;
graph.nodes[RANGE_OF_SIZE] = rangeOfSizeNode;
graph.nodes[ITERATE] = iterateNode;
graph.nodes[NOISE] = noiseNode;
// Connect random int to the start of the range of size so the range starts on the random first seed
graph.edges.push({
source: { node_id: RANDOM_INT, field: 'a' },
destination: { node_id: RANGE_OF_SIZE, field: 'start' },
});
// Connect range of size to iterate
graph.edges.push({
source: { node_id: RANGE_OF_SIZE, field: 'collection' },
destination: {
node_id: ITERATE,
field: 'collection',
},
});
// Connect iterate to noise
graph.edges.push({
source: {
node_id: ITERATE,
field: 'item',
},
destination: {
node_id: NOISE,
field: 'seed',
},
});
// Connect noise to l2l
graph.edges.push({
source: { node_id: NOISE, field: 'noise' },
destination: {
node_id: LATENTS_TO_LATENTS,
field: 'noise',
},
});
}
if (
shouldFitToWidthHeight &&
(initialImage.width !== width || initialImage.height !== height)
) {
// The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS`
// Create a resize node, explicitly setting its image
const resizeNode: ImageResizeInvocation = {
id: RESIZE,
type: 'img_resize',
image: {
image_name: initialImage,
},
is_intermediate: true,
height,
width,
};
graph.nodes[RESIZE] = resizeNode;
// The `RESIZE` node then passes its image to `IMAGE_TO_LATENTS`
graph.edges.push({
source: { node_id: RESIZE, field: 'image' },
destination: {
node_id: IMAGE_TO_LATENTS,
field: 'image',
},
});
// The `RESIZE` node also passes its width and height to `NOISE`
graph.edges.push({
source: { node_id: RESIZE, field: 'width' },
destination: {
node_id: NOISE,
field: 'width',
},
});
graph.edges.push({
source: { node_id: RESIZE, field: 'height' },
destination: {
node_id: NOISE,
field: 'height',
},
});
} else {
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
set(graph.nodes[IMAGE_TO_LATENTS], 'image', {
image_name: initialImage,
});
// Pass the image's dimensions to the `NOISE` node
graph.edges.push({
source: { node_id: IMAGE_TO_LATENTS, field: 'width' },
destination: {
node_id: NOISE,
field: 'width',
},
});
graph.edges.push({
source: { node_id: IMAGE_TO_LATENTS, field: 'height' },
destination: {
node_id: NOISE,
field: 'height',
},
});
}
addControlNetToLinearGraph(graph, LATENTS_TO_LATENTS, state);
return graph;
};

View File

@ -57,7 +57,7 @@ export const buildImg2ImgNode = (
}
imageToImageNode.image = {
image_name: initialImage.image_name,
image_name: initialImage,
};
}

View File

@ -11,6 +11,8 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIDndImage from 'common/components/IAIDndImage';
import { ImageDTO } from 'services/api';
import { IAIImageFallback } from 'common/components/IAIImageFallback';
import { useGetImageDTOQuery } from 'services/apiSlice';
import { skipToken } from '@reduxjs/toolkit/dist/query';
const selector = createSelector(
[generationSelector],
@ -27,14 +29,21 @@ const InitialImagePreview = () => {
const { initialImage } = useAppSelector(selector);
const dispatch = useAppDispatch();
const {
data: image,
isLoading,
isError,
isSuccess,
} = useGetImageDTOQuery(initialImage ?? skipToken);
const handleDrop = useCallback(
(droppedImage: ImageDTO) => {
if (droppedImage.image_name === initialImage?.image_name) {
({ image_name }: ImageDTO) => {
if (image_name === initialImage) {
return;
}
dispatch(initialImageChanged(droppedImage));
dispatch(initialImageChanged(image_name));
},
[dispatch, initialImage?.image_name]
[dispatch, initialImage]
);
const handleReset = useCallback(() => {
@ -53,7 +62,7 @@ const InitialImagePreview = () => {
}}
>
<IAIDndImage
image={initialImage}
image={image}
onDrop={handleDrop}
onReset={handleReset}
fallback={<IAIImageFallback sx={{ bg: 'none' }} />}

View File

@ -24,7 +24,7 @@ export interface GenerationState {
height: HeightParam;
img2imgStrength: StrengthParam;
infillMethod: string;
initialImage?: ImageDTO;
initialImage?: string;
iterations: number;
perlin: number;
positivePrompt: PositivePromptParam;
@ -211,7 +211,7 @@ export const generationSlice = createSlice({
setShouldUseNoiseSettings: (state, action: PayloadAction<boolean>) => {
state.shouldUseNoiseSettings = action.payload;
},
initialImageChanged: (state, action: PayloadAction<ImageDTO>) => {
initialImageChanged: (state, action: PayloadAction<string>) => {
state.initialImage = action.payload;
},
modelSelected: (state, action: PayloadAction<string>) => {
@ -233,14 +233,14 @@ export const generationSlice = createSlice({
}
});
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_url, thumbnail_url } = action.payload;
// builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
// const { image_name, image_url, thumbnail_url } = action.payload;
if (state.initialImage?.image_name === image_name) {
state.initialImage.image_url = image_url;
state.initialImage.thumbnail_url = thumbnail_url;
}
});
// if (state.initialImage?.image_name === image_name) {
// state.initialImage.image_url = image_url;
// state.initialImage.thumbnail_url = thumbnail_url;
// }
// });
},
});

View File

@ -1,8 +1,18 @@
import { createApi, fetchBaseQuery } from '@reduxjs/toolkit/query/react';
import {
TagDescription,
createApi,
fetchBaseQuery,
} from '@reduxjs/toolkit/query/react';
import { BoardDTO } from './api/models/BoardDTO';
import { OffsetPaginatedResults_BoardDTO_ } from './api/models/OffsetPaginatedResults_BoardDTO_';
import { BoardChanges } from './api/models/BoardChanges';
import { OffsetPaginatedResults_ImageDTO_ } from './api/models/OffsetPaginatedResults_ImageDTO_';
import { ImageDTO } from './api/models/ImageDTO';
import {
FullTagDescription,
TagTypesFrom,
TagTypesFromApi,
} from '@reduxjs/toolkit/dist/query/endpointDefinitions';
type ListBoardsArg = { offset: number; limit: number };
type UpdateBoardArg = { board_id: string; changes: BoardChanges };
@ -10,10 +20,15 @@ type AddImageToBoardArg = { board_id: string; image_name: string };
type RemoveImageFromBoardArg = { board_id: string; image_name: string };
type ListBoardImagesArg = { board_id: string; offset: number; limit: number };
const tagTypes = ['Board', 'Image'];
type ApiFullTagDescription = FullTagDescription<(typeof tagTypes)[number]>;
const LIST = 'LIST';
export const api = createApi({
baseQuery: fetchBaseQuery({ baseUrl: 'http://localhost:5173/api/v1/' }),
reducerPath: 'api',
tagTypes: ['Board'],
tagTypes,
endpoints: (build) => ({
/**
* Boards Queries
@ -21,19 +36,20 @@ export const api = createApi({
listBoards: build.query<OffsetPaginatedResults_BoardDTO_, ListBoardsArg>({
query: (arg) => ({ url: 'boards/', params: arg }),
providesTags: (result, error, arg) => {
if (!result) {
// Provide the broad 'Board' tag until there is a response
return ['Board'];
// any list of boards
const tags: ApiFullTagDescription[] = [{ id: 'Board', type: LIST }];
if (result) {
// and individual tags for each board
tags.push(
...result.items.map(({ board_id }) => ({
type: 'Board' as const,
id: board_id,
}))
);
}
// Provide the broad 'Board' tab, and individual tags for each board
return [
...result.items.map(({ board_id }) => ({
type: 'Board' as const,
id: board_id,
})),
'Board',
];
return tags;
},
}),
@ -43,19 +59,20 @@ export const api = createApi({
params: { all: true },
}),
providesTags: (result, error, arg) => {
if (!result) {
// Provide the broad 'Board' tag until there is a response
return ['Board'];
// any list of boards
const tags: ApiFullTagDescription[] = [{ id: 'Board', type: LIST }];
if (result) {
// and individual tags for each board
tags.push(
...result.map(({ board_id }) => ({
type: 'Board' as const,
id: board_id,
}))
);
}
// Provide the broad 'Board' tab, and individual tags for each board
return [
...result.map(({ board_id }) => ({
type: 'Board' as const,
id: board_id,
})),
'Board',
];
return tags;
},
}),
@ -113,10 +130,10 @@ export const api = createApi({
method: 'POST',
body: { board_id, image_name },
}),
invalidatesTags: ['Board'],
// invalidatesTags: (result, error, arg) => [
// { type: 'Board', id: arg.board_id },
// ],
invalidatesTags: (result, error, arg) => [
{ type: 'Board', id: arg.board_id },
{ type: 'Image', id: arg.image_name },
],
}),
removeImageFromBoard: build.mutation<void, RemoveImageFromBoardArg>({
@ -127,8 +144,23 @@ export const api = createApi({
}),
invalidatesTags: (result, error, arg) => [
{ type: 'Board', id: arg.board_id },
{ type: 'Image', id: arg.image_name },
],
}),
/**
* Image Queries
*/
getImageDTO: build.query<ImageDTO, string>({
query: (image_name) => ({ url: `images/${image_name}/metadata` }),
providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [{ type: 'Image', id: arg }];
if (result?.board_id) {
tags.push({ type: 'Board', id: result.board_id });
}
return tags;
},
}),
}),
});
@ -141,4 +173,5 @@ export const {
useAddImageToBoardMutation,
useRemoveImageFromBoardMutation,
useListBoardImagesQuery,
useGetImageDTOQuery,
} = api;