feat(ui): edge labels

Add setting to render labels with format `Source Node label -> Target Node label` on edges.
This commit is contained in:
psychedelicious 2024-04-15 19:26:29 +10:00
parent 7cf788e658
commit b508945b11
6 changed files with 103 additions and 27 deletions

View File

@ -770,6 +770,8 @@
"float": "Float", "float": "Float",
"fullyContainNodes": "Fully Contain Nodes to Select", "fullyContainNodes": "Fully Contain Nodes to Select",
"fullyContainNodesHelp": "Nodes must be fully inside the selection box to be selected", "fullyContainNodesHelp": "Nodes must be fully inside the selection box to be selected",
"showEdgeLabels": "Show Edge Labels",
"showEdgeLabelsHelp": "Show labels on edges, indicating the connected nodes",
"hideLegendNodes": "Hide Field Type Legend", "hideLegendNodes": "Hide Field Type Legend",
"hideMinimapnodes": "Hide MiniMap", "hideMinimapnodes": "Hide MiniMap",
"inputMayOnlyHaveOneConnection": "Input may only have one connection", "inputMayOnlyHaveOneConnection": "Input may only have one connection",

View File

@ -1,8 +1,9 @@
import { Flex, Text } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import type { CSSProperties } from 'react'; import type { CSSProperties } from 'react';
import { memo, useMemo } from 'react'; import { memo, useMemo } from 'react';
import type { EdgeProps } from 'reactflow'; import type { EdgeProps } from 'reactflow';
import { BaseEdge, getBezierPath } from 'reactflow'; import { BaseEdge, EdgeLabelRenderer, getBezierPath } from 'reactflow';
import { makeEdgeSelector } from './util/makeEdgeSelector'; import { makeEdgeSelector } from './util/makeEdgeSelector';
@ -25,9 +26,10 @@ const InvocationDefaultEdge = ({
[source, sourceHandleId, target, targetHandleId, selected] [source, sourceHandleId, target, targetHandleId, selected]
); );
const { isSelected, shouldAnimate, stroke } = useAppSelector(selector); const { isSelected, shouldAnimate, stroke, label } = useAppSelector(selector);
const shouldShowEdgeLabels = useAppSelector((s) => s.nodes.shouldShowEdgeLabels);
const [edgePath] = getBezierPath({ const [edgePath, labelX, labelY] = getBezierPath({
sourceX, sourceX,
sourceY, sourceY,
sourcePosition, sourcePosition,
@ -47,7 +49,33 @@ const InvocationDefaultEdge = ({
[isSelected, shouldAnimate, stroke] [isSelected, shouldAnimate, stroke]
); );
return <BaseEdge path={edgePath} markerEnd={markerEnd} style={edgeStyles} />; return (
<>
<BaseEdge path={edgePath} markerEnd={markerEnd} style={edgeStyles} />
{label && shouldShowEdgeLabels && (
<EdgeLabelRenderer>
<Flex
className="nodrag nopan"
pointerEvents="all"
position="absolute"
transform={`translate(-50%, -50%) translate(${labelX}px,${labelY}px)`}
bg="base.800"
borderRadius="base"
borderWidth={1}
borderColor={isSelected ? 'undefined' : 'transparent'}
opacity={isSelected ? 1 : 0.5}
py={1}
px={3}
shadow="md"
>
<Text size="sm" fontWeight="semibold" color={isSelected ? 'base.100' : 'base.300'}>
{label}
</Text>
</Flex>
</EdgeLabelRenderer>
)}
</>
);
}; };
export default memo(InvocationDefaultEdge); export default memo(InvocationDefaultEdge);

View File

@ -1,7 +1,7 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectFieldOutputTemplate } from 'features/nodes/store/selectors'; import { selectFieldOutputTemplate, selectNodeTemplate } from 'features/nodes/store/selectors';
import { isInvocationNode } from 'features/nodes/types/invocation'; import { isInvocationNode } from 'features/nodes/types/invocation';
import { getFieldColor } from './getEdgeColor'; import { getFieldColor } from './getEdgeColor';
@ -10,6 +10,7 @@ const defaultReturnValue = {
isSelected: false, isSelected: false,
shouldAnimate: false, shouldAnimate: false,
stroke: colorTokenToCssVar('base.500'), stroke: colorTokenToCssVar('base.500'),
label: '',
}; };
export const makeEdgeSelector = ( export const makeEdgeSelector = (
@ -19,25 +20,34 @@ export const makeEdgeSelector = (
targetHandleId: string | null | undefined, targetHandleId: string | null | undefined,
selected?: boolean selected?: boolean
) => ) =>
createMemoizedSelector(selectNodesSlice, (nodes): { isSelected: boolean; shouldAnimate: boolean; stroke: string } => { createMemoizedSelector(
const sourceNode = nodes.nodes.find((node) => node.id === source); selectNodesSlice,
const targetNode = nodes.nodes.find((node) => node.id === target); (nodes): { isSelected: boolean; shouldAnimate: boolean; stroke: string; label: string } => {
const sourceNode = nodes.nodes.find((node) => node.id === source);
const targetNode = nodes.nodes.find((node) => node.id === target);
const isInvocationToInvocationEdge = isInvocationNode(sourceNode) && isInvocationNode(targetNode); const isInvocationToInvocationEdge = isInvocationNode(sourceNode) && isInvocationNode(targetNode);
const isSelected = Boolean(sourceNode?.selected || targetNode?.selected || selected); const isSelected = Boolean(sourceNode?.selected || targetNode?.selected || selected);
if (!sourceNode || !sourceHandleId) { if (!sourceNode || !sourceHandleId || !targetNode || !targetHandleId) {
return defaultReturnValue; return defaultReturnValue;
}
const outputFieldTemplate = selectFieldOutputTemplate(nodes, sourceNode.id, sourceHandleId);
const sourceType = isInvocationToInvocationEdge ? outputFieldTemplate?.type : undefined;
const stroke = sourceType && nodes.shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500');
const sourceNodeTemplate = selectNodeTemplate(nodes, sourceNode.id);
const targetNodeTemplate = selectNodeTemplate(nodes, targetNode.id);
const label = `${sourceNodeTemplate?.title || sourceNode.data?.label} -> ${targetNodeTemplate?.title || targetNode.data?.label}`;
return {
isSelected,
shouldAnimate: nodes.shouldAnimateEdges && isSelected,
stroke,
label,
};
} }
);
const outputFieldTemplate = selectFieldOutputTemplate(nodes, sourceNode.id, sourceHandleId);
const sourceType = isInvocationToInvocationEdge ? outputFieldTemplate?.type : undefined;
const stroke = sourceType && nodes.shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500');
return {
isSelected,
shouldAnimate: nodes.shouldAnimateEdges && isSelected,
stroke,
};
});

View File

@ -24,6 +24,7 @@ import {
selectNodesSlice, selectNodesSlice,
shouldAnimateEdgesChanged, shouldAnimateEdgesChanged,
shouldColorEdgesChanged, shouldColorEdgesChanged,
shouldShowEdgeLabelsChanged,
shouldSnapToGridChanged, shouldSnapToGridChanged,
shouldValidateGraphChanged, shouldValidateGraphChanged,
} from 'features/nodes/store/nodesSlice'; } from 'features/nodes/store/nodesSlice';
@ -35,12 +36,20 @@ import { SelectionMode } from 'reactflow';
const formLabelProps: FormLabelProps = { flexGrow: 1 }; const formLabelProps: FormLabelProps = { flexGrow: 1 };
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => { const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
const { shouldAnimateEdges, shouldValidateGraph, shouldSnapToGrid, shouldColorEdges, selectionMode } = nodes; const {
shouldAnimateEdges,
shouldValidateGraph,
shouldSnapToGrid,
shouldColorEdges,
shouldShowEdgeLabels,
selectionMode,
} = nodes;
return { return {
shouldAnimateEdges, shouldAnimateEdges,
shouldValidateGraph, shouldValidateGraph,
shouldSnapToGrid, shouldSnapToGrid,
shouldColorEdges, shouldColorEdges,
shouldShowEdgeLabels,
selectionModeIsChecked: selectionMode === SelectionMode.Full, selectionModeIsChecked: selectionMode === SelectionMode.Full,
}; };
}); });
@ -52,8 +61,14 @@ type Props = {
const WorkflowEditorSettings = ({ children }: Props) => { const WorkflowEditorSettings = ({ children }: Props) => {
const { isOpen, onOpen, onClose } = useDisclosure(); const { isOpen, onOpen, onClose } = useDisclosure();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { shouldAnimateEdges, shouldValidateGraph, shouldSnapToGrid, shouldColorEdges, selectionModeIsChecked } = const {
useAppSelector(selector); shouldAnimateEdges,
shouldValidateGraph,
shouldSnapToGrid,
shouldColorEdges,
shouldShowEdgeLabels,
selectionModeIsChecked,
} = useAppSelector(selector);
const handleChangeShouldValidate = useCallback( const handleChangeShouldValidate = useCallback(
(e: ChangeEvent<HTMLInputElement>) => { (e: ChangeEvent<HTMLInputElement>) => {
@ -90,6 +105,13 @@ const WorkflowEditorSettings = ({ children }: Props) => {
[dispatch] [dispatch]
); );
const handleChangeShouldShowEdgeLabels = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(shouldShowEdgeLabelsChanged(e.target.checked));
},
[dispatch]
);
const { t } = useTranslation(); const { t } = useTranslation();
return ( return (
@ -137,6 +159,14 @@ const WorkflowEditorSettings = ({ children }: Props) => {
<FormHelperText>{t('nodes.fullyContainNodesHelp')}</FormHelperText> <FormHelperText>{t('nodes.fullyContainNodesHelp')}</FormHelperText>
</FormControl> </FormControl>
<Divider /> <Divider />
<FormControl>
<Flex w="full">
<FormLabel>{t('nodes.showEdgeLabels')}</FormLabel>
<Switch isChecked={shouldShowEdgeLabels} onChange={handleChangeShouldShowEdgeLabels} />
</Flex>
<FormHelperText>{t('nodes.showEdgeLabelsHelp')}</FormHelperText>
</FormControl>
<Divider />
<Heading size="sm" pt={4}> <Heading size="sm" pt={4}>
{t('common.advanced')} {t('common.advanced')}
</Heading> </Heading>

View File

@ -103,6 +103,7 @@ const initialNodesState: NodesState = {
shouldAnimateEdges: true, shouldAnimateEdges: true,
shouldSnapToGrid: false, shouldSnapToGrid: false,
shouldColorEdges: true, shouldColorEdges: true,
shouldShowEdgeLabels: false,
isAddNodePopoverOpen: false, isAddNodePopoverOpen: false,
nodeOpacity: 1, nodeOpacity: 1,
selectedNodes: [], selectedNodes: [],
@ -549,6 +550,9 @@ export const nodesSlice = createSlice({
shouldAnimateEdgesChanged: (state, action: PayloadAction<boolean>) => { shouldAnimateEdgesChanged: (state, action: PayloadAction<boolean>) => {
state.shouldAnimateEdges = action.payload; state.shouldAnimateEdges = action.payload;
}, },
shouldShowEdgeLabelsChanged: (state, action: PayloadAction<boolean>) => {
state.shouldShowEdgeLabels = action.payload;
},
shouldSnapToGridChanged: (state, action: PayloadAction<boolean>) => { shouldSnapToGridChanged: (state, action: PayloadAction<boolean>) => {
state.shouldSnapToGrid = action.payload; state.shouldSnapToGrid = action.payload;
}, },
@ -831,6 +835,7 @@ export const {
viewportChanged, viewportChanged,
edgeAdded, edgeAdded,
nodeTemplatesBuilt, nodeTemplatesBuilt,
shouldShowEdgeLabelsChanged,
} = nodesSlice.actions; } = nodesSlice.actions;
// This is used for tracking `state.workflow.isTouched` // This is used for tracking `state.workflow.isTouched`

View File

@ -32,6 +32,7 @@ export type NodesState = {
isAddNodePopoverOpen: boolean; isAddNodePopoverOpen: boolean;
addNewNodePosition: XYPosition | null; addNewNodePosition: XYPosition | null;
selectionMode: SelectionMode; selectionMode: SelectionMode;
shouldShowEdgeLabels: boolean;
}; };
export type WorkflowMode = 'edit' | 'view'; export type WorkflowMode = 'edit' | 'view';