mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
chore(ui): lint
This commit is contained in:
parent
a8b042177d
commit
6791b4eaa8
@ -70,7 +70,14 @@ export const useConnection = () => {
|
||||
}
|
||||
const candidateTemplate = templates[candidateNode.data.type];
|
||||
assert(candidateTemplate, `Template not found for node type: ${candidateNode.data.type}`);
|
||||
const connection = getFirstValidConnection(templates, nodes, edges, pendingConnection, candidateNode, candidateTemplate);
|
||||
const connection = getFirstValidConnection(
|
||||
templates,
|
||||
nodes,
|
||||
edges,
|
||||
pendingConnection,
|
||||
candidateNode,
|
||||
candidateTemplate
|
||||
);
|
||||
if (connection) {
|
||||
dispatch(connectionMade(connection));
|
||||
}
|
||||
|
@ -1,6 +1,12 @@
|
||||
import { getStore } from 'app/store/nanostores/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { $copiedEdges,$copiedNodes,$cursorPos, selectionPasted, selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
$copiedEdges,
|
||||
$copiedNodes,
|
||||
$cursorPos,
|
||||
selectionPasted,
|
||||
selectNodesSlice,
|
||||
} from 'features/nodes/store/nodesSlice';
|
||||
import { findUnoccupiedPosition } from 'features/nodes/store/util/findUnoccupiedPosition';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
|
@ -4,7 +4,7 @@ import type { InvocationNode, InvocationNodeData } from 'features/nodes/types/in
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
export const selectInvocationNode = (nodesSlice: NodesState, nodeId: string): InvocationNode => {
|
||||
const selectInvocationNode = (nodesSlice: NodesState, nodeId: string): InvocationNode => {
|
||||
const node = nodesSlice.nodes.find((node) => node.id === nodeId);
|
||||
assert(isInvocationNode(node), `Node ${nodeId} is not an invocation node`);
|
||||
return node;
|
||||
|
@ -1,120 +1,13 @@
|
||||
import type { PendingConnection, Templates } from 'features/nodes/store/types';
|
||||
import { getCollectItemType } from 'features/nodes/store/util/makeIsConnectionValidSelector';
|
||||
import type { FieldInputTemplate, FieldOutputTemplate, FieldType } from 'features/nodes/types/field';
|
||||
import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { differenceWith, isEqual, map } from 'lodash-es';
|
||||
import type { Connection, Edge, HandleType, Node } from 'reactflow';
|
||||
import type { Connection } from 'reactflow';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
|
||||
import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
|
||||
|
||||
const isValidConnection = (
|
||||
edges: Edge[],
|
||||
handleCurrentType: HandleType,
|
||||
handleCurrentFieldType: FieldType,
|
||||
node: Node,
|
||||
handle: FieldInputTemplate | FieldOutputTemplate
|
||||
) => {
|
||||
let isValidConnection = true;
|
||||
if (handleCurrentType === 'source') {
|
||||
if (
|
||||
edges.find((edge) => {
|
||||
return edge.target === node.id && edge.targetHandle === handle.name;
|
||||
})
|
||||
) {
|
||||
isValidConnection = false;
|
||||
}
|
||||
} else {
|
||||
if (
|
||||
edges.find((edge) => {
|
||||
return edge.source === node.id && edge.sourceHandle === handle.name;
|
||||
})
|
||||
) {
|
||||
isValidConnection = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!validateSourceAndTargetTypes(handleCurrentFieldType, handle.type)) {
|
||||
isValidConnection = false;
|
||||
}
|
||||
|
||||
return isValidConnection;
|
||||
};
|
||||
|
||||
export const findConnectionToValidHandle = (
|
||||
node: AnyNode,
|
||||
nodes: AnyNode[],
|
||||
edges: InvocationNodeEdge[],
|
||||
templates: Templates,
|
||||
handleCurrentNodeId: string,
|
||||
handleCurrentName: string,
|
||||
handleCurrentType: HandleType,
|
||||
handleCurrentFieldType: FieldType
|
||||
): Connection | null => {
|
||||
if (node.id === handleCurrentNodeId || !isInvocationNode(node)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const template = templates[node.data.type];
|
||||
|
||||
if (!template) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const handles = handleCurrentType === 'source' ? template.inputs : template.outputs;
|
||||
|
||||
//Prioritize handles whos name matches the node we're coming from
|
||||
const handle = handles[handleCurrentName];
|
||||
|
||||
if (handle) {
|
||||
const sourceID = handleCurrentType === 'source' ? handleCurrentNodeId : node.id;
|
||||
const targetID = handleCurrentType === 'source' ? node.id : handleCurrentNodeId;
|
||||
const sourceHandle = handleCurrentType === 'source' ? handleCurrentName : handle.name;
|
||||
const targetHandle = handleCurrentType === 'source' ? handle.name : handleCurrentName;
|
||||
|
||||
const isGraphAcyclic = getIsGraphAcyclic(sourceID, targetID, nodes, edges);
|
||||
|
||||
const valid = isValidConnection(edges, handleCurrentType, handleCurrentFieldType, node, handle);
|
||||
|
||||
if (isGraphAcyclic && valid) {
|
||||
return {
|
||||
source: sourceID,
|
||||
sourceHandle: sourceHandle,
|
||||
target: targetID,
|
||||
targetHandle: targetHandle,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
for (const handleName in handles) {
|
||||
const handle = handles[handleName];
|
||||
if (!handle) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const sourceID = handleCurrentType === 'source' ? handleCurrentNodeId : node.id;
|
||||
const targetID = handleCurrentType === 'source' ? node.id : handleCurrentNodeId;
|
||||
const sourceHandle = handleCurrentType === 'source' ? handleCurrentName : handle.name;
|
||||
const targetHandle = handleCurrentType === 'source' ? handle.name : handleCurrentName;
|
||||
|
||||
const isGraphAcyclic = getIsGraphAcyclic(sourceID, targetID, nodes, edges);
|
||||
|
||||
const valid = isValidConnection(edges, handleCurrentType, handleCurrentFieldType, node, handle);
|
||||
|
||||
if (isGraphAcyclic && valid) {
|
||||
return {
|
||||
source: sourceID,
|
||||
sourceHandle: sourceHandle,
|
||||
target: targetID,
|
||||
targetHandle: targetHandle,
|
||||
};
|
||||
}
|
||||
}
|
||||
return null;
|
||||
};
|
||||
|
||||
export const getFirstValidConnection = (
|
||||
templates: Templates,
|
||||
nodes: AnyNode[],
|
||||
|
@ -3,7 +3,7 @@ import { createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { SelectionMode } from 'reactflow';
|
||||
|
||||
export type WorkflowSettingsState = {
|
||||
type WorkflowSettingsState = {
|
||||
_version: 1;
|
||||
shouldShowMinimapPanel: boolean;
|
||||
shouldValidateGraph: boolean;
|
||||
|
@ -31,10 +31,7 @@ type ValidateWorkflowResult = {
|
||||
* @throws {WorkflowVersionError} If the workflow version is not recognized.
|
||||
* @throws {z.ZodError} If there is a validation error.
|
||||
*/
|
||||
export const validateWorkflow = (
|
||||
workflow: unknown,
|
||||
invocationTemplates: Templates
|
||||
): ValidateWorkflowResult => {
|
||||
export const validateWorkflow = (workflow: unknown, invocationTemplates: Templates): ValidateWorkflowResult => {
|
||||
// Parse the raw workflow data & migrate it to the latest version
|
||||
const _workflow = parseAndMigrateWorkflow(workflow);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user