feat(ui): use connection validationResults directly in components

This commit is contained in:
psychedelicious 2024-05-19 17:07:55 +10:00
parent 26d0d55d97
commit 89b0e9e4de
6 changed files with 30 additions and 30 deletions

View File

@ -2,10 +2,12 @@ import { Tooltip } from '@invoke-ai/ui-library';
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { getFieldColor } from 'features/nodes/components/flow/edges/util/getEdgeColor';
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
import type { ValidationResult } from 'features/nodes/store/util/validateConnection';
import { HANDLE_TOOLTIP_OPEN_DELAY, MODEL_TYPES } from 'features/nodes/types/constants';
import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field';
import type { CSSProperties } from 'react';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { HandleType } from 'reactflow';
import { Handle, Position } from 'reactflow';
@ -14,11 +16,12 @@ type FieldHandleProps = {
handleType: HandleType;
isConnectionInProgress: boolean;
isConnectionStartField: boolean;
connectionError?: string;
validationResult: ValidationResult;
};
const FieldHandle = (props: FieldHandleProps) => {
const { fieldTemplate, handleType, isConnectionInProgress, isConnectionStartField, connectionError } = props;
const { fieldTemplate, handleType, isConnectionInProgress, isConnectionStartField, validationResult } = props;
const { t } = useTranslation();
const { name } = fieldTemplate;
const type = fieldTemplate.type;
const fieldTypeName = useFieldTypeName(type);
@ -43,11 +46,11 @@ const FieldHandle = (props: FieldHandleProps) => {
s.insetInlineEnd = '-1rem';
}
if (isConnectionInProgress && !isConnectionStartField && connectionError) {
if (isConnectionInProgress && !isConnectionStartField && !validationResult.isValid) {
s.filter = 'opacity(0.4) grayscale(0.7)';
}
if (isConnectionInProgress && connectionError) {
if (isConnectionInProgress && !validationResult.isValid) {
if (isConnectionStartField) {
s.cursor = 'grab';
} else {
@ -58,14 +61,14 @@ const FieldHandle = (props: FieldHandleProps) => {
}
return s;
}, [connectionError, handleType, isConnectionInProgress, isConnectionStartField, type]);
}, [handleType, isConnectionInProgress, isConnectionStartField, type, validationResult.isValid]);
const tooltip = useMemo(() => {
if (isConnectionInProgress && connectionError) {
return connectionError;
if (isConnectionInProgress && validationResult.messageTKey) {
return t(validationResult.messageTKey);
}
return fieldTypeName;
}, [connectionError, fieldTypeName, isConnectionInProgress]);
}, [fieldTypeName, isConnectionInProgress, t, validationResult.messageTKey]);
return (
<Tooltip

View File

@ -24,7 +24,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
const doesFieldHaveValue = useDoesInputHaveValue(nodeId, fieldName);
const [isHovered, setIsHovered] = useState(false);
const { isConnected, isConnectionInProgress, isConnectionStartField, connectionError, shouldDim } =
const { isConnected, isConnectionInProgress, isConnectionStartField, validationResult, shouldDim } =
useConnectionState({ nodeId, fieldName, kind: 'inputs' });
const isMissingInput = useMemo(() => {
@ -88,7 +88,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
handleType="target"
isConnectionInProgress={isConnectionInProgress}
isConnectionStartField={isConnectionStartField}
connectionError={connectionError}
validationResult={validationResult}
/>
</InputFieldWrapper>
);
@ -126,7 +126,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
handleType="target"
isConnectionInProgress={isConnectionInProgress}
isConnectionStartField={isConnectionStartField}
connectionError={connectionError}
validationResult={validationResult}
/>
)}
</InputFieldWrapper>

View File

