merge with main

This commit is contained in:
Lincoln Stein 2023-06-26 13:53:59 -04:00
commit 011adfc958
54 changed files with 1507 additions and 846 deletions

View File

@ -259,8 +259,8 @@ def _convert_ckpt_and_cache(
""" """
app_config = InvokeAIAppConfig.get_config() app_config = InvokeAIAppConfig.get_config()
weights = app_config.root_dir / model_config.path weights = app_config.root_path / model_config.path
config_file = app_config.root_dir / model_config.config config_file = app_config.root_path / model_config.config
output_path = Path(output_path) output_path = Path(output_path)
# return cached version if it exists # return cached version if it exists

View File

@ -0,0 +1,14 @@
import react from '@vitejs/plugin-react-swc';
import { visualizer } from 'rollup-plugin-visualizer';
import { PluginOption, UserConfig } from 'vite';
import eslint from 'vite-plugin-eslint';
import tsconfigPaths from 'vite-tsconfig-paths';
import { nodePolyfills } from 'vite-plugin-node-polyfills';
export const commonPlugins: UserConfig['plugins'] = [
react(),
eslint(),
tsconfigPaths(),
visualizer() as unknown as PluginOption,
nodePolyfills(),
];

View File

@ -1,17 +1,9 @@
import react from '@vitejs/plugin-react-swc'; import { UserConfig } from 'vite';
import { visualizer } from 'rollup-plugin-visualizer'; import { commonPlugins } from './common';
import { PluginOption, UserConfig } from 'vite';
import eslint from 'vite-plugin-eslint';
import tsconfigPaths from 'vite-tsconfig-paths';
export const appConfig: UserConfig = { export const appConfig: UserConfig = {
base: './', base: './',
plugins: [ plugins: [...commonPlugins],
react(),
eslint(),
tsconfigPaths(),
visualizer() as unknown as PluginOption,
],
build: { build: {
chunkSizeWarningLimit: 1500, chunkSizeWarningLimit: 1500,
}, },

View File

@ -1,19 +1,13 @@
import react from '@vitejs/plugin-react-swc';
import path from 'path'; import path from 'path';
import { visualizer } from 'rollup-plugin-visualizer'; import { UserConfig } from 'vite';
import { PluginOption, UserConfig } from 'vite';
import dts from 'vite-plugin-dts'; import dts from 'vite-plugin-dts';
import eslint from 'vite-plugin-eslint';
import tsconfigPaths from 'vite-tsconfig-paths';
import cssInjectedByJsPlugin from 'vite-plugin-css-injected-by-js'; import cssInjectedByJsPlugin from 'vite-plugin-css-injected-by-js';
import { commonPlugins } from './common';
export const packageConfig: UserConfig = { export const packageConfig: UserConfig = {
base: './', base: './',
plugins: [ plugins: [
react(), ...commonPlugins,
eslint(),
tsconfigPaths(),
visualizer() as unknown as PluginOption,
dts({ dts({
insertTypesEntry: true, insertTypesEntry: true,
}), }),

View File

@ -53,6 +53,7 @@
] ]
}, },
"dependencies": { "dependencies": {
"@apidevtools/swagger-parser": "^10.1.0",
"@chakra-ui/anatomy": "^2.1.1", "@chakra-ui/anatomy": "^2.1.1",
"@chakra-ui/icons": "^2.0.19", "@chakra-ui/icons": "^2.0.19",
"@chakra-ui/react": "^2.7.1", "@chakra-ui/react": "^2.7.1",
@ -154,6 +155,7 @@
"vite-plugin-css-injected-by-js": "^3.1.1", "vite-plugin-css-injected-by-js": "^3.1.1",
"vite-plugin-dts": "^2.3.0", "vite-plugin-dts": "^2.3.0",
"vite-plugin-eslint": "^1.8.1", "vite-plugin-eslint": "^1.8.1",
"vite-plugin-node-polyfills": "^0.9.0",
"vite-tsconfig-paths": "^4.2.0", "vite-tsconfig-paths": "^4.2.0",
"yarn": "^1.22.19" "yarn": "^1.22.19"
} }

View File

@ -15,7 +15,7 @@ import { ImageDTO } from 'services/api/types';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { canvasSelector } from 'features/canvas/store/canvasSelectors'; import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { controlNetSelector } from 'features/controlNet/store/controlNetSlice'; import { controlNetSelector } from 'features/controlNet/store/controlNetSlice';
import { nodesSelecter } from 'features/nodes/store/nodesSlice'; import { nodesSelector } from 'features/nodes/store/nodesSlice';
import { generationSelector } from 'features/parameters/store/generationSelectors'; import { generationSelector } from 'features/parameters/store/generationSelectors';
import { some } from 'lodash-es'; import { some } from 'lodash-es';
@ -30,7 +30,7 @@ export const selectImageUsage = createSelector(
[ [
generationSelector, generationSelector,
canvasSelector, canvasSelector,
nodesSelecter, nodesSelector,
controlNetSelector, controlNetSelector,
(state: RootState, image_name?: string) => image_name, (state: RootState, image_name?: string) => image_name,
], ],

View File

@ -1,6 +1,7 @@
import { AnyAction } from '@reduxjs/toolkit'; import { AnyAction } from '@reduxjs/toolkit';
import { isAnyGraphBuilt } from 'features/nodes/store/actions'; import { isAnyGraphBuilt } from 'features/nodes/store/actions';
import { forEach } from 'lodash-es'; import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
import { Graph } from 'services/api/types'; import { Graph } from 'services/api/types';
export const actionSanitizer = <A extends AnyAction>(action: A): A => { export const actionSanitizer = <A extends AnyAction>(action: A): A => {
@ -8,17 +9,6 @@ export const actionSanitizer = <A extends AnyAction>(action: A): A => {
if (action.payload.nodes) { if (action.payload.nodes) {
const sanitizedNodes: Graph['nodes'] = {}; const sanitizedNodes: Graph['nodes'] = {};
// Sanitize nodes as needed
forEach(action.payload.nodes, (node, key) => {
// Don't log the whole freaking dataURL
if (node.type === 'dataURL_image') {
const { dataURL, ...rest } = node;
sanitizedNodes[key] = { ...rest, dataURL: '<dataURL>' };
} else {
sanitizedNodes[key] = { ...node };
}
});
return { return {
...action, ...action,
payload: { ...action.payload, nodes: sanitizedNodes }, payload: { ...action.payload, nodes: sanitizedNodes },
@ -26,5 +16,19 @@ export const actionSanitizer = <A extends AnyAction>(action: A): A => {
} }
} }
if (receivedOpenAPISchema.fulfilled.match(action)) {
return {
...action,
payload: '<OpenAPI schema omitted>',
};
}
if (nodeTemplatesBuilt.match(action)) {
return {
...action,
payload: '<Node templates omitted>',
};
}
return action; return action;
}; };

View File

@ -82,6 +82,7 @@ import {
addImageRemovedFromBoardFulfilledListener, addImageRemovedFromBoardFulfilledListener,
addImageRemovedFromBoardRejectedListener, addImageRemovedFromBoardRejectedListener,
} from './listeners/imageRemovedFromBoard'; } from './listeners/imageRemovedFromBoard';
import { addReceivedOpenAPISchemaListener } from './listeners/receivedOpenAPISchema';
export const listenerMiddleware = createListenerMiddleware(); export const listenerMiddleware = createListenerMiddleware();
@ -205,3 +206,6 @@ addImageAddedToBoardRejectedListener();
addImageRemovedFromBoardFulfilledListener(); addImageRemovedFromBoardFulfilledListener();
addImageRemovedFromBoardRejectedListener(); addImageRemovedFromBoardRejectedListener();
addBoardIdSelectedListener(); addBoardIdSelectedListener();
// Node schemas
addReceivedOpenAPISchemaListener();

View File

@ -0,0 +1,35 @@
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
import { startAppListening } from '..';
import { log } from 'app/logging/useLogger';
import { parseSchema } from 'features/nodes/util/parseSchema';
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
import { size } from 'lodash-es';
const schemaLog = log.child({ namespace: 'schema' });
export const addReceivedOpenAPISchemaListener = () => {
startAppListening({
actionCreator: receivedOpenAPISchema.fulfilled,
effect: (action, { dispatch, getState }) => {
const schemaJSON = action.payload;
schemaLog.info({ data: { schemaJSON } }, 'Dereferenced OpenAPI schema');
const nodeTemplates = parseSchema(schemaJSON);
schemaLog.info(
{ data: { nodeTemplates } },
`Built ${size(nodeTemplates)} node templates`
);
dispatch(nodeTemplatesBuilt(nodeTemplates));
},
});
startAppListening({
actionCreator: receivedOpenAPISchema.rejected,
effect: (action, { dispatch, getState }) => {
schemaLog.error('Problem dereferencing OpenAPI Schema');
},
});
};

View File

@ -3,7 +3,7 @@ import { startAppListening } from '..';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { generationSelector } from 'features/parameters/store/generationSelectors'; import { generationSelector } from 'features/parameters/store/generationSelectors';
import { canvasSelector } from 'features/canvas/store/canvasSelectors'; import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { nodesSelecter } from 'features/nodes/store/nodesSlice'; import { nodesSelector } from 'features/nodes/store/nodesSlice';
import { controlNetSelector } from 'features/controlNet/store/controlNetSlice'; import { controlNetSelector } from 'features/controlNet/store/controlNetSlice';
import { forEach, uniqBy } from 'lodash-es'; import { forEach, uniqBy } from 'lodash-es';
import { imageUrlsReceived } from 'services/api/thunks/image'; import { imageUrlsReceived } from 'services/api/thunks/image';
@ -16,7 +16,7 @@ const selectAllUsedImages = createSelector(
[ [
generationSelector, generationSelector,
canvasSelector, canvasSelector,
nodesSelecter, nodesSelector,
controlNetSelector, controlNetSelector,
selectImagesEntities, selectImagesEntities,
], ],

View File

@ -22,6 +22,7 @@ import boardsReducer from 'features/gallery/store/boardSlice';
import configReducer from 'features/system/store/configSlice'; import configReducer from 'features/system/store/configSlice';
import hotkeysReducer from 'features/ui/store/hotkeysSlice'; import hotkeysReducer from 'features/ui/store/hotkeysSlice';
import uiReducer from 'features/ui/store/uiSlice'; import uiReducer from 'features/ui/store/uiSlice';
import dynamicPromptsReducer from 'features/dynamicPrompts/store/slice';
import { listenerMiddleware } from './middleware/listenerMiddleware'; import { listenerMiddleware } from './middleware/listenerMiddleware';
@ -48,6 +49,7 @@ const allReducers = {
controlNet: controlNetReducer, controlNet: controlNetReducer,
boards: boardsReducer, boards: boardsReducer,
// session: sessionReducer, // session: sessionReducer,
dynamicPrompts: dynamicPromptsReducer,
[api.reducerPath]: api.reducer, [api.reducerPath]: api.reducer,
}; };
@ -65,6 +67,7 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
'system', 'system',
'ui', 'ui',
'controlNet', 'controlNet',
'dynamicPrompts',
// 'boards', // 'boards',
// 'hotkeys', // 'hotkeys',
// 'config', // 'config',
@ -100,3 +103,4 @@ export type AppGetState = typeof store.getState;
export type RootState = ReturnType<typeof store.getState>; export type RootState = ReturnType<typeof store.getState>;
export type AppThunkDispatch = ThunkDispatch<RootState, any, AnyAction>; export type AppThunkDispatch = ThunkDispatch<RootState, any, AnyAction>;
export type AppDispatch = typeof store.dispatch; export type AppDispatch = typeof store.dispatch;
export const stateSelector = (state: RootState) => state;

View File

@ -171,6 +171,14 @@ export type AppConfig = {
fineStep: number; fineStep: number;
coarseStep: number; coarseStep: number;
}; };
dynamicPrompts: {
maxPrompts: {
initial: number;
min: number;
sliderMax: number;
inputMax: number;
};
};
}; };
}; };

