feat(ui): move socket event handling out of redux

Download events and invocation status events (including progress images) are very frequent. There's no real need for these to pass through redux. Handling them outside redux is a significant performance win - far fewer store subscription calls, far fewer trips through middleware.

All event handling is moved outside middleware. Cleanup of unused actions and listeners to follow.
This commit is contained in:
psychedelicious 2024-08-17 15:30:55 +10:00
parent 29ac1b5e01
commit b630dbdf20
31 changed files with 809 additions and 301 deletions

View File

@ -2,7 +2,7 @@ import { useStore } from '@nanostores/react';
import { $authToken } from 'app/store/nanostores/authToken';
import { $baseUrl } from 'app/store/nanostores/baseUrl';
import { $isDebugging } from 'app/store/nanostores/isDebugging';
import { useAppDispatch } from 'app/store/storeHooks';
import { useAppStore } from 'app/store/nanostores/store';
import type { MapStore } from 'nanostores';
import { atom, map } from 'nanostores';
import { useEffect, useMemo } from 'react';
@ -28,13 +28,15 @@ export const getSocket = () => {
return socket;
};
export const $socketOptions = map<Partial<ManagerOptions & SocketOptions>>({});
const $isSocketInitialized = atom<boolean>(false);
export const $isConnected = atom<boolean>(false);
/**
* Initializes the socket.io connection and sets up event listeners.
*/
export const useSocketIO = () => {
const dispatch = useAppDispatch();
const { dispatch, getState } = useAppStore();
const baseUrl = useStore($baseUrl);
const authToken = useStore($authToken);
const addlSocketOptions = useStore($socketOptions);
@ -72,7 +74,7 @@ export const useSocketIO = () => {
const socket: AppSocket = io(socketUrl, socketOptions);
$socket.set(socket);
setEventListeners({ dispatch, socket });
setEventListeners({ socket, dispatch, getState, setIsConnected: $isConnected.set });
socket.connect();
if ($isDebugging.get() || import.meta.env.MODE === 'development') {
@ -94,5 +96,5 @@ export const useSocketIO = () => {
socket.disconnect();
$isSocketInitialized.set(false);
};
}, [dispatch, socketOptions, socketUrl]);
}, [dispatch, getState, socketOptions, socketUrl]);
};

View File