@ -18,7 +18,7 @@ const OutputField = ({ nodeId, fieldName }: Props) => {
const { t } = useTranslation();
const fieldTemplate = useFieldOutputTemplate(nodeId, fieldName);
const { isConnected, isConnectionInProgress, isConnectionStartField, connectionError, shouldDim } =
const { isConnected, isConnectionInProgress, isConnectionStartField, validationResult, shouldDim } =
useConnectionState({ nodeId, fieldName, kind: 'outputs' });
if (!fieldTemplate) {
@ -52,7 +52,7 @@ const OutputField = ({ nodeId, fieldName }: Props) => {
handleType="source"
isConnectionInProgress={isConnectionInProgress}
isConnectionStartField={isConnectionStartField}
connectionError={connectionError}
validationResult={validationResult}
/>
</OutputFieldWrapper>
);

View File

@ -31,7 +31,7 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta
[fieldName, kind, nodeId]
);
const selectConnectionError = useMemo(
const selectValidationResult = useMemo(
() => makeConnectionErrorSelector(templates, nodeId, fieldName, kind === 'inputs' ? 'target' : 'source'),
[templates, nodeId, fieldName, kind]
);
@ -48,18 +48,18 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta
pendingConnection.fieldTemplate.fieldKind === { inputs: 'input', outputs: 'output' }[kind]
);
}, [fieldName, kind, nodeId, pendingConnection]);
const connectionError = useAppSelector((s) => selectConnectionError(s, pendingConnection, edgePendingUpdate));
const validationResult = useAppSelector((s) => selectValidationResult(s, pendingConnection, edgePendingUpdate));
const shouldDim = useMemo(
() => Boolean(isConnectionInProgress && connectionError && !isConnectionStartField),
[connectionError, isConnectionInProgress, isConnectionStartField]
() => Boolean(isConnectionInProgress && !validationResult.isValid && !isConnectionStartField),
[validationResult, isConnectionInProgress, isConnectionStartField]
);
return {
isConnected,
isConnectionInProgress,
isConnectionStartField,
connectionError,
validationResult,
shouldDim,
};
};

View File

@ -2,8 +2,7 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import type { RootState } from 'app/store/store';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import type { NodesState, PendingConnection, Templates } from 'features/nodes/store/types';
import { validateConnection } from 'features/nodes/store/util/validateConnection';
import i18n from 'i18next';
import { buildRejectResult, validateConnection } from 'features/nodes/store/util/validateConnection';
import type { Edge, HandleType } from 'reactflow';
/**
@ -33,14 +32,14 @@ export const makeConnectionErrorSelector = (
const { nodes, edges } = nodesSlice;
if (!pendingConnection) {
return i18n.t('nodes.noConnectionInProgress');
return buildRejectResult('nodes.noConnectionInProgress');
}
if (handleType === pendingConnection.handleType) {
if (handleType === 'source') {
return i18n.t('nodes.cannotConnectOutputToOutput');
return buildRejectResult('nodes.cannotConnectOutputToOutput');
}
return i18n.t('nodes.cannotConnectInputToInput');
return buildRejectResult('nodes.cannotConnectInputToInput');
}
// we have to figure out which is the target and which is the source
@ -62,9 +61,7 @@ export const makeConnectionErrorSelector = (
edgePendingUpdate
);
if (!validationResult.isValid) {
return i18n.t(validationResult.messageTKey);
}
return validationResult;
}
);
};

View File

@ -9,7 +9,7 @@ import type { O } from 'ts-toolbelt';
type Connection = O.NonNullable<NullableConnection>;
type ValidateConnectionResult =
export type ValidationResult =
| {
isValid: true;
messageTKey?: string;
@ -26,7 +26,7 @@ type ValidateConnectionFunc = (
templates: Templates,
ignoreEdge: Edge | null,
strict?: boolean
) => ValidateConnectionResult;
) => ValidationResult;
const getEqualityPredicate =
(c: Connection) =>
@ -45,8 +45,8 @@ const getTargetEqualityPredicate =
return e.target === c.target && e.targetHandle === c.targetHandle;
};
export const buildAcceptResult = (): ValidateConnectionResult => ({ isValid: true });
export const buildRejectResult = (messageTKey: string): ValidateConnectionResult => ({ isValid: false, messageTKey });
export const buildAcceptResult = (): ValidationResult => ({ isValid: true });
export const buildRejectResult = (messageTKey: string): ValidationResult => ({ isValid: false, messageTKey });
export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, templates, ignoreEdge, strict = true) => {
if (c.source === c.target) {