Merge branch 'main' of github.com:invoke-ai/InvokeAI into feat/controlnet_extras

This commit is contained in:
user1 2023-06-26 16:39:31 -07:00
commit 862bfa2c36
64 changed files with 1628 additions and 983 deletions

View File

@ -1,10 +1,16 @@
name: Test invoke.py pip name: Test invoke.py pip
# This is a dummy stand-in for the actual tests
# we don't need to run python tests on non-Python changes
# But PRs require passing tests to be mergeable
on: on:
pull_request: pull_request:
paths: paths:
- '**' - '**'
- '!pyproject.toml' - '!pyproject.toml'
- '!invokeai/**' - '!invokeai/**'
- '!tests/**'
- 'invokeai/frontend/web/**' - 'invokeai/frontend/web/**'
merge_group: merge_group:
workflow_dispatch: workflow_dispatch:
@ -19,48 +25,26 @@ jobs:
strategy: strategy:
matrix: matrix:
python-version: python-version:
# - '3.9'
- '3.10' - '3.10'
pytorch: pytorch:
# - linux-cuda-11_6
- linux-cuda-11_7 - linux-cuda-11_7
- linux-rocm-5_2 - linux-rocm-5_2
- linux-cpu - linux-cpu
- macos-default - macos-default
- windows-cpu - windows-cpu
# - windows-cuda-11_6
# - windows-cuda-11_7
include: include:
# - pytorch: linux-cuda-11_6
# os: ubuntu-22.04
# extra-index-url: 'https://download.pytorch.org/whl/cu116'
# github-env: $GITHUB_ENV
- pytorch: linux-cuda-11_7 - pytorch: linux-cuda-11_7
os: ubuntu-22.04 os: ubuntu-22.04
github-env: $GITHUB_ENV
- pytorch: linux-rocm-5_2 - pytorch: linux-rocm-5_2
os: ubuntu-22.04 os: ubuntu-22.04
extra-index-url: 'https://download.pytorch.org/whl/rocm5.2'
github-env: $GITHUB_ENV
- pytorch: linux-cpu - pytorch: linux-cpu
os: ubuntu-22.04 os: ubuntu-22.04
extra-index-url: 'https://download.pytorch.org/whl/cpu'
github-env: $GITHUB_ENV
- pytorch: macos-default - pytorch: macos-default
os: macOS-12 os: macOS-12
github-env: $GITHUB_ENV
- pytorch: windows-cpu - pytorch: windows-cpu
os: windows-2022 os: windows-2022
github-env: $env:GITHUB_ENV
# - pytorch: windows-cuda-11_6
# os: windows-2022
# extra-index-url: 'https://download.pytorch.org/whl/cu116'
# github-env: $env:GITHUB_ENV
# - pytorch: windows-cuda-11_7
# os: windows-2022
# extra-index-url: 'https://download.pytorch.org/whl/cu117'
# github-env: $env:GITHUB_ENV
name: ${{ matrix.pytorch }} on ${{ matrix.python-version }} name: ${{ matrix.pytorch }} on ${{ matrix.python-version }}
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
steps: steps:
- run: 'echo "No build required"' - name: skip
run: echo "no build required"

View File

@ -11,6 +11,7 @@ on:
paths: paths:
- 'pyproject.toml' - 'pyproject.toml'
- 'invokeai/**' - 'invokeai/**'
- 'tests/**'
- '!invokeai/frontend/web/**' - '!invokeai/frontend/web/**'
types: types:
- 'ready_for_review' - 'ready_for_review'
@ -32,19 +33,12 @@ jobs:
# - '3.9' # - '3.9'
- '3.10' - '3.10'
pytorch: pytorch:
# - linux-cuda-11_6
- linux-cuda-11_7 - linux-cuda-11_7
- linux-rocm-5_2 - linux-rocm-5_2
- linux-cpu - linux-cpu
- macos-default - macos-default
- windows-cpu - windows-cpu
# - windows-cuda-11_6
# - windows-cuda-11_7
include: include:
# - pytorch: linux-cuda-11_6
# os: ubuntu-22.04
# extra-index-url: 'https://download.pytorch.org/whl/cu116'
# github-env: $GITHUB_ENV
- pytorch: linux-cuda-11_7 - pytorch: linux-cuda-11_7
os: ubuntu-22.04 os: ubuntu-22.04
github-env: $GITHUB_ENV github-env: $GITHUB_ENV
@ -62,14 +56,6 @@ jobs:
- pytorch: windows-cpu - pytorch: windows-cpu
os: windows-2022 os: windows-2022
github-env: $env:GITHUB_ENV github-env: $env:GITHUB_ENV
# - pytorch: windows-cuda-11_6
# os: windows-2022
# extra-index-url: 'https://download.pytorch.org/whl/cu116'
# github-env: $env:GITHUB_ENV
# - pytorch: windows-cuda-11_7
# os: windows-2022
# extra-index-url: 'https://download.pytorch.org/whl/cu117'
# github-env: $env:GITHUB_ENV
name: ${{ matrix.pytorch }} on ${{ matrix.python-version }} name: ${{ matrix.pytorch }} on ${{ matrix.python-version }}
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
env: env:
@ -100,40 +86,38 @@ jobs:
id: run-pytest id: run-pytest
run: pytest run: pytest
- name: run invokeai-configure # - name: run invokeai-configure
id: run-preload-models # env:
env: # HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGINGFACE_TOKEN }}
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGINGFACE_TOKEN }} # run: >
run: > # invokeai-configure
invokeai-configure # --yes
--yes # --default_only
--default_only # --full-precision
--full-precision # # can't use fp16 weights without a GPU
# can't use fp16 weights without a GPU
- name: run invokeai # - name: run invokeai
id: run-invokeai # id: run-invokeai
env: # env:
# Set offline mode to make sure configure preloaded successfully. # # Set offline mode to make sure configure preloaded successfully.
HF_HUB_OFFLINE: 1 # HF_HUB_OFFLINE: 1
HF_DATASETS_OFFLINE: 1 # HF_DATASETS_OFFLINE: 1
TRANSFORMERS_OFFLINE: 1 # TRANSFORMERS_OFFLINE: 1
INVOKEAI_OUTDIR: ${{ github.workspace }}/results # INVOKEAI_OUTDIR: ${{ github.workspace }}/results
run: > # run: >
invokeai # invokeai
--no-patchmatch # --no-patchmatch
--no-nsfw_checker # --no-nsfw_checker
--precision=float32 # --precision=float32
--always_use_cpu # --always_use_cpu
--use_memory_db # --use_memory_db
--outdir ${{ env.INVOKEAI_OUTDIR }}/${{ matrix.python-version }}/${{ matrix.pytorch }} # --outdir ${{ env.INVOKEAI_OUTDIR }}/${{ matrix.python-version }}/${{ matrix.pytorch }}
--from_file ${{ env.TEST_PROMPTS }} # --from_file ${{ env.TEST_PROMPTS }}
- name: Archive results # - name: Archive results
id: archive-results # env:
env: # INVOKEAI_OUTDIR: ${{ github.workspace }}/results
INVOKEAI_OUTDIR: ${{ github.workspace }}/results # uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v3 # with:
with: # name: results
name: results # path: ${{ env.INVOKEAI_OUTDIR }}
path: ${{ env.INVOKEAI_OUTDIR }}

View File

@ -38,6 +38,7 @@ echo https://learn.microsoft.com/en-US/cpp/windows/latest-supported-vc-redist
echo. echo.
echo See %INSTRUCTIONS% for more details. echo See %INSTRUCTIONS% for more details.
echo. echo.
echo "For the best user experience we suggest enlarging or maximizing this window now."
pause pause
@rem ---------------------------- check Python version --------------- @rem ---------------------------- check Python version ---------------

View File

@ -26,6 +26,7 @@ done
if [ -z "$PYTHON" ]; then if [ -z "$PYTHON" ]; then
echo "A suitable Python interpreter could not be found" echo "A suitable Python interpreter could not be found"
echo "Please install Python $MINIMUM_PYTHON_VERSION or higher (maximum $MAXIMUM_PYTHON_VERSION) before running this script. See instructions at $INSTRUCTIONS for help." echo "Please install Python $MINIMUM_PYTHON_VERSION or higher (maximum $MAXIMUM_PYTHON_VERSION) before running this script. See instructions at $INSTRUCTIONS for help."
echo "For the best user experience we suggest enlarging or maximizing this window now."
read -p "Press any key to exit" read -p "Press any key to exit"
exit -1 exit -1
fi fi

View File

@ -293,6 +293,8 @@ def introduction() -> None:
"3. Create initial configuration files.", "3. Create initial configuration files.",
"", "",
"[i]At any point you may interrupt this program and resume later.", "[i]At any point you may interrupt this program and resume later.",
"",
"[b]For the best user experience, please enlarge or maximize this window",
), ),
) )
) )

View File

