mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): generation mode calculation, fudged graphs
This commit is contained in:
parent
7dd11bd60a
commit
e9204b87e3
@ -23,6 +23,8 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
|||||||
const manager = $nodeManager.get();
|
const manager = $nodeManager.get();
|
||||||
assert(manager, 'Konva node manager not initialized');
|
assert(manager, 'Konva node manager not initialized');
|
||||||
|
|
||||||
|
console.log('generation mode', manager.util.getGenerationMode());
|
||||||
|
|
||||||
if (model?.base === 'sdxl') {
|
if (model?.base === 'sdxl') {
|
||||||
graph = await buildGenerationTabSDXLGraph(state, manager);
|
graph = await buildGenerationTabSDXLGraph(state, manager);
|
||||||
} else {
|
} else {
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
export const getImageDataTransparency = (pixels: Uint8ClampedArray) => {
|
export const getImageDataTransparency = (imageData: ImageData) => {
|
||||||
let isFullyTransparent = true;
|
let isFullyTransparent = true;
|
||||||
let isPartiallyTransparent = false;
|
let isPartiallyTransparent = false;
|
||||||
const len = pixels.length;
|
const len = imageData.data.length;
|
||||||
let i = 3;
|
for (let i = 3; i < len; i += 4) {
|
||||||
for (i; i < len; i += 4) {
|
if (imageData.data[i] === 255) {
|
||||||
if (pixels[i] === 255) {
|
|
||||||
isFullyTransparent = false;
|
isFullyTransparent = false;
|
||||||
} else {
|
} else {
|
||||||
isPartiallyTransparent = true;
|
isPartiallyTransparent = true;
|
||||||
@ -18,8 +17,8 @@ export const getImageDataTransparency = (pixels: Uint8ClampedArray) => {
|
|||||||
|
|
||||||
export const areAnyPixelsBlack = (pixels: Uint8ClampedArray) => {
|
export const areAnyPixelsBlack = (pixels: Uint8ClampedArray) => {
|
||||||
const len = pixels.length;
|
const len = pixels.length;
|
||||||
let i = 0;
|
const i = 0;
|
||||||
for (i; i < len; ) {
|
for (let i = 0; i < len; i) {
|
||||||
if (pixels[i++] === 0 && pixels[i++] === 0 && pixels[i++] === 0 && pixels[i++] === 255) {
|
if (pixels[i++] === 0 && pixels[i++] === 0 && pixels[i++] === 0 && pixels[i++] === 255) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
|
import { getImageDataTransparency } from 'common/util/arrayBuffer';
|
||||||
import { blobToDataURL } from 'features/controlLayers/konva/util';
|
import { konvaNodeToBlob, konvaNodeToImageData, previewBlob } from 'features/controlLayers/konva/util';
|
||||||
import type {
|
import type {
|
||||||
BrushLine,
|
BrushLine,
|
||||||
BrushLineAddedArg,
|
BrushLineAddedArg,
|
||||||
@ -7,6 +7,7 @@ import type {
|
|||||||
CanvasV2State,
|
CanvasV2State,
|
||||||
EraserLine,
|
EraserLine,
|
||||||
EraserLineAddedArg,
|
EraserLineAddedArg,
|
||||||
|
GenerationMode,
|
||||||
ImageObject,
|
ImageObject,
|
||||||
PointAddedToLineArg,
|
PointAddedToLineArg,
|
||||||
PosChangedArg,
|
PosChangedArg,
|
||||||
@ -157,6 +158,9 @@ type Util = {
|
|||||||
getRegionMaskImage: (arg: { id: string; bbox?: Rect; preview?: boolean }) => Promise<ImageDTO>;
|
getRegionMaskImage: (arg: { id: string; bbox?: Rect; preview?: boolean }) => Promise<ImageDTO>;
|
||||||
getInpaintMaskImage: (arg: { bbox?: Rect; preview?: boolean }) => Promise<ImageDTO>;
|
getInpaintMaskImage: (arg: { bbox?: Rect; preview?: boolean }) => Promise<ImageDTO>;
|
||||||
getImageSourceImage: (arg: { bbox?: Rect; preview?: boolean }) => Promise<ImageDTO>;
|
getImageSourceImage: (arg: { bbox?: Rect; preview?: boolean }) => Promise<ImageDTO>;
|
||||||
|
getMaskLayerClone: (arg: { id: string }) => Konva.Layer;
|
||||||
|
getCompositeLayerStageClone: () => Konva.Stage;
|
||||||
|
getGenerationMode: () => GenerationMode;
|
||||||
};
|
};
|
||||||
|
|
||||||
export class KonvaNodeManager {
|
export class KonvaNodeManager {
|
||||||
@ -183,6 +187,9 @@ export class KonvaNodeManager {
|
|||||||
getRegionMaskImage: this._getRegionMaskImage.bind(this),
|
getRegionMaskImage: this._getRegionMaskImage.bind(this),
|
||||||
getInpaintMaskImage: this._getInpaintMaskImage.bind(this),
|
getInpaintMaskImage: this._getInpaintMaskImage.bind(this),
|
||||||
getImageSourceImage: this._getImageSourceImage.bind(this),
|
getImageSourceImage: this._getImageSourceImage.bind(this),
|
||||||
|
getMaskLayerClone: this._getMaskLayerClone.bind(this),
|
||||||
|
getCompositeLayerStageClone: this._getCompositeLayerStageClone.bind(this),
|
||||||
|
getGenerationMode: this._getGenerationMode.bind(this),
|
||||||
};
|
};
|
||||||
this._konvaApi = null;
|
this._konvaApi = null;
|
||||||
this._preview = null;
|
this._preview = null;
|
||||||
@ -254,112 +261,34 @@ export class KonvaNodeManager {
|
|||||||
return this._stateApi;
|
return this._stateApi;
|
||||||
}
|
}
|
||||||
|
|
||||||
async _getRegionMaskImage(arg: { id: string; bbox?: Rect; preview?: boolean }): Promise<ImageDTO> {
|
_getMaskLayerClone(arg: { id: string }): Konva.Layer {
|
||||||
const { id, bbox, preview = false } = arg;
|
const { id } = arg;
|
||||||
const region = this.stateApi.getRegionsState().entities.find((entity) => entity.id === id);
|
const adapter = this.get(id);
|
||||||
assert(region, `Region entity state with id ${id} not found`);
|
assert(adapter, `Adapter for entity ${id} not found`);
|
||||||
const adapter = this.get(region.id);
|
|
||||||
assert(adapter, `Adapter for region ${region.id} not found`);
|
|
||||||
|
|
||||||
if (region.imageCache) {
|
const layerClone = adapter.konvaLayer.clone();
|
||||||
const imageDTO = await this.util.getImageDTO(region.imageCache.name);
|
const objectGroupClone = adapter.konvaObjectGroup.clone();
|
||||||
if (imageDTO) {
|
|
||||||
return imageDTO;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const layer = adapter.konvaLayer.clone();
|
layerClone.destroyChildren();
|
||||||
const objectGroup = adapter.konvaObjectGroup.clone();
|
layerClone.add(objectGroupClone);
|
||||||
layer.destroyChildren();
|
|
||||||
layer.add(objectGroup);
|
|
||||||
objectGroup.opacity(1);
|
|
||||||
objectGroup.cache();
|
|
||||||
|
|
||||||
const blob = await new Promise<Blob>((resolve) => {
|
objectGroupClone.opacity(1);
|
||||||
layer.toBlob({
|
objectGroupClone.cache();
|
||||||
callback: (blob) => {
|
|
||||||
assert(blob, 'Blob is null');
|
|
||||||
resolve(blob);
|
|
||||||
},
|
|
||||||
...bbox,
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
if (preview) {
|
return layerClone;
|
||||||
const base64 = await blobToDataURL(blob);
|
|
||||||
const caption = `${region.id}: ${region.positivePrompt} / ${region.negativePrompt}`;
|
|
||||||
openBase64ImageInTab([{ base64, caption }]);
|
|
||||||
}
|
|
||||||
|
|
||||||
layer.destroy();
|
|
||||||
|
|
||||||
const imageDTO = await this.util.uploadImage(blob, `${region.id}_mask.png`, 'mask', true);
|
|
||||||
this.stateApi.onRegionMaskImageCached(region.id, imageDTO);
|
|
||||||
return imageDTO;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async _getInpaintMaskImage(arg: { bbox?: Rect; preview?: boolean }): Promise<ImageDTO> {
|
_getCompositeLayerStageClone(): Konva.Stage {
|
||||||
const { bbox, preview = false } = arg;
|
|
||||||
const inpaintMask = this.stateApi.getInpaintMaskState();
|
|
||||||
const adapter = this.get(inpaintMask.id);
|
|
||||||
assert(adapter, `Adapter for ${inpaintMask.id} not found`);
|
|
||||||
|
|
||||||
if (inpaintMask.imageCache) {
|
|
||||||
const imageDTO = await this.util.getImageDTO(inpaintMask.imageCache.name);
|
|
||||||
if (imageDTO) {
|
|
||||||
return imageDTO;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const layer = adapter.konvaLayer.clone();
|
|
||||||
const objectGroup = adapter.konvaObjectGroup.clone();
|
|
||||||
layer.destroyChildren();
|
|
||||||
layer.add(objectGroup);
|
|
||||||
objectGroup.opacity(1);
|
|
||||||
objectGroup.cache();
|
|
||||||
|
|
||||||
const blob = await new Promise<Blob>((resolve) => {
|
|
||||||
layer.toBlob({
|
|
||||||
callback: (blob) => {
|
|
||||||
assert(blob, 'Blob is null');
|
|
||||||
resolve(blob);
|
|
||||||
},
|
|
||||||
...bbox,
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
if (preview) {
|
|
||||||
const base64 = await blobToDataURL(blob);
|
|
||||||
const caption = 'inpaint mask';
|
|
||||||
openBase64ImageInTab([{ base64, caption }]);
|
|
||||||
}
|
|
||||||
|
|
||||||
layer.destroy();
|
|
||||||
|
|
||||||
const imageDTO = await this.util.uploadImage(blob, 'inpaint_mask.png', 'mask', true);
|
|
||||||
this.stateApi.onInpaintMaskImageCached(imageDTO);
|
|
||||||
return imageDTO;
|
|
||||||
}
|
|
||||||
|
|
||||||
async _getImageSourceImage(arg: { bbox?: Rect; preview?: boolean }): Promise<ImageDTO> {
|
|
||||||
const { bbox, preview = false } = arg;
|
|
||||||
const layersState = this.stateApi.getLayersState();
|
const layersState = this.stateApi.getLayersState();
|
||||||
const { entities, imageCache } = layersState;
|
|
||||||
if (imageCache) {
|
|
||||||
const imageDTO = await this.util.getImageDTO(imageCache.name);
|
|
||||||
if (imageDTO) {
|
|
||||||
return imageDTO;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const stage = this.stage.clone();
|
const stageClone = this.stage.clone();
|
||||||
|
|
||||||
stage.scaleX(1);
|
stageClone.scaleX(1);
|
||||||
stage.scaleY(1);
|
stageClone.scaleY(1);
|
||||||
stage.x(0);
|
stageClone.x(0);
|
||||||
stage.y(0);
|
stageClone.y(0);
|
||||||
|
|
||||||
const validLayers = entities.filter(isValidLayer);
|
const validLayers = layersState.entities.filter(isValidLayer);
|
||||||
|
|
||||||
// Konva bug (?) - when iterating over the array returned from `stage.getLayers()`, if you destroy a layer, the array
|
// Konva bug (?) - when iterating over the array returned from `stage.getLayers()`, if you destroy a layer, the array
|
||||||
// is mutated in-place and the next iteration will skip the next layer. To avoid this, we first collect the layers
|
// is mutated in-place and the next iteration will skip the next layer. To avoid this, we first collect the layers
|
||||||
@ -367,7 +296,7 @@ export class KonvaNodeManager {
|
|||||||
// TODO(psyche): Maybe report this?
|
// TODO(psyche): Maybe report this?
|
||||||
const toDelete: Konva.Layer[] = [];
|
const toDelete: Konva.Layer[] = [];
|
||||||
|
|
||||||
for (const konvaLayer of stage.getLayers()) {
|
for (const konvaLayer of stageClone.getLayers()) {
|
||||||
const layer = validLayers.find((l) => l.id === konvaLayer.id());
|
const layer = validLayers.find((l) => l.id === konvaLayer.id());
|
||||||
if (!layer) {
|
if (!layer) {
|
||||||
toDelete.push(konvaLayer);
|
toDelete.push(konvaLayer);
|
||||||
@ -378,22 +307,100 @@ export class KonvaNodeManager {
|
|||||||
konvaLayer.destroy();
|
konvaLayer.destroy();
|
||||||
}
|
}
|
||||||
|
|
||||||
const blob = await new Promise<Blob>((resolve) => {
|
return stageClone;
|
||||||
stage.toBlob({
|
}
|
||||||
callback: (blob) => {
|
|
||||||
assert(blob, 'Blob is null');
|
|
||||||
resolve(blob);
|
|
||||||
},
|
|
||||||
...bbox,
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
if (preview) {
|
_getGenerationMode(): GenerationMode {
|
||||||
const base64 = await blobToDataURL(blob);
|
const { x, y, width, height } = this.stateApi.getBbox();
|
||||||
openBase64ImageInTab([{ base64, caption: 'base layer' }]);
|
const inpaintMaskLayer = this.util.getMaskLayerClone({ id: 'inpaint_mask' });
|
||||||
|
const inpaintMaskImageData = konvaNodeToImageData(inpaintMaskLayer, { x, y, width, height });
|
||||||
|
const inpaintMaskTransparency = getImageDataTransparency(inpaintMaskImageData);
|
||||||
|
const compositeLayer = this.util.getCompositeLayerStageClone();
|
||||||
|
const compositeLayerImageData = konvaNodeToImageData(compositeLayer, { x, y, width, height });
|
||||||
|
const compositeLayerTransparency = getImageDataTransparency(compositeLayerImageData);
|
||||||
|
if (compositeLayerTransparency.isPartiallyTransparent) {
|
||||||
|
if (compositeLayerTransparency.isFullyTransparent) {
|
||||||
|
return 'txt2img';
|
||||||
|
}
|
||||||
|
return 'outpaint';
|
||||||
|
} else {
|
||||||
|
if (!inpaintMaskTransparency.isFullyTransparent) {
|
||||||
|
return 'inpaint';
|
||||||
|
}
|
||||||
|
return 'img2img';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async _getRegionMaskImage(arg: { id: string; bbox?: Rect; preview?: boolean }): Promise<ImageDTO> {
|
||||||
|
const { id, bbox, preview = false } = arg;
|
||||||
|
const region = this.stateApi.getRegionsState().entities.find((entity) => entity.id === id);
|
||||||
|
assert(region, `Region entity state with id ${id} not found`);
|
||||||
|
|
||||||
|
if (region.imageCache) {
|
||||||
|
const imageDTO = await this.util.getImageDTO(region.imageCache.name);
|
||||||
|
if (imageDTO) {
|
||||||
|
return imageDTO;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
stage.destroy();
|
const layerClone = this.util.getMaskLayerClone({ id });
|
||||||
|
const blob = await konvaNodeToBlob(layerClone, bbox);
|
||||||
|
|
||||||
|
if (preview) {
|
||||||
|
previewBlob(blob, `region ${region.id} mask`);
|
||||||
|
}
|
||||||
|
|
||||||
|
layerClone.destroy();
|
||||||
|
|
||||||
|
const imageDTO = await this.util.uploadImage(blob, `${region.id}_mask.png`, 'mask', true);
|
||||||
|
this.stateApi.onRegionMaskImageCached(region.id, imageDTO);
|
||||||
|
return imageDTO;
|
||||||
|
}
|
||||||
|
|
||||||
|
async _getInpaintMaskImage(arg: { bbox?: Rect; preview?: boolean }): Promise<ImageDTO> {
|
||||||
|
const { bbox, preview = false } = arg;
|
||||||
|
const inpaintMask = this.stateApi.getInpaintMaskState();
|
||||||
|
|
||||||
|
if (inpaintMask.imageCache) {
|
||||||
|
const imageDTO = await this.util.getImageDTO(inpaintMask.imageCache.name);
|
||||||
|
if (imageDTO) {
|
||||||
|
return imageDTO;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const layerClone = this.util.getMaskLayerClone({ id: inpaintMask.id });
|
||||||
|
const blob = await konvaNodeToBlob(layerClone, bbox);
|
||||||
|
|
||||||
|
if (preview) {
|
||||||
|
previewBlob(blob, 'inpaint mask');
|
||||||
|
}
|
||||||
|
|
||||||
|
layerClone.destroy();
|
||||||
|
|
||||||
|
const imageDTO = await this.util.uploadImage(blob, 'inpaint_mask.png', 'mask', true);
|
||||||
|
this.stateApi.onInpaintMaskImageCached(imageDTO);
|
||||||
|
return imageDTO;
|
||||||
|
}
|
||||||
|
|
||||||
|
async _getImageSourceImage(arg: { bbox?: Rect; preview?: boolean }): Promise<ImageDTO> {
|
||||||
|
const { bbox, preview = false } = arg;
|
||||||
|
const { imageCache } = this.stateApi.getLayersState();
|
||||||
|
if (imageCache) {
|
||||||
|
const imageDTO = await this.util.getImageDTO(imageCache.name);
|
||||||
|
if (imageDTO) {
|
||||||
|
return imageDTO;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const stageClone = this.util.getCompositeLayerStageClone();
|
||||||
|
|
||||||
|
const blob = await konvaNodeToBlob(stageClone, bbox);
|
||||||
|
|
||||||
|
if (preview) {
|
||||||
|
previewBlob(blob, 'image source');
|
||||||
|
}
|
||||||
|
|
||||||
|
stageClone.destroy();
|
||||||
|
|
||||||
const imageDTO = await this.util.uploadImage(blob, 'base_layer.png', 'general', true);
|
const imageDTO = await this.util.uploadImage(blob, 'base_layer.png', 'general', true);
|
||||||
this.stateApi.onLayerImageCached(imageDTO);
|
this.stateApi.onLayerImageCached(imageDTO);
|
||||||
|
@ -11,10 +11,11 @@ import {
|
|||||||
RG_LAYER_NAME,
|
RG_LAYER_NAME,
|
||||||
RG_LAYER_RECT_SHAPE_NAME,
|
RG_LAYER_RECT_SHAPE_NAME,
|
||||||
} from 'features/controlLayers/konva/naming';
|
} from 'features/controlLayers/konva/naming';
|
||||||
import type { RgbaColor } from 'features/controlLayers/store/types';
|
import type { Rect, RgbaColor } from 'features/controlLayers/store/types';
|
||||||
import Konva from 'konva';
|
import Konva from 'konva';
|
||||||
import type { KonvaEventObject } from 'konva/lib/Node';
|
import type { KonvaEventObject } from 'konva/lib/Node';
|
||||||
import type { IRect, Vector2d } from 'konva/lib/types';
|
import type { Vector2d } from 'konva/lib/types';
|
||||||
|
import { assert } from 'tsafe';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Gets the scaled and floored cursor position on the stage. If the cursor is not currently over the stage, returns null.
|
* Gets the scaled and floored cursor position on the stage. If the cursor is not currently over the stage, returns null.
|
||||||
@ -203,24 +204,33 @@ export const dataURLToImageData = async (dataURL: string, width: number, height:
|
|||||||
/**
|
/**
|
||||||
* Converts a Konva node to a Blob
|
* Converts a Konva node to a Blob
|
||||||
* @param node - The Konva node to convert to a Blob
|
* @param node - The Konva node to convert to a Blob
|
||||||
* @param boundingBox - The bounding box to crop to
|
* @param bbox - The bounding box to crop to
|
||||||
* @returns A Promise that resolves with Blob of the node cropped to the bounding box
|
* @returns A Promise that resolves with Blob of the node cropped to the bounding box
|
||||||
*/
|
*/
|
||||||
export const konvaNodeToBlob = async (node: Konva.Node, boundingBox: IRect): Promise<Blob> => {
|
export const konvaNodeToBlob = async (node: Konva.Node, bbox?: Rect): Promise<Blob> => {
|
||||||
return await canvasToBlob(node.toCanvas(boundingBox));
|
return await new Promise<Blob>((resolve) => {
|
||||||
|
node.toBlob({
|
||||||
|
callback: (blob) => {
|
||||||
|
assert(blob, 'Blob is null');
|
||||||
|
resolve(blob);
|
||||||
|
},
|
||||||
|
...(bbox ?? {}),
|
||||||
|
});
|
||||||
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Converts a Konva node to an ImageData object
|
* Converts a Konva node to an ImageData object
|
||||||
* @param node - The Konva node to convert to an ImageData object
|
* @param node - The Konva node to convert to an ImageData object
|
||||||
* @param boundingBox - The bounding box to crop to
|
* @param bbox - The bounding box to crop to
|
||||||
* @returns A Promise that resolves with ImageData object of the node cropped to the bounding box
|
* @returns A Promise that resolves with ImageData object of the node cropped to the bounding box
|
||||||
*/
|
*/
|
||||||
export const konvaNodeToImageData = async (node: Konva.Node, boundingBox: IRect): Promise<ImageData> => {
|
export const konvaNodeToImageData = (node: Konva.Node, bbox?: Rect): ImageData => {
|
||||||
// get a dataURL of the bbox'd region
|
// get a dataURL of the bbox'd region
|
||||||
const dataURL = node.toDataURL(boundingBox);
|
const canvas = node.toCanvas({ ...(bbox ?? {}) });
|
||||||
|
const ctx = canvas.getContext('2d');
|
||||||
return await dataURLToImageData(dataURL, boundingBox.width, boundingBox.height);
|
assert(ctx, 'ctx is null');
|
||||||
|
return ctx.getImageData(0, 0, canvas.width, canvas.height);
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -246,3 +256,16 @@ export const getPixelUnderCursor = (stage: Konva.Stage): RgbaColor | null => {
|
|||||||
|
|
||||||
return { r, g, b, a };
|
return { r, g, b, a };
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const previewBlob = async (blob: Blob, label?: string) => {
|
||||||
|
const url = URL.createObjectURL(blob);
|
||||||
|
const w = window.open('');
|
||||||
|
if (!w) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (label) {
|
||||||
|
w.document.write(label);
|
||||||
|
w.document.write('</br>');
|
||||||
|
}
|
||||||
|
w.document.write(`<img src="${url}" style="border: 1px solid red;" />`);
|
||||||
|
};
|
||||||
|
@ -906,3 +906,5 @@ export const isLine = (obj: RenderableObject): obj is BrushLine | EraserLine =>
|
|||||||
export type RemoveIndexString<T> = {
|
export type RemoveIndexString<T> = {
|
||||||
[K in keyof T as string extends K ? never : K]: T[K];
|
[K in keyof T as string extends K ? never : K]: T[K];
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type GenerationMode = 'txt2img' | 'img2img' | 'inpaint' | 'outpaint';
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
|
import type { KonvaNodeManager } from 'features/controlLayers/konva/nodeManager';
|
||||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
import {
|
import {
|
||||||
CLIP_SKIP,
|
CLIP_SKIP,
|
||||||
@ -29,7 +30,7 @@ import { assert } from 'tsafe';
|
|||||||
|
|
||||||
import { addRegions } from './addRegions';
|
import { addRegions } from './addRegions';
|
||||||
|
|
||||||
export const buildGenerationTabGraph = async (state: RootState): Promise<GraphType> => {
|
export const buildGenerationTabGraph = async (state: RootState, manager: KonvaNodeManager): Promise<GraphType> => {
|
||||||
const {
|
const {
|
||||||
model,
|
model,
|
||||||
cfgScale: cfg_scale,
|
cfgScale: cfg_scale,
|
||||||
@ -159,6 +160,7 @@ export const buildGenerationTabGraph = async (state: RootState): Promise<GraphTy
|
|||||||
const _addedCAs = addControlAdapters(state.canvasV2.controlAdapters.entities, g, denoise, modelConfig.base);
|
const _addedCAs = addControlAdapters(state.canvasV2.controlAdapters.entities, g, denoise, modelConfig.base);
|
||||||
const _addedIPAs = addIPAdapters(state.canvasV2.ipAdapters.entities, g, denoise, modelConfig.base);
|
const _addedIPAs = addIPAdapters(state.canvasV2.ipAdapters.entities, g, denoise, modelConfig.base);
|
||||||
const _addedRegions = await addRegions(
|
const _addedRegions = await addRegions(
|
||||||
|
manager,
|
||||||
state.canvasV2.regions.entities,
|
state.canvasV2.regions.entities,
|
||||||
g,
|
g,
|
||||||
state.canvasV2.document,
|
state.canvasV2.document,
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
|
import type { KonvaNodeManager } from 'features/controlLayers/konva/nodeManager';
|
||||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
import {
|
import {
|
||||||
LATENTS_TO_IMAGE,
|
LATENTS_TO_IMAGE,
|
||||||
@ -27,7 +28,10 @@ import { assert } from 'tsafe';
|
|||||||
|
|
||||||
import { addRegions } from './addRegions';
|
import { addRegions } from './addRegions';
|
||||||
|
|
||||||
export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<NonNullableGraph> => {
|
export const buildGenerationTabSDXLGraph = async (
|
||||||
|
state: RootState,
|
||||||
|
manager: KonvaNodeManager
|
||||||
|
): Promise<NonNullableGraph> => {
|
||||||
const {
|
const {
|
||||||
model,
|
model,
|
||||||
cfgScale: cfg_scale,
|
cfgScale: cfg_scale,
|
||||||
@ -42,6 +46,7 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
|
|||||||
negativePrompt,
|
negativePrompt,
|
||||||
refinerModel,
|
refinerModel,
|
||||||
refinerStart,
|
refinerStart,
|
||||||
|
img2imgStrength,
|
||||||
} = state.canvasV2.params;
|
} = state.canvasV2.params;
|
||||||
const { width, height } = state.canvasV2.bbox;
|
const { width, height } = state.canvasV2.bbox;
|
||||||
|
|
||||||
@ -76,6 +81,7 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
|
|||||||
id: NEGATIVE_CONDITIONING_COLLECT,
|
id: NEGATIVE_CONDITIONING_COLLECT,
|
||||||
});
|
});
|
||||||
const noise = g.addNode({ type: 'noise', id: NOISE, seed, width, height, use_cpu: shouldUseCpuNoise });
|
const noise = g.addNode({ type: 'noise', id: NOISE, seed, width, height, use_cpu: shouldUseCpuNoise });
|
||||||
|
const i2l = g.addNode({ type: 'i2l', id: 'i2l' });
|
||||||
const denoise = g.addNode({
|
const denoise = g.addNode({
|
||||||
type: 'denoise_latents',
|
type: 'denoise_latents',
|
||||||
id: SDXL_DENOISE_LATENTS,
|
id: SDXL_DENOISE_LATENTS,
|
||||||
@ -83,7 +89,7 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
|
|||||||
cfg_rescale_multiplier,
|
cfg_rescale_multiplier,
|
||||||
scheduler,
|
scheduler,
|
||||||
steps,
|
steps,
|
||||||
denoising_start: 0,
|
denoising_start: refinerModel ? Math.min(refinerStart, 1 - img2imgStrength) : 1 - img2imgStrength,
|
||||||
denoising_end: refinerModel ? refinerStart : 1,
|
denoising_end: refinerModel ? refinerStart : 1,
|
||||||
});
|
});
|
||||||
const l2i = g.addNode({
|
const l2i = g.addNode({
|
||||||
@ -116,6 +122,7 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
|
|||||||
g.addEdge(posCondCollect, 'collection', denoise, 'positive_conditioning');
|
g.addEdge(posCondCollect, 'collection', denoise, 'positive_conditioning');
|
||||||
g.addEdge(negCondCollect, 'collection', denoise, 'negative_conditioning');
|
g.addEdge(negCondCollect, 'collection', denoise, 'negative_conditioning');
|
||||||
g.addEdge(noise, 'noise', denoise, 'noise');
|
g.addEdge(noise, 'noise', denoise, 'noise');
|
||||||
|
g.addEdge(i2l, 'latents', denoise, 'latents');
|
||||||
g.addEdge(denoise, 'latents', l2i, 'latents');
|
g.addEdge(denoise, 'latents', l2i, 'latents');
|
||||||
|
|
||||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||||
@ -146,6 +153,7 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
|
|||||||
// We might get the VAE from the main model, custom VAE, or seamless node.
|
// We might get the VAE from the main model, custom VAE, or seamless node.
|
||||||
const vaeSource = seamless ?? vaeLoader ?? modelLoader;
|
const vaeSource = seamless ?? vaeLoader ?? modelLoader;
|
||||||
g.addEdge(vaeSource, 'vae', l2i, 'vae');
|
g.addEdge(vaeSource, 'vae', l2i, 'vae');
|
||||||
|
g.addEdge(vaeSource, 'vae', i2l, 'vae');
|
||||||
|
|
||||||
// Add Refiner if enabled
|
// Add Refiner if enabled
|
||||||
if (refinerModel) {
|
if (refinerModel) {
|
||||||
@ -155,6 +163,7 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
|
|||||||
const _addedCAs = addControlAdapters(state.canvasV2.controlAdapters.entities, g, denoise, modelConfig.base);
|
const _addedCAs = addControlAdapters(state.canvasV2.controlAdapters.entities, g, denoise, modelConfig.base);
|
||||||
const _addedIPAs = addIPAdapters(state.canvasV2.ipAdapters.entities, g, denoise, modelConfig.base);
|
const _addedIPAs = addIPAdapters(state.canvasV2.ipAdapters.entities, g, denoise, modelConfig.base);
|
||||||
const _addedRegions = await addRegions(
|
const _addedRegions = await addRegions(
|
||||||
|
manager,
|
||||||
state.canvasV2.regions.entities,
|
state.canvasV2.regions.entities,
|
||||||
g,
|
g,
|
||||||
state.canvasV2.document,
|
state.canvasV2.document,
|
||||||
@ -166,6 +175,9 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
|
|||||||
posCondCollect,
|
posCondCollect,
|
||||||
negCondCollect
|
negCondCollect
|
||||||
);
|
);
|
||||||
|
const { image_name } = await manager.util.getImageSourceImage({ bbox: state.canvasV2.bbox, preview: true });
|
||||||
|
await manager.util.getInpaintMaskImage({ bbox: state.canvasV2.bbox, preview: true });
|
||||||
|
i2l.image = { image_name };
|
||||||
|
|
||||||
if (state.system.shouldUseNSFWChecker) {
|
if (state.system.shouldUseNSFWChecker) {
|
||||||
imageOutput = addNSFWChecker(g, imageOutput);
|
imageOutput = addNSFWChecker(g, imageOutput);
|
||||||
|
Loading…
Reference in New Issue
Block a user