tidy(ui): cleanup after events change

This commit is contained in:
psychedelicious 2024-08-17 15:41:34 +10:00
parent b630dbdf20
commit c18fb980a2
18 changed files with 12 additions and 865 deletions

View File

@ -1,9 +1,7 @@
import type { UnknownAction } from '@reduxjs/toolkit';
import { deepClone } from 'common/util/deepClone';
import { isAnyGraphBuilt } from 'features/nodes/store/actions';
import { appInfoApi } from 'services/api/endpoints/appInfo';
import type { Graph } from 'services/api/types';
import { socketGeneratorProgress } from 'services/events/actions';
export const actionSanitizer = <A extends UnknownAction>(action: A): A => {
if (isAnyGraphBuilt(action)) {
@ -24,13 +22,5 @@ export const actionSanitizer = <A extends UnknownAction>(action: A): A => {
};
}
if (socketGeneratorProgress.match(action)) {
const sanitized = deepClone(action);
if (sanitized.payload.data.progress_image) {
sanitized.payload.data.progress_image.dataURL = '<Progress image omitted>';
}
return sanitized;
}
return action;
};

View File

@ -26,15 +26,7 @@ import { addModelSelectedListener } from 'app/store/middleware/listenerMiddlewar
import { addModelsLoadedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelsLoaded';
import { addDynamicPromptsListener } from 'app/store/middleware/listenerMiddleware/listeners/promptChanged';
import { addSetDefaultSettingsListener } from 'app/store/middleware/listenerMiddleware/listeners/setDefaultSettings';
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected';
import { addSocketDisconnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected';
import { addGeneratorProgressEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress';
import { addInvocationCompleteEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete';
import { addInvocationErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError';
import { addInvocationStartedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted';
import { addModelInstallEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall';
import { addModelLoadEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoad';
import { addSocketQueueEventsListeners } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueEvents';
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketConnected';
import { addUpdateAllNodesRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested';
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
import type { AppDispatch, RootState } from 'app/store/store';
@ -90,15 +82,9 @@ addBatchEnqueuedListener(startAppListening);
addStagingListeners(startAppListening);
// Socket.IO
addGeneratorProgressEventListener(startAppListening);
addInvocationCompleteEventListener(startAppListening);
addInvocationErrorEventListener(startAppListening);
addInvocationStartedEventListener(startAppListening);
addSocketConnectedEventListener(startAppListening);
addSocketDisconnectedEventListener(startAppListening);
addModelLoadEventListener(startAppListening);
addModelInstallEventListener(startAppListening);
addSocketQueueEventsListeners(startAppListening);
// Gallery bulk download
addBulkDownloadListeners(startAppListening);
// Boards

View File

@ -1,21 +1,15 @@
import { ExternalLink } from '@invoke-ai/ui-library';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
import {
socketBulkDownloadComplete,
socketBulkDownloadError,
socketBulkDownloadStarted,
} from 'services/events/actions';
const log = logger('images');
export const addBulkDownloadListeners = (startAppListening: AppStartListening) => {
startAppListening({
matcher: imagesApi.endpoints.bulkDownloadImages.matchFulfilled,
effect: async (action) => {
effect: (action) => {
log.debug(action.payload, 'Bulk download requested');
// If we have an item name, we are processing the bulk download locally and should use it as the toast id to
@ -33,7 +27,7 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
startAppListening({
matcher: imagesApi.endpoints.bulkDownloadImages.matchRejected,
effect: async () => {
effect: () => {
log.debug('Bulk download request failed');
// There isn't any toast to update if we get this event.
@ -44,55 +38,4 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
});
},
});
startAppListening({
actionCreator: socketBulkDownloadStarted,
effect: async (action) => {
// This should always happen immediately after the bulk download request, so we don't need to show a toast here.
log.debug(action.payload.data, 'Bulk download preparation started');
},
});
startAppListening({
actionCreator: socketBulkDownloadComplete,
effect: async (action) => {
log.debug(action.payload.data, 'Bulk download preparation completed');
const { bulk_download_item_name } = action.payload.data;
// TODO(psyche): This URL may break in in some environments (e.g. Nvidia workbench) but we need to test it first
const url = `/api/v1/images/download/${bulk_download_item_name}`;
toast({
id: bulk_download_item_name,
title: t('gallery.bulkDownloadReady', 'Download ready'),
status: 'success',
description: (
<ExternalLink
label={t('gallery.clickToDownload', 'Click here to download')}
href={url}
download={bulk_download_item_name}
/>
),
duration: null,
});
},
});
startAppListening({
actionCreator: socketBulkDownloadError,
effect: async (action) => {
log.debug(action.payload.data, 'Bulk download preparation failed');
const { bulk_download_item_name } = action.payload.data;
toast({
id: bulk_download_item_name,
title: t('gallery.bulkDownloadFailed'),
status: 'error',
description: action.payload.data.error,
duration: null,
});
},
});
};

View File

@ -15,7 +15,7 @@ import { getPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilder
import { activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
import { stylePresetsApi } from 'services/api/endpoints/stylePresets';
import { utilitiesApi } from 'services/api/endpoints/utilities';
import { socketConnected } from 'services/events/actions';
import { socketConnected } from 'services/events/setEventListeners';
const matcher = isAnyOf(
positivePromptChanged,

View File

@ -6,7 +6,7 @@ import { atom } from 'nanostores';
import { api } from 'services/api';
import { modelsApi } from 'services/api/endpoints/models';
import { queueApi, selectQueueStatus } from 'services/api/endpoints/queue';
import { socketConnected } from 'services/events/actions';
import { socketConnected } from 'services/events/setEventListeners';
const log = logger('socketio');

View File

@ -1,14 +0,0 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { socketDisconnected } from 'services/events/actions';
const log = logger('socketio');
export const addSocketDisconnectedEventListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: socketDisconnected,
effect: () => {
log.debug('Disconnected');
},
});
};

View File

@ -1,34 +0,0 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone';
import { parseify } from 'common/util/serialize';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { socketGeneratorProgress } from 'services/events/actions';
import { $lastCanvasProgressEvent } from 'services/events/setEventListeners';
const log = logger('socketio');
export const addGeneratorProgressEventListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: socketGeneratorProgress,
effect: (action) => {
const { invocation_source_id, invocation, step, total_steps, progress_image, origin } = action.payload.data;
log.trace(parseify(action.payload), `Generator progress (${invocation.type}, ${invocation_source_id})`);
if (origin === 'workflows') {
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
if (nes) {
nes.status = zNodeStatus.enum.IN_PROGRESS;
nes.progress = (step + 1) / total_steps;
nes.progressImage = progress_image ?? null;
upsertExecutionState(nes.nodeId, nes);
}
}
if (origin === 'canvas') {
$lastCanvasProgressEvent.set(action.payload.data);
}
},
});
};

View File

@ -1,132 +0,0 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone';
import { parseify } from 'common/util/serialize';
import { sessionImageStaged } from 'features/controlLayers/store/canvasV2Slice';
import { boardIdSelected, galleryViewChanged, imageSelected, offsetChanged } from 'features/gallery/store/gallerySlice';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { boardsApi } from 'services/api/endpoints/boards';
import { imagesApi } from 'services/api/endpoints/images';
import { getCategories, getListImagesUrl } from 'services/api/util';
import { socketInvocationComplete } from 'services/events/actions';
// These nodes output an image, but do not actually *save* an image, so we don't want to handle the gallery logic on them
const nodeTypeDenylist = ['load_image', 'image'];
const log = logger('socketio');
export const addInvocationCompleteEventListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: socketInvocationComplete,
effect: async (action, { dispatch, getState }) => {
const { data } = action.payload;
log.debug(
{ data: parseify(data) },
`Invocation complete (${data.invocation.type}, ${data.invocation_source_id})`
);
const { result, invocation_source_id } = data;
if (data.origin === 'workflows') {
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
if (nes) {
nes.status = zNodeStatus.enum.COMPLETED;
if (nes.progress !== null) {
nes.progress = 1;
}
nes.outputs.push(result);
upsertExecutionState(nes.nodeId, nes);
}
}
// This complete event has an associated image output
if (
(data.result.type === 'image_output' || data.result.type === 'canvas_v2_mask_and_crop_output') &&
!nodeTypeDenylist.includes(data.invocation.type)
) {
const { image_name } = data.result.image;
const { gallery, canvasV2 } = getState();
// This populates the `getImageDTO` cache
const imageDTORequest = dispatch(
imagesApi.endpoints.getImageDTO.initiate(image_name, {
forceRefetch: true,
})
);
const imageDTO = await imageDTORequest.unwrap();
imageDTORequest.unsubscribe();
// handle tab-specific logic
if (data.origin === 'canvas' && data.invocation_source_id === 'canvas_output') {
if (data.result.type === 'canvas_v2_mask_and_crop_output') {
const { offset_x, offset_y } = data.result;
if (canvasV2.session.isStaging) {
dispatch(sessionImageStaged({ stagingAreaImage: { imageDTO, offsetX: offset_x, offsetY: offset_y } }));
}
} else if (data.result.type === 'image_output') {
if (canvasV2.session.isStaging) {
dispatch(sessionImageStaged({ stagingAreaImage: { imageDTO, offsetX: 0, offsetY: 0 } }));
}
}
}
if (!imageDTO.is_intermediate) {
// update the total images for the board
dispatch(
boardsApi.util.updateQueryData('getBoardImagesTotal', imageDTO.board_id ?? 'none', (draft) => {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
draft.total += 1;
})
);
dispatch(
imagesApi.util.invalidateTags([
{ type: 'Board', id: imageDTO.board_id ?? 'none' },
{
type: 'ImageList',
id: getListImagesUrl({
board_id: imageDTO.board_id ?? 'none',
categories: getCategories(imageDTO),
}),
},
])
);
const { shouldAutoSwitch } = gallery;
// If auto-switch is enabled, select the new image
if (shouldAutoSwitch) {
// if auto-add is enabled, switch the gallery view and board if needed as the image comes in
if (gallery.galleryView !== 'images') {
dispatch(galleryViewChanged('images'));
}
if (imageDTO.board_id && imageDTO.board_id !== gallery.selectedBoardId) {
dispatch(
boardIdSelected({
boardId: imageDTO.board_id,
selectedImageName: imageDTO.image_name,
})
);
}
dispatch(offsetChanged({ offset: 0 }));
if (!imageDTO.board_id && gallery.selectedBoardId !== 'none') {
dispatch(
boardIdSelected({
boardId: 'none',
selectedImageName: imageDTO.image_name,
})
);
}
dispatch(imageSelected(imageDTO));
}
}
}
},
});
};

View File

@ -1,31 +0,0 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone';
import { parseify } from 'common/util/serialize';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { socketInvocationError } from 'services/events/actions';
const log = logger('socketio');
export const addInvocationErrorEventListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: socketInvocationError,
effect: (action) => {
const { invocation_source_id, invocation, error_type, error_message, error_traceback } = action.payload.data;
log.error(parseify(action.payload), `Invocation error (${invocation.type}, ${invocation_source_id})`);
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
if (nes) {
nes.status = zNodeStatus.enum.FAILED;
nes.progress = null;
nes.progressImage = null;
nes.error = {
error_type,
error_message,
error_traceback,
};
upsertExecutionState(nes.nodeId, nes);
}
},
});
};

View File

@ -1,24 +0,0 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone';
import { parseify } from 'common/util/serialize';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { socketInvocationStarted } from 'services/events/actions';
const log = logger('socketio');
export const addInvocationStartedEventListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: socketInvocationStarted,
effect: (action) => {
const { invocation_source_id, invocation } = action.payload.data;
log.debug(parseify(action.payload), `Invocation started (${invocation.type}, ${invocation_source_id})`);
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
if (nes) {
nes.status = zNodeStatus.enum.IN_PROGRESS;
upsertExecutionState(nes.nodeId, nes);
}
},
});
};

