mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): get region and base layer canvas to blob logic working
This commit is contained in:
parent
3b864921ac
commit
0c9cf73702
@ -216,7 +216,7 @@ export const updateBboxes = (
|
|||||||
onBboxChanged({ id: entityState.id, bbox: getLayerBboxPixels(konvaLayer, filterLayerChildren) }, 'layer');
|
onBboxChanged({ id: entityState.id, bbox: getLayerBboxPixels(konvaLayer, filterLayerChildren) }, 'layer');
|
||||||
}
|
}
|
||||||
} else if (entityState.type === 'control_adapter') {
|
} else if (entityState.type === 'control_adapter') {
|
||||||
if (!entityState.image && !entityState.processedImage) {
|
if (!entityState.imageObject && !entityState.processedImageObject) {
|
||||||
// No objects - no bbox to calculate
|
// No objects - no bbox to calculate
|
||||||
onBboxChanged({ id: entityState.id, bbox: null }, 'control_adapter');
|
onBboxChanged({ id: entityState.id, bbox: null }, 'control_adapter');
|
||||||
} else {
|
} else {
|
||||||
|
@ -8,9 +8,9 @@ import {
|
|||||||
} from 'features/controlLayers/konva/naming';
|
} from 'features/controlLayers/konva/naming';
|
||||||
import type {
|
import type {
|
||||||
BrushLineObjectRecord,
|
BrushLineObjectRecord,
|
||||||
KonvaEntityAdapter,
|
|
||||||
EraserLineObjectRecord,
|
EraserLineObjectRecord,
|
||||||
ImageObjectRecord,
|
ImageObjectRecord,
|
||||||
|
KonvaEntityAdapter,
|
||||||
RectShapeObjectRecord,
|
RectShapeObjectRecord,
|
||||||
} from 'features/controlLayers/konva/nodeManager';
|
} from 'features/controlLayers/konva/nodeManager';
|
||||||
import type {
|
import type {
|
||||||
|
@ -53,8 +53,11 @@ import type {
|
|||||||
import type Konva from 'konva';
|
import type Konva from 'konva';
|
||||||
import type { IRect, Vector2d } from 'konva/lib/types';
|
import type { IRect, Vector2d } from 'konva/lib/types';
|
||||||
import { debounce } from 'lodash-es';
|
import { debounce } from 'lodash-es';
|
||||||
|
import { atom } from 'nanostores';
|
||||||
import type { RgbaColor } from 'react-colorful';
|
import type { RgbaColor } from 'react-colorful';
|
||||||
|
|
||||||
|
export const $nodeManager = atom<KonvaNodeManager | null>(null);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Initializes the canvas renderer. It subscribes to the redux store and listens for changes directly, bypassing the
|
* Initializes the canvas renderer. It subscribes to the redux store and listens for changes directly, bypassing the
|
||||||
* react rendering cycle entirely, improving canvas performance.
|
* react rendering cycle entirely, improving canvas performance.
|
||||||
@ -249,6 +252,8 @@ export const initializeRenderer = (
|
|||||||
};
|
};
|
||||||
|
|
||||||
const manager = new KonvaNodeManager(stage, getBbox, onBboxTransformed, $shift.get, $ctrl.get, $meta.get, $alt.get);
|
const manager = new KonvaNodeManager(stage, getBbox, onBboxTransformed, $shift.get, $ctrl.get, $meta.get, $alt.get);
|
||||||
|
console.log(manager);
|
||||||
|
$nodeManager.set(manager);
|
||||||
|
|
||||||
const cleanupListeners = setStageEventHandlers({
|
const cleanupListeners = setStageEventHandlers({
|
||||||
manager,
|
manager,
|
||||||
@ -344,7 +349,7 @@ export const initializeRenderer = (
|
|||||||
canvasV2.controlAdapters !== prevCanvasV2.controlAdapters ||
|
canvasV2.controlAdapters !== prevCanvasV2.controlAdapters ||
|
||||||
canvasV2.regions !== prevCanvasV2.regions
|
canvasV2.regions !== prevCanvasV2.regions
|
||||||
) {
|
) {
|
||||||
logIfDebugging('Updating entity bboxes');
|
// logIfDebugging('Updating entity bboxes');
|
||||||
// debouncedUpdateBboxes(stage, canvasV2.layers, canvasV2.controlAdapters, canvasV2.regions, onBboxChanged);
|
// debouncedUpdateBboxes(stage, canvasV2.layers, canvasV2.controlAdapters, canvasV2.regions, onBboxChanged);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -16,9 +16,10 @@ import { toolReducers } from 'features/controlLayers/store/toolReducers';
|
|||||||
import { initialAspectRatioState } from 'features/parameters/components/ImageSize/constants';
|
import { initialAspectRatioState } from 'features/parameters/components/ImageSize/constants';
|
||||||
import type { AspectRatioState } from 'features/parameters/components/ImageSize/types';
|
import type { AspectRatioState } from 'features/parameters/components/ImageSize/types';
|
||||||
import { atom } from 'nanostores';
|
import { atom } from 'nanostores';
|
||||||
|
import type { ImageDTO } from 'services/api/types';
|
||||||
|
|
||||||
import type { CanvasEntityIdentifier, CanvasV2State, StageAttrs } from './types';
|
import type { CanvasEntityIdentifier, CanvasV2State, StageAttrs } from './types';
|
||||||
import { DEFAULT_RGBA_COLOR } from './types';
|
import { DEFAULT_RGBA_COLOR, imageDTOToImageWithDims } from './types';
|
||||||
|
|
||||||
const initialState: CanvasV2State = {
|
const initialState: CanvasV2State = {
|
||||||
_version: 3,
|
_version: 3,
|
||||||
@ -119,6 +120,7 @@ const initialState: CanvasV2State = {
|
|||||||
refinerNegativeAestheticScore: 2.5,
|
refinerNegativeAestheticScore: 2.5,
|
||||||
refinerStart: 0.8,
|
refinerStart: 0.8,
|
||||||
},
|
},
|
||||||
|
baseLayerImageCache: null,
|
||||||
};
|
};
|
||||||
|
|
||||||
export const canvasV2Slice = createSlice({
|
export const canvasV2Slice = createSlice({
|
||||||
@ -164,6 +166,10 @@ export const canvasV2Slice = createSlice({
|
|||||||
state.layers = [];
|
state.layers = [];
|
||||||
state.ipAdapters = [];
|
state.ipAdapters = [];
|
||||||
state.controlAdapters = [];
|
state.controlAdapters = [];
|
||||||
|
state.baseLayerImageCache = null;
|
||||||
|
},
|
||||||
|
baseLayerImageCacheChanged: (state, action: PayloadAction<ImageDTO | null>) => {
|
||||||
|
state.baseLayerImageCache = action.payload ? imageDTOToImageWithDims(action.payload) : null;
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
@ -185,6 +191,7 @@ export const {
|
|||||||
scaledBboxChanged,
|
scaledBboxChanged,
|
||||||
bboxScaleMethodChanged,
|
bboxScaleMethodChanged,
|
||||||
clipToBboxChanged,
|
clipToBboxChanged,
|
||||||
|
baseLayerImageCacheChanged,
|
||||||
// layers
|
// layers
|
||||||
layerAdded,
|
layerAdded,
|
||||||
layerRecalled,
|
layerRecalled,
|
||||||
|
@ -39,6 +39,7 @@ export const layersReducers = {
|
|||||||
y: 0,
|
y: 0,
|
||||||
});
|
});
|
||||||
state.selectedEntityIdentifier = { type: 'layer', id };
|
state.selectedEntityIdentifier = { type: 'layer', id };
|
||||||
|
state.baseLayerImageCache = null;
|
||||||
},
|
},
|
||||||
prepare: () => ({ payload: { id: uuidv4() } }),
|
prepare: () => ({ payload: { id: uuidv4() } }),
|
||||||
},
|
},
|
||||||
@ -46,6 +47,7 @@ export const layersReducers = {
|
|||||||
const { data } = action.payload;
|
const { data } = action.payload;
|
||||||
state.layers.push(data);
|
state.layers.push(data);
|
||||||
state.selectedEntityIdentifier = { type: 'layer', id: data.id };
|
state.selectedEntityIdentifier = { type: 'layer', id: data.id };
|
||||||
|
state.baseLayerImageCache = null;
|
||||||
},
|
},
|
||||||
layerIsEnabledToggled: (state, action: PayloadAction<{ id: string }>) => {
|
layerIsEnabledToggled: (state, action: PayloadAction<{ id: string }>) => {
|
||||||
const { id } = action.payload;
|
const { id } = action.payload;
|
||||||
@ -54,6 +56,7 @@ export const layersReducers = {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
layer.isEnabled = !layer.isEnabled;
|
layer.isEnabled = !layer.isEnabled;
|
||||||
|
state.baseLayerImageCache = null;
|
||||||
},
|
},
|
||||||
layerTranslated: (state, action: PayloadAction<{ id: string; x: number; y: number }>) => {
|
layerTranslated: (state, action: PayloadAction<{ id: string; x: number; y: number }>) => {
|
||||||
const { id, x, y } = action.payload;
|
const { id, x, y } = action.payload;
|
||||||
@ -63,6 +66,7 @@ export const layersReducers = {
|
|||||||
}
|
}
|
||||||
layer.x = x;
|
layer.x = x;
|
||||||
layer.y = y;
|
layer.y = y;
|
||||||
|
state.baseLayerImageCache = null;
|
||||||
},
|
},
|
||||||
layerBboxChanged: (state, action: PayloadAction<{ id: string; bbox: IRect | null }>) => {
|
layerBboxChanged: (state, action: PayloadAction<{ id: string; bbox: IRect | null }>) => {
|
||||||
const { id, bbox } = action.payload;
|
const { id, bbox } = action.payload;
|
||||||
@ -88,13 +92,16 @@ export const layersReducers = {
|
|||||||
layer.objects = [];
|
layer.objects = [];
|
||||||
layer.bbox = null;
|
layer.bbox = null;
|
||||||
layer.bboxNeedsUpdate = false;
|
layer.bboxNeedsUpdate = false;
|
||||||
|
state.baseLayerImageCache = null;
|
||||||
},
|
},
|
||||||
layerDeleted: (state, action: PayloadAction<{ id: string }>) => {
|
layerDeleted: (state, action: PayloadAction<{ id: string }>) => {
|
||||||
const { id } = action.payload;
|
const { id } = action.payload;
|
||||||
state.layers = state.layers.filter((l) => l.id !== id);
|
state.layers = state.layers.filter((l) => l.id !== id);
|
||||||
|
state.baseLayerImageCache = null;
|
||||||
},
|
},
|
||||||
layerAllDeleted: (state) => {
|
layerAllDeleted: (state) => {
|
||||||
state.layers = [];
|
state.layers = [];
|
||||||
|
state.baseLayerImageCache = null;
|
||||||
},
|
},
|
||||||
layerOpacityChanged: (state, action: PayloadAction<{ id: string; opacity: number }>) => {
|
layerOpacityChanged: (state, action: PayloadAction<{ id: string; opacity: number }>) => {
|
||||||
const { id, opacity } = action.payload;
|
const { id, opacity } = action.payload;
|
||||||
@ -103,6 +110,7 @@ export const layersReducers = {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
layer.opacity = opacity;
|
layer.opacity = opacity;
|
||||||
|
state.baseLayerImageCache = null;
|
||||||
},
|
},
|
||||||
layerMovedForwardOne: (state, action: PayloadAction<{ id: string }>) => {
|
layerMovedForwardOne: (state, action: PayloadAction<{ id: string }>) => {
|
||||||
const { id } = action.payload;
|
const { id } = action.payload;
|
||||||
@ -111,6 +119,7 @@ export const layersReducers = {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
moveOneToEnd(state.layers, layer);
|
moveOneToEnd(state.layers, layer);
|
||||||
|
state.baseLayerImageCache = null;
|
||||||
},
|
},
|
||||||
layerMovedToFront: (state, action: PayloadAction<{ id: string }>) => {
|
layerMovedToFront: (state, action: PayloadAction<{ id: string }>) => {
|
||||||
const { id } = action.payload;
|
const { id } = action.payload;
|
||||||
@ -119,6 +128,7 @@ export const layersReducers = {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
moveToEnd(state.layers, layer);
|
moveToEnd(state.layers, layer);
|
||||||
|
state.baseLayerImageCache = null;
|
||||||
},
|
},
|
||||||
layerMovedBackwardOne: (state, action: PayloadAction<{ id: string }>) => {
|
layerMovedBackwardOne: (state, action: PayloadAction<{ id: string }>) => {
|
||||||
const { id } = action.payload;
|
const { id } = action.payload;
|
||||||
@ -127,6 +137,7 @@ export const layersReducers = {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
moveOneToStart(state.layers, layer);
|
moveOneToStart(state.layers, layer);
|
||||||
|
state.baseLayerImageCache = null;
|
||||||
},
|
},
|
||||||
layerMovedToBack: (state, action: PayloadAction<{ id: string }>) => {
|
layerMovedToBack: (state, action: PayloadAction<{ id: string }>) => {
|
||||||
const { id } = action.payload;
|
const { id } = action.payload;
|
||||||
@ -135,6 +146,7 @@ export const layersReducers = {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
moveToStart(state.layers, layer);
|
moveToStart(state.layers, layer);
|
||||||
|
state.baseLayerImageCache = null;
|
||||||
},
|
},
|
||||||
layerBrushLineAdded: {
|
layerBrushLineAdded: {
|
||||||
reducer: (state, action: PayloadAction<BrushLineAddedArg & { lineId: string }>) => {
|
reducer: (state, action: PayloadAction<BrushLineAddedArg & { lineId: string }>) => {
|
||||||
@ -153,6 +165,7 @@ export const layersReducers = {
|
|||||||
clip,
|
clip,
|
||||||
});
|
});
|
||||||
layer.bboxNeedsUpdate = true;
|
layer.bboxNeedsUpdate = true;
|
||||||
|
state.baseLayerImageCache = null;
|
||||||
},
|
},
|
||||||
prepare: (payload: BrushLineAddedArg) => ({
|
prepare: (payload: BrushLineAddedArg) => ({
|
||||||
payload: { ...payload, lineId: uuidv4() },
|
payload: { ...payload, lineId: uuidv4() },
|
||||||
@ -174,6 +187,7 @@ export const layersReducers = {
|
|||||||
clip,
|
clip,
|
||||||
});
|
});
|
||||||
layer.bboxNeedsUpdate = true;
|
layer.bboxNeedsUpdate = true;
|
||||||
|
state.baseLayerImageCache = null;
|
||||||
},
|
},
|
||||||
prepare: (payload: EraserLineAddedArg) => ({
|
prepare: (payload: EraserLineAddedArg) => ({
|
||||||
payload: { ...payload, lineId: uuidv4() },
|
payload: { ...payload, lineId: uuidv4() },
|
||||||
@ -191,6 +205,7 @@ export const layersReducers = {
|
|||||||
}
|
}
|
||||||
lastObject.points.push(...point);
|
lastObject.points.push(...point);
|
||||||
layer.bboxNeedsUpdate = true;
|
layer.bboxNeedsUpdate = true;
|
||||||
|
state.baseLayerImageCache = null;
|
||||||
},
|
},
|
||||||
layerRectAdded: {
|
layerRectAdded: {
|
||||||
reducer: (state, action: PayloadAction<RectShapeAddedArg & { rectId: string }>) => {
|
reducer: (state, action: PayloadAction<RectShapeAddedArg & { rectId: string }>) => {
|
||||||
@ -210,6 +225,7 @@ export const layersReducers = {
|
|||||||
color,
|
color,
|
||||||
});
|
});
|
||||||
layer.bboxNeedsUpdate = true;
|
layer.bboxNeedsUpdate = true;
|
||||||
|
state.baseLayerImageCache = null;
|
||||||
},
|
},
|
||||||
prepare: (payload: RectShapeAddedArg) => ({ payload: { ...payload, rectId: uuidv4() } }),
|
prepare: (payload: RectShapeAddedArg) => ({ payload: { ...payload, rectId: uuidv4() } }),
|
||||||
},
|
},
|
||||||
@ -222,6 +238,7 @@ export const layersReducers = {
|
|||||||
}
|
}
|
||||||
layer.objects.push(imageDTOToImageObject(id, objectId, imageDTO));
|
layer.objects.push(imageDTOToImageObject(id, objectId, imageDTO));
|
||||||
layer.bboxNeedsUpdate = true;
|
layer.bboxNeedsUpdate = true;
|
||||||
|
state.baseLayerImageCache = null;
|
||||||
},
|
},
|
||||||
prepare: (payload: ImageObjectAddedArg) => ({ payload: { ...payload, objectId: uuidv4() } }),
|
prepare: (payload: ImageObjectAddedArg) => ({ payload: { ...payload, objectId: uuidv4() } }),
|
||||||
},
|
},
|
||||||
|
@ -872,6 +872,7 @@ export type CanvasV2State = {
|
|||||||
refinerNegativeAestheticScore: number;
|
refinerNegativeAestheticScore: number;
|
||||||
refinerStart: number;
|
refinerStart: number;
|
||||||
};
|
};
|
||||||
|
baseLayerImageCache: ImageWithDims | null;
|
||||||
};
|
};
|
||||||
|
|
||||||
export type StageAttrs = { x: number; y: number; width: number; height: number; scale: number };
|
export type StageAttrs = { x: number; y: number; width: number; height: number; scale: number };
|
||||||
|
@ -0,0 +1,96 @@
|
|||||||
|
import { getStore } from 'app/store/nanostores/store';
|
||||||
|
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
|
||||||
|
import { $nodeManager } from 'features/controlLayers/konva/renderers/renderer';
|
||||||
|
import { blobToDataURL } from 'features/controlLayers/konva/util';
|
||||||
|
import { baseLayerImageCacheChanged } from 'features/controlLayers/store/canvasV2Slice';
|
||||||
|
import type { LayerEntity } from 'features/controlLayers/store/types';
|
||||||
|
import type Konva from 'konva';
|
||||||
|
import type { IRect } from 'konva/lib/types';
|
||||||
|
import { getImageDTO, imagesApi } from 'services/api/endpoints/images';
|
||||||
|
import type { ImageDTO } from 'services/api/types';
|
||||||
|
import { assert } from 'tsafe';
|
||||||
|
|
||||||
|
const isValidLayer = (entity: LayerEntity) => {
|
||||||
|
return (
|
||||||
|
entity.isEnabled &&
|
||||||
|
// Boolean(entity.bbox) && TODO(psyche): Re-enable this check when we have a way to calculate bbox for all layers
|
||||||
|
entity.objects.length > 0
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the blobs of all regional prompt layers. Only visible layers are returned.
|
||||||
|
* @param layerIds The IDs of the layers to get blobs for. If not provided, all regional prompt layers are used.
|
||||||
|
* @param preview Whether to open a new tab displaying each layer.
|
||||||
|
* @returns A map of layer IDs to blobs.
|
||||||
|
*/
|
||||||
|
|
||||||
|
const getBaseLayer = async (layers: LayerEntity[], bbox: IRect, preview: boolean = false): Promise<Blob> => {
|
||||||
|
const manager = $nodeManager.get();
|
||||||
|
assert(manager, 'Node manager is null');
|
||||||
|
|
||||||
|
const stage = manager.stage.clone();
|
||||||
|
|
||||||
|
stage.scaleX(1);
|
||||||
|
stage.scaleY(1);
|
||||||
|
stage.x(0);
|
||||||
|
stage.y(0);
|
||||||
|
|
||||||
|
const validLayers = layers.filter(isValidLayer);
|
||||||
|
|
||||||
|
// Konva bug (?) - when iterating over the array returned from `stage.getLayers()`, if you destroy a layer, the array
|
||||||
|
// is mutated in-place and the next iteration will skip the next layer. To avoid this, we first collect the layers
|
||||||
|
// to delete in a separate array and then destroy them.
|
||||||
|
// TODO(psyche): Maybe report this?
|
||||||
|
const toDelete: Konva.Layer[] = [];
|
||||||
|
|
||||||
|
for (const konvaLayer of stage.getLayers()) {
|
||||||
|
const layer = validLayers.find((l) => l.id === konvaLayer.id());
|
||||||
|
if (!layer) {
|
||||||
|
toDelete.push(konvaLayer);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const konvaLayer of toDelete) {
|
||||||
|
konvaLayer.destroy();
|
||||||
|
}
|
||||||
|
|
||||||
|
const blob = await new Promise<Blob>((resolve) => {
|
||||||
|
stage.toBlob({
|
||||||
|
callback: (blob) => {
|
||||||
|
assert(blob, 'Blob is null');
|
||||||
|
resolve(blob);
|
||||||
|
},
|
||||||
|
...bbox,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
if (preview) {
|
||||||
|
const base64 = await blobToDataURL(blob);
|
||||||
|
openBase64ImageInTab([{ base64, caption: 'base layer' }]);
|
||||||
|
}
|
||||||
|
|
||||||
|
stage.destroy();
|
||||||
|
|
||||||
|
return blob;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const getBaseLayerImage = async (): Promise<ImageDTO> => {
|
||||||
|
const { dispatch, getState } = getStore();
|
||||||
|
const state = getState();
|
||||||
|
if (state.canvasV2.baseLayerImageCache) {
|
||||||
|
const imageDTO = await getImageDTO(state.canvasV2.baseLayerImageCache.name);
|
||||||
|
if (imageDTO) {
|
||||||
|
return imageDTO;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const blob = await getBaseLayer(state.canvasV2.layers, state.canvasV2.bbox, true);
|
||||||
|
const file = new File([blob], 'image.png', { type: 'image/png' });
|
||||||
|
const req = dispatch(
|
||||||
|
imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'general', is_intermediate: true })
|
||||||
|
);
|
||||||
|
req.reset();
|
||||||
|
const imageDTO = await req.unwrap();
|
||||||
|
dispatch(baseLayerImageCacheChanged(imageDTO));
|
||||||
|
return imageDTO;
|
||||||
|
};
|
@ -1,8 +1,8 @@
|
|||||||
import { getStore } from 'app/store/nanostores/store';
|
import { getStore } from 'app/store/nanostores/store';
|
||||||
import { deepClone } from 'common/util/deepClone';
|
import { deepClone } from 'common/util/deepClone';
|
||||||
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
|
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
|
||||||
import { KonvaNodeManager } from 'features/controlLayers/konva/nodeManager';
|
import type { KonvaEntityAdapter } from 'features/controlLayers/konva/nodeManager';
|
||||||
import { renderRegions } from 'features/controlLayers/konva/renderers/regions';
|
import { $nodeManager } from 'features/controlLayers/konva/renderers/renderer';
|
||||||
import { blobToDataURL } from 'features/controlLayers/konva/util';
|
import { blobToDataURL } from 'features/controlLayers/konva/util';
|
||||||
import { rgMaskImageUploaded } from 'features/controlLayers/store/canvasV2Slice';
|
import { rgMaskImageUploaded } from 'features/controlLayers/store/canvasV2Slice';
|
||||||
import type { Dimensions, IPAdapterEntity, RegionEntity } from 'features/controlLayers/store/types';
|
import type { Dimensions, IPAdapterEntity, RegionEntity } from 'features/controlLayers/store/types';
|
||||||
@ -15,9 +15,7 @@ import {
|
|||||||
} from 'features/nodes/util/graph/constants';
|
} from 'features/nodes/util/graph/constants';
|
||||||
import { addIPAdapterCollectorSafe, isValidIPAdapter } from 'features/nodes/util/graph/generation/addIPAdapters';
|
import { addIPAdapterCollectorSafe, isValidIPAdapter } from 'features/nodes/util/graph/generation/addIPAdapters';
|
||||||
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||||
import Konva from 'konva';
|
|
||||||
import type { IRect } from 'konva/lib/types';
|
import type { IRect } from 'konva/lib/types';
|
||||||
import { size } from 'lodash-es';
|
|
||||||
import { getImageDTO, imagesApi } from 'services/api/endpoints/images';
|
import { getImageDTO, imagesApi } from 'services/api/endpoints/images';
|
||||||
import type { BaseModelType, ImageDTO, Invocation } from 'services/api/types';
|
import type { BaseModelType, ImageDTO, Invocation } from 'services/api/types';
|
||||||
import { assert } from 'tsafe';
|
import { assert } from 'tsafe';
|
||||||
@ -50,38 +48,34 @@ export const addRegions = async (
|
|||||||
const isSDXL = base === 'sdxl';
|
const isSDXL = base === 'sdxl';
|
||||||
|
|
||||||
const validRegions = regions.filter((rg) => isValidRegion(rg, base));
|
const validRegions = regions.filter((rg) => isValidRegion(rg, base));
|
||||||
const blobs = await getRGMaskBlobs(validRegions, documentSize, bbox);
|
|
||||||
assert(size(blobs) === size(validRegions), 'Mismatch between layer IDs and blobs');
|
|
||||||
|
|
||||||
for (const rg of validRegions) {
|
for (const region of validRegions) {
|
||||||
const blob = blobs[rg.id];
|
|
||||||
assert(blob, `Blob for layer ${rg.id} not found`);
|
|
||||||
// Upload the mask image, or get the cached image if it exists
|
// Upload the mask image, or get the cached image if it exists
|
||||||
const { image_name } = await getMaskImage(rg, blob);
|
const { image_name } = await getRegionMaskImage(region, bbox, true);
|
||||||
|
|
||||||
// The main mask-to-tensor node
|
// The main mask-to-tensor node
|
||||||
const maskToTensor = g.addNode({
|
const maskToTensor = g.addNode({
|
||||||
id: `${PROMPT_REGION_MASK_TO_TENSOR_PREFIX}_${rg.id}`,
|
id: `${PROMPT_REGION_MASK_TO_TENSOR_PREFIX}_${region.id}`,
|
||||||
type: 'alpha_mask_to_tensor',
|
type: 'alpha_mask_to_tensor',
|
||||||
image: {
|
image: {
|
||||||
image_name,
|
image_name,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
if (rg.positivePrompt) {
|
if (region.positivePrompt) {
|
||||||
// The main positive conditioning node
|
// The main positive conditioning node
|
||||||
const regionalPosCond = g.addNode(
|
const regionalPosCond = g.addNode(
|
||||||
isSDXL
|
isSDXL
|
||||||
? {
|
? {
|
||||||
type: 'sdxl_compel_prompt',
|
type: 'sdxl_compel_prompt',
|
||||||
id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${rg.id}`,
|
id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${region.id}`,
|
||||||
prompt: rg.positivePrompt,
|
prompt: region.positivePrompt,
|
||||||
style: rg.positivePrompt, // TODO: Should we put the positive prompt in both fields?
|
style: region.positivePrompt, // TODO: Should we put the positive prompt in both fields?
|
||||||
}
|
}
|
||||||
: {
|
: {
|
||||||
type: 'compel',
|
type: 'compel',
|
||||||
id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${rg.id}`,
|
id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${region.id}`,
|
||||||
prompt: rg.positivePrompt,
|
prompt: region.positivePrompt,
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
// Connect the mask to the conditioning
|
// Connect the mask to the conditioning
|
||||||
@ -106,20 +100,20 @@ export const addRegions = async (
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (rg.negativePrompt) {
|
if (region.negativePrompt) {
|
||||||
// The main negative conditioning node
|
// The main negative conditioning node
|
||||||
const regionalNegCond = g.addNode(
|
const regionalNegCond = g.addNode(
|
||||||
isSDXL
|
isSDXL
|
||||||
? {
|
? {
|
||||||
type: 'sdxl_compel_prompt',
|
type: 'sdxl_compel_prompt',
|
||||||
id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${rg.id}`,
|
id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${region.id}`,
|
||||||
prompt: rg.negativePrompt,
|
prompt: region.negativePrompt,
|
||||||
style: rg.negativePrompt,
|
style: region.negativePrompt,
|
||||||
}
|
}
|
||||||
: {
|
: {
|
||||||
type: 'compel',
|
type: 'compel',
|
||||||
id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${rg.id}`,
|
id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${region.id}`,
|
||||||
prompt: rg.negativePrompt,
|
prompt: region.negativePrompt,
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
// Connect the mask to the conditioning
|
// Connect the mask to the conditioning
|
||||||
@ -143,10 +137,10 @@ export const addRegions = async (
|
|||||||
}
|
}
|
||||||
|
|
||||||
// If we are using the "invert" auto-negative setting, we need to add an additional negative conditioning node
|
// If we are using the "invert" auto-negative setting, we need to add an additional negative conditioning node
|
||||||
if (rg.autoNegative === 'invert' && rg.positivePrompt) {
|
if (region.autoNegative === 'invert' && region.positivePrompt) {
|
||||||
// We re-use the mask image, but invert it when converting to tensor
|
// We re-use the mask image, but invert it when converting to tensor
|
||||||
const invertTensorMask = g.addNode({
|
const invertTensorMask = g.addNode({
|
||||||
id: `${PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX}_${rg.id}`,
|
id: `${PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX}_${region.id}`,
|
||||||
type: 'invert_tensor_mask',
|
type: 'invert_tensor_mask',
|
||||||
});
|
});
|
||||||
// Connect the OG mask image to the inverted mask-to-tensor node
|
// Connect the OG mask image to the inverted mask-to-tensor node
|
||||||
@ -156,14 +150,14 @@ export const addRegions = async (
|
|||||||
isSDXL
|
isSDXL
|
||||||
? {
|
? {
|
||||||
type: 'sdxl_compel_prompt',
|
type: 'sdxl_compel_prompt',
|
||||||
id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${rg.id}`,
|
id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${region.id}`,
|
||||||
prompt: rg.positivePrompt,
|
prompt: region.positivePrompt,
|
||||||
style: rg.positivePrompt,
|
style: region.positivePrompt,
|
||||||
}
|
}
|
||||||
: {
|
: {
|
||||||
type: 'compel',
|
type: 'compel',
|
||||||
id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${rg.id}`,
|
id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${region.id}`,
|
||||||
prompt: rg.positivePrompt,
|
prompt: region.positivePrompt,
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
// Connect the inverted mask to the conditioning
|
// Connect the inverted mask to the conditioning
|
||||||
@ -186,7 +180,7 @@ export const addRegions = async (
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const validRGIPAdapters: IPAdapterEntity[] = rg.ipAdapters.filter((ipa) => isValidIPAdapter(ipa, base));
|
const validRGIPAdapters: IPAdapterEntity[] = region.ipAdapters.filter((ipa) => isValidIPAdapter(ipa, base));
|
||||||
|
|
||||||
for (const ipa of validRGIPAdapters) {
|
for (const ipa of validRGIPAdapters) {
|
||||||
const ipAdapterCollect = addIPAdapterCollectorSafe(g, denoise);
|
const ipAdapterCollect = addIPAdapterCollectorSafe(g, denoise);
|
||||||
@ -245,6 +239,20 @@ export const getMaskImage = async (rg: RegionEntity, blob: Blob): Promise<ImageD
|
|||||||
return imageDTO;
|
return imageDTO;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const uploadMaskImage = async ({ id }: RegionEntity, blob: Blob): Promise<ImageDTO> => {
|
||||||
|
const { dispatch } = getStore();
|
||||||
|
// No cached mask, or the cached image no longer exists - we need to upload the mask image
|
||||||
|
const file = new File([blob], `${id}_mask.png`, { type: 'image/png' });
|
||||||
|
const req = dispatch(
|
||||||
|
imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'mask', is_intermediate: true })
|
||||||
|
);
|
||||||
|
req.reset();
|
||||||
|
|
||||||
|
const imageDTO = await req.unwrap();
|
||||||
|
dispatch(rgMaskImageUploaded({ id, imageDTO }));
|
||||||
|
return imageDTO;
|
||||||
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the blobs of all regional prompt layers. Only visible layers are returned.
|
* Get the blobs of all regional prompt layers. Only visible layers are returned.
|
||||||
* @param layerIds The IDs of the layers to get blobs for. If not provided, all regional prompt layers are used.
|
* @param layerIds The IDs of the layers to get blobs for. If not provided, all regional prompt layers are used.
|
||||||
@ -252,33 +260,32 @@ export const getMaskImage = async (rg: RegionEntity, blob: Blob): Promise<ImageD
|
|||||||
* @returns A map of layer IDs to blobs.
|
* @returns A map of layer IDs to blobs.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
export const getRGMaskBlobs = async (
|
export const getRegionMaskImage = async (
|
||||||
regions: RegionEntity[],
|
region: RegionEntity,
|
||||||
documentSize: Dimensions,
|
|
||||||
bbox: IRect,
|
bbox: IRect,
|
||||||
preview: boolean = false
|
preview: boolean = false
|
||||||
): Promise<Record<string, Blob>> => {
|
): Promise<ImageDTO> => {
|
||||||
const container = document.createElement('div');
|
const manager = $nodeManager.get();
|
||||||
const stage = new Konva.Stage({ container, ...documentSize });
|
assert(manager, 'Node manager is null');
|
||||||
const manager = new KonvaNodeManager(stage);
|
|
||||||
renderRegions(manager, regions, 1, 'brush', null);
|
|
||||||
const adapters = manager.getAll();
|
|
||||||
const blobs: Record<string, Blob> = {};
|
|
||||||
|
|
||||||
// First remove all layers
|
// TODO(psyche): Why do I need to annotate this? TS must have some kind of circular ref w/ this type but I can't figure it out...
|
||||||
for (const adapter of adapters) {
|
const adapter: KonvaEntityAdapter | undefined = manager.get(region.id);
|
||||||
adapter.konvaLayer.remove();
|
assert(adapter, `Adapter for region ${region.id} not found`);
|
||||||
|
if (region.imageCache) {
|
||||||
|
const imageDTO = await getImageDTO(region.imageCache.name);
|
||||||
|
if (imageDTO) {
|
||||||
|
return imageDTO;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
const layer = adapter.konvaLayer.clone();
|
||||||
|
const objectGroup = adapter.konvaObjectGroup.clone();
|
||||||
|
layer.destroyChildren();
|
||||||
|
layer.add(objectGroup);
|
||||||
|
objectGroup.opacity(1);
|
||||||
|
objectGroup.cache();
|
||||||
|
|
||||||
// Next render each layer to a blob
|
|
||||||
for (const adapter of adapters) {
|
|
||||||
const region = regions.find((l) => l.id === adapter.id);
|
|
||||||
if (!region) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
stage.add(adapter.konvaLayer);
|
|
||||||
const blob = await new Promise<Blob>((resolve) => {
|
const blob = await new Promise<Blob>((resolve) => {
|
||||||
stage.toBlob({
|
layer.toBlob({
|
||||||
callback: (blob) => {
|
callback: (blob) => {
|
||||||
assert(blob, 'Blob is null');
|
assert(blob, 'Blob is null');
|
||||||
resolve(blob);
|
resolve(blob);
|
||||||
@ -289,16 +296,11 @@ export const getRGMaskBlobs = async (
|
|||||||
|
|
||||||
if (preview) {
|
if (preview) {
|
||||||
const base64 = await blobToDataURL(blob);
|
const base64 = await blobToDataURL(blob);
|
||||||
openBase64ImageInTab([
|
const caption = `${region.id}: ${region.positivePrompt} / ${region.negativePrompt}`;
|
||||||
{
|
openBase64ImageInTab([{ base64, caption }]);
|
||||||
base64,
|
|
||||||
caption: `${region.id}: ${region.positivePrompt} / ${region.negativePrompt}`,
|
|
||||||
},
|
|
||||||
]);
|
|
||||||
}
|
|
||||||
adapter.konvaLayer.remove();
|
|
||||||
blobs[adapter.id] = blob;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return blobs;
|
layer.destroy();
|
||||||
|
|
||||||
|
return await uploadMaskImage(region, blob);
|
||||||
};
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user