feat(nodes,ui): fix soft locks on session/invocation retrieval (#3910)

## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [x] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [x] Yes
- [ ] No, because:

      
## Have you updated all relevant documentation?
- [ ] Yes
- [x] No, n/a


## Description

When a queue item is popped for processing, we need to retrieve its
session from the DB. Pydantic serializes the graph at this stage.

It's possible for a graph to have been made invalid during the graph
preparation stage (e.g. an ancestor node executes, and its output is not
valid for its successor node's input field).

When this occurs, the session in the DB will fail validation, but we
don't have a chance to find out until it is retrieved and parsed by
pydantic.

This logic was previously not wrapped in any exception handling.

Just after retrieving a session, we retrieve the specific invocation to
execute from the session. It's possible that this could also have some
sort of error, though it should be impossible for it to be a pydantic
validation error (that would have been caught during session
validation). There was also no exception handling here.

When either of these processes fail, the processor gets soft-locked
because the processor's cleanup logic is never run. (I didn't dig deeper
into exactly what cleanup is not happening, because the fix is to just
handle the exceptions.)

This PR adds exception handling to both the session retrieval and node
retrieval and events for each: `session_retrieval_error` and
`invocation_retrieval_error`.

These events are caught and displayed in the UI as toasts, along with
the type of the python exception (e.g. `Validation Error`). The events
are also logged to the browser console.


## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->

Closes #3860 , #3412

## QA Instructions, Screenshots, Recordings

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

Create an valid graph that will become invalid during execution. Here's
an example:

![image](https://github.com/invoke-ai/InvokeAI/assets/4822129/50aa824c-fb0c-4bd9-82f4-38a4c89436f9)

This is valid before execution, but the `width` field of the `Noise`
node will end up with an invalid value (`0`). Previously, this would
soft-lock the app and you'd have to restart it.

Now, with this graph, you will get an error toast, and the app will not
get locked up.

## Added/updated tests?

- [x] Yes (ish)
- [ ] No

@Kyle0654  @brandonrising 
It seems because the processor runs in its own thread, `pytest` cannot
catch exceptions raised in the processor.

I added a test that does work, insofar as it does recreate the issue.
But, because the exception occurs in a separate thread, the test doesn't
see it. The result is that the test passes even without the fix.

So when running the test, we see the exception:
```py
Exception in thread invoker_processor:
Traceback (most recent call last):
  File "/usr/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.10/threading.py", line 953, in run
    self._target(*self._args, **self._kwargs)
  File "/home/bat/Documents/Code/InvokeAI/invokeai/app/services/processor.py", line 50, in __process
    self.__invoker.services.graph_execution_manager.get(
  File "/home/bat/Documents/Code/InvokeAI/invokeai/app/services/sqlite.py", line 79, in get
    return self._parse_item(result[0])

  File "/home/bat/Documents/Code/InvokeAI/invokeai/app/services/sqlite.py", line 52, in _parse_item
    return parse_raw_as(item_type, item)
  File "pydantic/tools.py", line 82, in pydantic.tools.parse_raw_as
  File "pydantic/tools.py", line 38, in pydantic.tools.parse_obj_as
  File "pydantic/main.py", line 341, in pydantic.main.BaseModel.__init__
```

But `pytest` doesn't actually see it as an exception. Not sure how to
fix this, it's a bit beyond me.

## [optional] Are there any post deployment tasks we need to perform?

nope don't think so
This commit is contained in:
blessedcoolant 2023-07-24 20:17:39 +12:00 committed by GitHub
commit d42c394ab7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 273 additions and 57 deletions

View File

@ -3,7 +3,13 @@
from typing import Any, Optional
from invokeai.app.models.image import ProgressImage
from invokeai.app.util.misc import get_timestamp
from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, ModelInfo
from invokeai.app.services.model_manager_service import (
BaseModelType,
ModelType,
SubModelType,
ModelInfo,
)
class EventServiceBase:
session_event: str = "session_event"
@ -38,7 +44,9 @@ class EventServiceBase:
graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
progress_image=progress_image.dict() if progress_image is not None else None,
progress_image=progress_image.dict()
if progress_image is not None
else None,
step=step,
total_steps=total_steps,
),
@ -67,6 +75,7 @@ class EventServiceBase:
graph_execution_state_id: str,
node: dict,
source_node_id: str,
error_type: str,
error: str,
) -> None:
"""Emitted when an invocation has completed"""
@ -76,6 +85,7 @@ class EventServiceBase:
graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
error_type=error_type,
error=error,
),
)
@ -145,3 +155,37 @@ class EventServiceBase:
precision=str(model_info.precision),
),
)
def emit_session_retrieval_error(
self,
graph_execution_state_id: str,
error_type: str,
error: str,
) -> None:
"""Emitted when session retrieval fails"""
self.__emit_session_event(
event_name="session_retrieval_error",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
error_type=error_type,
error=error,
),
)
def emit_invocation_retrieval_error(
self,
graph_execution_state_id: str,
node_id: str,
error_type: str,
error: str,
) -> None:
"""Emitted when invocation retrieval fails"""
self.__emit_session_event(
event_name="invocation_retrieval_error",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
node_id=node_id,
error_type=error_type,
error=error,
),
)