View File

@ -1,196 +0,0 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { api, LIST_TAG } from 'services/api';
import { modelsApi } from 'services/api/endpoints/models';
import {
socketModelInstallCancelled,
socketModelInstallComplete,
socketModelInstallDownloadProgress,
socketModelInstallDownloadsComplete,
socketModelInstallDownloadStarted,
socketModelInstallError,
socketModelInstallStarted,
} from 'services/events/actions';
/**
* A model install has two main stages - downloading and installing. All these events are namespaced under `model_install_`
* which is a bit misleading. For example, a `model_install_started` event is actually fired _after_ the model has fully
* downloaded and is being "physically" installed.
*
* Note: the download events are only fired for remote model installs, not local.
*
* Here's the expected flow:
* - API receives install request, model manager preps the install
* - `model_install_download_started` fired when the download starts
* - `model_install_download_progress` fired continually until the download is complete
* - `model_install_download_complete` fired when the download is complete
* - `model_install_started` fired when the "physical" installation starts
* - `model_install_complete` fired when the installation is complete
* - `model_install_cancelled` fired if the installation is cancelled
* - `model_install_error` fired if the installation has an error
*/
const selectModelInstalls = modelsApi.endpoints.listModelInstalls.select();
export const addModelInstallEventListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: socketModelInstallDownloadStarted,
effect: async (action, { dispatch, getState }) => {
const { id } = action.payload.data;
const { data } = selectModelInstalls(getState());
if (!data || !data.find((m) => m.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.status = 'downloading';
}
return draft;
})
);
}
},
});
startAppListening({
actionCreator: socketModelInstallStarted,
effect: async (action, { dispatch, getState }) => {
const { id } = action.payload.data;
const { data } = selectModelInstalls(getState());
if (!data || !data.find((m) => m.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.status = 'running';
}
return draft;
})
);
}
},
});
startAppListening({
actionCreator: socketModelInstallDownloadProgress,
effect: async (action, { dispatch, getState }) => {
const { bytes, total_bytes, id } = action.payload.data;
const { data } = selectModelInstalls(getState());
if (!data || !data.find((m) => m.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.bytes = bytes;
modelImport.total_bytes = total_bytes;
modelImport.status = 'downloading';
}
return draft;
})
);
}
},
});
startAppListening({
actionCreator: socketModelInstallComplete,
effect: (action, { dispatch, getState }) => {
const { id } = action.payload.data;
const { data } = selectModelInstalls(getState());
if (!data || !data.find((m) => m.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.status = 'completed';
}
return draft;
})
);
}
dispatch(api.util.invalidateTags([{ type: 'ModelConfig', id: LIST_TAG }]));
dispatch(api.util.invalidateTags([{ type: 'ModelScanFolderResults', id: LIST_TAG }]));
},
});
startAppListening({
actionCreator: socketModelInstallError,
effect: (action, { dispatch, getState }) => {
const { id, error, error_type } = action.payload.data;
const { data } = selectModelInstalls(getState());
if (!data || !data.find((m) => m.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.status = 'error';
modelImport.error_reason = error_type;
modelImport.error = error;
}
return draft;
})
);
}
},
});
startAppListening({
actionCreator: socketModelInstallCancelled,
effect: (action, { dispatch, getState }) => {
const { id } = action.payload.data;
const { data } = selectModelInstalls(getState());
if (!data || !data.find((m) => m.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.status = 'cancelled';
}
return draft;
})
);
}
},
});
startAppListening({
actionCreator: socketModelInstallDownloadsComplete,
effect: (action, { dispatch, getState }) => {
const { id } = action.payload.data;
const { data } = selectModelInstalls(getState());
if (!data || !data.find((m) => m.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.status = 'downloads_done';
}
return draft;
})
);
}
},
});
};