@ -279,8 +279,8 @@ def _convert_ckpt_and_cache(
raise Exception(f"Model variant {model_config.variant} not supported for {version}") raise Exception(f"Model variant {model_config.variant} not supported for {version}")
weights = app_config.root_dir / model_config.path weights = app_config.root_path / model_config.path
config_file = app_config.root_dir / model_config.config config_file = app_config.root_path / model_config.config
output_path = Path(output_path) output_path = Path(output_path)
if version == BaseModelType.StableDiffusion1: if version == BaseModelType.StableDiffusion1:

View File

@ -965,13 +965,15 @@ def main():
logger.error( logger.error(
"Insufficient vertical space for the interface. Please make your window taller and try again" "Insufficient vertical space for the interface. Please make your window taller and try again"
) )
elif str(e).startswith("addwstr"): input('Press any key to continue...')
except Exception as e:
if str(e).startswith("addwstr"):
logger.error( logger.error(
"Insufficient horizontal space for the interface. Please make your window wider and try again." "Insufficient horizontal space for the interface. Please make your window wider and try again."
) )
except Exception as e: else:
print(f'An exception has occurred: {str(e)} Details:') print(f'An exception has occurred: {str(e)} Details:')
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
input('Press any key to continue...') input('Press any key to continue...')

View File

@ -42,6 +42,18 @@ def set_terminal_size(columns: int, lines: int, launch_command: str=None):
elif OS in ["Darwin", "Linux"]: elif OS in ["Darwin", "Linux"]:
_set_terminal_size_unix(width,height) _set_terminal_size_unix(width,height)
# check whether it worked....
ts = get_terminal_size()
pause = False
if ts.columns < columns:
print('\033[1mThis window is too narrow for the user interface. Please make it wider.\033[0m')
pause = True
if ts.lines < lines:
print('\033[1mThis window is too short for the user interface. Please make it taller.\033[0m')
pause = True
if pause:
input('Press any key to continue..')
def _set_terminal_size_powershell(width: int, height: int): def _set_terminal_size_powershell(width: int, height: int):
script=f''' script=f'''
$pshost = get-host $pshost = get-host

View File

@ -0,0 +1,14 @@
import react from '@vitejs/plugin-react-swc';
import { visualizer } from 'rollup-plugin-visualizer';
import { PluginOption, UserConfig } from 'vite';
import eslint from 'vite-plugin-eslint';
import tsconfigPaths from 'vite-tsconfig-paths';
import { nodePolyfills } from 'vite-plugin-node-polyfills';
export const commonPlugins: UserConfig['plugins'] = [
react(),
eslint(),
tsconfigPaths(),
visualizer() as unknown as PluginOption,
nodePolyfills(),
];

View File

@ -1,17 +1,9 @@
import react from '@vitejs/plugin-react-swc'; import { UserConfig } from 'vite';
import { visualizer } from 'rollup-plugin-visualizer'; import { commonPlugins } from './common';
import { PluginOption, UserConfig } from 'vite';
import eslint from 'vite-plugin-eslint';
import tsconfigPaths from 'vite-tsconfig-paths';
export const appConfig: UserConfig = { export const appConfig: UserConfig = {
base: './', base: './',
plugins: [ plugins: [...commonPlugins],
react(),
eslint(),
tsconfigPaths(),
visualizer() as unknown as PluginOption,
],
build: { build: {
chunkSizeWarningLimit: 1500, chunkSizeWarningLimit: 1500,
}, },

View File

@ -1,19 +1,13 @@
import react from '@vitejs/plugin-react-swc';
import path from 'path'; import path from 'path';
import { visualizer } from 'rollup-plugin-visualizer'; import { UserConfig } from 'vite';
import { PluginOption, UserConfig } from 'vite';
import dts from 'vite-plugin-dts'; import dts from 'vite-plugin-dts';
import eslint from 'vite-plugin-eslint';
import tsconfigPaths from 'vite-tsconfig-paths';
import cssInjectedByJsPlugin from 'vite-plugin-css-injected-by-js'; import cssInjectedByJsPlugin from 'vite-plugin-css-injected-by-js';
import { commonPlugins } from './common';
export const packageConfig: UserConfig = { export const packageConfig: UserConfig = {
base: './', base: './',
plugins: [ plugins: [
react(), ...commonPlugins,
eslint(),
tsconfigPaths(),
visualizer() as unknown as PluginOption,
dts({ dts({
insertTypesEntry: true, insertTypesEntry: true,
}), }),

View File

@ -53,6 +53,7 @@
] ]
}, },
"dependencies": { "dependencies": {
"@apidevtools/swagger-parser": "^10.1.0",
"@chakra-ui/anatomy": "^2.1.1", "@chakra-ui/anatomy": "^2.1.1",
"@chakra-ui/icons": "^2.0.19", "@chakra-ui/icons": "^2.0.19",
"@chakra-ui/react": "^2.7.1", "@chakra-ui/react": "^2.7.1",
@ -154,6 +155,7 @@
"vite-plugin-css-injected-by-js": "^3.1.1", "vite-plugin-css-injected-by-js": "^3.1.1",
"vite-plugin-dts": "^2.3.0", "vite-plugin-dts": "^2.3.0",
"vite-plugin-eslint": "^1.8.1", "vite-plugin-eslint": "^1.8.1",
"vite-plugin-node-polyfills": "^0.9.0",
"vite-tsconfig-paths": "^4.2.0", "vite-tsconfig-paths": "^4.2.0",
"yarn": "^1.22.19" "yarn": "^1.22.19"
} }

View File

@ -15,7 +15,7 @@ import { ImageDTO } from 'services/api/types';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { canvasSelector } from 'features/canvas/store/canvasSelectors'; import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { controlNetSelector } from 'features/controlNet/store/controlNetSlice'; import { controlNetSelector } from 'features/controlNet/store/controlNetSlice';
import { nodesSelecter } from 'features/nodes/store/nodesSlice'; import { nodesSelector } from 'features/nodes/store/nodesSlice';
import { generationSelector } from 'features/parameters/store/generationSelectors'; import { generationSelector } from 'features/parameters/store/generationSelectors';
import { some } from 'lodash-es'; import { some } from 'lodash-es';
@ -30,7 +30,7 @@ export const selectImageUsage = createSelector(
[ [
generationSelector, generationSelector,
canvasSelector, canvasSelector,
nodesSelecter, nodesSelector,
controlNetSelector, controlNetSelector,
(state: RootState, image_name?: string) => image_name, (state: RootState, image_name?: string) => image_name,
], ],

View File

@ -1,6 +1,7 @@
import { AnyAction } from '@reduxjs/toolkit'; import { AnyAction } from '@reduxjs/toolkit';
import { isAnyGraphBuilt } from 'features/nodes/store/actions'; import { isAnyGraphBuilt } from 'features/nodes/store/actions';
import { forEach } from 'lodash-es'; import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
import { Graph } from 'services/api/types'; import { Graph } from 'services/api/types';
export const actionSanitizer = <A extends AnyAction>(action: A): A => { export const actionSanitizer = <A extends AnyAction>(action: A): A => {
@ -8,17 +9,6 @@ export const actionSanitizer = <A extends AnyAction>(action: A): A => {
if (action.payload.nodes) { if (action.payload.nodes) {
const sanitizedNodes: Graph['nodes'] = {}; const sanitizedNodes: Graph['nodes'] = {};
// Sanitize nodes as needed
forEach(action.payload.nodes, (node, key) => {
// Don't log the whole freaking dataURL
if (node.type === 'dataURL_image') {
const { dataURL, ...rest } = node;
sanitizedNodes[key] = { ...rest, dataURL: '<dataURL>' };
} else {
sanitizedNodes[key] = { ...node };
}
});
return { return {
...action, ...action,
payload: { ...action.payload, nodes: sanitizedNodes }, payload: { ...action.payload, nodes: sanitizedNodes },
@ -26,5 +16,19 @@ export const actionSanitizer = <A extends AnyAction>(action: A): A => {
} }
} }
if (receivedOpenAPISchema.fulfilled.match(action)) {
return {
...action,
payload: '<OpenAPI schema omitted>',
};
}
if (nodeTemplatesBuilt.match(action)) {
return {
...action,
payload: '<Node templates omitted>',
};
}
return action; return action;
}; };

View File

@ -82,6 +82,7 @@ import {
addImageRemovedFromBoardFulfilledListener, addImageRemovedFromBoardFulfilledListener,
addImageRemovedFromBoardRejectedListener, addImageRemovedFromBoardRejectedListener,
} from './listeners/imageRemovedFromBoard'; } from './listeners/imageRemovedFromBoard';
import { addReceivedOpenAPISchemaListener } from './listeners/receivedOpenAPISchema';
export const listenerMiddleware = createListenerMiddleware(); export const listenerMiddleware = createListenerMiddleware();
@ -205,3 +206,6 @@ addImageAddedToBoardRejectedListener();
addImageRemovedFromBoardFulfilledListener(); addImageRemovedFromBoardFulfilledListener();
addImageRemovedFromBoardRejectedListener(); addImageRemovedFromBoardRejectedListener();
addBoardIdSelectedListener(); addBoardIdSelectedListener();
// Node schemas
addReceivedOpenAPISchemaListener();

View File

@ -0,0 +1,35 @@
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
import { startAppListening } from '..';
import { log } from 'app/logging/useLogger';
import { parseSchema } from 'features/nodes/util/parseSchema';
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
import { size } from 'lodash-es';
const schemaLog = log.child({ namespace: 'schema' });
export const addReceivedOpenAPISchemaListener = () => {
startAppListening({
actionCreator: receivedOpenAPISchema.fulfilled,
effect: (action, { dispatch, getState }) => {
const schemaJSON = action.payload;
schemaLog.info({ data: { schemaJSON } }, 'Dereferenced OpenAPI schema');
const nodeTemplates = parseSchema(schemaJSON);
schemaLog.info(
{ data: { nodeTemplates } },
`Built ${size(nodeTemplates)} node templates`
);
dispatch(nodeTemplatesBuilt(nodeTemplates));
},
});
startAppListening({
actionCreator: receivedOpenAPISchema.rejected,
effect: (action, { dispatch, getState }) => {
schemaLog.error('Problem dereferencing OpenAPI Schema');
},
});
};

View File

@ -3,7 +3,7 @@ import { startAppListening } from '..';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { generationSelector } from 'features/parameters/store/generationSelectors'; import { generationSelector } from 'features/parameters/store/generationSelectors';
import { canvasSelector } from 'features/canvas/store/canvasSelectors'; import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { nodesSelecter } from 'features/nodes/store/nodesSlice'; import { nodesSelector } from 'features/nodes/store/nodesSlice';
import { controlNetSelector } from 'features/controlNet/store/controlNetSlice'; import { controlNetSelector } from 'features/controlNet/store/controlNetSlice';
import { forEach, uniqBy } from 'lodash-es'; import { forEach, uniqBy } from 'lodash-es';
import { imageUrlsReceived } from 'services/api/thunks/image'; import { imageUrlsReceived } from 'services/api/thunks/image';
@ -16,7 +16,7 @@ const selectAllUsedImages = createSelector(
[ [
generationSelector, generationSelector,
canvasSelector, canvasSelector,
nodesSelecter, nodesSelector,
controlNetSelector, controlNetSelector,
selectImagesEntities, selectImagesEntities,
], ],

View File

@ -22,6 +22,7 @@ import boardsReducer from 'features/gallery/store/boardSlice';
import configReducer from 'features/system/store/configSlice'; import configReducer from 'features/system/store/configSlice';
import hotkeysReducer from 'features/ui/store/hotkeysSlice'; import hotkeysReducer from 'features/ui/store/hotkeysSlice';
import uiReducer from 'features/ui/store/uiSlice'; import uiReducer from 'features/ui/store/uiSlice';
import dynamicPromptsReducer from 'features/dynamicPrompts/store/slice';
import { listenerMiddleware } from './middleware/listenerMiddleware'; import { listenerMiddleware } from './middleware/listenerMiddleware';
@ -48,6 +49,7 @@ const allReducers = {
controlNet: controlNetReducer, controlNet: controlNetReducer,
boards: boardsReducer, boards: boardsReducer,
// session: sessionReducer, // session: sessionReducer,
dynamicPrompts: dynamicPromptsReducer,
[api.reducerPath]: api.reducer, [api.reducerPath]: api.reducer,
}; };
@ -65,6 +67,7 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
'system', 'system',
'ui', 'ui',
'controlNet', 'controlNet',
'dynamicPrompts',
// 'boards', // 'boards',
// 'hotkeys', // 'hotkeys',
// 'config', // 'config',
@ -100,3 +103,4 @@ export type AppGetState = typeof store.getState;
export type RootState = ReturnType<typeof store.getState>; export type RootState = ReturnType<typeof store.getState>;
export type AppThunkDispatch = ThunkDispatch<RootState, any, AnyAction>; export type AppThunkDispatch = ThunkDispatch<RootState, any, AnyAction>;
export type AppDispatch = typeof store.dispatch; export type AppDispatch = typeof store.dispatch;
export const stateSelector = (state: RootState) => state;

View File

@ -171,6 +171,14 @@ export type AppConfig = {
fineStep: number; fineStep: number;
coarseStep: number; coarseStep: number;
}; };
dynamicPrompts: {
maxPrompts: {
initial: number;
min: number;
sliderMax: number;
inputMax: number;
};
};
}; };
}; };

View File

@ -27,7 +27,6 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
borderWidth: '2px', borderWidth: '2px',
borderColor: 'var(--invokeai-colors-base-800)', borderColor: 'var(--invokeai-colors-base-800)',
color: 'var(--invokeai-colors-base-100)', color: 'var(--invokeai-colors-base-100)',
padding: 10,
paddingRight: 24, paddingRight: 24,
fontWeight: 600, fontWeight: 600,
'&:hover': { borderColor: 'var(--invokeai-colors-base-700)' }, '&:hover': { borderColor: 'var(--invokeai-colors-base-700)' },

View File

@ -34,6 +34,10 @@ const IAIMantineSelect = (props: IAISelectProps) => {
'&:focus': { '&:focus': {
borderColor: 'var(--invokeai-colors-accent-600)', borderColor: 'var(--invokeai-colors-accent-600)',
}, },
'&:disabled': {
backgroundColor: 'var(--invokeai-colors-base-700)',
color: 'var(--invokeai-colors-base-400)',
},
}, },
dropdown: { dropdown: {
backgroundColor: 'var(--invokeai-colors-base-800)', backgroundColor: 'var(--invokeai-colors-base-800)',
@ -64,7 +68,7 @@ const IAIMantineSelect = (props: IAISelectProps) => {
}, },
}, },
rightSection: { rightSection: {
width: 24, width: 32,
}, },
})} })}
{...rest} {...rest}

View File

@ -41,7 +41,15 @@ const IAISwitch = (props: Props) => {
{...formControlProps} {...formControlProps}
> >
{label && ( {label && (
<FormLabel my={1} flexGrow={1} {...formLabelProps}> <FormLabel
my={1}
flexGrow={1}
sx={{
cursor: isDisabled ? 'not-allowed' : 'pointer',
...formLabelProps?.sx,
}}
{...formLabelProps}
>
{label} {label}
</FormLabel> </FormLabel>
)} )}

View File

@ -9,10 +9,12 @@ type IAICanvasImageProps = {
}; };
const IAICanvasImage = (props: IAICanvasImageProps) => { const IAICanvasImage = (props: IAICanvasImageProps) => {
const { width, height, x, y, imageName } = props.canvasImage; const { width, height, x, y, imageName } = props.canvasImage;
const { currentData: imageDTO } = useGetImageDTOQuery(imageName ?? skipToken); const { currentData: imageDTO, isError } = useGetImageDTOQuery(
imageName ?? skipToken
);
const [image] = useImage(imageDTO?.image_url ?? '', 'anonymous'); const [image] = useImage(imageDTO?.image_url ?? '', 'anonymous');
if (!imageDTO) { if (isError) {
return <Rect x={x} y={y} width={width} height={height} fill="red" />; return <Rect x={x} y={y} width={width} height={height} fill="red" />;
} }

View File

@ -174,7 +174,10 @@ const ControlNet = (props: ControlNetProps) => {
aspectRatio: '1/1', aspectRatio: '1/1',
}} }}
> >
<ControlNetImagePreview controlNet={props.controlNet} /> <ControlNetImagePreview
controlNet={props.controlNet}
height={24}
/>
</Flex> </Flex>
)} )}
</Flex> </Flex>

View File

