feat(ui): use connection validationResults directly in components

This commit is contained in:
psychedelicious 2024-05-19 17:07:55 +10:00
parent 26d0d55d97
commit 89b0e9e4de
6 changed files with 30 additions and 30 deletions

View File

@ -2,10 +2,12 @@ import { Tooltip } from '@invoke-ai/ui-library';
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { getFieldColor } from 'features/nodes/components/flow/edges/util/getEdgeColor'; import { getFieldColor } from 'features/nodes/components/flow/edges/util/getEdgeColor';
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType'; import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
import type { ValidationResult } from 'features/nodes/store/util/validateConnection';
import { HANDLE_TOOLTIP_OPEN_DELAY, MODEL_TYPES } from 'features/nodes/types/constants'; import { HANDLE_TOOLTIP_OPEN_DELAY, MODEL_TYPES } from 'features/nodes/types/constants';
import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field'; import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field';
import type { CSSProperties } from 'react'; import type { CSSProperties } from 'react';
import { memo, useMemo } from 'react'; import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { HandleType } from 'reactflow'; import type { HandleType } from 'reactflow';
import { Handle, Position } from 'reactflow'; import { Handle, Position } from 'reactflow';
@ -14,11 +16,12 @@ type FieldHandleProps = {
handleType: HandleType; handleType: HandleType;
isConnectionInProgress: boolean; isConnectionInProgress: boolean;
isConnectionStartField: boolean; isConnectionStartField: boolean;
connectionError?: string; validationResult: ValidationResult;
}; };
const FieldHandle = (props: FieldHandleProps) => { const FieldHandle = (props: FieldHandleProps) => {
const { fieldTemplate, handleType, isConnectionInProgress, isConnectionStartField, connectionError } = props; const { fieldTemplate, handleType, isConnectionInProgress, isConnectionStartField, validationResult } = props;
const { t } = useTranslation();
const { name } = fieldTemplate; const { name } = fieldTemplate;
const type = fieldTemplate.type; const type = fieldTemplate.type;
const fieldTypeName = useFieldTypeName(type); const fieldTypeName = useFieldTypeName(type);
@ -43,11 +46,11 @@ const FieldHandle = (props: FieldHandleProps) => {
s.insetInlineEnd = '-1rem'; s.insetInlineEnd = '-1rem';
} }
if (isConnectionInProgress && !isConnectionStartField && connectionError) { if (isConnectionInProgress && !isConnectionStartField && !validationResult.isValid) {
s.filter = 'opacity(0.4) grayscale(0.7)'; s.filter = 'opacity(0.4) grayscale(0.7)';
} }
if (isConnectionInProgress && connectionError) { if (isConnectionInProgress && !validationResult.isValid) {
if (isConnectionStartField) { if (isConnectionStartField) {
s.cursor = 'grab'; s.cursor = 'grab';
} else { } else {
@ -58,14 +61,14 @@ const FieldHandle = (props: FieldHandleProps) => {
} }
return s; return s;
}, [connectionError, handleType, isConnectionInProgress, isConnectionStartField, type]); }, [handleType, isConnectionInProgress, isConnectionStartField, type, validationResult.isValid]);
const tooltip = useMemo(() => { const tooltip = useMemo(() => {
if (isConnectionInProgress && connectionError) { if (isConnectionInProgress && validationResult.messageTKey) {
return connectionError; return t(validationResult.messageTKey);
} }
return fieldTypeName; return fieldTypeName;
}, [connectionError, fieldTypeName, isConnectionInProgress]); }, [fieldTypeName, isConnectionInProgress, t, validationResult.messageTKey]);
return ( return (
<Tooltip <Tooltip

View File

@ -24,7 +24,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
const doesFieldHaveValue = useDoesInputHaveValue(nodeId, fieldName); const doesFieldHaveValue = useDoesInputHaveValue(nodeId, fieldName);
const [isHovered, setIsHovered] = useState(false); const [isHovered, setIsHovered] = useState(false);
const { isConnected, isConnectionInProgress, isConnectionStartField, connectionError, shouldDim } = const { isConnected, isConnectionInProgress, isConnectionStartField, validationResult, shouldDim } =
useConnectionState({ nodeId, fieldName, kind: 'inputs' }); useConnectionState({ nodeId, fieldName, kind: 'inputs' });
const isMissingInput = useMemo(() => { const isMissingInput = useMemo(() => {
@ -88,7 +88,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
handleType="target" handleType="target"
isConnectionInProgress={isConnectionInProgress} isConnectionInProgress={isConnectionInProgress}
isConnectionStartField={isConnectionStartField} isConnectionStartField={isConnectionStartField}
connectionError={connectionError} validationResult={validationResult}
/> />
</InputFieldWrapper> </InputFieldWrapper>
); );
@ -126,7 +126,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
handleType="target" handleType="target"
isConnectionInProgress={isConnectionInProgress} isConnectionInProgress={isConnectionInProgress}
isConnectionStartField={isConnectionStartField} isConnectionStartField={isConnectionStartField}
connectionError={connectionError} validationResult={validationResult}
/> />
)} )}
</InputFieldWrapper> </InputFieldWrapper>

View File

