feat(ui): handle new progress event

Minor changes to use the new progress event. Only additional feature is if the progress has a message, it is displayed as a tooltip on the progress bar.
This commit is contained in:
psychedelicious 2024-08-04 19:02:45 +10:00
parent 4e0e9041e2
commit c2e9bdc6c5
13 changed files with 47 additions and 58 deletions

View File

@ -3,7 +3,7 @@ import { deepClone } from 'common/util/deepClone';
import { isAnyGraphBuilt } from 'features/nodes/store/actions'; import { isAnyGraphBuilt } from 'features/nodes/store/actions';
import { appInfoApi } from 'services/api/endpoints/appInfo'; import { appInfoApi } from 'services/api/endpoints/appInfo';
import type { Graph } from 'services/api/types'; import type { Graph } from 'services/api/types';
import { socketGeneratorProgress } from 'services/events/actions'; import { socketInvocationProgress } from 'services/events/actions';
export const actionSanitizer = <A extends UnknownAction>(action: A): A => { export const actionSanitizer = <A extends UnknownAction>(action: A): A => {
if (isAnyGraphBuilt(action)) { if (isAnyGraphBuilt(action)) {
@ -24,10 +24,10 @@ export const actionSanitizer = <A extends UnknownAction>(action: A): A => {
}; };
} }
if (socketGeneratorProgress.match(action)) { if (socketInvocationProgress.match(action)) {
const sanitized = deepClone(action); const sanitized = deepClone(action);
if (sanitized.payload.data.progress_image) { if (sanitized.payload.data.image) {
sanitized.payload.data.progress_image.dataURL = '<Progress image omitted>'; sanitized.payload.data.image.dataURL = '<Progress image omitted>';
} }
return sanitized; return sanitized;
} }

View File

@ -39,9 +39,9 @@ import { addDynamicPromptsListener } from 'app/store/middleware/listenerMiddlewa
import { addSetDefaultSettingsListener } from 'app/store/middleware/listenerMiddleware/listeners/setDefaultSettings'; import { addSetDefaultSettingsListener } from 'app/store/middleware/listenerMiddleware/listeners/setDefaultSettings';
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected'; import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected';
import { addSocketDisconnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected'; import { addSocketDisconnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected';
import { addGeneratorProgressEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress';
import { addInvocationCompleteEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete'; import { addInvocationCompleteEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete';
import { addInvocationErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError'; import { addInvocationErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError';
import { addInvocationProgressEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationProgress';
import { addInvocationStartedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted'; import { addInvocationStartedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted';
import { addModelInstallEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall'; import { addModelInstallEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall';
import { addModelLoadEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoad'; import { addModelLoadEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoad';
@ -102,7 +102,7 @@ addStagingAreaImageSavedListener(startAppListening);
addCommitStagingAreaImageListener(startAppListening); addCommitStagingAreaImageListener(startAppListening);
// Socket.IO // Socket.IO
addGeneratorProgressEventListener(startAppListening); addInvocationProgressEventListener(startAppListening);
addInvocationCompleteEventListener(startAppListening); addInvocationCompleteEventListener(startAppListening);
addInvocationErrorEventListener(startAppListening); addInvocationErrorEventListener(startAppListening);
addInvocationStartedEventListener(startAppListening); addInvocationStartedEventListener(startAppListening);

View File

@ -4,21 +4,21 @@ import { deepClone } from 'common/util/deepClone';
import { parseify } from 'common/util/serialize'; import { parseify } from 'common/util/serialize';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState'; import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation'; import { zNodeStatus } from 'features/nodes/types/invocation';
import { socketGeneratorProgress } from 'services/events/actions'; import { socketInvocationProgress } from 'services/events/actions';
const log = logger('socketio'); const log = logger('socketio');
export const addGeneratorProgressEventListener = (startAppListening: AppStartListening) => { export const addInvocationProgressEventListener = (startAppListening: AppStartListening) => {
startAppListening({ startAppListening({
actionCreator: socketGeneratorProgress, actionCreator: socketInvocationProgress,
effect: (action) => { effect: (action) => {
log.trace(parseify(action.payload), `Generator progress`); log.trace(parseify(action.payload), `Generator progress`);
const { invocation_source_id, step, total_steps, progress_image } = action.payload.data; const { invocation_source_id, percentage, image } = action.payload.data;
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]); const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
if (nes) { if (nes) {
nes.status = zNodeStatus.enum.IN_PROGRESS; nes.status = zNodeStatus.enum.IN_PROGRESS;
nes.progress = (step + 1) / total_steps; nes.progress = percentage;
nes.progressImage = progress_image ?? null; nes.progressImage = image ?? null;
upsertExecutionState(nes.nodeId, nes); upsertExecutionState(nes.nodeId, nes);
} }
}, },

View File

@ -10,8 +10,7 @@ const progressImageSelector = createMemoizedSelector([selectSystemSlice, selectC
const { batchIds } = canvas; const { batchIds } = canvas;
return { return {
progressImage: progressImage: denoiseProgress && batchIds.includes(denoiseProgress.batch_id) ? denoiseProgress.image : undefined,
denoiseProgress && batchIds.includes(denoiseProgress.batch_id) ? denoiseProgress.progress_image : undefined,
boundingBox: canvas.layerState.stagingArea.boundingBox, boundingBox: canvas.layerState.stagingArea.boundingBox,
}; };
}); });

View File

@ -40,7 +40,7 @@ const selectShouldDisableToolbarButtons = createSelector(
selectGallerySlice, selectGallerySlice,
selectLastSelectedImage, selectLastSelectedImage,
(system, gallery, lastSelectedImage) => { (system, gallery, lastSelectedImage) => {
const hasProgressImage = Boolean(system.denoiseProgress?.progress_image); const hasProgressImage = Boolean(system.denoiseProgress?.image);
return hasProgressImage || !lastSelectedImage; return hasProgressImage || !lastSelectedImage;
} }
); );

View File

@ -4,7 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks';
import { memo, useMemo } from 'react'; import { memo, useMemo } from 'react';
const CurrentImagePreview = () => { const CurrentImagePreview = () => {
const progress_image = useAppSelector((s) => s.system.denoiseProgress?.progress_image); const image = useAppSelector((s) => s.system.denoiseProgress?.image);
const shouldAntialiasProgressImage = useAppSelector((s) => s.system.shouldAntialiasProgressImage); const shouldAntialiasProgressImage = useAppSelector((s) => s.system.shouldAntialiasProgressImage);
const sx = useMemo<SystemStyleObject>( const sx = useMemo<SystemStyleObject>(
@ -14,15 +14,15 @@ const CurrentImagePreview = () => {
[shouldAntialiasProgressImage] [shouldAntialiasProgressImage]
); );
if (!progress_image) { if (!image) {
return null; return null;
} }
return ( return (
<Image <Image
src={progress_image.dataURL} src={image.dataURL}
width={progress_image.width} width={image.width}
height={progress_image.height} height={image.height}
draggable={false} draggable={false}
data-testid="progress-image" data-testid="progress-image"
objectFit="contain" objectFit="contain"

View File

@ -20,7 +20,7 @@ const selector = createMemoizedSelector(selectSystemSlice, selectGallerySlice, (
return { return {
imageDTO, imageDTO,
progressImage: system.denoiseProgress?.progress_image, progressImage: system.denoiseProgress?.image,
}; };
}); });

View File

@ -1,4 +1,4 @@
import { Progress } from '@invoke-ai/ui-library'; import { Progress, Tooltip } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { selectSystemSlice } from 'features/system/store/systemSlice'; import { selectSystemSlice } from 'features/system/store/systemSlice';
@ -15,10 +15,12 @@ const ProgressBar = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const { data: queueStatus } = useGetQueueStatusQuery(); const { data: queueStatus } = useGetQueueStatusQuery();
const isConnected = useAppSelector((s) => s.system.isConnected); const isConnected = useAppSelector((s) => s.system.isConnected);
const hasSteps = useAppSelector((s) => Boolean(s.system.denoiseProgress)); const message = useAppSelector((s) => s.system.denoiseProgress?.message);
const hasSteps = useAppSelector((s) => Boolean(s.system.denoiseProgress?.percentage !== undefined));
const value = useAppSelector(selectProgressValue); const value = useAppSelector(selectProgressValue);
return ( return (
<Tooltip label={message} placement="end">
<Progress <Progress
value={value} value={value}
aria-label={t('accessibility.invokeProgressBar')} aria-label={t('accessibility.invokeProgressBar')}
@ -27,6 +29,7 @@ const ProgressBar = () => {
w="full" w="full"
colorScheme="invokeBlue" colorScheme="invokeBlue"
/> />
</Tooltip>
); );
}; };

View File

@ -5,8 +5,8 @@ import type { LogLevelName } from 'roarr';
import { import {
socketConnected, socketConnected,
socketDisconnected, socketDisconnected,
socketGeneratorProgress,
socketInvocationComplete, socketInvocationComplete,
socketInvocationProgress,
socketInvocationStarted, socketInvocationStarted,
socketModelLoadComplete, socketModelLoadComplete,
socketModelLoadStarted, socketModelLoadStarted,
@ -95,8 +95,8 @@ export const systemSlice = createSlice({
/** /**
* Generator Progress * Generator Progress
*/ */
builder.addCase(socketGeneratorProgress, (state, action) => { builder.addCase(socketInvocationProgress, (state, action) => {
const { step, total_steps, progress_image, session_id, batch_id, percentage } = action.payload.data; const { image, session_id, batch_id, percentage, message } = action.payload.data;
if (state.cancellations.includes(session_id)) { if (state.cancellations.includes(session_id)) {
// Do not update the progress if this session has been cancelled. This prevents a race condition where we get a // Do not update the progress if this session has been cancelled. This prevents a race condition where we get a
@ -105,10 +105,9 @@ export const systemSlice = createSlice({
} }
state.denoiseProgress = { state.denoiseProgress = {
step, message,
total_steps,
percentage, percentage,
progress_image, image,
session_id, session_id,
batch_id, batch_id,
}; };

View File

@ -1,18 +1,9 @@
import type { LogLevel } from 'app/logging/logger'; import type { LogLevel } from 'app/logging/logger';
import type { ProgressImage } from 'services/events/types'; import type { InvocationProgressEvent } from 'services/events/types';
import { z } from 'zod'; import { z } from 'zod';
type SystemStatus = 'CONNECTED' | 'DISCONNECTED' | 'PROCESSING' | 'ERROR' | 'LOADING_MODEL'; type SystemStatus = 'CONNECTED' | 'DISCONNECTED' | 'PROCESSING' | 'ERROR' | 'LOADING_MODEL';
type DenoiseProgress = {
session_id: string;
batch_id: string;
progress_image: ProgressImage | null | undefined;
step: number;
total_steps: number;
percentage: number;
};
const zLanguage = z.enum([ const zLanguage = z.enum([
'ar', 'ar',
'az', 'az',
@ -45,7 +36,7 @@ export interface SystemState {
isConnected: boolean; isConnected: boolean;
shouldConfirmOnDelete: boolean; shouldConfirmOnDelete: boolean;
enableImageDebugging: boolean; enableImageDebugging: boolean;
denoiseProgress: DenoiseProgress | null; denoiseProgress: Pick<InvocationProgressEvent, 'session_id' | 'batch_id' | 'image' | 'percentage' | 'message'> | null;
consoleLogLevel: LogLevel; consoleLogLevel: LogLevel;
shouldLogToConsole: boolean; shouldLogToConsole: boolean;
shouldAntialiasProgressImage: boolean; shouldAntialiasProgressImage: boolean;

View File

@ -9,8 +9,8 @@ import type {
DownloadProgressEvent, DownloadProgressEvent,
DownloadStartedEvent, DownloadStartedEvent,
InvocationCompleteEvent, InvocationCompleteEvent,
InvocationDenoiseProgressEvent,
InvocationErrorEvent, InvocationErrorEvent,
InvocationProgressEvent,
InvocationStartedEvent, InvocationStartedEvent,
ModelInstallCancelledEvent, ModelInstallCancelledEvent,
ModelInstallCompleteEvent, ModelInstallCompleteEvent,
@ -32,9 +32,7 @@ export const socketDisconnected = createSocketAction('Disconnected');
export const socketInvocationStarted = createSocketAction<InvocationStartedEvent>('InvocationStartedEvent'); export const socketInvocationStarted = createSocketAction<InvocationStartedEvent>('InvocationStartedEvent');
export const socketInvocationComplete = createSocketAction<InvocationCompleteEvent>('InvocationCompleteEvent'); export const socketInvocationComplete = createSocketAction<InvocationCompleteEvent>('InvocationCompleteEvent');
export const socketInvocationError = createSocketAction<InvocationErrorEvent>('InvocationErrorEvent'); export const socketInvocationError = createSocketAction<InvocationErrorEvent>('InvocationErrorEvent');
export const socketGeneratorProgress = createSocketAction<InvocationDenoiseProgressEvent>( export const socketInvocationProgress = createSocketAction<InvocationProgressEvent>('InvocationProgressEvent');
'InvocationDenoiseProgressEvent'
);
export const socketModelLoadStarted = createSocketAction<ModelLoadStartedEvent>('ModelLoadStartedEvent'); export const socketModelLoadStarted = createSocketAction<ModelLoadStartedEvent>('ModelLoadStartedEvent');
export const socketModelLoadComplete = createSocketAction<ModelLoadCompleteEvent>('ModelLoadCompleteEvent'); export const socketModelLoadComplete = createSocketAction<ModelLoadCompleteEvent>('ModelLoadCompleteEvent');
export const socketDownloadStarted = createSocketAction<DownloadStartedEvent>('DownloadStartedEvent'); export const socketDownloadStarted = createSocketAction<DownloadStartedEvent>('DownloadStartedEvent');

View File

@ -14,9 +14,9 @@ import {
socketDownloadError, socketDownloadError,
socketDownloadProgress, socketDownloadProgress,
socketDownloadStarted, socketDownloadStarted,
socketGeneratorProgress,
socketInvocationComplete, socketInvocationComplete,
socketInvocationError, socketInvocationError,
socketInvocationProgress,
socketInvocationStarted, socketInvocationStarted,
socketModelInstallCancelled, socketModelInstallCancelled,
socketModelInstallComplete, socketModelInstallComplete,
@ -65,8 +65,8 @@ export const setEventListeners = ({ socket, dispatch }: SetEventListenersArg) =>
socket.on('invocation_started', (data) => { socket.on('invocation_started', (data) => {
dispatch(socketInvocationStarted({ data })); dispatch(socketInvocationStarted({ data }));
}); });
socket.on('invocation_denoise_progress', (data) => { socket.on('invocation_progress', (data) => {
dispatch(socketGeneratorProgress({ data })); dispatch(socketInvocationProgress({ data }));
}); });
socket.on('invocation_error', (data) => { socket.on('invocation_error', (data) => {
dispatch(socketInvocationError({ data })); dispatch(socketInvocationError({ data }));

View File

@ -4,10 +4,9 @@ export type ModelLoadStartedEvent = S['ModelLoadStartedEvent'];
export type ModelLoadCompleteEvent = S['ModelLoadCompleteEvent']; export type ModelLoadCompleteEvent = S['ModelLoadCompleteEvent'];
export type InvocationStartedEvent = S['InvocationStartedEvent']; export type InvocationStartedEvent = S['InvocationStartedEvent'];
export type InvocationDenoiseProgressEvent = S['InvocationDenoiseProgressEvent']; export type InvocationProgressEvent = S['InvocationProgressEvent'];
export type InvocationCompleteEvent = S['InvocationCompleteEvent']; export type InvocationCompleteEvent = S['InvocationCompleteEvent'];
export type InvocationErrorEvent = S['InvocationErrorEvent']; export type InvocationErrorEvent = S['InvocationErrorEvent'];
export type ProgressImage = InvocationDenoiseProgressEvent['progress_image'];
export type ModelInstallDownloadStartedEvent = S['ModelInstallDownloadStartedEvent']; export type ModelInstallDownloadStartedEvent = S['ModelInstallDownloadStartedEvent'];
export type ModelInstallDownloadProgressEvent = S['ModelInstallDownloadProgressEvent']; export type ModelInstallDownloadProgressEvent = S['ModelInstallDownloadProgressEvent'];
@ -39,7 +38,7 @@ type ClientEmitSubscribeBulkDownload = {
type ClientEmitUnsubscribeBulkDownload = ClientEmitSubscribeBulkDownload; type ClientEmitUnsubscribeBulkDownload = ClientEmitSubscribeBulkDownload;
export type ServerToClientEvents = { export type ServerToClientEvents = {
invocation_denoise_progress: (payload: InvocationDenoiseProgressEvent) => void; invocation_progress: (payload: InvocationProgressEvent) => void;
invocation_complete: (payload: InvocationCompleteEvent) => void; invocation_complete: (payload: InvocationCompleteEvent) => void;
invocation_error: (payload: InvocationErrorEvent) => void; invocation_error: (payload: InvocationErrorEvent) => void;
invocation_started: (payload: InvocationStartedEvent) => void; invocation_started: (payload: InvocationStartedEvent) => void;