Merge branch 'main' into feat/batch-graphs

This commit is contained in:
Brandon 2023-08-16 14:03:34 -04:00 committed by GitHub
commit ef8dc2e8c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
58 changed files with 1318 additions and 951 deletions

View File

@ -48,7 +48,7 @@ class BooleanCollectionOutput(BaseInvocationOutput):
) )
@title("Boolean") @title("Boolean Primitive")
@tags("primitives", "boolean") @tags("primitives", "boolean")
class BooleanInvocation(BaseInvocation): class BooleanInvocation(BaseInvocation):
"""A boolean primitive value""" """A boolean primitive value"""
@ -62,7 +62,7 @@ class BooleanInvocation(BaseInvocation):
return BooleanOutput(a=self.a) return BooleanOutput(a=self.a)
@title("Boolean Collection") @title("Boolean Primitive Collection")
@tags("primitives", "boolean", "collection") @tags("primitives", "boolean", "collection")
class BooleanCollectionInvocation(BaseInvocation): class BooleanCollectionInvocation(BaseInvocation):
"""A collection of boolean primitive values""" """A collection of boolean primitive values"""
@ -101,7 +101,7 @@ class IntegerCollectionOutput(BaseInvocationOutput):
) )
@title("Integer") @title("Integer Primitive")
@tags("primitives", "integer") @tags("primitives", "integer")
class IntegerInvocation(BaseInvocation): class IntegerInvocation(BaseInvocation):
"""An integer primitive value""" """An integer primitive value"""
@ -115,7 +115,7 @@ class IntegerInvocation(BaseInvocation):
return IntegerOutput(a=self.a) return IntegerOutput(a=self.a)
@title("Integer Collection") @title("Integer Primitive Collection")
@tags("primitives", "integer", "collection") @tags("primitives", "integer", "collection")
class IntegerCollectionInvocation(BaseInvocation): class IntegerCollectionInvocation(BaseInvocation):
"""A collection of integer primitive values""" """A collection of integer primitive values"""
@ -154,7 +154,7 @@ class FloatCollectionOutput(BaseInvocationOutput):
) )
@title("Float") @title("Float Primitive")
@tags("primitives", "float") @tags("primitives", "float")
class FloatInvocation(BaseInvocation): class FloatInvocation(BaseInvocation):
"""A float primitive value""" """A float primitive value"""
@ -168,7 +168,7 @@ class FloatInvocation(BaseInvocation):
return FloatOutput(a=self.param) return FloatOutput(a=self.param)
@title("Float Collection") @title("Float Primitive Collection")
@tags("primitives", "float", "collection") @tags("primitives", "float", "collection")
class FloatCollectionInvocation(BaseInvocation): class FloatCollectionInvocation(BaseInvocation):
"""A collection of float primitive values""" """A collection of float primitive values"""
@ -207,7 +207,7 @@ class StringCollectionOutput(BaseInvocationOutput):
) )
@title("String") @title("String Primitive")
@tags("primitives", "string") @tags("primitives", "string")
class StringInvocation(BaseInvocation): class StringInvocation(BaseInvocation):
"""A string primitive value""" """A string primitive value"""
@ -221,7 +221,7 @@ class StringInvocation(BaseInvocation):
return StringOutput(text=self.text) return StringOutput(text=self.text)
@title("String Collection") @title("String Primitive Collection")
@tags("primitives", "string", "collection") @tags("primitives", "string", "collection")
class StringCollectionInvocation(BaseInvocation): class StringCollectionInvocation(BaseInvocation):
"""A collection of string primitive values""" """A collection of string primitive values"""
@ -289,7 +289,7 @@ class ImageInvocation(BaseInvocation):
) )
@title("Image Collection") @title("Image Primitive Collection")
@tags("primitives", "image", "collection") @tags("primitives", "image", "collection")
class ImageCollectionInvocation(BaseInvocation): class ImageCollectionInvocation(BaseInvocation):
"""A collection of image primitive values""" """A collection of image primitive values"""
@ -357,7 +357,7 @@ class LatentsInvocation(BaseInvocation):
return build_latents_output(self.latents.latents_name, latents) return build_latents_output(self.latents.latents_name, latents)
@title("Latents Collection") @title("Latents Primitive Collection")
@tags("primitives", "latents", "collection") @tags("primitives", "latents", "collection")
class LatentsCollectionInvocation(BaseInvocation): class LatentsCollectionInvocation(BaseInvocation):
"""A collection of latents tensor primitive values""" """A collection of latents tensor primitive values"""
@ -475,7 +475,7 @@ class ConditioningInvocation(BaseInvocation):
return ConditioningOutput(conditioning=self.conditioning) return ConditioningOutput(conditioning=self.conditioning)
@title("Conditioning Collection") @title("Conditioning Primitive Collection")
@tags("primitives", "conditioning", "collection") @tags("primitives", "conditioning", "collection")
class ConditioningCollectionInvocation(BaseInvocation): class ConditioningCollectionInvocation(BaseInvocation):
"""A collection of conditioning tensor primitive values""" """A collection of conditioning tensor primitive values"""

View File

@ -0,0 +1,126 @@
/**
* This is a copy-paste of https://github.com/lukasbach/chakra-ui-contextmenu with a small change.
*
* The reactflow background element somehow prevents the chakra `useOutsideClick()` hook from working.
* With a menu open, clicking on the reactflow background element doesn't close the menu.
*
* Reactflow does provide an `onPaneClick` to handle clicks on the background element, but it is not
* straightforward to programatically close the menu.
*
* As a (hopefully temporary) workaround, we will use a dirty hack:
* - create `globalContextMenuCloseTrigger: number` in `ui` slice
* - increment it in `onPaneClick`
* - `useEffect()` to close the menu when `globalContextMenuCloseTrigger` changes
*/
import {
Menu,
MenuButton,
MenuButtonProps,
MenuProps,
Portal,
PortalProps,
useEventListener,
} from '@chakra-ui/react';
import { useAppSelector } from 'app/store/storeHooks';
import * as React from 'react';
import {
MutableRefObject,
useCallback,
useEffect,
useRef,
useState,
} from 'react';
export interface IAIContextMenuProps<T extends HTMLElement> {
renderMenu: () => JSX.Element | null;
children: (ref: MutableRefObject<T | null>) => JSX.Element | null;
menuProps?: Omit<MenuProps, 'children'> & { children?: React.ReactNode };
portalProps?: Omit<PortalProps, 'children'> & { children?: React.ReactNode };
menuButtonProps?: MenuButtonProps;
}
export function IAIContextMenu<T extends HTMLElement = HTMLElement>(
props: IAIContextMenuProps<T>
) {
const [isOpen, setIsOpen] = useState(false);
const [isRendered, setIsRendered] = useState(false);
const [isDeferredOpen, setIsDeferredOpen] = useState(false);
const [position, setPosition] = useState<[number, number]>([0, 0]);
const targetRef = useRef<T>(null);
const globalContextMenuCloseTrigger = useAppSelector(
(state) => state.ui.globalContextMenuCloseTrigger
);
useEffect(() => {
if (isOpen) {
setTimeout(() => {
setIsRendered(true);
setTimeout(() => {
setIsDeferredOpen(true);
});
});
} else {
setIsDeferredOpen(false);
const timeout = setTimeout(() => {
setIsRendered(isOpen);
}, 1000);
return () => clearTimeout(timeout);
}
}, [isOpen]);
useEffect(() => {
setIsOpen(false);
setIsDeferredOpen(false);
setIsRendered(false);
}, [globalContextMenuCloseTrigger]);
useEventListener('contextmenu', (e) => {
if (
targetRef.current?.contains(e.target as HTMLElement) ||
e.target === targetRef.current
) {
e.preventDefault();
setIsOpen(true);
setPosition([e.pageX, e.pageY]);
} else {
setIsOpen(false);
}
});
const onCloseHandler = useCallback(() => {
props.menuProps?.onClose?.();
setIsOpen(false);
}, [props.menuProps]);
return (
<>
{props.children(targetRef)}
{isRendered && (
<Portal {...props.portalProps}>
<Menu
isOpen={isDeferredOpen}
gutter={0}
{...props.menuProps}
onClose={onCloseHandler}
>
<MenuButton
aria-hidden={true}
w={1}
h={1}
style={{
position: 'absolute',
left: position[0],
top: position[1],
cursor: 'default',
}}
{...props.menuButtonProps}
/>
{props.renderMenu()}
</Menu>
</Portal>
)}
</>
);
}

View File

@ -16,6 +16,7 @@ import ImageContextMenu from 'features/gallery/components/ImageContextMenu/Image
import { import {
MouseEvent, MouseEvent,
ReactElement, ReactElement,
ReactNode,
SyntheticEvent, SyntheticEvent,
memo, memo,
useCallback, useCallback,
@ -32,6 +33,17 @@ import {
TypesafeDroppableData, TypesafeDroppableData,
} from 'features/dnd/types'; } from 'features/dnd/types';
const defaultUploadElement = (
<Icon
as={FaUpload}
sx={{
boxSize: 16,
}}
/>
);
const defaultNoContentFallback = <IAINoContentFallback icon={FaImage} />;
type IAIDndImageProps = FlexProps & { type IAIDndImageProps = FlexProps & {
imageDTO: ImageDTO | undefined; imageDTO: ImageDTO | undefined;
onError?: (event: SyntheticEvent<HTMLImageElement>) => void; onError?: (event: SyntheticEvent<HTMLImageElement>) => void;
@ -47,13 +59,14 @@ type IAIDndImageProps = FlexProps & {
fitContainer?: boolean; fitContainer?: boolean;
droppableData?: TypesafeDroppableData; droppableData?: TypesafeDroppableData;
draggableData?: TypesafeDraggableData; draggableData?: TypesafeDraggableData;
dropLabel?: string; dropLabel?: ReactNode;
isSelected?: boolean; isSelected?: boolean;
thumbnail?: boolean; thumbnail?: boolean;
noContentFallback?: ReactElement; noContentFallback?: ReactElement;
useThumbailFallback?: boolean; useThumbailFallback?: boolean;
withHoverOverlay?: boolean; withHoverOverlay?: boolean;
children?: JSX.Element; children?: JSX.Element;
uploadElement?: ReactNode;
}; };
const IAIDndImage = (props: IAIDndImageProps) => { const IAIDndImage = (props: IAIDndImageProps) => {
@ -74,7 +87,8 @@ const IAIDndImage = (props: IAIDndImageProps) => {
dropLabel, dropLabel,
isSelected = false, isSelected = false,
thumbnail = false, thumbnail = false,
noContentFallback = <IAINoContentFallback icon={FaImage} />, noContentFallback = defaultNoContentFallback,
uploadElement = defaultUploadElement,
useThumbailFallback, useThumbailFallback,
withHoverOverlay = false, withHoverOverlay = false,
children, children,
@ -193,12 +207,7 @@ const IAIDndImage = (props: IAIDndImageProps) => {
{...getUploadButtonProps()} {...getUploadButtonProps()}
> >
<input {...getUploadInputProps()} /> <input {...getUploadInputProps()} />
<Icon {uploadElement}
as={FaUpload}
sx={{
boxSize: 16,
}}
/>
</Flex> </Flex>
</> </>
)} )}
@ -210,6 +219,7 @@ const IAIDndImage = (props: IAIDndImageProps) => {
onClick={onClick} onClick={onClick}
/> />
)} )}
{children}
{!isDropDisabled && ( {!isDropDisabled && (
<IAIDroppable <IAIDroppable
data={droppableData} data={droppableData}
@ -217,7 +227,6 @@ const IAIDndImage = (props: IAIDndImageProps) => {
dropLabel={dropLabel} dropLabel={dropLabel}
/> />
)} )}
{children}
</Flex> </Flex>
)} )}
</ImageContextMenu> </ImageContextMenu>

View File

@ -1,5 +1,8 @@
import { MenuList } from '@chakra-ui/react'; import { MenuList } from '@chakra-ui/react';
import { ContextMenu, ContextMenuProps } from 'chakra-ui-contextmenu'; import {
IAIContextMenu,
IAIContextMenuProps,
} from 'common/components/IAIContextMenu';
import { MouseEvent, memo, useCallback } from 'react'; import { MouseEvent, memo, useCallback } from 'react';
import { ImageDTO } from 'services/api/types'; import { ImageDTO } from 'services/api/types';
import { menuListMotionProps } from 'theme/components/menu'; import { menuListMotionProps } from 'theme/components/menu';
@ -12,7 +15,7 @@ import MultipleSelectionMenuItems from './MultipleSelectionMenuItems';
type Props = { type Props = {
imageDTO: ImageDTO | undefined; imageDTO: ImageDTO | undefined;
children: ContextMenuProps<HTMLDivElement>['children']; children: IAIContextMenuProps<HTMLDivElement>['children'];
}; };
const selector = createSelector( const selector = createSelector(
@ -33,7 +36,7 @@ const ImageContextMenu = ({ imageDTO, children }: Props) => {
}, []); }, []);
return ( return (
<ContextMenu<HTMLDivElement> <IAIContextMenu<HTMLDivElement>
menuProps={{ size: 'sm', isLazy: true }} menuProps={{ size: 'sm', isLazy: true }}
menuButtonProps={{ menuButtonProps={{
bg: 'transparent', bg: 'transparent',
@ -68,7 +71,7 @@ const ImageContextMenu = ({ imageDTO, children }: Props) => {
}} }}
> >
{children} {children}
</ContextMenu> </IAIContextMenu>
); );
}; };

View File

