mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): move socket event handling out of redux
Download events and invocation status events (including progress images) are very frequent. There's no real need for these to pass through redux. Handling them outside redux is a significant performance win - far fewer store subscription calls, far fewer trips through middleware. All event handling is moved outside middleware. Cleanup of unused actions and listeners to follow.
This commit is contained in:
parent
29ac1b5e01
commit
b630dbdf20
@ -2,7 +2,7 @@ import { useStore } from '@nanostores/react';
|
||||
import { $authToken } from 'app/store/nanostores/authToken';
|
||||
import { $baseUrl } from 'app/store/nanostores/baseUrl';
|
||||
import { $isDebugging } from 'app/store/nanostores/isDebugging';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useAppStore } from 'app/store/nanostores/store';
|
||||
import type { MapStore } from 'nanostores';
|
||||
import { atom, map } from 'nanostores';
|
||||
import { useEffect, useMemo } from 'react';
|
||||
@ -28,13 +28,15 @@ export const getSocket = () => {
|
||||
return socket;
|
||||
};
|
||||
export const $socketOptions = map<Partial<ManagerOptions & SocketOptions>>({});
|
||||
|
||||
const $isSocketInitialized = atom<boolean>(false);
|
||||
export const $isConnected = atom<boolean>(false);
|
||||
|
||||
/**
|
||||
* Initializes the socket.io connection and sets up event listeners.
|
||||
*/
|
||||
export const useSocketIO = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { dispatch, getState } = useAppStore();
|
||||
const baseUrl = useStore($baseUrl);
|
||||
const authToken = useStore($authToken);
|
||||
const addlSocketOptions = useStore($socketOptions);
|
||||
@ -72,7 +74,7 @@ export const useSocketIO = () => {
|
||||
|
||||
const socket: AppSocket = io(socketUrl, socketOptions);
|
||||
$socket.set(socket);
|
||||
setEventListeners({ dispatch, socket });
|
||||
setEventListeners({ socket, dispatch, getState, setIsConnected: $isConnected.set });
|
||||
socket.connect();
|
||||
|
||||
if ($isDebugging.get() || import.meta.env.MODE === 'development') {
|
||||
@ -94,5 +96,5 @@ export const useSocketIO = () => {
|
||||
socket.disconnect();
|
||||
$isSocketInitialized.set(false);
|
||||
};
|
||||
}, [dispatch, socketOptions, socketUrl]);
|
||||
}, [dispatch, getState, socketOptions, socketUrl]);
|
||||
};
|
||||
|
@ -1,7 +1,6 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import {
|
||||
$lastProgressEvent,
|
||||
rasterLayerAdded,
|
||||
sessionStagingAreaImageAccepted,
|
||||
sessionStagingAreaReset,
|
||||
@ -11,6 +10,7 @@ import { imageDTOToImageObject } from 'features/controlLayers/store/types';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import { $lastCanvasProgressEvent } from 'services/events/setEventListeners';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
export const addStagingListeners = (startAppListening: AppStartListening) => {
|
||||
@ -29,7 +29,7 @@ export const addStagingListeners = (startAppListening: AppStartListening) => {
|
||||
const { canceled } = await req.unwrap();
|
||||
req.reset();
|
||||
|
||||
$lastProgressEvent.set(null);
|
||||
$lastCanvasProgressEvent.set(null);
|
||||
|
||||
if (canceled > 0) {
|
||||
log.debug(`Canceled ${canceled} canvas batches`);
|
||||
|
@ -2,10 +2,10 @@ import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { $lastProgressEvent } from 'features/controlLayers/store/canvasV2Slice';
|
||||
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { socketGeneratorProgress } from 'services/events/actions';
|
||||
import { $lastCanvasProgressEvent } from 'services/events/setEventListeners';
|
||||
|
||||
const log = logger('socketio');
|
||||
|
||||
@ -27,7 +27,7 @@ export const addGeneratorProgressEventListener = (startAppListening: AppStartLis
|
||||
}
|
||||
|
||||
if (origin === 'canvas') {
|
||||
$lastProgressEvent.set(action.payload.data);
|
||||
$lastCanvasProgressEvent.set(action.payload.data);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
@ -1,7 +1,6 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { $lastProgressEvent } from 'features/controlLayers/store/canvasV2Slice';
|
||||
import { $nodeExecutionStates } from 'features/nodes/hooks/useExecutionState';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import ErrorToastDescription, { getTitleFromErrorType } from 'features/toast/ErrorToastDescription';
|
||||
@ -9,6 +8,7 @@ import { toast } from 'features/toast/toast';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue';
|
||||
import { socketQueueItemStatusChanged } from 'services/events/actions';
|
||||
import { $lastCanvasProgressEvent } from 'services/events/setEventListeners';
|
||||
|
||||
const log = logger('socketio');
|
||||
|
||||
@ -17,13 +17,13 @@ export const addSocketQueueEventsListeners = (startAppListening: AppStartListeni
|
||||
startAppListening({
|
||||
matcher: queueApi.endpoints.clearQueue.matchFulfilled,
|
||||
effect: () => {
|
||||
$lastProgressEvent.set(null);
|
||||
$lastCanvasProgressEvent.set(null);
|
||||
},
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
actionCreator: socketQueueItemStatusChanged,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
// we've got new status for the queue item, batch and queue
|
||||
const {
|
||||
item_id,
|
||||
@ -103,7 +103,7 @@ export const addSocketQueueEventsListeners = (startAppListening: AppStartListeni
|
||||
const isLocal = getState().config.isLocal ?? true;
|
||||
const sessionId = session_id;
|
||||
if (origin === 'canvas') {
|
||||
$lastProgressEvent.set(null);
|
||||
$lastCanvasProgressEvent.set(null);
|
||||
}
|
||||
|
||||
toast({
|
||||
@ -122,7 +122,7 @@ export const addSocketQueueEventsListeners = (startAppListening: AppStartListeni
|
||||
),
|
||||
});
|
||||
} else if (status === 'canceled' && origin === 'canvas') {
|
||||
$lastProgressEvent.set(null);
|
||||
$lastCanvasProgressEvent.set(null);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
@ -1,4 +1,5 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $isConnected } from 'app/hooks/useSocketIO';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectCanvasV2Slice } from 'features/controlLayers/store/canvasV2Slice';
|
||||
@ -25,7 +26,7 @@ const LAYER_TYPE_TO_TKEY = {
|
||||
control_layer: 'controlLayers.globalControlAdapter',
|
||||
} as const;
|
||||
|
||||
const createSelector = (templates: Templates) =>
|
||||
const createSelector = (templates: Templates, isConnected: boolean) =>
|
||||
createMemoizedSelector(
|
||||
[
|
||||
selectSystemSlice,
|
||||
@ -41,8 +42,6 @@ const createSelector = (templates: Templates) =>
|
||||
const { bbox } = canvasV2;
|
||||
const { model, positivePrompt } = canvasV2.params;
|
||||
|
||||
const { isConnected } = system;
|
||||
|
||||
const reasons: { prefix?: string; content: string }[] = [];
|
||||
|
||||
// Cannot generate if not connected
|
||||
@ -240,7 +239,8 @@ const createSelector = (templates: Templates) =>
|
||||
|
||||
export const useIsReadyToEnqueue = () => {
|
||||
const templates = useStore($templates);
|
||||
const selector = useMemo(() => createSelector(templates), [templates]);
|
||||
const isConnected = useStore($isConnected)
|
||||
const selector = useMemo(() => createSelector(templates, isConnected), [templates, isConnected]);
|
||||
const value = useAppSelector(selector);
|
||||
return value;
|
||||
};
|
||||
|
@ -3,3 +3,8 @@ type JSONValue = string | number | boolean | null | JSONValue[] | { [key: string
|
||||
export interface JSONObject {
|
||||
[k: string]: JSONValue;
|
||||
}
|
||||
|
||||
type SerializableValue = string | number | boolean | null | undefined | SerializableValue[] | SerializableObject;
|
||||
export type SerializableObject = {
|
||||
[k: string | number]: SerializableValue;
|
||||
};
|
||||
|
@ -1,5 +1,7 @@
|
||||
import { Flex, useShiftModifier } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { $isConnected } from 'app/hooks/useSocketIO';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIDndImage from 'common/components/IAIDndImage';
|
||||
import IAIDndImageIcon from 'common/components/IAIDndImageIcon';
|
||||
@ -22,14 +24,17 @@ type Props = {
|
||||
postUploadAction: PostUploadAction;
|
||||
};
|
||||
|
||||
export const IPAdapterImagePreview = memo(({ image, onChangeImage, ipAdapterId, droppableData, postUploadAction }: Props) => {
|
||||
export const IPAdapterImagePreview = memo(
|
||||
({ image, onChangeImage, ipAdapterId, droppableData, postUploadAction }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const isConnected = useAppSelector((s) => s.system.isConnected);
|
||||
const isConnected = useStore($isConnected);
|
||||
const optimalDimension = useAppSelector(selectOptimalDimension);
|
||||
const shift = useShiftModifier();
|
||||
|
||||
const { currentData: controlImage, isError: isErrorControlImage } = useGetImageDTOQuery(image?.image_name ?? skipToken);
|
||||
const { currentData: controlImage, isError: isErrorControlImage } = useGetImageDTOQuery(
|
||||
image?.image_name ?? skipToken
|
||||
);
|
||||
const handleResetControlImage = useCallback(() => {
|
||||
onChangeImage(null);
|
||||
}, [onChangeImage]);
|
||||
@ -89,12 +94,15 @@ export const IPAdapterImagePreview = memo(({ image, onChangeImage, ipAdapterId,
|
||||
<IAIDndImageIcon
|
||||
onClick={handleSetControlImageToDimensions}
|
||||
icon={<PiRulerBold size={16} />}
|
||||
tooltip={shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')}
|
||||
tooltip={
|
||||
shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')
|
||||
}
|
||||
/>
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
}
|
||||
);
|
||||
|
||||
IPAdapterImagePreview.displayName = 'IPAdapterImagePreview';
|
||||
|
@ -48,7 +48,7 @@ export class CanvasProgressImage {
|
||||
image: null,
|
||||
};
|
||||
|
||||
this.manager.stateApi.$lastProgressEvent.listen((event) => {
|
||||
this.manager.stateApi.$lastCanvasProgressEvent.listen((event) => {
|
||||
this.lastProgressEvent = event;
|
||||
this.render();
|
||||
});
|
||||
|
@ -76,7 +76,7 @@ export class CanvasStagingArea {
|
||||
|
||||
if (!this.image.isLoading && !this.image.isError) {
|
||||
await this.image.updateImageSource(imageDTO.image_name);
|
||||
this.manager.stateApi.$lastProgressEvent.set(null);
|
||||
this.manager.stateApi.$lastCanvasProgressEvent.set(null);
|
||||
}
|
||||
this.image.konva.group.visible(shouldShowStagedImage);
|
||||
} else {
|
||||
|
@ -11,7 +11,6 @@ import {
|
||||
$lastAddedPoint,
|
||||
$lastCursorPos,
|
||||
$lastMouseDownPos,
|
||||
$lastProgressEvent,
|
||||
$shouldShowStagedImage,
|
||||
$spaceKey,
|
||||
$stageAttrs,
|
||||
@ -51,6 +50,7 @@ import type {
|
||||
import { RGBA_RED } from 'features/controlLayers/store/types';
|
||||
import type { WritableAtom } from 'nanostores';
|
||||
import { atom } from 'nanostores';
|
||||
import { $lastCanvasProgressEvent } from 'services/events/setEventListeners';
|
||||
|
||||
type EntityStateAndAdapter =
|
||||
| {
|
||||
@ -263,7 +263,7 @@ export class CanvasStateApi {
|
||||
$lastAddedPoint = $lastAddedPoint;
|
||||
$lastMouseDownPos = $lastMouseDownPos;
|
||||
$lastCursorPos = $lastCursorPos;
|
||||
$lastProgressEvent = $lastProgressEvent;
|
||||
$lastCanvasProgressEvent = $lastCanvasProgressEvent;
|
||||
$spaceKey = $spaceKey;
|
||||
$altKey = $alt;
|
||||
$ctrlKey = $ctrl;
|
||||
|
@ -20,7 +20,6 @@ import { initialAspectRatioState } from 'features/parameters/components/Document
|
||||
import { getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||
import { isEqual, pick } from 'lodash-es';
|
||||
import { atom } from 'nanostores';
|
||||
import type { InvocationDenoiseProgressEvent } from 'services/events/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import type {
|
||||
@ -622,7 +621,6 @@ export const $stageAttrs = atom<StageAttrs>({
|
||||
scale: 0,
|
||||
});
|
||||
export const $shouldShowStagedImage = atom(true);
|
||||
export const $lastProgressEvent = atom<InvocationDenoiseProgressEvent | null>(null);
|
||||
export const $isDrawing = atom<boolean>(false);
|
||||
export const $isMouseDown = atom<boolean>(false);
|
||||
export const $lastAddedPoint = atom<Coordinate | null>(null);
|
||||
|
@ -1,5 +1,7 @@
|
||||
import type { IconButtonProps } from '@invoke-ai/ui-library';
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $isConnected } from 'app/hooks/useSocketIO';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@ -12,7 +14,7 @@ type DeleteImageButtonProps = Omit<IconButtonProps, 'aria-label'> & {
|
||||
export const DeleteImageButton = memo((props: DeleteImageButtonProps) => {
|
||||
const { onClick, isDisabled } = props;
|
||||
const { t } = useTranslation();
|
||||
const isConnected = useAppSelector((s) => s.system.isConnected);
|
||||
const isConnected = useStore($isConnected);
|
||||
const imageSelectionLength: number = useAppSelector((s) => s.gallery.selection.length);
|
||||
const labelMessage: string = `${t('gallery.deleteImage', { count: imageSelectionLength })} (Del)`;
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
import { ButtonGroup, IconButton, Menu, MenuButton, MenuList } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { $isConnected } from 'app/hooks/useSocketIO';
|
||||
import { adHocPostProcessingRequested } from 'app/store/middleware/listenerMiddleware/listeners/addAdHocPostProcessingRequestedListener';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { DeleteImageButton } from 'features/deleteImageModal/components/DeleteImageButton';
|
||||
@ -10,17 +10,15 @@ import SingleSelectionMenuItems from 'features/gallery/components/ImageContextMe
|
||||
import { useImageActions } from 'features/gallery/hooks/useImageActions';
|
||||
import { sentImageToImg2Img } from 'features/gallery/store/actions';
|
||||
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
|
||||
import { selectGallerySlice } from 'features/gallery/store/gallerySlice';
|
||||
import { parseAndRecallImageDimensions } from 'features/metadata/util/handlers';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import { PostProcessingPopover } from 'features/parameters/components/PostProcessing/PostProcessingPopover';
|
||||
import { useIsQueueMutationInProgress } from 'features/queue/hooks/useIsQueueMutationInProgress';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { selectSystemSlice } from 'features/system/store/systemSlice';
|
||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||
import { useGetAndLoadEmbeddedWorkflow } from 'features/workflowLibrary/hooks/useGetAndLoadEmbeddedWorkflow';
|
||||
import { size } from 'lodash-es';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
@ -33,23 +31,17 @@ import {
|
||||
PiRulerBold,
|
||||
} from 'react-icons/pi';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
|
||||
const selectShouldDisableToolbarButtons = createSelector(
|
||||
selectSystemSlice,
|
||||
selectGallerySlice,
|
||||
selectLastSelectedImage,
|
||||
(system, gallery, lastSelectedImage) => {
|
||||
const hasProgressImage = Boolean(system.denoiseProgress?.progress_image);
|
||||
return hasProgressImage || !lastSelectedImage;
|
||||
}
|
||||
);
|
||||
import { $progressImage } from 'services/events/setEventListeners';
|
||||
|
||||
const CurrentImageButtons = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const isConnected = useAppSelector((s) => s.system.isConnected);
|
||||
const isConnected = useStore($isConnected);
|
||||
const lastSelectedImage = useAppSelector(selectLastSelectedImage);
|
||||
const progressImage = useStore($progressImage);
|
||||
const selection = useAppSelector((s) => s.gallery.selection);
|
||||
const shouldDisableToolbarButtons = useAppSelector(selectShouldDisableToolbarButtons);
|
||||
const shouldDisableToolbarButtons = useMemo(() => {
|
||||
return Boolean(progressImage) || !lastSelectedImage;
|
||||
}, [lastSelectedImage, progressImage]);
|
||||
const templates = useStore($templates);
|
||||
const isUpscalingEnabled = useFeatureStatus('upscaling');
|
||||
const isQueueMutationInProgress = useIsQueueMutationInProgress();
|
||||
|
@ -1,4 +1,5 @@
|
||||
import { Box, Flex } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
@ -14,6 +15,7 @@ import { memo, useCallback, useMemo, useRef, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiImageBold } from 'react-icons/pi';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import { $hasProgress } from 'services/events/setEventListeners';
|
||||
|
||||
import ProgressImage from './ProgressImage';
|
||||
|
||||
@ -26,7 +28,7 @@ const CurrentImagePreview = () => {
|
||||
const { t } = useTranslation();
|
||||
const shouldShowImageDetails = useAppSelector((s) => s.ui.shouldShowImageDetails);
|
||||
const imageName = useAppSelector(selectLastSelectedImageName);
|
||||
const hasDenoiseProgress = useAppSelector((s) => Boolean(s.system.denoiseProgress));
|
||||
const hasDenoiseProgress = useStore($hasProgress);
|
||||
const shouldShowProgressInViewer = useAppSelector((s) => s.ui.shouldShowProgressInViewer);
|
||||
|
||||
const { currentData: imageDTO } = useGetImageDTOQuery(imageName ?? skipToken);
|
||||
|
@ -1,10 +1,12 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Image } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { $progressImage } from 'services/events/setEventListeners';
|
||||
|
||||
const CurrentImagePreview = () => {
|
||||
const progress_image = useAppSelector((s) => s.system.denoiseProgress?.progress_image);
|
||||
const progressImage = useStore($progressImage);
|
||||
const shouldAntialiasProgressImage = useAppSelector((s) => s.system.shouldAntialiasProgressImage);
|
||||
|
||||
const sx = useMemo<SystemStyleObject>(
|
||||
@ -14,15 +16,15 @@ const CurrentImagePreview = () => {
|
||||
[shouldAntialiasProgressImage]
|
||||
);
|
||||
|
||||
if (!progress_image) {
|
||||
if (!progressImage) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Image
|
||||
src={progress_image.dataURL}
|
||||
width={progress_image.width}
|
||||
height={progress_image.height}
|
||||
src={progressImage.dataURL}
|
||||
width={progressImage.width}
|
||||
height={progressImage.height}
|
||||
draggable={false}
|
||||
data-testid="progress-image"
|
||||
objectFit="contain"
|
||||
|
@ -1,36 +1,33 @@
|
||||
import { Flex, Image, Text } from '@invoke-ai/ui-library';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIDndImage from 'common/components/IAIDndImage';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import NextPrevImageButtons from 'features/gallery/components/NextPrevImageButtons';
|
||||
import { selectGallerySlice } from 'features/gallery/store/gallerySlice';
|
||||
import NodeWrapper from 'features/nodes/components/flow/nodes/common/NodeWrapper';
|
||||
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
|
||||
import { selectSystemSlice } from 'features/system/store/systemSlice';
|
||||
import type { AnimationProps } from 'framer-motion';
|
||||
import { motion } from 'framer-motion';
|
||||
import type { CSSProperties, PropsWithChildren } from 'react';
|
||||
import { memo, useCallback, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { NodeProps } from 'reactflow';
|
||||
|
||||
const selector = createMemoizedSelector(selectSystemSlice, selectGallerySlice, (system, gallery) => {
|
||||
const imageDTO = gallery.selection[gallery.selection.length - 1];
|
||||
|
||||
return {
|
||||
imageDTO,
|
||||
progressImage: system.denoiseProgress?.progress_image,
|
||||
};
|
||||
});
|
||||
import { $lastProgressEvent } from 'services/events/setEventListeners';
|
||||
|
||||
const CurrentImageNode = (props: NodeProps) => {
|
||||
const { progressImage, imageDTO } = useAppSelector(selector);
|
||||
const imageDTO = useAppSelector((s) => s.gallery.selection[s.gallery.selection.length - 1]);
|
||||
const lastProgressEvent = useStore($lastProgressEvent);
|
||||
|
||||
if (progressImage) {
|
||||
if (lastProgressEvent?.progress_image) {
|
||||
return (
|
||||
<Wrapper nodeProps={props}>
|
||||
<Image src={progressImage.dataURL} w="full" h="full" objectFit="contain" borderRadius="base" />
|
||||
<Image
|
||||
src={lastProgressEvent?.progress_image.dataURL}
|
||||
w="full"
|
||||
h="full"
|
||||
objectFit="contain"
|
||||
borderRadius="base"
|
||||
/>
|
||||
</Wrapper>
|
||||
);
|
||||
}
|
||||
|
@ -1,6 +1,8 @@
|
||||
import { Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { $isConnected } from 'app/hooks/useSocketIO';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIDndImage from 'common/components/IAIDndImage';
|
||||
import IAIDndImageIcon from 'common/components/IAIDndImageIcon';
|
||||
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
|
||||
@ -17,7 +19,7 @@ import type { FieldComponentProps } from './types';
|
||||
const ImageFieldInputComponent = (props: FieldComponentProps<ImageFieldInputInstance, ImageFieldInputTemplate>) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const isConnected = useAppSelector((s) => s.system.isConnected);
|
||||
const isConnected = useStore($isConnected);
|
||||
const { currentData: imageDTO, isError } = useGetImageDTOQuery(field.value?.image_name ?? skipToken);
|
||||
|
||||
const handleReset = useCallback(() => {
|
||||
|
@ -1,11 +1,12 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $isConnected } from 'app/hooks/useSocketIO';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useCancelByBatchIdsMutation, useGetBatchStatusQuery } from 'services/api/endpoints/queue';
|
||||
|
||||
export const useCancelBatch = (batch_id: string) => {
|
||||
const isConnected = useAppSelector((s) => s.system.isConnected);
|
||||
const isConnected = useStore($isConnected);
|
||||
const { isCanceled } = useGetBatchStatusQuery(
|
||||
{ batch_id },
|
||||
{
|
||||
|
@ -1,4 +1,5 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $isConnected } from 'app/hooks/useSocketIO';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { isNil } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
@ -6,7 +7,7 @@ import { useTranslation } from 'react-i18next';
|
||||
import { useCancelQueueItemMutation, useGetQueueStatusQuery } from 'services/api/endpoints/queue';
|
||||
|
||||
export const useCancelCurrentQueueItem = () => {
|
||||
const isConnected = useAppSelector((s) => s.system.isConnected);
|
||||
const isConnected = useStore($isConnected);
|
||||
const { data: queueStatus } = useGetQueueStatusQuery();
|
||||
const [trigger, { isLoading }] = useCancelQueueItemMutation();
|
||||
const { t } = useTranslation();
|
||||
|
@ -1,11 +1,12 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $isConnected } from 'app/hooks/useSocketIO';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useCancelQueueItemMutation } from 'services/api/endpoints/queue';
|
||||
|
||||
export const useCancelQueueItem = (item_id: number) => {
|
||||
const isConnected = useAppSelector((s) => s.system.isConnected);
|
||||
const isConnected = useStore($isConnected);
|
||||
const [trigger, { isLoading }] = useCancelQueueItemMutation();
|
||||
const { t } = useTranslation();
|
||||
const cancelQueueItem = useCallback(async () => {
|
||||
|
@ -1,4 +1,5 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $isConnected } from 'app/hooks/useSocketIO';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@ -7,7 +8,7 @@ import { useClearInvocationCacheMutation, useGetInvocationCacheStatusQuery } fro
|
||||
export const useClearInvocationCache = () => {
|
||||
const { t } = useTranslation();
|
||||
const { data: cacheStatus } = useGetInvocationCacheStatusQuery();
|
||||
const isConnected = useAppSelector((s) => s.system.isConnected);
|
||||
const isConnected = useStore($isConnected);
|
||||
const [trigger, { isLoading }] = useClearInvocationCacheMutation({
|
||||
fixedCacheKey: 'clearInvocationCache',
|
||||
});
|
||||
|
@ -1,4 +1,6 @@
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $isConnected } from 'app/hooks/useSocketIO';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { listCursorChanged, listPriorityChanged } from 'features/queue/store/queueSlice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
@ -9,7 +11,7 @@ export const useClearQueue = () => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const { data: queueStatus } = useGetQueueStatusQuery();
|
||||
const isConnected = useAppSelector((s) => s.system.isConnected);
|
||||
const isConnected = useStore($isConnected);
|
||||
const [trigger, { isLoading }] = useClearQueueMutation({
|
||||
fixedCacheKey: 'clearQueue',
|
||||
});
|
||||
|
@ -1,4 +1,5 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $isConnected } from 'app/hooks/useSocketIO';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@ -7,7 +8,7 @@ import { useDisableInvocationCacheMutation, useGetInvocationCacheStatusQuery } f
|
||||
export const useDisableInvocationCache = () => {
|
||||
const { t } = useTranslation();
|
||||
const { data: cacheStatus } = useGetInvocationCacheStatusQuery();
|
||||
const isConnected = useAppSelector((s) => s.system.isConnected);
|
||||
const isConnected = useStore($isConnected);
|
||||
const [trigger, { isLoading }] = useDisableInvocationCacheMutation({
|
||||
fixedCacheKey: 'disableInvocationCache',
|
||||
});
|
||||
|
@ -1,4 +1,5 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $isConnected } from 'app/hooks/useSocketIO';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@ -7,7 +8,7 @@ import { useEnableInvocationCacheMutation, useGetInvocationCacheStatusQuery } fr
|
||||
export const useEnableInvocationCache = () => {
|
||||
const { t } = useTranslation();
|
||||
const { data: cacheStatus } = useGetInvocationCacheStatusQuery();
|
||||
const isConnected = useAppSelector((s) => s.system.isConnected);
|
||||
const isConnected = useStore($isConnected);
|
||||
const [trigger, { isLoading }] = useEnableInvocationCacheMutation({
|
||||
fixedCacheKey: 'enableInvocationCache',
|
||||
});
|
||||
|
@ -1,4 +1,5 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $isConnected } from 'app/hooks/useSocketIO';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@ -6,7 +7,7 @@ import { useGetQueueStatusQuery, usePauseProcessorMutation } from 'services/api/
|
||||
|
||||
export const usePauseProcessor = () => {
|
||||
const { t } = useTranslation();
|
||||
const isConnected = useAppSelector((s) => s.system.isConnected);
|
||||
const isConnected = useStore($isConnected);
|
||||
const { data: queueStatus } = useGetQueueStatusQuery();
|
||||
const [trigger, { isLoading }] = usePauseProcessorMutation({
|
||||
fixedCacheKey: 'pauseProcessor',
|
||||
|
@ -1,4 +1,6 @@
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $isConnected } from 'app/hooks/useSocketIO';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { listCursorChanged, listPriorityChanged } from 'features/queue/store/queueSlice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
@ -8,7 +10,7 @@ import { useGetQueueStatusQuery, usePruneQueueMutation } from 'services/api/endp
|
||||
export const usePruneQueue = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const isConnected = useAppSelector((s) => s.system.isConnected);
|
||||
const isConnected = useStore($isConnected);
|
||||
const [trigger, { isLoading }] = usePruneQueueMutation({
|
||||
fixedCacheKey: 'pruneQueue',
|
||||
});
|
||||
|
@ -1,11 +1,12 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $isConnected } from 'app/hooks/useSocketIO';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetQueueStatusQuery, useResumeProcessorMutation } from 'services/api/endpoints/queue';
|
||||
|
||||
export const useResumeProcessor = () => {
|
||||
const isConnected = useAppSelector((s) => s.system.isConnected);
|
||||
const isConnected = useStore($isConnected);
|
||||
const { data: queueStatus } = useGetQueueStatusQuery();
|
||||
const { t } = useTranslation();
|
||||
const [trigger, { isLoading }] = useResumeProcessorMutation({
|
||||
|
@ -1,28 +1,28 @@
|
||||
import { Progress } from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectSystemSlice } from 'features/system/store/systemSlice';
|
||||
import { memo } from 'react';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $isConnected } from 'app/hooks/useSocketIO';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetQueueStatusQuery } from 'services/api/endpoints/queue';
|
||||
|
||||
const selectProgressValue = createSelector(
|
||||
selectSystemSlice,
|
||||
(system) => (system.denoiseProgress?.percentage ?? 0) * 100
|
||||
);
|
||||
import { $lastProgressEvent } from 'services/events/setEventListeners';
|
||||
|
||||
const ProgressBar = () => {
|
||||
const { t } = useTranslation();
|
||||
const { data: queueStatus } = useGetQueueStatusQuery();
|
||||
const isConnected = useAppSelector((s) => s.system.isConnected);
|
||||
const hasSteps = useAppSelector((s) => Boolean(s.system.denoiseProgress));
|
||||
const value = useAppSelector(selectProgressValue);
|
||||
const isConnected = useStore($isConnected);
|
||||
const lastProgressEvent = useStore($lastProgressEvent);
|
||||
const value = useMemo(() => {
|
||||
if (!lastProgressEvent) {
|
||||
return 0;
|
||||
}
|
||||
return (lastProgressEvent.percentage ?? 0) * 100;
|
||||
}, [lastProgressEvent]);
|
||||
|
||||
return (
|
||||
<Progress
|
||||
value={value}
|
||||
aria-label={t('accessibility.invokeProgressBar')}
|
||||
isIndeterminate={isConnected && Boolean(queueStatus?.queue.in_progress) && !hasSteps}
|
||||
isIndeterminate={isConnected && Boolean(queueStatus?.queue.in_progress) && !lastProgressEvent}
|
||||
h={2}
|
||||
w="full"
|
||||
colorScheme="invokeBlue"
|
||||
|
@ -1,11 +1,12 @@
|
||||
import { Icon, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $isConnected } from 'app/hooks/useSocketIO';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiWarningBold } from 'react-icons/pi';
|
||||
|
||||
const StatusIndicator = () => {
|
||||
const isConnected = useAppSelector((s) => s.system.isConnected);
|
||||
const isConnected = useStore($isConnected);
|
||||
const { t } = useTranslation();
|
||||
|
||||
if (!isConnected) {
|
||||
|
@ -1,136 +0,0 @@
|
||||
import { $baseUrl } from 'app/store/nanostores/baseUrl';
|
||||
import { $bulkDownloadId } from 'app/store/nanostores/bulkDownloadId';
|
||||
import { $queueId } from 'app/store/nanostores/queueId';
|
||||
import type { AppDispatch } from 'app/store/store';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import {
|
||||
socketBatchEnqueued,
|
||||
socketBulkDownloadComplete,
|
||||
socketBulkDownloadError,
|
||||
socketBulkDownloadStarted,
|
||||
socketConnected,
|
||||
socketDisconnected,
|
||||
socketDownloadCancelled,
|
||||
socketDownloadComplete,
|
||||
socketDownloadError,
|
||||
socketDownloadProgress,
|
||||
socketDownloadStarted,
|
||||
socketGeneratorProgress,
|
||||
socketInvocationComplete,
|
||||
socketInvocationError,
|
||||
socketInvocationStarted,
|
||||
socketModelInstallCancelled,
|
||||
socketModelInstallComplete,
|
||||
socketModelInstallDownloadProgress,
|
||||
socketModelInstallDownloadsComplete,
|
||||
socketModelInstallError,
|
||||
socketModelInstallStarted,
|
||||
socketModelLoadComplete,
|
||||
socketModelLoadStarted,
|
||||
socketQueueCleared,
|
||||
socketQueueItemStatusChanged,
|
||||
} from 'services/events/actions';
|
||||
import type { ClientToServerEvents, ServerToClientEvents } from 'services/events/types';
|
||||
import type { Socket } from 'socket.io-client';
|
||||
|
||||
type SetEventListenersArg = {
|
||||
socket: Socket<ServerToClientEvents, ClientToServerEvents>;
|
||||
dispatch: AppDispatch;
|
||||
};
|
||||
|
||||
export const setEventListeners = ({ socket, dispatch }: SetEventListenersArg) => {
|
||||
socket.on('connect', () => {
|
||||
dispatch(socketConnected());
|
||||
const queue_id = $queueId.get();
|
||||
socket.emit('subscribe_queue', { queue_id });
|
||||
if (!$baseUrl.get()) {
|
||||
const bulk_download_id = $bulkDownloadId.get();
|
||||
socket.emit('subscribe_bulk_download', { bulk_download_id });
|
||||
}
|
||||
});
|
||||
socket.on('connect_error', (error) => {
|
||||
if (error && error.message) {
|
||||
const data: string | undefined = (error as unknown as { data: string | undefined }).data;
|
||||
if (data === 'ERR_UNAUTHENTICATED') {
|
||||
toast({
|
||||
id: `connect-error-${error.message}`,
|
||||
title: error.message,
|
||||
status: 'error',
|
||||
duration: 10000,
|
||||
});
|
||||
}
|
||||
}
|
||||
});
|
||||
socket.on('disconnect', () => {
|
||||
dispatch(socketDisconnected());
|
||||
});
|
||||
socket.on('invocation_started', (data) => {
|
||||
dispatch(socketInvocationStarted({ data }));
|
||||
});
|
||||
socket.on('invocation_denoise_progress', (data) => {
|
||||
dispatch(socketGeneratorProgress({ data }));
|
||||
});
|
||||
socket.on('invocation_error', (data) => {
|
||||
dispatch(socketInvocationError({ data }));
|
||||
});
|
||||
socket.on('invocation_complete', (data) => {
|
||||
dispatch(socketInvocationComplete({ data }));
|
||||
});
|
||||
socket.on('model_load_started', (data) => {
|
||||
dispatch(socketModelLoadStarted({ data }));
|
||||
});
|
||||
socket.on('model_load_complete', (data) => {
|
||||
dispatch(socketModelLoadComplete({ data }));
|
||||
});
|
||||
socket.on('download_started', (data) => {
|
||||
dispatch(socketDownloadStarted({ data }));
|
||||
});
|
||||
socket.on('download_progress', (data) => {
|
||||
dispatch(socketDownloadProgress({ data }));
|
||||
});
|
||||
socket.on('download_complete', (data) => {
|
||||
dispatch(socketDownloadComplete({ data }));
|
||||
});
|
||||
socket.on('download_cancelled', (data) => {
|
||||
dispatch(socketDownloadCancelled({ data }));
|
||||
});
|
||||
socket.on('download_error', (data) => {
|
||||
dispatch(socketDownloadError({ data }));
|
||||
});
|
||||
socket.on('model_install_started', (data) => {
|
||||
dispatch(socketModelInstallStarted({ data }));
|
||||
});
|
||||
socket.on('model_install_download_progress', (data) => {
|
||||
dispatch(socketModelInstallDownloadProgress({ data }));
|
||||
});
|
||||
socket.on('model_install_downloads_complete', (data) => {
|
||||
dispatch(socketModelInstallDownloadsComplete({ data }));
|
||||
});
|
||||
socket.on('model_install_complete', (data) => {
|
||||
dispatch(socketModelInstallComplete({ data }));
|
||||
});
|
||||
socket.on('model_install_error', (data) => {
|
||||
dispatch(socketModelInstallError({ data }));
|
||||
});
|
||||
socket.on('model_install_cancelled', (data) => {
|
||||
dispatch(socketModelInstallCancelled({ data }));
|
||||
});
|
||||
socket.on('queue_item_status_changed', (data) => {
|
||||
dispatch(socketQueueItemStatusChanged({ data }));
|
||||
});
|
||||
socket.on('queue_cleared', (data) => {
|
||||
dispatch(socketQueueCleared({ data }));
|
||||
});
|
||||
socket.on('batch_enqueued', (data) => {
|
||||
dispatch(socketBatchEnqueued({ data }));
|
||||
});
|
||||
socket.on('bulk_download_started', (data) => {
|
||||
dispatch(socketBulkDownloadStarted({ data }));
|
||||
});
|
||||
socket.on('bulk_download_complete', (data) => {
|
||||
dispatch(socketBulkDownloadComplete({ data }));
|
||||
});
|
||||
socket.on('bulk_download_error', (data) => {
|
||||
dispatch(socketBulkDownloadError({ data }));
|
||||
});
|
||||
};
|
621
invokeai/frontend/web/src/services/events/setEventListeners.tsx
Normal file
621
invokeai/frontend/web/src/services/events/setEventListeners.tsx
Normal file
@ -0,0 +1,621 @@
|
||||
import { ExternalLink } from '@invoke-ai/ui-library';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { $baseUrl } from 'app/store/nanostores/baseUrl';
|
||||
import { $bulkDownloadId } from 'app/store/nanostores/bulkDownloadId';
|
||||
import { $queueId } from 'app/store/nanostores/queueId';
|
||||
import type { AppDispatch, RootState } from 'app/store/store';
|
||||
import type { SerializableObject } from 'common/types';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { sessionImageStaged } from 'features/controlLayers/store/canvasV2Slice';
|
||||
import { boardIdSelected, galleryViewChanged, imageSelected, offsetChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import ErrorToastDescription, { getTitleFromErrorType } from 'features/toast/ErrorToastDescription';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { atom, computed } from 'nanostores';
|
||||
import { api, LIST_TAG } from 'services/api';
|
||||
import { boardsApi } from 'services/api/endpoints/boards';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { modelsApi } from 'services/api/endpoints/models';
|
||||
import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue';
|
||||
import { getCategories, getListImagesUrl } from 'services/api/util';
|
||||
import { socketConnected } from 'services/events/actions';
|
||||
import type { ClientToServerEvents, InvocationDenoiseProgressEvent, ServerToClientEvents } from 'services/events/types';
|
||||
import type { Socket } from 'socket.io-client';
|
||||
|
||||
const log = logger('socketio');
|
||||
|
||||
type SetEventListenersArg = {
|
||||
socket: Socket<ServerToClientEvents, ClientToServerEvents>;
|
||||
dispatch: AppDispatch;
|
||||
getState: () => RootState;
|
||||
setIsConnected: (isConnected: boolean) => void;
|
||||
};
|
||||
|
||||
const selectModelInstalls = modelsApi.endpoints.listModelInstalls.select();
|
||||
const nodeTypeDenylist = ['load_image', 'image'];
|
||||
export const $lastProgressEvent = atom<InvocationDenoiseProgressEvent | null>(null);
|
||||
export const $lastCanvasProgressEvent = atom<InvocationDenoiseProgressEvent | null>(null);
|
||||
export const $hasProgress = computed($lastProgressEvent, (val) => Boolean(val));
|
||||
export const $progressImage = computed($lastProgressEvent, (val) => val?.progress_image ?? null);
|
||||
const cancellations = new Set<string>();
|
||||
|
||||
export const setEventListeners = ({ socket, dispatch, getState, setIsConnected }: SetEventListenersArg) => {
|
||||
socket.on('connect', () => {
|
||||
log.debug('Connected');
|
||||
setIsConnected(true);
|
||||
dispatch(socketConnected());
|
||||
const queue_id = $queueId.get();
|
||||
socket.emit('subscribe_queue', { queue_id });
|
||||
if (!$baseUrl.get()) {
|
||||
const bulk_download_id = $bulkDownloadId.get();
|
||||
socket.emit('subscribe_bulk_download', { bulk_download_id });
|
||||
}
|
||||
$lastProgressEvent.set(null);
|
||||
$lastCanvasProgressEvent.set(null);
|
||||
cancellations.clear();
|
||||
});
|
||||
|
||||
socket.on('connect_error', (error) => {
|
||||
log.debug('Connect error');
|
||||
setIsConnected(false);
|
||||
$lastProgressEvent.set(null);
|
||||
$lastCanvasProgressEvent.set(null);
|
||||
if (error && error.message) {
|
||||
const data: string | undefined = (error as unknown as { data: string | undefined }).data;
|
||||
if (data === 'ERR_UNAUTHENTICATED') {
|
||||
toast({
|
||||
id: `connect-error-${error.message}`,
|
||||
title: error.message,
|
||||
status: 'error',
|
||||
duration: 10000,
|
||||
});
|
||||
}
|
||||
}
|
||||
cancellations.clear();
|
||||
});
|
||||
|
||||
socket.on('disconnect', () => {
|
||||
log.debug('Disconnected');
|
||||
$lastProgressEvent.set(null);
|
||||
$lastCanvasProgressEvent.set(null);
|
||||
setIsConnected(false);
|
||||
cancellations.clear();
|
||||
});
|
||||
|
||||
socket.on('invocation_started', (data) => {
|
||||
const { invocation_source_id, invocation } = data;
|
||||
log.debug({ data } as SerializableObject, `Invocation started (${invocation.type}, ${invocation_source_id})`);
|
||||
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
|
||||
if (nes) {
|
||||
nes.status = zNodeStatus.enum.IN_PROGRESS;
|
||||
upsertExecutionState(nes.nodeId, nes);
|
||||
}
|
||||
cancellations.clear();
|
||||
});
|
||||
|
||||
socket.on('invocation_denoise_progress', (data) => {
|
||||
const { invocation_source_id, invocation, step, total_steps, progress_image, origin, percentage, session_id } =
|
||||
data;
|
||||
|
||||
if (cancellations.has(session_id)) {
|
||||
// Do not update the progress if this session has been cancelled. This prevents a race condition where we get a
|
||||
// progress update after the session has been cancelled.
|
||||
return;
|
||||
}
|
||||
|
||||
log.trace(
|
||||
{ data } as SerializableObject,
|
||||
`Denoise ${Math.round(percentage * 100)}% (${invocation.type}, ${invocation_source_id})`
|
||||
);
|
||||
|
||||
$lastProgressEvent.set(data);
|
||||
|
||||
if (origin === 'workflows') {
|
||||
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
|
||||
if (nes) {
|
||||
nes.status = zNodeStatus.enum.IN_PROGRESS;
|
||||
nes.progress = (step + 1) / total_steps;
|
||||
nes.progressImage = progress_image ?? null;
|
||||
upsertExecutionState(nes.nodeId, nes);
|
||||
}
|
||||
}
|
||||
|
||||
if (origin === 'canvas') {
|
||||
$lastCanvasProgressEvent.set(data);
|
||||
}
|
||||
});
|
||||
|
||||
socket.on('invocation_error', (data) => {
|
||||
const { invocation_source_id, invocation, error_type, error_message, error_traceback } = data;
|
||||
log.error({ data } as SerializableObject, `Invocation error (${invocation.type}, ${invocation_source_id})`);
|
||||
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
|
||||
if (nes) {
|
||||
nes.status = zNodeStatus.enum.FAILED;
|
||||
nes.progress = null;
|
||||
nes.progressImage = null;
|
||||
nes.error = {
|
||||
error_type,
|
||||
error_message,
|
||||
error_traceback,
|
||||
};
|
||||
upsertExecutionState(nes.nodeId, nes);
|
||||
}
|
||||
});
|
||||
|
||||
socket.on('invocation_complete', async (data) => {
|
||||
log.debug(
|
||||
{ data } as SerializableObject,
|
||||
`Invocation complete (${data.invocation.type}, ${data.invocation_source_id})`
|
||||
);
|
||||
|
||||
const { result, invocation_source_id } = data;
|
||||
|
||||
if (data.origin === 'workflows') {
|
||||
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
|
||||
if (nes) {
|
||||
nes.status = zNodeStatus.enum.COMPLETED;
|
||||
if (nes.progress !== null) {
|
||||
nes.progress = 1;
|
||||
}
|
||||
nes.outputs.push(result);
|
||||
upsertExecutionState(nes.nodeId, nes);
|
||||
}
|
||||
}
|
||||
|
||||
// This complete event has an associated image output
|
||||
if (
|
||||
(data.result.type === 'image_output' || data.result.type === 'canvas_v2_mask_and_crop_output') &&
|
||||
!nodeTypeDenylist.includes(data.invocation.type)
|
||||
) {
|
||||
const { image_name } = data.result.image;
|
||||
const { gallery, canvasV2 } = getState();
|
||||
|
||||
// This populates the `getImageDTO` cache
|
||||
const imageDTORequest = dispatch(
|
||||
imagesApi.endpoints.getImageDTO.initiate(image_name, {
|
||||
forceRefetch: true,
|
||||
})
|
||||
);
|
||||
|
||||
const imageDTO = await imageDTORequest.unwrap();
|
||||
imageDTORequest.unsubscribe();
|
||||
|
||||
// handle tab-specific logic
|
||||
if (data.origin === 'canvas' && data.invocation_source_id === 'canvas_output') {
|
||||
if (data.result.type === 'canvas_v2_mask_and_crop_output') {
|
||||
const { offset_x, offset_y } = data.result;
|
||||
if (canvasV2.session.isStaging) {
|
||||
dispatch(sessionImageStaged({ stagingAreaImage: { imageDTO, offsetX: offset_x, offsetY: offset_y } }));
|
||||
}
|
||||
} else if (data.result.type === 'image_output') {
|
||||
if (canvasV2.session.isStaging) {
|
||||
dispatch(sessionImageStaged({ stagingAreaImage: { imageDTO, offsetX: 0, offsetY: 0 } }));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!imageDTO.is_intermediate) {
|
||||
// update the total images for the board
|
||||
dispatch(
|
||||
boardsApi.util.updateQueryData('getBoardImagesTotal', imageDTO.board_id ?? 'none', (draft) => {
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
draft.total += 1;
|
||||
})
|
||||
);
|
||||
|
||||
dispatch(
|
||||
imagesApi.util.invalidateTags([
|
||||
{ type: 'Board', id: imageDTO.board_id ?? 'none' },
|
||||
{
|
||||
type: 'ImageList',
|
||||
id: getListImagesUrl({
|
||||
board_id: imageDTO.board_id ?? 'none',
|
||||
categories: getCategories(imageDTO),
|
||||
}),
|
||||
},
|
||||
])
|
||||
);
|
||||
|
||||
const { shouldAutoSwitch } = gallery;
|
||||
|
||||
// If auto-switch is enabled, select the new image
|
||||
if (shouldAutoSwitch) {
|
||||
// if auto-add is enabled, switch the gallery view and board if needed as the image comes in
|
||||
if (gallery.galleryView !== 'images') {
|
||||
dispatch(galleryViewChanged('images'));
|
||||
}
|
||||
|
||||
if (imageDTO.board_id && imageDTO.board_id !== gallery.selectedBoardId) {
|
||||
dispatch(
|
||||
boardIdSelected({
|
||||
boardId: imageDTO.board_id,
|
||||
selectedImageName: imageDTO.image_name,
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
dispatch(offsetChanged({ offset: 0 }));
|
||||
|
||||
if (!imageDTO.board_id && gallery.selectedBoardId !== 'none') {
|
||||
dispatch(
|
||||
boardIdSelected({
|
||||
boardId: 'none',
|
||||
selectedImageName: imageDTO.image_name,
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
dispatch(imageSelected(imageDTO));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
$lastProgressEvent.set(null);
|
||||
});
|
||||
|
||||
socket.on('model_load_started', (data) => {
|
||||
const { config, submodel_type } = data;
|
||||
const { name, base, type } = config;
|
||||
|
||||
const extras: string[] = [base, type];
|
||||
|
||||
if (submodel_type) {
|
||||
extras.push(submodel_type);
|
||||
}
|
||||
|
||||
const message = `Model load started: ${name} (${extras.join(', ')})`;
|
||||
|
||||
log.debug({ data }, message);
|
||||
});
|
||||
|
||||
socket.on('model_load_complete', (data) => {
|
||||
const { config, submodel_type } = data;
|
||||
const { name, base, type } = config;
|
||||
|
||||
const extras: string[] = [base, type];
|
||||
if (submodel_type) {
|
||||
extras.push(submodel_type);
|
||||
}
|
||||
|
||||
const message = `Model load complete: ${name} (${extras.join(', ')})`;
|
||||
|
||||
log.debug({ data }, message);
|
||||
});
|
||||
|
||||
socket.on('download_started', (data) => {
|
||||
log.debug({ data }, 'Download started');
|
||||
});
|
||||
|
||||
socket.on('download_progress', (data) => {
|
||||
log.trace({ data }, 'Download progress');
|
||||
});
|
||||
|
||||
socket.on('download_complete', (data) => {
|
||||
log.debug({ data }, 'Download complete');
|
||||
});
|
||||
|
||||
socket.on('download_cancelled', (data) => {
|
||||
log.warn({ data }, 'Download cancelled');
|
||||
});
|
||||
|
||||
socket.on('download_error', (data) => {
|
||||
log.error({ data }, 'Download error');
|
||||
});
|
||||
|
||||
socket.on('model_install_started', (data) => {
|
||||
log.debug({ data }, 'Model install started');
|
||||
|
||||
const { id } = data;
|
||||
const installs = selectModelInstalls(getState()).data;
|
||||
|
||||
if (!installs?.find((install) => install.id === id)) {
|
||||
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
|
||||
} else {
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||
const modelImport = draft.find((m) => m.id === id);
|
||||
if (modelImport) {
|
||||
modelImport.status = 'running';
|
||||
}
|
||||
return draft;
|
||||
})
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
socket.on('model_install_download_started', (data) => {
|
||||
log.debug({ data }, 'Model install download started');
|
||||
|
||||
const { id } = data;
|
||||
const installs = selectModelInstalls(getState()).data;
|
||||
|
||||
if (!installs?.find((install) => install.id === id)) {
|
||||
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
|
||||
} else {
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||
const modelImport = draft.find((m) => m.id === id);
|
||||
if (modelImport) {
|
||||
modelImport.status = 'downloading';
|
||||
}
|
||||
return draft;
|
||||
})
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
socket.on('model_install_download_progress', (data) => {
|
||||
log.trace({ data }, 'Model install download progress');
|
||||
|
||||
const { bytes, total_bytes, id } = data;
|
||||
const installs = selectModelInstalls(getState()).data;
|
||||
|
||||
if (!installs?.find((install) => install.id === id)) {
|
||||
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
|
||||
} else {
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||
const modelImport = draft.find((m) => m.id === id);
|
||||
if (modelImport) {
|
||||
modelImport.bytes = bytes;
|
||||
modelImport.total_bytes = total_bytes;
|
||||
modelImport.status = 'downloading';
|
||||
}
|
||||
return draft;
|
||||
})
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
socket.on('model_install_downloads_complete', (data) => {
|
||||
log.debug({ data }, 'Model install downloads complete');
|
||||
|
||||
const { id } = data;
|
||||
const installs = selectModelInstalls(getState()).data;
|
||||
|
||||
if (!installs?.find((install) => install.id === id)) {
|
||||
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
|
||||
} else {
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||
const modelImport = draft.find((m) => m.id === id);
|
||||
if (modelImport) {
|
||||
modelImport.status = 'downloads_done';
|
||||
}
|
||||
return draft;
|
||||
})
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
socket.on('model_install_complete', (data) => {
|
||||
log.debug({ data }, 'Model install complete');
|
||||
|
||||
const { id } = data;
|
||||
|
||||
const installs = selectModelInstalls(getState()).data;
|
||||
|
||||
if (!installs?.find((install) => install.id === id)) {
|
||||
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
|
||||
} else {
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||
const modelImport = draft.find((m) => m.id === id);
|
||||
if (modelImport) {
|
||||
modelImport.status = 'completed';
|
||||
}
|
||||
return draft;
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
dispatch(api.util.invalidateTags([{ type: 'ModelConfig', id: LIST_TAG }]));
|
||||
dispatch(api.util.invalidateTags([{ type: 'ModelScanFolderResults', id: LIST_TAG }]));
|
||||
});
|
||||
|
||||
socket.on('model_install_error', (data) => {
|
||||
log.error({ data }, 'Model install error');
|
||||
|
||||
const { id, error, error_type } = data;
|
||||
const installs = selectModelInstalls(getState()).data;
|
||||
|
||||
if (!installs?.find((install) => install.id === id)) {
|
||||
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
|
||||
} else {
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||
const modelImport = draft.find((m) => m.id === id);
|
||||
if (modelImport) {
|
||||
modelImport.status = 'error';
|
||||
modelImport.error_reason = error_type;
|
||||
modelImport.error = error;
|
||||
}
|
||||
return draft;
|
||||
})
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
socket.on('model_install_cancelled', (data) => {
|
||||
log.warn({ data }, 'Model install cancelled');
|
||||
|
||||
const { id } = data;
|
||||
const installs = selectModelInstalls(getState()).data;
|
||||
|
||||
if (!installs?.find((install) => install.id === id)) {
|
||||
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
|
||||
} else {
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||
const modelImport = draft.find((m) => m.id === id);
|
||||
if (modelImport) {
|
||||
modelImport.status = 'cancelled';
|
||||
}
|
||||
return draft;
|
||||
})
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
socket.on('queue_item_status_changed', (data) => {
|
||||
// we've got new status for the queue item, batch and queue
|
||||
const {
|
||||
item_id,
|
||||
session_id,
|
||||
status,
|
||||
started_at,
|
||||
updated_at,
|
||||
completed_at,
|
||||
batch_status,
|
||||
queue_status,
|
||||
error_type,
|
||||
error_message,
|
||||
error_traceback,
|
||||
origin,
|
||||
} = data;
|
||||
|
||||
log.debug({ data }, `Queue item ${item_id} status updated: ${status}`);
|
||||
|
||||
// Update this specific queue item in the list of queue items (this is the queue item DTO, without the session)
|
||||
dispatch(
|
||||
queueApi.util.updateQueryData('listQueueItems', undefined, (draft) => {
|
||||
queueItemsAdapter.updateOne(draft, {
|
||||
id: String(item_id),
|
||||
changes: {
|
||||
status,
|
||||
started_at,
|
||||
updated_at: updated_at ?? undefined,
|
||||
completed_at: completed_at ?? undefined,
|
||||
error_type,
|
||||
error_message,
|
||||
error_traceback,
|
||||
},
|
||||
});
|
||||
})
|
||||
);
|
||||
|
||||
// Update the queue status (we do not get the processor status here)
|
||||
dispatch(
|
||||
queueApi.util.updateQueryData('getQueueStatus', undefined, (draft) => {
|
||||
if (!draft) {
|
||||
return;
|
||||
}
|
||||
Object.assign(draft.queue, queue_status);
|
||||
})
|
||||
);
|
||||
|
||||
// Update the batch status
|
||||
dispatch(queueApi.util.updateQueryData('getBatchStatus', { batch_id: batch_status.batch_id }, () => batch_status));
|
||||
|
||||
// Invalidate caches for things we cannot update
|
||||
// TODO: technically, we could possibly update the current session queue item, but feels safer to just request it again
|
||||
dispatch(
|
||||
queueApi.util.invalidateTags([
|
||||
'CurrentSessionQueueItem',
|
||||
'NextSessionQueueItem',
|
||||
'InvocationCacheStatus',
|
||||
{ type: 'SessionQueueItem', id: item_id },
|
||||
])
|
||||
);
|
||||
|
||||
if (status === 'in_progress') {
|
||||
forEach($nodeExecutionStates.get(), (nes) => {
|
||||
if (!nes) {
|
||||
return;
|
||||
}
|
||||
const clone = deepClone(nes);
|
||||
clone.status = zNodeStatus.enum.PENDING;
|
||||
clone.error = null;
|
||||
clone.progress = null;
|
||||
clone.progressImage = null;
|
||||
clone.outputs = [];
|
||||
$nodeExecutionStates.setKey(clone.nodeId, clone);
|
||||
});
|
||||
} else if (status === 'failed' && error_type) {
|
||||
const isLocal = getState().config.isLocal ?? true;
|
||||
const sessionId = session_id;
|
||||
$lastProgressEvent.set(null);
|
||||
|
||||
if (origin === 'canvas') {
|
||||
$lastCanvasProgressEvent.set(null);
|
||||
}
|
||||
|
||||
toast({
|
||||
id: `INVOCATION_ERROR_${error_type}`,
|
||||
title: getTitleFromErrorType(error_type),
|
||||
status: 'error',
|
||||
duration: null,
|
||||
updateDescription: isLocal,
|
||||
description: (
|
||||
<ErrorToastDescription
|
||||
errorType={error_type}
|
||||
errorMessage={error_message}
|
||||
sessionId={sessionId}
|
||||
isLocal={isLocal}
|
||||
/>
|
||||
),
|
||||
});
|
||||
cancellations.add(session_id);
|
||||
} else if (status === 'canceled') {
|
||||
$lastProgressEvent.set(null);
|
||||
if (origin === 'canvas') {
|
||||
$lastCanvasProgressEvent.set(null);
|
||||
}
|
||||
cancellations.add(session_id);
|
||||
} else if (status === 'completed') {
|
||||
$lastProgressEvent.set(null);
|
||||
cancellations.add(session_id);
|
||||
}
|
||||
});
|
||||
|
||||
socket.on('queue_cleared', (data) => {
|
||||
log.debug({ data }, 'Queue cleared');
|
||||
});
|
||||
|
||||
socket.on('batch_enqueued', (data) => {
|
||||
log.debug({ data }, 'Batch enqueued');
|
||||
});
|
||||
|
||||
socket.on('bulk_download_started', (data) => {
|
||||
log.debug({ data }, 'Bulk gallery download preparation started');
|
||||
});
|
||||
|
||||
socket.on('bulk_download_complete', (data) => {
|
||||
log.debug({ data }, 'Bulk gallery download ready');
|
||||
const { bulk_download_item_name } = data;
|
||||
|
||||
// TODO(psyche): This URL may break in in some environments (e.g. Nvidia workbench) but we need to test it first
|
||||
const url = `/api/v1/images/download/${bulk_download_item_name}`;
|
||||
|
||||
toast({
|
||||
id: bulk_download_item_name,
|
||||
title: t('gallery.bulkDownloadReady', 'Download ready'),
|
||||
status: 'success',
|
||||
description: (
|
||||
<ExternalLink
|
||||
label={t('gallery.clickToDownload', 'Click here to download')}
|
||||
href={url}
|
||||
download={bulk_download_item_name}
|
||||
/>
|
||||
),
|
||||
duration: null,
|
||||
});
|
||||
});
|
||||
|
||||
socket.on('bulk_download_error', (data) => {
|
||||
log.error({ data }, 'Bulk gallery download error');
|
||||
|
||||
const { bulk_download_item_name, error } = data;
|
||||
|
||||
toast({
|
||||
id: bulk_download_item_name,
|
||||
title: t('gallery.bulkDownloadFailed'),
|
||||
status: 'error',
|
||||
description: error,
|
||||
duration: null,
|
||||
});
|
||||
});
|
||||
};
|
Loading…
Reference in New Issue
Block a user