mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): use connection validationResults directly in components
This commit is contained in:
parent
26d0d55d97
commit
89b0e9e4de
@ -2,10 +2,12 @@ import { Tooltip } from '@invoke-ai/ui-library';
|
||||
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
|
||||
import { getFieldColor } from 'features/nodes/components/flow/edges/util/getEdgeColor';
|
||||
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
|
||||
import type { ValidationResult } from 'features/nodes/store/util/validateConnection';
|
||||
import { HANDLE_TOOLTIP_OPEN_DELAY, MODEL_TYPES } from 'features/nodes/types/constants';
|
||||
import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field';
|
||||
import type { CSSProperties } from 'react';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { HandleType } from 'reactflow';
|
||||
import { Handle, Position } from 'reactflow';
|
||||
|
||||
@ -14,11 +16,12 @@ type FieldHandleProps = {
|
||||
handleType: HandleType;
|
||||
isConnectionInProgress: boolean;
|
||||
isConnectionStartField: boolean;
|
||||
connectionError?: string;
|
||||
validationResult: ValidationResult;
|
||||
};
|
||||
|
||||
const FieldHandle = (props: FieldHandleProps) => {
|
||||
const { fieldTemplate, handleType, isConnectionInProgress, isConnectionStartField, connectionError } = props;
|
||||
const { fieldTemplate, handleType, isConnectionInProgress, isConnectionStartField, validationResult } = props;
|
||||
const { t } = useTranslation();
|
||||
const { name } = fieldTemplate;
|
||||
const type = fieldTemplate.type;
|
||||
const fieldTypeName = useFieldTypeName(type);
|
||||
@ -43,11 +46,11 @@ const FieldHandle = (props: FieldHandleProps) => {
|
||||
s.insetInlineEnd = '-1rem';
|
||||
}
|
||||
|
||||
if (isConnectionInProgress && !isConnectionStartField && connectionError) {
|
||||
if (isConnectionInProgress && !isConnectionStartField && !validationResult.isValid) {
|
||||
s.filter = 'opacity(0.4) grayscale(0.7)';
|
||||
}
|
||||
|
||||
if (isConnectionInProgress && connectionError) {
|
||||
if (isConnectionInProgress && !validationResult.isValid) {
|
||||
if (isConnectionStartField) {
|
||||
s.cursor = 'grab';
|
||||
} else {
|
||||
@ -58,14 +61,14 @@ const FieldHandle = (props: FieldHandleProps) => {
|
||||
}
|
||||
|
||||
return s;
|
||||
}, [connectionError, handleType, isConnectionInProgress, isConnectionStartField, type]);
|
||||
}, [handleType, isConnectionInProgress, isConnectionStartField, type, validationResult.isValid]);
|
||||
|
||||
const tooltip = useMemo(() => {
|
||||
if (isConnectionInProgress && connectionError) {
|
||||
return connectionError;
|
||||
if (isConnectionInProgress && validationResult.messageTKey) {
|
||||
return t(validationResult.messageTKey);
|
||||
}
|
||||
return fieldTypeName;
|
||||
}, [connectionError, fieldTypeName, isConnectionInProgress]);
|
||||
}, [fieldTypeName, isConnectionInProgress, t, validationResult.messageTKey]);
|
||||
|
||||
return (
|
||||
<Tooltip
|
||||
|
@ -24,7 +24,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
|
||||
const doesFieldHaveValue = useDoesInputHaveValue(nodeId, fieldName);
|
||||
const [isHovered, setIsHovered] = useState(false);
|
||||
|
||||
const { isConnected, isConnectionInProgress, isConnectionStartField, connectionError, shouldDim } =
|
||||
const { isConnected, isConnectionInProgress, isConnectionStartField, validationResult, shouldDim } =
|
||||
useConnectionState({ nodeId, fieldName, kind: 'inputs' });
|
||||
|
||||
const isMissingInput = useMemo(() => {
|
||||
@ -88,7 +88,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
|
||||
handleType="target"
|
||||
isConnectionInProgress={isConnectionInProgress}
|
||||
isConnectionStartField={isConnectionStartField}
|
||||
connectionError={connectionError}
|
||||
validationResult={validationResult}
|
||||
/>
|
||||
</InputFieldWrapper>
|
||||
);
|
||||
@ -126,7 +126,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
|
||||
handleType="target"
|
||||
isConnectionInProgress={isConnectionInProgress}
|
||||
isConnectionStartField={isConnectionStartField}
|
||||
connectionError={connectionError}
|
||||
validationResult={validationResult}
|
||||
/>
|
||||
)}
|
||||
</InputFieldWrapper>
|
||||
|
@ -18,7 +18,7 @@ const OutputField = ({ nodeId, fieldName }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const fieldTemplate = useFieldOutputTemplate(nodeId, fieldName);
|
||||
|
||||
const { isConnected, isConnectionInProgress, isConnectionStartField, connectionError, shouldDim } =
|
||||
const { isConnected, isConnectionInProgress, isConnectionStartField, validationResult, shouldDim } =
|
||||
useConnectionState({ nodeId, fieldName, kind: 'outputs' });
|
||||
|
||||
if (!fieldTemplate) {
|
||||
@ -52,7 +52,7 @@ const OutputField = ({ nodeId, fieldName }: Props) => {
|
||||
handleType="source"
|
||||
isConnectionInProgress={isConnectionInProgress}
|
||||
isConnectionStartField={isConnectionStartField}
|
||||
connectionError={connectionError}
|
||||
validationResult={validationResult}
|
||||
/>
|
||||
</OutputFieldWrapper>
|
||||
);
|
||||
|
@ -31,7 +31,7 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta
|
||||
[fieldName, kind, nodeId]
|
||||
);
|
||||
|
||||
const selectConnectionError = useMemo(
|
||||
const selectValidationResult = useMemo(
|
||||
() => makeConnectionErrorSelector(templates, nodeId, fieldName, kind === 'inputs' ? 'target' : 'source'),
|
||||
[templates, nodeId, fieldName, kind]
|
||||
);
|
||||
@ -48,18 +48,18 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta
|
||||
pendingConnection.fieldTemplate.fieldKind === { inputs: 'input', outputs: 'output' }[kind]
|
||||
);
|
||||
}, [fieldName, kind, nodeId, pendingConnection]);
|
||||
const connectionError = useAppSelector((s) => selectConnectionError(s, pendingConnection, edgePendingUpdate));
|
||||
const validationResult = useAppSelector((s) => selectValidationResult(s, pendingConnection, edgePendingUpdate));
|
||||
|
||||
const shouldDim = useMemo(
|
||||
() => Boolean(isConnectionInProgress && connectionError && !isConnectionStartField),
|
||||
[connectionError, isConnectionInProgress, isConnectionStartField]
|
||||
() => Boolean(isConnectionInProgress && !validationResult.isValid && !isConnectionStartField),
|
||||
[validationResult, isConnectionInProgress, isConnectionStartField]
|
||||
);
|
||||
|
||||
return {
|
||||
isConnected,
|
||||
isConnectionInProgress,
|
||||
isConnectionStartField,
|
||||
connectionError,
|
||||
validationResult,
|
||||
shouldDim,
|
||||
};
|
||||
};
|
||||
|
@ -2,8 +2,7 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import type { NodesState, PendingConnection, Templates } from 'features/nodes/store/types';
|
||||
import { validateConnection } from 'features/nodes/store/util/validateConnection';
|
||||
import i18n from 'i18next';
|
||||
import { buildRejectResult, validateConnection } from 'features/nodes/store/util/validateConnection';
|
||||
import type { Edge, HandleType } from 'reactflow';
|
||||
|
||||
/**
|
||||
@ -33,14 +32,14 @@ export const makeConnectionErrorSelector = (
|
||||
const { nodes, edges } = nodesSlice;
|
||||
|
||||
if (!pendingConnection) {
|
||||
return i18n.t('nodes.noConnectionInProgress');
|
||||
return buildRejectResult('nodes.noConnectionInProgress');
|
||||
}
|
||||
|
||||
if (handleType === pendingConnection.handleType) {
|
||||
if (handleType === 'source') {
|
||||
return i18n.t('nodes.cannotConnectOutputToOutput');
|
||||
return buildRejectResult('nodes.cannotConnectOutputToOutput');
|
||||
}
|
||||
return i18n.t('nodes.cannotConnectInputToInput');
|
||||
return buildRejectResult('nodes.cannotConnectInputToInput');
|
||||
}
|
||||
|
||||
// we have to figure out which is the target and which is the source
|
||||
@ -62,9 +61,7 @@ export const makeConnectionErrorSelector = (
|
||||
edgePendingUpdate
|
||||
);
|
||||
|
||||
if (!validationResult.isValid) {
|
||||
return i18n.t(validationResult.messageTKey);
|
||||
}
|
||||
return validationResult;
|
||||
}
|
||||
);
|
||||
};
|
||||
|
@ -9,7 +9,7 @@ import type { O } from 'ts-toolbelt';
|
||||
|
||||
type Connection = O.NonNullable<NullableConnection>;
|
||||
|
||||
type ValidateConnectionResult =
|
||||
export type ValidationResult =
|
||||
| {
|
||||
isValid: true;
|
||||
messageTKey?: string;
|
||||
@ -26,7 +26,7 @@ type ValidateConnectionFunc = (
|
||||
templates: Templates,
|
||||
ignoreEdge: Edge | null,
|
||||
strict?: boolean
|
||||
) => ValidateConnectionResult;
|
||||
) => ValidationResult;
|
||||
|
||||
const getEqualityPredicate =
|
||||
(c: Connection) =>
|
||||
@ -45,8 +45,8 @@ const getTargetEqualityPredicate =
|
||||
return e.target === c.target && e.targetHandle === c.targetHandle;
|
||||
};
|
||||
|
||||
export const buildAcceptResult = (): ValidateConnectionResult => ({ isValid: true });
|
||||
export const buildRejectResult = (messageTKey: string): ValidateConnectionResult => ({ isValid: false, messageTKey });
|
||||
export const buildAcceptResult = (): ValidationResult => ({ isValid: true });
|
||||
export const buildRejectResult = (messageTKey: string): ValidationResult => ({ isValid: false, messageTKey });
|
||||
|
||||
export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, templates, ignoreEdge, strict = true) => {
|
||||
if (c.source === c.target) {
|
||||
|
Loading…
Reference in New Issue
Block a user