diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 1a902a88b7..1391f35930 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -593,7 +593,10 @@ "metadataLoadFailed": "Failed to load metadata", "initialImageSet": "Initial Image Set", "initialImageNotSet": "Initial Image Not Set", - "initialImageNotSetDesc": "Could not load initial image" + "initialImageNotSetDesc": "Could not load initial image", + "nodesSaved": "Nodes Saved", + "nodesLoaded": "Nodes Loaded", + "nodesLoadedFailed": "Failed To Load Nodes" }, "tooltip": { "feature": { @@ -676,6 +679,8 @@ "swapSizes": "Swap Sizes" }, "nodes": { - "reloadSchema": "Reload Schema" + "reloadSchema": "Reload Schema", + "saveNodes": "Save Nodes", + "loadNodes": "Load Nodes" } } diff --git a/invokeai/frontend/web/src/features/nodes/components/Flow.tsx b/invokeai/frontend/web/src/features/nodes/components/Flow.tsx index c6aa04bd24..7b0718182b 100644 --- a/invokeai/frontend/web/src/features/nodes/components/Flow.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/Flow.tsx @@ -10,7 +10,6 @@ import { OnInit, OnNodesChange, ReactFlow, - ReactFlowInstance, } from 'reactflow'; import { connectionEnded, @@ -18,6 +17,7 @@ import { connectionStarted, edgesChanged, nodesChanged, + setEditorInstance, } from '../store/nodesSlice'; import { InvocationComponent } from './InvocationComponent'; import ProgressImageNode from './ProgressImageNode'; @@ -69,11 +69,13 @@ export const Flow = () => { dispatch(connectionEnded()); }, [dispatch]); - const onInit: OnInit = useCallback((v: ReactFlowInstance) => { - if (v) { - v.fitView(); - } - }, []); + const onInit: OnInit = useCallback( + (v) => { + dispatch(setEditorInstance(v)); + if (v) v.fitView(); + }, + [dispatch] + ); return ( { return ( @@ -13,6 +14,8 @@ const TopCenterPanel = () => { + + ); diff --git a/invokeai/frontend/web/src/features/nodes/components/ui/LoadNodesButton.tsx b/invokeai/frontend/web/src/features/nodes/components/ui/LoadNodesButton.tsx new file mode 100644 index 0000000000..10aecc9fcc --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/ui/LoadNodesButton.tsx @@ -0,0 +1,79 @@ +import { FileButton } from '@mantine/core'; +import { makeToast } from 'app/components/Toaster'; +import { useAppDispatch } from 'app/store/storeHooks'; +import IAIIconButton from 'common/components/IAIIconButton'; +import { loadFileEdges, loadFileNodes } from 'features/nodes/store/nodesSlice'; +import { addToast } from 'features/system/store/systemSlice'; +import { memo, useCallback, useRef } from 'react'; +import { useTranslation } from 'react-i18next'; +import { FaUpload } from 'react-icons/fa'; +import { useReactFlow } from 'reactflow'; + +const LoadNodesButton = () => { + const { t } = useTranslation(); + const dispatch = useAppDispatch(); + const { fitView } = useReactFlow(); + + const uploadedFileRef = useRef<() => void>(null); + + const restoreJSONToEditor = useCallback( + (v: File | null) => { + if (!v) return; + const reader = new FileReader(); + reader.onload = async () => { + const json = reader.result; + const retrievedNodeTree = await JSON.parse(String(json)); + + if (!retrievedNodeTree) { + dispatch( + addToast( + makeToast({ + title: t('toast.nodesLoadedFailed'), + status: 'error', + }) + ) + ); + } + + if (retrievedNodeTree) { + dispatch(loadFileNodes(retrievedNodeTree.nodes)); + dispatch(loadFileEdges(retrievedNodeTree.edges)); + fitView(); + + dispatch( + addToast( + makeToast({ title: t('toast.nodesLoaded'), status: 'success' }) + ) + ); + } + + // Cleanup + reader.abort(); + }; + + reader.readAsText(v); + + // Cleanup + uploadedFileRef.current?.(); + }, + [fitView, dispatch, t] + ); + return ( + + {(props) => ( + } + tooltip={t('nodes.loadNodes')} + aria-label={t('nodes.loadNodes')} + {...props} + /> + )} + + ); +}; + +export default memo(LoadNodesButton); diff --git a/invokeai/frontend/web/src/features/nodes/components/ui/SaveNodesButton.tsx b/invokeai/frontend/web/src/features/nodes/components/ui/SaveNodesButton.tsx new file mode 100644 index 0000000000..14bf0a1ce8 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/ui/SaveNodesButton.tsx @@ -0,0 +1,45 @@ +import { RootState } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import IAIIconButton from 'common/components/IAIIconButton'; +import { map, omit } from 'lodash-es'; +import { memo, useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { FaSave } from 'react-icons/fa'; + +const SaveNodesButton = () => { + const { t } = useTranslation(); + const editorInstance = useAppSelector( + (state: RootState) => state.nodes.editorInstance + ); + + const saveEditorToJSON = useCallback(() => { + if (editorInstance) { + const editorState = editorInstance.toObject(); + + editorState.edges = map(editorState.edges, (edge) => { + return omit(edge, ['style']); + }); + + const nodeSetupJSON = new Blob([JSON.stringify(editorState)]); + const nodeDownloadElement = document.createElement('a'); + nodeDownloadElement.href = URL.createObjectURL(nodeSetupJSON); + nodeDownloadElement.download = 'MyNodes.json'; + document.body.appendChild(nodeDownloadElement); + nodeDownloadElement.click(); + // Cleanup + nodeDownloadElement.remove(); + } + }, [editorInstance]); + + return ( + } + fontSize={18} + tooltip={t('nodes.saveNodes')} + aria-label={t('nodes.saveNodes')} + onClick={saveEditorToJSON} + /> + ); +}; + +export default memo(SaveNodesButton); diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index 4fa69c626b..094a43b944 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -13,6 +13,7 @@ import { Node, NodeChange, OnConnectStartParams, + ReactFlowInstance, } from 'reactflow'; import { receivedOpenAPISchema } from 'services/api/thunks/schema'; import { ImageField } from 'services/api/types'; @@ -25,6 +26,7 @@ export type NodesState = { invocationTemplates: Record; connectionStartParams: OnConnectStartParams | null; shouldShowGraphOverlay: boolean; + editorInstance: ReactFlowInstance | undefined; }; export const initialNodesState: NodesState = { @@ -34,6 +36,7 @@ export const initialNodesState: NodesState = { invocationTemplates: {}, connectionStartParams: null, shouldShowGraphOverlay: false, + editorInstance: undefined, }; const nodesSlice = createSlice({ @@ -121,6 +124,15 @@ const nodesSlice = createSlice({ nodeEditorReset: () => { return { ...initialNodesState }; }, + setEditorInstance: (state, action) => { + state.editorInstance = action.payload; + }, + loadFileNodes: (state, action: PayloadAction[]>) => { + state.nodes = action.payload; + }, + loadFileEdges: (state, action: PayloadAction) => { + state.edges = action.payload; + }, }, extraReducers: (builder) => { builder.addCase(receivedOpenAPISchema.fulfilled, (state, action) => { @@ -141,6 +153,9 @@ export const { nodeTemplatesBuilt, nodeEditorReset, imageCollectionFieldValueChanged, + setEditorInstance, + loadFileNodes, + loadFileEdges, } = nodesSlice.actions; export default nodesSlice.reducer;