mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): cache control layer mask images
When invoking with control layers, we were creating and uploading the mask images on every enqueue, even when the mask didn't change. The mask image can be cached to greatly reduce the number of uploads. With this change, we are a bit smarter about the mask images: - Check if there is an uploaded mask image name - If so, attempt to retrieve its DTO. Typically it will be in the RTKQ cache, so there is no network request, but it will make a network request if not cached to confirm the image actually exists on the server. - If we don't have an uploaded mask image name, or the request fails, we go ahead and upload the generated blob - Update the layer's state with a reference to this uploaded image for next time - Continue as before Any time we modify the mask (drawing/erasing, resetting the layer), we invalidate that cached image name (set it to null). We now only upload images when we need to and generation starts faster.
This commit is contained in:
parent
3cba53533d
commit
af9f0e0963
@ -86,6 +86,7 @@ const resetLayer = (layer: Layer) => {
|
||||
layer.isEnabled = true;
|
||||
layer.needsPixelBbox = false;
|
||||
layer.bboxNeedsUpdate = false;
|
||||
layer.uploadedMaskImage = null;
|
||||
return;
|
||||
}
|
||||
};
|
||||
@ -173,6 +174,7 @@ export const controlLayersSlice = createSlice({
|
||||
if (bbox === null && layer.type === 'regional_guidance_layer') {
|
||||
// The layer was fully erased, empty its objects to prevent accumulation of invisible objects
|
||||
layer.maskObjects = [];
|
||||
layer.uploadedMaskImage = null;
|
||||
layer.needsPixelBbox = false;
|
||||
}
|
||||
}
|
||||
@ -456,6 +458,7 @@ export const controlLayersSlice = createSlice({
|
||||
negativePrompt: null,
|
||||
ipAdapters: [],
|
||||
isSelected: true,
|
||||
uploadedMaskImage: null,
|
||||
};
|
||||
state.layers.push(layer);
|
||||
state.selectedLayerId = layer.id;
|
||||
@ -505,6 +508,7 @@ export const controlLayersSlice = createSlice({
|
||||
strokeWidth: state.brushSize,
|
||||
});
|
||||
layer.bboxNeedsUpdate = true;
|
||||
layer.uploadedMaskImage = null;
|
||||
if (!layer.needsPixelBbox && tool === 'eraser') {
|
||||
layer.needsPixelBbox = true;
|
||||
}
|
||||
@ -524,6 +528,7 @@ export const controlLayersSlice = createSlice({
|
||||
// TODO: Handle this in the event listener
|
||||
lastLine.points.push(point[0] - layer.x, point[1] - layer.y);
|
||||
layer.bboxNeedsUpdate = true;
|
||||
layer.uploadedMaskImage = null;
|
||||
},
|
||||
rgLayerRectAdded: {
|
||||
reducer: (state, action: PayloadAction<{ layerId: string; rect: IRect; rectUuid: string }>) => {
|
||||
@ -543,9 +548,15 @@ export const controlLayersSlice = createSlice({
|
||||
height: rect.height,
|
||||
});
|
||||
layer.bboxNeedsUpdate = true;
|
||||
layer.uploadedMaskImage = null;
|
||||
},
|
||||
prepare: (payload: { layerId: string; rect: IRect }) => ({ payload: { ...payload, rectUuid: uuidv4() } }),
|
||||
},
|
||||
rgLayerMaskImageUploaded: (state, action: PayloadAction<{ layerId: string; imageDTO: ImageDTO }>) => {
|
||||
const { layerId, imageDTO } = action.payload;
|
||||
const layer = selectRGLayerOrThrow(state, layerId);
|
||||
layer.uploadedMaskImage = imageDTOToImageWithDims(imageDTO);
|
||||
},
|
||||
rgLayerAutoNegativeChanged: (
|
||||
state,
|
||||
action: PayloadAction<{ layerId: string; autoNegative: ParameterAutoNegative }>
|
||||
@ -825,6 +836,7 @@ export const {
|
||||
rgLayerLineAdded,
|
||||
rgLayerPointsAdded,
|
||||
rgLayerRectAdded,
|
||||
rgLayerMaskImageUploaded,
|
||||
rgLayerAutoNegativeChanged,
|
||||
rgLayerIPAdapterAdded,
|
||||
rgLayerIPAdapterDeleted,
|
||||
|
@ -72,6 +72,7 @@ export type RegionalGuidanceLayer = RenderableLayerBase & {
|
||||
previewColor: RgbColor;
|
||||
autoNegative: ParameterAutoNegative;
|
||||
needsPixelBbox: boolean; // Needs the slower pixel-based bbox calculation - set to true when an there is an eraser object
|
||||
uploadedMaskImage: ImageWithDims | null;
|
||||
};
|
||||
|
||||
export type InitialImageLayer = RenderableLayerBase & {
|
||||
|
@ -4,7 +4,9 @@ import {
|
||||
isControlAdapterLayer,
|
||||
isIPAdapterLayer,
|
||||
isRegionalGuidanceLayer,
|
||||
rgLayerMaskImageUploaded,
|
||||
} from 'features/controlLayers/store/controlLayersSlice';
|
||||
import type { RegionalGuidanceLayer } from 'features/controlLayers/store/types';
|
||||
import {
|
||||
type ControlNetConfigV2,
|
||||
type ImageWithDims,
|
||||
@ -32,12 +34,13 @@ import {
|
||||
} from 'features/nodes/util/graph/constants';
|
||||
import { upsertMetadata } from 'features/nodes/util/graph/metadata';
|
||||
import { size } from 'lodash-es';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { getImageDTO, imagesApi } from 'services/api/endpoints/images';
|
||||
import type {
|
||||
CollectInvocation,
|
||||
ControlNetInvocation,
|
||||
CoreMetadataInvocation,
|
||||
Edge,
|
||||
ImageDTO,
|
||||
IPAdapterInvocation,
|
||||
NonNullableGraph,
|
||||
S,
|
||||
@ -337,7 +340,6 @@ const addGlobalIPAdaptersToGraph = async (
|
||||
};
|
||||
|
||||
export const addControlLayersToGraph = async (state: RootState, graph: NonNullableGraph, denoiseNodeId: string) => {
|
||||
const { dispatch } = getStore();
|
||||
const mainModel = state.generation.model;
|
||||
assert(mainModel, 'Missing main model when building graph');
|
||||
const isSDXL = mainModel.base === 'sdxl';
|
||||
@ -404,10 +406,6 @@ export const addControlLayersToGraph = async (state: RootState, graph: NonNullab
|
||||
return hasTextPrompt || hasIPAdapter;
|
||||
});
|
||||
|
||||
const layerIds = rgLayers.map((l) => l.id);
|
||||
const blobs = await getRegionalPromptLayerBlobs(layerIds);
|
||||
assert(size(blobs) === size(layerIds), 'Mismatch between layer IDs and blobs');
|
||||
|
||||
// TODO: We should probably just use conditioning collectors by default, and skip all this fanagling with re-routing
|
||||
// the existing conditioning nodes.
|
||||
|
||||
@ -470,22 +468,15 @@ export const addControlLayersToGraph = async (state: RootState, graph: NonNullab
|
||||
},
|
||||
});
|
||||
|
||||
// Upload the blobs to the backend, add each to graph
|
||||
// TODO: Store the uploaded image names in redux to reuse them, so long as the layer hasn't otherwise changed. This
|
||||
// would be a great perf win - not only would we skip re-uploading the same image, but we'd be able to use the node
|
||||
// cache (currently, when we re-use the same mask data, since it is a different image, the node cache is not used).
|
||||
const layerIds = rgLayers.map((l) => l.id);
|
||||
const blobs = await getRegionalPromptLayerBlobs(layerIds);
|
||||
assert(size(blobs) === size(layerIds), 'Mismatch between layer IDs and blobs');
|
||||
|
||||
for (const layer of rgLayers) {
|
||||
const blob = blobs[layer.id];
|
||||
assert(blob, `Blob for layer ${layer.id} not found`);
|
||||
|
||||
const file = new File([blob], `${layer.id}_mask.png`, { type: 'image/png' });
|
||||
const req = dispatch(
|
||||
imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'mask', is_intermediate: true })
|
||||
);
|
||||
req.reset();
|
||||
|
||||
// TODO: This will raise on network error
|
||||
const { image_name } = await req.unwrap();
|
||||
// Upload the mask image, or get the cached image if it exists
|
||||
const { image_name } = await getMaskImage(layer, blob);
|
||||
|
||||
// The main mask-to-tensor node
|
||||
const maskToTensorNode: S['AlphaMaskToTensorInvocation'] = {
|
||||
@ -679,3 +670,23 @@ export const addControlLayersToGraph = async (state: RootState, graph: NonNullab
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const getMaskImage = async (layer: RegionalGuidanceLayer, blob: Blob): Promise<ImageDTO> => {
|
||||
if (layer.uploadedMaskImage) {
|
||||
const imageDTO = await getImageDTO(layer.uploadedMaskImage.imageName);
|
||||
if (imageDTO) {
|
||||
return imageDTO;
|
||||
}
|
||||
}
|
||||
const { dispatch } = getStore();
|
||||
// No cached mask, or the cached image no longer exists - we need to upload the mask image
|
||||
const file = new File([blob], `${layer.id}_mask.png`, { type: 'image/png' });
|
||||
const req = dispatch(
|
||||
imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'mask', is_intermediate: true })
|
||||
);
|
||||
req.reset();
|
||||
|
||||
const imageDTO = await req.unwrap();
|
||||
dispatch(rgLayerMaskImageUploaded({ layerId: layer.id, imageDTO }));
|
||||
return imageDTO;
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user