mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): use connection validationResults directly in components
This commit is contained in:
parent
26d0d55d97
commit
89b0e9e4de
@ -2,10 +2,12 @@ import { Tooltip } from '@invoke-ai/ui-library';
|
|||||||
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
|
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
|
||||||
import { getFieldColor } from 'features/nodes/components/flow/edges/util/getEdgeColor';
|
import { getFieldColor } from 'features/nodes/components/flow/edges/util/getEdgeColor';
|
||||||
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
|
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
|
||||||
|
import type { ValidationResult } from 'features/nodes/store/util/validateConnection';
|
||||||
import { HANDLE_TOOLTIP_OPEN_DELAY, MODEL_TYPES } from 'features/nodes/types/constants';
|
import { HANDLE_TOOLTIP_OPEN_DELAY, MODEL_TYPES } from 'features/nodes/types/constants';
|
||||||
import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field';
|
import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field';
|
||||||
import type { CSSProperties } from 'react';
|
import type { CSSProperties } from 'react';
|
||||||
import { memo, useMemo } from 'react';
|
import { memo, useMemo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
import type { HandleType } from 'reactflow';
|
import type { HandleType } from 'reactflow';
|
||||||
import { Handle, Position } from 'reactflow';
|
import { Handle, Position } from 'reactflow';
|
||||||
|
|
||||||
@ -14,11 +16,12 @@ type FieldHandleProps = {
|
|||||||
handleType: HandleType;
|
handleType: HandleType;
|
||||||
isConnectionInProgress: boolean;
|
isConnectionInProgress: boolean;
|
||||||
isConnectionStartField: boolean;
|
isConnectionStartField: boolean;
|
||||||
connectionError?: string;
|
validationResult: ValidationResult;
|
||||||
};
|
};
|
||||||
|
|
||||||
const FieldHandle = (props: FieldHandleProps) => {
|
const FieldHandle = (props: FieldHandleProps) => {
|
||||||
const { fieldTemplate, handleType, isConnectionInProgress, isConnectionStartField, connectionError } = props;
|
const { fieldTemplate, handleType, isConnectionInProgress, isConnectionStartField, validationResult } = props;
|
||||||
|
const { t } = useTranslation();
|
||||||
const { name } = fieldTemplate;
|
const { name } = fieldTemplate;
|
||||||
const type = fieldTemplate.type;
|
const type = fieldTemplate.type;
|
||||||
const fieldTypeName = useFieldTypeName(type);
|
const fieldTypeName = useFieldTypeName(type);
|
||||||
@ -43,11 +46,11 @@ const FieldHandle = (props: FieldHandleProps) => {
|
|||||||
s.insetInlineEnd = '-1rem';
|
s.insetInlineEnd = '-1rem';
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isConnectionInProgress && !isConnectionStartField && connectionError) {
|
if (isConnectionInProgress && !isConnectionStartField && !validationResult.isValid) {
|
||||||
s.filter = 'opacity(0.4) grayscale(0.7)';
|
s.filter = 'opacity(0.4) grayscale(0.7)';
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isConnectionInProgress && connectionError) {
|
if (isConnectionInProgress && !validationResult.isValid) {
|
||||||
if (isConnectionStartField) {
|
if (isConnectionStartField) {
|
||||||
s.cursor = 'grab';
|
s.cursor = 'grab';
|
||||||
} else {
|
} else {
|
||||||
@ -58,14 +61,14 @@ const FieldHandle = (props: FieldHandleProps) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return s;
|
return s;
|
||||||
}, [connectionError, handleType, isConnectionInProgress, isConnectionStartField, type]);
|
}, [handleType, isConnectionInProgress, isConnectionStartField, type, validationResult.isValid]);
|
||||||
|
|
||||||
const tooltip = useMemo(() => {
|
const tooltip = useMemo(() => {
|
||||||
if (isConnectionInProgress && connectionError) {
|
if (isConnectionInProgress && validationResult.messageTKey) {
|
||||||
return connectionError;
|
return t(validationResult.messageTKey);
|
||||||
}
|
}
|
||||||
return fieldTypeName;
|
return fieldTypeName;
|
||||||
}, [connectionError, fieldTypeName, isConnectionInProgress]);
|
}, [fieldTypeName, isConnectionInProgress, t, validationResult.messageTKey]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Tooltip
|
<Tooltip
|
||||||
|
@ -24,7 +24,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
|
|||||||
const doesFieldHaveValue = useDoesInputHaveValue(nodeId, fieldName);
|
const doesFieldHaveValue = useDoesInputHaveValue(nodeId, fieldName);
|
||||||
const [isHovered, setIsHovered] = useState(false);
|
const [isHovered, setIsHovered] = useState(false);
|
||||||
|
|
||||||
const { isConnected, isConnectionInProgress, isConnectionStartField, connectionError, shouldDim } =
|
const { isConnected, isConnectionInProgress, isConnectionStartField, validationResult, shouldDim } =
|
||||||
useConnectionState({ nodeId, fieldName, kind: 'inputs' });
|
useConnectionState({ nodeId, fieldName, kind: 'inputs' });
|
||||||
|
|
||||||
const isMissingInput = useMemo(() => {
|
const isMissingInput = useMemo(() => {
|
||||||
@ -88,7 +88,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
|
|||||||
handleType="target"
|
handleType="target"
|
||||||
isConnectionInProgress={isConnectionInProgress}
|
isConnectionInProgress={isConnectionInProgress}
|
||||||
isConnectionStartField={isConnectionStartField}
|
isConnectionStartField={isConnectionStartField}
|
||||||
connectionError={connectionError}
|
validationResult={validationResult}
|
||||||
/>
|
/>
|
||||||
</InputFieldWrapper>
|
</InputFieldWrapper>
|
||||||
);
|
);
|
||||||
@ -126,7 +126,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
|
|||||||
handleType="target"
|
handleType="target"
|
||||||
isConnectionInProgress={isConnectionInProgress}
|
isConnectionInProgress={isConnectionInProgress}
|
||||||
isConnectionStartField={isConnectionStartField}
|
isConnectionStartField={isConnectionStartField}
|
||||||
connectionError={connectionError}
|
validationResult={validationResult}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
</InputFieldWrapper>
|
</InputFieldWrapper>
|
||||||
|
@ -18,7 +18,7 @@ const OutputField = ({ nodeId, fieldName }: Props) => {
|
|||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const fieldTemplate = useFieldOutputTemplate(nodeId, fieldName);
|
const fieldTemplate = useFieldOutputTemplate(nodeId, fieldName);
|
||||||
|
|
||||||
const { isConnected, isConnectionInProgress, isConnectionStartField, connectionError, shouldDim } =
|
const { isConnected, isConnectionInProgress, isConnectionStartField, validationResult, shouldDim } =
|
||||||
useConnectionState({ nodeId, fieldName, kind: 'outputs' });
|
useConnectionState({ nodeId, fieldName, kind: 'outputs' });
|
||||||
|
|
||||||
if (!fieldTemplate) {
|
if (!fieldTemplate) {
|
||||||
@ -52,7 +52,7 @@ const OutputField = ({ nodeId, fieldName }: Props) => {
|
|||||||
handleType="source"
|
handleType="source"
|
||||||
isConnectionInProgress={isConnectionInProgress}
|
isConnectionInProgress={isConnectionInProgress}
|
||||||
isConnectionStartField={isConnectionStartField}
|
isConnectionStartField={isConnectionStartField}
|
||||||
connectionError={connectionError}
|
validationResult={validationResult}
|
||||||
/>
|
/>
|
||||||
</OutputFieldWrapper>
|
</OutputFieldWrapper>
|
||||||
);
|
);
|
||||||
|
@ -31,7 +31,7 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta
|
|||||||
[fieldName, kind, nodeId]
|
[fieldName, kind, nodeId]
|
||||||
);
|
);
|
||||||
|
|
||||||
const selectConnectionError = useMemo(
|
const selectValidationResult = useMemo(
|
||||||
() => makeConnectionErrorSelector(templates, nodeId, fieldName, kind === 'inputs' ? 'target' : 'source'),
|
() => makeConnectionErrorSelector(templates, nodeId, fieldName, kind === 'inputs' ? 'target' : 'source'),
|
||||||
[templates, nodeId, fieldName, kind]
|
[templates, nodeId, fieldName, kind]
|
||||||
);
|
);
|
||||||
@ -48,18 +48,18 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta
|
|||||||
pendingConnection.fieldTemplate.fieldKind === { inputs: 'input', outputs: 'output' }[kind]
|
pendingConnection.fieldTemplate.fieldKind === { inputs: 'input', outputs: 'output' }[kind]
|
||||||
);
|
);
|
||||||
}, [fieldName, kind, nodeId, pendingConnection]);
|
}, [fieldName, kind, nodeId, pendingConnection]);
|
||||||
const connectionError = useAppSelector((s) => selectConnectionError(s, pendingConnection, edgePendingUpdate));
|
const validationResult = useAppSelector((s) => selectValidationResult(s, pendingConnection, edgePendingUpdate));
|
||||||
|
|
||||||
const shouldDim = useMemo(
|
const shouldDim = useMemo(
|
||||||
() => Boolean(isConnectionInProgress && connectionError && !isConnectionStartField),
|
() => Boolean(isConnectionInProgress && !validationResult.isValid && !isConnectionStartField),
|
||||||
[connectionError, isConnectionInProgress, isConnectionStartField]
|
[validationResult, isConnectionInProgress, isConnectionStartField]
|
||||||
);
|
);
|
||||||
|
|
||||||
return {
|
return {
|
||||||
isConnected,
|
isConnected,
|
||||||
isConnectionInProgress,
|
isConnectionInProgress,
|
||||||
isConnectionStartField,
|
isConnectionStartField,
|
||||||
connectionError,
|
validationResult,
|
||||||
shouldDim,
|
shouldDim,
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
@ -2,8 +2,7 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
|||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||||
import type { NodesState, PendingConnection, Templates } from 'features/nodes/store/types';
|
import type { NodesState, PendingConnection, Templates } from 'features/nodes/store/types';
|
||||||
import { validateConnection } from 'features/nodes/store/util/validateConnection';
|
import { buildRejectResult, validateConnection } from 'features/nodes/store/util/validateConnection';
|
||||||
import i18n from 'i18next';
|
|
||||||
import type { Edge, HandleType } from 'reactflow';
|
import type { Edge, HandleType } from 'reactflow';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -33,14 +32,14 @@ export const makeConnectionErrorSelector = (
|
|||||||
const { nodes, edges } = nodesSlice;
|
const { nodes, edges } = nodesSlice;
|
||||||
|
|
||||||
if (!pendingConnection) {
|
if (!pendingConnection) {
|
||||||
return i18n.t('nodes.noConnectionInProgress');
|
return buildRejectResult('nodes.noConnectionInProgress');
|
||||||
}
|
}
|
||||||
|
|
||||||
if (handleType === pendingConnection.handleType) {
|
if (handleType === pendingConnection.handleType) {
|
||||||
if (handleType === 'source') {
|
if (handleType === 'source') {
|
||||||
return i18n.t('nodes.cannotConnectOutputToOutput');
|
return buildRejectResult('nodes.cannotConnectOutputToOutput');
|
||||||
}
|
}
|
||||||
return i18n.t('nodes.cannotConnectInputToInput');
|
return buildRejectResult('nodes.cannotConnectInputToInput');
|
||||||
}
|
}
|
||||||
|
|
||||||
// we have to figure out which is the target and which is the source
|
// we have to figure out which is the target and which is the source
|
||||||
@ -62,9 +61,7 @@ export const makeConnectionErrorSelector = (
|
|||||||
edgePendingUpdate
|
edgePendingUpdate
|
||||||
);
|
);
|
||||||
|
|
||||||
if (!validationResult.isValid) {
|
return validationResult;
|
||||||
return i18n.t(validationResult.messageTKey);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -9,7 +9,7 @@ import type { O } from 'ts-toolbelt';
|
|||||||
|
|
||||||
type Connection = O.NonNullable<NullableConnection>;
|
type Connection = O.NonNullable<NullableConnection>;
|
||||||
|
|
||||||
type ValidateConnectionResult =
|
export type ValidationResult =
|
||||||
| {
|
| {
|
||||||
isValid: true;
|
isValid: true;
|
||||||
messageTKey?: string;
|
messageTKey?: string;
|
||||||
@ -26,7 +26,7 @@ type ValidateConnectionFunc = (
|
|||||||
templates: Templates,
|
templates: Templates,
|
||||||
ignoreEdge: Edge | null,
|
ignoreEdge: Edge | null,
|
||||||
strict?: boolean
|
strict?: boolean
|
||||||
) => ValidateConnectionResult;
|
) => ValidationResult;
|
||||||
|
|
||||||
const getEqualityPredicate =
|
const getEqualityPredicate =
|
||||||
(c: Connection) =>
|
(c: Connection) =>
|
||||||
@ -45,8 +45,8 @@ const getTargetEqualityPredicate =
|
|||||||
return e.target === c.target && e.targetHandle === c.targetHandle;
|
return e.target === c.target && e.targetHandle === c.targetHandle;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const buildAcceptResult = (): ValidateConnectionResult => ({ isValid: true });
|
export const buildAcceptResult = (): ValidationResult => ({ isValid: true });
|
||||||
export const buildRejectResult = (messageTKey: string): ValidateConnectionResult => ({ isValid: false, messageTKey });
|
export const buildRejectResult = (messageTKey: string): ValidationResult => ({ isValid: false, messageTKey });
|
||||||
|
|
||||||
export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, templates, ignoreEdge, strict = true) => {
|
export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, templates, ignoreEdge, strict = true) => {
|
||||||
if (c.source === c.target) {
|
if (c.source === c.target) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user