Merge branch 'main' of github.com:invoke-ai/InvokeAI into feat/controlnet_extras

This commit is contained in:
user1 2023-06-26 16:39:31 -07:00
commit 862bfa2c36
64 changed files with 1628 additions and 983 deletions

View File

@ -1,10 +1,16 @@
name: Test invoke.py pip
# This is a dummy stand-in for the actual tests
# we don't need to run python tests on non-Python changes
# But PRs require passing tests to be mergeable
on:
pull_request:
paths:
- '**'
- '!pyproject.toml'
- '!invokeai/**'
- '!tests/**'
- 'invokeai/frontend/web/**'
merge_group:
workflow_dispatch:
@ -19,48 +25,26 @@ jobs:
strategy:
matrix:
python-version:
# - '3.9'
- '3.10'
pytorch:
# - linux-cuda-11_6
- linux-cuda-11_7
- linux-rocm-5_2
- linux-cpu
- macos-default
- windows-cpu
# - windows-cuda-11_6
# - windows-cuda-11_7
include:
# - pytorch: linux-cuda-11_6
# os: ubuntu-22.04
# extra-index-url: 'https://download.pytorch.org/whl/cu116'
# github-env: $GITHUB_ENV
- pytorch: linux-cuda-11_7
os: ubuntu-22.04
github-env: $GITHUB_ENV
- pytorch: linux-rocm-5_2
os: ubuntu-22.04
extra-index-url: 'https://download.pytorch.org/whl/rocm5.2'
github-env: $GITHUB_ENV
- pytorch: linux-cpu
os: ubuntu-22.04
extra-index-url: 'https://download.pytorch.org/whl/cpu'
github-env: $GITHUB_ENV
- pytorch: macos-default
os: macOS-12
github-env: $GITHUB_ENV
- pytorch: windows-cpu
os: windows-2022
github-env: $env:GITHUB_ENV
# - pytorch: windows-cuda-11_6
# os: windows-2022
# extra-index-url: 'https://download.pytorch.org/whl/cu116'
# github-env: $env:GITHUB_ENV
# - pytorch: windows-cuda-11_7
# os: windows-2022
# extra-index-url: 'https://download.pytorch.org/whl/cu117'
# github-env: $env:GITHUB_ENV
name: ${{ matrix.pytorch }} on ${{ matrix.python-version }}
runs-on: ${{ matrix.os }}
steps:
- run: 'echo "No build required"'
- name: skip
run: echo "no build required"

View File

@ -11,6 +11,7 @@ on:
paths:
- 'pyproject.toml'
- 'invokeai/**'
- 'tests/**'
- '!invokeai/frontend/web/**'
types:
- 'ready_for_review'
@ -32,19 +33,12 @@ jobs:
# - '3.9'
- '3.10'
pytorch:
# - linux-cuda-11_6
- linux-cuda-11_7
- linux-rocm-5_2
- linux-cpu
- macos-default
- windows-cpu
# - windows-cuda-11_6
# - windows-cuda-11_7
include:
# - pytorch: linux-cuda-11_6
# os: ubuntu-22.04
# extra-index-url: 'https://download.pytorch.org/whl/cu116'
# github-env: $GITHUB_ENV
- pytorch: linux-cuda-11_7
os: ubuntu-22.04
github-env: $GITHUB_ENV
@ -62,14 +56,6 @@ jobs:
- pytorch: windows-cpu
os: windows-2022
github-env: $env:GITHUB_ENV
# - pytorch: windows-cuda-11_6
# os: windows-2022
# extra-index-url: 'https://download.pytorch.org/whl/cu116'
# github-env: $env:GITHUB_ENV
# - pytorch: windows-cuda-11_7
# os: windows-2022
# extra-index-url: 'https://download.pytorch.org/whl/cu117'
# github-env: $env:GITHUB_ENV
name: ${{ matrix.pytorch }} on ${{ matrix.python-version }}
runs-on: ${{ matrix.os }}
env:
@ -100,40 +86,38 @@ jobs:
id: run-pytest
run: pytest
- name: run invokeai-configure
id: run-preload-models
env:
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGINGFACE_TOKEN }}
run: >
invokeai-configure
--yes
--default_only
--full-precision
# can't use fp16 weights without a GPU
# - name: run invokeai-configure
# env:
# HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGINGFACE_TOKEN }}
# run: >
# invokeai-configure
# --yes
# --default_only
# --full-precision
# # can't use fp16 weights without a GPU
- name: run invokeai
id: run-invokeai
env:
# Set offline mode to make sure configure preloaded successfully.
HF_HUB_OFFLINE: 1
HF_DATASETS_OFFLINE: 1
TRANSFORMERS_OFFLINE: 1
INVOKEAI_OUTDIR: ${{ github.workspace }}/results
run: >
invokeai
--no-patchmatch
--no-nsfw_checker
--precision=float32
--always_use_cpu
--use_memory_db
--outdir ${{ env.INVOKEAI_OUTDIR }}/${{ matrix.python-version }}/${{ matrix.pytorch }}
--from_file ${{ env.TEST_PROMPTS }}
# - name: run invokeai
# id: run-invokeai
# env:
# # Set offline mode to make sure configure preloaded successfully.
# HF_HUB_OFFLINE: 1
# HF_DATASETS_OFFLINE: 1
# TRANSFORMERS_OFFLINE: 1
# INVOKEAI_OUTDIR: ${{ github.workspace }}/results
# run: >
# invokeai
# --no-patchmatch
# --no-nsfw_checker
# --precision=float32
# --always_use_cpu
# --use_memory_db
# --outdir ${{ env.INVOKEAI_OUTDIR }}/${{ matrix.python-version }}/${{ matrix.pytorch }}
# --from_file ${{ env.TEST_PROMPTS }}
- name: Archive results
id: archive-results
env:
INVOKEAI_OUTDIR: ${{ github.workspace }}/results
uses: actions/upload-artifact@v3
with:
name: results
path: ${{ env.INVOKEAI_OUTDIR }}
# - name: Archive results
# env:
# INVOKEAI_OUTDIR: ${{ github.workspace }}/results
# uses: actions/upload-artifact@v3
# with:
# name: results
# path: ${{ env.INVOKEAI_OUTDIR }}

View File

@ -38,6 +38,7 @@ echo https://learn.microsoft.com/en-US/cpp/windows/latest-supported-vc-redist
echo.
echo See %INSTRUCTIONS% for more details.
echo.
echo "For the best user experience we suggest enlarging or maximizing this window now."
pause
@rem ---------------------------- check Python version ---------------

View File

@ -26,6 +26,7 @@ done
if [ -z "$PYTHON" ]; then
echo "A suitable Python interpreter could not be found"
echo "Please install Python $MINIMUM_PYTHON_VERSION or higher (maximum $MAXIMUM_PYTHON_VERSION) before running this script. See instructions at $INSTRUCTIONS for help."
echo "For the best user experience we suggest enlarging or maximizing this window now."
read -p "Press any key to exit"
exit -1
fi

View File

@ -293,6 +293,8 @@ def introduction() -> None:
"3. Create initial configuration files.",
"",
"[i]At any point you may interrupt this program and resume later.",
"",
"[b]For the best user experience, please enlarge or maximize this window",
),
)
)

View File

