feat(ui): move socket event handling out of redux

Download events and invocation status events (including progress images) are very frequent. There's no real need for these to pass through redux. Handling them outside redux is a significant performance win - far fewer store subscription calls, far fewer trips through middleware.

All event handling is moved outside middleware. Cleanup of unused actions and listeners to follow.
This commit is contained in:
psychedelicious 2024-08-17 15:30:55 +10:00
parent 29ac1b5e01
commit b630dbdf20
31 changed files with 809 additions and 301 deletions

View File

@ -2,7 +2,7 @@ import { useStore } from '@nanostores/react';
import { $authToken } from 'app/store/nanostores/authToken'; import { $authToken } from 'app/store/nanostores/authToken';
import { $baseUrl } from 'app/store/nanostores/baseUrl'; import { $baseUrl } from 'app/store/nanostores/baseUrl';
import { $isDebugging } from 'app/store/nanostores/isDebugging'; import { $isDebugging } from 'app/store/nanostores/isDebugging';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppStore } from 'app/store/nanostores/store';
import type { MapStore } from 'nanostores'; import type { MapStore } from 'nanostores';
import { atom, map } from 'nanostores'; import { atom, map } from 'nanostores';
import { useEffect, useMemo } from 'react'; import { useEffect, useMemo } from 'react';
@ -28,13 +28,15 @@ export const getSocket = () => {
return socket; return socket;
}; };
export const $socketOptions = map<Partial<ManagerOptions & SocketOptions>>({}); export const $socketOptions = map<Partial<ManagerOptions & SocketOptions>>({});
const $isSocketInitialized = atom<boolean>(false); const $isSocketInitialized = atom<boolean>(false);
export const $isConnected = atom<boolean>(false);
/** /**
* Initializes the socket.io connection and sets up event listeners. * Initializes the socket.io connection and sets up event listeners.
*/ */
export const useSocketIO = () => { export const useSocketIO = () => {
const dispatch = useAppDispatch(); const { dispatch, getState } = useAppStore();
const baseUrl = useStore($baseUrl); const baseUrl = useStore($baseUrl);
const authToken = useStore($authToken); const authToken = useStore($authToken);
const addlSocketOptions = useStore($socketOptions); const addlSocketOptions = useStore($socketOptions);
@ -72,7 +74,7 @@ export const useSocketIO = () => {
const socket: AppSocket = io(socketUrl, socketOptions); const socket: AppSocket = io(socketUrl, socketOptions);
$socket.set(socket); $socket.set(socket);
setEventListeners({ dispatch, socket }); setEventListeners({ socket, dispatch, getState, setIsConnected: $isConnected.set });
socket.connect(); socket.connect();
if ($isDebugging.get() || import.meta.env.MODE === 'development') { if ($isDebugging.get() || import.meta.env.MODE === 'development') {
@ -94,5 +96,5 @@ export const useSocketIO = () => {
socket.disconnect(); socket.disconnect();
$isSocketInitialized.set(false); $isSocketInitialized.set(false);
}; };
}, [dispatch, socketOptions, socketUrl]); }, [dispatch, getState, socketOptions, socketUrl]);
}; };

View File

@ -1,7 +1,6 @@
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { import {
$lastProgressEvent,
rasterLayerAdded, rasterLayerAdded,
sessionStagingAreaImageAccepted, sessionStagingAreaImageAccepted,
sessionStagingAreaReset, sessionStagingAreaReset,
@ -11,6 +10,7 @@ import { imageDTOToImageObject } from 'features/controlLayers/store/types';
import { toast } from 'features/toast/toast'; import { toast } from 'features/toast/toast';
import { t } from 'i18next'; import { t } from 'i18next';
import { queueApi } from 'services/api/endpoints/queue'; import { queueApi } from 'services/api/endpoints/queue';
import { $lastCanvasProgressEvent } from 'services/events/setEventListeners';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
export const addStagingListeners = (startAppListening: AppStartListening) => { export const addStagingListeners = (startAppListening: AppStartListening) => {
@ -29,7 +29,7 @@ export const addStagingListeners = (startAppListening: AppStartListening) => {
const { canceled } = await req.unwrap(); const { canceled } = await req.unwrap();
req.reset(); req.reset();
$lastProgressEvent.set(null); $lastCanvasProgressEvent.set(null);
if (canceled > 0) { if (canceled > 0) {
log.debug(`Canceled ${canceled} canvas batches`); log.debug(`Canceled ${canceled} canvas batches`);

View File

@ -2,10 +2,10 @@ import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone'; import { deepClone } from 'common/util/deepClone';
import { parseify } from 'common/util/serialize'; import { parseify } from 'common/util/serialize';
import { $lastProgressEvent } from 'features/controlLayers/store/canvasV2Slice';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState'; import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation'; import { zNodeStatus } from 'features/nodes/types/invocation';
import { socketGeneratorProgress } from 'services/events/actions'; import { socketGeneratorProgress } from 'services/events/actions';
import { $lastCanvasProgressEvent } from 'services/events/setEventListeners';
const log = logger('socketio'); const log = logger('socketio');
@ -27,7 +27,7 @@ export const addGeneratorProgressEventListener = (startAppListening: AppStartLis
} }
if (origin === 'canvas') { if (origin === 'canvas') {
$lastProgressEvent.set(action.payload.data); $lastCanvasProgressEvent.set(action.payload.data);
} }
}, },
}); });

View File

