diff --git a/invokeai/frontend/web/src/app/hooks/useSocketIO.ts b/invokeai/frontend/web/src/app/hooks/useSocketIO.ts index 8a530b8229..89cb1ae172 100644 --- a/invokeai/frontend/web/src/app/hooks/useSocketIO.ts +++ b/invokeai/frontend/web/src/app/hooks/useSocketIO.ts @@ -2,7 +2,7 @@ import { useStore } from '@nanostores/react'; import { $authToken } from 'app/store/nanostores/authToken'; import { $baseUrl } from 'app/store/nanostores/baseUrl'; import { $isDebugging } from 'app/store/nanostores/isDebugging'; -import { useAppDispatch } from 'app/store/storeHooks'; +import { useAppStore } from 'app/store/nanostores/store'; import type { MapStore } from 'nanostores'; import { atom, map } from 'nanostores'; import { useEffect, useMemo } from 'react'; @@ -28,13 +28,15 @@ export const getSocket = () => { return socket; }; export const $socketOptions = map>({}); + const $isSocketInitialized = atom(false); +export const $isConnected = atom(false); /** * Initializes the socket.io connection and sets up event listeners. */ export const useSocketIO = () => { - const dispatch = useAppDispatch(); + const { dispatch, getState } = useAppStore(); const baseUrl = useStore($baseUrl); const authToken = useStore($authToken); const addlSocketOptions = useStore($socketOptions); @@ -72,7 +74,7 @@ export const useSocketIO = () => { const socket: AppSocket = io(socketUrl, socketOptions); $socket.set(socket); - setEventListeners({ dispatch, socket }); + setEventListeners({ socket, dispatch, getState, setIsConnected: $isConnected.set }); socket.connect(); if ($isDebugging.get() || import.meta.env.MODE === 'development') { @@ -94,5 +96,5 @@ export const useSocketIO = () => { socket.disconnect(); $isSocketInitialized.set(false); }; - }, [dispatch, socketOptions, socketUrl]); + }, [dispatch, getState, socketOptions, socketUrl]); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts index d6cb10ff43..6b8d9782ca 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts @@ -1,7 +1,6 @@ import { logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import { - $lastProgressEvent, rasterLayerAdded, sessionStagingAreaImageAccepted, sessionStagingAreaReset, @@ -11,6 +10,7 @@ import { imageDTOToImageObject } from 'features/controlLayers/store/types'; import { toast } from 'features/toast/toast'; import { t } from 'i18next'; import { queueApi } from 'services/api/endpoints/queue'; +import { $lastCanvasProgressEvent } from 'services/events/setEventListeners'; import { assert } from 'tsafe'; export const addStagingListeners = (startAppListening: AppStartListening) => { @@ -29,7 +29,7 @@ export const addStagingListeners = (startAppListening: AppStartListening) => { const { canceled } = await req.unwrap(); req.reset(); - $lastProgressEvent.set(null); + $lastCanvasProgressEvent.set(null); if (canceled > 0) { log.debug(`Canceled ${canceled} canvas batches`); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress.ts index e28235da59..1aff46d0a3 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress.ts @@ -2,10 +2,10 @@ import { logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import { deepClone } from 'common/util/deepClone'; import { parseify } from 'common/util/serialize'; -import { $lastProgressEvent } from 'features/controlLayers/store/canvasV2Slice'; import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState'; import { zNodeStatus } from 'features/nodes/types/invocation'; import { socketGeneratorProgress } from 'services/events/actions'; +import { $lastCanvasProgressEvent } from 'services/events/setEventListeners'; const log = logger('socketio'); @@ -27,7 +27,7 @@ export const addGeneratorProgressEventListener = (startAppListening: AppStartLis } if (origin === 'canvas') { - $lastProgressEvent.set(action.payload.data); + $lastCanvasProgressEvent.set(action.payload.data); } }, }); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueEvents.tsx b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueEvents.tsx index 5ba1013bb7..0b37104ca6 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueEvents.tsx +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueEvents.tsx @@ -1,7 +1,6 @@ import { logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import { deepClone } from 'common/util/deepClone'; -import { $lastProgressEvent } from 'features/controlLayers/store/canvasV2Slice'; import { $nodeExecutionStates } from 'features/nodes/hooks/useExecutionState'; import { zNodeStatus } from 'features/nodes/types/invocation'; import ErrorToastDescription, { getTitleFromErrorType } from 'features/toast/ErrorToastDescription'; @@ -9,6 +8,7 @@ import { toast } from 'features/toast/toast'; import { forEach } from 'lodash-es'; import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue'; import { socketQueueItemStatusChanged } from 'services/events/actions'; +import { $lastCanvasProgressEvent } from 'services/events/setEventListeners'; const log = logger('socketio'); @@ -17,13 +17,13 @@ export const addSocketQueueEventsListeners = (startAppListening: AppStartListeni startAppListening({ matcher: queueApi.endpoints.clearQueue.matchFulfilled, effect: () => { - $lastProgressEvent.set(null); + $lastCanvasProgressEvent.set(null); }, }); startAppListening({ actionCreator: socketQueueItemStatusChanged, - effect: async (action, { dispatch, getState }) => { + effect: (action, { dispatch, getState }) => { // we've got new status for the queue item, batch and queue const { item_id, @@ -103,7 +103,7 @@ export const addSocketQueueEventsListeners = (startAppListening: AppStartListeni const isLocal = getState().config.isLocal ?? true; const sessionId = session_id; if (origin === 'canvas') { - $lastProgressEvent.set(null); + $lastCanvasProgressEvent.set(null); } toast({ @@ -122,7 +122,7 @@ export const addSocketQueueEventsListeners = (startAppListening: AppStartListeni ), }); } else if (status === 'canceled' && origin === 'canvas') { - $lastProgressEvent.set(null); + $lastCanvasProgressEvent.set(null); } }, }); diff --git a/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts b/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts index eb2ff5bc27..d30ee3f964 100644 --- a/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts +++ b/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts @@ -1,4 +1,5 @@ import { useStore } from '@nanostores/react'; +import { $isConnected } from 'app/hooks/useSocketIO'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectCanvasV2Slice } from 'features/controlLayers/store/canvasV2Slice'; @@ -25,7 +26,7 @@ const LAYER_TYPE_TO_TKEY = { control_layer: 'controlLayers.globalControlAdapter', } as const; -const createSelector = (templates: Templates) => +const createSelector = (templates: Templates, isConnected: boolean) => createMemoizedSelector( [ selectSystemSlice, @@ -41,8 +42,6 @@ const createSelector = (templates: Templates) => const { bbox } = canvasV2; const { model, positivePrompt } = canvasV2.params; - const { isConnected } = system; - const reasons: { prefix?: string; content: string }[] = []; // Cannot generate if not connected @@ -240,7 +239,8 @@ const createSelector = (templates: Templates) => export const useIsReadyToEnqueue = () => { const templates = useStore($templates); - const selector = useMemo(() => createSelector(templates), [templates]); + const isConnected = useStore($isConnected) + const selector = useMemo(() => createSelector(templates, isConnected), [templates, isConnected]); const value = useAppSelector(selector); return value; }; diff --git a/invokeai/frontend/web/src/common/types.ts b/invokeai/frontend/web/src/common/types.ts index f3037dcc2b..dd23638b8f 100644 --- a/invokeai/frontend/web/src/common/types.ts +++ b/invokeai/frontend/web/src/common/types.ts @@ -3,3 +3,8 @@ type JSONValue = string | number | boolean | null | JSONValue[] | { [key: string export interface JSONObject { [k: string]: JSONValue; } + +type SerializableValue = string | number | boolean | null | undefined | SerializableValue[] | SerializableObject; +export type SerializableObject = { + [k: string | number]: SerializableValue; +}; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/IPAdapter/IPAdapterImagePreview.tsx b/invokeai/frontend/web/src/features/controlLayers/components/IPAdapter/IPAdapterImagePreview.tsx index 9e76aa1b91..e1f6b07857 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/IPAdapter/IPAdapterImagePreview.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/IPAdapter/IPAdapterImagePreview.tsx @@ -1,5 +1,7 @@ import { Flex, useShiftModifier } from '@invoke-ai/ui-library'; +import { useStore } from '@nanostores/react'; import { skipToken } from '@reduxjs/toolkit/query'; +import { $isConnected } from 'app/hooks/useSocketIO'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIDndImage from 'common/components/IAIDndImage'; import IAIDndImageIcon from 'common/components/IAIDndImageIcon'; @@ -22,79 +24,85 @@ type Props = { postUploadAction: PostUploadAction; }; -export const IPAdapterImagePreview = memo(({ image, onChangeImage, ipAdapterId, droppableData, postUploadAction }: Props) => { - const { t } = useTranslation(); - const dispatch = useAppDispatch(); - const isConnected = useAppSelector((s) => s.system.isConnected); - const optimalDimension = useAppSelector(selectOptimalDimension); - const shift = useShiftModifier(); +export const IPAdapterImagePreview = memo( + ({ image, onChangeImage, ipAdapterId, droppableData, postUploadAction }: Props) => { + const { t } = useTranslation(); + const dispatch = useAppDispatch(); + const isConnected = useStore($isConnected); + const optimalDimension = useAppSelector(selectOptimalDimension); + const shift = useShiftModifier(); - const { currentData: controlImage, isError: isErrorControlImage } = useGetImageDTOQuery(image?.image_name ?? skipToken); - const handleResetControlImage = useCallback(() => { - onChangeImage(null); - }, [onChangeImage]); + const { currentData: controlImage, isError: isErrorControlImage } = useGetImageDTOQuery( + image?.image_name ?? skipToken + ); + const handleResetControlImage = useCallback(() => { + onChangeImage(null); + }, [onChangeImage]); - const handleSetControlImageToDimensions = useCallback(() => { - if (!controlImage) { - return; - } + const handleSetControlImageToDimensions = useCallback(() => { + if (!controlImage) { + return; + } - const options = { updateAspectRatio: true, clamp: true }; - if (shift) { - const { width, height } = controlImage; - dispatch(bboxWidthChanged({ width, ...options })); - dispatch(bboxHeightChanged({ height, ...options })); - } else { - const { width, height } = calculateNewSize( - controlImage.width / controlImage.height, - optimalDimension * optimalDimension - ); - dispatch(bboxWidthChanged({ width, ...options })); - dispatch(bboxHeightChanged({ height, ...options })); - } - }, [controlImage, dispatch, optimalDimension, shift]); + const options = { updateAspectRatio: true, clamp: true }; + if (shift) { + const { width, height } = controlImage; + dispatch(bboxWidthChanged({ width, ...options })); + dispatch(bboxHeightChanged({ height, ...options })); + } else { + const { width, height } = calculateNewSize( + controlImage.width / controlImage.height, + optimalDimension * optimalDimension + ); + dispatch(bboxWidthChanged({ width, ...options })); + dispatch(bboxHeightChanged({ height, ...options })); + } + }, [controlImage, dispatch, optimalDimension, shift]); - const draggableData = useMemo(() => { - if (controlImage) { - return { - id: ipAdapterId, - payloadType: 'IMAGE_DTO', - payload: { imageDTO: controlImage }, - }; - } - }, [controlImage, ipAdapterId]); + const draggableData = useMemo(() => { + if (controlImage) { + return { + id: ipAdapterId, + payloadType: 'IMAGE_DTO', + payload: { imageDTO: controlImage }, + }; + } + }, [controlImage, ipAdapterId]); - useEffect(() => { - if (isConnected && isErrorControlImage) { - handleResetControlImage(); - } - }, [handleResetControlImage, isConnected, isErrorControlImage]); + useEffect(() => { + if (isConnected && isErrorControlImage) { + handleResetControlImage(); + } + }, [handleResetControlImage, isConnected, isErrorControlImage]); - return ( - - + return ( + + - {controlImage && ( - - } - tooltip={t('controlnet.resetControlImage')} - /> - } - tooltip={shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')} - /> - - )} - - ); -}); + {controlImage && ( + + } + tooltip={t('controlnet.resetControlImage')} + /> + } + tooltip={ + shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions') + } + /> + + )} + + ); + } +); IPAdapterImagePreview.displayName = 'IPAdapterImagePreview'; diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasProgressImage.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasProgressImage.ts index 00c796b2c2..0739c267b5 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasProgressImage.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasProgressImage.ts @@ -48,7 +48,7 @@ export class CanvasProgressImage { image: null, }; - this.manager.stateApi.$lastProgressEvent.listen((event) => { + this.manager.stateApi.$lastCanvasProgressEvent.listen((event) => { this.lastProgressEvent = event; this.render(); }); diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStagingArea.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStagingArea.ts index c58186a14d..2ce21d239e 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStagingArea.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStagingArea.ts @@ -76,7 +76,7 @@ export class CanvasStagingArea { if (!this.image.isLoading && !this.image.isError) { await this.image.updateImageSource(imageDTO.image_name); - this.manager.stateApi.$lastProgressEvent.set(null); + this.manager.stateApi.$lastCanvasProgressEvent.set(null); } this.image.konva.group.visible(shouldShowStagedImage); } else { diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApi.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApi.ts index 262e28dce5..31a2aee8b2 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApi.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApi.ts @@ -11,7 +11,6 @@ import { $lastAddedPoint, $lastCursorPos, $lastMouseDownPos, - $lastProgressEvent, $shouldShowStagedImage, $spaceKey, $stageAttrs, @@ -51,6 +50,7 @@ import type { import { RGBA_RED } from 'features/controlLayers/store/types'; import type { WritableAtom } from 'nanostores'; import { atom } from 'nanostores'; +import { $lastCanvasProgressEvent } from 'services/events/setEventListeners'; type EntityStateAndAdapter = | { @@ -263,7 +263,7 @@ export class CanvasStateApi { $lastAddedPoint = $lastAddedPoint; $lastMouseDownPos = $lastMouseDownPos; $lastCursorPos = $lastCursorPos; - $lastProgressEvent = $lastProgressEvent; + $lastCanvasProgressEvent = $lastCanvasProgressEvent; $spaceKey = $spaceKey; $altKey = $alt; $ctrlKey = $ctrl; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/canvasV2Slice.ts b/invokeai/frontend/web/src/features/controlLayers/store/canvasV2Slice.ts index bc939f0894..d0ae06c28a 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/canvasV2Slice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/canvasV2Slice.ts @@ -20,7 +20,6 @@ import { initialAspectRatioState } from 'features/parameters/components/Document import { getOptimalDimension } from 'features/parameters/util/optimalDimension'; import { isEqual, pick } from 'lodash-es'; import { atom } from 'nanostores'; -import type { InvocationDenoiseProgressEvent } from 'services/events/types'; import { assert } from 'tsafe'; import type { @@ -622,7 +621,6 @@ export const $stageAttrs = atom({ scale: 0, }); export const $shouldShowStagedImage = atom(true); -export const $lastProgressEvent = atom(null); export const $isDrawing = atom(false); export const $isMouseDown = atom(false); export const $lastAddedPoint = atom(null); diff --git a/invokeai/frontend/web/src/features/deleteImageModal/components/DeleteImageButton.tsx b/invokeai/frontend/web/src/features/deleteImageModal/components/DeleteImageButton.tsx index 6855cb8e55..452d101fa2 100644 --- a/invokeai/frontend/web/src/features/deleteImageModal/components/DeleteImageButton.tsx +++ b/invokeai/frontend/web/src/features/deleteImageModal/components/DeleteImageButton.tsx @@ -1,5 +1,7 @@ import type { IconButtonProps } from '@invoke-ai/ui-library'; import { IconButton } from '@invoke-ai/ui-library'; +import { useStore } from '@nanostores/react'; +import { $isConnected } from 'app/hooks/useSocketIO'; import { useAppSelector } from 'app/store/storeHooks'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; @@ -12,7 +14,7 @@ type DeleteImageButtonProps = Omit & { export const DeleteImageButton = memo((props: DeleteImageButtonProps) => { const { onClick, isDisabled } = props; const { t } = useTranslation(); - const isConnected = useAppSelector((s) => s.system.isConnected); + const isConnected = useStore($isConnected); const imageSelectionLength: number = useAppSelector((s) => s.gallery.selection.length); const labelMessage: string = `${t('gallery.deleteImage', { count: imageSelectionLength })} (Del)`; diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImageButtons.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImageButtons.tsx index 1ef91e7e2e..9ccd69b898 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImageButtons.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImageButtons.tsx @@ -1,7 +1,7 @@ import { ButtonGroup, IconButton, Menu, MenuButton, MenuList } from '@invoke-ai/ui-library'; import { useStore } from '@nanostores/react'; -import { createSelector } from '@reduxjs/toolkit'; import { skipToken } from '@reduxjs/toolkit/query'; +import { $isConnected } from 'app/hooks/useSocketIO'; import { adHocPostProcessingRequested } from 'app/store/middleware/listenerMiddleware/listeners/addAdHocPostProcessingRequestedListener'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { DeleteImageButton } from 'features/deleteImageModal/components/DeleteImageButton'; @@ -10,17 +10,15 @@ import SingleSelectionMenuItems from 'features/gallery/components/ImageContextMe import { useImageActions } from 'features/gallery/hooks/useImageActions'; import { sentImageToImg2Img } from 'features/gallery/store/actions'; import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors'; -import { selectGallerySlice } from 'features/gallery/store/gallerySlice'; import { parseAndRecallImageDimensions } from 'features/metadata/util/handlers'; import { $templates } from 'features/nodes/store/nodesSlice'; import { PostProcessingPopover } from 'features/parameters/components/PostProcessing/PostProcessingPopover'; import { useIsQueueMutationInProgress } from 'features/queue/hooks/useIsQueueMutationInProgress'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; -import { selectSystemSlice } from 'features/system/store/systemSlice'; import { setActiveTab } from 'features/ui/store/uiSlice'; import { useGetAndLoadEmbeddedWorkflow } from 'features/workflowLibrary/hooks/useGetAndLoadEmbeddedWorkflow'; import { size } from 'lodash-es'; -import { memo, useCallback } from 'react'; +import { memo, useCallback, useMemo } from 'react'; import { useHotkeys } from 'react-hotkeys-hook'; import { useTranslation } from 'react-i18next'; import { @@ -33,23 +31,17 @@ import { PiRulerBold, } from 'react-icons/pi'; import { useGetImageDTOQuery } from 'services/api/endpoints/images'; - -const selectShouldDisableToolbarButtons = createSelector( - selectSystemSlice, - selectGallerySlice, - selectLastSelectedImage, - (system, gallery, lastSelectedImage) => { - const hasProgressImage = Boolean(system.denoiseProgress?.progress_image); - return hasProgressImage || !lastSelectedImage; - } -); +import { $progressImage } from 'services/events/setEventListeners'; const CurrentImageButtons = () => { const dispatch = useAppDispatch(); - const isConnected = useAppSelector((s) => s.system.isConnected); + const isConnected = useStore($isConnected); const lastSelectedImage = useAppSelector(selectLastSelectedImage); + const progressImage = useStore($progressImage); const selection = useAppSelector((s) => s.gallery.selection); - const shouldDisableToolbarButtons = useAppSelector(selectShouldDisableToolbarButtons); + const shouldDisableToolbarButtons = useMemo(() => { + return Boolean(progressImage) || !lastSelectedImage; + }, [lastSelectedImage, progressImage]); const templates = useStore($templates); const isUpscalingEnabled = useFeatureStatus('upscaling'); const isQueueMutationInProgress = useIsQueueMutationInProgress(); diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImagePreview.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImagePreview.tsx index a812391992..23e75498ec 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImagePreview.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImagePreview.tsx @@ -1,4 +1,5 @@ import { Box, Flex } from '@invoke-ai/ui-library'; +import { useStore } from '@nanostores/react'; import { createSelector } from '@reduxjs/toolkit'; import { skipToken } from '@reduxjs/toolkit/query'; import { useAppSelector } from 'app/store/storeHooks'; @@ -14,6 +15,7 @@ import { memo, useCallback, useMemo, useRef, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { PiImageBold } from 'react-icons/pi'; import { useGetImageDTOQuery } from 'services/api/endpoints/images'; +import { $hasProgress } from 'services/events/setEventListeners'; import ProgressImage from './ProgressImage'; @@ -26,7 +28,7 @@ const CurrentImagePreview = () => { const { t } = useTranslation(); const shouldShowImageDetails = useAppSelector((s) => s.ui.shouldShowImageDetails); const imageName = useAppSelector(selectLastSelectedImageName); - const hasDenoiseProgress = useAppSelector((s) => Boolean(s.system.denoiseProgress)); + const hasDenoiseProgress = useStore($hasProgress); const shouldShowProgressInViewer = useAppSelector((s) => s.ui.shouldShowProgressInViewer); const { currentData: imageDTO } = useGetImageDTOQuery(imageName ?? skipToken); diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressImage.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressImage.tsx index 0ee75fbcd4..46c1bd71c2 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressImage.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressImage.tsx @@ -1,10 +1,12 @@ import type { SystemStyleObject } from '@invoke-ai/ui-library'; import { Image } from '@invoke-ai/ui-library'; +import { useStore } from '@nanostores/react'; import { useAppSelector } from 'app/store/storeHooks'; import { memo, useMemo } from 'react'; +import { $progressImage } from 'services/events/setEventListeners'; const CurrentImagePreview = () => { - const progress_image = useAppSelector((s) => s.system.denoiseProgress?.progress_image); + const progressImage = useStore($progressImage); const shouldAntialiasProgressImage = useAppSelector((s) => s.system.shouldAntialiasProgressImage); const sx = useMemo( @@ -14,15 +16,15 @@ const CurrentImagePreview = () => { [shouldAntialiasProgressImage] ); - if (!progress_image) { + if (!progressImage) { return null; } return ( { - const imageDTO = gallery.selection[gallery.selection.length - 1]; - - return { - imageDTO, - progressImage: system.denoiseProgress?.progress_image, - }; -}); +import { $lastProgressEvent } from 'services/events/setEventListeners'; const CurrentImageNode = (props: NodeProps) => { - const { progressImage, imageDTO } = useAppSelector(selector); + const imageDTO = useAppSelector((s) => s.gallery.selection[s.gallery.selection.length - 1]); + const lastProgressEvent = useStore($lastProgressEvent); - if (progressImage) { + if (lastProgressEvent?.progress_image) { return ( - + ); } diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldInputComponent.tsx index c3224238c5..1ec0b575f6 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldInputComponent.tsx @@ -1,6 +1,8 @@ import { Flex, Text } from '@invoke-ai/ui-library'; +import { useStore } from '@nanostores/react'; import { skipToken } from '@reduxjs/toolkit/query'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { $isConnected } from 'app/hooks/useSocketIO'; +import { useAppDispatch } from 'app/store/storeHooks'; import IAIDndImage from 'common/components/IAIDndImage'; import IAIDndImageIcon from 'common/components/IAIDndImageIcon'; import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types'; @@ -17,7 +19,7 @@ import type { FieldComponentProps } from './types'; const ImageFieldInputComponent = (props: FieldComponentProps) => { const { nodeId, field } = props; const dispatch = useAppDispatch(); - const isConnected = useAppSelector((s) => s.system.isConnected); + const isConnected = useStore($isConnected); const { currentData: imageDTO, isError } = useGetImageDTOQuery(field.value?.image_name ?? skipToken); const handleReset = useCallback(() => { diff --git a/invokeai/frontend/web/src/features/queue/hooks/useCancelBatch.ts b/invokeai/frontend/web/src/features/queue/hooks/useCancelBatch.ts index 9d92eabff8..d9ad1a736f 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/useCancelBatch.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/useCancelBatch.ts @@ -1,11 +1,12 @@ -import { useAppSelector } from 'app/store/storeHooks'; +import { useStore } from '@nanostores/react'; +import { $isConnected } from 'app/hooks/useSocketIO'; import { toast } from 'features/toast/toast'; import { useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { useCancelByBatchIdsMutation, useGetBatchStatusQuery } from 'services/api/endpoints/queue'; export const useCancelBatch = (batch_id: string) => { - const isConnected = useAppSelector((s) => s.system.isConnected); + const isConnected = useStore($isConnected); const { isCanceled } = useGetBatchStatusQuery( { batch_id }, { diff --git a/invokeai/frontend/web/src/features/queue/hooks/useCancelCurrentQueueItem.ts b/invokeai/frontend/web/src/features/queue/hooks/useCancelCurrentQueueItem.ts index 057490ed99..9ae8e2dd2e 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/useCancelCurrentQueueItem.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/useCancelCurrentQueueItem.ts @@ -1,4 +1,5 @@ -import { useAppSelector } from 'app/store/storeHooks'; +import { useStore } from '@nanostores/react'; +import { $isConnected } from 'app/hooks/useSocketIO'; import { toast } from 'features/toast/toast'; import { isNil } from 'lodash-es'; import { useCallback, useMemo } from 'react'; @@ -6,7 +7,7 @@ import { useTranslation } from 'react-i18next'; import { useCancelQueueItemMutation, useGetQueueStatusQuery } from 'services/api/endpoints/queue'; export const useCancelCurrentQueueItem = () => { - const isConnected = useAppSelector((s) => s.system.isConnected); + const isConnected = useStore($isConnected); const { data: queueStatus } = useGetQueueStatusQuery(); const [trigger, { isLoading }] = useCancelQueueItemMutation(); const { t } = useTranslation(); diff --git a/invokeai/frontend/web/src/features/queue/hooks/useCancelQueueItem.ts b/invokeai/frontend/web/src/features/queue/hooks/useCancelQueueItem.ts index 268eca75cc..bf0af41605 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/useCancelQueueItem.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/useCancelQueueItem.ts @@ -1,11 +1,12 @@ -import { useAppSelector } from 'app/store/storeHooks'; +import { useStore } from '@nanostores/react'; +import { $isConnected } from 'app/hooks/useSocketIO'; import { toast } from 'features/toast/toast'; import { useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { useCancelQueueItemMutation } from 'services/api/endpoints/queue'; export const useCancelQueueItem = (item_id: number) => { - const isConnected = useAppSelector((s) => s.system.isConnected); + const isConnected = useStore($isConnected); const [trigger, { isLoading }] = useCancelQueueItemMutation(); const { t } = useTranslation(); const cancelQueueItem = useCallback(async () => { diff --git a/invokeai/frontend/web/src/features/queue/hooks/useClearInvocationCache.ts b/invokeai/frontend/web/src/features/queue/hooks/useClearInvocationCache.ts index 7ef9d93742..d177a72f5f 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/useClearInvocationCache.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/useClearInvocationCache.ts @@ -1,4 +1,5 @@ -import { useAppSelector } from 'app/store/storeHooks'; +import { useStore } from '@nanostores/react'; +import { $isConnected } from 'app/hooks/useSocketIO'; import { toast } from 'features/toast/toast'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; @@ -7,7 +8,7 @@ import { useClearInvocationCacheMutation, useGetInvocationCacheStatusQuery } fro export const useClearInvocationCache = () => { const { t } = useTranslation(); const { data: cacheStatus } = useGetInvocationCacheStatusQuery(); - const isConnected = useAppSelector((s) => s.system.isConnected); + const isConnected = useStore($isConnected); const [trigger, { isLoading }] = useClearInvocationCacheMutation({ fixedCacheKey: 'clearInvocationCache', }); diff --git a/invokeai/frontend/web/src/features/queue/hooks/useClearQueue.ts b/invokeai/frontend/web/src/features/queue/hooks/useClearQueue.ts index ca7d1e4894..bb80f7aa10 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/useClearQueue.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/useClearQueue.ts @@ -1,4 +1,6 @@ -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { useStore } from '@nanostores/react'; +import { $isConnected } from 'app/hooks/useSocketIO'; +import { useAppDispatch } from 'app/store/storeHooks'; import { listCursorChanged, listPriorityChanged } from 'features/queue/store/queueSlice'; import { toast } from 'features/toast/toast'; import { useCallback, useMemo } from 'react'; @@ -9,7 +11,7 @@ export const useClearQueue = () => { const { t } = useTranslation(); const dispatch = useAppDispatch(); const { data: queueStatus } = useGetQueueStatusQuery(); - const isConnected = useAppSelector((s) => s.system.isConnected); + const isConnected = useStore($isConnected); const [trigger, { isLoading }] = useClearQueueMutation({ fixedCacheKey: 'clearQueue', }); diff --git a/invokeai/frontend/web/src/features/queue/hooks/useDisableInvocationCache.ts b/invokeai/frontend/web/src/features/queue/hooks/useDisableInvocationCache.ts index 371e9198e7..cf71e4bd4b 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/useDisableInvocationCache.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/useDisableInvocationCache.ts @@ -1,4 +1,5 @@ -import { useAppSelector } from 'app/store/storeHooks'; +import { useStore } from '@nanostores/react'; +import { $isConnected } from 'app/hooks/useSocketIO'; import { toast } from 'features/toast/toast'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; @@ -7,7 +8,7 @@ import { useDisableInvocationCacheMutation, useGetInvocationCacheStatusQuery } f export const useDisableInvocationCache = () => { const { t } = useTranslation(); const { data: cacheStatus } = useGetInvocationCacheStatusQuery(); - const isConnected = useAppSelector((s) => s.system.isConnected); + const isConnected = useStore($isConnected); const [trigger, { isLoading }] = useDisableInvocationCacheMutation({ fixedCacheKey: 'disableInvocationCache', }); diff --git a/invokeai/frontend/web/src/features/queue/hooks/useEnableInvocationCache.ts b/invokeai/frontend/web/src/features/queue/hooks/useEnableInvocationCache.ts index fb39cf7347..7f28bddd78 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/useEnableInvocationCache.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/useEnableInvocationCache.ts @@ -1,4 +1,5 @@ -import { useAppSelector } from 'app/store/storeHooks'; +import { useStore } from '@nanostores/react'; +import { $isConnected } from 'app/hooks/useSocketIO'; import { toast } from 'features/toast/toast'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; @@ -7,7 +8,7 @@ import { useEnableInvocationCacheMutation, useGetInvocationCacheStatusQuery } fr export const useEnableInvocationCache = () => { const { t } = useTranslation(); const { data: cacheStatus } = useGetInvocationCacheStatusQuery(); - const isConnected = useAppSelector((s) => s.system.isConnected); + const isConnected = useStore($isConnected); const [trigger, { isLoading }] = useEnableInvocationCacheMutation({ fixedCacheKey: 'enableInvocationCache', }); diff --git a/invokeai/frontend/web/src/features/queue/hooks/usePauseProcessor.ts b/invokeai/frontend/web/src/features/queue/hooks/usePauseProcessor.ts index f5424c6b18..d25c8051e5 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/usePauseProcessor.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/usePauseProcessor.ts @@ -1,4 +1,5 @@ -import { useAppSelector } from 'app/store/storeHooks'; +import { useStore } from '@nanostores/react'; +import { $isConnected } from 'app/hooks/useSocketIO'; import { toast } from 'features/toast/toast'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; @@ -6,7 +7,7 @@ import { useGetQueueStatusQuery, usePauseProcessorMutation } from 'services/api/ export const usePauseProcessor = () => { const { t } = useTranslation(); - const isConnected = useAppSelector((s) => s.system.isConnected); + const isConnected = useStore($isConnected); const { data: queueStatus } = useGetQueueStatusQuery(); const [trigger, { isLoading }] = usePauseProcessorMutation({ fixedCacheKey: 'pauseProcessor', diff --git a/invokeai/frontend/web/src/features/queue/hooks/usePruneQueue.ts b/invokeai/frontend/web/src/features/queue/hooks/usePruneQueue.ts index eaeabe5423..f9426291be 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/usePruneQueue.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/usePruneQueue.ts @@ -1,4 +1,6 @@ -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { useStore } from '@nanostores/react'; +import { $isConnected } from 'app/hooks/useSocketIO'; +import { useAppDispatch } from 'app/store/storeHooks'; import { listCursorChanged, listPriorityChanged } from 'features/queue/store/queueSlice'; import { toast } from 'features/toast/toast'; import { useCallback, useMemo } from 'react'; @@ -8,7 +10,7 @@ import { useGetQueueStatusQuery, usePruneQueueMutation } from 'services/api/endp export const usePruneQueue = () => { const dispatch = useAppDispatch(); const { t } = useTranslation(); - const isConnected = useAppSelector((s) => s.system.isConnected); + const isConnected = useStore($isConnected); const [trigger, { isLoading }] = usePruneQueueMutation({ fixedCacheKey: 'pruneQueue', }); diff --git a/invokeai/frontend/web/src/features/queue/hooks/useResumeProcessor.ts b/invokeai/frontend/web/src/features/queue/hooks/useResumeProcessor.ts index 851b268416..72d787103b 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/useResumeProcessor.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/useResumeProcessor.ts @@ -1,11 +1,12 @@ -import { useAppSelector } from 'app/store/storeHooks'; +import { useStore } from '@nanostores/react'; +import { $isConnected } from 'app/hooks/useSocketIO'; import { toast } from 'features/toast/toast'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useGetQueueStatusQuery, useResumeProcessorMutation } from 'services/api/endpoints/queue'; export const useResumeProcessor = () => { - const isConnected = useAppSelector((s) => s.system.isConnected); + const isConnected = useStore($isConnected); const { data: queueStatus } = useGetQueueStatusQuery(); const { t } = useTranslation(); const [trigger, { isLoading }] = useResumeProcessorMutation({ diff --git a/invokeai/frontend/web/src/features/system/components/ProgressBar.tsx b/invokeai/frontend/web/src/features/system/components/ProgressBar.tsx index 4389431813..06c7e70c7f 100644 --- a/invokeai/frontend/web/src/features/system/components/ProgressBar.tsx +++ b/invokeai/frontend/web/src/features/system/components/ProgressBar.tsx @@ -1,28 +1,28 @@ import { Progress } from '@invoke-ai/ui-library'; -import { createSelector } from '@reduxjs/toolkit'; -import { useAppSelector } from 'app/store/storeHooks'; -import { selectSystemSlice } from 'features/system/store/systemSlice'; -import { memo } from 'react'; +import { useStore } from '@nanostores/react'; +import { $isConnected } from 'app/hooks/useSocketIO'; +import { memo, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useGetQueueStatusQuery } from 'services/api/endpoints/queue'; - -const selectProgressValue = createSelector( - selectSystemSlice, - (system) => (system.denoiseProgress?.percentage ?? 0) * 100 -); +import { $lastProgressEvent } from 'services/events/setEventListeners'; const ProgressBar = () => { const { t } = useTranslation(); const { data: queueStatus } = useGetQueueStatusQuery(); - const isConnected = useAppSelector((s) => s.system.isConnected); - const hasSteps = useAppSelector((s) => Boolean(s.system.denoiseProgress)); - const value = useAppSelector(selectProgressValue); + const isConnected = useStore($isConnected); + const lastProgressEvent = useStore($lastProgressEvent); + const value = useMemo(() => { + if (!lastProgressEvent) { + return 0; + } + return (lastProgressEvent.percentage ?? 0) * 100; + }, [lastProgressEvent]); return ( { - const isConnected = useAppSelector((s) => s.system.isConnected); + const isConnected = useStore($isConnected); const { t } = useTranslation(); if (!isConnected) { diff --git a/invokeai/frontend/web/src/services/events/setEventListeners.ts b/invokeai/frontend/web/src/services/events/setEventListeners.ts deleted file mode 100644 index 8c8c9da2e8..0000000000 --- a/invokeai/frontend/web/src/services/events/setEventListeners.ts +++ /dev/null @@ -1,136 +0,0 @@ -import { $baseUrl } from 'app/store/nanostores/baseUrl'; -import { $bulkDownloadId } from 'app/store/nanostores/bulkDownloadId'; -import { $queueId } from 'app/store/nanostores/queueId'; -import type { AppDispatch } from 'app/store/store'; -import { toast } from 'features/toast/toast'; -import { - socketBatchEnqueued, - socketBulkDownloadComplete, - socketBulkDownloadError, - socketBulkDownloadStarted, - socketConnected, - socketDisconnected, - socketDownloadCancelled, - socketDownloadComplete, - socketDownloadError, - socketDownloadProgress, - socketDownloadStarted, - socketGeneratorProgress, - socketInvocationComplete, - socketInvocationError, - socketInvocationStarted, - socketModelInstallCancelled, - socketModelInstallComplete, - socketModelInstallDownloadProgress, - socketModelInstallDownloadsComplete, - socketModelInstallError, - socketModelInstallStarted, - socketModelLoadComplete, - socketModelLoadStarted, - socketQueueCleared, - socketQueueItemStatusChanged, -} from 'services/events/actions'; -import type { ClientToServerEvents, ServerToClientEvents } from 'services/events/types'; -import type { Socket } from 'socket.io-client'; - -type SetEventListenersArg = { - socket: Socket; - dispatch: AppDispatch; -}; - -export const setEventListeners = ({ socket, dispatch }: SetEventListenersArg) => { - socket.on('connect', () => { - dispatch(socketConnected()); - const queue_id = $queueId.get(); - socket.emit('subscribe_queue', { queue_id }); - if (!$baseUrl.get()) { - const bulk_download_id = $bulkDownloadId.get(); - socket.emit('subscribe_bulk_download', { bulk_download_id }); - } - }); - socket.on('connect_error', (error) => { - if (error && error.message) { - const data: string | undefined = (error as unknown as { data: string | undefined }).data; - if (data === 'ERR_UNAUTHENTICATED') { - toast({ - id: `connect-error-${error.message}`, - title: error.message, - status: 'error', - duration: 10000, - }); - } - } - }); - socket.on('disconnect', () => { - dispatch(socketDisconnected()); - }); - socket.on('invocation_started', (data) => { - dispatch(socketInvocationStarted({ data })); - }); - socket.on('invocation_denoise_progress', (data) => { - dispatch(socketGeneratorProgress({ data })); - }); - socket.on('invocation_error', (data) => { - dispatch(socketInvocationError({ data })); - }); - socket.on('invocation_complete', (data) => { - dispatch(socketInvocationComplete({ data })); - }); - socket.on('model_load_started', (data) => { - dispatch(socketModelLoadStarted({ data })); - }); - socket.on('model_load_complete', (data) => { - dispatch(socketModelLoadComplete({ data })); - }); - socket.on('download_started', (data) => { - dispatch(socketDownloadStarted({ data })); - }); - socket.on('download_progress', (data) => { - dispatch(socketDownloadProgress({ data })); - }); - socket.on('download_complete', (data) => { - dispatch(socketDownloadComplete({ data })); - }); - socket.on('download_cancelled', (data) => { - dispatch(socketDownloadCancelled({ data })); - }); - socket.on('download_error', (data) => { - dispatch(socketDownloadError({ data })); - }); - socket.on('model_install_started', (data) => { - dispatch(socketModelInstallStarted({ data })); - }); - socket.on('model_install_download_progress', (data) => { - dispatch(socketModelInstallDownloadProgress({ data })); - }); - socket.on('model_install_downloads_complete', (data) => { - dispatch(socketModelInstallDownloadsComplete({ data })); - }); - socket.on('model_install_complete', (data) => { - dispatch(socketModelInstallComplete({ data })); - }); - socket.on('model_install_error', (data) => { - dispatch(socketModelInstallError({ data })); - }); - socket.on('model_install_cancelled', (data) => { - dispatch(socketModelInstallCancelled({ data })); - }); - socket.on('queue_item_status_changed', (data) => { - dispatch(socketQueueItemStatusChanged({ data })); - }); - socket.on('queue_cleared', (data) => { - dispatch(socketQueueCleared({ data })); - }); - socket.on('batch_enqueued', (data) => { - dispatch(socketBatchEnqueued({ data })); - }); - socket.on('bulk_download_started', (data) => { - dispatch(socketBulkDownloadStarted({ data })); - }); - socket.on('bulk_download_complete', (data) => { - dispatch(socketBulkDownloadComplete({ data })); - }); - socket.on('bulk_download_error', (data) => { - dispatch(socketBulkDownloadError({ data })); - }); -}; diff --git a/invokeai/frontend/web/src/services/events/setEventListeners.tsx b/invokeai/frontend/web/src/services/events/setEventListeners.tsx new file mode 100644 index 0000000000..379b280032 --- /dev/null +++ b/invokeai/frontend/web/src/services/events/setEventListeners.tsx @@ -0,0 +1,621 @@ +import { ExternalLink } from '@invoke-ai/ui-library'; +import { logger } from 'app/logging/logger'; +import { $baseUrl } from 'app/store/nanostores/baseUrl'; +import { $bulkDownloadId } from 'app/store/nanostores/bulkDownloadId'; +import { $queueId } from 'app/store/nanostores/queueId'; +import type { AppDispatch, RootState } from 'app/store/store'; +import type { SerializableObject } from 'common/types'; +import { deepClone } from 'common/util/deepClone'; +import { sessionImageStaged } from 'features/controlLayers/store/canvasV2Slice'; +import { boardIdSelected, galleryViewChanged, imageSelected, offsetChanged } from 'features/gallery/store/gallerySlice'; +import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState'; +import { zNodeStatus } from 'features/nodes/types/invocation'; +import ErrorToastDescription, { getTitleFromErrorType } from 'features/toast/ErrorToastDescription'; +import { toast } from 'features/toast/toast'; +import { t } from 'i18next'; +import { forEach } from 'lodash-es'; +import { atom, computed } from 'nanostores'; +import { api, LIST_TAG } from 'services/api'; +import { boardsApi } from 'services/api/endpoints/boards'; +import { imagesApi } from 'services/api/endpoints/images'; +import { modelsApi } from 'services/api/endpoints/models'; +import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue'; +import { getCategories, getListImagesUrl } from 'services/api/util'; +import { socketConnected } from 'services/events/actions'; +import type { ClientToServerEvents, InvocationDenoiseProgressEvent, ServerToClientEvents } from 'services/events/types'; +import type { Socket } from 'socket.io-client'; + +const log = logger('socketio'); + +type SetEventListenersArg = { + socket: Socket; + dispatch: AppDispatch; + getState: () => RootState; + setIsConnected: (isConnected: boolean) => void; +}; + +const selectModelInstalls = modelsApi.endpoints.listModelInstalls.select(); +const nodeTypeDenylist = ['load_image', 'image']; +export const $lastProgressEvent = atom(null); +export const $lastCanvasProgressEvent = atom(null); +export const $hasProgress = computed($lastProgressEvent, (val) => Boolean(val)); +export const $progressImage = computed($lastProgressEvent, (val) => val?.progress_image ?? null); +const cancellations = new Set(); + +export const setEventListeners = ({ socket, dispatch, getState, setIsConnected }: SetEventListenersArg) => { + socket.on('connect', () => { + log.debug('Connected'); + setIsConnected(true); + dispatch(socketConnected()); + const queue_id = $queueId.get(); + socket.emit('subscribe_queue', { queue_id }); + if (!$baseUrl.get()) { + const bulk_download_id = $bulkDownloadId.get(); + socket.emit('subscribe_bulk_download', { bulk_download_id }); + } + $lastProgressEvent.set(null); + $lastCanvasProgressEvent.set(null); + cancellations.clear(); + }); + + socket.on('connect_error', (error) => { + log.debug('Connect error'); + setIsConnected(false); + $lastProgressEvent.set(null); + $lastCanvasProgressEvent.set(null); + if (error && error.message) { + const data: string | undefined = (error as unknown as { data: string | undefined }).data; + if (data === 'ERR_UNAUTHENTICATED') { + toast({ + id: `connect-error-${error.message}`, + title: error.message, + status: 'error', + duration: 10000, + }); + } + } + cancellations.clear(); + }); + + socket.on('disconnect', () => { + log.debug('Disconnected'); + $lastProgressEvent.set(null); + $lastCanvasProgressEvent.set(null); + setIsConnected(false); + cancellations.clear(); + }); + + socket.on('invocation_started', (data) => { + const { invocation_source_id, invocation } = data; + log.debug({ data } as SerializableObject, `Invocation started (${invocation.type}, ${invocation_source_id})`); + const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]); + if (nes) { + nes.status = zNodeStatus.enum.IN_PROGRESS; + upsertExecutionState(nes.nodeId, nes); + } + cancellations.clear(); + }); + + socket.on('invocation_denoise_progress', (data) => { + const { invocation_source_id, invocation, step, total_steps, progress_image, origin, percentage, session_id } = + data; + + if (cancellations.has(session_id)) { + // Do not update the progress if this session has been cancelled. This prevents a race condition where we get a + // progress update after the session has been cancelled. + return; + } + + log.trace( + { data } as SerializableObject, + `Denoise ${Math.round(percentage * 100)}% (${invocation.type}, ${invocation_source_id})` + ); + + $lastProgressEvent.set(data); + + if (origin === 'workflows') { + const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]); + if (nes) { + nes.status = zNodeStatus.enum.IN_PROGRESS; + nes.progress = (step + 1) / total_steps; + nes.progressImage = progress_image ?? null; + upsertExecutionState(nes.nodeId, nes); + } + } + + if (origin === 'canvas') { + $lastCanvasProgressEvent.set(data); + } + }); + + socket.on('invocation_error', (data) => { + const { invocation_source_id, invocation, error_type, error_message, error_traceback } = data; + log.error({ data } as SerializableObject, `Invocation error (${invocation.type}, ${invocation_source_id})`); + const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]); + if (nes) { + nes.status = zNodeStatus.enum.FAILED; + nes.progress = null; + nes.progressImage = null; + nes.error = { + error_type, + error_message, + error_traceback, + }; + upsertExecutionState(nes.nodeId, nes); + } + }); + + socket.on('invocation_complete', async (data) => { + log.debug( + { data } as SerializableObject, + `Invocation complete (${data.invocation.type}, ${data.invocation_source_id})` + ); + + const { result, invocation_source_id } = data; + + if (data.origin === 'workflows') { + const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]); + if (nes) { + nes.status = zNodeStatus.enum.COMPLETED; + if (nes.progress !== null) { + nes.progress = 1; + } + nes.outputs.push(result); + upsertExecutionState(nes.nodeId, nes); + } + } + + // This complete event has an associated image output + if ( + (data.result.type === 'image_output' || data.result.type === 'canvas_v2_mask_and_crop_output') && + !nodeTypeDenylist.includes(data.invocation.type) + ) { + const { image_name } = data.result.image; + const { gallery, canvasV2 } = getState(); + + // This populates the `getImageDTO` cache + const imageDTORequest = dispatch( + imagesApi.endpoints.getImageDTO.initiate(image_name, { + forceRefetch: true, + }) + ); + + const imageDTO = await imageDTORequest.unwrap(); + imageDTORequest.unsubscribe(); + + // handle tab-specific logic + if (data.origin === 'canvas' && data.invocation_source_id === 'canvas_output') { + if (data.result.type === 'canvas_v2_mask_and_crop_output') { + const { offset_x, offset_y } = data.result; + if (canvasV2.session.isStaging) { + dispatch(sessionImageStaged({ stagingAreaImage: { imageDTO, offsetX: offset_x, offsetY: offset_y } })); + } + } else if (data.result.type === 'image_output') { + if (canvasV2.session.isStaging) { + dispatch(sessionImageStaged({ stagingAreaImage: { imageDTO, offsetX: 0, offsetY: 0 } })); + } + } + } + + if (!imageDTO.is_intermediate) { + // update the total images for the board + dispatch( + boardsApi.util.updateQueryData('getBoardImagesTotal', imageDTO.board_id ?? 'none', (draft) => { + // eslint-disable-next-line @typescript-eslint/no-unused-vars + draft.total += 1; + }) + ); + + dispatch( + imagesApi.util.invalidateTags([ + { type: 'Board', id: imageDTO.board_id ?? 'none' }, + { + type: 'ImageList', + id: getListImagesUrl({ + board_id: imageDTO.board_id ?? 'none', + categories: getCategories(imageDTO), + }), + }, + ]) + ); + + const { shouldAutoSwitch } = gallery; + + // If auto-switch is enabled, select the new image + if (shouldAutoSwitch) { + // if auto-add is enabled, switch the gallery view and board if needed as the image comes in + if (gallery.galleryView !== 'images') { + dispatch(galleryViewChanged('images')); + } + + if (imageDTO.board_id && imageDTO.board_id !== gallery.selectedBoardId) { + dispatch( + boardIdSelected({ + boardId: imageDTO.board_id, + selectedImageName: imageDTO.image_name, + }) + ); + } + + dispatch(offsetChanged({ offset: 0 })); + + if (!imageDTO.board_id && gallery.selectedBoardId !== 'none') { + dispatch( + boardIdSelected({ + boardId: 'none', + selectedImageName: imageDTO.image_name, + }) + ); + } + + dispatch(imageSelected(imageDTO)); + } + } + } + + $lastProgressEvent.set(null); + }); + + socket.on('model_load_started', (data) => { + const { config, submodel_type } = data; + const { name, base, type } = config; + + const extras: string[] = [base, type]; + + if (submodel_type) { + extras.push(submodel_type); + } + + const message = `Model load started: ${name} (${extras.join(', ')})`; + + log.debug({ data }, message); + }); + + socket.on('model_load_complete', (data) => { + const { config, submodel_type } = data; + const { name, base, type } = config; + + const extras: string[] = [base, type]; + if (submodel_type) { + extras.push(submodel_type); + } + + const message = `Model load complete: ${name} (${extras.join(', ')})`; + + log.debug({ data }, message); + }); + + socket.on('download_started', (data) => { + log.debug({ data }, 'Download started'); + }); + + socket.on('download_progress', (data) => { + log.trace({ data }, 'Download progress'); + }); + + socket.on('download_complete', (data) => { + log.debug({ data }, 'Download complete'); + }); + + socket.on('download_cancelled', (data) => { + log.warn({ data }, 'Download cancelled'); + }); + + socket.on('download_error', (data) => { + log.error({ data }, 'Download error'); + }); + + socket.on('model_install_started', (data) => { + log.debug({ data }, 'Model install started'); + + const { id } = data; + const installs = selectModelInstalls(getState()).data; + + if (!installs?.find((install) => install.id === id)) { + dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }])); + } else { + dispatch( + modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { + const modelImport = draft.find((m) => m.id === id); + if (modelImport) { + modelImport.status = 'running'; + } + return draft; + }) + ); + } + }); + + socket.on('model_install_download_started', (data) => { + log.debug({ data }, 'Model install download started'); + + const { id } = data; + const installs = selectModelInstalls(getState()).data; + + if (!installs?.find((install) => install.id === id)) { + dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }])); + } else { + dispatch( + modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { + const modelImport = draft.find((m) => m.id === id); + if (modelImport) { + modelImport.status = 'downloading'; + } + return draft; + }) + ); + } + }); + + socket.on('model_install_download_progress', (data) => { + log.trace({ data }, 'Model install download progress'); + + const { bytes, total_bytes, id } = data; + const installs = selectModelInstalls(getState()).data; + + if (!installs?.find((install) => install.id === id)) { + dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }])); + } else { + dispatch( + modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { + const modelImport = draft.find((m) => m.id === id); + if (modelImport) { + modelImport.bytes = bytes; + modelImport.total_bytes = total_bytes; + modelImport.status = 'downloading'; + } + return draft; + }) + ); + } + }); + + socket.on('model_install_downloads_complete', (data) => { + log.debug({ data }, 'Model install downloads complete'); + + const { id } = data; + const installs = selectModelInstalls(getState()).data; + + if (!installs?.find((install) => install.id === id)) { + dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }])); + } else { + dispatch( + modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { + const modelImport = draft.find((m) => m.id === id); + if (modelImport) { + modelImport.status = 'downloads_done'; + } + return draft; + }) + ); + } + }); + + socket.on('model_install_complete', (data) => { + log.debug({ data }, 'Model install complete'); + + const { id } = data; + + const installs = selectModelInstalls(getState()).data; + + if (!installs?.find((install) => install.id === id)) { + dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }])); + } else { + dispatch( + modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { + const modelImport = draft.find((m) => m.id === id); + if (modelImport) { + modelImport.status = 'completed'; + } + return draft; + }) + ); + } + + dispatch(api.util.invalidateTags([{ type: 'ModelConfig', id: LIST_TAG }])); + dispatch(api.util.invalidateTags([{ type: 'ModelScanFolderResults', id: LIST_TAG }])); + }); + + socket.on('model_install_error', (data) => { + log.error({ data }, 'Model install error'); + + const { id, error, error_type } = data; + const installs = selectModelInstalls(getState()).data; + + if (!installs?.find((install) => install.id === id)) { + dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }])); + } else { + dispatch( + modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { + const modelImport = draft.find((m) => m.id === id); + if (modelImport) { + modelImport.status = 'error'; + modelImport.error_reason = error_type; + modelImport.error = error; + } + return draft; + }) + ); + } + }); + + socket.on('model_install_cancelled', (data) => { + log.warn({ data }, 'Model install cancelled'); + + const { id } = data; + const installs = selectModelInstalls(getState()).data; + + if (!installs?.find((install) => install.id === id)) { + dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }])); + } else { + dispatch( + modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { + const modelImport = draft.find((m) => m.id === id); + if (modelImport) { + modelImport.status = 'cancelled'; + } + return draft; + }) + ); + } + }); + + socket.on('queue_item_status_changed', (data) => { + // we've got new status for the queue item, batch and queue + const { + item_id, + session_id, + status, + started_at, + updated_at, + completed_at, + batch_status, + queue_status, + error_type, + error_message, + error_traceback, + origin, + } = data; + + log.debug({ data }, `Queue item ${item_id} status updated: ${status}`); + + // Update this specific queue item in the list of queue items (this is the queue item DTO, without the session) + dispatch( + queueApi.util.updateQueryData('listQueueItems', undefined, (draft) => { + queueItemsAdapter.updateOne(draft, { + id: String(item_id), + changes: { + status, + started_at, + updated_at: updated_at ?? undefined, + completed_at: completed_at ?? undefined, + error_type, + error_message, + error_traceback, + }, + }); + }) + ); + + // Update the queue status (we do not get the processor status here) + dispatch( + queueApi.util.updateQueryData('getQueueStatus', undefined, (draft) => { + if (!draft) { + return; + } + Object.assign(draft.queue, queue_status); + }) + ); + + // Update the batch status + dispatch(queueApi.util.updateQueryData('getBatchStatus', { batch_id: batch_status.batch_id }, () => batch_status)); + + // Invalidate caches for things we cannot update + // TODO: technically, we could possibly update the current session queue item, but feels safer to just request it again + dispatch( + queueApi.util.invalidateTags([ + 'CurrentSessionQueueItem', + 'NextSessionQueueItem', + 'InvocationCacheStatus', + { type: 'SessionQueueItem', id: item_id }, + ]) + ); + + if (status === 'in_progress') { + forEach($nodeExecutionStates.get(), (nes) => { + if (!nes) { + return; + } + const clone = deepClone(nes); + clone.status = zNodeStatus.enum.PENDING; + clone.error = null; + clone.progress = null; + clone.progressImage = null; + clone.outputs = []; + $nodeExecutionStates.setKey(clone.nodeId, clone); + }); + } else if (status === 'failed' && error_type) { + const isLocal = getState().config.isLocal ?? true; + const sessionId = session_id; + $lastProgressEvent.set(null); + + if (origin === 'canvas') { + $lastCanvasProgressEvent.set(null); + } + + toast({ + id: `INVOCATION_ERROR_${error_type}`, + title: getTitleFromErrorType(error_type), + status: 'error', + duration: null, + updateDescription: isLocal, + description: ( + + ), + }); + cancellations.add(session_id); + } else if (status === 'canceled') { + $lastProgressEvent.set(null); + if (origin === 'canvas') { + $lastCanvasProgressEvent.set(null); + } + cancellations.add(session_id); + } else if (status === 'completed') { + $lastProgressEvent.set(null); + cancellations.add(session_id); + } + }); + + socket.on('queue_cleared', (data) => { + log.debug({ data }, 'Queue cleared'); + }); + + socket.on('batch_enqueued', (data) => { + log.debug({ data }, 'Batch enqueued'); + }); + + socket.on('bulk_download_started', (data) => { + log.debug({ data }, 'Bulk gallery download preparation started'); + }); + + socket.on('bulk_download_complete', (data) => { + log.debug({ data }, 'Bulk gallery download ready'); + const { bulk_download_item_name } = data; + + // TODO(psyche): This URL may break in in some environments (e.g. Nvidia workbench) but we need to test it first + const url = `/api/v1/images/download/${bulk_download_item_name}`; + + toast({ + id: bulk_download_item_name, + title: t('gallery.bulkDownloadReady', 'Download ready'), + status: 'success', + description: ( + + ), + duration: null, + }); + }); + + socket.on('bulk_download_error', (data) => { + log.error({ data }, 'Bulk gallery download error'); + + const { bulk_download_item_name, error } = data; + + toast({ + id: bulk_download_item_name, + title: t('gallery.bulkDownloadFailed'), + status: 'error', + description: error, + duration: null, + }); + }); +};