diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useBuildNodeData.ts b/invokeai/frontend/web/src/features/nodes/hooks/useBuildNodeData.ts index 694261d943..46f11b3823 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useBuildNodeData.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useBuildNodeData.ts @@ -9,7 +9,7 @@ import { buildNotesNode, } from '../store/util/buildNodeData'; import { DRAG_HANDLE_CLASSNAME, NODE_WIDTH } from '../types/constants'; -import { AnyNodeData, InvocationTemplate } from '../types/invocation'; +import { AnyNode, InvocationTemplate } from '../types/invocation'; const templatesSelector = createSelector( [(state: RootState) => state.nodes], (nodes) => nodes.nodeTemplates @@ -26,7 +26,7 @@ export const useBuildNodeData = () => { return useCallback( // string here is "any invocation type" - (type: string | 'current_image' | 'notes'): Node => { + (type: string | 'current_image' | 'notes'): AnyNode => { let _x = window.innerWidth / 2; let _y = window.innerHeight / 2; diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index 0c21d02fed..f91fe04fde 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -28,7 +28,7 @@ import { appSocketQueueItemStatusChanged, } from 'services/events/actions'; import { v4 as uuidv4 } from 'uuid'; -import { DRAG_HANDLE_CLASSNAME } from '../types/constants'; +import { SHARED_NODE_PROPERTIES } from '../types/constants'; import { BoardFieldValue, BooleanFieldValue, @@ -50,7 +50,7 @@ import { VAEModelFieldValue, } from '../types/field'; import { - AnyNodeData, + AnyNode, InvocationTemplate, isInvocationNode, isNotesNode, @@ -157,7 +157,7 @@ const nodesSlice = createSlice({ } state.nodes[nodeIndex] = action.payload.node; }, - nodeAdded: (state, action: PayloadAction>) => { + nodeAdded: (state, action: PayloadAction) => { const node = action.payload; const position = findUnoccupiedPosition( state.nodes, @@ -520,7 +520,7 @@ const nodesSlice = createSlice({ state.edges = applyEdgeChanges(edgeChanges, state.edges); } }, - nodesDeleted: (state, action: PayloadAction[]>) => { + nodesDeleted: (state, action: PayloadAction) => { action.payload.forEach((node) => { state.workflow.exposedFields = state.workflow.exposedFields.filter( (f) => f.nodeId !== node.id @@ -731,7 +731,7 @@ const nodesSlice = createSlice({ state.nodes = applyNodeChanges( nodes.map((node) => ({ - item: { ...node, dragHandle: `.${DRAG_HANDLE_CLASSNAME}` }, + item: { ...node, ...SHARED_NODE_PROPERTIES }, type: 'add', })), [] diff --git a/invokeai/frontend/web/src/features/nodes/store/types.ts b/invokeai/frontend/web/src/features/nodes/store/types.ts index b865b9d3a1..278be3c498 100644 --- a/invokeai/frontend/web/src/features/nodes/store/types.ts +++ b/invokeai/frontend/web/src/features/nodes/store/types.ts @@ -1,6 +1,4 @@ import { - Edge, - Node, OnConnectStartParams, SelectionMode, Viewport, @@ -8,16 +6,16 @@ import { } from 'reactflow'; import { FieldIdentifier, FieldType } from '../types/field'; import { - AnyNodeData, - InvocationEdgeExtra, + AnyNode, + InvocationNodeEdge, InvocationTemplate, NodeExecutionState, } from '../types/invocation'; import { WorkflowV2 } from '../types/workflow'; export type NodesState = { - nodes: Node[]; - edges: Edge[]; + nodes: AnyNode[]; + edges: InvocationNodeEdge[]; nodeTemplates: Record; connectionStartParams: OnConnectStartParams | null; connectionStartFieldType: FieldType | null; @@ -37,8 +35,8 @@ export type NodesState = { isReady: boolean; mouseOverField: FieldIdentifier | null; mouseOverNode: string | null; - nodesToCopy: Node[]; - edgesToCopy: Edge[]; + nodesToCopy: AnyNode[]; + edgesToCopy: InvocationNodeEdge[]; isAddNodePopoverOpen: boolean; addNewNodePosition: XYPosition | null; selectionMode: SelectionMode; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/buildNodeData.ts b/invokeai/frontend/web/src/features/nodes/store/util/buildNodeData.ts index 5328f789ad..c2582600c3 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/buildNodeData.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/buildNodeData.ts @@ -1,4 +1,4 @@ -import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants'; +import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants'; import { FieldInputInstance, FieldOutputInstance, @@ -14,10 +14,6 @@ import { reduce } from 'lodash-es'; import { Node, XYPosition } from 'reactflow'; import { v4 as uuidv4 } from 'uuid'; -export const SHARED_NODE_PROPERTIES: Partial = { - dragHandle: `.${DRAG_HANDLE_CLASSNAME}`, -}; - export const buildNotesNode = (position: XYPosition): Node => { const nodeId = uuidv4(); const node: Node = { diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts index a97899de91..27cb4fa778 100644 --- a/invokeai/frontend/web/src/features/nodes/types/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts @@ -1,3 +1,5 @@ +import { Node } from 'reactflow'; + /** * How long to wait before showing a tooltip when hovering a field handle. */ @@ -14,6 +16,13 @@ export const NODE_WIDTH = 320; */ export const DRAG_HANDLE_CLASSNAME = 'node-drag-handle'; +/** + * reactflow-specifc properties shared between all node types. + */ +export const SHARED_NODE_PROPERTIES: Partial = { + dragHandle: `.${DRAG_HANDLE_CLASSNAME}`, +}; + /** * Helper for getting the kind of a field. */ diff --git a/invokeai/frontend/web/src/features/nodes/types/invocation.ts b/invokeai/frontend/web/src/features/nodes/types/invocation.ts index 70403169de..5d22e64545 100644 --- a/invokeai/frontend/web/src/features/nodes/types/invocation.ts +++ b/invokeai/frontend/web/src/features/nodes/types/invocation.ts @@ -1,4 +1,4 @@ -import { Node } from 'reactflow'; +import { Edge, Node } from 'reactflow'; import { z } from 'zod'; import { zProgressImage } from './common'; import { @@ -64,16 +64,16 @@ export type InvocationNodeData = z.infer; export type CurrentImageNodeData = z.infer; export type AnyNodeData = z.infer; -export const isInvocationNode = ( - node?: Node -): node is Node => +export type InvocationNode = Node; +export type NotesNode = Node; +export type CurrentImageNode = Node; +export type AnyNode = Node; + +export const isInvocationNode = (node?: AnyNode): node is InvocationNode => Boolean(node && node.type === 'invocation'); -export const isNotesNode = ( - node?: Node -): node is Node => Boolean(node && node.type === 'notes'); -export const isCurrentImageNode = ( - node?: Node -): node is Node => +export const isNotesNode = (node?: AnyNode): node is NotesNode => + Boolean(node && node.type === 'notes'); +export const isCurrentImageNode = (node?: AnyNode): node is CurrentImageNode => Boolean(node && node.type === 'current_image'); export const isInvocationNodeData = ( node?: AnyNodeData @@ -101,8 +101,9 @@ export type NodeStatus = z.infer; // #endregion // #region Edges -export const zInvocationEdgeExtra = z.object({ +export const zInvocationNodeEdgeExtra = z.object({ type: z.union([z.literal('default'), z.literal('collapsed')]), }); -export type InvocationEdgeExtra = z.infer; +export type InvocationNodeEdgeExtra = z.infer; +export type InvocationNodeEdge = Edge; // #endregion