mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(app): node_pack not added to openapi schema correctly
This commit is contained in:
parent
4a1a6639f6
commit
a413b261f0
@ -20,7 +20,6 @@ from typing import (
|
|||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
import semver
|
import semver
|
||||||
@ -80,7 +79,7 @@ class UIConfigBase(BaseModel):
|
|||||||
version: str = Field(
|
version: str = Field(
|
||||||
description='The node\'s version. Should be a valid semver string e.g. "1.0.0" or "3.8.13".',
|
description='The node\'s version. Should be a valid semver string e.g. "1.0.0" or "3.8.13".',
|
||||||
)
|
)
|
||||||
node_pack: Optional[str] = Field(default=None, description="Whether or not this is a custom node")
|
node_pack: str = Field(description="The node pack that this node belongs to, will be 'invokeai' for built-in nodes")
|
||||||
classification: Classification = Field(default=Classification.Stable, description="The node's classification")
|
classification: Classification = Field(default=Classification.Stable, description="The node's classification")
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
@ -230,18 +229,16 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None:
|
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None:
|
||||||
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
|
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
|
||||||
uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None))
|
if title := model_class.UIConfig.title:
|
||||||
if uiconfig is not None:
|
schema["title"] = title
|
||||||
if uiconfig.title is not None:
|
if tags := model_class.UIConfig.tags:
|
||||||
schema["title"] = uiconfig.title
|
schema["tags"] = tags
|
||||||
if uiconfig.tags is not None:
|
if category := model_class.UIConfig.category:
|
||||||
schema["tags"] = uiconfig.tags
|
schema["category"] = category
|
||||||
if uiconfig.category is not None:
|
if node_pack := model_class.UIConfig.node_pack:
|
||||||
schema["category"] = uiconfig.category
|
schema["node_pack"] = node_pack
|
||||||
if uiconfig.node_pack is not None:
|
schema["classification"] = model_class.UIConfig.classification
|
||||||
schema["node_pack"] = uiconfig.node_pack
|
schema["version"] = model_class.UIConfig.version
|
||||||
schema["classification"] = uiconfig.classification
|
|
||||||
schema["version"] = uiconfig.version
|
|
||||||
if "required" not in schema or not isinstance(schema["required"], list):
|
if "required" not in schema or not isinstance(schema["required"], list):
|
||||||
schema["required"] = []
|
schema["required"] = []
|
||||||
schema["class"] = "invocation"
|
schema["class"] = "invocation"
|
||||||
@ -312,7 +309,7 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
json_schema_extra={"field_kind": FieldKind.NodeAttribute},
|
json_schema_extra={"field_kind": FieldKind.NodeAttribute},
|
||||||
)
|
)
|
||||||
|
|
||||||
UIConfig: ClassVar[Type[UIConfigBase]]
|
UIConfig: ClassVar[UIConfigBase]
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
protected_namespaces=(),
|
protected_namespaces=(),
|
||||||
@ -441,30 +438,25 @@ def invocation(
|
|||||||
validate_fields(cls.model_fields, invocation_type)
|
validate_fields(cls.model_fields, invocation_type)
|
||||||
|
|
||||||
# Add OpenAPI schema extras
|
# Add OpenAPI schema extras
|
||||||
uiconfig_name = cls.__qualname__ + ".UIConfig"
|
uiconfig: dict[str, Any] = {}
|
||||||
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconfig_name:
|
uiconfig["title"] = title
|
||||||
cls.UIConfig = type(uiconfig_name, (UIConfigBase,), {})
|
uiconfig["tags"] = tags
|
||||||
cls.UIConfig.title = title
|
uiconfig["category"] = category
|
||||||
cls.UIConfig.tags = tags
|
uiconfig["classification"] = classification
|
||||||
cls.UIConfig.category = category
|
# The node pack is the module name - will be "invokeai" for built-in nodes
|
||||||
cls.UIConfig.classification = classification
|
uiconfig["node_pack"] = cls.__module__.split(".")[0]
|
||||||
|
|
||||||
# Grab the node pack's name from the module name, if it's a custom node
|
|
||||||
is_custom_node = cls.__module__.rsplit(".", 1)[0] == "invokeai.app.invocations"
|
|
||||||
if is_custom_node:
|
|
||||||
cls.UIConfig.node_pack = cls.__module__.split(".")[0]
|
|
||||||
else:
|
|
||||||
cls.UIConfig.node_pack = None
|
|
||||||
|
|
||||||
if version is not None:
|
if version is not None:
|
||||||
try:
|
try:
|
||||||
semver.Version.parse(version)
|
semver.Version.parse(version)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e
|
raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e
|
||||||
cls.UIConfig.version = version
|
uiconfig["version"] = version
|
||||||
else:
|
else:
|
||||||
logger.warn(f'No version specified for node "{invocation_type}", using "1.0.0"')
|
logger.warn(f'No version specified for node "{invocation_type}", using "1.0.0"')
|
||||||
cls.UIConfig.version = "1.0.0"
|
uiconfig["version"] = "1.0.0"
|
||||||
|
|
||||||
|
cls.UIConfig = UIConfigBase(**uiconfig)
|
||||||
|
|
||||||
if use_cache is not None:
|
if use_cache is not None:
|
||||||
cls.model_fields["use_cache"].default = use_cache
|
cls.model_fields["use_cache"].default = use_cache
|
||||||
|
Loading…
x
Reference in New Issue
Block a user