feat(nodes): improve docstrings in baseinvocation, disambiguate method names

This commit is contained in:
psychedelicious 2023-11-25 21:39:27 +11:00
parent ed79980dd4
commit 858bcdd3ff
2 changed files with 46 additions and 28 deletions

View File

@ -93,6 +93,10 @@ class UIType(str, Enum, metaclass=MetaEnum):
Similar to the Any Field, the `collect` and `iterate` nodes make use of `typing.Any`. To facilitate Similar to the Any Field, the `collect` and `iterate` nodes make use of `typing.Any`. To facilitate
handling these types in the client, we use `UIType._Collection` and `UIType._CollectionItem`. These handling these types in the client, we use `UIType._Collection` and `UIType._CollectionItem`. These
should not be used by node authors. should not be used by node authors.
- DEPRECATED Fields
These types are deprecated and should not be used by node authors. A warning will be logged if one is
used, and the type will be ignored. They are included here for backwards compatibility.
""" """
# region Model Field Types # region Model Field Types
@ -173,10 +177,8 @@ class UIComponent(str, Enum, metaclass=MetaEnum):
class InputFieldJSONSchemaExtra(BaseModel): class InputFieldJSONSchemaExtra(BaseModel):
""" """
*DO NOT USE* Extra attributes to be added to input fields and their OpenAPI schema. Used during graph execution,
This helper class is used to tell the client about our custom field attributes via OpenAPI and by the workflow editor during schema parsing and UI rendering.
schema generation, and Typescript type generation from that schema. It serves no functional
purpose in the backend.
""" """
input: Input input: Input
@ -198,10 +200,8 @@ class InputFieldJSONSchemaExtra(BaseModel):
class OutputFieldJSONSchemaExtra(BaseModel): class OutputFieldJSONSchemaExtra(BaseModel):
""" """
*DO NOT USE* Extra attributes to be added to input fields and their OpenAPI schema. Used by the workflow editor
This helper class is used to tell the client about our custom field attributes via OpenAPI during schema parsing and UI rendering.
schema generation, and Typescript type generation from that schema. It serves no functional
purpose in the backend.
""" """
field_kind: FieldKind field_kind: FieldKind
@ -215,11 +215,6 @@ class OutputFieldJSONSchemaExtra(BaseModel):
) )
def get_type(klass: BaseModel) -> str:
"""Helper function to get an invocation or invocation output's type. This is the default value of the `type` field."""
return klass.model_fields["type"].default
def InputField( def InputField(
# copied from pydantic's Field # copied from pydantic's Field
# TODO: Can we support default_factory? # TODO: Can we support default_factory?
@ -483,29 +478,39 @@ class BaseInvocationOutput(BaseModel):
@classmethod @classmethod
def register_output(cls, output: BaseInvocationOutput) -> None: def register_output(cls, output: BaseInvocationOutput) -> None:
"""Registers an invocation output."""
cls._output_classes.add(output) cls._output_classes.add(output)
@classmethod @classmethod
def get_outputs(cls) -> Iterable[BaseInvocationOutput]: def get_outputs(cls) -> Iterable[BaseInvocationOutput]:
"""Gets all invocation outputs."""
return cls._output_classes return cls._output_classes
@classmethod @classmethod
def get_outputs_union(cls) -> UnionType: def get_outputs_union(cls) -> UnionType:
"""Gets a union of all invocation outputs."""
outputs_union = Union[tuple(cls._output_classes)] # type: ignore [valid-type] outputs_union = Union[tuple(cls._output_classes)] # type: ignore [valid-type]
return outputs_union # type: ignore [return-value] return outputs_union # type: ignore [return-value]
@classmethod @classmethod
def get_output_types(cls) -> Iterable[str]: def get_output_types(cls) -> Iterable[str]:
return (get_type(i) for i in BaseInvocationOutput.get_outputs()) """Gets all invocation output types."""
return (i.get_type() for i in BaseInvocationOutput.get_outputs())
@staticmethod @staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None: def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
"""Adds various UI-facing attributes to the invocation output's OpenAPI schema."""
# Because we use a pydantic Literal field with default value for the invocation type, # Because we use a pydantic Literal field with default value for the invocation type,
# it will be typed as optional in the OpenAPI schema. Make it required manually. # it will be typed as optional in the OpenAPI schema. Make it required manually.
if "required" not in schema or not isinstance(schema["required"], list): if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = [] schema["required"] = []
schema["required"].extend(["type"]) schema["required"].extend(["type"])
@classmethod
def get_type(cls) -> str:
"""Gets the invocation output's type, as provided by the `@invocation_output` decorator."""
return cls.model_fields["type"].default
model_config = ConfigDict( model_config = ConfigDict(
protected_namespaces=(), protected_namespaces=(),
validate_assignment=True, validate_assignment=True,
@ -535,21 +540,29 @@ class BaseInvocation(ABC, BaseModel):
_invocation_classes: ClassVar[set[BaseInvocation]] = set() _invocation_classes: ClassVar[set[BaseInvocation]] = set()
@classmethod
def get_type(cls) -> str:
"""Gets the invocation's type, as provided by the `@invocation` decorator."""
return cls.model_fields["type"].default
@classmethod @classmethod
def register_invocation(cls, invocation: BaseInvocation) -> None: def register_invocation(cls, invocation: BaseInvocation) -> None:
"""Registers an invocation."""
cls._invocation_classes.add(invocation) cls._invocation_classes.add(invocation)
@classmethod @classmethod
def get_invocations_union(cls) -> UnionType: def get_invocations_union(cls) -> UnionType:
"""Gets a union of all invocation types."""
invocations_union = Union[tuple(cls._invocation_classes)] # type: ignore [valid-type] invocations_union = Union[tuple(cls._invocation_classes)] # type: ignore [valid-type]
return invocations_union # type: ignore [return-value] return invocations_union # type: ignore [return-value]
@classmethod @classmethod
def get_invocations(cls) -> Iterable[BaseInvocation]: def get_invocations(cls) -> Iterable[BaseInvocation]:
"""Gets all invocations, respecting the allowlist and denylist."""
app_config = InvokeAIAppConfig.get_config() app_config = InvokeAIAppConfig.get_config()
allowed_invocations: set[BaseInvocation] = set() allowed_invocations: set[BaseInvocation] = set()
for sc in cls._invocation_classes: for sc in cls._invocation_classes:
invocation_type = get_type(sc) invocation_type = sc.get_type()
is_in_allowlist = ( is_in_allowlist = (
invocation_type in app_config.allow_nodes if isinstance(app_config.allow_nodes, list) else True invocation_type in app_config.allow_nodes if isinstance(app_config.allow_nodes, list) else True
) )
@ -562,20 +575,22 @@ class BaseInvocation(ABC, BaseModel):
@classmethod @classmethod
def get_invocations_map(cls) -> dict[str, BaseInvocation]: def get_invocations_map(cls) -> dict[str, BaseInvocation]:
# Get the type strings out of the literals and into a dictionary """Gets a map of all invocation types to their invocation classes."""
return {get_type(i): i for i in BaseInvocation.get_invocations()} return {i.get_type(): i for i in BaseInvocation.get_invocations()}
@classmethod @classmethod
def get_invocation_types(cls) -> Iterable[str]: def get_invocation_types(cls) -> Iterable[str]:
return (get_type(i) for i in BaseInvocation.get_invocations()) """Gets all invocation types."""
return (i.get_type() for i in BaseInvocation.get_invocations())
@classmethod @classmethod
def get_output_type(cls) -> BaseInvocationOutput: def get_output_annotation(cls) -> BaseInvocationOutput:
"""Gets the invocation's output annotation (i.e. the return annotation of its `invoke()` method)."""
return signature(cls.invoke).return_annotation return signature(cls.invoke).return_annotation
@staticmethod @staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel], *args, **kwargs) -> None: def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel], *args, **kwargs) -> None:
# Add the various UI-facing attributes to the schema. These are used to build the invocation templates. """Adds various UI-facing attributes to the invocation's OpenAPI schema."""
uiconfig = getattr(model_class, "UIConfig", None) uiconfig = getattr(model_class, "UIConfig", None)
if uiconfig and hasattr(uiconfig, "title"): if uiconfig and hasattr(uiconfig, "title"):
schema["title"] = uiconfig.title schema["title"] = uiconfig.title
@ -595,6 +610,10 @@ class BaseInvocation(ABC, BaseModel):
pass pass
def invoke_internal(self, context: InvocationContext) -> BaseInvocationOutput: def invoke_internal(self, context: InvocationContext) -> BaseInvocationOutput:
"""
Internal invoke method, calls `invoke()` after some prep.
Handles optional fields that are required to call `invoke()` and invocation cache.
"""
for field_name, field in self.model_fields.items(): for field_name, field in self.model_fields.items():
if not field.json_schema_extra or callable(field.json_schema_extra): if not field.json_schema_extra or callable(field.json_schema_extra):
# something has gone terribly awry, we should always have this and it should be a dict # something has gone terribly awry, we should always have this and it should be a dict
@ -634,9 +653,6 @@ class BaseInvocation(ABC, BaseModel):
context.services.logger.debug(f'Skipping invocation cache for "{self.get_type()}": {self.id}') context.services.logger.debug(f'Skipping invocation cache for "{self.get_type()}": {self.id}')
return self.invoke(context) return self.invoke(context)
def get_type(self) -> str:
return self.model_fields["type"].default
id: str = Field( id: str = Field(
default_factory=uuid_string, default_factory=uuid_string,
description="The id of this instance of an invocation. Must be unique among all instances of invocations.", description="The id of this instance of an invocation. Must be unique among all instances of invocations.",
@ -693,9 +709,11 @@ RESERVED_PYDANTIC_FIELD_NAMES = {m[0] for m in inspect.getmembers(_Model())}
def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None: def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None:
""" """
Validates the fields of an invocation or invocation output: Validates the fields of an invocation or invocation output:
- must not override any pydantic reserved fields - Must not override any pydantic reserved fields
- must not end with "Collection" or "Polymorphic" as these are reserved for internal use - Must have a type annotation
- must be created via `InputField`, `OutputField`, or be an internal field defined in this file - Must have a json_schema_extra dict
- Must have field_kind in json_schema_extra
- Field name must not be reserved, according to its field_kind
""" """
for name, field in model_fields.items(): for name, field in model_fields.items():
if name in RESERVED_PYDANTIC_FIELD_NAMES: if name in RESERVED_PYDANTIC_FIELD_NAMES:

View File

@ -49,7 +49,7 @@ class Edge(BaseModel):
def get_output_field(node: BaseInvocation, field: str) -> Any: def get_output_field(node: BaseInvocation, field: str) -> Any:
node_type = type(node) node_type = type(node)
node_outputs = get_type_hints(node_type.get_output_type()) node_outputs = get_type_hints(node_type.get_output_annotation())
node_output_field = node_outputs.get(field) or None node_output_field = node_outputs.get(field) or None
return node_output_field return node_output_field
@ -379,7 +379,7 @@ class Graph(BaseModel):
raise NodeNotFoundError(f"Edge destination node {edge.destination.node_id} does not exist in the graph") raise NodeNotFoundError(f"Edge destination node {edge.destination.node_id} does not exist in the graph")
# output fields are not on the node object directly, they are on the output type # output fields are not on the node object directly, they are on the output type
if edge.source.field not in source_node.get_output_type().model_fields: if edge.source.field not in source_node.get_output_annotation().model_fields:
raise NodeFieldNotFoundError( raise NodeFieldNotFoundError(
f"Edge source field {edge.source.field} does not exist in node {edge.source.node_id}" f"Edge source field {edge.source.field} does not exist in node {edge.source.node_id}"
) )