mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): better workflow validation and parsing
Checks for the existence of nodes for each edge - does not yet check the types.
This commit is contained in:
parent
68fd07a606
commit
9a2c0554de
@ -111,6 +111,7 @@
|
|||||||
"roarr": "^7.15.1",
|
"roarr": "^7.15.1",
|
||||||
"serialize-error": "^11.0.1",
|
"serialize-error": "^11.0.1",
|
||||||
"socket.io-client": "^4.7.2",
|
"socket.io-client": "^4.7.2",
|
||||||
|
"type-fest": "^4.2.0",
|
||||||
"use-debounce": "^9.0.4",
|
"use-debounce": "^9.0.4",
|
||||||
"use-image": "^1.1.1",
|
"use-image": "^1.1.1",
|
||||||
"uuid": "^9.0.0",
|
"uuid": "^9.0.0",
|
||||||
|
@ -3,7 +3,7 @@ import { useLogger } from 'app/logging/useLogger';
|
|||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import { parseify } from 'common/util/serialize';
|
import { parseify } from 'common/util/serialize';
|
||||||
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
|
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
|
||||||
import { zWorkflow } from 'features/nodes/types/types';
|
import { zValidatedWorkflow } from 'features/nodes/types/types';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
import { makeToast } from 'features/system/util/makeToast';
|
import { makeToast } from 'features/system/util/makeToast';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
@ -24,52 +24,65 @@ export const useLoadWorkflowFromFile = () => {
|
|||||||
|
|
||||||
try {
|
try {
|
||||||
const parsedJSON = JSON.parse(String(rawJSON));
|
const parsedJSON = JSON.parse(String(rawJSON));
|
||||||
const result = zWorkflow.safeParse(parsedJSON);
|
const result = zValidatedWorkflow.safeParse(parsedJSON);
|
||||||
|
|
||||||
if (!result.success) {
|
if (!result.success) {
|
||||||
const message = fromZodError(result.error, {
|
const { message } = fromZodError(result.error, {
|
||||||
prefix: 'Workflow Validation Error',
|
prefix: 'Workflow Validation Error',
|
||||||
}).toString();
|
});
|
||||||
|
|
||||||
logger.error({ error: parseify(result.error) }, message);
|
logger.error({ error: parseify(result.error) }, message);
|
||||||
|
|
||||||
dispatch(
|
dispatch(
|
||||||
addToast(
|
addToast(
|
||||||
makeToast({
|
makeToast({
|
||||||
title: 'Unable to Validate Workflow',
|
title: 'Unable to Validate Workflow',
|
||||||
description: (
|
|
||||||
<WorkflowValidationErrorContent error={result.error} />
|
|
||||||
),
|
|
||||||
status: 'error',
|
status: 'error',
|
||||||
duration: 5000,
|
duration: 5000,
|
||||||
})
|
})
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
reader.abort();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
dispatch(workflowLoaded(result.data.workflow));
|
||||||
|
|
||||||
dispatch(workflowLoaded(result.data));
|
if (!result.data.warnings.length) {
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: 'Workflow Loaded',
|
||||||
|
status: 'success',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
reader.abort();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
dispatch(
|
dispatch(
|
||||||
addToast(
|
addToast(
|
||||||
makeToast({
|
makeToast({
|
||||||
title: 'Workflow Loaded',
|
title: 'Workflow Loaded with Warnings',
|
||||||
status: 'success',
|
status: 'warning',
|
||||||
})
|
})
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
result.data.warnings.forEach(({ message, ...rest }) => {
|
||||||
|
logger.warn(rest, message);
|
||||||
|
});
|
||||||
|
|
||||||
reader.abort();
|
reader.abort();
|
||||||
} catch (error) {
|
} catch {
|
||||||
// file reader error
|
// file reader error
|
||||||
if (error) {
|
dispatch(
|
||||||
dispatch(
|
addToast(
|
||||||
addToast(
|
makeToast({
|
||||||
makeToast({
|
title: 'Unable to Load Workflow',
|
||||||
title: 'Unable to Load Workflow',
|
status: 'error',
|
||||||
status: 'error',
|
})
|
||||||
})
|
)
|
||||||
)
|
);
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -589,7 +589,7 @@ const nodesSlice = createSlice({
|
|||||||
nodeEditorReset: (state) => {
|
nodeEditorReset: (state) => {
|
||||||
state.nodes = [];
|
state.nodes = [];
|
||||||
state.edges = [];
|
state.edges = [];
|
||||||
state.workflow.exposedFields = [];
|
state.workflow = cloneDeep(initialWorkflow);
|
||||||
},
|
},
|
||||||
shouldValidateGraphChanged: (state, action: PayloadAction<boolean>) => {
|
shouldValidateGraphChanged: (state, action: PayloadAction<boolean>) => {
|
||||||
state.shouldValidateGraph = action.payload;
|
state.shouldValidateGraph = action.payload;
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import { store } from 'app/store/store';
|
||||||
import {
|
import {
|
||||||
SchedulerParam,
|
SchedulerParam,
|
||||||
zBaseModel,
|
zBaseModel,
|
||||||
@ -5,9 +6,11 @@ import {
|
|||||||
zSDXLRefinerModel,
|
zSDXLRefinerModel,
|
||||||
zScheduler,
|
zScheduler,
|
||||||
} from 'features/parameters/types/parameterSchemas';
|
} from 'features/parameters/types/parameterSchemas';
|
||||||
|
import { keyBy } from 'lodash-es';
|
||||||
import { OpenAPIV3 } from 'openapi-types';
|
import { OpenAPIV3 } from 'openapi-types';
|
||||||
import { RgbaColor } from 'react-colorful';
|
import { RgbaColor } from 'react-colorful';
|
||||||
import { Node } from 'reactflow';
|
import { Node } from 'reactflow';
|
||||||
|
import { JsonObject } from 'type-fest';
|
||||||
import { Graph, ImageDTO, _InputField, _OutputField } from 'services/api/types';
|
import { Graph, ImageDTO, _InputField, _OutputField } from 'services/api/types';
|
||||||
import {
|
import {
|
||||||
AnyInvocationType,
|
AnyInvocationType,
|
||||||
@ -224,7 +227,7 @@ export type DenoiseMaskFieldValue = z.infer<typeof zDenoiseMaskField>;
|
|||||||
|
|
||||||
export const zIntegerInputFieldValue = zInputFieldValueBase.extend({
|
export const zIntegerInputFieldValue = zInputFieldValueBase.extend({
|
||||||
type: z.literal('integer'),
|
type: z.literal('integer'),
|
||||||
value: z.number().optional(),
|
value: z.number().int().optional(),
|
||||||
});
|
});
|
||||||
export type IntegerInputFieldValue = z.infer<typeof zIntegerInputFieldValue>;
|
export type IntegerInputFieldValue = z.infer<typeof zIntegerInputFieldValue>;
|
||||||
|
|
||||||
@ -825,28 +828,38 @@ export const zNotesNodeData = z.object({
|
|||||||
|
|
||||||
export type NotesNodeData = z.infer<typeof zNotesNodeData>;
|
export type NotesNodeData = z.infer<typeof zNotesNodeData>;
|
||||||
|
|
||||||
|
const zPosition = z
|
||||||
|
.object({
|
||||||
|
x: z.number(),
|
||||||
|
y: z.number(),
|
||||||
|
})
|
||||||
|
.default({ x: 0, y: 0 });
|
||||||
|
|
||||||
|
const zDimension = z.number().gt(0).nullish();
|
||||||
|
|
||||||
export const zWorkflowInvocationNode = z.object({
|
export const zWorkflowInvocationNode = z.object({
|
||||||
id: z.string().trim().min(1),
|
id: z.string().trim().min(1),
|
||||||
type: z.literal('invocation'),
|
type: z.literal('invocation'),
|
||||||
data: zInvocationNodeData,
|
data: zInvocationNodeData,
|
||||||
width: z.number().gt(0),
|
width: zDimension,
|
||||||
height: z.number().gt(0),
|
height: zDimension,
|
||||||
position: z.object({
|
position: zPosition,
|
||||||
x: z.number(),
|
|
||||||
y: z.number(),
|
|
||||||
}),
|
|
||||||
});
|
});
|
||||||
|
|
||||||
|
export type WorkflowInvocationNode = z.infer<typeof zWorkflowInvocationNode>;
|
||||||
|
|
||||||
|
export const isWorkflowInvocationNode = (
|
||||||
|
val: unknown
|
||||||
|
): val is WorkflowInvocationNode =>
|
||||||
|
zWorkflowInvocationNode.safeParse(val).success;
|
||||||
|
|
||||||
export const zWorkflowNotesNode = z.object({
|
export const zWorkflowNotesNode = z.object({
|
||||||
id: z.string().trim().min(1),
|
id: z.string().trim().min(1),
|
||||||
type: z.literal('notes'),
|
type: z.literal('notes'),
|
||||||
data: zNotesNodeData,
|
data: zNotesNodeData,
|
||||||
width: z.number().gt(0),
|
width: zDimension,
|
||||||
height: z.number().gt(0),
|
height: zDimension,
|
||||||
position: z.object({
|
position: zPosition,
|
||||||
x: z.number(),
|
|
||||||
y: z.number(),
|
|
||||||
}),
|
|
||||||
});
|
});
|
||||||
|
|
||||||
export const zWorkflowNode = z.discriminatedUnion('type', [
|
export const zWorkflowNode = z.discriminatedUnion('type', [
|
||||||
@ -886,20 +899,75 @@ export const zSemVer = z.string().refine((val) => {
|
|||||||
|
|
||||||
export type SemVer = z.infer<typeof zSemVer>;
|
export type SemVer = z.infer<typeof zSemVer>;
|
||||||
|
|
||||||
|
export type WorkflowWarning = {
|
||||||
|
message: string;
|
||||||
|
issues: string[];
|
||||||
|
data: JsonObject;
|
||||||
|
};
|
||||||
|
|
||||||
export const zWorkflow = z.object({
|
export const zWorkflow = z.object({
|
||||||
name: z.string(),
|
name: z.string().default(''),
|
||||||
author: z.string(),
|
author: z.string().default(''),
|
||||||
description: z.string(),
|
description: z.string().default(''),
|
||||||
version: z.string(),
|
version: z.string().default(''),
|
||||||
contact: z.string(),
|
contact: z.string().default(''),
|
||||||
tags: z.string(),
|
tags: z.string().default(''),
|
||||||
notes: z.string(),
|
notes: z.string().default(''),
|
||||||
nodes: z.array(zWorkflowNode),
|
nodes: z.array(zWorkflowNode).default([]),
|
||||||
edges: z.array(zWorkflowEdge),
|
edges: z.array(zWorkflowEdge).default([]),
|
||||||
exposedFields: z.array(zFieldIdentifier),
|
exposedFields: z.array(zFieldIdentifier).default([]),
|
||||||
meta: z.object({
|
meta: z
|
||||||
version: zSemVer,
|
.object({
|
||||||
}),
|
version: zSemVer,
|
||||||
|
})
|
||||||
|
.default({ version: '1.0.0' }),
|
||||||
|
});
|
||||||
|
|
||||||
|
export const zValidatedWorkflow = zWorkflow.transform((workflow) => {
|
||||||
|
const nodeTemplates = store.getState().nodes.nodeTemplates;
|
||||||
|
const { nodes, edges } = workflow;
|
||||||
|
const warnings: WorkflowWarning[] = [];
|
||||||
|
const invocationNodes = nodes.filter(isWorkflowInvocationNode);
|
||||||
|
const keyedNodes = keyBy(invocationNodes, 'id');
|
||||||
|
invocationNodes.forEach((node, i) => {
|
||||||
|
const nodeTemplate = nodeTemplates[node.data.type];
|
||||||
|
if (!nodeTemplate) {
|
||||||
|
warnings.push({
|
||||||
|
message: `Node "${node.data.label || node.data.id}" skipped`,
|
||||||
|
issues: [`Unable to find template for type "${node.data.type}"`],
|
||||||
|
data: node,
|
||||||
|
});
|
||||||
|
delete nodes[i];
|
||||||
|
}
|
||||||
|
});
|
||||||
|
edges.forEach((edge, i) => {
|
||||||
|
const sourceNode = keyedNodes[edge.source];
|
||||||
|
const targetNode = keyedNodes[edge.target];
|
||||||
|
const issues: string[] = [];
|
||||||
|
if (!sourceNode) {
|
||||||
|
issues.push(`Output node ${edge.source} does not exist`);
|
||||||
|
} else if (!(edge.sourceHandle in sourceNode.data.outputs)) {
|
||||||
|
issues.push(
|
||||||
|
`Output field "${edge.source}.${edge.sourceHandle}" does not exist`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (!targetNode) {
|
||||||
|
issues.push(`Input node ${edge.target} does not exist`);
|
||||||
|
} else if (!(edge.targetHandle in targetNode.data.inputs)) {
|
||||||
|
issues.push(
|
||||||
|
`Input field "${edge.target}.${edge.targetHandle}" does not exist`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (issues.length) {
|
||||||
|
delete edges[i];
|
||||||
|
warnings.push({
|
||||||
|
message: `Edge "${edge.sourceHandle} -> ${edge.targetHandle}" skipped`,
|
||||||
|
issues,
|
||||||
|
data: edge,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return { workflow, warnings };
|
||||||
});
|
});
|
||||||
|
|
||||||
export type Workflow = z.infer<typeof zWorkflow>;
|
export type Workflow = z.infer<typeof zWorkflow>;
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import * as png from '@stevebel/png';
|
import * as png from '@stevebel/png';
|
||||||
import { logger } from 'app/logging/logger';
|
|
||||||
import {
|
import {
|
||||||
ImageMetadataAndWorkflow,
|
ImageMetadataAndWorkflow,
|
||||||
zCoreMetadata,
|
zCoreMetadata,
|
||||||
@ -11,27 +10,24 @@ export const getMetadataAndWorkflowFromImageBlob = async (
|
|||||||
image: Blob
|
image: Blob
|
||||||
): Promise<ImageMetadataAndWorkflow> => {
|
): Promise<ImageMetadataAndWorkflow> => {
|
||||||
const data: ImageMetadataAndWorkflow = {};
|
const data: ImageMetadataAndWorkflow = {};
|
||||||
try {
|
const buffer = await image.arrayBuffer();
|
||||||
const buffer = await image.arrayBuffer();
|
const text = png.decode(buffer).text;
|
||||||
const text = png.decode(buffer).text;
|
|
||||||
const rawMetadata = get(text, 'invokeai_metadata');
|
const rawMetadata = get(text, 'invokeai_metadata');
|
||||||
const rawWorkflow = get(text, 'invokeai_workflow');
|
if (rawMetadata) {
|
||||||
if (rawMetadata) {
|
const metadataResult = zCoreMetadata.safeParse(JSON.parse(rawMetadata));
|
||||||
try {
|
if (metadataResult.success) {
|
||||||
data.metadata = zCoreMetadata.parse(JSON.parse(rawMetadata));
|
data.metadata = metadataResult.data;
|
||||||
} catch {
|
|
||||||
// no-op
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if (rawWorkflow) {
|
|
||||||
try {
|
|
||||||
data.workflow = zWorkflow.parse(JSON.parse(rawWorkflow));
|
|
||||||
} catch {
|
|
||||||
// no-op
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} catch {
|
|
||||||
logger('nodes').warn('Unable to parse image');
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const rawWorkflow = get(text, 'invokeai_workflow');
|
||||||
|
if (rawWorkflow) {
|
||||||
|
const workflowResult = zWorkflow.safeParse(JSON.parse(rawWorkflow));
|
||||||
|
if (workflowResult.success) {
|
||||||
|
data.workflow = workflowResult.data;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return data;
|
return data;
|
||||||
};
|
};
|
||||||
|
@ -68,7 +68,7 @@ export const parseSchema = (
|
|||||||
|
|
||||||
const invocations = filteredSchemas.reduce<
|
const invocations = filteredSchemas.reduce<
|
||||||
Record<string, InvocationTemplate>
|
Record<string, InvocationTemplate>
|
||||||
>((acc, schema) => {
|
>((invocationsAccumulator, schema) => {
|
||||||
const type = schema.properties.type.default;
|
const type = schema.properties.type.default;
|
||||||
const title = schema.title.replace('Invocation', '');
|
const title = schema.title.replace('Invocation', '');
|
||||||
const tags = schema.tags ?? [];
|
const tags = schema.tags ?? [];
|
||||||
@ -133,7 +133,7 @@ export const parseSchema = (
|
|||||||
);
|
);
|
||||||
|
|
||||||
if (!field) {
|
if (!field) {
|
||||||
logger('nodes').warn(
|
logger('nodes').debug(
|
||||||
{
|
{
|
||||||
node: type,
|
node: type,
|
||||||
fieldName: propertyName,
|
fieldName: propertyName,
|
||||||
@ -154,17 +154,17 @@ export const parseSchema = (
|
|||||||
const outputSchemaName = schema.output.$ref.split('/').pop();
|
const outputSchemaName = schema.output.$ref.split('/').pop();
|
||||||
|
|
||||||
if (!outputSchemaName) {
|
if (!outputSchemaName) {
|
||||||
logger('nodes').error(
|
logger('nodes').warn(
|
||||||
{ outputRefObject: parseify(schema.output) },
|
{ outputRefObject: parseify(schema.output) },
|
||||||
'No output schema name found in ref object'
|
'No output schema name found in ref object'
|
||||||
);
|
);
|
||||||
throw 'No output schema name found in ref object';
|
return invocationsAccumulator;
|
||||||
}
|
}
|
||||||
|
|
||||||
const outputSchema = openAPI.components?.schemas?.[outputSchemaName];
|
const outputSchema = openAPI.components?.schemas?.[outputSchemaName];
|
||||||
if (!outputSchema) {
|
if (!outputSchema) {
|
||||||
logger('nodes').error({ outputSchemaName }, 'Output schema not found');
|
logger('nodes').warn({ outputSchemaName }, 'Output schema not found');
|
||||||
throw 'Output schema not found';
|
return invocationsAccumulator;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!isInvocationOutputSchemaObject(outputSchema)) {
|
if (!isInvocationOutputSchemaObject(outputSchema)) {
|
||||||
@ -172,7 +172,7 @@ export const parseSchema = (
|
|||||||
{ outputSchema: parseify(outputSchema) },
|
{ outputSchema: parseify(outputSchema) },
|
||||||
'Invalid output schema'
|
'Invalid output schema'
|
||||||
);
|
);
|
||||||
throw 'Invalid output schema';
|
return invocationsAccumulator;
|
||||||
}
|
}
|
||||||
|
|
||||||
const outputType = outputSchema.properties.type.default;
|
const outputType = outputSchema.properties.type.default;
|
||||||
@ -203,19 +203,20 @@ export const parseSchema = (
|
|||||||
{ fieldName: propertyName, fieldType, field: parseify(property) },
|
{ fieldName: propertyName, fieldType, field: parseify(property) },
|
||||||
'Skipping unknown output field type'
|
'Skipping unknown output field type'
|
||||||
);
|
);
|
||||||
} else {
|
return outputsAccumulator;
|
||||||
outputsAccumulator[propertyName] = {
|
|
||||||
fieldKind: 'output',
|
|
||||||
name: propertyName,
|
|
||||||
title: property.title ?? '',
|
|
||||||
description: property.description ?? '',
|
|
||||||
type: fieldType,
|
|
||||||
ui_hidden: property.ui_hidden ?? false,
|
|
||||||
ui_type: property.ui_type,
|
|
||||||
ui_order: property.ui_order,
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
outputsAccumulator[propertyName] = {
|
||||||
|
fieldKind: 'output',
|
||||||
|
name: propertyName,
|
||||||
|
title: property.title ?? '',
|
||||||
|
description: property.description ?? '',
|
||||||
|
type: fieldType,
|
||||||
|
ui_hidden: property.ui_hidden ?? false,
|
||||||
|
ui_type: property.ui_type,
|
||||||
|
ui_order: property.ui_order,
|
||||||
|
};
|
||||||
|
|
||||||
return outputsAccumulator;
|
return outputsAccumulator;
|
||||||
},
|
},
|
||||||
{} as Record<string, OutputFieldTemplate>
|
{} as Record<string, OutputFieldTemplate>
|
||||||
@ -231,9 +232,9 @@ export const parseSchema = (
|
|||||||
outputType,
|
outputType,
|
||||||
};
|
};
|
||||||
|
|
||||||
Object.assign(acc, { [type]: invocation });
|
Object.assign(invocationsAccumulator, { [type]: invocation });
|
||||||
|
|
||||||
return acc;
|
return invocationsAccumulator;
|
||||||
}, {});
|
}, {});
|
||||||
|
|
||||||
return invocations;
|
return invocations;
|
||||||
|
@ -6668,6 +6668,11 @@ type-fest@^2.12.2:
|
|||||||
resolved "https://registry.yarnpkg.com/type-fest/-/type-fest-2.19.0.tgz#88068015bb33036a598b952e55e9311a60fd3a9b"
|
resolved "https://registry.yarnpkg.com/type-fest/-/type-fest-2.19.0.tgz#88068015bb33036a598b952e55e9311a60fd3a9b"
|
||||||
integrity sha512-RAH822pAdBgcNMAfWnCBU3CFZcfZ/i1eZjwFU/dsLKumyuuP3niueg2UAukXYF0E2AAoc82ZSSf9J0WQBinzHA==
|
integrity sha512-RAH822pAdBgcNMAfWnCBU3CFZcfZ/i1eZjwFU/dsLKumyuuP3niueg2UAukXYF0E2AAoc82ZSSf9J0WQBinzHA==
|
||||||
|
|
||||||
|
type-fest@^4.2.0:
|
||||||
|
version "4.2.0"
|
||||||
|
resolved "https://registry.yarnpkg.com/type-fest/-/type-fest-4.2.0.tgz#e259430307710e77721ecf6f545840acad72195f"
|
||||||
|
integrity sha512-5zknd7Dss75pMSED270A1RQS3KloqRJA9XbXLe0eCxyw7xXFb3rd+9B0UQ/0E+LQT6lnrLviEolYORlRWamn4w==
|
||||||
|
|
||||||
typed-array-buffer@^1.0.0:
|
typed-array-buffer@^1.0.0:
|
||||||
version "1.0.0"
|
version "1.0.0"
|
||||||
resolved "https://registry.yarnpkg.com/typed-array-buffer/-/typed-array-buffer-1.0.0.tgz#18de3e7ed7974b0a729d3feecb94338d1472cd60"
|
resolved "https://registry.yarnpkg.com/typed-array-buffer/-/typed-array-buffer-1.0.0.tgz#18de3e7ed7974b0a729d3feecb94338d1472cd60"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user