mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(ui): improve nodes performance
This commit is contained in:
parent
44a653925a
commit
4901911c1a
@ -49,7 +49,7 @@ export const AddNodeMenu = () => {
|
||||
);
|
||||
|
||||
return (
|
||||
<Menu>
|
||||
<Menu isLazy>
|
||||
<MenuButton
|
||||
as={IAIIconButton}
|
||||
aria-label="Add Node"
|
||||
|
@ -1,5 +1,5 @@
|
||||
import { Tooltip } from '@chakra-ui/react';
|
||||
import { CSSProperties, useMemo } from 'react';
|
||||
import { CSSProperties, memo, useMemo } from 'react';
|
||||
import {
|
||||
Handle,
|
||||
Position,
|
||||
@ -38,13 +38,14 @@ type FieldHandleProps = {
|
||||
styles?: CSSProperties;
|
||||
};
|
||||
|
||||
export const FieldHandle = (props: FieldHandleProps) => {
|
||||
const FieldHandle = (props: FieldHandleProps) => {
|
||||
const { nodeId, field, isValidConnection, handleType, styles } = props;
|
||||
const { name, title, type, description } = field;
|
||||
|
||||
console.log(props);
|
||||
|
||||
return (
|
||||
<Tooltip
|
||||
key={name}
|
||||
label={type}
|
||||
placement={handleType === 'target' ? 'start' : 'end'}
|
||||
hasArrow
|
||||
@ -67,3 +68,5 @@ export const FieldHandle = (props: FieldHandleProps) => {
|
||||
</Tooltip>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(FieldHandle);
|
||||
|
@ -30,6 +30,10 @@ import { IAIIconButton } from 'exports';
|
||||
import { InfoIcon } from '@chakra-ui/icons';
|
||||
import { ViewportControls } from './ViewportControls';
|
||||
import NodeGraphOverlay from './NodeGraphOverlay';
|
||||
import TopLeftPanel from './panels/TopLeftPanel';
|
||||
import TopRightPanel from './panels/TopRightPanel';
|
||||
import TopCenterPanel from './panels/TopCenterPanel';
|
||||
import BottomLeftPanel from './panels/BottomLeftPanel.tsx';
|
||||
|
||||
const nodeTypes = { invocation: InvocationComponent };
|
||||
|
||||
@ -37,9 +41,6 @@ export const Flow = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const nodes = useAppSelector((state: RootState) => state.nodes.nodes);
|
||||
const edges = useAppSelector((state: RootState) => state.nodes.edges);
|
||||
const shouldShowGraphOverlay = useAppSelector(
|
||||
(state: RootState) => state.nodes.shouldShowGraphOverlay
|
||||
);
|
||||
|
||||
const onNodesChange: OnNodesChange = useCallback(
|
||||
(changes) => {
|
||||
@ -76,10 +77,6 @@ export const Flow = () => {
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleInvoke = useCallback(() => {
|
||||
dispatch(nodesGraphBuilt());
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<ReactFlow
|
||||
nodeTypes={nodeTypes}
|
||||
@ -94,19 +91,10 @@ export const Flow = () => {
|
||||
style: { strokeWidth: 2 },
|
||||
}}
|
||||
>
|
||||
<Panel position="top-left">
|
||||
<AddNodeMenu />
|
||||
</Panel>
|
||||
<Panel position="top-center">
|
||||
<Button onClick={handleInvoke}>Will it blend?</Button>
|
||||
</Panel>
|
||||
<Panel position="top-right">
|
||||
<FieldTypeLegend />
|
||||
{shouldShowGraphOverlay && <NodeGraphOverlay />}
|
||||
</Panel>
|
||||
<Panel position="bottom-left">
|
||||
<ViewportControls />
|
||||
</Panel>
|
||||
<TopLeftPanel />
|
||||
<TopCenterPanel />
|
||||
<TopRightPanel />
|
||||
<BottomLeftPanel />
|
||||
<Background />
|
||||
<MiniMap nodeStrokeWidth={3} zoomable pannable />
|
||||
</ReactFlow>
|
||||
|
@ -5,7 +5,7 @@ import { FaInfoCircle } from 'react-icons/fa';
|
||||
|
||||
interface IAINodeHeaderProps {
|
||||
nodeId: string;
|
||||
template: MutableRefObject<InvocationTemplate | undefined>;
|
||||
template: InvocationTemplate;
|
||||
}
|
||||
|
||||
export default function IAINodeHeader(props: IAINodeHeaderProps) {
|
||||
@ -21,11 +21,11 @@ export default function IAINodeHeader(props: IAINodeHeaderProps) {
|
||||
>
|
||||
<Tooltip label={nodeId}>
|
||||
<Heading size="xs" fontWeight={600} color="base.100">
|
||||
{template.current?.title}
|
||||
{template.title}
|
||||
</Heading>
|
||||
</Tooltip>
|
||||
<Tooltip
|
||||
label={template.current?.description}
|
||||
label={template.description}
|
||||
placement="top"
|
||||
hasArrow
|
||||
shouldWrapChildren
|
||||
|
@ -3,7 +3,7 @@ import {
|
||||
InputFieldValue,
|
||||
InvocationTemplate,
|
||||
} from 'features/nodes/types/types';
|
||||
import { MutableRefObject, ReactNode } from 'react';
|
||||
import { memo, MutableRefObject, ReactNode } from 'react';
|
||||
import { map } from 'lodash';
|
||||
import { useAppSelector } from 'app/storeHooks';
|
||||
import { RootState } from 'app/store';
|
||||
@ -17,7 +17,7 @@ import {
|
||||
Icon,
|
||||
Divider,
|
||||
} from '@chakra-ui/react';
|
||||
import { FieldHandle } from '../FieldHandle';
|
||||
import FieldHandle from '../FieldHandle';
|
||||
import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection';
|
||||
import { InputFieldComponent } from '../InputFieldComponent';
|
||||
import { FaInfoCircle } from 'react-icons/fa';
|
||||
@ -37,7 +37,6 @@ function IAINodeInput(props: IAINodeInputProps) {
|
||||
|
||||
return (
|
||||
<Box
|
||||
key={input.id}
|
||||
position="relative"
|
||||
borderColor={
|
||||
!template
|
||||
@ -96,11 +95,11 @@ function IAINodeInput(props: IAINodeInputProps) {
|
||||
|
||||
interface IAINodeInputsProps {
|
||||
nodeId: string;
|
||||
template: MutableRefObject<InvocationTemplate | undefined>;
|
||||
template: InvocationTemplate;
|
||||
inputs: Record<string, InputFieldValue>;
|
||||
}
|
||||
|
||||
export default function IAINodeInputs(props: IAINodeInputsProps) {
|
||||
const IAINodeInputs = (props: IAINodeInputsProps) => {
|
||||
const { nodeId, template, inputs } = props;
|
||||
|
||||
const connectedInputs = useAppSelector(
|
||||
@ -112,7 +111,7 @@ export default function IAINodeInputs(props: IAINodeInputsProps) {
|
||||
const inputSockets = map(inputs);
|
||||
|
||||
inputSockets.forEach((inputSocket, index) => {
|
||||
const inputTemplate = template.current?.inputs[inputSocket.name];
|
||||
const inputTemplate = template.inputs[inputSocket.name];
|
||||
|
||||
const isConnected = Boolean(
|
||||
connectedInputs.filter((connectedInput) => {
|
||||
@ -129,6 +128,7 @@ export default function IAINodeInputs(props: IAINodeInputsProps) {
|
||||
|
||||
IAINodeInputsToRender.push(
|
||||
<IAINodeInput
|
||||
key={inputSocket.id}
|
||||
nodeId={nodeId}
|
||||
input={inputSocket}
|
||||
template={inputTemplate}
|
||||
@ -145,4 +145,6 @@ export default function IAINodeInputs(props: IAINodeInputsProps) {
|
||||
};
|
||||
|
||||
return renderIAINodeInputs();
|
||||
}
|
||||
};
|
||||
|
||||
export default memo(IAINodeInputs);
|
||||
|
@ -3,12 +3,12 @@ import {
|
||||
OutputFieldTemplate,
|
||||
OutputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
import { MutableRefObject, ReactNode } from 'react';
|
||||
import { memo, MutableRefObject, ReactNode } from 'react';
|
||||
import { map } from 'lodash';
|
||||
import { useAppSelector } from 'app/storeHooks';
|
||||
import { RootState } from 'app/store';
|
||||
import { Box, Flex, FormControl, FormLabel, HStack } from '@chakra-ui/react';
|
||||
import { FieldHandle } from '../FieldHandle';
|
||||
import FieldHandle from '../FieldHandle';
|
||||
import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection';
|
||||
|
||||
interface IAINodeOutputProps {
|
||||
@ -23,7 +23,7 @@ function IAINodeOutput(props: IAINodeOutputProps) {
|
||||
const isValidConnection = useIsValidConnection();
|
||||
|
||||
return (
|
||||
<Box key={output.id} position="relative">
|
||||
<Box position="relative">
|
||||
<FormControl isDisabled={!template ? true : connected} paddingRight={3}>
|
||||
{!template ? (
|
||||
<HStack justifyContent="space-between" alignItems="center">
|
||||
@ -52,11 +52,11 @@ function IAINodeOutput(props: IAINodeOutputProps) {
|
||||
|
||||
interface IAINodeOutputsProps {
|
||||
nodeId: string;
|
||||
template: MutableRefObject<InvocationTemplate | undefined>;
|
||||
template: InvocationTemplate;
|
||||
outputs: Record<string, OutputFieldValue>;
|
||||
}
|
||||
|
||||
export default function IAINodeOutputs(props: IAINodeOutputsProps) {
|
||||
const IAINodeOutputs = (props: IAINodeOutputsProps) => {
|
||||
const { nodeId, template, outputs } = props;
|
||||
|
||||
const connectedInputs = useAppSelector(
|
||||
@ -68,7 +68,7 @@ export default function IAINodeOutputs(props: IAINodeOutputsProps) {
|
||||
const outputSockets = map(outputs);
|
||||
|
||||
outputSockets.forEach((outputSocket) => {
|
||||
const outputTemplate = template.current?.outputs[outputSocket.name];
|
||||
const outputTemplate = template.outputs[outputSocket.name];
|
||||
|
||||
const isConnected = Boolean(
|
||||
connectedInputs.filter((connectedInput) => {
|
||||
@ -81,6 +81,7 @@ export default function IAINodeOutputs(props: IAINodeOutputsProps) {
|
||||
|
||||
IAINodeOutputsToRender.push(
|
||||
<IAINodeOutput
|
||||
key={outputSocket.id}
|
||||
nodeId={nodeId}
|
||||
output={outputSocket}
|
||||
template={outputTemplate}
|
||||
@ -93,4 +94,6 @@ export default function IAINodeOutputs(props: IAINodeOutputsProps) {
|
||||
};
|
||||
|
||||
return renderIAINodeOutputs();
|
||||
}
|
||||
};
|
||||
|
||||
export default memo(IAINodeOutputs);
|
||||
|
@ -1,15 +1,19 @@
|
||||
import { NodeProps, NodeResizeControl } from 'reactflow';
|
||||
import { Box, Flex, Icon, useToken } from '@chakra-ui/react';
|
||||
import { FaExclamationCircle } from 'react-icons/fa';
|
||||
import { InvocationValue } from '../types/types';
|
||||
import { InvocationTemplate, InvocationValue } from '../types/types';
|
||||
|
||||
import { memo, PropsWithChildren, useRef } from 'react';
|
||||
import { memo, PropsWithChildren, useMemo, useRef } from 'react';
|
||||
import { useGetInvocationTemplate } from '../hooks/useInvocationTemplate';
|
||||
import IAINodeOutputs from './IAINode/IAINodeOutputs';
|
||||
import IAINodeInputs from './IAINode/IAINodeInputs';
|
||||
import IAINodeHeader from './IAINode/IAINodeHeader';
|
||||
import { IoResize } from 'react-icons/io5';
|
||||
import IAINodeResizer from './IAINode/IAINodeResizer';
|
||||
import { RootState } from 'app/store';
|
||||
import { AnyInvocationType } from 'services/events/types';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/storeHooks';
|
||||
|
||||
type InvocationComponentWrapperProps = PropsWithChildren & {
|
||||
selected: boolean;
|
||||
@ -36,16 +40,35 @@ const InvocationComponentWrapper = (props: InvocationComponentWrapperProps) => {
|
||||
);
|
||||
};
|
||||
|
||||
const makeTemplateSelector = (type: AnyInvocationType) =>
|
||||
createSelector(
|
||||
[(state: RootState) => state.nodes],
|
||||
(nodes) => {
|
||||
const template = nodes.invocationTemplates[type];
|
||||
if (!template) {
|
||||
return;
|
||||
}
|
||||
return template;
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: (
|
||||
a: InvocationTemplate | undefined,
|
||||
b: InvocationTemplate | undefined
|
||||
) => a !== undefined && b !== undefined && a.type === b.type,
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
export const InvocationComponent = memo((props: NodeProps<InvocationValue>) => {
|
||||
const { id: nodeId, data, selected } = props;
|
||||
const { type, inputs, outputs } = data;
|
||||
|
||||
const getInvocationTemplate = useGetInvocationTemplate();
|
||||
// TODO: determine if a field/handle is connected and disable the input if so
|
||||
const templateSelector = useMemo(() => makeTemplateSelector(type), [type]);
|
||||
|
||||
const template = useRef(getInvocationTemplate(type));
|
||||
const template = useAppSelector(templateSelector);
|
||||
|
||||
if (!template.current) {
|
||||
if (!template) {
|
||||
return (
|
||||
<InvocationComponentWrapper selected={selected}>
|
||||
<Flex sx={{ alignItems: 'center', justifyContent: 'center' }}>
|
||||
|
@ -0,0 +1,11 @@
|
||||
import { memo } from 'react';
|
||||
import { Panel } from 'reactflow';
|
||||
import { ViewportControls } from '../ViewportControls';
|
||||
|
||||
const BottomLeftPanel = () => (
|
||||
<Panel position="bottom-left">
|
||||
<ViewportControls />
|
||||
</Panel>
|
||||
);
|
||||
|
||||
export default memo(BottomLeftPanel);
|
@ -0,0 +1,23 @@
|
||||
import { useAppDispatch } from 'app/storeHooks';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { Panel } from 'reactflow';
|
||||
import { nodesGraphBuilt } from 'services/thunks/session';
|
||||
|
||||
const TopCenterPanel = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleInvoke = useCallback(() => {
|
||||
dispatch(nodesGraphBuilt());
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<Panel position="top-center">
|
||||
<IAIButton colorScheme="accent" onClick={handleInvoke}>
|
||||
Will it blend?
|
||||
</IAIButton>
|
||||
</Panel>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(TopCenterPanel);
|
@ -0,0 +1,11 @@
|
||||
import { memo } from 'react';
|
||||
import { Panel } from 'reactflow';
|
||||
import { AddNodeMenu } from '../AddNodeMenu';
|
||||
|
||||
const TopLeftPanel = () => (
|
||||
<Panel position="top-left">
|
||||
<AddNodeMenu />
|
||||
</Panel>
|
||||
);
|
||||
|
||||
export default memo(TopLeftPanel);
|
@ -0,0 +1,21 @@
|
||||
import { RootState } from 'app/store';
|
||||
import { useAppSelector } from 'app/storeHooks';
|
||||
import { memo } from 'react';
|
||||
import { Panel } from 'reactflow';
|
||||
import { FieldTypeLegend } from '../FieldTypeLegend';
|
||||
import NodeGraphOverlay from '../NodeGraphOverlay';
|
||||
|
||||
const TopRightPanel = () => {
|
||||
const shouldShowGraphOverlay = useAppSelector(
|
||||
(state: RootState) => state.nodes.shouldShowGraphOverlay
|
||||
);
|
||||
|
||||
return (
|
||||
<Panel position="top-right">
|
||||
<FieldTypeLegend />
|
||||
{shouldShowGraphOverlay && <NodeGraphOverlay />}
|
||||
</Panel>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(TopRightPanel);
|
@ -1,6 +1,8 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store';
|
||||
import { useAppSelector } from 'app/storeHooks';
|
||||
import { reduce } from 'lodash';
|
||||
import { useCallback } from 'react';
|
||||
import { Node, useReactFlow } from 'reactflow';
|
||||
import { AnyInvocationType } from 'services/events/types';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
@ -11,75 +13,82 @@ import {
|
||||
} from '../types/types';
|
||||
import { buildInputFieldValue } from '../util/fieldValueBuilders';
|
||||
|
||||
const templatesSelector = createSelector(
|
||||
[(state: RootState) => state.nodes],
|
||||
(nodes) => nodes.invocationTemplates,
|
||||
{ memoizeOptions: { resultEqualityCheck: (a, b) => true } }
|
||||
);
|
||||
|
||||
export const useBuildInvocation = () => {
|
||||
const invocationTemplates = useAppSelector(
|
||||
(state: RootState) => state.nodes.invocationTemplates
|
||||
);
|
||||
const invocationTemplates = useAppSelector(templatesSelector);
|
||||
|
||||
const reactflow = useReactFlow();
|
||||
const flow = useReactFlow();
|
||||
|
||||
return (type: AnyInvocationType) => {
|
||||
const template = invocationTemplates[type];
|
||||
return useCallback(
|
||||
(type: AnyInvocationType) => {
|
||||
const template = invocationTemplates[type];
|
||||
|
||||
if (template === undefined) {
|
||||
console.error(`Unable to find template ${type}.`);
|
||||
return;
|
||||
}
|
||||
if (template === undefined) {
|
||||
console.error(`Unable to find template ${type}.`);
|
||||
return;
|
||||
}
|
||||
|
||||
const nodeId = uuidv4();
|
||||
const nodeId = uuidv4();
|
||||
|
||||
const inputs = reduce(
|
||||
template.inputs,
|
||||
(inputsAccumulator, inputTemplate, inputName) => {
|
||||
const fieldId = uuidv4();
|
||||
const inputs = reduce(
|
||||
template.inputs,
|
||||
(inputsAccumulator, inputTemplate, inputName) => {
|
||||
const fieldId = uuidv4();
|
||||
|
||||
const inputFieldValue: InputFieldValue = buildInputFieldValue(
|
||||
fieldId,
|
||||
inputTemplate
|
||||
);
|
||||
const inputFieldValue: InputFieldValue = buildInputFieldValue(
|
||||
fieldId,
|
||||
inputTemplate
|
||||
);
|
||||
|
||||
inputsAccumulator[inputName] = inputFieldValue;
|
||||
inputsAccumulator[inputName] = inputFieldValue;
|
||||
|
||||
return inputsAccumulator;
|
||||
},
|
||||
{} as Record<string, InputFieldValue>
|
||||
);
|
||||
return inputsAccumulator;
|
||||
},
|
||||
{} as Record<string, InputFieldValue>
|
||||
);
|
||||
|
||||
const outputs = reduce(
|
||||
template.outputs,
|
||||
(outputsAccumulator, outputTemplate, outputName) => {
|
||||
const fieldId = uuidv4();
|
||||
const outputs = reduce(
|
||||
template.outputs,
|
||||
(outputsAccumulator, outputTemplate, outputName) => {
|
||||
const fieldId = uuidv4();
|
||||
|
||||
const outputFieldValue: OutputFieldValue = {
|
||||
id: fieldId,
|
||||
name: outputName,
|
||||
type: outputTemplate.type,
|
||||
};
|
||||
const outputFieldValue: OutputFieldValue = {
|
||||
id: fieldId,
|
||||
name: outputName,
|
||||
type: outputTemplate.type,
|
||||
};
|
||||
|
||||
outputsAccumulator[outputName] = outputFieldValue;
|
||||
outputsAccumulator[outputName] = outputFieldValue;
|
||||
|
||||
return outputsAccumulator;
|
||||
},
|
||||
{} as Record<string, OutputFieldValue>
|
||||
);
|
||||
return outputsAccumulator;
|
||||
},
|
||||
{} as Record<string, OutputFieldValue>
|
||||
);
|
||||
|
||||
const { x, y } = reactflow.project({
|
||||
x: window.innerWidth / 2.5,
|
||||
y: window.innerHeight / 8,
|
||||
});
|
||||
const { x, y } = flow.project({
|
||||
x: window.innerWidth / 2.5,
|
||||
y: window.innerHeight / 8,
|
||||
});
|
||||
|
||||
const invocation: Node<InvocationValue> = {
|
||||
id: nodeId,
|
||||
type: 'invocation',
|
||||
position: { x: x, y: y },
|
||||
data: {
|
||||
const invocation: Node<InvocationValue> = {
|
||||
id: nodeId,
|
||||
type,
|
||||
inputs,
|
||||
outputs,
|
||||
},
|
||||
};
|
||||
type: 'invocation',
|
||||
position: { x: x, y: y },
|
||||
data: {
|
||||
id: nodeId,
|
||||
type,
|
||||
inputs,
|
||||
outputs,
|
||||
},
|
||||
};
|
||||
|
||||
return invocation;
|
||||
};
|
||||
return invocation;
|
||||
},
|
||||
[invocationTemplates, flow]
|
||||
);
|
||||
};
|
||||
|
@ -1,16 +0,0 @@
|
||||
import { useAppSelector } from 'app/storeHooks';
|
||||
import { invocationTemplatesSelector } from '../store/selectors/invocationTemplatesSelector';
|
||||
|
||||
export const useGetInvocationTemplate = () => {
|
||||
const invocationTemplates = useAppSelector(invocationTemplatesSelector);
|
||||
|
||||
return (invocationType: string) => {
|
||||
const template = invocationTemplates[invocationType];
|
||||
|
||||
if (!template) {
|
||||
return;
|
||||
}
|
||||
|
||||
return template;
|
||||
};
|
||||
};
|
Loading…
Reference in New Issue
Block a user