fix(app): node_pack not added to openapi schema correctly

This commit is contained in:
psychedelicious 2024-08-28 21:06:33 +10:00
parent 219d7c9611
commit 1b54e58726

View File

@ -20,7 +20,6 @@ from typing import (
Type,
TypeVar,
Union,
cast,
)
import semver
@ -80,7 +79,7 @@ class UIConfigBase(BaseModel):
version: str = Field(
description='The node\'s version. Should be a valid semver string e.g. "1.0.0" or "3.8.13".',
)
node_pack: Optional[str] = Field(default=None, description="Whether or not this is a custom node")
node_pack: str = Field(description="The node pack that this node belongs to, will be 'invokeai' for built-in nodes")
classification: Classification = Field(default=Classification.Stable, description="The node's classification")
model_config = ConfigDict(
@ -230,18 +229,16 @@ class BaseInvocation(ABC, BaseModel):
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None:
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None))
if uiconfig is not None:
if uiconfig.title is not None:
schema["title"] = uiconfig.title
if uiconfig.tags is not None:
schema["tags"] = uiconfig.tags
if uiconfig.category is not None:
schema["category"] = uiconfig.category
if uiconfig.node_pack is not None:
schema["node_pack"] = uiconfig.node_pack
schema["classification"] = uiconfig.classification
schema["version"] = uiconfig.version
if title := model_class.UIConfig.title:
schema["title"] = title
if tags := model_class.UIConfig.tags:
schema["tags"] = tags
if category := model_class.UIConfig.category:
schema["category"] = category
if node_pack := model_class.UIConfig.node_pack:
schema["node_pack"] = node_pack
schema["classification"] = model_class.UIConfig.classification
schema["version"] = model_class.UIConfig.version
if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = []
schema["class"] = "invocation"
@ -312,7 +309,7 @@ class BaseInvocation(ABC, BaseModel):
json_schema_extra={"field_kind": FieldKind.NodeAttribute},
)
UIConfig: ClassVar[Type[UIConfigBase]]
UIConfig: ClassVar[UIConfigBase]
model_config = ConfigDict(
protected_namespaces=(),
@ -441,30 +438,25 @@ def invocation(
validate_fields(cls.model_fields, invocation_type)
# Add OpenAPI schema extras
uiconfig_name = cls.__qualname__ + ".UIConfig"
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconfig_name:
cls.UIConfig = type(uiconfig_name, (UIConfigBase,), {})
cls.UIConfig.title = title
cls.UIConfig.tags = tags
cls.UIConfig.category = category
cls.UIConfig.classification = classification
# Grab the node pack's name from the module name, if it's a custom node
is_custom_node = cls.__module__.rsplit(".", 1)[0] == "invokeai.app.invocations"
if is_custom_node:
cls.UIConfig.node_pack = cls.__module__.split(".")[0]
else:
cls.UIConfig.node_pack = None
uiconfig: dict[str, Any] = {}
uiconfig["title"] = title
uiconfig["tags"] = tags
uiconfig["category"] = category
uiconfig["classification"] = classification
# The node pack is the module name - will be "invokeai" for built-in nodes
uiconfig["node_pack"] = cls.__module__.split(".")[0]
if version is not None:
try:
semver.Version.parse(version)
except ValueError as e:
raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e
cls.UIConfig.version = version
uiconfig["version"] = version
else:
logger.warn(f'No version specified for node "{invocation_type}", using "1.0.0"')
cls.UIConfig.version = "1.0.0"
uiconfig["version"] = "1.0.0"
cls.UIConfig = UIConfigBase(**uiconfig)
if use_cache is not None:
cls.model_fields["use_cache"].default = use_cache