@ -18,7 +18,7 @@ const OutputField = ({ nodeId, fieldName }: Props) => {
const { t } = useTranslation(); const { t } = useTranslation();
const fieldTemplate = useFieldOutputTemplate(nodeId, fieldName); const fieldTemplate = useFieldOutputTemplate(nodeId, fieldName);
const { isConnected, isConnectionInProgress, isConnectionStartField, connectionError, shouldDim } = const { isConnected, isConnectionInProgress, isConnectionStartField, validationResult, shouldDim } =
useConnectionState({ nodeId, fieldName, kind: 'outputs' }); useConnectionState({ nodeId, fieldName, kind: 'outputs' });
if (!fieldTemplate) { if (!fieldTemplate) {
@ -52,7 +52,7 @@ const OutputField = ({ nodeId, fieldName }: Props) => {
handleType="source" handleType="source"
isConnectionInProgress={isConnectionInProgress} isConnectionInProgress={isConnectionInProgress}
isConnectionStartField={isConnectionStartField} isConnectionStartField={isConnectionStartField}
connectionError={connectionError} validationResult={validationResult}
/> />
</OutputFieldWrapper> </OutputFieldWrapper>
); );

View File

@ -31,7 +31,7 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta
[fieldName, kind, nodeId] [fieldName, kind, nodeId]
); );
const selectConnectionError = useMemo( const selectValidationResult = useMemo(
() => makeConnectionErrorSelector(templates, nodeId, fieldName, kind === 'inputs' ? 'target' : 'source'), () => makeConnectionErrorSelector(templates, nodeId, fieldName, kind === 'inputs' ? 'target' : 'source'),
[templates, nodeId, fieldName, kind] [templates, nodeId, fieldName, kind]
); );
@ -48,18 +48,18 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta
pendingConnection.fieldTemplate.fieldKind === { inputs: 'input', outputs: 'output' }[kind] pendingConnection.fieldTemplate.fieldKind === { inputs: 'input', outputs: 'output' }[kind]
); );
}, [fieldName, kind, nodeId, pendingConnection]); }, [fieldName, kind, nodeId, pendingConnection]);
const connectionError = useAppSelector((s) => selectConnectionError(s, pendingConnection, edgePendingUpdate)); const validationResult = useAppSelector((s) => selectValidationResult(s, pendingConnection, edgePendingUpdate));
const shouldDim = useMemo( const shouldDim = useMemo(
() => Boolean(isConnectionInProgress && connectionError && !isConnectionStartField), () => Boolean(isConnectionInProgress && !validationResult.isValid && !isConnectionStartField),
[connectionError, isConnectionInProgress, isConnectionStartField] [validationResult, isConnectionInProgress, isConnectionStartField]
); );
return { return {
isConnected, isConnected,
isConnectionInProgress, isConnectionInProgress,
isConnectionStartField, isConnectionStartField,
connectionError, validationResult,
shouldDim, shouldDim,
}; };
}; };

View File

@ -2,8 +2,7 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import type { RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import type { NodesState, PendingConnection, Templates } from 'features/nodes/store/types'; import type { NodesState, PendingConnection, Templates } from 'features/nodes/store/types';
import { validateConnection } from 'features/nodes/store/util/validateConnection'; import { buildRejectResult, validateConnection } from 'features/nodes/store/util/validateConnection';
import i18n from 'i18next';
import type { Edge, HandleType } from 'reactflow'; import type { Edge, HandleType } from 'reactflow';
/** /**
@ -33,14 +32,14 @@ export const makeConnectionErrorSelector = (
const { nodes, edges } = nodesSlice; const { nodes, edges } = nodesSlice;
if (!pendingConnection) { if (!pendingConnection) {
return i18n.t('nodes.noConnectionInProgress'); return buildRejectResult('nodes.noConnectionInProgress');
} }
if (handleType === pendingConnection.handleType) { if (handleType === pendingConnection.handleType) {
if (handleType === 'source') { if (handleType === 'source') {
return i18n.t('nodes.cannotConnectOutputToOutput'); return buildRejectResult('nodes.cannotConnectOutputToOutput');
} }
return i18n.t('nodes.cannotConnectInputToInput'); return buildRejectResult('nodes.cannotConnectInputToInput');
} }
// we have to figure out which is the target and which is the source // we have to figure out which is the target and which is the source
@ -62,9 +61,7 @@ export const makeConnectionErrorSelector = (
edgePendingUpdate edgePendingUpdate
); );
if (!validationResult.isValid) { return validationResult;
return i18n.t(validationResult.messageTKey);
}
} }
); );
}; };

View File

@ -9,7 +9,7 @@ import type { O } from 'ts-toolbelt';
type Connection = O.NonNullable<NullableConnection>; type Connection = O.NonNullable<NullableConnection>;
type ValidateConnectionResult = export type ValidationResult =
| { | {
isValid: true; isValid: true;
messageTKey?: string; messageTKey?: string;
@ -26,7 +26,7 @@ type ValidateConnectionFunc = (
templates: Templates, templates: Templates,
ignoreEdge: Edge | null, ignoreEdge: Edge | null,
strict?: boolean strict?: boolean
) => ValidateConnectionResult; ) => ValidationResult;
const getEqualityPredicate = const getEqualityPredicate =
(c: Connection) => (c: Connection) =>
@ -45,8 +45,8 @@ const getTargetEqualityPredicate =
return e.target === c.target && e.targetHandle === c.targetHandle; return e.target === c.target && e.targetHandle === c.targetHandle;
}; };
export const buildAcceptResult = (): ValidateConnectionResult => ({ isValid: true }); export const buildAcceptResult = (): ValidationResult => ({ isValid: true });
export const buildRejectResult = (messageTKey: string): ValidateConnectionResult => ({ isValid: false, messageTKey }); export const buildRejectResult = (messageTKey: string): ValidationResult => ({ isValid: false, messageTKey });
export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, templates, ignoreEdge, strict = true) => { export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, templates, ignoreEdge, strict = true) => {
if (c.source === c.target) { if (c.source === c.target) {