From 38320a5100ff9a15056adf8532c3bdd91f365c5a Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 21 May 2024 09:54:13 +1000 Subject: [PATCH] feat(ui): reset missing images, boards and models when loading workflows These fields are reset back to `undefined` if not accessible. A warning toast is showing, and in the JS console, the full warning message is logged. --- invokeai/frontend/web/public/locales/en.json | 5 +- .../listeners/workflowLoadRequested.ts | 11 +-- .../nodes/util/workflow/validateWorkflow.ts | 82 +++++++++++++++++-- .../src/services/api/hooks/accessChecks.ts | 55 +++++++++++++ 4 files changed, 138 insertions(+), 15 deletions(-) create mode 100644 invokeai/frontend/web/src/services/api/hooks/accessChecks.ts diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 1d41a1de63..da8d69c91e 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -897,7 +897,10 @@ "zoomInNodes": "Zoom In", "zoomOutNodes": "Zoom Out", "betaDesc": "This invocation is in beta. Until it is stable, it may have breaking changes during app updates. We plan to support this invocation long-term.", - "prototypeDesc": "This invocation is a prototype. It may have breaking changes during app updates and may be removed at any time." + "prototypeDesc": "This invocation is a prototype. It may have breaking changes during app updates and may be removed at any time.", + "imageAccessError": "Unable to find image {{image_name}}, resetting to default", + "boardAccessError": "Unable to find board {{board_id}}, resetting to default", + "modelAccessError": "Unable to find model {{key}}, resetting to default" }, "parameters": { "aspect": "Aspect", diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts index a680bbca97..9f9e70fc01 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts @@ -11,20 +11,21 @@ import { validateWorkflow } from 'features/nodes/util/workflow/validateWorkflow' import { addToast } from 'features/system/store/systemSlice'; import { makeToast } from 'features/system/util/makeToast'; import { t } from 'i18next'; +import { checkBoardAccess, checkImageAccess, checkModelAccess } from 'services/api/hooks/accessChecks'; import type { GraphAndWorkflowResponse, NonNullableGraph } from 'services/api/types'; import { z } from 'zod'; import { fromZodError } from 'zod-validation-error'; -const getWorkflow = (data: GraphAndWorkflowResponse, templates: Templates) => { +const getWorkflow = async (data: GraphAndWorkflowResponse, templates: Templates) => { if (data.workflow) { // Prefer to load the workflow if it's available - it has more information const parsed = JSON.parse(data.workflow); - return validateWorkflow(parsed, templates); + return await validateWorkflow(parsed, templates, checkImageAccess, checkBoardAccess, checkModelAccess); } else if (data.graph) { // Else we fall back on the graph, using the graphToWorkflow function to convert and do layout const parsed = JSON.parse(data.graph); const workflow = graphToWorkflow(parsed as NonNullableGraph, true); - return validateWorkflow(workflow, templates); + return await validateWorkflow(workflow, templates, checkImageAccess, checkBoardAccess, checkModelAccess); } else { throw new Error('No workflow or graph provided'); } @@ -33,13 +34,13 @@ const getWorkflow = (data: GraphAndWorkflowResponse, templates: Templates) => { export const addWorkflowLoadRequestedListener = (startAppListening: AppStartListening) => { startAppListening({ actionCreator: workflowLoadRequested, - effect: (action, { dispatch }) => { + effect: async (action, { dispatch }) => { const log = logger('nodes'); const { data, asCopy } = action.payload; const nodeTemplates = $templates.get(); try { - const { workflow, warnings } = getWorkflow(data, nodeTemplates); + const { workflow, warnings } = await getWorkflow(data, nodeTemplates); if (asCopy) { // If we're loading a copy, we need to remove the ID so that the backend will create a new workflow diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts index d2d3d64cb0..546537e275 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts @@ -1,6 +1,11 @@ import type { JSONObject } from 'common/types'; import { parseify } from 'common/util/serialize'; import type { Templates } from 'features/nodes/store/types'; +import { + isBoardFieldInputInstance, + isImageFieldInputInstance, + isModelIdentifierFieldInputInstance, +} from 'features/nodes/types/field'; import type { WorkflowV3 } from 'features/nodes/types/workflow'; import { isWorkflowInvocationNode } from 'features/nodes/types/workflow'; import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate'; @@ -20,6 +25,18 @@ type ValidateWorkflowResult = { warnings: WorkflowWarning[]; }; +const MODEL_FIELD_TYPES = [ + 'ModelIdentifier', + 'MainModelField', + 'SDXLMainModelField', + 'SDXLRefinerModelField', + 'VAEModelField', + 'LoRAModelField', + 'ControlNetModelField', + 'IPAdapterModelField', + 'T2IAdapterModelField', +]; + /** * Parses and validates a workflow: * - Parses the workflow schema, and migrates it to the latest version if necessary. @@ -27,11 +44,17 @@ type ValidateWorkflowResult = { * - Attempts to update nodes which have a mismatched version. * - Removes edges which are invalid. * @param workflow The raw workflow object (e.g. JSON.parse(stringifiedWorklow)) - * @param invocationTemplates The node templates to validate against. + * @param templates The node templates to validate against. * @throws {WorkflowVersionError} If the workflow version is not recognized. * @throws {z.ZodError} If there is a validation error. */ -export const validateWorkflow = (workflow: unknown, invocationTemplates: Templates): ValidateWorkflowResult => { +export const validateWorkflow = async ( + workflow: unknown, + templates: Templates, + checkImageAccess: (name: string) => Promise, + checkBoardAccess: (id: string) => Promise, + checkModelAccess: (key: string) => Promise +): Promise => { // Parse the raw workflow data & migrate it to the latest version const _workflow = parseAndMigrateWorkflow(workflow); @@ -50,8 +73,8 @@ export const validateWorkflow = (workflow: unknown, invocationTemplates: Templat const invocationNodes = nodes.filter(isWorkflowInvocationNode); const keyedNodes = keyBy(invocationNodes, 'id'); - invocationNodes.forEach((node) => { - const template = invocationTemplates[node.data.type]; + for (const node of Object.values(invocationNodes)) { + const template = templates[node.data.type]; if (!template) { // This node's type template does not exist const message = t('nodes.missingTemplate', { @@ -62,7 +85,7 @@ export const validateWorkflow = (workflow: unknown, invocationTemplates: Templat message, data: parseify(node), }); - return; + continue; } if (getNeedsUpdate(node.data, template)) { @@ -75,15 +98,56 @@ export const validateWorkflow = (workflow: unknown, invocationTemplates: Templat message, data: parseify({ node, nodeTemplate: template }), }); - return; + continue; } - }); + + for (const input of Object.values(node.data.inputs)) { + const fieldTemplate = template.inputs[input.name]; + + if (!fieldTemplate) { + const message = t('nodes.missingFieldTemplate'); + warnings.push({ + message, + data: parseify({ node, nodeTemplate: template, input }), + }); + continue; + } + if (fieldTemplate.type.name === 'ImageField' && isImageFieldInputInstance(input) && input.value) { + const hasAccess = await checkImageAccess(input.value.image_name); + if (!hasAccess) { + const message = t('nodes.imageAccessError', { image_name: input.value.image_name }); + warnings.push({ message, data: parseify({ node, nodeTemplate: template, input }) }); + input.value = undefined; + } + } + if (fieldTemplate.type.name === 'BoardField' && isBoardFieldInputInstance(input) && input.value) { + const hasAccess = await checkBoardAccess(input.value.board_id); + if (!hasAccess) { + const message = t('nodes.boardAccessError', { board_id: input.value.board_id }); + warnings.push({ message, data: parseify({ node, nodeTemplate: template, input }) }); + input.value = undefined; + } + } + if ( + MODEL_FIELD_TYPES.includes(fieldTemplate.type.name) && + isModelIdentifierFieldInputInstance(input) && + input.value + ) { + const hasAccess = await checkModelAccess(input.value.key); + if (!hasAccess) { + const message = t('nodes.modelAccessError', { key: input.value.key }); + warnings.push({ message, data: parseify({ node, nodeTemplate: template, input }) }); + input.value = undefined; + } + } + } + } edges.forEach((edge, i) => { // Validate each edge. If the edge is invalid, we must remove it to prevent runtime errors with reactflow. const sourceNode = keyedNodes[edge.source]; const targetNode = keyedNodes[edge.target]; - const sourceTemplate = sourceNode ? invocationTemplates[sourceNode.data.type] : undefined; - const targetTemplate = targetNode ? invocationTemplates[targetNode.data.type] : undefined; + const sourceTemplate = sourceNode ? templates[sourceNode.data.type] : undefined; + const targetTemplate = targetNode ? templates[targetNode.data.type] : undefined; const issues: string[] = []; if (!sourceNode) { diff --git a/invokeai/frontend/web/src/services/api/hooks/accessChecks.ts b/invokeai/frontend/web/src/services/api/hooks/accessChecks.ts new file mode 100644 index 0000000000..00e27d49c6 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/hooks/accessChecks.ts @@ -0,0 +1,55 @@ +import { getStore } from 'app/store/nanostores/store'; +import { boardsApi } from 'services/api/endpoints/boards'; +import { imagesApi } from 'services/api/endpoints/images'; +import { modelsApi } from 'services/api/endpoints/models'; + +/** + * Checks if the client has access to a model. + * @param key The model key. + * @returns A promise that resolves to true if the client has access, else false. + */ +export const checkModelAccess = async (key: string): Promise => { + const { dispatch } = getStore(); + try { + const req = dispatch(modelsApi.endpoints.getModelConfig.initiate(key)); + req.unsubscribe(); + const result = await req.unwrap(); + return Boolean(result); + } catch { + return false; + } +}; + +/** + * Checks if the client has access to an image. + * @param name The image name. + * @returns A promise that resolves to true if the client has access, else false. + */ +export const checkImageAccess = async (name: string): Promise => { + const { dispatch } = getStore(); + try { + const req = dispatch(imagesApi.endpoints.getImageDTO.initiate(name)); + req.unsubscribe(); + const result = await req.unwrap(); + return Boolean(result); + } catch { + return false; + } +}; + +/** + * Checks if the client has access to a board. + * @param id The board id. + * @returns A promise that resolves to true if the client has access, else false. + */ +export const checkBoardAccess = async (id: string): Promise => { + const { dispatch } = getStore(); + try { + const req = dispatch(boardsApi.endpoints.listAllBoards.initiate()); + req.unsubscribe(); + const result = await req.unwrap(); + return result.some((b) => b.board_id === id); + } catch { + return false; + } +};