feat: node editor

squashed rebase on main after backendd refactor
This commit is contained in:
psychedelicious
2023-08-14 13:23:09 +10:00
parent d6c9bf5b38
commit f49fc7fb55
188 changed files with 8541 additions and 4660 deletions

View File

@ -61,6 +61,7 @@
"@dagrejs/graphlib": "^2.1.13",
"@dnd-kit/core": "^6.0.8",
"@dnd-kit/modifiers": "^6.0.1",
"@dnd-kit/utilities": "^3.2.1",
"@emotion/react": "^11.11.1",
"@emotion/styled": "^11.11.0",
"@floating-ui/react-dom": "^2.0.1",

View File

@ -0,0 +1,34 @@
export const COLORS = {
reset: '\x1b[0m',
bright: '\x1b[1m',
dim: '\x1b[2m',
underscore: '\x1b[4m',
blink: '\x1b[5m',
reverse: '\x1b[7m',
hidden: '\x1b[8m',
fg: {
black: '\x1b[30m',
red: '\x1b[31m',
green: '\x1b[32m',
yellow: '\x1b[33m',
blue: '\x1b[34m',
magenta: '\x1b[35m',
cyan: '\x1b[36m',
white: '\x1b[37m',
gray: '\x1b[90m',
crimson: '\x1b[38m',
},
bg: {
black: '\x1b[40m',
red: '\x1b[41m',
green: '\x1b[42m',
yellow: '\x1b[43m',
blue: '\x1b[44m',
magenta: '\x1b[45m',
cyan: '\x1b[46m',
white: '\x1b[47m',
gray: '\x1b[100m',
crimson: '\x1b[48m',
},
};

View File