@ -279,8 +279,8 @@ def _convert_ckpt_and_cache(
raise Exception(f"Model variant {model_config.variant} not supported for {version}")
weights = app_config.root_dir / model_config.path
config_file = app_config.root_dir / model_config.config
weights = app_config.root_path / model_config.path
config_file = app_config.root_path / model_config.config
output_path = Path(output_path)
if version == BaseModelType.StableDiffusion1:

View File

@ -965,13 +965,15 @@ def main():
logger.error(
"Insufficient vertical space for the interface. Please make your window taller and try again"
)
elif str(e).startswith("addwstr"):
input('Press any key to continue...')
except Exception as e:
if str(e).startswith("addwstr"):
logger.error(
"Insufficient horizontal space for the interface. Please make your window wider and try again."
)
except Exception as e:
print(f'An exception has occurred: {str(e)} Details:')
print(traceback.format_exc(), file=sys.stderr)
else:
print(f'An exception has occurred: {str(e)} Details:')
print(traceback.format_exc(), file=sys.stderr)
input('Press any key to continue...')

View File

@ -42,6 +42,18 @@ def set_terminal_size(columns: int, lines: int, launch_command: str=None):
elif OS in ["Darwin", "Linux"]:
_set_terminal_size_unix(width,height)
# check whether it worked....
ts = get_terminal_size()
pause = False
if ts.columns < columns:
print('\033[1mThis window is too narrow for the user interface. Please make it wider.\033[0m')
pause = True
if ts.lines < lines:
print('\033[1mThis window is too short for the user interface. Please make it taller.\033[0m')
pause = True
if pause:
input('Press any key to continue..')
def _set_terminal_size_powershell(width: int, height: int):
script=f'''
$pshost = get-host

View File

@ -0,0 +1,14 @@
import react from '@vitejs/plugin-react-swc';
import { visualizer } from 'rollup-plugin-visualizer';
import { PluginOption, UserConfig } from 'vite';
import eslint from 'vite-plugin-eslint';
import tsconfigPaths from 'vite-tsconfig-paths';
import { nodePolyfills } from 'vite-plugin-node-polyfills';
export const commonPlugins: UserConfig['plugins'] = [
react(),
eslint(),
tsconfigPaths(),
visualizer() as unknown as PluginOption,
nodePolyfills(),
];

View File

@ -1,17 +1,9 @@
import react from '@vitejs/plugin-react-swc';
import { visualizer } from 'rollup-plugin-visualizer';
import { PluginOption, UserConfig } from 'vite';
import eslint from 'vite-plugin-eslint';
import tsconfigPaths from 'vite-tsconfig-paths';
import { UserConfig } from 'vite';
import { commonPlugins } from './common';
export const appConfig: UserConfig = {
base: './',
plugins: [
react(),
eslint(),
tsconfigPaths(),
visualizer() as unknown as PluginOption,
],
plugins: [...commonPlugins],
build: {
chunkSizeWarningLimit: 1500,
},

View File

@ -1,19 +1,13 @@
import react from '@vitejs/plugin-react-swc';
import path from 'path';
import { visualizer } from 'rollup-plugin-visualizer';
import { PluginOption, UserConfig } from 'vite';
import { UserConfig } from 'vite';
import dts from 'vite-plugin-dts';
import eslint from 'vite-plugin-eslint';
import tsconfigPaths from 'vite-tsconfig-paths';
import cssInjectedByJsPlugin from 'vite-plugin-css-injected-by-js';
import { commonPlugins } from './common';
export const packageConfig: UserConfig = {
base: './',
plugins: [
react(),
eslint(),
tsconfigPaths(),
visualizer() as unknown as PluginOption,
...commonPlugins,
dts({
insertTypesEntry: true,
}),

View File

@ -53,6 +53,7 @@
]
},
"dependencies": {
"@apidevtools/swagger-parser": "^10.1.0",
"@chakra-ui/anatomy": "^2.1.1",
"@chakra-ui/icons": "^2.0.19",
"@chakra-ui/react": "^2.7.1",
@ -154,6 +155,7 @@
"vite-plugin-css-injected-by-js": "^3.1.1",
"vite-plugin-dts": "^2.3.0",
"vite-plugin-eslint": "^1.8.1",
"vite-plugin-node-polyfills": "^0.9.0",
"vite-tsconfig-paths": "^4.2.0",
"yarn": "^1.22.19"
}

View File

@ -15,7 +15,7 @@ import { ImageDTO } from 'services/api/types';
import { RootState } from 'app/store/store';
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { controlNetSelector } from 'features/controlNet/store/controlNetSlice';
import { nodesSelecter } from 'features/nodes/store/nodesSlice';
import { nodesSelector } from 'features/nodes/store/nodesSlice';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import { some } from 'lodash-es';
@ -30,7 +30,7 @@ export const selectImageUsage = createSelector(
[
generationSelector,
canvasSelector,
nodesSelecter,
nodesSelector,
controlNetSelector,
(state: RootState, image_name?: string) => image_name,
],

View File

@ -1,6 +1,7 @@
import { AnyAction } from '@reduxjs/toolkit';
import { isAnyGraphBuilt } from 'features/nodes/store/actions';
import { forEach } from 'lodash-es';
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
import { Graph } from 'services/api/types';
export const actionSanitizer = <A extends AnyAction>(action: A): A => {
@ -8,17 +9,6 @@ export const actionSanitizer = <A extends AnyAction>(action: A): A => {
if (action.payload.nodes) {
const sanitizedNodes: Graph['nodes'] = {};
// Sanitize nodes as needed
forEach(action.payload.nodes, (node, key) => {
// Don't log the whole freaking dataURL
if (node.type === 'dataURL_image') {
const { dataURL, ...rest } = node;
sanitizedNodes[key] = { ...rest, dataURL: '<dataURL>' };
} else {
sanitizedNodes[key] = { ...node };
}
});
return {
...action,
payload: { ...action.payload, nodes: sanitizedNodes },
@ -26,5 +16,19 @@ export const actionSanitizer = <A extends AnyAction>(action: A): A => {
}
}
if (receivedOpenAPISchema.fulfilled.match(action)) {
return {
...action,
payload: '<OpenAPI schema omitted>',
};
}
if (nodeTemplatesBuilt.match(action)) {
return {
...action,
payload: '<Node templates omitted>',
};
}
return action;
};

View File

@ -82,6 +82,7 @@ import {
addImageRemovedFromBoardFulfilledListener,
addImageRemovedFromBoardRejectedListener,
} from './listeners/imageRemovedFromBoard';
import { addReceivedOpenAPISchemaListener } from './listeners/receivedOpenAPISchema';
export const listenerMiddleware = createListenerMiddleware();
@ -205,3 +206,6 @@ addImageAddedToBoardRejectedListener();
addImageRemovedFromBoardFulfilledListener();
addImageRemovedFromBoardRejectedListener();
addBoardIdSelectedListener();
// Node schemas
addReceivedOpenAPISchemaListener();

View File

@ -0,0 +1,35 @@
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
import { startAppListening } from '..';
import { log } from 'app/logging/useLogger';
import { parseSchema } from 'features/nodes/util/parseSchema';
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
import { size } from 'lodash-es';
const schemaLog = log.child({ namespace: 'schema' });
export const addReceivedOpenAPISchemaListener = () => {
startAppListening({
actionCreator: receivedOpenAPISchema.fulfilled,
effect: (action, { dispatch, getState }) => {
const schemaJSON = action.payload;
schemaLog.info({ data: { schemaJSON } }, 'Dereferenced OpenAPI schema');
const nodeTemplates = parseSchema(schemaJSON);
schemaLog.info(
{ data: { nodeTemplates } },
`Built ${size(nodeTemplates)} node templates`
);
dispatch(nodeTemplatesBuilt(nodeTemplates));
},
});
startAppListening({
actionCreator: receivedOpenAPISchema.rejected,
effect: (action, { dispatch, getState }) => {
schemaLog.error('Problem dereferencing OpenAPI Schema');
},
});
};

View File

@ -3,7 +3,7 @@ import { startAppListening } from '..';
import { createSelector } from '@reduxjs/toolkit';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { nodesSelecter } from 'features/nodes/store/nodesSlice';
import { nodesSelector } from 'features/nodes/store/nodesSlice';
import { controlNetSelector } from 'features/controlNet/store/controlNetSlice';
import { forEach, uniqBy } from 'lodash-es';
import { imageUrlsReceived } from 'services/api/thunks/image';
@ -16,7 +16,7 @@ const selectAllUsedImages = createSelector(
[
generationSelector,
canvasSelector,
nodesSelecter,
nodesSelector,
controlNetSelector,
selectImagesEntities,
],

View File

@ -22,6 +22,7 @@ import boardsReducer from 'features/gallery/store/boardSlice';
import configReducer from 'features/system/store/configSlice';
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
import uiReducer from 'features/ui/store/uiSlice';
import dynamicPromptsReducer from 'features/dynamicPrompts/store/slice';
import { listenerMiddleware } from './middleware/listenerMiddleware';
@ -48,6 +49,7 @@ const allReducers = {
controlNet: controlNetReducer,
boards: boardsReducer,
// session: sessionReducer,
dynamicPrompts: dynamicPromptsReducer,
[api.reducerPath]: api.reducer,
};
@ -65,6 +67,7 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
'system',
'ui',
'controlNet',
'dynamicPrompts',
// 'boards',
// 'hotkeys',
// 'config',
@ -100,3 +103,4 @@ export type AppGetState = typeof store.getState;
export type RootState = ReturnType<typeof store.getState>;
export type AppThunkDispatch = ThunkDispatch<RootState, any, AnyAction>;
export type AppDispatch = typeof store.dispatch;
export const stateSelector = (state: RootState) => state;

View File

@ -171,6 +171,14 @@ export type AppConfig = {
fineStep: number;
coarseStep: number;
};
dynamicPrompts: {
maxPrompts: {
initial: number;
min: number;
sliderMax: number;
inputMax: number;
};
};
};
};

View File

@ -27,7 +27,6 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
borderWidth: '2px',
borderColor: 'var(--invokeai-colors-base-800)',
color: 'var(--invokeai-colors-base-100)',
padding: 10,
paddingRight: 24,
fontWeight: 600,
'&:hover': { borderColor: 'var(--invokeai-colors-base-700)' },

View File

@ -34,6 +34,10 @@ const IAIMantineSelect = (props: IAISelectProps) => {
'&:focus': {
borderColor: 'var(--invokeai-colors-accent-600)',
},
'&:disabled': {
backgroundColor: 'var(--invokeai-colors-base-700)',
color: 'var(--invokeai-colors-base-400)',
},
},
dropdown: {
backgroundColor: 'var(--invokeai-colors-base-800)',
@ -64,7 +68,7 @@ const IAIMantineSelect = (props: IAISelectProps) => {
},
},
rightSection: {
width: 24,
width: 32,
},
})}
{...rest}

View File

@ -41,7 +41,15 @@ const IAISwitch = (props: Props) => {
{...formControlProps}
>
{label && (
<FormLabel my={1} flexGrow={1} {...formLabelProps}>
<FormLabel
my={1}
flexGrow={1}
sx={{
cursor: isDisabled ? 'not-allowed' : 'pointer',
...formLabelProps?.sx,
}}
{...formLabelProps}
>
{label}
</FormLabel>
)}

View File

@ -9,10 +9,12 @@ type IAICanvasImageProps = {
};
const IAICanvasImage = (props: IAICanvasImageProps) => {
const { width, height, x, y, imageName } = props.canvasImage;
const { currentData: imageDTO } = useGetImageDTOQuery(imageName ?? skipToken);
const { currentData: imageDTO, isError } = useGetImageDTOQuery(
imageName ?? skipToken
);
const [image] = useImage(imageDTO?.image_url ?? '', 'anonymous');
if (!imageDTO) {
if (isError) {
return <Rect x={x} y={y} width={width} height={height} fill="red" />;
}

View File

@ -174,7 +174,10 @@ const ControlNet = (props: ControlNetProps) => {
aspectRatio: '1/1',
}}
>
<ControlNetImagePreview controlNet={props.controlNet} />
<ControlNetImagePreview
controlNet={props.controlNet}
height={24}
/>
</Flex>
)}
</Flex>

View File

@ -0,0 +1,45 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAICollapse from 'common/components/IAICollapse';
import { useCallback } from 'react';
import { isEnabledToggled } from '../store/slice';
import ParamDynamicPromptsMaxPrompts from './ParamDynamicPromptsMaxPrompts';
import ParamDynamicPromptsCombinatorial from './ParamDynamicPromptsCombinatorial';
import { Flex } from '@chakra-ui/react';
const selector = createSelector(
stateSelector,
(state) => {
const { isEnabled } = state.dynamicPrompts;
return { isEnabled };
},
defaultSelectorOptions
);
const ParamDynamicPromptsCollapse = () => {
const dispatch = useAppDispatch();
const { isEnabled } = useAppSelector(selector);
const handleToggleIsEnabled = useCallback(() => {
dispatch(isEnabledToggled());
}, [dispatch]);
return (
<IAICollapse
isOpen={isEnabled}
onToggle={handleToggleIsEnabled}
label="Dynamic Prompts"
withSwitch
>
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<ParamDynamicPromptsMaxPrompts />
<ParamDynamicPromptsCombinatorial />
</Flex>
</IAICollapse>
);
};
export default ParamDynamicPromptsCollapse;

View File

@ -0,0 +1,36 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { combinatorialToggled } from '../store/slice';
import { createSelector } from '@reduxjs/toolkit';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useCallback } from 'react';
import { stateSelector } from 'app/store/store';
import IAISwitch from 'common/components/IAISwitch';
const selector = createSelector(
stateSelector,
(state) => {
const { combinatorial } = state.dynamicPrompts;
return { combinatorial };
},
defaultSelectorOptions
);
const ParamDynamicPromptsCombinatorial = () => {
const { combinatorial } = useAppSelector(selector);
const dispatch = useAppDispatch();
const handleChange = useCallback(() => {
dispatch(combinatorialToggled());
}, [dispatch]);
return (
<IAISwitch
label="Combinatorial Generation"
isChecked={combinatorial}
onChange={handleChange}
/>
);
};
export default ParamDynamicPromptsCombinatorial;

View File

@ -0,0 +1,53 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { maxPromptsChanged, maxPromptsReset } from '../store/slice';
import { createSelector } from '@reduxjs/toolkit';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useCallback } from 'react';
import { stateSelector } from 'app/store/store';
const selector = createSelector(
stateSelector,
(state) => {
const { maxPrompts } = state.dynamicPrompts;
const { min, sliderMax, inputMax } =
state.config.sd.dynamicPrompts.maxPrompts;
return { maxPrompts, min, sliderMax, inputMax };
},
defaultSelectorOptions
);
const ParamDynamicPromptsMaxPrompts = () => {
const { maxPrompts, min, sliderMax, inputMax } = useAppSelector(selector);
const dispatch = useAppDispatch();
const handleChange = useCallback(
(v: number) => {
dispatch(maxPromptsChanged(v));
},
[dispatch]
);
const handleReset = useCallback(() => {
dispatch(maxPromptsReset());
}, [dispatch]);
return (
<IAISlider
label="Max Prompts"
min={min}
max={sliderMax}
value={maxPrompts}
onChange={handleChange}
sliderNumberInputProps={{ max: inputMax }}
withSliderMarks
withInput
inputReadOnly
withReset
handleReset={handleReset}
/>
);
};
export default ParamDynamicPromptsMaxPrompts;

View File

@ -0,0 +1,50 @@
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
export interface DynamicPromptsState {
isEnabled: boolean;
maxPrompts: number;
combinatorial: boolean;
}
export const initialDynamicPromptsState: DynamicPromptsState = {
isEnabled: false,
maxPrompts: 100,
combinatorial: true,
};
const initialState: DynamicPromptsState = initialDynamicPromptsState;
export const dynamicPromptsSlice = createSlice({
name: 'dynamicPrompts',
initialState,
reducers: {
maxPromptsChanged: (state, action: PayloadAction<number>) => {
state.maxPrompts = action.payload;
},
maxPromptsReset: (state) => {
state.maxPrompts = initialDynamicPromptsState.maxPrompts;
},
combinatorialToggled: (state) => {
state.combinatorial = !state.combinatorial;
},
isEnabledToggled: (state) => {
state.isEnabled = !state.isEnabled;
},
},
extraReducers: (builder) => {
//
},
});
export const {
isEnabledToggled,
maxPromptsChanged,
maxPromptsReset,
combinatorialToggled,
} = dynamicPromptsSlice.actions;
export default dynamicPromptsSlice.reducer;
export const dynamicPromptsSelector = (state: RootState) =>
state.dynamicPrompts;

View File

@ -1,28 +1,41 @@
import 'reactflow/dist/style.css';
import { memo, useCallback } from 'react';
import {
Tooltip,
Menu,
MenuButton,
MenuList,
MenuItem,
} from '@chakra-ui/react';
import { FaEllipsisV } from 'react-icons/fa';
import { useCallback, forwardRef } from 'react';
import { Flex, Text } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { nodeAdded } from '../store/nodesSlice';
import { nodeAdded, nodesSelector } from '../store/nodesSlice';
import { map } from 'lodash-es';
import { RootState } from 'app/store/store';
import { useBuildInvocation } from '../hooks/useBuildInvocation';
import { AnyInvocationType } from 'services/events/types';
import IAIIconButton from 'common/components/IAIIconButton';
import { useAppToaster } from 'app/components/Toaster';
import { createSelector } from '@reduxjs/toolkit';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
type NodeTemplate = {
label: string;
value: string;
description: string;
};
const selector = createSelector(
nodesSelector,
(nodes) => {
const data: NodeTemplate[] = map(nodes.invocationTemplates, (template) => {
return {
label: template.title,
value: template.type,
description: template.description,
};
});
return { data };
},
defaultSelectorOptions
);
const AddNodeMenu = () => {
const dispatch = useAppDispatch();
const invocationTemplates = useAppSelector(
(state: RootState) => state.nodes.invocationTemplates
);
const { data } = useAppSelector(selector);
const buildInvocation = useBuildInvocation();
@ -46,23 +59,52 @@ const AddNodeMenu = () => {
);
return (
<Menu isLazy>
<MenuButton
as={IAIIconButton}
aria-label="Add Node"
icon={<FaEllipsisV />}
<Flex sx={{ gap: 2, alignItems: 'center' }}>
<IAIMantineMultiSelect
selectOnBlur={false}
placeholder="Add Node"
value={[]}
data={data}
maxDropdownHeight={400}
nothingFound="No matching nodes"
itemComponent={SelectItem}
filter={(value, selected, item: NodeTemplate) =>
item.label.toLowerCase().includes(value.toLowerCase().trim()) ||
item.value.toLowerCase().includes(value.toLowerCase().trim()) ||
item.description.toLowerCase().includes(value.toLowerCase().trim())
}
onChange={(v) => {
v[0] && addNode(v[0] as AnyInvocationType);
}}
sx={{
width: '18rem',
}}
/>
<MenuList overflowY="scroll" height={400}>
{map(invocationTemplates, ({ title, description, type }, key) => {
return (
<Tooltip key={key} label={description} placement="end" hasArrow>
<MenuItem onClick={() => addNode(type)}>{title}</MenuItem>
</Tooltip>
);
})}
</MenuList>
</Menu>
</Flex>
);
};
export default memo(AddNodeMenu);
interface ItemProps extends React.ComponentPropsWithoutRef<'div'> {
value: string;
label: string;
description: string;
}
const SelectItem = forwardRef<HTMLDivElement, ItemProps>(
({ label, description, ...others }: ItemProps, ref) => {
return (
<div ref={ref} {...others}>
<div>
<Text>{label}</Text>
<Text size="xs" color="base.600">
{description}
</Text>
</div>
</div>
);
}
);
SelectItem.displayName = 'SelectItem';
export default AddNodeMenu;

View File

@ -1,10 +1,10 @@
import { memo } from 'react';
import { Panel } from 'reactflow';
import NodeSearch from '../search/NodeSearch';
import AddNodeMenu from '../AddNodeMenu';
const TopLeftPanel = () => (
<Panel position="top-left">
<NodeSearch />
<AddNodeMenu />
</Panel>
);

View File

@ -14,9 +14,6 @@ import {
import { ImageField } from 'services/api/types';
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
import { InvocationTemplate, InvocationValue } from '../types/types';
import { parseSchema } from '../util/parseSchema';
import { log } from 'app/logging/useLogger';
import { size } from 'lodash-es';
import { RgbaColor } from 'react-colorful';
import { RootState } from 'app/store/store';
@ -78,25 +75,17 @@ const nodesSlice = createSlice({
shouldShowGraphOverlayChanged: (state, action: PayloadAction<boolean>) => {
state.shouldShowGraphOverlay = action.payload;
},
parsedOpenAPISchema: (state, action: PayloadAction<OpenAPIV3.Document>) => {
try {
const parsedSchema = parseSchema(action.payload);
// TODO: Achtung! Side effect in a reducer!
log.info(
{ namespace: 'schema', nodes: parsedSchema },
`Parsed ${size(parsedSchema)} nodes`
);
state.invocationTemplates = parsedSchema;
} catch (err) {
console.error(err);
}
nodeTemplatesBuilt: (
state,
action: PayloadAction<Record<string, InvocationTemplate>>
) => {
state.invocationTemplates = action.payload;
},
nodeEditorReset: () => {
return { ...initialNodesState };
},
},
extraReducers(builder) {
extraReducers: (builder) => {
builder.addCase(receivedOpenAPISchema.fulfilled, (state, action) => {
state.schema = action.payload;
});
@ -112,10 +101,10 @@ export const {
connectionStarted,
connectionEnded,
shouldShowGraphOverlayChanged,
parsedOpenAPISchema,
nodeTemplatesBuilt,
nodeEditorReset,
} = nodesSlice.actions;
export default nodesSlice.reducer;
export const nodesSelecter = (state: RootState) => state.nodes;
export const nodesSelector = (state: RootState) => state.nodes;

View File

@ -34,12 +34,10 @@ export type InvocationTemplate = {
* Array of invocation inputs
*/
inputs: Record<string, InputFieldTemplate>;
// inputs: InputField[];
/**
* Array of the invocation outputs
*/
outputs: Record<string, OutputFieldTemplate>;
// outputs: OutputField[];
};
export type FieldUIConfig = {
@ -335,7 +333,7 @@ export type TypeHints = {
};
export type InvocationSchemaExtra = {
output: OpenAPIV3.ReferenceObject; // the output of the invocation
output: OpenAPIV3.SchemaObject; // the output of the invocation
ui?: {
tags?: string[];
type_hints?: TypeHints;

View File

@ -1,5 +1,5 @@
import { RootState } from 'app/store/store';
import { filter, forEach, size } from 'lodash-es';
import { filter } from 'lodash-es';
import { CollectInvocation, ControlNetInvocation } from 'services/api/types';
import { NonNullableGraph } from '../types/types';
import { CONTROL_NET_COLLECT } from './graphBuilders/constants';
@ -19,9 +19,9 @@ export const addControlNetToLinearGraph = (
(c.processorType === 'none' && Boolean(c.controlImage)))
);
// Add ControlNet
if (isControlNetEnabled && validControlNets.length > 0) {
if (size(controlNets) > 1) {
if (isControlNetEnabled && Boolean(validControlNets.length)) {
if (validControlNets.length > 1) {
// We have multiple controlnets, add ControlNet collector
const controlNetIterateNode: CollectInvocation = {
id: CONTROL_NET_COLLECT,
type: 'collect',
@ -36,10 +36,9 @@ export const addControlNetToLinearGraph = (
});
}
forEach(controlNets, (controlNet) => {
validControlNets.forEach((controlNet) => {
const {
controlNetId,
isEnabled,
controlImage,
processedControlImage,
beginStepPct,
@ -50,11 +49,6 @@ export const addControlNetToLinearGraph = (
weight,
} = controlNet;
if (!isEnabled) {
// Skip disabled ControlNets
return;
}
const controlNetNode: ControlNetInvocation = {
id: `control_net_${controlNetId}`,
type: 'controlnet',
@ -82,7 +76,8 @@ export const addControlNetToLinearGraph = (
graph.nodes[controlNetNode.id] = controlNetNode;
if (size(controlNets) > 1) {
if (validControlNets.length > 1) {
// if we have multiple controlnets, link to the collector
graph.edges.push({
source: { node_id: controlNetNode.id, field: 'control' },
destination: {
@ -91,6 +86,7 @@ export const addControlNetToLinearGraph = (
},
});
} else {
// otherwise, link directly to the base node
graph.edges.push({
source: { node_id: controlNetNode.id, field: 'control' },
destination: {

View File

@ -349,21 +349,11 @@ export const getFieldType = (
if (typeHints && name in typeHints) {
rawFieldType = typeHints[name];
} else if (!schemaObject.type) {
// if schemaObject has no type, then it should have one of allOf, anyOf, oneOf
if (schemaObject.allOf) {
rawFieldType = refObjectToFieldType(
schemaObject.allOf![0] as OpenAPIV3.ReferenceObject
);
} else if (schemaObject.anyOf) {
rawFieldType = refObjectToFieldType(
schemaObject.anyOf![0] as OpenAPIV3.ReferenceObject
);
} else if (schemaObject.oneOf) {
rawFieldType = refObjectToFieldType(
schemaObject.oneOf![0] as OpenAPIV3.ReferenceObject
);
}
} else if (!schemaObject.type && schemaObject.allOf) {
// if schemaObject has no type, then it should have one of allOf
rawFieldType =
(schemaObject.allOf[0] as OpenAPIV3.SchemaObject).title ??
'Missing Field Type';
} else if (schemaObject.enum) {
rawFieldType = 'enum';
} else if (schemaObject.type) {

View File

@ -0,0 +1,153 @@
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import {
DynamicPromptInvocation,
IterateInvocation,
NoiseInvocation,
RandomIntInvocation,
RangeOfSizeInvocation,
} from 'services/api/types';
import {
DYNAMIC_PROMPT,
ITERATE,
NOISE,
POSITIVE_CONDITIONING,
RANDOM_INT,
RANGE_OF_SIZE,
} from './constants';
import { unset } from 'lodash-es';
export const addDynamicPromptsToGraph = (
graph: NonNullableGraph,
state: RootState
): void => {
const { positivePrompt, iterations, seed, shouldRandomizeSeed } =
state.generation;
const {
combinatorial,
isEnabled: isDynamicPromptsEnabled,
maxPrompts,
} = state.dynamicPrompts;
if (isDynamicPromptsEnabled) {
// iteration is handled via dynamic prompts
unset(graph.nodes[POSITIVE_CONDITIONING], 'prompt');
const dynamicPromptNode: DynamicPromptInvocation = {
id: DYNAMIC_PROMPT,
type: 'dynamic_prompt',
max_prompts: maxPrompts,
combinatorial,
prompt: positivePrompt,
};
const iterateNode: IterateInvocation = {
id: ITERATE,
type: 'iterate',
};
graph.nodes[DYNAMIC_PROMPT] = dynamicPromptNode;
graph.nodes[ITERATE] = iterateNode;
// connect dynamic prompts to compel nodes
graph.edges.push(
{
source: {
node_id: DYNAMIC_PROMPT,
field: 'prompt_collection',
},
destination: {
node_id: ITERATE,
field: 'collection',
},
},
{
source: {
node_id: ITERATE,
field: 'item',
},
destination: {
node_id: POSITIVE_CONDITIONING,
field: 'prompt',
},
}
);
if (shouldRandomizeSeed) {
// Random int node to generate the starting seed
const randomIntNode: RandomIntInvocation = {
id: RANDOM_INT,
type: 'rand_int',
};
graph.nodes[RANDOM_INT] = randomIntNode;
// Connect random int to the start of the range of size so the range starts on the random first seed
graph.edges.push({
source: { node_id: RANDOM_INT, field: 'a' },
destination: { node_id: NOISE, field: 'seed' },
});
} else {
// User specified seed, so set the start of the range of size to the seed
(graph.nodes[NOISE] as NoiseInvocation).seed = seed;
}
} else {
const rangeOfSizeNode: RangeOfSizeInvocation = {
id: RANGE_OF_SIZE,
type: 'range_of_size',
size: iterations,
step: 1,
};
const iterateNode: IterateInvocation = {
id: ITERATE,
type: 'iterate',
};
graph.nodes[ITERATE] = iterateNode;
graph.nodes[RANGE_OF_SIZE] = rangeOfSizeNode;
graph.edges.push({
source: {
node_id: RANGE_OF_SIZE,
field: 'collection',
},
destination: {
node_id: ITERATE,
field: 'collection',
},
});
graph.edges.push({
source: {
node_id: ITERATE,
field: 'item',
},
destination: {
node_id: NOISE,
field: 'seed',
},
});
// handle seed
if (shouldRandomizeSeed) {
// Random int node to generate the starting seed
const randomIntNode: RandomIntInvocation = {
id: RANDOM_INT,
type: 'rand_int',
};
graph.nodes[RANDOM_INT] = randomIntNode;
// Connect random int to the start of the range of size so the range starts on the random first seed
graph.edges.push({
source: { node_id: RANDOM_INT, field: 'a' },
destination: { node_id: RANGE_OF_SIZE, field: 'start' },
});
} else {
// User specified seed, so set the start of the range of size to the seed
rangeOfSizeNode.start = seed;
}
}
};

View File

@ -2,6 +2,7 @@ import { RootState } from 'app/store/store';
import {
ImageDTO,
ImageResizeInvocation,
ImageToLatentsInvocation,
RandomIntInvocation,
RangeOfSizeInvocation,
} from 'services/api/types';
@ -10,7 +11,7 @@ import { log } from 'app/logging/useLogger';
import {
ITERATE,
LATENTS_TO_IMAGE,
MODEL_LOADER,
PIPELINE_MODEL_LOADER,
NEGATIVE_CONDITIONING,
NOISE,
POSITIVE_CONDITIONING,
@ -24,6 +25,7 @@ import {
import { set } from 'lodash-es';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
const moduleLog = log.child({ namespace: 'nodes' });
@ -75,31 +77,19 @@ export const buildCanvasImageToImageGraph = (
id: NEGATIVE_CONDITIONING,
prompt: negativePrompt,
},
[RANGE_OF_SIZE]: {
type: 'range_of_size',
id: RANGE_OF_SIZE,
// seed - must be connected manually
// start: 0,
size: iterations,
step: 1,
},
[NOISE]: {
type: 'noise',
id: NOISE,
},
[MODEL_LOADER]: {
[PIPELINE_MODEL_LOADER]: {
type: 'pipeline_model_loader',
id: MODEL_LOADER,
id: PIPELINE_MODEL_LOADER,
model,
},
[LATENTS_TO_IMAGE]: {
type: 'l2i',
id: LATENTS_TO_IMAGE,
},
[ITERATE]: {
type: 'iterate',
id: ITERATE,
},
[LATENTS_TO_LATENTS]: {
type: 'l2l',
id: LATENTS_TO_LATENTS,
@ -120,7 +110,7 @@ export const buildCanvasImageToImageGraph = (
edges: [
{
source: {
node_id: MODEL_LOADER,
node_id: PIPELINE_MODEL_LOADER,
field: 'clip',
},
destination: {
@ -130,7 +120,7 @@ export const buildCanvasImageToImageGraph = (
},
{
source: {
node_id: MODEL_LOADER,
node_id: PIPELINE_MODEL_LOADER,
field: 'clip',
},
destination: {
@ -140,7 +130,7 @@ export const buildCanvasImageToImageGraph = (
},
{
source: {
node_id: MODEL_LOADER,
node_id: PIPELINE_MODEL_LOADER,
field: 'vae',
},
destination: {
@ -148,26 +138,6 @@ export const buildCanvasImageToImageGraph = (
field: 'vae',
},
},
{
source: {
node_id: RANGE_OF_SIZE,
field: 'collection',
},
destination: {
node_id: ITERATE,
field: 'collection',
},
},
{
source: {
node_id: ITERATE,
field: 'item',
},
destination: {
node_id: NOISE,
field: 'seed',
},
},
{
source: {
node_id: LATENTS_TO_LATENTS,
@ -200,7 +170,7 @@ export const buildCanvasImageToImageGraph = (
},
{
source: {
node_id: MODEL_LOADER,
node_id: PIPELINE_MODEL_LOADER,
field: 'vae',
},
destination: {
@ -210,7 +180,7 @@ export const buildCanvasImageToImageGraph = (
},
{
source: {
node_id: MODEL_LOADER,
node_id: PIPELINE_MODEL_LOADER,
field: 'unet',
},
destination: {
@ -241,26 +211,6 @@ export const buildCanvasImageToImageGraph = (
],
};
// handle seed
if (shouldRandomizeSeed) {
// Random int node to generate the starting seed
const randomIntNode: RandomIntInvocation = {
id: RANDOM_INT,
type: 'rand_int',
};
graph.nodes[RANDOM_INT] = randomIntNode;
// Connect random int to the start of the range of size so the range starts on the random first seed
graph.edges.push({
source: { node_id: RANDOM_INT, field: 'a' },
destination: { node_id: RANGE_OF_SIZE, field: 'start' },
});
} else {
// User specified seed, so set the start of the range of size to the seed
(graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed;
}
// handle `fit`
if (initialImage.width !== width || initialImage.height !== height) {
// The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS`
@ -306,9 +256,9 @@ export const buildCanvasImageToImageGraph = (
});
} else {
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
set(graph.nodes[IMAGE_TO_LATENTS], 'image', {
(graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image = {
image_name: initialImage.image_name,
});
};
// Pass the image's dimensions to the `NOISE` node
graph.edges.push({
@ -327,7 +277,10 @@ export const buildCanvasImageToImageGraph = (
});
}
// add controlnet
// add dynamic prompts, mutating `graph`
addDynamicPromptsToGraph(graph, state);
// add controlnet, mutating `graph`
addControlNetToLinearGraph(graph, LATENTS_TO_LATENTS, state);
return graph;

View File

@ -9,7 +9,7 @@ import { NonNullableGraph } from 'features/nodes/types/types';
import { log } from 'app/logging/useLogger';
import {
ITERATE,
MODEL_LOADER,
PIPELINE_MODEL_LOADER,
NEGATIVE_CONDITIONING,
POSITIVE_CONDITIONING,
RANDOM_INT,
@ -101,9 +101,9 @@ export const buildCanvasInpaintGraph = (
id: NEGATIVE_CONDITIONING,
prompt: negativePrompt,
},
[MODEL_LOADER]: {
[PIPELINE_MODEL_LOADER]: {
type: 'pipeline_model_loader',
id: MODEL_LOADER,
id: PIPELINE_MODEL_LOADER,
model,
},
[RANGE_OF_SIZE]: {
@ -142,7 +142,7 @@ export const buildCanvasInpaintGraph = (
},
{
source: {
node_id: MODEL_LOADER,
node_id: PIPELINE_MODEL_LOADER,
field: 'clip',
},
destination: {
@ -152,7 +152,7 @@ export const buildCanvasInpaintGraph = (
},
{
source: {
node_id: MODEL_LOADER,
node_id: PIPELINE_MODEL_LOADER,
field: 'clip',
},
destination: {
@ -162,7 +162,7 @@ export const buildCanvasInpaintGraph = (
},
{
source: {
node_id: MODEL_LOADER,
node_id: PIPELINE_MODEL_LOADER,
field: 'unet',
},
destination: {
@ -172,7 +172,7 @@ export const buildCanvasInpaintGraph = (
},
{
source: {
node_id: MODEL_LOADER,
node_id: PIPELINE_MODEL_LOADER,
field: 'vae',
},
destination: {

View File

@ -4,7 +4,7 @@ import { RandomIntInvocation, RangeOfSizeInvocation } from 'services/api/types';
import {
ITERATE,
LATENTS_TO_IMAGE,
MODEL_LOADER,
PIPELINE_MODEL_LOADER,
NEGATIVE_CONDITIONING,
NOISE,
POSITIVE_CONDITIONING,
@ -15,6 +15,7 @@ import {
} from './constants';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
/**
* Builds the Canvas tab's Text to Image graph.
@ -62,13 +63,6 @@ export const buildCanvasTextToImageGraph = (
id: NEGATIVE_CONDITIONING,
prompt: negativePrompt,
},
[RANGE_OF_SIZE]: {
type: 'range_of_size',
id: RANGE_OF_SIZE,
// start: 0, // seed - must be connected manually
size: iterations,
step: 1,
},
[NOISE]: {
type: 'noise',
id: NOISE,
@ -82,19 +76,15 @@ export const buildCanvasTextToImageGraph = (
scheduler,
steps,
},
[MODEL_LOADER]: {
[PIPELINE_MODEL_LOADER]: {
type: 'pipeline_model_loader',
id: MODEL_LOADER,
id: PIPELINE_MODEL_LOADER,
model,
},
[LATENTS_TO_IMAGE]: {
type: 'l2i',
id: LATENTS_TO_IMAGE,
},
[ITERATE]: {
type: 'iterate',
id: ITERATE,
},
},
edges: [
{
@ -119,7 +109,7 @@ export const buildCanvasTextToImageGraph = (
},
{
source: {
node_id: MODEL_LOADER,
node_id: PIPELINE_MODEL_LOADER,
field: 'clip',
},
destination: {
@ -129,7 +119,7 @@ export const buildCanvasTextToImageGraph = (
},
{
source: {
node_id: MODEL_LOADER,
node_id: PIPELINE_MODEL_LOADER,
field: 'clip',
},
destination: {
@ -139,7 +129,7 @@ export const buildCanvasTextToImageGraph = (
},
{
source: {
node_id: MODEL_LOADER,
node_id: PIPELINE_MODEL_LOADER,
field: 'unet',
},
destination: {
@ -159,7 +149,7 @@ export const buildCanvasTextToImageGraph = (
},
{
source: {
node_id: MODEL_LOADER,
node_id: PIPELINE_MODEL_LOADER,
field: 'vae',
},
destination: {
@ -167,26 +157,6 @@ export const buildCanvasTextToImageGraph = (
field: 'vae',
},
},
{
source: {
node_id: RANGE_OF_SIZE,
field: 'collection',
},
destination: {
node_id: ITERATE,
field: 'collection',
},
},
{
source: {
node_id: ITERATE,
field: 'item',
},
destination: {
node_id: NOISE,
field: 'seed',
},
},
{
source: {
node_id: NOISE,
@ -200,27 +170,10 @@ export const buildCanvasTextToImageGraph = (
],
};
// handle seed
if (shouldRandomizeSeed) {
// Random int node to generate the starting seed
const randomIntNode: RandomIntInvocation = {
id: RANDOM_INT,
type: 'rand_int',
};
// add dynamic prompts, mutating `graph`
addDynamicPromptsToGraph(graph, state);
graph.nodes[RANDOM_INT] = randomIntNode;
// Connect random int to the start of the range of size so the range starts on the random first seed
graph.edges.push({
source: { node_id: RANDOM_INT, field: 'a' },
destination: { node_id: RANGE_OF_SIZE, field: 'start' },
});
} else {
// User specified seed, so set the start of the range of size to the seed
(graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed;
}
// add controlnet
// add controlnet, mutating `graph`
addControlNetToLinearGraph(graph, TEXT_TO_LATENTS, state);
return graph;

View File

@ -1,28 +1,24 @@
import { RootState } from 'app/store/store';
import {
ImageResizeInvocation,
RandomIntInvocation,
RangeOfSizeInvocation,
ImageToLatentsInvocation,
} from 'services/api/types';
import { NonNullableGraph } from 'features/nodes/types/types';
import { log } from 'app/logging/useLogger';
import {
ITERATE,
LATENTS_TO_IMAGE,
MODEL_LOADER,
PIPELINE_MODEL_LOADER,
NEGATIVE_CONDITIONING,
NOISE,
POSITIVE_CONDITIONING,
RANDOM_INT,
RANGE_OF_SIZE,
IMAGE_TO_IMAGE_GRAPH,
IMAGE_TO_LATENTS,
LATENTS_TO_LATENTS,
RESIZE,
} from './constants';
import { set } from 'lodash-es';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
const moduleLog = log.child({ namespace: 'nodes' });
@ -44,9 +40,6 @@ export const buildLinearImageToImageGraph = (
shouldFitToWidthHeight,
width,
height,
iterations,
seed,
shouldRandomizeSeed,
} = state.generation;
/**
@ -79,31 +72,19 @@ export const buildLinearImageToImageGraph = (
id: NEGATIVE_CONDITIONING,
prompt: negativePrompt,
},
[RANGE_OF_SIZE]: {
type: 'range_of_size',
id: RANGE_OF_SIZE,
// seed - must be connected manually
// start: 0,
size: iterations,
step: 1,
},
[NOISE]: {
type: 'noise',
id: NOISE,
},
[MODEL_LOADER]: {
[PIPELINE_MODEL_LOADER]: {
type: 'pipeline_model_loader',
id: MODEL_LOADER,
id: PIPELINE_MODEL_LOADER,
model,
},
[LATENTS_TO_IMAGE]: {
type: 'l2i',
id: LATENTS_TO_IMAGE,
},
[ITERATE]: {
type: 'iterate',
id: ITERATE,
},
[LATENTS_TO_LATENTS]: {
type: 'l2l',
id: LATENTS_TO_LATENTS,
@ -124,7 +105,7 @@ export const buildLinearImageToImageGraph = (
edges: [
{
source: {
node_id: MODEL_LOADER,
node_id: PIPELINE_MODEL_LOADER,
field: 'clip',
},
destination: {
@ -134,7 +115,7 @@ export const buildLinearImageToImageGraph = (
},
{
source: {
node_id: MODEL_LOADER,
node_id: PIPELINE_MODEL_LOADER,
field: 'clip',
},
destination: {
@ -144,7 +125,7 @@ export const buildLinearImageToImageGraph = (
},
{
source: {
node_id: MODEL_LOADER,
node_id: PIPELINE_MODEL_LOADER,
field: 'vae',
},
destination: {
@ -152,26 +133,6 @@ export const buildLinearImageToImageGraph = (
field: 'vae',
},
},
{
source: {
node_id: RANGE_OF_SIZE,
field: 'collection',
},
destination: {
node_id: ITERATE,
field: 'collection',
},
},
{
source: {
node_id: ITERATE,
field: 'item',
},
destination: {
node_id: NOISE,
field: 'seed',
},
},
{
source: {
node_id: LATENTS_TO_LATENTS,
@ -204,7 +165,7 @@ export const buildLinearImageToImageGraph = (
},
{
source: {
node_id: MODEL_LOADER,
node_id: PIPELINE_MODEL_LOADER,
field: 'vae',
},
destination: {
@ -214,7 +175,7 @@ export const buildLinearImageToImageGraph = (
},
{
source: {
node_id: MODEL_LOADER,
node_id: PIPELINE_MODEL_LOADER,
field: 'unet',
},
destination: {
@ -245,26 +206,6 @@ export const buildLinearImageToImageGraph = (
],
};
// handle seed
if (shouldRandomizeSeed) {
// Random int node to generate the starting seed
const randomIntNode: RandomIntInvocation = {
id: RANDOM_INT,
type: 'rand_int',
};
graph.nodes[RANDOM_INT] = randomIntNode;
// Connect random int to the start of the range of size so the range starts on the random first seed
graph.edges.push({
source: { node_id: RANDOM_INT, field: 'a' },
destination: { node_id: RANGE_OF_SIZE, field: 'start' },
});
} else {
// User specified seed, so set the start of the range of size to the seed
(graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed;
}
// handle `fit`
if (
shouldFitToWidthHeight &&
@ -313,9 +254,9 @@ export const buildLinearImageToImageGraph = (
});
} else {
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
set(graph.nodes[IMAGE_TO_LATENTS], 'image', {
(graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image = {
image_name: initialImage.imageName,
});
};
// Pass the image's dimensions to the `NOISE` node
graph.edges.push({
@ -334,7 +275,10 @@ export const buildLinearImageToImageGraph = (
});
}
// add controlnet
// add dynamic prompts, mutating `graph`
addDynamicPromptsToGraph(graph, state);
// add controlnet, mutating `graph`
addControlNetToLinearGraph(graph, LATENTS_TO_LATENTS, state);
return graph;

View File

@ -1,33 +1,20 @@
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import {
BaseModelType,
RandomIntInvocation,
RangeOfSizeInvocation,
} from 'services/api/types';
import {
ITERATE,
LATENTS_TO_IMAGE,
MODEL_LOADER,
PIPELINE_MODEL_LOADER,
NEGATIVE_CONDITIONING,
NOISE,
POSITIVE_CONDITIONING,
RANDOM_INT,
RANGE_OF_SIZE,
TEXT_TO_IMAGE_GRAPH,
TEXT_TO_LATENTS,
} from './constants';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
type TextToImageGraphOverrides = {
width: number;
height: number;
};
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
export const buildLinearTextToImageGraph = (
state: RootState,
overrides?: TextToImageGraphOverrides
state: RootState
): NonNullableGraph => {
const {
positivePrompt,
@ -38,9 +25,6 @@ export const buildLinearTextToImageGraph = (
steps,
width,
height,
iterations,
seed,
shouldRandomizeSeed,
} = state.generation;
const model = modelIdToPipelineModelField(modelId);
@ -68,18 +52,11 @@ export const buildLinearTextToImageGraph = (
id: NEGATIVE_CONDITIONING,
prompt: negativePrompt,
},
[RANGE_OF_SIZE]: {
type: 'range_of_size',
id: RANGE_OF_SIZE,
// start: 0, // seed - must be connected manually
size: iterations,
step: 1,
},
[NOISE]: {
type: 'noise',
id: NOISE,
width: overrides?.width || width,
height: overrides?.height || height,
width,
height,
},
[TEXT_TO_LATENTS]: {
type: 't2l',
@ -88,19 +65,15 @@ export const buildLinearTextToImageGraph = (
scheduler,
steps,
},
[MODEL_LOADER]: {
[PIPELINE_MODEL_LOADER]: {
type: 'pipeline_model_loader',
id: MODEL_LOADER,
id: PIPELINE_MODEL_LOADER,
model,
},
[LATENTS_TO_IMAGE]: {
type: 'l2i',
id: LATENTS_TO_IMAGE,
},
[ITERATE]: {
type: 'iterate',
id: ITERATE,
},
},
edges: [
{
@ -125,7 +98,7 @@ export const buildLinearTextToImageGraph = (
},
{
source: {
node_id: MODEL_LOADER,
node_id: PIPELINE_MODEL_LOADER,
field: 'clip',
},
destination: {
@ -135,7 +108,7 @@ export const buildLinearTextToImageGraph = (
},
{
source: {
node_id: MODEL_LOADER,
node_id: PIPELINE_MODEL_LOADER,
field: 'clip',
},
destination: {
@ -145,7 +118,7 @@ export const buildLinearTextToImageGraph = (
},
{
source: {
node_id: MODEL_LOADER,
node_id: PIPELINE_MODEL_LOADER,
field: 'unet',
},
destination: {
@ -165,7 +138,7 @@ export const buildLinearTextToImageGraph = (
},
{
source: {
node_id: MODEL_LOADER,
node_id: PIPELINE_MODEL_LOADER,
field: 'vae',
},
destination: {
@ -173,26 +146,6 @@ export const buildLinearTextToImageGraph = (
field: 'vae',
},
},
{
source: {
node_id: RANGE_OF_SIZE,
field: 'collection',
},
destination: {
node_id: ITERATE,
field: 'collection',
},
},
{
source: {
node_id: ITERATE,
field: 'item',
},
destination: {
node_id: NOISE,
field: 'seed',
},
},
{
source: {
node_id: NOISE,
@ -206,27 +159,10 @@ export const buildLinearTextToImageGraph = (
],
};
// handle seed
if (shouldRandomizeSeed) {
// Random int node to generate the starting seed
const randomIntNode: RandomIntInvocation = {
id: RANDOM_INT,
type: 'rand_int',
};
// add dynamic prompts, mutating `graph`
addDynamicPromptsToGraph(graph, state);
graph.nodes[RANDOM_INT] = randomIntNode;
// Connect random int to the start of the range of size so the range starts on the random first seed
graph.edges.push({
source: { node_id: RANDOM_INT, field: 'a' },
destination: { node_id: RANGE_OF_SIZE, field: 'start' },
});
} else {
// User specified seed, so set the start of the range of size to the seed
(graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed;
}
// add controlnet
// add controlnet, mutating `graph`
addControlNetToLinearGraph(graph, TEXT_TO_LATENTS, state);
return graph;

View File

@ -7,12 +7,13 @@ export const NOISE = 'noise';
export const RANDOM_INT = 'rand_int';
export const RANGE_OF_SIZE = 'range_of_size';
export const ITERATE = 'iterate';
export const MODEL_LOADER = 'pipeline_model_loader';
export const PIPELINE_MODEL_LOADER = 'pipeline_model_loader';
export const IMAGE_TO_LATENTS = 'image_to_latents';
export const LATENTS_TO_LATENTS = 'latents_to_latents';
export const RESIZE = 'resize_image';
export const INPAINT = 'inpaint';
export const CONTROL_NET_COLLECT = 'control_net_collect';
export const DYNAMIC_PROMPT = 'dynamic_prompt';
// friendly graph ids
export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph';

View File

@ -1,26 +0,0 @@
import { v4 as uuidv4 } from 'uuid';
import { RootState } from 'app/store/store';
import { CompelInvocation } from 'services/api/types';
import { O } from 'ts-toolbelt';
export const buildCompelNode = (
prompt: string,
state: RootState,
overrides: O.Partial<CompelInvocation, 'deep'> = {}
): CompelInvocation => {
const nodeId = uuidv4();
const { generation } = state;
const { model } = generation;
const compelNode: CompelInvocation = {
id: nodeId,
type: 'compel',
prompt,
model,
};
Object.assign(compelNode, overrides);
return compelNode;
};

View File

@ -1,107 +0,0 @@
import { v4 as uuidv4 } from 'uuid';
import { RootState } from 'app/store/store';
import {
Edge,
ImageToImageInvocation,
TextToImageInvocation,
} from 'services/api/types';
import { O } from 'ts-toolbelt';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
export const buildImg2ImgNode = (
state: RootState,
overrides: O.Partial<ImageToImageInvocation, 'deep'> = {}
): ImageToImageInvocation => {
const nodeId = uuidv4();
const { generation } = state;
const activeTabName = activeTabNameSelector(state);
const {
positivePrompt: prompt,
negativePrompt: negativePrompt,
seed,
steps,
width,
height,
cfgScale,
scheduler,
model,
img2imgStrength: strength,
shouldFitToWidthHeight: fit,
shouldRandomizeSeed,
initialImage,
} = generation;
// const initialImage = initialImageSelector(state);
const imageToImageNode: ImageToImageInvocation = {
id: nodeId,
type: 'img2img',
prompt: `${prompt} [${negativePrompt}]`,
steps,
width,
height,
cfg_scale: cfgScale,
scheduler,
model,
strength,
fit,
};
// on Canvas tab, we do not manually specific init image
if (activeTabName !== 'unifiedCanvas') {
if (!initialImage) {
// TODO: handle this more better
throw 'no initial image';
}
imageToImageNode.image = {
image_name: initialImage.imageName,
};
}
if (!shouldRandomizeSeed) {
imageToImageNode.seed = seed;
}
Object.assign(imageToImageNode, overrides);
return imageToImageNode;
};
type hiresReturnType = {
node: Record<string, ImageToImageInvocation>;
edge: Edge;
};
export const buildHiResNode = (
baseNode: Record<string, TextToImageInvocation>,
strength?: number
): hiresReturnType => {
const nodeId = uuidv4();
const baseNodeId = Object.keys(baseNode)[0];
const baseNodeValues = Object.values(baseNode)[0];
return {
node: {
[nodeId]: {
...baseNodeValues,
id: nodeId,
type: 'img2img',
strength,
fit: true,
},
},
edge: {
source: {
field: 'image',
node_id: baseNodeId,
},
destination: {
field: 'image',
node_id: nodeId,
},
},
};
};

View File

@ -1,48 +0,0 @@
import { v4 as uuidv4 } from 'uuid';
import { RootState } from 'app/store/store';
import { InpaintInvocation } from 'services/api/types';
import { O } from 'ts-toolbelt';
export const buildInpaintNode = (
state: RootState,
overrides: O.Partial<InpaintInvocation, 'deep'> = {}
): InpaintInvocation => {
const nodeId = uuidv4();
const {
positivePrompt: prompt,
negativePrompt: negativePrompt,
seed,
steps,
width,
height,
cfgScale,
scheduler,
model,
img2imgStrength: strength,
shouldFitToWidthHeight: fit,
shouldRandomizeSeed,
} = state.generation;
const inpaintNode: InpaintInvocation = {
id: nodeId,
type: 'inpaint',
prompt: `${prompt} [${negativePrompt}]`,
steps,
width,
height,
cfg_scale: cfgScale,
scheduler,
model,
strength,
fit,
};
if (!shouldRandomizeSeed) {
inpaintNode.seed = seed;
}
Object.assign(inpaintNode, overrides);
return inpaintNode;
};

View File

@ -1,13 +0,0 @@
import { v4 as uuidv4 } from 'uuid';
import { IterateInvocation } from 'services/api/types';
export const buildIterateNode = (): IterateInvocation => {
const nodeId = uuidv4();
return {
id: nodeId,
type: 'iterate',
// collection: [],
// index: 0,
};
};

View File

@ -1,26 +0,0 @@
import { v4 as uuidv4 } from 'uuid';
import { RootState } from 'app/store/store';
import { RandomRangeInvocation, RangeInvocation } from 'services/api/types';
export const buildRangeNode = (
state: RootState
): RangeInvocation | RandomRangeInvocation => {
const nodeId = uuidv4();
const { shouldRandomizeSeed, iterations, seed } = state.generation;
if (shouldRandomizeSeed) {
return {
id: nodeId,
type: 'random_range',
size: iterations,
};
}
return {
id: nodeId,
type: 'range',
start: seed,
stop: seed + iterations,
};
};

View File

@ -1,45 +0,0 @@
import { v4 as uuidv4 } from 'uuid';
import { RootState } from 'app/store/store';
import { TextToImageInvocation } from 'services/api/types';
import { O } from 'ts-toolbelt';
export const buildTxt2ImgNode = (
state: RootState,
overrides: O.Partial<TextToImageInvocation, 'deep'> = {}
): TextToImageInvocation => {
const nodeId = uuidv4();
const { generation } = state;
const {
positivePrompt: prompt,
negativePrompt: negativePrompt,
seed,
steps,
width,
height,
cfgScale: cfg_scale,
scheduler,
shouldRandomizeSeed,
model,
} = generation;
const textToImageNode: NonNullable<TextToImageInvocation> = {
id: nodeId,
type: 'txt2img',
prompt: `${prompt} [${negativePrompt}]`,
steps,
width,
height,
cfg_scale,
scheduler,
model,
};
if (!shouldRandomizeSeed) {
textToImageNode.seed = seed;
}
Object.assign(textToImageNode, overrides);
return textToImageNode;
};

View File

@ -5,127 +5,154 @@ import {
InputFieldTemplate,
InvocationSchemaObject,
InvocationTemplate,
isInvocationSchemaObject,
OutputFieldTemplate,
} from '../types/types';
import {
buildInputFieldTemplate,
buildOutputFieldTemplates,
} from './fieldTemplateBuilders';
import { buildInputFieldTemplate, getFieldType } from './fieldTemplateBuilders';
import { O } from 'ts-toolbelt';
// recursively exclude all properties of type U from T
type DeepExclude<T, U> = T extends U
? never
: T extends object
? {
[K in keyof T]: DeepExclude<T[K], U>;
}
: T;
// The schema from swagger-parser is dereferenced, and we know `components` and `components.schemas` exist
type DereferencedOpenAPIDocument = DeepExclude<
O.Required<OpenAPIV3.Document, 'schemas' | 'components', 'deep'>,
OpenAPIV3.ReferenceObject
>;
const RESERVED_FIELD_NAMES = ['id', 'type', 'is_intermediate'];
const invocationDenylist = ['Graph', 'InvocationMeta'];
export const parseSchema = (openAPI: OpenAPIV3.Document) => {
const nodeFilter = (
schema: DereferencedOpenAPIDocument['components']['schemas'][string],
key: string
) =>
key.includes('Invocation') &&
!key.includes('InvocationOutput') &&
!invocationDenylist.some((denylistItem) => key.includes(denylistItem));
export const parseSchema = (openAPI: DereferencedOpenAPIDocument) => {
// filter out non-invocation schemas, plus some tricky invocations for now
const filteredSchemas = filter(
openAPI.components!.schemas,
(schema, key) =>
key.includes('Invocation') &&
!key.includes('InvocationOutput') &&
!invocationDenylist.some((denylistItem) => key.includes(denylistItem))
) as (OpenAPIV3.ReferenceObject | InvocationSchemaObject)[];
const filteredSchemas = filter(openAPI.components.schemas, nodeFilter);
const invocations = filteredSchemas.reduce<
Record<string, InvocationTemplate>
>((acc, schema) => {
// only want SchemaObjects
if (isInvocationSchemaObject(schema)) {
const type = schema.properties.type.default;
>((acc, s) => {
// cast to InvocationSchemaObject, we know the shape
const schema = s as InvocationSchemaObject;
const title = schema.ui?.title ?? schema.title.replace('Invocation', '');
const type = schema.properties.type.default;
const typeHints = schema.ui?.type_hints;
const title = schema.ui?.title ?? schema.title.replace('Invocation', '');
const inputs: Record<string, InputFieldTemplate> = {};
const typeHints = schema.ui?.type_hints;
if (type === 'collect') {
const itemProperty = schema.properties[
'item'
] as InvocationSchemaObject;
// Handle the special Collect node
inputs.item = {
type: 'item',
name: 'item',
description: itemProperty.description ?? '',
title: 'Collection Item',
inputKind: 'connection',
inputRequirement: 'always',
default: undefined,
};
} else if (type === 'iterate') {
const itemProperty = schema.properties[
'collection'
] as InvocationSchemaObject;
const inputs: Record<string, InputFieldTemplate> = {};
inputs.collection = {
type: 'array',
name: 'collection',
title: itemProperty.title ?? '',
default: [],
description: itemProperty.description ?? '',
inputRequirement: 'always',
inputKind: 'connection',
};
} else {
// All other nodes
reduce(
schema.properties,
(inputsAccumulator, property, propertyName) => {
if (
// `type` and `id` are not valid inputs/outputs
!RESERVED_FIELD_NAMES.includes(propertyName) &&
isSchemaObject(property)
) {
const field: InputFieldTemplate | undefined =
buildInputFieldTemplate(property, propertyName, typeHints);
if (field) {
inputsAccumulator[propertyName] = field;
}
}
return inputsAccumulator;
},
inputs
);
}
const rawOutput = (schema as InvocationSchemaObject).output;
let outputs: Record<string, OutputFieldTemplate>;
// some special handling is needed for collect, iterate and range nodes
if (type === 'iterate') {
// this is guaranteed to be a SchemaObject
const iterationOutput = openAPI.components!.schemas![
'IterateInvocationOutput'
] as OpenAPIV3.SchemaObject;
outputs = {
item: {
name: 'item',
title: iterationOutput.title ?? '',
description: iterationOutput.description ?? '',
type: 'array',
},
};
} else {
outputs = buildOutputFieldTemplates(rawOutput, openAPI, typeHints);
}
const invocation: InvocationTemplate = {
title,
type,
tags: schema.ui?.tags ?? [],
description: schema.description ?? '',
inputs,
outputs,
if (type === 'collect') {
// Special handling for the Collect node
const itemProperty = schema.properties['item'] as InvocationSchemaObject;
inputs.item = {
type: 'item',
name: 'item',
description: itemProperty.description ?? '',
title: 'Collection Item',
inputKind: 'connection',
inputRequirement: 'always',
default: undefined,
};
} else if (type === 'iterate') {
// Special handling for the Iterate node
const itemProperty = schema.properties[
'collection'
] as InvocationSchemaObject;
Object.assign(acc, { [type]: invocation });
inputs.collection = {
type: 'array',
name: 'collection',
title: itemProperty.title ?? '',
default: [],
description: itemProperty.description ?? '',
inputRequirement: 'always',
inputKind: 'connection',
};
} else {
// All other nodes
reduce(
schema.properties,
(inputsAccumulator, property, propertyName) => {
if (
// `type` and `id` are not valid inputs/outputs
!RESERVED_FIELD_NAMES.includes(propertyName) &&
isSchemaObject(property)
) {
const field: InputFieldTemplate | undefined =
buildInputFieldTemplate(property, propertyName, typeHints);
if (field) {
inputsAccumulator[propertyName] = field;
}
}
return inputsAccumulator;
},
inputs
);
}
let outputs: Record<string, OutputFieldTemplate>;
if (type === 'iterate') {
// Special handling for the Iterate node output
const iterationOutput =
openAPI.components.schemas['IterateInvocationOutput'];
outputs = {
item: {
name: 'item',
title: iterationOutput.title ?? '',
description: iterationOutput.description ?? '',
type: 'array',
},
};
} else {
// All other node outputs
outputs = reduce(
schema.output.properties as OpenAPIV3.SchemaObject,
(outputsAccumulator, property, propertyName) => {
if (!['type', 'id'].includes(propertyName)) {
const fieldType = getFieldType(property, propertyName, typeHints);
outputsAccumulator[propertyName] = {
name: propertyName,
title: property.title ?? '',
description: property.description ?? '',
type: fieldType,
};
}
return outputsAccumulator;
},
{} as Record<string, OutputFieldTemplate>
);
}
const invocation: InvocationTemplate = {
title,
type,
tags: schema.ui?.tags ?? [],
description: schema.description ?? '',
inputs,
outputs,
};
Object.assign(acc, { [type]: invocation });
return acc;
}, {});

View File

@ -1,4 +1,5 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAINumberInput from 'common/components/IAINumberInput';
import IAISlider from 'common/components/IAISlider';
@ -10,27 +11,26 @@ import { uiSelector } from 'features/ui/store/uiSelectors';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const selector = createSelector(
[generationSelector, configSelector, uiSelector, hotkeysSelector],
(generation, config, ui, hotkeys) => {
const { initial, min, sliderMax, inputMax, fineStep, coarseStep } =
config.sd.iterations;
const { iterations } = generation;
const { shouldUseSliders } = ui;
const selector = createSelector([stateSelector], (state) => {
const { initial, min, sliderMax, inputMax, fineStep, coarseStep } =
state.config.sd.iterations;
const { iterations } = state.generation;
const { shouldUseSliders } = state.ui;
const isDisabled = state.dynamicPrompts.isEnabled;
const step = hotkeys.shift ? fineStep : coarseStep;
const step = state.hotkeys.shift ? fineStep : coarseStep;
return {
iterations,
initial,
min,
sliderMax,
inputMax,
step,
shouldUseSliders,
};
}
);
return {
iterations,
initial,
min,
sliderMax,
inputMax,
step,
shouldUseSliders,
isDisabled,
};
});
const ParamIterations = () => {
const {
@ -41,6 +41,7 @@ const ParamIterations = () => {
inputMax,
step,
shouldUseSliders,
isDisabled,
} = useAppSelector(selector);
const dispatch = useAppDispatch();
const { t } = useTranslation();
@ -58,6 +59,7 @@ const ParamIterations = () => {
return shouldUseSliders ? (
<IAISlider
isDisabled={isDisabled}
label={t('parameters.images')}
step={step}
min={min}
@ -72,6 +74,7 @@ const ParamIterations = () => {
/>
) : (
<IAINumberInput
isDisabled={isDisabled}
label={t('parameters.images')}
step={step}
min={min}

View File

@ -60,6 +60,14 @@ export const initialConfigState: AppConfig = {
fineStep: 0.01,
coarseStep: 0.05,
},
dynamicPrompts: {
maxPrompts: {
initial: 100,
min: 1,
sliderMax: 1000,
inputMax: 10000,
},
},
},
};

View File

@ -4,7 +4,6 @@ import * as InvokeAI from 'app/types/invokeai';
import { InvokeLogLevel } from 'app/logging/useLogger';
import { userInvoked } from 'app/store/actions';
import { parsedOpenAPISchema } from 'features/nodes/store/nodesSlice';
import { TFuncKey, t } from 'i18next';
import { LogLevelName } from 'roarr';
import {
@ -26,6 +25,7 @@ import {
} from 'services/api/thunks/session';
import { makeToast } from '../../../app/components/Toaster';
import { LANGUAGES } from '../components/LanguagePicker';
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
export type CancelStrategy = 'immediate' | 'scheduled';
@ -382,7 +382,7 @@ export const systemSlice = createSlice({
/**
* OpenAPI schema was parsed
*/
builder.addCase(parsedOpenAPISchema, (state) => {
builder.addCase(nodeTemplatesBuilt, (state) => {
state.wasSchemaParsed = true;
});

View File

@ -8,6 +8,7 @@ import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Sym
import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse';
import ImageToImageTabCoreParameters from './ImageToImageTabCoreParameters';
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
const ImageToImageTabParameters = () => {
return (
@ -16,6 +17,7 @@ const ImageToImageTabParameters = () => {
<ParamNegativeConditioning />
<ProcessButtons />
<ImageToImageTabCoreParameters />
<ParamDynamicPromptsCollapse />
<ParamControlNetCollapse />
<ParamVariationCollapse />
<ParamNoiseCollapse />

View File

@ -9,6 +9,7 @@ import ParamHiresCollapse from 'features/parameters/components/Parameters/Hires/
import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse';
import TextToImageTabCoreParameters from './TextToImageTabCoreParameters';
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
const TextToImageTabParameters = () => {
return (
@ -17,6 +18,7 @@ const TextToImageTabParameters = () => {
<ParamNegativeConditioning />
<ProcessButtons />
<TextToImageTabCoreParameters />
<ParamDynamicPromptsCollapse />
<ParamControlNetCollapse />
<ParamVariationCollapse />
<ParamNoiseCollapse />

View File

@ -8,6 +8,7 @@ import { memo } from 'react';
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
const UnifiedCanvasParameters = () => {
return (
@ -16,6 +17,7 @@ const UnifiedCanvasParameters = () => {
<ParamNegativeConditioning />
<ProcessButtons />
<UnifiedCanvasCoreParameters />
<ParamDynamicPromptsCollapse />
<ParamControlNetCollapse />
<ParamVariationCollapse />
<ParamSymmetryCollapse />

View File

@ -15,6 +15,7 @@ export const imagesApi = api.injectEndpoints({
}
return tags;
},
keepUnusedDataFor: 86400, // 24 hours
}),
}),
});

View File

@ -2917,7 +2917,7 @@ export type components = {
/** ModelsList */
ModelsList: {
/** Models */
models: (components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"])[];
models: (components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"])[];
};
/**
* MultiplyInvocation
@ -4177,18 +4177,18 @@ export type components = {
*/
image?: components["schemas"]["ImageField"];
};
/**
* StableDiffusion2ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion1ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion2ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
};
responses: never;
parameters: never;

View File

@ -1,20 +1,45 @@
import SwaggerParser from '@apidevtools/swagger-parser';
import { createAsyncThunk } from '@reduxjs/toolkit';
import { log } from 'app/logging/useLogger';
import { parsedOpenAPISchema } from 'features/nodes/store/nodesSlice';
import { OpenAPIV3 } from 'openapi-types';
const schemaLog = log.child({ namespace: 'schema' });
function getCircularReplacer() {
const ancestors: Record<string, any>[] = [];
return function (key: string, value: any) {
if (typeof value !== 'object' || value === null) {
return value;
}
// `this` is the object that value is contained in,
// i.e., its direct parent.
// @ts-ignore
while (ancestors.length > 0 && ancestors.at(-1) !== this) {
ancestors.pop();
}
if (ancestors.includes(value)) {
return '[Circular]';
}
ancestors.push(value);
return value;
};
}
export const receivedOpenAPISchema = createAsyncThunk(
'nodes/receivedOpenAPISchema',
async (_, { dispatch }): Promise<OpenAPIV3.Document> => {
const response = await fetch(`openapi.json`);
const openAPISchema = await response.json();
async (_, { dispatch, rejectWithValue }) => {
try {
const dereferencedSchema = (await SwaggerParser.dereference(
'openapi.json'
)) as OpenAPIV3.Document;
schemaLog.info({ openAPISchema }, 'Received OpenAPI schema');
const schemaJSON = JSON.parse(
JSON.stringify(dereferencedSchema, getCircularReplacer())
);
dispatch(parsedOpenAPISchema(openAPISchema as OpenAPIV3.Document));
return openAPISchema;
return schemaJSON;
} catch (error) {
return rejectWithValue({ error });
}
}
);

View File

@ -1,7 +1,14 @@
import { O } from 'ts-toolbelt';
import { components } from './schema';
type schemas = components['schemas'];
/**
* Helper type to extract the invocation type from the schema.
* Also flags the `type` property as required.
*/
type Invocation<T extends keyof schemas> = O.Required<schemas[T], 'type'>;
/**
* Types from the API, re-exported from the types generated by `openapi-typescript`.
*/
@ -31,42 +38,51 @@ export type Edge = schemas['Edge'];
export type GraphExecutionState = schemas['GraphExecutionState'];
// General nodes
export type CollectInvocation = schemas['CollectInvocation'];
export type IterateInvocation = schemas['IterateInvocation'];
export type RangeInvocation = schemas['RangeInvocation'];
export type RandomRangeInvocation = schemas['RandomRangeInvocation'];
export type RangeOfSizeInvocation = schemas['RangeOfSizeInvocation'];
export type InpaintInvocation = schemas['InpaintInvocation'];
export type ImageResizeInvocation = schemas['ImageResizeInvocation'];
export type RandomIntInvocation = schemas['RandomIntInvocation'];
export type CompelInvocation = schemas['CompelInvocation'];
export type CollectInvocation = Invocation<'CollectInvocation'>;
export type IterateInvocation = Invocation<'IterateInvocation'>;
export type RangeInvocation = Invocation<'RangeInvocation'>;
export type RandomRangeInvocation = Invocation<'RandomRangeInvocation'>;
export type RangeOfSizeInvocation = Invocation<'RangeOfSizeInvocation'>;
export type InpaintInvocation = Invocation<'InpaintInvocation'>;
export type ImageResizeInvocation = Invocation<'ImageResizeInvocation'>;
export type RandomIntInvocation = Invocation<'RandomIntInvocation'>;
export type CompelInvocation = Invocation<'CompelInvocation'>;
export type DynamicPromptInvocation = Invocation<'DynamicPromptInvocation'>;
export type NoiseInvocation = Invocation<'NoiseInvocation'>;
export type TextToLatentsInvocation = Invocation<'TextToLatentsInvocation'>;
export type LatentsToLatentsInvocation =
Invocation<'LatentsToLatentsInvocation'>;
export type ImageToLatentsInvocation = Invocation<'ImageToLatentsInvocation'>;
export type LatentsToImageInvocation = Invocation<'LatentsToImageInvocation'>;
export type PipelineModelLoaderInvocation =
Invocation<'PipelineModelLoaderInvocation'>;
// ControlNet Nodes
export type ControlNetInvocation = schemas['ControlNetInvocation'];
export type ControlNetInvocation = Invocation<'ControlNetInvocation'>;
export type CannyImageProcessorInvocation =
schemas['CannyImageProcessorInvocation'];
Invocation<'CannyImageProcessorInvocation'>;
export type ContentShuffleImageProcessorInvocation =
schemas['ContentShuffleImageProcessorInvocation'];
Invocation<'ContentShuffleImageProcessorInvocation'>;
export type HedImageProcessorInvocation =
schemas['HedImageProcessorInvocation'];
Invocation<'HedImageProcessorInvocation'>;
export type LineartAnimeImageProcessorInvocation =
schemas['LineartAnimeImageProcessorInvocation'];
Invocation<'LineartAnimeImageProcessorInvocation'>;
export type LineartImageProcessorInvocation =
schemas['LineartImageProcessorInvocation'];
Invocation<'LineartImageProcessorInvocation'>;
export type MediapipeFaceProcessorInvocation =
schemas['MediapipeFaceProcessorInvocation'];
Invocation<'MediapipeFaceProcessorInvocation'>;
export type MidasDepthImageProcessorInvocation =
schemas['MidasDepthImageProcessorInvocation'];
Invocation<'MidasDepthImageProcessorInvocation'>;
export type MlsdImageProcessorInvocation =
schemas['MlsdImageProcessorInvocation'];
Invocation<'MlsdImageProcessorInvocation'>;
export type NormalbaeImageProcessorInvocation =
schemas['NormalbaeImageProcessorInvocation'];
Invocation<'NormalbaeImageProcessorInvocation'>;
export type OpenposeImageProcessorInvocation =
schemas['OpenposeImageProcessorInvocation'];
Invocation<'OpenposeImageProcessorInvocation'>;
export type PidiImageProcessorInvocation =
schemas['PidiImageProcessorInvocation'];
Invocation<'PidiImageProcessorInvocation'>;
export type ZoeDepthImageProcessorInvocation =
schemas['ZoeDepthImageProcessorInvocation'];
Invocation<'ZoeDepthImageProcessorInvocation'>;
// Node Outputs
export type ImageOutput = schemas['ImageOutput'];

View File

@ -9,6 +9,7 @@
"vite.config.ts",
"./config/vite.app.config.ts",
"./config/vite.package.config.ts",
"./config/vite.common.config.ts"
"./config/vite.common.config.ts",
"./config/common.ts"
]
}

File diff suppressed because it is too large Load Diff

30
tests/conftest.py Normal file
View File

@ -0,0 +1,30 @@
import pytest
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
from invokeai.app.services.graph import LibraryGraph, GraphExecutionState
from invokeai.app.services.processor import DefaultInvocationProcessor
# Ignore these files as they need to be rewritten following the model manager refactor
collect_ignore = ["nodes/test_graph_execution_state.py", "nodes/test_node_graph.py", "test_textual_inversion.py"]
@pytest.fixture(scope="session", autouse=True)
def mock_services():
# NOTE: none of these are actually called by the test invocations
return InvocationServices(
model_manager = None, # type: ignore
events = None, # type: ignore
logger = None, # type: ignore
images = None, # type: ignore
latents = None, # type: ignore
board_images=None, # type: ignore
boards=None, # type: ignore
queue = MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=sqlite_memory, table_name="graphs"
),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor(),
restoration = None, # type: ignore
configuration = None, # type: ignore
)

View File

@ -1,14 +1,18 @@
from .test_invoker import create_edge
from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
import pytest
from invokeai.app.invocations.baseinvocation import (BaseInvocation,
BaseInvocationOutput,
InvocationContext)
from invokeai.app.invocations.collections import RangeInvocation
from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation
from invokeai.app.services.processor import DefaultInvocationProcessor
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
from invokeai.app.services.graph import (CollectInvocation, Graph,
GraphExecutionState,
IterateInvocation)
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, LibraryGraph, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
import pytest
from .test_invoker import create_edge
from .test_nodes import (ImageTestInvocation, PromptCollectionTestInvocation,
PromptTestInvocation)
@pytest.fixture
@ -19,30 +23,11 @@ def simple_graph():
g.add_edge(create_edge("1", "prompt", "2", "prompt"))
return g
@pytest.fixture
def mock_services():
# NOTE: none of these are actually called by the test invocations
return InvocationServices(
model_manager = None, # type: ignore
events = None, # type: ignore
logger = None, # type: ignore
images = None, # type: ignore
latents = None, # type: ignore
queue = MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=sqlite_memory, table_name="graphs"
),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor(),
restoration = None, # type: ignore
configuration = None, # type: ignore
)
def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]:
n = g.next()
if n is None:
return (None, None)
print(f'invoking {n.id}: {type(n)}')
o = n.invoke(InvocationContext(services, "1"))
g.complete(n.id, o)
@ -51,7 +36,7 @@ def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[B
def test_graph_state_executes_in_order(simple_graph, mock_services):
g = GraphExecutionState(graph = simple_graph)
n1 = invoke_next(g, mock_services)
n2 = invoke_next(g, mock_services)
n3 = g.next()
@ -88,11 +73,11 @@ def test_graph_state_expands_iterator(mock_services):
graph.add_edge(create_edge("0", "collection", "1", "collection"))
graph.add_edge(create_edge("1", "item", "2", "a"))
graph.add_edge(create_edge("2", "a", "3", "a"))
g = GraphExecutionState(graph = graph)
while not g.is_complete():
invoke_next(g, mock_services)
prepared_add_nodes = g.source_prepared_mapping['3']
results = set([g.results[n].a for n in prepared_add_nodes])
expected = set([1, 11, 21])
@ -109,7 +94,7 @@ def test_graph_state_collects(mock_services):
graph.add_edge(create_edge("1", "collection", "2", "collection"))
graph.add_edge(create_edge("2", "item", "3", "prompt"))
graph.add_edge(create_edge("3", "prompt", "4", "item"))
g = GraphExecutionState(graph = graph)
n1 = invoke_next(g, mock_services)
n2 = invoke_next(g, mock_services)

View File

@ -1,13 +1,12 @@
from .test_nodes import ErrorInvocation, ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation, TestEventService, create_edge, wait_until
from invokeai.app.services.processor import DefaultInvocationProcessor
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
from invokeai.app.services.invoker import Invoker
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, LibraryGraph, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
import pytest
from invokeai.app.services.graph import Graph, GraphExecutionState
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invoker import Invoker
from .test_nodes import (ErrorInvocation, ImageTestInvocation,
PromptTestInvocation, create_edge, wait_until)
@pytest.fixture
def simple_graph():
@ -17,25 +16,6 @@ def simple_graph():
g.add_edge(create_edge("1", "prompt", "2", "prompt"))
return g
@pytest.fixture
def mock_services() -> InvocationServices:
# NOTE: none of these are actually called by the test invocations
return InvocationServices(
model_manager = None, # type: ignore
events = TestEventService(),
logger = None, # type: ignore
images = None, # type: ignore
latents = None, # type: ignore
queue = MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=sqlite_memory, table_name="graphs"
),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor(),
restoration = None, # type: ignore
configuration = None, # type: ignore
)
@pytest.fixture()
def mock_invoker(mock_services: InvocationServices) -> Invoker:
return Invoker(
@ -57,6 +37,7 @@ def test_can_create_graph_state_from_graph(mock_invoker: Invoker, simple_graph):
assert isinstance(g, GraphExecutionState)
assert g.graph == simple_graph
@pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
def test_can_invoke(mock_invoker: Invoker, simple_graph):
g = mock_invoker.create_execution_state(graph = simple_graph)
invocation_id = mock_invoker.invoke(g)
@ -72,6 +53,7 @@ def test_can_invoke(mock_invoker: Invoker, simple_graph):
g = mock_invoker.services.graph_execution_manager.get(g.id)
assert len(g.executed) > 0
@pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
g = mock_invoker.create_execution_state(graph = simple_graph)
invocation_id = mock_invoker.invoke(g, invoke_all = True)
@ -87,6 +69,7 @@ def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
g = mock_invoker.services.graph_execution_manager.get(g.id)
assert g.is_complete()
@pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
def test_handles_errors(mock_invoker: Invoker):
g = mock_invoker.create_execution_state()
g.graph.add_node(ErrorInvocation(id = "1"))