From fdc444ed616a27f9cd45a7c1dafd83a2583cd115 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Sun, 23 Jul 2023 15:24:04 +1200 Subject: [PATCH 01/10] fix: Fix app crashing when you upload an incorrect JSON to node editor --- invokeai/frontend/web/public/locales/en.json | 2 + .../nodes/components/ui/LoadNodesButton.tsx | 78 ++++++++++++------- 2 files changed, 54 insertions(+), 26 deletions(-) diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 59cf87fbda..ab10276491 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -615,6 +615,8 @@ "initialImageNotSetDesc": "Could not load initial image", "nodesSaved": "Nodes Saved", "nodesLoaded": "Nodes Loaded", + "nodesNotValidGraph": "Not a valid InvokeAI Node Graph", + "nodesNotValidJSON": "Not a valid JSON", "nodesLoadedFailed": "Failed To Load Nodes", "nodesCleared": "Nodes Cleared" }, diff --git a/invokeai/frontend/web/src/features/nodes/components/ui/LoadNodesButton.tsx b/invokeai/frontend/web/src/features/nodes/components/ui/LoadNodesButton.tsx index 706fbd8b31..2aa369bc11 100644 --- a/invokeai/frontend/web/src/features/nodes/components/ui/LoadNodesButton.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/ui/LoadNodesButton.tsx @@ -1,14 +1,28 @@ import { FileButton } from '@mantine/core'; -import { makeToast } from 'features/system/util/makeToast'; import { useAppDispatch } from 'app/store/storeHooks'; import IAIIconButton from 'common/components/IAIIconButton'; import { loadFileEdges, loadFileNodes } from 'features/nodes/store/nodesSlice'; import { addToast } from 'features/system/store/systemSlice'; +import { makeToast } from 'features/system/util/makeToast'; import { memo, useCallback, useRef } from 'react'; import { useTranslation } from 'react-i18next'; import { FaUpload } from 'react-icons/fa'; import { useReactFlow } from 'reactflow'; +interface JsonFile { + [key: string]: unknown; +} + +function validateInvokeAIGraph(jsonFile: JsonFile): boolean { + const keys = ['nodes', 'edges', 'viewport']; + for (const key of keys) { + if (!(key in jsonFile)) { + return false; + } + } + return true; +} + const LoadNodesButton = () => { const { t } = useTranslation(); const dispatch = useAppDispatch(); @@ -22,33 +36,45 @@ const LoadNodesButton = () => { const reader = new FileReader(); reader.onload = async () => { const json = reader.result; - const retrievedNodeTree = await JSON.parse(String(json)); - if (!retrievedNodeTree) { - dispatch( - addToast( - makeToast({ - title: t('toast.nodesLoadedFailed'), - status: 'error', - }) - ) - ); + try { + const retrievedNodeTree = await JSON.parse(String(json)); + const isValidNodeTree = validateInvokeAIGraph(retrievedNodeTree); + + if (isValidNodeTree) { + dispatch(loadFileNodes(retrievedNodeTree.nodes)); + dispatch(loadFileEdges(retrievedNodeTree.edges)); + fitView(); + + dispatch( + addToast( + makeToast({ title: t('toast.nodesLoaded'), status: 'success' }) + ) + ); + } else { + dispatch( + addToast( + makeToast({ + title: t('toast.nodesNotValidGraph'), + status: 'error', + }) + ) + ); + } + // Cleanup + reader.abort(); + } catch (error) { + if (error) { + dispatch( + addToast( + makeToast({ + title: t('toast.nodesNotValidJSON'), + status: 'error', + }) + ) + ); + } } - - if (retrievedNodeTree) { - dispatch(loadFileNodes(retrievedNodeTree.nodes)); - dispatch(loadFileEdges(retrievedNodeTree.edges)); - fitView(); - - dispatch( - addToast( - makeToast({ title: t('toast.nodesLoaded'), status: 'success' }) - ) - ); - } - - // Cleanup - reader.abort(); }; reader.readAsText(v); From 225f60855647147e6e4a67461df36c6fc437cf3d Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Sun, 23 Jul 2023 16:49:52 +1200 Subject: [PATCH 02/10] fix: Add more sanity checks & rename buttons to Graphs --- invokeai/frontend/web/public/locales/en.json | 6 ++--- .../components/panels/TopCenterPanel.tsx | 12 ++++----- ...arNodesButton.tsx => ClearGraphButton.tsx} | 16 ++++++------ ...oadNodesButton.tsx => LoadGraphButton.tsx} | 25 +++++++++++++------ ...aveNodesButton.tsx => SaveGraphButton.tsx} | 8 +++--- 5 files changed, 39 insertions(+), 28 deletions(-) rename invokeai/frontend/web/src/features/nodes/components/ui/{ClearNodesButton.tsx => ClearGraphButton.tsx} (87%) rename invokeai/frontend/web/src/features/nodes/components/ui/{LoadNodesButton.tsx => LoadGraphButton.tsx} (82%) rename invokeai/frontend/web/src/features/nodes/components/ui/{SaveNodesButton.tsx => SaveGraphButton.tsx} (90%) diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index ab10276491..404c2013e4 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -702,9 +702,9 @@ }, "nodes": { "reloadSchema": "Reload Schema", - "saveNodes": "Save Nodes", - "loadNodes": "Load Nodes", - "clearNodes": "Clear Nodes", + "saveGraph": "Save Graph", + "loadGraph": "Load Graph (saved from Node Editor) (Do not copy-paste metadata)", + "clearGraph": "Clear Graph", "zoomInNodes": "Zoom In", "zoomOutNodes": "Zoom Out", "fitViewportNodes": "Fit View", diff --git a/invokeai/frontend/web/src/features/nodes/components/panels/TopCenterPanel.tsx b/invokeai/frontend/web/src/features/nodes/components/panels/TopCenterPanel.tsx index 90f8039285..21076e16f5 100644 --- a/invokeai/frontend/web/src/features/nodes/components/panels/TopCenterPanel.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/panels/TopCenterPanel.tsx @@ -2,11 +2,11 @@ import { HStack } from '@chakra-ui/react'; import CancelButton from 'features/parameters/components/ProcessButtons/CancelButton'; import { memo } from 'react'; import { Panel } from 'reactflow'; -import LoadNodesButton from '../ui/LoadNodesButton'; +import ClearGraphButton from '../ui/ClearGraphButton'; +import LoadGraphButton from '../ui/LoadGraphButton'; import NodeInvokeButton from '../ui/NodeInvokeButton'; import ReloadSchemaButton from '../ui/ReloadSchemaButton'; -import SaveNodesButton from '../ui/SaveNodesButton'; -import ClearNodesButton from '../ui/ClearNodesButton'; +import SaveGraphButton from '../ui/SaveGraphButton'; const TopCenterPanel = () => { return ( @@ -15,9 +15,9 @@ const TopCenterPanel = () => { - - - + + + ); diff --git a/invokeai/frontend/web/src/features/nodes/components/ui/ClearNodesButton.tsx b/invokeai/frontend/web/src/features/nodes/components/ui/ClearGraphButton.tsx similarity index 87% rename from invokeai/frontend/web/src/features/nodes/components/ui/ClearNodesButton.tsx rename to invokeai/frontend/web/src/features/nodes/components/ui/ClearGraphButton.tsx index 86d9d08a84..88fb60ee0f 100644 --- a/invokeai/frontend/web/src/features/nodes/components/ui/ClearNodesButton.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/ui/ClearGraphButton.tsx @@ -9,17 +9,17 @@ import { Text, useDisclosure, } from '@chakra-ui/react'; -import { makeToast } from 'features/system/util/makeToast'; import { RootState } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIIconButton from 'common/components/IAIIconButton'; import { nodeEditorReset } from 'features/nodes/store/nodesSlice'; import { addToast } from 'features/system/store/systemSlice'; -import { memo, useRef, useCallback } from 'react'; +import { makeToast } from 'features/system/util/makeToast'; +import { memo, useCallback, useRef } from 'react'; import { useTranslation } from 'react-i18next'; import { FaTrash } from 'react-icons/fa'; -const ClearNodesButton = () => { +const ClearGraphButton = () => { const { t } = useTranslation(); const dispatch = useAppDispatch(); const { isOpen, onOpen, onClose } = useDisclosure(); @@ -46,8 +46,8 @@ const ClearNodesButton = () => { <> } - tooltip={t('nodes.clearNodes')} - aria-label={t('nodes.clearNodes')} + tooltip={t('nodes.clearGraph')} + aria-label={t('nodes.clearGraph')} onClick={onOpen} isDisabled={nodes.length === 0} /> @@ -62,11 +62,11 @@ const ClearNodesButton = () => { - {t('nodes.clearNodes')} + {t('nodes.clearGraph')} - {t('common.clearNodes')} + {t('common.clearGraph')} @@ -83,4 +83,4 @@ const ClearNodesButton = () => { ); }; -export default memo(ClearNodesButton); +export default memo(ClearGraphButton); diff --git a/invokeai/frontend/web/src/features/nodes/components/ui/LoadNodesButton.tsx b/invokeai/frontend/web/src/features/nodes/components/ui/LoadGraphButton.tsx similarity index 82% rename from invokeai/frontend/web/src/features/nodes/components/ui/LoadNodesButton.tsx rename to invokeai/frontend/web/src/features/nodes/components/ui/LoadGraphButton.tsx index 2aa369bc11..437418e18a 100644 --- a/invokeai/frontend/web/src/features/nodes/components/ui/LoadNodesButton.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/ui/LoadGraphButton.tsx @@ -13,17 +13,28 @@ interface JsonFile { [key: string]: unknown; } -function validateInvokeAIGraph(jsonFile: JsonFile): boolean { +function sanityCheckInvokeAIGraph(jsonFile: JsonFile): boolean { const keys = ['nodes', 'edges', 'viewport']; for (const key of keys) { if (!(key in jsonFile)) { return false; } } + + if (!Array.isArray(jsonFile.nodes) || !Array.isArray(jsonFile.edges)) { + return false; + } + + for (const node of jsonFile.nodes) { + if (!('data' in node)) { + return false; + } + } + return true; } -const LoadNodesButton = () => { +const LoadGraphButton = () => { const { t } = useTranslation(); const dispatch = useAppDispatch(); const { fitView } = useReactFlow(); @@ -39,9 +50,9 @@ const LoadNodesButton = () => { try { const retrievedNodeTree = await JSON.parse(String(json)); - const isValidNodeTree = validateInvokeAIGraph(retrievedNodeTree); + const isSaneNodeTree = sanityCheckInvokeAIGraph(retrievedNodeTree); - if (isValidNodeTree) { + if (isSaneNodeTree) { dispatch(loadFileNodes(retrievedNodeTree.nodes)); dispatch(loadFileEdges(retrievedNodeTree.edges)); fitView(); @@ -93,8 +104,8 @@ const LoadNodesButton = () => { {(props) => ( } - tooltip={t('nodes.loadNodes')} - aria-label={t('nodes.loadNodes')} + tooltip={t('nodes.loadGraph')} + aria-label={t('nodes.loadGraph')} {...props} /> )} @@ -102,4 +113,4 @@ const LoadNodesButton = () => { ); }; -export default memo(LoadNodesButton); +export default memo(LoadGraphButton); diff --git a/invokeai/frontend/web/src/features/nodes/components/ui/SaveNodesButton.tsx b/invokeai/frontend/web/src/features/nodes/components/ui/SaveGraphButton.tsx similarity index 90% rename from invokeai/frontend/web/src/features/nodes/components/ui/SaveNodesButton.tsx rename to invokeai/frontend/web/src/features/nodes/components/ui/SaveGraphButton.tsx index 5833182456..42e545258e 100644 --- a/invokeai/frontend/web/src/features/nodes/components/ui/SaveNodesButton.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/ui/SaveGraphButton.tsx @@ -6,7 +6,7 @@ import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { FaSave } from 'react-icons/fa'; -const SaveNodesButton = () => { +const SaveGraphButton = () => { const { t } = useTranslation(); const editorInstance = useAppSelector( (state: RootState) => state.nodes.editorInstance @@ -37,12 +37,12 @@ const SaveNodesButton = () => { } fontSize={18} - tooltip={t('nodes.saveNodes')} - aria-label={t('nodes.saveNodes')} + tooltip={t('nodes.saveGraph')} + aria-label={t('nodes.saveGraph')} onClick={saveEditorToJSON} isDisabled={nodes.length === 0} /> ); }; -export default memo(SaveNodesButton); +export default memo(SaveGraphButton); From af4579b4d494d37ae1615f94ae61c7631c28bc79 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Sun, 23 Jul 2023 18:12:25 +1200 Subject: [PATCH 03/10] feat: Add more sanity checks for graph loading --- invokeai/frontend/web/public/locales/en.json | 7 +- .../nodes/components/ui/ClearGraphButton.tsx | 2 +- .../nodes/components/ui/LoadGraphButton.tsx | 71 +++++++++++++++---- 3 files changed, 64 insertions(+), 16 deletions(-) diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 404c2013e4..0640ab9ef0 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -102,8 +102,7 @@ "openInNewTab": "Open in New Tab", "dontAskMeAgain": "Don't ask me again", "areYouSure": "Are you sure?", - "imagePrompt": "Image Prompt", - "clearNodes": "Are you sure you want to clear all nodes?" + "imagePrompt": "Image Prompt" }, "gallery": { "generations": "Generations", @@ -617,6 +616,9 @@ "nodesLoaded": "Nodes Loaded", "nodesNotValidGraph": "Not a valid InvokeAI Node Graph", "nodesNotValidJSON": "Not a valid JSON", + "nodesCorruptedGraph": "Cannot load. Graph seems to be corrupted.", + "nodesUnrecognizedTypes": "Cannot load. Graph has unrecognized types", + "nodesBrokenConnections": "Cannot load. Some connections are broken.", "nodesLoadedFailed": "Failed To Load Nodes", "nodesCleared": "Nodes Cleared" }, @@ -705,6 +707,7 @@ "saveGraph": "Save Graph", "loadGraph": "Load Graph (saved from Node Editor) (Do not copy-paste metadata)", "clearGraph": "Clear Graph", + "clearGraphDesc": "Are you sure you want to clear all nodes?", "zoomInNodes": "Zoom In", "zoomOutNodes": "Zoom Out", "fitViewportNodes": "Fit View", diff --git a/invokeai/frontend/web/src/features/nodes/components/ui/ClearGraphButton.tsx b/invokeai/frontend/web/src/features/nodes/components/ui/ClearGraphButton.tsx index 88fb60ee0f..432675c5cd 100644 --- a/invokeai/frontend/web/src/features/nodes/components/ui/ClearGraphButton.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/ui/ClearGraphButton.tsx @@ -66,7 +66,7 @@ const ClearGraphButton = () => { - {t('common.clearGraph')} + {t('nodes.clearGraphDesc')} diff --git a/invokeai/frontend/web/src/features/nodes/components/ui/LoadGraphButton.tsx b/invokeai/frontend/web/src/features/nodes/components/ui/LoadGraphButton.tsx index 437418e18a..44d93bb8fe 100644 --- a/invokeai/frontend/web/src/features/nodes/components/ui/LoadGraphButton.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/ui/LoadGraphButton.tsx @@ -4,6 +4,7 @@ import IAIIconButton from 'common/components/IAIIconButton'; import { loadFileEdges, loadFileNodes } from 'features/nodes/store/nodesSlice'; import { addToast } from 'features/system/store/systemSlice'; import { makeToast } from 'features/system/util/makeToast'; +import i18n from 'i18n'; import { memo, useCallback, useRef } from 'react'; import { useTranslation } from 'react-i18next'; import { FaUpload } from 'react-icons/fa'; @@ -13,25 +14,70 @@ interface JsonFile { [key: string]: unknown; } -function sanityCheckInvokeAIGraph(jsonFile: JsonFile): boolean { +function sanityCheckInvokeAIGraph(jsonFile: JsonFile): { + isValid: boolean; + message: string; +} { + // Check if primary keys exist const keys = ['nodes', 'edges', 'viewport']; for (const key of keys) { if (!(key in jsonFile)) { - return false; + return { + isValid: false, + message: i18n.t('toast.nodesNotValidGraph'), + }; } } + // Check if nodes and edges are arrays if (!Array.isArray(jsonFile.nodes) || !Array.isArray(jsonFile.edges)) { - return false; + return { + isValid: false, + message: i18n.t('toast.nodesNotValidGraph'), + }; } - for (const node of jsonFile.nodes) { - if (!('data' in node)) { - return false; + // Check if data is present in nodes + const nodeKeys = ['data', 'type']; + const nodeTypes = ['invocation', 'progress_image']; + if (jsonFile.nodes.length > 0) { + for (const node of jsonFile.nodes) { + for (const nodeKey of nodeKeys) { + if (!(nodeKey in node)) { + return { + isValid: false, + message: i18n.t('toast.nodesNotValidGraph'), + }; + } + if (nodeKey === 'type' && !nodeTypes.includes(node[nodeKey])) { + return { + isValid: false, + message: i18n.t('toast.nodesUnrecognizedTypes'), + }; + } + } } } - return true; + // Check Edge Object + const edgeKeys = ['source', 'sourceHandle', 'target', 'targetHandle']; + if (jsonFile.edges.length > 0) { + for (const edge of jsonFile.edges) { + for (const edgeKey of edgeKeys) { + if (!(edgeKey in edge)) { + return { + isValid: false, + message: i18n.t('toast.nodesBrokenConnections'), + }; + } + } + } + } + + return { + isValid: true, + message: i18n.t('toast.nodesLoaded'), + }; } const LoadGraphButton = () => { @@ -50,23 +96,22 @@ const LoadGraphButton = () => { try { const retrievedNodeTree = await JSON.parse(String(json)); - const isSaneNodeTree = sanityCheckInvokeAIGraph(retrievedNodeTree); + const { isValid, message } = + sanityCheckInvokeAIGraph(retrievedNodeTree); - if (isSaneNodeTree) { + if (isValid) { dispatch(loadFileNodes(retrievedNodeTree.nodes)); dispatch(loadFileEdges(retrievedNodeTree.edges)); fitView(); dispatch( - addToast( - makeToast({ title: t('toast.nodesLoaded'), status: 'success' }) - ) + addToast(makeToast({ title: message, status: 'success' })) ); } else { dispatch( addToast( makeToast({ - title: t('toast.nodesNotValidGraph'), + title: message, status: 'error', }) ) From 4b334be7d030d369bbac2d0dc4d30fb27954fa2b Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 23 Jul 2023 12:27:59 +1000 Subject: [PATCH 04/10] feat(nodes,ui): fix soft locks on session/invocation retrieval When a queue item is popped for processing, we need to retrieve its session from the DB. Pydantic serializes the graph at this stage. It's possible for a graph to have been made invalid during the graph preparation stage (e.g. an ancestor node executes, and its output is not valid for its successor node's input field). When this occurs, the session in the DB will fail validation, but we don't have a chance to find out until it is retrieved and parsed by pydantic. This logic was previously not wrapped in any exception handling. Just after retrieving a session, we retrieve the specific invocation to execute from the session. It's possible that this could also have some sort of error, though it should be impossible for it to be a pydantic validation error (that would have been caught during session validation). There was also no exception handling here. When either of these processes fail, the processor gets soft-locked because the processor's cleanup logic is never run. (I didn't dig deeper into exactly what cleanup is not happening, because the fix is to just handle the exceptions.) This PR adds exception handling to both the session retrieval and node retrieval and events for each: `session_retrieval_error` and `invocation_retrieval_error`. These events are caught and displayed in the UI as toasts, along with the type of the python exception (e.g. `Validation Error`). The events are also logged to the browser console. --- invokeai/app/services/events.py | 76 +++++++++++++++---- invokeai/app/services/processor.py | 41 +++++++--- .../middleware/listenerMiddleware/index.ts | 4 + .../listeners/sessionCreated.ts | 7 +- .../listeners/sessionInvoked.ts | 3 +- .../socketInvocationRetrievalError.ts | 20 +++++ .../socketio/socketSessionRetrievalError.ts | 20 +++++ .../src/features/system/store/systemSlice.ts | 62 +++++++++------ .../web/src/services/api/thunks/session.ts | 13 +++- .../web/src/services/events/actions.ts | 34 +++++++++ .../frontend/web/src/services/events/types.ts | 26 +++++++ .../services/events/util/setEventListeners.ts | 24 ++++++ 12 files changed, 273 insertions(+), 57 deletions(-) create mode 100644 invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationRetrievalError.ts create mode 100644 invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSessionRetrievalError.ts diff --git a/invokeai/app/services/events.py b/invokeai/app/services/events.py index 35003536e6..73d74de2d9 100644 --- a/invokeai/app/services/events.py +++ b/invokeai/app/services/events.py @@ -3,7 +3,13 @@ from typing import Any, Optional from invokeai.app.models.image import ProgressImage from invokeai.app.util.misc import get_timestamp -from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, ModelInfo +from invokeai.app.services.model_manager_service import ( + BaseModelType, + ModelType, + SubModelType, + ModelInfo, +) + class EventServiceBase: session_event: str = "session_event" @@ -38,7 +44,9 @@ class EventServiceBase: graph_execution_state_id=graph_execution_state_id, node=node, source_node_id=source_node_id, - progress_image=progress_image.dict() if progress_image is not None else None, + progress_image=progress_image.dict() + if progress_image is not None + else None, step=step, total_steps=total_steps, ), @@ -67,6 +75,7 @@ class EventServiceBase: graph_execution_state_id: str, node: dict, source_node_id: str, + error_type: str, error: str, ) -> None: """Emitted when an invocation has completed""" @@ -76,6 +85,7 @@ class EventServiceBase: graph_execution_state_id=graph_execution_state_id, node=node, source_node_id=source_node_id, + error_type=error_type, error=error, ), ) @@ -102,13 +112,13 @@ class EventServiceBase: ), ) - def emit_model_load_started ( - self, - graph_execution_state_id: str, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: SubModelType, + def emit_model_load_started( + self, + graph_execution_state_id: str, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + submodel: SubModelType, ) -> None: """Emitted when a model is requested""" self.__emit_session_event( @@ -123,13 +133,13 @@ class EventServiceBase: ) def emit_model_load_completed( - self, - graph_execution_state_id: str, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: SubModelType, - model_info: ModelInfo, + self, + graph_execution_state_id: str, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + submodel: SubModelType, + model_info: ModelInfo, ) -> None: """Emitted when a model is correctly loaded (returns model info)""" self.__emit_session_event( @@ -145,3 +155,37 @@ class EventServiceBase: precision=str(model_info.precision), ), ) + + def emit_session_retrieval_error( + self, + graph_execution_state_id: str, + error_type: str, + error: str, + ) -> None: + """Emitted when session retrieval fails""" + self.__emit_session_event( + event_name="session_retrieval_error", + payload=dict( + graph_execution_state_id=graph_execution_state_id, + error_type=error_type, + error=error, + ), + ) + + def emit_invocation_retrieval_error( + self, + graph_execution_state_id: str, + node_id: str, + error_type: str, + error: str, + ) -> None: + """Emitted when invocation retrieval fails""" + self.__emit_session_event( + event_name="invocation_retrieval_error", + payload=dict( + graph_execution_state_id=graph_execution_state_id, + node_id=node_id, + error_type=error_type, + error=error, + ), + ) diff --git a/invokeai/app/services/processor.py b/invokeai/app/services/processor.py index e11eb84b3d..5995e4ffc3 100644 --- a/invokeai/app/services/processor.py +++ b/invokeai/app/services/processor.py @@ -39,21 +39,41 @@ class DefaultInvocationProcessor(InvocationProcessorABC): try: queue_item: InvocationQueueItem = self.__invoker.services.queue.get() except Exception as e: - logger.debug("Exception while getting from queue: %s" % e) + self.__invoker.services.logger.error("Exception while getting from queue:\n%s" % e) if not queue_item: # Probably stopping # do not hammer the queue time.sleep(0.5) continue - graph_execution_state = ( - self.__invoker.services.graph_execution_manager.get( - queue_item.graph_execution_state_id + try: + graph_execution_state = ( + self.__invoker.services.graph_execution_manager.get( + queue_item.graph_execution_state_id + ) ) - ) - invocation = graph_execution_state.execution_graph.get_node( - queue_item.invocation_id - ) + except Exception as e: + self.__invoker.services.logger.error("Exception while retrieving session:\n%s" % e) + self.__invoker.services.events.emit_session_retrieval_error( + graph_execution_state_id=queue_item.graph_execution_state_id, + error_type=e.__class__.__name__, + error=traceback.format_exc(), + ) + continue + + try: + invocation = graph_execution_state.execution_graph.get_node( + queue_item.invocation_id + ) + except Exception as e: + self.__invoker.services.logger.error("Exception while retrieving invocation:\n%s" % e) + self.__invoker.services.events.emit_invocation_retrieval_error( + graph_execution_state_id=queue_item.graph_execution_state_id, + node_id=queue_item.invocation_id, + error_type=e.__class__.__name__, + error=traceback.format_exc(), + ) + continue # get the source node id to provide to clients (the prepared node id is not as useful) source_node_id = graph_execution_state.prepared_source_mapping[invocation.id] @@ -114,11 +134,13 @@ class DefaultInvocationProcessor(InvocationProcessorABC): graph_execution_state ) + self.__invoker.services.logger.error("Error while invoking:\n%s" % e) # Send error event self.__invoker.services.events.emit_invocation_error( graph_execution_state_id=graph_execution_state.id, node=invocation.dict(), source_node_id=source_node_id, + error_type=e.__class__.__name__, error=error, ) @@ -136,11 +158,12 @@ class DefaultInvocationProcessor(InvocationProcessorABC): try: self.__invoker.invoke(graph_execution_state, invoke_all=True) except Exception as e: - logger.error("Error while invoking: %s" % e) + self.__invoker.services.logger.error("Error while invoking:\n%s" % e) self.__invoker.services.events.emit_invocation_error( graph_execution_state_id=graph_execution_state.id, node=invocation.dict(), source_node_id=source_node_id, + error_type=e.__class__.__name__, error=traceback.format_exc() ) elif is_complete: diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts index 04f0ce7a0b..5adc4f5e5e 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts @@ -75,6 +75,8 @@ import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas'; import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage'; import { addUserInvokedNodesListener } from './listeners/userInvokedNodes'; import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage'; +import { addSessionRetrievalErrorEventListener } from './listeners/socketio/socketSessionRetrievalError'; +import { addInvocationRetrievalErrorEventListener } from './listeners/socketio/socketInvocationRetrievalError'; export const listenerMiddleware = createListenerMiddleware(); @@ -153,6 +155,8 @@ addSocketDisconnectedListener(); addSocketSubscribedListener(); addSocketUnsubscribedListener(); addModelLoadEventListener(); +addSessionRetrievalErrorEventListener(); +addInvocationRetrievalErrorEventListener(); // Session Created addSessionCreatedPendingListener(); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionCreated.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionCreated.ts index 5709d87d22..e89acb7542 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionCreated.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionCreated.ts @@ -33,12 +33,11 @@ export const addSessionCreatedRejectedListener = () => { effect: (action) => { const log = logger('session'); if (action.payload) { - const { error } = action.payload; + const { error, status } = action.payload; const graph = parseify(action.meta.arg); - const stringifiedError = JSON.stringify(error); log.error( - { graph, error: serializeError(error) }, - `Problem creating session: ${stringifiedError}` + { graph, status, error: serializeError(error) }, + `Problem creating session` ); } }, diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionInvoked.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionInvoked.ts index 60009ed194..a62f75d957 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionInvoked.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionInvoked.ts @@ -31,13 +31,12 @@ export const addSessionInvokedRejectedListener = () => { const { session_id } = action.meta.arg; if (action.payload) { const { error } = action.payload; - const stringifiedError = JSON.stringify(error); log.error( { session_id, error: serializeError(error), }, - `Problem invoking session: ${stringifiedError}` + `Problem invoking session` ); } }, diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationRetrievalError.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationRetrievalError.ts new file mode 100644 index 0000000000..aa88457eb7 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationRetrievalError.ts @@ -0,0 +1,20 @@ +import { logger } from 'app/logging/logger'; +import { + appSocketInvocationRetrievalError, + socketInvocationRetrievalError, +} from 'services/events/actions'; +import { startAppListening } from '../..'; + +export const addInvocationRetrievalErrorEventListener = () => { + startAppListening({ + actionCreator: socketInvocationRetrievalError, + effect: (action, { dispatch }) => { + const log = logger('socketio'); + log.error( + action.payload, + `Invocation retrieval error (${action.payload.data.graph_execution_state_id})` + ); + dispatch(appSocketInvocationRetrievalError(action.payload)); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSessionRetrievalError.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSessionRetrievalError.ts new file mode 100644 index 0000000000..7efb7f463a --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSessionRetrievalError.ts @@ -0,0 +1,20 @@ +import { logger } from 'app/logging/logger'; +import { + appSocketSessionRetrievalError, + socketSessionRetrievalError, +} from 'services/events/actions'; +import { startAppListening } from '../..'; + +export const addSessionRetrievalErrorEventListener = () => { + startAppListening({ + actionCreator: socketSessionRetrievalError, + effect: (action, { dispatch }) => { + const log = logger('socketio'); + log.error( + action.payload, + `Session retrieval error (${action.payload.data.graph_execution_state_id})` + ); + dispatch(appSocketSessionRetrievalError(action.payload)); + }, + }); +}; diff --git a/invokeai/frontend/web/src/features/system/store/systemSlice.ts b/invokeai/frontend/web/src/features/system/store/systemSlice.ts index 629a4f0139..b7a5e606e2 100644 --- a/invokeai/frontend/web/src/features/system/store/systemSlice.ts +++ b/invokeai/frontend/web/src/features/system/store/systemSlice.ts @@ -1,5 +1,5 @@ import { UseToastOptions } from '@chakra-ui/react'; -import { PayloadAction, createSlice } from '@reduxjs/toolkit'; +import { PayloadAction, createSlice, isAnyOf } from '@reduxjs/toolkit'; import { InvokeLogLevel } from 'app/logging/logger'; import { userInvoked } from 'app/store/actions'; import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice'; @@ -16,13 +16,16 @@ import { appSocketGraphExecutionStateComplete, appSocketInvocationComplete, appSocketInvocationError, + appSocketInvocationRetrievalError, appSocketInvocationStarted, + appSocketSessionRetrievalError, appSocketSubscribed, appSocketUnsubscribed, } from 'services/events/actions'; import { ProgressImage } from 'services/events/types'; import { makeToast } from '../util/makeToast'; import { LANGUAGES } from './constants'; +import { startCase } from 'lodash-es'; export type CancelStrategy = 'immediate' | 'scheduled'; @@ -288,25 +291,6 @@ export const systemSlice = createSlice({ } }); - /** - * Invocation Error - */ - builder.addCase(appSocketInvocationError, (state) => { - state.isProcessing = false; - state.isCancelable = true; - // state.currentIteration = 0; - // state.totalIterations = 0; - state.currentStatusHasSteps = false; - state.currentStep = 0; - state.totalSteps = 0; - state.statusTranslationKey = 'common.statusError'; - state.progressImage = null; - - state.toastQueue.push( - makeToast({ title: t('toast.serverError'), status: 'error' }) - ); - }); - /** * Graph Execution State Complete */ @@ -362,7 +346,7 @@ export const systemSlice = createSlice({ * Session Invoked - REJECTED * Session Created - REJECTED */ - builder.addMatcher(isAnySessionRejected, (state) => { + builder.addMatcher(isAnySessionRejected, (state, action) => { state.isProcessing = false; state.isCancelable = false; state.isCancelScheduled = false; @@ -372,7 +356,35 @@ export const systemSlice = createSlice({ state.progressImage = null; state.toastQueue.push( - makeToast({ title: t('toast.serverError'), status: 'error' }) + makeToast({ + title: t('toast.serverError'), + status: 'error', + description: + action.payload?.status === 422 ? 'Validation Error' : undefined, + }) + ); + }); + + /** + * Any server error + */ + builder.addMatcher(isAnyServerError, (state, action) => { + state.isProcessing = false; + state.isCancelable = true; + // state.currentIteration = 0; + // state.totalIterations = 0; + state.currentStatusHasSteps = false; + state.currentStep = 0; + state.totalSteps = 0; + state.statusTranslationKey = 'common.statusError'; + state.progressImage = null; + + state.toastQueue.push( + makeToast({ + title: t('toast.serverError'), + status: 'error', + description: startCase(action.payload.data.error_type), + }) ); }); }, @@ -400,3 +412,9 @@ export const { } = systemSlice.actions; export default systemSlice.reducer; + +const isAnyServerError = isAnyOf( + appSocketInvocationError, + appSocketSessionRetrievalError, + appSocketInvocationRetrievalError +); diff --git a/invokeai/frontend/web/src/services/api/thunks/session.ts b/invokeai/frontend/web/src/services/api/thunks/session.ts index 6d20b9dd33..5588f25b46 100644 --- a/invokeai/frontend/web/src/services/api/thunks/session.ts +++ b/invokeai/frontend/web/src/services/api/thunks/session.ts @@ -18,7 +18,7 @@ type CreateSessionResponse = O.Required< >; type CreateSessionThunkConfig = { - rejectValue: { arg: CreateSessionArg; error: unknown }; + rejectValue: { arg: CreateSessionArg; status: number; error: unknown }; }; /** @@ -36,7 +36,7 @@ export const sessionCreated = createAsyncThunk< }); if (error) { - return rejectWithValue({ arg, error }); + return rejectWithValue({ arg, status: response.status, error }); } return data; @@ -53,6 +53,7 @@ type InvokedSessionThunkConfig = { rejectValue: { arg: InvokedSessionArg; error: unknown; + status: number; }; }; @@ -78,9 +79,13 @@ export const sessionInvoked = createAsyncThunk< if (error) { if (isErrorWithStatus(error) && error.status === 403) { - return rejectWithValue({ arg, error: (error as any).body.detail }); + return rejectWithValue({ + arg, + status: response.status, + error: (error as any).body.detail, + }); } - return rejectWithValue({ arg, error }); + return rejectWithValue({ arg, status: response.status, error }); } }); diff --git a/invokeai/frontend/web/src/services/events/actions.ts b/invokeai/frontend/web/src/services/events/actions.ts index b6316c5e95..35ebb725cb 100644 --- a/invokeai/frontend/web/src/services/events/actions.ts +++ b/invokeai/frontend/web/src/services/events/actions.ts @@ -4,9 +4,11 @@ import { GraphExecutionStateCompleteEvent, InvocationCompleteEvent, InvocationErrorEvent, + InvocationRetrievalErrorEvent, InvocationStartedEvent, ModelLoadCompletedEvent, ModelLoadStartedEvent, + SessionRetrievalErrorEvent, } from 'services/events/types'; // Create actions for each socket @@ -181,3 +183,35 @@ export const socketModelLoadCompleted = createAction<{ export const appSocketModelLoadCompleted = createAction<{ data: ModelLoadCompletedEvent; }>('socket/appSocketModelLoadCompleted'); + +/** + * Socket.IO Session Retrieval Error + * + * Do not use. Only for use in middleware. + */ +export const socketSessionRetrievalError = createAction<{ + data: SessionRetrievalErrorEvent; +}>('socket/socketSessionRetrievalError'); + +/** + * App-level Session Retrieval Error + */ +export const appSocketSessionRetrievalError = createAction<{ + data: SessionRetrievalErrorEvent; +}>('socket/appSocketSessionRetrievalError'); + +/** + * Socket.IO Invocation Retrieval Error + * + * Do not use. Only for use in middleware. + */ +export const socketInvocationRetrievalError = createAction<{ + data: InvocationRetrievalErrorEvent; +}>('socket/socketInvocationRetrievalError'); + +/** + * App-level Invocation Retrieval Error + */ +export const appSocketInvocationRetrievalError = createAction<{ + data: InvocationRetrievalErrorEvent; +}>('socket/appSocketInvocationRetrievalError'); diff --git a/invokeai/frontend/web/src/services/events/types.ts b/invokeai/frontend/web/src/services/events/types.ts index ec1b55e3fe..37f5f24eac 100644 --- a/invokeai/frontend/web/src/services/events/types.ts +++ b/invokeai/frontend/web/src/services/events/types.ts @@ -87,6 +87,7 @@ export type InvocationErrorEvent = { graph_execution_state_id: string; node: BaseNode; source_node_id: string; + error_type: string; error: string; }; @@ -110,6 +111,29 @@ export type GraphExecutionStateCompleteEvent = { graph_execution_state_id: string; }; +/** + * A `session_retrieval_error` socket.io event. + * + * @example socket.on('session_retrieval_error', (data: SessionRetrievalErrorEvent) => { ... } + */ +export type SessionRetrievalErrorEvent = { + graph_execution_state_id: string; + error_type: string; + error: string; +}; + +/** + * A `invocation_retrieval_error` socket.io event. + * + * @example socket.on('invocation_retrieval_error', (data: InvocationRetrievalErrorEvent) => { ... } + */ +export type InvocationRetrievalErrorEvent = { + graph_execution_state_id: string; + node_id: string; + error_type: string; + error: string; +}; + export type ClientEmitSubscribe = { session: string; }; @@ -128,6 +152,8 @@ export type ServerToClientEvents = { ) => void; model_load_started: (payload: ModelLoadStartedEvent) => void; model_load_completed: (payload: ModelLoadCompletedEvent) => void; + session_retrieval_error: (payload: SessionRetrievalErrorEvent) => void; + invocation_retrieval_error: (payload: InvocationRetrievalErrorEvent) => void; }; export type ClientToServerEvents = { diff --git a/invokeai/frontend/web/src/services/events/util/setEventListeners.ts b/invokeai/frontend/web/src/services/events/util/setEventListeners.ts index d44a549183..9ebb7ffbff 100644 --- a/invokeai/frontend/web/src/services/events/util/setEventListeners.ts +++ b/invokeai/frontend/web/src/services/events/util/setEventListeners.ts @@ -11,9 +11,11 @@ import { socketGraphExecutionStateComplete, socketInvocationComplete, socketInvocationError, + socketInvocationRetrievalError, socketInvocationStarted, socketModelLoadCompleted, socketModelLoadStarted, + socketSessionRetrievalError, socketSubscribed, } from '../actions'; import { ClientToServerEvents, ServerToClientEvents } from '../types'; @@ -138,4 +140,26 @@ export const setEventListeners = (arg: SetEventListenersArg) => { }) ); }); + + /** + * Session retrieval error + */ + socket.on('session_retrieval_error', (data) => { + dispatch( + socketSessionRetrievalError({ + data, + }) + ); + }); + + /** + * Invocation retrieval error + */ + socket.on('invocation_retrieval_error', (data) => { + dispatch( + socketInvocationRetrievalError({ + data, + }) + ); + }); }; From 07a90c019800ce485c88dabe421f3896ebf38f72 Mon Sep 17 00:00:00 2001 From: Alexandre Macabies Date: Sun, 23 Jul 2023 14:49:28 +0200 Subject: [PATCH 05/10] Fix incorrect use of a singleton list. This was found through pylance type errors. Go types! --- invokeai/app/api/routers/models.py | 2 +- invokeai/app/services/model_manager_service.py | 2 +- invokeai/backend/model_management/model_search.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index 870ca33534..759f6c9f59 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -298,7 +298,7 @@ async def search_for_models( )->List[pathlib.Path]: if not search_path.is_dir(): raise HTTPException(status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory") - return ApiDependencies.invoker.services.model_manager.search_for_models([search_path]) + return ApiDependencies.invoker.services.model_manager.search_for_models(search_path) @models_router.get( "/ckpt_confs", diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index b1b995309e..f7d3b3a7a7 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -600,7 +600,7 @@ class ModelManagerService(ModelManagerServiceBase): """ Return list of all models found in the designated directory. """ - search = FindModels(directory,self.logger) + search = FindModels([directory], self.logger) return search.list_models() def sync_to_config(self): diff --git a/invokeai/backend/model_management/model_search.py b/invokeai/backend/model_management/model_search.py index 1e282b4bb8..5657bd9549 100644 --- a/invokeai/backend/model_management/model_search.py +++ b/invokeai/backend/model_management/model_search.py @@ -98,6 +98,6 @@ class FindModels(ModelSearch): def list_models(self) -> List[Path]: self.search() - return self.models_found + return list(self.models_found) From 0beec08d3822b398f855ab11ee5064ed5b479bce Mon Sep 17 00:00:00 2001 From: Alexandre Macabies Date: Sun, 23 Jul 2023 14:46:16 +0200 Subject: [PATCH 06/10] Add missing import. --- invokeai/backend/model_management/models/lora.py | 1 + 1 file changed, 1 insertion(+) diff --git a/invokeai/backend/model_management/models/lora.py b/invokeai/backend/model_management/models/lora.py index 5387ade0e5..eb771841ec 100644 --- a/invokeai/backend/model_management/models/lora.py +++ b/invokeai/backend/model_management/models/lora.py @@ -10,6 +10,7 @@ from .base import ( SubModelType, classproperty, InvalidModelException, + ModelNotFoundException, ) # TODO: naming from ..lora import LoRAModel as LoRAModelRaw From d5a75eb83301f89a2240f2ca4a9719d04ef574b5 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 24 Jul 2023 16:34:34 +1000 Subject: [PATCH 07/10] feat: increase seed from int32 to uint32 At some point I typo'd this and set the max seed to signed int32 max. It should be *un*signed int32 max. This restored the seed range to what it was in v2.3. --- invokeai/app/util/misc.py | 2 +- invokeai/frontend/web/src/app/constants.ts | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/invokeai/app/util/misc.py b/invokeai/app/util/misc.py index 7c674674e2..be5b698258 100644 --- a/invokeai/app/util/misc.py +++ b/invokeai/app/util/misc.py @@ -14,7 +14,7 @@ def get_datetime_from_iso_timestamp(iso_timestamp: str) -> datetime.datetime: return datetime.datetime.fromisoformat(iso_timestamp) -SEED_MAX = np.iinfo(np.int32).max +SEED_MAX = np.iinfo(np.uint32).max def get_random_seed(): diff --git a/invokeai/frontend/web/src/app/constants.ts b/invokeai/frontend/web/src/app/constants.ts index 1194ea467b..b8fab16c1c 100644 --- a/invokeai/frontend/web/src/app/constants.ts +++ b/invokeai/frontend/web/src/app/constants.ts @@ -1,2 +1,2 @@ export const NUMPY_RAND_MIN = 0; -export const NUMPY_RAND_MAX = 2147483647; +export const NUMPY_RAND_MAX = 4294967295; From 66cdeba8a1e8ba78f255ea3ca3269c9331a889cf Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 24 Jul 2023 16:44:32 +1000 Subject: [PATCH 08/10] fix(nodes): fix seed modulus operation This was incorect and resulted in the max seed being one less than intended. --- invokeai/app/invocations/noise.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/invokeai/app/invocations/noise.py b/invokeai/app/invocations/noise.py index 442557520a..fff0f29f14 100644 --- a/invokeai/app/invocations/noise.py +++ b/invokeai/app/invocations/noise.py @@ -119,8 +119,8 @@ class NoiseInvocation(BaseInvocation): @validator("seed", pre=True) def modulo_seed(cls, v): - """Returns the seed modulo SEED_MAX to ensure it is within the valid range.""" - return v % SEED_MAX + """Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range.""" + return v % (SEED_MAX + 1) def invoke(self, context: InvocationContext) -> NoiseOutput: noise = get_noise( From 0cf7a10c5ccd413b001fb360ff7b0fcb6e961883 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Mon, 24 Jul 2023 18:58:24 +1200 Subject: [PATCH 09/10] fix: Other lora missing type --- invokeai/backend/model_management/lora.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index b0481f3cfa..222169afbb 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -474,7 +474,7 @@ class ModelPatcher: @staticmethod def _lora_forward_hook( - applied_loras: List[Tuple[LoraModel, float]], + applied_loras: List[Tuple[LoRAModel, float]], layer_name: str, ): @@ -519,7 +519,7 @@ class ModelPatcher: def apply_lora( cls, model: torch.nn.Module, - loras: List[Tuple[LoraModel, float]], + loras: List[Tuple[LoRAModel, float]], prefix: str, ): original_weights = dict() From e766ddbcf45b2eca1436fe8b51294e91769a4405 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Mon, 24 Jul 2023 19:38:21 +1200 Subject: [PATCH 10/10] fix: Generate random seed using the generator instead of RandomState --- invokeai/app/util/misc.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/invokeai/app/util/misc.py b/invokeai/app/util/misc.py index be5b698258..503f3af4c8 100644 --- a/invokeai/app/util/misc.py +++ b/invokeai/app/util/misc.py @@ -18,4 +18,5 @@ SEED_MAX = np.iinfo(np.uint32).max def get_random_seed(): - return np.random.randint(0, SEED_MAX) + rng = np.random.default_rng(seed=0) + return int(rng.integers(0, SEED_MAX))