mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): improved node parsing
- use `swagger-parser` to dereference openapi schema - tidy vite plugins - use mantine select for node add menu
This commit is contained in:
parent
922468b836
commit
862bf7546c
14
invokeai/frontend/web/config/common.ts
Normal file
14
invokeai/frontend/web/config/common.ts
Normal file
@ -0,0 +1,14 @@
|
||||
import react from '@vitejs/plugin-react-swc';
|
||||
import { visualizer } from 'rollup-plugin-visualizer';
|
||||
import { PluginOption, UserConfig } from 'vite';
|
||||
import eslint from 'vite-plugin-eslint';
|
||||
import tsconfigPaths from 'vite-tsconfig-paths';
|
||||
import { nodePolyfills } from 'vite-plugin-node-polyfills';
|
||||
|
||||
export const commonPlugins: UserConfig['plugins'] = [
|
||||
react(),
|
||||
eslint(),
|
||||
tsconfigPaths(),
|
||||
visualizer() as unknown as PluginOption,
|
||||
nodePolyfills(),
|
||||
];
|
@ -1,17 +1,9 @@
|
||||
import react from '@vitejs/plugin-react-swc';
|
||||
import { visualizer } from 'rollup-plugin-visualizer';
|
||||
import { PluginOption, UserConfig } from 'vite';
|
||||
import eslint from 'vite-plugin-eslint';
|
||||
import tsconfigPaths from 'vite-tsconfig-paths';
|
||||
import { UserConfig } from 'vite';
|
||||
import { commonPlugins } from './common';
|
||||
|
||||
export const appConfig: UserConfig = {
|
||||
base: './',
|
||||
plugins: [
|
||||
react(),
|
||||
eslint(),
|
||||
tsconfigPaths(),
|
||||
visualizer() as unknown as PluginOption,
|
||||
],
|
||||
plugins: [...commonPlugins],
|
||||
build: {
|
||||
chunkSizeWarningLimit: 1500,
|
||||
},
|
||||
|
@ -1,19 +1,13 @@
|
||||
import react from '@vitejs/plugin-react-swc';
|
||||
import path from 'path';
|
||||
import { visualizer } from 'rollup-plugin-visualizer';
|
||||
import { PluginOption, UserConfig } from 'vite';
|
||||
import { UserConfig } from 'vite';
|
||||
import dts from 'vite-plugin-dts';
|
||||
import eslint from 'vite-plugin-eslint';
|
||||
import tsconfigPaths from 'vite-tsconfig-paths';
|
||||
import cssInjectedByJsPlugin from 'vite-plugin-css-injected-by-js';
|
||||
import { commonPlugins } from './common';
|
||||
|
||||
export const packageConfig: UserConfig = {
|
||||
base: './',
|
||||
plugins: [
|
||||
react(),
|
||||
eslint(),
|
||||
tsconfigPaths(),
|
||||
visualizer() as unknown as PluginOption,
|
||||
...commonPlugins,
|
||||
dts({
|
||||
insertTypesEntry: true,
|
||||
}),
|
||||
|
@ -53,6 +53,7 @@
|
||||
]
|
||||
},
|
||||
"dependencies": {
|
||||
"@apidevtools/swagger-parser": "^10.1.0",
|
||||
"@chakra-ui/anatomy": "^2.1.1",
|
||||
"@chakra-ui/icons": "^2.0.19",
|
||||
"@chakra-ui/react": "^2.7.1",
|
||||
@ -154,6 +155,7 @@
|
||||
"vite-plugin-css-injected-by-js": "^3.1.1",
|
||||
"vite-plugin-dts": "^2.3.0",
|
||||
"vite-plugin-eslint": "^1.8.1",
|
||||
"vite-plugin-node-polyfills": "^0.9.0",
|
||||
"vite-tsconfig-paths": "^4.2.0",
|
||||
"yarn": "^1.22.19"
|
||||
}
|
||||
|
@ -15,7 +15,7 @@ import { ImageDTO } from 'services/api/types';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
||||
import { controlNetSelector } from 'features/controlNet/store/controlNetSlice';
|
||||
import { nodesSelecter } from 'features/nodes/store/nodesSlice';
|
||||
import { nodesSelector } from 'features/nodes/store/nodesSlice';
|
||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||
import { some } from 'lodash-es';
|
||||
|
||||
@ -30,7 +30,7 @@ export const selectImageUsage = createSelector(
|
||||
[
|
||||
generationSelector,
|
||||
canvasSelector,
|
||||
nodesSelecter,
|
||||
nodesSelector,
|
||||
controlNetSelector,
|
||||
(state: RootState, image_name?: string) => image_name,
|
||||
],
|
||||
|
@ -82,6 +82,7 @@ import {
|
||||
addImageRemovedFromBoardFulfilledListener,
|
||||
addImageRemovedFromBoardRejectedListener,
|
||||
} from './listeners/imageRemovedFromBoard';
|
||||
import { addReceivedOpenAPISchemaListener } from './listeners/receivedOpenAPISchema';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
|
||||
@ -205,3 +206,6 @@ addImageAddedToBoardRejectedListener();
|
||||
addImageRemovedFromBoardFulfilledListener();
|
||||
addImageRemovedFromBoardRejectedListener();
|
||||
addBoardIdSelectedListener();
|
||||
|
||||
// Node schemas
|
||||
addReceivedOpenAPISchemaListener();
|
||||
|
@ -0,0 +1,35 @@
|
||||
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
|
||||
import { startAppListening } from '..';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { parseSchema } from 'features/nodes/util/parseSchema';
|
||||
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
|
||||
import { size } from 'lodash-es';
|
||||
|
||||
const schemaLog = log.child({ namespace: 'schema' });
|
||||
|
||||
export const addReceivedOpenAPISchemaListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: receivedOpenAPISchema.fulfilled,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const schemaJSON = action.payload;
|
||||
|
||||
schemaLog.info({ data: { schemaJSON } }, 'Dereferenced OpenAPI schema');
|
||||
|
||||
const nodeTemplates = parseSchema(schemaJSON);
|
||||
|
||||
schemaLog.info(
|
||||
{ data: { nodeTemplates } },
|
||||
`Built ${size(nodeTemplates)} node templates`
|
||||
);
|
||||
|
||||
dispatch(nodeTemplatesBuilt(nodeTemplates));
|
||||
},
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
actionCreator: receivedOpenAPISchema.rejected,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
schemaLog.error('Problem dereferencing OpenAPI Schema');
|
||||
},
|
||||
});
|
||||
};
|
@ -3,7 +3,7 @@ import { startAppListening } from '..';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
||||
import { nodesSelecter } from 'features/nodes/store/nodesSlice';
|
||||
import { nodesSelector } from 'features/nodes/store/nodesSlice';
|
||||
import { controlNetSelector } from 'features/controlNet/store/controlNetSlice';
|
||||
import { forEach, uniqBy } from 'lodash-es';
|
||||
import { imageUrlsReceived } from 'services/api/thunks/image';
|
||||
@ -16,7 +16,7 @@ const selectAllUsedImages = createSelector(
|
||||
[
|
||||
generationSelector,
|
||||
canvasSelector,
|
||||
nodesSelecter,
|
||||
nodesSelector,
|
||||
controlNetSelector,
|
||||
selectImagesEntities,
|
||||
],
|
||||
|
@ -27,7 +27,6 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
|
||||
borderWidth: '2px',
|
||||
borderColor: 'var(--invokeai-colors-base-800)',
|
||||
color: 'var(--invokeai-colors-base-100)',
|
||||
padding: 10,
|
||||
paddingRight: 24,
|
||||
fontWeight: 600,
|
||||
'&:hover': { borderColor: 'var(--invokeai-colors-base-700)' },
|
||||
|
@ -64,7 +64,7 @@ const IAIMantineSelect = (props: IAISelectProps) => {
|
||||
},
|
||||
},
|
||||
rightSection: {
|
||||
width: 24,
|
||||
width: 32,
|
||||
},
|
||||
})}
|
||||
{...rest}
|
||||
|
@ -1,28 +1,41 @@
|
||||
import 'reactflow/dist/style.css';
|
||||
import { memo, useCallback } from 'react';
|
||||
import {
|
||||
Tooltip,
|
||||
Menu,
|
||||
MenuButton,
|
||||
MenuList,
|
||||
MenuItem,
|
||||
} from '@chakra-ui/react';
|
||||
import { FaEllipsisV } from 'react-icons/fa';
|
||||
import { useCallback, forwardRef } from 'react';
|
||||
import { Flex, Text } from '@chakra-ui/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { nodeAdded } from '../store/nodesSlice';
|
||||
import { nodeAdded, nodesSelector } from '../store/nodesSlice';
|
||||
import { map } from 'lodash-es';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useBuildInvocation } from '../hooks/useBuildInvocation';
|
||||
import { AnyInvocationType } from 'services/events/types';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { useAppToaster } from 'app/components/Toaster';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
|
||||
|
||||
type NodeTemplate = {
|
||||
label: string;
|
||||
value: string;
|
||||
description: string;
|
||||
};
|
||||
|
||||
const selector = createSelector(
|
||||
nodesSelector,
|
||||
(nodes) => {
|
||||
const data: NodeTemplate[] = map(nodes.invocationTemplates, (template) => {
|
||||
return {
|
||||
label: template.title,
|
||||
value: template.type,
|
||||
description: template.description,
|
||||
};
|
||||
});
|
||||
|
||||
return { data };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const AddNodeMenu = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const invocationTemplates = useAppSelector(
|
||||
(state: RootState) => state.nodes.invocationTemplates
|
||||
);
|
||||
const { data } = useAppSelector(selector);
|
||||
|
||||
const buildInvocation = useBuildInvocation();
|
||||
|
||||
@ -46,23 +59,52 @@ const AddNodeMenu = () => {
|
||||
);
|
||||
|
||||
return (
|
||||
<Menu isLazy>
|
||||
<MenuButton
|
||||
as={IAIIconButton}
|
||||
aria-label="Add Node"
|
||||
icon={<FaEllipsisV />}
|
||||
<Flex sx={{ gap: 2, alignItems: 'center' }}>
|
||||
<IAIMantineMultiSelect
|
||||
selectOnBlur={false}
|
||||
placeholder="Add Node"
|
||||
value={[]}
|
||||
data={data}
|
||||
maxDropdownHeight={400}
|
||||
nothingFound="No matching nodes"
|
||||
itemComponent={SelectItem}
|
||||
filter={(value, selected, item: NodeTemplate) =>
|
||||
item.label.toLowerCase().includes(value.toLowerCase().trim()) ||
|
||||
item.value.toLowerCase().includes(value.toLowerCase().trim()) ||
|
||||
item.description.toLowerCase().includes(value.toLowerCase().trim())
|
||||
}
|
||||
onChange={(v) => {
|
||||
v[0] && addNode(v[0] as AnyInvocationType);
|
||||
}}
|
||||
sx={{
|
||||
width: '18rem',
|
||||
}}
|
||||
/>
|
||||
<MenuList overflowY="scroll" height={400}>
|
||||
{map(invocationTemplates, ({ title, description, type }, key) => {
|
||||
return (
|
||||
<Tooltip key={key} label={description} placement="end" hasArrow>
|
||||
<MenuItem onClick={() => addNode(type)}>{title}</MenuItem>
|
||||
</Tooltip>
|
||||
);
|
||||
})}
|
||||
</MenuList>
|
||||
</Menu>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(AddNodeMenu);
|
||||
interface ItemProps extends React.ComponentPropsWithoutRef<'div'> {
|
||||
value: string;
|
||||
label: string;
|
||||
description: string;
|
||||
}
|
||||
|
||||
const SelectItem = forwardRef<HTMLDivElement, ItemProps>(
|
||||
({ label, description, ...others }: ItemProps, ref) => {
|
||||
return (
|
||||
<div ref={ref} {...others}>
|
||||
<div>
|
||||
<Text>{label}</Text>
|
||||
<Text size="xs" color="base.600">
|
||||
{description}
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
SelectItem.displayName = 'SelectItem';
|
||||
|
||||
export default AddNodeMenu;
|
||||
|
@ -1,10 +1,10 @@
|
||||
import { memo } from 'react';
|
||||
import { Panel } from 'reactflow';
|
||||
import NodeSearch from '../search/NodeSearch';
|
||||
import AddNodeMenu from '../AddNodeMenu';
|
||||
|
||||
const TopLeftPanel = () => (
|
||||
<Panel position="top-left">
|
||||
<NodeSearch />
|
||||
<AddNodeMenu />
|
||||
</Panel>
|
||||
);
|
||||
|
||||
|
@ -14,9 +14,6 @@ import {
|
||||
import { ImageField } from 'services/api/types';
|
||||
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
|
||||
import { InvocationTemplate, InvocationValue } from '../types/types';
|
||||
import { parseSchema } from '../util/parseSchema';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { size } from 'lodash-es';
|
||||
import { RgbaColor } from 'react-colorful';
|
||||
import { RootState } from 'app/store/store';
|
||||
|
||||
@ -78,25 +75,17 @@ const nodesSlice = createSlice({
|
||||
shouldShowGraphOverlayChanged: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldShowGraphOverlay = action.payload;
|
||||
},
|
||||
parsedOpenAPISchema: (state, action: PayloadAction<OpenAPIV3.Document>) => {
|
||||
try {
|
||||
const parsedSchema = parseSchema(action.payload);
|
||||
|
||||
// TODO: Achtung! Side effect in a reducer!
|
||||
log.info(
|
||||
{ namespace: 'schema', nodes: parsedSchema },
|
||||
`Parsed ${size(parsedSchema)} nodes`
|
||||
);
|
||||
state.invocationTemplates = parsedSchema;
|
||||
} catch (err) {
|
||||
console.error(err);
|
||||
}
|
||||
nodeTemplatesBuilt: (
|
||||
state,
|
||||
action: PayloadAction<Record<string, InvocationTemplate>>
|
||||
) => {
|
||||
state.invocationTemplates = action.payload;
|
||||
},
|
||||
nodeEditorReset: () => {
|
||||
return { ...initialNodesState };
|
||||
},
|
||||
},
|
||||
extraReducers(builder) {
|
||||
extraReducers: (builder) => {
|
||||
builder.addCase(receivedOpenAPISchema.fulfilled, (state, action) => {
|
||||
state.schema = action.payload;
|
||||
});
|
||||
@ -112,10 +101,10 @@ export const {
|
||||
connectionStarted,
|
||||
connectionEnded,
|
||||
shouldShowGraphOverlayChanged,
|
||||
parsedOpenAPISchema,
|
||||
nodeTemplatesBuilt,
|
||||
nodeEditorReset,
|
||||
} = nodesSlice.actions;
|
||||
|
||||
export default nodesSlice.reducer;
|
||||
|
||||
export const nodesSelecter = (state: RootState) => state.nodes;
|
||||
export const nodesSelector = (state: RootState) => state.nodes;
|
||||
|
@ -34,12 +34,10 @@ export type InvocationTemplate = {
|
||||
* Array of invocation inputs
|
||||
*/
|
||||
inputs: Record<string, InputFieldTemplate>;
|
||||
// inputs: InputField[];
|
||||
/**
|
||||
* Array of the invocation outputs
|
||||
*/
|
||||
outputs: Record<string, OutputFieldTemplate>;
|
||||
// outputs: OutputField[];
|
||||
};
|
||||
|
||||
export type FieldUIConfig = {
|
||||
@ -335,7 +333,7 @@ export type TypeHints = {
|
||||
};
|
||||
|
||||
export type InvocationSchemaExtra = {
|
||||
output: OpenAPIV3.ReferenceObject; // the output of the invocation
|
||||
output: OpenAPIV3.SchemaObject; // the output of the invocation
|
||||
ui?: {
|
||||
tags?: string[];
|
||||
type_hints?: TypeHints;
|
||||
|
@ -349,21 +349,11 @@ export const getFieldType = (
|
||||
|
||||
if (typeHints && name in typeHints) {
|
||||
rawFieldType = typeHints[name];
|
||||
} else if (!schemaObject.type) {
|
||||
// if schemaObject has no type, then it should have one of allOf, anyOf, oneOf
|
||||
if (schemaObject.allOf) {
|
||||
rawFieldType = refObjectToFieldType(
|
||||
schemaObject.allOf![0] as OpenAPIV3.ReferenceObject
|
||||
);
|
||||
} else if (schemaObject.anyOf) {
|
||||
rawFieldType = refObjectToFieldType(
|
||||
schemaObject.anyOf![0] as OpenAPIV3.ReferenceObject
|
||||
);
|
||||
} else if (schemaObject.oneOf) {
|
||||
rawFieldType = refObjectToFieldType(
|
||||
schemaObject.oneOf![0] as OpenAPIV3.ReferenceObject
|
||||
);
|
||||
}
|
||||
} else if (!schemaObject.type && schemaObject.allOf) {
|
||||
// if schemaObject has no type, then it should have one of allOf
|
||||
rawFieldType =
|
||||
(schemaObject.allOf[0] as OpenAPIV3.SchemaObject).title ??
|
||||
'Missing Field Type';
|
||||
} else if (schemaObject.enum) {
|
||||
rawFieldType = 'enum';
|
||||
} else if (schemaObject.type) {
|
||||
|
@ -5,127 +5,154 @@ import {
|
||||
InputFieldTemplate,
|
||||
InvocationSchemaObject,
|
||||
InvocationTemplate,
|
||||
isInvocationSchemaObject,
|
||||
OutputFieldTemplate,
|
||||
} from '../types/types';
|
||||
import {
|
||||
buildInputFieldTemplate,
|
||||
buildOutputFieldTemplates,
|
||||
} from './fieldTemplateBuilders';
|
||||
import { buildInputFieldTemplate, getFieldType } from './fieldTemplateBuilders';
|
||||
import { O } from 'ts-toolbelt';
|
||||
|
||||
// recursively exclude all properties of type U from T
|
||||
type DeepExclude<T, U> = T extends U
|
||||
? never
|
||||
: T extends object
|
||||
? {
|
||||
[K in keyof T]: DeepExclude<T[K], U>;
|
||||
}
|
||||
: T;
|
||||
|
||||
// The schema from swagger-parser is dereferenced, and we know `components` and `components.schemas` exist
|
||||
type DereferencedOpenAPIDocument = DeepExclude<
|
||||
O.Required<OpenAPIV3.Document, 'schemas' | 'components', 'deep'>,
|
||||
OpenAPIV3.ReferenceObject
|
||||
>;
|
||||
|
||||
const RESERVED_FIELD_NAMES = ['id', 'type', 'is_intermediate'];
|
||||
|
||||
const invocationDenylist = ['Graph', 'InvocationMeta'];
|
||||
|
||||
export const parseSchema = (openAPI: OpenAPIV3.Document) => {
|
||||
const nodeFilter = (
|
||||
schema: DereferencedOpenAPIDocument['components']['schemas'][string],
|
||||
key: string
|
||||
) =>
|
||||
key.includes('Invocation') &&
|
||||
!key.includes('InvocationOutput') &&
|
||||
!invocationDenylist.some((denylistItem) => key.includes(denylistItem));
|
||||
|
||||
export const parseSchema = (openAPI: DereferencedOpenAPIDocument) => {
|
||||
// filter out non-invocation schemas, plus some tricky invocations for now
|
||||
const filteredSchemas = filter(
|
||||
openAPI.components!.schemas,
|
||||
(schema, key) =>
|
||||
key.includes('Invocation') &&
|
||||
!key.includes('InvocationOutput') &&
|
||||
!invocationDenylist.some((denylistItem) => key.includes(denylistItem))
|
||||
) as (OpenAPIV3.ReferenceObject | InvocationSchemaObject)[];
|
||||
const filteredSchemas = filter(openAPI.components.schemas, nodeFilter);
|
||||
|
||||
const invocations = filteredSchemas.reduce<
|
||||
Record<string, InvocationTemplate>
|
||||
>((acc, schema) => {
|
||||
// only want SchemaObjects
|
||||
if (isInvocationSchemaObject(schema)) {
|
||||
const type = schema.properties.type.default;
|
||||
>((acc, s) => {
|
||||
// cast to InvocationSchemaObject, we know the shape
|
||||
const schema = s as InvocationSchemaObject;
|
||||
|
||||
const title = schema.ui?.title ?? schema.title.replace('Invocation', '');
|
||||
const type = schema.properties.type.default;
|
||||
|
||||
const typeHints = schema.ui?.type_hints;
|
||||
const title = schema.ui?.title ?? schema.title.replace('Invocation', '');
|
||||
|
||||
const inputs: Record<string, InputFieldTemplate> = {};
|
||||
const typeHints = schema.ui?.type_hints;
|
||||
|
||||
if (type === 'collect') {
|
||||
const itemProperty = schema.properties[
|
||||
'item'
|
||||
] as InvocationSchemaObject;
|
||||
// Handle the special Collect node
|
||||
inputs.item = {
|
||||
type: 'item',
|
||||
name: 'item',
|
||||
description: itemProperty.description ?? '',
|
||||
title: 'Collection Item',
|
||||
inputKind: 'connection',
|
||||
inputRequirement: 'always',
|
||||
default: undefined,
|
||||
};
|
||||
} else if (type === 'iterate') {
|
||||
const itemProperty = schema.properties[
|
||||
'collection'
|
||||
] as InvocationSchemaObject;
|
||||
const inputs: Record<string, InputFieldTemplate> = {};
|
||||
|
||||
inputs.collection = {
|
||||
type: 'array',
|
||||
name: 'collection',
|
||||
title: itemProperty.title ?? '',
|
||||
default: [],
|
||||
description: itemProperty.description ?? '',
|
||||
inputRequirement: 'always',
|
||||
inputKind: 'connection',
|
||||
};
|
||||
} else {
|
||||
// All other nodes
|
||||
reduce(
|
||||
schema.properties,
|
||||
(inputsAccumulator, property, propertyName) => {
|
||||
if (
|
||||
// `type` and `id` are not valid inputs/outputs
|
||||
!RESERVED_FIELD_NAMES.includes(propertyName) &&
|
||||
isSchemaObject(property)
|
||||
) {
|
||||
const field: InputFieldTemplate | undefined =
|
||||
buildInputFieldTemplate(property, propertyName, typeHints);
|
||||
|
||||
if (field) {
|
||||
inputsAccumulator[propertyName] = field;
|
||||
}
|
||||
}
|
||||
return inputsAccumulator;
|
||||
},
|
||||
inputs
|
||||
);
|
||||
}
|
||||
|
||||
const rawOutput = (schema as InvocationSchemaObject).output;
|
||||
|
||||
let outputs: Record<string, OutputFieldTemplate>;
|
||||
|
||||
// some special handling is needed for collect, iterate and range nodes
|
||||
if (type === 'iterate') {
|
||||
// this is guaranteed to be a SchemaObject
|
||||
const iterationOutput = openAPI.components!.schemas![
|
||||
'IterateInvocationOutput'
|
||||
] as OpenAPIV3.SchemaObject;
|
||||
|
||||
outputs = {
|
||||
item: {
|
||||
name: 'item',
|
||||
title: iterationOutput.title ?? '',
|
||||
description: iterationOutput.description ?? '',
|
||||
type: 'array',
|
||||
},
|
||||
};
|
||||
} else {
|
||||
outputs = buildOutputFieldTemplates(rawOutput, openAPI, typeHints);
|
||||
}
|
||||
|
||||
const invocation: InvocationTemplate = {
|
||||
title,
|
||||
type,
|
||||
tags: schema.ui?.tags ?? [],
|
||||
description: schema.description ?? '',
|
||||
inputs,
|
||||
outputs,
|
||||
if (type === 'collect') {
|
||||
// Special handling for the Collect node
|
||||
const itemProperty = schema.properties['item'] as InvocationSchemaObject;
|
||||
inputs.item = {
|
||||
type: 'item',
|
||||
name: 'item',
|
||||
description: itemProperty.description ?? '',
|
||||
title: 'Collection Item',
|
||||
inputKind: 'connection',
|
||||
inputRequirement: 'always',
|
||||
default: undefined,
|
||||
};
|
||||
} else if (type === 'iterate') {
|
||||
// Special handling for the Iterate node
|
||||
const itemProperty = schema.properties[
|
||||
'collection'
|
||||
] as InvocationSchemaObject;
|
||||
|
||||
Object.assign(acc, { [type]: invocation });
|
||||
inputs.collection = {
|
||||
type: 'array',
|
||||
name: 'collection',
|
||||
title: itemProperty.title ?? '',
|
||||
default: [],
|
||||
description: itemProperty.description ?? '',
|
||||
inputRequirement: 'always',
|
||||
inputKind: 'connection',
|
||||
};
|
||||
} else {
|
||||
// All other nodes
|
||||
reduce(
|
||||
schema.properties,
|
||||
(inputsAccumulator, property, propertyName) => {
|
||||
if (
|
||||
// `type` and `id` are not valid inputs/outputs
|
||||
!RESERVED_FIELD_NAMES.includes(propertyName) &&
|
||||
isSchemaObject(property)
|
||||
) {
|
||||
const field: InputFieldTemplate | undefined =
|
||||
buildInputFieldTemplate(property, propertyName, typeHints);
|
||||
|
||||
if (field) {
|
||||
inputsAccumulator[propertyName] = field;
|
||||
}
|
||||
}
|
||||
return inputsAccumulator;
|
||||
},
|
||||
inputs
|
||||
);
|
||||
}
|
||||
|
||||
let outputs: Record<string, OutputFieldTemplate>;
|
||||
|
||||
if (type === 'iterate') {
|
||||
// Special handling for the Iterate node output
|
||||
const iterationOutput =
|
||||
openAPI.components.schemas['IterateInvocationOutput'];
|
||||
|
||||
outputs = {
|
||||
item: {
|
||||
name: 'item',
|
||||
title: iterationOutput.title ?? '',
|
||||
description: iterationOutput.description ?? '',
|
||||
type: 'array',
|
||||
},
|
||||
};
|
||||
} else {
|
||||
// All other node outputs
|
||||
outputs = reduce(
|
||||
schema.output.properties as OpenAPIV3.SchemaObject,
|
||||
(outputsAccumulator, property, propertyName) => {
|
||||
if (!['type', 'id'].includes(propertyName)) {
|
||||
const fieldType = getFieldType(property, propertyName, typeHints);
|
||||
|
||||
outputsAccumulator[propertyName] = {
|
||||
name: propertyName,
|
||||
title: property.title ?? '',
|
||||
description: property.description ?? '',
|
||||
type: fieldType,
|
||||
};
|
||||
}
|
||||
|
||||
return outputsAccumulator;
|
||||
},
|
||||
{} as Record<string, OutputFieldTemplate>
|
||||
);
|
||||
}
|
||||
|
||||
const invocation: InvocationTemplate = {
|
||||
title,
|
||||
type,
|
||||
tags: schema.ui?.tags ?? [],
|
||||
description: schema.description ?? '',
|
||||
inputs,
|
||||
outputs,
|
||||
};
|
||||
|
||||
Object.assign(acc, { [type]: invocation });
|
||||
|
||||
return acc;
|
||||
}, {});
|
||||
|
||||
|
@ -4,7 +4,6 @@ import * as InvokeAI from 'app/types/invokeai';
|
||||
|
||||
import { InvokeLogLevel } from 'app/logging/useLogger';
|
||||
import { userInvoked } from 'app/store/actions';
|
||||
import { parsedOpenAPISchema } from 'features/nodes/store/nodesSlice';
|
||||
import { TFuncKey, t } from 'i18next';
|
||||
import { LogLevelName } from 'roarr';
|
||||
import {
|
||||
@ -26,6 +25,7 @@ import {
|
||||
} from 'services/api/thunks/session';
|
||||
import { makeToast } from '../../../app/components/Toaster';
|
||||
import { LANGUAGES } from '../components/LanguagePicker';
|
||||
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
|
||||
|
||||
export type CancelStrategy = 'immediate' | 'scheduled';
|
||||
|
||||
@ -382,7 +382,7 @@ export const systemSlice = createSlice({
|
||||
/**
|
||||
* OpenAPI schema was parsed
|
||||
*/
|
||||
builder.addCase(parsedOpenAPISchema, (state) => {
|
||||
builder.addCase(nodeTemplatesBuilt, (state) => {
|
||||
state.wasSchemaParsed = true;
|
||||
});
|
||||
|
||||
|
@ -2917,7 +2917,7 @@ export type components = {
|
||||
/** ModelsList */
|
||||
ModelsList: {
|
||||
/** Models */
|
||||
models: (components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"])[];
|
||||
models: (components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"])[];
|
||||
};
|
||||
/**
|
||||
* MultiplyInvocation
|
||||
@ -4177,18 +4177,18 @@ export type components = {
|
||||
*/
|
||||
image?: components["schemas"]["ImageField"];
|
||||
};
|
||||
/**
|
||||
* StableDiffusion2ModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
|
||||
/**
|
||||
* StableDiffusion1ModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
|
||||
/**
|
||||
* StableDiffusion2ModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
|
||||
};
|
||||
responses: never;
|
||||
parameters: never;
|
||||
|
@ -1,20 +1,45 @@
|
||||
import SwaggerParser from '@apidevtools/swagger-parser';
|
||||
import { createAsyncThunk } from '@reduxjs/toolkit';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { parsedOpenAPISchema } from 'features/nodes/store/nodesSlice';
|
||||
import { OpenAPIV3 } from 'openapi-types';
|
||||
|
||||
const schemaLog = log.child({ namespace: 'schema' });
|
||||
|
||||
function getCircularReplacer() {
|
||||
const ancestors: Record<string, any>[] = [];
|
||||
return function (key: string, value: any) {
|
||||
if (typeof value !== 'object' || value === null) {
|
||||
return value;
|
||||
}
|
||||
// `this` is the object that value is contained in,
|
||||
// i.e., its direct parent.
|
||||
// @ts-ignore
|
||||
while (ancestors.length > 0 && ancestors.at(-1) !== this) {
|
||||
ancestors.pop();
|
||||
}
|
||||
if (ancestors.includes(value)) {
|
||||
return '[Circular]';
|
||||
}
|
||||
ancestors.push(value);
|
||||
return value;
|
||||
};
|
||||
}
|
||||
|
||||
export const receivedOpenAPISchema = createAsyncThunk(
|
||||
'nodes/receivedOpenAPISchema',
|
||||
async (_, { dispatch }): Promise<OpenAPIV3.Document> => {
|
||||
const response = await fetch(`openapi.json`);
|
||||
const openAPISchema = await response.json();
|
||||
async (_, { dispatch, rejectWithValue }) => {
|
||||
try {
|
||||
const dereferencedSchema = (await SwaggerParser.dereference(
|
||||
'openapi.json'
|
||||
)) as OpenAPIV3.Document;
|
||||
|
||||
schemaLog.info({ openAPISchema }, 'Received OpenAPI schema');
|
||||
const schemaJSON = JSON.parse(
|
||||
JSON.stringify(dereferencedSchema, getCircularReplacer())
|
||||
);
|
||||
|
||||
dispatch(parsedOpenAPISchema(openAPISchema as OpenAPIV3.Document));
|
||||
|
||||
return openAPISchema;
|
||||
return schemaJSON;
|
||||
} catch (error) {
|
||||
return rejectWithValue({ error });
|
||||
}
|
||||
}
|
||||
);
|
||||
|
@ -1,7 +1,14 @@
|
||||
import { O } from 'ts-toolbelt';
|
||||
import { components } from './schema';
|
||||
|
||||
type schemas = components['schemas'];
|
||||
|
||||
/**
|
||||
* Helper type to extract the invocation type from the schema.
|
||||
* Also flags the `type` property as required.
|
||||
*/
|
||||
type Invocation<T extends keyof schemas> = O.Required<schemas[T], 'type'>;
|
||||
|
||||
/**
|
||||
* Types from the API, re-exported from the types generated by `openapi-typescript`.
|
||||
*/
|
||||
@ -31,42 +38,42 @@ export type Edge = schemas['Edge'];
|
||||
export type GraphExecutionState = schemas['GraphExecutionState'];
|
||||
|
||||
// General nodes
|
||||
export type CollectInvocation = schemas['CollectInvocation'];
|
||||
export type IterateInvocation = schemas['IterateInvocation'];
|
||||
export type RangeInvocation = schemas['RangeInvocation'];
|
||||
export type RandomRangeInvocation = schemas['RandomRangeInvocation'];
|
||||
export type RangeOfSizeInvocation = schemas['RangeOfSizeInvocation'];
|
||||
export type InpaintInvocation = schemas['InpaintInvocation'];
|
||||
export type ImageResizeInvocation = schemas['ImageResizeInvocation'];
|
||||
export type RandomIntInvocation = schemas['RandomIntInvocation'];
|
||||
export type CompelInvocation = schemas['CompelInvocation'];
|
||||
export type CollectInvocation = Invocation<'CollectInvocation'>;
|
||||
export type IterateInvocation = Invocation<'IterateInvocation'>;
|
||||
export type RangeInvocation = Invocation<'RangeInvocation'>;
|
||||
export type RandomRangeInvocation = Invocation<'RandomRangeInvocation'>;
|
||||
export type RangeOfSizeInvocation = Invocation<'RangeOfSizeInvocation'>;
|
||||
export type InpaintInvocation = Invocation<'InpaintInvocation'>;
|
||||
export type ImageResizeInvocation = Invocation<'ImageResizeInvocation'>;
|
||||
export type RandomIntInvocation = Invocation<'RandomIntInvocation'>;
|
||||
export type CompelInvocation = Invocation<'CompelInvocation'>;
|
||||
|
||||
// ControlNet Nodes
|
||||
export type ControlNetInvocation = schemas['ControlNetInvocation'];
|
||||
export type ControlNetInvocation = Invocation<'ControlNetInvocation'>;
|
||||
export type CannyImageProcessorInvocation =
|
||||
schemas['CannyImageProcessorInvocation'];
|
||||
Invocation<'CannyImageProcessorInvocation'>;
|
||||
export type ContentShuffleImageProcessorInvocation =
|
||||
schemas['ContentShuffleImageProcessorInvocation'];
|
||||
Invocation<'ContentShuffleImageProcessorInvocation'>;
|
||||
export type HedImageProcessorInvocation =
|
||||
schemas['HedImageProcessorInvocation'];
|
||||
Invocation<'HedImageProcessorInvocation'>;
|
||||
export type LineartAnimeImageProcessorInvocation =
|
||||
schemas['LineartAnimeImageProcessorInvocation'];
|
||||
Invocation<'LineartAnimeImageProcessorInvocation'>;
|
||||
export type LineartImageProcessorInvocation =
|
||||
schemas['LineartImageProcessorInvocation'];
|
||||
Invocation<'LineartImageProcessorInvocation'>;
|
||||
export type MediapipeFaceProcessorInvocation =
|
||||
schemas['MediapipeFaceProcessorInvocation'];
|
||||
Invocation<'MediapipeFaceProcessorInvocation'>;
|
||||
export type MidasDepthImageProcessorInvocation =
|
||||
schemas['MidasDepthImageProcessorInvocation'];
|
||||
Invocation<'MidasDepthImageProcessorInvocation'>;
|
||||
export type MlsdImageProcessorInvocation =
|
||||
schemas['MlsdImageProcessorInvocation'];
|
||||
Invocation<'MlsdImageProcessorInvocation'>;
|
||||
export type NormalbaeImageProcessorInvocation =
|
||||
schemas['NormalbaeImageProcessorInvocation'];
|
||||
Invocation<'NormalbaeImageProcessorInvocation'>;
|
||||
export type OpenposeImageProcessorInvocation =
|
||||
schemas['OpenposeImageProcessorInvocation'];
|
||||
Invocation<'OpenposeImageProcessorInvocation'>;
|
||||
export type PidiImageProcessorInvocation =
|
||||
schemas['PidiImageProcessorInvocation'];
|
||||
Invocation<'PidiImageProcessorInvocation'>;
|
||||
export type ZoeDepthImageProcessorInvocation =
|
||||
schemas['ZoeDepthImageProcessorInvocation'];
|
||||
Invocation<'ZoeDepthImageProcessorInvocation'>;
|
||||
|
||||
// Node Outputs
|
||||
export type ImageOutput = schemas['ImageOutput'];
|
||||
|
@ -9,6 +9,7 @@
|
||||
"vite.config.ts",
|
||||
"./config/vite.app.config.ts",
|
||||
"./config/vite.package.config.ts",
|
||||
"./config/vite.common.config.ts"
|
||||
"./config/vite.common.config.ts",
|
||||
"./config/common.ts"
|
||||
]
|
||||
}
|
||||
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user