feat(nodes,ui): add detection of custom nodes

Custom nodes have a new attribute `node_pack` indicating the node pack they came from.

- This is displayed in the UI in the icon icon tooltip.
- If a workflow is loaded and a node is unavailable, its node pack will be displayed (if it is known).
- If a workflow is migrated from v1 to v2, and the node is unknown, it falls back to "Unknown". If the missing node pack is installed and the node is updated, the node pack will be updated as expected.
This commit is contained in:
psychedelicious 2023-11-27 12:51:19 +11:00
parent 282a7f32d3
commit 4af4486dd9
13 changed files with 196 additions and 79 deletions

View File

@ -1,4 +1,4 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI team
from __future__ import annotations
@ -8,7 +8,7 @@ from abc import ABC, abstractmethod
from enum import Enum
from inspect import signature
from types import UnionType
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union, cast
import semver
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, create_model
@ -26,6 +26,8 @@ if TYPE_CHECKING:
logger = InvokeAILogger.get_logger()
CUSTOM_NODE_PACK_SUFFIX = "__invokeai-custom-node"
class InvalidVersionError(ValueError):
pass
@ -432,10 +434,10 @@ class UIConfigBase(BaseModel):
tags: Optional[list[str]] = Field(default_factory=None, description="The node's tags")
title: Optional[str] = Field(default=None, description="The node's display name")
category: Optional[str] = Field(default=None, description="The node's category")
version: Optional[str] = Field(
default=None,
version: str = Field(
description='The node\'s version. Should be a valid semver string e.g. "1.0.0" or "3.8.13".',
)
node_pack: Optional[str] = Field(default=None, description="Whether or not this is a custom node")
model_config = ConfigDict(
validate_assignment=True,
@ -591,14 +593,16 @@ class BaseInvocation(ABC, BaseModel):
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel], *args, **kwargs) -> None:
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
uiconfig = getattr(model_class, "UIConfig", None)
if uiconfig and hasattr(uiconfig, "title"):
schema["title"] = uiconfig.title
if uiconfig and hasattr(uiconfig, "tags"):
schema["tags"] = uiconfig.tags
if uiconfig and hasattr(uiconfig, "category"):
schema["category"] = uiconfig.category
if uiconfig and hasattr(uiconfig, "version"):
uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None))
if uiconfig is not None:
if uiconfig.title is not None:
schema["title"] = uiconfig.title
if uiconfig.tags is not None:
schema["tags"] = uiconfig.tags
if uiconfig.category is not None:
schema["category"] = uiconfig.category
if uiconfig.node_pack is not None:
schema["node_pack"] = uiconfig.node_pack
schema["version"] = uiconfig.version
if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = []
@ -796,15 +800,20 @@ def invocation(
validate_fields(cls.model_fields, invocation_type)
# Add OpenAPI schema extras
uiconf_name = cls.__qualname__ + ".UIConfig"
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
cls.UIConfig = type(uiconf_name, (UIConfigBase,), {})
if title is not None:
cls.UIConfig.title = title
if tags is not None:
cls.UIConfig.tags = tags
if category is not None:
cls.UIConfig.category = category
uiconfig_name = cls.__qualname__ + ".UIConfig"
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconfig_name:
cls.UIConfig = type(uiconfig_name, (UIConfigBase,), {})
cls.UIConfig.title = title
cls.UIConfig.tags = tags
cls.UIConfig.category = category
# Grab the node pack's name from the module name, if it's a custom node
module_name = cls.__module__.split(".")[0]
if module_name.endswith(CUSTOM_NODE_PACK_SUFFIX):
cls.UIConfig.node_pack = module_name.split(CUSTOM_NODE_PACK_SUFFIX)[0]
else:
cls.UIConfig.node_pack = None
if version is not None:
try:
semver.Version.parse(version)
@ -814,6 +823,7 @@ def invocation(
else:
logger.warn(f'No version specified for node "{invocation_type}", using "1.0.0"')
cls.UIConfig.version = "1.0.0"
if use_cache is not None:
cls.model_fields["use_cache"].default = use_cache

View File

@ -6,6 +6,7 @@ import sys
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path
from invokeai.app.invocations.baseinvocation import CUSTOM_NODE_PACK_SUFFIX
from invokeai.backend.util.logging import InvokeAILogger
logger = InvokeAILogger.get_logger()
@ -32,8 +33,8 @@ for d in Path(__file__).parent.iterdir():
if module_name in globals():
continue
# we have a legit module to import
spec = spec_from_file_location(module_name, init.absolute())
# load the module, appending adding a suffix to identify it as a custom node pack
spec = spec_from_file_location(f"{module_name}{CUSTOM_NODE_PACK_SUFFIX}", init.absolute())
if spec is None or spec.loader is None:
logger.warn(f"Could not load {init}")

View File

@ -160,6 +160,7 @@
"trainingDesc2": "InvokeAI already supports training custom embeddourings using Textual Inversion using the main script.",
"txt2img": "Text To Image",
"unifiedCanvas": "Unified Canvas",
"unknown": "Unknown",
"upload": "Upload"
},
"controlnet": {
@ -802,6 +803,7 @@
"cannotConnectOutputToOutput": "Cannot connect output to output",
"cannotConnectToSelf": "Cannot connect to self",
"cannotDuplicateConnection": "Cannot create duplicate connections",
"nodePack": "Node pack",
"clipField": "Clip",
"clipFieldDescription": "Tokenizer and text_encoder submodels.",
"collection": "Collection",
@ -966,6 +968,7 @@
"unableToParseNode": "Unable to parse node",
"unableToUpdateNode": "Unable to update node",
"unableToValidateWorkflow": "Unable to Validate Workflow",
"unableToMigrateWorkflow": "Unable to Migrate Workflow",
"unknownErrorValidatingWorkflow": "Unknown error validating workflow",
"inputFieldTypeParseError": "Unable to parse type of input field {{node}}.{{field}} ({{message}})",
"outputFieldTypeParseError": "Unable to parse type of output field {{node}}.{{field}} ({{message}})",
@ -979,9 +982,9 @@
"unhandledInputProperty": "Unhandled input property",
"unhandledOutputProperty": "Unhandled output property",
"unknownField": "Unknown field",
"unknownFieldType": "$t(nodes.unknownField) type",
"unknownFieldType": "$t(nodes.unknownField) type: {{type}}",
"unknownNode": "Unknown Node",
"unknownNodeType": "$t(nodes.unknownNode) type",
"unknownNodeType": "Unknown node type",
"unknownTemplate": "Unknown Template",
"unknownInput": "Unknown input: {{name}}",
"unkownInvocation": "Unknown Invocation type",

View File

@ -3,7 +3,10 @@ import { parseify } from 'common/util/serialize';
import { workflowLoadRequested } from 'features/nodes/store/actions';
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
import { $flow } from 'features/nodes/store/reactFlowInstance';
import { WorkflowVersionError } from 'features/nodes/types/error';
import {
WorkflowMigrationError,
WorkflowVersionError,
} from 'features/nodes/types/error';
import { validateWorkflow } from 'features/nodes/util/workflow/validateWorkflow';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
@ -67,6 +70,18 @@ export const addWorkflowLoadRequestedListener = () => {
})
)
);
} else if (e instanceof WorkflowMigrationError) {
// There was a problem migrating the workflow to the latest version
log.error({ error: parseify(e) }, e.message);
dispatch(
addToast(
makeToast({
title: t('nodes.unableToValidateWorkflow'),
status: 'error',
description: e.message,
})
)
);
} else if (e instanceof z.ZodError) {
// There was a problem validating the workflow itself
const { message } = fromZodError(e, {

View File

@ -24,6 +24,7 @@ const InvocationNodeInfoIcon = ({ nodeId }: Props) => {
<Icon
as={FaInfoCircle}
sx={{
display: 'block',
boxSize: 4,
w: 8,
color: needsUpdate ? 'error.400' : 'base.400',
@ -109,6 +110,11 @@ const TooltipContent = memo(({ nodeId }: { nodeId: string }) => {
<Text as="span" sx={{ fontWeight: 600 }}>
{title}
</Text>
{nodeTemplate?.nodePack && (
<Text opacity={0.7}>
{t('nodes.nodePack')}: {nodeTemplate.nodePack}
</Text>
)}
<Text sx={{ opacity: 0.7, fontStyle: 'oblique 5deg' }}>
{nodeTemplate?.description}
</Text>

View File

@ -1,9 +1,10 @@
import { Box, Flex, Text } from '@chakra-ui/react';
import { Flex, Text } from '@chakra-ui/react';
import { useNodePack } from 'features/nodes/hooks/useNodePack';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import NodeCollapseButton from '../common/NodeCollapseButton';
import NodeWrapper from '../common/NodeWrapper';
import { useTranslation } from 'react-i18next';
type Props = {
nodeId: string;
@ -21,6 +22,7 @@ const InvocationNodeUnknownFallback = ({
selected,
}: Props) => {
const { t } = useTranslation();
const nodePack = useNodePack(nodeId);
return (
<NodeWrapper nodeId={nodeId} selected={selected}>
<Flex
@ -62,12 +64,22 @@ const InvocationNodeUnknownFallback = ({
fontSize: 'sm',
}}
>
<Box>
<Text as="span">{t('nodes.unknownNodeType')}: </Text>
<Text as="span" fontWeight={600}>
{type}
<Flex gap={2} flexDir="column">
<Text as="span">
{t('nodes.unknownNodeType')}:{' '}
<Text as="span" fontWeight={600}>
{type}
</Text>
</Text>
</Box>
{nodePack && (
<Text as="span">
{t('nodes.nodePack')}:{' '}
<Text as="span" fontWeight={600}>
{nodePack}
</Text>
</Text>
)}
</Flex>
</Flex>
)}
</NodeWrapper>

View File

@ -298,7 +298,7 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
_dark: { color: 'error.300' },
}}
>
{t('nodes.unknownFieldType')}: {fieldInstance?.type.name}
{t('nodes.unknownFieldType', { type: fieldInstance?.type.name })}
</Text>
</Box>
);

View File

@ -0,0 +1,27 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useMemo } from 'react';
import { isInvocationNode } from '../types/invocation';
export const useNodePack = (nodeId: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return false;
}
return node.data.nodePack;
},
defaultSelectorOptions
),
[nodeId]
);
const title = useAppSelector(selector);
return title;
};

View File

@ -12,6 +12,20 @@ export class WorkflowVersionError extends Error {
this.name = this.constructor.name;
}
}
/**
* Workflow Migration Error
* Raised when a workflow migration fails.
*/
export class WorkflowMigrationError extends Error {
/**
* Create WorkflowMigrationError
* @param {String} message
*/
constructor(message: string) {
super(message);
this.name = this.constructor.name;
}
}
/**
* Unable to Update Node Error

View File

@ -21,6 +21,7 @@ export const zInvocationTemplate = z.object({
withWorkflow: z.boolean(),
version: zSemVer,
useCache: z.boolean(),
nodePack: z.string().min(1).nullish(),
});
export type InvocationTemplate = z.infer<typeof zInvocationTemplate>;
// #endregion
@ -36,6 +37,7 @@ export const zInvocationNodeData = z.object({
isIntermediate: z.boolean(),
useCache: z.boolean(),
version: zSemVer,
nodePack: z.string().min(1).nullish(),
inputs: z.record(zFieldInputInstance),
outputs: z.record(zFieldOutputInstance),
});

View File

@ -82,6 +82,7 @@ export const parseSchema = (
const tags = schema.tags ?? [];
const description = schema.description ?? '';
const version = schema.version;
const nodePack = schema.node_pack;
let withWorkflow = false;
const inputs = reduce(
@ -257,6 +258,7 @@ export const parseSchema = (
outputs,
useCache,
withWorkflow,
nodePack,
};
Object.assign(invocationsAccumulator, { [type]: invocation });

View File

@ -1,7 +1,14 @@
import { $store } from 'app/store/nanostores/store';
import { RootState } from 'app/store/store';
import { FieldType } from 'features/nodes/types/field';
import { InvocationNodeData } from 'features/nodes/types/invocation';
import { t } from 'i18next';
import { forEach, isString } from 'lodash-es';
import { forEach } from 'lodash-es';
import { z } from 'zod';
import { WorkflowVersionError } from '../../types/error';
import {
WorkflowMigrationError,
WorkflowVersionError,
} from '../../types/error';
import { zSemVer } from '../../types/semver';
import { FIELD_TYPE_V1_TO_FIELD_TYPE_V2_MAPPING } from '../../types/v1/fieldTypeMap';
import { WorkflowV1, zWorkflowV1 } from '../../types/v1/workflowV1';
@ -20,22 +27,39 @@ const zWorkflowMetaVersion = z.object({
* Migrates a workflow from V1 to V2.
*/
const migrateV1toV2 = (workflowToMigrate: WorkflowV1): WorkflowV2 => {
const invocationTemplates = ($store.get()?.getState() as RootState).nodes
.nodeTemplates;
workflowToMigrate.nodes.forEach((node) => {
if (node.type === 'invocation') {
// Migrate field types
forEach(node.data.inputs, (input) => {
if (!isString(input.type)) {
return;
const newFieldType = FIELD_TYPE_V1_TO_FIELD_TYPE_V2_MAPPING[input.type];
if (!newFieldType) {
throw new WorkflowMigrationError(
t('nodes.unknownFieldType', { type: input.type })
);
}
(input.type as unknown) =
FIELD_TYPE_V1_TO_FIELD_TYPE_V2_MAPPING[input.type];
// Cast as the V2 type
(input.type as unknown as FieldType) = newFieldType;
});
forEach(node.data.outputs, (output) => {
if (!isString(output.type)) {
return;
}
(output.type as unknown) =
const newFieldType =
FIELD_TYPE_V1_TO_FIELD_TYPE_V2_MAPPING[output.type];
if (!newFieldType) {
throw new WorkflowMigrationError(
t('nodes.unknownFieldType', { type: output.type })
);
}
// Cast as the V2 type
(output.type as unknown as FieldType) = newFieldType;
});
// Migrate nodePack
const invocationTemplate = invocationTemplates[node.data.type];
const nodePack = invocationTemplate
? invocationTemplate.nodePack
: t('common.unknown');
// Cast as the V2 type
(node.data as unknown as InvocationNodeData).nodePack = nodePack;
}
});
(workflowToMigrate.meta.version as WorkflowV2['meta']['version']) = '2.0.0';
@ -49,6 +73,7 @@ export const parseAndMigrateWorkflow = (data: unknown): WorkflowV2 => {
const workflowVersionResult = zWorkflowMetaVersion.safeParse(data);
if (!workflowVersionResult.success) {
console.log(data);
throw new WorkflowVersionError(t('nodes.unableToGetWorkflowVersion'));
}

File diff suppressed because one or more lines are too long