feat(ui): node styling, controls

- custom node controls
- fix some types
- fix badge colors via colorScheme
- style nodes
This commit is contained in:
psychedelicious 2023-04-22 20:29:42 +10:00
parent 94a07a8da7
commit 44a653925a
23 changed files with 193 additions and 79 deletions

View File

@ -10,7 +10,7 @@ import {
MenuItem,
IconButton,
} from '@chakra-ui/react';
import { FaPlus } from 'react-icons/fa';
import { FaEllipsisV, FaPlus } from 'react-icons/fa';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { nodeAdded } from '../store/nodesSlice';
import { cloneDeep, map } from 'lodash';
@ -18,6 +18,8 @@ import { RootState } from 'app/store';
import { useBuildInvocation } from '../hooks/useBuildInvocation';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/hooks/useToastWatcher';
import { IAIIconButton } from 'exports';
import { AnyInvocationType } from 'services/events/types';
export const AddNodeMenu = () => {
const dispatch = useAppDispatch();
@ -29,7 +31,7 @@ export const AddNodeMenu = () => {
const buildInvocation = useBuildInvocation();
const addNode = useCallback(
(nodeType: string) => {
(nodeType: AnyInvocationType) => {
const invocation = buildInvocation(nodeType);
if (!invocation) {
@ -48,7 +50,11 @@ export const AddNodeMenu = () => {
return (
<Menu>
<MenuButton as={IconButton} aria-label="Add Node" icon={<FaPlus />} />
<MenuButton
as={IAIIconButton}
aria-label="Add Node"
icon={<FaEllipsisV />}
/>
<MenuList overflowY="scroll" height={400}>
{map(invocationTemplates, ({ title, description, type }, key) => {
return (

View File

@ -19,11 +19,11 @@ const handleBaseStyles: CSSProperties = {
};
const inputHandleStyles: CSSProperties = {
left: '-1.6rem',
left: '-1rem',
};
const outputHandleStyles: CSSProperties = {
right: '-0.6rem',
right: '-0.5rem',
};
const requiredConnectionStyles: CSSProperties = {

View File

@ -20,12 +20,16 @@ import {
edgesChanged,
nodesChanged,
} from '../store/nodesSlice';
import { useCallback } from 'react';
import { useCallback, useState } from 'react';
import { InvocationComponent } from './InvocationComponent';
import { AddNodeMenu } from './AddNodeMenu';
import { FieldTypeLegend } from './FieldTypeLegend';
import { Button } from '@chakra-ui/react';
import { nodesGraphBuilt } from 'services/thunks/session';
import { IAIIconButton } from 'exports';
import { InfoIcon } from '@chakra-ui/icons';
import { ViewportControls } from './ViewportControls';
import NodeGraphOverlay from './NodeGraphOverlay';
const nodeTypes = { invocation: InvocationComponent };
@ -33,6 +37,9 @@ export const Flow = () => {
const dispatch = useAppDispatch();
const nodes = useAppSelector((state: RootState) => state.nodes.nodes);
const edges = useAppSelector((state: RootState) => state.nodes.edges);
const shouldShowGraphOverlay = useAppSelector(
(state: RootState) => state.nodes.shouldShowGraphOverlay
);
const onNodesChange: OnNodesChange = useCallback(
(changes) => {
@ -95,9 +102,12 @@ export const Flow = () => {
</Panel>
<Panel position="top-right">
<FieldTypeLegend />
{shouldShowGraphOverlay && <NodeGraphOverlay />}
</Panel>
<Panel position="bottom-left">
<ViewportControls />
</Panel>
<Background />
<Controls />
<MiniMap nodeStrokeWidth={3} zoomable pannable />
</ReactFlow>
);

View File

@ -12,7 +12,7 @@ export default function IAINodeHeader(props: IAINodeHeaderProps) {
const { nodeId, template } = props;
return (
<Flex
borderRadius="sm"
borderTopRadius="md"
justifyContent="space-between"
background="base.700"
px={2}
@ -20,7 +20,7 @@ export default function IAINodeHeader(props: IAINodeHeaderProps) {
alignItems="center"
>
<Tooltip label={nodeId}>
<Heading size="sm" fontWeight={600} color="base.100">
<Heading size="xs" fontWeight={600} color="base.100">
{template.current?.title}
</Heading>
</Tooltip>
@ -30,7 +30,7 @@ export default function IAINodeHeader(props: IAINodeHeaderProps) {
hasArrow
shouldWrapChildren
>
<Icon color="base.300" as={FaInfoCircle} />
<Icon color="base.300" as={FaInfoCircle} h="min-content" />
</Tooltip>
</Flex>
);

View File

@ -15,11 +15,13 @@ import {
HStack,
Tooltip,
Icon,
Divider,
} from '@chakra-ui/react';
import { FieldHandle } from '../FieldHandle';
import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection';
import { InputFieldComponent } from '../InputFieldComponent';
import { FaInfoCircle } from 'react-icons/fa';
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
interface IAINodeInputProps {
nodeId: string;
@ -35,13 +37,8 @@ function IAINodeInput(props: IAINodeInputProps) {
return (
<Box
p={2}
key={input.id}
position="relative"
borderWidth={1}
borderRadius="md"
borderLeft="none"
borderRight="none"
borderColor={
!template
? 'error.400'
@ -63,14 +60,14 @@ function IAINodeInput(props: IAINodeInputProps) {
<>
<HStack justifyContent="space-between" alignItems="center">
<HStack>
<FormLabel>{template?.title}</FormLabel>
<Tooltip
label={template?.description}
placement="top"
hasArrow
shouldWrapChildren
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
>
<Icon color="base.400" as={FaInfoCircle} />
<FormLabel>{template?.title}</FormLabel>
</Tooltip>
</HStack>
<InputFieldComponent
@ -114,7 +111,7 @@ export default function IAINodeInputs(props: IAINodeInputsProps) {
const IAINodeInputsToRender: ReactNode[] = [];
const inputSockets = map(inputs);
inputSockets.forEach((inputSocket) => {
inputSockets.forEach((inputSocket, index) => {
const inputTemplate = template.current?.inputs[inputSocket.name];
const isConnected = Boolean(
@ -126,6 +123,10 @@ export default function IAINodeInputs(props: IAINodeInputsProps) {
}).length
);
if (index < inputSockets.length) {
IAINodeInputsToRender.push(<Divider />);
}
IAINodeInputsToRender.push(
<IAINodeInput
nodeId={nodeId}

View File

@ -5,12 +5,13 @@ export default function IAINodeResizer(props: NodeResizerProps) {
return (
<NodeResizeControl
style={{
position: 'relative',
position: 'absolute',
border: 'none',
background: 'none',
width: 10,
height: 10,
top: 10,
background: 'transparent',
width: 15,
height: 15,
bottom: 0,
right: 0,
}}
minWidth={350}
{...rest}

View File

@ -1,6 +1,6 @@
import { Box } from '@chakra-ui/react';
import { InputFieldTemplate, InputFieldValue } from '../types/types';
import { ArrayInputFieldComponent } from './fields/ArrayInputField.tsx';
import { ArrayInputFieldComponent } from './fields/ArrayInputFieldComponent';
import { BooleanInputFieldComponent } from './fields/BooleanInputFieldComponent';
import { EnumInputFieldComponent } from './fields/EnumInputFieldComponent';
import { ImageInputFieldComponent } from './fields/ImageInputFieldComponent';

View File

@ -1,9 +1,9 @@
import { NodeProps, NodeResizeControl } from 'reactflow';
import { Box, Flex, Icon } from '@chakra-ui/react';
import { Box, Flex, Icon, useToken } from '@chakra-ui/react';
import { FaExclamationCircle } from 'react-icons/fa';
import { InvocationValue } from '../types/types';
import { memo, useRef } from 'react';
import { memo, PropsWithChildren, useRef } from 'react';
import { useGetInvocationTemplate } from '../hooks/useInvocationTemplate';
import IAINodeOutputs from './IAINode/IAINodeOutputs';
import IAINodeInputs from './IAINode/IAINodeInputs';
@ -11,6 +11,31 @@ import IAINodeHeader from './IAINode/IAINodeHeader';
import { IoResize } from 'react-icons/io5';
import IAINodeResizer from './IAINode/IAINodeResizer';
type InvocationComponentWrapperProps = PropsWithChildren & {
selected: boolean;
};
const InvocationComponentWrapper = (props: InvocationComponentWrapperProps) => {
const [nodeSelectedOutline, nodeShadow] = useToken('shadows', [
'nodeSelectedOutline',
'dark-lg',
]);
return (
<Box
sx={{
position: 'relative',
borderRadius: 'md',
boxShadow: props.selected
? `${nodeSelectedOutline}, ${nodeShadow}`
: `${nodeShadow}`,
}}
>
{props.children}
</Box>
);
};
export const InvocationComponent = memo((props: NodeProps<InvocationValue>) => {
const { id: nodeId, data, selected } = props;
const { type, inputs, outputs } = data;
@ -22,41 +47,31 @@ export const InvocationComponent = memo((props: NodeProps<InvocationValue>) => {
if (!template.current) {
return (
<Box
sx={{
padding: 4,
bg: 'base.800',
borderRadius: 'md',
boxShadow: 'dark-lg',
borderWidth: 2,
borderColor: selected ? 'base.400' : 'transparent',
}}
>
<InvocationComponentWrapper selected={selected}>
<Flex sx={{ alignItems: 'center', justifyContent: 'center' }}>
<Icon color="base.400" boxSize={32} as={FaExclamationCircle}></Icon>
<IAINodeResizer />
</Flex>
</Box>
</InvocationComponentWrapper>
);
}
return (
<Box
sx={{
bg: 'base.800',
borderRadius: 'md',
boxShadow: 'dark-lg',
borderWidth: 2,
borderColor: selected ? 'base.400' : 'transparent',
}}
>
<Flex flexDirection="column" gap={2}>
<IAINodeHeader nodeId={nodeId} template={template} />
<InvocationComponentWrapper selected={selected}>
<IAINodeHeader nodeId={nodeId} template={template} />
<Flex
sx={{
flexDirection: 'column',
borderBottomRadius: 'md',
bg: 'base.800',
py: 2,
}}
>
<IAINodeOutputs nodeId={nodeId} outputs={outputs} template={template} />
<IAINodeInputs nodeId={nodeId} inputs={inputs} template={template} />
<IAINodeResizer />
</Flex>
</Box>
<IAINodeResizer />
</InvocationComponentWrapper>
);
});

View File

@ -3,15 +3,8 @@ import { Box } from '@chakra-ui/react';
import { ReactFlowProvider } from 'reactflow';
import { Flow } from './Flow';
import { useAppSelector } from 'app/storeHooks';
import { RootState } from 'app/store';
import { buildNodesGraph } from '../util/nodesGraphBuilder/buildNodesGraph';
const NodeEditor = () => {
const state = useAppSelector((state: RootState) => state);
const graph = buildNodesGraph(state);
return (
<Box
sx={{
@ -25,20 +18,6 @@ const NodeEditor = () => {
<ReactFlowProvider>
<Flow />
</ReactFlowProvider>
<Box
as="pre"
fontFamily="monospace"
position="absolute"
top={2}
left={2}
width="full"
height="full"
userSelect="none"
pointerEvents="none"
opacity={0.7}
>
<Box w="50%">{JSON.stringify(graph, null, 2)}</Box>
</Box>
</Box>
);
};

View File

@ -0,0 +1,28 @@
import { Box } from '@chakra-ui/react';
import { RootState } from 'app/store';
import { useAppSelector } from 'app/storeHooks';
import { buildNodesGraph } from '../util/nodesGraphBuilder/buildNodesGraph';
export default function NodeGraphOverlay() {
const state = useAppSelector((state: RootState) => state);
const graph = buildNodesGraph(state);
return (
<Box
as="pre"
fontFamily="monospace"
position="absolute"
top={10}
right={2}
userSelect="none"
opacity={0.7}
background="base.800"
p={2}
maxHeight={500}
overflowY="scroll"
borderRadius="md"
>
{JSON.stringify(graph, null, 2)}
</Box>
);
}

View File

@ -0,0 +1,57 @@
import { ButtonGroup } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { IAIIconButton } from 'exports';
import { useCallback } from 'react';
import { FaCode, FaExpand, FaMinus, FaPlus } from 'react-icons/fa';
import { useReactFlow } from 'reactflow';
import { shouldShowGraphOverlayChanged } from '../store/nodesSlice';
export const ViewportControls = () => {
const { zoomIn, zoomOut, fitView } = useReactFlow();
const dispatch = useAppDispatch();
const shouldShowGraphOverlay = useAppSelector(
(state) => state.nodes.shouldShowGraphOverlay
);
const handleClickedZoomIn = useCallback(() => {
zoomIn();
}, [zoomIn]);
const handleClickedZoomOut = useCallback(() => {
zoomOut();
}, [zoomOut]);
const handleClickedFitView = useCallback(() => {
fitView();
}, [fitView]);
const handleClickedToggleGraphOverlay = useCallback(() => {
dispatch(shouldShowGraphOverlayChanged(!shouldShowGraphOverlay));
}, [shouldShowGraphOverlay, dispatch]);
return (
<ButtonGroup isAttached orientation="vertical">
<IAIIconButton
onClick={handleClickedZoomIn}
aria-label="Zoom In"
icon={<FaPlus />}
/>
<IAIIconButton
onClick={handleClickedZoomOut}
aria-label="Zoom Out"
icon={<FaMinus />}
/>
<IAIIconButton
onClick={handleClickedFitView}
aria-label="Fit to Viewport"
icon={<FaExpand />}
/>
<IAIIconButton
isChecked={shouldShowGraphOverlay}
onClick={handleClickedToggleGraphOverlay}
aria-label="Show/Hide Graph"
icon={<FaCode />}
/>
</ButtonGroup>
);
};

View File

@ -1,7 +1,7 @@
import {
ArrayInputFieldTemplate,
ArrayInputFieldValue,
} from 'features/nodes/types';
} from 'features/nodes/types/types';
import { FaImage, FaList } from 'react-icons/fa';
import { FieldComponentProps } from './types';

View File

@ -4,7 +4,7 @@ import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import {
BooleanInputFieldTemplate,
BooleanInputFieldValue,
} from 'features/nodes/types';
} from 'features/nodes/types/types';
import { ChangeEvent } from 'react';
import { FieldComponentProps } from './types';

View File

@ -4,7 +4,7 @@ import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import {
EnumInputFieldTemplate,
EnumInputFieldValue,
} from 'features/nodes/types';
} from 'features/nodes/types/types';
import { ChangeEvent } from 'react';
import { FieldComponentProps } from './types';

View File

@ -8,7 +8,7 @@ import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import {
ImageInputFieldTemplate,
ImageInputFieldValue,
} from 'features/nodes/types';
} from 'features/nodes/types/types';
import { DragEvent, useCallback, useState } from 'react';
import { FaImage } from 'react-icons/fa';
import { ImageType } from 'services/api';

View File

@ -1,7 +1,7 @@
import {
LatentsInputFieldTemplate,
LatentsInputFieldValue,
} from 'features/nodes/types';
} from 'features/nodes/types/types';
import { FieldComponentProps } from './types';
export const LatentsInputFieldComponent = (

View File

@ -6,7 +6,7 @@ import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import {
ModelInputFieldTemplate,
ModelInputFieldValue,
} from 'features/nodes/types';
} from 'features/nodes/types/types';
import {
selectModelsById,
selectModelsIds,

View File

@ -12,7 +12,7 @@ import {
FloatInputFieldValue,
IntegerInputFieldTemplate,
IntegerInputFieldValue,
} from 'features/nodes/types';
} from 'features/nodes/types/types';
import { FieldComponentProps } from './types';
export const NumberInputFieldComponent = (

View File

@ -4,7 +4,7 @@ import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import {
StringInputFieldTemplate,
StringInputFieldValue,
} from 'features/nodes/types';
} from 'features/nodes/types/types';
import { ChangeEvent } from 'react';
import { FieldComponentProps } from './types';

View File

@ -24,6 +24,7 @@ export type NodesState = {
invocationTemplates: Record<string, InvocationTemplate>;
connectionStartParams: OnConnectStartParams | null;
lastGraph: Graph | null;
shouldShowGraphOverlay: boolean;
};
export const initialNodesState: NodesState = {
@ -33,6 +34,7 @@ export const initialNodesState: NodesState = {
invocationTemplates: {},
connectionStartParams: null,
lastGraph: null,
shouldShowGraphOverlay: false,
};
const nodesSlice = createSlice({
@ -77,6 +79,9 @@ const nodesSlice = createSlice({
state.nodes[nodeIndex].data.inputs[fieldName].value = value;
}
},
shouldShowGraphOverlayChanged: (state, action: PayloadAction<boolean>) => {
state.shouldShowGraphOverlay = action.payload;
},
},
extraReducers(builder) {
builder.addCase(receivedOpenAPISchema.fulfilled, (state, action) => {
@ -98,6 +103,7 @@ export const {
connectionMade,
connectionStarted,
connectionEnded,
shouldShowGraphOverlayChanged,
} = nodesSlice.actions;
export default nodesSlice.reducer;

View File

@ -22,46 +22,55 @@ const getColorTokenCssVariable = (color: string) =>
export const FIELDS: Record<FieldType, FieldUIConfig> = {
integer: {
color: 'red',
colorCssVar: getColorTokenCssVariable('red'),
title: 'Integer',
description: 'Integers are whole numbers, without a decimal point.',
},
float: {
color: 'orange',
colorCssVar: getColorTokenCssVariable('orange'),
title: 'Float',
description: 'Floats are numbers with a decimal point.',
},
string: {
color: 'yellow',
colorCssVar: getColorTokenCssVariable('yellow'),
title: 'String',
description: 'Strings are text.',
},
boolean: {
color: 'green',
colorCssVar: getColorTokenCssVariable('green'),
title: 'Boolean',
description: 'Booleans are true or false.',
},
enum: {
color: 'blue',
colorCssVar: getColorTokenCssVariable('blue'),
title: 'Enum',
description: 'Enums are values that may be one of a number of options.',
},
image: {
color: 'purple',
colorCssVar: getColorTokenCssVariable('purple'),
title: 'Image',
description: 'Images may be passed between nodes.',
},
latents: {
color: 'pink',
colorCssVar: getColorTokenCssVariable('pink'),
title: 'Latents',
description: 'Latents may be passed between nodes.',
},
model: {
color: 'teal',
colorCssVar: getColorTokenCssVariable('teal'),
title: 'Model',
description: 'Models are models.',
},
array: {
color: 'gray',
colorCssVar: getColorTokenCssVariable('gray'),
title: 'Array',
description: 'TODO: Array type description.',

View File

@ -39,6 +39,7 @@ export type InvocationTemplate = {
};
export type FieldUIConfig = {
color: string;
colorCssVar: string;
title: string;
description: string;

View File

@ -53,6 +53,7 @@ export const theme: ThemeOverride = {
working: `0 0 7px var(--invokeai-colors-working-400)`,
error: `0 0 7px var(--invokeai-colors-error-400)`,
},
nodeSelectedOutline: `0 0 0 2px var(--invokeai-colors-base-500)`,
},
colors: {
...invokeAIThemeColors,