mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
166 lines
5.3 KiB
TypeScript
166 lines
5.3 KiB
TypeScript
import { logger } from 'app/logging/logger';
|
|
import type { AppDispatch, RootState } from 'app/store/store';
|
|
import type { SerializableObject } from 'common/types';
|
|
import { deepClone } from 'common/util/deepClone';
|
|
import { sessionImageStaged } from 'features/controlLayers/store/canvasSessionSlice';
|
|
import { boardIdSelected, galleryViewChanged, imageSelected, offsetChanged } from 'features/gallery/store/gallerySlice';
|
|
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
|
|
import { zNodeStatus } from 'features/nodes/types/invocation';
|
|
import { boardsApi } from 'services/api/endpoints/boards';
|
|
import { getImageDTO, imagesApi } from 'services/api/endpoints/images';
|
|
import type { ImageDTO, S } from 'services/api/types';
|
|
import { getCategories, getListImagesUrl } from 'services/api/util';
|
|
|
|
const log = logger('events');
|
|
|
|
const isCanvasOutputNode = (data: S['InvocationCompleteEvent']) => {
|
|
return data.invocation_source_id.split(':')[0] === 'canvas_output';
|
|
};
|
|
|
|
export const buildOnInvocationComplete = (
|
|
getState: () => RootState,
|
|
dispatch: AppDispatch,
|
|
nodeTypeDenylist: string[],
|
|
setLastProgressEvent: (event: S['InvocationDenoiseProgressEvent'] | null) => void,
|
|
setLastCanvasProgressEvent: (event: S['InvocationDenoiseProgressEvent'] | null) => void
|
|
) => {
|
|
const addImageToGallery = (imageDTO: ImageDTO) => {
|
|
if (imageDTO.is_intermediate) {
|
|
return;
|
|
}
|
|
|
|
// update the total images for the board
|
|
dispatch(
|
|
boardsApi.util.updateQueryData('getBoardImagesTotal', imageDTO.board_id ?? 'none', (draft) => {
|
|
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
|
draft.total += 1;
|
|
})
|
|
);
|
|
|
|
dispatch(
|
|
imagesApi.util.invalidateTags([
|
|
{ type: 'Board', id: imageDTO.board_id ?? 'none' },
|
|
{
|
|
type: 'ImageList',
|
|
id: getListImagesUrl({
|
|
board_id: imageDTO.board_id ?? 'none',
|
|
categories: getCategories(imageDTO),
|
|
}),
|
|
},
|
|
])
|
|
);
|
|
|
|
const { shouldAutoSwitch, galleryView, selectedBoardId } = getState().gallery;
|
|
|
|
// If auto-switch is enabled, select the new image
|
|
if (shouldAutoSwitch) {
|
|
// if auto-add is enabled, switch the gallery view and board if needed as the image comes in
|
|
if (galleryView !== 'images') {
|
|
dispatch(galleryViewChanged('images'));
|
|
}
|
|
|
|
if (imageDTO.board_id && imageDTO.board_id !== selectedBoardId) {
|
|
dispatch(
|
|
boardIdSelected({
|
|
boardId: imageDTO.board_id,
|
|
selectedImageName: imageDTO.image_name,
|
|
})
|
|
);
|
|
}
|
|
|
|
dispatch(offsetChanged({ offset: 0 }));
|
|
|
|
if (!imageDTO.board_id && selectedBoardId !== 'none') {
|
|
dispatch(
|
|
boardIdSelected({
|
|
boardId: 'none',
|
|
selectedImageName: imageDTO.image_name,
|
|
})
|
|
);
|
|
}
|
|
|
|
dispatch(imageSelected(imageDTO));
|
|
}
|
|
};
|
|
|
|
const getResultImageDTO = (data: S['InvocationCompleteEvent']) => {
|
|
const { result } = data;
|
|
if (result.type === 'image_output') {
|
|
return getImageDTO(result.image.image_name);
|
|
} else if (result.type === 'canvas_v2_mask_and_crop_output') {
|
|
return getImageDTO(result.image.image_name);
|
|
}
|
|
return null;
|
|
};
|
|
|
|
const handleOriginWorkflows = async (data: S['InvocationCompleteEvent']) => {
|
|
const { result, invocation_source_id } = data;
|
|
|
|
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
|
|
if (nes) {
|
|
nes.status = zNodeStatus.enum.COMPLETED;
|
|
if (nes.progress !== null) {
|
|
nes.progress = 1;
|
|
}
|
|
nes.outputs.push(result);
|
|
upsertExecutionState(nes.nodeId, nes);
|
|
}
|
|
|
|
const imageDTO = await getResultImageDTO(data);
|
|
|
|
if (imageDTO) {
|
|
addImageToGallery(imageDTO);
|
|
}
|
|
};
|
|
|
|
const handleOriginCanvas = async (data: S['InvocationCompleteEvent']) => {
|
|
const imageDTO = await getResultImageDTO(data);
|
|
|
|
if (!imageDTO) {
|
|
return;
|
|
}
|
|
|
|
if (data.destination === 'canvas') {
|
|
if (isCanvasOutputNode(data)) {
|
|
if (data.result.type === 'canvas_v2_mask_and_crop_output') {
|
|
const { offset_x, offset_y } = data.result;
|
|
dispatch(sessionImageStaged({ stagingAreaImage: { imageDTO, offsetX: offset_x, offsetY: offset_y } }));
|
|
} else if (data.result.type === 'image_output') {
|
|
dispatch(sessionImageStaged({ stagingAreaImage: { imageDTO, offsetX: 0, offsetY: 0 } }));
|
|
}
|
|
}
|
|
} else {
|
|
// session.mode === 'generate'
|
|
setLastCanvasProgressEvent(null);
|
|
}
|
|
|
|
addImageToGallery(imageDTO);
|
|
};
|
|
|
|
const handleOriginOther = async (data: S['InvocationCompleteEvent']) => {
|
|
const imageDTO = await getResultImageDTO(data);
|
|
|
|
if (imageDTO) {
|
|
addImageToGallery(imageDTO);
|
|
}
|
|
};
|
|
|
|
return async (data: S['InvocationCompleteEvent']) => {
|
|
log.debug(
|
|
{ data } as SerializableObject,
|
|
`Invocation complete (${data.invocation.type}, ${data.invocation_source_id})`
|
|
);
|
|
|
|
// Update the node execution states - the image output is handled below
|
|
if (data.origin === 'workflows') {
|
|
await handleOriginWorkflows(data);
|
|
} else if (data.origin === 'generation') {
|
|
await handleOriginCanvas(data);
|
|
} else {
|
|
await handleOriginOther(data);
|
|
}
|
|
|
|
setLastProgressEvent(null);
|
|
};
|
|
};
|