diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py
index 870ca33534..759f6c9f59 100644
--- a/invokeai/app/api/routers/models.py
+++ b/invokeai/app/api/routers/models.py
@@ -298,7 +298,7 @@ async def search_for_models(
)->List[pathlib.Path]:
if not search_path.is_dir():
raise HTTPException(status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory")
- return ApiDependencies.invoker.services.model_manager.search_for_models([search_path])
+ return ApiDependencies.invoker.services.model_manager.search_for_models(search_path)
@models_router.get(
"/ckpt_confs",
diff --git a/invokeai/app/services/events.py b/invokeai/app/services/events.py
index 35003536e6..73d74de2d9 100644
--- a/invokeai/app/services/events.py
+++ b/invokeai/app/services/events.py
@@ -3,7 +3,13 @@
from typing import Any, Optional
from invokeai.app.models.image import ProgressImage
from invokeai.app.util.misc import get_timestamp
-from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, ModelInfo
+from invokeai.app.services.model_manager_service import (
+ BaseModelType,
+ ModelType,
+ SubModelType,
+ ModelInfo,
+)
+
class EventServiceBase:
session_event: str = "session_event"
@@ -38,7 +44,9 @@ class EventServiceBase:
graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
- progress_image=progress_image.dict() if progress_image is not None else None,
+ progress_image=progress_image.dict()
+ if progress_image is not None
+ else None,
step=step,
total_steps=total_steps,
),
@@ -67,6 +75,7 @@ class EventServiceBase:
graph_execution_state_id: str,
node: dict,
source_node_id: str,
+ error_type: str,
error: str,
) -> None:
"""Emitted when an invocation has completed"""
@@ -76,6 +85,7 @@ class EventServiceBase:
graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
+ error_type=error_type,
error=error,
),
)
@@ -102,13 +112,13 @@ class EventServiceBase:
),
)
- def emit_model_load_started (
- self,
- graph_execution_state_id: str,
- model_name: str,
- base_model: BaseModelType,
- model_type: ModelType,
- submodel: SubModelType,
+ def emit_model_load_started(
+ self,
+ graph_execution_state_id: str,
+ model_name: str,
+ base_model: BaseModelType,
+ model_type: ModelType,
+ submodel: SubModelType,
) -> None:
"""Emitted when a model is requested"""
self.__emit_session_event(
@@ -123,13 +133,13 @@ class EventServiceBase:
)
def emit_model_load_completed(
- self,
- graph_execution_state_id: str,
- model_name: str,
- base_model: BaseModelType,
- model_type: ModelType,
- submodel: SubModelType,
- model_info: ModelInfo,
+ self,
+ graph_execution_state_id: str,
+ model_name: str,
+ base_model: BaseModelType,
+ model_type: ModelType,
+ submodel: SubModelType,
+ model_info: ModelInfo,
) -> None:
"""Emitted when a model is correctly loaded (returns model info)"""
self.__emit_session_event(
@@ -145,3 +155,37 @@ class EventServiceBase:
precision=str(model_info.precision),
),
)
+
+ def emit_session_retrieval_error(
+ self,
+ graph_execution_state_id: str,
+ error_type: str,
+ error: str,
+ ) -> None:
+ """Emitted when session retrieval fails"""
+ self.__emit_session_event(
+ event_name="session_retrieval_error",
+ payload=dict(
+ graph_execution_state_id=graph_execution_state_id,
+ error_type=error_type,
+ error=error,
+ ),
+ )
+
+ def emit_invocation_retrieval_error(
+ self,
+ graph_execution_state_id: str,
+ node_id: str,
+ error_type: str,
+ error: str,
+ ) -> None:
+ """Emitted when invocation retrieval fails"""
+ self.__emit_session_event(
+ event_name="invocation_retrieval_error",
+ payload=dict(
+ graph_execution_state_id=graph_execution_state_id,
+ node_id=node_id,
+ error_type=error_type,
+ error=error,
+ ),
+ )
diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py
index b1b995309e..f7d3b3a7a7 100644
--- a/invokeai/app/services/model_manager_service.py
+++ b/invokeai/app/services/model_manager_service.py
@@ -600,7 +600,7 @@ class ModelManagerService(ModelManagerServiceBase):
"""
Return list of all models found in the designated directory.
"""
- search = FindModels(directory,self.logger)
+ search = FindModels([directory], self.logger)
return search.list_models()
def sync_to_config(self):
diff --git a/invokeai/app/services/processor.py b/invokeai/app/services/processor.py
index e11eb84b3d..5995e4ffc3 100644
--- a/invokeai/app/services/processor.py
+++ b/invokeai/app/services/processor.py
@@ -39,21 +39,41 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
try:
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
except Exception as e:
- logger.debug("Exception while getting from queue: %s" % e)
+ self.__invoker.services.logger.error("Exception while getting from queue:\n%s" % e)
if not queue_item: # Probably stopping
# do not hammer the queue
time.sleep(0.5)
continue
- graph_execution_state = (
- self.__invoker.services.graph_execution_manager.get(
- queue_item.graph_execution_state_id
+ try:
+ graph_execution_state = (
+ self.__invoker.services.graph_execution_manager.get(
+ queue_item.graph_execution_state_id
+ )
)
- )
- invocation = graph_execution_state.execution_graph.get_node(
- queue_item.invocation_id
- )
+ except Exception as e:
+ self.__invoker.services.logger.error("Exception while retrieving session:\n%s" % e)
+ self.__invoker.services.events.emit_session_retrieval_error(
+ graph_execution_state_id=queue_item.graph_execution_state_id,
+ error_type=e.__class__.__name__,
+ error=traceback.format_exc(),
+ )
+ continue
+
+ try:
+ invocation = graph_execution_state.execution_graph.get_node(
+ queue_item.invocation_id
+ )
+ except Exception as e:
+ self.__invoker.services.logger.error("Exception while retrieving invocation:\n%s" % e)
+ self.__invoker.services.events.emit_invocation_retrieval_error(
+ graph_execution_state_id=queue_item.graph_execution_state_id,
+ node_id=queue_item.invocation_id,
+ error_type=e.__class__.__name__,
+ error=traceback.format_exc(),
+ )
+ continue
# get the source node id to provide to clients (the prepared node id is not as useful)
source_node_id = graph_execution_state.prepared_source_mapping[invocation.id]
@@ -114,11 +134,13 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
graph_execution_state
)
+ self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
# Send error event
self.__invoker.services.events.emit_invocation_error(
graph_execution_state_id=graph_execution_state.id,
node=invocation.dict(),
source_node_id=source_node_id,
+ error_type=e.__class__.__name__,
error=error,
)
@@ -136,11 +158,12 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
try:
self.__invoker.invoke(graph_execution_state, invoke_all=True)
except Exception as e:
- logger.error("Error while invoking: %s" % e)
+ self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
self.__invoker.services.events.emit_invocation_error(
graph_execution_state_id=graph_execution_state.id,
node=invocation.dict(),
source_node_id=source_node_id,
+ error_type=e.__class__.__name__,
error=traceback.format_exc()
)
elif is_complete:
diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py
index b0481f3cfa..222169afbb 100644
--- a/invokeai/backend/model_management/lora.py
+++ b/invokeai/backend/model_management/lora.py
@@ -474,7 +474,7 @@ class ModelPatcher:
@staticmethod
def _lora_forward_hook(
- applied_loras: List[Tuple[LoraModel, float]],
+ applied_loras: List[Tuple[LoRAModel, float]],
layer_name: str,
):
@@ -519,7 +519,7 @@ class ModelPatcher:
def apply_lora(
cls,
model: torch.nn.Module,
- loras: List[Tuple[LoraModel, float]],
+ loras: List[Tuple[LoRAModel, float]],
prefix: str,
):
original_weights = dict()
diff --git a/invokeai/backend/model_management/model_search.py b/invokeai/backend/model_management/model_search.py
index 1e282b4bb8..5657bd9549 100644
--- a/invokeai/backend/model_management/model_search.py
+++ b/invokeai/backend/model_management/model_search.py
@@ -98,6 +98,6 @@ class FindModels(ModelSearch):
def list_models(self) -> List[Path]:
self.search()
- return self.models_found
+ return list(self.models_found)
diff --git a/invokeai/backend/model_management/models/lora.py b/invokeai/backend/model_management/models/lora.py
index 5387ade0e5..eb771841ec 100644
--- a/invokeai/backend/model_management/models/lora.py
+++ b/invokeai/backend/model_management/models/lora.py
@@ -10,6 +10,7 @@ from .base import (
SubModelType,
classproperty,
InvalidModelException,
+ ModelNotFoundException,
)
# TODO: naming
from ..lora import LoRAModel as LoRAModelRaw
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts
index 04f0ce7a0b..5adc4f5e5e 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts
@@ -75,6 +75,8 @@ import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
+import { addSessionRetrievalErrorEventListener } from './listeners/socketio/socketSessionRetrievalError';
+import { addInvocationRetrievalErrorEventListener } from './listeners/socketio/socketInvocationRetrievalError';
export const listenerMiddleware = createListenerMiddleware();
@@ -153,6 +155,8 @@ addSocketDisconnectedListener();
addSocketSubscribedListener();
addSocketUnsubscribedListener();
addModelLoadEventListener();
+addSessionRetrievalErrorEventListener();
+addInvocationRetrievalErrorEventListener();
// Session Created
addSessionCreatedPendingListener();
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionCreated.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionCreated.ts
index 5709d87d22..e89acb7542 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionCreated.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionCreated.ts
@@ -33,12 +33,11 @@ export const addSessionCreatedRejectedListener = () => {
effect: (action) => {
const log = logger('session');
if (action.payload) {
- const { error } = action.payload;
+ const { error, status } = action.payload;
const graph = parseify(action.meta.arg);
- const stringifiedError = JSON.stringify(error);
log.error(
- { graph, error: serializeError(error) },
- `Problem creating session: ${stringifiedError}`
+ { graph, status, error: serializeError(error) },
+ `Problem creating session`
);
}
},
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionInvoked.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionInvoked.ts
index 60009ed194..a62f75d957 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionInvoked.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionInvoked.ts
@@ -31,13 +31,12 @@ export const addSessionInvokedRejectedListener = () => {
const { session_id } = action.meta.arg;
if (action.payload) {
const { error } = action.payload;
- const stringifiedError = JSON.stringify(error);
log.error(
{
session_id,
error: serializeError(error),
},
- `Problem invoking session: ${stringifiedError}`
+ `Problem invoking session`
);
}
},
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationRetrievalError.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationRetrievalError.ts
new file mode 100644
index 0000000000..aa88457eb7
--- /dev/null
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationRetrievalError.ts
@@ -0,0 +1,20 @@
+import { logger } from 'app/logging/logger';
+import {
+ appSocketInvocationRetrievalError,
+ socketInvocationRetrievalError,
+} from 'services/events/actions';
+import { startAppListening } from '../..';
+
+export const addInvocationRetrievalErrorEventListener = () => {
+ startAppListening({
+ actionCreator: socketInvocationRetrievalError,
+ effect: (action, { dispatch }) => {
+ const log = logger('socketio');
+ log.error(
+ action.payload,
+ `Invocation retrieval error (${action.payload.data.graph_execution_state_id})`
+ );
+ dispatch(appSocketInvocationRetrievalError(action.payload));
+ },
+ });
+};
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSessionRetrievalError.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSessionRetrievalError.ts
new file mode 100644
index 0000000000..7efb7f463a
--- /dev/null
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSessionRetrievalError.ts
@@ -0,0 +1,20 @@
+import { logger } from 'app/logging/logger';
+import {
+ appSocketSessionRetrievalError,
+ socketSessionRetrievalError,
+} from 'services/events/actions';
+import { startAppListening } from '../..';
+
+export const addSessionRetrievalErrorEventListener = () => {
+ startAppListening({
+ actionCreator: socketSessionRetrievalError,
+ effect: (action, { dispatch }) => {
+ const log = logger('socketio');
+ log.error(
+ action.payload,
+ `Session retrieval error (${action.payload.data.graph_execution_state_id})`
+ );
+ dispatch(appSocketSessionRetrievalError(action.payload));
+ },
+ });
+};
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts
index 2ef62aed7b..39bd742d7d 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts
@@ -39,8 +39,22 @@ export const addUserInvokedCanvasListener = () => {
const state = getState();
+ const {
+ layerState,
+ boundingBoxCoordinates,
+ boundingBoxDimensions,
+ isMaskEnabled,
+ shouldPreserveMaskedArea,
+ } = state.canvas;
+
// Build canvas blobs
- const canvasBlobsAndImageData = await getCanvasData(state);
+ const canvasBlobsAndImageData = await getCanvasData(
+ layerState,
+ boundingBoxCoordinates,
+ boundingBoxDimensions,
+ isMaskEnabled,
+ shouldPreserveMaskedArea
+ );
if (!canvasBlobsAndImageData) {
log.error('Unable to create canvas data');
diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvasStatusText.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvasStatusText.tsx
index 69bf628a39..8c1dfbb86f 100644
--- a/invokeai/frontend/web/src/features/canvas/components/IAICanvasStatusText.tsx
+++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvasStatusText.tsx
@@ -2,8 +2,8 @@ import { Box, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
+import GenerationModeStatusText from 'features/parameters/components/Parameters/Canvas/GenerationModeStatusText';
import { isEqual } from 'lodash-es';
-
import { useTranslation } from 'react-i18next';
import roundToHundreth from '../util/roundToHundreth';
import IAICanvasStatusTextCursorPos from './IAICanvasStatusText/IAICanvasStatusTextCursorPos';
@@ -110,6 +110,7 @@ const IAICanvasStatusText = () => {
},
}}
>
+
{
}}
>
- {
+ const layerState = useAppSelector((state) => state.canvas.layerState);
+
+ const boundingBoxCoordinates = useAppSelector(
+ (state) => state.canvas.boundingBoxCoordinates
+ );
+ const boundingBoxDimensions = useAppSelector(
+ (state) => state.canvas.boundingBoxDimensions
+ );
+ const isMaskEnabled = useAppSelector((state) => state.canvas.isMaskEnabled);
+
+ const shouldPreserveMaskedArea = useAppSelector(
+ (state) => state.canvas.shouldPreserveMaskedArea
+ );
+ const [generationMode, setGenerationMode] = useState<
+ GenerationMode | undefined
+ >();
+
+ useEffect(() => {
+ setGenerationMode(undefined);
+ }, [
+ layerState,
+ boundingBoxCoordinates,
+ boundingBoxDimensions,
+ isMaskEnabled,
+ shouldPreserveMaskedArea,
+ ]);
+
+ useDebounce(
+ async () => {
+ // Build canvas blobs
+ const canvasBlobsAndImageData = await getCanvasData(
+ layerState,
+ boundingBoxCoordinates,
+ boundingBoxDimensions,
+ isMaskEnabled,
+ shouldPreserveMaskedArea
+ );
+
+ if (!canvasBlobsAndImageData) {
+ return;
+ }
+
+ const { baseImageData, maskImageData } = canvasBlobsAndImageData;
+
+ // Determine the generation mode
+ const generationMode = getCanvasGenerationMode(
+ baseImageData,
+ maskImageData
+ );
+
+ setGenerationMode(generationMode);
+ },
+ 1000,
+ [
+ layerState,
+ boundingBoxCoordinates,
+ boundingBoxDimensions,
+ isMaskEnabled,
+ shouldPreserveMaskedArea,
+ ]
+ );
+
+ return generationMode;
+};
diff --git a/invokeai/frontend/web/src/features/canvas/store/canvasTypes.ts b/invokeai/frontend/web/src/features/canvas/store/canvasTypes.ts
index 48d59395ab..ba85a7e132 100644
--- a/invokeai/frontend/web/src/features/canvas/store/canvasTypes.ts
+++ b/invokeai/frontend/web/src/features/canvas/store/canvasTypes.ts
@@ -168,4 +168,7 @@ export interface CanvasState {
stageDimensions: Dimensions;
stageScale: number;
tool: CanvasTool;
+ generationMode?: GenerationMode;
}
+
+export type GenerationMode = 'txt2img' | 'img2img' | 'inpaint' | 'outpaint';
diff --git a/invokeai/frontend/web/src/features/canvas/util/getCanvasData.ts b/invokeai/frontend/web/src/features/canvas/util/getCanvasData.ts
index d37ee7b8d0..4e575791ed 100644
--- a/invokeai/frontend/web/src/features/canvas/util/getCanvasData.ts
+++ b/invokeai/frontend/web/src/features/canvas/util/getCanvasData.ts
@@ -1,6 +1,10 @@
import { logger } from 'app/logging/logger';
-import { RootState } from 'app/store/store';
-import { isCanvasMaskLine } from '../store/canvasTypes';
+import { Vector2d } from 'konva/lib/types';
+import {
+ CanvasLayerState,
+ Dimensions,
+ isCanvasMaskLine,
+} from '../store/canvasTypes';
import createMaskStage from './createMaskStage';
import { getCanvasBaseLayer, getCanvasStage } from './konvaInstanceProvider';
import { konvaNodeToBlob } from './konvaNodeToBlob';
@@ -9,7 +13,13 @@ import { konvaNodeToImageData } from './konvaNodeToImageData';
/**
* Gets Blob and ImageData objects for the base and mask layers
*/
-export const getCanvasData = async (state: RootState) => {
+export const getCanvasData = async (
+ layerState: CanvasLayerState,
+ boundingBoxCoordinates: Vector2d,
+ boundingBoxDimensions: Dimensions,
+ isMaskEnabled: boolean,
+ shouldPreserveMaskedArea: boolean
+) => {
const log = logger('canvas');
const canvasBaseLayer = getCanvasBaseLayer();
@@ -20,14 +30,6 @@ export const getCanvasData = async (state: RootState) => {
return;
}
- const {
- layerState: { objects },
- boundingBoxCoordinates,
- boundingBoxDimensions,
- isMaskEnabled,
- shouldPreserveMaskedArea,
- } = state.canvas;
-
const boundingBox = {
...boundingBoxCoordinates,
...boundingBoxDimensions,
@@ -58,7 +60,7 @@ export const getCanvasData = async (state: RootState) => {
// For the mask layer, use the normal boundingBox
const maskStage = await createMaskStage(
- isMaskEnabled ? objects.filter(isCanvasMaskLine) : [], // only include mask lines, and only if mask is enabled
+ isMaskEnabled ? layerState.objects.filter(isCanvasMaskLine) : [], // only include mask lines, and only if mask is enabled
boundingBox,
shouldPreserveMaskedArea
);
diff --git a/invokeai/frontend/web/src/features/canvas/util/getCanvasGenerationMode.ts b/invokeai/frontend/web/src/features/canvas/util/getCanvasGenerationMode.ts
index 5b38ecf938..d3e8792690 100644
--- a/invokeai/frontend/web/src/features/canvas/util/getCanvasGenerationMode.ts
+++ b/invokeai/frontend/web/src/features/canvas/util/getCanvasGenerationMode.ts
@@ -2,11 +2,12 @@ import {
areAnyPixelsBlack,
getImageDataTransparency,
} from 'common/util/arrayBuffer';
+import { GenerationMode } from '../store/canvasTypes';
export const getCanvasGenerationMode = (
baseImageData: ImageData,
maskImageData: ImageData
-) => {
+): GenerationMode => {
const {
isPartiallyTransparent: baseIsPartiallyTransparent,
isFullyTransparent: baseIsFullyTransparent,
diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/GenerationModeStatusText.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/GenerationModeStatusText.tsx
new file mode 100644
index 0000000000..511e90f0f3
--- /dev/null
+++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/GenerationModeStatusText.tsx
@@ -0,0 +1,21 @@
+import { Box } from '@chakra-ui/react';
+import { useCanvasGenerationMode } from 'features/canvas/hooks/useCanvasGenerationMode';
+
+const GENERATION_MODE_NAME_MAP = {
+ txt2img: 'Text to Image',
+ img2img: 'Image to Image',
+ inpaint: 'Inpaint',
+ outpaint: 'Inpaint',
+};
+
+const GenerationModeStatusText = () => {
+ const generationMode = useCanvasGenerationMode();
+
+ return (
+
+ Mode: {generationMode ? GENERATION_MODE_NAME_MAP[generationMode] : '...'}
+
+ );
+};
+
+export default GenerationModeStatusText;
diff --git a/invokeai/frontend/web/src/features/system/store/systemSlice.ts b/invokeai/frontend/web/src/features/system/store/systemSlice.ts
index 629a4f0139..b7a5e606e2 100644
--- a/invokeai/frontend/web/src/features/system/store/systemSlice.ts
+++ b/invokeai/frontend/web/src/features/system/store/systemSlice.ts
@@ -1,5 +1,5 @@
import { UseToastOptions } from '@chakra-ui/react';
-import { PayloadAction, createSlice } from '@reduxjs/toolkit';
+import { PayloadAction, createSlice, isAnyOf } from '@reduxjs/toolkit';
import { InvokeLogLevel } from 'app/logging/logger';
import { userInvoked } from 'app/store/actions';
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
@@ -16,13 +16,16 @@ import {
appSocketGraphExecutionStateComplete,
appSocketInvocationComplete,
appSocketInvocationError,
+ appSocketInvocationRetrievalError,
appSocketInvocationStarted,
+ appSocketSessionRetrievalError,
appSocketSubscribed,
appSocketUnsubscribed,
} from 'services/events/actions';
import { ProgressImage } from 'services/events/types';
import { makeToast } from '../util/makeToast';
import { LANGUAGES } from './constants';
+import { startCase } from 'lodash-es';
export type CancelStrategy = 'immediate' | 'scheduled';
@@ -288,25 +291,6 @@ export const systemSlice = createSlice({
}
});
- /**
- * Invocation Error
- */
- builder.addCase(appSocketInvocationError, (state) => {
- state.isProcessing = false;
- state.isCancelable = true;
- // state.currentIteration = 0;
- // state.totalIterations = 0;
- state.currentStatusHasSteps = false;
- state.currentStep = 0;
- state.totalSteps = 0;
- state.statusTranslationKey = 'common.statusError';
- state.progressImage = null;
-
- state.toastQueue.push(
- makeToast({ title: t('toast.serverError'), status: 'error' })
- );
- });
-
/**
* Graph Execution State Complete
*/
@@ -362,7 +346,7 @@ export const systemSlice = createSlice({
* Session Invoked - REJECTED
* Session Created - REJECTED
*/
- builder.addMatcher(isAnySessionRejected, (state) => {
+ builder.addMatcher(isAnySessionRejected, (state, action) => {
state.isProcessing = false;
state.isCancelable = false;
state.isCancelScheduled = false;
@@ -372,7 +356,35 @@ export const systemSlice = createSlice({
state.progressImage = null;
state.toastQueue.push(
- makeToast({ title: t('toast.serverError'), status: 'error' })
+ makeToast({
+ title: t('toast.serverError'),
+ status: 'error',
+ description:
+ action.payload?.status === 422 ? 'Validation Error' : undefined,
+ })
+ );
+ });
+
+ /**
+ * Any server error
+ */
+ builder.addMatcher(isAnyServerError, (state, action) => {
+ state.isProcessing = false;
+ state.isCancelable = true;
+ // state.currentIteration = 0;
+ // state.totalIterations = 0;
+ state.currentStatusHasSteps = false;
+ state.currentStep = 0;
+ state.totalSteps = 0;
+ state.statusTranslationKey = 'common.statusError';
+ state.progressImage = null;
+
+ state.toastQueue.push(
+ makeToast({
+ title: t('toast.serverError'),
+ status: 'error',
+ description: startCase(action.payload.data.error_type),
+ })
);
});
},
@@ -400,3 +412,9 @@ export const {
} = systemSlice.actions;
export default systemSlice.reducer;
+
+const isAnyServerError = isAnyOf(
+ appSocketInvocationError,
+ appSocketSessionRetrievalError,
+ appSocketInvocationRetrievalError
+);
diff --git a/invokeai/frontend/web/src/services/api/thunks/session.ts b/invokeai/frontend/web/src/services/api/thunks/session.ts
index 6d20b9dd33..5588f25b46 100644
--- a/invokeai/frontend/web/src/services/api/thunks/session.ts
+++ b/invokeai/frontend/web/src/services/api/thunks/session.ts
@@ -18,7 +18,7 @@ type CreateSessionResponse = O.Required<
>;
type CreateSessionThunkConfig = {
- rejectValue: { arg: CreateSessionArg; error: unknown };
+ rejectValue: { arg: CreateSessionArg; status: number; error: unknown };
};
/**
@@ -36,7 +36,7 @@ export const sessionCreated = createAsyncThunk<
});
if (error) {
- return rejectWithValue({ arg, error });
+ return rejectWithValue({ arg, status: response.status, error });
}
return data;
@@ -53,6 +53,7 @@ type InvokedSessionThunkConfig = {
rejectValue: {
arg: InvokedSessionArg;
error: unknown;
+ status: number;
};
};
@@ -78,9 +79,13 @@ export const sessionInvoked = createAsyncThunk<
if (error) {
if (isErrorWithStatus(error) && error.status === 403) {
- return rejectWithValue({ arg, error: (error as any).body.detail });
+ return rejectWithValue({
+ arg,
+ status: response.status,
+ error: (error as any).body.detail,
+ });
}
- return rejectWithValue({ arg, error });
+ return rejectWithValue({ arg, status: response.status, error });
}
});
diff --git a/invokeai/frontend/web/src/services/events/actions.ts b/invokeai/frontend/web/src/services/events/actions.ts
index b6316c5e95..35ebb725cb 100644
--- a/invokeai/frontend/web/src/services/events/actions.ts
+++ b/invokeai/frontend/web/src/services/events/actions.ts
@@ -4,9 +4,11 @@ import {
GraphExecutionStateCompleteEvent,
InvocationCompleteEvent,
InvocationErrorEvent,
+ InvocationRetrievalErrorEvent,
InvocationStartedEvent,
ModelLoadCompletedEvent,
ModelLoadStartedEvent,
+ SessionRetrievalErrorEvent,
} from 'services/events/types';
// Create actions for each socket
@@ -181,3 +183,35 @@ export const socketModelLoadCompleted = createAction<{
export const appSocketModelLoadCompleted = createAction<{
data: ModelLoadCompletedEvent;
}>('socket/appSocketModelLoadCompleted');
+
+/**
+ * Socket.IO Session Retrieval Error
+ *
+ * Do not use. Only for use in middleware.
+ */
+export const socketSessionRetrievalError = createAction<{
+ data: SessionRetrievalErrorEvent;
+}>('socket/socketSessionRetrievalError');
+
+/**
+ * App-level Session Retrieval Error
+ */
+export const appSocketSessionRetrievalError = createAction<{
+ data: SessionRetrievalErrorEvent;
+}>('socket/appSocketSessionRetrievalError');
+
+/**
+ * Socket.IO Invocation Retrieval Error
+ *
+ * Do not use. Only for use in middleware.
+ */
+export const socketInvocationRetrievalError = createAction<{
+ data: InvocationRetrievalErrorEvent;
+}>('socket/socketInvocationRetrievalError');
+
+/**
+ * App-level Invocation Retrieval Error
+ */
+export const appSocketInvocationRetrievalError = createAction<{
+ data: InvocationRetrievalErrorEvent;
+}>('socket/appSocketInvocationRetrievalError');
diff --git a/invokeai/frontend/web/src/services/events/types.ts b/invokeai/frontend/web/src/services/events/types.ts
index ec1b55e3fe..37f5f24eac 100644
--- a/invokeai/frontend/web/src/services/events/types.ts
+++ b/invokeai/frontend/web/src/services/events/types.ts
@@ -87,6 +87,7 @@ export type InvocationErrorEvent = {
graph_execution_state_id: string;
node: BaseNode;
source_node_id: string;
+ error_type: string;
error: string;
};
@@ -110,6 +111,29 @@ export type GraphExecutionStateCompleteEvent = {
graph_execution_state_id: string;
};
+/**
+ * A `session_retrieval_error` socket.io event.
+ *
+ * @example socket.on('session_retrieval_error', (data: SessionRetrievalErrorEvent) => { ... }
+ */
+export type SessionRetrievalErrorEvent = {
+ graph_execution_state_id: string;
+ error_type: string;
+ error: string;
+};
+
+/**
+ * A `invocation_retrieval_error` socket.io event.
+ *
+ * @example socket.on('invocation_retrieval_error', (data: InvocationRetrievalErrorEvent) => { ... }
+ */
+export type InvocationRetrievalErrorEvent = {
+ graph_execution_state_id: string;
+ node_id: string;
+ error_type: string;
+ error: string;
+};
+
export type ClientEmitSubscribe = {
session: string;
};
@@ -128,6 +152,8 @@ export type ServerToClientEvents = {
) => void;
model_load_started: (payload: ModelLoadStartedEvent) => void;
model_load_completed: (payload: ModelLoadCompletedEvent) => void;
+ session_retrieval_error: (payload: SessionRetrievalErrorEvent) => void;
+ invocation_retrieval_error: (payload: InvocationRetrievalErrorEvent) => void;
};
export type ClientToServerEvents = {
diff --git a/invokeai/frontend/web/src/services/events/util/setEventListeners.ts b/invokeai/frontend/web/src/services/events/util/setEventListeners.ts
index d44a549183..9ebb7ffbff 100644
--- a/invokeai/frontend/web/src/services/events/util/setEventListeners.ts
+++ b/invokeai/frontend/web/src/services/events/util/setEventListeners.ts
@@ -11,9 +11,11 @@ import {
socketGraphExecutionStateComplete,
socketInvocationComplete,
socketInvocationError,
+ socketInvocationRetrievalError,
socketInvocationStarted,
socketModelLoadCompleted,
socketModelLoadStarted,
+ socketSessionRetrievalError,
socketSubscribed,
} from '../actions';
import { ClientToServerEvents, ServerToClientEvents } from '../types';
@@ -138,4 +140,26 @@ export const setEventListeners = (arg: SetEventListenersArg) => {
})
);
});
+
+ /**
+ * Session retrieval error
+ */
+ socket.on('session_retrieval_error', (data) => {
+ dispatch(
+ socketSessionRetrievalError({
+ data,
+ })
+ );
+ });
+
+ /**
+ * Invocation retrieval error
+ */
+ socket.on('invocation_retrieval_error', (data) => {
+ dispatch(
+ socketInvocationRetrievalError({
+ data,
+ })
+ );
+ });
};
diff --git a/mkdocs.yml b/mkdocs.yml
index 7d3e0e0b85..cbcaf52af6 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -101,7 +101,7 @@ plugins:
nav:
- Home: 'index.md'
- - Installation:
+ - Installation:
- Overview: 'installation/index.md'
- Installing with the Automated Installer: 'installation/010_INSTALL_AUTOMATED.md'
- Installing manually: 'installation/020_INSTALL_MANUAL.md'
@@ -122,14 +122,14 @@ nav:
- Community Nodes:
- Community Nodes: 'nodes/communityNodes.md'
- Overview: 'nodes/overview.md'
- - Features:
+ - Features:
- Overview: 'features/index.md'
- Concepts: 'features/CONCEPTS.md'
- Configuration: 'features/CONFIGURATION.md'
- ControlNet: 'features/CONTROLNET.md'
- Image-to-Image: 'features/IMG2IMG.md'
- Controlling Logging: 'features/LOGGING.md'
- - Model Mergeing: 'features/MODEL_MERGING.md'
+ - Model Merging: 'features/MODEL_MERGING.md'
- Nodes Editor (Experimental): 'features/NODES.md'
- NSFW Checker: 'features/NSFW.md'
- Postprocessing: 'features/POSTPROCESS.md'
@@ -140,9 +140,9 @@ nav:
- InvokeAI Web Server: 'features/WEB.md'
- WebUI Hotkeys: "features/WEBUIHOTKEYS.md"
- Other: 'features/OTHER.md'
- - Contributing:
+ - Contributing:
- How to Contribute: 'contributing/CONTRIBUTING.md'
- - Development:
+ - Development:
- Overview: 'contributing/contribution_guides/development.md'
- InvokeAI Architecture: 'contributing/ARCHITECTURE.md'
- Frontend Documentation: 'contributing/contribution_guides/development_guides/contributingToFrontend.md'
@@ -161,5 +161,3 @@ nav:
- Other:
- Contributors: 'other/CONTRIBUTORS.md'
- CompViz-README: 'other/README-CompViz.md'
-
-