@ -0,0 +1,45 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAICollapse from 'common/components/IAICollapse';
import { useCallback } from 'react';
import { isEnabledToggled } from '../store/slice';
import ParamDynamicPromptsMaxPrompts from './ParamDynamicPromptsMaxPrompts';
import ParamDynamicPromptsCombinatorial from './ParamDynamicPromptsCombinatorial';
import { Flex } from '@chakra-ui/react';
const selector = createSelector(
stateSelector,
(state) => {
const { isEnabled } = state.dynamicPrompts;
return { isEnabled };
},
defaultSelectorOptions
);
const ParamDynamicPromptsCollapse = () => {
const dispatch = useAppDispatch();
const { isEnabled } = useAppSelector(selector);
const handleToggleIsEnabled = useCallback(() => {
dispatch(isEnabledToggled());
}, [dispatch]);
return (
<IAICollapse
isOpen={isEnabled}
onToggle={handleToggleIsEnabled}
label="Dynamic Prompts"
withSwitch
>
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<ParamDynamicPromptsMaxPrompts />
<ParamDynamicPromptsCombinatorial />
</Flex>
</IAICollapse>
);
};
export default ParamDynamicPromptsCollapse;

View File

@ -0,0 +1,36 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { combinatorialToggled } from '../store/slice';
import { createSelector } from '@reduxjs/toolkit';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useCallback } from 'react';
import { stateSelector } from 'app/store/store';
import IAISwitch from 'common/components/IAISwitch';
const selector = createSelector(
stateSelector,
(state) => {
const { combinatorial } = state.dynamicPrompts;
return { combinatorial };
},
defaultSelectorOptions
);
const ParamDynamicPromptsCombinatorial = () => {
const { combinatorial } = useAppSelector(selector);
const dispatch = useAppDispatch();
const handleChange = useCallback(() => {
dispatch(combinatorialToggled());
}, [dispatch]);
return (
<IAISwitch
label="Combinatorial Generation"
isChecked={combinatorial}
onChange={handleChange}
/>
);
};
export default ParamDynamicPromptsCombinatorial;

View File

@ -0,0 +1,53 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { maxPromptsChanged, maxPromptsReset } from '../store/slice';
import { createSelector } from '@reduxjs/toolkit';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useCallback } from 'react';
import { stateSelector } from 'app/store/store';
const selector = createSelector(
stateSelector,
(state) => {
const { maxPrompts } = state.dynamicPrompts;
const { min, sliderMax, inputMax } =
state.config.sd.dynamicPrompts.maxPrompts;
return { maxPrompts, min, sliderMax, inputMax };
},
defaultSelectorOptions
);
const ParamDynamicPromptsMaxPrompts = () => {
const { maxPrompts, min, sliderMax, inputMax } = useAppSelector(selector);
const dispatch = useAppDispatch();
const handleChange = useCallback(
(v: number) => {
dispatch(maxPromptsChanged(v));
},
[dispatch]
);
const handleReset = useCallback(() => {
dispatch(maxPromptsReset());
}, [dispatch]);
return (
<IAISlider
label="Max Prompts"
min={min}
max={sliderMax}
value={maxPrompts}
onChange={handleChange}
sliderNumberInputProps={{ max: inputMax }}
withSliderMarks
withInput
inputReadOnly
withReset
handleReset={handleReset}
/>
);
};
export default ParamDynamicPromptsMaxPrompts;

View File

@ -0,0 +1,50 @@
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
export interface DynamicPromptsState {
isEnabled: boolean;
maxPrompts: number;
combinatorial: boolean;
}
export const initialDynamicPromptsState: DynamicPromptsState = {
isEnabled: false,
maxPrompts: 100,
combinatorial: true,
};
const initialState: DynamicPromptsState = initialDynamicPromptsState;
export const dynamicPromptsSlice = createSlice({
name: 'dynamicPrompts',
initialState,
reducers: {
maxPromptsChanged: (state, action: PayloadAction<number>) => {
state.maxPrompts = action.payload;
},
maxPromptsReset: (state) => {
state.maxPrompts = initialDynamicPromptsState.maxPrompts;
},
combinatorialToggled: (state) => {
state.combinatorial = !state.combinatorial;
},
isEnabledToggled: (state) => {
state.isEnabled = !state.isEnabled;
},
},
extraReducers: (builder) => {
//
},
});
export const {
isEnabledToggled,
maxPromptsChanged,
maxPromptsReset,
combinatorialToggled,
} = dynamicPromptsSlice.actions;
export default dynamicPromptsSlice.reducer;
export const dynamicPromptsSelector = (state: RootState) =>
state.dynamicPrompts;

View File

@ -1,28 +1,41 @@
import 'reactflow/dist/style.css'; import 'reactflow/dist/style.css';
import { memo, useCallback } from 'react'; import { useCallback, forwardRef } from 'react';
import { import { Flex, Text } from '@chakra-ui/react';
Tooltip,
Menu,
MenuButton,
MenuList,
MenuItem,
} from '@chakra-ui/react';
import { FaEllipsisV } from 'react-icons/fa';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { nodeAdded } from '../store/nodesSlice'; import { nodeAdded, nodesSelector } from '../store/nodesSlice';
import { map } from 'lodash-es'; import { map } from 'lodash-es';
import { RootState } from 'app/store/store';
import { useBuildInvocation } from '../hooks/useBuildInvocation'; import { useBuildInvocation } from '../hooks/useBuildInvocation';
import { AnyInvocationType } from 'services/events/types'; import { AnyInvocationType } from 'services/events/types';
import IAIIconButton from 'common/components/IAIIconButton';
import { useAppToaster } from 'app/components/Toaster'; import { useAppToaster } from 'app/components/Toaster';
import { createSelector } from '@reduxjs/toolkit';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
type NodeTemplate = {
label: string;
value: string;
description: string;
};
const selector = createSelector(
nodesSelector,
(nodes) => {
const data: NodeTemplate[] = map(nodes.invocationTemplates, (template) => {
return {
label: template.title,
value: template.type,
description: template.description,
};
});
return { data };
},
defaultSelectorOptions
);
const AddNodeMenu = () => { const AddNodeMenu = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { data } = useAppSelector(selector);
const invocationTemplates = useAppSelector(
(state: RootState) => state.nodes.invocationTemplates
);
const buildInvocation = useBuildInvocation(); const buildInvocation = useBuildInvocation();
@ -46,23 +59,52 @@ const AddNodeMenu = () => {
); );
return ( return (
<Menu isLazy> <Flex sx={{ gap: 2, alignItems: 'center' }}>
<MenuButton <IAIMantineMultiSelect
as={IAIIconButton} selectOnBlur={false}
aria-label="Add Node" placeholder="Add Node"
icon={<FaEllipsisV />} value={[]}
data={data}
maxDropdownHeight={400}
nothingFound="No matching nodes"
itemComponent={SelectItem}
filter={(value, selected, item: NodeTemplate) =>
item.label.toLowerCase().includes(value.toLowerCase().trim()) ||
item.value.toLowerCase().includes(value.toLowerCase().trim()) ||
item.description.toLowerCase().includes(value.toLowerCase().trim())
}
onChange={(v) => {
v[0] && addNode(v[0] as AnyInvocationType);
}}
sx={{
width: '18rem',
}}
/> />
<MenuList overflowY="scroll" height={400}> </Flex>
{map(invocationTemplates, ({ title, description, type }, key) => {
return (
<Tooltip key={key} label={description} placement="end" hasArrow>
<MenuItem onClick={() => addNode(type)}>{title}</MenuItem>
</Tooltip>
);
})}
</MenuList>
</Menu>
); );
}; };
export default memo(AddNodeMenu); interface ItemProps extends React.ComponentPropsWithoutRef<'div'> {
value: string;
label: string;
description: string;
}
const SelectItem = forwardRef<HTMLDivElement, ItemProps>(
({ label, description, ...others }: ItemProps, ref) => {
return (
<div ref={ref} {...others}>
<div>
<Text>{label}</Text>
<Text size="xs" color="base.600">
{description}
</Text>
</div>
</div>
);
}
);
SelectItem.displayName = 'SelectItem';
export default AddNodeMenu;

View File

@ -1,10 +1,10 @@
import { memo } from 'react'; import { memo } from 'react';
import { Panel } from 'reactflow'; import { Panel } from 'reactflow';
import NodeSearch from '../search/NodeSearch'; import AddNodeMenu from '../AddNodeMenu';
const TopLeftPanel = () => ( const TopLeftPanel = () => (
<Panel position="top-left"> <Panel position="top-left">
<NodeSearch /> <AddNodeMenu />
</Panel> </Panel>
); );

View File

