mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): improve node schema parsing and add outputType
to templates
This commit is contained in:
parent
089ada8cd1
commit
2497aa5cd8
@ -53,6 +53,10 @@ export type InvocationTemplate = {
|
||||
* Array of the invocation outputs
|
||||
*/
|
||||
outputs: Record<string, OutputFieldTemplate>;
|
||||
/**
|
||||
* The type of this node's output
|
||||
*/
|
||||
outputType: string; // TODO: generate a union of output types
|
||||
};
|
||||
|
||||
export type FieldUIConfig = {
|
||||
@ -521,14 +525,15 @@ export type InvocationBaseSchemaObject = Omit<
|
||||
export type InvocationOutputSchemaObject = Omit<
|
||||
OpenAPIV3.SchemaObject,
|
||||
'properties'
|
||||
> &
|
||||
OpenAPIV3.SchemaObject['properties'] & {
|
||||
> & {
|
||||
properties: OpenAPIV3.SchemaObject['properties'] & {
|
||||
type: Omit<OpenAPIV3.SchemaObject, 'default'> & {
|
||||
default: AnyInvocationType;
|
||||
default: string;
|
||||
};
|
||||
} & {
|
||||
class: 'output';
|
||||
};
|
||||
};
|
||||
|
||||
export type InvocationFieldSchema = OpenAPIV3.SchemaObject & _InputField;
|
||||
|
||||
|
@ -14,13 +14,34 @@ import {
|
||||
} from '../types/types';
|
||||
import { buildInputFieldTemplate, getFieldType } from './fieldTemplateBuilders';
|
||||
|
||||
const RESERVED_FIELD_NAMES = ['id', 'type', 'metadata'];
|
||||
const RESERVED_INPUT_FIELD_NAMES = ['id', 'type', 'metadata'];
|
||||
const RESERVED_OUTPUT_FIELD_NAMES = ['type'];
|
||||
|
||||
const invocationDenylist: AnyInvocationType[] = [
|
||||
'graph',
|
||||
'metadata_accumulator',
|
||||
];
|
||||
|
||||
const isAllowedInputField = (nodeType: string, fieldName: string) => {
|
||||
if (RESERVED_INPUT_FIELD_NAMES.includes(fieldName)) {
|
||||
return false;
|
||||
}
|
||||
if (nodeType === 'collect' && fieldName === 'collection') {
|
||||
return false;
|
||||
}
|
||||
if (nodeType === 'iterate' && fieldName === 'index') {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
const isAllowedOutputField = (nodeType: string, fieldName: string) => {
|
||||
if (RESERVED_OUTPUT_FIELD_NAMES.includes(fieldName)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
const isNotInDenylist = (schema: InvocationSchemaObject) =>
|
||||
!invocationDenylist.includes(schema.properties.type.default);
|
||||
|
||||
@ -42,17 +63,28 @@ export const parseSchema = (
|
||||
const inputs = reduce(
|
||||
schema.properties,
|
||||
(inputsAccumulator, property, propertyName) => {
|
||||
if (
|
||||
!RESERVED_FIELD_NAMES.includes(propertyName) &&
|
||||
isInvocationFieldSchema(property) &&
|
||||
!property.ui_hidden
|
||||
) {
|
||||
if (!isAllowedInputField(type, propertyName)) {
|
||||
logger('nodes').trace(
|
||||
{ type, propertyName, property: parseify(property) },
|
||||
'Skipped reserved input field'
|
||||
);
|
||||
return inputsAccumulator;
|
||||
}
|
||||
|
||||
if (!isInvocationFieldSchema(property)) {
|
||||
logger('nodes').warn(
|
||||
{ type, propertyName, property: parseify(property) },
|
||||
'Unhandled input property'
|
||||
);
|
||||
return inputsAccumulator;
|
||||
}
|
||||
|
||||
const field = buildInputFieldTemplate(schema, property, propertyName);
|
||||
|
||||
if (field) {
|
||||
inputsAccumulator[propertyName] = field;
|
||||
}
|
||||
}
|
||||
|
||||
return inputsAccumulator;
|
||||
},
|
||||
{} as Record<string, InputFieldTemplate>
|
||||
@ -82,14 +114,27 @@ export const parseSchema = (
|
||||
throw 'Invalid output schema';
|
||||
}
|
||||
|
||||
const outputType = outputSchema.properties.type.default;
|
||||
|
||||
const outputs = reduce(
|
||||
outputSchema.properties as OpenAPIV3.SchemaObject,
|
||||
outputSchema.properties,
|
||||
(outputsAccumulator, property, propertyName) => {
|
||||
if (
|
||||
!['type', 'id'].includes(propertyName) &&
|
||||
!['object'].includes(property.type) && // TODO: handle objects?
|
||||
isInvocationFieldSchema(property)
|
||||
) {
|
||||
if (!isAllowedOutputField(type, propertyName)) {
|
||||
logger('nodes').trace(
|
||||
{ type, propertyName, property: parseify(property) },
|
||||
'Skipped reserved output field'
|
||||
);
|
||||
return outputsAccumulator;
|
||||
}
|
||||
|
||||
if (!isInvocationFieldSchema(property)) {
|
||||
logger('nodes').warn(
|
||||
{ type, propertyName, property: parseify(property) },
|
||||
'Unhandled output property'
|
||||
);
|
||||
return outputsAccumulator;
|
||||
}
|
||||
|
||||
const fieldType = getFieldType(property);
|
||||
outputsAccumulator[propertyName] = {
|
||||
fieldKind: 'output',
|
||||
@ -98,9 +143,6 @@ export const parseSchema = (
|
||||
description: property.description ?? '',
|
||||
type: fieldType,
|
||||
};
|
||||
} else {
|
||||
logger('nodes').warn({ property }, 'Unhandled output property');
|
||||
}
|
||||
|
||||
return outputsAccumulator;
|
||||
},
|
||||
@ -114,6 +156,7 @@ export const parseSchema = (
|
||||
description,
|
||||
inputs,
|
||||
outputs,
|
||||
outputType,
|
||||
};
|
||||
|
||||
Object.assign(acc, { [type]: invocation });
|
||||
|
Loading…
Reference in New Issue
Block a user