mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into responsive-ui
This commit is contained in:
commit
e973aeef0d
@ -1,7 +1,7 @@
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
import 'reactflow/dist/style.css';
|
||||
import { useCallback } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import {
|
||||
Tooltip,
|
||||
Menu,
|
||||
@ -10,7 +10,7 @@ import {
|
||||
MenuItem,
|
||||
IconButton,
|
||||
} from '@chakra-ui/react';
|
||||
import { FaPlus } from 'react-icons/fa';
|
||||
import { FaEllipsisV, FaPlus } from 'react-icons/fa';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import { nodeAdded } from '../store/nodesSlice';
|
||||
import { cloneDeep, map } from 'lodash';
|
||||
@ -18,8 +18,10 @@ import { RootState } from 'app/store';
|
||||
import { useBuildInvocation } from '../hooks/useBuildInvocation';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/hooks/useToastWatcher';
|
||||
import { IAIIconButton } from 'exports';
|
||||
import { AnyInvocationType } from 'services/events/types';
|
||||
|
||||
export const AddNodeMenu = () => {
|
||||
const AddNodeMenu = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const invocationTemplates = useAppSelector(
|
||||
@ -29,7 +31,7 @@ export const AddNodeMenu = () => {
|
||||
const buildInvocation = useBuildInvocation();
|
||||
|
||||
const addNode = useCallback(
|
||||
(nodeType: string) => {
|
||||
(nodeType: AnyInvocationType) => {
|
||||
const invocation = buildInvocation(nodeType);
|
||||
|
||||
if (!invocation) {
|
||||
@ -47,9 +49,13 @@ export const AddNodeMenu = () => {
|
||||
);
|
||||
|
||||
return (
|
||||
<Menu>
|
||||
<MenuButton as={IconButton} aria-label="Add Node" icon={<FaPlus />} />
|
||||
<MenuList>
|
||||
<Menu isLazy>
|
||||
<MenuButton
|
||||
as={IAIIconButton}
|
||||
aria-label="Add Node"
|
||||
icon={<FaEllipsisV />}
|
||||
/>
|
||||
<MenuList overflowY="scroll" height={400}>
|
||||
{map(invocationTemplates, ({ title, description, type }, key) => {
|
||||
return (
|
||||
<Tooltip key={key} label={description} placement="end" hasArrow>
|
||||
@ -61,3 +67,5 @@ export const AddNodeMenu = () => {
|
||||
</Menu>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(AddNodeMenu);
|
||||
|
@ -1,5 +1,5 @@
|
||||
import { Tooltip } from '@chakra-ui/react';
|
||||
import { CSSProperties, useMemo } from 'react';
|
||||
import { CSSProperties, memo, useMemo } from 'react';
|
||||
import {
|
||||
Handle,
|
||||
Position,
|
||||
@ -19,11 +19,11 @@ const handleBaseStyles: CSSProperties = {
|
||||
};
|
||||
|
||||
const inputHandleStyles: CSSProperties = {
|
||||
left: '-1.7rem',
|
||||
left: '-1rem',
|
||||
};
|
||||
|
||||
const outputHandleStyles: CSSProperties = {
|
||||
right: '-1.7rem',
|
||||
right: '-0.5rem',
|
||||
};
|
||||
|
||||
const requiredConnectionStyles: CSSProperties = {
|
||||
@ -38,13 +38,14 @@ type FieldHandleProps = {
|
||||
styles?: CSSProperties;
|
||||
};
|
||||
|
||||
export const FieldHandle = (props: FieldHandleProps) => {
|
||||
const FieldHandle = (props: FieldHandleProps) => {
|
||||
const { nodeId, field, isValidConnection, handleType, styles } = props;
|
||||
const { name, title, type, description } = field;
|
||||
|
||||
console.log(props);
|
||||
|
||||
return (
|
||||
<Tooltip
|
||||
key={name}
|
||||
label={type}
|
||||
placement={handleType === 'target' ? 'start' : 'end'}
|
||||
hasArrow
|
||||
@ -67,3 +68,5 @@ export const FieldHandle = (props: FieldHandleProps) => {
|
||||
</Tooltip>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(FieldHandle);
|
||||
|
@ -2,8 +2,9 @@ import 'reactflow/dist/style.css';
|
||||
import { Tooltip, Badge, HStack } from '@chakra-ui/react';
|
||||
import { map } from 'lodash';
|
||||
import { FIELDS } from '../types/constants';
|
||||
import { memo } from 'react';
|
||||
|
||||
export const FieldTypeLegend = () => {
|
||||
const FieldTypeLegend = () => {
|
||||
return (
|
||||
<HStack>
|
||||
{map(FIELDS, ({ title, description, color }, key) => (
|
||||
@ -16,3 +17,5 @@ export const FieldTypeLegend = () => {
|
||||
</HStack>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(FieldTypeLegend);
|
||||
|
@ -1,15 +1,12 @@
|
||||
import {
|
||||
Background,
|
||||
Controls,
|
||||
MiniMap,
|
||||
OnConnect,
|
||||
OnEdgesChange,
|
||||
OnNodesChange,
|
||||
ReactFlow,
|
||||
ConnectionLineType,
|
||||
OnConnectStart,
|
||||
OnConnectEnd,
|
||||
Panel,
|
||||
} from 'reactflow';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import { RootState } from 'app/store';
|
||||
@ -22,10 +19,10 @@ import {
|
||||
} from '../store/nodesSlice';
|
||||
import { useCallback } from 'react';
|
||||
import { InvocationComponent } from './InvocationComponent';
|
||||
import { AddNodeMenu } from './AddNodeMenu';
|
||||
import { FieldTypeLegend } from './FieldTypeLegend';
|
||||
import { Button } from '@chakra-ui/react';
|
||||
import { nodesGraphBuilt } from 'services/thunks/session';
|
||||
import TopLeftPanel from './panels/TopLeftPanel';
|
||||
import TopRightPanel from './panels/TopRightPanel';
|
||||
import TopCenterPanel from './panels/TopCenterPanel';
|
||||
import BottomLeftPanel from './panels/BottomLeftPanel.tsx';
|
||||
|
||||
const nodeTypes = { invocation: InvocationComponent };
|
||||
|
||||
@ -69,10 +66,6 @@ export const Flow = () => {
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleInvoke = useCallback(() => {
|
||||
dispatch(nodesGraphBuilt());
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<ReactFlow
|
||||
nodeTypes={nodeTypes}
|
||||
@ -87,17 +80,11 @@ export const Flow = () => {
|
||||
style: { strokeWidth: 2 },
|
||||
}}
|
||||
>
|
||||
<Panel position="top-left">
|
||||
<AddNodeMenu />
|
||||
</Panel>
|
||||
<Panel position="top-center">
|
||||
<Button onClick={handleInvoke}>Will it blend?</Button>
|
||||
</Panel>
|
||||
<Panel position="top-right">
|
||||
<FieldTypeLegend />
|
||||
</Panel>
|
||||
<TopLeftPanel />
|
||||
<TopCenterPanel />
|
||||
<TopRightPanel />
|
||||
<BottomLeftPanel />
|
||||
<Background />
|
||||
<Controls />
|
||||
<MiniMap nodeStrokeWidth={3} zoomable pannable />
|
||||
</ReactFlow>
|
||||
);
|
||||
|
@ -0,0 +1,39 @@
|
||||
import { Flex, Heading, Tooltip, Icon } from '@chakra-ui/react';
|
||||
import { InvocationTemplate } from 'features/nodes/types/types';
|
||||
import { memo, MutableRefObject } from 'react';
|
||||
import { FaInfoCircle } from 'react-icons/fa';
|
||||
|
||||
interface IAINodeHeaderProps {
|
||||
nodeId: string;
|
||||
template: InvocationTemplate;
|
||||
}
|
||||
|
||||
const IAINodeHeader = (props: IAINodeHeaderProps) => {
|
||||
const { nodeId, template } = props;
|
||||
return (
|
||||
<Flex
|
||||
borderTopRadius="md"
|
||||
justifyContent="space-between"
|
||||
background="base.700"
|
||||
px={2}
|
||||
py={1}
|
||||
alignItems="center"
|
||||
>
|
||||
<Tooltip label={nodeId}>
|
||||
<Heading size="xs" fontWeight={600} color="base.100">
|
||||
{template.title}
|
||||
</Heading>
|
||||
</Tooltip>
|
||||
<Tooltip
|
||||
label={template.description}
|
||||
placement="top"
|
||||
hasArrow
|
||||
shouldWrapChildren
|
||||
>
|
||||
<Icon color="base.300" as={FaInfoCircle} h="min-content" />
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(IAINodeHeader);
|
@ -0,0 +1,146 @@
|
||||
import {
|
||||
InputFieldTemplate,
|
||||
InputFieldValue,
|
||||
InvocationTemplate,
|
||||
} from 'features/nodes/types/types';
|
||||
import { memo, ReactNode, useCallback } from 'react';
|
||||
import { map } from 'lodash';
|
||||
import { useAppSelector } from 'app/storeHooks';
|
||||
import { RootState } from 'app/store';
|
||||
import {
|
||||
Box,
|
||||
Flex,
|
||||
FormControl,
|
||||
FormLabel,
|
||||
HStack,
|
||||
Tooltip,
|
||||
Divider,
|
||||
} from '@chakra-ui/react';
|
||||
import FieldHandle from '../FieldHandle';
|
||||
import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection';
|
||||
import InputFieldComponent from '../InputFieldComponent';
|
||||
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
|
||||
|
||||
interface IAINodeInputProps {
|
||||
nodeId: string;
|
||||
|
||||
input: InputFieldValue;
|
||||
template?: InputFieldTemplate | undefined;
|
||||
connected: boolean;
|
||||
}
|
||||
|
||||
function IAINodeInput(props: IAINodeInputProps) {
|
||||
const { nodeId, input, template, connected } = props;
|
||||
const isValidConnection = useIsValidConnection();
|
||||
|
||||
return (
|
||||
<Box
|
||||
position="relative"
|
||||
borderColor={
|
||||
!template
|
||||
? 'error.400'
|
||||
: !connected &&
|
||||
['always', 'connectionOnly'].includes(
|
||||
String(template?.inputRequirement)
|
||||
) &&
|
||||
input.value === undefined
|
||||
? 'warning.400'
|
||||
: undefined
|
||||
}
|
||||
>
|
||||
<FormControl isDisabled={!template ? true : connected} pl={2}>
|
||||
{!template ? (
|
||||
<HStack justifyContent="space-between" alignItems="center">
|
||||
<FormLabel>Unknown input: {input.name}</FormLabel>
|
||||
</HStack>
|
||||
) : (
|
||||
<>
|
||||
<HStack justifyContent="space-between" alignItems="center">
|
||||
<HStack>
|
||||
<Tooltip
|
||||
label={template?.description}
|
||||
placement="top"
|
||||
hasArrow
|
||||
shouldWrapChildren
|
||||
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
|
||||
>
|
||||
<FormLabel>{template?.title}</FormLabel>
|
||||
</Tooltip>
|
||||
</HStack>
|
||||
<InputFieldComponent
|
||||
nodeId={nodeId}
|
||||
field={input}
|
||||
template={template}
|
||||
/>
|
||||
</HStack>
|
||||
|
||||
{!['never', 'directOnly'].includes(
|
||||
template?.inputRequirement ?? ''
|
||||
) && (
|
||||
<FieldHandle
|
||||
nodeId={nodeId}
|
||||
field={template}
|
||||
isValidConnection={isValidConnection}
|
||||
handleType="target"
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</FormControl>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
interface IAINodeInputsProps {
|
||||
nodeId: string;
|
||||
template: InvocationTemplate;
|
||||
inputs: Record<string, InputFieldValue>;
|
||||
}
|
||||
|
||||
const IAINodeInputs = (props: IAINodeInputsProps) => {
|
||||
const { nodeId, template, inputs } = props;
|
||||
|
||||
const edges = useAppSelector((state: RootState) => state.nodes.edges);
|
||||
|
||||
const renderIAINodeInputs = useCallback(() => {
|
||||
const IAINodeInputsToRender: ReactNode[] = [];
|
||||
const inputSockets = map(inputs);
|
||||
|
||||
inputSockets.forEach((inputSocket, index) => {
|
||||
const inputTemplate = template.inputs[inputSocket.name];
|
||||
|
||||
const isConnected = Boolean(
|
||||
edges.filter((connectedInput) => {
|
||||
return (
|
||||
connectedInput.target === nodeId &&
|
||||
connectedInput.targetHandle === inputSocket.name
|
||||
);
|
||||
}).length
|
||||
);
|
||||
|
||||
if (index < inputSockets.length) {
|
||||
IAINodeInputsToRender.push(<Divider />);
|
||||
}
|
||||
|
||||
IAINodeInputsToRender.push(
|
||||
<IAINodeInput
|
||||
key={inputSocket.id}
|
||||
nodeId={nodeId}
|
||||
input={inputSocket}
|
||||
template={inputTemplate}
|
||||
connected={isConnected}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" gap={2} p={2}>
|
||||
{IAINodeInputsToRender}
|
||||
</Flex>
|
||||
);
|
||||
}, [edges, inputs, nodeId, template.inputs]);
|
||||
|
||||
return renderIAINodeInputs();
|
||||
};
|
||||
|
||||
export default memo(IAINodeInputs);
|
@ -0,0 +1,97 @@
|
||||
import {
|
||||
InvocationTemplate,
|
||||
OutputFieldTemplate,
|
||||
OutputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
import { memo, ReactNode, useCallback } from 'react';
|
||||
import { map } from 'lodash';
|
||||
import { useAppSelector } from 'app/storeHooks';
|
||||
import { RootState } from 'app/store';
|
||||
import { Box, Flex, FormControl, FormLabel, HStack } from '@chakra-ui/react';
|
||||
import FieldHandle from '../FieldHandle';
|
||||
import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection';
|
||||
|
||||
interface IAINodeOutputProps {
|
||||
nodeId: string;
|
||||
output: OutputFieldValue;
|
||||
template?: OutputFieldTemplate | undefined;
|
||||
connected: boolean;
|
||||
}
|
||||
|
||||
function IAINodeOutput(props: IAINodeOutputProps) {
|
||||
const { nodeId, output, template, connected } = props;
|
||||
const isValidConnection = useIsValidConnection();
|
||||
|
||||
return (
|
||||
<Box position="relative">
|
||||
<FormControl isDisabled={!template ? true : connected} paddingRight={3}>
|
||||
{!template ? (
|
||||
<HStack justifyContent="space-between" alignItems="center">
|
||||
<FormLabel color="error.400">
|
||||
Unknown Output: {output.name}
|
||||
</FormLabel>
|
||||
</HStack>
|
||||
) : (
|
||||
<>
|
||||
<FormLabel textAlign="end" padding={1}>
|
||||
{template?.title}
|
||||
</FormLabel>
|
||||
<FieldHandle
|
||||
key={output.id}
|
||||
nodeId={nodeId}
|
||||
field={template}
|
||||
isValidConnection={isValidConnection}
|
||||
handleType="source"
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
</FormControl>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
interface IAINodeOutputsProps {
|
||||
nodeId: string;
|
||||
template: InvocationTemplate;
|
||||
outputs: Record<string, OutputFieldValue>;
|
||||
}
|
||||
|
||||
const IAINodeOutputs = (props: IAINodeOutputsProps) => {
|
||||
const { nodeId, template, outputs } = props;
|
||||
|
||||
const edges = useAppSelector((state: RootState) => state.nodes.edges);
|
||||
|
||||
const renderIAINodeOutputs = useCallback(() => {
|
||||
const IAINodeOutputsToRender: ReactNode[] = [];
|
||||
const outputSockets = map(outputs);
|
||||
|
||||
outputSockets.forEach((outputSocket) => {
|
||||
const outputTemplate = template.outputs[outputSocket.name];
|
||||
|
||||
const isConnected = Boolean(
|
||||
edges.filter((connectedInput) => {
|
||||
return (
|
||||
connectedInput.source === nodeId &&
|
||||
connectedInput.sourceHandle === outputSocket.name
|
||||
);
|
||||
}).length
|
||||
);
|
||||
|
||||
IAINodeOutputsToRender.push(
|
||||
<IAINodeOutput
|
||||
key={outputSocket.id}
|
||||
nodeId={nodeId}
|
||||
output={outputSocket}
|
||||
template={outputTemplate}
|
||||
connected={isConnected}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
return <Flex flexDir="column">{IAINodeOutputsToRender}</Flex>;
|
||||
}, [edges, nodeId, outputs, template.outputs]);
|
||||
|
||||
return renderIAINodeOutputs();
|
||||
};
|
||||
|
||||
export default memo(IAINodeOutputs);
|
@ -0,0 +1,23 @@
|
||||
import { memo } from 'react';
|
||||
import { NodeResizeControl, NodeResizerProps } from 'reactflow';
|
||||
|
||||
const IAINodeResizer = (props: NodeResizerProps) => {
|
||||
const { ...rest } = props;
|
||||
return (
|
||||
<NodeResizeControl
|
||||
style={{
|
||||
position: 'absolute',
|
||||
border: 'none',
|
||||
background: 'transparent',
|
||||
width: 15,
|
||||
height: 15,
|
||||
bottom: 0,
|
||||
right: 0,
|
||||
}}
|
||||
minWidth={350}
|
||||
{...rest}
|
||||
></NodeResizeControl>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(IAINodeResizer);
|
@ -1,13 +1,14 @@
|
||||
import { Box } from '@chakra-ui/react';
|
||||
import { memo } from 'react';
|
||||
import { InputFieldTemplate, InputFieldValue } from '../types/types';
|
||||
import { ArrayInputFieldComponent } from './fields/ArrayInputField.tsx';
|
||||
import { BooleanInputFieldComponent } from './fields/BooleanInputFieldComponent';
|
||||
import { EnumInputFieldComponent } from './fields/EnumInputFieldComponent';
|
||||
import { ImageInputFieldComponent } from './fields/ImageInputFieldComponent';
|
||||
import { LatentsInputFieldComponent } from './fields/LatentsInputFieldComponent';
|
||||
import { ModelInputFieldComponent } from './fields/ModelInputFieldComponent';
|
||||
import { NumberInputFieldComponent } from './fields/NumberInputFieldComponent';
|
||||
import { StringInputFieldComponent } from './fields/StringInputFieldComponent';
|
||||
import ArrayInputFieldComponent from './fields/ArrayInputFieldComponent';
|
||||
import BooleanInputFieldComponent from './fields/BooleanInputFieldComponent';
|
||||
import EnumInputFieldComponent from './fields/EnumInputFieldComponent';
|
||||
import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
|
||||
import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent';
|
||||
import ModelInputFieldComponent from './fields/ModelInputFieldComponent';
|
||||
import NumberInputFieldComponent from './fields/NumberInputFieldComponent';
|
||||
import StringInputFieldComponent from './fields/StringInputFieldComponent';
|
||||
|
||||
type InputFieldComponentProps = {
|
||||
nodeId: string;
|
||||
@ -16,7 +17,7 @@ type InputFieldComponentProps = {
|
||||
};
|
||||
|
||||
// build an individual input element based on the schema
|
||||
export const InputFieldComponent = (props: InputFieldComponentProps) => {
|
||||
const InputFieldComponent = (props: InputFieldComponentProps) => {
|
||||
const { nodeId, field, template } = props;
|
||||
const { type, value } = field;
|
||||
|
||||
@ -105,3 +106,5 @@ export const InputFieldComponent = (props: InputFieldComponentProps) => {
|
||||
|
||||
return <Box p={2}>Unknown field type: {type}</Box>;
|
||||
};
|
||||
|
||||
export default memo(InputFieldComponent);
|
||||
|
@ -1,242 +1,98 @@
|
||||
import { NodeProps, useReactFlow } from 'reactflow';
|
||||
import {
|
||||
Box,
|
||||
Flex,
|
||||
FormControl,
|
||||
FormLabel,
|
||||
Heading,
|
||||
HStack,
|
||||
Tooltip,
|
||||
Icon,
|
||||
Code,
|
||||
Text,
|
||||
} from '@chakra-ui/react';
|
||||
import { FaExclamationCircle, FaInfoCircle } from 'react-icons/fa';
|
||||
import { InvocationValue } from '../types/types';
|
||||
import { InputFieldComponent } from './InputFieldComponent';
|
||||
import { FieldHandle } from './FieldHandle';
|
||||
import { isEqual, map, size } from 'lodash';
|
||||
import { memo, useMemo, useRef } from 'react';
|
||||
import { useIsValidConnection } from '../hooks/useIsValidConnection';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store';
|
||||
import { useAppSelector } from 'app/storeHooks';
|
||||
import { useGetInvocationTemplate } from '../hooks/useInvocationTemplate';
|
||||
import { NodeProps } from 'reactflow';
|
||||
import { Box, Flex, Icon, useToken } from '@chakra-ui/react';
|
||||
import { FaExclamationCircle } from 'react-icons/fa';
|
||||
import { InvocationTemplate, InvocationValue } from '../types/types';
|
||||
|
||||
const connectedInputFieldsSelector = createSelector(
|
||||
[(state: RootState) => state.nodes.edges],
|
||||
(edges) => {
|
||||
// return edges.map((e) => e.targetHandle);
|
||||
return edges;
|
||||
import { memo, PropsWithChildren, useMemo } from 'react';
|
||||
import IAINodeOutputs from './IAINode/IAINodeOutputs';
|
||||
import IAINodeInputs from './IAINode/IAINodeInputs';
|
||||
import IAINodeHeader from './IAINode/IAINodeHeader';
|
||||
import IAINodeResizer from './IAINode/IAINodeResizer';
|
||||
import { RootState } from 'app/store';
|
||||
import { AnyInvocationType } from 'services/events/types';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/storeHooks';
|
||||
|
||||
type InvocationComponentWrapperProps = PropsWithChildren & {
|
||||
selected: boolean;
|
||||
};
|
||||
|
||||
const InvocationComponentWrapper = (props: InvocationComponentWrapperProps) => {
|
||||
const [nodeSelectedOutline, nodeShadow] = useToken('shadows', [
|
||||
'nodeSelectedOutline',
|
||||
'dark-lg',
|
||||
]);
|
||||
|
||||
return (
|
||||
<Box
|
||||
sx={{
|
||||
position: 'relative',
|
||||
borderRadius: 'md',
|
||||
boxShadow: props.selected
|
||||
? `${nodeSelectedOutline}, ${nodeShadow}`
|
||||
: `${nodeShadow}`,
|
||||
}}
|
||||
>
|
||||
{props.children}
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
const makeTemplateSelector = (type: AnyInvocationType) =>
|
||||
createSelector(
|
||||
[(state: RootState) => state.nodes],
|
||||
(nodes) => {
|
||||
const template = nodes.invocationTemplates[type];
|
||||
if (!template) {
|
||||
return;
|
||||
}
|
||||
return template;
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
resultEqualityCheck: (
|
||||
a: InvocationTemplate | undefined,
|
||||
b: InvocationTemplate | undefined
|
||||
) => a !== undefined && b !== undefined && a.type === b.type,
|
||||
},
|
||||
}
|
||||
);
|
||||
);
|
||||
|
||||
export const InvocationComponent = memo((props: NodeProps<InvocationValue>) => {
|
||||
const { id: nodeId, data, selected } = props;
|
||||
const { type, inputs, outputs } = data;
|
||||
|
||||
const isValidConnection = useIsValidConnection();
|
||||
const templateSelector = useMemo(() => makeTemplateSelector(type), [type]);
|
||||
|
||||
const connectedInputs = useAppSelector(connectedInputFieldsSelector);
|
||||
const getInvocationTemplate = useGetInvocationTemplate();
|
||||
// TODO: determine if a field/handle is connected and disable the input if so
|
||||
const template = useAppSelector(templateSelector);
|
||||
|
||||
const template = useRef(getInvocationTemplate(type));
|
||||
|
||||
if (!template.current) {
|
||||
if (!template) {
|
||||
return (
|
||||
<Box
|
||||
sx={{
|
||||
padding: 4,
|
||||
bg: 'base.800',
|
||||
borderRadius: 'md',
|
||||
boxShadow: 'dark-lg',
|
||||
borderWidth: 2,
|
||||
borderColor: selected ? 'base.400' : 'transparent',
|
||||
}}
|
||||
>
|
||||
<InvocationComponentWrapper selected={selected}>
|
||||
<Flex sx={{ alignItems: 'center', justifyContent: 'center' }}>
|
||||
<Icon color="base.400" boxSize={32} as={FaExclamationCircle}></Icon>
|
||||
<IAINodeResizer />
|
||||
</Flex>
|
||||
</Box>
|
||||
</InvocationComponentWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<Box
|
||||
<InvocationComponentWrapper selected={selected}>
|
||||
<IAINodeHeader nodeId={nodeId} template={template} />
|
||||
<Flex
|
||||
sx={{
|
||||
padding: 4,
|
||||
flexDirection: 'column',
|
||||
borderBottomRadius: 'md',
|
||||
bg: 'base.800',
|
||||
borderRadius: 'md',
|
||||
boxShadow: 'dark-lg',
|
||||
borderWidth: 2,
|
||||
borderColor: selected ? 'base.400' : 'transparent',
|
||||
py: 2,
|
||||
}}
|
||||
>
|
||||
<Flex flexDirection="column" gap={2}>
|
||||
<>
|
||||
<Code>{nodeId}</Code>
|
||||
<HStack justifyContent="space-between">
|
||||
<Heading size="sm" fontWeight={500} color="base.100">
|
||||
{template.current.title}
|
||||
</Heading>
|
||||
<Tooltip
|
||||
label={template.current.description}
|
||||
placement="top"
|
||||
hasArrow
|
||||
shouldWrapChildren
|
||||
>
|
||||
<Icon color="base.300" as={FaInfoCircle} />
|
||||
</Tooltip>
|
||||
</HStack>
|
||||
{map(inputs, (input, i) => {
|
||||
const { id: fieldId } = input;
|
||||
const inputTemplate = template.current?.inputs[input.name];
|
||||
|
||||
if (!inputTemplate) {
|
||||
return (
|
||||
<Box
|
||||
key={fieldId}
|
||||
position="relative"
|
||||
p={2}
|
||||
borderWidth={1}
|
||||
borderRadius="md"
|
||||
sx={{
|
||||
borderColor: 'error.400',
|
||||
}}
|
||||
>
|
||||
<FormControl isDisabled={true}>
|
||||
<HStack justifyContent="space-between" alignItems="center">
|
||||
<FormLabel>Unknown input: {input.name}</FormLabel>
|
||||
</HStack>
|
||||
</FormControl>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
const isConnected = Boolean(
|
||||
connectedInputs.filter((connectedInput) => {
|
||||
return (
|
||||
connectedInput.target === nodeId &&
|
||||
connectedInput.targetHandle === input.name
|
||||
);
|
||||
}).length
|
||||
);
|
||||
|
||||
return (
|
||||
<Box
|
||||
key={fieldId}
|
||||
position="relative"
|
||||
p={2}
|
||||
borderWidth={1}
|
||||
borderRadius="md"
|
||||
sx={{
|
||||
borderColor:
|
||||
!isConnected &&
|
||||
['always', 'connectionOnly'].includes(
|
||||
String(inputTemplate?.inputRequirement)
|
||||
) &&
|
||||
input.value === undefined
|
||||
? 'warning.400'
|
||||
: undefined,
|
||||
}}
|
||||
>
|
||||
<FormControl isDisabled={isConnected}>
|
||||
<HStack justifyContent="space-between" alignItems="center">
|
||||
<FormLabel>{inputTemplate?.title}</FormLabel>
|
||||
<Tooltip
|
||||
label={inputTemplate?.description}
|
||||
placement="top"
|
||||
hasArrow
|
||||
shouldWrapChildren
|
||||
>
|
||||
<Icon color="base.400" as={FaInfoCircle} />
|
||||
</Tooltip>
|
||||
</HStack>
|
||||
<InputFieldComponent
|
||||
nodeId={nodeId}
|
||||
field={input}
|
||||
template={inputTemplate}
|
||||
/>
|
||||
</FormControl>
|
||||
{!['never', 'directOnly'].includes(
|
||||
inputTemplate?.inputRequirement ?? ''
|
||||
) && (
|
||||
<FieldHandle
|
||||
nodeId={nodeId}
|
||||
field={inputTemplate}
|
||||
isValidConnection={isValidConnection}
|
||||
handleType="target"
|
||||
/>
|
||||
)}
|
||||
</Box>
|
||||
);
|
||||
})}
|
||||
{map(outputs).map((output, i) => {
|
||||
const outputTemplate = template.current?.outputs[output.name];
|
||||
|
||||
const isConnected = Boolean(
|
||||
connectedInputs.filter((connectedInput) => {
|
||||
return (
|
||||
connectedInput.source === nodeId &&
|
||||
connectedInput.sourceHandle === output.name
|
||||
);
|
||||
}).length
|
||||
);
|
||||
|
||||
if (!outputTemplate) {
|
||||
return (
|
||||
<Box
|
||||
key={output.id}
|
||||
position="relative"
|
||||
p={2}
|
||||
borderWidth={1}
|
||||
borderRadius="md"
|
||||
sx={{
|
||||
borderColor: 'error.400',
|
||||
}}
|
||||
>
|
||||
<FormControl isDisabled={true}>
|
||||
<HStack justifyContent="space-between" alignItems="center">
|
||||
<FormLabel>Unknown output: {output.name}</FormLabel>
|
||||
</HStack>
|
||||
</FormControl>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<Box
|
||||
key={output.id}
|
||||
position="relative"
|
||||
p={2}
|
||||
borderWidth={1}
|
||||
borderRadius="md"
|
||||
>
|
||||
<FormControl isDisabled={isConnected}>
|
||||
<FormLabel textAlign="end">
|
||||
{outputTemplate?.title} Output
|
||||
</FormLabel>
|
||||
</FormControl>
|
||||
<FieldHandle
|
||||
key={output.id}
|
||||
nodeId={nodeId}
|
||||
field={outputTemplate}
|
||||
isValidConnection={isValidConnection}
|
||||
handleType="source"
|
||||
/>
|
||||
</Box>
|
||||
);
|
||||
})}
|
||||
</>
|
||||
<IAINodeOutputs nodeId={nodeId} outputs={outputs} template={template} />
|
||||
<IAINodeInputs nodeId={nodeId} inputs={inputs} template={template} />
|
||||
</Flex>
|
||||
<Flex></Flex>
|
||||
</Box>
|
||||
<IAINodeResizer />
|
||||
</InvocationComponentWrapper>
|
||||
);
|
||||
});
|
||||
|
||||
|
@ -3,15 +3,9 @@ import { Box } from '@chakra-ui/react';
|
||||
import { ReactFlowProvider } from 'reactflow';
|
||||
|
||||
import { Flow } from './Flow';
|
||||
import { useAppSelector } from 'app/storeHooks';
|
||||
import { RootState } from 'app/store';
|
||||
import { buildNodesGraph } from '../util/nodesGraphBuilder/buildNodesGraph';
|
||||
import { memo } from 'react';
|
||||
|
||||
const NodeEditor = () => {
|
||||
const state = useAppSelector((state: RootState) => state);
|
||||
|
||||
const graph = buildNodesGraph(state);
|
||||
|
||||
return (
|
||||
<Box
|
||||
sx={{
|
||||
@ -25,22 +19,8 @@ const NodeEditor = () => {
|
||||
<ReactFlowProvider>
|
||||
<Flow />
|
||||
</ReactFlowProvider>
|
||||
<Box
|
||||
as="pre"
|
||||
fontFamily="monospace"
|
||||
position="absolute"
|
||||
top={2}
|
||||
left={2}
|
||||
width="full"
|
||||
height="full"
|
||||
userSelect="none"
|
||||
pointerEvents="none"
|
||||
opacity={0.7}
|
||||
>
|
||||
<Box w="50%">{JSON.stringify(graph, null, 2)}</Box>
|
||||
</Box>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
export default NodeEditor;
|
||||
export default memo(NodeEditor);
|
||||
|
@ -0,0 +1,30 @@
|
||||
import { Box } from '@chakra-ui/react';
|
||||
import { RootState } from 'app/store';
|
||||
import { useAppSelector } from 'app/storeHooks';
|
||||
import { memo } from 'react';
|
||||
import { buildNodesGraph } from '../util/nodesGraphBuilder/buildNodesGraph';
|
||||
|
||||
const NodeGraphOverlay = () => {
|
||||
const state = useAppSelector((state: RootState) => state);
|
||||
const graph = buildNodesGraph(state);
|
||||
|
||||
return (
|
||||
<Box
|
||||
as="pre"
|
||||
fontFamily="monospace"
|
||||
position="absolute"
|
||||
top={10}
|
||||
right={2}
|
||||
opacity={0.7}
|
||||
background="base.800"
|
||||
p={2}
|
||||
maxHeight={500}
|
||||
overflowY="scroll"
|
||||
borderRadius="md"
|
||||
>
|
||||
{JSON.stringify(graph, null, 2)}
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(NodeGraphOverlay);
|
@ -0,0 +1,59 @@
|
||||
import { ButtonGroup } from '@chakra-ui/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import { IAIIconButton } from 'exports';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { FaCode, FaExpand, FaMinus, FaPlus } from 'react-icons/fa';
|
||||
import { useReactFlow } from 'reactflow';
|
||||
import { shouldShowGraphOverlayChanged } from '../store/nodesSlice';
|
||||
|
||||
const ViewportControls = () => {
|
||||
const { zoomIn, zoomOut, fitView } = useReactFlow();
|
||||
const dispatch = useAppDispatch();
|
||||
const shouldShowGraphOverlay = useAppSelector(
|
||||
(state) => state.nodes.shouldShowGraphOverlay
|
||||
);
|
||||
|
||||
const handleClickedZoomIn = useCallback(() => {
|
||||
zoomIn();
|
||||
}, [zoomIn]);
|
||||
|
||||
const handleClickedZoomOut = useCallback(() => {
|
||||
zoomOut();
|
||||
}, [zoomOut]);
|
||||
|
||||
const handleClickedFitView = useCallback(() => {
|
||||
fitView();
|
||||
}, [fitView]);
|
||||
|
||||
const handleClickedToggleGraphOverlay = useCallback(() => {
|
||||
dispatch(shouldShowGraphOverlayChanged(!shouldShowGraphOverlay));
|
||||
}, [shouldShowGraphOverlay, dispatch]);
|
||||
|
||||
return (
|
||||
<ButtonGroup isAttached orientation="vertical">
|
||||
<IAIIconButton
|
||||
onClick={handleClickedZoomIn}
|
||||
aria-label="Zoom In"
|
||||
icon={<FaPlus />}
|
||||
/>
|
||||
<IAIIconButton
|
||||
onClick={handleClickedZoomOut}
|
||||
aria-label="Zoom Out"
|
||||
icon={<FaMinus />}
|
||||
/>
|
||||
<IAIIconButton
|
||||
onClick={handleClickedFitView}
|
||||
aria-label="Fit to Viewport"
|
||||
icon={<FaExpand />}
|
||||
/>
|
||||
<IAIIconButton
|
||||
isChecked={shouldShowGraphOverlay}
|
||||
onClick={handleClickedToggleGraphOverlay}
|
||||
aria-label="Show/Hide Graph"
|
||||
icon={<FaCode />}
|
||||
/>
|
||||
</ButtonGroup>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ViewportControls);
|
@ -1,14 +1,17 @@
|
||||
import {
|
||||
ArrayInputFieldTemplate,
|
||||
ArrayInputFieldValue,
|
||||
} from 'features/nodes/types';
|
||||
import { FaImage, FaList } from 'react-icons/fa';
|
||||
} from 'features/nodes/types/types';
|
||||
import { memo } from 'react';
|
||||
import { FaList } from 'react-icons/fa';
|
||||
import { FieldComponentProps } from './types';
|
||||
|
||||
export const ArrayInputFieldComponent = (
|
||||
const ArrayInputFieldComponent = (
|
||||
props: FieldComponentProps<ArrayInputFieldValue, ArrayInputFieldTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
|
||||
return <FaList />;
|
||||
};
|
||||
|
||||
export default memo(ArrayInputFieldComponent);
|
@ -4,11 +4,11 @@ import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
BooleanInputFieldTemplate,
|
||||
BooleanInputFieldValue,
|
||||
} from 'features/nodes/types';
|
||||
import { ChangeEvent } from 'react';
|
||||
} from 'features/nodes/types/types';
|
||||
import { ChangeEvent, memo } from 'react';
|
||||
import { FieldComponentProps } from './types';
|
||||
|
||||
export const BooleanInputFieldComponent = (
|
||||
const BooleanInputFieldComponent = (
|
||||
props: FieldComponentProps<BooleanInputFieldValue, BooleanInputFieldTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
@ -29,3 +29,5 @@ export const BooleanInputFieldComponent = (
|
||||
<Switch onChange={handleValueChanged} isChecked={field.value}></Switch>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(BooleanInputFieldComponent);
|
||||
|
@ -4,11 +4,11 @@ import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
EnumInputFieldTemplate,
|
||||
EnumInputFieldValue,
|
||||
} from 'features/nodes/types';
|
||||
import { ChangeEvent } from 'react';
|
||||
} from 'features/nodes/types/types';
|
||||
import { ChangeEvent, memo } from 'react';
|
||||
import { FieldComponentProps } from './types';
|
||||
|
||||
export const EnumInputFieldComponent = (
|
||||
const EnumInputFieldComponent = (
|
||||
props: FieldComponentProps<EnumInputFieldValue, EnumInputFieldTemplate>
|
||||
) => {
|
||||
const { nodeId, field, template } = props;
|
||||
@ -33,3 +33,5 @@ export const EnumInputFieldComponent = (
|
||||
</Select>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(EnumInputFieldComponent);
|
||||
|
@ -8,13 +8,13 @@ import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
ImageInputFieldTemplate,
|
||||
ImageInputFieldValue,
|
||||
} from 'features/nodes/types';
|
||||
import { DragEvent, useCallback, useState } from 'react';
|
||||
} from 'features/nodes/types/types';
|
||||
import { DragEvent, memo, useCallback, useState } from 'react';
|
||||
import { FaImage } from 'react-icons/fa';
|
||||
import { ImageType } from 'services/api';
|
||||
import { FieldComponentProps } from './types';
|
||||
|
||||
export const ImageInputFieldComponent = (
|
||||
const ImageInputFieldComponent = (
|
||||
props: FieldComponentProps<ImageInputFieldValue, ImageInputFieldTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
@ -62,3 +62,5 @@ export const ImageInputFieldComponent = (
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ImageInputFieldComponent);
|
||||
|
@ -1,13 +1,16 @@
|
||||
import {
|
||||
LatentsInputFieldTemplate,
|
||||
LatentsInputFieldValue,
|
||||
} from 'features/nodes/types';
|
||||
} from 'features/nodes/types/types';
|
||||
import { memo } from 'react';
|
||||
import { FieldComponentProps } from './types';
|
||||
|
||||
export const LatentsInputFieldComponent = (
|
||||
const LatentsInputFieldComponent = (
|
||||
props: FieldComponentProps<LatentsInputFieldValue, LatentsInputFieldTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
export default memo(LatentsInputFieldComponent);
|
||||
|
@ -6,13 +6,13 @@ import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
ModelInputFieldTemplate,
|
||||
ModelInputFieldValue,
|
||||
} from 'features/nodes/types';
|
||||
} from 'features/nodes/types/types';
|
||||
import {
|
||||
selectModelsById,
|
||||
selectModelsIds,
|
||||
} from 'features/system/store/modelSlice';
|
||||
import { isEqual, map } from 'lodash';
|
||||
import { ChangeEvent } from 'react';
|
||||
import { ChangeEvent, memo } from 'react';
|
||||
import { FieldComponentProps } from './types';
|
||||
|
||||
const availableModelsSelector = createSelector(
|
||||
@ -28,7 +28,7 @@ const availableModelsSelector = createSelector(
|
||||
}
|
||||
);
|
||||
|
||||
export const ModelInputFieldComponent = (
|
||||
const ModelInputFieldComponent = (
|
||||
props: FieldComponentProps<ModelInputFieldValue, ModelInputFieldTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
@ -55,3 +55,5 @@ export const ModelInputFieldComponent = (
|
||||
</Select>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ModelInputFieldComponent);
|
||||
|
@ -12,10 +12,11 @@ import {
|
||||
FloatInputFieldValue,
|
||||
IntegerInputFieldTemplate,
|
||||
IntegerInputFieldValue,
|
||||
} from 'features/nodes/types';
|
||||
} from 'features/nodes/types/types';
|
||||
import { memo } from 'react';
|
||||
import { FieldComponentProps } from './types';
|
||||
|
||||
export const NumberInputFieldComponent = (
|
||||
const NumberInputFieldComponent = (
|
||||
props: FieldComponentProps<
|
||||
IntegerInputFieldValue | FloatInputFieldValue,
|
||||
IntegerInputFieldTemplate | FloatInputFieldTemplate
|
||||
@ -39,3 +40,5 @@ export const NumberInputFieldComponent = (
|
||||
</NumberInput>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(NumberInputFieldComponent);
|
||||
|
@ -4,11 +4,11 @@ import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
StringInputFieldTemplate,
|
||||
StringInputFieldValue,
|
||||
} from 'features/nodes/types';
|
||||
import { ChangeEvent } from 'react';
|
||||
} from 'features/nodes/types/types';
|
||||
import { ChangeEvent, memo } from 'react';
|
||||
import { FieldComponentProps } from './types';
|
||||
|
||||
export const StringInputFieldComponent = (
|
||||
const StringInputFieldComponent = (
|
||||
props: FieldComponentProps<StringInputFieldValue, StringInputFieldTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
@ -27,3 +27,5 @@ export const StringInputFieldComponent = (
|
||||
|
||||
return <Input onChange={handleValueChanged} value={field.value}></Input>;
|
||||
};
|
||||
|
||||
export default memo(StringInputFieldComponent);
|
||||
|
@ -0,0 +1,11 @@
|
||||
import { memo } from 'react';
|
||||
import { Panel } from 'reactflow';
|
||||
import ViewportControls from '../ViewportControls';
|
||||
|
||||
const BottomLeftPanel = () => (
|
||||
<Panel position="bottom-left">
|
||||
<ViewportControls />
|
||||
</Panel>
|
||||
);
|
||||
|
||||
export default memo(BottomLeftPanel);
|
@ -0,0 +1,23 @@
|
||||
import { useAppDispatch } from 'app/storeHooks';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { Panel } from 'reactflow';
|
||||
import { nodesGraphBuilt } from 'services/thunks/session';
|
||||
|
||||
const TopCenterPanel = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleInvoke = useCallback(() => {
|
||||
dispatch(nodesGraphBuilt());
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<Panel position="top-center">
|
||||
<IAIButton colorScheme="accent" onClick={handleInvoke}>
|
||||
Will it blend?
|
||||
</IAIButton>
|
||||
</Panel>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(TopCenterPanel);
|
@ -0,0 +1,11 @@
|
||||
import { memo } from 'react';
|
||||
import { Panel } from 'reactflow';
|
||||
import AddNodeMenu from '../AddNodeMenu';
|
||||
|
||||
const TopLeftPanel = () => (
|
||||
<Panel position="top-left">
|
||||
<AddNodeMenu />
|
||||
</Panel>
|
||||
);
|
||||
|
||||
export default memo(TopLeftPanel);
|
@ -0,0 +1,21 @@
|
||||
import { RootState } from 'app/store';
|
||||
import { useAppSelector } from 'app/storeHooks';
|
||||
import { memo } from 'react';
|
||||
import { Panel } from 'reactflow';
|
||||
import FieldTypeLegend from '../FieldTypeLegend';
|
||||
import NodeGraphOverlay from '../NodeGraphOverlay';
|
||||
|
||||
const TopRightPanel = () => {
|
||||
const shouldShowGraphOverlay = useAppSelector(
|
||||
(state: RootState) => state.nodes.shouldShowGraphOverlay
|
||||
);
|
||||
|
||||
return (
|
||||
<Panel position="top-right">
|
||||
<FieldTypeLegend />
|
||||
{shouldShowGraphOverlay && <NodeGraphOverlay />}
|
||||
</Panel>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(TopRightPanel);
|
@ -1,7 +1,9 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store';
|
||||
import { useAppSelector } from 'app/storeHooks';
|
||||
import { reduce } from 'lodash';
|
||||
import { Node } from 'reactflow';
|
||||
import { useCallback } from 'react';
|
||||
import { Node, useReactFlow } from 'reactflow';
|
||||
import { AnyInvocationType } from 'services/events/types';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import {
|
||||
@ -11,12 +13,19 @@ import {
|
||||
} from '../types/types';
|
||||
import { buildInputFieldValue } from '../util/fieldValueBuilders';
|
||||
|
||||
export const useBuildInvocation = () => {
|
||||
const invocationTemplates = useAppSelector(
|
||||
(state: RootState) => state.nodes.invocationTemplates
|
||||
);
|
||||
const templatesSelector = createSelector(
|
||||
[(state: RootState) => state.nodes],
|
||||
(nodes) => nodes.invocationTemplates,
|
||||
{ memoizeOptions: { resultEqualityCheck: (a, b) => true } }
|
||||
);
|
||||
|
||||
return (type: AnyInvocationType) => {
|
||||
export const useBuildInvocation = () => {
|
||||
const invocationTemplates = useAppSelector(templatesSelector);
|
||||
|
||||
const flow = useReactFlow();
|
||||
|
||||
return useCallback(
|
||||
(type: AnyInvocationType) => {
|
||||
const template = invocationTemplates[type];
|
||||
|
||||
if (template === undefined) {
|
||||
@ -61,10 +70,15 @@ export const useBuildInvocation = () => {
|
||||
{} as Record<string, OutputFieldValue>
|
||||
);
|
||||
|
||||
const { x, y } = flow.project({
|
||||
x: window.innerWidth / 2.5,
|
||||
y: window.innerHeight / 8,
|
||||
});
|
||||
|
||||
const invocation: Node<InvocationValue> = {
|
||||
id: nodeId,
|
||||
type: 'invocation',
|
||||
position: { x: 0, y: 0 },
|
||||
position: { x: x, y: y },
|
||||
data: {
|
||||
id: nodeId,
|
||||
type,
|
||||
@ -74,5 +88,7 @@ export const useBuildInvocation = () => {
|
||||
};
|
||||
|
||||
return invocation;
|
||||
};
|
||||
},
|
||||
[invocationTemplates, flow]
|
||||
);
|
||||
};
|
||||
|
@ -1,16 +0,0 @@
|
||||
import { useAppSelector } from 'app/storeHooks';
|
||||
import { invocationTemplatesSelector } from '../store/selectors/invocationTemplatesSelector';
|
||||
|
||||
export const useGetInvocationTemplate = () => {
|
||||
const invocationTemplates = useAppSelector(invocationTemplatesSelector);
|
||||
|
||||
return (invocationType: string) => {
|
||||
const template = invocationTemplates[invocationType];
|
||||
|
||||
if (!template) {
|
||||
return;
|
||||
}
|
||||
|
||||
return template;
|
||||
};
|
||||
};
|
@ -24,6 +24,7 @@ export type NodesState = {
|
||||
invocationTemplates: Record<string, InvocationTemplate>;
|
||||
connectionStartParams: OnConnectStartParams | null;
|
||||
lastGraph: Graph | null;
|
||||
shouldShowGraphOverlay: boolean;
|
||||
};
|
||||
|
||||
export const initialNodesState: NodesState = {
|
||||
@ -33,6 +34,7 @@ export const initialNodesState: NodesState = {
|
||||
invocationTemplates: {},
|
||||
connectionStartParams: null,
|
||||
lastGraph: null,
|
||||
shouldShowGraphOverlay: false,
|
||||
};
|
||||
|
||||
const nodesSlice = createSlice({
|
||||
@ -77,6 +79,9 @@ const nodesSlice = createSlice({
|
||||
state.nodes[nodeIndex].data.inputs[fieldName].value = value;
|
||||
}
|
||||
},
|
||||
shouldShowGraphOverlayChanged: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldShowGraphOverlay = action.payload;
|
||||
},
|
||||
},
|
||||
extraReducers(builder) {
|
||||
builder.addCase(receivedOpenAPISchema.fulfilled, (state, action) => {
|
||||
@ -98,6 +103,7 @@ export const {
|
||||
connectionMade,
|
||||
connectionStarted,
|
||||
connectionEnded,
|
||||
shouldShowGraphOverlayChanged,
|
||||
} = nodesSlice.actions;
|
||||
|
||||
export default nodesSlice.reducer;
|
||||
|
@ -22,46 +22,55 @@ const getColorTokenCssVariable = (color: string) =>
|
||||
|
||||
export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
||||
integer: {
|
||||
color: 'red',
|
||||
colorCssVar: getColorTokenCssVariable('red'),
|
||||
title: 'Integer',
|
||||
description: 'Integers are whole numbers, without a decimal point.',
|
||||
},
|
||||
float: {
|
||||
color: 'orange',
|
||||
colorCssVar: getColorTokenCssVariable('orange'),
|
||||
title: 'Float',
|
||||
description: 'Floats are numbers with a decimal point.',
|
||||
},
|
||||
string: {
|
||||
color: 'yellow',
|
||||
colorCssVar: getColorTokenCssVariable('yellow'),
|
||||
title: 'String',
|
||||
description: 'Strings are text.',
|
||||
},
|
||||
boolean: {
|
||||
color: 'green',
|
||||
colorCssVar: getColorTokenCssVariable('green'),
|
||||
title: 'Boolean',
|
||||
description: 'Booleans are true or false.',
|
||||
},
|
||||
enum: {
|
||||
color: 'blue',
|
||||
colorCssVar: getColorTokenCssVariable('blue'),
|
||||
title: 'Enum',
|
||||
description: 'Enums are values that may be one of a number of options.',
|
||||
},
|
||||
image: {
|
||||
color: 'purple',
|
||||
colorCssVar: getColorTokenCssVariable('purple'),
|
||||
title: 'Image',
|
||||
description: 'Images may be passed between nodes.',
|
||||
},
|
||||
latents: {
|
||||
color: 'pink',
|
||||
colorCssVar: getColorTokenCssVariable('pink'),
|
||||
title: 'Latents',
|
||||
description: 'Latents may be passed between nodes.',
|
||||
},
|
||||
model: {
|
||||
color: 'teal',
|
||||
colorCssVar: getColorTokenCssVariable('teal'),
|
||||
title: 'Model',
|
||||
description: 'Models are models.',
|
||||
},
|
||||
array: {
|
||||
color: 'gray',
|
||||
colorCssVar: getColorTokenCssVariable('gray'),
|
||||
title: 'Array',
|
||||
description: 'TODO: Array type description.',
|
||||
|
@ -39,6 +39,7 @@ export type InvocationTemplate = {
|
||||
};
|
||||
|
||||
export type FieldUIConfig = {
|
||||
color: string;
|
||||
colorCssVar: string;
|
||||
title: string;
|
||||
description: string;
|
||||
|
@ -64,6 +64,7 @@ export const theme: ThemeOverride = {
|
||||
working: `0 0 7px var(--invokeai-colors-working-400)`,
|
||||
error: `0 0 7px var(--invokeai-colors-error-400)`,
|
||||
},
|
||||
nodeSelectedOutline: `0 0 0 2px var(--invokeai-colors-base-500)`,
|
||||
},
|
||||
colors: {
|
||||
...invokeAIThemeColors,
|
||||
|
Loading…
Reference in New Issue
Block a user