View File

@ -1,42 +0,0 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { socketModelLoadComplete, socketModelLoadStarted } from 'services/events/actions';
const log = logger('socketio');
export const addModelLoadEventListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: socketModelLoadStarted,
effect: (action) => {
const { config, submodel_type } = action.payload.data;
const { name, base, type } = config;
const extras: string[] = [base, type];
if (submodel_type) {
extras.push(submodel_type);
}
const message = `Model load started: ${name} (${extras.join(', ')})`;
log.debug(action.payload, message);
},
});
startAppListening({
actionCreator: socketModelLoadComplete,
effect: (action) => {
const { config, submodel_type } = action.payload.data;
const { name, base, type } = config;
const extras: string[] = [base, type];
if (submodel_type) {
extras.push(submodel_type);
}
const message = `Model load complete: ${name} (${extras.join(', ')})`;
log.debug(action.payload, message);
},
});
};

View File

@ -1,129 +0,0 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone';
import { $nodeExecutionStates } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import ErrorToastDescription, { getTitleFromErrorType } from 'features/toast/ErrorToastDescription';
import { toast } from 'features/toast/toast';
import { forEach } from 'lodash-es';
import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue';
import { socketQueueItemStatusChanged } from 'services/events/actions';
import { $lastCanvasProgressEvent } from 'services/events/setEventListeners';
const log = logger('socketio');
export const addSocketQueueEventsListeners = (startAppListening: AppStartListening) => {
// When the queue is cleared or canvas batch is canceled, we should clear the last canvas progress event
startAppListening({
matcher: queueApi.endpoints.clearQueue.matchFulfilled,
effect: () => {
$lastCanvasProgressEvent.set(null);
},
});
startAppListening({
actionCreator: socketQueueItemStatusChanged,
effect: (action, { dispatch, getState }) => {
// we've got new status for the queue item, batch and queue
const {
item_id,
session_id,
status,
started_at,
updated_at,
completed_at,
batch_status,
queue_status,
error_type,
error_message,
error_traceback,
origin,
} = action.payload.data;
log.debug(action.payload, `Queue item ${item_id} status updated: ${status}`);
// Update this specific queue item in the list of queue items (this is the queue item DTO, without the session)
dispatch(
queueApi.util.updateQueryData('listQueueItems', undefined, (draft) => {
queueItemsAdapter.updateOne(draft, {
id: String(item_id),
changes: {
status,
started_at,
updated_at: updated_at ?? undefined,
completed_at: completed_at ?? undefined,
error_type,
error_message,
error_traceback,
},
});
})
);
// Update the queue status (we do not get the processor status here)
dispatch(
queueApi.util.updateQueryData('getQueueStatus', undefined, (draft) => {
if (!draft) {
return;
}
Object.assign(draft.queue, queue_status);
})
);
// Update the batch status
dispatch(
queueApi.util.updateQueryData('getBatchStatus', { batch_id: batch_status.batch_id }, () => batch_status)
);
// Invalidate caches for things we cannot update
// TODO: technically, we could possibly update the current session queue item, but feels safer to just request it again
dispatch(
queueApi.util.invalidateTags([
'CurrentSessionQueueItem',
'NextSessionQueueItem',
'InvocationCacheStatus',
{ type: 'SessionQueueItem', id: item_id },
])
);
if (status === 'in_progress') {
forEach($nodeExecutionStates.get(), (nes) => {
if (!nes) {
return;
}
const clone = deepClone(nes);
clone.status = zNodeStatus.enum.PENDING;
clone.error = null;
clone.progress = null;
clone.progressImage = null;
clone.outputs = [];
$nodeExecutionStates.setKey(clone.nodeId, clone);
});
} else if (status === 'failed' && error_type) {
const isLocal = getState().config.isLocal ?? true;
const sessionId = session_id;
if (origin === 'canvas') {
$lastCanvasProgressEvent.set(null);
}
toast({
id: `INVOCATION_ERROR_${error_type}`,
title: getTitleFromErrorType(error_type),
status: 'error',
duration: null,
updateDescription: isLocal,
description: (
<ErrorToastDescription
errorType={error_type}
errorMessage={error_message}
sessionId={sessionId}
isLocal={isLocal}
/>
),
});
} else if (status === 'canceled' && origin === 'canvas') {
$lastCanvasProgressEvent.set(null);
}
},
});
};

