From 858bcdd3ff780e46d157402d404e373d29f8d3a8 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 25 Nov 2023 21:39:27 +1100 Subject: [PATCH] feat(nodes): improve docstrings in baseinvocation, disambiguate method names --- invokeai/app/invocations/baseinvocation.py | 70 ++++++++++++++-------- invokeai/app/services/shared/graph.py | 4 +- 2 files changed, 46 insertions(+), 28 deletions(-) diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index cddbd071de..59978c13c1 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -93,6 +93,10 @@ class UIType(str, Enum, metaclass=MetaEnum): Similar to the Any Field, the `collect` and `iterate` nodes make use of `typing.Any`. To facilitate handling these types in the client, we use `UIType._Collection` and `UIType._CollectionItem`. These should not be used by node authors. + + - DEPRECATED Fields + These types are deprecated and should not be used by node authors. A warning will be logged if one is + used, and the type will be ignored. They are included here for backwards compatibility. """ # region Model Field Types @@ -173,10 +177,8 @@ class UIComponent(str, Enum, metaclass=MetaEnum): class InputFieldJSONSchemaExtra(BaseModel): """ - *DO NOT USE* - This helper class is used to tell the client about our custom field attributes via OpenAPI - schema generation, and Typescript type generation from that schema. It serves no functional - purpose in the backend. + Extra attributes to be added to input fields and their OpenAPI schema. Used during graph execution, + and by the workflow editor during schema parsing and UI rendering. """ input: Input @@ -198,10 +200,8 @@ class InputFieldJSONSchemaExtra(BaseModel): class OutputFieldJSONSchemaExtra(BaseModel): """ - *DO NOT USE* - This helper class is used to tell the client about our custom field attributes via OpenAPI - schema generation, and Typescript type generation from that schema. It serves no functional - purpose in the backend. + Extra attributes to be added to input fields and their OpenAPI schema. Used by the workflow editor + during schema parsing and UI rendering. """ field_kind: FieldKind @@ -215,11 +215,6 @@ class OutputFieldJSONSchemaExtra(BaseModel): ) -def get_type(klass: BaseModel) -> str: - """Helper function to get an invocation or invocation output's type. This is the default value of the `type` field.""" - return klass.model_fields["type"].default - - def InputField( # copied from pydantic's Field # TODO: Can we support default_factory? @@ -483,29 +478,39 @@ class BaseInvocationOutput(BaseModel): @classmethod def register_output(cls, output: BaseInvocationOutput) -> None: + """Registers an invocation output.""" cls._output_classes.add(output) @classmethod def get_outputs(cls) -> Iterable[BaseInvocationOutput]: + """Gets all invocation outputs.""" return cls._output_classes @classmethod def get_outputs_union(cls) -> UnionType: + """Gets a union of all invocation outputs.""" outputs_union = Union[tuple(cls._output_classes)] # type: ignore [valid-type] return outputs_union # type: ignore [return-value] @classmethod def get_output_types(cls) -> Iterable[str]: - return (get_type(i) for i in BaseInvocationOutput.get_outputs()) + """Gets all invocation output types.""" + return (i.get_type() for i in BaseInvocationOutput.get_outputs()) @staticmethod def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None: + """Adds various UI-facing attributes to the invocation output's OpenAPI schema.""" # Because we use a pydantic Literal field with default value for the invocation type, # it will be typed as optional in the OpenAPI schema. Make it required manually. if "required" not in schema or not isinstance(schema["required"], list): schema["required"] = [] schema["required"].extend(["type"]) + @classmethod + def get_type(cls) -> str: + """Gets the invocation output's type, as provided by the `@invocation_output` decorator.""" + return cls.model_fields["type"].default + model_config = ConfigDict( protected_namespaces=(), validate_assignment=True, @@ -535,21 +540,29 @@ class BaseInvocation(ABC, BaseModel): _invocation_classes: ClassVar[set[BaseInvocation]] = set() + @classmethod + def get_type(cls) -> str: + """Gets the invocation's type, as provided by the `@invocation` decorator.""" + return cls.model_fields["type"].default + @classmethod def register_invocation(cls, invocation: BaseInvocation) -> None: + """Registers an invocation.""" cls._invocation_classes.add(invocation) @classmethod def get_invocations_union(cls) -> UnionType: + """Gets a union of all invocation types.""" invocations_union = Union[tuple(cls._invocation_classes)] # type: ignore [valid-type] return invocations_union # type: ignore [return-value] @classmethod def get_invocations(cls) -> Iterable[BaseInvocation]: + """Gets all invocations, respecting the allowlist and denylist.""" app_config = InvokeAIAppConfig.get_config() allowed_invocations: set[BaseInvocation] = set() for sc in cls._invocation_classes: - invocation_type = get_type(sc) + invocation_type = sc.get_type() is_in_allowlist = ( invocation_type in app_config.allow_nodes if isinstance(app_config.allow_nodes, list) else True ) @@ -562,20 +575,22 @@ class BaseInvocation(ABC, BaseModel): @classmethod def get_invocations_map(cls) -> dict[str, BaseInvocation]: - # Get the type strings out of the literals and into a dictionary - return {get_type(i): i for i in BaseInvocation.get_invocations()} + """Gets a map of all invocation types to their invocation classes.""" + return {i.get_type(): i for i in BaseInvocation.get_invocations()} @classmethod def get_invocation_types(cls) -> Iterable[str]: - return (get_type(i) for i in BaseInvocation.get_invocations()) + """Gets all invocation types.""" + return (i.get_type() for i in BaseInvocation.get_invocations()) @classmethod - def get_output_type(cls) -> BaseInvocationOutput: + def get_output_annotation(cls) -> BaseInvocationOutput: + """Gets the invocation's output annotation (i.e. the return annotation of its `invoke()` method).""" return signature(cls.invoke).return_annotation @staticmethod def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel], *args, **kwargs) -> None: - # Add the various UI-facing attributes to the schema. These are used to build the invocation templates. + """Adds various UI-facing attributes to the invocation's OpenAPI schema.""" uiconfig = getattr(model_class, "UIConfig", None) if uiconfig and hasattr(uiconfig, "title"): schema["title"] = uiconfig.title @@ -595,6 +610,10 @@ class BaseInvocation(ABC, BaseModel): pass def invoke_internal(self, context: InvocationContext) -> BaseInvocationOutput: + """ + Internal invoke method, calls `invoke()` after some prep. + Handles optional fields that are required to call `invoke()` and invocation cache. + """ for field_name, field in self.model_fields.items(): if not field.json_schema_extra or callable(field.json_schema_extra): # something has gone terribly awry, we should always have this and it should be a dict @@ -634,9 +653,6 @@ class BaseInvocation(ABC, BaseModel): context.services.logger.debug(f'Skipping invocation cache for "{self.get_type()}": {self.id}') return self.invoke(context) - def get_type(self) -> str: - return self.model_fields["type"].default - id: str = Field( default_factory=uuid_string, description="The id of this instance of an invocation. Must be unique among all instances of invocations.", @@ -693,9 +709,11 @@ RESERVED_PYDANTIC_FIELD_NAMES = {m[0] for m in inspect.getmembers(_Model())} def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None: """ Validates the fields of an invocation or invocation output: - - must not override any pydantic reserved fields - - must not end with "Collection" or "Polymorphic" as these are reserved for internal use - - must be created via `InputField`, `OutputField`, or be an internal field defined in this file + - Must not override any pydantic reserved fields + - Must have a type annotation + - Must have a json_schema_extra dict + - Must have field_kind in json_schema_extra + - Field name must not be reserved, according to its field_kind """ for name, field in model_fields.items(): if name in RESERVED_PYDANTIC_FIELD_NAMES: diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index ee86ef17c6..0d97c0b9a1 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -49,7 +49,7 @@ class Edge(BaseModel): def get_output_field(node: BaseInvocation, field: str) -> Any: node_type = type(node) - node_outputs = get_type_hints(node_type.get_output_type()) + node_outputs = get_type_hints(node_type.get_output_annotation()) node_output_field = node_outputs.get(field) or None return node_output_field @@ -379,7 +379,7 @@ class Graph(BaseModel): raise NodeNotFoundError(f"Edge destination node {edge.destination.node_id} does not exist in the graph") # output fields are not on the node object directly, they are on the output type - if edge.source.field not in source_node.get_output_type().model_fields: + if edge.source.field not in source_node.get_output_annotation().model_fields: raise NodeFieldNotFoundError( f"Edge source field {edge.source.field} does not exist in node {edge.source.node_id}" )