fix(ui): improve node rendering performance

Previously the editor was using prop-drilling node data and templates to get values deep into nodes. This ended up causing very noticeable performance degradation. For example, any text entry fields were super laggy.

Refactor the whole thing to use memoized selectors via hooks. The hooks are mostly very narrow, returning only the data needed.

Data objects are never passed down, only node id and field name - sometimes the field kind ('input' or 'output').

The end result is a *much* smoother node editor with very minimal rerenders.
This commit is contained in:
psychedelicious 2023-08-16 22:18:48 +10:00
parent f7c92e1eff
commit f9b8b5cff2
42 changed files with 928 additions and 736 deletions

View File

@ -1,40 +1,34 @@
import { Flex } from '@chakra-ui/react'; import { Flex } from '@chakra-ui/react';
import { import { useFieldNames, useWithFooter } from 'features/nodes/hooks/useNodeData';
InvocationNodeData, import { memo } from 'react';
InvocationTemplate,
} from 'features/nodes/types/types';
import { map, some } from 'lodash-es';
import { memo, useMemo } from 'react';
import { NodeProps } from 'reactflow';
import InputField from '../fields/InputField'; import InputField from '../fields/InputField';
import OutputField from '../fields/OutputField'; import OutputField from '../fields/OutputField';
import NodeFooter, { FOOTER_FIELDS } from './NodeFooter'; import NodeFooter from './NodeFooter';
import NodeHeader from './NodeHeader'; import NodeHeader from './NodeHeader';
import NodeWrapper from './NodeWrapper'; import NodeWrapper from './NodeWrapper';
type Props = { type Props = {
nodeProps: NodeProps<InvocationNodeData>; nodeId: string;
nodeTemplate: InvocationTemplate; isOpen: boolean;
label: string;
type: string;
selected: boolean;
}; };
const InvocationNode = ({ nodeProps, nodeTemplate }: Props) => { const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
const { id: nodeId, data } = nodeProps; const inputFieldNames = useFieldNames(nodeId, 'input');
const { inputs, outputs, isOpen } = data; const outputFieldNames = useFieldNames(nodeId, 'output');
const withFooter = useWithFooter(nodeId);
const inputFields = useMemo(
() => map(inputs).filter((i) => i.name !== 'is_intermediate'),
[inputs]
);
const outputFields = useMemo(() => map(outputs), [outputs]);
const withFooter = useMemo(
() => some(outputs, (output) => FOOTER_FIELDS.includes(output.type)),
[outputs]
);
return ( return (
<NodeWrapper nodeProps={nodeProps}> <NodeWrapper nodeId={nodeId} selected={selected}>
<NodeHeader nodeProps={nodeProps} nodeTemplate={nodeTemplate} /> <NodeHeader
nodeId={nodeId}
isOpen={isOpen}
label={label}
selected={selected}
type={type}
/>
{isOpen && ( {isOpen && (
<> <>
<Flex <Flex
@ -54,27 +48,23 @@ const InvocationNode = ({ nodeProps, nodeTemplate }: Props) => {
className="nopan" className="nopan"
sx={{ flexDir: 'column', px: 2, w: 'full', h: 'full' }} sx={{ flexDir: 'column', px: 2, w: 'full', h: 'full' }}
> >
{outputFields.map((field) => ( {outputFieldNames.map((fieldName) => (
<OutputField <OutputField
key={`${nodeId}.${field.id}.input-field`} key={`${nodeId}.${fieldName}.output-field`}
nodeProps={nodeProps} nodeId={nodeId}
nodeTemplate={nodeTemplate} fieldName={fieldName}
field={field}
/> />
))} ))}
{inputFields.map((field) => ( {inputFieldNames.map((fieldName) => (
<InputField <InputField
key={`${nodeId}.${field.id}.input-field`} key={`${nodeId}.${fieldName}.input-field`}
nodeProps={nodeProps} nodeId={nodeId}
nodeTemplate={nodeTemplate} fieldName={fieldName}
field={field}
/> />
))} ))}
</Flex> </Flex>
</Flex> </Flex>
{withFooter && ( {withFooter && <NodeFooter nodeId={nodeId} />}
<NodeFooter nodeProps={nodeProps} nodeTemplate={nodeTemplate} />
)}
</> </>
)} )}
</NodeWrapper> </NodeWrapper>

View File

@ -2,16 +2,15 @@ import { ChevronUpIcon } from '@chakra-ui/icons';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton'; import IAIIconButton from 'common/components/IAIIconButton';
import { nodeIsOpenChanged } from 'features/nodes/store/nodesSlice'; import { nodeIsOpenChanged } from 'features/nodes/store/nodesSlice';
import { NodeData } from 'features/nodes/types/types';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { NodeProps, useUpdateNodeInternals } from 'reactflow'; import { useUpdateNodeInternals } from 'reactflow';
interface Props { interface Props {
nodeProps: NodeProps<NodeData>; nodeId: string;
isOpen: boolean;
} }
const NodeCollapseButton = (props: Props) => { const NodeCollapseButton = ({ nodeId, isOpen }: Props) => {
const { id: nodeId, isOpen } = props.nodeProps.data;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const updateNodeInternals = useUpdateNodeInternals(); const updateNodeInternals = useUpdateNodeInternals();

View File

@ -1,20 +1,17 @@
import { useColorModeValue } from '@chakra-ui/react'; import { useColorModeValue } from '@chakra-ui/react';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens'; import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import { import { useNodeData } from 'features/nodes/hooks/useNodeData';
InvocationNodeData, import { isInvocationNodeData } from 'features/nodes/types/types';
InvocationTemplate,
} from 'features/nodes/types/types';
import { map } from 'lodash-es'; import { map } from 'lodash-es';
import { CSSProperties, memo, useMemo } from 'react'; import { CSSProperties, memo, useMemo } from 'react';
import { Handle, NodeProps, Position } from 'reactflow'; import { Handle, Position } from 'reactflow';
interface Props { interface Props {
nodeProps: NodeProps<InvocationNodeData>; nodeId: string;
nodeTemplate: InvocationTemplate;
} }
const NodeCollapsedHandles = (props: Props) => { const NodeCollapsedHandles = ({ nodeId }: Props) => {
const { data } = props.nodeProps; const data = useNodeData(nodeId);
const { base400, base600 } = useChakraThemeTokens(); const { base400, base600 } = useChakraThemeTokens();
const backgroundColor = useColorModeValue(base400, base600); const backgroundColor = useColorModeValue(base400, base600);
@ -30,6 +27,10 @@ const NodeCollapsedHandles = (props: Props) => {
[backgroundColor] [backgroundColor]
); );
if (!isInvocationNodeData(data)) {
return null;
}
return ( return (
<> <>
<Handle <Handle
@ -44,7 +45,7 @@ const NodeCollapsedHandles = (props: Props) => {
key={`${data.id}-${input.name}-collapsed-input-handle`} key={`${data.id}-${input.name}-collapsed-input-handle`}
type="target" type="target"
id={input.name} id={input.name}
isValidConnection={() => false} isConnectable={false}
position={Position.Left} position={Position.Left}
style={{ visibility: 'hidden' }} style={{ visibility: 'hidden' }}
/> />
@ -52,7 +53,6 @@ const NodeCollapsedHandles = (props: Props) => {
<Handle <Handle
type="source" type="source"
id={`${data.id}-collapsed-source`} id={`${data.id}-collapsed-source`}
isValidConnection={() => false}
isConnectable={false} isConnectable={false}
position={Position.Right} position={Position.Right}
style={{ ...dummyHandleStyles, right: '-0.5rem' }} style={{ ...dummyHandleStyles, right: '-0.5rem' }}
@ -62,7 +62,7 @@ const NodeCollapsedHandles = (props: Props) => {
key={`${data.id}-${output.name}-collapsed-output-handle`} key={`${data.id}-${output.name}-collapsed-output-handle`}
type="source" type="source"
id={output.name} id={output.name}
isValidConnection={() => false} isConnectable={false}
position={Position.Right} position={Position.Right}
style={{ visibility: 'hidden' }} style={{ visibility: 'hidden' }}
/> />

View File

@ -6,49 +6,22 @@ import {
Spacer, Spacer,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import {
useHasImageOutput,
useIsIntermediate,
} from 'features/nodes/hooks/useNodeData';
import { fieldBooleanValueChanged } from 'features/nodes/store/nodesSlice'; import { fieldBooleanValueChanged } from 'features/nodes/store/nodesSlice';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants'; import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import { import { ChangeEvent, memo, useCallback } from 'react';
InvocationNodeData,
InvocationTemplate,
} from 'features/nodes/types/types';
import { some } from 'lodash-es';
import { ChangeEvent, memo, useCallback, useMemo } from 'react';
import { NodeProps } from 'reactflow';
export const IMAGE_FIELDS = ['ImageField', 'ImageCollection']; export const IMAGE_FIELDS = ['ImageField', 'ImageCollection'];
export const FOOTER_FIELDS = IMAGE_FIELDS; export const FOOTER_FIELDS = IMAGE_FIELDS;
type Props = { type Props = {
nodeProps: NodeProps<InvocationNodeData>; nodeId: string;
nodeTemplate: InvocationTemplate;
}; };
const NodeFooter = (props: Props) => { const NodeFooter = ({ nodeId }: Props) => {
const { nodeProps, nodeTemplate } = props;
const dispatch = useAppDispatch();
const hasImageOutput = useMemo(
() =>
some(nodeTemplate?.outputs, (output) =>
IMAGE_FIELDS.includes(output.type)
),
[nodeTemplate?.outputs]
);
const handleChangeIsIntermediate = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(
fieldBooleanValueChanged({
nodeId: nodeProps.data.id,
fieldName: 'is_intermediate',
value: !e.target.checked,
})
);
},
[dispatch, nodeProps.data.id]
);
return ( return (
<Flex <Flex
className={DRAG_HANDLE_CLASSNAME} className={DRAG_HANDLE_CLASSNAME}
@ -62,19 +35,45 @@ const NodeFooter = (props: Props) => {
}} }}
> >
<Spacer /> <Spacer />
{hasImageOutput && ( <SaveImageCheckbox nodeId={nodeId} />
<FormControl as={Flex} sx={{ alignItems: 'center', gap: 2, w: 'auto' }}>
<FormLabel sx={{ fontSize: 'xs', mb: '1px' }}>Save Output</FormLabel>
<Checkbox
className="nopan"
size="sm"
onChange={handleChangeIsIntermediate}
isChecked={!nodeProps.data.inputs['is_intermediate']?.value}
/>
</FormControl>
)}
</Flex> </Flex>
); );
}; };
export default memo(NodeFooter); export default memo(NodeFooter);
const SaveImageCheckbox = memo(({ nodeId }: { nodeId: string }) => {
const dispatch = useAppDispatch();
const hasImageOutput = useHasImageOutput(nodeId);
const is_intermediate = useIsIntermediate(nodeId);
const handleChangeIsIntermediate = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(
fieldBooleanValueChanged({
nodeId,
fieldName: 'is_intermediate',
value: !e.target.checked,
})
);
},
[dispatch, nodeId]
);
if (!hasImageOutput) {
return null;
}
return (
<FormControl as={Flex} sx={{ alignItems: 'center', gap: 2, w: 'auto' }}>
<FormLabel sx={{ fontSize: 'xs', mb: '1px' }}>Save Output</FormLabel>
<Checkbox
className="nopan"
size="sm"
onChange={handleChangeIsIntermediate}
isChecked={!is_intermediate}
/>
</FormControl>
);
});
SaveImageCheckbox.displayName = 'SaveImageCheckbox';

View File

@ -1,10 +1,5 @@
import { Flex } from '@chakra-ui/react'; import { Flex } from '@chakra-ui/react';
import {
InvocationNodeData,
InvocationTemplate,
} from 'features/nodes/types/types';
import { memo } from 'react'; import { memo } from 'react';
import { NodeProps } from 'reactflow';
import NodeCollapseButton from '../Invocation/NodeCollapseButton'; import NodeCollapseButton from '../Invocation/NodeCollapseButton';
import NodeCollapsedHandles from '../Invocation/NodeCollapsedHandles'; import NodeCollapsedHandles from '../Invocation/NodeCollapsedHandles';
import NodeNotesEdit from '../Invocation/NodeNotesEdit'; import NodeNotesEdit from '../Invocation/NodeNotesEdit';
@ -12,14 +7,14 @@ import NodeStatusIndicator from '../Invocation/NodeStatusIndicator';
import NodeTitle from '../Invocation/NodeTitle'; import NodeTitle from '../Invocation/NodeTitle';
type Props = { type Props = {
nodeProps: NodeProps<InvocationNodeData>; nodeId: string;
nodeTemplate: InvocationTemplate; isOpen: boolean;
label: string;
type: string;
selected: boolean;
}; };
const NodeHeader = (props: Props) => { const NodeHeader = ({ nodeId, isOpen, label, type, selected }: Props) => {
const { nodeProps, nodeTemplate } = props;
const { isOpen } = nodeProps.data;
return ( return (
<Flex <Flex
layerStyle="nodeHeader" layerStyle="nodeHeader"
@ -35,18 +30,13 @@ const NodeHeader = (props: Props) => {
_dark: { color: 'base.200' }, _dark: { color: 'base.200' },
}} }}
> >
<NodeCollapseButton nodeProps={nodeProps} /> <NodeCollapseButton nodeId={nodeId} isOpen={isOpen} />
<NodeTitle nodeData={nodeProps.data} title={nodeTemplate.title} /> <NodeTitle nodeId={nodeId} />
<Flex alignItems="center"> <Flex alignItems="center">
<NodeStatusIndicator nodeProps={nodeProps} /> <NodeStatusIndicator nodeId={nodeId} />
<NodeNotesEdit nodeProps={nodeProps} nodeTemplate={nodeTemplate} /> <NodeNotesEdit nodeId={nodeId} />
</Flex> </Flex>
{!isOpen && ( {!isOpen && <NodeCollapsedHandles nodeId={nodeId} />}
<NodeCollapsedHandles
nodeProps={nodeProps}
nodeTemplate={nodeTemplate}
/>
)}
</Flex> </Flex>
); );
}; };

View File

@ -16,41 +16,31 @@ import {
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import IAITextarea from 'common/components/IAITextarea'; import IAITextarea from 'common/components/IAITextarea';
import {
useNodeData,
useNodeLabel,
useNodeTemplate,
useNodeTemplateTitle,
} from 'features/nodes/hooks/useNodeData';
import { nodeNotesChanged } from 'features/nodes/store/nodesSlice'; import { nodeNotesChanged } from 'features/nodes/store/nodesSlice';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants'; import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import { import { isInvocationNodeData } from 'features/nodes/types/types';
InvocationNodeData,
InvocationTemplate,
} from 'features/nodes/types/types';
import { ChangeEvent, memo, useCallback } from 'react'; import { ChangeEvent, memo, useCallback } from 'react';
import { FaInfoCircle } from 'react-icons/fa'; import { FaInfoCircle } from 'react-icons/fa';
import { NodeProps } from 'reactflow';
interface Props { interface Props {
nodeProps: NodeProps<InvocationNodeData>; nodeId: string;
nodeTemplate: InvocationTemplate;
} }
const NodeNotesEdit = (props: Props) => { const NodeNotesEdit = ({ nodeId }: Props) => {
const { nodeProps, nodeTemplate } = props;
const { data } = nodeProps;
const { isOpen, onOpen, onClose } = useDisclosure(); const { isOpen, onOpen, onClose } = useDisclosure();
const dispatch = useAppDispatch(); const label = useNodeLabel(nodeId);
const handleNotesChanged = useCallback( const title = useNodeTemplateTitle(nodeId);
(e: ChangeEvent<HTMLTextAreaElement>) => {
dispatch(nodeNotesChanged({ nodeId: data.id, notes: e.target.value }));
},
[data.id, dispatch]
);
return ( return (
<> <>
<Tooltip <Tooltip
label={ label={<TooltipContent nodeId={nodeId} />}
nodeTemplate ? (
<TooltipContent nodeProps={nodeProps} nodeTemplate={nodeTemplate} />
) : undefined
}
placement="top" placement="top"
shouldWrapChildren shouldWrapChildren
> >
@ -75,19 +65,10 @@ const NodeNotesEdit = (props: Props) => {
<Modal isOpen={isOpen} onClose={onClose} isCentered> <Modal isOpen={isOpen} onClose={onClose} isCentered>
<ModalOverlay /> <ModalOverlay />
<ModalContent> <ModalContent>
<ModalHeader> <ModalHeader>{label || title || 'Unknown Node'}</ModalHeader>
{data.label || nodeTemplate?.title || 'Unknown Node'}
</ModalHeader>
<ModalCloseButton /> <ModalCloseButton />
<ModalBody> <ModalBody>
<FormControl> <NotesTextarea nodeId={nodeId} />
<FormLabel>Notes</FormLabel>
<IAITextarea
value={data.notes}
onChange={handleNotesChanged}
rows={10}
/>
</FormControl>
</ModalBody> </ModalBody>
<ModalFooter /> <ModalFooter />
</ModalContent> </ModalContent>
@ -98,16 +79,49 @@ const NodeNotesEdit = (props: Props) => {
export default memo(NodeNotesEdit); export default memo(NodeNotesEdit);
type TooltipContentProps = Props; const TooltipContent = memo(({ nodeId }: { nodeId: string }) => {
const data = useNodeData(nodeId);
const nodeTemplate = useNodeTemplate(nodeId);
if (!isInvocationNodeData(data)) {
return 'Unknown Node';
}
const TooltipContent = (props: TooltipContentProps) => {
return ( return (
<Flex sx={{ flexDir: 'column' }}> <Flex sx={{ flexDir: 'column' }}>
<Text sx={{ fontWeight: 600 }}>{props.nodeTemplate?.title}</Text> <Text sx={{ fontWeight: 600 }}>{nodeTemplate?.title}</Text>
<Text sx={{ opacity: 0.7, fontStyle: 'oblique 5deg' }}> <Text sx={{ opacity: 0.7, fontStyle: 'oblique 5deg' }}>
{props.nodeTemplate?.description} {nodeTemplate?.description}
</Text> </Text>
{props.nodeProps.data.notes && <Text>{props.nodeProps.data.notes}</Text>} {data?.notes && <Text>{data.notes}</Text>}
</Flex> </Flex>
); );
}; });
TooltipContent.displayName = 'TooltipContent';
const NotesTextarea = memo(({ nodeId }: { nodeId: string }) => {
const dispatch = useAppDispatch();
const data = useNodeData(nodeId);
const handleNotesChanged = useCallback(
(e: ChangeEvent<HTMLTextAreaElement>) => {
dispatch(nodeNotesChanged({ nodeId, notes: e.target.value }));
},
[dispatch, nodeId]
);
if (!isInvocationNodeData(data)) {
return null;
}
return (
<FormControl>
<FormLabel>Notes</FormLabel>
<IAITextarea
value={data?.notes}
onChange={handleNotesChanged}
rows={10}
/>
</FormControl>
);
});
NotesTextarea.displayName = 'NodesTextarea';

View File

@ -11,17 +11,12 @@ import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants'; import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import { import { NodeExecutionState, NodeStatus } from 'features/nodes/types/types';
InvocationNodeData,
NodeExecutionState,
NodeStatus,
} from 'features/nodes/types/types';
import { memo, useMemo } from 'react'; import { memo, useMemo } from 'react';
import { FaCheck, FaEllipsisH, FaExclamation } from 'react-icons/fa'; import { FaCheck, FaEllipsisH, FaExclamation } from 'react-icons/fa';
import { NodeProps } from 'reactflow';
type Props = { type Props = {
nodeProps: NodeProps<InvocationNodeData>; nodeId: string;
}; };
const iconBoxSize = 3; const iconBoxSize = 3;
@ -33,8 +28,7 @@ const circleStyles = {
'.chakra-progress__track': { stroke: 'transparent' }, '.chakra-progress__track': { stroke: 'transparent' },
}; };
const NodeStatusIndicator = (props: Props) => { const NodeStatusIndicator = ({ nodeId }: Props) => {
const nodeId = props.nodeProps.data.id;
const selectNodeExecutionState = useMemo( const selectNodeExecutionState = useMemo(
() => () =>
createSelector( createSelector(
@ -76,7 +70,7 @@ type TooltipLabelProps = {
nodeExecutionState: NodeExecutionState; nodeExecutionState: NodeExecutionState;
}; };
const TooltipLabel = ({ nodeExecutionState }: TooltipLabelProps) => { const TooltipLabel = memo(({ nodeExecutionState }: TooltipLabelProps) => {
const { status, progress, progressImage } = nodeExecutionState; const { status, progress, progressImage } = nodeExecutionState;
if (status === NodeStatus.PENDING) { if (status === NodeStatus.PENDING) {
return <Text>Pending</Text>; return <Text>Pending</Text>;
@ -118,13 +112,15 @@ const TooltipLabel = ({ nodeExecutionState }: TooltipLabelProps) => {
} }
return null; return null;
}; });
TooltipLabel.displayName = 'TooltipLabel';
type StatusIconProps = { type StatusIconProps = {
nodeExecutionState: NodeExecutionState; nodeExecutionState: NodeExecutionState;
}; };
const StatusIcon = (props: StatusIconProps) => { const StatusIcon = memo((props: StatusIconProps) => {
const { progress, status } = props.nodeExecutionState; const { progress, status } = props.nodeExecutionState;
if (status === NodeStatus.PENDING) { if (status === NodeStatus.PENDING) {
return ( return (
@ -182,4 +178,6 @@ const StatusIcon = (props: StatusIconProps) => {
); );
} }
return null; return null;
}; });
StatusIcon.displayName = 'StatusIcon';

View File

@ -7,26 +7,29 @@ import {
useEditableControls, useEditableControls,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import {
useNodeLabel,
useNodeTemplateTitle,
} from 'features/nodes/hooks/useNodeData';
import { nodeLabelChanged } from 'features/nodes/store/nodesSlice'; import { nodeLabelChanged } from 'features/nodes/store/nodesSlice';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants'; import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import { NodeData } from 'features/nodes/types/types';
import { MouseEvent, memo, useCallback, useEffect, useState } from 'react'; import { MouseEvent, memo, useCallback, useEffect, useState } from 'react';
type Props = { type Props = {
nodeData: NodeData; nodeId: string;
title: string; title?: string;
}; };
const NodeTitle = (props: Props) => { const NodeTitle = ({ nodeId, title }: Props) => {
const { title } = props;
const { id: nodeId, label } = props.nodeData;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const [localTitle, setLocalTitle] = useState(label || title); const label = useNodeLabel(nodeId);
const templateTitle = useNodeTemplateTitle(nodeId);
const [localTitle, setLocalTitle] = useState('');
const handleSubmit = useCallback( const handleSubmit = useCallback(
async (newTitle: string) => { async (newTitle: string) => {
dispatch(nodeLabelChanged({ nodeId, label: newTitle })); dispatch(nodeLabelChanged({ nodeId, label: newTitle }));
setLocalTitle(newTitle || title); setLocalTitle(newTitle || title || 'Problem Setting Title');
}, },
[nodeId, dispatch, title] [nodeId, dispatch, title]
); );
@ -37,8 +40,8 @@ const NodeTitle = (props: Props) => {
useEffect(() => { useEffect(() => {
// Another component may change the title; sync local title with global state // Another component may change the title; sync local title with global state
setLocalTitle(label || title); setLocalTitle(label || title || templateTitle || 'Problem Setting Title');
}, [label, title]); }, [label, templateTitle, title]);
return ( return (
<Flex <Flex

View File

@ -6,10 +6,14 @@ import {
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { nodeClicked } from 'features/nodes/store/nodesSlice'; import { nodeClicked } from 'features/nodes/store/nodesSlice';
import { MouseEvent, PropsWithChildren, useCallback, useMemo } from 'react'; import {
MouseEvent,
PropsWithChildren,
memo,
useCallback,
useMemo,
} from 'react';
import { DRAG_HANDLE_CLASSNAME, NODE_WIDTH } from '../../types/constants'; import { DRAG_HANDLE_CLASSNAME, NODE_WIDTH } from '../../types/constants';
import { NodeData } from 'features/nodes/types/types';
import { NodeProps } from 'reactflow';
const useNodeSelect = (nodeId: string) => { const useNodeSelect = (nodeId: string) => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
@ -25,14 +29,13 @@ const useNodeSelect = (nodeId: string) => {
}; };
type NodeWrapperProps = PropsWithChildren & { type NodeWrapperProps = PropsWithChildren & {
nodeProps: NodeProps<NodeData>; nodeId: string;
selected: boolean;
width?: NonNullable<ChakraProps['sx']>['w']; width?: NonNullable<ChakraProps['sx']>['w'];
}; };
const NodeWrapper = (props: NodeWrapperProps) => { const NodeWrapper = (props: NodeWrapperProps) => {
const { width, children, nodeProps } = props; const { width, children, nodeId, selected } = props;
const { data, selected } = nodeProps;
const nodeId = data.id;
const [ const [
nodeSelectedOutlineLight, nodeSelectedOutlineLight,
@ -93,4 +96,4 @@ const NodeWrapper = (props: NodeWrapperProps) => {
); );
}; };
export default NodeWrapper; export default memo(NodeWrapper);

View File

@ -1,20 +1,26 @@
import { Box, Flex, Text } from '@chakra-ui/react'; import { Box, Flex, Text } from '@chakra-ui/react';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants'; import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import { InvocationNodeData } from 'features/nodes/types/types';
import { memo } from 'react'; import { memo } from 'react';
import { NodeProps } from 'reactflow';
import NodeCollapseButton from '../Invocation/NodeCollapseButton'; import NodeCollapseButton from '../Invocation/NodeCollapseButton';
import NodeWrapper from '../Invocation/NodeWrapper'; import NodeWrapper from '../Invocation/NodeWrapper';
type Props = { type Props = {
nodeProps: NodeProps<InvocationNodeData>; nodeId: string;
isOpen: boolean;
label: string;
type: string;
selected: boolean;
}; };
const UnknownNodeFallback = ({ nodeProps }: Props) => { const UnknownNodeFallback = ({
const { data } = nodeProps; nodeId,
const { isOpen, label, type } = data; isOpen,
label,
type,
selected,
}: Props) => {
return ( return (
<NodeWrapper nodeProps={nodeProps}> <NodeWrapper nodeId={nodeId} selected={selected}>
<Flex <Flex
className={DRAG_HANDLE_CLASSNAME} className={DRAG_HANDLE_CLASSNAME}
layerStyle="nodeHeader" layerStyle="nodeHeader"
@ -27,7 +33,7 @@ const UnknownNodeFallback = ({ nodeProps }: Props) => {
fontSize: 'sm', fontSize: 'sm',
}} }}
> >
<NodeCollapseButton nodeProps={nodeProps} /> <NodeCollapseButton nodeId={nodeId} isOpen={isOpen} />
<Text <Text
sx={{ sx={{
w: 'full', w: 'full',

View File

@ -1,19 +1,12 @@
import { Tooltip } from '@chakra-ui/react'; import { Tooltip } from '@chakra-ui/react';
import { CSSProperties, memo, useMemo } from 'react'; import { CSSProperties, memo, useMemo } from 'react';
import { Handle, HandleType, NodeProps, Position } from 'reactflow'; import { Handle, HandleType, Position } from 'reactflow';
import { import {
FIELDS, FIELDS,
HANDLE_TOOLTIP_OPEN_DELAY, HANDLE_TOOLTIP_OPEN_DELAY,
colorTokenToCssVar, colorTokenToCssVar,
} from '../../types/constants'; } from '../../types/constants';
import { import { InputFieldTemplate, OutputFieldTemplate } from '../../types/types';
InputFieldTemplate,
InputFieldValue,
InvocationNodeData,
InvocationTemplate,
OutputFieldTemplate,
OutputFieldValue,
} from '../../types/types';
export const handleBaseStyles: CSSProperties = { export const handleBaseStyles: CSSProperties = {
position: 'absolute', position: 'absolute',
@ -32,9 +25,6 @@ export const outputHandleStyles: CSSProperties = {
}; };
type FieldHandleProps = { type FieldHandleProps = {
nodeProps: NodeProps<InvocationNodeData>;
nodeTemplate: InvocationTemplate;
field: InputFieldValue | OutputFieldValue;
fieldTemplate: InputFieldTemplate | OutputFieldTemplate; fieldTemplate: InputFieldTemplate | OutputFieldTemplate;
handleType: HandleType; handleType: HandleType;
isConnectionInProgress: boolean; isConnectionInProgress: boolean;

View File

@ -8,13 +8,11 @@ import {
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import IAIDraggable from 'common/components/IAIDraggable'; import IAIDraggable from 'common/components/IAIDraggable';
import { NodeFieldDraggableData } from 'features/dnd/types'; import { NodeFieldDraggableData } from 'features/dnd/types';
import { fieldLabelChanged } from 'features/nodes/store/nodesSlice';
import { import {
InputFieldTemplate, useFieldData,
InputFieldValue, useFieldTemplate,
InvocationNodeData, } from 'features/nodes/hooks/useNodeData';
InvocationTemplate, import { fieldLabelChanged } from 'features/nodes/store/nodesSlice';
} from 'features/nodes/types/types';
import { import {
MouseEvent, MouseEvent,
memo, memo,
@ -25,41 +23,43 @@ import {
} from 'react'; } from 'react';
interface Props { interface Props {
nodeData: InvocationNodeData; nodeId: string;
nodeTemplate: InvocationTemplate; fieldName: string;
field: InputFieldValue;
fieldTemplate: InputFieldTemplate;
isDraggable?: boolean; isDraggable?: boolean;
kind: 'input' | 'output';
} }
const FieldTitle = (props: Props) => { const FieldTitle = (props: Props) => {
const { nodeData, field, fieldTemplate, isDraggable = false } = props; const { nodeId, fieldName, isDraggable = false, kind } = props;
const { label } = field; const fieldTemplate = useFieldTemplate(nodeId, fieldName, kind);
const { title, input } = fieldTemplate; const field = useFieldData(nodeId, fieldName);
const { id: nodeId } = nodeData;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const [localTitle, setLocalTitle] = useState(label || title); const [localTitle, setLocalTitle] = useState(
field?.label || fieldTemplate?.title || 'Unknown Field'
);
const draggableData: NodeFieldDraggableData | undefined = useMemo( const draggableData: NodeFieldDraggableData | undefined = useMemo(
() => () =>
input !== 'connection' && isDraggable field &&
fieldTemplate?.fieldKind === 'input' &&
fieldTemplate?.input !== 'connection' &&
isDraggable
? { ? {
id: `${nodeId}-${field.name}`, id: `${nodeId}-${fieldName}`,
payloadType: 'NODE_FIELD', payloadType: 'NODE_FIELD',
payload: { nodeId, field, fieldTemplate }, payload: { nodeId, field, fieldTemplate },
} }
: undefined, : undefined,
[field, fieldTemplate, input, isDraggable, nodeId] [field, fieldName, fieldTemplate, isDraggable, nodeId]
); );
const handleSubmit = useCallback( const handleSubmit = useCallback(
async (newTitle: string) => { async (newTitle: string) => {
dispatch( dispatch(fieldLabelChanged({ nodeId, fieldName, label: newTitle }));
fieldLabelChanged({ nodeId, fieldName: field.name, label: newTitle }) setLocalTitle(newTitle || fieldTemplate?.title || 'Unknown Field');
);
setLocalTitle(newTitle || title);
}, },
[dispatch, nodeId, field.name, title] [dispatch, nodeId, fieldName, fieldTemplate?.title]
); );
const handleChange = useCallback((newTitle: string) => { const handleChange = useCallback((newTitle: string) => {
@ -68,8 +68,8 @@ const FieldTitle = (props: Props) => {
useEffect(() => { useEffect(() => {
// Another component may change the title; sync local title with global state // Another component may change the title; sync local title with global state
setLocalTitle(label || title); setLocalTitle(field?.label || fieldTemplate?.title || 'Unknown Field');
}, [label, title]); }, [field?.label, fieldTemplate?.title]);
return ( return (
<Flex <Flex
@ -120,7 +120,7 @@ type EditableControlsProps = {
draggableData?: NodeFieldDraggableData; draggableData?: NodeFieldDraggableData;
}; };
function EditableControls(props: EditableControlsProps) { const EditableControls = memo((props: EditableControlsProps) => {
const { isEditing, getEditButtonProps } = useEditableControls(); const { isEditing, getEditButtonProps } = useEditableControls();
const handleDoubleClick = useCallback( const handleDoubleClick = useCallback(
(e: MouseEvent<HTMLDivElement>) => { (e: MouseEvent<HTMLDivElement>) => {
@ -158,4 +158,6 @@ function EditableControls(props: EditableControlsProps) {
cursor="text" cursor="text"
/> />
); );
} });
EditableControls.displayName = 'EditableControls';

View File

@ -1,38 +1,53 @@
import { Flex, Text } from '@chakra-ui/react'; import { Flex, Text } from '@chakra-ui/react';
import {
useFieldData,
useFieldTemplate,
} from 'features/nodes/hooks/useNodeData';
import { FIELDS } from 'features/nodes/types/constants'; import { FIELDS } from 'features/nodes/types/constants';
import { import {
InputFieldTemplate,
InputFieldValue,
InvocationNodeData,
InvocationTemplate,
OutputFieldTemplate,
OutputFieldValue,
isInputFieldTemplate, isInputFieldTemplate,
isInputFieldValue, isInputFieldValue,
} from 'features/nodes/types/types'; } from 'features/nodes/types/types';
import { startCase } from 'lodash-es'; import { startCase } from 'lodash-es';
import { useMemo } from 'react';
interface Props { interface Props {
nodeData: InvocationNodeData; nodeId: string;
nodeTemplate: InvocationTemplate; fieldName: string;
field: InputFieldValue | OutputFieldValue; kind: 'input' | 'output';
fieldTemplate: InputFieldTemplate | OutputFieldTemplate;
} }
const FieldTooltipContent = ({ field, fieldTemplate }: Props) => { const FieldTooltipContent = ({ nodeId, fieldName, kind }: Props) => {
const field = useFieldData(nodeId, fieldName);
const fieldTemplate = useFieldTemplate(nodeId, fieldName, kind);
const isInputTemplate = isInputFieldTemplate(fieldTemplate); const isInputTemplate = isInputFieldTemplate(fieldTemplate);
const fieldTitle = useMemo(() => {
if (isInputFieldValue(field)) {
if (field.label && fieldTemplate) {
return `${field.label} (${fieldTemplate.title})`;
}
if (field.label && !fieldTemplate) {
return field.label;
}
if (!field.label && fieldTemplate) {
return fieldTemplate.title;
}
return 'Unknown Field';
}
}, [field, fieldTemplate]);
return ( return (
<Flex sx={{ flexDir: 'column' }}> <Flex sx={{ flexDir: 'column' }}>
<Text sx={{ fontWeight: 600 }}> <Text sx={{ fontWeight: 600 }}>{fieldTitle}</Text>
{isInputFieldValue(field) && field.label {fieldTemplate && (
? `${field.label} (${fieldTemplate.title})` <Text sx={{ opacity: 0.7, fontStyle: 'oblique 5deg' }}>
: fieldTemplate.title} {fieldTemplate.description}
</Text> </Text>
<Text sx={{ opacity: 0.7, fontStyle: 'oblique 5deg' }}> )}
{fieldTemplate.description} {fieldTemplate && <Text>Type: {FIELDS[fieldTemplate.type].title}</Text>}
</Text>
<Text>Type: {FIELDS[fieldTemplate.type].title}</Text>
{isInputTemplate && <Text>Input: {startCase(fieldTemplate.input)}</Text>} {isInputTemplate && <Text>Input: {startCase(fieldTemplate.input)}</Text>}
</Flex> </Flex>
); );

View File

@ -1,27 +1,24 @@
import { Flex, FormControl, FormLabel, Tooltip } from '@chakra-ui/react'; import { Flex, FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
import { useConnectionState } from 'features/nodes/hooks/useConnectionState'; import { useConnectionState } from 'features/nodes/hooks/useConnectionState';
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
import { import {
InputFieldValue, useDoesInputHaveValue,
InvocationNodeData, useFieldTemplate,
InvocationTemplate, } from 'features/nodes/hooks/useNodeData';
} from 'features/nodes/types/types'; import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
import { PropsWithChildren, useMemo } from 'react'; import { PropsWithChildren, memo, useMemo } from 'react';
import { NodeProps } from 'reactflow';
import FieldHandle from './FieldHandle'; import FieldHandle from './FieldHandle';
import FieldTitle from './FieldTitle'; import FieldTitle from './FieldTitle';
import FieldTooltipContent from './FieldTooltipContent'; import FieldTooltipContent from './FieldTooltipContent';
import InputFieldRenderer from './InputFieldRenderer'; import InputFieldRenderer from './InputFieldRenderer';
interface Props { interface Props {
nodeProps: NodeProps<InvocationNodeData>; nodeId: string;
nodeTemplate: InvocationTemplate; fieldName: string;
field: InputFieldValue;
} }
const InputField = (props: Props) => { const InputField = ({ nodeId, fieldName }: Props) => {
const { nodeProps, nodeTemplate, field } = props; const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input');
const { id: nodeId } = nodeProps.data; const doesFieldHaveValue = useDoesInputHaveValue(nodeId, fieldName);
const { const {
isConnected, isConnected,
@ -29,15 +26,10 @@ const InputField = (props: Props) => {
isConnectionStartField, isConnectionStartField,
connectionError, connectionError,
shouldDim, shouldDim,
} = useConnectionState({ nodeId, field, kind: 'input' }); } = useConnectionState({ nodeId, fieldName, kind: 'input' });
const fieldTemplate = useMemo(
() => nodeTemplate.inputs[field.name],
[field.name, nodeTemplate.inputs]
);
const isMissingInput = useMemo(() => { const isMissingInput = useMemo(() => {
if (!fieldTemplate) { if (fieldTemplate?.fieldKind !== 'input') {
return false; return false;
} }
@ -49,18 +41,18 @@ const InputField = (props: Props) => {
return true; return true;
} }
if (!field.value && !isConnected && fieldTemplate.input === 'any') { if (!doesFieldHaveValue && !isConnected && fieldTemplate.input === 'any') {
return true; return true;
} }
}, [fieldTemplate, isConnected, field.value]); }, [fieldTemplate, isConnected, doesFieldHaveValue]);
if (!fieldTemplate) { if (fieldTemplate?.fieldKind !== 'input') {
return ( return (
<InputFieldWrapper shouldDim={shouldDim}> <InputFieldWrapper shouldDim={shouldDim}>
<FormControl <FormControl
sx={{ color: 'error.400', textAlign: 'left', fontSize: 'sm' }} sx={{ color: 'error.400', textAlign: 'left', fontSize: 'sm' }}
> >
Unknown input: {field.name} Unknown input: {fieldName}
</FormControl> </FormControl>
</InputFieldWrapper> </InputFieldWrapper>
); );
@ -82,10 +74,9 @@ const InputField = (props: Props) => {
<Tooltip <Tooltip
label={ label={
<FieldTooltipContent <FieldTooltipContent
nodeData={nodeProps.data} nodeId={nodeId}
nodeTemplate={nodeTemplate} fieldName={fieldName}
field={field} kind="input"
fieldTemplate={fieldTemplate}
/> />
} }
openDelay={HANDLE_TOOLTIP_OPEN_DELAY} openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
@ -95,27 +86,18 @@ const InputField = (props: Props) => {
> >
<FormLabel sx={{ mb: 0 }}> <FormLabel sx={{ mb: 0 }}>
<FieldTitle <FieldTitle
nodeData={nodeProps.data} nodeId={nodeId}
nodeTemplate={nodeTemplate} fieldName={fieldName}
field={field} kind="input"
fieldTemplate={fieldTemplate}
isDraggable isDraggable
/> />
</FormLabel> </FormLabel>
</Tooltip> </Tooltip>
<InputFieldRenderer <InputFieldRenderer nodeId={nodeId} fieldName={fieldName} />
nodeData={nodeProps.data}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
</FormControl> </FormControl>
{fieldTemplate.input !== 'direct' && ( {fieldTemplate.input !== 'direct' && (
<FieldHandle <FieldHandle
nodeProps={nodeProps}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate} fieldTemplate={fieldTemplate}
handleType="target" handleType="target"
isConnectionInProgress={isConnectionInProgress} isConnectionInProgress={isConnectionInProgress}
@ -133,21 +115,25 @@ type InputFieldWrapperProps = PropsWithChildren<{
shouldDim: boolean; shouldDim: boolean;
}>; }>;
const InputFieldWrapper = ({ shouldDim, children }: InputFieldWrapperProps) => ( const InputFieldWrapper = memo(
<Flex ({ shouldDim, children }: InputFieldWrapperProps) => (
className="nopan" <Flex
sx={{ className="nopan"
position: 'relative', sx={{
minH: 8, position: 'relative',
py: 0.5, minH: 8,
alignItems: 'center', py: 0.5,
opacity: shouldDim ? 0.5 : 1, alignItems: 'center',
transitionProperty: 'opacity', opacity: shouldDim ? 0.5 : 1,
transitionDuration: '0.1s', transitionProperty: 'opacity',
w: 'full', transitionDuration: '0.1s',
h: 'full', w: 'full',
}} h: 'full',
> }}
{children} >
</Flex> {children}
</Flex>
)
); );
InputFieldWrapper.displayName = 'InputFieldWrapper';

View File

@ -1,11 +1,9 @@
import { Box } from '@chakra-ui/react'; import { Box } from '@chakra-ui/react';
import { memo } from 'react';
import { import {
InputFieldTemplate, useFieldData,
InputFieldValue, useFieldTemplate,
InvocationNodeData, } from 'features/nodes/hooks/useNodeData';
InvocationTemplate, import { memo } from 'react';
} from '../../types/types';
import BooleanInputField from './fieldTypes/BooleanInputField'; import BooleanInputField from './fieldTypes/BooleanInputField';
import ClipInputField from './fieldTypes/ClipInputField'; import ClipInputField from './fieldTypes/ClipInputField';
import CollectionInputField from './fieldTypes/CollectionInputField'; import CollectionInputField from './fieldTypes/CollectionInputField';
@ -29,33 +27,33 @@ import VaeInputField from './fieldTypes/VaeInputField';
import VaeModelInputField from './fieldTypes/VaeModelInputField'; import VaeModelInputField from './fieldTypes/VaeModelInputField';
type InputFieldProps = { type InputFieldProps = {
nodeData: InvocationNodeData; nodeId: string;
nodeTemplate: InvocationTemplate; fieldName: string;
field: InputFieldValue;
fieldTemplate: InputFieldTemplate;
}; };
// build an individual input element based on the schema // build an individual input element based on the schema
const InputFieldRenderer = (props: InputFieldProps) => { const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
const { nodeData, nodeTemplate, field, fieldTemplate } = props; const field = useFieldData(nodeId, fieldName);
const { type } = field; const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input');
if (type === 'string' && fieldTemplate.type === 'string') { if (fieldTemplate?.fieldKind === 'output') {
return <Box p={2}>Output field in input: {field?.type}</Box>;
}
if (field?.type === 'string' && fieldTemplate?.type === 'string') {
return ( return (
<StringInputField <StringInputField
nodeData={nodeData} nodeId={nodeId}
nodeTemplate={nodeTemplate}
field={field} field={field}
fieldTemplate={fieldTemplate} fieldTemplate={fieldTemplate}
/> />
); );
} }
if (type === 'boolean' && fieldTemplate.type === 'boolean') { if (field?.type === 'boolean' && fieldTemplate?.type === 'boolean') {
return ( return (
<BooleanInputField <BooleanInputField
nodeData={nodeData} nodeId={nodeId}
nodeTemplate={nodeTemplate}
field={field} field={field}
fieldTemplate={fieldTemplate} fieldTemplate={fieldTemplate}
/> />
@ -63,46 +61,32 @@ const InputFieldRenderer = (props: InputFieldProps) => {
} }
if ( if (
(type === 'integer' && fieldTemplate.type === 'integer') || (field?.type === 'integer' && fieldTemplate?.type === 'integer') ||
(type === 'float' && fieldTemplate.type === 'float') (field?.type === 'float' && fieldTemplate?.type === 'float')
) { ) {
return ( return (
<NumberInputField <NumberInputField
nodeData={nodeData} nodeId={nodeId}
nodeTemplate={nodeTemplate}
field={field} field={field}
fieldTemplate={fieldTemplate} fieldTemplate={fieldTemplate}
/> />
); );
} }
if (type === 'enum' && fieldTemplate.type === 'enum') { if (field?.type === 'enum' && fieldTemplate?.type === 'enum') {
return ( return (
<EnumInputField <EnumInputField
nodeData={nodeData} nodeId={nodeId}
nodeTemplate={nodeTemplate}
field={field} field={field}
fieldTemplate={fieldTemplate} fieldTemplate={fieldTemplate}
/> />
); );
} }
if (type === 'ImageField' && fieldTemplate.type === 'ImageField') { if (field?.type === 'ImageField' && fieldTemplate?.type === 'ImageField') {
return ( return (
<ImageInputField <ImageInputField
nodeData={nodeData} nodeId={nodeId}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'LatentsField' && fieldTemplate.type === 'LatentsField') {
return (
<LatentsInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field} field={field}
fieldTemplate={fieldTemplate} fieldTemplate={fieldTemplate}
/> />
@ -110,68 +94,55 @@ const InputFieldRenderer = (props: InputFieldProps) => {
} }
if ( if (
type === 'ConditioningField' && field?.type === 'LatentsField' &&
fieldTemplate.type === 'ConditioningField' fieldTemplate?.type === 'LatentsField'
) {
return (
<LatentsInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'ConditioningField' &&
fieldTemplate?.type === 'ConditioningField'
) { ) {
return ( return (
<ConditioningInputField <ConditioningInputField
nodeData={nodeData} nodeId={nodeId}
nodeTemplate={nodeTemplate}
field={field} field={field}
fieldTemplate={fieldTemplate} fieldTemplate={fieldTemplate}
/> />
); );
} }
if (type === 'UNetField' && fieldTemplate.type === 'UNetField') { if (field?.type === 'UNetField' && fieldTemplate?.type === 'UNetField') {
return ( return (
<UnetInputField <UnetInputField
nodeData={nodeData} nodeId={nodeId}
nodeTemplate={nodeTemplate}
field={field} field={field}
fieldTemplate={fieldTemplate} fieldTemplate={fieldTemplate}
/> />
); );
} }
if (type === 'ClipField' && fieldTemplate.type === 'ClipField') { if (field?.type === 'ClipField' && fieldTemplate?.type === 'ClipField') {
return ( return (
<ClipInputField <ClipInputField
nodeData={nodeData} nodeId={nodeId}
nodeTemplate={nodeTemplate}
field={field} field={field}
fieldTemplate={fieldTemplate} fieldTemplate={fieldTemplate}
/> />
); );
} }
if (type === 'VaeField' && fieldTemplate.type === 'VaeField') { if (field?.type === 'VaeField' && fieldTemplate?.type === 'VaeField') {
return ( return (
<VaeInputField <VaeInputField
nodeData={nodeData} nodeId={nodeId}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'ControlField' && fieldTemplate.type === 'ControlField') {
return (
<ControlInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'MainModelField' && fieldTemplate.type === 'MainModelField') {
return (
<MainModelInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field} field={field}
fieldTemplate={fieldTemplate} fieldTemplate={fieldTemplate}
/> />
@ -179,35 +150,38 @@ const InputFieldRenderer = (props: InputFieldProps) => {
} }
if ( if (
type === 'SDXLRefinerModelField' && field?.type === 'ControlField' &&
fieldTemplate.type === 'SDXLRefinerModelField' fieldTemplate?.type === 'ControlField'
) {
return (
<ControlInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'MainModelField' &&
fieldTemplate?.type === 'MainModelField'
) {
return (
<MainModelInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'SDXLRefinerModelField' &&
fieldTemplate?.type === 'SDXLRefinerModelField'
) { ) {
return ( return (
<RefinerModelInputField <RefinerModelInputField
nodeData={nodeData} nodeId={nodeId}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'VaeModelField' && fieldTemplate.type === 'VaeModelField') {
return (
<VaeModelInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'LoRAModelField' && fieldTemplate.type === 'LoRAModelField') {
return (
<LoRAModelInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field} field={field}
fieldTemplate={fieldTemplate} fieldTemplate={fieldTemplate}
/> />
@ -215,57 +189,48 @@ const InputFieldRenderer = (props: InputFieldProps) => {
} }
if ( if (
type === 'ControlNetModelField' && field?.type === 'VaeModelField' &&
fieldTemplate.type === 'ControlNetModelField' fieldTemplate?.type === 'VaeModelField'
) {
return (
<VaeModelInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'LoRAModelField' &&
fieldTemplate?.type === 'LoRAModelField'
) {
return (
<LoRAModelInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'ControlNetModelField' &&
fieldTemplate?.type === 'ControlNetModelField'
) { ) {
return ( return (
<ControlNetModelInputField <ControlNetModelInputField
nodeData={nodeData} nodeId={nodeId}
nodeTemplate={nodeTemplate}
field={field} field={field}
fieldTemplate={fieldTemplate} fieldTemplate={fieldTemplate}
/> />
); );
} }
if (type === 'Collection' && fieldTemplate.type === 'Collection') { if (field?.type === 'Collection' && fieldTemplate?.type === 'Collection') {
return ( return (
<CollectionInputField <CollectionInputField
nodeData={nodeData} nodeId={nodeId}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'CollectionItem' && fieldTemplate.type === 'CollectionItem') {
return (
<CollectionItemInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'ColorField' && fieldTemplate.type === 'ColorField') {
return (
<ColorInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'ImageCollection' && fieldTemplate.type === 'ImageCollection') {
return (
<ImageCollectionInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field} field={field}
fieldTemplate={fieldTemplate} fieldTemplate={fieldTemplate}
/> />
@ -273,20 +238,55 @@ const InputFieldRenderer = (props: InputFieldProps) => {
} }
if ( if (
type === 'SDXLMainModelField' && field?.type === 'CollectionItem' &&
fieldTemplate.type === 'SDXLMainModelField' fieldTemplate?.type === 'CollectionItem'
) { ) {
return ( return (
<SDXLMainModelInputField <CollectionItemInputField
nodeData={nodeData} nodeId={nodeId}
nodeTemplate={nodeTemplate}
field={field} field={field}
fieldTemplate={fieldTemplate} fieldTemplate={fieldTemplate}
/> />
); );
} }
return <Box p={2}>Unknown field type: {type}</Box>; if (field?.type === 'ColorField' && fieldTemplate?.type === 'ColorField') {
return (
<ColorInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'ImageCollection' &&
fieldTemplate?.type === 'ImageCollection'
) {
return (
<ImageCollectionInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'SDXLMainModelField' &&
fieldTemplate?.type === 'SDXLMainModelField'
) {
return (
<SDXLMainModelInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
return <Box p={2}>Unknown field type: {field?.type}</Box>;
}; };
export default memo(InputFieldRenderer); export default memo(InputFieldRenderer);

View File

@ -1,39 +1,16 @@
import { Flex, FormControl, FormLabel, Tooltip } from '@chakra-ui/react'; import { Flex, FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants'; import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
import {
InputFieldTemplate,
InputFieldValue,
InvocationNodeData,
InvocationTemplate,
} from 'features/nodes/types/types';
import { memo } from 'react'; import { memo } from 'react';
import FieldTitle from './FieldTitle'; import FieldTitle from './FieldTitle';
import FieldTooltipContent from './FieldTooltipContent'; import FieldTooltipContent from './FieldTooltipContent';
import InputFieldRenderer from './InputFieldRenderer'; import InputFieldRenderer from './InputFieldRenderer';
type Props = { type Props = {
nodeData: InvocationNodeData; nodeId: string;
nodeTemplate: InvocationTemplate; fieldName: string;
field: InputFieldValue;
fieldTemplate: InputFieldTemplate;
}; };
const LinearViewField = ({ const LinearViewField = ({ nodeId, fieldName }: Props) => {
nodeData,
nodeTemplate,
field,
fieldTemplate,
}: Props) => {
// const dispatch = useAppDispatch();
// const handleRemoveField = useCallback(() => {
// dispatch(
// workflowExposedFieldRemoved({
// nodeId: nodeData.id,
// fieldName: field.name,
// })
// );
// }, [dispatch, field.name, nodeData.id]);
return ( return (
<Flex <Flex
layerStyle="second" layerStyle="second"
@ -48,10 +25,9 @@ const LinearViewField = ({
<Tooltip <Tooltip
label={ label={
<FieldTooltipContent <FieldTooltipContent
nodeData={nodeData} nodeId={nodeId}
nodeTemplate={nodeTemplate} fieldName={fieldName}
field={field} kind="input"
fieldTemplate={fieldTemplate}
/> />
} }
openDelay={HANDLE_TOOLTIP_OPEN_DELAY} openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
@ -66,20 +42,10 @@ const LinearViewField = ({
mb: 0, mb: 0,
}} }}
> >
<FieldTitle <FieldTitle nodeId={nodeId} fieldName={fieldName} kind="input" />
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
</FormLabel> </FormLabel>
</Tooltip> </Tooltip>
<InputFieldRenderer <InputFieldRenderer nodeId={nodeId} fieldName={fieldName} />
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
</FormControl> </FormControl>
</Flex> </Flex>
); );

View File

@ -6,25 +6,19 @@ import {
Tooltip, Tooltip,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { useConnectionState } from 'features/nodes/hooks/useConnectionState'; import { useConnectionState } from 'features/nodes/hooks/useConnectionState';
import { useFieldTemplate } from 'features/nodes/hooks/useNodeData';
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants'; import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
import { import { PropsWithChildren, memo } from 'react';
InvocationNodeData,
InvocationTemplate,
OutputFieldValue,
} from 'features/nodes/types/types';
import { PropsWithChildren, useMemo } from 'react';
import { NodeProps } from 'reactflow';
import FieldHandle from './FieldHandle'; import FieldHandle from './FieldHandle';
import FieldTooltipContent from './FieldTooltipContent'; import FieldTooltipContent from './FieldTooltipContent';
interface Props { interface Props {
nodeProps: NodeProps<InvocationNodeData>; nodeId: string;
nodeTemplate: InvocationTemplate; fieldName: string;
field: OutputFieldValue;
} }
const OutputField = (props: Props) => { const OutputField = ({ nodeId, fieldName }: Props) => {
const { nodeTemplate, nodeProps, field } = props; const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'output');
const { const {
isConnected, isConnected,
@ -32,20 +26,15 @@ const OutputField = (props: Props) => {
isConnectionStartField, isConnectionStartField,
connectionError, connectionError,
shouldDim, shouldDim,
} = useConnectionState({ nodeId: nodeProps.data.id, field, kind: 'output' }); } = useConnectionState({ nodeId, fieldName, kind: 'output' });
const fieldTemplate = useMemo( if (fieldTemplate?.fieldKind !== 'output') {
() => nodeTemplate.outputs[field.name],
[field.name, nodeTemplate]
);
if (!fieldTemplate) {
return ( return (
<OutputFieldWrapper shouldDim={shouldDim}> <OutputFieldWrapper shouldDim={shouldDim}>
<FormControl <FormControl
sx={{ color: 'error.400', textAlign: 'right', fontSize: 'sm' }} sx={{ color: 'error.400', textAlign: 'right', fontSize: 'sm' }}
> >
Unknown output: {field.name} Unknown output: {fieldName}
</FormControl> </FormControl>
</OutputFieldWrapper> </OutputFieldWrapper>
); );
@ -57,10 +46,9 @@ const OutputField = (props: Props) => {
<Tooltip <Tooltip
label={ label={
<FieldTooltipContent <FieldTooltipContent
nodeData={nodeProps.data} nodeId={nodeId}
nodeTemplate={nodeTemplate} fieldName={fieldName}
field={field} kind="output"
fieldTemplate={fieldTemplate}
/> />
} }
openDelay={HANDLE_TOOLTIP_OPEN_DELAY} openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
@ -75,9 +63,6 @@ const OutputField = (props: Props) => {
</FormControl> </FormControl>
</Tooltip> </Tooltip>
<FieldHandle <FieldHandle
nodeProps={nodeProps}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate} fieldTemplate={fieldTemplate}
handleType="source" handleType="source"
isConnectionInProgress={isConnectionInProgress} isConnectionInProgress={isConnectionInProgress}
@ -88,27 +73,28 @@ const OutputField = (props: Props) => {
); );
}; };
export default OutputField; export default memo(OutputField);
type OutputFieldWrapperProps = PropsWithChildren<{ type OutputFieldWrapperProps = PropsWithChildren<{
shouldDim: boolean; shouldDim: boolean;
}>; }>;
const OutputFieldWrapper = ({ const OutputFieldWrapper = memo(
shouldDim, ({ shouldDim, children }: OutputFieldWrapperProps) => (
children, <Flex
}: OutputFieldWrapperProps) => ( sx={{
<Flex position: 'relative',
sx={{ minH: 8,
position: 'relative', py: 0.5,
minH: 8, alignItems: 'center',
py: 0.5, opacity: shouldDim ? 0.5 : 1,
alignItems: 'center', transitionProperty: 'opacity',
opacity: shouldDim ? 0.5 : 1, transitionDuration: '0.1s',
transitionProperty: 'opacity', }}
transitionDuration: '0.1s', >
}} {children}
> </Flex>
{children} )
</Flex>
); );
OutputFieldWrapper.displayName = 'OutputFieldWrapper';

View File

@ -11,8 +11,7 @@ import { FieldComponentProps } from './types';
const BooleanInputFieldComponent = ( const BooleanInputFieldComponent = (
props: FieldComponentProps<BooleanInputFieldValue, BooleanInputFieldTemplate> props: FieldComponentProps<BooleanInputFieldValue, BooleanInputFieldTemplate>
) => { ) => {
const { nodeData, field } = props; const { nodeId, field } = props;
const nodeId = nodeData.id;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();

View File

@ -11,8 +11,7 @@ import { FieldComponentProps } from './types';
const ColorInputFieldComponent = ( const ColorInputFieldComponent = (
props: FieldComponentProps<ColorInputFieldValue, ColorInputFieldTemplate> props: FieldComponentProps<ColorInputFieldValue, ColorInputFieldTemplate>
) => { ) => {
const { nodeData, field } = props; const { nodeId, field } = props;
const nodeId = nodeData.id;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();

View File

@ -19,8 +19,7 @@ const ControlNetModelInputFieldComponent = (
ControlNetModelInputFieldTemplate ControlNetModelInputFieldTemplate
> >
) => { ) => {
const { nodeData, field } = props; const { nodeId, field } = props;
const nodeId = nodeData.id;
const controlNetModel = field.value; const controlNetModel = field.value;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();

View File

@ -11,8 +11,7 @@ import { FieldComponentProps } from './types';
const EnumInputFieldComponent = ( const EnumInputFieldComponent = (
props: FieldComponentProps<EnumInputFieldValue, EnumInputFieldTemplate> props: FieldComponentProps<EnumInputFieldValue, EnumInputFieldTemplate>
) => { ) => {
const { nodeData, field, fieldTemplate } = props; const { nodeId, field, fieldTemplate } = props;
const nodeId = nodeData.id;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();

View File

@ -19,8 +19,7 @@ const ImageCollectionInputFieldComponent = (
ImageCollectionInputFieldTemplate ImageCollectionInputFieldTemplate
> >
) => { ) => {
const { nodeData, field } = props; const { nodeId, field } = props;
const nodeId = nodeData.id;
// const dispatch = useAppDispatch(); // const dispatch = useAppDispatch();

View File

@ -21,8 +21,7 @@ import { FieldComponentProps } from './types';
const ImageInputFieldComponent = ( const ImageInputFieldComponent = (
props: FieldComponentProps<ImageInputFieldValue, ImageInputFieldTemplate> props: FieldComponentProps<ImageInputFieldValue, ImageInputFieldTemplate>
) => { ) => {
const { nodeData, field } = props; const { nodeId, field } = props;
const nodeId = nodeData.id;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { currentData: imageDTO } = useGetImageDTOQuery( const { currentData: imageDTO } = useGetImageDTOQuery(

View File

@ -21,8 +21,7 @@ const LoRAModelInputFieldComponent = (
LoRAModelInputFieldTemplate LoRAModelInputFieldTemplate
> >
) => { ) => {
const { nodeData, field } = props; const { nodeId, field } = props;
const nodeId = nodeData.id;
const lora = field.value; const lora = field.value;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { data: loraModels } = useGetLoRAModelsQuery(); const { data: loraModels } = useGetLoRAModelsQuery();

View File

@ -26,8 +26,7 @@ const MainModelInputFieldComponent = (
MainModelInputFieldTemplate MainModelInputFieldTemplate
> >
) => { ) => {
const { nodeData, field } = props; const { nodeId, field } = props;
const nodeId = nodeData.id;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled; const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;

View File

@ -23,8 +23,7 @@ const NumberInputFieldComponent = (
IntegerInputFieldTemplate | FloatInputFieldTemplate IntegerInputFieldTemplate | FloatInputFieldTemplate
> >
) => { ) => {
const { nodeData, field, fieldTemplate } = props; const { nodeId, field, fieldTemplate } = props;
const nodeId = nodeData.id;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const [valueAsString, setValueAsString] = useState<string>( const [valueAsString, setValueAsString] = useState<string>(
String(field.value) String(field.value)

View File

@ -24,8 +24,7 @@ const RefinerModelInputFieldComponent = (
SDXLRefinerModelInputFieldTemplate SDXLRefinerModelInputFieldTemplate
> >
) => { ) => {
const { nodeData, field } = props; const { nodeId, field } = props;
const nodeId = nodeData.id;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled; const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;

View File

@ -27,8 +27,7 @@ const ModelInputFieldComponent = (
SDXLMainModelInputFieldTemplate SDXLMainModelInputFieldTemplate
> >
) => { ) => {
const { nodeData, field } = props; const { nodeId, field } = props;
const nodeId = nodeData.id;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled; const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;

View File

@ -12,8 +12,7 @@ import { FieldComponentProps } from './types';
const StringInputFieldComponent = ( const StringInputFieldComponent = (
props: FieldComponentProps<StringInputFieldValue, StringInputFieldTemplate> props: FieldComponentProps<StringInputFieldValue, StringInputFieldTemplate>
) => { ) => {
const { nodeData, field, fieldTemplate } = props; const { nodeId, field, fieldTemplate } = props;
const nodeId = nodeData.id;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const handleValueChanged = useCallback( const handleValueChanged = useCallback(

View File

@ -20,8 +20,7 @@ const VaeModelInputFieldComponent = (
VaeModelInputFieldTemplate VaeModelInputFieldTemplate
> >
) => { ) => {
const { nodeData, field } = props; const { nodeId, field } = props;
const nodeId = nodeData.id;
const vae = field.value; const vae = field.value;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { data: vaeModels } = useGetVaeModelsQuery(); const { data: vaeModels } = useGetVaeModelsQuery();

View File

@ -1,16 +1,13 @@
import { import {
InputFieldTemplate, InputFieldTemplate,
InputFieldValue, InputFieldValue,
InvocationNodeData,
InvocationTemplate,
} from 'features/nodes/types/types'; } from 'features/nodes/types/types';
export type FieldComponentProps< export type FieldComponentProps<
V extends InputFieldValue, V extends InputFieldValue,
T extends InputFieldTemplate T extends InputFieldTemplate
> = { > = {
nodeData: InvocationNodeData; nodeId: string;
nodeTemplate: InvocationTemplate;
field: V; field: V;
fieldTemplate: T; fieldTemplate: T;
}; };

View File

@ -55,7 +55,11 @@ const CurrentImageNode = (props: NodeProps) => {
export default memo(CurrentImageNode); export default memo(CurrentImageNode);
const Wrapper = (props: PropsWithChildren<{ nodeProps: NodeProps }>) => ( const Wrapper = (props: PropsWithChildren<{ nodeProps: NodeProps }>) => (
<NodeWrapper nodeProps={props.nodeProps} width={384}> <NodeWrapper
nodeId={props.nodeProps.data.id}
selected={props.nodeProps.selected}
width={384}
>
<Flex <Flex
className={DRAG_HANDLE_CLASSNAME} className={DRAG_HANDLE_CLASSNAME}
sx={{ sx={{

View File

@ -1,5 +1,6 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { makeTemplateSelector } from 'features/nodes/store/util/makeTemplateSelector';
import { InvocationNodeData } from 'features/nodes/types/types'; import { InvocationNodeData } from 'features/nodes/types/types';
import { memo, useMemo } from 'react'; import { memo, useMemo } from 'react';
import { NodeProps } from 'reactflow'; import { NodeProps } from 'reactflow';
@ -7,18 +8,40 @@ import InvocationNode from '../Invocation/InvocationNode';
import UnknownNodeFallback from '../Invocation/UnknownNodeFallback'; import UnknownNodeFallback from '../Invocation/UnknownNodeFallback';
const InvocationNodeWrapper = (props: NodeProps<InvocationNodeData>) => { const InvocationNodeWrapper = (props: NodeProps<InvocationNodeData>) => {
const { data } = props; const { data, selected } = props;
const { type } = data; const { id: nodeId, type, isOpen, label } = data;
const templateSelector = useMemo(() => makeTemplateSelector(type), [type]); const hasTemplateSelector = useMemo(
() =>
createSelector(stateSelector, ({ nodes }) =>
Boolean(nodes.nodeTemplates[type])
),
[type]
);
const nodeTemplate = useAppSelector(templateSelector); const nodeTemplate = useAppSelector(hasTemplateSelector);
if (!nodeTemplate) { if (!nodeTemplate) {
return <UnknownNodeFallback nodeProps={props} />; return (
<UnknownNodeFallback
nodeId={nodeId}
isOpen={isOpen}
label={label}
type={type}
selected={selected}
/>
);
} }
return <InvocationNode nodeProps={props} nodeTemplate={nodeTemplate} />; return (
<InvocationNode
nodeId={nodeId}
isOpen={isOpen}
label={label}
type={type}
selected={selected}
/>
);
}; };
export default memo(InvocationNodeWrapper); export default memo(InvocationNodeWrapper);

View File

@ -10,7 +10,7 @@ import NodeTitle from '../Invocation/NodeTitle';
import NodeWrapper from '../Invocation/NodeWrapper'; import NodeWrapper from '../Invocation/NodeWrapper';
const NotesNode = (props: NodeProps<NotesNodeData>) => { const NotesNode = (props: NodeProps<NotesNodeData>) => {
const { id: nodeId, data } = props; const { id: nodeId, data, selected } = props;
const { notes, isOpen } = data; const { notes, isOpen } = data;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const handleChange = useCallback( const handleChange = useCallback(
@ -21,7 +21,7 @@ const NotesNode = (props: NodeProps<NotesNodeData>) => {
); );
return ( return (
<NodeWrapper nodeProps={props}> <NodeWrapper nodeId={nodeId} selected={selected}>
<Flex <Flex
layerStyle="nodeHeader" layerStyle="nodeHeader"
sx={{ sx={{
@ -32,8 +32,8 @@ const NotesNode = (props: NodeProps<NotesNodeData>) => {
h: 8, h: 8,
}} }}
> >
<NodeCollapseButton nodeProps={props} /> <NodeCollapseButton nodeId={nodeId} isOpen={isOpen} />
<NodeTitle nodeData={props.data} title="Notes" /> <NodeTitle nodeId={nodeId} title="Notes" />
<Box minW={8} /> <Box minW={8} />
</Flex> </Flex>
{isOpen && ( {isOpen && (

View File

@ -6,39 +6,11 @@ import {
TabPanels, TabPanels,
Tabs, Tabs,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import ImageMetadataJSON from 'features/gallery/components/ImageMetadataViewer/ImageMetadataJSON';
import { memo } from 'react'; import { memo } from 'react';
import NodeDataInspector from './NodeDataInspector';
const selector = createSelector( import NodeTemplateInspector from './NodeTemplateInspector';
stateSelector,
({ nodes }) => {
const lastSelectedNodeId =
nodes.selectedNodes[nodes.selectedNodes.length - 1];
const lastSelectedNode = nodes.nodes.find(
(node) => node.id === lastSelectedNodeId
);
const lastSelectedNodeTemplate = lastSelectedNode
? nodes.nodeTemplates[lastSelectedNode.data.type]
: undefined;
return {
node: lastSelectedNode,
template: lastSelectedNodeTemplate,
};
},
defaultSelectorOptions
);
const InspectorPanel = () => { const InspectorPanel = () => {
const { node, template } = useAppSelector(selector);
return ( return (
<Flex <Flex
layerStyle="first" layerStyle="first"
@ -60,37 +32,10 @@ const InspectorPanel = () => {
<TabPanels> <TabPanels>
<TabPanel> <TabPanel>
{template ? ( <NodeTemplateInspector />
<Flex
sx={{
flexDir: 'column',
alignItems: 'flex-start',
gap: 2,
h: 'full',
}}
>
<ImageMetadataJSON
jsonObject={template}
label="Node Template"
/>
</Flex>
) : (
<IAINoContentFallback
label={
node
? 'No template found for selected node'
: 'No node selected'
}
icon={null}
/>
)}
</TabPanel> </TabPanel>
<TabPanel> <TabPanel>
{node ? ( <NodeDataInspector />
<ImageMetadataJSON jsonObject={node.data} label="Node Data" />
) : (
<IAINoContentFallback label="No node selected" icon={null} />
)}
</TabPanel> </TabPanel>
</TabPanels> </TabPanels>
</Tabs> </Tabs>

View File

@ -17,20 +17,20 @@ const selector = createSelector(
); );
return { return {
node: lastSelectedNode, data: lastSelectedNode?.data,
}; };
}, },
defaultSelectorOptions defaultSelectorOptions
); );
const NodeDataInspector = () => { const NodeDataInspector = () => {
const { node } = useAppSelector(selector); const { data } = useAppSelector(selector);
return node ? ( if (!data) {
<ImageMetadataJSON jsonObject={node.data} label="Node Data" /> return <IAINoContentFallback label="No node selected" icon={null} />;
) : ( }
<IAINoContentFallback label="No node data" icon={null} />
); return <ImageMetadataJSON jsonObject={data} label="Node Data" />;
}; };
export default memo(NodeDataInspector); export default memo(NodeDataInspector);

View File

@ -0,0 +1,40 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import ImageMetadataJSON from 'features/gallery/components/ImageMetadataViewer/ImageMetadataJSON';
import { memo } from 'react';
const selector = createSelector(
stateSelector,
({ nodes }) => {
const lastSelectedNodeId =
nodes.selectedNodes[nodes.selectedNodes.length - 1];
const lastSelectedNode = nodes.nodes.find(
(node) => node.id === lastSelectedNodeId
);
const lastSelectedNodeTemplate = lastSelectedNode
? nodes.nodeTemplates[lastSelectedNode.data.type]
: undefined;
return {
template: lastSelectedNodeTemplate,
};
},
defaultSelectorOptions
);
const NodeTemplateInspector = () => {
const { template } = useAppSelector(selector);
if (!template) {
return <IAINoContentFallback label="No node selected" icon={null} />;
}
return <ImageMetadataJSON jsonObject={template} label="Node Template" />;
};
export default memo(NodeTemplateInspector);

View File

@ -6,14 +6,6 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIDroppable from 'common/components/IAIDroppable'; import IAIDroppable from 'common/components/IAIDroppable';
import { IAINoContentFallback } from 'common/components/IAIImageFallback'; import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { AddFieldToLinearViewDropData } from 'features/dnd/types'; import { AddFieldToLinearViewDropData } from 'features/dnd/types';
import {
InputFieldTemplate,
InputFieldValue,
InvocationNodeData,
InvocationTemplate,
isInvocationNode,
} from 'features/nodes/types/types';
import { forEach } from 'lodash-es';
import { memo } from 'react'; import { memo } from 'react';
import LinearViewField from '../../fields/LinearViewField'; import LinearViewField from '../../fields/LinearViewField';
import ScrollableContent from '../ScrollableContent'; import ScrollableContent from '../ScrollableContent';
@ -21,41 +13,8 @@ import ScrollableContent from '../ScrollableContent';
const selector = createSelector( const selector = createSelector(
stateSelector, stateSelector,
({ nodes }) => { ({ nodes }) => {
const fields: {
nodeData: InvocationNodeData;
nodeTemplate: InvocationTemplate;
field: InputFieldValue;
fieldTemplate: InputFieldTemplate;
}[] = [];
const { exposedFields } = nodes.workflow;
nodes.nodes.filter(isInvocationNode).forEach((node) => {
const nodeTemplate = nodes.nodeTemplates[node.data.type];
if (!nodeTemplate) {
return;
}
forEach(node.data.inputs, (field) => {
if (
!exposedFields.some(
(f) => f.nodeId === node.id && f.fieldName === field.name
)
) {
return;
}
const fieldTemplate = nodeTemplate.inputs[field.name];
if (!fieldTemplate) {
return;
}
fields.push({
nodeData: node.data,
nodeTemplate,
field,
fieldTemplate,
});
});
});
return { return {
fields, fields: nodes.workflow.exposedFields,
}; };
}, },
defaultSelectorOptions defaultSelectorOptions
@ -89,13 +48,11 @@ const LinearTabContent = () => {
}} }}
> >
{fields.length ? ( {fields.length ? (
fields.map(({ nodeData, nodeTemplate, field, fieldTemplate }) => ( fields.map(({ nodeId, fieldName }) => (
<LinearViewField <LinearViewField
key={field.id} key={`${nodeId}-${fieldName}`}
nodeData={nodeData} nodeId={nodeId}
nodeTemplate={nodeTemplate} fieldName={fieldName}
field={field}
fieldTemplate={fieldTemplate}
/> />
)) ))
) : ( ) : (

View File

@ -2,8 +2,8 @@ import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { makeConnectionErrorSelector } from 'features/nodes/store/util/makeIsConnectionValidSelector'; import { makeConnectionErrorSelector } from 'features/nodes/store/util/makeIsConnectionValidSelector';
import { InputFieldValue, OutputFieldValue } from 'features/nodes/types/types';
import { useMemo } from 'react'; import { useMemo } from 'react';
import { useFieldType } from './useNodeData';
const selectIsConnectionInProgress = createSelector( const selectIsConnectionInProgress = createSelector(
stateSelector, stateSelector,
@ -12,23 +12,19 @@ const selectIsConnectionInProgress = createSelector(
nodes.connectionStartParams !== null nodes.connectionStartParams !== null
); );
export type UseConnectionStateProps = export type UseConnectionStateProps = {
| { nodeId: string;
nodeId: string; fieldName: string;
field: InputFieldValue; kind: 'input' | 'output';
kind: 'input'; };
}
| {
nodeId: string;
field: OutputFieldValue;
kind: 'output';
};
export const useConnectionState = ({ export const useConnectionState = ({
nodeId, nodeId,
field, fieldName,
kind, kind,
}: UseConnectionStateProps) => { }: UseConnectionStateProps) => {
const fieldType = useFieldType(nodeId, fieldName, kind);
const selectIsConnected = useMemo( const selectIsConnected = useMemo(
() => () =>
createSelector(stateSelector, ({ nodes }) => createSelector(stateSelector, ({ nodes }) =>
@ -37,23 +33,23 @@ export const useConnectionState = ({
return ( return (
(kind === 'input' ? edge.target : edge.source) === nodeId && (kind === 'input' ? edge.target : edge.source) === nodeId &&
(kind === 'input' ? edge.targetHandle : edge.sourceHandle) === (kind === 'input' ? edge.targetHandle : edge.sourceHandle) ===
field.name fieldName
); );
}).length }).length
) )
), ),
[field.name, kind, nodeId] [fieldName, kind, nodeId]
); );
const selectConnectionError = useMemo( const selectConnectionError = useMemo(
() => () =>
makeConnectionErrorSelector( makeConnectionErrorSelector(
nodeId, nodeId,
field.name, fieldName,
kind === 'input' ? 'target' : 'source', kind === 'input' ? 'target' : 'source',
field.type fieldType
), ),
[nodeId, field.name, field.type, kind] [nodeId, fieldName, kind, fieldType]
); );
const selectIsConnectionStartField = useMemo( const selectIsConnectionStartField = useMemo(
@ -61,12 +57,12 @@ export const useConnectionState = ({
createSelector(stateSelector, ({ nodes }) => createSelector(stateSelector, ({ nodes }) =>
Boolean( Boolean(
nodes.connectionStartParams?.nodeId === nodeId && nodes.connectionStartParams?.nodeId === nodeId &&
nodes.connectionStartParams?.handleId === field.name && nodes.connectionStartParams?.handleId === fieldName &&
nodes.connectionStartParams?.handleType === nodes.connectionStartParams?.handleType ===
{ input: 'target', output: 'source' }[kind] { input: 'target', output: 'source' }[kind]
) )
), ),
[field.name, kind, nodeId] [fieldName, kind, nodeId]
); );
const isConnected = useAppSelector(selectIsConnected); const isConnected = useAppSelector(selectIsConnected);

View File

@ -0,0 +1,289 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { map, some } from 'lodash-es';
import { useMemo } from 'react';
import {
FOOTER_FIELDS,
IMAGE_FIELDS,
} from '../components/Invocation/NodeFooter';
import { isInvocationNode } from '../types/types';
const KIND_MAP = {
input: 'inputs' as const,
output: 'outputs' as const,
};
export const useNodeTemplate = (nodeId: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? ''];
return nodeTemplate;
},
defaultSelectorOptions
),
[nodeId]
);
const nodeTemplate = useAppSelector(selector);
return nodeTemplate;
};
export const useNodeData = (nodeId: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
return node?.data;
},
defaultSelectorOptions
),
[nodeId]
);
const nodeData = useAppSelector(selector);
return nodeData;
};
export const useFieldData = (nodeId: string, fieldName: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return;
}
return node?.data.inputs[fieldName];
},
defaultSelectorOptions
),
[fieldName, nodeId]
);
const fieldData = useAppSelector(selector);
return fieldData;
};
export const useFieldType = (
nodeId: string,
fieldName: string,
kind: 'input' | 'output'
) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return;
}
return node?.data[KIND_MAP[kind]][fieldName]?.type;
},
defaultSelectorOptions
),
[fieldName, kind, nodeId]
);
const fieldType = useAppSelector(selector);
return fieldType;
};
export const useFieldNames = (nodeId: string, kind: 'input' | 'output') => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return [];
}
return map(node.data[KIND_MAP[kind]], (field) => field.name).filter(
(fieldName) => fieldName !== 'is_intermediate'
);
},
defaultSelectorOptions
),
[kind, nodeId]
);
const fieldNames = useAppSelector(selector);
return fieldNames;
};
export const useWithFooter = (nodeId: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return false;
}
return some(node.data.outputs, (output) =>
FOOTER_FIELDS.includes(output.type)
);
},
defaultSelectorOptions
),
[nodeId]
);
const withFooter = useAppSelector(selector);
return withFooter;
};
export const useHasImageOutput = (nodeId: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return false;
}
return some(node.data.outputs, (output) =>
IMAGE_FIELDS.includes(output.type)
);
},
defaultSelectorOptions
),
[nodeId]
);
const hasImageOutput = useAppSelector(selector);
return hasImageOutput;
};
export const useIsIntermediate = (nodeId: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return false;
}
return Boolean(node.data.inputs.is_intermediate?.value);
},
defaultSelectorOptions
),
[nodeId]
);
const is_intermediate = useAppSelector(selector);
return is_intermediate;
};
export const useNodeLabel = (nodeId: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return false;
}
return node.data.label;
},
defaultSelectorOptions
),
[nodeId]
);
const label = useAppSelector(selector);
return label;
};
export const useNodeTemplateTitle = (nodeId: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return false;
}
const nodeTemplate = node
? nodes.nodeTemplates[node.data.type]
: undefined;
return nodeTemplate?.title;
},
defaultSelectorOptions
),
[nodeId]
);
const title = useAppSelector(selector);
return title;
};
export const useFieldTemplate = (
nodeId: string,
fieldName: string,
kind: 'input' | 'output'
) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return;
}
const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? ''];
return nodeTemplate?.[KIND_MAP[kind]][fieldName];
},
defaultSelectorOptions
),
[fieldName, kind, nodeId]
);
const fieldTemplate = useAppSelector(selector);
return fieldTemplate;
};
export const useDoesInputHaveValue = (nodeId: string, fieldName: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return;
}
return Boolean(node?.data.inputs[fieldName]?.value);
},
defaultSelectorOptions
),
[fieldName, nodeId]
);
const doesFieldHaveValue = useAppSelector(selector);
return doesFieldHaveValue;
};

View File

@ -9,9 +9,13 @@ export const makeConnectionErrorSelector = (
nodeId: string, nodeId: string,
fieldName: string, fieldName: string,
handleType: HandleType, handleType: HandleType,
fieldType: FieldType fieldType?: FieldType
) => ) =>
createSelector(stateSelector, (state) => { createSelector(stateSelector, (state) => {
if (!fieldType) {
return 'No field type';
}
const { currentConnectionFieldType, connectionStartParams, nodes, edges } = const { currentConnectionFieldType, connectionStartParams, nodes, edges } =
state.nodes; state.nodes;

View File

@ -457,12 +457,13 @@ export type ColorInputFieldTemplate = InputFieldTemplateBase & {
}; };
export const isInputFieldValue = ( export const isInputFieldValue = (
field: InputFieldValue | OutputFieldValue field?: InputFieldValue | OutputFieldValue
): field is InputFieldValue => field.fieldKind === 'input'; ): field is InputFieldValue => Boolean(field && field.fieldKind === 'input');
export const isInputFieldTemplate = ( export const isInputFieldTemplate = (
fieldTemplate: InputFieldTemplate | OutputFieldTemplate fieldTemplate?: InputFieldTemplate | OutputFieldTemplate
): fieldTemplate is InputFieldTemplate => fieldTemplate.fieldKind === 'input'; ): fieldTemplate is InputFieldTemplate =>
Boolean(fieldTemplate && fieldTemplate.fieldKind === 'input');
/** /**
* JANKY CUSTOMISATION OF OpenAPI SCHEMA TYPES * JANKY CUSTOMISATION OF OpenAPI SCHEMA TYPES
@ -632,20 +633,22 @@ export type NodeData =
export const isInvocationNode = ( export const isInvocationNode = (
node?: Node<NodeData> node?: Node<NodeData>
): node is Node<InvocationNodeData> => node?.type === 'invocation'; ): node is Node<InvocationNodeData> =>
Boolean(node && node.type === 'invocation');
export const isInvocationNodeData = ( export const isInvocationNodeData = (
node?: NodeData node?: NodeData
): node is InvocationNodeData => ): node is InvocationNodeData =>
!['notes', 'current_image'].includes(node?.type ?? ''); Boolean(node && !['notes', 'current_image'].includes(node.type));
export const isNotesNode = ( export const isNotesNode = (
node?: Node<NodeData> node?: Node<NodeData>
): node is Node<NotesNodeData> => node?.type === 'notes'; ): node is Node<NotesNodeData> => Boolean(node && node.type === 'notes');
export const isProgressImageNode = ( export const isProgressImageNode = (
node?: Node<NodeData> node?: Node<NodeData>
): node is Node<CurrentImageNodeData> => node?.type === 'current_image'; ): node is Node<CurrentImageNodeData> =>
Boolean(node && node.type === 'current_image');
export enum NodeStatus { export enum NodeStatus {
PENDING, PENDING,