From 2497aa5cd801180ddd375fa4eef24a6a7698bfd8 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 22 Aug 2023 15:31:18 +1000 Subject: [PATCH] feat(ui): improve node schema parsing and add `outputType` to templates --- .../web/src/features/nodes/types/types.ts | 11 ++- .../src/features/nodes/util/parseSchema.ts | 97 +++++++++++++------ 2 files changed, 78 insertions(+), 30 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/types/types.ts b/invokeai/frontend/web/src/features/nodes/types/types.ts index 2cc14bc618..af6b80c5d5 100644 --- a/invokeai/frontend/web/src/features/nodes/types/types.ts +++ b/invokeai/frontend/web/src/features/nodes/types/types.ts @@ -53,6 +53,10 @@ export type InvocationTemplate = { * Array of the invocation outputs */ outputs: Record; + /** + * The type of this node's output + */ + outputType: string; // TODO: generate a union of output types }; export type FieldUIConfig = { @@ -521,14 +525,15 @@ export type InvocationBaseSchemaObject = Omit< export type InvocationOutputSchemaObject = Omit< OpenAPIV3.SchemaObject, 'properties' -> & - OpenAPIV3.SchemaObject['properties'] & { +> & { + properties: OpenAPIV3.SchemaObject['properties'] & { type: Omit & { - default: AnyInvocationType; + default: string; }; } & { class: 'output'; }; +}; export type InvocationFieldSchema = OpenAPIV3.SchemaObject & _InputField; diff --git a/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts b/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts index e237ecbfe4..c2f49d205d 100644 --- a/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts +++ b/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts @@ -14,13 +14,34 @@ import { } from '../types/types'; import { buildInputFieldTemplate, getFieldType } from './fieldTemplateBuilders'; -const RESERVED_FIELD_NAMES = ['id', 'type', 'metadata']; +const RESERVED_INPUT_FIELD_NAMES = ['id', 'type', 'metadata']; +const RESERVED_OUTPUT_FIELD_NAMES = ['type']; const invocationDenylist: AnyInvocationType[] = [ 'graph', 'metadata_accumulator', ]; +const isAllowedInputField = (nodeType: string, fieldName: string) => { + if (RESERVED_INPUT_FIELD_NAMES.includes(fieldName)) { + return false; + } + if (nodeType === 'collect' && fieldName === 'collection') { + return false; + } + if (nodeType === 'iterate' && fieldName === 'index') { + return false; + } + return true; +}; + +const isAllowedOutputField = (nodeType: string, fieldName: string) => { + if (RESERVED_OUTPUT_FIELD_NAMES.includes(fieldName)) { + return false; + } + return true; +}; + const isNotInDenylist = (schema: InvocationSchemaObject) => !invocationDenylist.includes(schema.properties.type.default); @@ -42,17 +63,28 @@ export const parseSchema = ( const inputs = reduce( schema.properties, (inputsAccumulator, property, propertyName) => { - if ( - !RESERVED_FIELD_NAMES.includes(propertyName) && - isInvocationFieldSchema(property) && - !property.ui_hidden - ) { - const field = buildInputFieldTemplate(schema, property, propertyName); - - if (field) { - inputsAccumulator[propertyName] = field; - } + if (!isAllowedInputField(type, propertyName)) { + logger('nodes').trace( + { type, propertyName, property: parseify(property) }, + 'Skipped reserved input field' + ); + return inputsAccumulator; } + + if (!isInvocationFieldSchema(property)) { + logger('nodes').warn( + { type, propertyName, property: parseify(property) }, + 'Unhandled input property' + ); + return inputsAccumulator; + } + + const field = buildInputFieldTemplate(schema, property, propertyName); + + if (field) { + inputsAccumulator[propertyName] = field; + } + return inputsAccumulator; }, {} as Record @@ -82,26 +114,36 @@ export const parseSchema = ( throw 'Invalid output schema'; } + const outputType = outputSchema.properties.type.default; + const outputs = reduce( - outputSchema.properties as OpenAPIV3.SchemaObject, + outputSchema.properties, (outputsAccumulator, property, propertyName) => { - if ( - !['type', 'id'].includes(propertyName) && - !['object'].includes(property.type) && // TODO: handle objects? - isInvocationFieldSchema(property) - ) { - const fieldType = getFieldType(property); - outputsAccumulator[propertyName] = { - fieldKind: 'output', - name: propertyName, - title: property.title ?? '', - description: property.description ?? '', - type: fieldType, - }; - } else { - logger('nodes').warn({ property }, 'Unhandled output property'); + if (!isAllowedOutputField(type, propertyName)) { + logger('nodes').trace( + { type, propertyName, property: parseify(property) }, + 'Skipped reserved output field' + ); + return outputsAccumulator; } + if (!isInvocationFieldSchema(property)) { + logger('nodes').warn( + { type, propertyName, property: parseify(property) }, + 'Unhandled output property' + ); + return outputsAccumulator; + } + + const fieldType = getFieldType(property); + outputsAccumulator[propertyName] = { + fieldKind: 'output', + name: propertyName, + title: property.title ?? '', + description: property.description ?? '', + type: fieldType, + }; + return outputsAccumulator; }, {} as Record @@ -114,6 +156,7 @@ export const parseSchema = ( description, inputs, outputs, + outputType, }; Object.assign(acc, { [type]: invocation });