feat(ui): improve node schema parsing and add outputType to templates

This commit is contained in:
psychedelicious 2023-08-22 15:31:18 +10:00
parent 089ada8cd1
commit 2497aa5cd8
2 changed files with 78 additions and 30 deletions

View File

@ -53,6 +53,10 @@ export type InvocationTemplate = {
* Array of the invocation outputs
*/
outputs: Record<string, OutputFieldTemplate>;
/**
* The type of this node's output
*/
outputType: string; // TODO: generate a union of output types
};
export type FieldUIConfig = {
@ -521,14 +525,15 @@ export type InvocationBaseSchemaObject = Omit<
export type InvocationOutputSchemaObject = Omit<
OpenAPIV3.SchemaObject,
'properties'
> &
OpenAPIV3.SchemaObject['properties'] & {
> & {
properties: OpenAPIV3.SchemaObject['properties'] & {
type: Omit<OpenAPIV3.SchemaObject, 'default'> & {
default: AnyInvocationType;
default: string;
};
} & {
class: 'output';
};
};
export type InvocationFieldSchema = OpenAPIV3.SchemaObject & _InputField;

View File

@ -14,13 +14,34 @@ import {
} from '../types/types';
import { buildInputFieldTemplate, getFieldType } from './fieldTemplateBuilders';
const RESERVED_FIELD_NAMES = ['id', 'type', 'metadata'];
const RESERVED_INPUT_FIELD_NAMES = ['id', 'type', 'metadata'];
const RESERVED_OUTPUT_FIELD_NAMES = ['type'];
const invocationDenylist: AnyInvocationType[] = [
'graph',
'metadata_accumulator',
];
const isAllowedInputField = (nodeType: string, fieldName: string) => {
if (RESERVED_INPUT_FIELD_NAMES.includes(fieldName)) {
return false;
}
if (nodeType === 'collect' && fieldName === 'collection') {
return false;
}
if (nodeType === 'iterate' && fieldName === 'index') {
return false;
}
return true;
};
const isAllowedOutputField = (nodeType: string, fieldName: string) => {
if (RESERVED_OUTPUT_FIELD_NAMES.includes(fieldName)) {
return false;
}
return true;
};
const isNotInDenylist = (schema: InvocationSchemaObject) =>
!invocationDenylist.includes(schema.properties.type.default);
@ -42,17 +63,28 @@ export const parseSchema = (
const inputs = reduce(
schema.properties,
(inputsAccumulator, property, propertyName) => {
if (
!RESERVED_FIELD_NAMES.includes(propertyName) &&
isInvocationFieldSchema(property) &&
!property.ui_hidden
) {
if (!isAllowedInputField(type, propertyName)) {
logger('nodes').trace(
{ type, propertyName, property: parseify(property) },
'Skipped reserved input field'
);
return inputsAccumulator;
}
if (!isInvocationFieldSchema(property)) {
logger('nodes').warn(
{ type, propertyName, property: parseify(property) },
'Unhandled input property'
);
return inputsAccumulator;
}
const field = buildInputFieldTemplate(schema, property, propertyName);
if (field) {
inputsAccumulator[propertyName] = field;
}
}
return inputsAccumulator;
},
{} as Record<string, InputFieldTemplate>
@ -82,14 +114,27 @@ export const parseSchema = (
throw 'Invalid output schema';
}
const outputType = outputSchema.properties.type.default;
const outputs = reduce(
outputSchema.properties as OpenAPIV3.SchemaObject,
outputSchema.properties,
(outputsAccumulator, property, propertyName) => {
if (
!['type', 'id'].includes(propertyName) &&
!['object'].includes(property.type) && // TODO: handle objects?
isInvocationFieldSchema(property)
) {
if (!isAllowedOutputField(type, propertyName)) {
logger('nodes').trace(
{ type, propertyName, property: parseify(property) },
'Skipped reserved output field'
);
return outputsAccumulator;
}
if (!isInvocationFieldSchema(property)) {
logger('nodes').warn(
{ type, propertyName, property: parseify(property) },
'Unhandled output property'
);
return outputsAccumulator;
}
const fieldType = getFieldType(property);
outputsAccumulator[propertyName] = {
fieldKind: 'output',
@ -98,9 +143,6 @@ export const parseSchema = (
description: property.description ?? '',
type: fieldType,
};
} else {
logger('nodes').warn({ property }, 'Unhandled output property');
}
return outputsAccumulator;
},
@ -114,6 +156,7 @@ export const parseSchema = (
description,
inputs,
outputs,
outputType,
};
Object.assign(acc, { [type]: invocation });