From 4b334be7d030d369bbac2d0dc4d30fb27954fa2b Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 23 Jul 2023 12:27:59 +1000 Subject: [PATCH 1/7] feat(nodes,ui): fix soft locks on session/invocation retrieval When a queue item is popped for processing, we need to retrieve its session from the DB. Pydantic serializes the graph at this stage. It's possible for a graph to have been made invalid during the graph preparation stage (e.g. an ancestor node executes, and its output is not valid for its successor node's input field). When this occurs, the session in the DB will fail validation, but we don't have a chance to find out until it is retrieved and parsed by pydantic. This logic was previously not wrapped in any exception handling. Just after retrieving a session, we retrieve the specific invocation to execute from the session. It's possible that this could also have some sort of error, though it should be impossible for it to be a pydantic validation error (that would have been caught during session validation). There was also no exception handling here. When either of these processes fail, the processor gets soft-locked because the processor's cleanup logic is never run. (I didn't dig deeper into exactly what cleanup is not happening, because the fix is to just handle the exceptions.) This PR adds exception handling to both the session retrieval and node retrieval and events for each: `session_retrieval_error` and `invocation_retrieval_error`. These events are caught and displayed in the UI as toasts, along with the type of the python exception (e.g. `Validation Error`). The events are also logged to the browser console. --- invokeai/app/services/events.py | 76 +++++++++++++++---- invokeai/app/services/processor.py | 41 +++++++--- .../middleware/listenerMiddleware/index.ts | 4 + .../listeners/sessionCreated.ts | 7 +- .../listeners/sessionInvoked.ts | 3 +- .../socketInvocationRetrievalError.ts | 20 +++++ .../socketio/socketSessionRetrievalError.ts | 20 +++++ .../src/features/system/store/systemSlice.ts | 62 +++++++++------ .../web/src/services/api/thunks/session.ts | 13 +++- .../web/src/services/events/actions.ts | 34 +++++++++ .../frontend/web/src/services/events/types.ts | 26 +++++++ .../services/events/util/setEventListeners.ts | 24 ++++++ 12 files changed, 273 insertions(+), 57 deletions(-) create mode 100644 invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationRetrievalError.ts create mode 100644 invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSessionRetrievalError.ts diff --git a/invokeai/app/services/events.py b/invokeai/app/services/events.py index 35003536e6..73d74de2d9 100644 --- a/invokeai/app/services/events.py +++ b/invokeai/app/services/events.py @@ -3,7 +3,13 @@ from typing import Any, Optional from invokeai.app.models.image import ProgressImage from invokeai.app.util.misc import get_timestamp -from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, ModelInfo +from invokeai.app.services.model_manager_service import ( + BaseModelType, + ModelType, + SubModelType, + ModelInfo, +) + class EventServiceBase: session_event: str = "session_event" @@ -38,7 +44,9 @@ class EventServiceBase: graph_execution_state_id=graph_execution_state_id, node=node, source_node_id=source_node_id, - progress_image=progress_image.dict() if progress_image is not None else None, + progress_image=progress_image.dict() + if progress_image is not None + else None, step=step, total_steps=total_steps, ), @@ -67,6 +75,7 @@ class EventServiceBase: graph_execution_state_id: str, node: dict, source_node_id: str, + error_type: str, error: str, ) -> None: """Emitted when an invocation has completed""" @@ -76,6 +85,7 @@ class EventServiceBase: graph_execution_state_id=graph_execution_state_id, node=node, source_node_id=source_node_id, + error_type=error_type, error=error, ), ) @@ -102,13 +112,13 @@ class EventServiceBase: ), ) - def emit_model_load_started ( - self, - graph_execution_state_id: str, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: SubModelType, + def emit_model_load_started( + self, + graph_execution_state_id: str, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + submodel: SubModelType, ) -> None: """Emitted when a model is requested""" self.__emit_session_event( @@ -123,13 +133,13 @@ class EventServiceBase: ) def emit_model_load_completed( - self, - graph_execution_state_id: str, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: SubModelType, - model_info: ModelInfo, + self, + graph_execution_state_id: str, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + submodel: SubModelType, + model_info: ModelInfo, ) -> None: """Emitted when a model is correctly loaded (returns model info)""" self.__emit_session_event( @@ -145,3 +155,37 @@ class EventServiceBase: precision=str(model_info.precision), ), ) + + def emit_session_retrieval_error( + self, + graph_execution_state_id: str, + error_type: str, + error: str, + ) -> None: + """Emitted when session retrieval fails""" + self.__emit_session_event( + event_name="session_retrieval_error", + payload=dict( + graph_execution_state_id=graph_execution_state_id, + error_type=error_type, + error=error, + ), + ) + + def emit_invocation_retrieval_error( + self, + graph_execution_state_id: str, + node_id: str, + error_type: str, + error: str, + ) -> None: + """Emitted when invocation retrieval fails""" + self.__emit_session_event( + event_name="invocation_retrieval_error", + payload=dict( + graph_execution_state_id=graph_execution_state_id, + node_id=node_id, + error_type=error_type, + error=error, + ), + ) diff --git a/invokeai/app/services/processor.py b/invokeai/app/services/processor.py index e11eb84b3d..5995e4ffc3 100644 --- a/invokeai/app/services/processor.py +++ b/invokeai/app/services/processor.py @@ -39,21 +39,41 @@ class DefaultInvocationProcessor(InvocationProcessorABC): try: queue_item: InvocationQueueItem = self.__invoker.services.queue.get() except Exception as e: - logger.debug("Exception while getting from queue: %s" % e) + self.__invoker.services.logger.error("Exception while getting from queue:\n%s" % e) if not queue_item: # Probably stopping # do not hammer the queue time.sleep(0.5) continue - graph_execution_state = ( - self.__invoker.services.graph_execution_manager.get( - queue_item.graph_execution_state_id + try: + graph_execution_state = ( + self.__invoker.services.graph_execution_manager.get( + queue_item.graph_execution_state_id + ) ) - ) - invocation = graph_execution_state.execution_graph.get_node( - queue_item.invocation_id - ) + except Exception as e: + self.__invoker.services.logger.error("Exception while retrieving session:\n%s" % e) + self.__invoker.services.events.emit_session_retrieval_error( + graph_execution_state_id=queue_item.graph_execution_state_id, + error_type=e.__class__.__name__, + error=traceback.format_exc(), + ) + continue + + try: + invocation = graph_execution_state.execution_graph.get_node( + queue_item.invocation_id + ) + except Exception as e: + self.__invoker.services.logger.error("Exception while retrieving invocation:\n%s" % e) + self.__invoker.services.events.emit_invocation_retrieval_error( + graph_execution_state_id=queue_item.graph_execution_state_id, + node_id=queue_item.invocation_id, + error_type=e.__class__.__name__, + error=traceback.format_exc(), + ) + continue # get the source node id to provide to clients (the prepared node id is not as useful) source_node_id = graph_execution_state.prepared_source_mapping[invocation.id] @@ -114,11 +134,13 @@ class DefaultInvocationProcessor(InvocationProcessorABC): graph_execution_state ) + self.__invoker.services.logger.error("Error while invoking:\n%s" % e) # Send error event self.__invoker.services.events.emit_invocation_error( graph_execution_state_id=graph_execution_state.id, node=invocation.dict(), source_node_id=source_node_id, + error_type=e.__class__.__name__, error=error, ) @@ -136,11 +158,12 @@ class DefaultInvocationProcessor(InvocationProcessorABC): try: self.__invoker.invoke(graph_execution_state, invoke_all=True) except Exception as e: - logger.error("Error while invoking: %s" % e) + self.__invoker.services.logger.error("Error while invoking:\n%s" % e) self.__invoker.services.events.emit_invocation_error( graph_execution_state_id=graph_execution_state.id, node=invocation.dict(), source_node_id=source_node_id, + error_type=e.__class__.__name__, error=traceback.format_exc() ) elif is_complete: diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts index 04f0ce7a0b..5adc4f5e5e 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts @@ -75,6 +75,8 @@ import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas'; import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage'; import { addUserInvokedNodesListener } from './listeners/userInvokedNodes'; import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage'; +import { addSessionRetrievalErrorEventListener } from './listeners/socketio/socketSessionRetrievalError'; +import { addInvocationRetrievalErrorEventListener } from './listeners/socketio/socketInvocationRetrievalError'; export const listenerMiddleware = createListenerMiddleware(); @@ -153,6 +155,8 @@ addSocketDisconnectedListener(); addSocketSubscribedListener(); addSocketUnsubscribedListener(); addModelLoadEventListener(); +addSessionRetrievalErrorEventListener(); +addInvocationRetrievalErrorEventListener(); // Session Created addSessionCreatedPendingListener(); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionCreated.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionCreated.ts index 5709d87d22..e89acb7542 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionCreated.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionCreated.ts @@ -33,12 +33,11 @@ export const addSessionCreatedRejectedListener = () => { effect: (action) => { const log = logger('session'); if (action.payload) { - const { error } = action.payload; + const { error, status } = action.payload; const graph = parseify(action.meta.arg); - const stringifiedError = JSON.stringify(error); log.error( - { graph, error: serializeError(error) }, - `Problem creating session: ${stringifiedError}` + { graph, status, error: serializeError(error) }, + `Problem creating session` ); } }, diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionInvoked.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionInvoked.ts index 60009ed194..a62f75d957 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionInvoked.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionInvoked.ts @@ -31,13 +31,12 @@ export const addSessionInvokedRejectedListener = () => { const { session_id } = action.meta.arg; if (action.payload) { const { error } = action.payload; - const stringifiedError = JSON.stringify(error); log.error( { session_id, error: serializeError(error), }, - `Problem invoking session: ${stringifiedError}` + `Problem invoking session` ); } }, diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationRetrievalError.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationRetrievalError.ts new file mode 100644 index 0000000000..aa88457eb7 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationRetrievalError.ts @@ -0,0 +1,20 @@ +import { logger } from 'app/logging/logger'; +import { + appSocketInvocationRetrievalError, + socketInvocationRetrievalError, +} from 'services/events/actions'; +import { startAppListening } from '../..'; + +export const addInvocationRetrievalErrorEventListener = () => { + startAppListening({ + actionCreator: socketInvocationRetrievalError, + effect: (action, { dispatch }) => { + const log = logger('socketio'); + log.error( + action.payload, + `Invocation retrieval error (${action.payload.data.graph_execution_state_id})` + ); + dispatch(appSocketInvocationRetrievalError(action.payload)); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSessionRetrievalError.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSessionRetrievalError.ts new file mode 100644 index 0000000000..7efb7f463a --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSessionRetrievalError.ts @@ -0,0 +1,20 @@ +import { logger } from 'app/logging/logger'; +import { + appSocketSessionRetrievalError, + socketSessionRetrievalError, +} from 'services/events/actions'; +import { startAppListening } from '../..'; + +export const addSessionRetrievalErrorEventListener = () => { + startAppListening({ + actionCreator: socketSessionRetrievalError, + effect: (action, { dispatch }) => { + const log = logger('socketio'); + log.error( + action.payload, + `Session retrieval error (${action.payload.data.graph_execution_state_id})` + ); + dispatch(appSocketSessionRetrievalError(action.payload)); + }, + }); +}; diff --git a/invokeai/frontend/web/src/features/system/store/systemSlice.ts b/invokeai/frontend/web/src/features/system/store/systemSlice.ts index 629a4f0139..b7a5e606e2 100644 --- a/invokeai/frontend/web/src/features/system/store/systemSlice.ts +++ b/invokeai/frontend/web/src/features/system/store/systemSlice.ts @@ -1,5 +1,5 @@ import { UseToastOptions } from '@chakra-ui/react'; -import { PayloadAction, createSlice } from '@reduxjs/toolkit'; +import { PayloadAction, createSlice, isAnyOf } from '@reduxjs/toolkit'; import { InvokeLogLevel } from 'app/logging/logger'; import { userInvoked } from 'app/store/actions'; import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice'; @@ -16,13 +16,16 @@ import { appSocketGraphExecutionStateComplete, appSocketInvocationComplete, appSocketInvocationError, + appSocketInvocationRetrievalError, appSocketInvocationStarted, + appSocketSessionRetrievalError, appSocketSubscribed, appSocketUnsubscribed, } from 'services/events/actions'; import { ProgressImage } from 'services/events/types'; import { makeToast } from '../util/makeToast'; import { LANGUAGES } from './constants'; +import { startCase } from 'lodash-es'; export type CancelStrategy = 'immediate' | 'scheduled'; @@ -288,25 +291,6 @@ export const systemSlice = createSlice({ } }); - /** - * Invocation Error - */ - builder.addCase(appSocketInvocationError, (state) => { - state.isProcessing = false; - state.isCancelable = true; - // state.currentIteration = 0; - // state.totalIterations = 0; - state.currentStatusHasSteps = false; - state.currentStep = 0; - state.totalSteps = 0; - state.statusTranslationKey = 'common.statusError'; - state.progressImage = null; - - state.toastQueue.push( - makeToast({ title: t('toast.serverError'), status: 'error' }) - ); - }); - /** * Graph Execution State Complete */ @@ -362,7 +346,7 @@ export const systemSlice = createSlice({ * Session Invoked - REJECTED * Session Created - REJECTED */ - builder.addMatcher(isAnySessionRejected, (state) => { + builder.addMatcher(isAnySessionRejected, (state, action) => { state.isProcessing = false; state.isCancelable = false; state.isCancelScheduled = false; @@ -372,7 +356,35 @@ export const systemSlice = createSlice({ state.progressImage = null; state.toastQueue.push( - makeToast({ title: t('toast.serverError'), status: 'error' }) + makeToast({ + title: t('toast.serverError'), + status: 'error', + description: + action.payload?.status === 422 ? 'Validation Error' : undefined, + }) + ); + }); + + /** + * Any server error + */ + builder.addMatcher(isAnyServerError, (state, action) => { + state.isProcessing = false; + state.isCancelable = true; + // state.currentIteration = 0; + // state.totalIterations = 0; + state.currentStatusHasSteps = false; + state.currentStep = 0; + state.totalSteps = 0; + state.statusTranslationKey = 'common.statusError'; + state.progressImage = null; + + state.toastQueue.push( + makeToast({ + title: t('toast.serverError'), + status: 'error', + description: startCase(action.payload.data.error_type), + }) ); }); }, @@ -400,3 +412,9 @@ export const { } = systemSlice.actions; export default systemSlice.reducer; + +const isAnyServerError = isAnyOf( + appSocketInvocationError, + appSocketSessionRetrievalError, + appSocketInvocationRetrievalError +); diff --git a/invokeai/frontend/web/src/services/api/thunks/session.ts b/invokeai/frontend/web/src/services/api/thunks/session.ts index 6d20b9dd33..5588f25b46 100644 --- a/invokeai/frontend/web/src/services/api/thunks/session.ts +++ b/invokeai/frontend/web/src/services/api/thunks/session.ts @@ -18,7 +18,7 @@ type CreateSessionResponse = O.Required< >; type CreateSessionThunkConfig = { - rejectValue: { arg: CreateSessionArg; error: unknown }; + rejectValue: { arg: CreateSessionArg; status: number; error: unknown }; }; /** @@ -36,7 +36,7 @@ export const sessionCreated = createAsyncThunk< }); if (error) { - return rejectWithValue({ arg, error }); + return rejectWithValue({ arg, status: response.status, error }); } return data; @@ -53,6 +53,7 @@ type InvokedSessionThunkConfig = { rejectValue: { arg: InvokedSessionArg; error: unknown; + status: number; }; }; @@ -78,9 +79,13 @@ export const sessionInvoked = createAsyncThunk< if (error) { if (isErrorWithStatus(error) && error.status === 403) { - return rejectWithValue({ arg, error: (error as any).body.detail }); + return rejectWithValue({ + arg, + status: response.status, + error: (error as any).body.detail, + }); } - return rejectWithValue({ arg, error }); + return rejectWithValue({ arg, status: response.status, error }); } }); diff --git a/invokeai/frontend/web/src/services/events/actions.ts b/invokeai/frontend/web/src/services/events/actions.ts index b6316c5e95..35ebb725cb 100644 --- a/invokeai/frontend/web/src/services/events/actions.ts +++ b/invokeai/frontend/web/src/services/events/actions.ts @@ -4,9 +4,11 @@ import { GraphExecutionStateCompleteEvent, InvocationCompleteEvent, InvocationErrorEvent, + InvocationRetrievalErrorEvent, InvocationStartedEvent, ModelLoadCompletedEvent, ModelLoadStartedEvent, + SessionRetrievalErrorEvent, } from 'services/events/types'; // Create actions for each socket @@ -181,3 +183,35 @@ export const socketModelLoadCompleted = createAction<{ export const appSocketModelLoadCompleted = createAction<{ data: ModelLoadCompletedEvent; }>('socket/appSocketModelLoadCompleted'); + +/** + * Socket.IO Session Retrieval Error + * + * Do not use. Only for use in middleware. + */ +export const socketSessionRetrievalError = createAction<{ + data: SessionRetrievalErrorEvent; +}>('socket/socketSessionRetrievalError'); + +/** + * App-level Session Retrieval Error + */ +export const appSocketSessionRetrievalError = createAction<{ + data: SessionRetrievalErrorEvent; +}>('socket/appSocketSessionRetrievalError'); + +/** + * Socket.IO Invocation Retrieval Error + * + * Do not use. Only for use in middleware. + */ +export const socketInvocationRetrievalError = createAction<{ + data: InvocationRetrievalErrorEvent; +}>('socket/socketInvocationRetrievalError'); + +/** + * App-level Invocation Retrieval Error + */ +export const appSocketInvocationRetrievalError = createAction<{ + data: InvocationRetrievalErrorEvent; +}>('socket/appSocketInvocationRetrievalError'); diff --git a/invokeai/frontend/web/src/services/events/types.ts b/invokeai/frontend/web/src/services/events/types.ts index ec1b55e3fe..37f5f24eac 100644 --- a/invokeai/frontend/web/src/services/events/types.ts +++ b/invokeai/frontend/web/src/services/events/types.ts @@ -87,6 +87,7 @@ export type InvocationErrorEvent = { graph_execution_state_id: string; node: BaseNode; source_node_id: string; + error_type: string; error: string; }; @@ -110,6 +111,29 @@ export type GraphExecutionStateCompleteEvent = { graph_execution_state_id: string; }; +/** + * A `session_retrieval_error` socket.io event. + * + * @example socket.on('session_retrieval_error', (data: SessionRetrievalErrorEvent) => { ... } + */ +export type SessionRetrievalErrorEvent = { + graph_execution_state_id: string; + error_type: string; + error: string; +}; + +/** + * A `invocation_retrieval_error` socket.io event. + * + * @example socket.on('invocation_retrieval_error', (data: InvocationRetrievalErrorEvent) => { ... } + */ +export type InvocationRetrievalErrorEvent = { + graph_execution_state_id: string; + node_id: string; + error_type: string; + error: string; +}; + export type ClientEmitSubscribe = { session: string; }; @@ -128,6 +152,8 @@ export type ServerToClientEvents = { ) => void; model_load_started: (payload: ModelLoadStartedEvent) => void; model_load_completed: (payload: ModelLoadCompletedEvent) => void; + session_retrieval_error: (payload: SessionRetrievalErrorEvent) => void; + invocation_retrieval_error: (payload: InvocationRetrievalErrorEvent) => void; }; export type ClientToServerEvents = { diff --git a/invokeai/frontend/web/src/services/events/util/setEventListeners.ts b/invokeai/frontend/web/src/services/events/util/setEventListeners.ts index d44a549183..9ebb7ffbff 100644 --- a/invokeai/frontend/web/src/services/events/util/setEventListeners.ts +++ b/invokeai/frontend/web/src/services/events/util/setEventListeners.ts @@ -11,9 +11,11 @@ import { socketGraphExecutionStateComplete, socketInvocationComplete, socketInvocationError, + socketInvocationRetrievalError, socketInvocationStarted, socketModelLoadCompleted, socketModelLoadStarted, + socketSessionRetrievalError, socketSubscribed, } from '../actions'; import { ClientToServerEvents, ServerToClientEvents } from '../types'; @@ -138,4 +140,26 @@ export const setEventListeners = (arg: SetEventListenersArg) => { }) ); }); + + /** + * Session retrieval error + */ + socket.on('session_retrieval_error', (data) => { + dispatch( + socketSessionRetrievalError({ + data, + }) + ); + }); + + /** + * Invocation retrieval error + */ + socket.on('invocation_retrieval_error', (data) => { + dispatch( + socketInvocationRetrievalError({ + data, + }) + ); + }); }; From 28031ead708fadf51ba38a9b162c2b0a72f851ee Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 23 Jul 2023 22:34:31 +1000 Subject: [PATCH 2/7] feat(ui): display canvas generation mode in status text - use the existing logic to determine if generation is txt2img, img2img, inpaint or outpaint - technically `outpaint` and `inpaint` are the same, just display "Inpaint" if its either - debounce this by 1s to prevent jank --- .../listeners/userInvokedCanvas.ts | 2 +- .../canvas/components/IAICanvasStatusText.tsx | 3 +- .../IAICanvasToolbar/IAICanvasToolbar.tsx | 21 ++++--- .../src/features/canvas/store/canvasSlice.ts | 5 ++ .../src/features/canvas/store/canvasTypes.ts | 3 + .../src/features/canvas/util/getCanvasData.ts | 7 +-- .../canvas/util/getCanvasGenerationMode.ts | 3 +- .../Canvas/GenerationModeStatusText.tsx | 55 +++++++++++++++++++ 8 files changed, 81 insertions(+), 18 deletions(-) create mode 100644 invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/GenerationModeStatusText.tsx diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts index 2ef62aed7b..17b2eeed46 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts @@ -40,7 +40,7 @@ export const addUserInvokedCanvasListener = () => { const state = getState(); // Build canvas blobs - const canvasBlobsAndImageData = await getCanvasData(state); + const canvasBlobsAndImageData = await getCanvasData(state.canvas); if (!canvasBlobsAndImageData) { log.error('Unable to create canvas data'); diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvasStatusText.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvasStatusText.tsx index 69bf628a39..8c1dfbb86f 100644 --- a/invokeai/frontend/web/src/features/canvas/components/IAICanvasStatusText.tsx +++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvasStatusText.tsx @@ -2,8 +2,8 @@ import { Box, Flex } from '@chakra-ui/react'; import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { canvasSelector } from 'features/canvas/store/canvasSelectors'; +import GenerationModeStatusText from 'features/parameters/components/Parameters/Canvas/GenerationModeStatusText'; import { isEqual } from 'lodash-es'; - import { useTranslation } from 'react-i18next'; import roundToHundreth from '../util/roundToHundreth'; import IAICanvasStatusTextCursorPos from './IAICanvasStatusText/IAICanvasStatusTextCursorPos'; @@ -110,6 +110,7 @@ const IAICanvasStatusText = () => { }, }} > + { }} > - ) => { + state.generationMode = action.payload; + }, }, extraReducers: (builder) => { builder.addCase(sessionCanceled.pending, (state) => { @@ -955,6 +959,7 @@ export const { stagingAreaInitialized, canvasSessionIdChanged, setShouldAntialias, + generationModeChanged, } = canvasSlice.actions; export default canvasSlice.reducer; diff --git a/invokeai/frontend/web/src/features/canvas/store/canvasTypes.ts b/invokeai/frontend/web/src/features/canvas/store/canvasTypes.ts index 48d59395ab..ba85a7e132 100644 --- a/invokeai/frontend/web/src/features/canvas/store/canvasTypes.ts +++ b/invokeai/frontend/web/src/features/canvas/store/canvasTypes.ts @@ -168,4 +168,7 @@ export interface CanvasState { stageDimensions: Dimensions; stageScale: number; tool: CanvasTool; + generationMode?: GenerationMode; } + +export type GenerationMode = 'txt2img' | 'img2img' | 'inpaint' | 'outpaint'; diff --git a/invokeai/frontend/web/src/features/canvas/util/getCanvasData.ts b/invokeai/frontend/web/src/features/canvas/util/getCanvasData.ts index d37ee7b8d0..855420f78a 100644 --- a/invokeai/frontend/web/src/features/canvas/util/getCanvasData.ts +++ b/invokeai/frontend/web/src/features/canvas/util/getCanvasData.ts @@ -1,6 +1,5 @@ import { logger } from 'app/logging/logger'; -import { RootState } from 'app/store/store'; -import { isCanvasMaskLine } from '../store/canvasTypes'; +import { CanvasState, isCanvasMaskLine } from '../store/canvasTypes'; import createMaskStage from './createMaskStage'; import { getCanvasBaseLayer, getCanvasStage } from './konvaInstanceProvider'; import { konvaNodeToBlob } from './konvaNodeToBlob'; @@ -9,7 +8,7 @@ import { konvaNodeToImageData } from './konvaNodeToImageData'; /** * Gets Blob and ImageData objects for the base and mask layers */ -export const getCanvasData = async (state: RootState) => { +export const getCanvasData = async (canvasState: CanvasState) => { const log = logger('canvas'); const canvasBaseLayer = getCanvasBaseLayer(); @@ -26,7 +25,7 @@ export const getCanvasData = async (state: RootState) => { boundingBoxDimensions, isMaskEnabled, shouldPreserveMaskedArea, - } = state.canvas; + } = canvasState; const boundingBox = { ...boundingBoxCoordinates, diff --git a/invokeai/frontend/web/src/features/canvas/util/getCanvasGenerationMode.ts b/invokeai/frontend/web/src/features/canvas/util/getCanvasGenerationMode.ts index 5b38ecf938..d3e8792690 100644 --- a/invokeai/frontend/web/src/features/canvas/util/getCanvasGenerationMode.ts +++ b/invokeai/frontend/web/src/features/canvas/util/getCanvasGenerationMode.ts @@ -2,11 +2,12 @@ import { areAnyPixelsBlack, getImageDataTransparency, } from 'common/util/arrayBuffer'; +import { GenerationMode } from '../store/canvasTypes'; export const getCanvasGenerationMode = ( baseImageData: ImageData, maskImageData: ImageData -) => { +): GenerationMode => { const { isPartiallyTransparent: baseIsPartiallyTransparent, isFullyTransparent: baseIsFullyTransparent, diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/GenerationModeStatusText.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/GenerationModeStatusText.tsx new file mode 100644 index 0000000000..5c6bbd0ba3 --- /dev/null +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/GenerationModeStatusText.tsx @@ -0,0 +1,55 @@ +import { Box } from '@chakra-ui/react'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { generationModeChanged } from 'features/canvas/store/canvasSlice'; +import { getCanvasData } from 'features/canvas/util/getCanvasData'; +import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGenerationMode'; +import { useDebounce } from 'react-use'; + +const GENERATION_MODE_NAME_MAP = { + txt2img: 'Text to Image', + img2img: 'Image to Image', + inpaint: 'Inpaint', + outpaint: 'Inpaint', +}; + +export const useGenerationMode = () => { + const dispatch = useAppDispatch(); + const canvasState = useAppSelector((state) => state.canvas); + + useDebounce( + async () => { + // Build canvas blobs + const canvasBlobsAndImageData = await getCanvasData(canvasState); + + if (!canvasBlobsAndImageData) { + return; + } + + const { baseImageData, maskImageData } = canvasBlobsAndImageData; + + // Determine the generation mode + const generationMode = getCanvasGenerationMode( + baseImageData, + maskImageData + ); + + dispatch(generationModeChanged(generationMode)); + }, + 1000, + [dispatch, canvasState, generationModeChanged] + ); +}; + +const GenerationModeStatusText = () => { + const generationMode = useAppSelector((state) => state.canvas.generationMode); + + useGenerationMode(); + + return ( + + Mode: {generationMode ? GENERATION_MODE_NAME_MAP[generationMode] : '...'} + + ); +}; + +export default GenerationModeStatusText; From 07a90c019800ce485c88dabe421f3896ebf38f72 Mon Sep 17 00:00:00 2001 From: Alexandre Macabies Date: Sun, 23 Jul 2023 14:49:28 +0200 Subject: [PATCH 3/7] Fix incorrect use of a singleton list. This was found through pylance type errors. Go types! --- invokeai/app/api/routers/models.py | 2 +- invokeai/app/services/model_manager_service.py | 2 +- invokeai/backend/model_management/model_search.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index 870ca33534..759f6c9f59 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -298,7 +298,7 @@ async def search_for_models( )->List[pathlib.Path]: if not search_path.is_dir(): raise HTTPException(status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory") - return ApiDependencies.invoker.services.model_manager.search_for_models([search_path]) + return ApiDependencies.invoker.services.model_manager.search_for_models(search_path) @models_router.get( "/ckpt_confs", diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index b1b995309e..f7d3b3a7a7 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -600,7 +600,7 @@ class ModelManagerService(ModelManagerServiceBase): """ Return list of all models found in the designated directory. """ - search = FindModels(directory,self.logger) + search = FindModels([directory], self.logger) return search.list_models() def sync_to_config(self): diff --git a/invokeai/backend/model_management/model_search.py b/invokeai/backend/model_management/model_search.py index 1e282b4bb8..5657bd9549 100644 --- a/invokeai/backend/model_management/model_search.py +++ b/invokeai/backend/model_management/model_search.py @@ -98,6 +98,6 @@ class FindModels(ModelSearch): def list_models(self) -> List[Path]: self.search() - return self.models_found + return list(self.models_found) From 0beec08d3822b398f855ab11ee5064ed5b479bce Mon Sep 17 00:00:00 2001 From: Alexandre Macabies Date: Sun, 23 Jul 2023 14:46:16 +0200 Subject: [PATCH 4/7] Add missing import. --- invokeai/backend/model_management/models/lora.py | 1 + 1 file changed, 1 insertion(+) diff --git a/invokeai/backend/model_management/models/lora.py b/invokeai/backend/model_management/models/lora.py index 5387ade0e5..eb771841ec 100644 --- a/invokeai/backend/model_management/models/lora.py +++ b/invokeai/backend/model_management/models/lora.py @@ -10,6 +10,7 @@ from .base import ( SubModelType, classproperty, InvalidModelException, + ModelNotFoundException, ) # TODO: naming from ..lora import LoRAModel as LoRAModelRaw From 0cf7a10c5ccd413b001fb360ff7b0fcb6e961883 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Mon, 24 Jul 2023 18:58:24 +1200 Subject: [PATCH 5/7] fix: Other lora missing type --- invokeai/backend/model_management/lora.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index b0481f3cfa..222169afbb 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -474,7 +474,7 @@ class ModelPatcher: @staticmethod def _lora_forward_hook( - applied_loras: List[Tuple[LoraModel, float]], + applied_loras: List[Tuple[LoRAModel, float]], layer_name: str, ): @@ -519,7 +519,7 @@ class ModelPatcher: def apply_lora( cls, model: torch.nn.Module, - loras: List[Tuple[LoraModel, float]], + loras: List[Tuple[LoRAModel, float]], prefix: str, ): original_weights = dict() From 61fa960a18e2715c6b7cc4d730f500ad25545395 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 24 Jul 2023 18:16:15 +1000 Subject: [PATCH 6/7] feat(ui): make generation mode calculation more granular --- .../listeners/userInvokedCanvas.ts | 16 ++++- .../canvas/hooks/useCanvasGenerationMode.ts | 72 +++++++++++++++++++ .../src/features/canvas/store/canvasSlice.ts | 5 -- .../src/features/canvas/util/getCanvasData.ts | 25 ++++--- .../Canvas/GenerationModeStatusText.tsx | 38 +--------- 5 files changed, 103 insertions(+), 53 deletions(-) create mode 100644 invokeai/frontend/web/src/features/canvas/hooks/useCanvasGenerationMode.ts diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts index 17b2eeed46..39bd742d7d 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts @@ -39,8 +39,22 @@ export const addUserInvokedCanvasListener = () => { const state = getState(); + const { + layerState, + boundingBoxCoordinates, + boundingBoxDimensions, + isMaskEnabled, + shouldPreserveMaskedArea, + } = state.canvas; + // Build canvas blobs - const canvasBlobsAndImageData = await getCanvasData(state.canvas); + const canvasBlobsAndImageData = await getCanvasData( + layerState, + boundingBoxCoordinates, + boundingBoxDimensions, + isMaskEnabled, + shouldPreserveMaskedArea + ); if (!canvasBlobsAndImageData) { log.error('Unable to create canvas data'); diff --git a/invokeai/frontend/web/src/features/canvas/hooks/useCanvasGenerationMode.ts b/invokeai/frontend/web/src/features/canvas/hooks/useCanvasGenerationMode.ts new file mode 100644 index 0000000000..55b04efca4 --- /dev/null +++ b/invokeai/frontend/web/src/features/canvas/hooks/useCanvasGenerationMode.ts @@ -0,0 +1,72 @@ +import { useAppSelector } from 'app/store/storeHooks'; +import { GenerationMode } from 'features/canvas/store/canvasTypes'; +import { getCanvasData } from 'features/canvas/util/getCanvasData'; +import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGenerationMode'; +import { useEffect, useState } from 'react'; +import { useDebounce } from 'react-use'; + +export const useCanvasGenerationMode = () => { + const layerState = useAppSelector((state) => state.canvas.layerState); + + const boundingBoxCoordinates = useAppSelector( + (state) => state.canvas.boundingBoxCoordinates + ); + const boundingBoxDimensions = useAppSelector( + (state) => state.canvas.boundingBoxDimensions + ); + const isMaskEnabled = useAppSelector((state) => state.canvas.isMaskEnabled); + + const shouldPreserveMaskedArea = useAppSelector( + (state) => state.canvas.shouldPreserveMaskedArea + ); + const [generationMode, setGenerationMode] = useState< + GenerationMode | undefined + >(); + + useEffect(() => { + setGenerationMode(undefined); + }, [ + layerState, + boundingBoxCoordinates, + boundingBoxDimensions, + isMaskEnabled, + shouldPreserveMaskedArea, + ]); + + useDebounce( + async () => { + // Build canvas blobs + const canvasBlobsAndImageData = await getCanvasData( + layerState, + boundingBoxCoordinates, + boundingBoxDimensions, + isMaskEnabled, + shouldPreserveMaskedArea + ); + + if (!canvasBlobsAndImageData) { + return; + } + + const { baseImageData, maskImageData } = canvasBlobsAndImageData; + + // Determine the generation mode + const generationMode = getCanvasGenerationMode( + baseImageData, + maskImageData + ); + + setGenerationMode(generationMode); + }, + 1000, + [ + layerState, + boundingBoxCoordinates, + boundingBoxDimensions, + isMaskEnabled, + shouldPreserveMaskedArea, + ] + ); + + return generationMode; +}; diff --git a/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts b/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts index dc91c1c769..3163e513e9 100644 --- a/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts +++ b/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts @@ -30,7 +30,6 @@ import { CanvasState, CanvasTool, Dimensions, - GenerationMode, isCanvasAnyLine, isCanvasBaseImage, isCanvasMaskLine, @@ -859,9 +858,6 @@ export const canvasSlice = createSlice({ state.isMovingBoundingBox = false; state.isTransformingBoundingBox = false; }, - generationModeChanged: (state, action: PayloadAction) => { - state.generationMode = action.payload; - }, }, extraReducers: (builder) => { builder.addCase(sessionCanceled.pending, (state) => { @@ -959,7 +955,6 @@ export const { stagingAreaInitialized, canvasSessionIdChanged, setShouldAntialias, - generationModeChanged, } = canvasSlice.actions; export default canvasSlice.reducer; diff --git a/invokeai/frontend/web/src/features/canvas/util/getCanvasData.ts b/invokeai/frontend/web/src/features/canvas/util/getCanvasData.ts index 855420f78a..4e575791ed 100644 --- a/invokeai/frontend/web/src/features/canvas/util/getCanvasData.ts +++ b/invokeai/frontend/web/src/features/canvas/util/getCanvasData.ts @@ -1,5 +1,10 @@ import { logger } from 'app/logging/logger'; -import { CanvasState, isCanvasMaskLine } from '../store/canvasTypes'; +import { Vector2d } from 'konva/lib/types'; +import { + CanvasLayerState, + Dimensions, + isCanvasMaskLine, +} from '../store/canvasTypes'; import createMaskStage from './createMaskStage'; import { getCanvasBaseLayer, getCanvasStage } from './konvaInstanceProvider'; import { konvaNodeToBlob } from './konvaNodeToBlob'; @@ -8,7 +13,13 @@ import { konvaNodeToImageData } from './konvaNodeToImageData'; /** * Gets Blob and ImageData objects for the base and mask layers */ -export const getCanvasData = async (canvasState: CanvasState) => { +export const getCanvasData = async ( + layerState: CanvasLayerState, + boundingBoxCoordinates: Vector2d, + boundingBoxDimensions: Dimensions, + isMaskEnabled: boolean, + shouldPreserveMaskedArea: boolean +) => { const log = logger('canvas'); const canvasBaseLayer = getCanvasBaseLayer(); @@ -19,14 +30,6 @@ export const getCanvasData = async (canvasState: CanvasState) => { return; } - const { - layerState: { objects }, - boundingBoxCoordinates, - boundingBoxDimensions, - isMaskEnabled, - shouldPreserveMaskedArea, - } = canvasState; - const boundingBox = { ...boundingBoxCoordinates, ...boundingBoxDimensions, @@ -57,7 +60,7 @@ export const getCanvasData = async (canvasState: CanvasState) => { // For the mask layer, use the normal boundingBox const maskStage = await createMaskStage( - isMaskEnabled ? objects.filter(isCanvasMaskLine) : [], // only include mask lines, and only if mask is enabled + isMaskEnabled ? layerState.objects.filter(isCanvasMaskLine) : [], // only include mask lines, and only if mask is enabled boundingBox, shouldPreserveMaskedArea ); diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/GenerationModeStatusText.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/GenerationModeStatusText.tsx index 5c6bbd0ba3..511e90f0f3 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/GenerationModeStatusText.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/GenerationModeStatusText.tsx @@ -1,9 +1,5 @@ import { Box } from '@chakra-ui/react'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { generationModeChanged } from 'features/canvas/store/canvasSlice'; -import { getCanvasData } from 'features/canvas/util/getCanvasData'; -import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGenerationMode'; -import { useDebounce } from 'react-use'; +import { useCanvasGenerationMode } from 'features/canvas/hooks/useCanvasGenerationMode'; const GENERATION_MODE_NAME_MAP = { txt2img: 'Text to Image', @@ -12,38 +8,8 @@ const GENERATION_MODE_NAME_MAP = { outpaint: 'Inpaint', }; -export const useGenerationMode = () => { - const dispatch = useAppDispatch(); - const canvasState = useAppSelector((state) => state.canvas); - - useDebounce( - async () => { - // Build canvas blobs - const canvasBlobsAndImageData = await getCanvasData(canvasState); - - if (!canvasBlobsAndImageData) { - return; - } - - const { baseImageData, maskImageData } = canvasBlobsAndImageData; - - // Determine the generation mode - const generationMode = getCanvasGenerationMode( - baseImageData, - maskImageData - ); - - dispatch(generationModeChanged(generationMode)); - }, - 1000, - [dispatch, canvasState, generationModeChanged] - ); -}; - const GenerationModeStatusText = () => { - const generationMode = useAppSelector((state) => state.canvas.generationMode); - - useGenerationMode(); + const generationMode = useCanvasGenerationMode(); return ( From 437532f2f9a3208a4b43d486cb00e204d0e3d4d2 Mon Sep 17 00:00:00 2001 From: Josh Corbett Date: Mon, 24 Jul 2023 17:42:01 -0600 Subject: [PATCH 7/7] fix: :pencil2: fix docs generation typo and remove trailing white space --- mkdocs.yml | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/mkdocs.yml b/mkdocs.yml index 7d3e0e0b85..cbcaf52af6 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -101,7 +101,7 @@ plugins: nav: - Home: 'index.md' - - Installation: + - Installation: - Overview: 'installation/index.md' - Installing with the Automated Installer: 'installation/010_INSTALL_AUTOMATED.md' - Installing manually: 'installation/020_INSTALL_MANUAL.md' @@ -122,14 +122,14 @@ nav: - Community Nodes: - Community Nodes: 'nodes/communityNodes.md' - Overview: 'nodes/overview.md' - - Features: + - Features: - Overview: 'features/index.md' - Concepts: 'features/CONCEPTS.md' - Configuration: 'features/CONFIGURATION.md' - ControlNet: 'features/CONTROLNET.md' - Image-to-Image: 'features/IMG2IMG.md' - Controlling Logging: 'features/LOGGING.md' - - Model Mergeing: 'features/MODEL_MERGING.md' + - Model Merging: 'features/MODEL_MERGING.md' - Nodes Editor (Experimental): 'features/NODES.md' - NSFW Checker: 'features/NSFW.md' - Postprocessing: 'features/POSTPROCESS.md' @@ -140,9 +140,9 @@ nav: - InvokeAI Web Server: 'features/WEB.md' - WebUI Hotkeys: "features/WEBUIHOTKEYS.md" - Other: 'features/OTHER.md' - - Contributing: + - Contributing: - How to Contribute: 'contributing/CONTRIBUTING.md' - - Development: + - Development: - Overview: 'contributing/contribution_guides/development.md' - InvokeAI Architecture: 'contributing/ARCHITECTURE.md' - Frontend Documentation: 'contributing/contribution_guides/development_guides/contributingToFrontend.md' @@ -161,5 +161,3 @@ nav: - Other: - Contributors: 'other/CONTRIBUTORS.md' - CompViz-README: 'other/README-CompViz.md' - -