feat(ui): handle new progress event

Minor changes to use the new progress event. Only additional feature is if the progress has a message, it is displayed as a tooltip on the progress bar.
This commit is contained in:
psychedelicious 2024-08-04 19:02:45 +10:00
parent 4e0e9041e2
commit c2e9bdc6c5
13 changed files with 47 additions and 58 deletions

View File

@ -3,7 +3,7 @@ import { deepClone } from 'common/util/deepClone';
import { isAnyGraphBuilt } from 'features/nodes/store/actions';
import { appInfoApi } from 'services/api/endpoints/appInfo';
import type { Graph } from 'services/api/types';
import { socketGeneratorProgress } from 'services/events/actions';
import { socketInvocationProgress } from 'services/events/actions';
export const actionSanitizer = <A extends UnknownAction>(action: A): A => {
if (isAnyGraphBuilt(action)) {
@ -24,10 +24,10 @@ export const actionSanitizer = <A extends UnknownAction>(action: A): A => {
};
}
if (socketGeneratorProgress.match(action)) {
if (socketInvocationProgress.match(action)) {
const sanitized = deepClone(action);
if (sanitized.payload.data.progress_image) {
sanitized.payload.data.progress_image.dataURL = '<Progress image omitted>';
if (sanitized.payload.data.image) {
sanitized.payload.data.image.dataURL = '<Progress image omitted>';
}
return sanitized;
}

View File

@ -39,9 +39,9 @@ import { addDynamicPromptsListener } from 'app/store/middleware/listenerMiddlewa
import { addSetDefaultSettingsListener } from 'app/store/middleware/listenerMiddleware/listeners/setDefaultSettings';
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected';
import { addSocketDisconnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected';
import { addGeneratorProgressEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress';
import { addInvocationCompleteEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete';
import { addInvocationErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError';
import { addInvocationProgressEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationProgress';
import { addInvocationStartedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted';
import { addModelInstallEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall';
import { addModelLoadEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoad';
@ -102,7 +102,7 @@ addStagingAreaImageSavedListener(startAppListening);
addCommitStagingAreaImageListener(startAppListening);
// Socket.IO
addGeneratorProgressEventListener(startAppListening);
addInvocationProgressEventListener(startAppListening);
addInvocationCompleteEventListener(startAppListening);
addInvocationErrorEventListener(startAppListening);
addInvocationStartedEventListener(startAppListening);

View File

@ -4,21 +4,21 @@ import { deepClone } from 'common/util/deepClone';
import { parseify } from 'common/util/serialize';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { socketGeneratorProgress } from 'services/events/actions';
import { socketInvocationProgress } from 'services/events/actions';
const log = logger('socketio');
export const addGeneratorProgressEventListener = (startAppListening: AppStartListening) => {
export const addInvocationProgressEventListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: socketGeneratorProgress,
actionCreator: socketInvocationProgress,
effect: (action) => {
log.trace(parseify(action.payload), `Generator progress`);
const { invocation_source_id, step, total_steps, progress_image } = action.payload.data;
const { invocation_source_id, percentage, image } = action.payload.data;
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
if (nes) {
nes.status = zNodeStatus.enum.IN_PROGRESS;
nes.progress = (step + 1) / total_steps;
nes.progressImage = progress_image ?? null;
nes.progress = percentage;
nes.progressImage = image ?? null;
upsertExecutionState(nes.nodeId, nes);
}
},

View File

@ -10,8 +10,7 @@ const progressImageSelector = createMemoizedSelector([selectSystemSlice, selectC
const { batchIds } = canvas;
return {
progressImage:
denoiseProgress && batchIds.includes(denoiseProgress.batch_id) ? denoiseProgress.progress_image : undefined,
progressImage: denoiseProgress && batchIds.includes(denoiseProgress.batch_id) ? denoiseProgress.image : undefined,
boundingBox: canvas.layerState.stagingArea.boundingBox,
};
});

View File

@ -40,7 +40,7 @@ const selectShouldDisableToolbarButtons = createSelector(
selectGallerySlice,
selectLastSelectedImage,
(system, gallery, lastSelectedImage) => {
const hasProgressImage = Boolean(system.denoiseProgress?.progress_image);
const hasProgressImage = Boolean(system.denoiseProgress?.image);
return hasProgressImage || !lastSelectedImage;
}
);

View File

@ -4,7 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks';
import { memo, useMemo } from 'react';
const CurrentImagePreview = () => {
const progress_image = useAppSelector((s) => s.system.denoiseProgress?.progress_image);
const image = useAppSelector((s) => s.system.denoiseProgress?.image);
const shouldAntialiasProgressImage = useAppSelector((s) => s.system.shouldAntialiasProgressImage);
const sx = useMemo<SystemStyleObject>(
@ -14,15 +14,15 @@ const CurrentImagePreview = () => {
[shouldAntialiasProgressImage]
);
if (!progress_image) {
if (!image) {
return null;
}
return (
<Image
src={progress_image.dataURL}
width={progress_image.width}
height={progress_image.height}
src={image.dataURL}
width={image.width}
height={image.height}
draggable={false}
data-testid="progress-image"
objectFit="contain"

View File

@ -20,7 +20,7 @@ const selector = createMemoizedSelector(selectSystemSlice, selectGallerySlice, (
return {
imageDTO,
progressImage: system.denoiseProgress?.progress_image,
progressImage: system.denoiseProgress?.image,
};
});

View File

@ -1,4 +1,4 @@
import { Progress } from '@invoke-ai/ui-library';
import { Progress, Tooltip } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { selectSystemSlice } from 'features/system/store/systemSlice';
@ -15,18 +15,21 @@ const ProgressBar = () => {
const { t } = useTranslation();
const { data: queueStatus } = useGetQueueStatusQuery();
const isConnected = useAppSelector((s) => s.system.isConnected);
const hasSteps = useAppSelector((s) => Boolean(s.system.denoiseProgress));
const message = useAppSelector((s) => s.system.denoiseProgress?.message);
const hasSteps = useAppSelector((s) => Boolean(s.system.denoiseProgress?.percentage !== undefined));
const value = useAppSelector(selectProgressValue);
return (
<Progress
value={value}
aria-label={t('accessibility.invokeProgressBar')}
isIndeterminate={isConnected && Boolean(queueStatus?.queue.in_progress) && !hasSteps}
h={2}
w="full"
colorScheme="invokeBlue"
/>
<Tooltip label={message} placement="end">
<Progress
value={value}
aria-label={t('accessibility.invokeProgressBar')}
isIndeterminate={isConnected && Boolean(queueStatus?.queue.in_progress) && !hasSteps}
h={2}
w="full"
colorScheme="invokeBlue"
/>
</Tooltip>
);
};

View File

@ -5,8 +5,8 @@ import type { LogLevelName } from 'roarr';
import {
socketConnected,
socketDisconnected,
socketGeneratorProgress,
socketInvocationComplete,
socketInvocationProgress,
socketInvocationStarted,
socketModelLoadComplete,
socketModelLoadStarted,
@ -95,8 +95,8 @@ export const systemSlice = createSlice({
/**
* Generator Progress
*/
builder.addCase(socketGeneratorProgress, (state, action) => {
const { step, total_steps, progress_image, session_id, batch_id, percentage } = action.payload.data;
builder.addCase(socketInvocationProgress, (state, action) => {
const { image, session_id, batch_id, percentage, message } = action.payload.data;
if (state.cancellations.includes(session_id)) {
// Do not update the progress if this session has been cancelled. This prevents a race condition where we get a
@ -105,10 +105,9 @@ export const systemSlice = createSlice({
}
state.denoiseProgress = {
step,
total_steps,
message,
percentage,
progress_image,
image,
session_id,
batch_id,
};

View File

@ -1,18 +1,9 @@
import type { LogLevel } from 'app/logging/logger';
import type { ProgressImage } from 'services/events/types';
import type { InvocationProgressEvent } from 'services/events/types';
import { z } from 'zod';
type SystemStatus = 'CONNECTED' | 'DISCONNECTED' | 'PROCESSING' | 'ERROR' | 'LOADING_MODEL';
type DenoiseProgress = {
session_id: string;
batch_id: string;
progress_image: ProgressImage | null | undefined;
step: number;
total_steps: number;
percentage: number;
};
const zLanguage = z.enum([
'ar',
'az',
@ -45,7 +36,7 @@ export interface SystemState {
isConnected: boolean;
shouldConfirmOnDelete: boolean;
enableImageDebugging: boolean;
denoiseProgress: DenoiseProgress | null;
denoiseProgress: Pick<InvocationProgressEvent, 'session_id' | 'batch_id' | 'image' | 'percentage' | 'message'> | null;
consoleLogLevel: LogLevel;
shouldLogToConsole: boolean;
shouldAntialiasProgressImage: boolean;

View File

@ -9,8 +9,8 @@ import type {
DownloadProgressEvent,
DownloadStartedEvent,
InvocationCompleteEvent,
InvocationDenoiseProgressEvent,
InvocationErrorEvent,
InvocationProgressEvent,
InvocationStartedEvent,
ModelInstallCancelledEvent,
ModelInstallCompleteEvent,
@ -32,9 +32,7 @@ export const socketDisconnected = createSocketAction('Disconnected');
export const socketInvocationStarted = createSocketAction<InvocationStartedEvent>('InvocationStartedEvent');
export const socketInvocationComplete = createSocketAction<InvocationCompleteEvent>('InvocationCompleteEvent');
export const socketInvocationError = createSocketAction<InvocationErrorEvent>('InvocationErrorEvent');
export const socketGeneratorProgress = createSocketAction<InvocationDenoiseProgressEvent>(
'InvocationDenoiseProgressEvent'
);
export const socketInvocationProgress = createSocketAction<InvocationProgressEvent>('InvocationProgressEvent');
export const socketModelLoadStarted = createSocketAction<ModelLoadStartedEvent>('ModelLoadStartedEvent');
export const socketModelLoadComplete = createSocketAction<ModelLoadCompleteEvent>('ModelLoadCompleteEvent');
export const socketDownloadStarted = createSocketAction<DownloadStartedEvent>('DownloadStartedEvent');

View File

@ -14,9 +14,9 @@ import {
socketDownloadError,
socketDownloadProgress,
socketDownloadStarted,
socketGeneratorProgress,
socketInvocationComplete,
socketInvocationError,
socketInvocationProgress,
socketInvocationStarted,
socketModelInstallCancelled,
socketModelInstallComplete,
@ -65,8 +65,8 @@ export const setEventListeners = ({ socket, dispatch }: SetEventListenersArg) =>
socket.on('invocation_started', (data) => {
dispatch(socketInvocationStarted({ data }));
});
socket.on('invocation_denoise_progress', (data) => {
dispatch(socketGeneratorProgress({ data }));
socket.on('invocation_progress', (data) => {
dispatch(socketInvocationProgress({ data }));
});
socket.on('invocation_error', (data) => {
dispatch(socketInvocationError({ data }));

View File

@ -4,10 +4,9 @@ export type ModelLoadStartedEvent = S['ModelLoadStartedEvent'];
export type ModelLoadCompleteEvent = S['ModelLoadCompleteEvent'];
export type InvocationStartedEvent = S['InvocationStartedEvent'];
export type InvocationDenoiseProgressEvent = S['InvocationDenoiseProgressEvent'];
export type InvocationProgressEvent = S['InvocationProgressEvent'];
export type InvocationCompleteEvent = S['InvocationCompleteEvent'];
export type InvocationErrorEvent = S['InvocationErrorEvent'];
export type ProgressImage = InvocationDenoiseProgressEvent['progress_image'];
export type ModelInstallDownloadStartedEvent = S['ModelInstallDownloadStartedEvent'];
export type ModelInstallDownloadProgressEvent = S['ModelInstallDownloadProgressEvent'];
@ -39,7 +38,7 @@ type ClientEmitSubscribeBulkDownload = {
type ClientEmitUnsubscribeBulkDownload = ClientEmitSubscribeBulkDownload;
export type ServerToClientEvents = {
invocation_denoise_progress: (payload: InvocationDenoiseProgressEvent) => void;
invocation_progress: (payload: InvocationProgressEvent) => void;
invocation_complete: (payload: InvocationCompleteEvent) => void;
invocation_error: (payload: InvocationErrorEvent) => void;
invocation_started: (payload: InvocationStartedEvent) => void;