View File

@ -39,21 +39,41 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
try:
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
except Exception as e:
logger.debug("Exception while getting from queue: %s" % e)
self.__invoker.services.logger.error("Exception while getting from queue:\n%s" % e)
if not queue_item: # Probably stopping
# do not hammer the queue
time.sleep(0.5)
continue
try:
graph_execution_state = (
self.__invoker.services.graph_execution_manager.get(
queue_item.graph_execution_state_id
)
)
except Exception as e:
self.__invoker.services.logger.error("Exception while retrieving session:\n%s" % e)
self.__invoker.services.events.emit_session_retrieval_error(
graph_execution_state_id=queue_item.graph_execution_state_id,
error_type=e.__class__.__name__,
error=traceback.format_exc(),
)
continue
try:
invocation = graph_execution_state.execution_graph.get_node(
queue_item.invocation_id
)
except Exception as e:
self.__invoker.services.logger.error("Exception while retrieving invocation:\n%s" % e)
self.__invoker.services.events.emit_invocation_retrieval_error(
graph_execution_state_id=queue_item.graph_execution_state_id,
node_id=queue_item.invocation_id,
error_type=e.__class__.__name__,
error=traceback.format_exc(),
)
continue
# get the source node id to provide to clients (the prepared node id is not as useful)
source_node_id = graph_execution_state.prepared_source_mapping[invocation.id]
@ -114,11 +134,13 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
graph_execution_state
)
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
# Send error event
self.__invoker.services.events.emit_invocation_error(
graph_execution_state_id=graph_execution_state.id,
node=invocation.dict(),
source_node_id=source_node_id,
error_type=e.__class__.__name__,
error=error,
)
@ -136,11 +158,12 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
try:
self.__invoker.invoke(graph_execution_state, invoke_all=True)
except Exception as e:
logger.error("Error while invoking: %s" % e)
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
self.__invoker.services.events.emit_invocation_error(
graph_execution_state_id=graph_execution_state.id,
node=invocation.dict(),
source_node_id=source_node_id,
error_type=e.__class__.__name__,
error=traceback.format_exc()
)
elif is_complete:

View File

@ -75,6 +75,8 @@ import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
import { addSessionRetrievalErrorEventListener } from './listeners/socketio/socketSessionRetrievalError';
import { addInvocationRetrievalErrorEventListener } from './listeners/socketio/socketInvocationRetrievalError';
export const listenerMiddleware = createListenerMiddleware();
@ -153,6 +155,8 @@ addSocketDisconnectedListener();
addSocketSubscribedListener();
addSocketUnsubscribedListener();
addModelLoadEventListener();
addSessionRetrievalErrorEventListener();
addInvocationRetrievalErrorEventListener();
// Session Created
addSessionCreatedPendingListener();

View File

