From 7b6e2bc37fd1381a944768d19bb50b0cb27ee1ad Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 17 Oct 2023 23:23:17 +1100 Subject: [PATCH] feat(nodes): add field name validation Protect against using reserved field names --- invokeai/app/invocations/baseinvocation.py | 109 ++++++++++++++++++--- 1 file changed, 93 insertions(+), 16 deletions(-) diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 5f1ff0395f..25589510a6 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -2,6 +2,7 @@ from __future__ import annotations +import inspect import re from abc import ABC, abstractmethod from enum import Enum @@ -11,7 +12,7 @@ from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Op import semver from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, create_model -from pydantic.fields import _Unset +from pydantic.fields import FieldInfo, _Unset from pydantic_core import PydanticUndefined from invokeai.app.services.config.config_default import InvokeAIAppConfig @@ -25,6 +26,10 @@ class InvalidVersionError(ValueError): pass +class InvalidFieldError(TypeError): + pass + + class FieldDescriptions: denoising_start = "When to start denoising, expressed a percentage of total steps" denoising_end = "When to stop denoising, expressed a percentage of total steps" @@ -302,6 +307,7 @@ def InputField( ui_order=ui_order, item_default=item_default, ui_choice_labels=ui_choice_labels, + _field_kind="input", ) field_args = dict( @@ -444,6 +450,7 @@ def OutputField( ui_type=ui_type, ui_hidden=ui_hidden, ui_order=ui_order, + _field_kind="output", ), ) @@ -527,6 +534,7 @@ class BaseInvocationOutput(BaseModel): schema["required"].extend(["type"]) model_config = ConfigDict( + protected_namespaces=(), validate_assignment=True, json_schema_serialization_defaults_required=True, json_schema_extra=json_schema_extra, @@ -549,9 +557,6 @@ class MissingInputException(Exception): class BaseInvocation(ABC, BaseModel): """ - A node to process inputs and produce outputs. - May use dependency injection in __init__ to receive providers. - All invocations must use the `@invocation` decorator to provide their unique type. """ @@ -667,17 +672,21 @@ class BaseInvocation(ABC, BaseModel): id: str = Field( default_factory=uuid_string, description="The id of this instance of an invocation. Must be unique among all instances of invocations.", + json_schema_extra=dict(_field_kind="internal"), ) - is_intermediate: Optional[bool] = Field( + is_intermediate: bool = Field( default=False, description="Whether or not this is an intermediate invocation.", - json_schema_extra=dict(ui_type=UIType.IsIntermediate), + json_schema_extra=dict(ui_type=UIType.IsIntermediate, _field_kind="internal"), + ) + use_cache: bool = Field( + default=True, description="Whether or not to use the cache", json_schema_extra=dict(_field_kind="internal") ) - use_cache: bool = InputField(default=True, description="Whether or not to use the cache") UIConfig: ClassVar[Type[UIConfigBase]] model_config = ConfigDict( + protected_namespaces=(), validate_assignment=True, json_schema_extra=json_schema_extra, json_schema_serialization_defaults_required=True, @@ -688,6 +697,70 @@ class BaseInvocation(ABC, BaseModel): TBaseInvocation = TypeVar("TBaseInvocation", bound=BaseInvocation) +RESERVED_INPUT_FIELD_NAMES = { + "id", + "is_intermediate", + "use_cache", + "type", + "workflow", + "metadata", +} + +RESERVED_OUTPUT_FIELD_NAMES = {"type"} + + +class _Model(BaseModel): + pass + + +# Get all pydantic model attrs, methods, etc +RESERVED_PYDANTIC_FIELD_NAMES = set(map(lambda m: m[0], inspect.getmembers(_Model()))) + +print(RESERVED_PYDANTIC_FIELD_NAMES) + + +def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None: + """ + Validates the fields of an invocation or invocation output: + - must not override any pydantic reserved fields + - must be created via `InputField`, `OutputField`, or be an internal field defined in this file + """ + for name, field in model_fields.items(): + if name in RESERVED_PYDANTIC_FIELD_NAMES: + raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved by pydantic)') + + field_kind = ( + # _field_kind is defined via InputField(), OutputField() or by one of the internal fields defined in this file + field.json_schema_extra.get("_field_kind", None) + if field.json_schema_extra + else None + ) + + # must have a field_kind + if field_kind is None or field_kind not in {"input", "output", "internal"}: + raise InvalidFieldError( + f'Invalid field definition for "{name}" on "{model_type}" (maybe it\'s not an InputField or OutputField?)' + ) + + if field_kind == "input" and name in RESERVED_INPUT_FIELD_NAMES: + raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved input field name)') + + if field_kind == "output" and name in RESERVED_OUTPUT_FIELD_NAMES: + raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved output field name)') + + # internal fields *must* be in the reserved list + if ( + field_kind == "internal" + and name not in RESERVED_INPUT_FIELD_NAMES + and name not in RESERVED_OUTPUT_FIELD_NAMES + ): + raise InvalidFieldError( + f'Invalid field name "{name}" on "{model_type}" (internal field without reserved name)' + ) + + return None + + def invocation( invocation_type: str, title: Optional[str] = None, @@ -697,7 +770,7 @@ def invocation( use_cache: Optional[bool] = True, ) -> Callable[[Type[TBaseInvocation]], Type[TBaseInvocation]]: """ - Adds metadata to an invocation. + Registers an invocation. :param str invocation_type: The type of the invocation. Must be unique among all invocations. :param Optional[str] title: Adds a title to the invocation. Use if the auto-generated title isn't quite right. Defaults to None. @@ -716,6 +789,8 @@ def invocation( if invocation_type in BaseInvocation.get_invocation_types(): raise ValueError(f'Invocation type "{invocation_type}" already exists') + validate_fields(cls.model_fields, invocation_type) + # Add OpenAPI schema extras uiconf_name = cls.__qualname__ + ".UIConfig" if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name: @@ -746,8 +821,7 @@ def invocation( invocation_type_annotation = Literal[invocation_type] # type: ignore invocation_type_field = Field( - title="type", - default=invocation_type, + title="type", default=invocation_type, json_schema_extra=dict(_field_kind="internal") ) docstring = cls.__doc__ @@ -788,13 +862,12 @@ def invocation_output( if output_type in BaseInvocationOutput.get_output_types(): raise ValueError(f'Invocation type "{output_type}" already exists') + validate_fields(cls.model_fields, output_type) + # Add the output type to the model. output_type_annotation = Literal[output_type] # type: ignore - output_type_field = Field( - title="type", - default=output_type, - ) + output_type_field = Field(title="type", default=output_type, json_schema_extra=dict(_field_kind="internal")) docstring = cls.__doc__ cls = create_model( @@ -825,7 +898,9 @@ WorkflowFieldValidator = TypeAdapter(WorkflowField) class WithWorkflow(BaseModel): - workflow: Optional[WorkflowField] = InputField(default=None, description=FieldDescriptions.workflow) + workflow: Optional[WorkflowField] = Field( + default=None, description=FieldDescriptions.workflow, json_schema_extra=dict(_field_kind="internal") + ) class MetadataField(RootModel): @@ -841,4 +916,6 @@ MetadataFieldValidator = TypeAdapter(MetadataField) class WithMetadata(BaseModel): - metadata: Optional[MetadataField] = InputField(default=None, description=FieldDescriptions.metadata) + metadata: Optional[MetadataField] = Field( + default=None, description=FieldDescriptions.metadata, json_schema_extra=dict(_field_kind="internal") + )