feat(ui): cache control layer mask images

When invoking with control layers, we were creating and uploading the mask images on every enqueue, even when the mask didn't change. The mask image can be cached to greatly reduce the number of uploads.

With this change, we are a bit smarter about the mask images:
- Check if there is an uploaded mask image name
- If so, attempt to retrieve its DTO. Typically it will be in the RTKQ cache, so there is no network request, but it will make a network request if not cached to confirm the image actually exists on the server.
- If we don't have an uploaded mask image name, or the request fails, we go ahead and upload the generated blob
- Update the layer's state with a reference to this uploaded image for next time
- Continue as before

Any time we modify the mask (drawing/erasing, resetting the layer), we invalidate that cached image name (set it to null).

We now only upload images when we need to and generation starts faster.
This commit is contained in:
psychedelicious 2024-05-03 18:43:45 +10:00
parent 3cba53533d
commit af9f0e0963
3 changed files with 43 additions and 19 deletions

View File

@ -86,6 +86,7 @@ const resetLayer = (layer: Layer) => {
layer.isEnabled = true;
layer.needsPixelBbox = false;
layer.bboxNeedsUpdate = false;
layer.uploadedMaskImage = null;
return;
}
};
@ -173,6 +174,7 @@ export const controlLayersSlice = createSlice({
if (bbox === null && layer.type === 'regional_guidance_layer') {
// The layer was fully erased, empty its objects to prevent accumulation of invisible objects
layer.maskObjects = [];
layer.uploadedMaskImage = null;
layer.needsPixelBbox = false;
}
}
@ -456,6 +458,7 @@ export const controlLayersSlice = createSlice({
negativePrompt: null,
ipAdapters: [],
isSelected: true,
uploadedMaskImage: null,
};
state.layers.push(layer);
state.selectedLayerId = layer.id;
@ -505,6 +508,7 @@ export const controlLayersSlice = createSlice({
strokeWidth: state.brushSize,
});
layer.bboxNeedsUpdate = true;
layer.uploadedMaskImage = null;
if (!layer.needsPixelBbox && tool === 'eraser') {
layer.needsPixelBbox = true;
}
@ -524,6 +528,7 @@ export const controlLayersSlice = createSlice({
// TODO: Handle this in the event listener
lastLine.points.push(point[0] - layer.x, point[1] - layer.y);
layer.bboxNeedsUpdate = true;
layer.uploadedMaskImage = null;
},
rgLayerRectAdded: {
reducer: (state, action: PayloadAction<{ layerId: string; rect: IRect; rectUuid: string }>) => {
@ -543,9 +548,15 @@ export const controlLayersSlice = createSlice({
height: rect.height,
});
layer.bboxNeedsUpdate = true;
layer.uploadedMaskImage = null;
},
prepare: (payload: { layerId: string; rect: IRect }) => ({ payload: { ...payload, rectUuid: uuidv4() } }),
},
rgLayerMaskImageUploaded: (state, action: PayloadAction<{ layerId: string; imageDTO: ImageDTO }>) => {
const { layerId, imageDTO } = action.payload;
const layer = selectRGLayerOrThrow(state, layerId);
layer.uploadedMaskImage = imageDTOToImageWithDims(imageDTO);
},
rgLayerAutoNegativeChanged: (
state,
action: PayloadAction<{ layerId: string; autoNegative: ParameterAutoNegative }>
@ -825,6 +836,7 @@ export const {
rgLayerLineAdded,
rgLayerPointsAdded,
rgLayerRectAdded,
rgLayerMaskImageUploaded,
rgLayerAutoNegativeChanged,
rgLayerIPAdapterAdded,
rgLayerIPAdapterDeleted,

View File

@ -72,6 +72,7 @@ export type RegionalGuidanceLayer = RenderableLayerBase & {
previewColor: RgbColor;
autoNegative: ParameterAutoNegative;
needsPixelBbox: boolean; // Needs the slower pixel-based bbox calculation - set to true when an there is an eraser object
uploadedMaskImage: ImageWithDims | null;
};
export type InitialImageLayer = RenderableLayerBase & {

View File

@ -4,7 +4,9 @@ import {
isControlAdapterLayer,
isIPAdapterLayer,
isRegionalGuidanceLayer,
rgLayerMaskImageUploaded,
} from 'features/controlLayers/store/controlLayersSlice';
import type { RegionalGuidanceLayer } from 'features/controlLayers/store/types';
import {
type ControlNetConfigV2,
type ImageWithDims,
@ -32,12 +34,13 @@ import {
} from 'features/nodes/util/graph/constants';
import { upsertMetadata } from 'features/nodes/util/graph/metadata';
import { size } from 'lodash-es';
import { imagesApi } from 'services/api/endpoints/images';
import { getImageDTO, imagesApi } from 'services/api/endpoints/images';
import type {
CollectInvocation,
ControlNetInvocation,
CoreMetadataInvocation,
Edge,
ImageDTO,
IPAdapterInvocation,
NonNullableGraph,
S,
@ -337,7 +340,6 @@ const addGlobalIPAdaptersToGraph = async (
};
export const addControlLayersToGraph = async (state: RootState, graph: NonNullableGraph, denoiseNodeId: string) => {
const { dispatch } = getStore();
const mainModel = state.generation.model;
assert(mainModel, 'Missing main model when building graph');
const isSDXL = mainModel.base === 'sdxl';
@ -404,10 +406,6 @@ export const addControlLayersToGraph = async (state: RootState, graph: NonNullab
return hasTextPrompt || hasIPAdapter;
});
const layerIds = rgLayers.map((l) => l.id);
const blobs = await getRegionalPromptLayerBlobs(layerIds);
assert(size(blobs) === size(layerIds), 'Mismatch between layer IDs and blobs');
// TODO: We should probably just use conditioning collectors by default, and skip all this fanagling with re-routing
// the existing conditioning nodes.
@ -470,22 +468,15 @@ export const addControlLayersToGraph = async (state: RootState, graph: NonNullab
},
});
// Upload the blobs to the backend, add each to graph
// TODO: Store the uploaded image names in redux to reuse them, so long as the layer hasn't otherwise changed. This
// would be a great perf win - not only would we skip re-uploading the same image, but we'd be able to use the node
// cache (currently, when we re-use the same mask data, since it is a different image, the node cache is not used).
const layerIds = rgLayers.map((l) => l.id);
const blobs = await getRegionalPromptLayerBlobs(layerIds);
assert(size(blobs) === size(layerIds), 'Mismatch between layer IDs and blobs');
for (const layer of rgLayers) {
const blob = blobs[layer.id];
assert(blob, `Blob for layer ${layer.id} not found`);
const file = new File([blob], `${layer.id}_mask.png`, { type: 'image/png' });
const req = dispatch(
imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'mask', is_intermediate: true })
);
req.reset();
// TODO: This will raise on network error
const { image_name } = await req.unwrap();
// Upload the mask image, or get the cached image if it exists
const { image_name } = await getMaskImage(layer, blob);
// The main mask-to-tensor node
const maskToTensorNode: S['AlphaMaskToTensorInvocation'] = {
@ -679,3 +670,23 @@ export const addControlLayersToGraph = async (state: RootState, graph: NonNullab
}
}
};
const getMaskImage = async (layer: RegionalGuidanceLayer, blob: Blob): Promise<ImageDTO> => {
if (layer.uploadedMaskImage) {
const imageDTO = await getImageDTO(layer.uploadedMaskImage.imageName);
if (imageDTO) {
return imageDTO;
}
}
const { dispatch } = getStore();
// No cached mask, or the cached image no longer exists - we need to upload the mask image
const file = new File([blob], `${layer.id}_mask.png`, { type: 'image/png' });
const req = dispatch(
imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'mask', is_intermediate: true })
);
req.reset();
const imageDTO = await req.unwrap();
dispatch(rgLayerMaskImageUploaded({ layerId: layer.id, imageDTO }));
return imageDTO;
};