mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes,ui): add detection of custom nodes
Custom nodes have a new attribute `node_pack` indicating the node pack they came from. - This is displayed in the UI in the icon icon tooltip. - If a workflow is loaded and a node is unavailable, its node pack will be displayed (if it is known). - If a workflow is migrated from v1 to v2, and the node is unknown, it falls back to "Unknown". If the missing node pack is installed and the node is updated, the node pack will be updated as expected.
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI team
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@ -8,7 +8,7 @@ from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from inspect import signature
|
||||
from types import UnionType
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union, cast
|
||||
|
||||
import semver
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, create_model
|
||||
@ -26,6 +26,8 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = InvokeAILogger.get_logger()
|
||||
|
||||
CUSTOM_NODE_PACK_SUFFIX = "__invokeai-custom-node"
|
||||
|
||||
|
||||
class InvalidVersionError(ValueError):
|
||||
pass
|
||||
@ -432,10 +434,10 @@ class UIConfigBase(BaseModel):
|
||||
tags: Optional[list[str]] = Field(default_factory=None, description="The node's tags")
|
||||
title: Optional[str] = Field(default=None, description="The node's display name")
|
||||
category: Optional[str] = Field(default=None, description="The node's category")
|
||||
version: Optional[str] = Field(
|
||||
default=None,
|
||||
version: str = Field(
|
||||
description='The node\'s version. Should be a valid semver string e.g. "1.0.0" or "3.8.13".',
|
||||
)
|
||||
node_pack: Optional[str] = Field(default=None, description="Whether or not this is a custom node")
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
@ -591,14 +593,16 @@ class BaseInvocation(ABC, BaseModel):
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel], *args, **kwargs) -> None:
|
||||
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
|
||||
uiconfig = getattr(model_class, "UIConfig", None)
|
||||
if uiconfig and hasattr(uiconfig, "title"):
|
||||
schema["title"] = uiconfig.title
|
||||
if uiconfig and hasattr(uiconfig, "tags"):
|
||||
schema["tags"] = uiconfig.tags
|
||||
if uiconfig and hasattr(uiconfig, "category"):
|
||||
schema["category"] = uiconfig.category
|
||||
if uiconfig and hasattr(uiconfig, "version"):
|
||||
uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None))
|
||||
if uiconfig is not None:
|
||||
if uiconfig.title is not None:
|
||||
schema["title"] = uiconfig.title
|
||||
if uiconfig.tags is not None:
|
||||
schema["tags"] = uiconfig.tags
|
||||
if uiconfig.category is not None:
|
||||
schema["category"] = uiconfig.category
|
||||
if uiconfig.node_pack is not None:
|
||||
schema["node_pack"] = uiconfig.node_pack
|
||||
schema["version"] = uiconfig.version
|
||||
if "required" not in schema or not isinstance(schema["required"], list):
|
||||
schema["required"] = []
|
||||
@ -796,15 +800,20 @@ def invocation(
|
||||
validate_fields(cls.model_fields, invocation_type)
|
||||
|
||||
# Add OpenAPI schema extras
|
||||
uiconf_name = cls.__qualname__ + ".UIConfig"
|
||||
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
|
||||
cls.UIConfig = type(uiconf_name, (UIConfigBase,), {})
|
||||
if title is not None:
|
||||
cls.UIConfig.title = title
|
||||
if tags is not None:
|
||||
cls.UIConfig.tags = tags
|
||||
if category is not None:
|
||||
cls.UIConfig.category = category
|
||||
uiconfig_name = cls.__qualname__ + ".UIConfig"
|
||||
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconfig_name:
|
||||
cls.UIConfig = type(uiconfig_name, (UIConfigBase,), {})
|
||||
cls.UIConfig.title = title
|
||||
cls.UIConfig.tags = tags
|
||||
cls.UIConfig.category = category
|
||||
|
||||
# Grab the node pack's name from the module name, if it's a custom node
|
||||
module_name = cls.__module__.split(".")[0]
|
||||
if module_name.endswith(CUSTOM_NODE_PACK_SUFFIX):
|
||||
cls.UIConfig.node_pack = module_name.split(CUSTOM_NODE_PACK_SUFFIX)[0]
|
||||
else:
|
||||
cls.UIConfig.node_pack = None
|
||||
|
||||
if version is not None:
|
||||
try:
|
||||
semver.Version.parse(version)
|
||||
@ -814,6 +823,7 @@ def invocation(
|
||||
else:
|
||||
logger.warn(f'No version specified for node "{invocation_type}", using "1.0.0"')
|
||||
cls.UIConfig.version = "1.0.0"
|
||||
|
||||
if use_cache is not None:
|
||||
cls.model_fields["use_cache"].default = use_cache
|
||||
|
||||
|
@ -6,6 +6,7 @@ import sys
|
||||
from importlib.util import module_from_spec, spec_from_file_location
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import CUSTOM_NODE_PACK_SUFFIX
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
logger = InvokeAILogger.get_logger()
|
||||
@ -32,8 +33,8 @@ for d in Path(__file__).parent.iterdir():
|
||||
if module_name in globals():
|
||||
continue
|
||||
|
||||
# we have a legit module to import
|
||||
spec = spec_from_file_location(module_name, init.absolute())
|
||||
# load the module, appending adding a suffix to identify it as a custom node pack
|
||||
spec = spec_from_file_location(f"{module_name}{CUSTOM_NODE_PACK_SUFFIX}", init.absolute())
|
||||
|
||||
if spec is None or spec.loader is None:
|
||||
logger.warn(f"Could not load {init}")
|
||||
|
Reference in New Issue
Block a user