mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): cache control layer mask images
When invoking with control layers, we were creating and uploading the mask images on every enqueue, even when the mask didn't change. The mask image can be cached to greatly reduce the number of uploads. With this change, we are a bit smarter about the mask images: - Check if there is an uploaded mask image name - If so, attempt to retrieve its DTO. Typically it will be in the RTKQ cache, so there is no network request, but it will make a network request if not cached to confirm the image actually exists on the server. - If we don't have an uploaded mask image name, or the request fails, we go ahead and upload the generated blob - Update the layer's state with a reference to this uploaded image for next time - Continue as before Any time we modify the mask (drawing/erasing, resetting the layer), we invalidate that cached image name (set it to null). We now only upload images when we need to and generation starts faster.
This commit is contained in:
parent
3cba53533d
commit
af9f0e0963
@ -86,6 +86,7 @@ const resetLayer = (layer: Layer) => {
|
|||||||
layer.isEnabled = true;
|
layer.isEnabled = true;
|
||||||
layer.needsPixelBbox = false;
|
layer.needsPixelBbox = false;
|
||||||
layer.bboxNeedsUpdate = false;
|
layer.bboxNeedsUpdate = false;
|
||||||
|
layer.uploadedMaskImage = null;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -173,6 +174,7 @@ export const controlLayersSlice = createSlice({
|
|||||||
if (bbox === null && layer.type === 'regional_guidance_layer') {
|
if (bbox === null && layer.type === 'regional_guidance_layer') {
|
||||||
// The layer was fully erased, empty its objects to prevent accumulation of invisible objects
|
// The layer was fully erased, empty its objects to prevent accumulation of invisible objects
|
||||||
layer.maskObjects = [];
|
layer.maskObjects = [];
|
||||||
|
layer.uploadedMaskImage = null;
|
||||||
layer.needsPixelBbox = false;
|
layer.needsPixelBbox = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -456,6 +458,7 @@ export const controlLayersSlice = createSlice({
|
|||||||
negativePrompt: null,
|
negativePrompt: null,
|
||||||
ipAdapters: [],
|
ipAdapters: [],
|
||||||
isSelected: true,
|
isSelected: true,
|
||||||
|
uploadedMaskImage: null,
|
||||||
};
|
};
|
||||||
state.layers.push(layer);
|
state.layers.push(layer);
|
||||||
state.selectedLayerId = layer.id;
|
state.selectedLayerId = layer.id;
|
||||||
@ -505,6 +508,7 @@ export const controlLayersSlice = createSlice({
|
|||||||
strokeWidth: state.brushSize,
|
strokeWidth: state.brushSize,
|
||||||
});
|
});
|
||||||
layer.bboxNeedsUpdate = true;
|
layer.bboxNeedsUpdate = true;
|
||||||
|
layer.uploadedMaskImage = null;
|
||||||
if (!layer.needsPixelBbox && tool === 'eraser') {
|
if (!layer.needsPixelBbox && tool === 'eraser') {
|
||||||
layer.needsPixelBbox = true;
|
layer.needsPixelBbox = true;
|
||||||
}
|
}
|
||||||
@ -524,6 +528,7 @@ export const controlLayersSlice = createSlice({
|
|||||||
// TODO: Handle this in the event listener
|
// TODO: Handle this in the event listener
|
||||||
lastLine.points.push(point[0] - layer.x, point[1] - layer.y);
|
lastLine.points.push(point[0] - layer.x, point[1] - layer.y);
|
||||||
layer.bboxNeedsUpdate = true;
|
layer.bboxNeedsUpdate = true;
|
||||||
|
layer.uploadedMaskImage = null;
|
||||||
},
|
},
|
||||||
rgLayerRectAdded: {
|
rgLayerRectAdded: {
|
||||||
reducer: (state, action: PayloadAction<{ layerId: string; rect: IRect; rectUuid: string }>) => {
|
reducer: (state, action: PayloadAction<{ layerId: string; rect: IRect; rectUuid: string }>) => {
|
||||||
@ -543,9 +548,15 @@ export const controlLayersSlice = createSlice({
|
|||||||
height: rect.height,
|
height: rect.height,
|
||||||
});
|
});
|
||||||
layer.bboxNeedsUpdate = true;
|
layer.bboxNeedsUpdate = true;
|
||||||
|
layer.uploadedMaskImage = null;
|
||||||
},
|
},
|
||||||
prepare: (payload: { layerId: string; rect: IRect }) => ({ payload: { ...payload, rectUuid: uuidv4() } }),
|
prepare: (payload: { layerId: string; rect: IRect }) => ({ payload: { ...payload, rectUuid: uuidv4() } }),
|
||||||
},
|
},
|
||||||
|
rgLayerMaskImageUploaded: (state, action: PayloadAction<{ layerId: string; imageDTO: ImageDTO }>) => {
|
||||||
|
const { layerId, imageDTO } = action.payload;
|
||||||
|
const layer = selectRGLayerOrThrow(state, layerId);
|
||||||
|
layer.uploadedMaskImage = imageDTOToImageWithDims(imageDTO);
|
||||||
|
},
|
||||||
rgLayerAutoNegativeChanged: (
|
rgLayerAutoNegativeChanged: (
|
||||||
state,
|
state,
|
||||||
action: PayloadAction<{ layerId: string; autoNegative: ParameterAutoNegative }>
|
action: PayloadAction<{ layerId: string; autoNegative: ParameterAutoNegative }>
|
||||||
@ -825,6 +836,7 @@ export const {
|
|||||||
rgLayerLineAdded,
|
rgLayerLineAdded,
|
||||||
rgLayerPointsAdded,
|
rgLayerPointsAdded,
|
||||||
rgLayerRectAdded,
|
rgLayerRectAdded,
|
||||||
|
rgLayerMaskImageUploaded,
|
||||||
rgLayerAutoNegativeChanged,
|
rgLayerAutoNegativeChanged,
|
||||||
rgLayerIPAdapterAdded,
|
rgLayerIPAdapterAdded,
|
||||||
rgLayerIPAdapterDeleted,
|
rgLayerIPAdapterDeleted,
|
||||||
|
@ -72,6 +72,7 @@ export type RegionalGuidanceLayer = RenderableLayerBase & {
|
|||||||
previewColor: RgbColor;
|
previewColor: RgbColor;
|
||||||
autoNegative: ParameterAutoNegative;
|
autoNegative: ParameterAutoNegative;
|
||||||
needsPixelBbox: boolean; // Needs the slower pixel-based bbox calculation - set to true when an there is an eraser object
|
needsPixelBbox: boolean; // Needs the slower pixel-based bbox calculation - set to true when an there is an eraser object
|
||||||
|
uploadedMaskImage: ImageWithDims | null;
|
||||||
};
|
};
|
||||||
|
|
||||||
export type InitialImageLayer = RenderableLayerBase & {
|
export type InitialImageLayer = RenderableLayerBase & {
|
||||||
|
@ -4,7 +4,9 @@ import {
|
|||||||
isControlAdapterLayer,
|
isControlAdapterLayer,
|
||||||
isIPAdapterLayer,
|
isIPAdapterLayer,
|
||||||
isRegionalGuidanceLayer,
|
isRegionalGuidanceLayer,
|
||||||
|
rgLayerMaskImageUploaded,
|
||||||
} from 'features/controlLayers/store/controlLayersSlice';
|
} from 'features/controlLayers/store/controlLayersSlice';
|
||||||
|
import type { RegionalGuidanceLayer } from 'features/controlLayers/store/types';
|
||||||
import {
|
import {
|
||||||
type ControlNetConfigV2,
|
type ControlNetConfigV2,
|
||||||
type ImageWithDims,
|
type ImageWithDims,
|
||||||
@ -32,12 +34,13 @@ import {
|
|||||||
} from 'features/nodes/util/graph/constants';
|
} from 'features/nodes/util/graph/constants';
|
||||||
import { upsertMetadata } from 'features/nodes/util/graph/metadata';
|
import { upsertMetadata } from 'features/nodes/util/graph/metadata';
|
||||||
import { size } from 'lodash-es';
|
import { size } from 'lodash-es';
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
import { getImageDTO, imagesApi } from 'services/api/endpoints/images';
|
||||||
import type {
|
import type {
|
||||||
CollectInvocation,
|
CollectInvocation,
|
||||||
ControlNetInvocation,
|
ControlNetInvocation,
|
||||||
CoreMetadataInvocation,
|
CoreMetadataInvocation,
|
||||||
Edge,
|
Edge,
|
||||||
|
ImageDTO,
|
||||||
IPAdapterInvocation,
|
IPAdapterInvocation,
|
||||||
NonNullableGraph,
|
NonNullableGraph,
|
||||||
S,
|
S,
|
||||||
@ -337,7 +340,6 @@ const addGlobalIPAdaptersToGraph = async (
|
|||||||
};
|
};
|
||||||
|
|
||||||
export const addControlLayersToGraph = async (state: RootState, graph: NonNullableGraph, denoiseNodeId: string) => {
|
export const addControlLayersToGraph = async (state: RootState, graph: NonNullableGraph, denoiseNodeId: string) => {
|
||||||
const { dispatch } = getStore();
|
|
||||||
const mainModel = state.generation.model;
|
const mainModel = state.generation.model;
|
||||||
assert(mainModel, 'Missing main model when building graph');
|
assert(mainModel, 'Missing main model when building graph');
|
||||||
const isSDXL = mainModel.base === 'sdxl';
|
const isSDXL = mainModel.base === 'sdxl';
|
||||||
@ -404,10 +406,6 @@ export const addControlLayersToGraph = async (state: RootState, graph: NonNullab
|
|||||||
return hasTextPrompt || hasIPAdapter;
|
return hasTextPrompt || hasIPAdapter;
|
||||||
});
|
});
|
||||||
|
|
||||||
const layerIds = rgLayers.map((l) => l.id);
|
|
||||||
const blobs = await getRegionalPromptLayerBlobs(layerIds);
|
|
||||||
assert(size(blobs) === size(layerIds), 'Mismatch between layer IDs and blobs');
|
|
||||||
|
|
||||||
// TODO: We should probably just use conditioning collectors by default, and skip all this fanagling with re-routing
|
// TODO: We should probably just use conditioning collectors by default, and skip all this fanagling with re-routing
|
||||||
// the existing conditioning nodes.
|
// the existing conditioning nodes.
|
||||||
|
|
||||||
@ -470,22 +468,15 @@ export const addControlLayersToGraph = async (state: RootState, graph: NonNullab
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
// Upload the blobs to the backend, add each to graph
|
const layerIds = rgLayers.map((l) => l.id);
|
||||||
// TODO: Store the uploaded image names in redux to reuse them, so long as the layer hasn't otherwise changed. This
|
const blobs = await getRegionalPromptLayerBlobs(layerIds);
|
||||||
// would be a great perf win - not only would we skip re-uploading the same image, but we'd be able to use the node
|
assert(size(blobs) === size(layerIds), 'Mismatch between layer IDs and blobs');
|
||||||
// cache (currently, when we re-use the same mask data, since it is a different image, the node cache is not used).
|
|
||||||
for (const layer of rgLayers) {
|
for (const layer of rgLayers) {
|
||||||
const blob = blobs[layer.id];
|
const blob = blobs[layer.id];
|
||||||
assert(blob, `Blob for layer ${layer.id} not found`);
|
assert(blob, `Blob for layer ${layer.id} not found`);
|
||||||
|
// Upload the mask image, or get the cached image if it exists
|
||||||
const file = new File([blob], `${layer.id}_mask.png`, { type: 'image/png' });
|
const { image_name } = await getMaskImage(layer, blob);
|
||||||
const req = dispatch(
|
|
||||||
imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'mask', is_intermediate: true })
|
|
||||||
);
|
|
||||||
req.reset();
|
|
||||||
|
|
||||||
// TODO: This will raise on network error
|
|
||||||
const { image_name } = await req.unwrap();
|
|
||||||
|
|
||||||
// The main mask-to-tensor node
|
// The main mask-to-tensor node
|
||||||
const maskToTensorNode: S['AlphaMaskToTensorInvocation'] = {
|
const maskToTensorNode: S['AlphaMaskToTensorInvocation'] = {
|
||||||
@ -679,3 +670,23 @@ export const addControlLayersToGraph = async (state: RootState, graph: NonNullab
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const getMaskImage = async (layer: RegionalGuidanceLayer, blob: Blob): Promise<ImageDTO> => {
|
||||||
|
if (layer.uploadedMaskImage) {
|
||||||
|
const imageDTO = await getImageDTO(layer.uploadedMaskImage.imageName);
|
||||||
|
if (imageDTO) {
|
||||||
|
return imageDTO;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const { dispatch } = getStore();
|
||||||
|
// No cached mask, or the cached image no longer exists - we need to upload the mask image
|
||||||
|
const file = new File([blob], `${layer.id}_mask.png`, { type: 'image/png' });
|
||||||
|
const req = dispatch(
|
||||||
|
imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'mask', is_intermediate: true })
|
||||||
|
);
|
||||||
|
req.reset();
|
||||||
|
|
||||||
|
const imageDTO = await req.unwrap();
|
||||||
|
dispatch(rgLayerMaskImageUploaded({ layerId: layer.id, imageDTO }));
|
||||||
|
return imageDTO;
|
||||||
|
};
|
||||||
|
Loading…
x
Reference in New Issue
Block a user