feat: Add more sanity checks for graph loading

This commit is contained in:
blessedcoolant 2023-07-23 18:12:25 +12:00
parent 35acb5de76
commit af4579b4d4
3 changed files with 64 additions and 16 deletions

View File

@ -102,8 +102,7 @@
"openInNewTab": "Open in New Tab",
"dontAskMeAgain": "Don't ask me again",
"areYouSure": "Are you sure?",
"imagePrompt": "Image Prompt",
"clearNodes": "Are you sure you want to clear all nodes?"
"imagePrompt": "Image Prompt"
},
"gallery": {
"generations": "Generations",
@ -617,6 +616,9 @@
"nodesLoaded": "Nodes Loaded",
"nodesNotValidGraph": "Not a valid InvokeAI Node Graph",
"nodesNotValidJSON": "Not a valid JSON",
"nodesCorruptedGraph": "Cannot load. Graph seems to be corrupted.",
"nodesUnrecognizedTypes": "Cannot load. Graph has unrecognized types",
"nodesBrokenConnections": "Cannot load. Some connections are broken.",
"nodesLoadedFailed": "Failed To Load Nodes",
"nodesCleared": "Nodes Cleared"
},
@ -705,6 +707,7 @@
"saveGraph": "Save Graph",
"loadGraph": "Load Graph (saved from Node Editor) (Do not copy-paste metadata)",
"clearGraph": "Clear Graph",
"clearGraphDesc": "Are you sure you want to clear all nodes?",
"zoomInNodes": "Zoom In",
"zoomOutNodes": "Zoom Out",
"fitViewportNodes": "Fit View",

View File

@ -66,7 +66,7 @@ const ClearGraphButton = () => {
</AlertDialogHeader>
<AlertDialogBody>
<Text>{t('common.clearGraph')}</Text>
<Text>{t('nodes.clearGraphDesc')}</Text>
</AlertDialogBody>
<AlertDialogFooter>

View File

@ -4,6 +4,7 @@ import IAIIconButton from 'common/components/IAIIconButton';
import { loadFileEdges, loadFileNodes } from 'features/nodes/store/nodesSlice';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import i18n from 'i18n';
import { memo, useCallback, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import { FaUpload } from 'react-icons/fa';
@ -13,25 +14,70 @@ interface JsonFile {
[key: string]: unknown;
}
function sanityCheckInvokeAIGraph(jsonFile: JsonFile): boolean {
function sanityCheckInvokeAIGraph(jsonFile: JsonFile): {
isValid: boolean;
message: string;
} {
// Check if primary keys exist
const keys = ['nodes', 'edges', 'viewport'];
for (const key of keys) {
if (!(key in jsonFile)) {
return false;
return {
isValid: false,
message: i18n.t('toast.nodesNotValidGraph'),
};
}
}
// Check if nodes and edges are arrays
if (!Array.isArray(jsonFile.nodes) || !Array.isArray(jsonFile.edges)) {
return false;
return {
isValid: false,
message: i18n.t('toast.nodesNotValidGraph'),
};
}
for (const node of jsonFile.nodes) {
if (!('data' in node)) {
return false;
// Check if data is present in nodes
const nodeKeys = ['data', 'type'];
const nodeTypes = ['invocation', 'progress_image'];
if (jsonFile.nodes.length > 0) {
for (const node of jsonFile.nodes) {
for (const nodeKey of nodeKeys) {
if (!(nodeKey in node)) {
return {
isValid: false,
message: i18n.t('toast.nodesNotValidGraph'),
};
}
if (nodeKey === 'type' && !nodeTypes.includes(node[nodeKey])) {
return {
isValid: false,
message: i18n.t('toast.nodesUnrecognizedTypes'),
};
}
}
}
}
return true;
// Check Edge Object
const edgeKeys = ['source', 'sourceHandle', 'target', 'targetHandle'];
if (jsonFile.edges.length > 0) {
for (const edge of jsonFile.edges) {
for (const edgeKey of edgeKeys) {
if (!(edgeKey in edge)) {
return {
isValid: false,
message: i18n.t('toast.nodesBrokenConnections'),
};
}
}
}
}
return {
isValid: true,
message: i18n.t('toast.nodesLoaded'),
};
}
const LoadGraphButton = () => {
@ -50,23 +96,22 @@ const LoadGraphButton = () => {
try {
const retrievedNodeTree = await JSON.parse(String(json));
const isSaneNodeTree = sanityCheckInvokeAIGraph(retrievedNodeTree);
const { isValid, message } =
sanityCheckInvokeAIGraph(retrievedNodeTree);
if (isSaneNodeTree) {
if (isValid) {
dispatch(loadFileNodes(retrievedNodeTree.nodes));
dispatch(loadFileEdges(retrievedNodeTree.edges));
fitView();
dispatch(
addToast(
makeToast({ title: t('toast.nodesLoaded'), status: 'success' })
)
addToast(makeToast({ title: message, status: 'success' }))
);
} else {
dispatch(
addToast(
makeToast({
title: t('toast.nodesNotValidGraph'),
title: message,
status: 'error',
})
)