mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
chore(ui): lint
This commit is contained in:
parent
a8b042177d
commit
6791b4eaa8
@ -70,7 +70,14 @@ export const useConnection = () => {
|
|||||||
}
|
}
|
||||||
const candidateTemplate = templates[candidateNode.data.type];
|
const candidateTemplate = templates[candidateNode.data.type];
|
||||||
assert(candidateTemplate, `Template not found for node type: ${candidateNode.data.type}`);
|
assert(candidateTemplate, `Template not found for node type: ${candidateNode.data.type}`);
|
||||||
const connection = getFirstValidConnection(templates, nodes, edges, pendingConnection, candidateNode, candidateTemplate);
|
const connection = getFirstValidConnection(
|
||||||
|
templates,
|
||||||
|
nodes,
|
||||||
|
edges,
|
||||||
|
pendingConnection,
|
||||||
|
candidateNode,
|
||||||
|
candidateTemplate
|
||||||
|
);
|
||||||
if (connection) {
|
if (connection) {
|
||||||
dispatch(connectionMade(connection));
|
dispatch(connectionMade(connection));
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,12 @@
|
|||||||
import { getStore } from 'app/store/nanostores/store';
|
import { getStore } from 'app/store/nanostores/store';
|
||||||
import { deepClone } from 'common/util/deepClone';
|
import { deepClone } from 'common/util/deepClone';
|
||||||
import { $copiedEdges,$copiedNodes,$cursorPos, selectionPasted, selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
import {
|
||||||
|
$copiedEdges,
|
||||||
|
$copiedNodes,
|
||||||
|
$cursorPos,
|
||||||
|
selectionPasted,
|
||||||
|
selectNodesSlice,
|
||||||
|
} from 'features/nodes/store/nodesSlice';
|
||||||
import { findUnoccupiedPosition } from 'features/nodes/store/util/findUnoccupiedPosition';
|
import { findUnoccupiedPosition } from 'features/nodes/store/util/findUnoccupiedPosition';
|
||||||
import { v4 as uuidv4 } from 'uuid';
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ import type { InvocationNode, InvocationNodeData } from 'features/nodes/types/in
|
|||||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||||
import { assert } from 'tsafe';
|
import { assert } from 'tsafe';
|
||||||
|
|
||||||
export const selectInvocationNode = (nodesSlice: NodesState, nodeId: string): InvocationNode => {
|
const selectInvocationNode = (nodesSlice: NodesState, nodeId: string): InvocationNode => {
|
||||||
const node = nodesSlice.nodes.find((node) => node.id === nodeId);
|
const node = nodesSlice.nodes.find((node) => node.id === nodeId);
|
||||||
assert(isInvocationNode(node), `Node ${nodeId} is not an invocation node`);
|
assert(isInvocationNode(node), `Node ${nodeId} is not an invocation node`);
|
||||||
return node;
|
return node;
|
||||||
|
@ -1,120 +1,13 @@
|
|||||||
import type { PendingConnection, Templates } from 'features/nodes/store/types';
|
import type { PendingConnection, Templates } from 'features/nodes/store/types';
|
||||||
import { getCollectItemType } from 'features/nodes/store/util/makeIsConnectionValidSelector';
|
import { getCollectItemType } from 'features/nodes/store/util/makeIsConnectionValidSelector';
|
||||||
import type { FieldInputTemplate, FieldOutputTemplate, FieldType } from 'features/nodes/types/field';
|
|
||||||
import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation';
|
import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation';
|
||||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
|
||||||
import { differenceWith, isEqual, map } from 'lodash-es';
|
import { differenceWith, isEqual, map } from 'lodash-es';
|
||||||
import type { Connection, Edge, HandleType, Node } from 'reactflow';
|
import type { Connection } from 'reactflow';
|
||||||
import { assert } from 'tsafe';
|
import { assert } from 'tsafe';
|
||||||
|
|
||||||
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
|
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
|
||||||
import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
|
import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
|
||||||
|
|
||||||
const isValidConnection = (
|
|
||||||
edges: Edge[],
|
|
||||||
handleCurrentType: HandleType,
|
|
||||||
handleCurrentFieldType: FieldType,
|
|
||||||
node: Node,
|
|
||||||
handle: FieldInputTemplate | FieldOutputTemplate
|
|
||||||
) => {
|
|
||||||
let isValidConnection = true;
|
|
||||||
if (handleCurrentType === 'source') {
|
|
||||||
if (
|
|
||||||
edges.find((edge) => {
|
|
||||||
return edge.target === node.id && edge.targetHandle === handle.name;
|
|
||||||
})
|
|
||||||
) {
|
|
||||||
isValidConnection = false;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (
|
|
||||||
edges.find((edge) => {
|
|
||||||
return edge.source === node.id && edge.sourceHandle === handle.name;
|
|
||||||
})
|
|
||||||
) {
|
|
||||||
isValidConnection = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!validateSourceAndTargetTypes(handleCurrentFieldType, handle.type)) {
|
|
||||||
isValidConnection = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return isValidConnection;
|
|
||||||
};
|
|
||||||
|
|
||||||
export const findConnectionToValidHandle = (
|
|
||||||
node: AnyNode,
|
|
||||||
nodes: AnyNode[],
|
|
||||||
edges: InvocationNodeEdge[],
|
|
||||||
templates: Templates,
|
|
||||||
handleCurrentNodeId: string,
|
|
||||||
handleCurrentName: string,
|
|
||||||
handleCurrentType: HandleType,
|
|
||||||
handleCurrentFieldType: FieldType
|
|
||||||
): Connection | null => {
|
|
||||||
if (node.id === handleCurrentNodeId || !isInvocationNode(node)) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
const template = templates[node.data.type];
|
|
||||||
|
|
||||||
if (!template) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
const handles = handleCurrentType === 'source' ? template.inputs : template.outputs;
|
|
||||||
|
|
||||||
//Prioritize handles whos name matches the node we're coming from
|
|
||||||
const handle = handles[handleCurrentName];
|
|
||||||
|
|
||||||
if (handle) {
|
|
||||||
const sourceID = handleCurrentType === 'source' ? handleCurrentNodeId : node.id;
|
|
||||||
const targetID = handleCurrentType === 'source' ? node.id : handleCurrentNodeId;
|
|
||||||
const sourceHandle = handleCurrentType === 'source' ? handleCurrentName : handle.name;
|
|
||||||
const targetHandle = handleCurrentType === 'source' ? handle.name : handleCurrentName;
|
|
||||||
|
|
||||||
const isGraphAcyclic = getIsGraphAcyclic(sourceID, targetID, nodes, edges);
|
|
||||||
|
|
||||||
const valid = isValidConnection(edges, handleCurrentType, handleCurrentFieldType, node, handle);
|
|
||||||
|
|
||||||
if (isGraphAcyclic && valid) {
|
|
||||||
return {
|
|
||||||
source: sourceID,
|
|
||||||
sourceHandle: sourceHandle,
|
|
||||||
target: targetID,
|
|
||||||
targetHandle: targetHandle,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (const handleName in handles) {
|
|
||||||
const handle = handles[handleName];
|
|
||||||
if (!handle) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
const sourceID = handleCurrentType === 'source' ? handleCurrentNodeId : node.id;
|
|
||||||
const targetID = handleCurrentType === 'source' ? node.id : handleCurrentNodeId;
|
|
||||||
const sourceHandle = handleCurrentType === 'source' ? handleCurrentName : handle.name;
|
|
||||||
const targetHandle = handleCurrentType === 'source' ? handle.name : handleCurrentName;
|
|
||||||
|
|
||||||
const isGraphAcyclic = getIsGraphAcyclic(sourceID, targetID, nodes, edges);
|
|
||||||
|
|
||||||
const valid = isValidConnection(edges, handleCurrentType, handleCurrentFieldType, node, handle);
|
|
||||||
|
|
||||||
if (isGraphAcyclic && valid) {
|
|
||||||
return {
|
|
||||||
source: sourceID,
|
|
||||||
sourceHandle: sourceHandle,
|
|
||||||
target: targetID,
|
|
||||||
targetHandle: targetHandle,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
};
|
|
||||||
|
|
||||||
export const getFirstValidConnection = (
|
export const getFirstValidConnection = (
|
||||||
templates: Templates,
|
templates: Templates,
|
||||||
nodes: AnyNode[],
|
nodes: AnyNode[],
|
||||||
|
@ -3,7 +3,7 @@ import { createSlice } from '@reduxjs/toolkit';
|
|||||||
import type { PersistConfig, RootState } from 'app/store/store';
|
import type { PersistConfig, RootState } from 'app/store/store';
|
||||||
import { SelectionMode } from 'reactflow';
|
import { SelectionMode } from 'reactflow';
|
||||||
|
|
||||||
export type WorkflowSettingsState = {
|
type WorkflowSettingsState = {
|
||||||
_version: 1;
|
_version: 1;
|
||||||
shouldShowMinimapPanel: boolean;
|
shouldShowMinimapPanel: boolean;
|
||||||
shouldValidateGraph: boolean;
|
shouldValidateGraph: boolean;
|
||||||
|
@ -31,10 +31,7 @@ type ValidateWorkflowResult = {
|
|||||||
* @throws {WorkflowVersionError} If the workflow version is not recognized.
|
* @throws {WorkflowVersionError} If the workflow version is not recognized.
|
||||||
* @throws {z.ZodError} If there is a validation error.
|
* @throws {z.ZodError} If there is a validation error.
|
||||||
*/
|
*/
|
||||||
export const validateWorkflow = (
|
export const validateWorkflow = (workflow: unknown, invocationTemplates: Templates): ValidateWorkflowResult => {
|
||||||
workflow: unknown,
|
|
||||||
invocationTemplates: Templates
|
|
||||||
): ValidateWorkflowResult => {
|
|
||||||
// Parse the raw workflow data & migrate it to the latest version
|
// Parse the raw workflow data & migrate it to the latest version
|
||||||
const _workflow = parseAndMigrateWorkflow(workflow);
|
const _workflow = parseAndMigrateWorkflow(workflow);
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user