View File

@ -21,7 +21,7 @@ export const getImageUsage = (nodes: NodesState, canvasV2: CanvasV2State, image_
some(node.data.inputs, (input) => isImageFieldInputInstance(input) && input.value?.image_name === image_name)
);
const isControlAdapterImage = canvasV2.controlAdapters.entities.some(
const isControlAdapterImage = canvasV2.controlLayers.entities.some(
(ca) => ca.imageObject?.image.image_name === image_name || ca.processedImageObject?.image.image_name === image_name
);

View File

@ -2,25 +2,13 @@ import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import type { LogLevelName } from 'roarr';
import {
socketConnected,
socketDisconnected,
socketGeneratorProgress,
socketInvocationComplete,
socketInvocationStarted,
socketModelLoadComplete,
socketModelLoadStarted,
socketQueueItemStatusChanged,
} from 'services/events/actions';
import type { Language, SystemState } from './types';
const initialSystemState: SystemState = {
_version: 1,
isConnected: false,
shouldConfirmOnDelete: true,
enableImageDebugging: false,
denoiseProgress: null,
shouldAntialiasProgressImage: false,
consoleLogLevel: 'debug',
shouldLogToConsole: true,
@ -28,8 +16,6 @@ const initialSystemState: SystemState = {
shouldUseNSFWChecker: false,
shouldUseWatermarker: false,
shouldEnableInformationalPopovers: true,
status: 'DISCONNECTED',
cancellations: [],
};
export const systemSlice = createSlice({
@ -64,82 +50,6 @@ export const systemSlice = createSlice({
state.shouldEnableInformationalPopovers = action.payload;
},
},
extraReducers(builder) {
/**
* Socket Connected
*/
builder.addCase(socketConnected, (state) => {
state.isConnected = true;
state.denoiseProgress = null;
state.status = 'CONNECTED';
});
/**
* Socket Disconnected
*/
builder.addCase(socketDisconnected, (state) => {
state.isConnected = false;
state.denoiseProgress = null;
state.status = 'DISCONNECTED';
});
/**
* Invocation Started
*/
builder.addCase(socketInvocationStarted, (state) => {
state.cancellations = [];
state.denoiseProgress = null;
state.status = 'PROCESSING';
});
/**
* Generator Progress
*/
builder.addCase(socketGeneratorProgress, (state, action) => {
const { step, total_steps, progress_image, session_id, batch_id, percentage } = action.payload.data;
if (state.cancellations.includes(session_id)) {
// Do not update the progress if this session has been cancelled. This prevents a race condition where we get a
// progress update after the session has been cancelled.
return;
}
state.denoiseProgress = {
step,
total_steps,
percentage,
progress_image,
session_id,
batch_id,
};
state.status = 'PROCESSING';
});
/**
* Invocation Complete
*/
builder.addCase(socketInvocationComplete, (state) => {
state.denoiseProgress = null;
state.status = 'CONNECTED';
});
builder.addCase(socketModelLoadStarted, (state) => {
state.status = 'LOADING_MODEL';
});
builder.addCase(socketModelLoadComplete, (state) => {
state.status = 'CONNECTED';
});
builder.addCase(socketQueueItemStatusChanged, (state, action) => {
if (['completed', 'canceled', 'failed'].includes(action.payload.data.status)) {
state.status = 'CONNECTED';
state.denoiseProgress = null;
state.cancellations.push(action.payload.data.session_id);
}
});
},
});
export const {
@ -168,5 +78,5 @@ export const systemPersistConfig: PersistConfig<SystemState> = {
name: systemSlice.name,
initialState: initialSystemState,
migrate: migrateSystemState,
persistDenylist: ['isConnected', 'denoiseProgress', 'status', 'cancellations'],
persistDenylist: [],
};

View File

@ -1,18 +1,6 @@
import type { LogLevel } from 'app/logging/logger';
import type { ProgressImage } from 'services/events/types';
import { z } from 'zod';
type SystemStatus = 'CONNECTED' | 'DISCONNECTED' | 'PROCESSING' | 'ERROR' | 'LOADING_MODEL';
type DenoiseProgress = {
session_id: string;
batch_id: string;
progress_image: ProgressImage | null | undefined;
step: number;
total_steps: number;
percentage: number;
};
const zLanguage = z.enum([
'ar',
'az',
@ -42,17 +30,13 @@ export const isLanguage = (v: unknown): v is Language => zLanguage.safeParse(v).
export interface SystemState {
_version: 1;
isConnected: boolean;
shouldConfirmOnDelete: boolean;
enableImageDebugging: boolean;
denoiseProgress: DenoiseProgress | null;
consoleLogLevel: LogLevel;
shouldLogToConsole: boolean;
shouldAntialiasProgressImage: boolean;
language: Language;
shouldUseNSFWChecker: boolean;
shouldUseWatermarker: boolean;
status: SystemStatus;
shouldEnableInformationalPopovers: boolean;
cancellations: string[];
}

View File

@ -1,66 +0,0 @@
import { createAction } from '@reduxjs/toolkit';
import type {
BatchEnqueuedEvent,
BulkDownloadCompleteEvent,
BulkDownloadFailedEvent,
BulkDownloadStartedEvent,
DownloadCancelledEvent,
DownloadCompleteEvent,
DownloadErrorEvent,
DownloadProgressEvent,
DownloadStartedEvent,
InvocationCompleteEvent,
InvocationDenoiseProgressEvent,
InvocationErrorEvent,
InvocationStartedEvent,
ModelInstallCancelledEvent,
ModelInstallCompleteEvent,
ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent,
ModelInstallDownloadStartedEvent,
ModelInstallErrorEvent,
ModelInstallStartedEvent,
ModelLoadCompleteEvent,
ModelLoadStartedEvent,
QueueClearedEvent,
QueueItemStatusChangedEvent,
} from 'services/events/types';
const createSocketAction = <T = undefined>(name: string) =>
createAction<T extends undefined ? void : { data: T }>(`socket/${name}`);
export const socketConnected = createSocketAction('Connected');
export const socketDisconnected = createSocketAction('Disconnected');
export const socketInvocationStarted = createSocketAction<InvocationStartedEvent>('InvocationStartedEvent');
export const socketInvocationComplete = createSocketAction<InvocationCompleteEvent>('InvocationCompleteEvent');
export const socketInvocationError = createSocketAction<InvocationErrorEvent>('InvocationErrorEvent');
export const socketGeneratorProgress = createSocketAction<InvocationDenoiseProgressEvent>(
'InvocationDenoiseProgressEvent'
);
export const socketModelLoadStarted = createSocketAction<ModelLoadStartedEvent>('ModelLoadStartedEvent');
export const socketModelLoadComplete = createSocketAction<ModelLoadCompleteEvent>('ModelLoadCompleteEvent');
export const socketDownloadStarted = createSocketAction<DownloadStartedEvent>('DownloadStartedEvent');
export const socketDownloadProgress = createSocketAction<DownloadProgressEvent>('DownloadProgressEvent');
export const socketDownloadComplete = createSocketAction<DownloadCompleteEvent>('DownloadCompleteEvent');
export const socketDownloadCancelled = createSocketAction<DownloadCancelledEvent>('DownloadCancelledEvent');
export const socketDownloadError = createSocketAction<DownloadErrorEvent>('DownloadErrorEvent');
export const socketModelInstallStarted = createSocketAction<ModelInstallStartedEvent>('ModelInstallStartedEvent');
export const socketModelInstallDownloadProgress = createSocketAction<ModelInstallDownloadProgressEvent>(
'ModelInstallDownloadProgressEvent'
);
export const socketModelInstallDownloadStarted = createSocketAction<ModelInstallDownloadStartedEvent>(
'ModelInstallDownloadStartedEvent'
);
export const socketModelInstallDownloadsComplete = createSocketAction<ModelInstallDownloadsCompleteEvent>(
'ModelInstallDownloadsCompleteEvent'
);
export const socketModelInstallComplete = createSocketAction<ModelInstallCompleteEvent>('ModelInstallCompleteEvent');
export const socketModelInstallError = createSocketAction<ModelInstallErrorEvent>('ModelInstallErrorEvent');
export const socketModelInstallCancelled = createSocketAction<ModelInstallCancelledEvent>('ModelInstallCancelledEvent');
export const socketQueueItemStatusChanged =
createSocketAction<QueueItemStatusChangedEvent>('QueueItemStatusChangedEvent');
export const socketQueueCleared = createSocketAction<QueueClearedEvent>('QueueClearedEvent');
export const socketBatchEnqueued = createSocketAction<BatchEnqueuedEvent>('BatchEnqueuedEvent');
export const socketBulkDownloadStarted = createSocketAction<BulkDownloadStartedEvent>('BulkDownloadStartedEvent');
export const socketBulkDownloadComplete = createSocketAction<BulkDownloadCompleteEvent>('BulkDownloadCompleteEvent');
export const socketBulkDownloadError = createSocketAction<BulkDownloadFailedEvent>('BulkDownloadFailedEvent');

View File

@ -1,4 +1,5 @@
import { ExternalLink } from '@invoke-ai/ui-library';
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import { $baseUrl } from 'app/store/nanostores/baseUrl';
import { $bulkDownloadId } from 'app/store/nanostores/bulkDownloadId';
@ -21,10 +22,11 @@ import { imagesApi } from 'services/api/endpoints/images';
import { modelsApi } from 'services/api/endpoints/models';
import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue';
import { getCategories, getListImagesUrl } from 'services/api/util';
import { socketConnected } from 'services/events/actions';
import type { ClientToServerEvents, InvocationDenoiseProgressEvent, ServerToClientEvents } from 'services/events/types';
import type { Socket } from 'socket.io-client';
export const socketConnected = createAction('socket/connected');
const log = logger('socketio');
type SetEventListenersArg = {