From 1b54e58726fc3cf1b0d9256bc2c3c72384720b99 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 28 Aug 2024 21:06:33 +1000 Subject: [PATCH] fix(app): node_pack not added to openapi schema correctly --- invokeai/app/invocations/baseinvocation.py | 54 +++++++++------------- 1 file changed, 23 insertions(+), 31 deletions(-) diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index b527de41bc..ec4bb92355 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -20,7 +20,6 @@ from typing import ( Type, TypeVar, Union, - cast, ) import semver @@ -80,7 +79,7 @@ class UIConfigBase(BaseModel): version: str = Field( description='The node\'s version. Should be a valid semver string e.g. "1.0.0" or "3.8.13".', ) - node_pack: Optional[str] = Field(default=None, description="Whether or not this is a custom node") + node_pack: str = Field(description="The node pack that this node belongs to, will be 'invokeai' for built-in nodes") classification: Classification = Field(default=Classification.Stable, description="The node's classification") model_config = ConfigDict( @@ -230,18 +229,16 @@ class BaseInvocation(ABC, BaseModel): @staticmethod def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None: """Adds various UI-facing attributes to the invocation's OpenAPI schema.""" - uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None)) - if uiconfig is not None: - if uiconfig.title is not None: - schema["title"] = uiconfig.title - if uiconfig.tags is not None: - schema["tags"] = uiconfig.tags - if uiconfig.category is not None: - schema["category"] = uiconfig.category - if uiconfig.node_pack is not None: - schema["node_pack"] = uiconfig.node_pack - schema["classification"] = uiconfig.classification - schema["version"] = uiconfig.version + if title := model_class.UIConfig.title: + schema["title"] = title + if tags := model_class.UIConfig.tags: + schema["tags"] = tags + if category := model_class.UIConfig.category: + schema["category"] = category + if node_pack := model_class.UIConfig.node_pack: + schema["node_pack"] = node_pack + schema["classification"] = model_class.UIConfig.classification + schema["version"] = model_class.UIConfig.version if "required" not in schema or not isinstance(schema["required"], list): schema["required"] = [] schema["class"] = "invocation" @@ -312,7 +309,7 @@ class BaseInvocation(ABC, BaseModel): json_schema_extra={"field_kind": FieldKind.NodeAttribute}, ) - UIConfig: ClassVar[Type[UIConfigBase]] + UIConfig: ClassVar[UIConfigBase] model_config = ConfigDict( protected_namespaces=(), @@ -441,30 +438,25 @@ def invocation( validate_fields(cls.model_fields, invocation_type) # Add OpenAPI schema extras - uiconfig_name = cls.__qualname__ + ".UIConfig" - if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconfig_name: - cls.UIConfig = type(uiconfig_name, (UIConfigBase,), {}) - cls.UIConfig.title = title - cls.UIConfig.tags = tags - cls.UIConfig.category = category - cls.UIConfig.classification = classification - - # Grab the node pack's name from the module name, if it's a custom node - is_custom_node = cls.__module__.rsplit(".", 1)[0] == "invokeai.app.invocations" - if is_custom_node: - cls.UIConfig.node_pack = cls.__module__.split(".")[0] - else: - cls.UIConfig.node_pack = None + uiconfig: dict[str, Any] = {} + uiconfig["title"] = title + uiconfig["tags"] = tags + uiconfig["category"] = category + uiconfig["classification"] = classification + # The node pack is the module name - will be "invokeai" for built-in nodes + uiconfig["node_pack"] = cls.__module__.split(".")[0] if version is not None: try: semver.Version.parse(version) except ValueError as e: raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e - cls.UIConfig.version = version + uiconfig["version"] = version else: logger.warn(f'No version specified for node "{invocation_type}", using "1.0.0"') - cls.UIConfig.version = "1.0.0" + uiconfig["version"] = "1.0.0" + + cls.UIConfig = UIConfigBase(**uiconfig) if use_cache is not None: cls.model_fields["use_cache"].default = use_cache