feat(nodes): add invocation classifications

Invocations now have a classification:
- Stable: LTS
- Beta: LTS planned, API may change
- Prototype: No LTS planned, API may change, may be removed entirely

The `@invocation` decorator has a new arg `classification`, and an enum `Classification` is added to `baseinvocation.py`.

The default is Stable; this is a non-breaking change.

The classification is presented in the node header as a hammer icon (Beta) or flask icon (prototype).

The icon has a tooltip briefly describing the classification.
This commit is contained in:
psychedelicious 2023-12-12 15:04:13 +11:00
parent 22ccaa4e9a
commit 43f2837117
9 changed files with 656 additions and 36 deletions

View File

@ -39,6 +39,19 @@ class InvalidFieldError(TypeError):
pass
class Classification(str, Enum, metaclass=MetaEnum):
"""
The feature classification of an Invocation.
- `Stable`: The invocation, including its inputs/outputs and internal logic, is stable. You may build workflows with it, having confidence that they will not break because of a change in this invocation.
- `Beta`: The invocation is not yet stable, but is planned to be stable in the future. Workflows built around this invocation may break, but we are committed to supporting this invocation long-term.
- `Prototype`: The invocation is not yet stable and may be removed from the application at any time. Workflows built around this invocation may break, and we are *not* committed to supporting this invocation.
"""
Stable = "stable"
Beta = "beta"
Prototype = "prototype"
class Input(str, Enum, metaclass=MetaEnum):
"""
The type of input a field accepts.
@ -439,6 +452,7 @@ class UIConfigBase(BaseModel):
description='The node\'s version. Should be a valid semver string e.g. "1.0.0" or "3.8.13".',
)
node_pack: Optional[str] = Field(default=None, description="Whether or not this is a custom node")
classification: Classification = Field(default=Classification.Stable, description="The node's classification")
model_config = ConfigDict(
validate_assignment=True,
@ -607,6 +621,7 @@ class BaseInvocation(ABC, BaseModel):
schema["category"] = uiconfig.category
if uiconfig.node_pack is not None:
schema["node_pack"] = uiconfig.node_pack
schema["classification"] = uiconfig.classification
schema["version"] = uiconfig.version
if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = []
@ -782,6 +797,7 @@ def invocation(
category: Optional[str] = None,
version: Optional[str] = None,
use_cache: Optional[bool] = True,
classification: Classification = Classification.Stable,
) -> Callable[[Type[TBaseInvocation]], Type[TBaseInvocation]]:
"""
Registers an invocation.
@ -792,6 +808,7 @@ def invocation(
:param Optional[str] category: Adds a category to the invocation. Used to group the invocations in the UI. Defaults to None.
:param Optional[str] version: Adds a version to the invocation. Must be a valid semver string. Defaults to None.
:param Optional[bool] use_cache: Whether or not to use the invocation cache. Defaults to True. The user may override this in the workflow editor.
:param FeatureClassification classification: The classification of the invocation. Defaults to FeatureClassification.Stable. Use Beta or Prototype if the invocation is unstable.
"""
def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]:
@ -812,6 +829,7 @@ def invocation(
cls.UIConfig.title = title
cls.UIConfig.tags = tags
cls.UIConfig.category = category
cls.UIConfig.classification = classification
# Grab the node pack's name from the module name, if it's a custom node
is_custom_node = cls.__module__.rsplit(".", 1)[0] == "invokeai.app.invocations"

View File

@ -1032,7 +1032,9 @@
"workflowValidation": "Workflow Validation Error",
"workflowVersion": "Version",
"zoomInNodes": "Zoom In",
"zoomOutNodes": "Zoom Out"
"zoomOutNodes": "Zoom Out",
"betaDesc": "This invocation is in beta. Until it is stable, it may have breaking changes during app updates. We plan to support this invocation long-term.",
"prototypeDesc": "This invocation is a prototype. It may have breaking changes during app updates and may be removed at any time."
},
"parameters": {
"aspectRatio": "Aspect Ratio",

View File

@ -0,0 +1,68 @@
import { Icon, Tooltip } from '@chakra-ui/react';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { FaFlask } from 'react-icons/fa';
import { useNodeClassification } from 'features/nodes/hooks/useNodeClassification';
import { Classification } from 'features/nodes/types/common';
import { FaHammer } from 'react-icons/fa6';
interface Props {
nodeId: string;
}
const InvocationNodeClassificationIcon = ({ nodeId }: Props) => {
const classification = useNodeClassification(nodeId);
if (!classification || classification === 'stable') {
return null;
}
return (
<Tooltip
label={<ClassificationTooltipContent classification={classification} />}
placement="top"
shouldWrapChildren
>
<Icon
as={getIcon(classification)}
sx={{
display: 'block',
boxSize: 4,
color: 'base.400',
}}
/>
</Tooltip>
);
};
export default memo(InvocationNodeClassificationIcon);
const ClassificationTooltipContent = memo(
({ classification }: { classification: Classification }) => {
const { t } = useTranslation();
if (classification === 'beta') {
return t('nodes.betaDesc');
}
if (classification === 'prototype') {
return t('nodes.prototypeDesc');
}
return null;
}
);
ClassificationTooltipContent.displayName = 'ClassificationTooltipContent';
const getIcon = (classification: Classification) => {
if (classification === 'beta') {
return FaHammer;
}
if (classification === 'prototype') {
return FaFlask;
}
return undefined;
};

View File

@ -5,6 +5,7 @@ import NodeTitle from 'features/nodes/components/flow/nodes/common/NodeTitle';
import InvocationNodeCollapsedHandles from './InvocationNodeCollapsedHandles';
import InvocationNodeInfoIcon from './InvocationNodeInfoIcon';
import InvocationNodeStatusIndicator from './InvocationNodeStatusIndicator';
import InvocationNodeClassificationIcon from 'features/nodes/components/flow/nodes/Invocation/InvocationNodeClassificationIcon';
type Props = {
nodeId: string;
@ -31,6 +32,7 @@ const InvocationNodeHeader = ({ nodeId, isOpen }: Props) => {
}}
>
<NodeCollapseButton nodeId={nodeId} isOpen={isOpen} />
<InvocationNodeClassificationIcon nodeId={nodeId} />
<NodeTitle nodeId={nodeId} />
<Flex alignItems="center">
<InvocationNodeStatusIndicator nodeId={nodeId} />

View File

@ -0,0 +1,23 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { useMemo } from 'react';
export const useNodeClassification = (nodeId: string) => {
const selector = useMemo(
() =>
createMemoizedSelector(stateSelector, ({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return false;
}
const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? ''];
return nodeTemplate?.classification;
}),
[nodeId]
);
const title = useAppSelector(selector);
return title;
};

View File

@ -19,6 +19,9 @@ export const zColorField = z.object({
});
export type ColorField = z.infer<typeof zColorField>;
export const zClassification = z.enum(['stable', 'beta', 'prototype']);
export type Classification = z.infer<typeof zClassification>;
export const zSchedulerField = z.enum([
'euler',
'deis',

View File

@ -1,6 +1,6 @@
import { Edge, Node } from 'reactflow';
import { z } from 'zod';
import { zProgressImage } from './common';
import { zClassification, zProgressImage } from './common';
import {
zFieldInputInstance,
zFieldInputTemplate,
@ -21,6 +21,7 @@ export const zInvocationTemplate = z.object({
version: zSemVer,
useCache: z.boolean(),
nodePack: z.string().min(1).nullish(),
classification: zClassification,
});
export type InvocationTemplate = z.infer<typeof zInvocationTemplate>;
// #endregion

View File

@ -83,6 +83,7 @@ export const parseSchema = (
const description = schema.description ?? '';
const version = schema.version;
const nodePack = schema.node_pack;
const classification = schema.classification;
const inputs = reduce(
schema.properties,
@ -245,6 +246,7 @@ export const parseSchema = (
outputs,
useCache,
nodePack,
classification,
};
Object.assign(invocationsAccumulator, { [type]: invocation });

File diff suppressed because one or more lines are too long