@ -14,9 +14,6 @@ import {
import { ImageField } from 'services/api/types'; import { ImageField } from 'services/api/types';
import { receivedOpenAPISchema } from 'services/api/thunks/schema'; import { receivedOpenAPISchema } from 'services/api/thunks/schema';
import { InvocationTemplate, InvocationValue } from '../types/types'; import { InvocationTemplate, InvocationValue } from '../types/types';
import { parseSchema } from '../util/parseSchema';
import { log } from 'app/logging/useLogger';
import { size } from 'lodash-es';
import { RgbaColor } from 'react-colorful'; import { RgbaColor } from 'react-colorful';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
@ -78,25 +75,17 @@ const nodesSlice = createSlice({
shouldShowGraphOverlayChanged: (state, action: PayloadAction<boolean>) => { shouldShowGraphOverlayChanged: (state, action: PayloadAction<boolean>) => {
state.shouldShowGraphOverlay = action.payload; state.shouldShowGraphOverlay = action.payload;
}, },
parsedOpenAPISchema: (state, action: PayloadAction<OpenAPIV3.Document>) => { nodeTemplatesBuilt: (
try { state,
const parsedSchema = parseSchema(action.payload); action: PayloadAction<Record<string, InvocationTemplate>>
) => {
// TODO: Achtung! Side effect in a reducer! state.invocationTemplates = action.payload;
log.info(
{ namespace: 'schema', nodes: parsedSchema },
`Parsed ${size(parsedSchema)} nodes`
);
state.invocationTemplates = parsedSchema;
} catch (err) {
console.error(err);
}
}, },
nodeEditorReset: () => { nodeEditorReset: () => {
return { ...initialNodesState }; return { ...initialNodesState };
}, },
}, },
extraReducers(builder) { extraReducers: (builder) => {
builder.addCase(receivedOpenAPISchema.fulfilled, (state, action) => { builder.addCase(receivedOpenAPISchema.fulfilled, (state, action) => {
state.schema = action.payload; state.schema = action.payload;
}); });
@ -112,10 +101,10 @@ export const {
connectionStarted, connectionStarted,
connectionEnded, connectionEnded,
shouldShowGraphOverlayChanged, shouldShowGraphOverlayChanged,
parsedOpenAPISchema, nodeTemplatesBuilt,
nodeEditorReset, nodeEditorReset,
} = nodesSlice.actions; } = nodesSlice.actions;
export default nodesSlice.reducer; export default nodesSlice.reducer;
export const nodesSelecter = (state: RootState) => state.nodes; export const nodesSelector = (state: RootState) => state.nodes;

View File

@ -34,12 +34,10 @@ export type InvocationTemplate = {
* Array of invocation inputs * Array of invocation inputs
*/ */
inputs: Record<string, InputFieldTemplate>; inputs: Record<string, InputFieldTemplate>;
// inputs: InputField[];
/** /**
* Array of the invocation outputs * Array of the invocation outputs
*/ */
outputs: Record<string, OutputFieldTemplate>; outputs: Record<string, OutputFieldTemplate>;
// outputs: OutputField[];
}; };
export type FieldUIConfig = { export type FieldUIConfig = {
@ -335,7 +333,7 @@ export type TypeHints = {
}; };
export type InvocationSchemaExtra = { export type InvocationSchemaExtra = {
output: OpenAPIV3.ReferenceObject; // the output of the invocation output: OpenAPIV3.SchemaObject; // the output of the invocation
ui?: { ui?: {
tags?: string[]; tags?: string[];
type_hints?: TypeHints; type_hints?: TypeHints;

View File

@ -1,5 +1,5 @@
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { filter, forEach, size } from 'lodash-es'; import { filter } from 'lodash-es';
import { CollectInvocation, ControlNetInvocation } from 'services/api/types'; import { CollectInvocation, ControlNetInvocation } from 'services/api/types';
import { NonNullableGraph } from '../types/types'; import { NonNullableGraph } from '../types/types';
import { CONTROL_NET_COLLECT } from './graphBuilders/constants'; import { CONTROL_NET_COLLECT } from './graphBuilders/constants';
@ -19,9 +19,9 @@ export const addControlNetToLinearGraph = (
(c.processorType === 'none' && Boolean(c.controlImage))) (c.processorType === 'none' && Boolean(c.controlImage)))
); );
// Add ControlNet if (isControlNetEnabled && Boolean(validControlNets.length)) {
if (isControlNetEnabled && validControlNets.length > 0) { if (validControlNets.length > 1) {
if (size(controlNets) > 1) { // We have multiple controlnets, add ControlNet collector
const controlNetIterateNode: CollectInvocation = { const controlNetIterateNode: CollectInvocation = {
id: CONTROL_NET_COLLECT, id: CONTROL_NET_COLLECT,
type: 'collect', type: 'collect',
@ -36,10 +36,9 @@ export const addControlNetToLinearGraph = (
}); });
} }
forEach(controlNets, (controlNet) => { validControlNets.forEach((controlNet) => {
const { const {
controlNetId, controlNetId,
isEnabled,
controlImage, controlImage,
processedControlImage, processedControlImage,
beginStepPct, beginStepPct,
@ -50,11 +49,6 @@ export const addControlNetToLinearGraph = (
weight, weight,
} = controlNet; } = controlNet;
if (!isEnabled) {
// Skip disabled ControlNets
return;
}
const controlNetNode: ControlNetInvocation = { const controlNetNode: ControlNetInvocation = {
id: `control_net_${controlNetId}`, id: `control_net_${controlNetId}`,
type: 'controlnet', type: 'controlnet',
@ -82,7 +76,8 @@ export const addControlNetToLinearGraph = (
graph.nodes[controlNetNode.id] = controlNetNode; graph.nodes[controlNetNode.id] = controlNetNode;
if (size(controlNets) > 1) { if (validControlNets.length > 1) {
// if we have multiple controlnets, link to the collector
graph.edges.push({ graph.edges.push({
source: { node_id: controlNetNode.id, field: 'control' }, source: { node_id: controlNetNode.id, field: 'control' },
destination: { destination: {
@ -91,6 +86,7 @@ export const addControlNetToLinearGraph = (
}, },
}); });
} else { } else {
// otherwise, link directly to the base node
graph.edges.push({ graph.edges.push({
source: { node_id: controlNetNode.id, field: 'control' }, source: { node_id: controlNetNode.id, field: 'control' },
destination: { destination: {

View File

@ -349,21 +349,11 @@ export const getFieldType = (
if (typeHints && name in typeHints) { if (typeHints && name in typeHints) {
rawFieldType = typeHints[name]; rawFieldType = typeHints[name];
} else if (!schemaObject.type) { } else if (!schemaObject.type && schemaObject.allOf) {
// if schemaObject has no type, then it should have one of allOf, anyOf, oneOf // if schemaObject has no type, then it should have one of allOf
if (schemaObject.allOf) { rawFieldType =
rawFieldType = refObjectToFieldType( (schemaObject.allOf[0] as OpenAPIV3.SchemaObject).title ??
schemaObject.allOf![0] as OpenAPIV3.ReferenceObject 'Missing Field Type';
);
} else if (schemaObject.anyOf) {
rawFieldType = refObjectToFieldType(
schemaObject.anyOf![0] as OpenAPIV3.ReferenceObject
);
} else if (schemaObject.oneOf) {
rawFieldType = refObjectToFieldType(
schemaObject.oneOf![0] as OpenAPIV3.ReferenceObject
);
}
} else if (schemaObject.enum) { } else if (schemaObject.enum) {
rawFieldType = 'enum'; rawFieldType = 'enum';
} else if (schemaObject.type) { } else if (schemaObject.type) {

View File

@ -0,0 +1,153 @@
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import {
DynamicPromptInvocation,
IterateInvocation,
NoiseInvocation,
RandomIntInvocation,
RangeOfSizeInvocation,
} from 'services/api/types';
import {
DYNAMIC_PROMPT,
ITERATE,
NOISE,
POSITIVE_CONDITIONING,
RANDOM_INT,
RANGE_OF_SIZE,
} from './constants';
import { unset } from 'lodash-es';
export const addDynamicPromptsToGraph = (
graph: NonNullableGraph,
state: RootState
): void => {
const { positivePrompt, iterations, seed, shouldRandomizeSeed } =
state.generation;
const {
combinatorial,
isEnabled: isDynamicPromptsEnabled,
maxPrompts,
} = state.dynamicPrompts;
if (isDynamicPromptsEnabled) {
// iteration is handled via dynamic prompts
unset(graph.nodes[POSITIVE_CONDITIONING], 'prompt');
const dynamicPromptNode: DynamicPromptInvocation = {
id: DYNAMIC_PROMPT,
type: 'dynamic_prompt',
max_prompts: maxPrompts,
combinatorial,
prompt: positivePrompt,
};
const iterateNode: IterateInvocation = {
id: ITERATE,
type: 'iterate',
};
graph.nodes[DYNAMIC_PROMPT] = dynamicPromptNode;
graph.nodes[ITERATE] = iterateNode;
// connect dynamic prompts to compel nodes
graph.edges.push(
{
source: {
node_id: DYNAMIC_PROMPT,
field: 'prompt_collection',
},
destination: {
node_id: ITERATE,
field: 'collection',
},
},
{
source: {
node_id: ITERATE,
field: 'item',
},
destination: {
node_id: POSITIVE_CONDITIONING,
field: 'prompt',
},
}
);
if (shouldRandomizeSeed) {
// Random int node to generate the starting seed
const randomIntNode: RandomIntInvocation = {
id: RANDOM_INT,
type: 'rand_int',
};
graph.nodes[RANDOM_INT] = randomIntNode;
// Connect random int to the start of the range of size so the range starts on the random first seed
graph.edges.push({
source: { node_id: RANDOM_INT, field: 'a' },
destination: { node_id: NOISE, field: 'seed' },
});
} else {
// User specified seed, so set the start of the range of size to the seed
(graph.nodes[NOISE] as NoiseInvocation).seed = seed;
}
} else {
const rangeOfSizeNode: RangeOfSizeInvocation = {
id: RANGE_OF_SIZE,
type: 'range_of_size',
size: iterations,
step: 1,
};
const iterateNode: IterateInvocation = {
id: ITERATE,
type: 'iterate',
};
graph.nodes[ITERATE] = iterateNode;
graph.nodes[RANGE_OF_SIZE] = rangeOfSizeNode;
graph.edges.push({
source: {
node_id: RANGE_OF_SIZE,
field: 'collection',
},
destination: {
node_id: ITERATE,
field: 'collection',
},
});
graph.edges.push({
source: {
node_id: ITERATE,
field: 'item',
},
destination: {
node_id: NOISE,
field: 'seed',
},
});
// handle seed
if (shouldRandomizeSeed) {
// Random int node to generate the starting seed
const randomIntNode: RandomIntInvocation = {
id: RANDOM_INT,
type: 'rand_int',
};
graph.nodes[RANDOM_INT] = randomIntNode;
// Connect random int to the start of the range of size so the range starts on the random first seed
graph.edges.push({
source: { node_id: RANDOM_INT, field: 'a' },
destination: { node_id: RANGE_OF_SIZE, field: 'start' },
});
} else {
// User specified seed, so set the start of the range of size to the seed
rangeOfSizeNode.start = seed;
}
}
};

View File

@ -2,6 +2,7 @@ import { RootState } from 'app/store/store';
import { import {
ImageDTO, ImageDTO,
ImageResizeInvocation, ImageResizeInvocation,
ImageToLatentsInvocation,
RandomIntInvocation, RandomIntInvocation,
RangeOfSizeInvocation, RangeOfSizeInvocation,
} from 'services/api/types'; } from 'services/api/types';
@ -10,7 +11,7 @@ import { log } from 'app/logging/useLogger';
import { import {
ITERATE, ITERATE,
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
MODEL_LOADER, PIPELINE_MODEL_LOADER,
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
@ -24,6 +25,7 @@ import {
import { set } from 'lodash-es'; import { set } from 'lodash-es';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
const moduleLog = log.child({ namespace: 'nodes' }); const moduleLog = log.child({ namespace: 'nodes' });
@ -75,31 +77,19 @@ export const buildCanvasImageToImageGraph = (
id: NEGATIVE_CONDITIONING, id: NEGATIVE_CONDITIONING,
prompt: negativePrompt, prompt: negativePrompt,
}, },
[RANGE_OF_SIZE]: {
type: 'range_of_size',
id: RANGE_OF_SIZE,
// seed - must be connected manually
// start: 0,
size: iterations,
step: 1,
},
[NOISE]: { [NOISE]: {
type: 'noise', type: 'noise',
id: NOISE, id: NOISE,
}, },
[MODEL_LOADER]: { [PIPELINE_MODEL_LOADER]: {
type: 'pipeline_model_loader', type: 'pipeline_model_loader',
id: MODEL_LOADER, id: PIPELINE_MODEL_LOADER,
model, model,
}, },
[LATENTS_TO_IMAGE]: { [LATENTS_TO_IMAGE]: {
type: 'l2i', type: 'l2i',
id: LATENTS_TO_IMAGE, id: LATENTS_TO_IMAGE,
}, },
[ITERATE]: {
type: 'iterate',
id: ITERATE,
},
[LATENTS_TO_LATENTS]: { [LATENTS_TO_LATENTS]: {
type: 'l2l', type: 'l2l',
id: LATENTS_TO_LATENTS, id: LATENTS_TO_LATENTS,
@ -120,7 +110,7 @@ export const buildCanvasImageToImageGraph = (
edges: [ edges: [
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -130,7 +120,7 @@ export const buildCanvasImageToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -140,7 +130,7 @@ export const buildCanvasImageToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'vae', field: 'vae',
}, },
destination: { destination: {
@ -148,26 +138,6 @@ export const buildCanvasImageToImageGraph = (
field: 'vae', field: 'vae',
}, },
}, },
{
source: {
node_id: RANGE_OF_SIZE,
field: 'collection',
},
destination: {
node_id: ITERATE,
field: 'collection',
},
},
{
source: {
node_id: ITERATE,
field: 'item',
},
destination: {
node_id: NOISE,
field: 'seed',
},
},
{ {
source: { source: {
node_id: LATENTS_TO_LATENTS, node_id: LATENTS_TO_LATENTS,
@ -200,7 +170,7 @@ export const buildCanvasImageToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'vae', field: 'vae',
}, },
destination: { destination: {
@ -210,7 +180,7 @@ export const buildCanvasImageToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'unet', field: 'unet',
}, },
destination: { destination: {
@ -241,26 +211,6 @@ export const buildCanvasImageToImageGraph = (
], ],
}; };
// handle seed
if (shouldRandomizeSeed) {
// Random int node to generate the starting seed
const randomIntNode: RandomIntInvocation = {
id: RANDOM_INT,
type: 'rand_int',
};
graph.nodes[RANDOM_INT] = randomIntNode;
// Connect random int to the start of the range of size so the range starts on the random first seed
graph.edges.push({
source: { node_id: RANDOM_INT, field: 'a' },
destination: { node_id: RANGE_OF_SIZE, field: 'start' },
});
} else {
// User specified seed, so set the start of the range of size to the seed
(graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed;
}
// handle `fit` // handle `fit`
if (initialImage.width !== width || initialImage.height !== height) { if (initialImage.width !== width || initialImage.height !== height) {
// The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS` // The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS`
@ -306,9 +256,9 @@ export const buildCanvasImageToImageGraph = (
}); });
} else { } else {
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly // We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
set(graph.nodes[IMAGE_TO_LATENTS], 'image', { (graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image = {
image_name: initialImage.image_name, image_name: initialImage.image_name,
}); };
// Pass the image's dimensions to the `NOISE` node // Pass the image's dimensions to the `NOISE` node
graph.edges.push({ graph.edges.push({
@ -327,7 +277,10 @@ export const buildCanvasImageToImageGraph = (
}); });
} }
// add controlnet // add dynamic prompts, mutating `graph`
addDynamicPromptsToGraph(graph, state);
// add controlnet, mutating `graph`
addControlNetToLinearGraph(graph, LATENTS_TO_LATENTS, state); addControlNetToLinearGraph(graph, LATENTS_TO_LATENTS, state);
return graph; return graph;

View File

@ -9,7 +9,7 @@ import { NonNullableGraph } from 'features/nodes/types/types';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { import {
ITERATE, ITERATE,
MODEL_LOADER, PIPELINE_MODEL_LOADER,
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
RANDOM_INT, RANDOM_INT,
@ -101,9 +101,9 @@ export const buildCanvasInpaintGraph = (
id: NEGATIVE_CONDITIONING, id: NEGATIVE_CONDITIONING,
prompt: negativePrompt, prompt: negativePrompt,
}, },
[MODEL_LOADER]: { [PIPELINE_MODEL_LOADER]: {
type: 'pipeline_model_loader', type: 'pipeline_model_loader',
id: MODEL_LOADER, id: PIPELINE_MODEL_LOADER,
model, model,
}, },
[RANGE_OF_SIZE]: { [RANGE_OF_SIZE]: {
@ -142,7 +142,7 @@ export const buildCanvasInpaintGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -152,7 +152,7 @@ export const buildCanvasInpaintGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -162,7 +162,7 @@ export const buildCanvasInpaintGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'unet', field: 'unet',
}, },
destination: { destination: {
@ -172,7 +172,7 @@ export const buildCanvasInpaintGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'vae', field: 'vae',
}, },
destination: { destination: {

View File

@ -4,7 +4,7 @@ import { RandomIntInvocation, RangeOfSizeInvocation } from 'services/api/types';
import { import {
ITERATE, ITERATE,
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
MODEL_LOADER, PIPELINE_MODEL_LOADER,
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
@ -15,6 +15,7 @@ import {
} from './constants'; } from './constants';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
/** /**
* Builds the Canvas tab's Text to Image graph. * Builds the Canvas tab's Text to Image graph.
@ -62,13 +63,6 @@ export const buildCanvasTextToImageGraph = (
id: NEGATIVE_CONDITIONING, id: NEGATIVE_CONDITIONING,
prompt: negativePrompt, prompt: negativePrompt,
}, },
[RANGE_OF_SIZE]: {
type: 'range_of_size',
id: RANGE_OF_SIZE,
// start: 0, // seed - must be connected manually
size: iterations,
step: 1,
},
[NOISE]: { [NOISE]: {
type: 'noise', type: 'noise',
id: NOISE, id: NOISE,
@ -82,19 +76,15 @@ export const buildCanvasTextToImageGraph = (
scheduler, scheduler,
steps, steps,
}, },
[MODEL_LOADER]: { [PIPELINE_MODEL_LOADER]: {
type: 'pipeline_model_loader', type: 'pipeline_model_loader',
id: MODEL_LOADER, id: PIPELINE_MODEL_LOADER,
model, model,
}, },
[LATENTS_TO_IMAGE]: { [LATENTS_TO_IMAGE]: {
type: 'l2i', type: 'l2i',
id: LATENTS_TO_IMAGE, id: LATENTS_TO_IMAGE,
}, },
[ITERATE]: {
type: 'iterate',
id: ITERATE,
},
}, },
edges: [ edges: [
{ {
@ -119,7 +109,7 @@ export const buildCanvasTextToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -129,7 +119,7 @@ export const buildCanvasTextToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -139,7 +129,7 @@ export const buildCanvasTextToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'unet', field: 'unet',
}, },
destination: { destination: {
@ -159,7 +149,7 @@ export const buildCanvasTextToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'vae', field: 'vae',
}, },
destination: { destination: {
@ -167,26 +157,6 @@ export const buildCanvasTextToImageGraph = (
field: 'vae', field: 'vae',
}, },
}, },
{
source: {
node_id: RANGE_OF_SIZE,
field: 'collection',
},
destination: {
node_id: ITERATE,
field: 'collection',
},
},
{
source: {
node_id: ITERATE,
field: 'item',
},
destination: {
node_id: NOISE,
field: 'seed',
},
},
{ {
source: { source: {
node_id: NOISE, node_id: NOISE,
@ -200,27 +170,10 @@ export const buildCanvasTextToImageGraph = (
], ],
}; };
// handle seed // add dynamic prompts, mutating `graph`
if (shouldRandomizeSeed) { addDynamicPromptsToGraph(graph, state);
// Random int node to generate the starting seed
const randomIntNode: RandomIntInvocation = {
id: RANDOM_INT,
type: 'rand_int',
};
graph.nodes[RANDOM_INT] = randomIntNode; // add controlnet, mutating `graph`
// Connect random int to the start of the range of size so the range starts on the random first seed
graph.edges.push({
source: { node_id: RANDOM_INT, field: 'a' },
destination: { node_id: RANGE_OF_SIZE, field: 'start' },
});
} else {
// User specified seed, so set the start of the range of size to the seed
(graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed;
}
// add controlnet
addControlNetToLinearGraph(graph, TEXT_TO_LATENTS, state); addControlNetToLinearGraph(graph, TEXT_TO_LATENTS, state);
return graph; return graph;

View File

@ -1,28 +1,24 @@
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { import {
ImageResizeInvocation, ImageResizeInvocation,
RandomIntInvocation, ImageToLatentsInvocation,
RangeOfSizeInvocation,
} from 'services/api/types'; } from 'services/api/types';
import { NonNullableGraph } from 'features/nodes/types/types'; import { NonNullableGraph } from 'features/nodes/types/types';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { import {
ITERATE,
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
MODEL_LOADER, PIPELINE_MODEL_LOADER,
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
RANDOM_INT,
RANGE_OF_SIZE,
IMAGE_TO_IMAGE_GRAPH, IMAGE_TO_IMAGE_GRAPH,
IMAGE_TO_LATENTS, IMAGE_TO_LATENTS,
LATENTS_TO_LATENTS, LATENTS_TO_LATENTS,
RESIZE, RESIZE,
} from './constants'; } from './constants';
import { set } from 'lodash-es';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
const moduleLog = log.child({ namespace: 'nodes' }); const moduleLog = log.child({ namespace: 'nodes' });
@ -44,9 +40,6 @@ export const buildLinearImageToImageGraph = (
shouldFitToWidthHeight, shouldFitToWidthHeight,
width, width,
height, height,
iterations,
seed,
shouldRandomizeSeed,
} = state.generation; } = state.generation;
/** /**
@ -79,31 +72,19 @@ export const buildLinearImageToImageGraph = (
id: NEGATIVE_CONDITIONING, id: NEGATIVE_CONDITIONING,
prompt: negativePrompt, prompt: negativePrompt,
}, },
[RANGE_OF_SIZE]: {
type: 'range_of_size',
id: RANGE_OF_SIZE,
// seed - must be connected manually
// start: 0,
size: iterations,
step: 1,
},
[NOISE]: { [NOISE]: {
type: 'noise', type: 'noise',
id: NOISE, id: NOISE,
}, },
[MODEL_LOADER]: { [PIPELINE_MODEL_LOADER]: {
type: 'pipeline_model_loader', type: 'pipeline_model_loader',
id: MODEL_LOADER, id: PIPELINE_MODEL_LOADER,
model, model,
}, },
[LATENTS_TO_IMAGE]: { [LATENTS_TO_IMAGE]: {
type: 'l2i', type: 'l2i',
id: LATENTS_TO_IMAGE, id: LATENTS_TO_IMAGE,
}, },
[ITERATE]: {
type: 'iterate',
id: ITERATE,
},
[LATENTS_TO_LATENTS]: { [LATENTS_TO_LATENTS]: {
type: 'l2l', type: 'l2l',
id: LATENTS_TO_LATENTS, id: LATENTS_TO_LATENTS,
@ -124,7 +105,7 @@ export const buildLinearImageToImageGraph = (
edges: [ edges: [
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -134,7 +115,7 @@ export const buildLinearImageToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -144,7 +125,7 @@ export const buildLinearImageToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'vae', field: 'vae',
}, },
destination: { destination: {
@ -152,26 +133,6 @@ export const buildLinearImageToImageGraph = (
field: 'vae', field: 'vae',
}, },
}, },
{
source: {
node_id: RANGE_OF_SIZE,
field: 'collection',
},
destination: {
node_id: ITERATE,
field: 'collection',
},
},
{
source: {
node_id: ITERATE,
field: 'item',
},
destination: {
node_id: NOISE,
field: 'seed',
},
},
{ {
source: { source: {
node_id: LATENTS_TO_LATENTS, node_id: LATENTS_TO_LATENTS,
@ -204,7 +165,7 @@ export const buildLinearImageToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'vae', field: 'vae',
}, },
destination: { destination: {
@ -214,7 +175,7 @@ export const buildLinearImageToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'unet', field: 'unet',
}, },
destination: { destination: {
@ -245,26 +206,6 @@ export const buildLinearImageToImageGraph = (
], ],
}; };
// handle seed
if (shouldRandomizeSeed) {
// Random int node to generate the starting seed
const randomIntNode: RandomIntInvocation = {
id: RANDOM_INT,
type: 'rand_int',
};
graph.nodes[RANDOM_INT] = randomIntNode;
// Connect random int to the start of the range of size so the range starts on the random first seed
graph.edges.push({
source: { node_id: RANDOM_INT, field: 'a' },
destination: { node_id: RANGE_OF_SIZE, field: 'start' },
});
} else {
// User specified seed, so set the start of the range of size to the seed
(graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed;
}
// handle `fit` // handle `fit`
if ( if (
shouldFitToWidthHeight && shouldFitToWidthHeight &&
@ -313,9 +254,9 @@ export const buildLinearImageToImageGraph = (
}); });
} else { } else {
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly // We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
set(graph.nodes[IMAGE_TO_LATENTS], 'image', { (graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image = {
image_name: initialImage.imageName, image_name: initialImage.imageName,
}); };
// Pass the image's dimensions to the `NOISE` node // Pass the image's dimensions to the `NOISE` node
graph.edges.push({ graph.edges.push({
@ -334,7 +275,10 @@ export const buildLinearImageToImageGraph = (
}); });
} }
// add controlnet // add dynamic prompts, mutating `graph`
addDynamicPromptsToGraph(graph, state);
// add controlnet, mutating `graph`
addControlNetToLinearGraph(graph, LATENTS_TO_LATENTS, state); addControlNetToLinearGraph(graph, LATENTS_TO_LATENTS, state);
return graph; return graph;

View File

@ -1,33 +1,20 @@
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types'; import { NonNullableGraph } from 'features/nodes/types/types';
import { import {
BaseModelType,
RandomIntInvocation,
RangeOfSizeInvocation,
} from 'services/api/types';
import {
ITERATE,
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
MODEL_LOADER, PIPELINE_MODEL_LOADER,
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
RANDOM_INT,
RANGE_OF_SIZE,
TEXT_TO_IMAGE_GRAPH, TEXT_TO_IMAGE_GRAPH,
TEXT_TO_LATENTS, TEXT_TO_LATENTS,
} from './constants'; } from './constants';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
type TextToImageGraphOverrides = {
width: number;
height: number;
};
export const buildLinearTextToImageGraph = ( export const buildLinearTextToImageGraph = (
state: RootState, state: RootState
overrides?: TextToImageGraphOverrides
): NonNullableGraph => { ): NonNullableGraph => {
const { const {
positivePrompt, positivePrompt,
@ -38,9 +25,6 @@ export const buildLinearTextToImageGraph = (
steps, steps,
width, width,
height, height,
iterations,
seed,
shouldRandomizeSeed,
} = state.generation; } = state.generation;
const model = modelIdToPipelineModelField(modelId); const model = modelIdToPipelineModelField(modelId);
@ -68,18 +52,11 @@ export const buildLinearTextToImageGraph = (
id: NEGATIVE_CONDITIONING, id: NEGATIVE_CONDITIONING,
prompt: negativePrompt, prompt: negativePrompt,
}, },
[RANGE_OF_SIZE]: {
type: 'range_of_size',
id: RANGE_OF_SIZE,
// start: 0, // seed - must be connected manually
size: iterations,
step: 1,
},
[NOISE]: { [NOISE]: {
type: 'noise', type: 'noise',
id: NOISE, id: NOISE,
width: overrides?.width || width, width,
height: overrides?.height || height, height,
}, },
[TEXT_TO_LATENTS]: { [TEXT_TO_LATENTS]: {
type: 't2l', type: 't2l',
@ -88,19 +65,15 @@ export const buildLinearTextToImageGraph = (
scheduler, scheduler,
steps, steps,
}, },
[MODEL_LOADER]: { [PIPELINE_MODEL_LOADER]: {
type: 'pipeline_model_loader', type: 'pipeline_model_loader',
id: MODEL_LOADER, id: PIPELINE_MODEL_LOADER,
model, model,
}, },
[LATENTS_TO_IMAGE]: { [LATENTS_TO_IMAGE]: {
type: 'l2i', type: 'l2i',
id: LATENTS_TO_IMAGE, id: LATENTS_TO_IMAGE,
}, },
[ITERATE]: {
type: 'iterate',
id: ITERATE,
},
}, },
edges: [ edges: [
{ {
@ -125,7 +98,7 @@ export const buildLinearTextToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -135,7 +108,7 @@ export const buildLinearTextToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -145,7 +118,7 @@ export const buildLinearTextToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'unet', field: 'unet',
}, },
destination: { destination: {
@ -165,7 +138,7 @@ export const buildLinearTextToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'vae', field: 'vae',
}, },
destination: { destination: {
@ -173,26 +146,6 @@ export const buildLinearTextToImageGraph = (
field: 'vae', field: 'vae',
}, },
}, },
{
source: {
node_id: RANGE_OF_SIZE,
field: 'collection',
},
destination: {
node_id: ITERATE,
field: 'collection',
},
},
{
source: {
node_id: ITERATE,
field: 'item',
},
destination: {
node_id: NOISE,
field: 'seed',
},
},
{ {
source: { source: {
node_id: NOISE, node_id: NOISE,
@ -206,27 +159,10 @@ export const buildLinearTextToImageGraph = (
], ],
}; };
// handle seed // add dynamic prompts, mutating `graph`
if (shouldRandomizeSeed) { addDynamicPromptsToGraph(graph, state);
// Random int node to generate the starting seed
const randomIntNode: RandomIntInvocation = {
id: RANDOM_INT,
type: 'rand_int',
};
graph.nodes[RANDOM_INT] = randomIntNode; // add controlnet, mutating `graph`
// Connect random int to the start of the range of size so the range starts on the random first seed
graph.edges.push({
source: { node_id: RANDOM_INT, field: 'a' },
destination: { node_id: RANGE_OF_SIZE, field: 'start' },
});
} else {
// User specified seed, so set the start of the range of size to the seed
(graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed;
}
// add controlnet
addControlNetToLinearGraph(graph, TEXT_TO_LATENTS, state); addControlNetToLinearGraph(graph, TEXT_TO_LATENTS, state);
return graph; return graph;

View File

@ -7,12 +7,13 @@ export const NOISE = 'noise';
export const RANDOM_INT = 'rand_int'; export const RANDOM_INT = 'rand_int';
export const RANGE_OF_SIZE = 'range_of_size'; export const RANGE_OF_SIZE = 'range_of_size';
export const ITERATE = 'iterate'; export const ITERATE = 'iterate';
export const MODEL_LOADER = 'pipeline_model_loader'; export const PIPELINE_MODEL_LOADER = 'pipeline_model_loader';
export const IMAGE_TO_LATENTS = 'image_to_latents'; export const IMAGE_TO_LATENTS = 'image_to_latents';
export const LATENTS_TO_LATENTS = 'latents_to_latents'; export const LATENTS_TO_LATENTS = 'latents_to_latents';
export const RESIZE = 'resize_image'; export const RESIZE = 'resize_image';
export const INPAINT = 'inpaint'; export const INPAINT = 'inpaint';
export const CONTROL_NET_COLLECT = 'control_net_collect'; export const CONTROL_NET_COLLECT = 'control_net_collect';
export const DYNAMIC_PROMPT = 'dynamic_prompt';
// friendly graph ids // friendly graph ids
export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph'; export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph';

View File

@ -1,26 +0,0 @@
import { v4 as uuidv4 } from 'uuid';
import { RootState } from 'app/store/store';
import { CompelInvocation } from 'services/api/types';
import { O } from 'ts-toolbelt';
export const buildCompelNode = (
prompt: string,
state: RootState,
overrides: O.Partial<CompelInvocation, 'deep'> = {}
): CompelInvocation => {
const nodeId = uuidv4();
const { generation } = state;
const { model } = generation;
const compelNode: CompelInvocation = {
id: nodeId,
type: 'compel',
prompt,
model,
};
Object.assign(compelNode, overrides);
return compelNode;
};

View File

@ -1,107 +0,0 @@
import { v4 as uuidv4 } from 'uuid';
import { RootState } from 'app/store/store';
import {
Edge,
ImageToImageInvocation,
TextToImageInvocation,
} from 'services/api/types';
import { O } from 'ts-toolbelt';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
export const buildImg2ImgNode = (
state: RootState,
overrides: O.Partial<ImageToImageInvocation, 'deep'> = {}
): ImageToImageInvocation => {
const nodeId = uuidv4();
const { generation } = state;
const activeTabName = activeTabNameSelector(state);
const {
positivePrompt: prompt,
negativePrompt: negativePrompt,
seed,
steps,
width,
height,
cfgScale,
scheduler,
model,
img2imgStrength: strength,
shouldFitToWidthHeight: fit,
shouldRandomizeSeed,
initialImage,
} = generation;
// const initialImage = initialImageSelector(state);
const imageToImageNode: ImageToImageInvocation = {
id: nodeId,
type: 'img2img',
prompt: `${prompt} [${negativePrompt}]`,
steps,
width,
height,
cfg_scale: cfgScale,
scheduler,
model,
strength,
fit,
};
// on Canvas tab, we do not manually specific init image
if (activeTabName !== 'unifiedCanvas') {
if (!initialImage) {
// TODO: handle this more better
throw 'no initial image';
}
imageToImageNode.image = {
image_name: initialImage.imageName,
};
}
if (!shouldRandomizeSeed) {
imageToImageNode.seed = seed;
}
Object.assign(imageToImageNode, overrides);
return imageToImageNode;
};
type hiresReturnType = {
node: Record<string, ImageToImageInvocation>;
edge: Edge;
};
export const buildHiResNode = (
baseNode: Record<string, TextToImageInvocation>,
strength?: number
): hiresReturnType => {
const nodeId = uuidv4();
const baseNodeId = Object.keys(baseNode)[0];
const baseNodeValues = Object.values(baseNode)[0];
return {
node: {
[nodeId]: {
...baseNodeValues,
id: nodeId,
type: 'img2img',
strength,
fit: true,
},
},
edge: {
source: {
field: 'image',
node_id: baseNodeId,
},
destination: {
field: 'image',
node_id: nodeId,
},
},
};
};

View File

@ -1,48 +0,0 @@
import { v4 as uuidv4 } from 'uuid';
import { RootState } from 'app/store/store';
import { InpaintInvocation } from 'services/api/types';
import { O } from 'ts-toolbelt';
export const buildInpaintNode = (
state: RootState,
overrides: O.Partial<InpaintInvocation, 'deep'> = {}
): InpaintInvocation => {
const nodeId = uuidv4();
const {
positivePrompt: prompt,
negativePrompt: negativePrompt,
seed,
steps,
width,
height,
cfgScale,
scheduler,
model,
img2imgStrength: strength,
shouldFitToWidthHeight: fit,
shouldRandomizeSeed,
} = state.generation;
const inpaintNode: InpaintInvocation = {
id: nodeId,
type: 'inpaint',
prompt: `${prompt} [${negativePrompt}]`,
steps,
width,
height,
cfg_scale: cfgScale,
scheduler,
model,
strength,
fit,
};
if (!shouldRandomizeSeed) {
inpaintNode.seed = seed;
}
Object.assign(inpaintNode, overrides);
return inpaintNode;
};

View File

@ -1,13 +0,0 @@
import { v4 as uuidv4 } from 'uuid';
import { IterateInvocation } from 'services/api/types';
export const buildIterateNode = (): IterateInvocation => {
const nodeId = uuidv4();
return {
id: nodeId,
type: 'iterate',
// collection: [],
// index: 0,
};
};

View File

@ -1,26 +0,0 @@
import { v4 as uuidv4 } from 'uuid';
import { RootState } from 'app/store/store';
import { RandomRangeInvocation, RangeInvocation } from 'services/api/types';
export const buildRangeNode = (
state: RootState
): RangeInvocation | RandomRangeInvocation => {
const nodeId = uuidv4();
const { shouldRandomizeSeed, iterations, seed } = state.generation;
if (shouldRandomizeSeed) {
return {
id: nodeId,
type: 'random_range',
size: iterations,
};
}
return {
id: nodeId,
type: 'range',
start: seed,
stop: seed + iterations,
};
};

View File

@ -1,45 +0,0 @@
import { v4 as uuidv4 } from 'uuid';
import { RootState } from 'app/store/store';
import { TextToImageInvocation } from 'services/api/types';
import { O } from 'ts-toolbelt';
export const buildTxt2ImgNode = (
state: RootState,
overrides: O.Partial<TextToImageInvocation, 'deep'> = {}
): TextToImageInvocation => {
const nodeId = uuidv4();
const { generation } = state;
const {
positivePrompt: prompt,
negativePrompt: negativePrompt,
seed,
steps,
width,
height,
cfgScale: cfg_scale,
scheduler,
shouldRandomizeSeed,
model,
} = generation;
const textToImageNode: NonNullable<TextToImageInvocation> = {
id: nodeId,
type: 'txt2img',
prompt: `${prompt} [${negativePrompt}]`,
steps,
width,
height,
cfg_scale,
scheduler,
model,
};
if (!shouldRandomizeSeed) {
textToImageNode.seed = seed;
}
Object.assign(textToImageNode, overrides);
return textToImageNode;
};

View File

@ -5,127 +5,154 @@ import {
InputFieldTemplate, InputFieldTemplate,
InvocationSchemaObject, InvocationSchemaObject,
InvocationTemplate, InvocationTemplate,
isInvocationSchemaObject,
OutputFieldTemplate, OutputFieldTemplate,
} from '../types/types'; } from '../types/types';
import { import { buildInputFieldTemplate, getFieldType } from './fieldTemplateBuilders';
buildInputFieldTemplate, import { O } from 'ts-toolbelt';
buildOutputFieldTemplates,
} from './fieldTemplateBuilders'; // recursively exclude all properties of type U from T
type DeepExclude<T, U> = T extends U
? never
: T extends object
? {
[K in keyof T]: DeepExclude<T[K], U>;
}
: T;
// The schema from swagger-parser is dereferenced, and we know `components` and `components.schemas` exist
type DereferencedOpenAPIDocument = DeepExclude<
O.Required<OpenAPIV3.Document, 'schemas' | 'components', 'deep'>,
OpenAPIV3.ReferenceObject
>;
const RESERVED_FIELD_NAMES = ['id', 'type', 'is_intermediate']; const RESERVED_FIELD_NAMES = ['id', 'type', 'is_intermediate'];
const invocationDenylist = ['Graph', 'InvocationMeta']; const invocationDenylist = ['Graph', 'InvocationMeta'];
export const parseSchema = (openAPI: OpenAPIV3.Document) => { const nodeFilter = (
schema: DereferencedOpenAPIDocument['components']['schemas'][string],
key: string
) =>
key.includes('Invocation') &&
!key.includes('InvocationOutput') &&
!invocationDenylist.some((denylistItem) => key.includes(denylistItem));
export const parseSchema = (openAPI: DereferencedOpenAPIDocument) => {
// filter out non-invocation schemas, plus some tricky invocations for now // filter out non-invocation schemas, plus some tricky invocations for now
const filteredSchemas = filter( const filteredSchemas = filter(openAPI.components.schemas, nodeFilter);
openAPI.components!.schemas,
(schema, key) =>
key.includes('Invocation') &&
!key.includes('InvocationOutput') &&
!invocationDenylist.some((denylistItem) => key.includes(denylistItem))
) as (OpenAPIV3.ReferenceObject | InvocationSchemaObject)[];
const invocations = filteredSchemas.reduce< const invocations = filteredSchemas.reduce<
Record<string, InvocationTemplate> Record<string, InvocationTemplate>
>((acc, schema) => { >((acc, s) => {
// only want SchemaObjects // cast to InvocationSchemaObject, we know the shape
if (isInvocationSchemaObject(schema)) { const schema = s as InvocationSchemaObject;
const type = schema.properties.type.default;
const title = schema.ui?.title ?? schema.title.replace('Invocation', ''); const type = schema.properties.type.default;
const typeHints = schema.ui?.type_hints; const title = schema.ui?.title ?? schema.title.replace('Invocation', '');
const inputs: Record<string, InputFieldTemplate> = {}; const typeHints = schema.ui?.type_hints;
if (type === 'collect') { const inputs: Record<string, InputFieldTemplate> = {};
const itemProperty = schema.properties[
'item'
] as InvocationSchemaObject;
// Handle the special Collect node
inputs.item = {
type: 'item',
name: 'item',
description: itemProperty.description ?? '',
title: 'Collection Item',
inputKind: 'connection',
inputRequirement: 'always',
default: undefined,
};
} else if (type === 'iterate') {
const itemProperty = schema.properties[
'collection'
] as InvocationSchemaObject;
inputs.collection = { if (type === 'collect') {
type: 'array', // Special handling for the Collect node
name: 'collection', const itemProperty = schema.properties['item'] as InvocationSchemaObject;
title: itemProperty.title ?? '', inputs.item = {
default: [], type: 'item',
description: itemProperty.description ?? '', name: 'item',
inputRequirement: 'always', description: itemProperty.description ?? '',
inputKind: 'connection', title: 'Collection Item',
}; inputKind: 'connection',
} else { inputRequirement: 'always',
// All other nodes default: undefined,
reduce(
schema.properties,
(inputsAccumulator, property, propertyName) => {
if (
// `type` and `id` are not valid inputs/outputs
!RESERVED_FIELD_NAMES.includes(propertyName) &&
isSchemaObject(property)
) {
const field: InputFieldTemplate | undefined =
buildInputFieldTemplate(property, propertyName, typeHints);
if (field) {
inputsAccumulator[propertyName] = field;
}
}
return inputsAccumulator;
},
inputs
);
}
const rawOutput = (schema as InvocationSchemaObject).output;
let outputs: Record<string, OutputFieldTemplate>;
// some special handling is needed for collect, iterate and range nodes
if (type === 'iterate') {
// this is guaranteed to be a SchemaObject
const iterationOutput = openAPI.components!.schemas![
'IterateInvocationOutput'
] as OpenAPIV3.SchemaObject;
outputs = {
item: {
name: 'item',
title: iterationOutput.title ?? '',
description: iterationOutput.description ?? '',
type: 'array',
},
};
} else {
outputs = buildOutputFieldTemplates(rawOutput, openAPI, typeHints);
}
const invocation: InvocationTemplate = {
title,
type,
tags: schema.ui?.tags ?? [],
description: schema.description ?? '',
inputs,
outputs,
}; };
} else if (type === 'iterate') {
// Special handling for the Iterate node
const itemProperty = schema.properties[
'collection'
] as InvocationSchemaObject;
Object.assign(acc, { [type]: invocation }); inputs.collection = {
type: 'array',
name: 'collection',
title: itemProperty.title ?? '',
default: [],
description: itemProperty.description ?? '',
inputRequirement: 'always',
inputKind: 'connection',
};
} else {
// All other nodes
reduce(
schema.properties,
(inputsAccumulator, property, propertyName) => {
if (
// `type` and `id` are not valid inputs/outputs
!RESERVED_FIELD_NAMES.includes(propertyName) &&
isSchemaObject(property)
) {
const field: InputFieldTemplate | undefined =
buildInputFieldTemplate(property, propertyName, typeHints);
if (field) {
inputsAccumulator[propertyName] = field;
}
}
return inputsAccumulator;
},
inputs
);
} }
let outputs: Record<string, OutputFieldTemplate>;
if (type === 'iterate') {
// Special handling for the Iterate node output
const iterationOutput =
openAPI.components.schemas['IterateInvocationOutput'];
outputs = {
item: {
name: 'item',
title: iterationOutput.title ?? '',
description: iterationOutput.description ?? '',
type: 'array',
},
};
} else {
// All other node outputs
outputs = reduce(
schema.output.properties as OpenAPIV3.SchemaObject,
(outputsAccumulator, property, propertyName) => {
if (!['type', 'id'].includes(propertyName)) {
const fieldType = getFieldType(property, propertyName, typeHints);
outputsAccumulator[propertyName] = {
name: propertyName,
title: property.title ?? '',
description: property.description ?? '',
type: fieldType,
};
}
return outputsAccumulator;
},
{} as Record<string, OutputFieldTemplate>
);
}
const invocation: InvocationTemplate = {
title,
type,
tags: schema.ui?.tags ?? [],
description: schema.description ?? '',
inputs,
outputs,
};
Object.assign(acc, { [type]: invocation });
return acc; return acc;
}, {}); }, {});

View File

@ -1,4 +1,5 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAINumberInput from 'common/components/IAINumberInput'; import IAINumberInput from 'common/components/IAINumberInput';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
@ -10,27 +11,26 @@ import { uiSelector } from 'features/ui/store/uiSelectors';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
const selector = createSelector( const selector = createSelector([stateSelector], (state) => {
[generationSelector, configSelector, uiSelector, hotkeysSelector], const { initial, min, sliderMax, inputMax, fineStep, coarseStep } =
(generation, config, ui, hotkeys) => { state.config.sd.iterations;
const { initial, min, sliderMax, inputMax, fineStep, coarseStep } = const { iterations } = state.generation;
config.sd.iterations; const { shouldUseSliders } = state.ui;
const { iterations } = generation; const isDisabled = state.dynamicPrompts.isEnabled;
const { shouldUseSliders } = ui;
const step = hotkeys.shift ? fineStep : coarseStep; const step = state.hotkeys.shift ? fineStep : coarseStep;
return { return {
iterations, iterations,
initial, initial,
min, min,
sliderMax, sliderMax,
inputMax, inputMax,
step, step,
shouldUseSliders, shouldUseSliders,
}; isDisabled,
} };
); });
const ParamIterations = () => { const ParamIterations = () => {
const { const {
@ -41,6 +41,7 @@ const ParamIterations = () => {
inputMax, inputMax,
step, step,
shouldUseSliders, shouldUseSliders,
isDisabled,
} = useAppSelector(selector); } = useAppSelector(selector);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
@ -58,6 +59,7 @@ const ParamIterations = () => {
return shouldUseSliders ? ( return shouldUseSliders ? (
<IAISlider <IAISlider
isDisabled={isDisabled}
label={t('parameters.images')} label={t('parameters.images')}
step={step} step={step}
min={min} min={min}
@ -72,6 +74,7 @@ const ParamIterations = () => {
/> />
) : ( ) : (
<IAINumberInput <IAINumberInput
isDisabled={isDisabled}
label={t('parameters.images')} label={t('parameters.images')}
step={step} step={step}
min={min} min={min}

View File

@ -60,6 +60,14 @@ export const initialConfigState: AppConfig = {
fineStep: 0.01, fineStep: 0.01,
coarseStep: 0.05, coarseStep: 0.05,
}, },
dynamicPrompts: {
maxPrompts: {
initial: 100,
min: 1,
sliderMax: 1000,
inputMax: 10000,
},
},
}, },
}; };

View File

@ -4,7 +4,6 @@ import * as InvokeAI from 'app/types/invokeai';
import { InvokeLogLevel } from 'app/logging/useLogger'; import { InvokeLogLevel } from 'app/logging/useLogger';
import { userInvoked } from 'app/store/actions'; import { userInvoked } from 'app/store/actions';
import { parsedOpenAPISchema } from 'features/nodes/store/nodesSlice';
import { TFuncKey, t } from 'i18next'; import { TFuncKey, t } from 'i18next';
import { LogLevelName } from 'roarr'; import { LogLevelName } from 'roarr';
import { import {
@ -26,6 +25,7 @@ import {
} from 'services/api/thunks/session'; } from 'services/api/thunks/session';
import { makeToast } from '../../../app/components/Toaster'; import { makeToast } from '../../../app/components/Toaster';
import { LANGUAGES } from '../components/LanguagePicker'; import { LANGUAGES } from '../components/LanguagePicker';
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
export type CancelStrategy = 'immediate' | 'scheduled'; export type CancelStrategy = 'immediate' | 'scheduled';
@ -382,7 +382,7 @@ export const systemSlice = createSlice({
/** /**
* OpenAPI schema was parsed * OpenAPI schema was parsed
*/ */
builder.addCase(parsedOpenAPISchema, (state) => { builder.addCase(nodeTemplatesBuilt, (state) => {
state.wasSchemaParsed = true; state.wasSchemaParsed = true;
}); });

View File

@ -8,6 +8,7 @@ import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Sym
import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse'; import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse';
import ImageToImageTabCoreParameters from './ImageToImageTabCoreParameters'; import ImageToImageTabCoreParameters from './ImageToImageTabCoreParameters';
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse'; import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
const ImageToImageTabParameters = () => { const ImageToImageTabParameters = () => {
return ( return (
@ -16,6 +17,7 @@ const ImageToImageTabParameters = () => {
<ParamNegativeConditioning /> <ParamNegativeConditioning />
<ProcessButtons /> <ProcessButtons />
<ImageToImageTabCoreParameters /> <ImageToImageTabCoreParameters />
<ParamDynamicPromptsCollapse />
<ParamControlNetCollapse /> <ParamControlNetCollapse />
<ParamVariationCollapse /> <ParamVariationCollapse />
<ParamNoiseCollapse /> <ParamNoiseCollapse />

View File

@ -9,6 +9,7 @@ import ParamHiresCollapse from 'features/parameters/components/Parameters/Hires/
import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse'; import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse';
import TextToImageTabCoreParameters from './TextToImageTabCoreParameters'; import TextToImageTabCoreParameters from './TextToImageTabCoreParameters';
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse'; import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
const TextToImageTabParameters = () => { const TextToImageTabParameters = () => {
return ( return (
@ -17,6 +18,7 @@ const TextToImageTabParameters = () => {
<ParamNegativeConditioning /> <ParamNegativeConditioning />
<ProcessButtons /> <ProcessButtons />
<TextToImageTabCoreParameters /> <TextToImageTabCoreParameters />
<ParamDynamicPromptsCollapse />
<ParamControlNetCollapse /> <ParamControlNetCollapse />
<ParamVariationCollapse /> <ParamVariationCollapse />
<ParamNoiseCollapse /> <ParamNoiseCollapse />

View File

@ -8,6 +8,7 @@ import { memo } from 'react';
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning'; import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning'; import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse'; import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
const UnifiedCanvasParameters = () => { const UnifiedCanvasParameters = () => {
return ( return (
@ -16,6 +17,7 @@ const UnifiedCanvasParameters = () => {
<ParamNegativeConditioning /> <ParamNegativeConditioning />
<ProcessButtons /> <ProcessButtons />
<UnifiedCanvasCoreParameters /> <UnifiedCanvasCoreParameters />
<ParamDynamicPromptsCollapse />
<ParamControlNetCollapse /> <ParamControlNetCollapse />
<ParamVariationCollapse /> <ParamVariationCollapse />
<ParamSymmetryCollapse /> <ParamSymmetryCollapse />

View File

@ -15,6 +15,7 @@ export const imagesApi = api.injectEndpoints({
} }
return tags; return tags;
}, },
keepUnusedDataFor: 86400, // 24 hours
}), }),
}), }),
}); });

View File

@ -2917,7 +2917,7 @@ export type components = {
/** ModelsList */ /** ModelsList */
ModelsList: { ModelsList: {
/** Models */ /** Models */
models: (components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"])[]; models: (components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"])[];
}; };
/** /**
* MultiplyInvocation * MultiplyInvocation
@ -4177,18 +4177,18 @@ export type components = {
*/ */
image?: components["schemas"]["ImageField"]; image?: components["schemas"]["ImageField"];
}; };
/**
* StableDiffusion2ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
/** /**
* StableDiffusion1ModelFormat * StableDiffusion1ModelFormat
* @description An enumeration. * @description An enumeration.
* @enum {string} * @enum {string}
*/ */
StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion2ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
}; };
responses: never; responses: never;
parameters: never; parameters: never;

View File

@ -1,20 +1,45 @@
import SwaggerParser from '@apidevtools/swagger-parser';
import { createAsyncThunk } from '@reduxjs/toolkit'; import { createAsyncThunk } from '@reduxjs/toolkit';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { parsedOpenAPISchema } from 'features/nodes/store/nodesSlice';
import { OpenAPIV3 } from 'openapi-types'; import { OpenAPIV3 } from 'openapi-types';
const schemaLog = log.child({ namespace: 'schema' }); const schemaLog = log.child({ namespace: 'schema' });
function getCircularReplacer() {
const ancestors: Record<string, any>[] = [];
return function (key: string, value: any) {
if (typeof value !== 'object' || value === null) {
return value;
}
// `this` is the object that value is contained in,
// i.e., its direct parent.
// @ts-ignore
while (ancestors.length > 0 && ancestors.at(-1) !== this) {
ancestors.pop();
}
if (ancestors.includes(value)) {
return '[Circular]';
}
ancestors.push(value);
return value;
};
}
export const receivedOpenAPISchema = createAsyncThunk( export const receivedOpenAPISchema = createAsyncThunk(
'nodes/receivedOpenAPISchema', 'nodes/receivedOpenAPISchema',
async (_, { dispatch }): Promise<OpenAPIV3.Document> => { async (_, { dispatch, rejectWithValue }) => {
const response = await fetch(`openapi.json`); try {
const openAPISchema = await response.json(); const dereferencedSchema = (await SwaggerParser.dereference(
'openapi.json'
)) as OpenAPIV3.Document;
schemaLog.info({ openAPISchema }, 'Received OpenAPI schema'); const schemaJSON = JSON.parse(
JSON.stringify(dereferencedSchema, getCircularReplacer())
);
dispatch(parsedOpenAPISchema(openAPISchema as OpenAPIV3.Document)); return schemaJSON;
} catch (error) {
return openAPISchema; return rejectWithValue({ error });
}
} }
); );

View File

@ -1,7 +1,14 @@
import { O } from 'ts-toolbelt';
import { components } from './schema'; import { components } from './schema';
type schemas = components['schemas']; type schemas = components['schemas'];
/**
* Helper type to extract the invocation type from the schema.
* Also flags the `type` property as required.
*/
type Invocation<T extends keyof schemas> = O.Required<schemas[T], 'type'>;
/** /**
* Types from the API, re-exported from the types generated by `openapi-typescript`. * Types from the API, re-exported from the types generated by `openapi-typescript`.
*/ */
@ -31,42 +38,51 @@ export type Edge = schemas['Edge'];
export type GraphExecutionState = schemas['GraphExecutionState']; export type GraphExecutionState = schemas['GraphExecutionState'];
// General nodes // General nodes
export type CollectInvocation = schemas['CollectInvocation']; export type CollectInvocation = Invocation<'CollectInvocation'>;
export type IterateInvocation = schemas['IterateInvocation']; export type IterateInvocation = Invocation<'IterateInvocation'>;
export type RangeInvocation = schemas['RangeInvocation']; export type RangeInvocation = Invocation<'RangeInvocation'>;
export type RandomRangeInvocation = schemas['RandomRangeInvocation']; export type RandomRangeInvocation = Invocation<'RandomRangeInvocation'>;
export type RangeOfSizeInvocation = schemas['RangeOfSizeInvocation']; export type RangeOfSizeInvocation = Invocation<'RangeOfSizeInvocation'>;
export type InpaintInvocation = schemas['InpaintInvocation']; export type InpaintInvocation = Invocation<'InpaintInvocation'>;
export type ImageResizeInvocation = schemas['ImageResizeInvocation']; export type ImageResizeInvocation = Invocation<'ImageResizeInvocation'>;
export type RandomIntInvocation = schemas['RandomIntInvocation']; export type RandomIntInvocation = Invocation<'RandomIntInvocation'>;
export type CompelInvocation = schemas['CompelInvocation']; export type CompelInvocation = Invocation<'CompelInvocation'>;
export type DynamicPromptInvocation = Invocation<'DynamicPromptInvocation'>;
export type NoiseInvocation = Invocation<'NoiseInvocation'>;
export type TextToLatentsInvocation = Invocation<'TextToLatentsInvocation'>;
export type LatentsToLatentsInvocation =
Invocation<'LatentsToLatentsInvocation'>;
export type ImageToLatentsInvocation = Invocation<'ImageToLatentsInvocation'>;
export type LatentsToImageInvocation = Invocation<'LatentsToImageInvocation'>;
export type PipelineModelLoaderInvocation =
Invocation<'PipelineModelLoaderInvocation'>;
// ControlNet Nodes // ControlNet Nodes
export type ControlNetInvocation = schemas['ControlNetInvocation']; export type ControlNetInvocation = Invocation<'ControlNetInvocation'>;
export type CannyImageProcessorInvocation = export type CannyImageProcessorInvocation =
schemas['CannyImageProcessorInvocation']; Invocation<'CannyImageProcessorInvocation'>;
export type ContentShuffleImageProcessorInvocation = export type ContentShuffleImageProcessorInvocation =
schemas['ContentShuffleImageProcessorInvocation']; Invocation<'ContentShuffleImageProcessorInvocation'>;
export type HedImageProcessorInvocation = export type HedImageProcessorInvocation =
schemas['HedImageProcessorInvocation']; Invocation<'HedImageProcessorInvocation'>;
export type LineartAnimeImageProcessorInvocation = export type LineartAnimeImageProcessorInvocation =
schemas['LineartAnimeImageProcessorInvocation']; Invocation<'LineartAnimeImageProcessorInvocation'>;
export type LineartImageProcessorInvocation = export type LineartImageProcessorInvocation =
schemas['LineartImageProcessorInvocation']; Invocation<'LineartImageProcessorInvocation'>;
export type MediapipeFaceProcessorInvocation = export type MediapipeFaceProcessorInvocation =
schemas['MediapipeFaceProcessorInvocation']; Invocation<'MediapipeFaceProcessorInvocation'>;
export type MidasDepthImageProcessorInvocation = export type MidasDepthImageProcessorInvocation =
schemas['MidasDepthImageProcessorInvocation']; Invocation<'MidasDepthImageProcessorInvocation'>;
export type MlsdImageProcessorInvocation = export type MlsdImageProcessorInvocation =
schemas['MlsdImageProcessorInvocation']; Invocation<'MlsdImageProcessorInvocation'>;
export type NormalbaeImageProcessorInvocation = export type NormalbaeImageProcessorInvocation =
schemas['NormalbaeImageProcessorInvocation']; Invocation<'NormalbaeImageProcessorInvocation'>;
export type OpenposeImageProcessorInvocation = export type OpenposeImageProcessorInvocation =
schemas['OpenposeImageProcessorInvocation']; Invocation<'OpenposeImageProcessorInvocation'>;
export type PidiImageProcessorInvocation = export type PidiImageProcessorInvocation =
schemas['PidiImageProcessorInvocation']; Invocation<'PidiImageProcessorInvocation'>;
export type ZoeDepthImageProcessorInvocation = export type ZoeDepthImageProcessorInvocation =
schemas['ZoeDepthImageProcessorInvocation']; Invocation<'ZoeDepthImageProcessorInvocation'>;
// Node Outputs // Node Outputs
export type ImageOutput = schemas['ImageOutput']; export type ImageOutput = schemas['ImageOutput'];

View File

@ -9,6 +9,7 @@
"vite.config.ts", "vite.config.ts",
"./config/vite.app.config.ts", "./config/vite.app.config.ts",
"./config/vite.package.config.ts", "./config/vite.package.config.ts",
"./config/vite.common.config.ts" "./config/vite.common.config.ts",
"./config/common.ts"
] ]
} }

File diff suppressed because it is too large Load Diff

30
tests/conftest.py Normal file
View File

@ -0,0 +1,30 @@
import pytest
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
from invokeai.app.services.graph import LibraryGraph, GraphExecutionState
from invokeai.app.services.processor import DefaultInvocationProcessor
# Ignore these files as they need to be rewritten following the model manager refactor
collect_ignore = ["nodes/test_graph_execution_state.py", "nodes/test_node_graph.py", "test_textual_inversion.py"]
@pytest.fixture(scope="session", autouse=True)
def mock_services():
# NOTE: none of these are actually called by the test invocations
return InvocationServices(
model_manager = None, # type: ignore
events = None, # type: ignore
logger = None, # type: ignore
images = None, # type: ignore
latents = None, # type: ignore
board_images=None, # type: ignore
boards=None, # type: ignore
queue = MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=sqlite_memory, table_name="graphs"
),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor(),
restoration = None, # type: ignore
configuration = None, # type: ignore
)

View File

@ -1,14 +1,18 @@
from .test_invoker import create_edge import pytest
from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext from invokeai.app.invocations.baseinvocation import (BaseInvocation,
BaseInvocationOutput,
InvocationContext)
from invokeai.app.invocations.collections import RangeInvocation from invokeai.app.invocations.collections import RangeInvocation
from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation
from invokeai.app.services.processor import DefaultInvocationProcessor from invokeai.app.services.graph import (CollectInvocation, Graph,
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory GraphExecutionState,
from invokeai.app.services.invocation_queue import MemoryInvocationQueue IterateInvocation)
from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, LibraryGraph, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
import pytest from .test_invoker import create_edge
from .test_nodes import (ImageTestInvocation, PromptCollectionTestInvocation,
PromptTestInvocation)
@pytest.fixture @pytest.fixture
@ -19,30 +23,11 @@ def simple_graph():
g.add_edge(create_edge("1", "prompt", "2", "prompt")) g.add_edge(create_edge("1", "prompt", "2", "prompt"))
return g return g
@pytest.fixture
def mock_services():
# NOTE: none of these are actually called by the test invocations
return InvocationServices(
model_manager = None, # type: ignore
events = None, # type: ignore
logger = None, # type: ignore
images = None, # type: ignore
latents = None, # type: ignore
queue = MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=sqlite_memory, table_name="graphs"
),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor(),
restoration = None, # type: ignore
configuration = None, # type: ignore
)
def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]: def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]:
n = g.next() n = g.next()
if n is None: if n is None:
return (None, None) return (None, None)
print(f'invoking {n.id}: {type(n)}') print(f'invoking {n.id}: {type(n)}')
o = n.invoke(InvocationContext(services, "1")) o = n.invoke(InvocationContext(services, "1"))
g.complete(n.id, o) g.complete(n.id, o)
@ -51,7 +36,7 @@ def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[B
def test_graph_state_executes_in_order(simple_graph, mock_services): def test_graph_state_executes_in_order(simple_graph, mock_services):
g = GraphExecutionState(graph = simple_graph) g = GraphExecutionState(graph = simple_graph)
n1 = invoke_next(g, mock_services) n1 = invoke_next(g, mock_services)
n2 = invoke_next(g, mock_services) n2 = invoke_next(g, mock_services)
n3 = g.next() n3 = g.next()
@ -88,11 +73,11 @@ def test_graph_state_expands_iterator(mock_services):
graph.add_edge(create_edge("0", "collection", "1", "collection")) graph.add_edge(create_edge("0", "collection", "1", "collection"))
graph.add_edge(create_edge("1", "item", "2", "a")) graph.add_edge(create_edge("1", "item", "2", "a"))
graph.add_edge(create_edge("2", "a", "3", "a")) graph.add_edge(create_edge("2", "a", "3", "a"))
g = GraphExecutionState(graph = graph) g = GraphExecutionState(graph = graph)
while not g.is_complete(): while not g.is_complete():
invoke_next(g, mock_services) invoke_next(g, mock_services)
prepared_add_nodes = g.source_prepared_mapping['3'] prepared_add_nodes = g.source_prepared_mapping['3']
results = set([g.results[n].a for n in prepared_add_nodes]) results = set([g.results[n].a for n in prepared_add_nodes])
expected = set([1, 11, 21]) expected = set([1, 11, 21])
@ -109,7 +94,7 @@ def test_graph_state_collects(mock_services):
graph.add_edge(create_edge("1", "collection", "2", "collection")) graph.add_edge(create_edge("1", "collection", "2", "collection"))
graph.add_edge(create_edge("2", "item", "3", "prompt")) graph.add_edge(create_edge("2", "item", "3", "prompt"))
graph.add_edge(create_edge("3", "prompt", "4", "item")) graph.add_edge(create_edge("3", "prompt", "4", "item"))
g = GraphExecutionState(graph = graph) g = GraphExecutionState(graph = graph)
n1 = invoke_next(g, mock_services) n1 = invoke_next(g, mock_services)
n2 = invoke_next(g, mock_services) n2 = invoke_next(g, mock_services)

View File

@ -1,13 +1,12 @@
from .test_nodes import ErrorInvocation, ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation, TestEventService, create_edge, wait_until
from invokeai.app.services.processor import DefaultInvocationProcessor
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
from invokeai.app.services.invoker import Invoker
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, LibraryGraph, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
import pytest import pytest
from invokeai.app.services.graph import Graph, GraphExecutionState
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invoker import Invoker
from .test_nodes import (ErrorInvocation, ImageTestInvocation,
PromptTestInvocation, create_edge, wait_until)
@pytest.fixture @pytest.fixture
def simple_graph(): def simple_graph():
@ -17,25 +16,6 @@ def simple_graph():
g.add_edge(create_edge("1", "prompt", "2", "prompt")) g.add_edge(create_edge("1", "prompt", "2", "prompt"))
return g return g
@pytest.fixture
def mock_services() -> InvocationServices:
# NOTE: none of these are actually called by the test invocations
return InvocationServices(
model_manager = None, # type: ignore
events = TestEventService(),
logger = None, # type: ignore
images = None, # type: ignore
latents = None, # type: ignore
queue = MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=sqlite_memory, table_name="graphs"
),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor(),
restoration = None, # type: ignore
configuration = None, # type: ignore
)
@pytest.fixture() @pytest.fixture()
def mock_invoker(mock_services: InvocationServices) -> Invoker: def mock_invoker(mock_services: InvocationServices) -> Invoker:
return Invoker( return Invoker(
@ -57,6 +37,7 @@ def test_can_create_graph_state_from_graph(mock_invoker: Invoker, simple_graph):
assert isinstance(g, GraphExecutionState) assert isinstance(g, GraphExecutionState)
assert g.graph == simple_graph assert g.graph == simple_graph
@pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
def test_can_invoke(mock_invoker: Invoker, simple_graph): def test_can_invoke(mock_invoker: Invoker, simple_graph):
g = mock_invoker.create_execution_state(graph = simple_graph) g = mock_invoker.create_execution_state(graph = simple_graph)
invocation_id = mock_invoker.invoke(g) invocation_id = mock_invoker.invoke(g)
@ -72,6 +53,7 @@ def test_can_invoke(mock_invoker: Invoker, simple_graph):
g = mock_invoker.services.graph_execution_manager.get(g.id) g = mock_invoker.services.graph_execution_manager.get(g.id)
assert len(g.executed) > 0 assert len(g.executed) > 0
@pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
def test_can_invoke_all(mock_invoker: Invoker, simple_graph): def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
g = mock_invoker.create_execution_state(graph = simple_graph) g = mock_invoker.create_execution_state(graph = simple_graph)
invocation_id = mock_invoker.invoke(g, invoke_all = True) invocation_id = mock_invoker.invoke(g, invoke_all = True)
@ -87,6 +69,7 @@ def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
g = mock_invoker.services.graph_execution_manager.get(g.id) g = mock_invoker.services.graph_execution_manager.get(g.id)
assert g.is_complete() assert g.is_complete()
@pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
def test_handles_errors(mock_invoker: Invoker): def test_handles_errors(mock_invoker: Invoker):
g = mock_invoker.create_execution_state() g = mock_invoker.create_execution_state()
g.graph.add_node(ErrorInvocation(id = "1")) g.graph.add_node(ErrorInvocation(id = "1"))