@ -1,7 +1,6 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import {
$lastProgressEvent,
rasterLayerAdded,
sessionStagingAreaImageAccepted,
sessionStagingAreaReset,
@ -11,6 +10,7 @@ import { imageDTOToImageObject } from 'features/controlLayers/store/types';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { queueApi } from 'services/api/endpoints/queue';
import { $lastCanvasProgressEvent } from 'services/events/setEventListeners';
import { assert } from 'tsafe';
export const addStagingListeners = (startAppListening: AppStartListening) => {
@ -29,7 +29,7 @@ export const addStagingListeners = (startAppListening: AppStartListening) => {
const { canceled } = await req.unwrap();
req.reset();
$lastProgressEvent.set(null);
$lastCanvasProgressEvent.set(null);
if (canceled > 0) {
log.debug(`Canceled ${canceled} canvas batches`);

View File

@ -2,10 +2,10 @@ import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone';
import { parseify } from 'common/util/serialize';
import { $lastProgressEvent } from 'features/controlLayers/store/canvasV2Slice';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { socketGeneratorProgress } from 'services/events/actions';
import { $lastCanvasProgressEvent } from 'services/events/setEventListeners';
const log = logger('socketio');
@ -27,7 +27,7 @@ export const addGeneratorProgressEventListener = (startAppListening: AppStartLis
}
if (origin === 'canvas') {
$lastProgressEvent.set(action.payload.data);
$lastCanvasProgressEvent.set(action.payload.data);
}
},
});

View File

@ -1,7 +1,6 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone';
import { $lastProgressEvent } from 'features/controlLayers/store/canvasV2Slice';
import { $nodeExecutionStates } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import ErrorToastDescription, { getTitleFromErrorType } from 'features/toast/ErrorToastDescription';
@ -9,6 +8,7 @@ import { toast } from 'features/toast/toast';
import { forEach } from 'lodash-es';
import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue';
import { socketQueueItemStatusChanged } from 'services/events/actions';
import { $lastCanvasProgressEvent } from 'services/events/setEventListeners';
const log = logger('socketio');
@ -17,13 +17,13 @@ export const addSocketQueueEventsListeners = (startAppListening: AppStartListeni
startAppListening({
matcher: queueApi.endpoints.clearQueue.matchFulfilled,
effect: () => {
$lastProgressEvent.set(null);
$lastCanvasProgressEvent.set(null);
},
});
startAppListening({
actionCreator: socketQueueItemStatusChanged,
effect: async (action, { dispatch, getState }) => {
effect: (action, { dispatch, getState }) => {
// we've got new status for the queue item, batch and queue
const {
item_id,
@ -103,7 +103,7 @@ export const addSocketQueueEventsListeners = (startAppListening: AppStartListeni
const isLocal = getState().config.isLocal ?? true;
const sessionId = session_id;
if (origin === 'canvas') {
$lastProgressEvent.set(null);
$lastCanvasProgressEvent.set(null);
}
toast({
@ -122,7 +122,7 @@ export const addSocketQueueEventsListeners = (startAppListening: AppStartListeni
),
});
} else if (status === 'canceled' && origin === 'canvas') {
$lastProgressEvent.set(null);
$lastCanvasProgressEvent.set(null);
}
},
});

View File

@ -1,4 +1,5 @@
import { useStore } from '@nanostores/react';
import { $isConnected } from 'app/hooks/useSocketIO';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectCanvasV2Slice } from 'features/controlLayers/store/canvasV2Slice';
@ -25,7 +26,7 @@ const LAYER_TYPE_TO_TKEY = {
control_layer: 'controlLayers.globalControlAdapter',
} as const;
const createSelector = (templates: Templates) =>
const createSelector = (templates: Templates, isConnected: boolean) =>
createMemoizedSelector(
[
selectSystemSlice,
@ -41,8 +42,6 @@ const createSelector = (templates: Templates) =>
const { bbox } = canvasV2;
const { model, positivePrompt } = canvasV2.params;
const { isConnected } = system;
const reasons: { prefix?: string; content: string }[] = [];
// Cannot generate if not connected
@ -240,7 +239,8 @@ const createSelector = (templates: Templates) =>
export const useIsReadyToEnqueue = () => {
const templates = useStore($templates);
const selector = useMemo(() => createSelector(templates), [templates]);
const isConnected = useStore($isConnected)
const selector = useMemo(() => createSelector(templates, isConnected), [templates, isConnected]);
const value = useAppSelector(selector);
return value;
};

View File

@ -3,3 +3,8 @@ type JSONValue = string | number | boolean | null | JSONValue[] | { [key: string
export interface JSONObject {
[k: string]: JSONValue;
}
type SerializableValue = string | number | boolean | null | undefined | SerializableValue[] | SerializableObject;
export type SerializableObject = {
[k: string | number]: SerializableValue;
};

View File

@ -1,5 +1,7 @@
import { Flex, useShiftModifier } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { skipToken } from '@reduxjs/toolkit/query';
import { $isConnected } from 'app/hooks/useSocketIO';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIDndImage from 'common/components/IAIDndImage';
import IAIDndImageIcon from 'common/components/IAIDndImageIcon';
@ -22,14 +24,17 @@ type Props = {
postUploadAction: PostUploadAction;
};
export const IPAdapterImagePreview = memo(({ image, onChangeImage, ipAdapterId, droppableData, postUploadAction }: Props) => {
export const IPAdapterImagePreview = memo(
({ image, onChangeImage, ipAdapterId, droppableData, postUploadAction }: Props) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const isConnected = useAppSelector((s) => s.system.isConnected);
const isConnected = useStore($isConnected);
const optimalDimension = useAppSelector(selectOptimalDimension);
const shift = useShiftModifier();
const { currentData: controlImage, isError: isErrorControlImage } = useGetImageDTOQuery(image?.image_name ?? skipToken);
const { currentData: controlImage, isError: isErrorControlImage } = useGetImageDTOQuery(
image?.image_name ?? skipToken
);
const handleResetControlImage = useCallback(() => {
onChangeImage(null);
}, [onChangeImage]);
@ -89,12 +94,15 @@ export const IPAdapterImagePreview = memo(({ image, onChangeImage, ipAdapterId,
<IAIDndImageIcon
onClick={handleSetControlImageToDimensions}
icon={<PiRulerBold size={16} />}
tooltip={shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')}
tooltip={
shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')
}
/>
</Flex>
)}
</Flex>
);
});
}
);
IPAdapterImagePreview.displayName = 'IPAdapterImagePreview';

View File

@ -48,7 +48,7 @@ export class CanvasProgressImage {
image: null,
};
this.manager.stateApi.$lastProgressEvent.listen((event) => {
this.manager.stateApi.$lastCanvasProgressEvent.listen((event) => {
this.lastProgressEvent = event;
this.render();
});

View File

@ -76,7 +76,7 @@ export class CanvasStagingArea {
if (!this.image.isLoading && !this.image.isError) {
await this.image.updateImageSource(imageDTO.image_name);
this.manager.stateApi.$lastProgressEvent.set(null);
this.manager.stateApi.$lastCanvasProgressEvent.set(null);
}
this.image.konva.group.visible(shouldShowStagedImage);
} else {

View File

@ -11,7 +11,6 @@ import {
$lastAddedPoint,
$lastCursorPos,
$lastMouseDownPos,
$lastProgressEvent,
$shouldShowStagedImage,
$spaceKey,
$stageAttrs,
@ -51,6 +50,7 @@ import type {
import { RGBA_RED } from 'features/controlLayers/store/types';
import type { WritableAtom } from 'nanostores';
import { atom } from 'nanostores';
import { $lastCanvasProgressEvent } from 'services/events/setEventListeners';
type EntityStateAndAdapter =
| {
@ -263,7 +263,7 @@ export class CanvasStateApi {
$lastAddedPoint = $lastAddedPoint;
$lastMouseDownPos = $lastMouseDownPos;
$lastCursorPos = $lastCursorPos;
$lastProgressEvent = $lastProgressEvent;
$lastCanvasProgressEvent = $lastCanvasProgressEvent;
$spaceKey = $spaceKey;
$altKey = $alt;
$ctrlKey = $ctrl;

View File

@ -20,7 +20,6 @@ import { initialAspectRatioState } from 'features/parameters/components/Document
import { getOptimalDimension } from 'features/parameters/util/optimalDimension';
import { isEqual, pick } from 'lodash-es';
import { atom } from 'nanostores';
import type { InvocationDenoiseProgressEvent } from 'services/events/types';
import { assert } from 'tsafe';
import type {
@ -622,7 +621,6 @@ export const $stageAttrs = atom<StageAttrs>({
scale: 0,
});
export const $shouldShowStagedImage = atom(true);
export const $lastProgressEvent = atom<InvocationDenoiseProgressEvent | null>(null);
export const $isDrawing = atom<boolean>(false);
export const $isMouseDown = atom<boolean>(false);
export const $lastAddedPoint = atom<Coordinate | null>(null);

View File

@ -1,5 +1,7 @@
import type { IconButtonProps } from '@invoke-ai/ui-library';
import { IconButton } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { $isConnected } from 'app/hooks/useSocketIO';
import { useAppSelector } from 'app/store/storeHooks';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
@ -12,7 +14,7 @@ type DeleteImageButtonProps = Omit<IconButtonProps, 'aria-label'> & {
export const DeleteImageButton = memo((props: DeleteImageButtonProps) => {
const { onClick, isDisabled } = props;
const { t } = useTranslation();
const isConnected = useAppSelector((s) => s.system.isConnected);
const isConnected = useStore($isConnected);
const imageSelectionLength: number = useAppSelector((s) => s.gallery.selection.length);
const labelMessage: string = `${t('gallery.deleteImage', { count: imageSelectionLength })} (Del)`;

View File

@ -1,7 +1,7 @@
import { ButtonGroup, IconButton, Menu, MenuButton, MenuList } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { createSelector } from '@reduxjs/toolkit';
import { skipToken } from '@reduxjs/toolkit/query';
import { $isConnected } from 'app/hooks/useSocketIO';
import { adHocPostProcessingRequested } from 'app/store/middleware/listenerMiddleware/listeners/addAdHocPostProcessingRequestedListener';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { DeleteImageButton } from 'features/deleteImageModal/components/DeleteImageButton';
@ -10,17 +10,15 @@ import SingleSelectionMenuItems from 'features/gallery/components/ImageContextMe
import { useImageActions } from 'features/gallery/hooks/useImageActions';
import { sentImageToImg2Img } from 'features/gallery/store/actions';
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
import { selectGallerySlice } from 'features/gallery/store/gallerySlice';
import { parseAndRecallImageDimensions } from 'features/metadata/util/handlers';
import { $templates } from 'features/nodes/store/nodesSlice';
import { PostProcessingPopover } from 'features/parameters/components/PostProcessing/PostProcessingPopover';
import { useIsQueueMutationInProgress } from 'features/queue/hooks/useIsQueueMutationInProgress';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { selectSystemSlice } from 'features/system/store/systemSlice';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { useGetAndLoadEmbeddedWorkflow } from 'features/workflowLibrary/hooks/useGetAndLoadEmbeddedWorkflow';
import { size } from 'lodash-es';
import { memo, useCallback } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import {
@ -33,23 +31,17 @@ import {
PiRulerBold,
} from 'react-icons/pi';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
const selectShouldDisableToolbarButtons = createSelector(
selectSystemSlice,
selectGallerySlice,
selectLastSelectedImage,
(system, gallery, lastSelectedImage) => {
const hasProgressImage = Boolean(system.denoiseProgress?.progress_image);
return hasProgressImage || !lastSelectedImage;
}
);
import { $progressImage } from 'services/events/setEventListeners';
const CurrentImageButtons = () => {
const dispatch = useAppDispatch();
const isConnected = useAppSelector((s) => s.system.isConnected);
const isConnected = useStore($isConnected);
const lastSelectedImage = useAppSelector(selectLastSelectedImage);
const progressImage = useStore($progressImage);
const selection = useAppSelector((s) => s.gallery.selection);
const shouldDisableToolbarButtons = useAppSelector(selectShouldDisableToolbarButtons);
const shouldDisableToolbarButtons = useMemo(() => {
return Boolean(progressImage) || !lastSelectedImage;
}, [lastSelectedImage, progressImage]);
const templates = useStore($templates);
const isUpscalingEnabled = useFeatureStatus('upscaling');
const isQueueMutationInProgress = useIsQueueMutationInProgress();

View File

@ -1,4 +1,5 @@
import { Box, Flex } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { createSelector } from '@reduxjs/toolkit';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks';
@ -14,6 +15,7 @@ import { memo, useCallback, useMemo, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { PiImageBold } from 'react-icons/pi';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { $hasProgress } from 'services/events/setEventListeners';
import ProgressImage from './ProgressImage';
@ -26,7 +28,7 @@ const CurrentImagePreview = () => {
const { t } = useTranslation();
const shouldShowImageDetails = useAppSelector((s) => s.ui.shouldShowImageDetails);
const imageName = useAppSelector(selectLastSelectedImageName);
const hasDenoiseProgress = useAppSelector((s) => Boolean(s.system.denoiseProgress));
const hasDenoiseProgress = useStore($hasProgress);
const shouldShowProgressInViewer = useAppSelector((s) => s.ui.shouldShowProgressInViewer);
const { currentData: imageDTO } = useGetImageDTOQuery(imageName ?? skipToken);

View File

@ -1,10 +1,12 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Image } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks';
import { memo, useMemo } from 'react';
import { $progressImage } from 'services/events/setEventListeners';
const CurrentImagePreview = () => {
const progress_image = useAppSelector((s) => s.system.denoiseProgress?.progress_image);
const progressImage = useStore($progressImage);
const shouldAntialiasProgressImage = useAppSelector((s) => s.system.shouldAntialiasProgressImage);
const sx = useMemo<SystemStyleObject>(
@ -14,15 +16,15 @@ const CurrentImagePreview = () => {
[shouldAntialiasProgressImage]
);
if (!progress_image) {
if (!progressImage) {
return null;
}
return (
<Image
src={progress_image.dataURL}
width={progress_image.width}
height={progress_image.height}
src={progressImage.dataURL}
width={progressImage.width}
height={progressImage.height}
draggable={false}
data-testid="progress-image"
objectFit="contain"

View File

@ -1,36 +1,33 @@
import { Flex, Image, Text } from '@invoke-ai/ui-library';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks';
import IAIDndImage from 'common/components/IAIDndImage';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import NextPrevImageButtons from 'features/gallery/components/NextPrevImageButtons';
import { selectGallerySlice } from 'features/gallery/store/gallerySlice';
import NodeWrapper from 'features/nodes/components/flow/nodes/common/NodeWrapper';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import { selectSystemSlice } from 'features/system/store/systemSlice';
import type { AnimationProps } from 'framer-motion';
import { motion } from 'framer-motion';
import type { CSSProperties, PropsWithChildren } from 'react';
import { memo, useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
import type { NodeProps } from 'reactflow';
const selector = createMemoizedSelector(selectSystemSlice, selectGallerySlice, (system, gallery) => {
const imageDTO = gallery.selection[gallery.selection.length - 1];
return {
imageDTO,
progressImage: system.denoiseProgress?.progress_image,
};
});
import { $lastProgressEvent } from 'services/events/setEventListeners';
const CurrentImageNode = (props: NodeProps) => {
const { progressImage, imageDTO } = useAppSelector(selector);
const imageDTO = useAppSelector((s) => s.gallery.selection[s.gallery.selection.length - 1]);
const lastProgressEvent = useStore($lastProgressEvent);
if (progressImage) {
if (lastProgressEvent?.progress_image) {
return (
<Wrapper nodeProps={props}>
<Image src={progressImage.dataURL} w="full" h="full" objectFit="contain" borderRadius="base" />
<Image
src={lastProgressEvent?.progress_image.dataURL}
w="full"
h="full"
objectFit="contain"
borderRadius="base"
/>
</Wrapper>
);
}

View File

@ -1,6 +1,8 @@
import { Flex, Text } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { $isConnected } from 'app/hooks/useSocketIO';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIDndImage from 'common/components/IAIDndImage';
import IAIDndImageIcon from 'common/components/IAIDndImageIcon';
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
@ -17,7 +19,7 @@ import type { FieldComponentProps } from './types';
const ImageFieldInputComponent = (props: FieldComponentProps<ImageFieldInputInstance, ImageFieldInputTemplate>) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const isConnected = useAppSelector((s) => s.system.isConnected);
const isConnected = useStore($isConnected);
const { currentData: imageDTO, isError } = useGetImageDTOQuery(field.value?.image_name ?? skipToken);
const handleReset = useCallback(() => {

View File

@ -1,11 +1,12 @@
import { useAppSelector } from 'app/store/storeHooks';
import { useStore } from '@nanostores/react';
import { $isConnected } from 'app/hooks/useSocketIO';
import { toast } from 'features/toast/toast';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useCancelByBatchIdsMutation, useGetBatchStatusQuery } from 'services/api/endpoints/queue';
export const useCancelBatch = (batch_id: string) => {
const isConnected = useAppSelector((s) => s.system.isConnected);
const isConnected = useStore($isConnected);
const { isCanceled } = useGetBatchStatusQuery(
{ batch_id },
{

View File

@ -1,4 +1,5 @@
import { useAppSelector } from 'app/store/storeHooks';
import { useStore } from '@nanostores/react';
import { $isConnected } from 'app/hooks/useSocketIO';
import { toast } from 'features/toast/toast';
import { isNil } from 'lodash-es';
import { useCallback, useMemo } from 'react';
@ -6,7 +7,7 @@ import { useTranslation } from 'react-i18next';
import { useCancelQueueItemMutation, useGetQueueStatusQuery } from 'services/api/endpoints/queue';
export const useCancelCurrentQueueItem = () => {
const isConnected = useAppSelector((s) => s.system.isConnected);
const isConnected = useStore($isConnected);
const { data: queueStatus } = useGetQueueStatusQuery();
const [trigger, { isLoading }] = useCancelQueueItemMutation();
const { t } = useTranslation();

View File

@ -1,11 +1,12 @@
import { useAppSelector } from 'app/store/storeHooks';
import { useStore } from '@nanostores/react';
import { $isConnected } from 'app/hooks/useSocketIO';
import { toast } from 'features/toast/toast';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useCancelQueueItemMutation } from 'services/api/endpoints/queue';
export const useCancelQueueItem = (item_id: number) => {
const isConnected = useAppSelector((s) => s.system.isConnected);
const isConnected = useStore($isConnected);
const [trigger, { isLoading }] = useCancelQueueItemMutation();
const { t } = useTranslation();
const cancelQueueItem = useCallback(async () => {

View File

@ -1,4 +1,5 @@
import { useAppSelector } from 'app/store/storeHooks';
import { useStore } from '@nanostores/react';
import { $isConnected } from 'app/hooks/useSocketIO';
import { toast } from 'features/toast/toast';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@ -7,7 +8,7 @@ import { useClearInvocationCacheMutation, useGetInvocationCacheStatusQuery } fro
export const useClearInvocationCache = () => {
const { t } = useTranslation();
const { data: cacheStatus } = useGetInvocationCacheStatusQuery();
const isConnected = useAppSelector((s) => s.system.isConnected);
const isConnected = useStore($isConnected);
const [trigger, { isLoading }] = useClearInvocationCacheMutation({
fixedCacheKey: 'clearInvocationCache',
});

View File

@ -1,4 +1,6 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useStore } from '@nanostores/react';
import { $isConnected } from 'app/hooks/useSocketIO';
import { useAppDispatch } from 'app/store/storeHooks';
import { listCursorChanged, listPriorityChanged } from 'features/queue/store/queueSlice';
import { toast } from 'features/toast/toast';
import { useCallback, useMemo } from 'react';
@ -9,7 +11,7 @@ export const useClearQueue = () => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const { data: queueStatus } = useGetQueueStatusQuery();
const isConnected = useAppSelector((s) => s.system.isConnected);
const isConnected = useStore($isConnected);
const [trigger, { isLoading }] = useClearQueueMutation({
fixedCacheKey: 'clearQueue',
});

View File

@ -1,4 +1,5 @@
import { useAppSelector } from 'app/store/storeHooks';
import { useStore } from '@nanostores/react';
import { $isConnected } from 'app/hooks/useSocketIO';
import { toast } from 'features/toast/toast';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@ -7,7 +8,7 @@ import { useDisableInvocationCacheMutation, useGetInvocationCacheStatusQuery } f
export const useDisableInvocationCache = () => {
const { t } = useTranslation();
const { data: cacheStatus } = useGetInvocationCacheStatusQuery();
const isConnected = useAppSelector((s) => s.system.isConnected);
const isConnected = useStore($isConnected);
const [trigger, { isLoading }] = useDisableInvocationCacheMutation({
fixedCacheKey: 'disableInvocationCache',
});

View File

@ -1,4 +1,5 @@
import { useAppSelector } from 'app/store/storeHooks';
import { useStore } from '@nanostores/react';
import { $isConnected } from 'app/hooks/useSocketIO';
import { toast } from 'features/toast/toast';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@ -7,7 +8,7 @@ import { useEnableInvocationCacheMutation, useGetInvocationCacheStatusQuery } fr
export const useEnableInvocationCache = () => {
const { t } = useTranslation();
const { data: cacheStatus } = useGetInvocationCacheStatusQuery();
const isConnected = useAppSelector((s) => s.system.isConnected);
const isConnected = useStore($isConnected);
const [trigger, { isLoading }] = useEnableInvocationCacheMutation({
fixedCacheKey: 'enableInvocationCache',
});

View File

@ -1,4 +1,5 @@
import { useAppSelector } from 'app/store/storeHooks';
import { useStore } from '@nanostores/react';
import { $isConnected } from 'app/hooks/useSocketIO';
import { toast } from 'features/toast/toast';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@ -6,7 +7,7 @@ import { useGetQueueStatusQuery, usePauseProcessorMutation } from 'services/api/
export const usePauseProcessor = () => {
const { t } = useTranslation();
const isConnected = useAppSelector((s) => s.system.isConnected);
const isConnected = useStore($isConnected);
const { data: queueStatus } = useGetQueueStatusQuery();
const [trigger, { isLoading }] = usePauseProcessorMutation({
fixedCacheKey: 'pauseProcessor',

View File

@ -1,4 +1,6 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useStore } from '@nanostores/react';
import { $isConnected } from 'app/hooks/useSocketIO';
import { useAppDispatch } from 'app/store/storeHooks';
import { listCursorChanged, listPriorityChanged } from 'features/queue/store/queueSlice';
import { toast } from 'features/toast/toast';
import { useCallback, useMemo } from 'react';
@ -8,7 +10,7 @@ import { useGetQueueStatusQuery, usePruneQueueMutation } from 'services/api/endp
export const usePruneQueue = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const isConnected = useAppSelector((s) => s.system.isConnected);
const isConnected = useStore($isConnected);
const [trigger, { isLoading }] = usePruneQueueMutation({
fixedCacheKey: 'pruneQueue',
});

View File

@ -1,11 +1,12 @@
import { useAppSelector } from 'app/store/storeHooks';
import { useStore } from '@nanostores/react';
import { $isConnected } from 'app/hooks/useSocketIO';
import { toast } from 'features/toast/toast';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetQueueStatusQuery, useResumeProcessorMutation } from 'services/api/endpoints/queue';
export const useResumeProcessor = () => {
const isConnected = useAppSelector((s) => s.system.isConnected);
const isConnected = useStore($isConnected);
const { data: queueStatus } = useGetQueueStatusQuery();
const { t } = useTranslation();
const [trigger, { isLoading }] = useResumeProcessorMutation({

View File

@ -1,28 +1,28 @@
import { Progress } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { selectSystemSlice } from 'features/system/store/systemSlice';
import { memo } from 'react';
import { useStore } from '@nanostores/react';
import { $isConnected } from 'app/hooks/useSocketIO';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetQueueStatusQuery } from 'services/api/endpoints/queue';
const selectProgressValue = createSelector(
selectSystemSlice,
(system) => (system.denoiseProgress?.percentage ?? 0) * 100
);
import { $lastProgressEvent } from 'services/events/setEventListeners';
const ProgressBar = () => {
const { t } = useTranslation();
const { data: queueStatus } = useGetQueueStatusQuery();
const isConnected = useAppSelector((s) => s.system.isConnected);
const hasSteps = useAppSelector((s) => Boolean(s.system.denoiseProgress));
const value = useAppSelector(selectProgressValue);
const isConnected = useStore($isConnected);
const lastProgressEvent = useStore($lastProgressEvent);
const value = useMemo(() => {
if (!lastProgressEvent) {
return 0;
}
return (lastProgressEvent.percentage ?? 0) * 100;
}, [lastProgressEvent]);
return (
<Progress
value={value}
aria-label={t('accessibility.invokeProgressBar')}
isIndeterminate={isConnected && Boolean(queueStatus?.queue.in_progress) && !hasSteps}
isIndeterminate={isConnected && Boolean(queueStatus?.queue.in_progress) && !lastProgressEvent}
h={2}
w="full"
colorScheme="invokeBlue"

View File

@ -1,11 +1,12 @@
import { Icon, Tooltip } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { useStore } from '@nanostores/react';
import { $isConnected } from 'app/hooks/useSocketIO';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiWarningBold } from 'react-icons/pi';
const StatusIndicator = () => {
const isConnected = useAppSelector((s) => s.system.isConnected);
const isConnected = useStore($isConnected);
const { t } = useTranslation();
if (!isConnected) {

View File

@ -1,136 +0,0 @@
import { $baseUrl } from 'app/store/nanostores/baseUrl';
import { $bulkDownloadId } from 'app/store/nanostores/bulkDownloadId';
import { $queueId } from 'app/store/nanostores/queueId';
import type { AppDispatch } from 'app/store/store';
import { toast } from 'features/toast/toast';
import {
socketBatchEnqueued,
socketBulkDownloadComplete,
socketBulkDownloadError,
socketBulkDownloadStarted,
socketConnected,
socketDisconnected,
socketDownloadCancelled,
socketDownloadComplete,
socketDownloadError,
socketDownloadProgress,
socketDownloadStarted,
socketGeneratorProgress,
socketInvocationComplete,
socketInvocationError,
socketInvocationStarted,
socketModelInstallCancelled,
socketModelInstallComplete,
socketModelInstallDownloadProgress,
socketModelInstallDownloadsComplete,
socketModelInstallError,
socketModelInstallStarted,
socketModelLoadComplete,
socketModelLoadStarted,
socketQueueCleared,
socketQueueItemStatusChanged,
} from 'services/events/actions';
import type { ClientToServerEvents, ServerToClientEvents } from 'services/events/types';
import type { Socket } from 'socket.io-client';
type SetEventListenersArg = {
socket: Socket<ServerToClientEvents, ClientToServerEvents>;
dispatch: AppDispatch;
};
export const setEventListeners = ({ socket, dispatch }: SetEventListenersArg) => {
socket.on('connect', () => {
dispatch(socketConnected());
const queue_id = $queueId.get();
socket.emit('subscribe_queue', { queue_id });
if (!$baseUrl.get()) {
const bulk_download_id = $bulkDownloadId.get();
socket.emit('subscribe_bulk_download', { bulk_download_id });
}
});
socket.on('connect_error', (error) => {
if (error && error.message) {
const data: string | undefined = (error as unknown as { data: string | undefined }).data;
if (data === 'ERR_UNAUTHENTICATED') {
toast({
id: `connect-error-${error.message}`,
title: error.message,
status: 'error',
duration: 10000,
});
}
}
});
socket.on('disconnect', () => {
dispatch(socketDisconnected());
});
socket.on('invocation_started', (data) => {
dispatch(socketInvocationStarted({ data }));
});
socket.on('invocation_denoise_progress', (data) => {
dispatch(socketGeneratorProgress({ data }));
});
socket.on('invocation_error', (data) => {
dispatch(socketInvocationError({ data }));
});
socket.on('invocation_complete', (data) => {
dispatch(socketInvocationComplete({ data }));
});
socket.on('model_load_started', (data) => {
dispatch(socketModelLoadStarted({ data }));
});
socket.on('model_load_complete', (data) => {
dispatch(socketModelLoadComplete({ data }));
});
socket.on('download_started', (data) => {
dispatch(socketDownloadStarted({ data }));
});
socket.on('download_progress', (data) => {
dispatch(socketDownloadProgress({ data }));
});
socket.on('download_complete', (data) => {
dispatch(socketDownloadComplete({ data }));
});
socket.on('download_cancelled', (data) => {
dispatch(socketDownloadCancelled({ data }));
});
socket.on('download_error', (data) => {
dispatch(socketDownloadError({ data }));
});
socket.on('model_install_started', (data) => {
dispatch(socketModelInstallStarted({ data }));
});
socket.on('model_install_download_progress', (data) => {
dispatch(socketModelInstallDownloadProgress({ data }));
});
socket.on('model_install_downloads_complete', (data) => {
dispatch(socketModelInstallDownloadsComplete({ data }));
});
socket.on('model_install_complete', (data) => {
dispatch(socketModelInstallComplete({ data }));
});
socket.on('model_install_error', (data) => {
dispatch(socketModelInstallError({ data }));
});
socket.on('model_install_cancelled', (data) => {
dispatch(socketModelInstallCancelled({ data }));
});
socket.on('queue_item_status_changed', (data) => {
dispatch(socketQueueItemStatusChanged({ data }));
});
socket.on('queue_cleared', (data) => {
dispatch(socketQueueCleared({ data }));
});
socket.on('batch_enqueued', (data) => {
dispatch(socketBatchEnqueued({ data }));
});
socket.on('bulk_download_started', (data) => {
dispatch(socketBulkDownloadStarted({ data }));
});
socket.on('bulk_download_complete', (data) => {
dispatch(socketBulkDownloadComplete({ data }));
});
socket.on('bulk_download_error', (data) => {
dispatch(socketBulkDownloadError({ data }));
});
};

View File

@ -0,0 +1,621 @@
import { ExternalLink } from '@invoke-ai/ui-library';
import { logger } from 'app/logging/logger';
import { $baseUrl } from 'app/store/nanostores/baseUrl';
import { $bulkDownloadId } from 'app/store/nanostores/bulkDownloadId';
import { $queueId } from 'app/store/nanostores/queueId';
import type { AppDispatch, RootState } from 'app/store/store';
import type { SerializableObject } from 'common/types';
import { deepClone } from 'common/util/deepClone';
import { sessionImageStaged } from 'features/controlLayers/store/canvasV2Slice';
import { boardIdSelected, galleryViewChanged, imageSelected, offsetChanged } from 'features/gallery/store/gallerySlice';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import ErrorToastDescription, { getTitleFromErrorType } from 'features/toast/ErrorToastDescription';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { forEach } from 'lodash-es';
import { atom, computed } from 'nanostores';
import { api, LIST_TAG } from 'services/api';
import { boardsApi } from 'services/api/endpoints/boards';
import { imagesApi } from 'services/api/endpoints/images';
import { modelsApi } from 'services/api/endpoints/models';
import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue';
import { getCategories, getListImagesUrl } from 'services/api/util';
import { socketConnected } from 'services/events/actions';
import type { ClientToServerEvents, InvocationDenoiseProgressEvent, ServerToClientEvents } from 'services/events/types';
import type { Socket } from 'socket.io-client';
const log = logger('socketio');
type SetEventListenersArg = {
socket: Socket<ServerToClientEvents, ClientToServerEvents>;
dispatch: AppDispatch;
getState: () => RootState;
setIsConnected: (isConnected: boolean) => void;
};
const selectModelInstalls = modelsApi.endpoints.listModelInstalls.select();
const nodeTypeDenylist = ['load_image', 'image'];
export const $lastProgressEvent = atom<InvocationDenoiseProgressEvent | null>(null);
export const $lastCanvasProgressEvent = atom<InvocationDenoiseProgressEvent | null>(null);
export const $hasProgress = computed($lastProgressEvent, (val) => Boolean(val));
export const $progressImage = computed($lastProgressEvent, (val) => val?.progress_image ?? null);
const cancellations = new Set<string>();
export const setEventListeners = ({ socket, dispatch, getState, setIsConnected }: SetEventListenersArg) => {
socket.on('connect', () => {
log.debug('Connected');
setIsConnected(true);
dispatch(socketConnected());
const queue_id = $queueId.get();
socket.emit('subscribe_queue', { queue_id });
if (!$baseUrl.get()) {
const bulk_download_id = $bulkDownloadId.get();
socket.emit('subscribe_bulk_download', { bulk_download_id });
}
$lastProgressEvent.set(null);
$lastCanvasProgressEvent.set(null);
cancellations.clear();
});
socket.on('connect_error', (error) => {
log.debug('Connect error');
setIsConnected(false);
$lastProgressEvent.set(null);
$lastCanvasProgressEvent.set(null);
if (error && error.message) {
const data: string | undefined = (error as unknown as { data: string | undefined }).data;
if (data === 'ERR_UNAUTHENTICATED') {
toast({
id: `connect-error-${error.message}`,
title: error.message,
status: 'error',
duration: 10000,
});
}
}
cancellations.clear();
});
socket.on('disconnect', () => {
log.debug('Disconnected');
$lastProgressEvent.set(null);
$lastCanvasProgressEvent.set(null);
setIsConnected(false);
cancellations.clear();
});
socket.on('invocation_started', (data) => {
const { invocation_source_id, invocation } = data;
log.debug({ data } as SerializableObject, `Invocation started (${invocation.type}, ${invocation_source_id})`);
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
if (nes) {
nes.status = zNodeStatus.enum.IN_PROGRESS;
upsertExecutionState(nes.nodeId, nes);
}
cancellations.clear();
});
socket.on('invocation_denoise_progress', (data) => {
const { invocation_source_id, invocation, step, total_steps, progress_image, origin, percentage, session_id } =
data;
if (cancellations.has(session_id)) {
// Do not update the progress if this session has been cancelled. This prevents a race condition where we get a
// progress update after the session has been cancelled.
return;
}
log.trace(
{ data } as SerializableObject,
`Denoise ${Math.round(percentage * 100)}% (${invocation.type}, ${invocation_source_id})`
);
$lastProgressEvent.set(data);
if (origin === 'workflows') {
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
if (nes) {
nes.status = zNodeStatus.enum.IN_PROGRESS;
nes.progress = (step + 1) / total_steps;
nes.progressImage = progress_image ?? null;
upsertExecutionState(nes.nodeId, nes);
}
}
if (origin === 'canvas') {
$lastCanvasProgressEvent.set(data);
}
});
socket.on('invocation_error', (data) => {
const { invocation_source_id, invocation, error_type, error_message, error_traceback } = data;
log.error({ data } as SerializableObject, `Invocation error (${invocation.type}, ${invocation_source_id})`);
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
if (nes) {
nes.status = zNodeStatus.enum.FAILED;
nes.progress = null;
nes.progressImage = null;
nes.error = {
error_type,
error_message,
error_traceback,
};
upsertExecutionState(nes.nodeId, nes);
}
});
socket.on('invocation_complete', async (data) => {
log.debug(
{ data } as SerializableObject,
`Invocation complete (${data.invocation.type}, ${data.invocation_source_id})`
);
const { result, invocation_source_id } = data;
if (data.origin === 'workflows') {
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
if (nes) {
nes.status = zNodeStatus.enum.COMPLETED;
if (nes.progress !== null) {
nes.progress = 1;
}
nes.outputs.push(result);
upsertExecutionState(nes.nodeId, nes);
}
}
// This complete event has an associated image output
if (
(data.result.type === 'image_output' || data.result.type === 'canvas_v2_mask_and_crop_output') &&
!nodeTypeDenylist.includes(data.invocation.type)
) {
const { image_name } = data.result.image;
const { gallery, canvasV2 } = getState();
// This populates the `getImageDTO` cache
const imageDTORequest = dispatch(
imagesApi.endpoints.getImageDTO.initiate(image_name, {
forceRefetch: true,
})
);
const imageDTO = await imageDTORequest.unwrap();
imageDTORequest.unsubscribe();
// handle tab-specific logic
if (data.origin === 'canvas' && data.invocation_source_id === 'canvas_output') {
if (data.result.type === 'canvas_v2_mask_and_crop_output') {
const { offset_x, offset_y } = data.result;
if (canvasV2.session.isStaging) {
dispatch(sessionImageStaged({ stagingAreaImage: { imageDTO, offsetX: offset_x, offsetY: offset_y } }));
}
} else if (data.result.type === 'image_output') {
if (canvasV2.session.isStaging) {
dispatch(sessionImageStaged({ stagingAreaImage: { imageDTO, offsetX: 0, offsetY: 0 } }));
}
}
}
if (!imageDTO.is_intermediate) {
// update the total images for the board
dispatch(
boardsApi.util.updateQueryData('getBoardImagesTotal', imageDTO.board_id ?? 'none', (draft) => {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
draft.total += 1;
})
);
dispatch(
imagesApi.util.invalidateTags([
{ type: 'Board', id: imageDTO.board_id ?? 'none' },
{
type: 'ImageList',
id: getListImagesUrl({
board_id: imageDTO.board_id ?? 'none',
categories: getCategories(imageDTO),
}),
},
])
);
const { shouldAutoSwitch } = gallery;
// If auto-switch is enabled, select the new image
if (shouldAutoSwitch) {
// if auto-add is enabled, switch the gallery view and board if needed as the image comes in
if (gallery.galleryView !== 'images') {
dispatch(galleryViewChanged('images'));
}
if (imageDTO.board_id && imageDTO.board_id !== gallery.selectedBoardId) {
dispatch(
boardIdSelected({
boardId: imageDTO.board_id,
selectedImageName: imageDTO.image_name,
})
);
}
dispatch(offsetChanged({ offset: 0 }));
if (!imageDTO.board_id && gallery.selectedBoardId !== 'none') {
dispatch(
boardIdSelected({
boardId: 'none',
selectedImageName: imageDTO.image_name,
})
);
}
dispatch(imageSelected(imageDTO));
}
}
}
$lastProgressEvent.set(null);
});
socket.on('model_load_started', (data) => {
const { config, submodel_type } = data;
const { name, base, type } = config;
const extras: string[] = [base, type];
if (submodel_type) {
extras.push(submodel_type);
}
const message = `Model load started: ${name} (${extras.join(', ')})`;
log.debug({ data }, message);
});
socket.on('model_load_complete', (data) => {
const { config, submodel_type } = data;
const { name, base, type } = config;
const extras: string[] = [base, type];
if (submodel_type) {
extras.push(submodel_type);
}
const message = `Model load complete: ${name} (${extras.join(', ')})`;
log.debug({ data }, message);
});
socket.on('download_started', (data) => {
log.debug({ data }, 'Download started');
});
socket.on('download_progress', (data) => {
log.trace({ data }, 'Download progress');
});
socket.on('download_complete', (data) => {
log.debug({ data }, 'Download complete');
});
socket.on('download_cancelled', (data) => {
log.warn({ data }, 'Download cancelled');
});
socket.on('download_error', (data) => {
log.error({ data }, 'Download error');
});
socket.on('model_install_started', (data) => {
log.debug({ data }, 'Model install started');
const { id } = data;
const installs = selectModelInstalls(getState()).data;
if (!installs?.find((install) => install.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.status = 'running';
}
return draft;
})
);
}
});
socket.on('model_install_download_started', (data) => {
log.debug({ data }, 'Model install download started');
const { id } = data;
const installs = selectModelInstalls(getState()).data;
if (!installs?.find((install) => install.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.status = 'downloading';
}
return draft;
})
);
}
});
socket.on('model_install_download_progress', (data) => {
log.trace({ data }, 'Model install download progress');
const { bytes, total_bytes, id } = data;
const installs = selectModelInstalls(getState()).data;
if (!installs?.find((install) => install.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.bytes = bytes;
modelImport.total_bytes = total_bytes;
modelImport.status = 'downloading';
}
return draft;
})
);
}
});
socket.on('model_install_downloads_complete', (data) => {
log.debug({ data }, 'Model install downloads complete');
const { id } = data;
const installs = selectModelInstalls(getState()).data;
if (!installs?.find((install) => install.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.status = 'downloads_done';
}
return draft;
})
);
}
});
socket.on('model_install_complete', (data) => {
log.debug({ data }, 'Model install complete');
const { id } = data;
const installs = selectModelInstalls(getState()).data;
if (!installs?.find((install) => install.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.status = 'completed';
}
return draft;
})
);
}
dispatch(api.util.invalidateTags([{ type: 'ModelConfig', id: LIST_TAG }]));
dispatch(api.util.invalidateTags([{ type: 'ModelScanFolderResults', id: LIST_TAG }]));
});
socket.on('model_install_error', (data) => {
log.error({ data }, 'Model install error');
const { id, error, error_type } = data;
const installs = selectModelInstalls(getState()).data;
if (!installs?.find((install) => install.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.status = 'error';
modelImport.error_reason = error_type;
modelImport.error = error;
}
return draft;
})
);
}
});
socket.on('model_install_cancelled', (data) => {
log.warn({ data }, 'Model install cancelled');
const { id } = data;
const installs = selectModelInstalls(getState()).data;
if (!installs?.find((install) => install.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.status = 'cancelled';
}
return draft;
})
);
}
});
socket.on('queue_item_status_changed', (data) => {
// we've got new status for the queue item, batch and queue
const {
item_id,
session_id,
status,
started_at,
updated_at,
completed_at,
batch_status,
queue_status,
error_type,
error_message,
error_traceback,
origin,
} = data;
log.debug({ data }, `Queue item ${item_id} status updated: ${status}`);
// Update this specific queue item in the list of queue items (this is the queue item DTO, without the session)
dispatch(
queueApi.util.updateQueryData('listQueueItems', undefined, (draft) => {
queueItemsAdapter.updateOne(draft, {
id: String(item_id),
changes: {
status,
started_at,
updated_at: updated_at ?? undefined,
completed_at: completed_at ?? undefined,
error_type,
error_message,
error_traceback,
},
});
})
);
// Update the queue status (we do not get the processor status here)
dispatch(
queueApi.util.updateQueryData('getQueueStatus', undefined, (draft) => {
if (!draft) {
return;
}
Object.assign(draft.queue, queue_status);
})
);
// Update the batch status
dispatch(queueApi.util.updateQueryData('getBatchStatus', { batch_id: batch_status.batch_id }, () => batch_status));
// Invalidate caches for things we cannot update
// TODO: technically, we could possibly update the current session queue item, but feels safer to just request it again
dispatch(
queueApi.util.invalidateTags([
'CurrentSessionQueueItem',
'NextSessionQueueItem',
'InvocationCacheStatus',
{ type: 'SessionQueueItem', id: item_id },
])
);
if (status === 'in_progress') {
forEach($nodeExecutionStates.get(), (nes) => {
if (!nes) {
return;
}
const clone = deepClone(nes);
clone.status = zNodeStatus.enum.PENDING;
clone.error = null;
clone.progress = null;
clone.progressImage = null;
clone.outputs = [];
$nodeExecutionStates.setKey(clone.nodeId, clone);
});
} else if (status === 'failed' && error_type) {
const isLocal = getState().config.isLocal ?? true;
const sessionId = session_id;
$lastProgressEvent.set(null);
if (origin === 'canvas') {
$lastCanvasProgressEvent.set(null);
}
toast({
id: `INVOCATION_ERROR_${error_type}`,
title: getTitleFromErrorType(error_type),
status: 'error',
duration: null,
updateDescription: isLocal,
description: (
<ErrorToastDescription
errorType={error_type}
errorMessage={error_message}
sessionId={sessionId}
isLocal={isLocal}
/>
),
});
cancellations.add(session_id);
} else if (status === 'canceled') {
$lastProgressEvent.set(null);
if (origin === 'canvas') {
$lastCanvasProgressEvent.set(null);
}
cancellations.add(session_id);
} else if (status === 'completed') {
$lastProgressEvent.set(null);
cancellations.add(session_id);
}
});
socket.on('queue_cleared', (data) => {
log.debug({ data }, 'Queue cleared');
});
socket.on('batch_enqueued', (data) => {
log.debug({ data }, 'Batch enqueued');
});
socket.on('bulk_download_started', (data) => {
log.debug({ data }, 'Bulk gallery download preparation started');
});
socket.on('bulk_download_complete', (data) => {
log.debug({ data }, 'Bulk gallery download ready');
const { bulk_download_item_name } = data;
// TODO(psyche): This URL may break in in some environments (e.g. Nvidia workbench) but we need to test it first
const url = `/api/v1/images/download/${bulk_download_item_name}`;
toast({
id: bulk_download_item_name,
title: t('gallery.bulkDownloadReady', 'Download ready'),
status: 'success',
description: (
<ExternalLink
label={t('gallery.clickToDownload', 'Click here to download')}
href={url}
download={bulk_download_item_name}
/>
),
duration: null,
});
});
socket.on('bulk_download_error', (data) => {
log.error({ data }, 'Bulk gallery download error');
const { bulk_download_item_name, error } = data;
toast({
id: bulk_download_item_name,
title: t('gallery.bulkDownloadFailed'),
status: 'error',
description: error,
duration: null,
});
});
};