feat(ui): img2img working

This commit is contained in:
psychedelicious 2024-07-11 20:37:00 +10:00
parent 551dd393aa
commit b1fe6f9853
13 changed files with 141 additions and 128 deletions

View File

@ -4,8 +4,8 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
import {
layerAdded,
layerImageAdded,
sessionStagingCanceled,
sessionStagedImageAccepted,
sessionStagingCanceled,
} from 'features/controlLayers/store/canvasV2Slice';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
@ -67,7 +67,7 @@ export const addStagingListeners = (startAppListening: AppStartListening) => {
const { id } = layer;
api.dispatch(layerImageAdded({ id, imageDTO, pos: { x: bbox.x - layer.x, y: bbox.y - layer.y } }));
api.dispatch(layerImageAdded({ id, imageDTO, pos: { x: bbox.rect.x - layer.x, y: bbox.rect.y - layer.y } }));
},
});
};

View File

@ -2,7 +2,7 @@ import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone';
import { parseify } from 'common/util/serialize';
import { sessionImageStaged } from 'features/controlLayers/store/canvasV2Slice';
import { $lastProgressEvent, sessionImageStaged } from 'features/controlLayers/store/canvasV2Slice';
import { boardIdSelected, galleryViewChanged, imageSelected, offsetChanged } from 'features/gallery/store/gallerySlice';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
@ -44,9 +44,11 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
imageDTORequest.unsubscribe();
// handle tab-specific logic
if (data.origin === 'canvas') {
if (data.invocation_source_id === CANVAS_OUTPUT && canvasV2.session.isStaging) {
if (data.origin === 'canvas' && data.invocation_source_id === CANVAS_OUTPUT) {
if (canvasV2.session.isStaging) {
dispatch(sessionImageStaged({ imageDTO }));
} else if (!canvasV2.session.isActive) {
$lastProgressEvent.set(null);
}
} else if (data.origin === 'workflows') {
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);

View File

@ -24,6 +24,7 @@ const LAYER_TYPE_TO_TKEY: Record<CanvasEntity['type'], string> = {
regional_guidance: 'controlLayers.regionalGuidance',
layer: 'controlLayers.raster',
inpaint_mask: 'controlLayers.inpaintMask',
initial_image: 'controlLayers.initialImage',
};
const createSelector = (templates: Templates) =>
@ -149,7 +150,7 @@ const createSelector = (templates: Templates) =>
// T2I Adapters require images have dimensions that are multiples of 64 (SD1.5) or 32 (SDXL)
if (ca.adapterType === 't2i_adapter') {
const multiple = model?.base === 'sdxl' ? 32 : 64;
if (bbox.width % multiple !== 0 || bbox.height % multiple !== 0) {
if (bbox.rect.width % multiple !== 0 || bbox.rect.height % multiple !== 0) {
problems.push(i18n.t('parameters.invoke.layer.t2iAdapterIncompatibleDimensions', { multiple }));
}
}

View File

@ -1,9 +1,11 @@
import { FILTER_MAP } from 'features/controlLayers/konva/filters';
import { loadImage } from 'features/controlLayers/konva/util';
import type { ImageObject } from 'features/controlLayers/store/types';
import { t } from 'i18next';
import Konva from 'konva';
import { getImageDTO as defaultGetImageDTO } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import { assert } from 'tsafe';
export class CanvasImage {
id: string;
@ -23,14 +25,14 @@ export class CanvasImage {
constructor(
imageObject: ImageObject,
options: {
options?: {
getImageDTO?: (imageName: string) => Promise<ImageDTO | null>;
onLoading?: () => void;
onLoad?: (konvaImage: Konva.Image) => void;
onError?: () => void;
}
) {
const { getImageDTO, onLoading, onLoad, onError } = options;
const { getImageDTO, onLoading, onLoad, onError } = options ?? {};
const { id, width, height, x, y, filters } = imageObject;
this.konvaImageGroup = new Konva.Group({ id, listening: false, x, y });
this.konvaPlaceholderGroup = new Konva.Group({ listening: false });
@ -124,21 +126,10 @@ export class CanvasImage {
async updateImageSource(imageName: string) {
try {
this.onLoading();
const imageDTO = await this.getImageDTO(imageName);
if (!imageDTO) {
this.onError();
return;
}
const imageEl = new Image();
imageEl.onload = () => {
this.onLoad(imageName, imageEl);
};
imageEl.onerror = () => {
this.onError();
};
imageEl.id = imageName;
imageEl.src = imageDTO.image_url;
assert(imageDTO !== null, 'imageDTO is null');
const imageEl = await loadImage(imageDTO.image_url);
this.onLoad(imageName, imageEl);
} catch {
this.onError();
}

View File

@ -41,30 +41,19 @@ export class CanvasInitialImage {
return;
}
const imageObject = this.initialImageState.imageObject;
if (!imageObject) {
if (this.image) {
this.image.konvaImageGroup.visible(false);
}
} else if (!this.image) {
this.image = await new CanvasImage(imageObject, {
onLoad: () => {
this.updateGroup();
},
});
if (!this.image) {
this.image = await new CanvasImage(this.initialImageState.imageObject, {});
this.objectsGroup.add(this.image.konvaImageGroup);
await this.image.updateImageSource(imageObject.image.name);
await this.image.update(this.initialImageState.imageObject, true);
} else if (!this.image.isLoading && !this.image.isError) {
await this.image.update(imageObject);
await this.image.update(this.initialImageState.imageObject);
}
this.updateGroup();
}
updateGroup() {
const visible = this.initialImageState ? this.initialImageState.isEnabled : false;
this.layer.visible(visible);
if (this.initialImageState && this.initialImageState.isEnabled && !this.image?.isLoading && !this.image?.isError) {
this.layer.visible(true);
} else {
this.layer.visible(false);
}
}
destroy(): void {

View File

@ -180,14 +180,11 @@ export class CanvasLayer {
assert(image instanceof CanvasImage || image === undefined);
if (!image) {
image = await new CanvasImage(obj, {
onLoad: () => {
this.updateGroup(true);
},
});
image = await new CanvasImage(obj, {});
this.objects.set(image.id, image);
this.objectsGroup.add(image.konvaImageGroup);
await image.updateImageSource(obj.image.name);
this.updateGroup(true);
} else {
if (await image.update(obj, force)) {
return true;

View File

@ -2,6 +2,7 @@ import type { Store } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import { CanvasInitialImage } from 'features/controlLayers/konva/CanvasInitialImage';
import { CanvasProgressPreview } from 'features/controlLayers/konva/CanvasProgressPreview';
import {
getCompositeLayerImage,
getControlAdapterImage,
@ -92,7 +93,8 @@ export class CanvasManager {
new CanvasBbox(this),
new CanvasTool(this),
new CanvasDocumentSizeOverlay(this),
new CanvasStagingArea(this)
new CanvasStagingArea(this),
new CanvasProgressPreview(this)
);
this.stage.add(this.preview.layer);
@ -111,7 +113,7 @@ export class CanvasManager {
}
async renderInitialImage() {
this.initialImage.render(this.stateApi.getInitialImageState());
await this.initialImage.render(this.stateApi.getInitialImageState());
}
async renderLayers() {
@ -135,7 +137,7 @@ export class CanvasManager {
}
}
renderRegions() {
async renderRegions() {
const { entities } = this.stateApi.getRegionsState();
// Destroy the konva nodes for nonexistent entities
@ -153,16 +155,20 @@ export class CanvasManager {
this.regions.set(adapter.id, adapter);
this.stage.add(adapter.layer);
}
adapter.render(entity);
await adapter.render(entity);
}
}
renderInpaintMask() {
const inpaintMaskState = this.stateApi.getInpaintMaskState();
this.inpaintMask.render(inpaintMaskState);
async renderProgressPreview() {
await this.preview.progressPreview.render(this.stateApi.getLastProgressEvent());
}
renderControlAdapters() {
async renderInpaintMask() {
const inpaintMaskState = this.stateApi.getInpaintMaskState();
await this.inpaintMask.render(inpaintMaskState);
}
async renderControlAdapters() {
const { entities } = this.stateApi.getControlAdaptersState();
for (const canvasControlAdapter of this.controlAdapters.values()) {
@ -179,7 +185,7 @@ export class CanvasManager {
this.controlAdapters.set(adapter.id, adapter);
this.stage.add(adapter.layer);
}
adapter.render(entity);
await adapter.render(entity);
}
}
@ -222,7 +228,7 @@ export class CanvasManager {
const state = this.stateApi.getState();
if (this.prevState === state && !this.isFirstRender) {
log.debug('No changes detected, skipping render');
log.trace('No changes detected, skipping render');
return;
}
@ -233,7 +239,7 @@ export class CanvasManager {
state.selectedEntityIdentifier?.id !== this.prevState.selectedEntityIdentifier?.id
) {
log.debug('Rendering layers');
this.renderLayers();
await this.renderLayers();
}
if (
@ -243,8 +249,8 @@ export class CanvasManager {
state.tool.selected !== this.prevState.tool.selected ||
state.selectedEntityIdentifier?.id !== this.prevState.selectedEntityIdentifier?.id
) {
log.debug('Rendering intial image');
this.renderInitialImage();
log.debug('Rendering initial image');
await this.renderInitialImage();
}
if (
@ -255,7 +261,7 @@ export class CanvasManager {
state.selectedEntityIdentifier?.id !== this.prevState.selectedEntityIdentifier?.id
) {
log.debug('Rendering regions');
this.renderRegions();
await this.renderRegions();
}
if (
@ -266,7 +272,7 @@ export class CanvasManager {
state.selectedEntityIdentifier?.id !== this.prevState.selectedEntityIdentifier?.id
) {
log.debug('Rendering inpaint mask');
this.renderInpaintMask();
await this.renderInpaintMask();
}
if (
@ -276,12 +282,12 @@ export class CanvasManager {
state.selectedEntityIdentifier?.id !== this.prevState.selectedEntityIdentifier?.id
) {
log.debug('Rendering control adapters');
this.renderControlAdapters();
await this.renderControlAdapters();
}
if (this.isFirstRender || state.document !== this.prevState.document) {
log.debug('Rendering document bounds overlay');
this.preview.documentSizeOverlay.render();
await this.preview.documentSizeOverlay.render();
}
if (
@ -291,7 +297,7 @@ export class CanvasManager {
state.session.isActive !== this.prevState.session.isActive
) {
log.debug('Rendering generation bbox');
this.preview.bbox.render();
await this.preview.bbox.render();
}
if (
@ -306,7 +312,7 @@ export class CanvasManager {
if (this.isFirstRender || state.session !== this.prevState.session) {
log.debug('Rendering staging area');
this.preview.stagingArea.render();
await this.preview.stagingArea.render();
}
if (
@ -318,7 +324,7 @@ export class CanvasManager {
state.selectedEntityIdentifier?.id !== this.prevState.selectedEntityIdentifier?.id
) {
log.debug('Arranging entities');
this.arrangeEntities();
await this.arrangeEntities();
}
this.prevState = state;
@ -343,16 +349,21 @@ export class CanvasManager {
const unsubscribeRenderer = this.store.subscribe(this.render);
// When we this flag, we need to render the staging area
$shouldShowStagedImage.subscribe((shouldShowStagedImage, prevShouldShowStagedImage) => {
log.debug('Rendering staging area');
$shouldShowStagedImage.subscribe(async (shouldShowStagedImage, prevShouldShowStagedImage) => {
if (shouldShowStagedImage !== prevShouldShowStagedImage) {
this.preview.stagingArea.render();
log.debug('Rendering staging area');
await this.preview.stagingArea.render();
}
});
$lastProgressEvent.subscribe(() => {
log.debug('Rendering staging area');
this.preview.stagingArea.render();
$lastProgressEvent.subscribe(async (lastProgressEvent, prevLastProgressEvent) => {
if (lastProgressEvent !== prevLastProgressEvent) {
log.debug('Rendering progress image');
await this.preview.progressPreview.render(lastProgressEvent);
if (this.stateApi.getSession().isActive) {
this.preview.stagingArea.render();
}
}
});
log.debug('First render of konva stage');

View File

@ -1,3 +1,4 @@
import type { CanvasProgressPreview } from 'features/controlLayers/konva/CanvasProgressPreview';
import Konva from 'konva';
import type { CanvasBbox } from './CanvasBbox';
@ -11,12 +12,14 @@ export class CanvasPreview {
bbox: CanvasBbox;
documentSizeOverlay: CanvasDocumentSizeOverlay;
stagingArea: CanvasStagingArea;
progressPreview: CanvasProgressPreview;
constructor(
bbox: CanvasBbox,
tool: CanvasTool,
documentSizeOverlay: CanvasDocumentSizeOverlay,
stagingArea: CanvasStagingArea
stagingArea: CanvasStagingArea,
progressPreview: CanvasProgressPreview
) {
this.layer = new Konva.Layer({ listening: true, imageSmoothingEnabled: false });
@ -31,5 +34,8 @@ export class CanvasPreview {
this.tool = tool;
this.layer.add(this.tool.group);
this.progressPreview = progressPreview;
this.layer.add(this.progressPreview.group);
}
}

View File

@ -1,3 +1,4 @@
import { loadImage } from 'features/controlLayers/konva/util';
import Konva from 'konva';
export class CanvasProgressImage {
@ -11,7 +12,6 @@ export class CanvasProgressImage {
constructor(arg: { id: string }) {
const { id } = arg;
this.konvaImageGroup = new Konva.Group({ id, listening: false });
this.id = id;
this.progressImageId = null;
this.konvaImage = null;
@ -27,8 +27,12 @@ export class CanvasProgressImage {
width: number,
height: number
) {
const imageEl = new Image();
imageEl.onload = () => {
if (this.isLoading) {
return;
}
this.isLoading = true;
try {
const imageEl = await loadImage(dataURL);
if (this.konvaImage) {
this.konvaImage.setAttrs({
image: imageEl,
@ -49,9 +53,11 @@ export class CanvasProgressImage {
});
this.konvaImageGroup.add(this.konvaImage);
}
};
imageEl.id = progressImageId;
imageEl.src = dataURL;
this.isLoading = false;
this.id = progressImageId;
} catch {
this.isError = true;
}
}
destroy() {

View File

@ -0,0 +1,38 @@
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasProgressImage } from 'features/controlLayers/konva/CanvasProgressImage';
import Konva from 'konva';
import type { InvocationDenoiseProgressEvent } from 'services/events/types';
export class CanvasProgressPreview {
group: Konva.Group;
progressImage: CanvasProgressImage;
manager: CanvasManager;
constructor(manager: CanvasManager) {
this.manager = manager;
this.group = new Konva.Group({ listening: false });
this.progressImage = new CanvasProgressImage({ id: 'progress-image' });
this.group.add(this.progressImage.konvaImageGroup);
}
async render(lastProgressEvent: InvocationDenoiseProgressEvent | null) {
const bboxRect = this.manager.stateApi.getBbox().rect;
if (lastProgressEvent) {
const { invocation, step, progress_image } = lastProgressEvent;
const { dataURL } = progress_image;
const { x, y, width, height } = bboxRect;
const progressImageId = `${invocation.id}_${step}`;
if (
!this.progressImage.isLoading &&
!this.progressImage.isError &&
this.progressImage.progressImageId !== progressImageId
) {
await this.progressImage.updateImageSource(progressImageId, dataURL, x, y, width, height);
this.progressImage.konvaImageGroup.visible(true);
}
} else {
this.progressImage.konvaImageGroup.visible(false);
}
}
}

View File

@ -1,13 +1,11 @@
import { CanvasImage } from 'features/controlLayers/konva/CanvasImage';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasProgressImage } from 'features/controlLayers/konva/CanvasProgressImage';
import Konva from 'konva';
import type { ImageDTO } from 'services/api/types';
export class CanvasStagingArea {
group: Konva.Group;
image: CanvasImage | null;
progressImage: CanvasProgressImage | null;
imageDTO: ImageDTO | null;
manager: CanvasManager;
@ -15,35 +13,32 @@ export class CanvasStagingArea {
this.manager = manager;
this.group = new Konva.Group({ listening: false });
this.image = null;
this.progressImage = null;
this.imageDTO = null;
}
async render() {
const stagingArea = this.manager.stateApi.getSession();
const bbox = this.manager.stateApi.getBbox();
const session = this.manager.stateApi.getSession();
const bboxRect = this.manager.stateApi.getBbox().rect;
const shouldShowStagedImage = this.manager.stateApi.getShouldShowStagedImage();
const lastProgressEvent = this.manager.stateApi.getLastProgressEvent();
this.imageDTO = stagingArea.stagedImages[stagingArea.selectedStagedImageIndex] ?? null;
this.imageDTO = session.stagedImages[session.selectedStagedImageIndex] ?? null;
if (this.imageDTO) {
if (this.image) {
if (!this.image.isLoading && !this.image.isError && this.image.imageName !== this.imageDTO.image_name) {
await this.image.updateImageSource(this.imageDTO.image_name);
}
this.image.konvaImageGroup.x(bbox.x);
this.image.konvaImageGroup.y(bbox.y);
this.image.konvaImageGroup.x(bboxRect.x);
this.image.konvaImageGroup.y(bboxRect.y);
this.image.konvaImageGroup.visible(shouldShowStagedImage);
this.progressImage?.konvaImageGroup.visible(false);
} else {
const { image_name, width, height } = this.imageDTO;
this.image = new CanvasImage(
{
id: 'staging-area-image',
type: 'image',
x: bbox.x,
y: bbox.y,
x: bboxRect.x,
y: bboxRect.y,
width,
height,
filters: [],
@ -60,48 +55,16 @@ export class CanvasStagingArea {
konvaImage.height(this.imageDTO.height);
}
this.manager.stateApi.resetLastProgressEvent();
this.image?.konvaImageGroup.visible(shouldShowStagedImage);
},
}
);
this.group.add(this.image.konvaImageGroup);
await this.image.updateImageSource(this.imageDTO.image_name);
this.image.konvaImageGroup.visible(shouldShowStagedImage);
this.progressImage?.konvaImageGroup.visible(false);
}
}
if (stagingArea.isStaging && lastProgressEvent) {
const { invocation, step, progress_image } = lastProgressEvent;
const { dataURL } = progress_image;
const { x, y, width, height } = bbox;
const progressImageId = `${invocation.id}_${step}`;
if (this.progressImage) {
if (
!this.progressImage.isLoading &&
!this.progressImage.isError &&
this.progressImage.progressImageId !== progressImageId
) {
await this.progressImage.updateImageSource(progressImageId, dataURL, x, y, width, height);
this.image?.konvaImageGroup.visible(false);
this.progressImage.konvaImageGroup.visible(true);
}
} else {
this.progressImage = new CanvasProgressImage({ id: 'progress-image' });
this.group.add(this.progressImage.konvaImageGroup);
await this.progressImage.updateImageSource(progressImageId, dataURL, x, y, width, height);
this.image?.konvaImageGroup.visible(false);
this.progressImage.konvaImageGroup.visible(true);
}
}
if (!this.imageDTO && !lastProgressEvent) {
if (this.image) {
this.image.konvaImageGroup.visible(false);
}
if (this.progressImage) {
this.progressImage.konvaImageGroup.visible(false);
}
this.manager.stateApi.resetLastProgressEvent();
} else {
this.image?.konvaImageGroup.visible(false);
}
}
}

View File

@ -538,3 +538,12 @@ export async function getCompositeLayerImage(arg: {
manager.stateApi.onLayerImageCached(imageDTO);
return imageDTO;
}
export function loadImage(src: string, imageEl?: HTMLImageElement): Promise<HTMLImageElement> {
return new Promise((resolve, reject) => {
const _imageEl = imageEl ?? new Image();
_imageEl.onload = () => resolve(_imageEl);
_imageEl.onerror = (error) => reject(error);
_imageEl.src = src;
});
}

View File

@ -22,7 +22,7 @@ export const addOutpaint = async (
): Promise<Invocation<'canvas_paste_back'>> => {
denoise.denoising_start = denoising_start;
const cropBbox = pick(bbox, ['x', 'y', 'width', 'height']);
const cropBbox = pick(bbox.rect, ['x', 'y', 'width', 'height']);
const initialImage = await manager.getInitialImage({ bbox: cropBbox });
const maskImage = await manager.getInpaintMaskImage({ bbox: cropBbox });
const infill = getInfill(g, compositing);