mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): add field name validation
Protect against using reserved field names
This commit is contained in:
parent
bbae4045c9
commit
7b6e2bc37f
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import inspect
|
||||||
import re
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@ -11,7 +12,7 @@ from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Op
|
|||||||
|
|
||||||
import semver
|
import semver
|
||||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, create_model
|
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, create_model
|
||||||
from pydantic.fields import _Unset
|
from pydantic.fields import FieldInfo, _Unset
|
||||||
from pydantic_core import PydanticUndefined
|
from pydantic_core import PydanticUndefined
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||||
@ -25,6 +26,10 @@ class InvalidVersionError(ValueError):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidFieldError(TypeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class FieldDescriptions:
|
class FieldDescriptions:
|
||||||
denoising_start = "When to start denoising, expressed a percentage of total steps"
|
denoising_start = "When to start denoising, expressed a percentage of total steps"
|
||||||
denoising_end = "When to stop denoising, expressed a percentage of total steps"
|
denoising_end = "When to stop denoising, expressed a percentage of total steps"
|
||||||
@ -302,6 +307,7 @@ def InputField(
|
|||||||
ui_order=ui_order,
|
ui_order=ui_order,
|
||||||
item_default=item_default,
|
item_default=item_default,
|
||||||
ui_choice_labels=ui_choice_labels,
|
ui_choice_labels=ui_choice_labels,
|
||||||
|
_field_kind="input",
|
||||||
)
|
)
|
||||||
|
|
||||||
field_args = dict(
|
field_args = dict(
|
||||||
@ -444,6 +450,7 @@ def OutputField(
|
|||||||
ui_type=ui_type,
|
ui_type=ui_type,
|
||||||
ui_hidden=ui_hidden,
|
ui_hidden=ui_hidden,
|
||||||
ui_order=ui_order,
|
ui_order=ui_order,
|
||||||
|
_field_kind="output",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -527,6 +534,7 @@ class BaseInvocationOutput(BaseModel):
|
|||||||
schema["required"].extend(["type"])
|
schema["required"].extend(["type"])
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
|
protected_namespaces=(),
|
||||||
validate_assignment=True,
|
validate_assignment=True,
|
||||||
json_schema_serialization_defaults_required=True,
|
json_schema_serialization_defaults_required=True,
|
||||||
json_schema_extra=json_schema_extra,
|
json_schema_extra=json_schema_extra,
|
||||||
@ -549,9 +557,6 @@ class MissingInputException(Exception):
|
|||||||
|
|
||||||
class BaseInvocation(ABC, BaseModel):
|
class BaseInvocation(ABC, BaseModel):
|
||||||
"""
|
"""
|
||||||
A node to process inputs and produce outputs.
|
|
||||||
May use dependency injection in __init__ to receive providers.
|
|
||||||
|
|
||||||
All invocations must use the `@invocation` decorator to provide their unique type.
|
All invocations must use the `@invocation` decorator to provide their unique type.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -667,17 +672,21 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
id: str = Field(
|
id: str = Field(
|
||||||
default_factory=uuid_string,
|
default_factory=uuid_string,
|
||||||
description="The id of this instance of an invocation. Must be unique among all instances of invocations.",
|
description="The id of this instance of an invocation. Must be unique among all instances of invocations.",
|
||||||
|
json_schema_extra=dict(_field_kind="internal"),
|
||||||
)
|
)
|
||||||
is_intermediate: Optional[bool] = Field(
|
is_intermediate: bool = Field(
|
||||||
default=False,
|
default=False,
|
||||||
description="Whether or not this is an intermediate invocation.",
|
description="Whether or not this is an intermediate invocation.",
|
||||||
json_schema_extra=dict(ui_type=UIType.IsIntermediate),
|
json_schema_extra=dict(ui_type=UIType.IsIntermediate, _field_kind="internal"),
|
||||||
|
)
|
||||||
|
use_cache: bool = Field(
|
||||||
|
default=True, description="Whether or not to use the cache", json_schema_extra=dict(_field_kind="internal")
|
||||||
)
|
)
|
||||||
use_cache: bool = InputField(default=True, description="Whether or not to use the cache")
|
|
||||||
|
|
||||||
UIConfig: ClassVar[Type[UIConfigBase]]
|
UIConfig: ClassVar[Type[UIConfigBase]]
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
|
protected_namespaces=(),
|
||||||
validate_assignment=True,
|
validate_assignment=True,
|
||||||
json_schema_extra=json_schema_extra,
|
json_schema_extra=json_schema_extra,
|
||||||
json_schema_serialization_defaults_required=True,
|
json_schema_serialization_defaults_required=True,
|
||||||
@ -688,6 +697,70 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
TBaseInvocation = TypeVar("TBaseInvocation", bound=BaseInvocation)
|
TBaseInvocation = TypeVar("TBaseInvocation", bound=BaseInvocation)
|
||||||
|
|
||||||
|
|
||||||
|
RESERVED_INPUT_FIELD_NAMES = {
|
||||||
|
"id",
|
||||||
|
"is_intermediate",
|
||||||
|
"use_cache",
|
||||||
|
"type",
|
||||||
|
"workflow",
|
||||||
|
"metadata",
|
||||||
|
}
|
||||||
|
|
||||||
|
RESERVED_OUTPUT_FIELD_NAMES = {"type"}
|
||||||
|
|
||||||
|
|
||||||
|
class _Model(BaseModel):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Get all pydantic model attrs, methods, etc
|
||||||
|
RESERVED_PYDANTIC_FIELD_NAMES = set(map(lambda m: m[0], inspect.getmembers(_Model())))
|
||||||
|
|
||||||
|
print(RESERVED_PYDANTIC_FIELD_NAMES)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None:
|
||||||
|
"""
|
||||||
|
Validates the fields of an invocation or invocation output:
|
||||||
|
- must not override any pydantic reserved fields
|
||||||
|
- must be created via `InputField`, `OutputField`, or be an internal field defined in this file
|
||||||
|
"""
|
||||||
|
for name, field in model_fields.items():
|
||||||
|
if name in RESERVED_PYDANTIC_FIELD_NAMES:
|
||||||
|
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved by pydantic)')
|
||||||
|
|
||||||
|
field_kind = (
|
||||||
|
# _field_kind is defined via InputField(), OutputField() or by one of the internal fields defined in this file
|
||||||
|
field.json_schema_extra.get("_field_kind", None)
|
||||||
|
if field.json_schema_extra
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# must have a field_kind
|
||||||
|
if field_kind is None or field_kind not in {"input", "output", "internal"}:
|
||||||
|
raise InvalidFieldError(
|
||||||
|
f'Invalid field definition for "{name}" on "{model_type}" (maybe it\'s not an InputField or OutputField?)'
|
||||||
|
)
|
||||||
|
|
||||||
|
if field_kind == "input" and name in RESERVED_INPUT_FIELD_NAMES:
|
||||||
|
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved input field name)')
|
||||||
|
|
||||||
|
if field_kind == "output" and name in RESERVED_OUTPUT_FIELD_NAMES:
|
||||||
|
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved output field name)')
|
||||||
|
|
||||||
|
# internal fields *must* be in the reserved list
|
||||||
|
if (
|
||||||
|
field_kind == "internal"
|
||||||
|
and name not in RESERVED_INPUT_FIELD_NAMES
|
||||||
|
and name not in RESERVED_OUTPUT_FIELD_NAMES
|
||||||
|
):
|
||||||
|
raise InvalidFieldError(
|
||||||
|
f'Invalid field name "{name}" on "{model_type}" (internal field without reserved name)'
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def invocation(
|
def invocation(
|
||||||
invocation_type: str,
|
invocation_type: str,
|
||||||
title: Optional[str] = None,
|
title: Optional[str] = None,
|
||||||
@ -697,7 +770,7 @@ def invocation(
|
|||||||
use_cache: Optional[bool] = True,
|
use_cache: Optional[bool] = True,
|
||||||
) -> Callable[[Type[TBaseInvocation]], Type[TBaseInvocation]]:
|
) -> Callable[[Type[TBaseInvocation]], Type[TBaseInvocation]]:
|
||||||
"""
|
"""
|
||||||
Adds metadata to an invocation.
|
Registers an invocation.
|
||||||
|
|
||||||
:param str invocation_type: The type of the invocation. Must be unique among all invocations.
|
:param str invocation_type: The type of the invocation. Must be unique among all invocations.
|
||||||
:param Optional[str] title: Adds a title to the invocation. Use if the auto-generated title isn't quite right. Defaults to None.
|
:param Optional[str] title: Adds a title to the invocation. Use if the auto-generated title isn't quite right. Defaults to None.
|
||||||
@ -716,6 +789,8 @@ def invocation(
|
|||||||
if invocation_type in BaseInvocation.get_invocation_types():
|
if invocation_type in BaseInvocation.get_invocation_types():
|
||||||
raise ValueError(f'Invocation type "{invocation_type}" already exists')
|
raise ValueError(f'Invocation type "{invocation_type}" already exists')
|
||||||
|
|
||||||
|
validate_fields(cls.model_fields, invocation_type)
|
||||||
|
|
||||||
# Add OpenAPI schema extras
|
# Add OpenAPI schema extras
|
||||||
uiconf_name = cls.__qualname__ + ".UIConfig"
|
uiconf_name = cls.__qualname__ + ".UIConfig"
|
||||||
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
|
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
|
||||||
@ -746,8 +821,7 @@ def invocation(
|
|||||||
|
|
||||||
invocation_type_annotation = Literal[invocation_type] # type: ignore
|
invocation_type_annotation = Literal[invocation_type] # type: ignore
|
||||||
invocation_type_field = Field(
|
invocation_type_field = Field(
|
||||||
title="type",
|
title="type", default=invocation_type, json_schema_extra=dict(_field_kind="internal")
|
||||||
default=invocation_type,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
docstring = cls.__doc__
|
docstring = cls.__doc__
|
||||||
@ -788,13 +862,12 @@ def invocation_output(
|
|||||||
if output_type in BaseInvocationOutput.get_output_types():
|
if output_type in BaseInvocationOutput.get_output_types():
|
||||||
raise ValueError(f'Invocation type "{output_type}" already exists')
|
raise ValueError(f'Invocation type "{output_type}" already exists')
|
||||||
|
|
||||||
|
validate_fields(cls.model_fields, output_type)
|
||||||
|
|
||||||
# Add the output type to the model.
|
# Add the output type to the model.
|
||||||
|
|
||||||
output_type_annotation = Literal[output_type] # type: ignore
|
output_type_annotation = Literal[output_type] # type: ignore
|
||||||
output_type_field = Field(
|
output_type_field = Field(title="type", default=output_type, json_schema_extra=dict(_field_kind="internal"))
|
||||||
title="type",
|
|
||||||
default=output_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
docstring = cls.__doc__
|
docstring = cls.__doc__
|
||||||
cls = create_model(
|
cls = create_model(
|
||||||
@ -825,7 +898,9 @@ WorkflowFieldValidator = TypeAdapter(WorkflowField)
|
|||||||
|
|
||||||
|
|
||||||
class WithWorkflow(BaseModel):
|
class WithWorkflow(BaseModel):
|
||||||
workflow: Optional[WorkflowField] = InputField(default=None, description=FieldDescriptions.workflow)
|
workflow: Optional[WorkflowField] = Field(
|
||||||
|
default=None, description=FieldDescriptions.workflow, json_schema_extra=dict(_field_kind="internal")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MetadataField(RootModel):
|
class MetadataField(RootModel):
|
||||||
@ -841,4 +916,6 @@ MetadataFieldValidator = TypeAdapter(MetadataField)
|
|||||||
|
|
||||||
|
|
||||||
class WithMetadata(BaseModel):
|
class WithMetadata(BaseModel):
|
||||||
metadata: Optional[MetadataField] = InputField(default=None, description=FieldDescriptions.metadata)
|
metadata: Optional[MetadataField] = Field(
|
||||||
|
default=None, description=FieldDescriptions.metadata, json_schema_extra=dict(_field_kind="internal")
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user