@ -1,23 +1,83 @@
import fs from 'node:fs';
import openapiTS from 'openapi-typescript';
import { COLORS } from './colors.js';
const OPENAPI_URL = 'http://127.0.0.1:9090/openapi.json';
const OUTPUT_FILE = 'src/services/api/schema.d.ts';
async function main() {
process.stdout.write(
`Generating types "${OPENAPI_URL}" --> "${OUTPUT_FILE}"...`
`Generating types "${OPENAPI_URL}" --> "${OUTPUT_FILE}"...\n\n`
);
const types = await openapiTS(OPENAPI_URL, {
exportType: true,
transform: (schemaObject) => {
transform: (schemaObject, metadata) => {
if ('format' in schemaObject && schemaObject.format === 'binary') {
return schemaObject.nullable ? 'Blob | null' : 'Blob';
}
/**
* Because invocations may have required fields that accept connection input, the generated
* types may be incorrect.
*
* For example, the ImageResizeInvocation has a required `image` field, but because it accepts
* connection input, it should be optional on instantiation of the field.
*
* To handle this, the schema exposes an `input` property that can be used to determine if the
* field accepts connection input. If it does, we can make the field optional.
*/
// Check if we are generating types for an invocation
const isInvocationPath = metadata.path.match(
/^#\/components\/schemas\/\w*Invocation$/
);
const hasInvocationProperties =
schemaObject.properties &&
['id', 'is_intermediate', 'type'].every(
(prop) => prop in schemaObject.properties
);
if (isInvocationPath && hasInvocationProperties) {
// We only want to make fields optional if they are required
if (!Array.isArray(schemaObject?.required)) {
schemaObject.required = ['id', 'type'];
return;
}
schemaObject.required.forEach((prop) => {
const acceptsConnection = ['any', 'connection'].includes(
schemaObject.properties?.[prop]?.['input']
);
if (acceptsConnection) {
// remove this prop from the required array
const invocationName = metadata.path.split('/').pop();
console.log(
`Making connectable field optional: ${COLORS.fg.green}${invocationName}.${COLORS.fg.cyan}${prop}${COLORS.reset}`
);
schemaObject.required = schemaObject.required.filter(
(r) => r !== prop
);
}
});
schemaObject.required = [
...new Set(schemaObject.required.concat(['id', 'type'])),
];
return;
}
// if (
// 'input' in schemaObject &&
// (schemaObject.input === 'any' || schemaObject.input === 'connection')
// ) {
// schemaObject.required = false;
// }
},
});
fs.writeFileSync(OUTPUT_FILE, types);
process.stdout.write(` OK!\r\n`);
process.stdout.write(`\nOK!\r\n`);
}
main();

View File

@ -1,8 +1,12 @@
import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
import {
ctrlKeyPressed,
metaKeyPressed,
shiftKeyPressed,
} from 'features/ui/store/hotkeysSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import {
setActiveTab,
@ -16,11 +20,11 @@ import React, { memo } from 'react';
import { isHotkeyPressed, useHotkeys } from 'react-hotkeys-hook';
const globalHotkeysSelector = createSelector(
[(state: RootState) => state.hotkeys, (state: RootState) => state.ui],
(hotkeys, ui) => {
const { shift } = hotkeys;
[stateSelector],
({ hotkeys, ui }) => {
const { shift, ctrl, meta } = hotkeys;
const { shouldPinParametersPanel, shouldPinGallery } = ui;
return { shift, shouldPinGallery, shouldPinParametersPanel };
return { shift, ctrl, meta, shouldPinGallery, shouldPinParametersPanel };
},
{
memoizeOptions: {
@ -37,9 +41,8 @@ const globalHotkeysSelector = createSelector(
*/
const GlobalHotkeys: React.FC = () => {
const dispatch = useAppDispatch();
const { shift, shouldPinParametersPanel, shouldPinGallery } = useAppSelector(
globalHotkeysSelector
);
const { shift, ctrl, meta, shouldPinParametersPanel, shouldPinGallery } =
useAppSelector(globalHotkeysSelector);
const activeTabName = useAppSelector(activeTabNameSelector);
useHotkeys(
@ -50,9 +53,19 @@ const GlobalHotkeys: React.FC = () => {
} else {
shift && dispatch(shiftKeyPressed(false));
}
if (isHotkeyPressed('ctrl')) {
!ctrl && dispatch(ctrlKeyPressed(true));
} else {
ctrl && dispatch(ctrlKeyPressed(false));
}
if (isHotkeyPressed('meta')) {
!meta && dispatch(metaKeyPressed(true));
} else {
meta && dispatch(metaKeyPressed(false));
}
},
{ keyup: true, keydown: true },
[shift]
[shift, ctrl, meta]
);
useHotkeys('o', () => {

View File

@ -14,7 +14,7 @@ import { $authToken, $baseUrl, $projectId } from 'services/api/client';
import { socketMiddleware } from 'services/events/middleware';
import Loading from '../../common/components/Loading/Loading';
import '../../i18n';
import ImageDndContext from './ImageDnd/ImageDndContext';
import AppDndContext from '../../features/dnd/components/AppDndContext';
const App = lazy(() => import('./App'));
const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider'));
@ -80,9 +80,9 @@ const InvokeAIUI = ({
<Provider store={store}>
<React.Suspense fallback={<Loading />}>
<ThemeLocaleProvider>
<ImageDndContext>
<AppDndContext>
<App config={config} headerComponent={headerComponent} />
</ImageDndContext>
</AppDndContext>
</ThemeLocaleProvider>
</React.Suspense>
</Provider>

View File

@ -19,7 +19,8 @@ type LoggerNamespace =
| 'nodes'
| 'system'
| 'socketio'
| 'session';
| 'session'
| 'dnd';
export const logger = (namespace: LoggerNamespace) =>
$logger.get().child({ namespace });

View File

@ -15,7 +15,7 @@ export const actionsDenylist = [
'socket/socketGeneratorProgress',
'socket/appSocketGeneratorProgress',
// every time user presses shift
'hotkeys/shiftKeyPressed',
// 'hotkeys/shiftKeyPressed',
// this happens after every state change
'@@REMEMBER_PERSISTED',
];

View File

@ -1,16 +1,20 @@
import { createAction } from '@reduxjs/toolkit';
import {
TypesafeDraggableData,
TypesafeDroppableData,
} from 'app/components/ImageDnd/typesafeDnd';
import { logger } from 'app/logging/logger';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
import {
TypesafeDraggableData,
TypesafeDroppableData,
} from 'features/dnd/types';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import {
fieldImageValueChanged,
workflowExposedFieldAdded,
} from 'features/nodes/store/nodesSlice';
import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { imagesApi } from 'services/api/endpoints/images';
import { startAppListening } from '../';
import { parseify } from 'common/util/serialize';
export const dndDropped = createAction<{
overData: TypesafeDroppableData;
@ -21,7 +25,7 @@ export const addImageDroppedListener = () => {
startAppListening({
actionCreator: dndDropped,
effect: async (action, { dispatch }) => {
const log = logger('images');
const log = logger('dnd');
const { activeData, overData } = action.payload;
if (activeData.payloadType === 'IMAGE_DTO') {
@ -31,10 +35,28 @@ export const addImageDroppedListener = () => {
{ activeData, overData },
`Images (${activeData.payload.imageDTOs.length}) dropped`
);
} else if (activeData.payloadType === 'NODE_FIELD') {
log.debug(
{ activeData: parseify(activeData), overData: parseify(overData) },
'Node field dropped'
);
} else {
log.debug({ activeData, overData }, `Unknown payload dropped`);
}
if (
overData.actionType === 'ADD_FIELD_TO_LINEAR' &&
activeData.payloadType === 'NODE_FIELD'
) {
const { nodeId, field } = activeData.payload;
dispatch(
workflowExposedFieldAdded({
nodeId,
fieldName: field.name,
})
);
}
/**
* Image dropped on current image
*/
@ -99,7 +121,7 @@ export const addImageDroppedListener = () => {
) {
const { fieldName, nodeId } = overData.context;
dispatch(
fieldValueChanged({
fieldImageValueChanged({
nodeId,
fieldName,
value: activeData.payload.imageDTO,

View File

@ -2,7 +2,7 @@ import { UseToastOptions } from '@chakra-ui/react';
import { logger } from 'app/logging/logger';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { addToast } from 'features/system/store/systemSlice';
import { omit } from 'lodash-es';
@ -111,7 +111,9 @@ export const addImageUploadedFulfilledListener = () => {
if (postUploadAction?.type === 'SET_NODES_IMAGE') {
const { nodeId, fieldName } = postUploadAction;
dispatch(fieldValueChanged({ nodeId, fieldName, value: imageDTO }));
dispatch(
fieldImageValueChanged({ nodeId, fieldName, value: imageDTO })
);
dispatch(
addToast({
...DEFAULT_UPLOADED_TOAST,

View File

@ -15,12 +15,21 @@ import {
setShouldUseSDXLRefiner,
} from 'features/sdxl/store/sdxlSlice';
import { forEach, some } from 'lodash-es';
import { modelsApi, vaeModelsAdapter } from 'services/api/endpoints/models';
import {
mainModelsAdapter,
modelsApi,
vaeModelsAdapter,
} from 'services/api/endpoints/models';
import { TypeGuardFor } from 'services/api/types';
import { startAppListening } from '..';
export const addModelsLoadedListener = () => {
startAppListening({
predicate: (state, action) =>
predicate: (
action
): action is TypeGuardFor<
typeof modelsApi.endpoints.getMainModels.matchFulfilled
> =>
modelsApi.endpoints.getMainModels.matchFulfilled(action) &&
!action.meta.arg.originalArgs.includes('sdxl-refiner'),
effect: async (action, { getState, dispatch }) => {
@ -32,29 +41,28 @@ export const addModelsLoadedListener = () => {
);
const currentModel = getState().generation.model;
const models = mainModelsAdapter.getSelectors().selectAll(action.payload);
const isCurrentModelAvailable = some(
action.payload.entities,
(m) =>
m?.model_name === currentModel?.model_name &&
m?.base_model === currentModel?.base_model &&
m?.model_type === currentModel?.model_type
);
if (isCurrentModelAvailable) {
return;
}
const firstModelId = action.payload.ids[0];
const firstModel = action.payload.entities[firstModelId];
if (!firstModel) {
if (models.length === 0) {
// No models loaded at all
dispatch(modelChanged(null));
return;
}
const result = zMainOrOnnxModel.safeParse(firstModel);
const isCurrentModelAvailable = currentModel
? models.some(
(m) =>
m.model_name === currentModel.model_name &&
m.base_model === currentModel.base_model &&
m.model_type === currentModel.model_type
)
: false;
if (isCurrentModelAvailable) {
return;
}
const result = zMainOrOnnxModel.safeParse(models[0]);
if (!result.success) {
log.error(
@ -68,7 +76,11 @@ export const addModelsLoadedListener = () => {
},
});
startAppListening({
predicate: (state, action) =>
predicate: (
action
): action is TypeGuardFor<
typeof modelsApi.endpoints.getMainModels.matchFulfilled
> =>
modelsApi.endpoints.getMainModels.matchFulfilled(action) &&
action.meta.arg.originalArgs.includes('sdxl-refiner'),
effect: async (action, { getState, dispatch }) => {
@ -80,30 +92,29 @@ export const addModelsLoadedListener = () => {
);
const currentModel = getState().sdxl.refinerModel;
const models = mainModelsAdapter.getSelectors().selectAll(action.payload);
const isCurrentModelAvailable = some(
action.payload.entities,
(m) =>
m?.model_name === currentModel?.model_name &&
m?.base_model === currentModel?.base_model &&
m?.model_type === currentModel?.model_type
);
if (isCurrentModelAvailable) {
return;
}
const firstModelId = action.payload.ids[0];
const firstModel = action.payload.entities[firstModelId];
if (!firstModel) {
if (models.length === 0) {
// No models loaded at all
dispatch(refinerModelChanged(null));
dispatch(setShouldUseSDXLRefiner(false));
return;
}
const result = zSDXLRefinerModel.safeParse(firstModel);
const isCurrentModelAvailable = currentModel
? models.some(
(m) =>
m.model_name === currentModel.model_name &&
m.base_model === currentModel.base_model &&
m.model_type === currentModel.model_type
)
: false;
if (isCurrentModelAvailable) {
return;
}
const result = zSDXLRefinerModel.safeParse(models[0]);
if (!result.success) {
log.error(

View File

@ -13,7 +13,7 @@ export const addReceivedOpenAPISchemaListener = () => {
const log = logger('system');
const schemaJSON = action.payload;
log.debug({ schemaJSON }, 'Dereferenced OpenAPI schema');
log.debug({ schemaJSON }, 'Received OpenAPI schema');
const nodeTemplates = parseSchema(schemaJSON);
@ -28,9 +28,12 @@ export const addReceivedOpenAPISchemaListener = () => {
startAppListening({
actionCreator: receivedOpenAPISchema.rejected,
effect: () => {
effect: (action) => {
const log = logger('system');
log.error('Problem dereferencing OpenAPI Schema');
log.error(
{ error: parseify(action.error) },
'Problem retrieving OpenAPI Schema'
);
},
});
};

View File

@ -19,7 +19,7 @@ import {
} from 'services/events/actions';
import { startAppListening } from '../..';
const nodeDenylist = ['dataURL_image'];
const nodeDenylist = ['load_image'];
export const addInvocationCompleteEventListener = () => {
startAppListening({

View File

@ -15,7 +15,7 @@ export const addUserInvokedNodesListener = () => {
const log = logger('session');
const state = getState();
const graph = buildNodesGraph(state);
const graph = buildNodesGraph(state.nodes);
dispatch(nodesGraphBuilt(graph));
log.debug({ graph: parseify(graph) }, 'Nodes graph built');

View File

@ -1,86 +1,7 @@
import {
// CONTROLNET_MODELS,
CONTROLNET_PROCESSORS,
} from 'features/controlNet/store/constants';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { InvokeTabName } from 'features/ui/store/tabMap';
import { O } from 'ts-toolbelt';
// These are old types from the model management UI
// export type ModelStatus = 'active' | 'cached' | 'not loaded';
// export type Model = {
// status: ModelStatus;
// description: string;
// weights: string;
// config?: string;
// vae?: string;
// width?: number;
// height?: number;
// default?: boolean;
// format?: string;
// };
// export type DiffusersModel = {
// status: ModelStatus;
// description: string;
// repo_id?: string;
// path?: string;
// vae?: {
// repo_id?: string;
// path?: string;
// };
// format?: string;
// default?: boolean;
// };
// export type ModelList = Record<string, Model & DiffusersModel>;
// export type FoundModel = {
// name: string;
// location: string;
// };
// export type InvokeModelConfigProps = {
// name: string | undefined;
// description: string | undefined;
// config: string | undefined;
// weights: string | undefined;
// vae: string | undefined;
// width: number | undefined;
// height: number | undefined;
// default: boolean | undefined;
// format: string | undefined;
// };
// export type InvokeDiffusersModelConfigProps = {
// name: string | undefined;
// description: string | undefined;
// repo_id: string | undefined;
// path: string | undefined;
// default: boolean | undefined;
// format: string | undefined;
// vae: {
// repo_id: string | undefined;
// path: string | undefined;
// };
// };
// export type InvokeModelConversionProps = {
// model_name: string;
// save_location: string;
// custom_location: string | null;
// };
// export type InvokeModelMergingProps = {
// models_to_merge: string[];
// alpha: number;
// interp: 'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference';
// force: boolean;
// merged_model_name: string;
// model_merge_save_path: string | null;
// };
/**
* A disable-able application feature
*/

View File

@ -6,10 +6,6 @@ import {
useColorMode,
useColorModeValue,
} from '@chakra-ui/react';
import {
TypesafeDraggableData,
TypesafeDroppableData,
} from 'app/components/ImageDnd/typesafeDnd';
import IAIIconButton from 'common/components/IAIIconButton';
import {
IAILoadingImageFallback,
@ -17,6 +13,10 @@ import {
} from 'common/components/IAIImageFallback';
import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import {
TypesafeDraggableData,
TypesafeDroppableData,
} from 'features/dnd/types';
import ImageContextMenu from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
import {
MouseEvent,
@ -157,11 +157,10 @@ const IAIDndImage = (props: IAIDndImageProps) => {
<IAILoadingImageFallback image={imageDTO} />
)
}
width={imageDTO.width}
height={imageDTO.height}
onError={onError}
draggable={false}
sx={{
w: imageDTO.width,
objectFit: 'contain',
maxW: 'full',
maxH: 'full',
@ -213,13 +212,6 @@ const IAIDndImage = (props: IAIDndImageProps) => {
onClick={onClick}
/>
)}
{!isDropDisabled && (
<IAIDroppable
data={droppableData}
disabled={isDropDisabled}
dropLabel={dropLabel}
/>
)}
{onClickReset && withResetIcon && imageDTO && (
<IAIIconButton
onClick={onClickReset}
@ -244,6 +236,13 @@ const IAIDndImage = (props: IAIDndImageProps) => {
}}
/>
)}
{!isDropDisabled && (
<IAIDroppable
data={droppableData}
disabled={isDropDisabled}
dropLabel={dropLabel}
/>
)}
</Flex>
)}
</ImageContextMenu>

View File

@ -1,22 +1,19 @@
import { Box } from '@chakra-ui/react';
import {
TypesafeDraggableData,
useDraggable,
} from 'app/components/ImageDnd/typesafeDnd';
import { MouseEvent, memo, useRef } from 'react';
import { Box, BoxProps } from '@chakra-ui/react';
import { useDraggableTypesafe } from 'features/dnd/hooks/typesafeHooks';
import { TypesafeDraggableData } from 'features/dnd/types';
import { memo, useRef } from 'react';
import { v4 as uuidv4 } from 'uuid';
type IAIDraggableProps = {
type IAIDraggableProps = BoxProps & {
disabled?: boolean;
data?: TypesafeDraggableData;
onClick?: (event: MouseEvent<HTMLDivElement>) => void;
};
const IAIDraggable = (props: IAIDraggableProps) => {
const { data, disabled, onClick } = props;
const { data, disabled, ...rest } = props;
const dndId = useRef(uuidv4());
const { attributes, listeners, setNodeRef } = useDraggable({
const { attributes, listeners, setNodeRef } = useDraggableTypesafe({
id: dndId.current,
disabled,
data,
@ -24,7 +21,6 @@ const IAIDraggable = (props: IAIDraggableProps) => {
return (
<Box
onClick={onClick}
ref={setNodeRef}
position="absolute"
w="full"
@ -33,6 +29,7 @@ const IAIDraggable = (props: IAIDraggableProps) => {
insetInlineStart={0}
{...attributes}
{...listeners}
{...rest}
/>
);
};

View File

@ -1,9 +1,7 @@
import { Box } from '@chakra-ui/react';
import {
TypesafeDroppableData,
isValidDrop,
useDroppable,
} from 'app/components/ImageDnd/typesafeDnd';
import { useDroppableTypesafe } from 'features/dnd/hooks/typesafeHooks';
import { TypesafeDroppableData } from 'features/dnd/types';
import { isValidDrop } from 'features/dnd/util/isValidDrop';
import { AnimatePresence } from 'framer-motion';
import { ReactNode, memo, useRef } from 'react';
import { v4 as uuidv4 } from 'uuid';
@ -19,7 +17,7 @@ const IAIDroppable = (props: IAIDroppableProps) => {
const { dropLabel, data, disabled } = props;
const dndId = useRef(uuidv4());
const { isOver, setNodeRef, active } = useDroppable({
const { isOver, setNodeRef, active } = useDroppableTypesafe({
id: dndId.current,
disabled,
data,

View File

@ -49,7 +49,7 @@ export const IAILoadingImageFallback = (props: Props) => {
type IAINoImageFallbackProps = {
label?: string;
icon?: As;
icon?: As | null;
boxSize?: StyleProps['boxSize'];
sx?: ChakraProps['sx'];
};
@ -76,7 +76,7 @@ export const IAINoContentFallback = (props: IAINoImageFallbackProps) => {
...props.sx,
}}
>
<Icon as={icon} boxSize={boxSize} opacity={0.7} />
{icon && <Icon as={icon} boxSize={boxSize} opacity={0.7} />}
{props.label && <Text textAlign="center">{props.label}</Text>}
</Flex>
);

View File

@ -1,10 +1,13 @@
import {
Flex,
FormControl,
FormControlProps,
FormHelperText,
FormLabel,
FormLabelProps,
Switch,
SwitchProps,
Text,
Tooltip,
} from '@chakra-ui/react';
import { memo } from 'react';
@ -15,6 +18,7 @@ export interface IAISwitchProps extends SwitchProps {
formControlProps?: FormControlProps;
formLabelProps?: FormLabelProps;
tooltip?: string;
helperText?: string;
}
/**
@ -28,6 +32,7 @@ const IAISwitch = (props: IAISwitchProps) => {
formControlProps,
formLabelProps,
tooltip,
helperText,
...rest
} = props;
return (
@ -35,25 +40,33 @@ const IAISwitch = (props: IAISwitchProps) => {
<FormControl
isDisabled={isDisabled}
width={width}
display="flex"
alignItems="center"
{...formControlProps}
>
{label && (
<FormLabel
my={1}
flexGrow={1}
sx={{
cursor: isDisabled ? 'not-allowed' : 'pointer',
...formLabelProps?.sx,
pe: 4,
}}
{...formLabelProps}
>
{label}
</FormLabel>
)}
<Switch {...rest} />
<Flex sx={{ flexDir: 'column', w: 'full' }}>
<Flex sx={{ alignItems: 'center', w: 'full' }}>
{label && (
<FormLabel
my={1}
flexGrow={1}
sx={{
cursor: isDisabled ? 'not-allowed' : 'pointer',
...formLabelProps?.sx,
pe: 4,
}}
{...formLabelProps}
>
{label}
</FormLabel>
)}
<Switch {...rest} />
</Flex>
{helperText && (
<FormHelperText>
<Text variant="subtext">{helperText}</Text>
</FormHelperText>
)}
</Flex>
</FormControl>
</Tooltip>
);

View File

@ -40,6 +40,44 @@ export const useChakraThemeTokens = () => {
accent850,
accent900,
accent950,
baseAlpha50,
baseAlpha100,
baseAlpha150,
baseAlpha200,
baseAlpha250,
baseAlpha300,
baseAlpha350,
baseAlpha400,
baseAlpha450,
baseAlpha500,
baseAlpha550,
baseAlpha600,
baseAlpha650,
baseAlpha700,
baseAlpha750,
baseAlpha800,
baseAlpha850,
baseAlpha900,
baseAlpha950,
accentAlpha50,
accentAlpha100,
accentAlpha150,
accentAlpha200,
accentAlpha250,
accentAlpha300,
accentAlpha350,
accentAlpha400,
accentAlpha450,
accentAlpha500,
accentAlpha550,
accentAlpha600,
accentAlpha650,
accentAlpha700,
accentAlpha750,
accentAlpha800,
accentAlpha850,
accentAlpha900,
accentAlpha950,
] = useToken('colors', [
'base.50',
'base.100',
@ -79,6 +117,44 @@ export const useChakraThemeTokens = () => {
'accent.850',
'accent.900',
'accent.950',
'baseAlpha.50',
'baseAlpha.100',
'baseAlpha.150',
'baseAlpha.200',
'baseAlpha.250',
'baseAlpha.300',
'baseAlpha.350',
'baseAlpha.400',
'baseAlpha.450',
'baseAlpha.500',
'baseAlpha.550',
'baseAlpha.600',
'baseAlpha.650',
'baseAlpha.700',
'baseAlpha.750',
'baseAlpha.800',
'baseAlpha.850',
'baseAlpha.900',
'baseAlpha.950',
'accentAlpha.50',
'accentAlpha.100',
'accentAlpha.150',
'accentAlpha.200',
'accentAlpha.250',
'accentAlpha.300',
'accentAlpha.350',
'accentAlpha.400',
'accentAlpha.450',
'accentAlpha.500',
'accentAlpha.550',
'accentAlpha.600',
'accentAlpha.650',
'accentAlpha.700',
'accentAlpha.750',
'accentAlpha.800',
'accentAlpha.850',
'accentAlpha.900',
'accentAlpha.950',
]);
return {
@ -120,5 +196,43 @@ export const useChakraThemeTokens = () => {
accent850,
accent900,
accent950,
baseAlpha50,
baseAlpha100,
baseAlpha150,
baseAlpha200,
baseAlpha250,
baseAlpha300,
baseAlpha350,
baseAlpha400,
baseAlpha450,
baseAlpha500,
baseAlpha550,
baseAlpha600,
baseAlpha650,
baseAlpha700,
baseAlpha750,
baseAlpha800,
baseAlpha850,
baseAlpha900,
baseAlpha950,
accentAlpha50,
accentAlpha100,
accentAlpha150,
accentAlpha200,
accentAlpha250,
accentAlpha300,
accentAlpha350,
accentAlpha400,
accentAlpha450,
accentAlpha500,
accentAlpha550,
accentAlpha600,
accentAlpha650,
accentAlpha700,
accentAlpha750,
accentAlpha800,
accentAlpha850,
accentAlpha900,
accentAlpha950,
};
};

View File

@ -1,4 +1,10 @@
/**
* Serialize an object to JSON and back to a new object
*/
export const parseify = (obj: unknown) => JSON.parse(JSON.stringify(obj));
export const parseify = (obj: unknown) => {
try {
return JSON.parse(JSON.stringify(obj));
} catch {
return 'Error parsing object';
}
};

View File

@ -4,7 +4,7 @@ import { skipToken } from '@reduxjs/toolkit/dist/query';
import {
TypesafeDraggableData,
TypesafeDroppableData,
} from 'app/components/ImageDnd/typesafeDnd';
} from 'features/dnd/types';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';

View File

@ -138,7 +138,7 @@ export type RequiredZoeDepthImageProcessorInvocation = O.Required<
/**
* Any ControlNet Processor node, with its parameters flagged as required
*/
export type RequiredControlNetProcessorNode =
export type RequiredControlNetProcessorNode = O.Required<
| RequiredCannyImageProcessorInvocation
| RequiredContentShuffleImageProcessorInvocation
| RequiredHedImageProcessorInvocation
@ -150,7 +150,9 @@ export type RequiredControlNetProcessorNode =
| RequiredNormalbaeImageProcessorInvocation
| RequiredOpenposeImageProcessorInvocation
| RequiredPidiImageProcessorInvocation
| RequiredZoeDepthImageProcessorInvocation;
| RequiredZoeDepthImageProcessorInvocation,
'id'
>;
/**
* Type guard for CannyImageProcessorInvocation

View File

@ -3,6 +3,7 @@ import { RootState } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { some } from 'lodash-es';
import { ImageUsage } from './types';
import { isInvocationNode } from 'features/nodes/types/types';
export const getImageUsage = (state: RootState, image_name: string) => {
const { generation, canvas, nodes, controlNet } = state;
@ -12,11 +13,11 @@ export const getImageUsage = (state: RootState, image_name: string) => {
(obj) => obj.kind === 'image' && obj.imageName === image_name
);
const isNodesImage = nodes.nodes.some((node) => {
const isNodesImage = nodes.nodes.filter(isInvocationNode).some((node) => {
return some(
node.data.inputs,
(input) =>
input.type === 'image' && input.value?.image_name === image_name
input.type === 'ImageField' && input.value?.image_name === image_name
);
});

View File

@ -6,23 +6,18 @@ import {
useSensor,
useSensors,
} from '@dnd-kit/core';
import { snapCenterToCursor } from '@dnd-kit/modifiers';
import { logger } from 'app/logging/logger';
import { dndDropped } from 'app/store/middleware/listenerMiddleware/listeners/imageDropped';
import { useAppDispatch } from 'app/store/storeHooks';
import { parseify } from 'common/util/serialize';
import { AnimatePresence, motion } from 'framer-motion';
import { PropsWithChildren, memo, useCallback, useState } from 'react';
import { useScaledModifer } from '../hooks/useScaledCenteredModifer';
import { DragEndEvent, DragStartEvent, TypesafeDraggableData } from '../types';
import { DndContextTypesafe } from './DndContextTypesafe';
import DragPreview from './DragPreview';
import {
DndContext,
DragEndEvent,
DragStartEvent,
TypesafeDraggableData,
} from './typesafeDnd';
import { logger } from 'app/logging/logger';
type ImageDndContextProps = PropsWithChildren;
const ImageDndContext = (props: ImageDndContextProps) => {
const AppDndContext = (props: PropsWithChildren) => {
const [activeDragData, setActiveDragData] =
useState<TypesafeDraggableData | null>(null);
const log = logger('images');
@ -31,7 +26,10 @@ const ImageDndContext = (props: ImageDndContextProps) => {
const handleDragStart = useCallback(
(event: DragStartEvent) => {
log.trace({ dragData: event.active.data.current }, 'Drag started');
log.trace(
{ dragData: parseify(event.active.data.current) },
'Drag started'
);
const activeData = event.active.data.current;
if (!activeData) {
return;
@ -43,7 +41,10 @@ const ImageDndContext = (props: ImageDndContextProps) => {
const handleDragEnd = useCallback(
(event: DragEndEvent) => {
log.trace({ dragData: event.active.data.current }, 'Drag ended');
log.trace(
{ dragData: parseify(event.active.data.current) },
'Drag ended'
);
const overData = event.over?.data.current;
if (!activeDragData || !overData) {
return;
@ -69,15 +70,29 @@ const ImageDndContext = (props: ImageDndContextProps) => {
const sensors = useSensors(mouseSensor, touchSensor);
const scaledModifier = useScaledModifer();
return (
<DndContext
<DndContextTypesafe
onDragStart={handleDragStart}
onDragEnd={handleDragEnd}
sensors={sensors}
collisionDetection={pointerWithin}
autoScroll={false}
>
{props.children}
<DragOverlay dropAnimation={null} modifiers={[snapCenterToCursor]}>
<DragOverlay
dropAnimation={null}
modifiers={[scaledModifier]}
style={{
width: 'min-content',
height: 'min-content',
cursor: 'none',
userSelect: 'none',
// expand overlay to prevent cursor from going outside it and displaying
padding: '10rem',
}}
>
<AnimatePresence>
{activeDragData && (
<motion.div
@ -98,8 +113,8 @@ const ImageDndContext = (props: ImageDndContextProps) => {
)}
</AnimatePresence>
</DragOverlay>
</DndContext>
</DndContextTypesafe>
);
};
export default memo(ImageDndContext);
export default memo(AppDndContext);

View File

@ -0,0 +1,6 @@
import { DndContext } from '@dnd-kit/core';
import { DndContextTypesafeProps } from '../types';
export function DndContextTypesafe(props: DndContextTypesafeProps) {
return <DndContext {...props} />;
}

View File

@ -1,6 +1,6 @@
import { Box, ChakraProps, Flex, Heading, Image } from '@chakra-ui/react';
import { Box, ChakraProps, Flex, Heading, Image, Text } from '@chakra-ui/react';
import { memo } from 'react';
import { TypesafeDraggableData } from './typesafeDnd';
import { TypesafeDraggableData } from '../types';
type OverlayDragImageProps = {
dragData: TypesafeDraggableData | null;
@ -30,19 +30,38 @@ const DragPreview = (props: OverlayDragImageProps) => {
return null;
}
if (props.dragData.payloadType === 'NODE_FIELD') {
const { field, fieldTemplate } = props.dragData.payload;
return (
<Box
sx={{
position: 'relative',
p: 2,
px: 3,
opacity: 0.7,
bg: 'base.300',
borderRadius: 'base',
boxShadow: 'dark-lg',
whiteSpace: 'nowrap',
fontSize: 'sm',
}}
>
<Text>{field.label || fieldTemplate.title}</Text>
</Box>
);
}
if (props.dragData.payloadType === 'IMAGE_DTO') {
const { thumbnail_url, width, height } = props.dragData.payload.imageDTO;
return (
<Box
sx={{
position: 'relative',
width: '100%',
height: '100%',
width: 'full',
height: 'full',
display: 'flex',
alignItems: 'center',
justifyContent: 'center',
userSelect: 'none',
cursor: 'none',
}}
>
<Image
@ -62,8 +81,6 @@ const DragPreview = (props: OverlayDragImageProps) => {
return (
<Flex
sx={{
cursor: 'none',
userSelect: 'none',
position: 'relative',
alignItems: 'center',
justifyContent: 'center',

View File

@ -0,0 +1,15 @@
import { useDraggable, useDroppable } from '@dnd-kit/core';
import {
UseDraggableTypesafeArguments,
UseDraggableTypesafeReturnValue,
UseDroppableTypesafeArguments,
UseDroppableTypesafeReturnValue,
} from '../types';
export function useDroppableTypesafe(props: UseDroppableTypesafeArguments) {
return useDroppable(props) as UseDroppableTypesafeReturnValue;
}
export function useDraggableTypesafe(props: UseDraggableTypesafeArguments) {
return useDraggable(props) as UseDraggableTypesafeReturnValue;
}

View File

@ -0,0 +1,50 @@
import type { Modifier } from '@dnd-kit/core';
import { getEventCoordinates } from '@dnd-kit/utilities';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { useCallback } from 'react';
const selectZoom = createSelector(
[stateSelector, activeTabNameSelector],
({ nodes }, activeTabName) => (activeTabName === 'nodes' ? nodes.zoom : 1)
);
/**
* Applies scaling to the drag transform (if on node editor tab) and centers it on cursor.
*/
export const useScaledModifer = () => {
const zoom = useAppSelector(selectZoom);
const modifier: Modifier = useCallback(
({ activatorEvent, draggingNodeRect, transform }) => {
if (draggingNodeRect && activatorEvent) {
const activatorCoordinates = getEventCoordinates(activatorEvent);
if (!activatorCoordinates) {
return transform;
}
const offsetX = activatorCoordinates.x - draggingNodeRect.left;
const offsetY = activatorCoordinates.y - draggingNodeRect.top;
const x = transform.x + offsetX - draggingNodeRect.width / 2;
const y = transform.y + offsetY - draggingNodeRect.height / 2;
const scaleX = transform.scaleX * zoom;
const scaleY = transform.scaleY * zoom;
return {
x,
y,
scaleX,
scaleY,
};
}
return transform;
},
[zoom]
);
return modifier;
};

View File

@ -3,7 +3,6 @@ import {
Active,
Collision,
DndContextProps,
DndContext as OriginalDndContext,
Over,
Translate,
UseDraggableArguments,
@ -11,6 +10,10 @@ import {
useDraggable as useOriginalDraggable,
useDroppable as useOriginalDroppable,
} from '@dnd-kit/core';
import {
InputFieldTemplate,
InputFieldValue,
} from 'features/nodes/types/types';
import { ImageDTO } from 'services/api/types';
type BaseDropData = {
@ -62,6 +65,10 @@ export type RemoveFromBoardDropData = BaseDropData & {
actionType: 'REMOVE_FROM_BOARD';
};
export type AddFieldToLinearViewDropData = BaseDropData & {
actionType: 'ADD_FIELD_TO_LINEAR';
};
export type TypesafeDroppableData =
| CurrentImageDropData
| InitialImageDropData
@ -71,12 +78,22 @@ export type TypesafeDroppableData =
| AddToBatchDropData
| NodesMultiImageDropData
| AddToBoardDropData
| RemoveFromBoardDropData;
| RemoveFromBoardDropData
| AddFieldToLinearViewDropData;
type BaseDragData = {
id: string;
};
export type NodeFieldDraggableData = BaseDragData & {
payloadType: 'NODE_FIELD';
payload: {
nodeId: string;
field: InputFieldValue;
fieldTemplate: InputFieldTemplate;
};
};
export type ImageDraggableData = BaseDragData & {
payloadType: 'IMAGE_DTO';
payload: { imageDTO: ImageDTO };
@ -87,14 +104,17 @@ export type ImageDTOsDraggableData = BaseDragData & {
payload: { imageDTOs: ImageDTO[] };
};
export type TypesafeDraggableData = ImageDraggableData | ImageDTOsDraggableData;
export type TypesafeDraggableData =
| NodeFieldDraggableData
| ImageDraggableData
| ImageDTOsDraggableData;
interface UseDroppableTypesafeArguments
export interface UseDroppableTypesafeArguments
extends Omit<UseDroppableArguments, 'data'> {
data?: TypesafeDroppableData;
}
type UseDroppableTypesafeReturnValue = Omit<
export type UseDroppableTypesafeReturnValue = Omit<
ReturnType<typeof useOriginalDroppable>,
'active' | 'over'
> & {
@ -102,16 +122,12 @@ type UseDroppableTypesafeReturnValue = Omit<
over: TypesafeOver | null;
};
export function useDroppable(props: UseDroppableTypesafeArguments) {
return useOriginalDroppable(props) as UseDroppableTypesafeReturnValue;
}
interface UseDraggableTypesafeArguments
export interface UseDraggableTypesafeArguments
extends Omit<UseDraggableArguments, 'data'> {
data?: TypesafeDraggableData;
}
type UseDraggableTypesafeReturnValue = Omit<
export type UseDraggableTypesafeReturnValue = Omit<
ReturnType<typeof useOriginalDraggable>,
'active' | 'over'
> & {
@ -119,102 +135,14 @@ type UseDraggableTypesafeReturnValue = Omit<
over: TypesafeOver | null;
};
export function useDraggable(props: UseDraggableTypesafeArguments) {
return useOriginalDraggable(props) as UseDraggableTypesafeReturnValue;
}
interface TypesafeActive extends Omit<Active, 'data'> {
export interface TypesafeActive extends Omit<Active, 'data'> {
data: React.MutableRefObject<TypesafeDraggableData | undefined>;
}
interface TypesafeOver extends Omit<Over, 'data'> {
export interface TypesafeOver extends Omit<Over, 'data'> {
data: React.MutableRefObject<TypesafeDroppableData | undefined>;
}
export const isValidDrop = (
overData: TypesafeDroppableData | undefined,
active: TypesafeActive | null
) => {
if (!overData || !active?.data.current) {
return false;
}
const { actionType } = overData;
const { payloadType } = active.data.current;
if (overData.id === active.data.current.id) {
return false;
}
switch (actionType) {
case 'SET_CURRENT_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_INITIAL_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_CONTROLNET_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_CANVAS_INITIAL_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_NODES_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_MULTI_NODES_IMAGE':
return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
case 'ADD_TO_BATCH':
return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
case 'ADD_TO_BOARD': {
// If the board is the same, don't allow the drop
// Check the payload types
const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
if (!isPayloadValid) {
return false;
}
// Check if the image's board is the board we are dragging onto
if (payloadType === 'IMAGE_DTO') {
const { imageDTO } = active.data.current.payload;
const currentBoard = imageDTO.board_id ?? 'none';
const destinationBoard = overData.context.boardId;
return currentBoard !== destinationBoard;
}
if (payloadType === 'IMAGE_DTOS') {
// TODO (multi-select)
return true;
}
return false;
}
case 'REMOVE_FROM_BOARD': {
// If the board is the same, don't allow the drop
// Check the payload types
const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
if (!isPayloadValid) {
return false;
}
// Check if the image's board is the board we are dragging onto
if (payloadType === 'IMAGE_DTO') {
const { imageDTO } = active.data.current.payload;
const currentBoard = imageDTO.board_id;
return currentBoard !== 'none';
}
if (payloadType === 'IMAGE_DTOS') {
// TODO (multi-select)
return true;
}
return false;
}
default:
return false;
}
};
interface DragEvent {
activatorEvent: Event;
active: TypesafeActive;
@ -240,6 +168,3 @@ export interface DndContextTypesafeProps
onDragEnd?(event: DragEndEvent): void;
onDragCancel?(event: DragCancelEvent): void;
}
export function DndContext(props: DndContextTypesafeProps) {
return <OriginalDndContext {...props} />;
}

View File

@ -0,0 +1,87 @@
import { TypesafeActive, TypesafeDroppableData } from '../types';
export const isValidDrop = (
overData: TypesafeDroppableData | undefined,
active: TypesafeActive | null
) => {
if (!overData || !active?.data.current) {
return false;
}
const { actionType } = overData;
const { payloadType } = active.data.current;
if (overData.id === active.data.current.id) {
return false;
}
switch (actionType) {
case 'ADD_FIELD_TO_LINEAR':
return payloadType === 'NODE_FIELD';
case 'SET_CURRENT_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_INITIAL_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_CONTROLNET_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_CANVAS_INITIAL_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_NODES_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_MULTI_NODES_IMAGE':
return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
case 'ADD_TO_BATCH':
return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
case 'ADD_TO_BOARD': {
// If the board is the same, don't allow the drop
// Check the payload types
const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
if (!isPayloadValid) {
return false;
}
// Check if the image's board is the board we are dragging onto
if (payloadType === 'IMAGE_DTO') {
const { imageDTO } = active.data.current.payload;
const currentBoard = imageDTO.board_id ?? 'none';
const destinationBoard = overData.context.boardId;
return currentBoard !== destinationBoard;
}
if (payloadType === 'IMAGE_DTOS') {
// TODO (multi-select)
return true;
}
return false;
}
case 'REMOVE_FROM_BOARD': {
// If the board is the same, don't allow the drop
// Check the payload types
const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
if (!isPayloadValid) {
return false;
}
// Check if the image's board is the board we are dragging onto
if (payloadType === 'IMAGE_DTO') {
const { imageDTO } = active.data.current.payload;
const currentBoard = imageDTO.board_id;
return currentBoard !== 'none';
}
if (payloadType === 'IMAGE_DTOS') {
// TODO (multi-select)
return true;
}
return false;
}
default:
return false;
}
};

View File

@ -11,7 +11,6 @@ import {
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { skipToken } from '@reduxjs/toolkit/dist/query';
import { AddToBoardDropData } from 'app/components/ImageDnd/typesafeDnd';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
@ -32,6 +31,7 @@ import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { BoardDTO } from 'services/api/types';
import AutoAddIcon from '../AutoAddIcon';
import BoardContextMenu from '../BoardContextMenu';
import { AddToBoardDropData } from 'features/dnd/types';
interface GalleryBoardProps {
board: BoardDTO;

View File

@ -1,7 +1,7 @@
import { As, Badge, Flex } from '@chakra-ui/react';
import { TypesafeDroppableData } from 'app/components/ImageDnd/typesafeDnd';
import IAIDroppable from 'common/components/IAIDroppable';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { TypesafeDroppableData } from 'features/dnd/types';
import { BoardId } from 'features/gallery/store/types';
import { ReactNode } from 'react';
import BoardContextMenu from '../BoardContextMenu';

View File

@ -1,15 +1,15 @@
import { Box, Flex, Image, Text } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { RemoveFromBoardDropData } from 'app/components/ImageDnd/typesafeDnd';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import InvokeAILogoImage from 'assets/images/logo.png';
import IAIDroppable from 'common/components/IAIDroppable';
import SelectionOverlay from 'common/components/SelectionOverlay';
import { RemoveFromBoardDropData } from 'features/dnd/types';
import {
boardIdSelected,
autoAddBoardIdChanged,
boardIdSelected,
} from 'features/gallery/store/gallerySlice';
import { memo, useCallback, useMemo, useState } from 'react';
import { useBoardName } from 'services/api/hooks/useBoardName';

View File

@ -1,14 +1,14 @@
import { Box, Flex, Image } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { skipToken } from '@reduxjs/toolkit/dist/query';
import {
TypesafeDraggableData,
TypesafeDroppableData,
} from 'app/components/ImageDnd/typesafeDnd';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import IAIDndImage from 'common/components/IAIDndImage';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import {
TypesafeDraggableData,
TypesafeDroppableData,
} from 'features/dnd/types';
import { useNextPrevImage } from 'features/gallery/hooks/useNextPrevImage';
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
import { AnimatePresence, motion } from 'framer-motion';

View File

@ -52,11 +52,13 @@ const ImageGalleryContent = () => {
return (
<VStack
layerStyle="first"
sx={{
flexDirection: 'column',
h: 'full',
w: 'full',
borderRadius: 'base',
p: 2,
}}
>
<Box sx={{ w: 'full' }}>

View File

@ -1,9 +1,4 @@
import { Box, Flex } from '@chakra-ui/react';
import {
ImageDTOsDraggableData,
ImageDraggableData,
TypesafeDraggableData,
} from 'app/components/ImageDnd/typesafeDnd';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIDndImage from 'common/components/IAIDndImage';
import IAIFillSkeleton from 'common/components/IAIFillSkeleton';
@ -12,6 +7,11 @@ import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
import { MouseEvent, memo, useCallback, useMemo } from 'react';
import { FaTrash } from 'react-icons/fa';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import {
ImageDTOsDraggableData,
ImageDraggableData,
TypesafeDraggableData,
} from 'features/dnd/types';
interface HoverableImageProps {
imageName: string;

View File

@ -26,7 +26,7 @@ const overlayScrollbarsConfig: UseOverlayScrollbarsParams = {
options: {
scrollbars: {
visibility: 'auto',
autoHide: 'leave',
autoHide: 'scroll',
autoHideDelay: 1300,
theme: 'os-theme-dark',
},

View File

@ -1,26 +1,40 @@
import { Box, Flex, IconButton, Tooltip } from '@chakra-ui/react';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { useMemo } from 'react';
import { FaCopy } from 'react-icons/fa';
import { useCallback, useMemo } from 'react';
import { FaCopy, FaSave } from 'react-icons/fa';
type Props = {
copyTooltip: string;
label: string;
jsonObject: object;
fileName?: string;
};
const ImageMetadataJSON = (props: Props) => {
const { copyTooltip, jsonObject } = props;
const { label, jsonObject, fileName } = props;
const jsonString = useMemo(
() => JSON.stringify(jsonObject, null, 2),
[jsonObject]
);
const handleCopy = useCallback(() => {
navigator.clipboard.writeText(jsonString);
}, [jsonString]);
const handleSave = useCallback(() => {
const blob = new Blob([jsonString]);
const a = document.createElement('a');
a.href = URL.createObjectURL(blob);
a.download = `${fileName || label}.json`;
document.body.appendChild(a);
a.click();
a.remove();
}, [jsonString, label, fileName]);
return (
<Flex
layerStyle="second"
sx={{
borderRadius: 'base',
bg: 'whiteAlpha.500',
_dark: { bg: 'blackAlpha.500' },
flexGrow: 1,
w: 'full',
h: 'full',
@ -36,6 +50,7 @@ const ImageMetadataJSON = (props: Props) => {
bottom: 0,
overflow: 'auto',
p: 4,
fontSize: 'sm',
}}
>
<OverlayScrollbarsComponent
@ -44,7 +59,7 @@ const ImageMetadataJSON = (props: Props) => {
options={{
scrollbars: {
visibility: 'auto',
autoHide: 'move',
autoHide: 'scroll',
autoHideDelay: 1300,
theme: 'os-theme-dark',
},
@ -54,12 +69,22 @@ const ImageMetadataJSON = (props: Props) => {
</OverlayScrollbarsComponent>
</Box>
<Flex sx={{ position: 'absolute', top: 0, insetInlineEnd: 0, p: 2 }}>
<Tooltip label={copyTooltip}>
<Tooltip label={`Save ${label} JSON`}>
<IconButton
aria-label={copyTooltip}
aria-label={`Save ${label} JSON`}
icon={<FaSave />}
variant="ghost"
opacity={0.7}
onClick={handleSave}
/>
</Tooltip>
<Tooltip label={`Copy ${label} JSON`}>
<IconButton
aria-label={`Copy ${label} JSON`}
icon={<FaCopy />}
variant="ghost"
onClick={() => navigator.clipboard.writeText(jsonString)}
opacity={0.7}
onClick={handleCopy}
/>
</Tooltip>
</Flex>

View File

@ -10,7 +10,8 @@ import {
Text,
} from '@chakra-ui/react';
import { skipToken } from '@reduxjs/toolkit/dist/query';
import { memo, useMemo } from 'react';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { memo } from 'react';
import { useGetImageMetadataQuery } from 'services/api/endpoints/images';
import { ImageDTO } from 'services/api/types';
import { useDebounce } from 'use-debounce';
@ -41,48 +42,15 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
const metadata = currentData?.metadata;
const graph = currentData?.graph;
const tabData = useMemo(() => {
const _tabData: { label: string; data: object; copyTooltip: string }[] = [];
if (metadata) {
_tabData.push({
label: 'Core Metadata',
data: metadata,
copyTooltip: 'Copy Core Metadata JSON',
});
}
if (image) {
_tabData.push({
label: 'Image Details',
data: image,
copyTooltip: 'Copy Image Details JSON',
});
}
if (graph) {
_tabData.push({
label: 'Graph',
data: graph,
copyTooltip: 'Copy Graph JSON',
});
}
return _tabData;
}, [metadata, graph, image]);
return (
<Flex
layerStyle="first"
sx={{
padding: 4,
gap: 1,
flexDirection: 'column',
width: 'full',
height: 'full',
backdropFilter: 'blur(20px)',
bg: 'baseAlpha.200',
_dark: {
bg: 'blackAlpha.600',
},
borderRadius: 'base',
position: 'absolute',
overflow: 'hidden',
@ -103,32 +71,33 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
sx={{ display: 'flex', flexDir: 'column', w: 'full', h: 'full' }}
>
<TabList>
{tabData.map((tab) => (
<Tab
key={tab.label}
sx={{
borderTopRadius: 'base',
}}
>
<Text sx={{ color: 'base.700', _dark: { color: 'base.300' } }}>
{tab.label}
</Text>
</Tab>
))}
<Tab>Core Metadata</Tab>
<Tab>Image Details</Tab>
<Tab>Graph</Tab>
</TabList>
<TabPanels sx={{ w: 'full', h: 'full' }}>
{tabData.map((tab) => (
<TabPanel
key={tab.label}
sx={{ w: 'full', h: 'full', p: 0, pt: 4 }}
>
<ImageMetadataJSON
jsonObject={tab.data}
copyTooltip={tab.copyTooltip}
/>
</TabPanel>
))}
<TabPanels>
<TabPanel>
{metadata ? (
<ImageMetadataJSON jsonObject={metadata} label="Core Metadata" />
) : (
<IAINoContentFallback label="No core metadata found" />
)}
</TabPanel>
<TabPanel>
{image ? (
<ImageMetadataJSON jsonObject={image} label="Image Details" />
) : (
<IAINoContentFallback label="No image details found" />
)}
</TabPanel>
<TabPanel>
{graph ? (
<ImageMetadataJSON jsonObject={graph} label="Graph" />
) : (
<IAINoContentFallback label="No graph found" />
)}
</TabPanel>
</TabPanels>
</Tabs>
</Flex>

View File

@ -9,30 +9,40 @@ import { map } from 'lodash-es';
import { forwardRef, useCallback } from 'react';
import 'reactflow/dist/style.css';
import { AnyInvocationType } from 'services/events/types';
import { useBuildInvocation } from '../hooks/useBuildInvocation';
import { useBuildNodeData } from '../hooks/useBuildNodeData';
import { nodeAdded } from '../store/nodesSlice';
type NodeTemplate = {
label: string;
value: string;
description: string;
tags: string[];
};
const selector = createSelector(
[stateSelector],
({ nodes }) => {
const data: NodeTemplate[] = map(nodes.invocationTemplates, (template) => {
const data: NodeTemplate[] = map(nodes.nodeTemplates, (template) => {
return {
label: template.title,
value: template.type,
description: template.description,
tags: template.tags,
};
});
data.push({
label: 'Progress Image',
value: 'progress_image',
description: 'Displays the progress image in the Node Editor',
value: 'current_image',
description: 'Displays the current image in the Node Editor',
tags: ['progress'],
});
data.push({
label: 'Notes',
value: 'notes',
description: 'Add notes about your workflow',
tags: ['notes'],
});
return { data };
@ -44,7 +54,7 @@ const AddNodeMenu = () => {
const dispatch = useAppDispatch();
const { data } = useAppSelector(selector);
const buildInvocation = useBuildInvocation();
const buildInvocation = useBuildNodeData();
const toaster = useAppToaster();
@ -89,11 +99,12 @@ const AddNodeMenu = () => {
filter={(value, item: NodeTemplate) =>
item.label.toLowerCase().includes(value.toLowerCase().trim()) ||
item.value.toLowerCase().includes(value.toLowerCase().trim()) ||
item.description.toLowerCase().includes(value.toLowerCase().trim())
item.description.toLowerCase().includes(value.toLowerCase().trim()) ||
item.tags.includes(value.toLowerCase().trim())
}
onChange={handleChange}
sx={{
width: '18rem',
width: '24rem',
}}
/>
</Flex>

View File

@ -0,0 +1,61 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { ConnectionLineComponentProps, getBezierPath } from 'reactflow';
import { FIELDS, colorTokenToCssVar } from '../types/constants';
const selector = createSelector(stateSelector, ({ nodes }) => {
const { shouldAnimateEdges, currentConnectionFieldType, shouldColorEdges } =
nodes;
const stroke =
currentConnectionFieldType && shouldColorEdges
? colorTokenToCssVar(FIELDS[currentConnectionFieldType].color)
: colorTokenToCssVar('base.500');
let className = 'react-flow__custom_connection-path';
if (shouldAnimateEdges) {
className = className.concat(' animated');
}
return {
stroke,
className,
};
});
export const CustomConnectionLine = ({
fromX,
fromY,
fromPosition,
toX,
toY,
toPosition,
}: ConnectionLineComponentProps) => {
const { stroke, className } = useAppSelector(selector);
const pathParams = {
sourceX: fromX,
sourceY: fromY,
sourcePosition: fromPosition,
targetX: toX,
targetY: toY,
targetPosition: toPosition,
};
const [dAttr] = getBezierPath(pathParams);
return (
<g>
<path
fill="none"
stroke={stroke}
strokeWidth={2}
className={className}
d={dAttr}
style={{ opacity: 0.8 }}
/>
</g>
);
};

View File

@ -0,0 +1,183 @@
import { Badge, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import { useMemo } from 'react';
import {
BaseEdge,
EdgeLabelRenderer,
EdgeProps,
getBezierPath,
} from 'reactflow';
import { FIELDS, colorTokenToCssVar } from '../types/constants';
import { isInvocationNode } from '../types/types';
const makeEdgeSelector = (
source: string,
sourceHandleId: string | null | undefined,
target: string,
targetHandleId: string | null | undefined,
selected?: boolean
) =>
createSelector(stateSelector, ({ nodes }) => {
const sourceNode = nodes.nodes.find((node) => node.id === source);
const targetNode = nodes.nodes.find((node) => node.id === target);
const isInvocationToInvocationEdge =
isInvocationNode(sourceNode) && isInvocationNode(targetNode);
const isSelected = sourceNode?.selected || targetNode?.selected || selected;
const sourceType = isInvocationToInvocationEdge
? sourceNode?.data?.outputs[sourceHandleId || '']?.type
: undefined;
const stroke =
sourceType && nodes.shouldColorEdges
? colorTokenToCssVar(FIELDS[sourceType].color)
: colorTokenToCssVar('base.500');
return {
isSelected,
shouldAnimate: nodes.shouldAnimateEdges && isSelected,
stroke,
};
});
const CollapsedEdge = ({
sourceX,
sourceY,
targetX,
targetY,
sourcePosition,
targetPosition,
markerEnd,
data,
selected,
source,
target,
sourceHandleId,
targetHandleId,
}: EdgeProps<{ count: number }>) => {
const selector = useMemo(
() =>
makeEdgeSelector(
source,
sourceHandleId,
target,
targetHandleId,
selected
),
[selected, source, sourceHandleId, target, targetHandleId]
);
const { isSelected, shouldAnimate } = useAppSelector(selector);
const [edgePath, labelX, labelY] = getBezierPath({
sourceX,
sourceY,
sourcePosition,
targetX,
targetY,
targetPosition,
});
const { base500 } = useChakraThemeTokens();
return (
<>
<BaseEdge
path={edgePath}
markerEnd={markerEnd}
style={{
strokeWidth: isSelected ? 3 : 2,
stroke: base500,
opacity: isSelected ? 0.8 : 0.5,
animation: shouldAnimate
? 'dashdraw 0.5s linear infinite'
: undefined,
strokeDasharray: shouldAnimate ? 5 : 'none',
}}
/>
{data?.count && data.count > 1 && (
<EdgeLabelRenderer>
<Flex
sx={{
position: 'absolute',
transform: `translate(-50%, -50%) translate(${labelX}px,${labelY}px)`,
}}
className="nodrag nopan"
>
<Badge
variant="solid"
sx={{
bg: 'base.500',
opacity: isSelected ? 0.8 : 0.5,
boxShadow: 'base',
}}
>
{data.count}
</Badge>
</Flex>
</EdgeLabelRenderer>
)}
</>
);
};
const DefaultEdge = ({
sourceX,
sourceY,
targetX,
targetY,
sourcePosition,
targetPosition,
markerEnd,
selected,
source,
target,
sourceHandleId,
targetHandleId,
}: EdgeProps) => {
const selector = useMemo(
() =>
makeEdgeSelector(
source,
sourceHandleId,
target,
targetHandleId,
selected
),
[source, sourceHandleId, target, targetHandleId, selected]
);
const { isSelected, shouldAnimate, stroke } = useAppSelector(selector);
const [edgePath] = getBezierPath({
sourceX,
sourceY,
sourcePosition,
targetX,
targetY,
targetPosition,
});
return (
<BaseEdge
path={edgePath}
markerEnd={markerEnd}
style={{
strokeWidth: isSelected ? 3 : 2,
stroke,
opacity: isSelected ? 0.8 : 0.5,
animation: shouldAnimate ? 'dashdraw 0.5s linear infinite' : undefined,
strokeDasharray: shouldAnimate ? 5 : 'none',
}}
/>
);
};
export const edgeTypes = {
collapsed: CollapsedEdge,
default: DefaultEdge,
};

View File

@ -0,0 +1,9 @@
import InvocationNode from './nodes/InvocationNode';
import CurrentImageNode from './nodes/CurrentImageNode';
import NotesNode from './nodes/NotesNode';
export const nodeTypes = {
invocation: InvocationNode,
current_image: CurrentImageNode,
notes: NotesNode,
};

View File

@ -1,64 +0,0 @@
import { Tooltip } from '@chakra-ui/react';
import { CSSProperties, memo } from 'react';
import { Handle, Position, Connection, HandleType } from 'reactflow';
import { FIELDS, HANDLE_TOOLTIP_OPEN_DELAY } from '../types/constants';
// import { useConnectionEventStyles } from '../hooks/useConnectionEventStyles';
import { InputFieldTemplate, OutputFieldTemplate } from '../types/types';
const handleBaseStyles: CSSProperties = {
position: 'absolute',
width: '1rem',
height: '1rem',
borderWidth: 0,
};
const inputHandleStyles: CSSProperties = {
left: '-1rem',
};
const outputHandleStyles: CSSProperties = {
right: '-0.5rem',
};
// const requiredConnectionStyles: CSSProperties = {
// boxShadow: '0 0 0.5rem 0.5rem var(--invokeai-colors-error-400)',
// };
type FieldHandleProps = {
nodeId: string;
field: InputFieldTemplate | OutputFieldTemplate;
isValidConnection: (connection: Connection) => boolean;
handleType: HandleType;
styles?: CSSProperties;
};
const FieldHandle = (props: FieldHandleProps) => {
const { field, isValidConnection, handleType, styles } = props;
const { name, type } = field;
return (
<Tooltip
label={type}
placement={handleType === 'target' ? 'start' : 'end'}
hasArrow
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
>
<Handle
type={handleType}
id={name}
isValidConnection={isValidConnection}
position={handleType === 'target' ? Position.Left : Position.Right}
style={{
backgroundColor: FIELDS[type].colorCssVar,
...styles,
...handleBaseStyles,
...(handleType === 'target' ? inputHandleStyles : outputHandleStyles),
// ...(inputRequirement === 'always' ? requiredConnectionStyles : {}),
// ...connectionEventStyles,
}}
/>
</Tooltip>
);
};
export default memo(FieldHandle);

View File

@ -1,8 +1,8 @@
import 'reactflow/dist/style.css';
import { Tooltip, Badge, Flex } from '@chakra-ui/react';
import { Badge, Flex, Tooltip } from '@chakra-ui/react';
import { map } from 'lodash-es';
import { FIELDS } from '../types/constants';
import { memo } from 'react';
import 'reactflow/dist/style.css';
import { FIELDS } from '../types/constants';
const FieldTypeLegend = () => {
return (
@ -10,8 +10,14 @@ const FieldTypeLegend = () => {
{map(FIELDS, ({ title, description, color }, key) => (
<Tooltip key={key} label={description}>
<Badge
colorScheme={color}
sx={{ userSelect: 'none' }}
sx={{
userSelect: 'none',
color:
parseInt(color.split('.')[1] ?? '0', 10) < 500
? 'base.800'
: 'base.50',
bg: color,
}}
textAlign="center"
>
{title}

View File

@ -1,4 +1,3 @@
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useCallback } from 'react';
import {
@ -7,35 +6,49 @@ import {
OnConnectEnd,
OnConnectStart,
OnEdgesChange,
OnEdgesDelete,
OnInit,
OnMove,
OnNodesChange,
OnNodesDelete,
OnSelectionChangeFunc,
ProOptions,
ReactFlow,
} from 'reactflow';
import { useIsValidConnection } from '../hooks/useIsValidConnection';
import {
connectionEnded,
connectionMade,
connectionStarted,
edgesChanged,
edgesDeleted,
nodesChanged,
setEditorInstance,
nodesDeleted,
selectedEdgesChanged,
selectedNodesChanged,
zoomChanged,
} from '../store/nodesSlice';
import { InvocationComponent } from './InvocationComponent';
import ProgressImageNode from './ProgressImageNode';
import BottomLeftPanel from './panels/BottomLeftPanel.tsx';
import MinimapPanel from './panels/MinimapPanel';
import TopCenterPanel from './panels/TopCenterPanel';
import TopLeftPanel from './panels/TopLeftPanel';
import TopRightPanel from './panels/TopRightPanel';
import { CustomConnectionLine } from './CustomConnectionLine';
import { edgeTypes } from './CustomEdges';
import { nodeTypes } from './CustomNodes';
import BottomLeftPanel from './editorPanels/BottomLeftPanel';
import MinimapPanel from './editorPanels/MinimapPanel';
import TopCenterPanel from './editorPanels/TopCenterPanel';
import TopLeftPanel from './editorPanels/TopLeftPanel';
import TopRightPanel from './editorPanels/TopRightPanel';
const nodeTypes = {
invocation: InvocationComponent,
progress_image: ProgressImageNode,
};
// TODO: can we support reactflow? if not, we could style the attribution so it matches the app
const proOptions: ProOptions = { hideAttribution: true };
export const Flow = () => {
const dispatch = useAppDispatch();
const nodes = useAppSelector((state: RootState) => state.nodes.nodes);
const edges = useAppSelector((state: RootState) => state.nodes.edges);
const nodes = useAppSelector((state) => state.nodes.nodes);
const edges = useAppSelector((state) => state.nodes.edges);
const shouldSnapToGrid = useAppSelector(
(state) => state.nodes.shouldSnapToGrid
);
const isValidConnection = useIsValidConnection();
const onNodesChange: OnNodesChange = useCallback(
(changes) => {
@ -69,10 +82,36 @@ export const Flow = () => {
dispatch(connectionEnded());
}, [dispatch]);
const onInit: OnInit = useCallback(
(v) => {
dispatch(setEditorInstance(v));
if (v) v.fitView();
const onInit: OnInit = useCallback((v) => {
v.fitView();
}, []);
const onEdgesDelete: OnEdgesDelete = useCallback(
(edges) => {
dispatch(edgesDeleted(edges));
},
[dispatch]
);
const onNodesDelete: OnNodesDelete = useCallback(
(nodes) => {
dispatch(nodesDeleted(nodes));
},
[dispatch]
);
const handleSelectionChange: OnSelectionChangeFunc = useCallback(
({ nodes, edges }) => {
dispatch(selectedNodesChanged(nodes ? nodes.map((n) => n.id) : []));
dispatch(selectedEdgesChanged(edges ? edges.map((e) => e.id) : []));
},
[dispatch]
);
const handleMove: OnMove = useCallback(
(e, viewport) => {
const { zoom } = viewport;
dispatch(zoomChanged(zoom));
},
[dispatch]
);
@ -80,24 +119,33 @@ export const Flow = () => {
return (
<ReactFlow
nodeTypes={nodeTypes}
edgeTypes={edgeTypes}
nodes={nodes}
edges={edges}
onNodesChange={onNodesChange}
onEdgesChange={onEdgesChange}
onEdgesDelete={onEdgesDelete}
onNodesDelete={onNodesDelete}
onConnectStart={onConnectStart}
onConnect={onConnect}
onConnectEnd={onConnectEnd}
onMove={handleMove}
connectionLineComponent={CustomConnectionLine}
onSelectionChange={handleSelectionChange}
onInit={onInit}
defaultEdgeOptions={{
style: { strokeWidth: 2 },
}}
isValidConnection={isValidConnection}
minZoom={0.2}
snapToGrid={shouldSnapToGrid}
snapGrid={[25, 25]}
connectionRadius={30}
proOptions={proOptions}
>
<TopLeftPanel />
<TopCenterPanel />
<TopRightPanel />
<BottomLeftPanel />
<Background />
<MinimapPanel />
<Background />
</ReactFlow>
);
};

View File

@ -1,55 +0,0 @@
import { Flex, Heading, Icon, Tooltip } from '@chakra-ui/react';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/hooks/useBuildInvocation';
import { memo } from 'react';
import { FaInfoCircle } from 'react-icons/fa';
interface IAINodeHeaderProps {
nodeId?: string;
title?: string;
description?: string;
}
const IAINodeHeader = (props: IAINodeHeaderProps) => {
const { nodeId, title, description } = props;
return (
<Flex
className={DRAG_HANDLE_CLASSNAME}
sx={{
borderTopRadius: 'md',
alignItems: 'center',
justifyContent: 'space-between',
px: 2,
py: 1,
bg: 'base.100',
_dark: { bg: 'base.900' },
}}
>
<Tooltip label={nodeId}>
<Heading
size="xs"
sx={{
fontWeight: 600,
color: 'base.900',
_dark: { color: 'base.200' },
}}
>
{title}
</Heading>
</Tooltip>
<Tooltip label={description} placement="top" hasArrow shouldWrapChildren>
<Icon
sx={{
h: 'min-content',
color: 'base.700',
_dark: {
color: 'base.300',
},
}}
as={FaInfoCircle}
/>
</Tooltip>
</Flex>
);
};
export default memo(IAINodeHeader);

View File

@ -1,149 +0,0 @@
import {
Box,
Divider,
Flex,
FormControl,
FormLabel,
HStack,
Tooltip,
} from '@chakra-ui/react';
import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection';
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
import {
InputFieldTemplate,
InputFieldValue,
InvocationTemplate,
} from 'features/nodes/types/types';
import { map } from 'lodash-es';
import { ReactNode, memo, useCallback } from 'react';
import FieldHandle from '../FieldHandle';
import InputFieldComponent from '../InputFieldComponent';
interface IAINodeInputProps {
nodeId: string;
input: InputFieldValue;
template?: InputFieldTemplate | undefined;
connected: boolean;
}
function IAINodeInput(props: IAINodeInputProps) {
const { nodeId, input, template, connected } = props;
const isValidConnection = useIsValidConnection();
return (
<Box
className="nopan"
position="relative"
borderColor={
!template
? 'error.400'
: !connected &&
['always', 'connectionOnly'].includes(
String(template?.inputRequirement)
) &&
input.value === undefined
? 'warning.400'
: undefined
}
>
<FormControl isDisabled={!template ? true : connected} pl={2}>
{!template ? (
<HStack justifyContent="space-between" alignItems="center">
<FormLabel>Unknown input: {input.name}</FormLabel>
</HStack>
) : (
<>
<HStack justifyContent="space-between" alignItems="center">
<HStack>
<Tooltip
label={template?.description}
placement="top"
hasArrow
shouldWrapChildren
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
>
<FormLabel>{template?.title}</FormLabel>
</Tooltip>
</HStack>
<InputFieldComponent
nodeId={nodeId}
field={input}
template={template}
/>
</HStack>
{!['never', 'directOnly'].includes(
template?.inputRequirement ?? ''
) && (
<FieldHandle
nodeId={nodeId}
field={template}
isValidConnection={isValidConnection}
handleType="target"
/>
)}
</>
)}
</FormControl>
</Box>
);
}
interface IAINodeInputsProps {
nodeId: string;
template: InvocationTemplate;
inputs: Record<string, InputFieldValue>;
}
const IAINodeInputs = (props: IAINodeInputsProps) => {
const { nodeId, template, inputs } = props;
const edges = useAppSelector((state: RootState) => state.nodes.edges);
const renderIAINodeInputs = useCallback(() => {
const IAINodeInputsToRender: ReactNode[] = [];
const inputSockets = map(inputs);
inputSockets.forEach((inputSocket, index) => {
const inputTemplate = template.inputs[inputSocket.name];
const isConnected = Boolean(
edges.filter((connectedInput) => {
return (
connectedInput.target === nodeId &&
connectedInput.targetHandle === inputSocket.name
);
}).length
);
if (index < inputSockets.length) {
IAINodeInputsToRender.push(
<Divider key={`${inputSocket.id}.divider`} />
);
}
IAINodeInputsToRender.push(
<IAINodeInput
key={inputSocket.id}
nodeId={nodeId}
input={inputSocket}
template={inputTemplate}
connected={isConnected}
/>
);
});
return (
<Flex className="nopan" flexDir="column" gap={2} p={2}>
{IAINodeInputsToRender}
</Flex>
);
}, [edges, inputs, nodeId, template.inputs]);
return renderIAINodeInputs();
};
export default memo(IAINodeInputs);

View File

@ -1,97 +0,0 @@
import {
InvocationTemplate,
OutputFieldTemplate,
OutputFieldValue,
} from 'features/nodes/types/types';
import { memo, ReactNode, useCallback } from 'react';
import { map } from 'lodash-es';
import { useAppSelector } from 'app/store/storeHooks';
import { RootState } from 'app/store/store';
import { Box, Flex, FormControl, FormLabel, HStack } from '@chakra-ui/react';
import FieldHandle from '../FieldHandle';
import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection';
interface IAINodeOutputProps {
nodeId: string;
output: OutputFieldValue;
template?: OutputFieldTemplate | undefined;
connected: boolean;
}
function IAINodeOutput(props: IAINodeOutputProps) {
const { nodeId, output, template, connected } = props;
const isValidConnection = useIsValidConnection();
return (
<Box position="relative">
<FormControl isDisabled={!template ? true : connected} paddingRight={3}>
{!template ? (
<HStack justifyContent="space-between" alignItems="center">
<FormLabel color="error.400">
Unknown Output: {output.name}
</FormLabel>
</HStack>
) : (
<>
<FormLabel textAlign="end" padding={1}>
{template?.title}
</FormLabel>
<FieldHandle
key={output.id}
nodeId={nodeId}
field={template}
isValidConnection={isValidConnection}
handleType="source"
/>
</>
)}
</FormControl>
</Box>
);
}
interface IAINodeOutputsProps {
nodeId: string;
template: InvocationTemplate;
outputs: Record<string, OutputFieldValue>;
}
const IAINodeOutputs = (props: IAINodeOutputsProps) => {
const { nodeId, template, outputs } = props;
const edges = useAppSelector((state: RootState) => state.nodes.edges);
const renderIAINodeOutputs = useCallback(() => {
const IAINodeOutputsToRender: ReactNode[] = [];
const outputSockets = map(outputs);
outputSockets.forEach((outputSocket) => {
const outputTemplate = template.outputs[outputSocket.name];
const isConnected = Boolean(
edges.filter((connectedInput) => {
return (
connectedInput.source === nodeId &&
connectedInput.sourceHandle === outputSocket.name
);
}).length
);
IAINodeOutputsToRender.push(
<IAINodeOutput
key={outputSocket.id}
nodeId={nodeId}
output={outputSocket}
template={outputTemplate}
connected={isConnected}
/>
);
});
return <Flex flexDir="column">{IAINodeOutputsToRender}</Flex>;
}, [edges, nodeId, outputs, template.outputs]);
return renderIAINodeOutputs();
};
export default memo(IAINodeOutputs);

View File

@ -1,252 +0,0 @@
import { Box } from '@chakra-ui/react';
import { memo } from 'react';
import { InputFieldTemplate, InputFieldValue } from '../types/types';
import ArrayInputFieldComponent from './fields/ArrayInputFieldComponent';
import BooleanInputFieldComponent from './fields/BooleanInputFieldComponent';
import ClipInputFieldComponent from './fields/ClipInputFieldComponent';
import ColorInputFieldComponent from './fields/ColorInputFieldComponent';
import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent';
import ControlInputFieldComponent from './fields/ControlInputFieldComponent';
import ControlNetModelInputFieldComponent from './fields/ControlNetModelInputFieldComponent';
import EnumInputFieldComponent from './fields/EnumInputFieldComponent';
import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFieldComponent';
import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
import ItemInputFieldComponent from './fields/ItemInputFieldComponent';
import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent';
import LoRAModelInputFieldComponent from './fields/LoRAModelInputFieldComponent';
import ModelInputFieldComponent from './fields/ModelInputFieldComponent';
import NumberInputFieldComponent from './fields/NumberInputFieldComponent';
import StringInputFieldComponent from './fields/StringInputFieldComponent';
import UnetInputFieldComponent from './fields/UnetInputFieldComponent';
import VaeInputFieldComponent from './fields/VaeInputFieldComponent';
import VaeModelInputFieldComponent from './fields/VaeModelInputFieldComponent';
import RefinerModelInputFieldComponent from './fields/RefinerModelInputFieldComponent';
type InputFieldComponentProps = {
nodeId: string;
field: InputFieldValue;
template: InputFieldTemplate;
};
// build an individual input element based on the schema
const InputFieldComponent = (props: InputFieldComponentProps) => {
const { nodeId, field, template } = props;
const { type } = field;
if (type === 'string' && template.type === 'string') {
return (
<StringInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'boolean' && template.type === 'boolean') {
return (
<BooleanInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (
(type === 'integer' && template.type === 'integer') ||
(type === 'float' && template.type === 'float')
) {
return (
<NumberInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'enum' && template.type === 'enum') {
return (
<EnumInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'image' && template.type === 'image') {
return (
<ImageInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'latents' && template.type === 'latents') {
return (
<LatentsInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'conditioning' && template.type === 'conditioning') {
return (
<ConditioningInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'unet' && template.type === 'unet') {
return (
<UnetInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'clip' && template.type === 'clip') {
return (
<ClipInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'vae' && template.type === 'vae') {
return (
<VaeInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'control' && template.type === 'control') {
return (
<ControlInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'model' && template.type === 'model') {
return (
<ModelInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'refiner_model' && template.type === 'refiner_model') {
return (
<RefinerModelInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'vae_model' && template.type === 'vae_model') {
return (
<VaeModelInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'lora_model' && template.type === 'lora_model') {
return (
<LoRAModelInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'controlnet_model' && template.type === 'controlnet_model') {
return (
<ControlNetModelInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'array' && template.type === 'array') {
return (
<ArrayInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'item' && template.type === 'item') {
return (
<ItemInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'color' && template.type === 'color') {
return (
<ColorInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'item' && template.type === 'item') {
return (
<ItemInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'image_collection' && template.type === 'image_collection') {
return (
<ImageCollectionInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
return <Box p={2}>Unknown field type: {type}</Box>;
};
export default memo(InputFieldComponent);

View File

@ -0,0 +1,57 @@
import { ChevronUpIcon } from '@chakra-ui/icons';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import { nodeIsOpenChanged } from 'features/nodes/store/nodesSlice';
import { NodeData } from 'features/nodes/types/types';
import { memo, useCallback } from 'react';
import { NodeProps, useUpdateNodeInternals } from 'reactflow';
interface Props {
nodeProps: NodeProps<NodeData>;
}
const NodeCollapseButton = (props: Props) => {
const { id: nodeId, isOpen } = props.nodeProps.data;
const dispatch = useAppDispatch();
const updateNodeInternals = useUpdateNodeInternals();
const handleClick = useCallback(() => {
dispatch(nodeIsOpenChanged({ nodeId, isOpen: !isOpen }));
updateNodeInternals(nodeId);
}, [dispatch, isOpen, nodeId, updateNodeInternals]);
return (
<IAIIconButton
className="nopan"
onClick={handleClick}
aria-label="Minimize"
sx={{
minW: 8,
w: 8,
h: 8,
color: 'base.500',
_dark: {
color: 'base.500',
},
_hover: {
color: 'base.700',
_dark: {
color: 'base.300',
},
},
}}
variant="link"
icon={
<ChevronUpIcon
sx={{
transform: isOpen ? 'rotate(0deg)' : 'rotate(180deg)',
transitionProperty: 'common',
transitionDuration: 'normal',
}}
/>
}
/>
);
};
export default memo(NodeCollapseButton);

View File

@ -0,0 +1,74 @@
import { useColorModeValue } from '@chakra-ui/react';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import {
InvocationNodeData,
InvocationTemplate,
} from 'features/nodes/types/types';
import { map } from 'lodash-es';
import { CSSProperties, memo, useMemo } from 'react';
import { Handle, NodeProps, Position } from 'reactflow';
interface Props {
nodeProps: NodeProps<InvocationNodeData>;
nodeTemplate: InvocationTemplate;
}
const NodeCollapsedHandles = (props: Props) => {
const { data } = props.nodeProps;
const { base400, base600 } = useChakraThemeTokens();
const backgroundColor = useColorModeValue(base400, base600);
const dummyHandleStyles: CSSProperties = useMemo(
() => ({
borderWidth: 0,
borderRadius: '3px',
width: '1rem',
height: '1rem',
backgroundColor,
zIndex: -1,
}),
[backgroundColor]
);
return (
<>
<Handle
type="target"
id={`${data.id}-collapsed-target`}
isConnectable={false}
position={Position.Left}
style={{ ...dummyHandleStyles, left: '-0.5rem' }}
/>
{map(data.inputs, (input) => (
<Handle
key={`${data.id}-${input.name}-collapsed-input-handle`}
type="target"
id={input.name}
isValidConnection={() => false}
position={Position.Left}
style={{ visibility: 'hidden' }}
/>
))}
<Handle
type="source"
id={`${data.id}-collapsed-source`}
isValidConnection={() => false}
isConnectable={false}
position={Position.Right}
style={{ ...dummyHandleStyles, right: '-0.5rem' }}
/>
{map(data.outputs, (output) => (
<Handle
key={`${data.id}-${output.name}-collapsed-output-handle`}
type="source"
id={output.name}
isValidConnection={() => false}
position={Position.Right}
style={{ visibility: 'hidden' }}
/>
))}
</>
);
};
export default memo(NodeCollapsedHandles);

View File

@ -0,0 +1,77 @@
import {
Checkbox,
Flex,
FormControl,
FormLabel,
Spacer,
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { fieldBooleanValueChanged } from 'features/nodes/store/nodesSlice';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import {
InvocationNodeData,
InvocationTemplate,
} from 'features/nodes/types/types';
import { some } from 'lodash-es';
import { ChangeEvent, memo, useCallback, useMemo } from 'react';
import { NodeProps } from 'reactflow';
type Props = {
nodeProps: NodeProps<InvocationNodeData>;
nodeTemplate: InvocationTemplate;
};
const NodeFooter = (props: Props) => {
const { nodeProps, nodeTemplate } = props;
const dispatch = useAppDispatch();
const hasImageOutput = useMemo(
() =>
some(nodeTemplate?.outputs, (output) =>
['ImageField', 'ImageCollection'].includes(output.type)
),
[nodeTemplate?.outputs]
);
const handleChangeIsIntermediate = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(
fieldBooleanValueChanged({
nodeId: nodeProps.data.id,
fieldName: 'is_intermediate',
value: !e.target.checked,
})
);
},
[dispatch, nodeProps.data.id]
);
return (
<Flex
className={DRAG_HANDLE_CLASSNAME}
layerStyle="nodeFooter"
sx={{
w: 'full',
borderBottomRadius: 'base',
px: 2,
py: 0,
h: 6,
}}
>
<Spacer />
{hasImageOutput && (
<FormControl as={Flex} sx={{ alignItems: 'center', gap: 2, w: 'auto' }}>
<FormLabel sx={{ fontSize: 'xs', mb: '1px' }}>Save Output</FormLabel>
<Checkbox
className="nopan"
size="sm"
onChange={handleChangeIsIntermediate}
isChecked={!nodeProps.data.inputs['is_intermediate']?.value}
/>
</FormControl>
)}
</Flex>
);
};
export default memo(NodeFooter);

View File

@ -0,0 +1,113 @@
import {
Flex,
FormControl,
FormLabel,
Icon,
Modal,
ModalBody,
ModalCloseButton,
ModalContent,
ModalFooter,
ModalHeader,
ModalOverlay,
Text,
Tooltip,
useDisclosure,
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import IAITextarea from 'common/components/IAITextarea';
import { nodeNotesChanged } from 'features/nodes/store/nodesSlice';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import {
InvocationNodeData,
InvocationTemplate,
} from 'features/nodes/types/types';
import { ChangeEvent, memo, useCallback } from 'react';
import { FaInfoCircle } from 'react-icons/fa';
import { NodeProps } from 'reactflow';
interface Props {
nodeProps: NodeProps<InvocationNodeData>;
nodeTemplate: InvocationTemplate;
}
const NodeNotesEdit = (props: Props) => {
const { nodeProps, nodeTemplate } = props;
const { data } = nodeProps;
const { isOpen, onOpen, onClose } = useDisclosure();
const dispatch = useAppDispatch();
const handleNotesChanged = useCallback(
(e: ChangeEvent<HTMLTextAreaElement>) => {
dispatch(nodeNotesChanged({ nodeId: data.id, notes: e.target.value }));
},
[data.id, dispatch]
);
return (
<>
<Tooltip
label={
nodeTemplate ? (
<TooltipContent nodeProps={nodeProps} nodeTemplate={nodeTemplate} />
) : undefined
}
placement="top"
shouldWrapChildren
>
<Flex
className={DRAG_HANDLE_CLASSNAME}
onClick={onOpen}
sx={{
alignItems: 'center',
justifyContent: 'center',
w: 8,
h: 8,
cursor: 'pointer',
}}
>
<Icon
as={FaInfoCircle}
sx={{ boxSize: 4, w: 8, color: 'base.400' }}
/>
</Flex>
</Tooltip>
<Modal isOpen={isOpen} onClose={onClose} isCentered>
<ModalOverlay />
<ModalContent>
<ModalHeader>
{data.label || nodeTemplate?.title || 'Unknown Node'}
</ModalHeader>
<ModalCloseButton />
<ModalBody>
<FormControl>
<FormLabel>Notes</FormLabel>
<IAITextarea
value={data.notes}
onChange={handleNotesChanged}
rows={10}
/>
</FormControl>
</ModalBody>
<ModalFooter />
</ModalContent>
</Modal>
</>
);
};
export default memo(NodeNotesEdit);
type TooltipContentProps = Props;
const TooltipContent = (props: TooltipContentProps) => {
return (
<Flex sx={{ flexDir: 'column' }}>
<Text sx={{ fontWeight: 600 }}>{props.nodeTemplate?.title}</Text>
<Text sx={{ opacity: 0.7, fontStyle: 'oblique 5deg' }}>
{props.nodeTemplate?.description}
</Text>
{props.nodeProps.data.notes && <Text>{props.nodeProps.data.notes}</Text>}
</Flex>
);
};

View File

@ -2,7 +2,10 @@ import { NODE_MIN_WIDTH } from 'features/nodes/types/constants';
import { memo } from 'react';
import { NodeResizeControl, NodeResizerProps } from 'reactflow';
const IAINodeResizer = (props: NodeResizerProps) => {
// this causes https://github.com/invoke-ai/InvokeAI/issues/4140
// not using it for now
const NodeResizer = (props: NodeResizerProps) => {
const { ...rest } = props;
return (
<NodeResizeControl
@ -21,4 +24,4 @@ const IAINodeResizer = (props: NodeResizerProps) => {
);
};
export default memo(IAINodeResizer);
export default memo(NodeResizer);

View File

@ -0,0 +1,69 @@
import { Flex } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import IAIPopover from 'common/components/IAIPopover';
import IAISwitch from 'common/components/IAISwitch';
import { fieldBooleanValueChanged } from 'features/nodes/store/nodesSlice';
import { InvocationNodeData } from 'features/nodes/types/types';
import { ChangeEvent, memo, useCallback } from 'react';
import { FaBars } from 'react-icons/fa';
interface Props {
data: InvocationNodeData;
}
const NodeSettings = (props: Props) => {
const { data } = props;
const dispatch = useAppDispatch();
const handleChangeIsIntermediate = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(
fieldBooleanValueChanged({
nodeId: data.id,
fieldName: 'is_intermediate',
value: e.target.checked,
})
);
},
[data.id, dispatch]
);
return (
<IAIPopover
isLazy={false}
triggerComponent={
<IAIIconButton
className="nopan"
aria-label="Node Settings"
variant="link"
sx={{
minW: 8,
color: 'base.500',
_dark: {
color: 'base.500',
},
_hover: {
color: 'base.700',
_dark: {
color: 'base.300',
},
},
}}
icon={<FaBars />}
/>
}
>
<Flex sx={{ flexDir: 'column', gap: 4, w: 64 }}>
<IAISwitch
label="Intermediate"
isChecked={Boolean(data.inputs['is_intermediate']?.value)}
onChange={handleChangeIsIntermediate}
helperText="The outputs of intermediate nodes are considered temporary objects. Intermediate images are not added to the gallery."
/>
</Flex>
</IAIPopover>
);
};
export default memo(NodeSettings);

View File

@ -0,0 +1,185 @@
import {
Badge,
CircularProgress,
Flex,
Icon,
Image,
Text,
Tooltip,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import {
InvocationNodeData,
NodeExecutionState,
NodeStatus,
} from 'features/nodes/types/types';
import { memo, useMemo } from 'react';
import { FaCheck, FaEllipsisH, FaExclamation } from 'react-icons/fa';
import { NodeProps } from 'reactflow';
type Props = {
nodeProps: NodeProps<InvocationNodeData>;
};
const iconBoxSize = 3;
const circleStyles = {
circle: {
transitionProperty: 'none',
transitionDuration: '0s',
},
'.chakra-progress__track': { stroke: 'transparent' },
};
const NodeStatusIndicator = (props: Props) => {
const nodeId = props.nodeProps.data.id;
const selectNodeExecutionState = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => nodes.nodeExecutionStates[nodeId]
),
[nodeId]
);
const nodeExecutionState = useAppSelector(selectNodeExecutionState);
if (!nodeExecutionState) {
return null;
}
return (
<Tooltip
label={<TooltipLabel nodeExecutionState={nodeExecutionState} />}
placement="top"
>
<Flex
className={DRAG_HANDLE_CLASSNAME}
sx={{
w: 5,
h: 'full',
alignItems: 'center',
justifyContent: 'flex-end',
}}
>
<StatusIcon nodeExecutionState={nodeExecutionState} />
</Flex>
</Tooltip>
);
};
export default memo(NodeStatusIndicator);
type TooltipLabelProps = {
nodeExecutionState: NodeExecutionState;
};
const TooltipLabel = ({ nodeExecutionState }: TooltipLabelProps) => {
const { status, progress, progressImage } = nodeExecutionState;
if (status === NodeStatus.PENDING) {
return <Text>Pending</Text>;
}
if (status === NodeStatus.IN_PROGRESS) {
if (progressImage) {
return (
<Flex sx={{ pos: 'relative', pt: 1.5, pb: 0.5 }}>
<Image
src={progressImage.dataURL}
sx={{ w: 32, h: 32, borderRadius: 'base', objectFit: 'contain' }}
/>
{progress !== null && (
<Badge
variant="solid"
sx={{ pos: 'absolute', top: 2.5, insetInlineEnd: 1 }}
>
{Math.round(progress * 100)}%
</Badge>
)}
</Flex>
);
}
if (progress !== null) {
return <Text>In Progress ({Math.round(progress * 100)}%)</Text>;
}
return <Text>In Progress</Text>;
}
if (status === NodeStatus.COMPLETED) {
return <Text>Completed</Text>;
}
if (status === NodeStatus.FAILED) {
return <Text>nodeExecutionState.error</Text>;
}
return null;
};
type StatusIconProps = {
nodeExecutionState: NodeExecutionState;
};
const StatusIcon = (props: StatusIconProps) => {
const { progress, status } = props.nodeExecutionState;
if (status === NodeStatus.PENDING) {
return (
<Icon
as={FaEllipsisH}
sx={{
boxSize: iconBoxSize,
color: 'base.600',
_dark: { color: 'base.300' },
}}
/>
);
}
if (status === NodeStatus.IN_PROGRESS) {
return progress === null ? (
<CircularProgress
isIndeterminate
size="14px"
color="base.500"
thickness={14}
sx={circleStyles}
/>
) : (
<CircularProgress
value={Math.round(progress * 100)}
size="14px"
color="base.500"
thickness={14}
sx={circleStyles}
/>
);
}
if (status === NodeStatus.COMPLETED) {
return (
<Icon
as={FaCheck}
sx={{
boxSize: iconBoxSize,
color: 'ok.600',
_dark: { color: 'ok.300' },
}}
/>
);
}
if (status === NodeStatus.FAILED) {
return (
<Icon
as={FaExclamation}
sx={{
boxSize: iconBoxSize,
color: 'error.600',
_dark: { color: 'error.300' },
}}
/>
);
}
return null;
};

View File

@ -0,0 +1,123 @@
import {
Box,
Editable,
EditableInput,
EditablePreview,
Flex,
useEditableControls,
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { nodeLabelChanged } from 'features/nodes/store/nodesSlice';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import { NodeData } from 'features/nodes/types/types';
import { MouseEvent, memo, useCallback, useEffect, useState } from 'react';
type Props = {
nodeData: NodeData;
title: string;
};
const NodeTitle = (props: Props) => {
const { title } = props;
const { id: nodeId, label } = props.nodeData;
const dispatch = useAppDispatch();
const [localTitle, setLocalTitle] = useState(label || title);
const handleSubmit = useCallback(
async (newTitle: string) => {
dispatch(nodeLabelChanged({ nodeId, label: newTitle }));
setLocalTitle(newTitle || title);
},
[nodeId, dispatch, title]
);
const handleChange = useCallback((newTitle: string) => {
setLocalTitle(newTitle);
}, []);
useEffect(() => {
// Another component may change the title; sync local title with global state
setLocalTitle(label || title);
}, [label, title]);
return (
<Flex
className="nopan"
sx={{
overflow: 'hidden',
w: 'full',
h: 'full',
alignItems: 'center',
justifyContent: 'center',
cursor: 'text',
}}
>
<Editable
as={Flex}
value={localTitle}
onChange={handleChange}
onSubmit={handleSubmit}
sx={{
alignItems: 'center',
position: 'relative',
w: 'full',
h: 'full',
}}
>
<EditablePreview
fontSize="sm"
sx={{
p: 0,
w: 'full',
}}
noOfLines={1}
/>
<EditableInput
fontSize="sm"
sx={{
p: 0,
_focusVisible: {
p: 0,
boxShadow: 'none',
},
}}
/>
<EditableControls />
</Editable>
</Flex>
);
};
export default memo(NodeTitle);
function EditableControls() {
const { isEditing, getEditButtonProps } = useEditableControls();
const handleDoubleClick = useCallback(
(e: MouseEvent<HTMLDivElement>) => {
const { onClick } = getEditButtonProps();
if (!onClick) {
return;
}
onClick(e);
},
[getEditButtonProps]
);
if (isEditing) {
return null;
}
return (
<Box
className={DRAG_HANDLE_CLASSNAME}
onDoubleClick={handleDoubleClick}
sx={{
position: 'absolute',
w: 'full',
h: 'full',
top: 0,
cursor: 'grab',
}}
/>
);
}

View File

@ -0,0 +1,96 @@
import {
Box,
ChakraProps,
useColorModeValue,
useToken,
} from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { nodeClicked } from 'features/nodes/store/nodesSlice';
import { MouseEvent, PropsWithChildren, useCallback, useMemo } from 'react';
import { DRAG_HANDLE_CLASSNAME, NODE_WIDTH } from '../../types/constants';
import { NodeData } from 'features/nodes/types/types';
import { NodeProps } from 'reactflow';
const useNodeSelect = (nodeId: string) => {
const dispatch = useAppDispatch();
const selectNode = useCallback(
(e: MouseEvent<HTMLDivElement>) => {
dispatch(nodeClicked({ nodeId, ctrlOrMeta: e.ctrlKey || e.metaKey }));
},
[dispatch, nodeId]
);
return selectNode;
};
type NodeWrapperProps = PropsWithChildren & {
nodeProps: NodeProps<NodeData>;
width?: NonNullable<ChakraProps['sx']>['w'];
};
const NodeWrapper = (props: NodeWrapperProps) => {
const { width, children, nodeProps } = props;
const { data, selected } = nodeProps;
const nodeId = data.id;
const [
nodeSelectedOutlineLight,
nodeSelectedOutlineDark,
shadowsXl,
shadowsBase,
] = useToken('shadows', [
'nodeSelectedOutline.light',
'nodeSelectedOutline.dark',
'shadows.xl',
'shadows.base',
]);
const selectNode = useNodeSelect(nodeId);
const shadow = useColorModeValue(
nodeSelectedOutlineLight,
nodeSelectedOutlineDark
);
const shift = useAppSelector((state) => state.hotkeys.shift);
const opacity = useAppSelector((state) => state.nodes.nodeOpacity);
const className = useMemo(
() => (shift ? DRAG_HANDLE_CLASSNAME : 'nopan'),
[shift]
);
return (
<Box
onClickCapture={selectNode}
className={className}
sx={{
h: 'full',
position: 'relative',
borderRadius: 'base',
w: width ?? NODE_WIDTH,
transitionProperty: 'common',
transitionDuration: '0.1s',
shadow: selected ? shadow : undefined,
opacity,
}}
>
<Box
sx={{
position: 'absolute',
top: 0,
insetInlineEnd: 0,
bottom: 0,
insetInlineStart: 0,
borderRadius: 'base',
pointerEvents: 'none',
shadow: `${shadowsXl}, ${shadowsBase}, ${shadowsBase}`,
zIndex: -1,
}}
/>
{children}
</Box>
);
};
export default NodeWrapper;

View File

@ -1,74 +0,0 @@
import { Flex, Icon } from '@chakra-ui/react';
import { FaExclamationCircle } from 'react-icons/fa';
import { NodeProps } from 'reactflow';
import { InvocationValue } from '../types/types';
import { useAppSelector } from 'app/store/storeHooks';
import { memo, useMemo } from 'react';
import { makeTemplateSelector } from '../store/util/makeTemplateSelector';
import IAINodeHeader from './IAINode/IAINodeHeader';
import IAINodeInputs from './IAINode/IAINodeInputs';
import IAINodeOutputs from './IAINode/IAINodeOutputs';
import IAINodeResizer from './IAINode/IAINodeResizer';
import NodeWrapper from './NodeWrapper';
export const InvocationComponent = memo((props: NodeProps<InvocationValue>) => {
const { id: nodeId, data, selected } = props;
const { type, inputs, outputs } = data;
const templateSelector = useMemo(() => makeTemplateSelector(type), [type]);
const template = useAppSelector(templateSelector);
if (!template) {
return (
<NodeWrapper selected={selected}>
<Flex
className="nopan"
sx={{
alignItems: 'center',
justifyContent: 'center',
cursor: 'auto',
}}
>
<Icon
as={FaExclamationCircle}
sx={{
boxSize: 32,
color: 'base.600',
_dark: { color: 'base.400' },
}}
></Icon>
<IAINodeResizer />
</Flex>
</NodeWrapper>
);
}
return (
<NodeWrapper selected={selected}>
<IAINodeHeader
nodeId={nodeId}
title={template.title}
description={template.description}
/>
<Flex
className={'nopan'}
sx={{
cursor: 'auto',
flexDirection: 'column',
borderBottomRadius: 'md',
py: 2,
bg: 'base.150',
_dark: { bg: 'base.800' },
}}
>
<IAINodeOutputs nodeId={nodeId} outputs={outputs} template={template} />
<IAINodeInputs nodeId={nodeId} inputs={inputs} template={template} />
</Flex>
<IAINodeResizer />
</NodeWrapper>
);
});
InvocationComponent.displayName = 'InvocationComponent';

View File

@ -1,25 +1,45 @@
import { Box } from '@chakra-ui/react';
import { ReactFlowProvider } from 'reactflow';
import ResizeHandle from 'features/ui/components/tabs/ResizeHandle';
import { memo, useState } from 'react';
import { Panel, PanelGroup } from 'react-resizable-panels';
import 'reactflow/dist/style.css';
import { memo } from 'react';
import { Flow } from './Flow';
import NodeEditorPanelGroup from './panel/NodeEditorPanelGroup';
const NodeEditor = () => {
const [isPanelCollapsed, setIsPanelCollapsed] = useState(false);
return (
<Box
layerStyle={'first'}
sx={{
position: 'relative',
width: 'full',
height: 'full',
borderRadius: 'base',
}}
<PanelGroup
id="node-editor"
autoSaveId="node-editor"
direction="horizontal"
style={{ height: '100%', width: '100%' }}
>
<ReactFlowProvider>
<Flow />
</ReactFlowProvider>
</Box>
<Panel
id="node-editor-panel-group"
collapsible
onCollapse={setIsPanelCollapsed}
minSize={25}
>
<NodeEditorPanelGroup />
</Panel>
<ResizeHandle
collapsedDirection={isPanelCollapsed ? 'left' : undefined}
/>
<Panel id="node-editor-content">
<Box
layerStyle={'first'}
sx={{
position: 'relative',
width: 'full',
height: 'full',
borderRadius: 'base',
}}
>
<Flow />
</Box>
</Panel>
</PanelGroup>
);
};

View File

@ -0,0 +1,139 @@
import {
Divider,
Flex,
Heading,
Modal,
ModalBody,
ModalCloseButton,
ModalContent,
ModalHeader,
ModalOverlay,
useDisclosure,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import IAISwitch from 'common/components/IAISwitch';
import { ChangeEvent, useCallback } from 'react';
import { FaCog } from 'react-icons/fa';
import {
shouldAnimateEdgesChanged,
shouldColorEdgesChanged,
shouldSnapToGridChanged,
shouldValidateGraphChanged,
} from '../store/nodesSlice';
const selector = createSelector(stateSelector, ({ nodes }) => {
const {
shouldAnimateEdges,
shouldValidateGraph,
shouldSnapToGrid,
shouldColorEdges,
} = nodes;
return {
shouldAnimateEdges,
shouldValidateGraph,
shouldSnapToGrid,
shouldColorEdges,
};
});
const NodeEditorSettings = () => {
const { isOpen, onOpen, onClose } = useDisclosure();
const dispatch = useAppDispatch();
const {
shouldAnimateEdges,
shouldValidateGraph,
shouldSnapToGrid,
shouldColorEdges,
} = useAppSelector(selector);
const handleChangeShouldValidate = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(shouldValidateGraphChanged(e.target.checked));
},
[dispatch]
);
const handleChangeShouldAnimate = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(shouldAnimateEdgesChanged(e.target.checked));
},
[dispatch]
);
const handleChangeShouldSnap = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(shouldSnapToGridChanged(e.target.checked));
},
[dispatch]
);
const handleChangeShouldColor = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(shouldColorEdgesChanged(e.target.checked));
},
[dispatch]
);
return (
<>
<IAIIconButton
aria-label="Node Editor Settings"
icon={<FaCog />}
onClick={onOpen}
/>
<Modal isOpen={isOpen} onClose={onClose} size="2xl" isCentered>
<ModalOverlay />
<ModalContent>
<ModalHeader>Node Editor Settings</ModalHeader>
<ModalCloseButton />
<ModalBody>
<Flex
sx={{
flexDirection: 'column',
gap: 4,
py: 4,
}}
>
<Heading size="sm">General</Heading>
<IAISwitch
onChange={handleChangeShouldAnimate}
isChecked={shouldAnimateEdges}
label="Animated Edges"
helperText="Animate selected edges and edges connected to selected nodes"
/>
<Divider />
<IAISwitch
isChecked={shouldSnapToGrid}
onChange={handleChangeShouldSnap}
label="Snap to Grid"
helperText="Snap nodes to grid when moved"
/>
<Divider />
<IAISwitch
isChecked={shouldColorEdges}
onChange={handleChangeShouldColor}
label="Color-Code Edges"
helperText="Color-code edges according to their connected fields"
/>
<Heading size="sm" pt={4}>
Advanced
</Heading>
<IAISwitch
isChecked={shouldValidateGraph}
onChange={handleChangeShouldValidate}
label="Validate Connections and Graph"
helperText="Prevent invalid connections from being made, and invalid graphs from being invoked"
/>
</Flex>
</ModalBody>
</ModalContent>
</Modal>
</>
);
};
export default NodeEditorSettings;

View File

@ -1,34 +1,26 @@
import { Box } from '@chakra-ui/react';
import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { memo } from 'react';
import ImageMetadataJSON from 'features/gallery/components/ImageMetadataViewer/ImageMetadataJSON';
import { omit } from 'lodash-es';
import { useMemo } from 'react';
import { useDebounce } from 'use-debounce';
import { buildNodesGraph } from '../util/graphBuilders/buildNodesGraph';
const NodeGraphOverlay = () => {
const state = useAppSelector((state: RootState) => state);
const graph = buildNodesGraph(state);
return (
<Box
as="pre"
sx={{
fontFamily: 'monospace',
position: 'absolute',
top: 2,
right: 2,
opacity: 0.7,
p: 2,
maxHeight: 500,
maxWidth: 500,
overflowY: 'scroll',
borderRadius: 'base',
bg: 'base.200',
_dark: { bg: 'base.800' },
}}
>
{JSON.stringify(graph, null, 2)}
</Box>
const useNodesGraph = () => {
const nodes = useAppSelector((state: RootState) => state.nodes);
const [debouncedNodes] = useDebounce(nodes, 300);
const graph = useMemo(
() => omit(buildNodesGraph(debouncedNodes), 'id'),
[debouncedNodes]
);
return graph;
};
export default memo(NodeGraphOverlay);
const NodeGraph = () => {
const graph = useNodesGraph();
return <ImageMetadataJSON jsonObject={graph} label="Graph" />;
};
export default NodeGraph;

View File

@ -0,0 +1,42 @@
import {
Box,
Slider,
SliderFilledTrack,
SliderThumb,
SliderTrack,
} from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useCallback } from 'react';
import { nodeOpacityChanged } from '../store/nodesSlice';
export default function NodeOpacitySlider() {
const dispatch = useAppDispatch();
const nodeOpacity = useAppSelector((state) => state.nodes.nodeOpacity);
const handleChange = useCallback(
(v: number) => {
dispatch(nodeOpacityChanged(v));
},
[dispatch]
);
return (
<Box>
<Slider
aria-label="Node Opacity"
value={nodeOpacity}
min={0.5}
max={1}
step={0.01}
onChange={handleChange}
orientation="vertical"
defaultValue={30}
>
<SliderTrack>
<SliderFilledTrack />
</SliderTrack>
<SliderThumb />
</Slider>
</Box>
);
}

View File

@ -1,36 +0,0 @@
import { Box, useToken } from '@chakra-ui/react';
import { useAppSelector } from 'app/store/storeHooks';
import { PropsWithChildren } from 'react';
import { DRAG_HANDLE_CLASSNAME } from '../hooks/useBuildInvocation';
import { NODE_MIN_WIDTH } from '../types/constants';
type NodeWrapperProps = PropsWithChildren & {
selected: boolean;
};
const NodeWrapper = (props: NodeWrapperProps) => {
const [nodeSelectedOutline, nodeShadow] = useToken('shadows', [
'nodeSelectedOutline',
'dark-lg',
]);
const shift = useAppSelector((state) => state.hotkeys.shift);
return (
<Box
className={shift ? DRAG_HANDLE_CLASSNAME : 'nopan'}
sx={{
position: 'relative',
borderRadius: 'md',
minWidth: NODE_MIN_WIDTH,
shadow: props.selected
? `${nodeSelectedOutline}, ${nodeShadow}`
: `${nodeShadow}`,
}}
>
{props.children}
</Box>
);
};
export default NodeWrapper;

View File

@ -1,73 +0,0 @@
import { Flex, Image } from '@chakra-ui/react';
import { RootState } from 'app/store/store';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { memo } from 'react';
import { useDispatch, useSelector } from 'react-redux';
import { NodeProps, OnResize } from 'reactflow';
import { setProgressNodeSize } from '../store/nodesSlice';
import IAINodeHeader from './IAINode/IAINodeHeader';
import IAINodeResizer from './IAINode/IAINodeResizer';
import NodeWrapper from './NodeWrapper';
const ProgressImageNode = (props: NodeProps) => {
const progressImage = useSelector(
(state: RootState) => state.system.progressImage
);
const progressNodeSize = useSelector(
(state: RootState) => state.nodes.progressNodeSize
);
const dispatch = useDispatch();
const { selected } = props;
const handleResize: OnResize = (_, newSize) => {
dispatch(setProgressNodeSize(newSize));
};
return (
<NodeWrapper selected={selected}>
<IAINodeHeader
title="Progress Image"
description="Displays the progress image in the Node Editor"
/>
<Flex
sx={{
flexDirection: 'column',
flexShrink: 0,
borderBottomRadius: 'md',
bg: 'base.200',
_dark: { bg: 'base.800' },
width: progressNodeSize.width - 2,
height: progressNodeSize.height - 2,
minW: 250,
minH: 250,
overflow: 'hidden',
}}
>
{progressImage ? (
<Image
src={progressImage.dataURL}
sx={{
w: 'full',
h: 'full',
objectFit: 'contain',
}}
/>
) : (
<Flex
sx={{
minW: 250,
minH: 250,
width: progressNodeSize.width - 2,
height: progressNodeSize.height - 2,
}}
>
<IAINoContentFallback />
</Flex>
)}
</Flex>
<IAINodeResizer onResize={handleResize} />
</NodeWrapper>
);
};
export default memo(ProgressImageNode);

View File

@ -2,18 +2,16 @@ import { ButtonGroup, Tooltip } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import { memo, useCallback } from 'react';
import {
FaCode,
FaExpand,
FaMinus,
FaPlus,
FaInfo,
FaMapMarkerAlt,
} from 'react-icons/fa';
import { useReactFlow } from 'reactflow';
import { useTranslation } from 'react-i18next';
import {
shouldShowGraphOverlayChanged,
FaExpand,
FaInfo,
FaMapMarkerAlt,
FaMinus,
FaPlus,
} from 'react-icons/fa';
import { useReactFlow } from 'reactflow';
import {
shouldShowFieldTypeLegendChanged,
shouldShowMinimapPanelChanged,
} from '../store/nodesSlice';
@ -22,9 +20,6 @@ const ViewportControls = () => {
const { t } = useTranslation();
const { zoomIn, zoomOut, fitView } = useReactFlow();
const dispatch = useAppDispatch();
const shouldShowGraphOverlay = useAppSelector(
(state) => state.nodes.shouldShowGraphOverlay
);
const shouldShowFieldTypeLegend = useAppSelector(
(state) => state.nodes.shouldShowFieldTypeLegend
);
@ -44,10 +39,6 @@ const ViewportControls = () => {
fitView();
}, [fitView]);
const handleClickedToggleGraphOverlay = useCallback(() => {
dispatch(shouldShowGraphOverlayChanged(!shouldShowGraphOverlay));
}, [shouldShowGraphOverlay, dispatch]);
const handleClickedToggleFieldTypeLegend = useCallback(() => {
dispatch(shouldShowFieldTypeLegendChanged(!shouldShowFieldTypeLegend));
}, [shouldShowFieldTypeLegend, dispatch]);
@ -79,20 +70,6 @@ const ViewportControls = () => {
icon={<FaExpand />}
/>
</Tooltip>
<Tooltip
label={
shouldShowGraphOverlay
? t('nodes.hideGraphNodes')
: t('nodes.showGraphNodes')
}
>
<IAIIconButton
aria-label="Toggle nodes graph overlay"
isChecked={shouldShowGraphOverlay}
onClick={handleClickedToggleGraphOverlay}
icon={<FaCode />}
/>
</Tooltip>
<Tooltip
label={
shouldShowFieldTypeLegend

View File

@ -1,10 +1,15 @@
import { memo } from 'react';
import { Panel } from 'reactflow';
import ViewportControls from '../ViewportControls';
import NodeOpacitySlider from '../NodeOpacitySlider';
import { Flex } from '@chakra-ui/react';
const BottomLeftPanel = () => (
<Panel position="bottom-left">
<ViewportControls />
<Flex sx={{ gap: 2 }}>
<ViewportControls />
<NodeOpacitySlider />
</Flex>
</Panel>
);

View File

@ -20,7 +20,7 @@ const MinimapPanel = () => {
const nodeColor = useColorModeValue(
'var(--invokeai-colors-accent-300)',
'var(--invokeai-colors-accent-700)'
'var(--invokeai-colors-accent-600)'
);
const maskColor = useColorModeValue(
@ -32,10 +32,9 @@ const MinimapPanel = () => {
<>
{shouldShowMinimapPanel && (
<MiniMap
nodeStrokeWidth={3}
pannable
zoomable
nodeBorderRadius={30}
nodeBorderRadius={15}
style={miniMapStyle}
nodeColor={nodeColor}
maskColor={maskColor}

View File

@ -2,11 +2,10 @@ import { HStack } from '@chakra-ui/react';
import CancelButton from 'features/parameters/components/ProcessButtons/CancelButton';
import { memo } from 'react';
import { Panel } from 'reactflow';
import NodeEditorSettings from '../NodeEditorSettings';
import ClearGraphButton from '../ui/ClearGraphButton';
import LoadGraphButton from '../ui/LoadGraphButton';
import NodeInvokeButton from '../ui/NodeInvokeButton';
import ReloadSchemaButton from '../ui/ReloadSchemaButton';
import SaveGraphButton from '../ui/SaveGraphButton';
const TopCenterPanel = () => {
return (
@ -15,9 +14,8 @@ const TopCenterPanel = () => {
<NodeInvokeButton />
<CancelButton />
<ReloadSchemaButton />
<SaveGraphButton />
<LoadGraphButton />
<ClearGraphButton />
<NodeEditorSettings />
</HStack>
</Panel>
);

View File

@ -1,22 +1,16 @@
import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { memo } from 'react';
import { Panel } from 'reactflow';
import FieldTypeLegend from '../FieldTypeLegend';
import NodeGraphOverlay from '../NodeGraphOverlay';
const TopRightPanel = () => {
const shouldShowGraphOverlay = useAppSelector(
(state: RootState) => state.nodes.shouldShowGraphOverlay
);
const shouldShowFieldTypeLegend = useAppSelector(
(state: RootState) => state.nodes.shouldShowFieldTypeLegend
(state) => state.nodes.shouldShowFieldTypeLegend
);
return (
<Panel position="top-right">
{shouldShowFieldTypeLegend && <FieldTypeLegend />}
{shouldShowGraphOverlay && <NodeGraphOverlay />}
</Panel>
);
};

View File

@ -1,15 +0,0 @@
import {
ArrayInputFieldTemplate,
ArrayInputFieldValue,
} from 'features/nodes/types/types';
import { memo } from 'react';
import { FaList } from 'react-icons/fa';
import { FieldComponentProps } from './types';
const ArrayInputFieldComponent = (
_props: FieldComponentProps<ArrayInputFieldValue, ArrayInputFieldTemplate>
) => {
return <FaList />;
};
export default memo(ArrayInputFieldComponent);

View File

@ -1,37 +0,0 @@
import { Select } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import {
EnumInputFieldTemplate,
EnumInputFieldValue,
} from 'features/nodes/types/types';
import { ChangeEvent, memo } from 'react';
import { FieldComponentProps } from './types';
const EnumInputFieldComponent = (
props: FieldComponentProps<EnumInputFieldValue, EnumInputFieldTemplate>
) => {
const { nodeId, field, template } = props;
const dispatch = useAppDispatch();
const handleValueChanged = (e: ChangeEvent<HTMLSelectElement>) => {
dispatch(
fieldValueChanged({
nodeId,
fieldName: field.name,
value: e.target.value,
})
);
};
return (
<Select onChange={handleValueChanged} value={field.value}>
{template.options.map((option) => (
<option key={option}>{option}</option>
))}
</Select>
);
};
export default memo(EnumInputFieldComponent);

View File

@ -0,0 +1,47 @@
import { MenuItem, MenuList } from '@chakra-ui/react';
import { ContextMenu, ContextMenuProps } from 'chakra-ui-contextmenu';
import {
InputFieldTemplate,
InputFieldValue,
} from 'features/nodes/types/types';
import { MouseEvent, useCallback } from 'react';
import { menuListMotionProps } from 'theme/components/menu';
type Props = {
nodeId: string;
field: InputFieldValue;
fieldTemplate: InputFieldTemplate;
children: ContextMenuProps<HTMLDivElement>['children'];
};
const FieldContextMenu = (props: Props) => {
const skipEvent = useCallback((e: MouseEvent<HTMLDivElement>) => {
e.preventDefault();
}, []);
return (
<ContextMenu<HTMLDivElement>
menuProps={{
size: 'sm',
isLazy: true,
}}
menuButtonProps={{
bg: 'transparent',
_hover: { bg: 'transparent' },
}}
renderMenu={() => (
<MenuList
sx={{ visibility: 'visible !important' }}
motionProps={menuListMotionProps}
onContextMenu={skipEvent}
>
<MenuItem>Test</MenuItem>
</MenuList>
)}
>
{props.children}
</ContextMenu>
);
};
export default FieldContextMenu;

View File

@ -0,0 +1,122 @@
import { Tooltip } from '@chakra-ui/react';
import { CSSProperties, memo, useMemo } from 'react';
import { Handle, HandleType, NodeProps, Position } from 'reactflow';
import {
FIELDS,
HANDLE_TOOLTIP_OPEN_DELAY,
colorTokenToCssVar,
} from '../../types/constants';
import {
InputFieldTemplate,
InputFieldValue,
InvocationNodeData,
InvocationTemplate,
OutputFieldTemplate,
OutputFieldValue,
} from '../../types/types';
export const handleBaseStyles: CSSProperties = {
position: 'absolute',
width: '1rem',
height: '1rem',
borderWidth: 0,
zIndex: 1,
};
export const inputHandleStyles: CSSProperties = {
left: '-1rem',
};
export const outputHandleStyles: CSSProperties = {
right: '-0.5rem',
};
type FieldHandleProps = {
nodeProps: NodeProps<InvocationNodeData>;
nodeTemplate: InvocationTemplate;
field: InputFieldValue | OutputFieldValue;
fieldTemplate: InputFieldTemplate | OutputFieldTemplate;
handleType: HandleType;
isConnectionInProgress: boolean;
isConnectionStartField: boolean;
connectionError: string | null;
};
const FieldHandle = (props: FieldHandleProps) => {
const {
fieldTemplate,
handleType,
isConnectionInProgress,
isConnectionStartField,
connectionError,
} = props;
const { name, type } = fieldTemplate;
const { color, title } = FIELDS[type];
const styles: CSSProperties = useMemo(() => {
const s: CSSProperties = {
backgroundColor: colorTokenToCssVar(color),
position: 'absolute',
width: '1rem',
height: '1rem',
borderWidth: 0,
zIndex: 1,
};
if (handleType === 'target') {
s.insetInlineStart = '-1rem';
} else {
s.insetInlineEnd = '-1rem';
}
if (isConnectionInProgress && !isConnectionStartField && connectionError) {
s.filter = 'opacity(0.4) grayscale(0.7)';
}
if (isConnectionInProgress && connectionError) {
if (isConnectionStartField) {
s.cursor = 'grab';
} else {
s.cursor = 'not-allowed';
}
} else {
s.cursor = 'crosshair';
}
return s;
}, [
color,
connectionError,
handleType,
isConnectionInProgress,
isConnectionStartField,
]);
const tooltip = useMemo(() => {
if (isConnectionInProgress && isConnectionStartField) {
return title;
}
if (isConnectionInProgress && connectionError) {
return connectionError ?? title;
}
return title;
}, [connectionError, isConnectionInProgress, isConnectionStartField, title]);
return (
<Tooltip
label={tooltip}
placement={handleType === 'target' ? 'start' : 'end'}
hasArrow
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
>
<Handle
type={handleType}
id={name}
position={handleType === 'target' ? Position.Left : Position.Right}
style={styles}
/>
</Tooltip>
);
};
export default memo(FieldHandle);

View File

@ -0,0 +1,161 @@
import {
Editable,
EditableInput,
EditablePreview,
Flex,
useEditableControls,
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIDraggable from 'common/components/IAIDraggable';
import { NodeFieldDraggableData } from 'features/dnd/types';
import { fieldLabelChanged } from 'features/nodes/store/nodesSlice';
import {
InputFieldTemplate,
InputFieldValue,
InvocationNodeData,
InvocationTemplate,
} from 'features/nodes/types/types';
import {
MouseEvent,
memo,
useCallback,
useEffect,
useMemo,
useState,
} from 'react';
interface Props {
nodeData: InvocationNodeData;
nodeTemplate: InvocationTemplate;
field: InputFieldValue;
fieldTemplate: InputFieldTemplate;
isDraggable?: boolean;
}
const FieldTitle = (props: Props) => {
const { nodeData, field, fieldTemplate, isDraggable = false } = props;
const { label } = field;
const { title, input } = fieldTemplate;
const { id: nodeId } = nodeData;
const dispatch = useAppDispatch();
const [localTitle, setLocalTitle] = useState(label || title);
const draggableData: NodeFieldDraggableData | undefined = useMemo(
() =>
input !== 'connection' && isDraggable
? {
id: `${nodeId}-${field.name}`,
payloadType: 'NODE_FIELD',
payload: { nodeId, field, fieldTemplate },
}
: undefined,
[field, fieldTemplate, input, isDraggable, nodeId]
);
const handleSubmit = useCallback(
async (newTitle: string) => {
dispatch(
fieldLabelChanged({ nodeId, fieldName: field.name, label: newTitle })
);
setLocalTitle(newTitle || title);
},
[dispatch, nodeId, field.name, title]
);
const handleChange = useCallback((newTitle: string) => {
setLocalTitle(newTitle);
}, []);
useEffect(() => {
// Another component may change the title; sync local title with global state
setLocalTitle(label || title);
}, [label, title]);
return (
<Flex
className="nopan"
sx={{
position: 'relative',
overflow: 'hidden',
h: 'full',
alignItems: 'flex-start',
justifyContent: 'flex-start',
gap: 1,
}}
>
<Editable
value={localTitle}
onChange={handleChange}
onSubmit={handleSubmit}
sx={{
position: 'relative',
}}
>
<EditablePreview
sx={{
p: 0,
textAlign: 'left',
}}
noOfLines={1}
/>
<EditableInput
sx={{
p: 0,
_focusVisible: {
p: 0,
textAlign: 'left',
boxShadow: 'none',
},
}}
/>
<EditableControls draggableData={draggableData} />
</Editable>
</Flex>
);
};
export default memo(FieldTitle);
type EditableControlsProps = {
draggableData?: NodeFieldDraggableData;
};
function EditableControls(props: EditableControlsProps) {
const { isEditing, getEditButtonProps } = useEditableControls();
const handleDoubleClick = useCallback(
(e: MouseEvent<HTMLDivElement>) => {
const { onClick } = getEditButtonProps();
if (!onClick) {
return;
}
onClick(e);
},
[getEditButtonProps]
);
if (isEditing) {
return null;
}
if (props.draggableData) {
return (
<IAIDraggable
data={props.draggableData}
onDoubleClick={handleDoubleClick}
cursor={props.draggableData ? 'grab' : 'text'}
/>
);
}
return (
<Flex
onDoubleClick={handleDoubleClick}
position="absolute"
w="full"
h="full"
top={0}
insetInlineStart={0}
cursor="text"
/>
);
}

View File

@ -0,0 +1,41 @@
import { Flex, Text } from '@chakra-ui/react';
import { FIELDS } from 'features/nodes/types/constants';
import {
InputFieldTemplate,
InputFieldValue,
InvocationNodeData,
InvocationTemplate,
OutputFieldTemplate,
OutputFieldValue,
isInputFieldTemplate,
isInputFieldValue,
} from 'features/nodes/types/types';
import { startCase } from 'lodash-es';
interface Props {
nodeData: InvocationNodeData;
nodeTemplate: InvocationTemplate;
field: InputFieldValue | OutputFieldValue;
fieldTemplate: InputFieldTemplate | OutputFieldTemplate;
}
const FieldTooltipContent = ({ field, fieldTemplate }: Props) => {
const isInputTemplate = isInputFieldTemplate(fieldTemplate);
return (
<Flex sx={{ flexDir: 'column' }}>
<Text sx={{ fontWeight: 600 }}>
{isInputFieldValue(field) && field.label
? `${field.label} (${fieldTemplate.title})`
: fieldTemplate.title}
</Text>
<Text sx={{ opacity: 0.7, fontStyle: 'oblique 5deg' }}>
{fieldTemplate.description}
</Text>
<Text>Type: {FIELDS[fieldTemplate.type].title}</Text>
{isInputTemplate && <Text>Input: {startCase(fieldTemplate.input)}</Text>}
</Flex>
);
};
export default FieldTooltipContent;

View File

@ -0,0 +1,153 @@
import { Flex, FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
import { useConnectionState } from 'features/nodes/hooks/useConnectionState';
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
import {
InputFieldValue,
InvocationNodeData,
InvocationTemplate,
} from 'features/nodes/types/types';
import { PropsWithChildren, useMemo } from 'react';
import { NodeProps } from 'reactflow';
import FieldHandle from './FieldHandle';
import FieldTitle from './FieldTitle';
import FieldTooltipContent from './FieldTooltipContent';
import InputFieldRenderer from './InputFieldRenderer';
interface Props {
nodeProps: NodeProps<InvocationNodeData>;
nodeTemplate: InvocationTemplate;
field: InputFieldValue;
}
const InputField = (props: Props) => {
const { nodeProps, nodeTemplate, field } = props;
const { id: nodeId } = nodeProps.data;
const {
isConnected,
isConnectionInProgress,
isConnectionStartField,
connectionError,
shouldDim,
} = useConnectionState({ nodeId, field, kind: 'input' });
const fieldTemplate = useMemo(
() => nodeTemplate.inputs[field.name],
[field.name, nodeTemplate.inputs]
);
const isMissingInput = useMemo(() => {
if (!fieldTemplate) {
return false;
}
if (!fieldTemplate.required) {
return false;
}
if (!isConnected && fieldTemplate.input === 'connection') {
return true;
}
if (!field.value && !isConnected && fieldTemplate.input === 'any') {
return true;
}
}, [fieldTemplate, isConnected, field.value]);
if (!fieldTemplate) {
return (
<InputFieldWrapper shouldDim={shouldDim}>
<FormControl
sx={{ color: 'error.400', textAlign: 'left', fontSize: 'sm' }}
>
Unknown input: {field.name}
</FormControl>
</InputFieldWrapper>
);
}
return (
<InputFieldWrapper shouldDim={shouldDim}>
<FormControl
as={Flex}
isInvalid={isMissingInput}
isDisabled={isConnected}
sx={{
alignItems: 'center',
justifyContent: 'space-between',
ps: 2,
gap: 2,
}}
>
<Tooltip
label={
<FieldTooltipContent
nodeData={nodeProps.data}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
}
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
placement="top"
shouldWrapChildren
hasArrow
>
<FormLabel sx={{ mb: 0 }}>
<FieldTitle
nodeData={nodeProps.data}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
isDraggable
/>
</FormLabel>
</Tooltip>
<InputFieldRenderer
nodeData={nodeProps.data}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
</FormControl>
{fieldTemplate.input !== 'direct' && (
<FieldHandle
nodeProps={nodeProps}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
handleType="target"
isConnectionInProgress={isConnectionInProgress}
isConnectionStartField={isConnectionStartField}
connectionError={connectionError}
/>
)}
</InputFieldWrapper>
);
};
export default InputField;
type InputFieldWrapperProps = PropsWithChildren<{
shouldDim: boolean;
}>;
const InputFieldWrapper = ({ shouldDim, children }: InputFieldWrapperProps) => (
<Flex
className="nopan"
sx={{
position: 'relative',
minH: 8,
py: 0.5,
alignItems: 'center',
opacity: shouldDim ? 0.5 : 1,
transitionProperty: 'opacity',
transitionDuration: '0.1s',
w: 'full',
h: 'full',
}}
>
{children}
</Flex>
);

View File

@ -0,0 +1,293 @@
import { Box } from '@chakra-ui/react';
import { memo } from 'react';
import {
InputFieldTemplate,
InputFieldValue,
InvocationNodeData,
InvocationTemplate,
} from '../../types/types';
import BooleanInputField from './fieldTypes/BooleanInputField';
import ClipInputField from './fieldTypes/ClipInputField';
import CollectionInputField from './fieldTypes/CollectionInputField';
import CollectionItemInputField from './fieldTypes/CollectionItemInputField';
import ColorInputField from './fieldTypes/ColorInputField';
import ConditioningInputField from './fieldTypes/ConditioningInputField';
import ControlInputField from './fieldTypes/ControlInputField';
import ControlNetModelInputField from './fieldTypes/ControlNetModelInputField';
import EnumInputField from './fieldTypes/EnumInputField';
import ImageCollectionInputField from './fieldTypes/ImageCollectionInputField';
import ImageInputField from './fieldTypes/ImageInputField';
import LatentsInputField from './fieldTypes/LatentsInputField';
import LoRAModelInputField from './fieldTypes/LoRAModelInputField';
import MainModelInputField from './fieldTypes/MainModelInputField';
import NumberInputField from './fieldTypes/NumberInputField';
import RefinerModelInputField from './fieldTypes/RefinerModelInputField';
import SDXLMainModelInputField from './fieldTypes/SDXLMainModelInputField';
import StringInputField from './fieldTypes/StringInputField';
import UnetInputField from './fieldTypes/UnetInputField';
import VaeInputField from './fieldTypes/VaeInputField';
import VaeModelInputField from './fieldTypes/VaeModelInputField';
type InputFieldProps = {
nodeData: InvocationNodeData;
nodeTemplate: InvocationTemplate;
field: InputFieldValue;
fieldTemplate: InputFieldTemplate;
};
// build an individual input element based on the schema
const InputFieldRenderer = (props: InputFieldProps) => {
const { nodeData, nodeTemplate, field, fieldTemplate } = props;
const { type } = field;
if (type === 'string' && fieldTemplate.type === 'string') {
return (
<StringInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'boolean' && fieldTemplate.type === 'boolean') {
return (
<BooleanInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
(type === 'integer' && fieldTemplate.type === 'integer') ||
(type === 'float' && fieldTemplate.type === 'float') ||
(type === 'Seed' && fieldTemplate.type === 'Seed')
) {
return (
<NumberInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'enum' && fieldTemplate.type === 'enum') {
return (
<EnumInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'ImageField' && fieldTemplate.type === 'ImageField') {
return (
<ImageInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'LatentsField' && fieldTemplate.type === 'LatentsField') {
return (
<LatentsInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
type === 'ConditioningField' &&
fieldTemplate.type === 'ConditioningField'
) {
return (
<ConditioningInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'UNetField' && fieldTemplate.type === 'UNetField') {
return (
<UnetInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'ClipField' && fieldTemplate.type === 'ClipField') {
return (
<ClipInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'VaeField' && fieldTemplate.type === 'VaeField') {
return (
<VaeInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'ControlField' && fieldTemplate.type === 'ControlField') {
return (
<ControlInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'MainModelField' && fieldTemplate.type === 'MainModelField') {
return (
<MainModelInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
type === 'SDXLRefinerModelField' &&
fieldTemplate.type === 'SDXLRefinerModelField'
) {
return (
<RefinerModelInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'VaeModelField' && fieldTemplate.type === 'VaeModelField') {
return (
<VaeModelInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'LoRAModelField' && fieldTemplate.type === 'LoRAModelField') {
return (
<LoRAModelInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
type === 'ControlNetModelField' &&
fieldTemplate.type === 'ControlNetModelField'
) {
return (
<ControlNetModelInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'Collection' && fieldTemplate.type === 'Collection') {
return (
<CollectionInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'CollectionItem' && fieldTemplate.type === 'CollectionItem') {
return (
<CollectionItemInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'ColorField' && fieldTemplate.type === 'ColorField') {
return (
<ColorInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'ImageCollection' && fieldTemplate.type === 'ImageCollection') {
return (
<ImageCollectionInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
type === 'SDXLMainModelField' &&
fieldTemplate.type === 'SDXLMainModelField'
) {
return (
<SDXLMainModelInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
return <Box p={2}>Unknown field type: {type}</Box>;
};
export default memo(InputFieldRenderer);

View File

@ -1,15 +0,0 @@
import {
ItemInputFieldTemplate,
ItemInputFieldValue,
} from 'features/nodes/types/types';
import { memo } from 'react';
import { FaAddressCard } from 'react-icons/fa';
import { FieldComponentProps } from './types';
const ItemInputFieldComponent = (
_props: FieldComponentProps<ItemInputFieldValue, ItemInputFieldTemplate>
) => {
return <FaAddressCard />;
};
export default memo(ItemInputFieldComponent);

View File

@ -0,0 +1,88 @@
import { Flex, FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
import {
InputFieldTemplate,
InputFieldValue,
InvocationNodeData,
InvocationTemplate,
} from 'features/nodes/types/types';
import { memo } from 'react';
import FieldTitle from './FieldTitle';
import FieldTooltipContent from './FieldTooltipContent';
import InputFieldRenderer from './InputFieldRenderer';
type Props = {
nodeData: InvocationNodeData;
nodeTemplate: InvocationTemplate;
field: InputFieldValue;
fieldTemplate: InputFieldTemplate;
};
const LinearViewField = ({
nodeData,
nodeTemplate,
field,
fieldTemplate,
}: Props) => {
// const dispatch = useAppDispatch();
// const handleRemoveField = useCallback(() => {
// dispatch(
// workflowExposedFieldRemoved({
// nodeId: nodeData.id,
// fieldName: field.name,
// })
// );
// }, [dispatch, field.name, nodeData.id]);
return (
<Flex
layerStyle="second"
sx={{
position: 'relative',
borderRadius: 'base',
w: 'full',
p: 2,
}}
>
<FormControl as={Flex} sx={{ flexDir: 'column', gap: 1, flexShrink: 1 }}>
<Tooltip
label={
<FieldTooltipContent
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
}
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
placement="top"
shouldWrapChildren
hasArrow
>
<FormLabel
sx={{
display: 'flex',
justifyContent: 'space-between',
mb: 0,
}}
>
<FieldTitle
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
</FormLabel>
</Tooltip>
<InputFieldRenderer
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
</FormControl>
</Flex>
);
};
export default memo(LinearViewField);

View File

@ -0,0 +1,114 @@
import {
Flex,
FormControl,
FormLabel,
Spacer,
Tooltip,
} from '@chakra-ui/react';
import { useConnectionState } from 'features/nodes/hooks/useConnectionState';
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
import {
InvocationNodeData,
InvocationTemplate,
OutputFieldValue,
} from 'features/nodes/types/types';
import { PropsWithChildren, useMemo } from 'react';
import { NodeProps } from 'reactflow';
import FieldHandle from './FieldHandle';
import FieldTooltipContent from './FieldTooltipContent';
interface Props {
nodeProps: NodeProps<InvocationNodeData>;
nodeTemplate: InvocationTemplate;
field: OutputFieldValue;
}
const OutputField = (props: Props) => {
const { nodeTemplate, nodeProps, field } = props;
const {
isConnected,
isConnectionInProgress,
isConnectionStartField,
connectionError,
shouldDim,
} = useConnectionState({ nodeId: nodeProps.data.id, field, kind: 'output' });
const fieldTemplate = useMemo(
() => nodeTemplate.outputs[field.name],
[field.name, nodeTemplate]
);
if (!fieldTemplate) {
return (
<OutputFieldWrapper shouldDim={shouldDim}>
<FormControl
sx={{ color: 'error.400', textAlign: 'right', fontSize: 'sm' }}
>
Unknown output: {field.name}
</FormControl>
</OutputFieldWrapper>
);
}
return (
<OutputFieldWrapper shouldDim={shouldDim}>
<Spacer />
<Tooltip
label={
<FieldTooltipContent
nodeData={nodeProps.data}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
}
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
placement="top"
shouldWrapChildren
hasArrow
>
<FormControl isDisabled={isConnected} pe={2}>
<FormLabel sx={{ mb: 0, fontWeight: 500 }}>
{fieldTemplate?.title}
</FormLabel>
</FormControl>
</Tooltip>
<FieldHandle
nodeProps={nodeProps}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
handleType="source"
isConnectionInProgress={isConnectionInProgress}
isConnectionStartField={isConnectionStartField}
connectionError={connectionError}
/>
</OutputFieldWrapper>
);
};
export default OutputField;
type OutputFieldWrapperProps = PropsWithChildren<{
shouldDim: boolean;
}>;
const OutputFieldWrapper = ({
shouldDim,
children,
}: OutputFieldWrapperProps) => (
<Flex
sx={{
position: 'relative',
minH: 8,
py: 0.5,
alignItems: 'center',
opacity: shouldDim ? 0.5 : 1,
transitionProperty: 'opacity',
transitionDuration: '0.1s',
}}
>
{children}
</Flex>
);

View File

@ -1,36 +0,0 @@
import { Input, Textarea } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import {
StringInputFieldTemplate,
StringInputFieldValue,
} from 'features/nodes/types/types';
import { ChangeEvent, memo } from 'react';
import { FieldComponentProps } from './types';
const StringInputFieldComponent = (
props: FieldComponentProps<StringInputFieldValue, StringInputFieldTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const handleValueChanged = (
e: ChangeEvent<HTMLInputElement | HTMLTextAreaElement>
) => {
dispatch(
fieldValueChanged({
nodeId,
fieldName: field.name,
value: e.target.value,
})
);
};
return ['prompt', 'style'].includes(field.name.toLowerCase()) ? (
<Textarea onChange={handleValueChanged} value={field.value} rows={2} />
) : (
<Input onChange={handleValueChanged} value={field.value} />
);
};
export default memo(StringInputFieldComponent);

View File

@ -1,29 +1,33 @@
import { Switch } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import { fieldBooleanValueChanged } from 'features/nodes/store/nodesSlice';
import {
BooleanInputFieldTemplate,
BooleanInputFieldValue,
} from 'features/nodes/types/types';
import { ChangeEvent, memo } from 'react';
import { ChangeEvent, memo, useCallback } from 'react';
import { FieldComponentProps } from './types';
const BooleanInputFieldComponent = (
props: FieldComponentProps<BooleanInputFieldValue, BooleanInputFieldTemplate>
) => {
const { nodeId, field } = props;
const { nodeData, field } = props;
const nodeId = nodeData.id;
const dispatch = useAppDispatch();
const handleValueChanged = (e: ChangeEvent<HTMLInputElement>) => {
dispatch(
fieldValueChanged({
nodeId,
fieldName: field.name,
value: e.target.checked,
})
);
};
const handleValueChanged = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(
fieldBooleanValueChanged({
nodeId,
fieldName: field.name,
value: e.target.checked,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<Switch onChange={handleValueChanged} isChecked={field.value}></Switch>

View File

@ -0,0 +1,17 @@
import {
CollectionInputFieldTemplate,
CollectionInputFieldValue,
} from 'features/nodes/types/types';
import { memo } from 'react';
import { FieldComponentProps } from './types';
const CollectionInputFieldComponent = (
_props: FieldComponentProps<
CollectionInputFieldValue,
CollectionInputFieldTemplate
>
) => {
return null;
};
export default memo(CollectionInputFieldComponent);

View File

@ -0,0 +1,17 @@
import {
CollectionItemInputFieldTemplate,
CollectionItemInputFieldValue,
} from 'features/nodes/types/types';
import { memo } from 'react';
import { FieldComponentProps } from './types';
const CollectionItemInputFieldComponent = (
_props: FieldComponentProps<
CollectionItemInputFieldValue,
CollectionItemInputFieldTemplate
>
) => {
return null;
};
export default memo(CollectionItemInputFieldComponent);

View File

@ -1,23 +1,33 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { fieldColorValueChanged } from 'features/nodes/store/nodesSlice';
import {
ColorInputFieldTemplate,
ColorInputFieldValue,
} from 'features/nodes/types/types';
import { memo } from 'react';
import { FieldComponentProps } from './types';
import { memo, useCallback } from 'react';
import { RgbaColor, RgbaColorPicker } from 'react-colorful';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import { useAppDispatch } from 'app/store/storeHooks';
import { FieldComponentProps } from './types';
const ColorInputFieldComponent = (
props: FieldComponentProps<ColorInputFieldValue, ColorInputFieldTemplate>
) => {
const { nodeId, field } = props;
const { nodeData, field } = props;
const nodeId = nodeData.id;
const dispatch = useAppDispatch();
const handleValueChanged = (value: RgbaColor) => {
dispatch(fieldValueChanged({ nodeId, fieldName: field.name, value }));
};
const handleValueChanged = useCallback(
(value: RgbaColor) => {
dispatch(
fieldColorValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<RgbaColorPicker

View File

@ -1,7 +1,7 @@
import { SelectItem } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import { fieldControlNetModelValueChanged } from 'features/nodes/store/nodesSlice';
import {
ControlNetModelInputFieldTemplate,
ControlNetModelInputFieldValue,
@ -19,7 +19,8 @@ const ControlNetModelInputFieldComponent = (
ControlNetModelInputFieldTemplate
>
) => {
const { nodeId, field } = props;
const { nodeData, field } = props;
const nodeId = nodeData.id;
const controlNetModel = field.value;
const dispatch = useAppDispatch();
@ -73,7 +74,7 @@ const ControlNetModelInputFieldComponent = (
}
dispatch(
fieldValueChanged({
fieldControlNetModelValueChanged({
nodeId,
fieldName: field.name,
value: newControlNetModel,
@ -85,10 +86,8 @@ const ControlNetModelInputFieldComponent = (
return (
<IAIMantineSelect
className="nowheel"
tooltip={selectedModel?.description}
label={
selectedModel?.base_model && MODEL_TYPE_MAP[selectedModel?.base_model]
}
value={selectedModel?.id ?? null}
placeholder="Pick one"
error={!selectedModel}

View File

@ -0,0 +1,45 @@
import { Select } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { fieldEnumModelValueChanged } from 'features/nodes/store/nodesSlice';
import {
EnumInputFieldTemplate,
EnumInputFieldValue,
} from 'features/nodes/types/types';
import { ChangeEvent, memo, useCallback } from 'react';
import { FieldComponentProps } from './types';
const EnumInputFieldComponent = (
props: FieldComponentProps<EnumInputFieldValue, EnumInputFieldTemplate>
) => {
const { nodeData, field, fieldTemplate } = props;
const nodeId = nodeData.id;
const dispatch = useAppDispatch();
const handleValueChanged = useCallback(
(e: ChangeEvent<HTMLSelectElement>) => {
dispatch(
fieldEnumModelValueChanged({
nodeId,
fieldName: field.name,
value: e.target.value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<Select
className="nowheel"
onChange={handleValueChanged}
value={field.value}
>
{fieldTemplate.options.map((option) => (
<option key={option}>{option}</option>
))}
</Select>
);
};
export default memo(EnumInputFieldComponent);

View File

@ -5,13 +5,11 @@ import {
import { memo } from 'react';
import { Flex } from '@chakra-ui/react';
import {
NodesMultiImageDropData,
isValidDrop,
useDroppable,
} from 'app/components/ImageDnd/typesafeDnd';
import IAIDndImage from 'common/components/IAIDndImage';
import IAIDropOverlay from 'common/components/IAIDropOverlay';
import { useDroppableTypesafe } from 'features/dnd/hooks/typesafeHooks';
import { NodesMultiImageDropData } from 'features/dnd/types';
import { isValidDrop } from 'features/dnd/util/isValidDrop';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { FieldComponentProps } from './types';
@ -21,7 +19,8 @@ const ImageCollectionInputFieldComponent = (
ImageCollectionInputFieldTemplate
>
) => {
const { nodeId, field } = props;
const { nodeData, field } = props;
const nodeId = nodeData.id;
// const dispatch = useAppDispatch();
@ -41,14 +40,14 @@ const ImageCollectionInputFieldComponent = (
const droppableData: NodesMultiImageDropData = {
id: `node-${nodeId}-${field.name}`,
actionType: 'SET_MULTI_NODES_IMAGE',
context: { nodeId, fieldName: field.name },
context: { nodeId: nodeId, fieldName: field.name },
};
const {
isOver,
setNodeRef: setDroppableRef,
active,
} = useDroppable({
} = useDroppableTypesafe({
id: `node_${nodeId}`,
data: droppableData,
});

View File

@ -1,12 +1,12 @@
import { Flex } from '@chakra-ui/react';
import { skipToken } from '@reduxjs/toolkit/dist/query';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIDndImage from 'common/components/IAIDndImage';
import {
TypesafeDraggableData,
TypesafeDroppableData,
} from 'app/components/ImageDnd/typesafeDnd';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIDndImage from 'common/components/IAIDndImage';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
} from 'features/dnd/types';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import {
ImageInputFieldTemplate,
ImageInputFieldValue,
@ -19,8 +19,8 @@ import { FieldComponentProps } from './types';
const ImageInputFieldComponent = (
props: FieldComponentProps<ImageInputFieldValue, ImageInputFieldTemplate>
) => {
const { nodeId, field } = props;
const { nodeData, field } = props;
const nodeId = nodeData.id;
const dispatch = useAppDispatch();
const { currentData: imageDTO } = useGetImageDTOQuery(
@ -29,7 +29,7 @@ const ImageInputFieldComponent = (
const handleReset = useCallback(() => {
dispatch(
fieldValueChanged({
fieldImageValueChanged({
nodeId,
fieldName: field.name,
value: undefined,
@ -79,6 +79,9 @@ const ImageInputFieldComponent = (
droppableData={droppableData}
draggableData={draggableData}
onClickReset={handleReset}
withResetIcon
thumbnail
useThumbailFallback
postUploadAction={postUploadAction}
/>
</Flex>

View File

@ -3,7 +3,7 @@ import { SelectItem } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import { fieldLoRAModelValueChanged } from 'features/nodes/store/nodesSlice';
import {
LoRAModelInputFieldTemplate,
LoRAModelInputFieldValue,
@ -21,7 +21,8 @@ const LoRAModelInputFieldComponent = (
LoRAModelInputFieldTemplate
>
) => {
const { nodeId, field } = props;
const { nodeData, field } = props;
const nodeId = nodeData.id;
const lora = field.value;
const dispatch = useAppDispatch();
const { data: loraModels } = useGetLoRAModelsQuery();
@ -68,7 +69,7 @@ const LoRAModelInputFieldComponent = (
}
dispatch(
fieldValueChanged({
fieldLoRAModelValueChanged({
nodeId,
fieldName: field.name,
value: newLoRAModel,
@ -90,11 +91,8 @@ const LoRAModelInputFieldComponent = (
return (
<IAIMantineSearchableSelect
className="nowheel"
value={selectedLoRAModel?.id ?? null}
label={
selectedLoRAModel?.base_model &&
MODEL_TYPE_MAP[selectedLoRAModel?.base_model]
}
placeholder={data.length > 0 ? 'Select a LoRA' : 'No LoRAs available'}
data={data}
nothingFound="No matching LoRAs"

View File

@ -0,0 +1,144 @@
import { Flex, Text } from '@chakra-ui/react';
import { SelectItem } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import {
MainModelInputFieldTemplate,
MainModelInputFieldValue,
} from 'features/nodes/types/types';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { NON_SDXL_MAIN_MODELS } from 'services/api/constants';
import {
useGetMainModelsQuery,
useGetOnnxModelsQuery,
} from 'services/api/endpoints/models';
import { FieldComponentProps } from './types';
const MainModelInputFieldComponent = (
props: FieldComponentProps<
MainModelInputFieldValue,
MainModelInputFieldTemplate
>
) => {
const { nodeData, field } = props;
const nodeId = nodeData.id;
const dispatch = useAppDispatch();
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
const { data: onnxModels, isLoading: isLoadingOnnxModels } =
useGetOnnxModelsQuery(NON_SDXL_MAIN_MODELS);
const { data: mainModels, isLoading: isLoadingMainModels } =
useGetMainModelsQuery(NON_SDXL_MAIN_MODELS);
const isLoadingModels = useMemo(
() => isLoadingOnnxModels || isLoadingMainModels,
[isLoadingOnnxModels, isLoadingMainModels]
);
const data = useMemo(() => {
if (!mainModels) {
return [];
}
const data: SelectItem[] = [];
forEach(mainModels.entities, (model, id) => {
if (!model) {
return;
}
data.push({
value: id,
label: model.model_name,
group: MODEL_TYPE_MAP[model.base_model],
});
});
if (onnxModels) {
forEach(onnxModels.entities, (model, id) => {
if (!model) {
return;
}
data.push({
value: id,
label: model.model_name,
group: MODEL_TYPE_MAP[model.base_model],
});
});
}
return data;
}, [mainModels, onnxModels]);
// grab the full model entity from the RTK Query cache
// TODO: maybe we should just store the full model entity in state?
const selectedModel = useMemo(
() =>
(mainModels?.entities[
`${field.value?.base_model}/main/${field.value?.model_name}`
] ||
onnxModels?.entities[
`${field.value?.base_model}/onnx/${field.value?.model_name}`
]) ??
null,
[
field.value?.base_model,
field.value?.model_name,
mainModels?.entities,
onnxModels?.entities,
]
);
const handleChangeModel = useCallback(
(v: string | null) => {
if (!v) {
return;
}
const newModel = modelIdToMainModelParam(v);
if (!newModel) {
return;
}
dispatch(
fieldMainModelValueChanged({
nodeId,
fieldName: field.name,
value: newModel,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<Flex sx={{ w: 'full', alignItems: 'center', gap: 2 }}>
{isLoadingModels ? (
<Text variant="subtext">Loading...</Text>
) : (
<IAIMantineSearchableSelect
className="nowheel"
tooltip={selectedModel?.description}
value={selectedModel?.id}
placeholder={
data.length > 0 ? 'Select a model' : 'No models available'
}
data={data}
error={!selectedModel}
disabled={data.length === 0}
onChange={handleChangeModel}
/>
)}
{isSyncModelEnabled && <SyncModelsButton iconMode />}
</Flex>
);
};
export default memo(MainModelInputFieldComponent);

View File

@ -7,27 +7,34 @@ import {
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { numberStringRegex } from 'common/components/IAINumberInput';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import { fieldNumberValueChanged } from 'features/nodes/store/nodesSlice';
import {
FloatInputFieldTemplate,
FloatInputFieldValue,
IntegerInputFieldTemplate,
IntegerInputFieldValue,
SeedInputFieldTemplate,
SeedInputFieldValue,
} from 'features/nodes/types/types';
import { memo, useEffect, useState } from 'react';
import { memo, useEffect, useMemo, useState } from 'react';
import { FieldComponentProps } from './types';
const NumberInputFieldComponent = (
props: FieldComponentProps<
IntegerInputFieldValue | FloatInputFieldValue,
IntegerInputFieldTemplate | FloatInputFieldTemplate
IntegerInputFieldValue | FloatInputFieldValue | SeedInputFieldValue,
IntegerInputFieldTemplate | FloatInputFieldTemplate | SeedInputFieldTemplate
>
) => {
const { nodeId, field } = props;
const { nodeData, field, fieldTemplate } = props;
const nodeId = nodeData.id;
const dispatch = useAppDispatch();
const [valueAsString, setValueAsString] = useState<string>(
String(field.value)
);
const isIntegerField = useMemo(
() => fieldTemplate.type === 'integer' || fieldTemplate.type === 'Seed',
[fieldTemplate.type]
);
const handleValueChanged = (v: string) => {
setValueAsString(v);
@ -35,13 +42,10 @@ const NumberInputFieldComponent = (
if (!v.match(numberStringRegex)) {
// Cast the value to number. Floor it if it should be an integer.
dispatch(
fieldValueChanged({
fieldNumberValueChanged({
nodeId,
fieldName: field.name,
value:
props.template.type === 'integer'
? Math.floor(Number(v))
: Number(v),
value: isIntegerField ? Math.floor(Number(v)) : Number(v),
})
);
}
@ -60,8 +64,8 @@ const NumberInputFieldComponent = (
<NumberInput
onChange={handleValueChanged}
value={valueAsString}
step={props.template.type === 'integer' ? 1 : 0.1}
precision={props.template.type === 'integer' ? 0 : 3}
step={isIntegerField ? 1 : 0.1}
precision={isIntegerField ? 0 : 3}
>
<NumberInputField />
<NumberInputStepper>

Some files were not shown because too many files have changed in this diff Show More