mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): handle new progress event
Minor changes to use the new progress event. Only additional feature is if the progress has a message, it is displayed as a tooltip on the progress bar.
This commit is contained in:
parent
4e0e9041e2
commit
c2e9bdc6c5
@ -3,7 +3,7 @@ import { deepClone } from 'common/util/deepClone';
|
|||||||
import { isAnyGraphBuilt } from 'features/nodes/store/actions';
|
import { isAnyGraphBuilt } from 'features/nodes/store/actions';
|
||||||
import { appInfoApi } from 'services/api/endpoints/appInfo';
|
import { appInfoApi } from 'services/api/endpoints/appInfo';
|
||||||
import type { Graph } from 'services/api/types';
|
import type { Graph } from 'services/api/types';
|
||||||
import { socketGeneratorProgress } from 'services/events/actions';
|
import { socketInvocationProgress } from 'services/events/actions';
|
||||||
|
|
||||||
export const actionSanitizer = <A extends UnknownAction>(action: A): A => {
|
export const actionSanitizer = <A extends UnknownAction>(action: A): A => {
|
||||||
if (isAnyGraphBuilt(action)) {
|
if (isAnyGraphBuilt(action)) {
|
||||||
@ -24,10 +24,10 @@ export const actionSanitizer = <A extends UnknownAction>(action: A): A => {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
if (socketGeneratorProgress.match(action)) {
|
if (socketInvocationProgress.match(action)) {
|
||||||
const sanitized = deepClone(action);
|
const sanitized = deepClone(action);
|
||||||
if (sanitized.payload.data.progress_image) {
|
if (sanitized.payload.data.image) {
|
||||||
sanitized.payload.data.progress_image.dataURL = '<Progress image omitted>';
|
sanitized.payload.data.image.dataURL = '<Progress image omitted>';
|
||||||
}
|
}
|
||||||
return sanitized;
|
return sanitized;
|
||||||
}
|
}
|
||||||
|
@ -39,9 +39,9 @@ import { addDynamicPromptsListener } from 'app/store/middleware/listenerMiddlewa
|
|||||||
import { addSetDefaultSettingsListener } from 'app/store/middleware/listenerMiddleware/listeners/setDefaultSettings';
|
import { addSetDefaultSettingsListener } from 'app/store/middleware/listenerMiddleware/listeners/setDefaultSettings';
|
||||||
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected';
|
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected';
|
||||||
import { addSocketDisconnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected';
|
import { addSocketDisconnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected';
|
||||||
import { addGeneratorProgressEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress';
|
|
||||||
import { addInvocationCompleteEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete';
|
import { addInvocationCompleteEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete';
|
||||||
import { addInvocationErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError';
|
import { addInvocationErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError';
|
||||||
|
import { addInvocationProgressEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationProgress';
|
||||||
import { addInvocationStartedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted';
|
import { addInvocationStartedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted';
|
||||||
import { addModelInstallEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall';
|
import { addModelInstallEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall';
|
||||||
import { addModelLoadEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoad';
|
import { addModelLoadEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoad';
|
||||||
@ -102,7 +102,7 @@ addStagingAreaImageSavedListener(startAppListening);
|
|||||||
addCommitStagingAreaImageListener(startAppListening);
|
addCommitStagingAreaImageListener(startAppListening);
|
||||||
|
|
||||||
// Socket.IO
|
// Socket.IO
|
||||||
addGeneratorProgressEventListener(startAppListening);
|
addInvocationProgressEventListener(startAppListening);
|
||||||
addInvocationCompleteEventListener(startAppListening);
|
addInvocationCompleteEventListener(startAppListening);
|
||||||
addInvocationErrorEventListener(startAppListening);
|
addInvocationErrorEventListener(startAppListening);
|
||||||
addInvocationStartedEventListener(startAppListening);
|
addInvocationStartedEventListener(startAppListening);
|
||||||
|
@ -4,21 +4,21 @@ import { deepClone } from 'common/util/deepClone';
|
|||||||
import { parseify } from 'common/util/serialize';
|
import { parseify } from 'common/util/serialize';
|
||||||
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
|
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||||
import { socketGeneratorProgress } from 'services/events/actions';
|
import { socketInvocationProgress } from 'services/events/actions';
|
||||||
|
|
||||||
const log = logger('socketio');
|
const log = logger('socketio');
|
||||||
|
|
||||||
export const addGeneratorProgressEventListener = (startAppListening: AppStartListening) => {
|
export const addInvocationProgressEventListener = (startAppListening: AppStartListening) => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: socketGeneratorProgress,
|
actionCreator: socketInvocationProgress,
|
||||||
effect: (action) => {
|
effect: (action) => {
|
||||||
log.trace(parseify(action.payload), `Generator progress`);
|
log.trace(parseify(action.payload), `Generator progress`);
|
||||||
const { invocation_source_id, step, total_steps, progress_image } = action.payload.data;
|
const { invocation_source_id, percentage, image } = action.payload.data;
|
||||||
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
|
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
|
||||||
if (nes) {
|
if (nes) {
|
||||||
nes.status = zNodeStatus.enum.IN_PROGRESS;
|
nes.status = zNodeStatus.enum.IN_PROGRESS;
|
||||||
nes.progress = (step + 1) / total_steps;
|
nes.progress = percentage;
|
||||||
nes.progressImage = progress_image ?? null;
|
nes.progressImage = image ?? null;
|
||||||
upsertExecutionState(nes.nodeId, nes);
|
upsertExecutionState(nes.nodeId, nes);
|
||||||
}
|
}
|
||||||
},
|
},
|
@ -10,8 +10,7 @@ const progressImageSelector = createMemoizedSelector([selectSystemSlice, selectC
|
|||||||
const { batchIds } = canvas;
|
const { batchIds } = canvas;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
progressImage:
|
progressImage: denoiseProgress && batchIds.includes(denoiseProgress.batch_id) ? denoiseProgress.image : undefined,
|
||||||
denoiseProgress && batchIds.includes(denoiseProgress.batch_id) ? denoiseProgress.progress_image : undefined,
|
|
||||||
boundingBox: canvas.layerState.stagingArea.boundingBox,
|
boundingBox: canvas.layerState.stagingArea.boundingBox,
|
||||||
};
|
};
|
||||||
});
|
});
|
||||||
|
@ -40,7 +40,7 @@ const selectShouldDisableToolbarButtons = createSelector(
|
|||||||
selectGallerySlice,
|
selectGallerySlice,
|
||||||
selectLastSelectedImage,
|
selectLastSelectedImage,
|
||||||
(system, gallery, lastSelectedImage) => {
|
(system, gallery, lastSelectedImage) => {
|
||||||
const hasProgressImage = Boolean(system.denoiseProgress?.progress_image);
|
const hasProgressImage = Boolean(system.denoiseProgress?.image);
|
||||||
return hasProgressImage || !lastSelectedImage;
|
return hasProgressImage || !lastSelectedImage;
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
@ -4,7 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks';
|
|||||||
import { memo, useMemo } from 'react';
|
import { memo, useMemo } from 'react';
|
||||||
|
|
||||||
const CurrentImagePreview = () => {
|
const CurrentImagePreview = () => {
|
||||||
const progress_image = useAppSelector((s) => s.system.denoiseProgress?.progress_image);
|
const image = useAppSelector((s) => s.system.denoiseProgress?.image);
|
||||||
const shouldAntialiasProgressImage = useAppSelector((s) => s.system.shouldAntialiasProgressImage);
|
const shouldAntialiasProgressImage = useAppSelector((s) => s.system.shouldAntialiasProgressImage);
|
||||||
|
|
||||||
const sx = useMemo<SystemStyleObject>(
|
const sx = useMemo<SystemStyleObject>(
|
||||||
@ -14,15 +14,15 @@ const CurrentImagePreview = () => {
|
|||||||
[shouldAntialiasProgressImage]
|
[shouldAntialiasProgressImage]
|
||||||
);
|
);
|
||||||
|
|
||||||
if (!progress_image) {
|
if (!image) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Image
|
<Image
|
||||||
src={progress_image.dataURL}
|
src={image.dataURL}
|
||||||
width={progress_image.width}
|
width={image.width}
|
||||||
height={progress_image.height}
|
height={image.height}
|
||||||
draggable={false}
|
draggable={false}
|
||||||
data-testid="progress-image"
|
data-testid="progress-image"
|
||||||
objectFit="contain"
|
objectFit="contain"
|
||||||
|
@ -20,7 +20,7 @@ const selector = createMemoizedSelector(selectSystemSlice, selectGallerySlice, (
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
imageDTO,
|
imageDTO,
|
||||||
progressImage: system.denoiseProgress?.progress_image,
|
progressImage: system.denoiseProgress?.image,
|
||||||
};
|
};
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import { Progress } from '@invoke-ai/ui-library';
|
import { Progress, Tooltip } from '@invoke-ai/ui-library';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { selectSystemSlice } from 'features/system/store/systemSlice';
|
import { selectSystemSlice } from 'features/system/store/systemSlice';
|
||||||
@ -15,10 +15,12 @@ const ProgressBar = () => {
|
|||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { data: queueStatus } = useGetQueueStatusQuery();
|
const { data: queueStatus } = useGetQueueStatusQuery();
|
||||||
const isConnected = useAppSelector((s) => s.system.isConnected);
|
const isConnected = useAppSelector((s) => s.system.isConnected);
|
||||||
const hasSteps = useAppSelector((s) => Boolean(s.system.denoiseProgress));
|
const message = useAppSelector((s) => s.system.denoiseProgress?.message);
|
||||||
|
const hasSteps = useAppSelector((s) => Boolean(s.system.denoiseProgress?.percentage !== undefined));
|
||||||
const value = useAppSelector(selectProgressValue);
|
const value = useAppSelector(selectProgressValue);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
<Tooltip label={message} placement="end">
|
||||||
<Progress
|
<Progress
|
||||||
value={value}
|
value={value}
|
||||||
aria-label={t('accessibility.invokeProgressBar')}
|
aria-label={t('accessibility.invokeProgressBar')}
|
||||||
@ -27,6 +29,7 @@ const ProgressBar = () => {
|
|||||||
w="full"
|
w="full"
|
||||||
colorScheme="invokeBlue"
|
colorScheme="invokeBlue"
|
||||||
/>
|
/>
|
||||||
|
</Tooltip>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -5,8 +5,8 @@ import type { LogLevelName } from 'roarr';
|
|||||||
import {
|
import {
|
||||||
socketConnected,
|
socketConnected,
|
||||||
socketDisconnected,
|
socketDisconnected,
|
||||||
socketGeneratorProgress,
|
|
||||||
socketInvocationComplete,
|
socketInvocationComplete,
|
||||||
|
socketInvocationProgress,
|
||||||
socketInvocationStarted,
|
socketInvocationStarted,
|
||||||
socketModelLoadComplete,
|
socketModelLoadComplete,
|
||||||
socketModelLoadStarted,
|
socketModelLoadStarted,
|
||||||
@ -95,8 +95,8 @@ export const systemSlice = createSlice({
|
|||||||
/**
|
/**
|
||||||
* Generator Progress
|
* Generator Progress
|
||||||
*/
|
*/
|
||||||
builder.addCase(socketGeneratorProgress, (state, action) => {
|
builder.addCase(socketInvocationProgress, (state, action) => {
|
||||||
const { step, total_steps, progress_image, session_id, batch_id, percentage } = action.payload.data;
|
const { image, session_id, batch_id, percentage, message } = action.payload.data;
|
||||||
|
|
||||||
if (state.cancellations.includes(session_id)) {
|
if (state.cancellations.includes(session_id)) {
|
||||||
// Do not update the progress if this session has been cancelled. This prevents a race condition where we get a
|
// Do not update the progress if this session has been cancelled. This prevents a race condition where we get a
|
||||||
@ -105,10 +105,9 @@ export const systemSlice = createSlice({
|
|||||||
}
|
}
|
||||||
|
|
||||||
state.denoiseProgress = {
|
state.denoiseProgress = {
|
||||||
step,
|
message,
|
||||||
total_steps,
|
|
||||||
percentage,
|
percentage,
|
||||||
progress_image,
|
image,
|
||||||
session_id,
|
session_id,
|
||||||
batch_id,
|
batch_id,
|
||||||
};
|
};
|
||||||
|
@ -1,18 +1,9 @@
|
|||||||
import type { LogLevel } from 'app/logging/logger';
|
import type { LogLevel } from 'app/logging/logger';
|
||||||
import type { ProgressImage } from 'services/events/types';
|
import type { InvocationProgressEvent } from 'services/events/types';
|
||||||
import { z } from 'zod';
|
import { z } from 'zod';
|
||||||
|
|
||||||
type SystemStatus = 'CONNECTED' | 'DISCONNECTED' | 'PROCESSING' | 'ERROR' | 'LOADING_MODEL';
|
type SystemStatus = 'CONNECTED' | 'DISCONNECTED' | 'PROCESSING' | 'ERROR' | 'LOADING_MODEL';
|
||||||
|
|
||||||
type DenoiseProgress = {
|
|
||||||
session_id: string;
|
|
||||||
batch_id: string;
|
|
||||||
progress_image: ProgressImage | null | undefined;
|
|
||||||
step: number;
|
|
||||||
total_steps: number;
|
|
||||||
percentage: number;
|
|
||||||
};
|
|
||||||
|
|
||||||
const zLanguage = z.enum([
|
const zLanguage = z.enum([
|
||||||
'ar',
|
'ar',
|
||||||
'az',
|
'az',
|
||||||
@ -45,7 +36,7 @@ export interface SystemState {
|
|||||||
isConnected: boolean;
|
isConnected: boolean;
|
||||||
shouldConfirmOnDelete: boolean;
|
shouldConfirmOnDelete: boolean;
|
||||||
enableImageDebugging: boolean;
|
enableImageDebugging: boolean;
|
||||||
denoiseProgress: DenoiseProgress | null;
|
denoiseProgress: Pick<InvocationProgressEvent, 'session_id' | 'batch_id' | 'image' | 'percentage' | 'message'> | null;
|
||||||
consoleLogLevel: LogLevel;
|
consoleLogLevel: LogLevel;
|
||||||
shouldLogToConsole: boolean;
|
shouldLogToConsole: boolean;
|
||||||
shouldAntialiasProgressImage: boolean;
|
shouldAntialiasProgressImage: boolean;
|
||||||
|
@ -9,8 +9,8 @@ import type {
|
|||||||
DownloadProgressEvent,
|
DownloadProgressEvent,
|
||||||
DownloadStartedEvent,
|
DownloadStartedEvent,
|
||||||
InvocationCompleteEvent,
|
InvocationCompleteEvent,
|
||||||
InvocationDenoiseProgressEvent,
|
|
||||||
InvocationErrorEvent,
|
InvocationErrorEvent,
|
||||||
|
InvocationProgressEvent,
|
||||||
InvocationStartedEvent,
|
InvocationStartedEvent,
|
||||||
ModelInstallCancelledEvent,
|
ModelInstallCancelledEvent,
|
||||||
ModelInstallCompleteEvent,
|
ModelInstallCompleteEvent,
|
||||||
@ -32,9 +32,7 @@ export const socketDisconnected = createSocketAction('Disconnected');
|
|||||||
export const socketInvocationStarted = createSocketAction<InvocationStartedEvent>('InvocationStartedEvent');
|
export const socketInvocationStarted = createSocketAction<InvocationStartedEvent>('InvocationStartedEvent');
|
||||||
export const socketInvocationComplete = createSocketAction<InvocationCompleteEvent>('InvocationCompleteEvent');
|
export const socketInvocationComplete = createSocketAction<InvocationCompleteEvent>('InvocationCompleteEvent');
|
||||||
export const socketInvocationError = createSocketAction<InvocationErrorEvent>('InvocationErrorEvent');
|
export const socketInvocationError = createSocketAction<InvocationErrorEvent>('InvocationErrorEvent');
|
||||||
export const socketGeneratorProgress = createSocketAction<InvocationDenoiseProgressEvent>(
|
export const socketInvocationProgress = createSocketAction<InvocationProgressEvent>('InvocationProgressEvent');
|
||||||
'InvocationDenoiseProgressEvent'
|
|
||||||
);
|
|
||||||
export const socketModelLoadStarted = createSocketAction<ModelLoadStartedEvent>('ModelLoadStartedEvent');
|
export const socketModelLoadStarted = createSocketAction<ModelLoadStartedEvent>('ModelLoadStartedEvent');
|
||||||
export const socketModelLoadComplete = createSocketAction<ModelLoadCompleteEvent>('ModelLoadCompleteEvent');
|
export const socketModelLoadComplete = createSocketAction<ModelLoadCompleteEvent>('ModelLoadCompleteEvent');
|
||||||
export const socketDownloadStarted = createSocketAction<DownloadStartedEvent>('DownloadStartedEvent');
|
export const socketDownloadStarted = createSocketAction<DownloadStartedEvent>('DownloadStartedEvent');
|
||||||
|
@ -14,9 +14,9 @@ import {
|
|||||||
socketDownloadError,
|
socketDownloadError,
|
||||||
socketDownloadProgress,
|
socketDownloadProgress,
|
||||||
socketDownloadStarted,
|
socketDownloadStarted,
|
||||||
socketGeneratorProgress,
|
|
||||||
socketInvocationComplete,
|
socketInvocationComplete,
|
||||||
socketInvocationError,
|
socketInvocationError,
|
||||||
|
socketInvocationProgress,
|
||||||
socketInvocationStarted,
|
socketInvocationStarted,
|
||||||
socketModelInstallCancelled,
|
socketModelInstallCancelled,
|
||||||
socketModelInstallComplete,
|
socketModelInstallComplete,
|
||||||
@ -65,8 +65,8 @@ export const setEventListeners = ({ socket, dispatch }: SetEventListenersArg) =>
|
|||||||
socket.on('invocation_started', (data) => {
|
socket.on('invocation_started', (data) => {
|
||||||
dispatch(socketInvocationStarted({ data }));
|
dispatch(socketInvocationStarted({ data }));
|
||||||
});
|
});
|
||||||
socket.on('invocation_denoise_progress', (data) => {
|
socket.on('invocation_progress', (data) => {
|
||||||
dispatch(socketGeneratorProgress({ data }));
|
dispatch(socketInvocationProgress({ data }));
|
||||||
});
|
});
|
||||||
socket.on('invocation_error', (data) => {
|
socket.on('invocation_error', (data) => {
|
||||||
dispatch(socketInvocationError({ data }));
|
dispatch(socketInvocationError({ data }));
|
||||||
|
@ -4,10 +4,9 @@ export type ModelLoadStartedEvent = S['ModelLoadStartedEvent'];
|
|||||||
export type ModelLoadCompleteEvent = S['ModelLoadCompleteEvent'];
|
export type ModelLoadCompleteEvent = S['ModelLoadCompleteEvent'];
|
||||||
|
|
||||||
export type InvocationStartedEvent = S['InvocationStartedEvent'];
|
export type InvocationStartedEvent = S['InvocationStartedEvent'];
|
||||||
export type InvocationDenoiseProgressEvent = S['InvocationDenoiseProgressEvent'];
|
export type InvocationProgressEvent = S['InvocationProgressEvent'];
|
||||||
export type InvocationCompleteEvent = S['InvocationCompleteEvent'];
|
export type InvocationCompleteEvent = S['InvocationCompleteEvent'];
|
||||||
export type InvocationErrorEvent = S['InvocationErrorEvent'];
|
export type InvocationErrorEvent = S['InvocationErrorEvent'];
|
||||||
export type ProgressImage = InvocationDenoiseProgressEvent['progress_image'];
|
|
||||||
|
|
||||||
export type ModelInstallDownloadStartedEvent = S['ModelInstallDownloadStartedEvent'];
|
export type ModelInstallDownloadStartedEvent = S['ModelInstallDownloadStartedEvent'];
|
||||||
export type ModelInstallDownloadProgressEvent = S['ModelInstallDownloadProgressEvent'];
|
export type ModelInstallDownloadProgressEvent = S['ModelInstallDownloadProgressEvent'];
|
||||||
@ -39,7 +38,7 @@ type ClientEmitSubscribeBulkDownload = {
|
|||||||
type ClientEmitUnsubscribeBulkDownload = ClientEmitSubscribeBulkDownload;
|
type ClientEmitUnsubscribeBulkDownload = ClientEmitSubscribeBulkDownload;
|
||||||
|
|
||||||
export type ServerToClientEvents = {
|
export type ServerToClientEvents = {
|
||||||
invocation_denoise_progress: (payload: InvocationDenoiseProgressEvent) => void;
|
invocation_progress: (payload: InvocationProgressEvent) => void;
|
||||||
invocation_complete: (payload: InvocationCompleteEvent) => void;
|
invocation_complete: (payload: InvocationCompleteEvent) => void;
|
||||||
invocation_error: (payload: InvocationErrorEvent) => void;
|
invocation_error: (payload: InvocationErrorEvent) => void;
|
||||||
invocation_started: (payload: InvocationStartedEvent) => void;
|
invocation_started: (payload: InvocationStartedEvent) => void;
|
||||||
|
Loading…
Reference in New Issue
Block a user