feat(ui): update UI to use new events

- Use OpenAPI schema for event payload types
- Update all event listeners
- Add missing events / remove old nonexistent events
This commit is contained in:
psychedelicious 2024-03-14 19:05:40 +11:00
parent e25b39aca2
commit a1c4ef55d7
16 changed files with 145 additions and 475 deletions

View File

@ -35,18 +35,17 @@ import { addImageUploadedFulfilledListener } from 'app/store/middleware/listener
import { addModelSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelSelected'; import { addModelSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelSelected';
import { addModelsLoadedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelsLoaded'; import { addModelsLoadedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelsLoaded';
import { addDynamicPromptsListener } from 'app/store/middleware/listenerMiddleware/listeners/promptChanged'; import { addDynamicPromptsListener } from 'app/store/middleware/listenerMiddleware/listeners/promptChanged';
import { addSetDefaultSettingsListener } from 'app/store/middleware/listenerMiddleware/listeners/setDefaultSettings';
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected'; import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected';
import { addSocketDisconnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected'; import { addSocketDisconnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected';
import { addGeneratorProgressEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress'; import { addGeneratorProgressEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress';
import { addGraphExecutionStateCompleteEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketGraphExecutionStateComplete'; import { addGraphExecutionStateCompleteEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketGraphExecutionStateComplete';
import { addInvocationCompleteEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete'; import { addInvocationCompleteEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete';
import { addInvocationErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError'; import { addInvocationErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError';
import { addInvocationRetrievalErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationRetrievalError';
import { addInvocationStartedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted'; import { addInvocationStartedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted';
import { addModelInstallEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall'; import { addModelInstallEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall';
import { addModelLoadEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoad'; import { addModelLoadEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoad';
import { addSocketQueueItemStatusChangedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged'; import { addSocketQueueItemStatusChangedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged';
import { addSessionRetrievalErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketSessionRetrievalError';
import { addSocketSubscribedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketSubscribed'; import { addSocketSubscribedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketSubscribed';
import { addSocketUnsubscribedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketUnsubscribed'; import { addSocketUnsubscribedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketUnsubscribed';
import { addStagingAreaImageSavedListener } from 'app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved'; import { addStagingAreaImageSavedListener } from 'app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved';
@ -55,8 +54,6 @@ import { addUpscaleRequestedListener } from 'app/store/middleware/listenerMiddle
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested'; import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
import type { AppDispatch, RootState } from 'app/store/store'; import type { AppDispatch, RootState } from 'app/store/store';
import { addSetDefaultSettingsListener } from './listeners/setDefaultSettings';
export const listenerMiddleware = createListenerMiddleware(); export const listenerMiddleware = createListenerMiddleware();
export type AppStartListening = TypedStartListening<RootState, AppDispatch>; export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
@ -114,8 +111,6 @@ addSocketSubscribedEventListener(startAppListening);
addSocketUnsubscribedEventListener(startAppListening); addSocketUnsubscribedEventListener(startAppListening);
addModelLoadEventListener(startAppListening); addModelLoadEventListener(startAppListening);
addModelInstallEventListener(startAppListening); addModelInstallEventListener(startAppListening);
addSessionRetrievalErrorEventListener(startAppListening);
addInvocationRetrievalErrorEventListener(startAppListening);
addSocketQueueItemStatusChangedEventListener(startAppListening); addSocketQueueItemStatusChangedEventListener(startAppListening);
addBulkDownloadListeners(startAppListening); addBulkDownloadListeners(startAppListening);

View File

@ -6,8 +6,8 @@ import { toast } from 'common/util/toast';
import { t } from 'i18next'; import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images'; import { imagesApi } from 'services/api/endpoints/images';
import { import {
socketBulkDownloadCompleted, socketBulkDownloadComplete,
socketBulkDownloadFailed, socketBulkDownloadError,
socketBulkDownloadStarted, socketBulkDownloadStarted,
} from 'services/events/actions'; } from 'services/events/actions';
@ -56,7 +56,7 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
}); });
startAppListening({ startAppListening({
actionCreator: socketBulkDownloadCompleted, actionCreator: socketBulkDownloadComplete,
effect: async (action) => { effect: async (action) => {
log.debug(action.payload.data, 'Bulk download preparation completed'); log.debug(action.payload.data, 'Bulk download preparation completed');
@ -89,7 +89,7 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
}); });
startAppListening({ startAppListening({
actionCreator: socketBulkDownloadFailed, actionCreator: socketBulkDownloadError,
effect: async (action) => { effect: async (action) => {
log.debug(action.payload.data, 'Bulk download preparation failed'); log.debug(action.payload.data, 'Bulk download preparation failed');

View File

@ -69,8 +69,8 @@ export const addControlNetImageProcessedListener = (startAppListening: AppStartL
const [invocationCompleteAction] = await take( const [invocationCompleteAction] = await take(
(action): action is ReturnType<typeof socketInvocationComplete> => (action): action is ReturnType<typeof socketInvocationComplete> =>
socketInvocationComplete.match(action) && socketInvocationComplete.match(action) &&
action.payload.data.queue_batch_id === enqueueResult.batch.batch_id && action.payload.data.batch_id === enqueueResult.batch.batch_id &&
action.payload.data.source_node_id === nodeId action.payload.data.invocation_source_id === nodeId
); );
// We still have to check the output type // We still have to check the output type

View File

@ -29,12 +29,12 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
actionCreator: socketInvocationComplete, actionCreator: socketInvocationComplete,
effect: async (action, { dispatch, getState }) => { effect: async (action, { dispatch, getState }) => {
const { data } = action.payload; const { data } = action.payload;
log.debug({ data: parseify(data) }, `Invocation complete (${action.payload.data.node.type})`); log.debug({ data: parseify(data) }, `Invocation complete (${data.invocation_type})`);
const { result, node, queue_batch_id, source_node_id } = data; const { result, invocation_source_id } = data;
// This complete event has an associated image output // This complete event has an associated image output
if (isImageOutput(result) && !nodeTypeDenylist.includes(node.type)) { if (isImageOutput(data.result) && !nodeTypeDenylist.includes(data.invocation_type)) {
const { image_name } = result.image; const { image_name } = data.result.image;
const { canvas, gallery } = getState(); const { canvas, gallery } = getState();
// This populates the `getImageDTO` cache // This populates the `getImageDTO` cache
@ -48,7 +48,7 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
imageDTORequest.unsubscribe(); imageDTORequest.unsubscribe();
// Add canvas images to the staging area // Add canvas images to the staging area
if (canvas.batchIds.includes(queue_batch_id) && data.source_node_id === CANVAS_OUTPUT) { if (canvas.batchIds.includes(data.batch_id) && data.invocation_source_id === CANVAS_OUTPUT) {
dispatch(addImageToStagingArea(imageDTO)); dispatch(addImageToStagingArea(imageDTO));
} }
@ -114,7 +114,7 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
} }
} }
const nes = deepClone($nodeExecutionStates.get()[source_node_id]); const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
if (nes) { if (nes) {
nes.status = zNodeStatus.enum.COMPLETED; nes.status = zNodeStatus.enum.COMPLETED;
if (nes.progress !== null) { if (nes.progress !== null) {

View File

@ -11,9 +11,9 @@ export const addInvocationErrorEventListener = (startAppListening: AppStartListe
startAppListening({ startAppListening({
actionCreator: socketInvocationError, actionCreator: socketInvocationError,
effect: (action) => { effect: (action) => {
log.error(action.payload, `Invocation error (${action.payload.data.node.type})`); log.error(action.payload, `Invocation error (${action.payload.data.invocation_type})`);
const { source_node_id } = action.payload.data; const { invocation_source_id } = action.payload.data;
const nes = deepClone($nodeExecutionStates.get()[source_node_id]); const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
if (nes) { if (nes) {
nes.status = zNodeStatus.enum.FAILED; nes.status = zNodeStatus.enum.FAILED;
nes.error = action.payload.data.error; nes.error = action.payload.data.error;

View File

@ -1,14 +0,0 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { socketInvocationRetrievalError } from 'services/events/actions';
const log = logger('socketio');
export const addInvocationRetrievalErrorEventListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: socketInvocationRetrievalError,
effect: (action) => {
log.error(action.payload, `Invocation retrieval error (${action.payload.data.graph_execution_state_id})`);
},
});
};

View File

@ -11,9 +11,9 @@ export const addInvocationStartedEventListener = (startAppListening: AppStartLis
startAppListening({ startAppListening({
actionCreator: socketInvocationStarted, actionCreator: socketInvocationStarted,
effect: (action) => { effect: (action) => {
log.debug(action.payload, `Invocation started (${action.payload.data.node.type})`); log.debug(action.payload, `Invocation started (${action.payload.data.invocation_type})`);
const { source_node_id } = action.payload.data; const { invocation_source_id } = action.payload.data;
const nes = deepClone($nodeExecutionStates.get()[source_node_id]); const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
if (nes) { if (nes) {
nes.status = zNodeStatus.enum.IN_PROGRESS; nes.status = zNodeStatus.enum.IN_PROGRESS;
upsertExecutionState(nes.nodeId, nes); upsertExecutionState(nes.nodeId, nes);

View File

@ -3,14 +3,14 @@ import { api, LIST_TAG } from 'services/api';
import { modelsApi } from 'services/api/endpoints/models'; import { modelsApi } from 'services/api/endpoints/models';
import { import {
socketModelInstallCancelled, socketModelInstallCancelled,
socketModelInstallCompleted, socketModelInstallComplete,
socketModelInstallDownloading, socketModelInstallDownloadProgress,
socketModelInstallError, socketModelInstallError,
} from 'services/events/actions'; } from 'services/events/actions';
export const addModelInstallEventListener = (startAppListening: AppStartListening) => { export const addModelInstallEventListener = (startAppListening: AppStartListening) => {
startAppListening({ startAppListening({
actionCreator: socketModelInstallDownloading, actionCreator: socketModelInstallDownloadProgress,
effect: async (action, { dispatch }) => { effect: async (action, { dispatch }) => {
const { bytes, total_bytes, id } = action.payload.data; const { bytes, total_bytes, id } = action.payload.data;
@ -29,7 +29,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
}); });
startAppListening({ startAppListening({
actionCreator: socketModelInstallCompleted, actionCreator: socketModelInstallComplete,
effect: (action, { dispatch }) => { effect: (action, { dispatch }) => {
const { id } = action.payload.data; const { id } = action.payload.data;

View File

@ -1,6 +1,6 @@
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { socketModelLoadCompleted, socketModelLoadStarted } from 'services/events/actions'; import { socketModelLoadComplete, socketModelLoadStarted } from 'services/events/actions';
const log = logger('socketio'); const log = logger('socketio');
@ -8,10 +8,11 @@ export const addModelLoadEventListener = (startAppListening: AppStartListening)
startAppListening({ startAppListening({
actionCreator: socketModelLoadStarted, actionCreator: socketModelLoadStarted,
effect: (action) => { effect: (action) => {
const { model_config, submodel_type } = action.payload.data; const { config, submodel_type } = action.payload.data;
const { name, base, type } = model_config; const { name, base, type } = config;
const extras: string[] = [base, type]; const extras: string[] = [base, type];
if (submodel_type) { if (submodel_type) {
extras.push(submodel_type); extras.push(submodel_type);
} }
@ -23,16 +24,15 @@ export const addModelLoadEventListener = (startAppListening: AppStartListening)
}); });
startAppListening({ startAppListening({
actionCreator: socketModelLoadCompleted, actionCreator: socketModelLoadComplete,
effect: (action) => { effect: (action) => {
const { model_config, submodel_type } = action.payload.data; const { config, submodel_type } = action.payload.data;
const { name, base, type } = model_config; const { name, base, type } = config;
const extras: string[] = [base, type]; const extras: string[] = [base, type];
if (submodel_type) { if (submodel_type) {
extras.push(submodel_type); extras.push(submodel_type);
} }
const message = `Model load complete: ${name} (${extras.join(', ')})`; const message = `Model load complete: ${name} (${extras.join(', ')})`;
log.debug(action.payload, message); log.debug(action.payload, message);

View File

@ -14,16 +14,23 @@ export const addSocketQueueItemStatusChangedEventListener = (startAppListening:
actionCreator: socketQueueItemStatusChanged, actionCreator: socketQueueItemStatusChanged,
effect: async (action, { dispatch }) => { effect: async (action, { dispatch }) => {
// we've got new status for the queue item, batch and queue // we've got new status for the queue item, batch and queue
const { queue_item, batch_status, queue_status } = action.payload.data; const { item_id, status, started_at, updated_at, error, completed_at, created_at, batch_status, queue_status } =
action.payload.data;
log.debug(action.payload, `Queue item ${queue_item.item_id} status updated: ${queue_item.status}`); log.debug(action.payload, `Queue item ${item_id} status updated: ${status}`);
// Update this specific queue item in the list of queue items (this is the queue item DTO, without the session) // Update this specific queue item in the list of queue items (this is the queue item DTO, without the session)
dispatch( dispatch(
queueApi.util.updateQueryData('listQueueItems', undefined, (draft) => { queueApi.util.updateQueryData('listQueueItems', undefined, (draft) => {
queueItemsAdapter.updateOne(draft, { queueItemsAdapter.updateOne(draft, {
id: String(queue_item.item_id), id: String(item_id),
changes: queue_item, changes: {
status,
started_at,
updated_at: updated_at ?? undefined,
error,
completed_at: completed_at ?? undefined,
},
}); });
}) })
); );
@ -45,11 +52,18 @@ export const addSocketQueueItemStatusChangedEventListener = (startAppListening:
// Update the queue item status (this is the full queue item, including the session) // Update the queue item status (this is the full queue item, including the session)
dispatch( dispatch(
queueApi.util.updateQueryData('getQueueItem', queue_item.item_id, (draft) => { queueApi.util.updateQueryData('getQueueItem', item_id, (draft) => {
if (!draft) { if (!draft) {
return; return;
} }
Object.assign(draft, queue_item); Object.assign(draft, {
status,
started_at,
updated_at,
error,
completed_at,
created_at,
});
}) })
); );

View File

@ -1,14 +0,0 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { socketSessionRetrievalError } from 'services/events/actions';
const log = logger('socketio');
export const addSessionRetrievalErrorEventListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: socketSessionRetrievalError,
effect: (action) => {
log.error(action.payload, `Session retrieval error (${action.payload.data.graph_execution_state_id})`);
},
});
};

View File

@ -1,8 +1,7 @@
import type { UseToastOptions } from '@invoke-ai/ui-library'; import type { UseToastOptions } from '@invoke-ai/ui-library';
import type { PayloadAction } from '@reduxjs/toolkit'; import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice, isAnyOf } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store'; import type { PersistConfig, RootState } from 'app/store/store';
import { calculateStepPercentage } from 'features/system/util/calculateStepPercentage';
import { makeToast } from 'features/system/util/makeToast'; import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next'; import { t } from 'i18next';
import { startCase } from 'lodash-es'; import { startCase } from 'lodash-es';
@ -14,12 +13,10 @@ import {
socketGraphExecutionStateComplete, socketGraphExecutionStateComplete,
socketInvocationComplete, socketInvocationComplete,
socketInvocationError, socketInvocationError,
socketInvocationRetrievalError,
socketInvocationStarted, socketInvocationStarted,
socketModelLoadCompleted, socketModelLoadComplete,
socketModelLoadStarted, socketModelLoadStarted,
socketQueueItemStatusChanged, socketQueueItemStatusChanged,
socketSessionRetrievalError,
} from 'services/events/actions'; } from 'services/events/actions';
import type { Language, SystemState } from './types'; import type { Language, SystemState } from './types';
@ -110,20 +107,12 @@ export const systemSlice = createSlice({
* Generator Progress * Generator Progress
*/ */
builder.addCase(socketGeneratorProgress, (state, action) => { builder.addCase(socketGeneratorProgress, (state, action) => {
const { const { step, total_steps, progress_image, session_id, batch_id } = action.payload.data;
step,
total_steps,
order,
progress_image,
graph_execution_state_id: session_id,
queue_batch_id: batch_id,
} = action.payload.data;
state.denoiseProgress = { state.denoiseProgress = {
step, step,
total_steps, total_steps,
order, percentage: step / total_steps,
percentage: calculateStepPercentage(step, total_steps, order),
progress_image, progress_image,
session_id, session_id,
batch_id, batch_id,
@ -152,12 +141,12 @@ export const systemSlice = createSlice({
state.status = 'LOADING_MODEL'; state.status = 'LOADING_MODEL';
}); });
builder.addCase(socketModelLoadCompleted, (state) => { builder.addCase(socketModelLoadComplete, (state) => {
state.status = 'CONNECTED'; state.status = 'CONNECTED';
}); });
builder.addCase(socketQueueItemStatusChanged, (state, action) => { builder.addCase(socketQueueItemStatusChanged, (state, action) => {
if (['completed', 'canceled', 'failed'].includes(action.payload.data.queue_item.status)) { if (['completed', 'canceled', 'failed'].includes(action.payload.data.status)) {
state.status = 'CONNECTED'; state.status = 'CONNECTED';
state.denoiseProgress = null; state.denoiseProgress = null;
} }
@ -168,7 +157,7 @@ export const systemSlice = createSlice({
/** /**
* Any server error * Any server error
*/ */
builder.addMatcher(isAnyServerError, (state, action) => { builder.addCase(socketInvocationError, (state, action) => {
state.toastQueue.push( state.toastQueue.push(
makeToast({ makeToast({
title: t('toast.serverError'), title: t('toast.serverError'),
@ -194,8 +183,6 @@ export const {
setShouldEnableInformationalPopovers, setShouldEnableInformationalPopovers,
} = systemSlice.actions; } = systemSlice.actions;
const isAnyServerError = isAnyOf(socketInvocationError, socketSessionRetrievalError, socketInvocationRetrievalError);
export const selectSystemSlice = (state: RootState) => state.system; export const selectSystemSlice = (state: RootState) => state.system;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */ /* eslint-disable-next-line @typescript-eslint/no-explicit-any */

View File

@ -11,7 +11,6 @@ type DenoiseProgress = {
progress_image: ProgressImage | null | undefined; progress_image: ProgressImage | null | undefined;
step: number; step: number;
total_steps: number; total_steps: number;
order: number;
percentage: number; percentage: number;
}; };

View File

@ -1,22 +1,21 @@
import { createAction } from '@reduxjs/toolkit'; import { createAction } from '@reduxjs/toolkit';
import type { import type {
BulkDownloadCompletedEvent, BulkDownloadCompleteEvent,
BulkDownloadFailedEvent, BulkDownloadFailedEvent,
BulkDownloadStartedEvent, BulkDownloadStartedEvent,
GeneratorProgressEvent,
GraphExecutionStateCompleteEvent,
InvocationCompleteEvent, InvocationCompleteEvent,
InvocationDenoiseProgressEvent,
InvocationErrorEvent, InvocationErrorEvent,
InvocationRetrievalErrorEvent,
InvocationStartedEvent, InvocationStartedEvent,
ModelInstallCancelledEvent, ModelInstallCancelledEvent,
ModelInstallCompletedEvent, ModelInstallCompleteEvent,
ModelInstallDownloadingEvent, ModelInstallDownloadProgressEvent,
ModelInstallErrorEvent, ModelInstallErrorEvent,
ModelLoadCompletedEvent, ModelInstallStartedEvent,
ModelLoadCompleteEvent,
ModelLoadStartedEvent, ModelLoadStartedEvent,
QueueItemStatusChangedEvent, QueueItemStatusChangedEvent,
SessionRetrievalErrorEvent, SessionCompleteEvent,
} from 'services/events/types'; } from 'services/events/types';
// Create actions for each socket // Create actions for each socket
@ -45,28 +44,32 @@ export const socketInvocationError = createAction<{
}>('socket/socketInvocationError'); }>('socket/socketInvocationError');
export const socketGraphExecutionStateComplete = createAction<{ export const socketGraphExecutionStateComplete = createAction<{
data: GraphExecutionStateCompleteEvent; data: SessionCompleteEvent;
}>('socket/socketGraphExecutionStateComplete'); }>('socket/socketGraphExecutionStateComplete');
export const socketGeneratorProgress = createAction<{ export const socketGeneratorProgress = createAction<{
data: GeneratorProgressEvent; data: InvocationDenoiseProgressEvent;
}>('socket/socketGeneratorProgress'); }>('socket/socketGeneratorProgress');
export const socketModelLoadStarted = createAction<{ export const socketModelLoadStarted = createAction<{
data: ModelLoadStartedEvent; data: ModelLoadStartedEvent;
}>('socket/socketModelLoadStarted'); }>('socket/socketModelLoadStarted');
export const socketModelLoadCompleted = createAction<{ export const socketModelLoadComplete = createAction<{
data: ModelLoadCompletedEvent; data: ModelLoadCompleteEvent;
}>('socket/socketModelLoadCompleted'); }>('socket/socketModelLoadComplete');
export const socketModelInstallDownloading = createAction<{ export const socketModelInstallStarted = createAction<{
data: ModelInstallDownloadingEvent; data: ModelInstallStartedEvent;
}>('socket/socketModelInstallDownloading'); }>('socket/socketModelInstallStarted');
export const socketModelInstallCompleted = createAction<{ export const socketModelInstallDownloadProgress = createAction<{
data: ModelInstallCompletedEvent; data: ModelInstallDownloadProgressEvent;
}>('socket/socketModelInstallCompleted'); }>('socket/socketModelInstallDownloadProgress');
export const socketModelInstallComplete = createAction<{
data: ModelInstallCompleteEvent;
}>('socket/socketModelInstallComplete');
export const socketModelInstallError = createAction<{ export const socketModelInstallError = createAction<{
data: ModelInstallErrorEvent; data: ModelInstallErrorEvent;
@ -76,14 +79,6 @@ export const socketModelInstallCancelled = createAction<{
data: ModelInstallCancelledEvent; data: ModelInstallCancelledEvent;
}>('socket/socketModelInstallCancelled'); }>('socket/socketModelInstallCancelled');
export const socketSessionRetrievalError = createAction<{
data: SessionRetrievalErrorEvent;
}>('socket/socketSessionRetrievalError');
export const socketInvocationRetrievalError = createAction<{
data: InvocationRetrievalErrorEvent;
}>('socket/socketInvocationRetrievalError');
export const socketQueueItemStatusChanged = createAction<{ export const socketQueueItemStatusChanged = createAction<{
data: QueueItemStatusChangedEvent; data: QueueItemStatusChangedEvent;
}>('socket/socketQueueItemStatusChanged'); }>('socket/socketQueueItemStatusChanged');
@ -92,10 +87,10 @@ export const socketBulkDownloadStarted = createAction<{
data: BulkDownloadStartedEvent; data: BulkDownloadStartedEvent;
}>('socket/socketBulkDownloadStarted'); }>('socket/socketBulkDownloadStarted');
export const socketBulkDownloadCompleted = createAction<{ export const socketBulkDownloadComplete = createAction<{
data: BulkDownloadCompletedEvent; data: BulkDownloadCompleteEvent;
}>('socket/socketBulkDownloadCompleted'); }>('socket/socketBulkDownloadComplete');
export const socketBulkDownloadFailed = createAction<{ export const socketBulkDownloadError = createAction<{
data: BulkDownloadFailedEvent; data: BulkDownloadFailedEvent;
}>('socket/socketBulkDownloadFailed'); }>('socket/socketBulkDownloadError');

View File

@ -1,275 +1,59 @@
import type { components } from 'services/api/schema'; import type { Graph, GraphExecutionState, S } from 'services/api/types';
import type { AnyModelConfig, Graph, GraphExecutionState, SubModelType } from 'services/api/types';
/**
* A progress image, we get one for each step in the generation
*/
export type ProgressImage = {
dataURL: string;
width: number;
height: number;
};
export type AnyInvocation = NonNullable<NonNullable<Graph['nodes']>[string]>; export type AnyInvocation = NonNullable<NonNullable<Graph['nodes']>[string]>;
export type AnyResult = NonNullable<GraphExecutionState['results'][string]>; export type AnyResult = NonNullable<GraphExecutionState['results'][string]>;
type BaseNode = { export type ModelLoadStartedEvent = S['ModelLoadStartedEvent'];
id: string; export type ModelLoadCompleteEvent = S['ModelLoadCompleteEvent'];
type: string;
[key: string]: AnyInvocation[keyof AnyInvocation];
};
export type ModelLoadStartedEvent = { export type InvocationStartedEvent = S['InvocationStartedEvent'];
queue_id: string; export type InvocationDenoiseProgressEvent = S['InvocationDenoiseProgressEvent'];
queue_item_id: number; export type InvocationCompleteEvent = Omit<S['InvocationCompleteEvent'], 'result'> & { result: AnyResult };
queue_batch_id: string; export type InvocationErrorEvent = S['InvocationErrorEvent'];
graph_execution_state_id: string; export type ProgressImage = InvocationDenoiseProgressEvent['progress_image'];
model_config: AnyModelConfig;
submodel_type?: SubModelType | null;
};
export type ModelLoadCompletedEvent = { export type ModelInstallDownloadProgressEvent = S['ModelInstallDownloadProgressEvent'];
queue_id: string; export type ModelInstallCompleteEvent = S['ModelInstallCompleteEvent'];
queue_item_id: number; export type ModelInstallErrorEvent = S['ModelInstallErrorEvent'];
queue_batch_id: string; export type ModelInstallStartedEvent = S['ModelInstallStartedEvent'];
graph_execution_state_id: string; export type ModelInstallCancelledEvent = S['ModelInstallCancelledEvent'];
model_config: AnyModelConfig;
submodel_type?: SubModelType | null;
};
export type ModelInstallDownloadingEvent = { export type SessionCompleteEvent = S['SessionCompleteEvent'];
bytes: number; export type SessionCanceledEvent = S['SessionCanceledEvent'];
local_path: string;
source: string;
timestamp: number;
total_bytes: number;
id: number;
};
export type ModelInstallCompletedEvent = { export type QueueItemStatusChangedEvent = S['QueueItemStatusChangedEvent'];
key: number;
source: string;
timestamp: number;
id: number;
};
export type ModelInstallErrorEvent = { export type BulkDownloadStartedEvent = S['BulkDownloadStartedEvent'];
error: string; export type BulkDownloadCompleteEvent = S['BulkDownloadCompleteEvent'];
error_type: string; export type BulkDownloadFailedEvent = S['BulkDownloadErrorEvent'];
source: string;
timestamp: number;
id: number;
};
export type ModelInstallCancelledEvent = { export type ClientEmitSubscribeQueue = {
source: string;
timestamp: number;
id: number;
};
/**
* A `generator_progress` socket.io event.
*
* @example socket.on('generator_progress', (data: GeneratorProgressEvent) => { ... }
*/
export type GeneratorProgressEvent = {
queue_id: string;
queue_item_id: number;
queue_batch_id: string;
graph_execution_state_id: string;
node_id: string;
source_node_id: string;
progress_image?: ProgressImage;
step: number;
order: number;
total_steps: number;
};
/**
* A `invocation_complete` socket.io event.
*
* `result` is a discriminated union with a `type` property as the discriminant.
*
* @example socket.on('invocation_complete', (data: InvocationCompleteEvent) => { ... }
*/
export type InvocationCompleteEvent = {
queue_id: string;
queue_item_id: number;
queue_batch_id: string;
graph_execution_state_id: string;
node: BaseNode;
source_node_id: string;
result: AnyResult;
};
/**
* A `invocation_error` socket.io event.
*
* @example socket.on('invocation_error', (data: InvocationErrorEvent) => { ... }
*/
export type InvocationErrorEvent = {
queue_id: string;
queue_item_id: number;
queue_batch_id: string;
graph_execution_state_id: string;
node: BaseNode;
source_node_id: string;
error_type: string;
error: string;
};
/**
* A `invocation_started` socket.io event.
*
* @example socket.on('invocation_started', (data: InvocationStartedEvent) => { ... }
*/
export type InvocationStartedEvent = {
queue_id: string;
queue_item_id: number;
queue_batch_id: string;
graph_execution_state_id: string;
node: BaseNode;
source_node_id: string;
};
/**
* A `graph_execution_state_complete` socket.io event.
*
* @example socket.on('graph_execution_state_complete', (data: GraphExecutionStateCompleteEvent) => { ... }
*/
export type GraphExecutionStateCompleteEvent = {
queue_id: string;
queue_item_id: number;
queue_batch_id: string;
graph_execution_state_id: string;
};
/**
* A `session_retrieval_error` socket.io event.
*
* @example socket.on('session_retrieval_error', (data: SessionRetrievalErrorEvent) => { ... }
*/
export type SessionRetrievalErrorEvent = {
queue_id: string;
queue_item_id: number;
queue_batch_id: string;
graph_execution_state_id: string;
error_type: string;
error: string;
};
/**
* A `invocation_retrieval_error` socket.io event.
*
* @example socket.on('invocation_retrieval_error', (data: InvocationRetrievalErrorEvent) => { ... }
*/
export type InvocationRetrievalErrorEvent = {
queue_id: string;
queue_item_id: number;
queue_batch_id: string;
graph_execution_state_id: string;
node_id: string;
error_type: string;
error: string;
};
/**
* A `queue_item_status_changed` socket.io event.
*
* @example socket.on('queue_item_status_changed', (data: QueueItemStatusChangedEvent) => { ... }
*/
export type QueueItemStatusChangedEvent = {
queue_id: string;
queue_item: {
queue_id: string;
item_id: number;
batch_id: string;
session_id: string;
status: components['schemas']['SessionQueueItemDTO']['status'];
error: string | undefined;
created_at: string;
updated_at: string;
started_at: string | undefined;
completed_at: string | undefined;
};
batch_status: {
queue_id: string;
batch_id: string;
pending: number;
in_progress: number;
completed: number;
failed: number;
canceled: number;
total: number;
};
queue_status: {
queue_id: string;
item_id?: number;
batch_id?: string;
session_id?: string;
pending: number;
in_progress: number;
completed: number;
failed: number;
canceled: number;
total: number;
};
};
type ClientEmitSubscribeQueue = {
queue_id: string; queue_id: string;
}; };
export type ClientEmitUnsubscribeQueue = ClientEmitSubscribeQueue;
type ClientEmitUnsubscribeQueue = { export type ClientEmitSubscribeBulkDownload = {
queue_id: string;
};
export type BulkDownloadStartedEvent = {
bulk_download_id: string;
bulk_download_item_id: string;
bulk_download_item_name: string;
};
export type BulkDownloadCompletedEvent = {
bulk_download_id: string;
bulk_download_item_id: string;
bulk_download_item_name: string;
};
export type BulkDownloadFailedEvent = {
bulk_download_id: string;
bulk_download_item_id: string;
bulk_download_item_name: string;
error: string;
};
type ClientEmitSubscribeBulkDownload = {
bulk_download_id: string;
};
type ClientEmitUnsubscribeBulkDownload = {
bulk_download_id: string; bulk_download_id: string;
}; };
export type ClientEmitUnsubscribeBulkDownload = ClientEmitSubscribeBulkDownload;
export type ServerToClientEvents = { export type ServerToClientEvents = {
generator_progress: (payload: GeneratorProgressEvent) => void; invocation_denoise_progress: (payload: InvocationDenoiseProgressEvent) => void;
invocation_complete: (payload: InvocationCompleteEvent) => void; invocation_complete: (payload: InvocationCompleteEvent) => void;
invocation_error: (payload: InvocationErrorEvent) => void; invocation_error: (payload: InvocationErrorEvent) => void;
invocation_started: (payload: InvocationStartedEvent) => void; invocation_started: (payload: InvocationStartedEvent) => void;
graph_execution_state_complete: (payload: GraphExecutionStateCompleteEvent) => void; session_complete: (payload: SessionCompleteEvent) => void;
model_load_started: (payload: ModelLoadStartedEvent) => void; model_load_started: (payload: ModelLoadStartedEvent) => void;
model_load_completed: (payload: ModelLoadCompletedEvent) => void; model_install_started: (payload: ModelInstallStartedEvent) => void;
model_install_downloading: (payload: ModelInstallDownloadingEvent) => void; model_install_download_progress: (payload: ModelInstallDownloadProgressEvent) => void;
model_install_completed: (payload: ModelInstallCompletedEvent) => void; model_install_complete: (payload: ModelInstallCompleteEvent) => void;
model_install_error: (payload: ModelInstallErrorEvent) => void; model_install_error: (payload: ModelInstallErrorEvent) => void;
model_install_canceled: (payload: ModelInstallCancelledEvent) => void; model_install_cancelled: (payload: ModelInstallCancelledEvent) => void;
session_retrieval_error: (payload: SessionRetrievalErrorEvent) => void; model_load_complete: (payload: ModelLoadCompleteEvent) => void;
invocation_retrieval_error: (payload: InvocationRetrievalErrorEvent) => void;
queue_item_status_changed: (payload: QueueItemStatusChangedEvent) => void; queue_item_status_changed: (payload: QueueItemStatusChangedEvent) => void;
bulk_download_started: (payload: BulkDownloadStartedEvent) => void; bulk_download_started: (payload: BulkDownloadStartedEvent) => void;
bulk_download_completed: (payload: BulkDownloadCompletedEvent) => void; bulk_download_complete: (payload: BulkDownloadCompleteEvent) => void;
bulk_download_failed: (payload: BulkDownloadFailedEvent) => void; bulk_download_error: (payload: BulkDownloadFailedEvent) => void;
}; };
export type ClientToServerEvents = { export type ClientToServerEvents = {

View File

@ -5,8 +5,8 @@ import type { AppDispatch } from 'app/store/store';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast'; import { makeToast } from 'features/system/util/makeToast';
import { import {
socketBulkDownloadCompleted, socketBulkDownloadComplete,
socketBulkDownloadFailed, socketBulkDownloadError,
socketBulkDownloadStarted, socketBulkDownloadStarted,
socketConnected, socketConnected,
socketDisconnected, socketDisconnected,
@ -14,15 +14,15 @@ import {
socketGraphExecutionStateComplete, socketGraphExecutionStateComplete,
socketInvocationComplete, socketInvocationComplete,
socketInvocationError, socketInvocationError,
socketInvocationRetrievalError,
socketInvocationStarted, socketInvocationStarted,
socketModelInstallCompleted, socketModelInstallCancelled,
socketModelInstallDownloading, socketModelInstallComplete,
socketModelInstallDownloadProgress,
socketModelInstallError, socketModelInstallError,
socketModelLoadCompleted, socketModelInstallStarted,
socketModelLoadComplete,
socketModelLoadStarted, socketModelLoadStarted,
socketQueueItemStatusChanged, socketQueueItemStatusChanged,
socketSessionRetrievalError,
} from 'services/events/actions'; } from 'services/events/actions';
import type { ClientToServerEvents, ServerToClientEvents } from 'services/events/types'; import type { ClientToServerEvents, ServerToClientEvents } from 'services/events/types';
import type { Socket } from 'socket.io-client'; import type { Socket } from 'socket.io-client';
@ -65,131 +65,55 @@ export const setEventListeners = (arg: SetEventListenersArg) => {
} }
}); });
/**
* Disconnect
*/
socket.on('disconnect', () => { socket.on('disconnect', () => {
dispatch(socketDisconnected()); dispatch(socketDisconnected());
}); });
/**
* Invocation started
*/
socket.on('invocation_started', (data) => { socket.on('invocation_started', (data) => {
dispatch(socketInvocationStarted({ data })); dispatch(socketInvocationStarted({ data }));
}); });
/** socket.on('invocation_denoise_progress', (data) => {
* Generator progress
*/
socket.on('generator_progress', (data) => {
dispatch(socketGeneratorProgress({ data })); dispatch(socketGeneratorProgress({ data }));
}); });
/**
* Invocation error
*/
socket.on('invocation_error', (data) => { socket.on('invocation_error', (data) => {
dispatch(socketInvocationError({ data })); dispatch(socketInvocationError({ data }));
}); });
/**
* Invocation complete
*/
socket.on('invocation_complete', (data) => { socket.on('invocation_complete', (data) => {
dispatch( dispatch(socketInvocationComplete({ data }));
socketInvocationComplete({
data,
})
);
}); });
/** socket.on('session_complete', (data) => {
* Graph complete dispatch(socketGraphExecutionStateComplete({ data }));
*/
socket.on('graph_execution_state_complete', (data) => {
dispatch(
socketGraphExecutionStateComplete({
data,
})
);
}); });
/**
* Model load started
*/
socket.on('model_load_started', (data) => { socket.on('model_load_started', (data) => {
dispatch( dispatch(socketModelLoadStarted({ data }));
socketModelLoadStarted({
data,
})
);
}); });
/** socket.on('model_load_complete', (data) => {
* Model load completed dispatch(socketModelLoadComplete({ data }));
*/
socket.on('model_load_completed', (data) => {
dispatch(
socketModelLoadCompleted({
data,
})
);
}); });
/** socket.on('model_install_started', (data) => {
* Model Install Downloading dispatch(socketModelInstallStarted({ data }));
*/
socket.on('model_install_downloading', (data) => {
dispatch(
socketModelInstallDownloading({
data,
})
);
}); });
/** socket.on('model_install_download_progress', (data) => {
* Model Install Completed dispatch(socketModelInstallDownloadProgress({ data }));
*/ });
socket.on('model_install_completed', (data) => {
dispatch( socket.on('model_install_complete', (data) => {
socketModelInstallCompleted({ dispatch(socketModelInstallComplete({ data }));
data,
})
);
}); });
/**
* Model Install Error
*/
socket.on('model_install_error', (data) => { socket.on('model_install_error', (data) => {
dispatch( dispatch(socketModelInstallError({ data }));
socketModelInstallError({
data,
})
);
}); });
/** socket.on('model_install_cancelled', (data) => {
* Session retrieval error dispatch(socketModelInstallCancelled({ data }));
*/
socket.on('session_retrieval_error', (data) => {
dispatch(
socketSessionRetrievalError({
data,
})
);
});
/**
* Invocation retrieval error
*/
socket.on('invocation_retrieval_error', (data) => {
dispatch(
socketInvocationRetrievalError({
data,
})
);
}); });
socket.on('queue_item_status_changed', (data) => { socket.on('queue_item_status_changed', (data) => {
@ -200,11 +124,11 @@ export const setEventListeners = (arg: SetEventListenersArg) => {
dispatch(socketBulkDownloadStarted({ data })); dispatch(socketBulkDownloadStarted({ data }));
}); });
socket.on('bulk_download_completed', (data) => { socket.on('bulk_download_complete', (data) => {
dispatch(socketBulkDownloadCompleted({ data })); dispatch(socketBulkDownloadComplete({ data }));
}); });
socket.on('bulk_download_failed', (data) => { socket.on('bulk_download_error', (data) => {
dispatch(socketBulkDownloadFailed({ data })); dispatch(socketBulkDownloadError({ data }));
}); });
}; };