feat(ui): canvas staging area works

This commit is contained in:
psychedelicious 2024-06-28 21:04:15 +10:00
parent ac524153a7
commit 6f1d238d0a
8 changed files with 142 additions and 14 deletions

View File

@ -2,6 +2,7 @@ import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone'; import { deepClone } from 'common/util/deepClone';
import { parseify } from 'common/util/serialize'; import { parseify } from 'common/util/serialize';
import { $lastProgressEvent } from 'features/controlLayers/store/canvasV2Slice';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState'; import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation'; import { zNodeStatus } from 'features/nodes/types/invocation';
import { socketGeneratorProgress } from 'services/events/actions'; import { socketGeneratorProgress } from 'services/events/actions';
@ -11,9 +12,9 @@ const log = logger('socketio');
export const addGeneratorProgressEventListener = (startAppListening: AppStartListening) => { export const addGeneratorProgressEventListener = (startAppListening: AppStartListening) => {
startAppListening({ startAppListening({
actionCreator: socketGeneratorProgress, actionCreator: socketGeneratorProgress,
effect: (action) => { effect: (action, { getState }) => {
log.trace(parseify(action.payload), `Generator progress`); log.trace(parseify(action.payload), `Generator progress`);
const { invocation_source_id, step, total_steps, progress_image } = action.payload.data; const { invocation_source_id, step, total_steps, progress_image, batch_id } = action.payload.data;
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]); const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
if (nes) { if (nes) {
nes.status = zNodeStatus.enum.IN_PROGRESS; nes.status = zNodeStatus.enum.IN_PROGRESS;
@ -21,6 +22,11 @@ export const addGeneratorProgressEventListener = (startAppListening: AppStartLis
nes.progressImage = progress_image ?? null; nes.progressImage = progress_image ?? null;
upsertExecutionState(nes.nodeId, nes); upsertExecutionState(nes.nodeId, nes);
} }
const isCanvasQueueItem = getState().canvasV2.stagingArea?.batchIds.includes(batch_id);
if (isCanvasQueueItem) {
$lastProgressEvent.set(action.payload.data);
}
}, },
}); });
}; };

View File

@ -1,6 +1,7 @@
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone'; import { deepClone } from 'common/util/deepClone';
import { $lastProgressEvent } from 'features/controlLayers/store/canvasV2Slice';
import { $nodeExecutionStates } from 'features/nodes/hooks/useExecutionState'; import { $nodeExecutionStates } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation'; import { zNodeStatus } from 'features/nodes/types/invocation';
import ErrorToastDescription, { getTitleFromErrorType } from 'features/toast/ErrorToastDescription'; import ErrorToastDescription, { getTitleFromErrorType } from 'features/toast/ErrorToastDescription';
@ -28,10 +29,13 @@ export const addSocketQueueItemStatusChangedEventListener = (startAppListening:
error_type, error_type,
error_message, error_message,
error_traceback, error_traceback,
batch_id,
} = action.payload.data; } = action.payload.data;
log.debug(action.payload, `Queue item ${item_id} status updated: ${status}`); log.debug(action.payload, `Queue item ${item_id} status updated: ${status}`);
const isCanvasQueueItem = getState().canvasV2.stagingArea?.batchIds.includes(batch_id);
// Update this specific queue item in the list of queue items (this is the queue item DTO, without the session) // Update this specific queue item in the list of queue items (this is the queue item DTO, without the session)
dispatch( dispatch(
queueApi.util.updateQueryData('listQueueItems', undefined, (draft) => { queueApi.util.updateQueryData('listQueueItems', undefined, (draft) => {
@ -92,6 +96,9 @@ export const addSocketQueueItemStatusChangedEventListener = (startAppListening:
} else if (status === 'failed' && error_type) { } else if (status === 'failed' && error_type) {
const isLocal = getState().config.isLocal ?? true; const isLocal = getState().config.isLocal ?? true;
const sessionId = session_id; const sessionId = session_id;
if (isCanvasQueueItem) {
$lastProgressEvent.set(null);
}
toast({ toast({
id: `INVOCATION_ERROR_${error_type}`, id: `INVOCATION_ERROR_${error_type}`,
@ -108,6 +115,10 @@ export const addSocketQueueItemStatusChangedEventListener = (startAppListening:
/> />
), ),
}); });
} else if (status === 'completed' && isCanvasQueueItem) {
$lastProgressEvent.set(null);
} else if (status === 'canceled' && isCanvasQueueItem) {
$lastProgressEvent.set(null);
} }
}, },
}); });

