mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): reset missing images, boards and models when loading workflows
These fields are reset back to `undefined` if not accessible. A warning toast is showing, and in the JS console, the full warning message is logged.
This commit is contained in:
parent
7badaab17d
commit
38320a5100
@ -897,7 +897,10 @@
|
||||
"zoomInNodes": "Zoom In",
|
||||
"zoomOutNodes": "Zoom Out",
|
||||
"betaDesc": "This invocation is in beta. Until it is stable, it may have breaking changes during app updates. We plan to support this invocation long-term.",
|
||||
"prototypeDesc": "This invocation is a prototype. It may have breaking changes during app updates and may be removed at any time."
|
||||
"prototypeDesc": "This invocation is a prototype. It may have breaking changes during app updates and may be removed at any time.",
|
||||
"imageAccessError": "Unable to find image {{image_name}}, resetting to default",
|
||||
"boardAccessError": "Unable to find board {{board_id}}, resetting to default",
|
||||
"modelAccessError": "Unable to find model {{key}}, resetting to default"
|
||||
},
|
||||
"parameters": {
|
||||
"aspect": "Aspect",
|
||||
|
@ -11,20 +11,21 @@ import { validateWorkflow } from 'features/nodes/util/workflow/validateWorkflow'
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { t } from 'i18next';
|
||||
import { checkBoardAccess, checkImageAccess, checkModelAccess } from 'services/api/hooks/accessChecks';
|
||||
import type { GraphAndWorkflowResponse, NonNullableGraph } from 'services/api/types';
|
||||
import { z } from 'zod';
|
||||
import { fromZodError } from 'zod-validation-error';
|
||||
|
||||
const getWorkflow = (data: GraphAndWorkflowResponse, templates: Templates) => {
|
||||
const getWorkflow = async (data: GraphAndWorkflowResponse, templates: Templates) => {
|
||||
if (data.workflow) {
|
||||
// Prefer to load the workflow if it's available - it has more information
|
||||
const parsed = JSON.parse(data.workflow);
|
||||
return validateWorkflow(parsed, templates);
|
||||
return await validateWorkflow(parsed, templates, checkImageAccess, checkBoardAccess, checkModelAccess);
|
||||
} else if (data.graph) {
|
||||
// Else we fall back on the graph, using the graphToWorkflow function to convert and do layout
|
||||
const parsed = JSON.parse(data.graph);
|
||||
const workflow = graphToWorkflow(parsed as NonNullableGraph, true);
|
||||
return validateWorkflow(workflow, templates);
|
||||
return await validateWorkflow(workflow, templates, checkImageAccess, checkBoardAccess, checkModelAccess);
|
||||
} else {
|
||||
throw new Error('No workflow or graph provided');
|
||||
}
|
||||
@ -33,13 +34,13 @@ const getWorkflow = (data: GraphAndWorkflowResponse, templates: Templates) => {
|
||||
export const addWorkflowLoadRequestedListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: workflowLoadRequested,
|
||||
effect: (action, { dispatch }) => {
|
||||
effect: async (action, { dispatch }) => {
|
||||
const log = logger('nodes');
|
||||
const { data, asCopy } = action.payload;
|
||||
const nodeTemplates = $templates.get();
|
||||
|
||||
try {
|
||||
const { workflow, warnings } = getWorkflow(data, nodeTemplates);
|
||||
const { workflow, warnings } = await getWorkflow(data, nodeTemplates);
|
||||
|
||||
if (asCopy) {
|
||||
// If we're loading a copy, we need to remove the ID so that the backend will create a new workflow
|
||||
|
@ -1,6 +1,11 @@
|
||||
import type { JSONObject } from 'common/types';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import type { Templates } from 'features/nodes/store/types';
|
||||
import {
|
||||
isBoardFieldInputInstance,
|
||||
isImageFieldInputInstance,
|
||||
isModelIdentifierFieldInputInstance,
|
||||
} from 'features/nodes/types/field';
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import { isWorkflowInvocationNode } from 'features/nodes/types/workflow';
|
||||
import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate';
|
||||
@ -20,6 +25,18 @@ type ValidateWorkflowResult = {
|
||||
warnings: WorkflowWarning[];
|
||||
};
|
||||
|
||||
const MODEL_FIELD_TYPES = [
|
||||
'ModelIdentifier',
|
||||
'MainModelField',
|
||||
'SDXLMainModelField',
|
||||
'SDXLRefinerModelField',
|
||||
'VAEModelField',
|
||||
'LoRAModelField',
|
||||
'ControlNetModelField',
|
||||
'IPAdapterModelField',
|
||||
'T2IAdapterModelField',
|
||||
];
|
||||
|
||||
/**
|
||||
* Parses and validates a workflow:
|
||||
* - Parses the workflow schema, and migrates it to the latest version if necessary.
|
||||
@ -27,11 +44,17 @@ type ValidateWorkflowResult = {
|
||||
* - Attempts to update nodes which have a mismatched version.
|
||||
* - Removes edges which are invalid.
|
||||
* @param workflow The raw workflow object (e.g. JSON.parse(stringifiedWorklow))
|
||||
* @param invocationTemplates The node templates to validate against.
|
||||
* @param templates The node templates to validate against.
|
||||
* @throws {WorkflowVersionError} If the workflow version is not recognized.
|
||||
* @throws {z.ZodError} If there is a validation error.
|
||||
*/
|
||||
export const validateWorkflow = (workflow: unknown, invocationTemplates: Templates): ValidateWorkflowResult => {
|
||||
export const validateWorkflow = async (
|
||||
workflow: unknown,
|
||||
templates: Templates,
|
||||
checkImageAccess: (name: string) => Promise<boolean>,
|
||||
checkBoardAccess: (id: string) => Promise<boolean>,
|
||||
checkModelAccess: (key: string) => Promise<boolean>
|
||||
): Promise<ValidateWorkflowResult> => {
|
||||
// Parse the raw workflow data & migrate it to the latest version
|
||||
const _workflow = parseAndMigrateWorkflow(workflow);
|
||||
|
||||
@ -50,8 +73,8 @@ export const validateWorkflow = (workflow: unknown, invocationTemplates: Templat
|
||||
const invocationNodes = nodes.filter(isWorkflowInvocationNode);
|
||||
const keyedNodes = keyBy(invocationNodes, 'id');
|
||||
|
||||
invocationNodes.forEach((node) => {
|
||||
const template = invocationTemplates[node.data.type];
|
||||
for (const node of Object.values(invocationNodes)) {
|
||||
const template = templates[node.data.type];
|
||||
if (!template) {
|
||||
// This node's type template does not exist
|
||||
const message = t('nodes.missingTemplate', {
|
||||
@ -62,7 +85,7 @@ export const validateWorkflow = (workflow: unknown, invocationTemplates: Templat
|
||||
message,
|
||||
data: parseify(node),
|
||||
});
|
||||
return;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (getNeedsUpdate(node.data, template)) {
|
||||
@ -75,15 +98,56 @@ export const validateWorkflow = (workflow: unknown, invocationTemplates: Templat
|
||||
message,
|
||||
data: parseify({ node, nodeTemplate: template }),
|
||||
});
|
||||
return;
|
||||
continue;
|
||||
}
|
||||
});
|
||||
|
||||
for (const input of Object.values(node.data.inputs)) {
|
||||
const fieldTemplate = template.inputs[input.name];
|
||||
|
||||
if (!fieldTemplate) {
|
||||
const message = t('nodes.missingFieldTemplate');
|
||||
warnings.push({
|
||||
message,
|
||||
data: parseify({ node, nodeTemplate: template, input }),
|
||||
});
|
||||
continue;
|
||||
}
|
||||
if (fieldTemplate.type.name === 'ImageField' && isImageFieldInputInstance(input) && input.value) {
|
||||
const hasAccess = await checkImageAccess(input.value.image_name);
|
||||
if (!hasAccess) {
|
||||
const message = t('nodes.imageAccessError', { image_name: input.value.image_name });
|
||||
warnings.push({ message, data: parseify({ node, nodeTemplate: template, input }) });
|
||||
input.value = undefined;
|
||||
}
|
||||
}
|
||||
if (fieldTemplate.type.name === 'BoardField' && isBoardFieldInputInstance(input) && input.value) {
|
||||
const hasAccess = await checkBoardAccess(input.value.board_id);
|
||||
if (!hasAccess) {
|
||||
const message = t('nodes.boardAccessError', { board_id: input.value.board_id });
|
||||
warnings.push({ message, data: parseify({ node, nodeTemplate: template, input }) });
|
||||
input.value = undefined;
|
||||
}
|
||||
}
|
||||
if (
|
||||
MODEL_FIELD_TYPES.includes(fieldTemplate.type.name) &&
|
||||
isModelIdentifierFieldInputInstance(input) &&
|
||||
input.value
|
||||
) {
|
||||
const hasAccess = await checkModelAccess(input.value.key);
|
||||
if (!hasAccess) {
|
||||
const message = t('nodes.modelAccessError', { key: input.value.key });
|
||||
warnings.push({ message, data: parseify({ node, nodeTemplate: template, input }) });
|
||||
input.value = undefined;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
edges.forEach((edge, i) => {
|
||||
// Validate each edge. If the edge is invalid, we must remove it to prevent runtime errors with reactflow.
|
||||
const sourceNode = keyedNodes[edge.source];
|
||||
const targetNode = keyedNodes[edge.target];
|
||||
const sourceTemplate = sourceNode ? invocationTemplates[sourceNode.data.type] : undefined;
|
||||
const targetTemplate = targetNode ? invocationTemplates[targetNode.data.type] : undefined;
|
||||
const sourceTemplate = sourceNode ? templates[sourceNode.data.type] : undefined;
|
||||
const targetTemplate = targetNode ? templates[targetNode.data.type] : undefined;
|
||||
const issues: string[] = [];
|
||||
|
||||
if (!sourceNode) {
|
||||
|
55
invokeai/frontend/web/src/services/api/hooks/accessChecks.ts
Normal file
55
invokeai/frontend/web/src/services/api/hooks/accessChecks.ts
Normal file
@ -0,0 +1,55 @@
|
||||
import { getStore } from 'app/store/nanostores/store';
|
||||
import { boardsApi } from 'services/api/endpoints/boards';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { modelsApi } from 'services/api/endpoints/models';
|
||||
|
||||
/**
|
||||
* Checks if the client has access to a model.
|
||||
* @param key The model key.
|
||||
* @returns A promise that resolves to true if the client has access, else false.
|
||||
*/
|
||||
export const checkModelAccess = async (key: string): Promise<boolean> => {
|
||||
const { dispatch } = getStore();
|
||||
try {
|
||||
const req = dispatch(modelsApi.endpoints.getModelConfig.initiate(key));
|
||||
req.unsubscribe();
|
||||
const result = await req.unwrap();
|
||||
return Boolean(result);
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Checks if the client has access to an image.
|
||||
* @param name The image name.
|
||||
* @returns A promise that resolves to true if the client has access, else false.
|
||||
*/
|
||||
export const checkImageAccess = async (name: string): Promise<boolean> => {
|
||||
const { dispatch } = getStore();
|
||||
try {
|
||||
const req = dispatch(imagesApi.endpoints.getImageDTO.initiate(name));
|
||||
req.unsubscribe();
|
||||
const result = await req.unwrap();
|
||||
return Boolean(result);
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Checks if the client has access to a board.
|
||||
* @param id The board id.
|
||||
* @returns A promise that resolves to true if the client has access, else false.
|
||||
*/
|
||||
export const checkBoardAccess = async (id: string): Promise<boolean> => {
|
||||
const { dispatch } = getStore();
|
||||
try {
|
||||
const req = dispatch(boardsApi.endpoints.listAllBoards.initiate());
|
||||
req.unsubscribe();
|
||||
const result = await req.unwrap();
|
||||
return result.some((b) => b.board_id === id);
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
};
|
Loading…
Reference in New Issue
Block a user