@ -1,7 +1,6 @@
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone'; import { deepClone } from 'common/util/deepClone';
import { $lastProgressEvent } from 'features/controlLayers/store/canvasV2Slice';
import { $nodeExecutionStates } from 'features/nodes/hooks/useExecutionState'; import { $nodeExecutionStates } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation'; import { zNodeStatus } from 'features/nodes/types/invocation';
import ErrorToastDescription, { getTitleFromErrorType } from 'features/toast/ErrorToastDescription'; import ErrorToastDescription, { getTitleFromErrorType } from 'features/toast/ErrorToastDescription';
@ -9,6 +8,7 @@ import { toast } from 'features/toast/toast';
import { forEach } from 'lodash-es'; import { forEach } from 'lodash-es';
import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue'; import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue';
import { socketQueueItemStatusChanged } from 'services/events/actions'; import { socketQueueItemStatusChanged } from 'services/events/actions';
import { $lastCanvasProgressEvent } from 'services/events/setEventListeners';
const log = logger('socketio'); const log = logger('socketio');
@ -17,13 +17,13 @@ export const addSocketQueueEventsListeners = (startAppListening: AppStartListeni
startAppListening({ startAppListening({
matcher: queueApi.endpoints.clearQueue.matchFulfilled, matcher: queueApi.endpoints.clearQueue.matchFulfilled,
effect: () => { effect: () => {
$lastProgressEvent.set(null); $lastCanvasProgressEvent.set(null);
}, },
}); });
startAppListening({ startAppListening({
actionCreator: socketQueueItemStatusChanged, actionCreator: socketQueueItemStatusChanged,
effect: async (action, { dispatch, getState }) => { effect: (action, { dispatch, getState }) => {
// we've got new status for the queue item, batch and queue // we've got new status for the queue item, batch and queue
const { const {
item_id, item_id,
@ -103,7 +103,7 @@ export const addSocketQueueEventsListeners = (startAppListening: AppStartListeni
const isLocal = getState().config.isLocal ?? true; const isLocal = getState().config.isLocal ?? true;
const sessionId = session_id; const sessionId = session_id;
if (origin === 'canvas') { if (origin === 'canvas') {
$lastProgressEvent.set(null); $lastCanvasProgressEvent.set(null);
} }
toast({ toast({
@ -122,7 +122,7 @@ export const addSocketQueueEventsListeners = (startAppListening: AppStartListeni
), ),
}); });
} else if (status === 'canceled' && origin === 'canvas') { } else if (status === 'canceled' && origin === 'canvas') {
$lastProgressEvent.set(null); $lastCanvasProgressEvent.set(null);
} }
}, },
}); });

View File

@ -1,4 +1,5 @@
import { useStore } from '@nanostores/react'; import { useStore } from '@nanostores/react';
import { $isConnected } from 'app/hooks/useSocketIO';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { selectCanvasV2Slice } from 'features/controlLayers/store/canvasV2Slice'; import { selectCanvasV2Slice } from 'features/controlLayers/store/canvasV2Slice';
@ -25,7 +26,7 @@ const LAYER_TYPE_TO_TKEY = {
control_layer: 'controlLayers.globalControlAdapter', control_layer: 'controlLayers.globalControlAdapter',
} as const; } as const;
const createSelector = (templates: Templates) => const createSelector = (templates: Templates, isConnected: boolean) =>
createMemoizedSelector( createMemoizedSelector(
[ [
selectSystemSlice, selectSystemSlice,
@ -41,8 +42,6 @@ const createSelector = (templates: Templates) =>
const { bbox } = canvasV2; const { bbox } = canvasV2;
const { model, positivePrompt } = canvasV2.params; const { model, positivePrompt } = canvasV2.params;
const { isConnected } = system;
const reasons: { prefix?: string; content: string }[] = []; const reasons: { prefix?: string; content: string }[] = [];
// Cannot generate if not connected // Cannot generate if not connected
@ -240,7 +239,8 @@ const createSelector = (templates: Templates) =>
export const useIsReadyToEnqueue = () => { export const useIsReadyToEnqueue = () => {
const templates = useStore($templates); const templates = useStore($templates);
const selector = useMemo(() => createSelector(templates), [templates]); const isConnected = useStore($isConnected)
const selector = useMemo(() => createSelector(templates, isConnected), [templates, isConnected]);
const value = useAppSelector(selector); const value = useAppSelector(selector);
return value; return value;
}; };

View File

@ -3,3 +3,8 @@ type JSONValue = string | number | boolean | null | JSONValue[] | { [key: string
export interface JSONObject { export interface JSONObject {
[k: string]: JSONValue; [k: string]: JSONValue;
} }
type SerializableValue = string | number | boolean | null | undefined | SerializableValue[] | SerializableObject;
export type SerializableObject = {
[k: string | number]: SerializableValue;
};

View File

@ -1,5 +1,7 @@
import { Flex, useShiftModifier } from '@invoke-ai/ui-library'; import { Flex, useShiftModifier } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { skipToken } from '@reduxjs/toolkit/query'; import { skipToken } from '@reduxjs/toolkit/query';
import { $isConnected } from 'app/hooks/useSocketIO';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIDndImage from 'common/components/IAIDndImage'; import IAIDndImage from 'common/components/IAIDndImage';
import IAIDndImageIcon from 'common/components/IAIDndImageIcon'; import IAIDndImageIcon from 'common/components/IAIDndImageIcon';
@ -22,14 +24,17 @@ type Props = {
postUploadAction: PostUploadAction; postUploadAction: PostUploadAction;
}; };
export const IPAdapterImagePreview = memo(({ image, onChangeImage, ipAdapterId, droppableData, postUploadAction }: Props) => { export const IPAdapterImagePreview = memo(
({ image, onChangeImage, ipAdapterId, droppableData, postUploadAction }: Props) => {
const { t } = useTranslation(); const { t } = useTranslation();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const isConnected = useAppSelector((s) => s.system.isConnected); const isConnected = useStore($isConnected);
const optimalDimension = useAppSelector(selectOptimalDimension); const optimalDimension = useAppSelector(selectOptimalDimension);
const shift = useShiftModifier(); const shift = useShiftModifier();
const { currentData: controlImage, isError: isErrorControlImage } = useGetImageDTOQuery(image?.image_name ?? skipToken); const { currentData: controlImage, isError: isErrorControlImage } = useGetImageDTOQuery(
image?.image_name ?? skipToken
);
const handleResetControlImage = useCallback(() => { const handleResetControlImage = useCallback(() => {
onChangeImage(null); onChangeImage(null);
}, [onChangeImage]); }, [onChangeImage]);
@ -89,12 +94,15 @@ export const IPAdapterImagePreview = memo(({ image, onChangeImage, ipAdapterId,
<IAIDndImageIcon <IAIDndImageIcon
onClick={handleSetControlImageToDimensions} onClick={handleSetControlImageToDimensions}
icon={<PiRulerBold size={16} />} icon={<PiRulerBold size={16} />}
tooltip={shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')} tooltip={
shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')
}
/> />
</Flex> </Flex>
)} )}
</Flex> </Flex>
); );
}); }
);
IPAdapterImagePreview.displayName = 'IPAdapterImagePreview'; IPAdapterImagePreview.displayName = 'IPAdapterImagePreview';

View File

@ -48,7 +48,7 @@ export class CanvasProgressImage {
image: null, image: null,
}; };
this.manager.stateApi.$lastProgressEvent.listen((event) => { this.manager.stateApi.$lastCanvasProgressEvent.listen((event) => {
this.lastProgressEvent = event; this.lastProgressEvent = event;
this.render(); this.render();
}); });

View File

@ -76,7 +76,7 @@ export class CanvasStagingArea {
if (!this.image.isLoading && !this.image.isError) { if (!this.image.isLoading && !this.image.isError) {
await this.image.updateImageSource(imageDTO.image_name); await this.image.updateImageSource(imageDTO.image_name);
this.manager.stateApi.$lastProgressEvent.set(null); this.manager.stateApi.$lastCanvasProgressEvent.set(null);
} }
this.image.konva.group.visible(shouldShowStagedImage); this.image.konva.group.visible(shouldShowStagedImage);
} else { } else {

View File

@ -11,7 +11,6 @@ import {
$lastAddedPoint, $lastAddedPoint,
$lastCursorPos, $lastCursorPos,
$lastMouseDownPos, $lastMouseDownPos,
$lastProgressEvent,
$shouldShowStagedImage, $shouldShowStagedImage,
$spaceKey, $spaceKey,
$stageAttrs, $stageAttrs,
@ -51,6 +50,7 @@ import type {
import { RGBA_RED } from 'features/controlLayers/store/types'; import { RGBA_RED } from 'features/controlLayers/store/types';
import type { WritableAtom } from 'nanostores'; import type { WritableAtom } from 'nanostores';
import { atom } from 'nanostores'; import { atom } from 'nanostores';
import { $lastCanvasProgressEvent } from 'services/events/setEventListeners';
type EntityStateAndAdapter = type EntityStateAndAdapter =
| { | {
@ -263,7 +263,7 @@ export class CanvasStateApi {
$lastAddedPoint = $lastAddedPoint; $lastAddedPoint = $lastAddedPoint;
$lastMouseDownPos = $lastMouseDownPos; $lastMouseDownPos = $lastMouseDownPos;
$lastCursorPos = $lastCursorPos; $lastCursorPos = $lastCursorPos;
$lastProgressEvent = $lastProgressEvent; $lastCanvasProgressEvent = $lastCanvasProgressEvent;
$spaceKey = $spaceKey; $spaceKey = $spaceKey;
$altKey = $alt; $altKey = $alt;
$ctrlKey = $ctrl; $ctrlKey = $ctrl;

View File

@ -20,7 +20,6 @@ import { initialAspectRatioState } from 'features/parameters/components/Document
import { getOptimalDimension } from 'features/parameters/util/optimalDimension'; import { getOptimalDimension } from 'features/parameters/util/optimalDimension';
import { isEqual, pick } from 'lodash-es'; import { isEqual, pick } from 'lodash-es';
import { atom } from 'nanostores'; import { atom } from 'nanostores';
import type { InvocationDenoiseProgressEvent } from 'services/events/types';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
import type { import type {
@ -622,7 +621,6 @@ export const $stageAttrs = atom<StageAttrs>({
scale: 0, scale: 0,
}); });
export const $shouldShowStagedImage = atom(true); export const $shouldShowStagedImage = atom(true);
export const $lastProgressEvent = atom<InvocationDenoiseProgressEvent | null>(null);
export const $isDrawing = atom<boolean>(false); export const $isDrawing = atom<boolean>(false);
export const $isMouseDown = atom<boolean>(false); export const $isMouseDown = atom<boolean>(false);
export const $lastAddedPoint = atom<Coordinate | null>(null); export const $lastAddedPoint = atom<Coordinate | null>(null);

View File

@ -1,5 +1,7 @@
import type { IconButtonProps } from '@invoke-ai/ui-library'; import type { IconButtonProps } from '@invoke-ai/ui-library';
import { IconButton } from '@invoke-ai/ui-library'; import { IconButton } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { $isConnected } from 'app/hooks/useSocketIO';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { memo } from 'react'; import { memo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
@ -12,7 +14,7 @@ type DeleteImageButtonProps = Omit<IconButtonProps, 'aria-label'> & {
export const DeleteImageButton = memo((props: DeleteImageButtonProps) => { export const DeleteImageButton = memo((props: DeleteImageButtonProps) => {
const { onClick, isDisabled } = props; const { onClick, isDisabled } = props;
const { t } = useTranslation(); const { t } = useTranslation();
const isConnected = useAppSelector((s) => s.system.isConnected); const isConnected = useStore($isConnected);
const imageSelectionLength: number = useAppSelector((s) => s.gallery.selection.length); const imageSelectionLength: number = useAppSelector((s) => s.gallery.selection.length);
const labelMessage: string = `${t('gallery.deleteImage', { count: imageSelectionLength })} (Del)`; const labelMessage: string = `${t('gallery.deleteImage', { count: imageSelectionLength })} (Del)`;

View File

@ -1,7 +1,7 @@
import { ButtonGroup, IconButton, Menu, MenuButton, MenuList } from '@invoke-ai/ui-library'; import { ButtonGroup, IconButton, Menu, MenuButton, MenuList } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react'; import { useStore } from '@nanostores/react';
import { createSelector } from '@reduxjs/toolkit';
import { skipToken } from '@reduxjs/toolkit/query'; import { skipToken } from '@reduxjs/toolkit/query';
import { $isConnected } from 'app/hooks/useSocketIO';
import { adHocPostProcessingRequested } from 'app/store/middleware/listenerMiddleware/listeners/addAdHocPostProcessingRequestedListener'; import { adHocPostProcessingRequested } from 'app/store/middleware/listenerMiddleware/listeners/addAdHocPostProcessingRequestedListener';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { DeleteImageButton } from 'features/deleteImageModal/components/DeleteImageButton'; import { DeleteImageButton } from 'features/deleteImageModal/components/DeleteImageButton';
@ -10,17 +10,15 @@ import SingleSelectionMenuItems from 'features/gallery/components/ImageContextMe
import { useImageActions } from 'features/gallery/hooks/useImageActions'; import { useImageActions } from 'features/gallery/hooks/useImageActions';
import { sentImageToImg2Img } from 'features/gallery/store/actions'; import { sentImageToImg2Img } from 'features/gallery/store/actions';
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors'; import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
import { selectGallerySlice } from 'features/gallery/store/gallerySlice';
import { parseAndRecallImageDimensions } from 'features/metadata/util/handlers'; import { parseAndRecallImageDimensions } from 'features/metadata/util/handlers';
import { $templates } from 'features/nodes/store/nodesSlice'; import { $templates } from 'features/nodes/store/nodesSlice';
import { PostProcessingPopover } from 'features/parameters/components/PostProcessing/PostProcessingPopover'; import { PostProcessingPopover } from 'features/parameters/components/PostProcessing/PostProcessingPopover';
import { useIsQueueMutationInProgress } from 'features/queue/hooks/useIsQueueMutationInProgress'; import { useIsQueueMutationInProgress } from 'features/queue/hooks/useIsQueueMutationInProgress';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { selectSystemSlice } from 'features/system/store/systemSlice';
import { setActiveTab } from 'features/ui/store/uiSlice'; import { setActiveTab } from 'features/ui/store/uiSlice';
import { useGetAndLoadEmbeddedWorkflow } from 'features/workflowLibrary/hooks/useGetAndLoadEmbeddedWorkflow'; import { useGetAndLoadEmbeddedWorkflow } from 'features/workflowLibrary/hooks/useGetAndLoadEmbeddedWorkflow';
import { size } from 'lodash-es'; import { size } from 'lodash-es';
import { memo, useCallback } from 'react'; import { memo, useCallback, useMemo } from 'react';
import { useHotkeys } from 'react-hotkeys-hook'; import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { import {
@ -33,23 +31,17 @@ import {
PiRulerBold, PiRulerBold,
} from 'react-icons/pi'; } from 'react-icons/pi';
import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { $progressImage } from 'services/events/setEventListeners';
const selectShouldDisableToolbarButtons = createSelector(
selectSystemSlice,
selectGallerySlice,
selectLastSelectedImage,
(system, gallery, lastSelectedImage) => {
const hasProgressImage = Boolean(system.denoiseProgress?.progress_image);
return hasProgressImage || !lastSelectedImage;
}
);
const CurrentImageButtons = () => { const CurrentImageButtons = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const isConnected = useAppSelector((s) => s.system.isConnected); const isConnected = useStore($isConnected);
const lastSelectedImage = useAppSelector(selectLastSelectedImage); const lastSelectedImage = useAppSelector(selectLastSelectedImage);
const progressImage = useStore($progressImage);
const selection = useAppSelector((s) => s.gallery.selection); const selection = useAppSelector((s) => s.gallery.selection);
const shouldDisableToolbarButtons = useAppSelector(selectShouldDisableToolbarButtons); const shouldDisableToolbarButtons = useMemo(() => {
return Boolean(progressImage) || !lastSelectedImage;
}, [lastSelectedImage, progressImage]);
const templates = useStore($templates); const templates = useStore($templates);
const isUpscalingEnabled = useFeatureStatus('upscaling'); const isUpscalingEnabled = useFeatureStatus('upscaling');
const isQueueMutationInProgress = useIsQueueMutationInProgress(); const isQueueMutationInProgress = useIsQueueMutationInProgress();

View File

@ -1,4 +1,5 @@
import { Box, Flex } from '@invoke-ai/ui-library'; import { Box, Flex } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { skipToken } from '@reduxjs/toolkit/query'; import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
@ -14,6 +15,7 @@ import { memo, useCallback, useMemo, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { PiImageBold } from 'react-icons/pi'; import { PiImageBold } from 'react-icons/pi';
import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { $hasProgress } from 'services/events/setEventListeners';
import ProgressImage from './ProgressImage'; import ProgressImage from './ProgressImage';
@ -26,7 +28,7 @@ const CurrentImagePreview = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const shouldShowImageDetails = useAppSelector((s) => s.ui.shouldShowImageDetails); const shouldShowImageDetails = useAppSelector((s) => s.ui.shouldShowImageDetails);
const imageName = useAppSelector(selectLastSelectedImageName); const imageName = useAppSelector(selectLastSelectedImageName);
const hasDenoiseProgress = useAppSelector((s) => Boolean(s.system.denoiseProgress)); const hasDenoiseProgress = useStore($hasProgress);
const shouldShowProgressInViewer = useAppSelector((s) => s.ui.shouldShowProgressInViewer); const shouldShowProgressInViewer = useAppSelector((s) => s.ui.shouldShowProgressInViewer);
const { currentData: imageDTO } = useGetImageDTOQuery(imageName ?? skipToken); const { currentData: imageDTO } = useGetImageDTOQuery(imageName ?? skipToken);

View File

@ -1,10 +1,12 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library'; import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Image } from '@invoke-ai/ui-library'; import { Image } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { memo, useMemo } from 'react'; import { memo, useMemo } from 'react';
import { $progressImage } from 'services/events/setEventListeners';
const CurrentImagePreview = () => { const CurrentImagePreview = () => {
const progress_image = useAppSelector((s) => s.system.denoiseProgress?.progress_image); const progressImage = useStore($progressImage);
const shouldAntialiasProgressImage = useAppSelector((s) => s.system.shouldAntialiasProgressImage); const shouldAntialiasProgressImage = useAppSelector((s) => s.system.shouldAntialiasProgressImage);
const sx = useMemo<SystemStyleObject>( const sx = useMemo<SystemStyleObject>(
@ -14,15 +16,15 @@ const CurrentImagePreview = () => {
[shouldAntialiasProgressImage] [shouldAntialiasProgressImage]
); );
if (!progress_image) { if (!progressImage) {
return null; return null;
} }
return ( return (
<Image <Image
src={progress_image.dataURL} src={progressImage.dataURL}
width={progress_image.width} width={progressImage.width}
height={progress_image.height} height={progressImage.height}
draggable={false} draggable={false}
data-testid="progress-image" data-testid="progress-image"
objectFit="contain" objectFit="contain"

View File

@ -1,36 +1,33 @@
import { Flex, Image, Text } from '@invoke-ai/ui-library'; import { Flex, Image, Text } from '@invoke-ai/ui-library';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import IAIDndImage from 'common/components/IAIDndImage'; import IAIDndImage from 'common/components/IAIDndImage';
import { IAINoContentFallback } from 'common/components/IAIImageFallback'; import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import NextPrevImageButtons from 'features/gallery/components/NextPrevImageButtons'; import NextPrevImageButtons from 'features/gallery/components/NextPrevImageButtons';
import { selectGallerySlice } from 'features/gallery/store/gallerySlice';
import NodeWrapper from 'features/nodes/components/flow/nodes/common/NodeWrapper'; import NodeWrapper from 'features/nodes/components/flow/nodes/common/NodeWrapper';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants'; import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import { selectSystemSlice } from 'features/system/store/systemSlice';
import type { AnimationProps } from 'framer-motion'; import type { AnimationProps } from 'framer-motion';
import { motion } from 'framer-motion'; import { motion } from 'framer-motion';
import type { CSSProperties, PropsWithChildren } from 'react'; import type { CSSProperties, PropsWithChildren } from 'react';
import { memo, useCallback, useState } from 'react'; import { memo, useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import type { NodeProps } from 'reactflow'; import type { NodeProps } from 'reactflow';
import { $lastProgressEvent } from 'services/events/setEventListeners';
const selector = createMemoizedSelector(selectSystemSlice, selectGallerySlice, (system, gallery) => {
const imageDTO = gallery.selection[gallery.selection.length - 1];
return {
imageDTO,
progressImage: system.denoiseProgress?.progress_image,
};
});
const CurrentImageNode = (props: NodeProps) => { const CurrentImageNode = (props: NodeProps) => {
const { progressImage, imageDTO } = useAppSelector(selector); const imageDTO = useAppSelector((s) => s.gallery.selection[s.gallery.selection.length - 1]);
const lastProgressEvent = useStore($lastProgressEvent);
if (progressImage) { if (lastProgressEvent?.progress_image) {
return ( return (
<Wrapper nodeProps={props}> <Wrapper nodeProps={props}>
<Image src={progressImage.dataURL} w="full" h="full" objectFit="contain" borderRadius="base" /> <Image
src={lastProgressEvent?.progress_image.dataURL}
w="full"
h="full"
objectFit="contain"
borderRadius="base"
/>
</Wrapper> </Wrapper>
); );
} }

View File

@ -1,6 +1,8 @@
import { Flex, Text } from '@invoke-ai/ui-library'; import { Flex, Text } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { skipToken } from '@reduxjs/toolkit/query'; import { skipToken } from '@reduxjs/toolkit/query';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { $isConnected } from 'app/hooks/useSocketIO';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIDndImage from 'common/components/IAIDndImage'; import IAIDndImage from 'common/components/IAIDndImage';
import IAIDndImageIcon from 'common/components/IAIDndImageIcon'; import IAIDndImageIcon from 'common/components/IAIDndImageIcon';
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types'; import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
@ -17,7 +19,7 @@ import type { FieldComponentProps } from './types';
const ImageFieldInputComponent = (props: FieldComponentProps<ImageFieldInputInstance, ImageFieldInputTemplate>) => { const ImageFieldInputComponent = (props: FieldComponentProps<ImageFieldInputInstance, ImageFieldInputTemplate>) => {
const { nodeId, field } = props; const { nodeId, field } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const isConnected = useAppSelector((s) => s.system.isConnected); const isConnected = useStore($isConnected);
const { currentData: imageDTO, isError } = useGetImageDTOQuery(field.value?.image_name ?? skipToken); const { currentData: imageDTO, isError } = useGetImageDTOQuery(field.value?.image_name ?? skipToken);
const handleReset = useCallback(() => { const handleReset = useCallback(() => {

View File

@ -1,11 +1,12 @@
import { useAppSelector } from 'app/store/storeHooks'; import { useStore } from '@nanostores/react';
import { $isConnected } from 'app/hooks/useSocketIO';
import { toast } from 'features/toast/toast'; import { toast } from 'features/toast/toast';
import { useCallback } from 'react'; import { useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useCancelByBatchIdsMutation, useGetBatchStatusQuery } from 'services/api/endpoints/queue'; import { useCancelByBatchIdsMutation, useGetBatchStatusQuery } from 'services/api/endpoints/queue';
export const useCancelBatch = (batch_id: string) => { export const useCancelBatch = (batch_id: string) => {
const isConnected = useAppSelector((s) => s.system.isConnected); const isConnected = useStore($isConnected);
const { isCanceled } = useGetBatchStatusQuery( const { isCanceled } = useGetBatchStatusQuery(
{ batch_id }, { batch_id },
{ {

View File

@ -1,4 +1,5 @@
import { useAppSelector } from 'app/store/storeHooks'; import { useStore } from '@nanostores/react';
import { $isConnected } from 'app/hooks/useSocketIO';
import { toast } from 'features/toast/toast'; import { toast } from 'features/toast/toast';
import { isNil } from 'lodash-es'; import { isNil } from 'lodash-es';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
@ -6,7 +7,7 @@ import { useTranslation } from 'react-i18next';
import { useCancelQueueItemMutation, useGetQueueStatusQuery } from 'services/api/endpoints/queue'; import { useCancelQueueItemMutation, useGetQueueStatusQuery } from 'services/api/endpoints/queue';
export const useCancelCurrentQueueItem = () => { export const useCancelCurrentQueueItem = () => {
const isConnected = useAppSelector((s) => s.system.isConnected); const isConnected = useStore($isConnected);
const { data: queueStatus } = useGetQueueStatusQuery(); const { data: queueStatus } = useGetQueueStatusQuery();
const [trigger, { isLoading }] = useCancelQueueItemMutation(); const [trigger, { isLoading }] = useCancelQueueItemMutation();
const { t } = useTranslation(); const { t } = useTranslation();

View File

@ -1,11 +1,12 @@
import { useAppSelector } from 'app/store/storeHooks'; import { useStore } from '@nanostores/react';
import { $isConnected } from 'app/hooks/useSocketIO';
import { toast } from 'features/toast/toast'; import { toast } from 'features/toast/toast';
import { useCallback } from 'react'; import { useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useCancelQueueItemMutation } from 'services/api/endpoints/queue'; import { useCancelQueueItemMutation } from 'services/api/endpoints/queue';
export const useCancelQueueItem = (item_id: number) => { export const useCancelQueueItem = (item_id: number) => {
const isConnected = useAppSelector((s) => s.system.isConnected); const isConnected = useStore($isConnected);
const [trigger, { isLoading }] = useCancelQueueItemMutation(); const [trigger, { isLoading }] = useCancelQueueItemMutation();
const { t } = useTranslation(); const { t } = useTranslation();
const cancelQueueItem = useCallback(async () => { const cancelQueueItem = useCallback(async () => {

View File

@ -1,4 +1,5 @@
import { useAppSelector } from 'app/store/storeHooks'; import { useStore } from '@nanostores/react';
import { $isConnected } from 'app/hooks/useSocketIO';
import { toast } from 'features/toast/toast'; import { toast } from 'features/toast/toast';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
@ -7,7 +8,7 @@ import { useClearInvocationCacheMutation, useGetInvocationCacheStatusQuery } fro
export const useClearInvocationCache = () => { export const useClearInvocationCache = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const { data: cacheStatus } = useGetInvocationCacheStatusQuery(); const { data: cacheStatus } = useGetInvocationCacheStatusQuery();
const isConnected = useAppSelector((s) => s.system.isConnected); const isConnected = useStore($isConnected);
const [trigger, { isLoading }] = useClearInvocationCacheMutation({ const [trigger, { isLoading }] = useClearInvocationCacheMutation({
fixedCacheKey: 'clearInvocationCache', fixedCacheKey: 'clearInvocationCache',
}); });

View File

@ -1,4 +1,6 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useStore } from '@nanostores/react';
import { $isConnected } from 'app/hooks/useSocketIO';
import { useAppDispatch } from 'app/store/storeHooks';
import { listCursorChanged, listPriorityChanged } from 'features/queue/store/queueSlice'; import { listCursorChanged, listPriorityChanged } from 'features/queue/store/queueSlice';
import { toast } from 'features/toast/toast'; import { toast } from 'features/toast/toast';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
@ -9,7 +11,7 @@ export const useClearQueue = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { data: queueStatus } = useGetQueueStatusQuery(); const { data: queueStatus } = useGetQueueStatusQuery();
const isConnected = useAppSelector((s) => s.system.isConnected); const isConnected = useStore($isConnected);
const [trigger, { isLoading }] = useClearQueueMutation({ const [trigger, { isLoading }] = useClearQueueMutation({
fixedCacheKey: 'clearQueue', fixedCacheKey: 'clearQueue',
}); });

View File

@ -1,4 +1,5 @@
import { useAppSelector } from 'app/store/storeHooks'; import { useStore } from '@nanostores/react';
import { $isConnected } from 'app/hooks/useSocketIO';
import { toast } from 'features/toast/toast'; import { toast } from 'features/toast/toast';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
@ -7,7 +8,7 @@ import { useDisableInvocationCacheMutation, useGetInvocationCacheStatusQuery } f
export const useDisableInvocationCache = () => { export const useDisableInvocationCache = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const { data: cacheStatus } = useGetInvocationCacheStatusQuery(); const { data: cacheStatus } = useGetInvocationCacheStatusQuery();
const isConnected = useAppSelector((s) => s.system.isConnected); const isConnected = useStore($isConnected);
const [trigger, { isLoading }] = useDisableInvocationCacheMutation({ const [trigger, { isLoading }] = useDisableInvocationCacheMutation({
fixedCacheKey: 'disableInvocationCache', fixedCacheKey: 'disableInvocationCache',
}); });

View File

@ -1,4 +1,5 @@
import { useAppSelector } from 'app/store/storeHooks'; import { useStore } from '@nanostores/react';
import { $isConnected } from 'app/hooks/useSocketIO';
import { toast } from 'features/toast/toast'; import { toast } from 'features/toast/toast';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
@ -7,7 +8,7 @@ import { useEnableInvocationCacheMutation, useGetInvocationCacheStatusQuery } fr
export const useEnableInvocationCache = () => { export const useEnableInvocationCache = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const { data: cacheStatus } = useGetInvocationCacheStatusQuery(); const { data: cacheStatus } = useGetInvocationCacheStatusQuery();
const isConnected = useAppSelector((s) => s.system.isConnected); const isConnected = useStore($isConnected);
const [trigger, { isLoading }] = useEnableInvocationCacheMutation({ const [trigger, { isLoading }] = useEnableInvocationCacheMutation({
fixedCacheKey: 'enableInvocationCache', fixedCacheKey: 'enableInvocationCache',
}); });

View File

@ -1,4 +1,5 @@
import { useAppSelector } from 'app/store/storeHooks'; import { useStore } from '@nanostores/react';
import { $isConnected } from 'app/hooks/useSocketIO';
import { toast } from 'features/toast/toast'; import { toast } from 'features/toast/toast';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
@ -6,7 +7,7 @@ import { useGetQueueStatusQuery, usePauseProcessorMutation } from 'services/api/
export const usePauseProcessor = () => { export const usePauseProcessor = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const isConnected = useAppSelector((s) => s.system.isConnected); const isConnected = useStore($isConnected);
const { data: queueStatus } = useGetQueueStatusQuery(); const { data: queueStatus } = useGetQueueStatusQuery();
const [trigger, { isLoading }] = usePauseProcessorMutation({ const [trigger, { isLoading }] = usePauseProcessorMutation({
fixedCacheKey: 'pauseProcessor', fixedCacheKey: 'pauseProcessor',

View File

@ -1,4 +1,6 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useStore } from '@nanostores/react';
import { $isConnected } from 'app/hooks/useSocketIO';
import { useAppDispatch } from 'app/store/storeHooks';
import { listCursorChanged, listPriorityChanged } from 'features/queue/store/queueSlice'; import { listCursorChanged, listPriorityChanged } from 'features/queue/store/queueSlice';
import { toast } from 'features/toast/toast'; import { toast } from 'features/toast/toast';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
@ -8,7 +10,7 @@ import { useGetQueueStatusQuery, usePruneQueueMutation } from 'services/api/endp
export const usePruneQueue = () => { export const usePruneQueue = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const isConnected = useAppSelector((s) => s.system.isConnected); const isConnected = useStore($isConnected);
const [trigger, { isLoading }] = usePruneQueueMutation({ const [trigger, { isLoading }] = usePruneQueueMutation({
fixedCacheKey: 'pruneQueue', fixedCacheKey: 'pruneQueue',
}); });

View File

@ -1,11 +1,12 @@
import { useAppSelector } from 'app/store/storeHooks'; import { useStore } from '@nanostores/react';
import { $isConnected } from 'app/hooks/useSocketIO';
import { toast } from 'features/toast/toast'; import { toast } from 'features/toast/toast';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useGetQueueStatusQuery, useResumeProcessorMutation } from 'services/api/endpoints/queue'; import { useGetQueueStatusQuery, useResumeProcessorMutation } from 'services/api/endpoints/queue';
export const useResumeProcessor = () => { export const useResumeProcessor = () => {
const isConnected = useAppSelector((s) => s.system.isConnected); const isConnected = useStore($isConnected);
const { data: queueStatus } = useGetQueueStatusQuery(); const { data: queueStatus } = useGetQueueStatusQuery();
const { t } = useTranslation(); const { t } = useTranslation();
const [trigger, { isLoading }] = useResumeProcessorMutation({ const [trigger, { isLoading }] = useResumeProcessorMutation({

View File

@ -1,28 +1,28 @@
import { Progress } from '@invoke-ai/ui-library'; import { Progress } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit'; import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks'; import { $isConnected } from 'app/hooks/useSocketIO';
import { selectSystemSlice } from 'features/system/store/systemSlice'; import { memo, useMemo } from 'react';
import { memo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useGetQueueStatusQuery } from 'services/api/endpoints/queue'; import { useGetQueueStatusQuery } from 'services/api/endpoints/queue';
import { $lastProgressEvent } from 'services/events/setEventListeners';
const selectProgressValue = createSelector(
selectSystemSlice,
(system) => (system.denoiseProgress?.percentage ?? 0) * 100
);
const ProgressBar = () => { const ProgressBar = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const { data: queueStatus } = useGetQueueStatusQuery(); const { data: queueStatus } = useGetQueueStatusQuery();
const isConnected = useAppSelector((s) => s.system.isConnected); const isConnected = useStore($isConnected);
const hasSteps = useAppSelector((s) => Boolean(s.system.denoiseProgress)); const lastProgressEvent = useStore($lastProgressEvent);
const value = useAppSelector(selectProgressValue); const value = useMemo(() => {
if (!lastProgressEvent) {
return 0;
}
return (lastProgressEvent.percentage ?? 0) * 100;
}, [lastProgressEvent]);
return ( return (
<Progress <Progress
value={value} value={value}
aria-label={t('accessibility.invokeProgressBar')} aria-label={t('accessibility.invokeProgressBar')}
isIndeterminate={isConnected && Boolean(queueStatus?.queue.in_progress) && !hasSteps} isIndeterminate={isConnected && Boolean(queueStatus?.queue.in_progress) && !lastProgressEvent}
h={2} h={2}
w="full" w="full"
colorScheme="invokeBlue" colorScheme="invokeBlue"

View File

@ -1,11 +1,12 @@
import { Icon, Tooltip } from '@invoke-ai/ui-library'; import { Icon, Tooltip } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks'; import { useStore } from '@nanostores/react';
import { $isConnected } from 'app/hooks/useSocketIO';
import { memo } from 'react'; import { memo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { PiWarningBold } from 'react-icons/pi'; import { PiWarningBold } from 'react-icons/pi';
const StatusIndicator = () => { const StatusIndicator = () => {
const isConnected = useAppSelector((s) => s.system.isConnected); const isConnected = useStore($isConnected);
const { t } = useTranslation(); const { t } = useTranslation();
if (!isConnected) { if (!isConnected) {

View File

@ -1,136 +0,0 @@
import { $baseUrl } from 'app/store/nanostores/baseUrl';
import { $bulkDownloadId } from 'app/store/nanostores/bulkDownloadId';
import { $queueId } from 'app/store/nanostores/queueId';
import type { AppDispatch } from 'app/store/store';
import { toast } from 'features/toast/toast';
import {
socketBatchEnqueued,
socketBulkDownloadComplete,
socketBulkDownloadError,
socketBulkDownloadStarted,
socketConnected,
socketDisconnected,
socketDownloadCancelled,
socketDownloadComplete,
socketDownloadError,
socketDownloadProgress,
socketDownloadStarted,
socketGeneratorProgress,
socketInvocationComplete,
socketInvocationError,
socketInvocationStarted,
socketModelInstallCancelled,
socketModelInstallComplete,
socketModelInstallDownloadProgress,
socketModelInstallDownloadsComplete,
socketModelInstallError,
socketModelInstallStarted,
socketModelLoadComplete,
socketModelLoadStarted,
socketQueueCleared,
socketQueueItemStatusChanged,
} from 'services/events/actions';
import type { ClientToServerEvents, ServerToClientEvents } from 'services/events/types';
import type { Socket } from 'socket.io-client';
type SetEventListenersArg = {
socket: Socket<ServerToClientEvents, ClientToServerEvents>;
dispatch: AppDispatch;
};
export const setEventListeners = ({ socket, dispatch }: SetEventListenersArg) => {
socket.on('connect', () => {
dispatch(socketConnected());
const queue_id = $queueId.get();
socket.emit('subscribe_queue', { queue_id });
if (!$baseUrl.get()) {
const bulk_download_id = $bulkDownloadId.get();
socket.emit('subscribe_bulk_download', { bulk_download_id });
}
});
socket.on('connect_error', (error) => {
if (error && error.message) {
const data: string | undefined = (error as unknown as { data: string | undefined }).data;
if (data === 'ERR_UNAUTHENTICATED') {
toast({
id: `connect-error-${error.message}`,
title: error.message,
status: 'error',
duration: 10000,
});
}
}
});
socket.on('disconnect', () => {
dispatch(socketDisconnected());
});
socket.on('invocation_started', (data) => {
dispatch(socketInvocationStarted({ data }));
});
socket.on('invocation_denoise_progress', (data) => {
dispatch(socketGeneratorProgress({ data }));
});
socket.on('invocation_error', (data) => {
dispatch(socketInvocationError({ data }));
});
socket.on('invocation_complete', (data) => {
dispatch(socketInvocationComplete({ data }));
});
socket.on('model_load_started', (data) => {
dispatch(socketModelLoadStarted({ data }));
});
socket.on('model_load_complete', (data) => {
dispatch(socketModelLoadComplete({ data }));
});
socket.on('download_started', (data) => {
dispatch(socketDownloadStarted({ data }));
});
socket.on('download_progress', (data) => {
dispatch(socketDownloadProgress({ data }));
});
socket.on('download_complete', (data) => {
dispatch(socketDownloadComplete({ data }));
});
socket.on('download_cancelled', (data) => {
dispatch(socketDownloadCancelled({ data }));
});
socket.on('download_error', (data) => {
dispatch(socketDownloadError({ data }));
});
socket.on('model_install_started', (data) => {
dispatch(socketModelInstallStarted({ data }));
});
socket.on('model_install_download_progress', (data) => {
dispatch(socketModelInstallDownloadProgress({ data }));
});
socket.on('model_install_downloads_complete', (data) => {
dispatch(socketModelInstallDownloadsComplete({ data }));
});
socket.on('model_install_complete', (data) => {
dispatch(socketModelInstallComplete({ data }));
});
socket.on('model_install_error', (data) => {
dispatch(socketModelInstallError({ data }));
});
socket.on('model_install_cancelled', (data) => {
dispatch(socketModelInstallCancelled({ data }));
});
socket.on('queue_item_status_changed', (data) => {
dispatch(socketQueueItemStatusChanged({ data }));
});
socket.on('queue_cleared', (data) => {
dispatch(socketQueueCleared({ data }));
});
socket.on('batch_enqueued', (data) => {
dispatch(socketBatchEnqueued({ data }));
});
socket.on('bulk_download_started', (data) => {
dispatch(socketBulkDownloadStarted({ data }));
});
socket.on('bulk_download_complete', (data) => {
dispatch(socketBulkDownloadComplete({ data }));
});
socket.on('bulk_download_error', (data) => {
dispatch(socketBulkDownloadError({ data }));
});
};

View File

@ -0,0 +1,621 @@
import { ExternalLink } from '@invoke-ai/ui-library';
import { logger } from 'app/logging/logger';
import { $baseUrl } from 'app/store/nanostores/baseUrl';
import { $bulkDownloadId } from 'app/store/nanostores/bulkDownloadId';
import { $queueId } from 'app/store/nanostores/queueId';
import type { AppDispatch, RootState } from 'app/store/store';
import type { SerializableObject } from 'common/types';
import { deepClone } from 'common/util/deepClone';
import { sessionImageStaged } from 'features/controlLayers/store/canvasV2Slice';
import { boardIdSelected, galleryViewChanged, imageSelected, offsetChanged } from 'features/gallery/store/gallerySlice';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import ErrorToastDescription, { getTitleFromErrorType } from 'features/toast/ErrorToastDescription';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { forEach } from 'lodash-es';
import { atom, computed } from 'nanostores';
import { api, LIST_TAG } from 'services/api';
import { boardsApi } from 'services/api/endpoints/boards';
import { imagesApi } from 'services/api/endpoints/images';
import { modelsApi } from 'services/api/endpoints/models';
import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue';
import { getCategories, getListImagesUrl } from 'services/api/util';
import { socketConnected } from 'services/events/actions';
import type { ClientToServerEvents, InvocationDenoiseProgressEvent, ServerToClientEvents } from 'services/events/types';
import type { Socket } from 'socket.io-client';
const log = logger('socketio');
type SetEventListenersArg = {
socket: Socket<ServerToClientEvents, ClientToServerEvents>;
dispatch: AppDispatch;
getState: () => RootState;
setIsConnected: (isConnected: boolean) => void;
};
const selectModelInstalls = modelsApi.endpoints.listModelInstalls.select();
const nodeTypeDenylist = ['load_image', 'image'];
export const $lastProgressEvent = atom<InvocationDenoiseProgressEvent | null>(null);
export const $lastCanvasProgressEvent = atom<InvocationDenoiseProgressEvent | null>(null);
export const $hasProgress = computed($lastProgressEvent, (val) => Boolean(val));
export const $progressImage = computed($lastProgressEvent, (val) => val?.progress_image ?? null);
const cancellations = new Set<string>();
export const setEventListeners = ({ socket, dispatch, getState, setIsConnected }: SetEventListenersArg) => {
socket.on('connect', () => {
log.debug('Connected');
setIsConnected(true);
dispatch(socketConnected());
const queue_id = $queueId.get();
socket.emit('subscribe_queue', { queue_id });
if (!$baseUrl.get()) {
const bulk_download_id = $bulkDownloadId.get();
socket.emit('subscribe_bulk_download', { bulk_download_id });
}
$lastProgressEvent.set(null);
$lastCanvasProgressEvent.set(null);
cancellations.clear();
});
socket.on('connect_error', (error) => {
log.debug('Connect error');
setIsConnected(false);
$lastProgressEvent.set(null);
$lastCanvasProgressEvent.set(null);
if (error && error.message) {
const data: string | undefined = (error as unknown as { data: string | undefined }).data;
if (data === 'ERR_UNAUTHENTICATED') {
toast({
id: `connect-error-${error.message}`,
title: error.message,
status: 'error',
duration: 10000,
});
}
}
cancellations.clear();
});
socket.on('disconnect', () => {
log.debug('Disconnected');
$lastProgressEvent.set(null);
$lastCanvasProgressEvent.set(null);
setIsConnected(false);
cancellations.clear();
});
socket.on('invocation_started', (data) => {
const { invocation_source_id, invocation } = data;
log.debug({ data } as SerializableObject, `Invocation started (${invocation.type}, ${invocation_source_id})`);
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
if (nes) {
nes.status = zNodeStatus.enum.IN_PROGRESS;
upsertExecutionState(nes.nodeId, nes);
}
cancellations.clear();
});
socket.on('invocation_denoise_progress', (data) => {
const { invocation_source_id, invocation, step, total_steps, progress_image, origin, percentage, session_id } =
data;
if (cancellations.has(session_id)) {
// Do not update the progress if this session has been cancelled. This prevents a race condition where we get a
// progress update after the session has been cancelled.
return;
}
log.trace(
{ data } as SerializableObject,
`Denoise ${Math.round(percentage * 100)}% (${invocation.type}, ${invocation_source_id})`
);
$lastProgressEvent.set(data);
if (origin === 'workflows') {
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
if (nes) {
nes.status = zNodeStatus.enum.IN_PROGRESS;
nes.progress = (step + 1) / total_steps;
nes.progressImage = progress_image ?? null;
upsertExecutionState(nes.nodeId, nes);
}
}
if (origin === 'canvas') {
$lastCanvasProgressEvent.set(data);
}
});
socket.on('invocation_error', (data) => {
const { invocation_source_id, invocation, error_type, error_message, error_traceback } = data;
log.error({ data } as SerializableObject, `Invocation error (${invocation.type}, ${invocation_source_id})`);
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
if (nes) {
nes.status = zNodeStatus.enum.FAILED;
nes.progress = null;
nes.progressImage = null;
nes.error = {
error_type,
error_message,
error_traceback,
};
upsertExecutionState(nes.nodeId, nes);
}
});
socket.on('invocation_complete', async (data) => {
log.debug(
{ data } as SerializableObject,
`Invocation complete (${data.invocation.type}, ${data.invocation_source_id})`
);
const { result, invocation_source_id } = data;
if (data.origin === 'workflows') {
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
if (nes) {
nes.status = zNodeStatus.enum.COMPLETED;
if (nes.progress !== null) {
nes.progress = 1;
}
nes.outputs.push(result);
upsertExecutionState(nes.nodeId, nes);
}
}
// This complete event has an associated image output
if (
(data.result.type === 'image_output' || data.result.type === 'canvas_v2_mask_and_crop_output') &&
!nodeTypeDenylist.includes(data.invocation.type)
) {
const { image_name } = data.result.image;
const { gallery, canvasV2 } = getState();
// This populates the `getImageDTO` cache
const imageDTORequest = dispatch(
imagesApi.endpoints.getImageDTO.initiate(image_name, {
forceRefetch: true,
})
);
const imageDTO = await imageDTORequest.unwrap();
imageDTORequest.unsubscribe();
// handle tab-specific logic
if (data.origin === 'canvas' && data.invocation_source_id === 'canvas_output') {
if (data.result.type === 'canvas_v2_mask_and_crop_output') {
const { offset_x, offset_y } = data.result;
if (canvasV2.session.isStaging) {
dispatch(sessionImageStaged({ stagingAreaImage: { imageDTO, offsetX: offset_x, offsetY: offset_y } }));
}
} else if (data.result.type === 'image_output') {
if (canvasV2.session.isStaging) {
dispatch(sessionImageStaged({ stagingAreaImage: { imageDTO, offsetX: 0, offsetY: 0 } }));
}
}
}
if (!imageDTO.is_intermediate) {
// update the total images for the board
dispatch(
boardsApi.util.updateQueryData('getBoardImagesTotal', imageDTO.board_id ?? 'none', (draft) => {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
draft.total += 1;
})
);
dispatch(
imagesApi.util.invalidateTags([
{ type: 'Board', id: imageDTO.board_id ?? 'none' },
{
type: 'ImageList',
id: getListImagesUrl({
board_id: imageDTO.board_id ?? 'none',
categories: getCategories(imageDTO),
}),
},
])
);
const { shouldAutoSwitch } = gallery;
// If auto-switch is enabled, select the new image
if (shouldAutoSwitch) {
// if auto-add is enabled, switch the gallery view and board if needed as the image comes in
if (gallery.galleryView !== 'images') {
dispatch(galleryViewChanged('images'));
}
if (imageDTO.board_id && imageDTO.board_id !== gallery.selectedBoardId) {
dispatch(
boardIdSelected({
boardId: imageDTO.board_id,
selectedImageName: imageDTO.image_name,
})
);
}
dispatch(offsetChanged({ offset: 0 }));
if (!imageDTO.board_id && gallery.selectedBoardId !== 'none') {
dispatch(
boardIdSelected({
boardId: 'none',
selectedImageName: imageDTO.image_name,
})
);
}
dispatch(imageSelected(imageDTO));
}
}
}
$lastProgressEvent.set(null);
});
socket.on('model_load_started', (data) => {
const { config, submodel_type } = data;
const { name, base, type } = config;
const extras: string[] = [base, type];
if (submodel_type) {
extras.push(submodel_type);
}
const message = `Model load started: ${name} (${extras.join(', ')})`;
log.debug({ data }, message);
});
socket.on('model_load_complete', (data) => {
const { config, submodel_type } = data;
const { name, base, type } = config;
const extras: string[] = [base, type];
if (submodel_type) {
extras.push(submodel_type);
}
const message = `Model load complete: ${name} (${extras.join(', ')})`;
log.debug({ data }, message);
});
socket.on('download_started', (data) => {
log.debug({ data }, 'Download started');
});
socket.on('download_progress', (data) => {
log.trace({ data }, 'Download progress');
});
socket.on('download_complete', (data) => {
log.debug({ data }, 'Download complete');
});
socket.on('download_cancelled', (data) => {
log.warn({ data }, 'Download cancelled');
});
socket.on('download_error', (data) => {
log.error({ data }, 'Download error');
});
socket.on('model_install_started', (data) => {
log.debug({ data }, 'Model install started');
const { id } = data;
const installs = selectModelInstalls(getState()).data;
if (!installs?.find((install) => install.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.status = 'running';
}
return draft;
})
);
}
});
socket.on('model_install_download_started', (data) => {
log.debug({ data }, 'Model install download started');
const { id } = data;
const installs = selectModelInstalls(getState()).data;
if (!installs?.find((install) => install.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.status = 'downloading';
}
return draft;
})
);
}
});
socket.on('model_install_download_progress', (data) => {
log.trace({ data }, 'Model install download progress');
const { bytes, total_bytes, id } = data;
const installs = selectModelInstalls(getState()).data;
if (!installs?.find((install) => install.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.bytes = bytes;
modelImport.total_bytes = total_bytes;
modelImport.status = 'downloading';
}
return draft;
})
);
}
});
socket.on('model_install_downloads_complete', (data) => {
log.debug({ data }, 'Model install downloads complete');
const { id } = data;
const installs = selectModelInstalls(getState()).data;
if (!installs?.find((install) => install.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.status = 'downloads_done';
}
return draft;
})
);
}
});
socket.on('model_install_complete', (data) => {
log.debug({ data }, 'Model install complete');
const { id } = data;
const installs = selectModelInstalls(getState()).data;
if (!installs?.find((install) => install.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.status = 'completed';
}
return draft;
})
);
}
dispatch(api.util.invalidateTags([{ type: 'ModelConfig', id: LIST_TAG }]));
dispatch(api.util.invalidateTags([{ type: 'ModelScanFolderResults', id: LIST_TAG }]));
});
socket.on('model_install_error', (data) => {
log.error({ data }, 'Model install error');
const { id, error, error_type } = data;
const installs = selectModelInstalls(getState()).data;
if (!installs?.find((install) => install.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.status = 'error';
modelImport.error_reason = error_type;
modelImport.error = error;
}
return draft;
})
);
}
});
socket.on('model_install_cancelled', (data) => {
log.warn({ data }, 'Model install cancelled');
const { id } = data;
const installs = selectModelInstalls(getState()).data;
if (!installs?.find((install) => install.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.status = 'cancelled';
}
return draft;
})
);
}
});
socket.on('queue_item_status_changed', (data) => {
// we've got new status for the queue item, batch and queue
const {
item_id,
session_id,
status,
started_at,
updated_at,
completed_at,
batch_status,
queue_status,
error_type,
error_message,
error_traceback,
origin,
} = data;
log.debug({ data }, `Queue item ${item_id} status updated: ${status}`);
// Update this specific queue item in the list of queue items (this is the queue item DTO, without the session)
dispatch(
queueApi.util.updateQueryData('listQueueItems', undefined, (draft) => {
queueItemsAdapter.updateOne(draft, {
id: String(item_id),
changes: {
status,
started_at,
updated_at: updated_at ?? undefined,
completed_at: completed_at ?? undefined,
error_type,
error_message,
error_traceback,
},
});
})
);
// Update the queue status (we do not get the processor status here)
dispatch(
queueApi.util.updateQueryData('getQueueStatus', undefined, (draft) => {
if (!draft) {
return;
}
Object.assign(draft.queue, queue_status);
})
);
// Update the batch status
dispatch(queueApi.util.updateQueryData('getBatchStatus', { batch_id: batch_status.batch_id }, () => batch_status));
// Invalidate caches for things we cannot update
// TODO: technically, we could possibly update the current session queue item, but feels safer to just request it again
dispatch(
queueApi.util.invalidateTags([
'CurrentSessionQueueItem',
'NextSessionQueueItem',
'InvocationCacheStatus',
{ type: 'SessionQueueItem', id: item_id },
])
);
if (status === 'in_progress') {
forEach($nodeExecutionStates.get(), (nes) => {
if (!nes) {
return;
}
const clone = deepClone(nes);
clone.status = zNodeStatus.enum.PENDING;
clone.error = null;
clone.progress = null;
clone.progressImage = null;
clone.outputs = [];
$nodeExecutionStates.setKey(clone.nodeId, clone);
});
} else if (status === 'failed' && error_type) {
const isLocal = getState().config.isLocal ?? true;
const sessionId = session_id;
$lastProgressEvent.set(null);
if (origin === 'canvas') {
$lastCanvasProgressEvent.set(null);
}
toast({
id: `INVOCATION_ERROR_${error_type}`,
title: getTitleFromErrorType(error_type),
status: 'error',
duration: null,
updateDescription: isLocal,
description: (
<ErrorToastDescription
errorType={error_type}
errorMessage={error_message}
sessionId={sessionId}
isLocal={isLocal}
/>
),
});
cancellations.add(session_id);
} else if (status === 'canceled') {
$lastProgressEvent.set(null);
if (origin === 'canvas') {
$lastCanvasProgressEvent.set(null);
}
cancellations.add(session_id);
} else if (status === 'completed') {
$lastProgressEvent.set(null);
cancellations.add(session_id);
}
});
socket.on('queue_cleared', (data) => {
log.debug({ data }, 'Queue cleared');
});
socket.on('batch_enqueued', (data) => {
log.debug({ data }, 'Batch enqueued');
});
socket.on('bulk_download_started', (data) => {
log.debug({ data }, 'Bulk gallery download preparation started');
});
socket.on('bulk_download_complete', (data) => {
log.debug({ data }, 'Bulk gallery download ready');
const { bulk_download_item_name } = data;
// TODO(psyche): This URL may break in in some environments (e.g. Nvidia workbench) but we need to test it first
const url = `/api/v1/images/download/${bulk_download_item_name}`;
toast({
id: bulk_download_item_name,
title: t('gallery.bulkDownloadReady', 'Download ready'),
status: 'success',
description: (
<ExternalLink
label={t('gallery.clickToDownload', 'Click here to download')}
href={url}
download={bulk_download_item_name}
/>
),
duration: null,
});
});
socket.on('bulk_download_error', (data) => {
log.error({ data }, 'Bulk gallery download error');
const { bulk_download_item_name, error } = data;
toast({
id: bulk_download_item_name,
title: t('gallery.bulkDownloadFailed'),
status: 'error',
description: error,
duration: null,
});
});
};