View File

@ -22,6 +22,7 @@ import type { Vector2d } from 'konva/lib/types';
import { atom } from 'nanostores'; import { atom } from 'nanostores';
import { getImageDTO as defaultGetImageDTO, uploadImage as defaultUploadImage } from 'services/api/endpoints/images'; import { getImageDTO as defaultGetImageDTO, uploadImage as defaultUploadImage } from 'services/api/endpoints/images';
import type { ImageCategory, ImageDTO } from 'services/api/types'; import type { ImageCategory, ImageDTO } from 'services/api/types';
import type { InvocationDenoiseProgressEvent } from 'services/events/types';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
import { CanvasBbox } from './renderers/bbox'; import { CanvasBbox } from './renderers/bbox';
@ -74,6 +75,7 @@ export type StateApi = {
getRegionsState: () => CanvasV2State['regions']; getRegionsState: () => CanvasV2State['regions'];
getInpaintMaskState: () => CanvasV2State['inpaintMask']; getInpaintMaskState: () => CanvasV2State['inpaintMask'];
getStagingAreaState: () => CanvasV2State['stagingArea']; getStagingAreaState: () => CanvasV2State['stagingArea'];
getLastProgressEvent: () => InvocationDenoiseProgressEvent | null;
onInpaintMaskImageCached: (imageDTO: ImageDTO) => void; onInpaintMaskImageCached: (imageDTO: ImageDTO) => void;
onRegionMaskImageCached: (id: string, imageDTO: ImageDTO) => void; onRegionMaskImageCached: (id: string, imageDTO: ImageDTO) => void;
onLayerImageCached: (imageDTO: ImageDTO) => void; onLayerImageCached: (imageDTO: ImageDTO) => void;
@ -276,7 +278,11 @@ export class KonvaNodeManager {
} }
renderStagingArea() { renderStagingArea() {
this.preview.stagingArea.render(this.stateApi.getStagingAreaState(), this.stateApi.getShouldShowStagedImage()); this.preview.stagingArea.render(
this.stateApi.getStagingAreaState(),
this.stateApi.getShouldShowStagedImage(),
this.stateApi.getLastProgressEvent()
);
} }
fitDocument() { fitDocument() {

View File

@ -33,7 +33,7 @@ export class CanvasControlAdapter {
const imageObject = entity.processedImageObject ?? entity.imageObject; const imageObject = entity.processedImageObject ?? entity.imageObject;
if (!imageObject) { if (!imageObject) {
if (this.image) { if (this.image) {
this.image.destroy(); this.image.konvaImageGroup.visible(false);
} }
return; return;
} }

View File

@ -173,6 +173,7 @@ export class KonvaImage {
this.konvaPlaceholderGroup.visible(true); this.konvaPlaceholderGroup.visible(true);
this.konvaPlaceholderText.text(t('common.loadingImage', 'Loading Image')); this.konvaPlaceholderText.text(t('common.loadingImage', 'Loading Image'));
} }
this.konvaImageGroup.visible(true);
if (onLoading) { if (onLoading) {
onLoading(); onLoading();
} }
@ -194,6 +195,8 @@ export class KonvaImage {
this.isLoading = false; this.isLoading = false;
this.isError = false; this.isError = false;
this.konvaPlaceholderGroup.visible(false); this.konvaPlaceholderGroup.visible(false);
this.konvaImageGroup.visible(true);
if (onLoad) { if (onLoad) {
onLoad(this.konvaImage); onLoad(this.konvaImage);
} }
@ -204,6 +207,8 @@ export class KonvaImage {
this.isError = true; this.isError = true;
this.konvaPlaceholderGroup.visible(true); this.konvaPlaceholderGroup.visible(true);
this.konvaPlaceholderText.text(t('common.imageFailedToLoad', 'Image Failed to Load')); this.konvaPlaceholderText.text(t('common.imageFailedToLoad', 'Image Failed to Load'));
this.konvaImageGroup.visible(true);
if (onError) { if (onError) {
onError(); onError();
} }
@ -237,3 +242,62 @@ export class KonvaImage {
this.konvaImageGroup.destroy(); this.konvaImageGroup.destroy();
} }
} }
export class KonvaProgressImage {
id: string;
progressImageId: string | null;
konvaImageGroup: Konva.Group;
konvaImage: Konva.Image | null; // The image is loaded asynchronously, so it may not be available immediately
isLoading: boolean;
isError: boolean;
constructor(arg: { id: string }) {
const { id } = arg;
this.konvaImageGroup = new Konva.Group({ id, listening: false });
this.id = id;
this.progressImageId = null;
this.konvaImage = null;
this.isLoading = false;
this.isError = false;
}
async updateImageSource(
progressImageId: string,
dataURL: string,
x: number,
y: number,
width: number,
height: number
) {
const imageEl = new Image();
imageEl.onload = () => {
if (this.konvaImage) {
this.konvaImage.setAttrs({
image: imageEl,
x,
y,
width,
height,
});
} else {
this.konvaImage = new Konva.Image({
id: this.id,
listening: false,
image: imageEl,
x,
y,
width,
height,
});
this.konvaImageGroup.add(this.konvaImage);
}
};
imageEl.id = progressImageId;
imageEl.src = dataURL;
}
destroy() {
this.konvaImageGroup.destroy();
}
}

View File

@ -7,6 +7,7 @@ import { setStageEventHandlers } from 'features/controlLayers/konva/events';
import { KonvaNodeManager, setNodeManager } from 'features/controlLayers/konva/nodeManager'; import { KonvaNodeManager, setNodeManager } from 'features/controlLayers/konva/nodeManager';
import { updateBboxes } from 'features/controlLayers/konva/renderers/entityBbox'; import { updateBboxes } from 'features/controlLayers/konva/renderers/entityBbox';
import { import {
$lastProgressEvent,
$shouldShowStagedImage, $shouldShowStagedImage,
$stageAttrs, $stageAttrs,
bboxChanged, bboxChanged,
@ -305,6 +306,7 @@ export const initializeRenderer = (
getInpaintMaskState, getInpaintMaskState,
getStagingAreaState, getStagingAreaState,
getShouldShowStagedImage: $shouldShowStagedImage.get, getShouldShowStagedImage: $shouldShowStagedImage.get,
getLastProgressEvent: $lastProgressEvent.get,
// Read-write state // Read-write state
setTool, setTool,
@ -453,6 +455,11 @@ export const initializeRenderer = (
} }
}); });
$lastProgressEvent.subscribe(() => {
logIfDebugging('Rendering staging area');
manager.renderStagingArea();
});
logIfDebugging('First render of konva stage'); logIfDebugging('First render of konva stage');
// On first render, the document should be fit to the stage. // On first render, the document should be fit to the stage.
manager.renderDocumentSizeOverlay(); manager.renderDocumentSizeOverlay();