@ -33,12 +33,11 @@ export const addSessionCreatedRejectedListener = () => {
effect: (action) => {
const log = logger('session');
if (action.payload) {
const { error } = action.payload;
const { error, status } = action.payload;
const graph = parseify(action.meta.arg);
const stringifiedError = JSON.stringify(error);
log.error(
{ graph, error: serializeError(error) },
`Problem creating session: ${stringifiedError}`
{ graph, status, error: serializeError(error) },
`Problem creating session`
);
}
},

View File

@ -31,13 +31,12 @@ export const addSessionInvokedRejectedListener = () => {
const { session_id } = action.meta.arg;
if (action.payload) {
const { error } = action.payload;
const stringifiedError = JSON.stringify(error);
log.error(
{
session_id,
error: serializeError(error),
},
`Problem invoking session: ${stringifiedError}`
`Problem invoking session`
);
}
},

View File

@ -0,0 +1,20 @@
import { logger } from 'app/logging/logger';
import {
appSocketInvocationRetrievalError,
socketInvocationRetrievalError,
} from 'services/events/actions';
import { startAppListening } from '../..';
export const addInvocationRetrievalErrorEventListener = () => {
startAppListening({
actionCreator: socketInvocationRetrievalError,
effect: (action, { dispatch }) => {
const log = logger('socketio');
log.error(
action.payload,
`Invocation retrieval error (${action.payload.data.graph_execution_state_id})`
);
dispatch(appSocketInvocationRetrievalError(action.payload));
},
});
};

View File

@ -0,0 +1,20 @@
import { logger } from 'app/logging/logger';
import {
appSocketSessionRetrievalError,
socketSessionRetrievalError,
} from 'services/events/actions';
import { startAppListening } from '../..';
export const addSessionRetrievalErrorEventListener = () => {
startAppListening({
actionCreator: socketSessionRetrievalError,
effect: (action, { dispatch }) => {
const log = logger('socketio');
log.error(
action.payload,
`Session retrieval error (${action.payload.data.graph_execution_state_id})`
);
dispatch(appSocketSessionRetrievalError(action.payload));
},
});
};

View File

@ -1,5 +1,5 @@
import { UseToastOptions } from '@chakra-ui/react';
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import { PayloadAction, createSlice, isAnyOf } from '@reduxjs/toolkit';
import { InvokeLogLevel } from 'app/logging/logger';
import { userInvoked } from 'app/store/actions';
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
@ -16,13 +16,16 @@ import {
appSocketGraphExecutionStateComplete,
appSocketInvocationComplete,
appSocketInvocationError,
appSocketInvocationRetrievalError,
appSocketInvocationStarted,
appSocketSessionRetrievalError,
appSocketSubscribed,
appSocketUnsubscribed,
} from 'services/events/actions';
import { ProgressImage } from 'services/events/types';
import { makeToast } from '../util/makeToast';
import { LANGUAGES } from './constants';
import { startCase } from 'lodash-es';
export type CancelStrategy = 'immediate' | 'scheduled';
@ -288,25 +291,6 @@ export const systemSlice = createSlice({
}
});
/**
* Invocation Error
*/
builder.addCase(appSocketInvocationError, (state) => {
state.isProcessing = false;
state.isCancelable = true;
// state.currentIteration = 0;
// state.totalIterations = 0;
state.currentStatusHasSteps = false;
state.currentStep = 0;
state.totalSteps = 0;
state.statusTranslationKey = 'common.statusError';
state.progressImage = null;
state.toastQueue.push(
makeToast({ title: t('toast.serverError'), status: 'error' })
);
});
/**
* Graph Execution State Complete
*/
@ -362,7 +346,7 @@ export const systemSlice = createSlice({
* Session Invoked - REJECTED
* Session Created - REJECTED
*/
builder.addMatcher(isAnySessionRejected, (state) => {
builder.addMatcher(isAnySessionRejected, (state, action) => {
state.isProcessing = false;
state.isCancelable = false;
state.isCancelScheduled = false;
@ -372,7 +356,35 @@ export const systemSlice = createSlice({
state.progressImage = null;
state.toastQueue.push(
makeToast({ title: t('toast.serverError'), status: 'error' })
makeToast({
title: t('toast.serverError'),
status: 'error',
description:
action.payload?.status === 422 ? 'Validation Error' : undefined,
})
);
});
/**
* Any server error
*/
builder.addMatcher(isAnyServerError, (state, action) => {
state.isProcessing = false;
state.isCancelable = true;
// state.currentIteration = 0;
// state.totalIterations = 0;
state.currentStatusHasSteps = false;
state.currentStep = 0;
state.totalSteps = 0;
state.statusTranslationKey = 'common.statusError';
state.progressImage = null;
state.toastQueue.push(
makeToast({
title: t('toast.serverError'),
status: 'error',
description: startCase(action.payload.data.error_type),
})
);
});
},
@ -400,3 +412,9 @@ export const {
} = systemSlice.actions;
export default systemSlice.reducer;
const isAnyServerError = isAnyOf(
appSocketInvocationError,
appSocketSessionRetrievalError,
appSocketInvocationRetrievalError
);

View File

@ -18,7 +18,7 @@ type CreateSessionResponse = O.Required<
>;
type CreateSessionThunkConfig = {
rejectValue: { arg: CreateSessionArg; error: unknown };
rejectValue: { arg: CreateSessionArg; status: number; error: unknown };
};
/**
@ -36,7 +36,7 @@ export const sessionCreated = createAsyncThunk<
});
if (error) {
return rejectWithValue({ arg, error });
return rejectWithValue({ arg, status: response.status, error });
}
return data;
@ -53,6 +53,7 @@ type InvokedSessionThunkConfig = {
rejectValue: {
arg: InvokedSessionArg;
error: unknown;
status: number;
};
};
@ -78,9 +79,13 @@ export const sessionInvoked = createAsyncThunk<
if (error) {
if (isErrorWithStatus(error) && error.status === 403) {
return rejectWithValue({ arg, error: (error as any).body.detail });
return rejectWithValue({
arg,
status: response.status,
error: (error as any).body.detail,
});
}
return rejectWithValue({ arg, error });
return rejectWithValue({ arg, status: response.status, error });
}
});

View File

@ -4,9 +4,11 @@ import {
GraphExecutionStateCompleteEvent,
InvocationCompleteEvent,
InvocationErrorEvent,
InvocationRetrievalErrorEvent,
InvocationStartedEvent,
ModelLoadCompletedEvent,
ModelLoadStartedEvent,
SessionRetrievalErrorEvent,
} from 'services/events/types';
// Create actions for each socket
@ -181,3 +183,35 @@ export const socketModelLoadCompleted = createAction<{
export const appSocketModelLoadCompleted = createAction<{
data: ModelLoadCompletedEvent;
}>('socket/appSocketModelLoadCompleted');
/**
* Socket.IO Session Retrieval Error
*
* Do not use. Only for use in middleware.
*/
export const socketSessionRetrievalError = createAction<{
data: SessionRetrievalErrorEvent;
}>('socket/socketSessionRetrievalError');
/**
* App-level Session Retrieval Error
*/
export const appSocketSessionRetrievalError = createAction<{
data: SessionRetrievalErrorEvent;
}>('socket/appSocketSessionRetrievalError');
/**
* Socket.IO Invocation Retrieval Error
*
* Do not use. Only for use in middleware.
*/
export const socketInvocationRetrievalError = createAction<{
data: InvocationRetrievalErrorEvent;
}>('socket/socketInvocationRetrievalError');
/**
* App-level Invocation Retrieval Error
*/
export const appSocketInvocationRetrievalError = createAction<{
data: InvocationRetrievalErrorEvent;
}>('socket/appSocketInvocationRetrievalError');

View File

@ -87,6 +87,7 @@ export type InvocationErrorEvent = {
graph_execution_state_id: string;
node: BaseNode;
source_node_id: string;
error_type: string;
error: string;
};
@ -110,6 +111,29 @@ export type GraphExecutionStateCompleteEvent = {
graph_execution_state_id: string;
};
/**
* A `session_retrieval_error` socket.io event.
*
* @example socket.on('session_retrieval_error', (data: SessionRetrievalErrorEvent) => { ... }
*/
export type SessionRetrievalErrorEvent = {
graph_execution_state_id: string;
error_type: string;
error: string;
};
/**
* A `invocation_retrieval_error` socket.io event.
*
* @example socket.on('invocation_retrieval_error', (data: InvocationRetrievalErrorEvent) => { ... }
*/
export type InvocationRetrievalErrorEvent = {
graph_execution_state_id: string;
node_id: string;
error_type: string;
error: string;
};
export type ClientEmitSubscribe = {
session: string;
};
@ -128,6 +152,8 @@ export type ServerToClientEvents = {
) => void;
model_load_started: (payload: ModelLoadStartedEvent) => void;
model_load_completed: (payload: ModelLoadCompletedEvent) => void;
session_retrieval_error: (payload: SessionRetrievalErrorEvent) => void;
invocation_retrieval_error: (payload: InvocationRetrievalErrorEvent) => void;
};
export type ClientToServerEvents = {

View File

@ -11,9 +11,11 @@ import {
socketGraphExecutionStateComplete,
socketInvocationComplete,
socketInvocationError,
socketInvocationRetrievalError,
socketInvocationStarted,
socketModelLoadCompleted,
socketModelLoadStarted,
socketSessionRetrievalError,
socketSubscribed,
} from '../actions';
import { ClientToServerEvents, ServerToClientEvents } from '../types';
@ -138,4 +140,26 @@ export const setEventListeners = (arg: SetEventListenersArg) => {
})
);
});
/**
* Session retrieval error
*/
socket.on('session_retrieval_error', (data) => {
dispatch(
socketSessionRetrievalError({
data,
})
);
});
/**
* Invocation retrieval error
*/
socket.on('invocation_retrieval_error', (data) => {
dispatch(
socketInvocationRetrievalError({
data,
})
);
});
};