View File

@ -27,7 +27,6 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
borderWidth: '2px', borderWidth: '2px',
borderColor: 'var(--invokeai-colors-base-800)', borderColor: 'var(--invokeai-colors-base-800)',
color: 'var(--invokeai-colors-base-100)', color: 'var(--invokeai-colors-base-100)',
padding: 10,
paddingRight: 24, paddingRight: 24,
fontWeight: 600, fontWeight: 600,
'&:hover': { borderColor: 'var(--invokeai-colors-base-700)' }, '&:hover': { borderColor: 'var(--invokeai-colors-base-700)' },

View File

@ -34,6 +34,10 @@ const IAIMantineSelect = (props: IAISelectProps) => {
'&:focus': { '&:focus': {
borderColor: 'var(--invokeai-colors-accent-600)', borderColor: 'var(--invokeai-colors-accent-600)',
}, },
'&:disabled': {
backgroundColor: 'var(--invokeai-colors-base-700)',
color: 'var(--invokeai-colors-base-400)',
},
}, },
dropdown: { dropdown: {
backgroundColor: 'var(--invokeai-colors-base-800)', backgroundColor: 'var(--invokeai-colors-base-800)',
@ -64,7 +68,7 @@ const IAIMantineSelect = (props: IAISelectProps) => {
}, },
}, },
rightSection: { rightSection: {
width: 24, width: 32,
}, },
})} })}
{...rest} {...rest}

View File

@ -41,7 +41,15 @@ const IAISwitch = (props: Props) => {
{...formControlProps} {...formControlProps}
> >
{label && ( {label && (
<FormLabel my={1} flexGrow={1} {...formLabelProps}> <FormLabel
my={1}
flexGrow={1}
sx={{
cursor: isDisabled ? 'not-allowed' : 'pointer',
...formLabelProps?.sx,
}}
{...formLabelProps}
>
{label} {label}
</FormLabel> </FormLabel>
)} )}

View File

@ -9,10 +9,12 @@ type IAICanvasImageProps = {
}; };
const IAICanvasImage = (props: IAICanvasImageProps) => { const IAICanvasImage = (props: IAICanvasImageProps) => {
const { width, height, x, y, imageName } = props.canvasImage; const { width, height, x, y, imageName } = props.canvasImage;
const { currentData: imageDTO } = useGetImageDTOQuery(imageName ?? skipToken); const { currentData: imageDTO, isError } = useGetImageDTOQuery(
imageName ?? skipToken
);
const [image] = useImage(imageDTO?.image_url ?? '', 'anonymous'); const [image] = useImage(imageDTO?.image_url ?? '', 'anonymous');
if (!imageDTO) { if (isError) {
return <Rect x={x} y={y} width={width} height={height} fill="red" />; return <Rect x={x} y={y} width={width} height={height} fill="red" />;
} }

View File

@ -174,7 +174,10 @@ const ControlNet = (props: ControlNetProps) => {
aspectRatio: '1/1', aspectRatio: '1/1',
}} }}
> >
<ControlNetImagePreview controlNet={props.controlNet} /> <ControlNetImagePreview
controlNet={props.controlNet}
height={24}
/>
</Flex> </Flex>
)} )}
</Flex> </Flex>

View File

@ -0,0 +1,45 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAICollapse from 'common/components/IAICollapse';
import { useCallback } from 'react';
import { isEnabledToggled } from '../store/slice';
import ParamDynamicPromptsMaxPrompts from './ParamDynamicPromptsMaxPrompts';
import ParamDynamicPromptsCombinatorial from './ParamDynamicPromptsCombinatorial';
import { Flex } from '@chakra-ui/react';
const selector = createSelector(
stateSelector,
(state) => {
const { isEnabled } = state.dynamicPrompts;
return { isEnabled };
},
defaultSelectorOptions
);
const ParamDynamicPromptsCollapse = () => {
const dispatch = useAppDispatch();
const { isEnabled } = useAppSelector(selector);
const handleToggleIsEnabled = useCallback(() => {
dispatch(isEnabledToggled());
}, [dispatch]);
return (
<IAICollapse
isOpen={isEnabled}
onToggle={handleToggleIsEnabled}
label="Dynamic Prompts"
withSwitch
>
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<ParamDynamicPromptsMaxPrompts />
<ParamDynamicPromptsCombinatorial />
</Flex>
</IAICollapse>
);
};
export default ParamDynamicPromptsCollapse;

View File

@ -0,0 +1,36 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { combinatorialToggled } from '../store/slice';
import { createSelector } from '@reduxjs/toolkit';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useCallback } from 'react';
import { stateSelector } from 'app/store/store';
import IAISwitch from 'common/components/IAISwitch';
const selector = createSelector(
stateSelector,
(state) => {
const { combinatorial } = state.dynamicPrompts;
return { combinatorial };
},
defaultSelectorOptions
);
const ParamDynamicPromptsCombinatorial = () => {
const { combinatorial } = useAppSelector(selector);
const dispatch = useAppDispatch();
const handleChange = useCallback(() => {
dispatch(combinatorialToggled());
}, [dispatch]);
return (
<IAISwitch
label="Combinatorial Generation"
isChecked={combinatorial}
onChange={handleChange}
/>
);
};
export default ParamDynamicPromptsCombinatorial;

View File

@ -0,0 +1,53 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { maxPromptsChanged, maxPromptsReset } from '../store/slice';
import { createSelector } from '@reduxjs/toolkit';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useCallback } from 'react';
import { stateSelector } from 'app/store/store';
const selector = createSelector(
stateSelector,
(state) => {
const { maxPrompts } = state.dynamicPrompts;
const { min, sliderMax, inputMax } =
state.config.sd.dynamicPrompts.maxPrompts;
return { maxPrompts, min, sliderMax, inputMax };
},
defaultSelectorOptions
);
const ParamDynamicPromptsMaxPrompts = () => {
const { maxPrompts, min, sliderMax, inputMax } = useAppSelector(selector);
const dispatch = useAppDispatch();
const handleChange = useCallback(
(v: number) => {
dispatch(maxPromptsChanged(v));
},
[dispatch]
);
const handleReset = useCallback(() => {
dispatch(maxPromptsReset());
}, [dispatch]);
return (
<IAISlider
label="Max Prompts"
min={min}
max={sliderMax}
value={maxPrompts}
onChange={handleChange}
sliderNumberInputProps={{ max: inputMax }}
withSliderMarks
withInput
inputReadOnly
withReset
handleReset={handleReset}
/>
);
};
export default ParamDynamicPromptsMaxPrompts;

View File

@ -0,0 +1,50 @@
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
export interface DynamicPromptsState {
isEnabled: boolean;
maxPrompts: number;
combinatorial: boolean;
}
export const initialDynamicPromptsState: DynamicPromptsState = {
isEnabled: false,
maxPrompts: 100,
combinatorial: true,
};
const initialState: DynamicPromptsState = initialDynamicPromptsState;
export const dynamicPromptsSlice = createSlice({
name: 'dynamicPrompts',
initialState,
reducers: {
maxPromptsChanged: (state, action: PayloadAction<number>) => {
state.maxPrompts = action.payload;
},
maxPromptsReset: (state) => {
state.maxPrompts = initialDynamicPromptsState.maxPrompts;
},
combinatorialToggled: (state) => {
state.combinatorial = !state.combinatorial;
},
isEnabledToggled: (state) => {
state.isEnabled = !state.isEnabled;
},
},
extraReducers: (builder) => {
//
},
});
export const {
isEnabledToggled,
maxPromptsChanged,
maxPromptsReset,
combinatorialToggled,
} = dynamicPromptsSlice.actions;
export default dynamicPromptsSlice.reducer;
export const dynamicPromptsSelector = (state: RootState) =>
state.dynamicPrompts;

View File