@ -2,8 +2,9 @@ import { Badge, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens'; import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import { useMemo } from 'react'; import { memo, useMemo } from 'react';
import { import {
BaseEdge, BaseEdge,
EdgeLabelRenderer, EdgeLabelRenderer,
@ -20,78 +21,165 @@ const makeEdgeSelector = (
targetHandleId: string | null | undefined, targetHandleId: string | null | undefined,
selected?: boolean selected?: boolean
) => ) =>
createSelector(stateSelector, ({ nodes }) => { createSelector(
const sourceNode = nodes.nodes.find((node) => node.id === source); stateSelector,
const targetNode = nodes.nodes.find((node) => node.id === target); ({ nodes }) => {
const sourceNode = nodes.nodes.find((node) => node.id === source);
const targetNode = nodes.nodes.find((node) => node.id === target);
const isInvocationToInvocationEdge = const isInvocationToInvocationEdge =
isInvocationNode(sourceNode) && isInvocationNode(targetNode); isInvocationNode(sourceNode) && isInvocationNode(targetNode);
const isSelected = sourceNode?.selected || targetNode?.selected || selected; const isSelected =
const sourceType = isInvocationToInvocationEdge sourceNode?.selected || targetNode?.selected || selected;
? sourceNode?.data?.outputs[sourceHandleId || '']?.type const sourceType = isInvocationToInvocationEdge
: undefined; ? sourceNode?.data?.outputs[sourceHandleId || '']?.type
: undefined;
const stroke = const stroke =
sourceType && nodes.shouldColorEdges sourceType && nodes.shouldColorEdges
? colorTokenToCssVar(FIELDS[sourceType].color) ? colorTokenToCssVar(FIELDS[sourceType].color)
: colorTokenToCssVar('base.500'); : colorTokenToCssVar('base.500');
return { return {
isSelected, isSelected,
shouldAnimate: nodes.shouldAnimateEdges && isSelected, shouldAnimate: nodes.shouldAnimateEdges && isSelected,
stroke, stroke,
}; };
}); },
defaultSelectorOptions
const CollapsedEdge = ({
sourceX,
sourceY,
targetX,
targetY,
sourcePosition,
targetPosition,
markerEnd,
data,
selected,
source,
target,
sourceHandleId,
targetHandleId,
}: EdgeProps<{ count: number }>) => {
const selector = useMemo(
() =>
makeEdgeSelector(
source,
sourceHandleId,
target,
targetHandleId,
selected
),
[selected, source, sourceHandleId, target, targetHandleId]
); );
const { isSelected, shouldAnimate } = useAppSelector(selector); const CollapsedEdge = memo(
({
const [edgePath, labelX, labelY] = getBezierPath({
sourceX, sourceX,
sourceY, sourceY,
sourcePosition,
targetX, targetX,
targetY, targetY,
sourcePosition,
targetPosition, targetPosition,
}); markerEnd,
data,
selected,
source,
target,
sourceHandleId,
targetHandleId,
}: EdgeProps<{ count: number }>) => {
const selector = useMemo(
() =>
makeEdgeSelector(
source,
sourceHandleId,
target,
targetHandleId,
selected
),
[selected, source, sourceHandleId, target, targetHandleId]
);
const { base500 } = useChakraThemeTokens(); const { isSelected, shouldAnimate } = useAppSelector(selector);
return ( const [edgePath, labelX, labelY] = getBezierPath({
<> sourceX,
sourceY,
sourcePosition,
targetX,
targetY,
targetPosition,
});
const { base500 } = useChakraThemeTokens();
return (
<>
<BaseEdge
path={edgePath}
markerEnd={markerEnd}
style={{
strokeWidth: isSelected ? 3 : 2,
stroke: base500,
opacity: isSelected ? 0.8 : 0.5,
animation: shouldAnimate
? 'dashdraw 0.5s linear infinite'
: undefined,
strokeDasharray: shouldAnimate ? 5 : 'none',
}}
/>
{data?.count && data.count > 1 && (
<EdgeLabelRenderer>
<Flex
sx={{
position: 'absolute',
transform: `translate(-50%, -50%) translate(${labelX}px,${labelY}px)`,
}}
className="nodrag nopan"
>
<Badge
variant="solid"
sx={{
bg: 'base.500',
opacity: isSelected ? 0.8 : 0.5,
boxShadow: 'base',
}}
>
{data.count}
</Badge>
</Flex>
</EdgeLabelRenderer>
)}
</>
);
}
);
CollapsedEdge.displayName = 'CollapsedEdge';
const DefaultEdge = memo(
({
sourceX,
sourceY,
targetX,
targetY,
sourcePosition,
targetPosition,
markerEnd,
selected,
source,
target,
sourceHandleId,
targetHandleId,
}: EdgeProps) => {
const selector = useMemo(
() =>
makeEdgeSelector(
source,
sourceHandleId,
target,
targetHandleId,
selected
),
[source, sourceHandleId, target, targetHandleId, selected]
);
const { isSelected, shouldAnimate, stroke } = useAppSelector(selector);
const [edgePath] = getBezierPath({
sourceX,
sourceY,
sourcePosition,
targetX,
targetY,
targetPosition,
});
return (
<BaseEdge <BaseEdge
path={edgePath} path={edgePath}
markerEnd={markerEnd} markerEnd={markerEnd}
style={{ style={{
strokeWidth: isSelected ? 3 : 2, strokeWidth: isSelected ? 3 : 2,
stroke: base500, stroke,
opacity: isSelected ? 0.8 : 0.5, opacity: isSelected ? 0.8 : 0.5,
animation: shouldAnimate animation: shouldAnimate
? 'dashdraw 0.5s linear infinite' ? 'dashdraw 0.5s linear infinite'
@ -99,83 +187,11 @@ const CollapsedEdge = ({
strokeDasharray: shouldAnimate ? 5 : 'none', strokeDasharray: shouldAnimate ? 5 : 'none',
}} }}
/> />
{data?.count && data.count > 1 && ( );
<EdgeLabelRenderer> }
<Flex );
sx={{
position: 'absolute',
transform: `translate(-50%, -50%) translate(${labelX}px,${labelY}px)`,
}}
className="nodrag nopan"
>
<Badge
variant="solid"
sx={{
bg: 'base.500',
opacity: isSelected ? 0.8 : 0.5,
boxShadow: 'base',
}}
>
{data.count}
</Badge>
</Flex>
</EdgeLabelRenderer>
)}
</>
);
};
const DefaultEdge = ({ DefaultEdge.displayName = 'DefaultEdge';
sourceX,
sourceY,
targetX,
targetY,
sourcePosition,
targetPosition,
markerEnd,
selected,
source,
target,
sourceHandleId,
targetHandleId,
}: EdgeProps) => {
const selector = useMemo(
() =>
makeEdgeSelector(
source,
sourceHandleId,
target,
targetHandleId,
selected
),
[source, sourceHandleId, target, targetHandleId, selected]
);
const { isSelected, shouldAnimate, stroke } = useAppSelector(selector);
const [edgePath] = getBezierPath({
sourceX,
sourceY,
sourcePosition,
targetX,
targetY,
targetPosition,
});
return (
<BaseEdge
path={edgePath}
markerEnd={markerEnd}
style={{
strokeWidth: isSelected ? 3 : 2,
stroke,
opacity: isSelected ? 0.8 : 0.5,
animation: shouldAnimate ? 'dashdraw 0.5s linear infinite' : undefined,
strokeDasharray: shouldAnimate ? 5 : 'none',
}}
/>
);
};
export const edgeTypes = { export const edgeTypes = {
collapsed: CollapsedEdge, collapsed: CollapsedEdge,

View File

@ -1,5 +1,6 @@
import { useToken } from '@chakra-ui/react'; import { useToken } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { contextMenusClosed } from 'features/ui/store/uiSlice';
import { useCallback } from 'react'; import { useCallback } from 'react';
import { import {
Background, Background,
@ -114,6 +115,10 @@ export const Flow = () => {
[dispatch] [dispatch]
); );
const handlePaneClick = useCallback(() => {
dispatch(contextMenusClosed());
}, [dispatch]);
return ( return (
<ReactFlow <ReactFlow
defaultViewport={viewport} defaultViewport={viewport}
@ -132,12 +137,13 @@ export const Flow = () => {
connectionLineComponent={CustomConnectionLine} connectionLineComponent={CustomConnectionLine}
onSelectionChange={handleSelectionChange} onSelectionChange={handleSelectionChange}
isValidConnection={isValidConnection} isValidConnection={isValidConnection}
minZoom={0.2} minZoom={0.1}
snapToGrid={shouldSnapToGrid} snapToGrid={shouldSnapToGrid}
snapGrid={[25, 25]} snapGrid={[25, 25]}
connectionRadius={30} connectionRadius={30}
proOptions={proOptions} proOptions={proOptions}
style={{ borderRadius }} style={{ borderRadius }}
onPaneClick={handlePaneClick}
> >
<TopLeftPanel /> <TopLeftPanel />
<TopCenterPanel /> <TopCenterPanel />

View File

@ -1,40 +1,34 @@
import { Flex } from '@chakra-ui/react'; import { Flex } from '@chakra-ui/react';
import { import { useFieldNames, useWithFooter } from 'features/nodes/hooks/useNodeData';
InvocationNodeData, import { memo } from 'react';
InvocationTemplate,
} from 'features/nodes/types/types';
import { map, some } from 'lodash-es';
import { memo, useMemo } from 'react';
import { NodeProps } from 'reactflow';
import InputField from '../fields/InputField'; import InputField from '../fields/InputField';
import OutputField from '../fields/OutputField'; import OutputField from '../fields/OutputField';
import NodeFooter, { FOOTER_FIELDS } from './NodeFooter'; import NodeFooter from './NodeFooter';
import NodeHeader from './NodeHeader'; import NodeHeader from './NodeHeader';
import NodeWrapper from './NodeWrapper'; import NodeWrapper from './NodeWrapper';
type Props = { type Props = {
nodeProps: NodeProps<InvocationNodeData>; nodeId: string;
nodeTemplate: InvocationTemplate; isOpen: boolean;
label: string;
type: string;
selected: boolean;
}; };
const InvocationNode = ({ nodeProps, nodeTemplate }: Props) => { const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
const { id: nodeId, data } = nodeProps; const inputFieldNames = useFieldNames(nodeId, 'input');
const { inputs, outputs, isOpen } = data; const outputFieldNames = useFieldNames(nodeId, 'output');
const withFooter = useWithFooter(nodeId);
const inputFields = useMemo(
() => map(inputs).filter((i) => i.name !== 'is_intermediate'),
[inputs]
);
const outputFields = useMemo(() => map(outputs), [outputs]);
const withFooter = useMemo(
() => some(outputs, (output) => FOOTER_FIELDS.includes(output.type)),
[outputs]
);
return ( return (
<NodeWrapper nodeProps={nodeProps}> <NodeWrapper nodeId={nodeId} selected={selected}>
<NodeHeader nodeProps={nodeProps} nodeTemplate={nodeTemplate} /> <NodeHeader
nodeId={nodeId}
isOpen={isOpen}
label={label}
selected={selected}
type={type}
/>
{isOpen && ( {isOpen && (
<> <>
<Flex <Flex
@ -54,27 +48,23 @@ const InvocationNode = ({ nodeProps, nodeTemplate }: Props) => {
className="nopan" className="nopan"
sx={{ flexDir: 'column', px: 2, w: 'full', h: 'full' }} sx={{ flexDir: 'column', px: 2, w: 'full', h: 'full' }}
> >
{outputFields.map((field) => ( {outputFieldNames.map((fieldName) => (
<OutputField <OutputField
key={`${nodeId}.${field.id}.input-field`} key={`${nodeId}.${fieldName}.output-field`}
nodeProps={nodeProps} nodeId={nodeId}
nodeTemplate={nodeTemplate} fieldName={fieldName}
field={field}
/> />
))} ))}
{inputFields.map((field) => ( {inputFieldNames.map((fieldName) => (
<InputField <InputField
key={`${nodeId}.${field.id}.input-field`} key={`${nodeId}.${fieldName}.input-field`}
nodeProps={nodeProps} nodeId={nodeId}
nodeTemplate={nodeTemplate} fieldName={fieldName}
field={field}
/> />
))} ))}
</Flex> </Flex>
</Flex> </Flex>
{withFooter && ( {withFooter && <NodeFooter nodeId={nodeId} />}
<NodeFooter nodeProps={nodeProps} nodeTemplate={nodeTemplate} />
)}
</> </>
)} )}
</NodeWrapper> </NodeWrapper>

View File

@ -2,16 +2,15 @@ import { ChevronUpIcon } from '@chakra-ui/icons';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton'; import IAIIconButton from 'common/components/IAIIconButton';
import { nodeIsOpenChanged } from 'features/nodes/store/nodesSlice'; import { nodeIsOpenChanged } from 'features/nodes/store/nodesSlice';
import { NodeData } from 'features/nodes/types/types';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { NodeProps, useUpdateNodeInternals } from 'reactflow'; import { useUpdateNodeInternals } from 'reactflow';
interface Props { interface Props {
nodeProps: NodeProps<NodeData>; nodeId: string;
isOpen: boolean;
} }
const NodeCollapseButton = (props: Props) => { const NodeCollapseButton = ({ nodeId, isOpen }: Props) => {
const { id: nodeId, isOpen } = props.nodeProps.data;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const updateNodeInternals = useUpdateNodeInternals(); const updateNodeInternals = useUpdateNodeInternals();

View File

@ -1,20 +1,17 @@
import { useColorModeValue } from '@chakra-ui/react'; import { useColorModeValue } from '@chakra-ui/react';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens'; import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import { import { useNodeData } from 'features/nodes/hooks/useNodeData';
InvocationNodeData, import { isInvocationNodeData } from 'features/nodes/types/types';
InvocationTemplate,
} from 'features/nodes/types/types';
import { map } from 'lodash-es'; import { map } from 'lodash-es';
import { CSSProperties, memo, useMemo } from 'react'; import { CSSProperties, memo, useMemo } from 'react';
import { Handle, NodeProps, Position } from 'reactflow'; import { Handle, Position } from 'reactflow';
interface Props { interface Props {
nodeProps: NodeProps<InvocationNodeData>; nodeId: string;
nodeTemplate: InvocationTemplate;
} }
const NodeCollapsedHandles = (props: Props) => { const NodeCollapsedHandles = ({ nodeId }: Props) => {
const { data } = props.nodeProps; const data = useNodeData(nodeId);
const { base400, base600 } = useChakraThemeTokens(); const { base400, base600 } = useChakraThemeTokens();
const backgroundColor = useColorModeValue(base400, base600); const backgroundColor = useColorModeValue(base400, base600);
@ -30,6 +27,10 @@ const NodeCollapsedHandles = (props: Props) => {
[backgroundColor] [backgroundColor]
); );
if (!isInvocationNodeData(data)) {
return null;
}
return ( return (
<> <>
<Handle <Handle
@ -44,7 +45,7 @@ const NodeCollapsedHandles = (props: Props) => {
key={`${data.id}-${input.name}-collapsed-input-handle`} key={`${data.id}-${input.name}-collapsed-input-handle`}
type="target" type="target"
id={input.name} id={input.name}
isValidConnection={() => false} isConnectable={false}
position={Position.Left} position={Position.Left}
style={{ visibility: 'hidden' }} style={{ visibility: 'hidden' }}
/> />
@ -52,7 +53,6 @@ const NodeCollapsedHandles = (props: Props) => {
<Handle <Handle
type="source" type="source"
id={`${data.id}-collapsed-source`} id={`${data.id}-collapsed-source`}
isValidConnection={() => false}
isConnectable={false} isConnectable={false}
position={Position.Right} position={Position.Right}
style={{ ...dummyHandleStyles, right: '-0.5rem' }} style={{ ...dummyHandleStyles, right: '-0.5rem' }}
@ -62,7 +62,7 @@ const NodeCollapsedHandles = (props: Props) => {
key={`${data.id}-${output.name}-collapsed-output-handle`} key={`${data.id}-${output.name}-collapsed-output-handle`}
type="source" type="source"
id={output.name} id={output.name}
isValidConnection={() => false} isConnectable={false}
position={Position.Right} position={Position.Right}
style={{ visibility: 'hidden' }} style={{ visibility: 'hidden' }}
/> />

View File

@ -6,49 +6,19 @@ import {
Spacer, Spacer,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import {
useHasImageOutput,
useIsIntermediate,
} from 'features/nodes/hooks/useNodeData';
import { fieldBooleanValueChanged } from 'features/nodes/store/nodesSlice'; import { fieldBooleanValueChanged } from 'features/nodes/store/nodesSlice';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants'; import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import { import { ChangeEvent, memo, useCallback } from 'react';
InvocationNodeData,
InvocationTemplate,
} from 'features/nodes/types/types';
import { some } from 'lodash-es';
import { ChangeEvent, memo, useCallback, useMemo } from 'react';
import { NodeProps } from 'reactflow';
export const IMAGE_FIELDS = ['ImageField', 'ImageCollection'];
export const FOOTER_FIELDS = IMAGE_FIELDS;
type Props = { type Props = {
nodeProps: NodeProps<InvocationNodeData>; nodeId: string;
nodeTemplate: InvocationTemplate;
}; };
const NodeFooter = (props: Props) => { const NodeFooter = ({ nodeId }: Props) => {
const { nodeProps, nodeTemplate } = props;
const dispatch = useAppDispatch();
const hasImageOutput = useMemo(
() =>
some(nodeTemplate?.outputs, (output) =>
IMAGE_FIELDS.includes(output.type)
),
[nodeTemplate?.outputs]
);
const handleChangeIsIntermediate = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(
fieldBooleanValueChanged({
nodeId: nodeProps.data.id,
fieldName: 'is_intermediate',
value: !e.target.checked,
})
);
},
[dispatch, nodeProps.data.id]
);
return ( return (
<Flex <Flex
className={DRAG_HANDLE_CLASSNAME} className={DRAG_HANDLE_CLASSNAME}
@ -62,19 +32,45 @@ const NodeFooter = (props: Props) => {
}} }}
> >
<Spacer /> <Spacer />
{hasImageOutput && ( <SaveImageCheckbox nodeId={nodeId} />
<FormControl as={Flex} sx={{ alignItems: 'center', gap: 2, w: 'auto' }}>
<FormLabel sx={{ fontSize: 'xs', mb: '1px' }}>Save Output</FormLabel>
<Checkbox
className="nopan"
size="sm"
onChange={handleChangeIsIntermediate}
isChecked={!nodeProps.data.inputs['is_intermediate']?.value}
/>
</FormControl>
)}
</Flex> </Flex>
); );
}; };
export default memo(NodeFooter); export default memo(NodeFooter);
const SaveImageCheckbox = memo(({ nodeId }: { nodeId: string }) => {
const dispatch = useAppDispatch();
const hasImageOutput = useHasImageOutput(nodeId);
const is_intermediate = useIsIntermediate(nodeId);
const handleChangeIsIntermediate = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(
fieldBooleanValueChanged({
nodeId,
fieldName: 'is_intermediate',
value: !e.target.checked,
})
);
},
[dispatch, nodeId]
);
if (!hasImageOutput) {
return null;
}
return (
<FormControl as={Flex} sx={{ alignItems: 'center', gap: 2, w: 'auto' }}>
<FormLabel sx={{ fontSize: 'xs', mb: '1px' }}>Save Output</FormLabel>
<Checkbox
className="nopan"
size="sm"
onChange={handleChangeIsIntermediate}
isChecked={!is_intermediate}
/>
</FormControl>
);
});
SaveImageCheckbox.displayName = 'SaveImageCheckbox';

View File

@ -1,10 +1,5 @@
import { Flex } from '@chakra-ui/react'; import { Flex } from '@chakra-ui/react';
import {
InvocationNodeData,
InvocationTemplate,
} from 'features/nodes/types/types';
import { memo } from 'react'; import { memo } from 'react';
import { NodeProps } from 'reactflow';
import NodeCollapseButton from '../Invocation/NodeCollapseButton'; import NodeCollapseButton from '../Invocation/NodeCollapseButton';
import NodeCollapsedHandles from '../Invocation/NodeCollapsedHandles'; import NodeCollapsedHandles from '../Invocation/NodeCollapsedHandles';
import NodeNotesEdit from '../Invocation/NodeNotesEdit'; import NodeNotesEdit from '../Invocation/NodeNotesEdit';
@ -12,14 +7,14 @@ import NodeStatusIndicator from '../Invocation/NodeStatusIndicator';
import NodeTitle from '../Invocation/NodeTitle'; import NodeTitle from '../Invocation/NodeTitle';
type Props = { type Props = {
nodeProps: NodeProps<InvocationNodeData>; nodeId: string;
nodeTemplate: InvocationTemplate; isOpen: boolean;
label: string;
type: string;
selected: boolean;
}; };
const NodeHeader = (props: Props) => { const NodeHeader = ({ nodeId, isOpen }: Props) => {
const { nodeProps, nodeTemplate } = props;
const { isOpen } = nodeProps.data;
return ( return (
<Flex <Flex
layerStyle="nodeHeader" layerStyle="nodeHeader"
@ -35,18 +30,13 @@ const NodeHeader = (props: Props) => {
_dark: { color: 'base.200' }, _dark: { color: 'base.200' },
}} }}
> >
<NodeCollapseButton nodeProps={nodeProps} /> <NodeCollapseButton nodeId={nodeId} isOpen={isOpen} />
<NodeTitle nodeData={nodeProps.data} title={nodeTemplate.title} /> <NodeTitle nodeId={nodeId} />
<Flex alignItems="center"> <Flex alignItems="center">
<NodeStatusIndicator nodeProps={nodeProps} /> <NodeStatusIndicator nodeId={nodeId} />
<NodeNotesEdit nodeProps={nodeProps} nodeTemplate={nodeTemplate} /> <NodeNotesEdit nodeId={nodeId} />
</Flex> </Flex>
{!isOpen && ( {!isOpen && <NodeCollapsedHandles nodeId={nodeId} />}
<NodeCollapsedHandles
nodeProps={nodeProps}
nodeTemplate={nodeTemplate}
/>
)}
</Flex> </Flex>
); );
}; };

View File

@ -16,41 +16,31 @@ import {
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import IAITextarea from 'common/components/IAITextarea'; import IAITextarea from 'common/components/IAITextarea';
import {
useNodeData,
useNodeLabel,
useNodeTemplate,
useNodeTemplateTitle,
} from 'features/nodes/hooks/useNodeData';
import { nodeNotesChanged } from 'features/nodes/store/nodesSlice'; import { nodeNotesChanged } from 'features/nodes/store/nodesSlice';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants'; import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import { import { isInvocationNodeData } from 'features/nodes/types/types';
InvocationNodeData,
InvocationTemplate,
} from 'features/nodes/types/types';
import { ChangeEvent, memo, useCallback } from 'react'; import { ChangeEvent, memo, useCallback } from 'react';
import { FaInfoCircle } from 'react-icons/fa'; import { FaInfoCircle } from 'react-icons/fa';
import { NodeProps } from 'reactflow';
interface Props { interface Props {
nodeProps: NodeProps<InvocationNodeData>; nodeId: string;
nodeTemplate: InvocationTemplate;
} }
const NodeNotesEdit = (props: Props) => { const NodeNotesEdit = ({ nodeId }: Props) => {
const { nodeProps, nodeTemplate } = props;
const { data } = nodeProps;
const { isOpen, onOpen, onClose } = useDisclosure(); const { isOpen, onOpen, onClose } = useDisclosure();
const dispatch = useAppDispatch(); const label = useNodeLabel(nodeId);
const handleNotesChanged = useCallback( const title = useNodeTemplateTitle(nodeId);
(e: ChangeEvent<HTMLTextAreaElement>) => {
dispatch(nodeNotesChanged({ nodeId: data.id, notes: e.target.value }));
},
[data.id, dispatch]
);
return ( return (
<> <>
<Tooltip <Tooltip
label={ label={<TooltipContent nodeId={nodeId} />}
nodeTemplate ? (
<TooltipContent nodeProps={nodeProps} nodeTemplate={nodeTemplate} />
) : undefined
}
placement="top" placement="top"
shouldWrapChildren shouldWrapChildren
> >
@ -75,19 +65,10 @@ const NodeNotesEdit = (props: Props) => {
<Modal isOpen={isOpen} onClose={onClose} isCentered> <Modal isOpen={isOpen} onClose={onClose} isCentered>
<ModalOverlay /> <ModalOverlay />
<ModalContent> <ModalContent>
<ModalHeader> <ModalHeader>{label || title || 'Unknown Node'}</ModalHeader>
{data.label || nodeTemplate?.title || 'Unknown Node'}
</ModalHeader>
<ModalCloseButton /> <ModalCloseButton />
<ModalBody> <ModalBody>
<FormControl> <NotesTextarea nodeId={nodeId} />
<FormLabel>Notes</FormLabel>
<IAITextarea
value={data.notes}
onChange={handleNotesChanged}
rows={10}
/>
</FormControl>
</ModalBody> </ModalBody>
<ModalFooter /> <ModalFooter />
</ModalContent> </ModalContent>
@ -98,16 +79,49 @@ const NodeNotesEdit = (props: Props) => {
export default memo(NodeNotesEdit); export default memo(NodeNotesEdit);
type TooltipContentProps = Props; const TooltipContent = memo(({ nodeId }: { nodeId: string }) => {
const data = useNodeData(nodeId);
const nodeTemplate = useNodeTemplate(nodeId);
if (!isInvocationNodeData(data)) {
return <Text sx={{ fontWeight: 600 }}>Unknown Node</Text>;
}
const TooltipContent = (props: TooltipContentProps) => {
return ( return (
<Flex sx={{ flexDir: 'column' }}> <Flex sx={{ flexDir: 'column' }}>
<Text sx={{ fontWeight: 600 }}>{props.nodeTemplate?.title}</Text> <Text sx={{ fontWeight: 600 }}>{nodeTemplate?.title}</Text>
<Text sx={{ opacity: 0.7, fontStyle: 'oblique 5deg' }}> <Text sx={{ opacity: 0.7, fontStyle: 'oblique 5deg' }}>
{props.nodeTemplate?.description} {nodeTemplate?.description}
</Text> </Text>
{props.nodeProps.data.notes && <Text>{props.nodeProps.data.notes}</Text>} {data?.notes && <Text>{data.notes}</Text>}
</Flex> </Flex>
); );
}; });
TooltipContent.displayName = 'TooltipContent';
const NotesTextarea = memo(({ nodeId }: { nodeId: string }) => {
const dispatch = useAppDispatch();
const data = useNodeData(nodeId);
const handleNotesChanged = useCallback(
(e: ChangeEvent<HTMLTextAreaElement>) => {
dispatch(nodeNotesChanged({ nodeId, notes: e.target.value }));
},
[dispatch, nodeId]
);
if (!isInvocationNodeData(data)) {
return null;
}
return (
<FormControl>
<FormLabel>Notes</FormLabel>
<IAITextarea
value={data?.notes}
onChange={handleNotesChanged}
rows={10}
/>
</FormControl>
);
});
NotesTextarea.displayName = 'NodesTextarea';

View File

@ -11,17 +11,12 @@ import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants'; import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import { import { NodeExecutionState, NodeStatus } from 'features/nodes/types/types';
InvocationNodeData,
NodeExecutionState,
NodeStatus,
} from 'features/nodes/types/types';
import { memo, useMemo } from 'react'; import { memo, useMemo } from 'react';
import { FaCheck, FaEllipsisH, FaExclamation } from 'react-icons/fa'; import { FaCheck, FaEllipsisH, FaExclamation } from 'react-icons/fa';
import { NodeProps } from 'reactflow';
type Props = { type Props = {
nodeProps: NodeProps<InvocationNodeData>; nodeId: string;
}; };
const iconBoxSize = 3; const iconBoxSize = 3;
@ -33,8 +28,7 @@ const circleStyles = {
'.chakra-progress__track': { stroke: 'transparent' }, '.chakra-progress__track': { stroke: 'transparent' },
}; };
const NodeStatusIndicator = (props: Props) => { const NodeStatusIndicator = ({ nodeId }: Props) => {
const nodeId = props.nodeProps.data.id;
const selectNodeExecutionState = useMemo( const selectNodeExecutionState = useMemo(
() => () =>
createSelector( createSelector(
@ -76,7 +70,7 @@ type TooltipLabelProps = {
nodeExecutionState: NodeExecutionState; nodeExecutionState: NodeExecutionState;
}; };
const TooltipLabel = ({ nodeExecutionState }: TooltipLabelProps) => { const TooltipLabel = memo(({ nodeExecutionState }: TooltipLabelProps) => {
const { status, progress, progressImage } = nodeExecutionState; const { status, progress, progressImage } = nodeExecutionState;
if (status === NodeStatus.PENDING) { if (status === NodeStatus.PENDING) {
return <Text>Pending</Text>; return <Text>Pending</Text>;
@ -118,13 +112,15 @@ const TooltipLabel = ({ nodeExecutionState }: TooltipLabelProps) => {
} }
return null; return null;
}; });
TooltipLabel.displayName = 'TooltipLabel';
type StatusIconProps = { type StatusIconProps = {
nodeExecutionState: NodeExecutionState; nodeExecutionState: NodeExecutionState;
}; };
const StatusIcon = (props: StatusIconProps) => { const StatusIcon = memo((props: StatusIconProps) => {
const { progress, status } = props.nodeExecutionState; const { progress, status } = props.nodeExecutionState;
if (status === NodeStatus.PENDING) { if (status === NodeStatus.PENDING) {
return ( return (
@ -182,4 +178,6 @@ const StatusIcon = (props: StatusIconProps) => {
); );
} }
return null; return null;
}; });
StatusIcon.displayName = 'StatusIcon';

View File

@ -7,26 +7,29 @@ import {
useEditableControls, useEditableControls,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import {
useNodeLabel,
useNodeTemplateTitle,
} from 'features/nodes/hooks/useNodeData';
import { nodeLabelChanged } from 'features/nodes/store/nodesSlice'; import { nodeLabelChanged } from 'features/nodes/store/nodesSlice';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants'; import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import { NodeData } from 'features/nodes/types/types';
import { MouseEvent, memo, useCallback, useEffect, useState } from 'react'; import { MouseEvent, memo, useCallback, useEffect, useState } from 'react';
type Props = { type Props = {
nodeData: NodeData; nodeId: string;
title: string; title?: string;
}; };
const NodeTitle = (props: Props) => { const NodeTitle = ({ nodeId, title }: Props) => {
const { title } = props;
const { id: nodeId, label } = props.nodeData;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const [localTitle, setLocalTitle] = useState(label || title); const label = useNodeLabel(nodeId);
const templateTitle = useNodeTemplateTitle(nodeId);
const [localTitle, setLocalTitle] = useState('');
const handleSubmit = useCallback( const handleSubmit = useCallback(
async (newTitle: string) => { async (newTitle: string) => {
dispatch(nodeLabelChanged({ nodeId, label: newTitle })); dispatch(nodeLabelChanged({ nodeId, label: newTitle }));
setLocalTitle(newTitle || title); setLocalTitle(newTitle || title || 'Problem Setting Title');
}, },
[nodeId, dispatch, title] [nodeId, dispatch, title]
); );
@ -37,8 +40,8 @@ const NodeTitle = (props: Props) => {
useEffect(() => { useEffect(() => {
// Another component may change the title; sync local title with global state // Another component may change the title; sync local title with global state
setLocalTitle(label || title); setLocalTitle(label || title || templateTitle || 'Problem Setting Title');
}, [label, title]); }, [label, templateTitle, title]);
return ( return (
<Flex <Flex

View File

@ -6,10 +6,14 @@ import {
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { nodeClicked } from 'features/nodes/store/nodesSlice'; import { nodeClicked } from 'features/nodes/store/nodesSlice';
import { MouseEvent, PropsWithChildren, useCallback, useMemo } from 'react'; import {
MouseEvent,
PropsWithChildren,
memo,
useCallback,
useMemo,
} from 'react';
import { DRAG_HANDLE_CLASSNAME, NODE_WIDTH } from '../../types/constants'; import { DRAG_HANDLE_CLASSNAME, NODE_WIDTH } from '../../types/constants';
import { NodeData } from 'features/nodes/types/types';
import { NodeProps } from 'reactflow';
const useNodeSelect = (nodeId: string) => { const useNodeSelect = (nodeId: string) => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
@ -25,14 +29,13 @@ const useNodeSelect = (nodeId: string) => {
}; };
type NodeWrapperProps = PropsWithChildren & { type NodeWrapperProps = PropsWithChildren & {
nodeProps: NodeProps<NodeData>; nodeId: string;
selected: boolean;
width?: NonNullable<ChakraProps['sx']>['w']; width?: NonNullable<ChakraProps['sx']>['w'];
}; };
const NodeWrapper = (props: NodeWrapperProps) => { const NodeWrapper = (props: NodeWrapperProps) => {
const { width, children, nodeProps } = props; const { width, children, nodeId, selected } = props;
const { data, selected } = nodeProps;
const nodeId = data.id;
const [ const [
nodeSelectedOutlineLight, nodeSelectedOutlineLight,
@ -93,4 +96,4 @@ const NodeWrapper = (props: NodeWrapperProps) => {
); );
}; };
export default NodeWrapper; export default memo(NodeWrapper);

View File

@ -1,20 +1,26 @@
import { Box, Flex, Text } from '@chakra-ui/react'; import { Box, Flex, Text } from '@chakra-ui/react';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants'; import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import { InvocationNodeData } from 'features/nodes/types/types';
import { memo } from 'react'; import { memo } from 'react';
import { NodeProps } from 'reactflow';
import NodeCollapseButton from '../Invocation/NodeCollapseButton'; import NodeCollapseButton from '../Invocation/NodeCollapseButton';
import NodeWrapper from '../Invocation/NodeWrapper'; import NodeWrapper from '../Invocation/NodeWrapper';
type Props = { type Props = {
nodeProps: NodeProps<InvocationNodeData>; nodeId: string;
isOpen: boolean;
label: string;
type: string;
selected: boolean;
}; };
const UnknownNodeFallback = ({ nodeProps }: Props) => { const UnknownNodeFallback = ({
const { data } = nodeProps; nodeId,
const { isOpen, label, type } = data; isOpen,
label,
type,
selected,
}: Props) => {
return ( return (
<NodeWrapper nodeProps={nodeProps}> <NodeWrapper nodeId={nodeId} selected={selected}>
<Flex <Flex
className={DRAG_HANDLE_CLASSNAME} className={DRAG_HANDLE_CLASSNAME}
layerStyle="nodeHeader" layerStyle="nodeHeader"
@ -27,7 +33,7 @@ const UnknownNodeFallback = ({ nodeProps }: Props) => {
fontSize: 'sm', fontSize: 'sm',
}} }}
> >
<NodeCollapseButton nodeProps={nodeProps} /> <NodeCollapseButton nodeId={nodeId} isOpen={isOpen} />
<Text <Text
sx={{ sx={{
w: 'full', w: 'full',

View File

@ -46,7 +46,6 @@ const NodeEditor = () => {
<AnimatePresence> <AnimatePresence>
{isReady && ( {isReady && (
<motion.div <motion.div
layoutId="node-editor-flow"
initial={{ initial={{
opacity: 0, opacity: 0,
}} }}
@ -67,7 +66,6 @@ const NodeEditor = () => {
<AnimatePresence> <AnimatePresence>
{!isReady && ( {!isReady && (
<motion.div <motion.div
layoutId="node-editor-loading"
initial={{ initial={{
opacity: 0, opacity: 0,
}} }}

View File

@ -15,7 +15,7 @@ import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton'; import IAIIconButton from 'common/components/IAIIconButton';
import IAISwitch from 'common/components/IAISwitch'; import IAISwitch from 'common/components/IAISwitch';
import { ChangeEvent, useCallback } from 'react'; import { ChangeEvent, memo, useCallback } from 'react';
import { FaCog } from 'react-icons/fa'; import { FaCog } from 'react-icons/fa';
import { import {
shouldAnimateEdgesChanged, shouldAnimateEdgesChanged,
@ -23,21 +23,26 @@ import {
shouldSnapToGridChanged, shouldSnapToGridChanged,
shouldValidateGraphChanged, shouldValidateGraphChanged,
} from '../store/nodesSlice'; } from '../store/nodesSlice';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
const selector = createSelector(stateSelector, ({ nodes }) => { const selector = createSelector(
const { stateSelector,
shouldAnimateEdges, ({ nodes }) => {
shouldValidateGraph, const {
shouldSnapToGrid, shouldAnimateEdges,
shouldColorEdges, shouldValidateGraph,
} = nodes; shouldSnapToGrid,
return { shouldColorEdges,
shouldAnimateEdges, } = nodes;
shouldValidateGraph, return {
shouldSnapToGrid, shouldAnimateEdges,
shouldColorEdges, shouldValidateGraph,
}; shouldSnapToGrid,
}); shouldColorEdges,
};
},
defaultSelectorOptions
);
const NodeEditorSettings = () => { const NodeEditorSettings = () => {
const { isOpen, onOpen, onClose } = useDisclosure(); const { isOpen, onOpen, onClose } = useDisclosure();
@ -136,4 +141,4 @@ const NodeEditorSettings = () => {
); );
}; };
export default NodeEditorSettings; export default memo(NodeEditorSettings);

View File

@ -1,19 +1,12 @@
import { Tooltip } from '@chakra-ui/react'; import { Tooltip } from '@chakra-ui/react';
import { CSSProperties, memo, useMemo } from 'react'; import { CSSProperties, memo, useMemo } from 'react';
import { Handle, HandleType, NodeProps, Position } from 'reactflow'; import { Handle, HandleType, Position } from 'reactflow';
import { import {
FIELDS, FIELDS,
HANDLE_TOOLTIP_OPEN_DELAY, HANDLE_TOOLTIP_OPEN_DELAY,
colorTokenToCssVar, colorTokenToCssVar,
} from '../../types/constants'; } from '../../types/constants';
import { import { InputFieldTemplate, OutputFieldTemplate } from '../../types/types';
InputFieldTemplate,
InputFieldValue,
InvocationNodeData,
InvocationTemplate,
OutputFieldTemplate,
OutputFieldValue,
} from '../../types/types';
export const handleBaseStyles: CSSProperties = { export const handleBaseStyles: CSSProperties = {
position: 'absolute', position: 'absolute',
@ -32,9 +25,6 @@ export const outputHandleStyles: CSSProperties = {
}; };
type FieldHandleProps = { type FieldHandleProps = {
nodeProps: NodeProps<InvocationNodeData>;
nodeTemplate: InvocationTemplate;
field: InputFieldValue | OutputFieldValue;
fieldTemplate: InputFieldTemplate | OutputFieldTemplate; fieldTemplate: InputFieldTemplate | OutputFieldTemplate;
handleType: HandleType; handleType: HandleType;
isConnectionInProgress: boolean; isConnectionInProgress: boolean;

View File

@ -8,13 +8,11 @@ import {
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import IAIDraggable from 'common/components/IAIDraggable'; import IAIDraggable from 'common/components/IAIDraggable';
import { NodeFieldDraggableData } from 'features/dnd/types'; import { NodeFieldDraggableData } from 'features/dnd/types';
import { fieldLabelChanged } from 'features/nodes/store/nodesSlice';
import { import {
InputFieldTemplate, useFieldData,
InputFieldValue, useFieldTemplate,
InvocationNodeData, } from 'features/nodes/hooks/useNodeData';
InvocationTemplate, import { fieldLabelChanged } from 'features/nodes/store/nodesSlice';
} from 'features/nodes/types/types';
import { import {
MouseEvent, MouseEvent,
memo, memo,
@ -25,41 +23,43 @@ import {
} from 'react'; } from 'react';
interface Props { interface Props {
nodeData: InvocationNodeData; nodeId: string;
nodeTemplate: InvocationTemplate; fieldName: string;
field: InputFieldValue;
fieldTemplate: InputFieldTemplate;
isDraggable?: boolean; isDraggable?: boolean;
kind: 'input' | 'output';
} }
const FieldTitle = (props: Props) => { const FieldTitle = (props: Props) => {
const { nodeData, field, fieldTemplate, isDraggable = false } = props; const { nodeId, fieldName, isDraggable = false, kind } = props;
const { label } = field; const fieldTemplate = useFieldTemplate(nodeId, fieldName, kind);
const { title, input } = fieldTemplate; const field = useFieldData(nodeId, fieldName);
const { id: nodeId } = nodeData;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const [localTitle, setLocalTitle] = useState(label || title); const [localTitle, setLocalTitle] = useState(
field?.label || fieldTemplate?.title || 'Unknown Field'
);
const draggableData: NodeFieldDraggableData | undefined = useMemo( const draggableData: NodeFieldDraggableData | undefined = useMemo(
() => () =>
input !== 'connection' && isDraggable field &&
fieldTemplate?.fieldKind === 'input' &&
fieldTemplate?.input !== 'connection' &&
isDraggable
? { ? {
id: `${nodeId}-${field.name}`, id: `${nodeId}-${fieldName}`,
payloadType: 'NODE_FIELD', payloadType: 'NODE_FIELD',
payload: { nodeId, field, fieldTemplate }, payload: { nodeId, field, fieldTemplate },
} }
: undefined, : undefined,
[field, fieldTemplate, input, isDraggable, nodeId] [field, fieldName, fieldTemplate, isDraggable, nodeId]
); );
const handleSubmit = useCallback( const handleSubmit = useCallback(
async (newTitle: string) => { async (newTitle: string) => {
dispatch( dispatch(fieldLabelChanged({ nodeId, fieldName, label: newTitle }));
fieldLabelChanged({ nodeId, fieldName: field.name, label: newTitle }) setLocalTitle(newTitle || fieldTemplate?.title || 'Unknown Field');
);
setLocalTitle(newTitle || title);
}, },
[dispatch, nodeId, field.name, title] [dispatch, nodeId, fieldName, fieldTemplate?.title]
); );
const handleChange = useCallback((newTitle: string) => { const handleChange = useCallback((newTitle: string) => {
@ -68,8 +68,8 @@ const FieldTitle = (props: Props) => {
useEffect(() => { useEffect(() => {
// Another component may change the title; sync local title with global state // Another component may change the title; sync local title with global state
setLocalTitle(label || title); setLocalTitle(field?.label || fieldTemplate?.title || 'Unknown Field');
}, [label, title]); }, [field?.label, fieldTemplate?.title]);
return ( return (
<Flex <Flex
@ -120,7 +120,7 @@ type EditableControlsProps = {
draggableData?: NodeFieldDraggableData; draggableData?: NodeFieldDraggableData;
}; };
function EditableControls(props: EditableControlsProps) { const EditableControls = memo((props: EditableControlsProps) => {
const { isEditing, getEditButtonProps } = useEditableControls(); const { isEditing, getEditButtonProps } = useEditableControls();
const handleDoubleClick = useCallback( const handleDoubleClick = useCallback(
(e: MouseEvent<HTMLDivElement>) => { (e: MouseEvent<HTMLDivElement>) => {
@ -158,4 +158,6 @@ function EditableControls(props: EditableControlsProps) {
cursor="text" cursor="text"
/> />
); );
} });
EditableControls.displayName = 'EditableControls';

View File

@ -1,38 +1,53 @@
import { Flex, Text } from '@chakra-ui/react'; import { Flex, Text } from '@chakra-ui/react';
import {
useFieldData,
useFieldTemplate,
} from 'features/nodes/hooks/useNodeData';
import { FIELDS } from 'features/nodes/types/constants'; import { FIELDS } from 'features/nodes/types/constants';
import { import {
InputFieldTemplate,
InputFieldValue,
InvocationNodeData,
InvocationTemplate,
OutputFieldTemplate,
OutputFieldValue,
isInputFieldTemplate, isInputFieldTemplate,
isInputFieldValue, isInputFieldValue,
} from 'features/nodes/types/types'; } from 'features/nodes/types/types';
import { startCase } from 'lodash-es'; import { startCase } from 'lodash-es';
import { useMemo } from 'react';
interface Props { interface Props {
nodeData: InvocationNodeData; nodeId: string;
nodeTemplate: InvocationTemplate; fieldName: string;
field: InputFieldValue | OutputFieldValue; kind: 'input' | 'output';
fieldTemplate: InputFieldTemplate | OutputFieldTemplate;
} }
const FieldTooltipContent = ({ field, fieldTemplate }: Props) => { const FieldTooltipContent = ({ nodeId, fieldName, kind }: Props) => {
const field = useFieldData(nodeId, fieldName);
const fieldTemplate = useFieldTemplate(nodeId, fieldName, kind);
const isInputTemplate = isInputFieldTemplate(fieldTemplate); const isInputTemplate = isInputFieldTemplate(fieldTemplate);
const fieldTitle = useMemo(() => {
if (isInputFieldValue(field)) {
if (field.label && fieldTemplate) {
return `${field.label} (${fieldTemplate.title})`;
}
if (field.label && !fieldTemplate) {
return field.label;
}
if (!field.label && fieldTemplate) {
return fieldTemplate.title;
}
return 'Unknown Field';
}
}, [field, fieldTemplate]);
return ( return (
<Flex sx={{ flexDir: 'column' }}> <Flex sx={{ flexDir: 'column' }}>
<Text sx={{ fontWeight: 600 }}> <Text sx={{ fontWeight: 600 }}>{fieldTitle}</Text>
{isInputFieldValue(field) && field.label {fieldTemplate && (
? `${field.label} (${fieldTemplate.title})` <Text sx={{ opacity: 0.7, fontStyle: 'oblique 5deg' }}>
: fieldTemplate.title} {fieldTemplate.description}
</Text> </Text>
<Text sx={{ opacity: 0.7, fontStyle: 'oblique 5deg' }}> )}
{fieldTemplate.description} {fieldTemplate && <Text>Type: {FIELDS[fieldTemplate.type].title}</Text>}
</Text>
<Text>Type: {FIELDS[fieldTemplate.type].title}</Text>
{isInputTemplate && <Text>Input: {startCase(fieldTemplate.input)}</Text>} {isInputTemplate && <Text>Input: {startCase(fieldTemplate.input)}</Text>}
</Flex> </Flex>
); );

View File

@ -1,27 +1,24 @@
import { Flex, FormControl, FormLabel, Tooltip } from '@chakra-ui/react'; import { Flex, FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
import { useConnectionState } from 'features/nodes/hooks/useConnectionState'; import { useConnectionState } from 'features/nodes/hooks/useConnectionState';
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
import { import {
InputFieldValue, useDoesInputHaveValue,
InvocationNodeData, useFieldTemplate,
InvocationTemplate, } from 'features/nodes/hooks/useNodeData';
} from 'features/nodes/types/types'; import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
import { PropsWithChildren, useMemo } from 'react'; import { PropsWithChildren, memo, useMemo } from 'react';
import { NodeProps } from 'reactflow';
import FieldHandle from './FieldHandle'; import FieldHandle from './FieldHandle';
import FieldTitle from './FieldTitle'; import FieldTitle from './FieldTitle';
import FieldTooltipContent from './FieldTooltipContent'; import FieldTooltipContent from './FieldTooltipContent';
import InputFieldRenderer from './InputFieldRenderer'; import InputFieldRenderer from './InputFieldRenderer';
interface Props { interface Props {
nodeProps: NodeProps<InvocationNodeData>; nodeId: string;
nodeTemplate: InvocationTemplate; fieldName: string;
field: InputFieldValue;
} }
const InputField = (props: Props) => { const InputField = ({ nodeId, fieldName }: Props) => {
const { nodeProps, nodeTemplate, field } = props; const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input');
const { id: nodeId } = nodeProps.data; const doesFieldHaveValue = useDoesInputHaveValue(nodeId, fieldName);
const { const {
isConnected, isConnected,
@ -29,15 +26,10 @@ const InputField = (props: Props) => {
isConnectionStartField, isConnectionStartField,
connectionError, connectionError,
shouldDim, shouldDim,
} = useConnectionState({ nodeId, field, kind: 'input' }); } = useConnectionState({ nodeId, fieldName, kind: 'input' });
const fieldTemplate = useMemo(
() => nodeTemplate.inputs[field.name],
[field.name, nodeTemplate.inputs]
);
const isMissingInput = useMemo(() => { const isMissingInput = useMemo(() => {
if (!fieldTemplate) { if (fieldTemplate?.fieldKind !== 'input') {
return false; return false;
} }
@ -49,18 +41,18 @@ const InputField = (props: Props) => {
return true; return true;
} }
if (!field.value && !isConnected && fieldTemplate.input === 'any') { if (!doesFieldHaveValue && !isConnected && fieldTemplate.input === 'any') {
return true; return true;
} }
}, [fieldTemplate, isConnected, field.value]); }, [fieldTemplate, isConnected, doesFieldHaveValue]);
if (!fieldTemplate) { if (fieldTemplate?.fieldKind !== 'input') {
return ( return (
<InputFieldWrapper shouldDim={shouldDim}> <InputFieldWrapper shouldDim={shouldDim}>
<FormControl <FormControl
sx={{ color: 'error.400', textAlign: 'left', fontSize: 'sm' }} sx={{ color: 'error.400', textAlign: 'left', fontSize: 'sm' }}
> >
Unknown input: {field.name} Unknown input: {fieldName}
</FormControl> </FormControl>
</InputFieldWrapper> </InputFieldWrapper>
); );
@ -82,10 +74,9 @@ const InputField = (props: Props) => {
<Tooltip <Tooltip
label={ label={
<FieldTooltipContent <FieldTooltipContent
nodeData={nodeProps.data} nodeId={nodeId}
nodeTemplate={nodeTemplate} fieldName={fieldName}
field={field} kind="input"
fieldTemplate={fieldTemplate}
/> />
} }
openDelay={HANDLE_TOOLTIP_OPEN_DELAY} openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
@ -95,27 +86,18 @@ const InputField = (props: Props) => {
> >
<FormLabel sx={{ mb: 0 }}> <FormLabel sx={{ mb: 0 }}>
<FieldTitle <FieldTitle
nodeData={nodeProps.data} nodeId={nodeId}
nodeTemplate={nodeTemplate} fieldName={fieldName}
field={field} kind="input"
fieldTemplate={fieldTemplate}
isDraggable isDraggable
/> />
</FormLabel> </FormLabel>
</Tooltip> </Tooltip>
<InputFieldRenderer <InputFieldRenderer nodeId={nodeId} fieldName={fieldName} />
nodeData={nodeProps.data}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
</FormControl> </FormControl>
{fieldTemplate.input !== 'direct' && ( {fieldTemplate.input !== 'direct' && (
<FieldHandle <FieldHandle
nodeProps={nodeProps}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate} fieldTemplate={fieldTemplate}
handleType="target" handleType="target"
isConnectionInProgress={isConnectionInProgress} isConnectionInProgress={isConnectionInProgress}
@ -133,21 +115,25 @@ type InputFieldWrapperProps = PropsWithChildren<{
shouldDim: boolean; shouldDim: boolean;
}>; }>;
const InputFieldWrapper = ({ shouldDim, children }: InputFieldWrapperProps) => ( const InputFieldWrapper = memo(
<Flex ({ shouldDim, children }: InputFieldWrapperProps) => (
className="nopan" <Flex
sx={{ className="nopan"
position: 'relative', sx={{
minH: 8, position: 'relative',
py: 0.5, minH: 8,
alignItems: 'center', py: 0.5,
opacity: shouldDim ? 0.5 : 1, alignItems: 'center',
transitionProperty: 'opacity', opacity: shouldDim ? 0.5 : 1,
transitionDuration: '0.1s', transitionProperty: 'opacity',
w: 'full', transitionDuration: '0.1s',
h: 'full', w: 'full',
}} h: 'full',
> }}
{children} >
</Flex> {children}
</Flex>
)
); );
InputFieldWrapper.displayName = 'InputFieldWrapper';

View File

@ -1,11 +1,9 @@
import { Box } from '@chakra-ui/react'; import { Box } from '@chakra-ui/react';
import { memo } from 'react';
import { import {
InputFieldTemplate, useFieldData,
InputFieldValue, useFieldTemplate,
InvocationNodeData, } from 'features/nodes/hooks/useNodeData';
InvocationTemplate, import { memo } from 'react';
} from '../../types/types';
import BooleanInputField from './fieldTypes/BooleanInputField'; import BooleanInputField from './fieldTypes/BooleanInputField';
import ClipInputField from './fieldTypes/ClipInputField'; import ClipInputField from './fieldTypes/ClipInputField';
import CollectionInputField from './fieldTypes/CollectionInputField'; import CollectionInputField from './fieldTypes/CollectionInputField';
@ -29,33 +27,33 @@ import VaeInputField from './fieldTypes/VaeInputField';
import VaeModelInputField from './fieldTypes/VaeModelInputField'; import VaeModelInputField from './fieldTypes/VaeModelInputField';
type InputFieldProps = { type InputFieldProps = {
nodeData: InvocationNodeData; nodeId: string;
nodeTemplate: InvocationTemplate; fieldName: string;
field: InputFieldValue;
fieldTemplate: InputFieldTemplate;
}; };
// build an individual input element based on the schema // build an individual input element based on the schema
const InputFieldRenderer = (props: InputFieldProps) => { const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
const { nodeData, nodeTemplate, field, fieldTemplate } = props; const field = useFieldData(nodeId, fieldName);
const { type } = field; const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input');
if (type === 'string' && fieldTemplate.type === 'string') { if (fieldTemplate?.fieldKind === 'output') {
return <Box p={2}>Output field in input: {field?.type}</Box>;
}
if (field?.type === 'string' && fieldTemplate?.type === 'string') {
return ( return (
<StringInputField <StringInputField
nodeData={nodeData} nodeId={nodeId}
nodeTemplate={nodeTemplate}
field={field} field={field}
fieldTemplate={fieldTemplate} fieldTemplate={fieldTemplate}
/> />
); );
} }
if (type === 'boolean' && fieldTemplate.type === 'boolean') { if (field?.type === 'boolean' && fieldTemplate?.type === 'boolean') {
return ( return (
<BooleanInputField <BooleanInputField
nodeData={nodeData} nodeId={nodeId}
nodeTemplate={nodeTemplate}
field={field} field={field}
fieldTemplate={fieldTemplate} fieldTemplate={fieldTemplate}
/> />
@ -63,46 +61,32 @@ const InputFieldRenderer = (props: InputFieldProps) => {
} }
if ( if (
(type === 'integer' && fieldTemplate.type === 'integer') || (field?.type === 'integer' && fieldTemplate?.type === 'integer') ||
(type === 'float' && fieldTemplate.type === 'float') (field?.type === 'float' && fieldTemplate?.type === 'float')
) { ) {
return ( return (
<NumberInputField <NumberInputField
nodeData={nodeData} nodeId={nodeId}
nodeTemplate={nodeTemplate}
field={field} field={field}
fieldTemplate={fieldTemplate} fieldTemplate={fieldTemplate}
/> />
); );
} }
if (type === 'enum' && fieldTemplate.type === 'enum') { if (field?.type === 'enum' && fieldTemplate?.type === 'enum') {
return ( return (
<EnumInputField <EnumInputField
nodeData={nodeData} nodeId={nodeId}
nodeTemplate={nodeTemplate}
field={field} field={field}
fieldTemplate={fieldTemplate} fieldTemplate={fieldTemplate}
/> />
); );
} }
if (type === 'ImageField' && fieldTemplate.type === 'ImageField') { if (field?.type === 'ImageField' && fieldTemplate?.type === 'ImageField') {
return ( return (
<ImageInputField <ImageInputField
nodeData={nodeData} nodeId={nodeId}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'LatentsField' && fieldTemplate.type === 'LatentsField') {
return (
<LatentsInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field} field={field}
fieldTemplate={fieldTemplate} fieldTemplate={fieldTemplate}
/> />
@ -110,68 +94,55 @@ const InputFieldRenderer = (props: InputFieldProps) => {
} }
if ( if (
type === 'ConditioningField' && field?.type === 'LatentsField' &&
fieldTemplate.type === 'ConditioningField' fieldTemplate?.type === 'LatentsField'
) {
return (
<LatentsInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'ConditioningField' &&
fieldTemplate?.type === 'ConditioningField'
) { ) {
return ( return (
<ConditioningInputField <ConditioningInputField
nodeData={nodeData} nodeId={nodeId}
nodeTemplate={nodeTemplate}
field={field} field={field}
fieldTemplate={fieldTemplate} fieldTemplate={fieldTemplate}
/> />
); );
} }
if (type === 'UNetField' && fieldTemplate.type === 'UNetField') { if (field?.type === 'UNetField' && fieldTemplate?.type === 'UNetField') {
return ( return (
<UnetInputField <UnetInputField
nodeData={nodeData} nodeId={nodeId}
nodeTemplate={nodeTemplate}
field={field} field={field}
fieldTemplate={fieldTemplate} fieldTemplate={fieldTemplate}
/> />
); );
} }
if (type === 'ClipField' && fieldTemplate.type === 'ClipField') { if (field?.type === 'ClipField' && fieldTemplate?.type === 'ClipField') {
return ( return (
<ClipInputField <ClipInputField
nodeData={nodeData} nodeId={nodeId}
nodeTemplate={nodeTemplate}
field={field} field={field}
fieldTemplate={fieldTemplate} fieldTemplate={fieldTemplate}
/> />
); );
} }
if (type === 'VaeField' && fieldTemplate.type === 'VaeField') { if (field?.type === 'VaeField' && fieldTemplate?.type === 'VaeField') {
return ( return (
<VaeInputField <VaeInputField
nodeData={nodeData} nodeId={nodeId}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'ControlField' && fieldTemplate.type === 'ControlField') {
return (
<ControlInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'MainModelField' && fieldTemplate.type === 'MainModelField') {
return (
<MainModelInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field} field={field}
fieldTemplate={fieldTemplate} fieldTemplate={fieldTemplate}
/> />
@ -179,35 +150,38 @@ const InputFieldRenderer = (props: InputFieldProps) => {
} }
if ( if (
type === 'SDXLRefinerModelField' && field?.type === 'ControlField' &&
fieldTemplate.type === 'SDXLRefinerModelField' fieldTemplate?.type === 'ControlField'
) {
return (
<ControlInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'MainModelField' &&
fieldTemplate?.type === 'MainModelField'
) {
return (
<MainModelInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'SDXLRefinerModelField' &&
fieldTemplate?.type === 'SDXLRefinerModelField'
) { ) {
return ( return (
<RefinerModelInputField <RefinerModelInputField
nodeData={nodeData} nodeId={nodeId}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'VaeModelField' && fieldTemplate.type === 'VaeModelField') {
return (
<VaeModelInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'LoRAModelField' && fieldTemplate.type === 'LoRAModelField') {
return (
<LoRAModelInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field} field={field}
fieldTemplate={fieldTemplate} fieldTemplate={fieldTemplate}
/> />
@ -215,57 +189,48 @@ const InputFieldRenderer = (props: InputFieldProps) => {
} }
if ( if (
type === 'ControlNetModelField' && field?.type === 'VaeModelField' &&
fieldTemplate.type === 'ControlNetModelField' fieldTemplate?.type === 'VaeModelField'
) {
return (
<VaeModelInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'LoRAModelField' &&
fieldTemplate?.type === 'LoRAModelField'
) {
return (
<LoRAModelInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'ControlNetModelField' &&
fieldTemplate?.type === 'ControlNetModelField'
) { ) {
return ( return (
<ControlNetModelInputField <ControlNetModelInputField
nodeData={nodeData} nodeId={nodeId}
nodeTemplate={nodeTemplate}
field={field} field={field}
fieldTemplate={fieldTemplate} fieldTemplate={fieldTemplate}
/> />
); );
} }
if (type === 'Collection' && fieldTemplate.type === 'Collection') { if (field?.type === 'Collection' && fieldTemplate?.type === 'Collection') {
return ( return (
<CollectionInputField <CollectionInputField
nodeData={nodeData} nodeId={nodeId}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'CollectionItem' && fieldTemplate.type === 'CollectionItem') {
return (
<CollectionItemInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'ColorField' && fieldTemplate.type === 'ColorField') {
return (
<ColorInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'ImageCollection' && fieldTemplate.type === 'ImageCollection') {
return (
<ImageCollectionInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field} field={field}
fieldTemplate={fieldTemplate} fieldTemplate={fieldTemplate}
/> />
@ -273,20 +238,55 @@ const InputFieldRenderer = (props: InputFieldProps) => {
} }
if ( if (
type === 'SDXLMainModelField' && field?.type === 'CollectionItem' &&
fieldTemplate.type === 'SDXLMainModelField' fieldTemplate?.type === 'CollectionItem'
) { ) {
return ( return (
<SDXLMainModelInputField <CollectionItemInputField
nodeData={nodeData} nodeId={nodeId}
nodeTemplate={nodeTemplate}
field={field} field={field}
fieldTemplate={fieldTemplate} fieldTemplate={fieldTemplate}
/> />
); );
} }
return <Box p={2}>Unknown field type: {type}</Box>; if (field?.type === 'ColorField' && fieldTemplate?.type === 'ColorField') {
return (
<ColorInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'ImageCollection' &&
fieldTemplate?.type === 'ImageCollection'
) {
return (
<ImageCollectionInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'SDXLMainModelField' &&
fieldTemplate?.type === 'SDXLMainModelField'
) {
return (
<SDXLMainModelInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
return <Box p={2}>Unknown field type: {field?.type}</Box>;
}; };
export default memo(InputFieldRenderer); export default memo(InputFieldRenderer);

View File

@ -1,39 +1,16 @@
import { Flex, FormControl, FormLabel, Tooltip } from '@chakra-ui/react'; import { Flex, FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants'; import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
import {
InputFieldTemplate,
InputFieldValue,
InvocationNodeData,
InvocationTemplate,
} from 'features/nodes/types/types';
import { memo } from 'react'; import { memo } from 'react';
import FieldTitle from './FieldTitle'; import FieldTitle from './FieldTitle';
import FieldTooltipContent from './FieldTooltipContent'; import FieldTooltipContent from './FieldTooltipContent';
import InputFieldRenderer from './InputFieldRenderer'; import InputFieldRenderer from './InputFieldRenderer';
type Props = { type Props = {
nodeData: InvocationNodeData; nodeId: string;
nodeTemplate: InvocationTemplate; fieldName: string;
field: InputFieldValue;
fieldTemplate: InputFieldTemplate;
}; };
const LinearViewField = ({ const LinearViewField = ({ nodeId, fieldName }: Props) => {
nodeData,
nodeTemplate,
field,
fieldTemplate,
}: Props) => {
// const dispatch = useAppDispatch();
// const handleRemoveField = useCallback(() => {
// dispatch(
// workflowExposedFieldRemoved({
// nodeId: nodeData.id,
// fieldName: field.name,
// })
// );
// }, [dispatch, field.name, nodeData.id]);
return ( return (
<Flex <Flex
layerStyle="second" layerStyle="second"
@ -48,10 +25,9 @@ const LinearViewField = ({
<Tooltip <Tooltip
label={ label={
<FieldTooltipContent <FieldTooltipContent
nodeData={nodeData} nodeId={nodeId}
nodeTemplate={nodeTemplate} fieldName={fieldName}
field={field} kind="input"
fieldTemplate={fieldTemplate}
/> />
} }
openDelay={HANDLE_TOOLTIP_OPEN_DELAY} openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
@ -66,20 +42,10 @@ const LinearViewField = ({
mb: 0, mb: 0,
}} }}
> >
<FieldTitle <FieldTitle nodeId={nodeId} fieldName={fieldName} kind="input" />
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
</FormLabel> </FormLabel>
</Tooltip> </Tooltip>
<InputFieldRenderer <InputFieldRenderer nodeId={nodeId} fieldName={fieldName} />
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
</FormControl> </FormControl>
</Flex> </Flex>
); );

View File

@ -6,25 +6,19 @@ import {
Tooltip, Tooltip,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { useConnectionState } from 'features/nodes/hooks/useConnectionState'; import { useConnectionState } from 'features/nodes/hooks/useConnectionState';
import { useFieldTemplate } from 'features/nodes/hooks/useNodeData';
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants'; import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
import { import { PropsWithChildren, memo } from 'react';
InvocationNodeData,
InvocationTemplate,
OutputFieldValue,
} from 'features/nodes/types/types';
import { PropsWithChildren, useMemo } from 'react';
import { NodeProps } from 'reactflow';
import FieldHandle from './FieldHandle'; import FieldHandle from './FieldHandle';
import FieldTooltipContent from './FieldTooltipContent'; import FieldTooltipContent from './FieldTooltipContent';
interface Props { interface Props {
nodeProps: NodeProps<InvocationNodeData>; nodeId: string;
nodeTemplate: InvocationTemplate; fieldName: string;
field: OutputFieldValue;
} }
const OutputField = (props: Props) => { const OutputField = ({ nodeId, fieldName }: Props) => {
const { nodeTemplate, nodeProps, field } = props; const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'output');
const { const {
isConnected, isConnected,
@ -32,20 +26,15 @@ const OutputField = (props: Props) => {
isConnectionStartField, isConnectionStartField,
connectionError, connectionError,
shouldDim, shouldDim,
} = useConnectionState({ nodeId: nodeProps.data.id, field, kind: 'output' }); } = useConnectionState({ nodeId, fieldName, kind: 'output' });
const fieldTemplate = useMemo( if (fieldTemplate?.fieldKind !== 'output') {
() => nodeTemplate.outputs[field.name],
[field.name, nodeTemplate]
);
if (!fieldTemplate) {
return ( return (
<OutputFieldWrapper shouldDim={shouldDim}> <OutputFieldWrapper shouldDim={shouldDim}>
<FormControl <FormControl
sx={{ color: 'error.400', textAlign: 'right', fontSize: 'sm' }} sx={{ color: 'error.400', textAlign: 'right', fontSize: 'sm' }}
> >
Unknown output: {field.name} Unknown output: {fieldName}
</FormControl> </FormControl>
</OutputFieldWrapper> </OutputFieldWrapper>
); );
@ -57,10 +46,9 @@ const OutputField = (props: Props) => {
<Tooltip <Tooltip
label={ label={
<FieldTooltipContent <FieldTooltipContent
nodeData={nodeProps.data} nodeId={nodeId}
nodeTemplate={nodeTemplate} fieldName={fieldName}
field={field} kind="output"
fieldTemplate={fieldTemplate}
/> />
} }
openDelay={HANDLE_TOOLTIP_OPEN_DELAY} openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
@ -75,9 +63,6 @@ const OutputField = (props: Props) => {
</FormControl> </FormControl>
</Tooltip> </Tooltip>
<FieldHandle <FieldHandle
nodeProps={nodeProps}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate} fieldTemplate={fieldTemplate}
handleType="source" handleType="source"
isConnectionInProgress={isConnectionInProgress} isConnectionInProgress={isConnectionInProgress}
@ -88,27 +73,28 @@ const OutputField = (props: Props) => {
); );
}; };
export default OutputField; export default memo(OutputField);
type OutputFieldWrapperProps = PropsWithChildren<{ type OutputFieldWrapperProps = PropsWithChildren<{
shouldDim: boolean; shouldDim: boolean;
}>; }>;
const OutputFieldWrapper = ({ const OutputFieldWrapper = memo(
shouldDim, ({ shouldDim, children }: OutputFieldWrapperProps) => (
children, <Flex
}: OutputFieldWrapperProps) => ( sx={{
<Flex position: 'relative',
sx={{ minH: 8,
position: 'relative', py: 0.5,
minH: 8, alignItems: 'center',
py: 0.5, opacity: shouldDim ? 0.5 : 1,
alignItems: 'center', transitionProperty: 'opacity',
opacity: shouldDim ? 0.5 : 1, transitionDuration: '0.1s',
transitionProperty: 'opacity', }}
transitionDuration: '0.1s', >
}} {children}
> </Flex>
{children} )
</Flex>
); );
OutputFieldWrapper.displayName = 'OutputFieldWrapper';

View File

@ -11,8 +11,7 @@ import { FieldComponentProps } from './types';
const BooleanInputFieldComponent = ( const BooleanInputFieldComponent = (
props: FieldComponentProps<BooleanInputFieldValue, BooleanInputFieldTemplate> props: FieldComponentProps<BooleanInputFieldValue, BooleanInputFieldTemplate>
) => { ) => {
const { nodeData, field } = props; const { nodeId, field } = props;
const nodeId = nodeData.id;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();

View File

@ -11,8 +11,7 @@ import { FieldComponentProps } from './types';
const ColorInputFieldComponent = ( const ColorInputFieldComponent = (
props: FieldComponentProps<ColorInputFieldValue, ColorInputFieldTemplate> props: FieldComponentProps<ColorInputFieldValue, ColorInputFieldTemplate>
) => { ) => {
const { nodeData, field } = props; const { nodeId, field } = props;
const nodeId = nodeData.id;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();

View File

@ -19,8 +19,7 @@ const ControlNetModelInputFieldComponent = (
ControlNetModelInputFieldTemplate ControlNetModelInputFieldTemplate
> >
) => { ) => {
const { nodeData, field } = props; const { nodeId, field } = props;
const nodeId = nodeData.id;
const controlNetModel = field.value; const controlNetModel = field.value;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();

View File

@ -11,8 +11,7 @@ import { FieldComponentProps } from './types';
const EnumInputFieldComponent = ( const EnumInputFieldComponent = (
props: FieldComponentProps<EnumInputFieldValue, EnumInputFieldTemplate> props: FieldComponentProps<EnumInputFieldValue, EnumInputFieldTemplate>
) => { ) => {
const { nodeData, field, fieldTemplate } = props; const { nodeId, field, fieldTemplate } = props;
const nodeId = nodeData.id;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();

View File

@ -19,8 +19,7 @@ const ImageCollectionInputFieldComponent = (
ImageCollectionInputFieldTemplate ImageCollectionInputFieldTemplate
> >
) => { ) => {
const { nodeData, field } = props; const { nodeId, field } = props;
const nodeId = nodeData.id;
// const dispatch = useAppDispatch(); // const dispatch = useAppDispatch();

View File

@ -21,8 +21,7 @@ import { FieldComponentProps } from './types';
const ImageInputFieldComponent = ( const ImageInputFieldComponent = (
props: FieldComponentProps<ImageInputFieldValue, ImageInputFieldTemplate> props: FieldComponentProps<ImageInputFieldValue, ImageInputFieldTemplate>
) => { ) => {
const { nodeData, field } = props; const { nodeId, field } = props;
const nodeId = nodeData.id;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { currentData: imageDTO } = useGetImageDTOQuery( const { currentData: imageDTO } = useGetImageDTOQuery(

View File

@ -21,8 +21,7 @@ const LoRAModelInputFieldComponent = (
LoRAModelInputFieldTemplate LoRAModelInputFieldTemplate
> >
) => { ) => {
const { nodeData, field } = props; const { nodeId, field } = props;
const nodeId = nodeData.id;
const lora = field.value; const lora = field.value;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { data: loraModels } = useGetLoRAModelsQuery(); const { data: loraModels } = useGetLoRAModelsQuery();

View File

@ -26,8 +26,7 @@ const MainModelInputFieldComponent = (
MainModelInputFieldTemplate MainModelInputFieldTemplate
> >
) => { ) => {
const { nodeData, field } = props; const { nodeId, field } = props;
const nodeId = nodeData.id;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled; const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;

View File

@ -23,8 +23,7 @@ const NumberInputFieldComponent = (
IntegerInputFieldTemplate | FloatInputFieldTemplate IntegerInputFieldTemplate | FloatInputFieldTemplate
> >
) => { ) => {
const { nodeData, field, fieldTemplate } = props; const { nodeId, field, fieldTemplate } = props;
const nodeId = nodeData.id;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const [valueAsString, setValueAsString] = useState<string>( const [valueAsString, setValueAsString] = useState<string>(
String(field.value) String(field.value)

View File

@ -24,8 +24,7 @@ const RefinerModelInputFieldComponent = (
SDXLRefinerModelInputFieldTemplate SDXLRefinerModelInputFieldTemplate
> >
) => { ) => {
const { nodeData, field } = props; const { nodeId, field } = props;
const nodeId = nodeData.id;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled; const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;

View File

@ -27,8 +27,7 @@ const ModelInputFieldComponent = (
SDXLMainModelInputFieldTemplate SDXLMainModelInputFieldTemplate
> >
) => { ) => {
const { nodeData, field } = props; const { nodeId, field } = props;
const nodeId = nodeData.id;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled; const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;

View File

@ -12,8 +12,7 @@ import { FieldComponentProps } from './types';
const StringInputFieldComponent = ( const StringInputFieldComponent = (
props: FieldComponentProps<StringInputFieldValue, StringInputFieldTemplate> props: FieldComponentProps<StringInputFieldValue, StringInputFieldTemplate>
) => { ) => {
const { nodeData, field, fieldTemplate } = props; const { nodeId, field, fieldTemplate } = props;
const nodeId = nodeData.id;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const handleValueChanged = useCallback( const handleValueChanged = useCallback(

View File

@ -20,8 +20,7 @@ const VaeModelInputFieldComponent = (
VaeModelInputFieldTemplate VaeModelInputFieldTemplate
> >
) => { ) => {
const { nodeData, field } = props; const { nodeId, field } = props;
const nodeId = nodeData.id;
const vae = field.value; const vae = field.value;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { data: vaeModels } = useGetVaeModelsQuery(); const { data: vaeModels } = useGetVaeModelsQuery();

View File

@ -1,16 +1,13 @@
import { import {
InputFieldTemplate, InputFieldTemplate,
InputFieldValue, InputFieldValue,
InvocationNodeData,
InvocationTemplate,
} from 'features/nodes/types/types'; } from 'features/nodes/types/types';
export type FieldComponentProps< export type FieldComponentProps<
V extends InputFieldValue, V extends InputFieldValue,
T extends InputFieldTemplate T extends InputFieldTemplate
> = { > = {
nodeData: InvocationNodeData; nodeId: string;
nodeTemplate: InvocationTemplate;
field: V; field: V;
fieldTemplate: T; fieldTemplate: T;
}; };

View File

@ -55,7 +55,11 @@ const CurrentImageNode = (props: NodeProps) => {
export default memo(CurrentImageNode); export default memo(CurrentImageNode);
const Wrapper = (props: PropsWithChildren<{ nodeProps: NodeProps }>) => ( const Wrapper = (props: PropsWithChildren<{ nodeProps: NodeProps }>) => (
<NodeWrapper nodeProps={props.nodeProps} width={384}> <NodeWrapper
nodeId={props.nodeProps.data.id}
selected={props.nodeProps.selected}
width={384}
>
<Flex <Flex
className={DRAG_HANDLE_CLASSNAME} className={DRAG_HANDLE_CLASSNAME}
sx={{ sx={{

View File

@ -1,5 +1,6 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { makeTemplateSelector } from 'features/nodes/store/util/makeTemplateSelector';
import { InvocationNodeData } from 'features/nodes/types/types'; import { InvocationNodeData } from 'features/nodes/types/types';
import { memo, useMemo } from 'react'; import { memo, useMemo } from 'react';
import { NodeProps } from 'reactflow'; import { NodeProps } from 'reactflow';
@ -7,18 +8,40 @@ import InvocationNode from '../Invocation/InvocationNode';
import UnknownNodeFallback from '../Invocation/UnknownNodeFallback'; import UnknownNodeFallback from '../Invocation/UnknownNodeFallback';
const InvocationNodeWrapper = (props: NodeProps<InvocationNodeData>) => { const InvocationNodeWrapper = (props: NodeProps<InvocationNodeData>) => {
const { data } = props; const { data, selected } = props;
const { type } = data; const { id: nodeId, type, isOpen, label } = data;
const templateSelector = useMemo(() => makeTemplateSelector(type), [type]); const hasTemplateSelector = useMemo(
() =>
createSelector(stateSelector, ({ nodes }) =>
Boolean(nodes.nodeTemplates[type])
),
[type]
);
const nodeTemplate = useAppSelector(templateSelector); const nodeTemplate = useAppSelector(hasTemplateSelector);
if (!nodeTemplate) { if (!nodeTemplate) {
return <UnknownNodeFallback nodeProps={props} />; return (
<UnknownNodeFallback
nodeId={nodeId}
isOpen={isOpen}
label={label}
type={type}
selected={selected}
/>
);
} }
return <InvocationNode nodeProps={props} nodeTemplate={nodeTemplate} />; return (
<InvocationNode
nodeId={nodeId}
isOpen={isOpen}
label={label}
type={type}
selected={selected}
/>
);
}; };
export default memo(InvocationNodeWrapper); export default memo(InvocationNodeWrapper);

View File

@ -10,7 +10,7 @@ import NodeTitle from '../Invocation/NodeTitle';
import NodeWrapper from '../Invocation/NodeWrapper'; import NodeWrapper from '../Invocation/NodeWrapper';
const NotesNode = (props: NodeProps<NotesNodeData>) => { const NotesNode = (props: NodeProps<NotesNodeData>) => {
const { id: nodeId, data } = props; const { id: nodeId, data, selected } = props;
const { notes, isOpen } = data; const { notes, isOpen } = data;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const handleChange = useCallback( const handleChange = useCallback(
@ -21,7 +21,7 @@ const NotesNode = (props: NodeProps<NotesNodeData>) => {
); );
return ( return (
<NodeWrapper nodeProps={props}> <NodeWrapper nodeId={nodeId} selected={selected}>
<Flex <Flex
layerStyle="nodeHeader" layerStyle="nodeHeader"
sx={{ sx={{
@ -32,8 +32,8 @@ const NotesNode = (props: NodeProps<NotesNodeData>) => {
h: 8, h: 8,
}} }}
> >
<NodeCollapseButton nodeProps={props} /> <NodeCollapseButton nodeId={nodeId} isOpen={isOpen} />
<NodeTitle nodeData={props.data} title="Notes" /> <NodeTitle nodeId={nodeId} title="Notes" />
<Box minW={8} /> <Box minW={8} />
</Flex> </Flex>
{isOpen && ( {isOpen && (

View File

@ -6,39 +6,11 @@ import {
TabPanels, TabPanels,
Tabs, Tabs,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import ImageMetadataJSON from 'features/gallery/components/ImageMetadataViewer/ImageMetadataJSON';
import { memo } from 'react'; import { memo } from 'react';
import NodeDataInspector from './NodeDataInspector';
const selector = createSelector( import NodeTemplateInspector from './NodeTemplateInspector';
stateSelector,
({ nodes }) => {
const lastSelectedNodeId =
nodes.selectedNodes[nodes.selectedNodes.length - 1];
const lastSelectedNode = nodes.nodes.find(
(node) => node.id === lastSelectedNodeId
);
const lastSelectedNodeTemplate = lastSelectedNode
? nodes.nodeTemplates[lastSelectedNode.data.type]
: undefined;
return {
node: lastSelectedNode,
template: lastSelectedNodeTemplate,
};
},
defaultSelectorOptions
);
const InspectorPanel = () => { const InspectorPanel = () => {
const { node, template } = useAppSelector(selector);
return ( return (
<Flex <Flex
layerStyle="first" layerStyle="first"
@ -60,37 +32,10 @@ const InspectorPanel = () => {
<TabPanels> <TabPanels>
<TabPanel> <TabPanel>
{template ? ( <NodeTemplateInspector />
<Flex
sx={{
flexDir: 'column',
alignItems: 'flex-start',
gap: 2,
h: 'full',
}}
>
<ImageMetadataJSON
jsonObject={template}
label="Node Template"
/>
</Flex>
) : (
<IAINoContentFallback
label={
node
? 'No template found for selected node'
: 'No node selected'
}
icon={null}
/>
)}
</TabPanel> </TabPanel>
<TabPanel> <TabPanel>
{node ? ( <NodeDataInspector />
<ImageMetadataJSON jsonObject={node.data} label="Node Data" />
) : (
<IAINoContentFallback label="No node selected" icon={null} />
)}
</TabPanel> </TabPanel>
</TabPanels> </TabPanels>
</Tabs> </Tabs>

View File

@ -17,20 +17,20 @@ const selector = createSelector(
); );
return { return {
node: lastSelectedNode, data: lastSelectedNode?.data,
}; };
}, },
defaultSelectorOptions defaultSelectorOptions
); );
const NodeDataInspector = () => { const NodeDataInspector = () => {
const { node } = useAppSelector(selector); const { data } = useAppSelector(selector);
return node ? ( if (!data) {
<ImageMetadataJSON jsonObject={node.data} label="Node Data" /> return <IAINoContentFallback label="No node selected" icon={null} />;
) : ( }
<IAINoContentFallback label="No node data" icon={null} />
); return <ImageMetadataJSON jsonObject={data} label="Node Data" />;
}; };
export default memo(NodeDataInspector); export default memo(NodeDataInspector);

View File

@ -0,0 +1,40 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import ImageMetadataJSON from 'features/gallery/components/ImageMetadataViewer/ImageMetadataJSON';
import { memo } from 'react';
const selector = createSelector(
stateSelector,
({ nodes }) => {
const lastSelectedNodeId =
nodes.selectedNodes[nodes.selectedNodes.length - 1];
const lastSelectedNode = nodes.nodes.find(
(node) => node.id === lastSelectedNodeId
);
const lastSelectedNodeTemplate = lastSelectedNode
? nodes.nodeTemplates[lastSelectedNode.data.type]
: undefined;
return {
template: lastSelectedNodeTemplate,
};
},
defaultSelectorOptions
);
const NodeTemplateInspector = () => {
const { template } = useAppSelector(selector);
if (!template) {
return <IAINoContentFallback label="No node selected" icon={null} />;
}
return <ImageMetadataJSON jsonObject={template} label="Node Template" />;
};
export default memo(NodeTemplateInspector);

View File

@ -6,14 +6,6 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIDroppable from 'common/components/IAIDroppable'; import IAIDroppable from 'common/components/IAIDroppable';
import { IAINoContentFallback } from 'common/components/IAIImageFallback'; import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { AddFieldToLinearViewDropData } from 'features/dnd/types'; import { AddFieldToLinearViewDropData } from 'features/dnd/types';
import {
InputFieldTemplate,
InputFieldValue,
InvocationNodeData,
InvocationTemplate,
isInvocationNode,
} from 'features/nodes/types/types';
import { forEach } from 'lodash-es';
import { memo } from 'react'; import { memo } from 'react';
import LinearViewField from '../../fields/LinearViewField'; import LinearViewField from '../../fields/LinearViewField';
import ScrollableContent from '../ScrollableContent'; import ScrollableContent from '../ScrollableContent';
@ -21,41 +13,8 @@ import ScrollableContent from '../ScrollableContent';
const selector = createSelector( const selector = createSelector(
stateSelector, stateSelector,
({ nodes }) => { ({ nodes }) => {
const fields: {
nodeData: InvocationNodeData;
nodeTemplate: InvocationTemplate;
field: InputFieldValue;
fieldTemplate: InputFieldTemplate;
}[] = [];
const { exposedFields } = nodes.workflow;
nodes.nodes.filter(isInvocationNode).forEach((node) => {
const nodeTemplate = nodes.nodeTemplates[node.data.type];
if (!nodeTemplate) {
return;
}
forEach(node.data.inputs, (field) => {
if (
!exposedFields.some(
(f) => f.nodeId === node.id && f.fieldName === field.name
)
) {
return;
}
const fieldTemplate = nodeTemplate.inputs[field.name];
if (!fieldTemplate) {
return;
}
fields.push({
nodeData: node.data,
nodeTemplate,
field,
fieldTemplate,
});
});
});
return { return {
fields, fields: nodes.workflow.exposedFields,
}; };
}, },
defaultSelectorOptions defaultSelectorOptions
@ -89,13 +48,11 @@ const LinearTabContent = () => {
}} }}
> >
{fields.length ? ( {fields.length ? (
fields.map(({ nodeData, nodeTemplate, field, fieldTemplate }) => ( fields.map(({ nodeId, fieldName }) => (
<LinearViewField <LinearViewField
key={field.id} key={`${nodeId}-${fieldName}`}
nodeData={nodeData} nodeId={nodeId}
nodeTemplate={nodeTemplate} fieldName={fieldName}
field={field}
fieldTemplate={fieldTemplate}
/> />
)) ))
) : ( ) : (

View File

@ -25,7 +25,9 @@ const ClearGraphButton = () => {
const { isOpen, onOpen, onClose } = useDisclosure(); const { isOpen, onOpen, onClose } = useDisclosure();
const cancelRef = useRef<HTMLButtonElement | null>(null); const cancelRef = useRef<HTMLButtonElement | null>(null);
const nodes = useAppSelector((state: RootState) => state.nodes.nodes); const nodesCount = useAppSelector(
(state: RootState) => state.nodes.nodes.length
);
const handleConfirmClear = useCallback(() => { const handleConfirmClear = useCallback(() => {
dispatch(nodeEditorReset()); dispatch(nodeEditorReset());
@ -49,7 +51,7 @@ const ClearGraphButton = () => {
tooltip={t('nodes.clearGraph')} tooltip={t('nodes.clearGraph')}
aria-label={t('nodes.clearGraph')} aria-label={t('nodes.clearGraph')}
onClick={onOpen} onClick={onOpen}
isDisabled={nodes.length === 0} isDisabled={!nodesCount}
/> />
<AlertDialog <AlertDialog

View File

@ -8,7 +8,7 @@ import IAIIconButton, {
import { selectIsReadyNodes } from 'features/nodes/store/selectors'; import { selectIsReadyNodes } from 'features/nodes/store/selectors';
import ProgressBar from 'features/system/components/ProgressBar'; import ProgressBar from 'features/system/components/ProgressBar';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useHotkeys } from 'react-hotkeys-hook'; import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { FaPlay } from 'react-icons/fa'; import { FaPlay } from 'react-icons/fa';
@ -18,7 +18,7 @@ interface InvokeButton
iconButton?: boolean; iconButton?: boolean;
} }
export default function NodeInvokeButton(props: InvokeButton) { const NodeInvokeButton = (props: InvokeButton) => {
const { iconButton = false, ...rest } = props; const { iconButton = false, ...rest } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const activeTabName = useAppSelector(activeTabNameSelector); const activeTabName = useAppSelector(activeTabNameSelector);
@ -92,4 +92,6 @@ export default function NodeInvokeButton(props: InvokeButton) {
</Box> </Box>
</Box> </Box>
); );
} };
export default memo(NodeInvokeButton);

View File

@ -1,11 +1,11 @@
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton'; import IAIIconButton from 'common/components/IAIIconButton';
import { useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { FaSyncAlt } from 'react-icons/fa'; import { FaSyncAlt } from 'react-icons/fa';
import { receivedOpenAPISchema } from 'services/api/thunks/schema'; import { receivedOpenAPISchema } from 'services/api/thunks/schema';
export default function ReloadSchemaButton() { const ReloadSchemaButton = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
@ -21,4 +21,6 @@ export default function ReloadSchemaButton() {
onClick={handleReloadSchema} onClick={handleReloadSchema}
/> />
); );
} };
export default memo(ReloadSchemaButton);

View File

@ -2,8 +2,8 @@ import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { makeConnectionErrorSelector } from 'features/nodes/store/util/makeIsConnectionValidSelector'; import { makeConnectionErrorSelector } from 'features/nodes/store/util/makeIsConnectionValidSelector';
import { InputFieldValue, OutputFieldValue } from 'features/nodes/types/types';
import { useMemo } from 'react'; import { useMemo } from 'react';
import { useFieldType } from './useNodeData';
const selectIsConnectionInProgress = createSelector( const selectIsConnectionInProgress = createSelector(
stateSelector, stateSelector,
@ -12,23 +12,19 @@ const selectIsConnectionInProgress = createSelector(
nodes.connectionStartParams !== null nodes.connectionStartParams !== null
); );
export type UseConnectionStateProps = export type UseConnectionStateProps = {
| { nodeId: string;
nodeId: string; fieldName: string;
field: InputFieldValue; kind: 'input' | 'output';
kind: 'input'; };
}
| {
nodeId: string;
field: OutputFieldValue;
kind: 'output';
};
export const useConnectionState = ({ export const useConnectionState = ({
nodeId, nodeId,
field, fieldName,
kind, kind,
}: UseConnectionStateProps) => { }: UseConnectionStateProps) => {
const fieldType = useFieldType(nodeId, fieldName, kind);
const selectIsConnected = useMemo( const selectIsConnected = useMemo(
() => () =>
createSelector(stateSelector, ({ nodes }) => createSelector(stateSelector, ({ nodes }) =>
@ -37,23 +33,23 @@ export const useConnectionState = ({
return ( return (
(kind === 'input' ? edge.target : edge.source) === nodeId && (kind === 'input' ? edge.target : edge.source) === nodeId &&
(kind === 'input' ? edge.targetHandle : edge.sourceHandle) === (kind === 'input' ? edge.targetHandle : edge.sourceHandle) ===
field.name fieldName
); );
}).length }).length
) )
), ),
[field.name, kind, nodeId] [fieldName, kind, nodeId]
); );
const selectConnectionError = useMemo( const selectConnectionError = useMemo(
() => () =>
makeConnectionErrorSelector( makeConnectionErrorSelector(
nodeId, nodeId,
field.name, fieldName,
kind === 'input' ? 'target' : 'source', kind === 'input' ? 'target' : 'source',
field.type fieldType
), ),
[nodeId, field.name, field.type, kind] [nodeId, fieldName, kind, fieldType]
); );
const selectIsConnectionStartField = useMemo( const selectIsConnectionStartField = useMemo(
@ -61,12 +57,12 @@ export const useConnectionState = ({
createSelector(stateSelector, ({ nodes }) => createSelector(stateSelector, ({ nodes }) =>
Boolean( Boolean(
nodes.connectionStartParams?.nodeId === nodeId && nodes.connectionStartParams?.nodeId === nodeId &&
nodes.connectionStartParams?.handleId === field.name && nodes.connectionStartParams?.handleId === fieldName &&
nodes.connectionStartParams?.handleType === nodes.connectionStartParams?.handleType ===
{ input: 'target', output: 'source' }[kind] { input: 'target', output: 'source' }[kind]
) )
), ),
[field.name, kind, nodeId] [fieldName, kind, nodeId]
); );
const isConnected = useAppSelector(selectIsConnected); const isConnected = useAppSelector(selectIsConnected);

View File

@ -0,0 +1,286 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { map, some } from 'lodash-es';
import { useMemo } from 'react';
import { FOOTER_FIELDS, IMAGE_FIELDS } from '../types/constants';
import { isInvocationNode } from '../types/types';
const KIND_MAP = {
input: 'inputs' as const,
output: 'outputs' as const,
};
export const useNodeTemplate = (nodeId: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? ''];
return nodeTemplate;
},
defaultSelectorOptions
),
[nodeId]
);
const nodeTemplate = useAppSelector(selector);
return nodeTemplate;
};
export const useNodeData = (nodeId: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
return node?.data;
},
defaultSelectorOptions
),
[nodeId]
);
const nodeData = useAppSelector(selector);
return nodeData;
};
export const useFieldData = (nodeId: string, fieldName: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return;
}
return node?.data.inputs[fieldName];
},
defaultSelectorOptions
),
[fieldName, nodeId]
);
const fieldData = useAppSelector(selector);
return fieldData;
};
export const useFieldType = (
nodeId: string,
fieldName: string,
kind: 'input' | 'output'
) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return;
}
return node?.data[KIND_MAP[kind]][fieldName]?.type;
},
defaultSelectorOptions
),
[fieldName, kind, nodeId]
);
const fieldType = useAppSelector(selector);
return fieldType;
};
export const useFieldNames = (nodeId: string, kind: 'input' | 'output') => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return [];
}
return map(node.data[KIND_MAP[kind]], (field) => field.name).filter(
(fieldName) => fieldName !== 'is_intermediate'
);
},
defaultSelectorOptions
),
[kind, nodeId]
);
const fieldNames = useAppSelector(selector);
return fieldNames;
};
export const useWithFooter = (nodeId: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return false;
}
return some(node.data.outputs, (output) =>
FOOTER_FIELDS.includes(output.type)
);
},
defaultSelectorOptions
),
[nodeId]
);
const withFooter = useAppSelector(selector);
return withFooter;
};
export const useHasImageOutput = (nodeId: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return false;
}
return some(node.data.outputs, (output) =>
IMAGE_FIELDS.includes(output.type)
);
},
defaultSelectorOptions
),
[nodeId]
);
const hasImageOutput = useAppSelector(selector);
return hasImageOutput;
};
export const useIsIntermediate = (nodeId: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return false;
}
return Boolean(node.data.inputs.is_intermediate?.value);
},
defaultSelectorOptions
),
[nodeId]
);
const is_intermediate = useAppSelector(selector);
return is_intermediate;
};
export const useNodeLabel = (nodeId: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return false;
}
return node.data.label;
},
defaultSelectorOptions
),
[nodeId]
);
const label = useAppSelector(selector);
return label;
};
export const useNodeTemplateTitle = (nodeId: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return false;
}
const nodeTemplate = node
? nodes.nodeTemplates[node.data.type]
: undefined;
return nodeTemplate?.title;
},
defaultSelectorOptions
),
[nodeId]
);
const title = useAppSelector(selector);
return title;
};
export const useFieldTemplate = (
nodeId: string,
fieldName: string,
kind: 'input' | 'output'
) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return;
}
const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? ''];
return nodeTemplate?.[KIND_MAP[kind]][fieldName];
},
defaultSelectorOptions
),
[fieldName, kind, nodeId]
);
const fieldTemplate = useAppSelector(selector);
return fieldTemplate;
};
export const useDoesInputHaveValue = (nodeId: string, fieldName: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return;
}
return Boolean(node?.data.inputs[fieldName]?.value);
},
defaultSelectorOptions
),
[fieldName, nodeId]
);
const doesFieldHaveValue = useAppSelector(selector);
return doesFieldHaveValue;
};

View File

@ -9,9 +9,13 @@ export const makeConnectionErrorSelector = (
nodeId: string, nodeId: string,
fieldName: string, fieldName: string,
handleType: HandleType, handleType: HandleType,
fieldType: FieldType fieldType?: FieldType
) => ) =>
createSelector(stateSelector, (state) => { createSelector(stateSelector, (state) => {
if (!fieldType) {
return 'No field type';
}
const { currentConnectionFieldType, connectionStartParams, nodes, edges } = const { currentConnectionFieldType, connectionStartParams, nodes, edges } =
state.nodes; state.nodes;

View File

@ -6,6 +6,9 @@ export const NODE_WIDTH = 320;
export const NODE_MIN_WIDTH = 320; export const NODE_MIN_WIDTH = 320;
export const DRAG_HANDLE_CLASSNAME = 'node-drag-handle'; export const DRAG_HANDLE_CLASSNAME = 'node-drag-handle';
export const IMAGE_FIELDS = ['ImageField', 'ImageCollection'];
export const FOOTER_FIELDS = IMAGE_FIELDS;
export const COLLECTION_TYPES: FieldType[] = [ export const COLLECTION_TYPES: FieldType[] = [
'Collection', 'Collection',
'IntegerCollection', 'IntegerCollection',

View File

@ -457,12 +457,13 @@ export type ColorInputFieldTemplate = InputFieldTemplateBase & {
}; };
export const isInputFieldValue = ( export const isInputFieldValue = (
field: InputFieldValue | OutputFieldValue field?: InputFieldValue | OutputFieldValue
): field is InputFieldValue => field.fieldKind === 'input'; ): field is InputFieldValue => Boolean(field && field.fieldKind === 'input');
export const isInputFieldTemplate = ( export const isInputFieldTemplate = (
fieldTemplate: InputFieldTemplate | OutputFieldTemplate fieldTemplate?: InputFieldTemplate | OutputFieldTemplate
): fieldTemplate is InputFieldTemplate => fieldTemplate.fieldKind === 'input'; ): fieldTemplate is InputFieldTemplate =>
Boolean(fieldTemplate && fieldTemplate.fieldKind === 'input');
/** /**
* JANKY CUSTOMISATION OF OpenAPI SCHEMA TYPES * JANKY CUSTOMISATION OF OpenAPI SCHEMA TYPES
@ -632,20 +633,22 @@ export type NodeData =
export const isInvocationNode = ( export const isInvocationNode = (
node?: Node<NodeData> node?: Node<NodeData>
): node is Node<InvocationNodeData> => node?.type === 'invocation'; ): node is Node<InvocationNodeData> =>
Boolean(node && node.type === 'invocation');
export const isInvocationNodeData = ( export const isInvocationNodeData = (
node?: NodeData node?: NodeData
): node is InvocationNodeData => ): node is InvocationNodeData =>
!['notes', 'current_image'].includes(node?.type ?? ''); Boolean(node && !['notes', 'current_image'].includes(node.type));
export const isNotesNode = ( export const isNotesNode = (
node?: Node<NodeData> node?: Node<NodeData>
): node is Node<NotesNodeData> => node?.type === 'notes'; ): node is Node<NotesNodeData> => Boolean(node && node.type === 'notes');
export const isProgressImageNode = ( export const isProgressImageNode = (
node?: Node<NodeData> node?: Node<NodeData>
): node is Node<CurrentImageNodeData> => node?.type === 'current_image'; ): node is Node<CurrentImageNodeData> =>
Boolean(node && node.type === 'current_image');
export enum NodeStatus { export enum NodeStatus {
PENDING, PENDING,

View File

@ -3,4 +3,7 @@ import { UIState } from './uiTypes';
/** /**
* UI slice persist denylist * UI slice persist denylist
*/ */
export const uiPersistDenylist: (keyof UIState)[] = ['shouldShowImageDetails']; export const uiPersistDenylist: (keyof UIState)[] = [
'shouldShowImageDetails',
'globalContextMenuCloseTrigger',
];

View File

@ -20,6 +20,7 @@ export const initialUIState: UIState = {
shouldShowProgressInViewer: true, shouldShowProgressInViewer: true,
shouldShowEmbeddingPicker: false, shouldShowEmbeddingPicker: false,
favoriteSchedulers: [], favoriteSchedulers: [],
globalContextMenuCloseTrigger: 0,
}; };
export const uiSlice = createSlice({ export const uiSlice = createSlice({
@ -96,6 +97,9 @@ export const uiSlice = createSlice({
toggleEmbeddingPicker: (state) => { toggleEmbeddingPicker: (state) => {
state.shouldShowEmbeddingPicker = !state.shouldShowEmbeddingPicker; state.shouldShowEmbeddingPicker = !state.shouldShowEmbeddingPicker;
}, },
contextMenusClosed: (state) => {
state.globalContextMenuCloseTrigger += 1;
},
}, },
extraReducers(builder) { extraReducers(builder) {
builder.addCase(initialImageChanged, (state) => { builder.addCase(initialImageChanged, (state) => {
@ -122,6 +126,7 @@ export const {
setShouldShowProgressInViewer, setShouldShowProgressInViewer,
favoriteSchedulersChanged, favoriteSchedulersChanged,
toggleEmbeddingPicker, toggleEmbeddingPicker,
contextMenusClosed,
} = uiSlice.actions; } = uiSlice.actions;
export default uiSlice.reducer; export default uiSlice.reducer;

View File

@ -26,4 +26,5 @@ export interface UIState {
shouldShowProgressInViewer: boolean; shouldShowProgressInViewer: boolean;
shouldShowEmbeddingPicker: boolean; shouldShowEmbeddingPicker: boolean;
favoriteSchedulers: SchedulerParam[]; favoriteSchedulers: SchedulerParam[];
globalContextMenuCloseTrigger: number;
} }

View File

@ -573,7 +573,7 @@ export type components = {
file: Blob; file: Blob;
}; };
/** /**
* Boolean Collection * Boolean Primitive Collection
* @description A collection of boolean primitive values * @description A collection of boolean primitive values
*/ */
BooleanCollectionInvocation: { BooleanCollectionInvocation: {
@ -619,7 +619,7 @@ export type components = {
collection?: (boolean)[]; collection?: (boolean)[];
}; };
/** /**
* Boolean * Boolean Primitive
* @description A boolean primitive value * @description A boolean primitive value
*/ */
BooleanInvocation: { BooleanInvocation: {
@ -1002,7 +1002,7 @@ export type components = {
clip?: components["schemas"]["ClipField"]; clip?: components["schemas"]["ClipField"];
}; };
/** /**
* Conditioning Collection * Conditioning Primitive Collection
* @description A collection of conditioning tensor primitive values * @description A collection of conditioning tensor primitive values
*/ */
ConditioningCollectionInvocation: { ConditioningCollectionInvocation: {
@ -1770,7 +1770,7 @@ export type components = {
field: string; field: string;
}; };
/** /**
* Float Collection * Float Primitive Collection
* @description A collection of float primitive values * @description A collection of float primitive values
*/ */
FloatCollectionInvocation: { FloatCollectionInvocation: {
@ -1816,7 +1816,7 @@ export type components = {
collection?: (number)[]; collection?: (number)[];
}; };
/** /**
* Float * Float Primitive
* @description A float primitive value * @description A float primitive value
*/ */
FloatInvocation: { FloatInvocation: {
@ -2161,7 +2161,7 @@ export type components = {
channel?: "A" | "R" | "G" | "B"; channel?: "A" | "R" | "G" | "B";
}; };
/** /**
* Image Collection * Image Primitive Collection
* @description A collection of image primitive values * @description A collection of image primitive values
*/ */
ImageCollectionInvocation: { ImageCollectionInvocation: {
@ -3113,7 +3113,7 @@ export type components = {
seed?: number; seed?: number;
}; };
/** /**
* Integer Collection * Integer Primitive Collection
* @description A collection of integer primitive values * @description A collection of integer primitive values
*/ */
IntegerCollectionInvocation: { IntegerCollectionInvocation: {
@ -3159,7 +3159,7 @@ export type components = {
collection?: (number)[]; collection?: (number)[];
}; };
/** /**
* Integer * Integer Primitive
* @description An integer primitive value * @description An integer primitive value
*/ */
IntegerInvocation: { IntegerInvocation: {
@ -3256,7 +3256,7 @@ export type components = {
item?: unknown; item?: unknown;
}; };
/** /**
* Latents Collection * Latents Primitive Collection
* @description A collection of latents tensor primitive values * @description A collection of latents tensor primitive values
*/ */
LatentsCollectionInvocation: { LatentsCollectionInvocation: {
@ -5786,7 +5786,7 @@ export type components = {
show_easing_plot?: boolean; show_easing_plot?: boolean;
}; };
/** /**
* String Collection * String Primitive Collection
* @description A collection of string primitive values * @description A collection of string primitive values
*/ */
StringCollectionInvocation: { StringCollectionInvocation: {
@ -5832,7 +5832,7 @@ export type components = {
collection?: (string)[]; collection?: (string)[];
}; };
/** /**
* String * String Primitive
* @description A string primitive value * @description A string primitive value
*/ */
StringInvocation: { StringInvocation: {
@ -6193,24 +6193,6 @@ export type components = {
ui_hidden: boolean; ui_hidden: boolean;
ui_type?: components["schemas"]["UIType"]; ui_type?: components["schemas"]["UIType"];
}; };
/**
* StableDiffusion2ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
/**
* ControlNetModelFormat
* @description An enumeration.
* @enum {string}
*/
ControlNetModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusionXLModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusionXLModelFormat: "checkpoint" | "diffusers";
/** /**
* StableDiffusionOnnxModelFormat * StableDiffusionOnnxModelFormat
* @description An enumeration. * @description An enumeration.
@ -6223,6 +6205,24 @@ export type components = {
* @enum {string} * @enum {string}
*/ */
StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
/**
* ControlNetModelFormat
* @description An enumeration.
* @enum {string}
*/
ControlNetModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion2ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusionXLModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusionXLModelFormat: "checkpoint" | "diffusers";
}; };
responses: never; responses: never;
parameters: never; parameters: never;