From 7fcf475aec80b44c1cb58767254d87bff0b22432 Mon Sep 17 00:00:00 2001
From: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
Date: Thu, 16 Nov 2023 12:42:25 +1100
Subject: [PATCH] feat(ui): add Update All Nodes button
---
invokeai/frontend/web/public/locales/en.json | 3 +
.../middleware/listenerMiddleware/index.ts | 2 +
.../listeners/updateAllNodesRequested.ts | 52 +++++++
.../flow/panels/TopLeftPanel/TopLeftPanel.tsx | 14 +-
.../inspector/InspectorDetailsTab.tsx | 2 +-
.../nodes/hooks/useGetNodesNeedUpdate.ts | 25 ++++
.../features/nodes/hooks/useNodeVersion.ts | 127 ++++++++++++------
.../web/src/features/nodes/store/actions.ts | 4 +
8 files changed, 183 insertions(+), 46 deletions(-)
create mode 100644 invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts
create mode 100644 invokeai/frontend/web/src/features/nodes/hooks/useGetNodesNeedUpdate.ts
diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json
index 94e05f791c..561b577a46 100644
--- a/invokeai/frontend/web/public/locales/en.json
+++ b/invokeai/frontend/web/public/locales/en.json
@@ -920,7 +920,10 @@
"unknownTemplate": "Unknown Template",
"unkownInvocation": "Unknown Invocation type",
"updateNode": "Update Node",
+ "updateAllNodes": "Update All Nodes",
"updateApp": "Update App",
+ "unableToUpdateNodes_one": "Unable to update {{count}} node",
+ "unableToUpdateNodes_other": "Unable to update {{count}} nodes",
"vaeField": "Vae",
"vaeFieldDescription": "Vae submodel.",
"vaeModelField": "VAE",
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts
index 772ea216c0..9c1727fc79 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts
@@ -72,6 +72,7 @@ import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSa
import { addTabChangedListener } from './listeners/tabChanged';
import { addUpscaleRequestedListener } from './listeners/upscaleRequested';
import { addWorkflowLoadedListener } from './listeners/workflowLoaded';
+import { addUpdateAllNodesRequestedListener } from './listeners/updateAllNodesRequested';
export const listenerMiddleware = createListenerMiddleware();
@@ -178,6 +179,7 @@ addReceivedOpenAPISchemaListener();
// Workflows
addWorkflowLoadedListener();
+addUpdateAllNodesRequestedListener();
// DND
addImageDroppedListener();
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts
new file mode 100644
index 0000000000..ece6702ceb
--- /dev/null
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts
@@ -0,0 +1,52 @@
+import {
+ getNeedsUpdate,
+ updateNode,
+} from 'features/nodes/hooks/useNodeVersion';
+import { updateAllNodesRequested } from 'features/nodes/store/actions';
+import { nodeReplaced } from 'features/nodes/store/nodesSlice';
+import { startAppListening } from '..';
+import { logger } from 'app/logging/logger';
+import { addToast } from 'features/system/store/systemSlice';
+import { makeToast } from 'features/system/util/makeToast';
+import { t } from 'i18next';
+
+export const addUpdateAllNodesRequestedListener = () => {
+ startAppListening({
+ actionCreator: updateAllNodesRequested,
+ effect: (action, { dispatch, getState }) => {
+ const log = logger('nodes');
+ const nodes = getState().nodes.nodes;
+ const templates = getState().nodes.nodeTemplates;
+
+ let unableToUpdateCount = 0;
+
+ nodes.forEach((node) => {
+ const template = templates[node.data.type];
+ const needsUpdate = getNeedsUpdate(node, template);
+ const updatedNode = updateNode(node, template);
+ if (!updatedNode) {
+ if (needsUpdate) {
+ unableToUpdateCount++;
+ }
+ return;
+ }
+ dispatch(nodeReplaced({ nodeId: updatedNode.id, node: updatedNode }));
+ });
+
+ if (unableToUpdateCount) {
+ log.warn(
+ `Unable to update ${unableToUpdateCount} nodes. Please report this issue.`
+ );
+ dispatch(
+ addToast(
+ makeToast({
+ title: t('nodes.unableToUpdateNodes', {
+ count: unableToUpdateCount,
+ }),
+ })
+ )
+ );
+ }
+ },
+ });
+};
diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/panels/TopLeftPanel/TopLeftPanel.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/panels/TopLeftPanel/TopLeftPanel.tsx
index d355eab348..38aa9bbad7 100644
--- a/invokeai/frontend/web/src/features/nodes/components/flow/panels/TopLeftPanel/TopLeftPanel.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/flow/panels/TopLeftPanel/TopLeftPanel.tsx
@@ -3,15 +3,22 @@ import { useAppDispatch } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import { addNodePopoverOpened } from 'features/nodes/store/nodesSlice';
import { memo, useCallback } from 'react';
-import { FaPlus } from 'react-icons/fa';
+import { FaPlus, FaSync } from 'react-icons/fa';
import { useTranslation } from 'react-i18next';
+import IAIButton from 'common/components/IAIButton';
+import { useGetNodesNeedUpdate } from 'features/nodes/hooks/useGetNodesNeedUpdate';
+import { updateAllNodesRequested } from 'features/nodes/store/actions';
const TopLeftPanel = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
+ const nodesNeedUpdate = useGetNodesNeedUpdate();
const handleOpenAddNodePopover = useCallback(() => {
dispatch(addNodePopoverOpened());
}, [dispatch]);
+ const handleClickUpdateNodes = useCallback(() => {
+ dispatch(updateAllNodesRequested());
+ }, [dispatch]);
return (
@@ -21,6 +28,11 @@ const TopLeftPanel = () => {
icon={}
onClick={handleOpenAddNodePopover}
/>
+ {nodesNeedUpdate && (
+ } onClick={handleClickUpdateNodes}>
+ {t('nodes.updateAllNodes')}
+
+ )}
);
};
diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorDetailsTab.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorDetailsTab.tsx
index 9e765ff01e..397d1295b0 100644
--- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorDetailsTab.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorDetailsTab.tsx
@@ -107,7 +107,7 @@ const Content = (props: {
{props.node.data.version}
- {mayUpdate && (
+ {needsUpdate && (
{
+ const nodes = state.nodes.nodes;
+ const templates = state.nodes.nodeTemplates;
+
+ const needsUpdate = nodes.some((node) => {
+ const template = templates[node.data.type];
+ return getNeedsUpdate(node, template);
+ });
+ return needsUpdate;
+ },
+ defaultSelectorOptions
+);
+
+export const useGetNodesNeedUpdate = () => {
+ const getNeedsUpdate = useAppSelector(selector);
+ return getNeedsUpdate;
+};
diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeVersion.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeVersion.ts
index 60192de61d..1f213d6481 100644
--- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeVersion.ts
+++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeVersion.ts
@@ -3,20 +3,80 @@ import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { satisfies } from 'compare-versions';
+import { cloneDeep, defaultsDeep } from 'lodash-es';
import { useCallback, useMemo } from 'react';
+import { Node } from 'reactflow';
+import { AnyInvocationType } from 'services/events/types';
+import { nodeReplaced } from '../store/nodesSlice';
+import { buildNodeData } from '../store/util/buildNodeData';
import {
InvocationNodeData,
+ InvocationTemplate,
+ NodeData,
isInvocationNode,
zParsedSemver,
} from '../types/types';
-import { cloneDeep, defaultsDeep } from 'lodash-es';
-import { buildNodeData } from '../store/util/buildNodeData';
-import { AnyInvocationType } from 'services/events/types';
-import { Node } from 'reactflow';
-import { nodeReplaced } from '../store/nodesSlice';
+import { useAppToaster } from 'app/components/Toaster';
+import { useTranslation } from 'react-i18next';
+
+export const getNeedsUpdate = (
+ node?: Node,
+ template?: InvocationTemplate
+) => {
+ if (!isInvocationNode(node) || !template) {
+ return false;
+ }
+ return node.data.version !== template.version;
+};
+
+export const getMayUpdateNode = (
+ node?: Node,
+ template?: InvocationTemplate
+) => {
+ const needsUpdate = getNeedsUpdate(node, template);
+ if (
+ !needsUpdate ||
+ !isInvocationNode(node) ||
+ !template ||
+ !node.data.version
+ ) {
+ return false;
+ }
+ const templateMajor = zParsedSemver.parse(template.version).major;
+
+ return satisfies(node.data.version, `^${templateMajor}`);
+};
+
+export const updateNode = (
+ node?: Node,
+ template?: InvocationTemplate
+) => {
+ const mayUpdate = getMayUpdateNode(node, template);
+ if (
+ !mayUpdate ||
+ !isInvocationNode(node) ||
+ !template ||
+ !node.data.version
+ ) {
+ return;
+ }
+
+ const defaults = buildNodeData(
+ node.data.type as AnyInvocationType,
+ node.position,
+ template
+ ) as Node;
+
+ const clone = cloneDeep(node);
+ clone.data.version = template.version;
+ defaultsDeep(clone, defaults);
+ return clone;
+};
export const useNodeVersion = (nodeId: string) => {
const dispatch = useAppDispatch();
+ const toast = useAppToaster();
+ const { t } = useTranslation();
const selector = useMemo(
() =>
createSelector(
@@ -33,48 +93,27 @@ export const useNodeVersion = (nodeId: string) => {
const { node, nodeTemplate } = useAppSelector(selector);
- const needsUpdate = useMemo(() => {
- if (!isInvocationNode(node) || !nodeTemplate) {
- return false;
- }
- return node.data.version !== nodeTemplate.version;
- }, [node, nodeTemplate]);
+ const needsUpdate = useMemo(
+ () => getNeedsUpdate(node, nodeTemplate),
+ [node, nodeTemplate]
+ );
- const mayUpdate = useMemo(() => {
- if (
- !needsUpdate ||
- !isInvocationNode(node) ||
- !nodeTemplate ||
- !node.data.version
- ) {
- return false;
- }
- const templateMajor = zParsedSemver.parse(nodeTemplate.version).major;
+ const mayUpdate = useMemo(
+ () => getMayUpdateNode(node, nodeTemplate),
+ [node, nodeTemplate]
+ );
- return satisfies(node.data.version, `^${templateMajor}`);
- }, [needsUpdate, node, nodeTemplate]);
-
- const updateNode = useCallback(() => {
- if (
- !mayUpdate ||
- !isInvocationNode(node) ||
- !nodeTemplate ||
- !node.data.version
- ) {
+ const _updateNode = useCallback(() => {
+ const needsUpdate = getNeedsUpdate(node, nodeTemplate);
+ const updatedNode = updateNode(node, nodeTemplate);
+ if (!updatedNode) {
+ if (needsUpdate) {
+ toast({ title: t('nodes.unableToUpdateNodes', { count: 1 }) });
+ }
return;
}
+ dispatch(nodeReplaced({ nodeId: updatedNode.id, node: updatedNode }));
+ }, [dispatch, node, nodeTemplate, t, toast]);
- const defaults = buildNodeData(
- node.data.type as AnyInvocationType,
- node.position,
- nodeTemplate
- ) as Node;
-
- const clone = cloneDeep(node);
- clone.data.version = nodeTemplate.version;
- defaultsDeep(clone, defaults);
- dispatch(nodeReplaced({ nodeId: clone.id, node: clone }));
- }, [dispatch, mayUpdate, node, nodeTemplate]);
-
- return { needsUpdate, mayUpdate, updateNode };
+ return { needsUpdate, mayUpdate, updateNode: _updateNode };
};
diff --git a/invokeai/frontend/web/src/features/nodes/store/actions.ts b/invokeai/frontend/web/src/features/nodes/store/actions.ts
index cf7ccf8238..0d75e6934d 100644
--- a/invokeai/frontend/web/src/features/nodes/store/actions.ts
+++ b/invokeai/frontend/web/src/features/nodes/store/actions.ts
@@ -21,3 +21,7 @@ export const isAnyGraphBuilt = isAnyOf(
export const workflowLoadRequested = createAction(
'nodes/workflowLoadRequested'
);
+
+export const updateAllNodesRequested = createAction(
+ 'nodes/updateAllNodesRequested'
+);