fix(app): node_pack not added to openapi schema correctly

This commit is contained in:
psychedelicious 2024-08-28 21:06:33 +10:00
parent 4a1a6639f6
commit a413b261f0

View File

@ -20,7 +20,6 @@ from typing import (
Type, Type,
TypeVar, TypeVar,
Union, Union,
cast,
) )
import semver import semver
@ -80,7 +79,7 @@ class UIConfigBase(BaseModel):
version: str = Field( version: str = Field(
description='The node\'s version. Should be a valid semver string e.g. "1.0.0" or "3.8.13".', description='The node\'s version. Should be a valid semver string e.g. "1.0.0" or "3.8.13".',
) )
node_pack: Optional[str] = Field(default=None, description="Whether or not this is a custom node") node_pack: str = Field(description="The node pack that this node belongs to, will be 'invokeai' for built-in nodes")
classification: Classification = Field(default=Classification.Stable, description="The node's classification") classification: Classification = Field(default=Classification.Stable, description="The node's classification")
model_config = ConfigDict( model_config = ConfigDict(
@ -230,18 +229,16 @@ class BaseInvocation(ABC, BaseModel):
@staticmethod @staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None: def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None:
"""Adds various UI-facing attributes to the invocation's OpenAPI schema.""" """Adds various UI-facing attributes to the invocation's OpenAPI schema."""
uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None)) if title := model_class.UIConfig.title:
if uiconfig is not None: schema["title"] = title
if uiconfig.title is not None: if tags := model_class.UIConfig.tags:
schema["title"] = uiconfig.title schema["tags"] = tags
if uiconfig.tags is not None: if category := model_class.UIConfig.category:
schema["tags"] = uiconfig.tags schema["category"] = category
if uiconfig.category is not None: if node_pack := model_class.UIConfig.node_pack:
schema["category"] = uiconfig.category schema["node_pack"] = node_pack
if uiconfig.node_pack is not None: schema["classification"] = model_class.UIConfig.classification
schema["node_pack"] = uiconfig.node_pack schema["version"] = model_class.UIConfig.version
schema["classification"] = uiconfig.classification
schema["version"] = uiconfig.version
if "required" not in schema or not isinstance(schema["required"], list): if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = [] schema["required"] = []
schema["class"] = "invocation" schema["class"] = "invocation"
@ -312,7 +309,7 @@ class BaseInvocation(ABC, BaseModel):
json_schema_extra={"field_kind": FieldKind.NodeAttribute}, json_schema_extra={"field_kind": FieldKind.NodeAttribute},
) )
UIConfig: ClassVar[Type[UIConfigBase]] UIConfig: ClassVar[UIConfigBase]
model_config = ConfigDict( model_config = ConfigDict(
protected_namespaces=(), protected_namespaces=(),
@ -441,30 +438,25 @@ def invocation(
validate_fields(cls.model_fields, invocation_type) validate_fields(cls.model_fields, invocation_type)
# Add OpenAPI schema extras # Add OpenAPI schema extras
uiconfig_name = cls.__qualname__ + ".UIConfig" uiconfig: dict[str, Any] = {}
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconfig_name: uiconfig["title"] = title
cls.UIConfig = type(uiconfig_name, (UIConfigBase,), {}) uiconfig["tags"] = tags
cls.UIConfig.title = title uiconfig["category"] = category
cls.UIConfig.tags = tags uiconfig["classification"] = classification
cls.UIConfig.category = category # The node pack is the module name - will be "invokeai" for built-in nodes
cls.UIConfig.classification = classification uiconfig["node_pack"] = cls.__module__.split(".")[0]
# Grab the node pack's name from the module name, if it's a custom node
is_custom_node = cls.__module__.rsplit(".", 1)[0] == "invokeai.app.invocations"
if is_custom_node:
cls.UIConfig.node_pack = cls.__module__.split(".")[0]
else:
cls.UIConfig.node_pack = None
if version is not None: if version is not None:
try: try:
semver.Version.parse(version) semver.Version.parse(version)
except ValueError as e: except ValueError as e:
raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e
cls.UIConfig.version = version uiconfig["version"] = version
else: else:
logger.warn(f'No version specified for node "{invocation_type}", using "1.0.0"') logger.warn(f'No version specified for node "{invocation_type}", using "1.0.0"')
cls.UIConfig.version = "1.0.0" uiconfig["version"] = "1.0.0"
cls.UIConfig = UIConfigBase(**uiconfig)
if use_cache is not None: if use_cache is not None:
cls.model_fields["use_cache"].default = use_cache cls.model_fields["use_cache"].default = use_cache