mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes,ui): add detection of custom nodes
Custom nodes have a new attribute `node_pack` indicating the node pack they came from. - This is displayed in the UI in the icon icon tooltip. - If a workflow is loaded and a node is unavailable, its node pack will be displayed (if it is known). - If a workflow is migrated from v1 to v2, and the node is unknown, it falls back to "Unknown". If the missing node pack is installed and the node is updated, the node pack will be updated as expected.
This commit is contained in:
parent
282a7f32d3
commit
4af4486dd9
@ -1,4 +1,4 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI team
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@ -8,7 +8,7 @@ from abc import ABC, abstractmethod
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from types import UnionType
|
from types import UnionType
|
||||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union
|
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union, cast
|
||||||
|
|
||||||
import semver
|
import semver
|
||||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, create_model
|
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, create_model
|
||||||
@ -26,6 +26,8 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
logger = InvokeAILogger.get_logger()
|
logger = InvokeAILogger.get_logger()
|
||||||
|
|
||||||
|
CUSTOM_NODE_PACK_SUFFIX = "__invokeai-custom-node"
|
||||||
|
|
||||||
|
|
||||||
class InvalidVersionError(ValueError):
|
class InvalidVersionError(ValueError):
|
||||||
pass
|
pass
|
||||||
@ -432,10 +434,10 @@ class UIConfigBase(BaseModel):
|
|||||||
tags: Optional[list[str]] = Field(default_factory=None, description="The node's tags")
|
tags: Optional[list[str]] = Field(default_factory=None, description="The node's tags")
|
||||||
title: Optional[str] = Field(default=None, description="The node's display name")
|
title: Optional[str] = Field(default=None, description="The node's display name")
|
||||||
category: Optional[str] = Field(default=None, description="The node's category")
|
category: Optional[str] = Field(default=None, description="The node's category")
|
||||||
version: Optional[str] = Field(
|
version: str = Field(
|
||||||
default=None,
|
|
||||||
description='The node\'s version. Should be a valid semver string e.g. "1.0.0" or "3.8.13".',
|
description='The node\'s version. Should be a valid semver string e.g. "1.0.0" or "3.8.13".',
|
||||||
)
|
)
|
||||||
|
node_pack: Optional[str] = Field(default=None, description="Whether or not this is a custom node")
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
validate_assignment=True,
|
validate_assignment=True,
|
||||||
@ -591,14 +593,16 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel], *args, **kwargs) -> None:
|
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel], *args, **kwargs) -> None:
|
||||||
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
|
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
|
||||||
uiconfig = getattr(model_class, "UIConfig", None)
|
uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None))
|
||||||
if uiconfig and hasattr(uiconfig, "title"):
|
if uiconfig is not None:
|
||||||
schema["title"] = uiconfig.title
|
if uiconfig.title is not None:
|
||||||
if uiconfig and hasattr(uiconfig, "tags"):
|
schema["title"] = uiconfig.title
|
||||||
schema["tags"] = uiconfig.tags
|
if uiconfig.tags is not None:
|
||||||
if uiconfig and hasattr(uiconfig, "category"):
|
schema["tags"] = uiconfig.tags
|
||||||
schema["category"] = uiconfig.category
|
if uiconfig.category is not None:
|
||||||
if uiconfig and hasattr(uiconfig, "version"):
|
schema["category"] = uiconfig.category
|
||||||
|
if uiconfig.node_pack is not None:
|
||||||
|
schema["node_pack"] = uiconfig.node_pack
|
||||||
schema["version"] = uiconfig.version
|
schema["version"] = uiconfig.version
|
||||||
if "required" not in schema or not isinstance(schema["required"], list):
|
if "required" not in schema or not isinstance(schema["required"], list):
|
||||||
schema["required"] = []
|
schema["required"] = []
|
||||||
@ -796,15 +800,20 @@ def invocation(
|
|||||||
validate_fields(cls.model_fields, invocation_type)
|
validate_fields(cls.model_fields, invocation_type)
|
||||||
|
|
||||||
# Add OpenAPI schema extras
|
# Add OpenAPI schema extras
|
||||||
uiconf_name = cls.__qualname__ + ".UIConfig"
|
uiconfig_name = cls.__qualname__ + ".UIConfig"
|
||||||
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
|
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconfig_name:
|
||||||
cls.UIConfig = type(uiconf_name, (UIConfigBase,), {})
|
cls.UIConfig = type(uiconfig_name, (UIConfigBase,), {})
|
||||||
if title is not None:
|
cls.UIConfig.title = title
|
||||||
cls.UIConfig.title = title
|
cls.UIConfig.tags = tags
|
||||||
if tags is not None:
|
cls.UIConfig.category = category
|
||||||
cls.UIConfig.tags = tags
|
|
||||||
if category is not None:
|
# Grab the node pack's name from the module name, if it's a custom node
|
||||||
cls.UIConfig.category = category
|
module_name = cls.__module__.split(".")[0]
|
||||||
|
if module_name.endswith(CUSTOM_NODE_PACK_SUFFIX):
|
||||||
|
cls.UIConfig.node_pack = module_name.split(CUSTOM_NODE_PACK_SUFFIX)[0]
|
||||||
|
else:
|
||||||
|
cls.UIConfig.node_pack = None
|
||||||
|
|
||||||
if version is not None:
|
if version is not None:
|
||||||
try:
|
try:
|
||||||
semver.Version.parse(version)
|
semver.Version.parse(version)
|
||||||
@ -814,6 +823,7 @@ def invocation(
|
|||||||
else:
|
else:
|
||||||
logger.warn(f'No version specified for node "{invocation_type}", using "1.0.0"')
|
logger.warn(f'No version specified for node "{invocation_type}", using "1.0.0"')
|
||||||
cls.UIConfig.version = "1.0.0"
|
cls.UIConfig.version = "1.0.0"
|
||||||
|
|
||||||
if use_cache is not None:
|
if use_cache is not None:
|
||||||
cls.model_fields["use_cache"].default = use_cache
|
cls.model_fields["use_cache"].default = use_cache
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@ import sys
|
|||||||
from importlib.util import module_from_spec, spec_from_file_location
|
from importlib.util import module_from_spec, spec_from_file_location
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import CUSTOM_NODE_PACK_SUFFIX
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
logger = InvokeAILogger.get_logger()
|
logger = InvokeAILogger.get_logger()
|
||||||
@ -32,8 +33,8 @@ for d in Path(__file__).parent.iterdir():
|
|||||||
if module_name in globals():
|
if module_name in globals():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# we have a legit module to import
|
# load the module, appending adding a suffix to identify it as a custom node pack
|
||||||
spec = spec_from_file_location(module_name, init.absolute())
|
spec = spec_from_file_location(f"{module_name}{CUSTOM_NODE_PACK_SUFFIX}", init.absolute())
|
||||||
|
|
||||||
if spec is None or spec.loader is None:
|
if spec is None or spec.loader is None:
|
||||||
logger.warn(f"Could not load {init}")
|
logger.warn(f"Could not load {init}")
|
||||||
|
@ -160,6 +160,7 @@
|
|||||||
"trainingDesc2": "InvokeAI already supports training custom embeddourings using Textual Inversion using the main script.",
|
"trainingDesc2": "InvokeAI already supports training custom embeddourings using Textual Inversion using the main script.",
|
||||||
"txt2img": "Text To Image",
|
"txt2img": "Text To Image",
|
||||||
"unifiedCanvas": "Unified Canvas",
|
"unifiedCanvas": "Unified Canvas",
|
||||||
|
"unknown": "Unknown",
|
||||||
"upload": "Upload"
|
"upload": "Upload"
|
||||||
},
|
},
|
||||||
"controlnet": {
|
"controlnet": {
|
||||||
@ -802,6 +803,7 @@
|
|||||||
"cannotConnectOutputToOutput": "Cannot connect output to output",
|
"cannotConnectOutputToOutput": "Cannot connect output to output",
|
||||||
"cannotConnectToSelf": "Cannot connect to self",
|
"cannotConnectToSelf": "Cannot connect to self",
|
||||||
"cannotDuplicateConnection": "Cannot create duplicate connections",
|
"cannotDuplicateConnection": "Cannot create duplicate connections",
|
||||||
|
"nodePack": "Node pack",
|
||||||
"clipField": "Clip",
|
"clipField": "Clip",
|
||||||
"clipFieldDescription": "Tokenizer and text_encoder submodels.",
|
"clipFieldDescription": "Tokenizer and text_encoder submodels.",
|
||||||
"collection": "Collection",
|
"collection": "Collection",
|
||||||
@ -966,6 +968,7 @@
|
|||||||
"unableToParseNode": "Unable to parse node",
|
"unableToParseNode": "Unable to parse node",
|
||||||
"unableToUpdateNode": "Unable to update node",
|
"unableToUpdateNode": "Unable to update node",
|
||||||
"unableToValidateWorkflow": "Unable to Validate Workflow",
|
"unableToValidateWorkflow": "Unable to Validate Workflow",
|
||||||
|
"unableToMigrateWorkflow": "Unable to Migrate Workflow",
|
||||||
"unknownErrorValidatingWorkflow": "Unknown error validating workflow",
|
"unknownErrorValidatingWorkflow": "Unknown error validating workflow",
|
||||||
"inputFieldTypeParseError": "Unable to parse type of input field {{node}}.{{field}} ({{message}})",
|
"inputFieldTypeParseError": "Unable to parse type of input field {{node}}.{{field}} ({{message}})",
|
||||||
"outputFieldTypeParseError": "Unable to parse type of output field {{node}}.{{field}} ({{message}})",
|
"outputFieldTypeParseError": "Unable to parse type of output field {{node}}.{{field}} ({{message}})",
|
||||||
@ -979,9 +982,9 @@
|
|||||||
"unhandledInputProperty": "Unhandled input property",
|
"unhandledInputProperty": "Unhandled input property",
|
||||||
"unhandledOutputProperty": "Unhandled output property",
|
"unhandledOutputProperty": "Unhandled output property",
|
||||||
"unknownField": "Unknown field",
|
"unknownField": "Unknown field",
|
||||||
"unknownFieldType": "$t(nodes.unknownField) type",
|
"unknownFieldType": "$t(nodes.unknownField) type: {{type}}",
|
||||||
"unknownNode": "Unknown Node",
|
"unknownNode": "Unknown Node",
|
||||||
"unknownNodeType": "$t(nodes.unknownNode) type",
|
"unknownNodeType": "Unknown node type",
|
||||||
"unknownTemplate": "Unknown Template",
|
"unknownTemplate": "Unknown Template",
|
||||||
"unknownInput": "Unknown input: {{name}}",
|
"unknownInput": "Unknown input: {{name}}",
|
||||||
"unkownInvocation": "Unknown Invocation type",
|
"unkownInvocation": "Unknown Invocation type",
|
||||||
|
@ -3,7 +3,10 @@ import { parseify } from 'common/util/serialize';
|
|||||||
import { workflowLoadRequested } from 'features/nodes/store/actions';
|
import { workflowLoadRequested } from 'features/nodes/store/actions';
|
||||||
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
|
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
|
||||||
import { $flow } from 'features/nodes/store/reactFlowInstance';
|
import { $flow } from 'features/nodes/store/reactFlowInstance';
|
||||||
import { WorkflowVersionError } from 'features/nodes/types/error';
|
import {
|
||||||
|
WorkflowMigrationError,
|
||||||
|
WorkflowVersionError,
|
||||||
|
} from 'features/nodes/types/error';
|
||||||
import { validateWorkflow } from 'features/nodes/util/workflow/validateWorkflow';
|
import { validateWorkflow } from 'features/nodes/util/workflow/validateWorkflow';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
import { makeToast } from 'features/system/util/makeToast';
|
import { makeToast } from 'features/system/util/makeToast';
|
||||||
@ -67,6 +70,18 @@ export const addWorkflowLoadRequestedListener = () => {
|
|||||||
})
|
})
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
} else if (e instanceof WorkflowMigrationError) {
|
||||||
|
// There was a problem migrating the workflow to the latest version
|
||||||
|
log.error({ error: parseify(e) }, e.message);
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: t('nodes.unableToValidateWorkflow'),
|
||||||
|
status: 'error',
|
||||||
|
description: e.message,
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
} else if (e instanceof z.ZodError) {
|
} else if (e instanceof z.ZodError) {
|
||||||
// There was a problem validating the workflow itself
|
// There was a problem validating the workflow itself
|
||||||
const { message } = fromZodError(e, {
|
const { message } = fromZodError(e, {
|
||||||
|
@ -24,6 +24,7 @@ const InvocationNodeInfoIcon = ({ nodeId }: Props) => {
|
|||||||
<Icon
|
<Icon
|
||||||
as={FaInfoCircle}
|
as={FaInfoCircle}
|
||||||
sx={{
|
sx={{
|
||||||
|
display: 'block',
|
||||||
boxSize: 4,
|
boxSize: 4,
|
||||||
w: 8,
|
w: 8,
|
||||||
color: needsUpdate ? 'error.400' : 'base.400',
|
color: needsUpdate ? 'error.400' : 'base.400',
|
||||||
@ -109,6 +110,11 @@ const TooltipContent = memo(({ nodeId }: { nodeId: string }) => {
|
|||||||
<Text as="span" sx={{ fontWeight: 600 }}>
|
<Text as="span" sx={{ fontWeight: 600 }}>
|
||||||
{title}
|
{title}
|
||||||
</Text>
|
</Text>
|
||||||
|
{nodeTemplate?.nodePack && (
|
||||||
|
<Text opacity={0.7}>
|
||||||
|
{t('nodes.nodePack')}: {nodeTemplate.nodePack}
|
||||||
|
</Text>
|
||||||
|
)}
|
||||||
<Text sx={{ opacity: 0.7, fontStyle: 'oblique 5deg' }}>
|
<Text sx={{ opacity: 0.7, fontStyle: 'oblique 5deg' }}>
|
||||||
{nodeTemplate?.description}
|
{nodeTemplate?.description}
|
||||||
</Text>
|
</Text>
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
import { Box, Flex, Text } from '@chakra-ui/react';
|
import { Flex, Text } from '@chakra-ui/react';
|
||||||
|
import { useNodePack } from 'features/nodes/hooks/useNodePack';
|
||||||
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
|
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
import NodeCollapseButton from '../common/NodeCollapseButton';
|
import NodeCollapseButton from '../common/NodeCollapseButton';
|
||||||
import NodeWrapper from '../common/NodeWrapper';
|
import NodeWrapper from '../common/NodeWrapper';
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
nodeId: string;
|
nodeId: string;
|
||||||
@ -21,6 +22,7 @@ const InvocationNodeUnknownFallback = ({
|
|||||||
selected,
|
selected,
|
||||||
}: Props) => {
|
}: Props) => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
const nodePack = useNodePack(nodeId);
|
||||||
return (
|
return (
|
||||||
<NodeWrapper nodeId={nodeId} selected={selected}>
|
<NodeWrapper nodeId={nodeId} selected={selected}>
|
||||||
<Flex
|
<Flex
|
||||||
@ -62,12 +64,22 @@ const InvocationNodeUnknownFallback = ({
|
|||||||
fontSize: 'sm',
|
fontSize: 'sm',
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<Box>
|
<Flex gap={2} flexDir="column">
|
||||||
<Text as="span">{t('nodes.unknownNodeType')}: </Text>
|
<Text as="span">
|
||||||
<Text as="span" fontWeight={600}>
|
{t('nodes.unknownNodeType')}:{' '}
|
||||||
{type}
|
<Text as="span" fontWeight={600}>
|
||||||
|
{type}
|
||||||
|
</Text>
|
||||||
</Text>
|
</Text>
|
||||||
</Box>
|
{nodePack && (
|
||||||
|
<Text as="span">
|
||||||
|
{t('nodes.nodePack')}:{' '}
|
||||||
|
<Text as="span" fontWeight={600}>
|
||||||
|
{nodePack}
|
||||||
|
</Text>
|
||||||
|
</Text>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
</Flex>
|
</Flex>
|
||||||
)}
|
)}
|
||||||
</NodeWrapper>
|
</NodeWrapper>
|
||||||
|
@ -298,7 +298,7 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
|||||||
_dark: { color: 'error.300' },
|
_dark: { color: 'error.300' },
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
{t('nodes.unknownFieldType')}: {fieldInstance?.type.name}
|
{t('nodes.unknownFieldType', { type: fieldInstance?.type.name })}
|
||||||
</Text>
|
</Text>
|
||||||
</Box>
|
</Box>
|
||||||
);
|
);
|
||||||
|
@ -0,0 +1,27 @@
|
|||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import { useMemo } from 'react';
|
||||||
|
import { isInvocationNode } from '../types/invocation';
|
||||||
|
|
||||||
|
export const useNodePack = (nodeId: string) => {
|
||||||
|
const selector = useMemo(
|
||||||
|
() =>
|
||||||
|
createSelector(
|
||||||
|
stateSelector,
|
||||||
|
({ nodes }) => {
|
||||||
|
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||||
|
if (!isInvocationNode(node)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return node.data.nodePack;
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
),
|
||||||
|
[nodeId]
|
||||||
|
);
|
||||||
|
|
||||||
|
const title = useAppSelector(selector);
|
||||||
|
return title;
|
||||||
|
};
|
@ -12,6 +12,20 @@ export class WorkflowVersionError extends Error {
|
|||||||
this.name = this.constructor.name;
|
this.name = this.constructor.name;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
/**
|
||||||
|
* Workflow Migration Error
|
||||||
|
* Raised when a workflow migration fails.
|
||||||
|
*/
|
||||||
|
export class WorkflowMigrationError extends Error {
|
||||||
|
/**
|
||||||
|
* Create WorkflowMigrationError
|
||||||
|
* @param {String} message
|
||||||
|
*/
|
||||||
|
constructor(message: string) {
|
||||||
|
super(message);
|
||||||
|
this.name = this.constructor.name;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Unable to Update Node Error
|
* Unable to Update Node Error
|
||||||
|
@ -21,6 +21,7 @@ export const zInvocationTemplate = z.object({
|
|||||||
withWorkflow: z.boolean(),
|
withWorkflow: z.boolean(),
|
||||||
version: zSemVer,
|
version: zSemVer,
|
||||||
useCache: z.boolean(),
|
useCache: z.boolean(),
|
||||||
|
nodePack: z.string().min(1).nullish(),
|
||||||
});
|
});
|
||||||
export type InvocationTemplate = z.infer<typeof zInvocationTemplate>;
|
export type InvocationTemplate = z.infer<typeof zInvocationTemplate>;
|
||||||
// #endregion
|
// #endregion
|
||||||
@ -36,6 +37,7 @@ export const zInvocationNodeData = z.object({
|
|||||||
isIntermediate: z.boolean(),
|
isIntermediate: z.boolean(),
|
||||||
useCache: z.boolean(),
|
useCache: z.boolean(),
|
||||||
version: zSemVer,
|
version: zSemVer,
|
||||||
|
nodePack: z.string().min(1).nullish(),
|
||||||
inputs: z.record(zFieldInputInstance),
|
inputs: z.record(zFieldInputInstance),
|
||||||
outputs: z.record(zFieldOutputInstance),
|
outputs: z.record(zFieldOutputInstance),
|
||||||
});
|
});
|
||||||
|
@ -82,6 +82,7 @@ export const parseSchema = (
|
|||||||
const tags = schema.tags ?? [];
|
const tags = schema.tags ?? [];
|
||||||
const description = schema.description ?? '';
|
const description = schema.description ?? '';
|
||||||
const version = schema.version;
|
const version = schema.version;
|
||||||
|
const nodePack = schema.node_pack;
|
||||||
let withWorkflow = false;
|
let withWorkflow = false;
|
||||||
|
|
||||||
const inputs = reduce(
|
const inputs = reduce(
|
||||||
@ -257,6 +258,7 @@ export const parseSchema = (
|
|||||||
outputs,
|
outputs,
|
||||||
useCache,
|
useCache,
|
||||||
withWorkflow,
|
withWorkflow,
|
||||||
|
nodePack,
|
||||||
};
|
};
|
||||||
|
|
||||||
Object.assign(invocationsAccumulator, { [type]: invocation });
|
Object.assign(invocationsAccumulator, { [type]: invocation });
|
||||||
|
@ -1,7 +1,14 @@
|
|||||||
|
import { $store } from 'app/store/nanostores/store';
|
||||||
|
import { RootState } from 'app/store/store';
|
||||||
|
import { FieldType } from 'features/nodes/types/field';
|
||||||
|
import { InvocationNodeData } from 'features/nodes/types/invocation';
|
||||||
import { t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
import { forEach, isString } from 'lodash-es';
|
import { forEach } from 'lodash-es';
|
||||||
import { z } from 'zod';
|
import { z } from 'zod';
|
||||||
import { WorkflowVersionError } from '../../types/error';
|
import {
|
||||||
|
WorkflowMigrationError,
|
||||||
|
WorkflowVersionError,
|
||||||
|
} from '../../types/error';
|
||||||
import { zSemVer } from '../../types/semver';
|
import { zSemVer } from '../../types/semver';
|
||||||
import { FIELD_TYPE_V1_TO_FIELD_TYPE_V2_MAPPING } from '../../types/v1/fieldTypeMap';
|
import { FIELD_TYPE_V1_TO_FIELD_TYPE_V2_MAPPING } from '../../types/v1/fieldTypeMap';
|
||||||
import { WorkflowV1, zWorkflowV1 } from '../../types/v1/workflowV1';
|
import { WorkflowV1, zWorkflowV1 } from '../../types/v1/workflowV1';
|
||||||
@ -20,22 +27,39 @@ const zWorkflowMetaVersion = z.object({
|
|||||||
* Migrates a workflow from V1 to V2.
|
* Migrates a workflow from V1 to V2.
|
||||||
*/
|
*/
|
||||||
const migrateV1toV2 = (workflowToMigrate: WorkflowV1): WorkflowV2 => {
|
const migrateV1toV2 = (workflowToMigrate: WorkflowV1): WorkflowV2 => {
|
||||||
|
const invocationTemplates = ($store.get()?.getState() as RootState).nodes
|
||||||
|
.nodeTemplates;
|
||||||
workflowToMigrate.nodes.forEach((node) => {
|
workflowToMigrate.nodes.forEach((node) => {
|
||||||
if (node.type === 'invocation') {
|
if (node.type === 'invocation') {
|
||||||
|
// Migrate field types
|
||||||
forEach(node.data.inputs, (input) => {
|
forEach(node.data.inputs, (input) => {
|
||||||
if (!isString(input.type)) {
|
const newFieldType = FIELD_TYPE_V1_TO_FIELD_TYPE_V2_MAPPING[input.type];
|
||||||
return;
|
if (!newFieldType) {
|
||||||
|
throw new WorkflowMigrationError(
|
||||||
|
t('nodes.unknownFieldType', { type: input.type })
|
||||||
|
);
|
||||||
}
|
}
|
||||||
(input.type as unknown) =
|
// Cast as the V2 type
|
||||||
FIELD_TYPE_V1_TO_FIELD_TYPE_V2_MAPPING[input.type];
|
(input.type as unknown as FieldType) = newFieldType;
|
||||||
});
|
});
|
||||||
forEach(node.data.outputs, (output) => {
|
forEach(node.data.outputs, (output) => {
|
||||||
if (!isString(output.type)) {
|
const newFieldType =
|
||||||
return;
|
|
||||||
}
|
|
||||||
(output.type as unknown) =
|
|
||||||
FIELD_TYPE_V1_TO_FIELD_TYPE_V2_MAPPING[output.type];
|
FIELD_TYPE_V1_TO_FIELD_TYPE_V2_MAPPING[output.type];
|
||||||
|
if (!newFieldType) {
|
||||||
|
throw new WorkflowMigrationError(
|
||||||
|
t('nodes.unknownFieldType', { type: output.type })
|
||||||
|
);
|
||||||
|
}
|
||||||
|
// Cast as the V2 type
|
||||||
|
(output.type as unknown as FieldType) = newFieldType;
|
||||||
});
|
});
|
||||||
|
// Migrate nodePack
|
||||||
|
const invocationTemplate = invocationTemplates[node.data.type];
|
||||||
|
const nodePack = invocationTemplate
|
||||||
|
? invocationTemplate.nodePack
|
||||||
|
: t('common.unknown');
|
||||||
|
// Cast as the V2 type
|
||||||
|
(node.data as unknown as InvocationNodeData).nodePack = nodePack;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
(workflowToMigrate.meta.version as WorkflowV2['meta']['version']) = '2.0.0';
|
(workflowToMigrate.meta.version as WorkflowV2['meta']['version']) = '2.0.0';
|
||||||
@ -49,6 +73,7 @@ export const parseAndMigrateWorkflow = (data: unknown): WorkflowV2 => {
|
|||||||
const workflowVersionResult = zWorkflowMetaVersion.safeParse(data);
|
const workflowVersionResult = zWorkflowMetaVersion.safeParse(data);
|
||||||
|
|
||||||
if (!workflowVersionResult.success) {
|
if (!workflowVersionResult.success) {
|
||||||
|
console.log(data);
|
||||||
throw new WorkflowVersionError(t('nodes.unableToGetWorkflowVersion'));
|
throw new WorkflowVersionError(t('nodes.unableToGetWorkflowVersion'));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user