@ -1,28 +1,41 @@
import 'reactflow/dist/style.css'; import 'reactflow/dist/style.css';
import { memo, useCallback } from 'react'; import { useCallback, forwardRef } from 'react';
import { import { Flex, Text } from '@chakra-ui/react';
Tooltip,
Menu,
MenuButton,
MenuList,
MenuItem,
} from '@chakra-ui/react';
import { FaEllipsisV } from 'react-icons/fa';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { nodeAdded } from '../store/nodesSlice'; import { nodeAdded, nodesSelector } from '../store/nodesSlice';
import { map } from 'lodash-es'; import { map } from 'lodash-es';
import { RootState } from 'app/store/store';
import { useBuildInvocation } from '../hooks/useBuildInvocation'; import { useBuildInvocation } from '../hooks/useBuildInvocation';
import { AnyInvocationType } from 'services/events/types'; import { AnyInvocationType } from 'services/events/types';
import IAIIconButton from 'common/components/IAIIconButton';
import { useAppToaster } from 'app/components/Toaster'; import { useAppToaster } from 'app/components/Toaster';
import { createSelector } from '@reduxjs/toolkit';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
type NodeTemplate = {
label: string;
value: string;
description: string;
};
const selector = createSelector(
nodesSelector,
(nodes) => {
const data: NodeTemplate[] = map(nodes.invocationTemplates, (template) => {
return {
label: template.title,
value: template.type,
description: template.description,
};
});
return { data };
},
defaultSelectorOptions
);
const AddNodeMenu = () => { const AddNodeMenu = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { data } = useAppSelector(selector);
const invocationTemplates = useAppSelector(
(state: RootState) => state.nodes.invocationTemplates
);
const buildInvocation = useBuildInvocation(); const buildInvocation = useBuildInvocation();
@ -46,23 +59,52 @@ const AddNodeMenu = () => {
); );
return ( return (
<Menu isLazy> <Flex sx={{ gap: 2, alignItems: 'center' }}>
<MenuButton <IAIMantineMultiSelect
as={IAIIconButton} selectOnBlur={false}
aria-label="Add Node" placeholder="Add Node"
icon={<FaEllipsisV />} value={[]}
data={data}
maxDropdownHeight={400}
nothingFound="No matching nodes"
itemComponent={SelectItem}
filter={(value, selected, item: NodeTemplate) =>
item.label.toLowerCase().includes(value.toLowerCase().trim()) ||
item.value.toLowerCase().includes(value.toLowerCase().trim()) ||
item.description.toLowerCase().includes(value.toLowerCase().trim())
}
onChange={(v) => {
v[0] && addNode(v[0] as AnyInvocationType);
}}
sx={{
width: '18rem',
}}
/> />
<MenuList overflowY="scroll" height={400}> </Flex>
{map(invocationTemplates, ({ title, description, type }, key) => {
return (
<Tooltip key={key} label={description} placement="end" hasArrow>
<MenuItem onClick={() => addNode(type)}>{title}</MenuItem>
</Tooltip>
);
})}
</MenuList>
</Menu>
); );
}; };
export default memo(AddNodeMenu); interface ItemProps extends React.ComponentPropsWithoutRef<'div'> {
value: string;
label: string;
description: string;
}
const SelectItem = forwardRef<HTMLDivElement, ItemProps>(
({ label, description, ...others }: ItemProps, ref) => {
return (
<div ref={ref} {...others}>
<div>
<Text>{label}</Text>
<Text size="xs" color="base.600">
{description}
</Text>
</div>
</div>
);
}
);
SelectItem.displayName = 'SelectItem';
export default AddNodeMenu;

View File

@ -1,10 +1,10 @@
import { memo } from 'react'; import { memo } from 'react';
import { Panel } from 'reactflow'; import { Panel } from 'reactflow';
import NodeSearch from '../search/NodeSearch'; import AddNodeMenu from '../AddNodeMenu';
const TopLeftPanel = () => ( const TopLeftPanel = () => (
<Panel position="top-left"> <Panel position="top-left">
<NodeSearch /> <AddNodeMenu />
</Panel> </Panel>
); );

View File

