feat(ui): reset missing images, boards and models when loading workflows

These fields are reset back to `undefined` if not accessible. A warning toast is showing, and in the JS console, the full warning message is logged.
This commit is contained in:
psychedelicious 2024-05-21 09:54:13 +10:00 committed by Kent Keirsey
parent 7badaab17d
commit 38320a5100
4 changed files with 138 additions and 15 deletions

View File

@ -897,7 +897,10 @@
"zoomInNodes": "Zoom In",
"zoomOutNodes": "Zoom Out",
"betaDesc": "This invocation is in beta. Until it is stable, it may have breaking changes during app updates. We plan to support this invocation long-term.",
"prototypeDesc": "This invocation is a prototype. It may have breaking changes during app updates and may be removed at any time."
"prototypeDesc": "This invocation is a prototype. It may have breaking changes during app updates and may be removed at any time.",
"imageAccessError": "Unable to find image {{image_name}}, resetting to default",
"boardAccessError": "Unable to find board {{board_id}}, resetting to default",
"modelAccessError": "Unable to find model {{key}}, resetting to default"
},
"parameters": {
"aspect": "Aspect",

View File

@ -11,20 +11,21 @@ import { validateWorkflow } from 'features/nodes/util/workflow/validateWorkflow'
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next';
import { checkBoardAccess, checkImageAccess, checkModelAccess } from 'services/api/hooks/accessChecks';
import type { GraphAndWorkflowResponse, NonNullableGraph } from 'services/api/types';
import { z } from 'zod';
import { fromZodError } from 'zod-validation-error';
const getWorkflow = (data: GraphAndWorkflowResponse, templates: Templates) => {
const getWorkflow = async (data: GraphAndWorkflowResponse, templates: Templates) => {
if (data.workflow) {
// Prefer to load the workflow if it's available - it has more information
const parsed = JSON.parse(data.workflow);
return validateWorkflow(parsed, templates);
return await validateWorkflow(parsed, templates, checkImageAccess, checkBoardAccess, checkModelAccess);
} else if (data.graph) {
// Else we fall back on the graph, using the graphToWorkflow function to convert and do layout
const parsed = JSON.parse(data.graph);
const workflow = graphToWorkflow(parsed as NonNullableGraph, true);
return validateWorkflow(workflow, templates);
return await validateWorkflow(workflow, templates, checkImageAccess, checkBoardAccess, checkModelAccess);
} else {
throw new Error('No workflow or graph provided');
}
@ -33,13 +34,13 @@ const getWorkflow = (data: GraphAndWorkflowResponse, templates: Templates) => {
export const addWorkflowLoadRequestedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: workflowLoadRequested,
effect: (action, { dispatch }) => {
effect: async (action, { dispatch }) => {
const log = logger('nodes');
const { data, asCopy } = action.payload;
const nodeTemplates = $templates.get();
try {
const { workflow, warnings } = getWorkflow(data, nodeTemplates);
const { workflow, warnings } = await getWorkflow(data, nodeTemplates);
if (asCopy) {
// If we're loading a copy, we need to remove the ID so that the backend will create a new workflow

View File

@ -1,6 +1,11 @@
import type { JSONObject } from 'common/types';
import { parseify } from 'common/util/serialize';
import type { Templates } from 'features/nodes/store/types';
import {
isBoardFieldInputInstance,
isImageFieldInputInstance,
isModelIdentifierFieldInputInstance,
} from 'features/nodes/types/field';
import type { WorkflowV3 } from 'features/nodes/types/workflow';
import { isWorkflowInvocationNode } from 'features/nodes/types/workflow';
import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate';
@ -20,6 +25,18 @@ type ValidateWorkflowResult = {
warnings: WorkflowWarning[];
};
const MODEL_FIELD_TYPES = [
'ModelIdentifier',
'MainModelField',
'SDXLMainModelField',
'SDXLRefinerModelField',
'VAEModelField',
'LoRAModelField',
'ControlNetModelField',
'IPAdapterModelField',
'T2IAdapterModelField',
];
/**
* Parses and validates a workflow:
* - Parses the workflow schema, and migrates it to the latest version if necessary.
@ -27,11 +44,17 @@ type ValidateWorkflowResult = {
* - Attempts to update nodes which have a mismatched version.
* - Removes edges which are invalid.
* @param workflow The raw workflow object (e.g. JSON.parse(stringifiedWorklow))
* @param invocationTemplates The node templates to validate against.
* @param templates The node templates to validate against.
* @throws {WorkflowVersionError} If the workflow version is not recognized.
* @throws {z.ZodError} If there is a validation error.
*/
export const validateWorkflow = (workflow: unknown, invocationTemplates: Templates): ValidateWorkflowResult => {
export const validateWorkflow = async (
workflow: unknown,
templates: Templates,
checkImageAccess: (name: string) => Promise<boolean>,
checkBoardAccess: (id: string) => Promise<boolean>,
checkModelAccess: (key: string) => Promise<boolean>
): Promise<ValidateWorkflowResult> => {
// Parse the raw workflow data & migrate it to the latest version
const _workflow = parseAndMigrateWorkflow(workflow);
@ -50,8 +73,8 @@ export const validateWorkflow = (workflow: unknown, invocationTemplates: Templat
const invocationNodes = nodes.filter(isWorkflowInvocationNode);
const keyedNodes = keyBy(invocationNodes, 'id');
invocationNodes.forEach((node) => {
const template = invocationTemplates[node.data.type];
for (const node of Object.values(invocationNodes)) {
const template = templates[node.data.type];
if (!template) {
// This node's type template does not exist
const message = t('nodes.missingTemplate', {
@ -62,7 +85,7 @@ export const validateWorkflow = (workflow: unknown, invocationTemplates: Templat
message,
data: parseify(node),
});
return;
continue;
}
if (getNeedsUpdate(node.data, template)) {
@ -75,15 +98,56 @@ export const validateWorkflow = (workflow: unknown, invocationTemplates: Templat
message,
data: parseify({ node, nodeTemplate: template }),
});
return;
continue;
}
});
for (const input of Object.values(node.data.inputs)) {
const fieldTemplate = template.inputs[input.name];
if (!fieldTemplate) {
const message = t('nodes.missingFieldTemplate');
warnings.push({
message,
data: parseify({ node, nodeTemplate: template, input }),
});
continue;
}
if (fieldTemplate.type.name === 'ImageField' && isImageFieldInputInstance(input) && input.value) {
const hasAccess = await checkImageAccess(input.value.image_name);
if (!hasAccess) {
const message = t('nodes.imageAccessError', { image_name: input.value.image_name });
warnings.push({ message, data: parseify({ node, nodeTemplate: template, input }) });
input.value = undefined;
}
}
if (fieldTemplate.type.name === 'BoardField' && isBoardFieldInputInstance(input) && input.value) {
const hasAccess = await checkBoardAccess(input.value.board_id);
if (!hasAccess) {
const message = t('nodes.boardAccessError', { board_id: input.value.board_id });
warnings.push({ message, data: parseify({ node, nodeTemplate: template, input }) });
input.value = undefined;
}
}
if (
MODEL_FIELD_TYPES.includes(fieldTemplate.type.name) &&
isModelIdentifierFieldInputInstance(input) &&
input.value
) {
const hasAccess = await checkModelAccess(input.value.key);
if (!hasAccess) {
const message = t('nodes.modelAccessError', { key: input.value.key });
warnings.push({ message, data: parseify({ node, nodeTemplate: template, input }) });
input.value = undefined;
}
}
}
}
edges.forEach((edge, i) => {
// Validate each edge. If the edge is invalid, we must remove it to prevent runtime errors with reactflow.
const sourceNode = keyedNodes[edge.source];
const targetNode = keyedNodes[edge.target];
const sourceTemplate = sourceNode ? invocationTemplates[sourceNode.data.type] : undefined;
const targetTemplate = targetNode ? invocationTemplates[targetNode.data.type] : undefined;
const sourceTemplate = sourceNode ? templates[sourceNode.data.type] : undefined;
const targetTemplate = targetNode ? templates[targetNode.data.type] : undefined;
const issues: string[] = [];
if (!sourceNode) {

View File

@ -0,0 +1,55 @@
import { getStore } from 'app/store/nanostores/store';
import { boardsApi } from 'services/api/endpoints/boards';
import { imagesApi } from 'services/api/endpoints/images';
import { modelsApi } from 'services/api/endpoints/models';
/**
* Checks if the client has access to a model.
* @param key The model key.
* @returns A promise that resolves to true if the client has access, else false.
*/
export const checkModelAccess = async (key: string): Promise<boolean> => {
const { dispatch } = getStore();
try {
const req = dispatch(modelsApi.endpoints.getModelConfig.initiate(key));
req.unsubscribe();
const result = await req.unwrap();
return Boolean(result);
} catch {
return false;
}
};
/**
* Checks if the client has access to an image.
* @param name The image name.
* @returns A promise that resolves to true if the client has access, else false.
*/
export const checkImageAccess = async (name: string): Promise<boolean> => {
const { dispatch } = getStore();
try {
const req = dispatch(imagesApi.endpoints.getImageDTO.initiate(name));
req.unsubscribe();
const result = await req.unwrap();
return Boolean(result);
} catch {
return false;
}
};
/**
* Checks if the client has access to a board.
* @param id The board id.
* @returns A promise that resolves to true if the client has access, else false.
*/
export const checkBoardAccess = async (id: string): Promise<boolean> => {
const { dispatch } = getStore();
try {
const req = dispatch(boardsApi.endpoints.listAllBoards.initiate());
req.unsubscribe();
const result = await req.unwrap();
return result.some((b) => b.board_id === id);
} catch {
return false;
}
};