mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): improve docstrings in baseinvocation, disambiguate method names
This commit is contained in:
parent
ed79980dd4
commit
858bcdd3ff
@ -93,6 +93,10 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
|||||||
Similar to the Any Field, the `collect` and `iterate` nodes make use of `typing.Any`. To facilitate
|
Similar to the Any Field, the `collect` and `iterate` nodes make use of `typing.Any`. To facilitate
|
||||||
handling these types in the client, we use `UIType._Collection` and `UIType._CollectionItem`. These
|
handling these types in the client, we use `UIType._Collection` and `UIType._CollectionItem`. These
|
||||||
should not be used by node authors.
|
should not be used by node authors.
|
||||||
|
|
||||||
|
- DEPRECATED Fields
|
||||||
|
These types are deprecated and should not be used by node authors. A warning will be logged if one is
|
||||||
|
used, and the type will be ignored. They are included here for backwards compatibility.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# region Model Field Types
|
# region Model Field Types
|
||||||
@ -173,10 +177,8 @@ class UIComponent(str, Enum, metaclass=MetaEnum):
|
|||||||
|
|
||||||
class InputFieldJSONSchemaExtra(BaseModel):
|
class InputFieldJSONSchemaExtra(BaseModel):
|
||||||
"""
|
"""
|
||||||
*DO NOT USE*
|
Extra attributes to be added to input fields and their OpenAPI schema. Used during graph execution,
|
||||||
This helper class is used to tell the client about our custom field attributes via OpenAPI
|
and by the workflow editor during schema parsing and UI rendering.
|
||||||
schema generation, and Typescript type generation from that schema. It serves no functional
|
|
||||||
purpose in the backend.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
input: Input
|
input: Input
|
||||||
@ -198,10 +200,8 @@ class InputFieldJSONSchemaExtra(BaseModel):
|
|||||||
|
|
||||||
class OutputFieldJSONSchemaExtra(BaseModel):
|
class OutputFieldJSONSchemaExtra(BaseModel):
|
||||||
"""
|
"""
|
||||||
*DO NOT USE*
|
Extra attributes to be added to input fields and their OpenAPI schema. Used by the workflow editor
|
||||||
This helper class is used to tell the client about our custom field attributes via OpenAPI
|
during schema parsing and UI rendering.
|
||||||
schema generation, and Typescript type generation from that schema. It serves no functional
|
|
||||||
purpose in the backend.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
field_kind: FieldKind
|
field_kind: FieldKind
|
||||||
@ -215,11 +215,6 @@ class OutputFieldJSONSchemaExtra(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_type(klass: BaseModel) -> str:
|
|
||||||
"""Helper function to get an invocation or invocation output's type. This is the default value of the `type` field."""
|
|
||||||
return klass.model_fields["type"].default
|
|
||||||
|
|
||||||
|
|
||||||
def InputField(
|
def InputField(
|
||||||
# copied from pydantic's Field
|
# copied from pydantic's Field
|
||||||
# TODO: Can we support default_factory?
|
# TODO: Can we support default_factory?
|
||||||
@ -483,29 +478,39 @@ class BaseInvocationOutput(BaseModel):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register_output(cls, output: BaseInvocationOutput) -> None:
|
def register_output(cls, output: BaseInvocationOutput) -> None:
|
||||||
|
"""Registers an invocation output."""
|
||||||
cls._output_classes.add(output)
|
cls._output_classes.add(output)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_outputs(cls) -> Iterable[BaseInvocationOutput]:
|
def get_outputs(cls) -> Iterable[BaseInvocationOutput]:
|
||||||
|
"""Gets all invocation outputs."""
|
||||||
return cls._output_classes
|
return cls._output_classes
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_outputs_union(cls) -> UnionType:
|
def get_outputs_union(cls) -> UnionType:
|
||||||
|
"""Gets a union of all invocation outputs."""
|
||||||
outputs_union = Union[tuple(cls._output_classes)] # type: ignore [valid-type]
|
outputs_union = Union[tuple(cls._output_classes)] # type: ignore [valid-type]
|
||||||
return outputs_union # type: ignore [return-value]
|
return outputs_union # type: ignore [return-value]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_output_types(cls) -> Iterable[str]:
|
def get_output_types(cls) -> Iterable[str]:
|
||||||
return (get_type(i) for i in BaseInvocationOutput.get_outputs())
|
"""Gets all invocation output types."""
|
||||||
|
return (i.get_type() for i in BaseInvocationOutput.get_outputs())
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||||
|
"""Adds various UI-facing attributes to the invocation output's OpenAPI schema."""
|
||||||
# Because we use a pydantic Literal field with default value for the invocation type,
|
# Because we use a pydantic Literal field with default value for the invocation type,
|
||||||
# it will be typed as optional in the OpenAPI schema. Make it required manually.
|
# it will be typed as optional in the OpenAPI schema. Make it required manually.
|
||||||
if "required" not in schema or not isinstance(schema["required"], list):
|
if "required" not in schema or not isinstance(schema["required"], list):
|
||||||
schema["required"] = []
|
schema["required"] = []
|
||||||
schema["required"].extend(["type"])
|
schema["required"].extend(["type"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_type(cls) -> str:
|
||||||
|
"""Gets the invocation output's type, as provided by the `@invocation_output` decorator."""
|
||||||
|
return cls.model_fields["type"].default
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
protected_namespaces=(),
|
protected_namespaces=(),
|
||||||
validate_assignment=True,
|
validate_assignment=True,
|
||||||
@ -535,21 +540,29 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
|
|
||||||
_invocation_classes: ClassVar[set[BaseInvocation]] = set()
|
_invocation_classes: ClassVar[set[BaseInvocation]] = set()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_type(cls) -> str:
|
||||||
|
"""Gets the invocation's type, as provided by the `@invocation` decorator."""
|
||||||
|
return cls.model_fields["type"].default
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register_invocation(cls, invocation: BaseInvocation) -> None:
|
def register_invocation(cls, invocation: BaseInvocation) -> None:
|
||||||
|
"""Registers an invocation."""
|
||||||
cls._invocation_classes.add(invocation)
|
cls._invocation_classes.add(invocation)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_invocations_union(cls) -> UnionType:
|
def get_invocations_union(cls) -> UnionType:
|
||||||
|
"""Gets a union of all invocation types."""
|
||||||
invocations_union = Union[tuple(cls._invocation_classes)] # type: ignore [valid-type]
|
invocations_union = Union[tuple(cls._invocation_classes)] # type: ignore [valid-type]
|
||||||
return invocations_union # type: ignore [return-value]
|
return invocations_union # type: ignore [return-value]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_invocations(cls) -> Iterable[BaseInvocation]:
|
def get_invocations(cls) -> Iterable[BaseInvocation]:
|
||||||
|
"""Gets all invocations, respecting the allowlist and denylist."""
|
||||||
app_config = InvokeAIAppConfig.get_config()
|
app_config = InvokeAIAppConfig.get_config()
|
||||||
allowed_invocations: set[BaseInvocation] = set()
|
allowed_invocations: set[BaseInvocation] = set()
|
||||||
for sc in cls._invocation_classes:
|
for sc in cls._invocation_classes:
|
||||||
invocation_type = get_type(sc)
|
invocation_type = sc.get_type()
|
||||||
is_in_allowlist = (
|
is_in_allowlist = (
|
||||||
invocation_type in app_config.allow_nodes if isinstance(app_config.allow_nodes, list) else True
|
invocation_type in app_config.allow_nodes if isinstance(app_config.allow_nodes, list) else True
|
||||||
)
|
)
|
||||||
@ -562,20 +575,22 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_invocations_map(cls) -> dict[str, BaseInvocation]:
|
def get_invocations_map(cls) -> dict[str, BaseInvocation]:
|
||||||
# Get the type strings out of the literals and into a dictionary
|
"""Gets a map of all invocation types to their invocation classes."""
|
||||||
return {get_type(i): i for i in BaseInvocation.get_invocations()}
|
return {i.get_type(): i for i in BaseInvocation.get_invocations()}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_invocation_types(cls) -> Iterable[str]:
|
def get_invocation_types(cls) -> Iterable[str]:
|
||||||
return (get_type(i) for i in BaseInvocation.get_invocations())
|
"""Gets all invocation types."""
|
||||||
|
return (i.get_type() for i in BaseInvocation.get_invocations())
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_output_type(cls) -> BaseInvocationOutput:
|
def get_output_annotation(cls) -> BaseInvocationOutput:
|
||||||
|
"""Gets the invocation's output annotation (i.e. the return annotation of its `invoke()` method)."""
|
||||||
return signature(cls.invoke).return_annotation
|
return signature(cls.invoke).return_annotation
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel], *args, **kwargs) -> None:
|
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel], *args, **kwargs) -> None:
|
||||||
# Add the various UI-facing attributes to the schema. These are used to build the invocation templates.
|
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
|
||||||
uiconfig = getattr(model_class, "UIConfig", None)
|
uiconfig = getattr(model_class, "UIConfig", None)
|
||||||
if uiconfig and hasattr(uiconfig, "title"):
|
if uiconfig and hasattr(uiconfig, "title"):
|
||||||
schema["title"] = uiconfig.title
|
schema["title"] = uiconfig.title
|
||||||
@ -595,6 +610,10 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def invoke_internal(self, context: InvocationContext) -> BaseInvocationOutput:
|
def invoke_internal(self, context: InvocationContext) -> BaseInvocationOutput:
|
||||||
|
"""
|
||||||
|
Internal invoke method, calls `invoke()` after some prep.
|
||||||
|
Handles optional fields that are required to call `invoke()` and invocation cache.
|
||||||
|
"""
|
||||||
for field_name, field in self.model_fields.items():
|
for field_name, field in self.model_fields.items():
|
||||||
if not field.json_schema_extra or callable(field.json_schema_extra):
|
if not field.json_schema_extra or callable(field.json_schema_extra):
|
||||||
# something has gone terribly awry, we should always have this and it should be a dict
|
# something has gone terribly awry, we should always have this and it should be a dict
|
||||||
@ -634,9 +653,6 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
context.services.logger.debug(f'Skipping invocation cache for "{self.get_type()}": {self.id}')
|
context.services.logger.debug(f'Skipping invocation cache for "{self.get_type()}": {self.id}')
|
||||||
return self.invoke(context)
|
return self.invoke(context)
|
||||||
|
|
||||||
def get_type(self) -> str:
|
|
||||||
return self.model_fields["type"].default
|
|
||||||
|
|
||||||
id: str = Field(
|
id: str = Field(
|
||||||
default_factory=uuid_string,
|
default_factory=uuid_string,
|
||||||
description="The id of this instance of an invocation. Must be unique among all instances of invocations.",
|
description="The id of this instance of an invocation. Must be unique among all instances of invocations.",
|
||||||
@ -693,9 +709,11 @@ RESERVED_PYDANTIC_FIELD_NAMES = {m[0] for m in inspect.getmembers(_Model())}
|
|||||||
def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None:
|
def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None:
|
||||||
"""
|
"""
|
||||||
Validates the fields of an invocation or invocation output:
|
Validates the fields of an invocation or invocation output:
|
||||||
- must not override any pydantic reserved fields
|
- Must not override any pydantic reserved fields
|
||||||
- must not end with "Collection" or "Polymorphic" as these are reserved for internal use
|
- Must have a type annotation
|
||||||
- must be created via `InputField`, `OutputField`, or be an internal field defined in this file
|
- Must have a json_schema_extra dict
|
||||||
|
- Must have field_kind in json_schema_extra
|
||||||
|
- Field name must not be reserved, according to its field_kind
|
||||||
"""
|
"""
|
||||||
for name, field in model_fields.items():
|
for name, field in model_fields.items():
|
||||||
if name in RESERVED_PYDANTIC_FIELD_NAMES:
|
if name in RESERVED_PYDANTIC_FIELD_NAMES:
|
||||||
|
@ -49,7 +49,7 @@ class Edge(BaseModel):
|
|||||||
|
|
||||||
def get_output_field(node: BaseInvocation, field: str) -> Any:
|
def get_output_field(node: BaseInvocation, field: str) -> Any:
|
||||||
node_type = type(node)
|
node_type = type(node)
|
||||||
node_outputs = get_type_hints(node_type.get_output_type())
|
node_outputs = get_type_hints(node_type.get_output_annotation())
|
||||||
node_output_field = node_outputs.get(field) or None
|
node_output_field = node_outputs.get(field) or None
|
||||||
return node_output_field
|
return node_output_field
|
||||||
|
|
||||||
@ -379,7 +379,7 @@ class Graph(BaseModel):
|
|||||||
raise NodeNotFoundError(f"Edge destination node {edge.destination.node_id} does not exist in the graph")
|
raise NodeNotFoundError(f"Edge destination node {edge.destination.node_id} does not exist in the graph")
|
||||||
|
|
||||||
# output fields are not on the node object directly, they are on the output type
|
# output fields are not on the node object directly, they are on the output type
|
||||||
if edge.source.field not in source_node.get_output_type().model_fields:
|
if edge.source.field not in source_node.get_output_annotation().model_fields:
|
||||||
raise NodeFieldNotFoundError(
|
raise NodeFieldNotFoundError(
|
||||||
f"Edge source field {edge.source.field} does not exist in node {edge.source.node_id}"
|
f"Edge source field {edge.source.field} does not exist in node {edge.source.node_id}"
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user