@ -14,9 +14,6 @@ import {
import { ImageField } from 'services/api/types'; import { ImageField } from 'services/api/types';
import { receivedOpenAPISchema } from 'services/api/thunks/schema'; import { receivedOpenAPISchema } from 'services/api/thunks/schema';
import { InvocationTemplate, InvocationValue } from '../types/types'; import { InvocationTemplate, InvocationValue } from '../types/types';
import { parseSchema } from '../util/parseSchema';
import { log } from 'app/logging/useLogger';
import { size } from 'lodash-es';
import { RgbaColor } from 'react-colorful'; import { RgbaColor } from 'react-colorful';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
@ -78,25 +75,17 @@ const nodesSlice = createSlice({
shouldShowGraphOverlayChanged: (state, action: PayloadAction<boolean>) => { shouldShowGraphOverlayChanged: (state, action: PayloadAction<boolean>) => {
state.shouldShowGraphOverlay = action.payload; state.shouldShowGraphOverlay = action.payload;
}, },
parsedOpenAPISchema: (state, action: PayloadAction<OpenAPIV3.Document>) => { nodeTemplatesBuilt: (
try { state,
const parsedSchema = parseSchema(action.payload); action: PayloadAction<Record<string, InvocationTemplate>>
) => {
// TODO: Achtung! Side effect in a reducer! state.invocationTemplates = action.payload;
log.info(
{ namespace: 'schema', nodes: parsedSchema },
`Parsed ${size(parsedSchema)} nodes`
);
state.invocationTemplates = parsedSchema;
} catch (err) {
console.error(err);
}
}, },
nodeEditorReset: () => { nodeEditorReset: () => {
return { ...initialNodesState }; return { ...initialNodesState };
}, },
}, },
extraReducers(builder) { extraReducers: (builder) => {
builder.addCase(receivedOpenAPISchema.fulfilled, (state, action) => { builder.addCase(receivedOpenAPISchema.fulfilled, (state, action) => {
state.schema = action.payload; state.schema = action.payload;
}); });
@ -112,10 +101,10 @@ export const {
connectionStarted, connectionStarted,
connectionEnded, connectionEnded,
shouldShowGraphOverlayChanged, shouldShowGraphOverlayChanged,
parsedOpenAPISchema, nodeTemplatesBuilt,
nodeEditorReset, nodeEditorReset,
} = nodesSlice.actions; } = nodesSlice.actions;
export default nodesSlice.reducer; export default nodesSlice.reducer;
export const nodesSelecter = (state: RootState) => state.nodes; export const nodesSelector = (state: RootState) => state.nodes;

View File

@ -34,12 +34,10 @@ export type InvocationTemplate = {
* Array of invocation inputs * Array of invocation inputs
*/ */
inputs: Record<string, InputFieldTemplate>; inputs: Record<string, InputFieldTemplate>;
// inputs: InputField[];
/** /**
* Array of the invocation outputs * Array of the invocation outputs
*/ */
outputs: Record<string, OutputFieldTemplate>; outputs: Record<string, OutputFieldTemplate>;
// outputs: OutputField[];
}; };
export type FieldUIConfig = { export type FieldUIConfig = {
@ -335,7 +333,7 @@ export type TypeHints = {
}; };
export type InvocationSchemaExtra = { export type InvocationSchemaExtra = {
output: OpenAPIV3.ReferenceObject; // the output of the invocation output: OpenAPIV3.SchemaObject; // the output of the invocation
ui?: { ui?: {
tags?: string[]; tags?: string[];
type_hints?: TypeHints; type_hints?: TypeHints;

View File

@ -1,5 +1,5 @@
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { filter, forEach, size } from 'lodash-es'; import { filter } from 'lodash-es';
import { CollectInvocation, ControlNetInvocation } from 'services/api/types'; import { CollectInvocation, ControlNetInvocation } from 'services/api/types';
import { NonNullableGraph } from '../types/types'; import { NonNullableGraph } from '../types/types';
import { CONTROL_NET_COLLECT } from './graphBuilders/constants'; import { CONTROL_NET_COLLECT } from './graphBuilders/constants';
@ -19,9 +19,9 @@ export const addControlNetToLinearGraph = (
(c.processorType === 'none' && Boolean(c.controlImage))) (c.processorType === 'none' && Boolean(c.controlImage)))
); );
// Add ControlNet if (isControlNetEnabled && Boolean(validControlNets.length)) {
if (isControlNetEnabled && validControlNets.length > 0) { if (validControlNets.length > 1) {
if (size(controlNets) > 1) { // We have multiple controlnets, add ControlNet collector
const controlNetIterateNode: CollectInvocation = { const controlNetIterateNode: CollectInvocation = {
id: CONTROL_NET_COLLECT, id: CONTROL_NET_COLLECT,
type: 'collect', type: 'collect',
@ -36,10 +36,9 @@ export const addControlNetToLinearGraph = (
}); });
} }
forEach(controlNets, (controlNet) => { validControlNets.forEach((controlNet) => {
const { const {
controlNetId, controlNetId,
isEnabled,
controlImage, controlImage,
processedControlImage, processedControlImage,
beginStepPct, beginStepPct,
@ -50,11 +49,6 @@ export const addControlNetToLinearGraph = (
weight, weight,
} = controlNet; } = controlNet;
if (!isEnabled) {
// Skip disabled ControlNets
return;
}
const controlNetNode: ControlNetInvocation = { const controlNetNode: ControlNetInvocation = {
id: `control_net_${controlNetId}`, id: `control_net_${controlNetId}`,
type: 'controlnet', type: 'controlnet',
@ -82,7 +76,8 @@ export const addControlNetToLinearGraph = (
graph.nodes[controlNetNode.id] = controlNetNode; graph.nodes[controlNetNode.id] = controlNetNode;
if (size(controlNets) > 1) { if (validControlNets.length > 1) {
// if we have multiple controlnets, link to the collector
graph.edges.push({ graph.edges.push({
source: { node_id: controlNetNode.id, field: 'control' }, source: { node_id: controlNetNode.id, field: 'control' },
destination: { destination: {
@ -91,6 +86,7 @@ export const addControlNetToLinearGraph = (
}, },
}); });
} else { } else {
// otherwise, link directly to the base node
graph.edges.push({ graph.edges.push({
source: { node_id: controlNetNode.id, field: 'control' }, source: { node_id: controlNetNode.id, field: 'control' },
destination: { destination: {

View File

@ -349,21 +349,11 @@ export const getFieldType = (
if (typeHints && name in typeHints) { if (typeHints && name in typeHints) {
rawFieldType = typeHints[name]; rawFieldType = typeHints[name];
} else if (!schemaObject.type) { } else if (!schemaObject.type && schemaObject.allOf) {
// if schemaObject has no type, then it should have one of allOf, anyOf, oneOf // if schemaObject has no type, then it should have one of allOf
if (schemaObject.allOf) { rawFieldType =
rawFieldType = refObjectToFieldType( (schemaObject.allOf[0] as OpenAPIV3.SchemaObject).title ??
schemaObject.allOf![0] as OpenAPIV3.ReferenceObject 'Missing Field Type';
);
} else if (schemaObject.anyOf) {
rawFieldType = refObjectToFieldType(
schemaObject.anyOf![0] as OpenAPIV3.ReferenceObject
);
} else if (schemaObject.oneOf) {
rawFieldType = refObjectToFieldType(
schemaObject.oneOf![0] as OpenAPIV3.ReferenceObject
);
}
} else if (schemaObject.enum) { } else if (schemaObject.enum) {
rawFieldType = 'enum'; rawFieldType = 'enum';
} else if (schemaObject.type) { } else if (schemaObject.type) {

View File

@ -0,0 +1,153 @@
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import {
DynamicPromptInvocation,
IterateInvocation,
NoiseInvocation,
RandomIntInvocation,
RangeOfSizeInvocation,
} from 'services/api/types';
import {
DYNAMIC_PROMPT,
ITERATE,
NOISE,
POSITIVE_CONDITIONING,
RANDOM_INT,
RANGE_OF_SIZE,
} from './constants';
import { unset } from 'lodash-es';
export const addDynamicPromptsToGraph = (
graph: NonNullableGraph,
state: RootState
): void => {
const { positivePrompt, iterations, seed, shouldRandomizeSeed } =
state.generation;
const {
combinatorial,
isEnabled: isDynamicPromptsEnabled,
maxPrompts,
} = state.dynamicPrompts;
if (isDynamicPromptsEnabled) {
// iteration is handled via dynamic prompts
unset(graph.nodes[POSITIVE_CONDITIONING], 'prompt');
const dynamicPromptNode: DynamicPromptInvocation = {
id: DYNAMIC_PROMPT,
type: 'dynamic_prompt',
max_prompts: maxPrompts,
combinatorial,
prompt: positivePrompt,
};
const iterateNode: IterateInvocation = {
id: ITERATE,
type: 'iterate',
};
graph.nodes[DYNAMIC_PROMPT] = dynamicPromptNode;
graph.nodes[ITERATE] = iterateNode;
// connect dynamic prompts to compel nodes
graph.edges.push(
{
source: {
node_id: DYNAMIC_PROMPT,
field: 'prompt_collection',
},
destination: {
node_id: ITERATE,
field: 'collection',
},
},
{
source: {
node_id: ITERATE,
field: 'item',
},
destination: {
node_id: POSITIVE_CONDITIONING,
field: 'prompt',
},
}
);
if (shouldRandomizeSeed) {
// Random int node to generate the starting seed
const randomIntNode: RandomIntInvocation = {
id: RANDOM_INT,
type: 'rand_int',
};
graph.nodes[RANDOM_INT] = randomIntNode;
// Connect random int to the start of the range of size so the range starts on the random first seed
graph.edges.push({
source: { node_id: RANDOM_INT, field: 'a' },
destination: { node_id: NOISE, field: 'seed' },
});
} else {
// User specified seed, so set the start of the range of size to the seed
(graph.nodes[NOISE] as NoiseInvocation).seed = seed;
}
} else {
const rangeOfSizeNode: RangeOfSizeInvocation = {
id: RANGE_OF_SIZE,
type: 'range_of_size',
size: iterations,
step: 1,
};
const iterateNode: IterateInvocation = {
id: ITERATE,
type: 'iterate',
};
graph.nodes[ITERATE] = iterateNode;
graph.nodes[RANGE_OF_SIZE] = rangeOfSizeNode;
graph.edges.push({
source: {
node_id: RANGE_OF_SIZE,
field: 'collection',
},
destination: {
node_id: ITERATE,
field: 'collection',
},
});
graph.edges.push({
source: {
node_id: ITERATE,
field: 'item',
},
destination: {
node_id: NOISE,
field: 'seed',
},
});
// handle seed
if (shouldRandomizeSeed) {
// Random int node to generate the starting seed
const randomIntNode: RandomIntInvocation = {
id: RANDOM_INT,
type: 'rand_int',
};
graph.nodes[RANDOM_INT] = randomIntNode;
// Connect random int to the start of the range of size so the range starts on the random first seed
graph.edges.push({
source: { node_id: RANDOM_INT, field: 'a' },
destination: { node_id: RANGE_OF_SIZE, field: 'start' },
});
} else {
// User specified seed, so set the start of the range of size to the seed
rangeOfSizeNode.start = seed;
}
}
};

View File

@ -2,6 +2,7 @@ import { RootState } from 'app/store/store';
import { import {
ImageDTO, ImageDTO,
ImageResizeInvocation, ImageResizeInvocation,
ImageToLatentsInvocation,
RandomIntInvocation, RandomIntInvocation,
RangeOfSizeInvocation, RangeOfSizeInvocation,
} from 'services/api/types'; } from 'services/api/types';
@ -10,7 +11,7 @@ import { log } from 'app/logging/useLogger';
import { import {
ITERATE, ITERATE,
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
MODEL_LOADER, PIPELINE_MODEL_LOADER,
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
@ -24,6 +25,7 @@ import {
import { set } from 'lodash-es'; import { set } from 'lodash-es';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
const moduleLog = log.child({ namespace: 'nodes' }); const moduleLog = log.child({ namespace: 'nodes' });
@ -75,31 +77,19 @@ export const buildCanvasImageToImageGraph = (
id: NEGATIVE_CONDITIONING, id: NEGATIVE_CONDITIONING,
prompt: negativePrompt, prompt: negativePrompt,
}, },
[RANGE_OF_SIZE]: {
type: 'range_of_size',
id: RANGE_OF_SIZE,
// seed - must be connected manually
// start: 0,
size: iterations,
step: 1,
},
[NOISE]: { [NOISE]: {
type: 'noise', type: 'noise',
id: NOISE, id: NOISE,
}, },
[MODEL_LOADER]: { [PIPELINE_MODEL_LOADER]: {
type: 'pipeline_model_loader', type: 'pipeline_model_loader',
id: MODEL_LOADER, id: PIPELINE_MODEL_LOADER,
model, model,
}, },
[LATENTS_TO_IMAGE]: { [LATENTS_TO_IMAGE]: {
type: 'l2i', type: 'l2i',
id: LATENTS_TO_IMAGE, id: LATENTS_TO_IMAGE,
}, },
[ITERATE]: {
type: 'iterate',
id: ITERATE,
},
[LATENTS_TO_LATENTS]: { [LATENTS_TO_LATENTS]: {
type: 'l2l', type: 'l2l',
id: LATENTS_TO_LATENTS, id: LATENTS_TO_LATENTS,
@ -120,7 +110,7 @@ export const buildCanvasImageToImageGraph = (
edges: [ edges: [
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -130,7 +120,7 @@ export const buildCanvasImageToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -140,7 +130,7 @@ export const buildCanvasImageToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'vae', field: 'vae',
}, },
destination: { destination: {
@ -148,26 +138,6 @@ export const buildCanvasImageToImageGraph = (
field: 'vae', field: 'vae',
}, },
}, },
{
source: {
node_id: RANGE_OF_SIZE,
field: 'collection',
},
destination: {
node_id: ITERATE,
field: 'collection',
},
},
{
source: {
node_id: ITERATE,
field: 'item',
},
destination: {
node_id: NOISE,
field: 'seed',
},
},
{ {
source: { source: {
node_id: LATENTS_TO_LATENTS, node_id: LATENTS_TO_LATENTS,
@ -200,7 +170,7 @@ export const buildCanvasImageToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'vae', field: 'vae',
}, },
destination: { destination: {
@ -210,7 +180,7 @@ export const buildCanvasImageToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'unet', field: 'unet',
}, },
destination: { destination: {
@ -241,26 +211,6 @@ export const buildCanvasImageToImageGraph = (
], ],
}; };
// handle seed
if (shouldRandomizeSeed) {
// Random int node to generate the starting seed
const randomIntNode: RandomIntInvocation = {
id: RANDOM_INT,
type: 'rand_int',
};
graph.nodes[RANDOM_INT] = randomIntNode;
// Connect random int to the start of the range of size so the range starts on the random first seed
graph.edges.push({
source: { node_id: RANDOM_INT, field: 'a' },
destination: { node_id: RANGE_OF_SIZE, field: 'start' },
});
} else {
// User specified seed, so set the start of the range of size to the seed
(graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed;
}
// handle `fit` // handle `fit`
if (initialImage.width !== width || initialImage.height !== height) { if (initialImage.width !== width || initialImage.height !== height) {
// The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS` // The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS`
@ -306,9 +256,9 @@ export const buildCanvasImageToImageGraph = (
}); });
} else { } else {
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly // We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
set(graph.nodes[IMAGE_TO_LATENTS], 'image', { (graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image = {
image_name: initialImage.image_name, image_name: initialImage.image_name,
}); };
// Pass the image's dimensions to the `NOISE` node // Pass the image's dimensions to the `NOISE` node
graph.edges.push({ graph.edges.push({
@ -327,7 +277,10 @@ export const buildCanvasImageToImageGraph = (
}); });
} }
// add controlnet // add dynamic prompts, mutating `graph`
addDynamicPromptsToGraph(graph, state);
// add controlnet, mutating `graph`
addControlNetToLinearGraph(graph, LATENTS_TO_LATENTS, state); addControlNetToLinearGraph(graph, LATENTS_TO_LATENTS, state);
return graph; return graph;

View File

@ -9,7 +9,7 @@ import { NonNullableGraph } from 'features/nodes/types/types';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { import {
ITERATE, ITERATE,
MODEL_LOADER, PIPELINE_MODEL_LOADER,
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
RANDOM_INT, RANDOM_INT,
@ -101,9 +101,9 @@ export const buildCanvasInpaintGraph = (
id: NEGATIVE_CONDITIONING, id: NEGATIVE_CONDITIONING,
prompt: negativePrompt, prompt: negativePrompt,
}, },
[MODEL_LOADER]: { [PIPELINE_MODEL_LOADER]: {
type: 'pipeline_model_loader', type: 'pipeline_model_loader',
id: MODEL_LOADER, id: PIPELINE_MODEL_LOADER,
model, model,
}, },
[RANGE_OF_SIZE]: { [RANGE_OF_SIZE]: {
@ -142,7 +142,7 @@ export const buildCanvasInpaintGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -152,7 +152,7 @@ export const buildCanvasInpaintGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -162,7 +162,7 @@ export const buildCanvasInpaintGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'unet', field: 'unet',
}, },
destination: { destination: {
@ -172,7 +172,7 @@ export const buildCanvasInpaintGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'vae', field: 'vae',
}, },
destination: { destination: {

View File

@ -4,7 +4,7 @@ import { RandomIntInvocation, RangeOfSizeInvocation } from 'services/api/types';
import { import {
ITERATE, ITERATE,
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
MODEL_LOADER, PIPELINE_MODEL_LOADER,
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
@ -15,6 +15,7 @@ import {
} from './constants'; } from './constants';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
/** /**
* Builds the Canvas tab's Text to Image graph. * Builds the Canvas tab's Text to Image graph.
@ -62,13 +63,6 @@ export const buildCanvasTextToImageGraph = (
id: NEGATIVE_CONDITIONING, id: NEGATIVE_CONDITIONING,
prompt: negativePrompt, prompt: negativePrompt,
}, },
[RANGE_OF_SIZE]: {
type: 'range_of_size',
id: RANGE_OF_SIZE,
// start: 0, // seed - must be connected manually
size: iterations,
step: 1,
},
[NOISE]: { [NOISE]: {
type: 'noise', type: 'noise',
id: NOISE, id: NOISE,
@ -82,19 +76,15 @@ export const buildCanvasTextToImageGraph = (
scheduler, scheduler,
steps, steps,
}, },
[MODEL_LOADER]: { [PIPELINE_MODEL_LOADER]: {
type: 'pipeline_model_loader', type: 'pipeline_model_loader',
id: MODEL_LOADER, id: PIPELINE_MODEL_LOADER,
model, model,
}, },
[LATENTS_TO_IMAGE]: { [LATENTS_TO_IMAGE]: {
type: 'l2i', type: 'l2i',
id: LATENTS_TO_IMAGE, id: LATENTS_TO_IMAGE,
}, },
[ITERATE]: {
type: 'iterate',
id: ITERATE,
},
}, },
edges: [ edges: [
{ {
@ -119,7 +109,7 @@ export const buildCanvasTextToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -129,7 +119,7 @@ export const buildCanvasTextToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -139,7 +129,7 @@ export const buildCanvasTextToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'unet', field: 'unet',
}, },
destination: { destination: {
@ -159,7 +149,7 @@ export const buildCanvasTextToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'vae', field: 'vae',
}, },
destination: { destination: {
@ -167,26 +157,6 @@ export const buildCanvasTextToImageGraph = (
field: 'vae', field: 'vae',
}, },
}, },
{
source: {
node_id: RANGE_OF_SIZE,
field: 'collection',
},
destination: {
node_id: ITERATE,
field: 'collection',
},
},
{
source: {
node_id: ITERATE,
field: 'item',
},
destination: {
node_id: NOISE,
field: 'seed',
},
},
{ {
source: { source: {
node_id: NOISE, node_id: NOISE,
@ -200,27 +170,10 @@ export const buildCanvasTextToImageGraph = (
], ],
}; };
// handle seed // add dynamic prompts, mutating `graph`
if (shouldRandomizeSeed) { addDynamicPromptsToGraph(graph, state);
// Random int node to generate the starting seed
const randomIntNode: RandomIntInvocation = {
id: RANDOM_INT,
type: 'rand_int',
};
graph.nodes[RANDOM_INT] = randomIntNode; // add controlnet, mutating `graph`
// Connect random int to the start of the range of size so the range starts on the random first seed
graph.edges.push({
source: { node_id: RANDOM_INT, field: 'a' },
destination: { node_id: RANGE_OF_SIZE, field: 'start' },
});
} else {
// User specified seed, so set the start of the range of size to the seed
(graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed;
}
// add controlnet
addControlNetToLinearGraph(graph, TEXT_TO_LATENTS, state); addControlNetToLinearGraph(graph, TEXT_TO_LATENTS, state);
return graph; return graph;

View File

@ -1,28 +1,24 @@
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { import {
ImageResizeInvocation, ImageResizeInvocation,
RandomIntInvocation, ImageToLatentsInvocation,
RangeOfSizeInvocation,
} from 'services/api/types'; } from 'services/api/types';
import { NonNullableGraph } from 'features/nodes/types/types'; import { NonNullableGraph } from 'features/nodes/types/types';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { import {
ITERATE,
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
MODEL_LOADER, PIPELINE_MODEL_LOADER,
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
RANDOM_INT,
RANGE_OF_SIZE,
IMAGE_TO_IMAGE_GRAPH, IMAGE_TO_IMAGE_GRAPH,
IMAGE_TO_LATENTS, IMAGE_TO_LATENTS,
LATENTS_TO_LATENTS, LATENTS_TO_LATENTS,
RESIZE, RESIZE,
} from './constants'; } from './constants';
import { set } from 'lodash-es';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
const moduleLog = log.child({ namespace: 'nodes' }); const moduleLog = log.child({ namespace: 'nodes' });
@ -44,9 +40,6 @@ export const buildLinearImageToImageGraph = (
shouldFitToWidthHeight, shouldFitToWidthHeight,
width, width,
height, height,
iterations,
seed,
shouldRandomizeSeed,
} = state.generation; } = state.generation;
/** /**
@ -79,31 +72,19 @@ export const buildLinearImageToImageGraph = (
id: NEGATIVE_CONDITIONING, id: NEGATIVE_CONDITIONING,
prompt: negativePrompt, prompt: negativePrompt,
}, },
[RANGE_OF_SIZE]: {
type: 'range_of_size',
id: RANGE_OF_SIZE,
// seed - must be connected manually
// start: 0,
size: iterations,
step: 1,
},
[NOISE]: { [NOISE]: {
type: 'noise', type: 'noise',
id: NOISE, id: NOISE,
}, },
[MODEL_LOADER]: { [PIPELINE_MODEL_LOADER]: {
type: 'pipeline_model_loader', type: 'pipeline_model_loader',
id: MODEL_LOADER, id: PIPELINE_MODEL_LOADER,
model, model,
}, },
[LATENTS_TO_IMAGE]: { [LATENTS_TO_IMAGE]: {
type: 'l2i', type: 'l2i',
id: LATENTS_TO_IMAGE, id: LATENTS_TO_IMAGE,
}, },
[ITERATE]: {
type: 'iterate',
id: ITERATE,
},
[LATENTS_TO_LATENTS]: { [LATENTS_TO_LATENTS]: {
type: 'l2l', type: 'l2l',
id: LATENTS_TO_LATENTS, id: LATENTS_TO_LATENTS,
@ -124,7 +105,7 @@ export const buildLinearImageToImageGraph = (
edges: [ edges: [
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -134,7 +115,7 @@ export const buildLinearImageToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -144,7 +125,7 @@ export const buildLinearImageToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'vae', field: 'vae',
}, },
destination: { destination: {
@ -152,26 +133,6 @@ export const buildLinearImageToImageGraph = (
field: 'vae', field: 'vae',
}, },
}, },
{
source: {
node_id: RANGE_OF_SIZE,
field: 'collection',
},
destination: {
node_id: ITERATE,
field: 'collection',
},
},
{
source: {
node_id: ITERATE,
field: 'item',
},
destination: {
node_id: NOISE,
field: 'seed',
},
},
{ {
source: { source: {
node_id: LATENTS_TO_LATENTS, node_id: LATENTS_TO_LATENTS,
@ -204,7 +165,7 @@ export const buildLinearImageToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'vae', field: 'vae',
}, },
destination: { destination: {
@ -214,7 +175,7 @@ export const buildLinearImageToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'unet', field: 'unet',
}, },
destination: { destination: {
@ -245,26 +206,6 @@ export const buildLinearImageToImageGraph = (
], ],
}; };
// handle seed
if (shouldRandomizeSeed) {
// Random int node to generate the starting seed
const randomIntNode: RandomIntInvocation = {
id: RANDOM_INT,
type: 'rand_int',
};
graph.nodes[RANDOM_INT] = randomIntNode;
// Connect random int to the start of the range of size so the range starts on the random first seed
graph.edges.push({
source: { node_id: RANDOM_INT, field: 'a' },
destination: { node_id: RANGE_OF_SIZE, field: 'start' },
});
} else {
// User specified seed, so set the start of the range of size to the seed
(graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed;
}
// handle `fit` // handle `fit`
if ( if (
shouldFitToWidthHeight && shouldFitToWidthHeight &&
@ -313,9 +254,9 @@ export const buildLinearImageToImageGraph = (
}); });
} else { } else {
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly // We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
set(graph.nodes[IMAGE_TO_LATENTS], 'image', { (graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image = {
image_name: initialImage.imageName, image_name: initialImage.imageName,
}); };
// Pass the image's dimensions to the `NOISE` node // Pass the image's dimensions to the `NOISE` node
graph.edges.push({ graph.edges.push({
@ -334,7 +275,10 @@ export const buildLinearImageToImageGraph = (
}); });
} }
// add controlnet // add dynamic prompts, mutating `graph`
addDynamicPromptsToGraph(graph, state);
// add controlnet, mutating `graph`
addControlNetToLinearGraph(graph, LATENTS_TO_LATENTS, state); addControlNetToLinearGraph(graph, LATENTS_TO_LATENTS, state);
return graph; return graph;

View File

@ -1,33 +1,20 @@
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types'; import { NonNullableGraph } from 'features/nodes/types/types';
import { import {
BaseModelType,
RandomIntInvocation,
RangeOfSizeInvocation,
} from 'services/api/types';
import {
ITERATE,
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
MODEL_LOADER, PIPELINE_MODEL_LOADER,
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
RANDOM_INT,
RANGE_OF_SIZE,
TEXT_TO_IMAGE_GRAPH, TEXT_TO_IMAGE_GRAPH,
TEXT_TO_LATENTS, TEXT_TO_LATENTS,
} from './constants'; } from './constants';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
type TextToImageGraphOverrides = {
width: number;
height: number;
};
export const buildLinearTextToImageGraph = ( export const buildLinearTextToImageGraph = (
state: RootState, state: RootState
overrides?: TextToImageGraphOverrides
): NonNullableGraph => { ): NonNullableGraph => {
const { const {
positivePrompt, positivePrompt,
@ -38,9 +25,6 @@ export const buildLinearTextToImageGraph = (
steps, steps,
width, width,
height, height,
iterations,
seed,
shouldRandomizeSeed,
} = state.generation; } = state.generation;
const model = modelIdToPipelineModelField(modelId); const model = modelIdToPipelineModelField(modelId);
@ -68,18 +52,11 @@ export const buildLinearTextToImageGraph = (
id: NEGATIVE_CONDITIONING, id: NEGATIVE_CONDITIONING,
prompt: negativePrompt, prompt: negativePrompt,
}, },
[RANGE_OF_SIZE]: {
type: 'range_of_size',
id: RANGE_OF_SIZE,
// start: 0, // seed - must be connected manually
size: iterations,
step: 1,
},
[NOISE]: { [NOISE]: {
type: 'noise', type: 'noise',
id: NOISE, id: NOISE,
width: overrides?.width || width, width,
height: overrides?.height || height, height,
}, },
[TEXT_TO_LATENTS]: { [TEXT_TO_LATENTS]: {
type: 't2l', type: 't2l',
@ -88,19 +65,15 @@ export const buildLinearTextToImageGraph = (
scheduler, scheduler,
steps, steps,
}, },
[MODEL_LOADER]: { [PIPELINE_MODEL_LOADER]: {
type: 'pipeline_model_loader', type: 'pipeline_model_loader',
id: MODEL_LOADER, id: PIPELINE_MODEL_LOADER,
model, model,
}, },
[LATENTS_TO_IMAGE]: { [LATENTS_TO_IMAGE]: {
type: 'l2i', type: 'l2i',
id: LATENTS_TO_IMAGE, id: LATENTS_TO_IMAGE,
}, },
[ITERATE]: {
type: 'iterate',
id: ITERATE,
},
}, },
edges: [ edges: [
{ {
@ -125,7 +98,7 @@ export const buildLinearTextToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -135,7 +108,7 @@ export const buildLinearTextToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -145,7 +118,7 @@ export const buildLinearTextToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'unet', field: 'unet',
}, },
destination: { destination: {
@ -165,7 +138,7 @@ export const buildLinearTextToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'vae', field: 'vae',
}, },
destination: { destination: {
@ -173,26 +146,6 @@ export const buildLinearTextToImageGraph = (
field: 'vae', field: 'vae',
}, },
}, },
{
source: {
node_id: RANGE_OF_SIZE,
field: 'collection',
},
destination: {
node_id: ITERATE,
field: 'collection',
},
},
{
source: {
node_id: ITERATE,
field: 'item',
},
destination: {
node_id: NOISE,
field: 'seed',
},
},
{ {
source: { source: {
node_id: NOISE, node_id: NOISE,
@ -206,27 +159,10 @@ export const buildLinearTextToImageGraph = (
], ],
}; };
// handle seed // add dynamic prompts, mutating `graph`
if (shouldRandomizeSeed) { addDynamicPromptsToGraph(graph, state);
// Random int node to generate the starting seed
const randomIntNode: RandomIntInvocation = {
id: RANDOM_INT,
type: 'rand_int',
};
graph.nodes[RANDOM_INT] = randomIntNode; // add controlnet, mutating `graph`
// Connect random int to the start of the range of size so the range starts on the random first seed
graph.edges.push({
source: { node_id: RANDOM_INT, field: 'a' },
destination: { node_id: RANGE_OF_SIZE, field: 'start' },
});
} else {
// User specified seed, so set the start of the range of size to the seed
(graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed;
}
// add controlnet
addControlNetToLinearGraph(graph, TEXT_TO_LATENTS, state); addControlNetToLinearGraph(graph, TEXT_TO_LATENTS, state);
return graph; return graph;

View File

@ -7,12 +7,13 @@ export const NOISE = 'noise';
export const RANDOM_INT = 'rand_int'; export const RANDOM_INT = 'rand_int';
export const RANGE_OF_SIZE = 'range_of_size'; export const RANGE_OF_SIZE = 'range_of_size';
export const ITERATE = 'iterate'; export const ITERATE = 'iterate';
export const MODEL_LOADER = 'pipeline_model_loader'; export const PIPELINE_MODEL_LOADER = 'pipeline_model_loader';
export const IMAGE_TO_LATENTS = 'image_to_latents'; export const IMAGE_TO_LATENTS = 'image_to_latents';
export const LATENTS_TO_LATENTS = 'latents_to_latents'; export const LATENTS_TO_LATENTS = 'latents_to_latents';
export const RESIZE = 'resize_image'; export const RESIZE = 'resize_image';
export const INPAINT = 'inpaint'; export const INPAINT = 'inpaint';
export const CONTROL_NET_COLLECT = 'control_net_collect'; export const CONTROL_NET_COLLECT = 'control_net_collect';
export const DYNAMIC_PROMPT = 'dynamic_prompt';
// friendly graph ids // friendly graph ids
export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph'; export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph';

View File

@ -1,26 +0,0 @@
import { v4 as uuidv4 } from 'uuid';
import { RootState } from 'app/store/store';
import { CompelInvocation } from 'services/api/types';
import { O } from 'ts-toolbelt';
export const buildCompelNode = (
prompt: string,
state: RootState,
overrides: O.Partial<CompelInvocation, 'deep'> = {}
): CompelInvocation => {
const nodeId = uuidv4();
const { generation } = state;
const { model } = generation;
const compelNode: CompelInvocation = {
id: nodeId,
type: 'compel',
prompt,
model,
};
Object.assign(compelNode, overrides);
return compelNode;
};

View File

@ -1,107 +0,0 @@
import { v4 as uuidv4 } from 'uuid';
import { RootState } from 'app/store/store';
import {
Edge,
ImageToImageInvocation,
TextToImageInvocation,
} from 'services/api/types';
import { O } from 'ts-toolbelt';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
export const buildImg2ImgNode = (
state: RootState,
overrides: O.Partial<ImageToImageInvocation, 'deep'> = {}
): ImageToImageInvocation => {
const nodeId = uuidv4();
const { generation } = state;
const activeTabName = activeTabNameSelector(state);
const {
positivePrompt: prompt,
negativePrompt: negativePrompt,
seed,
steps,
width,
height,
cfgScale,
scheduler,
model,
img2imgStrength: strength,
shouldFitToWidthHeight: fit,
shouldRandomizeSeed,
initialImage,
} = generation;
// const initialImage = initialImageSelector(state);
const imageToImageNode: ImageToImageInvocation = {
id: nodeId,
type: 'img2img',
prompt: `${prompt} [${negativePrompt}]`,
steps,
width,
height,
cfg_scale: cfgScale,
scheduler,
model,
strength,
fit,
};
// on Canvas tab, we do not manually specific init image
if (activeTabName !== 'unifiedCanvas') {
if (!initialImage) {
// TODO: handle this more better
throw 'no initial image';
}
imageToImageNode.image = {
image_name: initialImage.imageName,
};
}
if (!shouldRandomizeSeed) {
imageToImageNode.seed = seed;
}
Object.assign(imageToImageNode, overrides);
return imageToImageNode;
};
type hiresReturnType = {
node: Record<string, ImageToImageInvocation>;
edge: Edge;
};
export const buildHiResNode = (
baseNode: Record<string, TextToImageInvocation>,
strength?: number
): hiresReturnType => {
const nodeId = uuidv4();
const baseNodeId = Object.keys(baseNode)[0];
const baseNodeValues = Object.values(baseNode)[0];
return {
node: {
[nodeId]: {
...baseNodeValues,
id: nodeId,
type: 'img2img',
strength,
fit: true,
},
},
edge: {
source: {
field: 'image',
node_id: baseNodeId,
},
destination: {
field: 'image',
node_id: nodeId,
},
},
};
};

View File

@ -1,48 +0,0 @@
import { v4 as uuidv4 } from 'uuid';
import { RootState } from 'app/store/store';
import { InpaintInvocation } from 'services/api/types';
import { O } from 'ts-toolbelt';
export const buildInpaintNode = (
state: RootState,
overrides: O.Partial<InpaintInvocation, 'deep'> = {}
): InpaintInvocation => {
const nodeId = uuidv4();
const {
positivePrompt: prompt,
negativePrompt: negativePrompt,
seed,
steps,
width,
height,
cfgScale,
scheduler,
model,
img2imgStrength: strength,
shouldFitToWidthHeight: fit,
shouldRandomizeSeed,
} = state.generation;
const inpaintNode: InpaintInvocation = {
id: nodeId,
type: 'inpaint',
prompt: `${prompt} [${negativePrompt}]`,
steps,
width,
height,
cfg_scale: cfgScale,
scheduler,
model,
strength,
fit,
};
if (!shouldRandomizeSeed) {
inpaintNode.seed = seed;
}
Object.assign(inpaintNode, overrides);
return inpaintNode;
};

View File

@ -1,13 +0,0 @@
import { v4 as uuidv4 } from 'uuid';
import { IterateInvocation } from 'services/api/types';
export const buildIterateNode = (): IterateInvocation => {
const nodeId = uuidv4();
return {
id: nodeId,
type: 'iterate',
// collection: [],
// index: 0,
};
};

View File

@ -1,26 +0,0 @@
import { v4 as uuidv4 } from 'uuid';
import { RootState } from 'app/store/store';
import { RandomRangeInvocation, RangeInvocation } from 'services/api/types';
export const buildRangeNode = (
state: RootState
): RangeInvocation | RandomRangeInvocation => {
const nodeId = uuidv4();
const { shouldRandomizeSeed, iterations, seed } = state.generation;
if (shouldRandomizeSeed) {
return {
id: nodeId,
type: 'random_range',
size: iterations,
};
}
return {
id: nodeId,
type: 'range',
start: seed,
stop: seed + iterations,
};
};

View File

@ -1,45 +0,0 @@
import { v4 as uuidv4 } from 'uuid';
import { RootState } from 'app/store/store';
import { TextToImageInvocation } from 'services/api/types';
import { O } from 'ts-toolbelt';
export const buildTxt2ImgNode = (
state: RootState,
overrides: O.Partial<TextToImageInvocation, 'deep'> = {}
): TextToImageInvocation => {
const nodeId = uuidv4();
const { generation } = state;
const {
positivePrompt: prompt,
negativePrompt: negativePrompt,
seed,
steps,
width,
height,
cfgScale: cfg_scale,
scheduler,
shouldRandomizeSeed,
model,
} = generation;
const textToImageNode: NonNullable<TextToImageInvocation> = {
id: nodeId,
type: 'txt2img',
prompt: `${prompt} [${negativePrompt}]`,
steps,
width,
height,
cfg_scale,
scheduler,
model,
};
if (!shouldRandomizeSeed) {
textToImageNode.seed = seed;
}
Object.assign(textToImageNode, overrides);
return textToImageNode;
};

View File

@ -5,127 +5,154 @@ import {
InputFieldTemplate, InputFieldTemplate,
InvocationSchemaObject, InvocationSchemaObject,
InvocationTemplate, InvocationTemplate,
isInvocationSchemaObject,
OutputFieldTemplate, OutputFieldTemplate,
} from '../types/types'; } from '../types/types';
import { import { buildInputFieldTemplate, getFieldType } from './fieldTemplateBuilders';
buildInputFieldTemplate, import { O } from 'ts-toolbelt';
buildOutputFieldTemplates,
} from './fieldTemplateBuilders'; // recursively exclude all properties of type U from T
type DeepExclude<T, U> = T extends U
? never
: T extends object
? {
[K in keyof T]: DeepExclude<T[K], U>;
}
: T;
// The schema from swagger-parser is dereferenced, and we know `components` and `components.schemas` exist
type DereferencedOpenAPIDocument = DeepExclude<
O.Required<OpenAPIV3.Document, 'schemas' | 'components', 'deep'>,
OpenAPIV3.ReferenceObject
>;
const RESERVED_FIELD_NAMES = ['id', 'type', 'is_intermediate']; const RESERVED_FIELD_NAMES = ['id', 'type', 'is_intermediate'];
const invocationDenylist = ['Graph', 'InvocationMeta']; const invocationDenylist = ['Graph', 'InvocationMeta'];
export const parseSchema = (openAPI: OpenAPIV3.Document) => { const nodeFilter = (
schema: DereferencedOpenAPIDocument['components']['schemas'][string],
key: string
) =>
key.includes('Invocation') &&
!key.includes('InvocationOutput') &&
!invocationDenylist.some((denylistItem) => key.includes(denylistItem));
export const parseSchema = (openAPI: DereferencedOpenAPIDocument) => {
// filter out non-invocation schemas, plus some tricky invocations for now // filter out non-invocation schemas, plus some tricky invocations for now
const filteredSchemas = filter( const filteredSchemas = filter(openAPI.components.schemas, nodeFilter);
openAPI.components!.schemas,
(schema, key) =>
key.includes('Invocation') &&
!key.includes('InvocationOutput') &&
!invocationDenylist.some((denylistItem) => key.includes(denylistItem))
) as (OpenAPIV3.ReferenceObject | InvocationSchemaObject)[];
const invocations = filteredSchemas.reduce< const invocations = filteredSchemas.reduce<
Record<string, InvocationTemplate> Record<string, InvocationTemplate>
>((acc, schema) => { >((acc, s) => {
// only want SchemaObjects // cast to InvocationSchemaObject, we know the shape
if (isInvocationSchemaObject(schema)) { const schema = s as InvocationSchemaObject;
const type = schema.properties.type.default;
const title = schema.ui?.title ?? schema.title.replace('Invocation', ''); const type = schema.properties.type.default;
const typeHints = schema.ui?.type_hints; const title = schema.ui?.title ?? schema.title.replace('Invocation', '');
const inputs: Record<string, InputFieldTemplate> = {}; const typeHints = schema.ui?.type_hints;
if (type === 'collect') { const inputs: Record<string, InputFieldTemplate> = {};
const itemProperty = schema.properties[
'item'
] as InvocationSchemaObject;
// Handle the special Collect node
inputs.item = {
type: 'item',
name: 'item',
description: itemProperty.description ?? '',
title: 'Collection Item',
inputKind: 'connection',
inputRequirement: 'always',
default: undefined,
};
} else if (type === 'iterate') {
const itemProperty = schema.properties[
'collection'
] as InvocationSchemaObject;
inputs.collection = { if (type === 'collect') {
type: 'array', // Special handling for the Collect node
name: 'collection', const itemProperty = schema.properties['item'] as InvocationSchemaObject;
title: itemProperty.title ?? '', inputs.item = {
default: [], type: 'item',
description: itemProperty.description ?? '', name: 'item',
inputRequirement: 'always', description: itemProperty.description ?? '',
inputKind: 'connection', title: 'Collection Item',
}; inputKind: 'connection',
} else { inputRequirement: 'always',
// All other nodes default: undefined,
reduce(
schema.properties,
(inputsAccumulator, property, propertyName) => {
if (
// `type` and `id` are not valid inputs/outputs
!RESERVED_FIELD_NAMES.includes(propertyName) &&
isSchemaObject(property)
) {
const field: InputFieldTemplate | undefined =
buildInputFieldTemplate(property, propertyName, typeHints);
if (field) {
inputsAccumulator[propertyName] = field;
}
}
return inputsAccumulator;
},
inputs
);
}
const rawOutput = (schema as InvocationSchemaObject).output;
let outputs: Record<string, OutputFieldTemplate>;
// some special handling is needed for collect, iterate and range nodes
if (type === 'iterate') {
// this is guaranteed to be a SchemaObject
const iterationOutput = openAPI.components!.schemas![
'IterateInvocationOutput'
] as OpenAPIV3.SchemaObject;
outputs = {
item: {
name: 'item',
title: iterationOutput.title ?? '',
description: iterationOutput.description ?? '',
type: 'array',
},
};
} else {
outputs = buildOutputFieldTemplates(rawOutput, openAPI, typeHints);
}
const invocation: InvocationTemplate = {
title,
type,
tags: schema.ui?.tags ?? [],
description: schema.description ?? '',
inputs,
outputs,
}; };
} else if (type === 'iterate') {
// Special handling for the Iterate node
const itemProperty = schema.properties[
'collection'
] as InvocationSchemaObject;
Object.assign(acc, { [type]: invocation }); inputs.collection = {
type: 'array',
name: 'collection',
title: itemProperty.title ?? '',
default: [],
description: itemProperty.description ?? '',
inputRequirement: 'always',
inputKind: 'connection',
};
} else {
// All other nodes
reduce(
schema.properties,
(inputsAccumulator, property, propertyName) => {
if (
// `type` and `id` are not valid inputs/outputs
!RESERVED_FIELD_NAMES.includes(propertyName) &&
isSchemaObject(property)
) {
const field: InputFieldTemplate | undefined =
buildInputFieldTemplate(property, propertyName, typeHints);
if (field) {
inputsAccumulator[propertyName] = field;
}
}
return inputsAccumulator;
},
inputs
);
} }
let outputs: Record<string, OutputFieldTemplate>;
if (type === 'iterate') {
// Special handling for the Iterate node output
const iterationOutput =
openAPI.components.schemas['IterateInvocationOutput'];
outputs = {
item: {
name: 'item',
title: iterationOutput.title ?? '',
description: iterationOutput.description ?? '',
type: 'array',
},
};
} else {
// All other node outputs
outputs = reduce(
schema.output.properties as OpenAPIV3.SchemaObject,
(outputsAccumulator, property, propertyName) => {
if (!['type', 'id'].includes(propertyName)) {
const fieldType = getFieldType(property, propertyName, typeHints);
outputsAccumulator[propertyName] = {
name: propertyName,
title: property.title ?? '',
description: property.description ?? '',
type: fieldType,
};
}
return outputsAccumulator;
},
{} as Record<string, OutputFieldTemplate>
);
}
const invocation: InvocationTemplate = {
title,
type,
tags: schema.ui?.tags ?? [],
description: schema.description ?? '',
inputs,
outputs,
};
Object.assign(acc, { [type]: invocation });
return acc; return acc;
}, {}); }, {});

View File

@ -1,4 +1,5 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAINumberInput from 'common/components/IAINumberInput'; import IAINumberInput from 'common/components/IAINumberInput';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
@ -10,27 +11,26 @@ import { uiSelector } from 'features/ui/store/uiSelectors';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
const selector = createSelector( const selector = createSelector([stateSelector], (state) => {
[generationSelector, configSelector, uiSelector, hotkeysSelector], const { initial, min, sliderMax, inputMax, fineStep, coarseStep } =
(generation, config, ui, hotkeys) => { state.config.sd.iterations;
const { initial, min, sliderMax, inputMax, fineStep, coarseStep } = const { iterations } = state.generation;
config.sd.iterations; const { shouldUseSliders } = state.ui;
const { iterations } = generation; const isDisabled = state.dynamicPrompts.isEnabled;
const { shouldUseSliders } = ui;
const step = hotkeys.shift ? fineStep : coarseStep; const step = state.hotkeys.shift ? fineStep : coarseStep;
return { return {
iterations, iterations,
initial, initial,
min, min,
sliderMax, sliderMax,
inputMax, inputMax,
step, step,
shouldUseSliders, shouldUseSliders,
}; isDisabled,
} };
); });
const ParamIterations = () => { const ParamIterations = () => {
const { const {
@ -41,6 +41,7 @@ const ParamIterations = () => {
inputMax, inputMax,
step, step,
shouldUseSliders, shouldUseSliders,
isDisabled,
} = useAppSelector(selector); } = useAppSelector(selector);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
@ -58,6 +59,7 @@ const ParamIterations = () => {
return shouldUseSliders ? ( return shouldUseSliders ? (
<IAISlider <IAISlider
isDisabled={isDisabled}
label={t('parameters.images')} label={t('parameters.images')}
step={step} step={step}
min={min} min={min}
@ -72,6 +74,7 @@ const ParamIterations = () => {
/> />
) : ( ) : (
<IAINumberInput <IAINumberInput
isDisabled={isDisabled}
label={t('parameters.images')} label={t('parameters.images')}
step={step} step={step}
min={min} min={min}

View File

@ -60,6 +60,14 @@ export const initialConfigState: AppConfig = {
fineStep: 0.01, fineStep: 0.01,
coarseStep: 0.05, coarseStep: 0.05,
}, },
dynamicPrompts: {
maxPrompts: {
initial: 100,
min: 1,
sliderMax: 1000,
inputMax: 10000,
},
},
}, },
}; };

View File

@ -4,7 +4,6 @@ import * as InvokeAI from 'app/types/invokeai';
import { InvokeLogLevel } from 'app/logging/useLogger'; import { InvokeLogLevel } from 'app/logging/useLogger';
import { userInvoked } from 'app/store/actions'; import { userInvoked } from 'app/store/actions';
import { parsedOpenAPISchema } from 'features/nodes/store/nodesSlice';
import { TFuncKey, t } from 'i18next'; import { TFuncKey, t } from 'i18next';
import { LogLevelName } from 'roarr'; import { LogLevelName } from 'roarr';
import { import {
@ -26,6 +25,7 @@ import {
} from 'services/api/thunks/session'; } from 'services/api/thunks/session';
import { makeToast } from '../../../app/components/Toaster'; import { makeToast } from '../../../app/components/Toaster';
import { LANGUAGES } from '../components/LanguagePicker'; import { LANGUAGES } from '../components/LanguagePicker';
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
export type CancelStrategy = 'immediate' | 'scheduled'; export type CancelStrategy = 'immediate' | 'scheduled';
@ -382,7 +382,7 @@ export const systemSlice = createSlice({
/** /**
* OpenAPI schema was parsed * OpenAPI schema was parsed
*/ */
builder.addCase(parsedOpenAPISchema, (state) => { builder.addCase(nodeTemplatesBuilt, (state) => {
state.wasSchemaParsed = true; state.wasSchemaParsed = true;
}); });

View File

@ -8,6 +8,7 @@ import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Sym
import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse'; import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse';
import ImageToImageTabCoreParameters from './ImageToImageTabCoreParameters'; import ImageToImageTabCoreParameters from './ImageToImageTabCoreParameters';
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse'; import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
const ImageToImageTabParameters = () => { const ImageToImageTabParameters = () => {
return ( return (
@ -16,6 +17,7 @@ const ImageToImageTabParameters = () => {
<ParamNegativeConditioning /> <ParamNegativeConditioning />
<ProcessButtons /> <ProcessButtons />
<ImageToImageTabCoreParameters /> <ImageToImageTabCoreParameters />
<ParamDynamicPromptsCollapse />
<ParamControlNetCollapse /> <ParamControlNetCollapse />
<ParamVariationCollapse /> <ParamVariationCollapse />
<ParamNoiseCollapse /> <ParamNoiseCollapse />

View File

@ -9,6 +9,7 @@ import ParamHiresCollapse from 'features/parameters/components/Parameters/Hires/
import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse'; import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse';
import TextToImageTabCoreParameters from './TextToImageTabCoreParameters'; import TextToImageTabCoreParameters from './TextToImageTabCoreParameters';
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse'; import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
const TextToImageTabParameters = () => { const TextToImageTabParameters = () => {
return ( return (
@ -17,6 +18,7 @@ const TextToImageTabParameters = () => {
<ParamNegativeConditioning /> <ParamNegativeConditioning />
<ProcessButtons /> <ProcessButtons />
<TextToImageTabCoreParameters /> <TextToImageTabCoreParameters />
<ParamDynamicPromptsCollapse />
<ParamControlNetCollapse /> <ParamControlNetCollapse />
<ParamVariationCollapse /> <ParamVariationCollapse />
<ParamNoiseCollapse /> <ParamNoiseCollapse />

View File

@ -8,6 +8,7 @@ import { memo } from 'react';
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning'; import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning'; import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse'; import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
const UnifiedCanvasParameters = () => { const UnifiedCanvasParameters = () => {
return ( return (
@ -16,6 +17,7 @@ const UnifiedCanvasParameters = () => {
<ParamNegativeConditioning /> <ParamNegativeConditioning />
<ProcessButtons /> <ProcessButtons />
<UnifiedCanvasCoreParameters /> <UnifiedCanvasCoreParameters />
<ParamDynamicPromptsCollapse />
<ParamControlNetCollapse /> <ParamControlNetCollapse />
<ParamVariationCollapse /> <ParamVariationCollapse />
<ParamSymmetryCollapse /> <ParamSymmetryCollapse />

View File

@ -15,6 +15,7 @@ export const imagesApi = api.injectEndpoints({
} }
return tags; return tags;
}, },
keepUnusedDataFor: 86400, // 24 hours
}), }),
}), }),
}); });

View File

@ -2917,7 +2917,7 @@ export type components = {
/** ModelsList */ /** ModelsList */
ModelsList: { ModelsList: {
/** Models */ /** Models */
models: (components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"])[]; models: (components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"])[];
}; };
/** /**
* MultiplyInvocation * MultiplyInvocation
@ -4177,18 +4177,18 @@ export type components = {
*/ */
image?: components["schemas"]["ImageField"]; image?: components["schemas"]["ImageField"];
}; };
/**
* StableDiffusion2ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
/** /**
* StableDiffusion1ModelFormat * StableDiffusion1ModelFormat
* @description An enumeration. * @description An enumeration.
* @enum {string} * @enum {string}
*/ */
StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion2ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
}; };
responses: never; responses: never;
parameters: never; parameters: never;

View File

@ -1,20 +1,45 @@
import SwaggerParser from '@apidevtools/swagger-parser';
import { createAsyncThunk } from '@reduxjs/toolkit'; import { createAsyncThunk } from '@reduxjs/toolkit';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { parsedOpenAPISchema } from 'features/nodes/store/nodesSlice';
import { OpenAPIV3 } from 'openapi-types'; import { OpenAPIV3 } from 'openapi-types';
const schemaLog = log.child({ namespace: 'schema' }); const schemaLog = log.child({ namespace: 'schema' });
function getCircularReplacer() {
const ancestors: Record<string, any>[] = [];
return function (key: string, value: any) {
if (typeof value !== 'object' || value === null) {
return value;
}
// `this` is the object that value is contained in,
// i.e., its direct parent.
// @ts-ignore
while (ancestors.length > 0 && ancestors.at(-1) !== this) {
ancestors.pop();
}
if (ancestors.includes(value)) {
return '[Circular]';
}
ancestors.push(value);
return value;
};
}
export const receivedOpenAPISchema = createAsyncThunk( export const receivedOpenAPISchema = createAsyncThunk(
'nodes/receivedOpenAPISchema', 'nodes/receivedOpenAPISchema',
async (_, { dispatch }): Promise<OpenAPIV3.Document> => { async (_, { dispatch, rejectWithValue }) => {
const response = await fetch(`openapi.json`); try {
const openAPISchema = await response.json(); const dereferencedSchema = (await SwaggerParser.dereference(
'openapi.json'
)) as OpenAPIV3.Document;
schemaLog.info({ openAPISchema }, 'Received OpenAPI schema'); const schemaJSON = JSON.parse(
JSON.stringify(dereferencedSchema, getCircularReplacer())
);
dispatch(parsedOpenAPISchema(openAPISchema as OpenAPIV3.Document)); return schemaJSON;
} catch (error) {
return openAPISchema; return rejectWithValue({ error });
}
} }
); );

View File

@ -1,7 +1,14 @@
import { O } from 'ts-toolbelt';
import { components } from './schema'; import { components } from './schema';
type schemas = components['schemas']; type schemas = components['schemas'];
/**
* Helper type to extract the invocation type from the schema.
* Also flags the `type` property as required.
*/
type Invocation<T extends keyof schemas> = O.Required<schemas[T], 'type'>;
/** /**
* Types from the API, re-exported from the types generated by `openapi-typescript`. * Types from the API, re-exported from the types generated by `openapi-typescript`.
*/ */
@ -31,42 +38,51 @@ export type Edge = schemas['Edge'];
export type GraphExecutionState = schemas['GraphExecutionState']; export type GraphExecutionState = schemas['GraphExecutionState'];
// General nodes // General nodes
export type CollectInvocation = schemas['CollectInvocation']; export type CollectInvocation = Invocation<'CollectInvocation'>;
export type IterateInvocation = schemas['IterateInvocation']; export type IterateInvocation = Invocation<'IterateInvocation'>;
export type RangeInvocation = schemas['RangeInvocation']; export type RangeInvocation = Invocation<'RangeInvocation'>;
export type RandomRangeInvocation = schemas['RandomRangeInvocation']; export type RandomRangeInvocation = Invocation<'RandomRangeInvocation'>;
export type RangeOfSizeInvocation = schemas['RangeOfSizeInvocation']; export type RangeOfSizeInvocation = Invocation<'RangeOfSizeInvocation'>;
export type InpaintInvocation = schemas['InpaintInvocation']; export type InpaintInvocation = Invocation<'InpaintInvocation'>;
export type ImageResizeInvocation = schemas['ImageResizeInvocation']; export type ImageResizeInvocation = Invocation<'ImageResizeInvocation'>;
export type RandomIntInvocation = schemas['RandomIntInvocation']; export type RandomIntInvocation = Invocation<'RandomIntInvocation'>;
export type CompelInvocation = schemas['CompelInvocation']; export type CompelInvocation = Invocation<'CompelInvocation'>;
export type DynamicPromptInvocation = Invocation<'DynamicPromptInvocation'>;
export type NoiseInvocation = Invocation<'NoiseInvocation'>;
export type TextToLatentsInvocation = Invocation<'TextToLatentsInvocation'>;
export type LatentsToLatentsInvocation =
Invocation<'LatentsToLatentsInvocation'>;
export type ImageToLatentsInvocation = Invocation<'ImageToLatentsInvocation'>;
export type LatentsToImageInvocation = Invocation<'LatentsToImageInvocation'>;
export type PipelineModelLoaderInvocation =
Invocation<'PipelineModelLoaderInvocation'>;
// ControlNet Nodes // ControlNet Nodes
export type ControlNetInvocation = schemas['ControlNetInvocation']; export type ControlNetInvocation = Invocation<'ControlNetInvocation'>;
export type CannyImageProcessorInvocation = export type CannyImageProcessorInvocation =
schemas['CannyImageProcessorInvocation']; Invocation<'CannyImageProcessorInvocation'>;
export type ContentShuffleImageProcessorInvocation = export type ContentShuffleImageProcessorInvocation =
schemas['ContentShuffleImageProcessorInvocation']; Invocation<'ContentShuffleImageProcessorInvocation'>;
export type HedImageProcessorInvocation = export type HedImageProcessorInvocation =
schemas['HedImageProcessorInvocation']; Invocation<'HedImageProcessorInvocation'>;
export type LineartAnimeImageProcessorInvocation = export type LineartAnimeImageProcessorInvocation =
schemas['LineartAnimeImageProcessorInvocation']; Invocation<'LineartAnimeImageProcessorInvocation'>;
export type LineartImageProcessorInvocation = export type LineartImageProcessorInvocation =
schemas['LineartImageProcessorInvocation']; Invocation<'LineartImageProcessorInvocation'>;
export type MediapipeFaceProcessorInvocation = export type MediapipeFaceProcessorInvocation =
schemas['MediapipeFaceProcessorInvocation']; Invocation<'MediapipeFaceProcessorInvocation'>;
export type MidasDepthImageProcessorInvocation = export type MidasDepthImageProcessorInvocation =
schemas['MidasDepthImageProcessorInvocation']; Invocation<'MidasDepthImageProcessorInvocation'>;
export type MlsdImageProcessorInvocation = export type MlsdImageProcessorInvocation =
schemas['MlsdImageProcessorInvocation']; Invocation<'MlsdImageProcessorInvocation'>;
export type NormalbaeImageProcessorInvocation = export type NormalbaeImageProcessorInvocation =
schemas['NormalbaeImageProcessorInvocation']; Invocation<'NormalbaeImageProcessorInvocation'>;
export type OpenposeImageProcessorInvocation = export type OpenposeImageProcessorInvocation =
schemas['OpenposeImageProcessorInvocation']; Invocation<'OpenposeImageProcessorInvocation'>;
export type PidiImageProcessorInvocation = export type PidiImageProcessorInvocation =
schemas['PidiImageProcessorInvocation']; Invocation<'PidiImageProcessorInvocation'>;
export type ZoeDepthImageProcessorInvocation = export type ZoeDepthImageProcessorInvocation =
schemas['ZoeDepthImageProcessorInvocation']; Invocation<'ZoeDepthImageProcessorInvocation'>;
// Node Outputs // Node Outputs
export type ImageOutput = schemas['ImageOutput']; export type ImageOutput = schemas['ImageOutput'];

View File

@ -9,6 +9,7 @@
"vite.config.ts", "vite.config.ts",
"./config/vite.app.config.ts", "./config/vite.app.config.ts",
"./config/vite.package.config.ts", "./config/vite.package.config.ts",
"./config/vite.common.config.ts" "./config/vite.common.config.ts",
"./config/common.ts"
] ]
} }

File diff suppressed because it is too large Load Diff