View File

@ -1,34 +1,58 @@
import { KonvaImage } from 'features/controlLayers/konva/renderers/objects'; import { KonvaImage, KonvaProgressImage } from 'features/controlLayers/konva/renderers/objects';
import type { CanvasV2State } from 'features/controlLayers/store/types'; import type { CanvasV2State } from 'features/controlLayers/store/types';
import Konva from 'konva'; import Konva from 'konva';
import type { InvocationDenoiseProgressEvent } from 'services/events/types';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
export class CanvasStagingArea { export class CanvasStagingArea {
group: Konva.Group; group: Konva.Group;
image: KonvaImage | null; image: KonvaImage | null;
progressImage: KonvaProgressImage | null;
constructor() { constructor() {
this.group = new Konva.Group({ listening: false }); this.group = new Konva.Group({ listening: false });
this.image = null; this.image = null;
this.progressImage = null;
} }
async render(stagingArea: CanvasV2State['stagingArea'], shouldShowStagedImage: boolean) { async render(
if (!stagingArea || stagingArea.selectedImageIndex === null) { stagingArea: CanvasV2State['stagingArea'],
if (this.image) { shouldShowStagedImage: boolean,
this.image.destroy(); lastProgressEvent: InvocationDenoiseProgressEvent | null
this.image = null; ) {
if (stagingArea && lastProgressEvent) {
const { invocation, step, progress_image } = lastProgressEvent;
const { dataURL } = progress_image;
const { x, y, width, height } = stagingArea.bbox;
const progressImageId = `${invocation.id}_${step}`;
if (this.progressImage) {
if (
!this.progressImage.isLoading &&
!this.progressImage.isError &&
this.progressImage.progressImageId !== progressImageId
) {
await this.progressImage.updateImageSource(progressImageId, dataURL, x, y, width, height);
this.image?.konvaImageGroup.visible(false);
this.progressImage.konvaImageGroup.visible(true);
} }
return; } else {
this.progressImage = new KonvaProgressImage({ id: 'progress-image' });
this.group.add(this.progressImage.konvaImageGroup);
await this.progressImage.updateImageSource(progressImageId, dataURL, x, y, width, height);
this.image?.konvaImageGroup.visible(false);
this.progressImage.konvaImageGroup.visible(true);
} }
} else if (stagingArea && stagingArea.selectedImageIndex !== null) {
if (stagingArea.selectedImageIndex !== null) {
const imageDTO = stagingArea.images[stagingArea.selectedImageIndex]; const imageDTO = stagingArea.images[stagingArea.selectedImageIndex];
assert(imageDTO, 'Image must exist'); assert(imageDTO, 'Image must exist');
if (this.image) { if (this.image) {
if (!this.image.isLoading && !this.image.isError && this.image.imageName !== imageDTO.image_name) { if (!this.image.isLoading && !this.image.isError && this.image.imageName !== imageDTO.image_name) {
await this.image.updateImageSource(imageDTO.image_name); await this.image.updateImageSource(imageDTO.image_name);
} }
this.image.konvaImageGroup.x(stagingArea.bbox.x);
this.image.konvaImageGroup.y(stagingArea.bbox.y);
this.image.konvaImageGroup.visible(shouldShowStagedImage); this.image.konvaImageGroup.visible(shouldShowStagedImage);
this.progressImage?.konvaImageGroup.visible(false);
} else { } else {
const { image_name, width, height } = imageDTO; const { image_name, width, height } = imageDTO;
this.image = new KonvaImage({ this.image = new KonvaImage({
@ -50,6 +74,14 @@ export class CanvasStagingArea {
this.group.add(this.image.konvaImageGroup); this.group.add(this.image.konvaImageGroup);
await this.image.updateImageSource(imageDTO.image_name); await this.image.updateImageSource(imageDTO.image_name);
this.image.konvaImageGroup.visible(shouldShowStagedImage); this.image.konvaImageGroup.visible(shouldShowStagedImage);
this.progressImage?.konvaImageGroup.visible(false);
}
} else {
if (this.image) {
this.image.konvaImageGroup.visible(false);
}
if (this.progressImage) {
this.progressImage.konvaImageGroup.visible(false);
} }
} }
} }

View File

@ -18,6 +18,7 @@ import { toolReducers } from 'features/controlLayers/store/toolReducers';
import { initialAspectRatioState } from 'features/parameters/components/ImageSize/constants'; import { initialAspectRatioState } from 'features/parameters/components/ImageSize/constants';
import type { AspectRatioState } from 'features/parameters/components/ImageSize/types'; import type { AspectRatioState } from 'features/parameters/components/ImageSize/types';
import { atom } from 'nanostores'; import { atom } from 'nanostores';
import type { InvocationDenoiseProgressEvent } from 'services/events/types';
import type { CanvasEntityIdentifier, CanvasV2State, StageAttrs } from './types'; import type { CanvasEntityIdentifier, CanvasV2State, StageAttrs } from './types';
import { RGBA_RED } from './types'; import { RGBA_RED } from './types';
@ -358,6 +359,7 @@ export const $stageAttrs = atom<StageAttrs>({
scale: 0, scale: 0,
}); });
export const $shouldShowStagedImage = atom(true); export const $shouldShowStagedImage = atom(true);
export const $lastProgressEvent = atom<InvocationDenoiseProgressEvent | null>(null);
export const canvasV2PersistConfig: PersistConfig<CanvasV2State> = { export const canvasV2PersistConfig: PersistConfig<CanvasV2State> = {
name: canvasV2Slice.name, name: canvasV2Slice.name,