feat(nodes): add field name validation

Protect against using reserved field names
This commit is contained in:
psychedelicious 2023-10-17 23:23:17 +11:00
parent bbae4045c9
commit 7b6e2bc37f

View File

@ -2,6 +2,7 @@
from __future__ import annotations
import inspect
import re
from abc import ABC, abstractmethod
from enum import Enum
@ -11,7 +12,7 @@ from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Op
import semver
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, create_model
from pydantic.fields import _Unset
from pydantic.fields import FieldInfo, _Unset
from pydantic_core import PydanticUndefined
from invokeai.app.services.config.config_default import InvokeAIAppConfig
@ -25,6 +26,10 @@ class InvalidVersionError(ValueError):
pass
class InvalidFieldError(TypeError):
pass
class FieldDescriptions:
denoising_start = "When to start denoising, expressed a percentage of total steps"
denoising_end = "When to stop denoising, expressed a percentage of total steps"
@ -302,6 +307,7 @@ def InputField(
ui_order=ui_order,
item_default=item_default,
ui_choice_labels=ui_choice_labels,
_field_kind="input",
)
field_args = dict(
@ -444,6 +450,7 @@ def OutputField(
ui_type=ui_type,
ui_hidden=ui_hidden,
ui_order=ui_order,
_field_kind="output",
),
)
@ -527,6 +534,7 @@ class BaseInvocationOutput(BaseModel):
schema["required"].extend(["type"])
model_config = ConfigDict(
protected_namespaces=(),
validate_assignment=True,
json_schema_serialization_defaults_required=True,
json_schema_extra=json_schema_extra,
@ -549,9 +557,6 @@ class MissingInputException(Exception):
class BaseInvocation(ABC, BaseModel):
"""
A node to process inputs and produce outputs.
May use dependency injection in __init__ to receive providers.
All invocations must use the `@invocation` decorator to provide their unique type.
"""
@ -667,17 +672,21 @@ class BaseInvocation(ABC, BaseModel):
id: str = Field(
default_factory=uuid_string,
description="The id of this instance of an invocation. Must be unique among all instances of invocations.",
json_schema_extra=dict(_field_kind="internal"),
)
is_intermediate: Optional[bool] = Field(
is_intermediate: bool = Field(
default=False,
description="Whether or not this is an intermediate invocation.",
json_schema_extra=dict(ui_type=UIType.IsIntermediate),
json_schema_extra=dict(ui_type=UIType.IsIntermediate, _field_kind="internal"),
)
use_cache: bool = Field(
default=True, description="Whether or not to use the cache", json_schema_extra=dict(_field_kind="internal")
)
use_cache: bool = InputField(default=True, description="Whether or not to use the cache")
UIConfig: ClassVar[Type[UIConfigBase]]
model_config = ConfigDict(
protected_namespaces=(),
validate_assignment=True,
json_schema_extra=json_schema_extra,
json_schema_serialization_defaults_required=True,
@ -688,6 +697,70 @@ class BaseInvocation(ABC, BaseModel):
TBaseInvocation = TypeVar("TBaseInvocation", bound=BaseInvocation)
RESERVED_INPUT_FIELD_NAMES = {
"id",
"is_intermediate",
"use_cache",
"type",
"workflow",
"metadata",
}
RESERVED_OUTPUT_FIELD_NAMES = {"type"}
class _Model(BaseModel):
pass
# Get all pydantic model attrs, methods, etc
RESERVED_PYDANTIC_FIELD_NAMES = set(map(lambda m: m[0], inspect.getmembers(_Model())))
print(RESERVED_PYDANTIC_FIELD_NAMES)
def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None:
"""
Validates the fields of an invocation or invocation output:
- must not override any pydantic reserved fields
- must be created via `InputField`, `OutputField`, or be an internal field defined in this file
"""
for name, field in model_fields.items():
if name in RESERVED_PYDANTIC_FIELD_NAMES:
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved by pydantic)')
field_kind = (
# _field_kind is defined via InputField(), OutputField() or by one of the internal fields defined in this file
field.json_schema_extra.get("_field_kind", None)
if field.json_schema_extra
else None
)
# must have a field_kind
if field_kind is None or field_kind not in {"input", "output", "internal"}:
raise InvalidFieldError(
f'Invalid field definition for "{name}" on "{model_type}" (maybe it\'s not an InputField or OutputField?)'
)
if field_kind == "input" and name in RESERVED_INPUT_FIELD_NAMES:
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved input field name)')
if field_kind == "output" and name in RESERVED_OUTPUT_FIELD_NAMES:
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved output field name)')
# internal fields *must* be in the reserved list
if (
field_kind == "internal"
and name not in RESERVED_INPUT_FIELD_NAMES
and name not in RESERVED_OUTPUT_FIELD_NAMES
):
raise InvalidFieldError(
f'Invalid field name "{name}" on "{model_type}" (internal field without reserved name)'
)
return None
def invocation(
invocation_type: str,
title: Optional[str] = None,
@ -697,7 +770,7 @@ def invocation(
use_cache: Optional[bool] = True,
) -> Callable[[Type[TBaseInvocation]], Type[TBaseInvocation]]:
"""
Adds metadata to an invocation.
Registers an invocation.
:param str invocation_type: The type of the invocation. Must be unique among all invocations.
:param Optional[str] title: Adds a title to the invocation. Use if the auto-generated title isn't quite right. Defaults to None.
@ -716,6 +789,8 @@ def invocation(
if invocation_type in BaseInvocation.get_invocation_types():
raise ValueError(f'Invocation type "{invocation_type}" already exists')
validate_fields(cls.model_fields, invocation_type)
# Add OpenAPI schema extras
uiconf_name = cls.__qualname__ + ".UIConfig"
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
@ -746,8 +821,7 @@ def invocation(
invocation_type_annotation = Literal[invocation_type] # type: ignore
invocation_type_field = Field(
title="type",
default=invocation_type,
title="type", default=invocation_type, json_schema_extra=dict(_field_kind="internal")
)
docstring = cls.__doc__
@ -788,13 +862,12 @@ def invocation_output(
if output_type in BaseInvocationOutput.get_output_types():
raise ValueError(f'Invocation type "{output_type}" already exists')
validate_fields(cls.model_fields, output_type)
# Add the output type to the model.
output_type_annotation = Literal[output_type] # type: ignore
output_type_field = Field(
title="type",
default=output_type,
)
output_type_field = Field(title="type", default=output_type, json_schema_extra=dict(_field_kind="internal"))
docstring = cls.__doc__
cls = create_model(
@ -825,7 +898,9 @@ WorkflowFieldValidator = TypeAdapter(WorkflowField)
class WithWorkflow(BaseModel):
workflow: Optional[WorkflowField] = InputField(default=None, description=FieldDescriptions.workflow)
workflow: Optional[WorkflowField] = Field(
default=None, description=FieldDescriptions.workflow, json_schema_extra=dict(_field_kind="internal")
)
class MetadataField(RootModel):
@ -841,4 +916,6 @@ MetadataFieldValidator = TypeAdapter(MetadataField)
class WithMetadata(BaseModel):
metadata: Optional[MetadataField] = InputField(default=None, description=FieldDescriptions.metadata)
metadata: Optional[MetadataField] = Field(
default=None, description=FieldDescriptions.metadata, json_schema_extra=dict(_field_kind="internal")
)