docs(nodes): update all docstrings for public nodes API

This commit is contained in:
psychedelicious 2024-02-29 23:03:17 +11:00
parent 2f26768d19
commit 753919c6d7

View File

@ -65,75 +65,86 @@ class InvocationContextInterface:
class BoardsInterface(InvocationContextInterface): class BoardsInterface(InvocationContextInterface):
def create(self, board_name: str) -> BoardDTO: def create(self, board_name: str) -> BoardDTO:
""" """Creates a board.
Creates a board.
:param board_name: The name of the board to create. Args:
board_name: The name of the board to create.
Returns:
The created board DTO.
""" """
return self._services.boards.create(board_name) return self._services.boards.create(board_name)
def get_dto(self, board_id: str) -> BoardDTO: def get_dto(self, board_id: str) -> BoardDTO:
""" """Gets a board DTO.
Gets a board DTO.
:param board_id: The ID of the board to get. Args:
board_id: The ID of the board to get.
Returns:
The board DTO.
""" """
return self._services.boards.get_dto(board_id) return self._services.boards.get_dto(board_id)
def get_all(self) -> list[BoardDTO]: def get_all(self) -> list[BoardDTO]:
""" """Gets all boards.
Gets all boards.
Returns:
A list of all boards.
""" """
return self._services.boards.get_all() return self._services.boards.get_all()
def add_image_to_board(self, board_id: str, image_name: str) -> None: def add_image_to_board(self, board_id: str, image_name: str) -> None:
""" """Adds an image to a board.
Adds an image to a board.
:param board_id: The ID of the board to add the image to. Args:
:param image_name: The name of the image to add to the board. board_id: The ID of the board to add the image to.
image_name: The name of the image to add to the board.
""" """
return self._services.board_images.add_image_to_board(board_id, image_name) return self._services.board_images.add_image_to_board(board_id, image_name)
def get_all_image_names_for_board(self, board_id: str) -> list[str]: def get_all_image_names_for_board(self, board_id: str) -> list[str]:
""" """Gets all image names for a board.
Gets all image names for a board.
:param board_id: The ID of the board to get the image names for. Args:
board_id: The ID of the board to get the image names for.
Returns:
A list of all image names for the board.
""" """
return self._services.board_images.get_all_board_image_names_for_board(board_id) return self._services.board_images.get_all_board_image_names_for_board(board_id)
class LoggerInterface(InvocationContextInterface): class LoggerInterface(InvocationContextInterface):
def debug(self, message: str) -> None: def debug(self, message: str) -> None:
""" """Logs a debug message.
Logs a debug message.
:param message: The message to log. Args:
message: The message to log.
""" """
self._services.logger.debug(message) self._services.logger.debug(message)
def info(self, message: str) -> None: def info(self, message: str) -> None:
""" """Logs an info message.
Logs an info message.
:param message: The message to log. Args:
message: The message to log.
""" """
self._services.logger.info(message) self._services.logger.info(message)
def warning(self, message: str) -> None: def warning(self, message: str) -> None:
""" """Logs a warning message.
Logs a warning message.
:param message: The message to log. Args:
message: The message to log.
""" """
self._services.logger.warning(message) self._services.logger.warning(message)
def error(self, message: str) -> None: def error(self, message: str) -> None:
""" """Logs an error message.
Logs an error message.
:param message: The message to log. Args:
message: The message to log.
""" """
self._services.logger.error(message) self._services.logger.error(message)
@ -146,20 +157,23 @@ class ImagesInterface(InvocationContextInterface):
image_category: ImageCategory = ImageCategory.GENERAL, image_category: ImageCategory = ImageCategory.GENERAL,
metadata: Optional[MetadataField] = None, metadata: Optional[MetadataField] = None,
) -> ImageDTO: ) -> ImageDTO:
""" """Saves an image, returning its DTO.
Saves an image, returning its DTO.
If the current queue item has a workflow or metadata, it is automatically saved with the image. If the current queue item has a workflow or metadata, it is automatically saved with the image.
:param image: The image to save, as a PIL image. Args:
:param board_id: The board ID to add the image to, if it should be added. It the invocation \ image: The image to save, as a PIL image.
board_id: The board ID to add the image to, if it should be added. It the invocation \
inherits from `WithBoard`, that board will be used automatically. **Use this only if \ inherits from `WithBoard`, that board will be used automatically. **Use this only if \
you want to override or provide a board manually!** you want to override or provide a board manually!**
:param image_category: The category of the image. Only the GENERAL category is added \ image_category: The category of the image. Only the GENERAL category is added \
to the gallery. to the gallery.
:param metadata: The metadata to save with the image, if it should have any. If the \ metadata: The metadata to save with the image, if it should have any. If the \
invocation inherits from `WithMetadata`, that metadata will be used automatically. \ invocation inherits from `WithMetadata`, that metadata will be used automatically. \
**Use this only if you want to override or provide metadata manually!** **Use this only if you want to override or provide metadata manually!**
Returns:
The saved image DTO.
""" """
# If `metadata` is provided directly, use that. Else, use the metadata provided by `WithMetadata`, falling back to None. # If `metadata` is provided directly, use that. Else, use the metadata provided by `WithMetadata`, falling back to None.
@ -189,11 +203,14 @@ class ImagesInterface(InvocationContextInterface):
) )
def get_pil(self, image_name: str, mode: IMAGE_MODES | None = None) -> Image: def get_pil(self, image_name: str, mode: IMAGE_MODES | None = None) -> Image:
""" """Gets an image as a PIL Image object.
Gets an image as a PIL Image object.
:param image_name: The name of the image to get. Args:
:param mode: The color mode to convert the image to. If None, the original mode is used. image_name: The name of the image to get.
mode: The color mode to convert the image to. If None, the original mode is used.
Returns:
The image as a PIL Image object.
""" """
image = self._services.images.get_pil_image(image_name) image = self._services.images.get_pil_image(image_name)
if mode and mode != image.mode: if mode and mode != image.mode:
@ -206,58 +223,76 @@ class ImagesInterface(InvocationContextInterface):
return image return image
def get_metadata(self, image_name: str) -> Optional[MetadataField]: def get_metadata(self, image_name: str) -> Optional[MetadataField]:
""" """Gets an image's metadata, if it has any.
Gets an image's metadata, if it has any.
:param image_name: The name of the image to get the metadata for. Args:
image_name: The name of the image to get the metadata for.
Returns:
The image's metadata, if it has any.
""" """
return self._services.images.get_metadata(image_name) return self._services.images.get_metadata(image_name)
def get_dto(self, image_name: str) -> ImageDTO: def get_dto(self, image_name: str) -> ImageDTO:
""" """Gets an image as an ImageDTO object.
Gets an image as an ImageDTO object.
:param image_name: The name of the image to get. Args:
image_name: The name of the image to get.
Returns:
The image as an ImageDTO object.
""" """
return self._services.images.get_dto(image_name) return self._services.images.get_dto(image_name)
class TensorsInterface(InvocationContextInterface): class TensorsInterface(InvocationContextInterface):
def save(self, tensor: Tensor) -> str: def save(self, tensor: Tensor) -> str:
""" """Saves a tensor, returning its name.
Saves a tensor, returning its name.
:param tensor: The tensor to save. Args:
tensor: The tensor to save.
Returns:
The name of the saved tensor.
""" """
name = self._services.tensors.save(obj=tensor) name = self._services.tensors.save(obj=tensor)
return name return name
def load(self, name: str) -> Tensor: def load(self, name: str) -> Tensor:
""" """Loads a tensor by name.
Loads a tensor by name.
:param name: The name of the tensor to load. Args:
name: The name of the tensor to load.
Returns:
The loaded tensor.
""" """
return self._services.tensors.load(name) return self._services.tensors.load(name)
class ConditioningInterface(InvocationContextInterface): class ConditioningInterface(InvocationContextInterface):
def save(self, conditioning_data: ConditioningFieldData) -> str: def save(self, conditioning_data: ConditioningFieldData) -> str:
""" """Saves a conditioning data object, returning its name.
Saves a conditioning data object, returning its name.
:param conditioning_data: The conditioning data to save. Args:
conditioning_data: The conditioning data to save.
Returns:
The name of the saved conditioning data.
""" """
name = self._services.conditioning.save(obj=conditioning_data) name = self._services.conditioning.save(obj=conditioning_data)
return name return name
def load(self, name: str) -> ConditioningFieldData: def load(self, name: str) -> ConditioningFieldData:
""" """Loads conditioning data by name.
Loads conditioning data by name.
:param name: The name of the conditioning data to load. Args:
name: The name of the conditioning data to load.
Returns:
The loaded conditioning data.
""" """
return self._services.conditioning.load(name) return self._services.conditioning.load(name)
@ -265,20 +300,25 @@ class ConditioningInterface(InvocationContextInterface):
class ModelsInterface(InvocationContextInterface): class ModelsInterface(InvocationContextInterface):
def exists(self, key: str) -> bool: def exists(self, key: str) -> bool:
""" """Checks if a model exists.
Checks if a model exists.
:param key: The key of the model. Args:
key: The key of the model.
Returns:
True if the model exists, False if not.
""" """
return self._services.model_manager.store.exists(key) return self._services.model_manager.store.exists(key)
def load(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel: def load(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
""" """Loads a model.
Loads a model.
:param key: The key of the model. Args:
:param submodel_type: The submodel of the model to get. key: The key of the model.
:returns: An object representing the loaded model. submodel_type: The submodel of the model to get.
Returns:
An object representing the loaded model.
""" """
# The model manager emits events as it loads the model. It needs the context data to build # The model manager emits events as it loads the model. It needs the context data to build
@ -291,13 +331,17 @@ class ModelsInterface(InvocationContextInterface):
def load_by_attrs( def load_by_attrs(
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
) -> LoadedModel: ) -> LoadedModel:
""" """Loads a model by its attributes.
Loads a model by its attributes.
:param model_name: Name of to be fetched. Args:
:param base_model: Base model name: Name of the model.
:param model_type: Type of the model base: The models' base type, e.g. `BaseModelType.StableDiffusion1`, `BaseModelType.StableDiffusionXL`, etc.
:param submodel: For main (pipeline models), the submodel to fetch type: Type of the model, e.g. `ModelType.Main`, `ModelType.Vae`, etc.
submodel_type: The type of submodel to load, e.g. `SubModelType.UNet`, `SubModelType.TextEncoder`, etc. Only main
models have submodels.
Returns:
An object representing the loaded model.
""" """
return self._services.model_manager.load_model_by_attr( return self._services.model_manager.load_model_by_attr(
model_name=name, model_name=name,
@ -308,26 +352,35 @@ class ModelsInterface(InvocationContextInterface):
) )
def get_config(self, key: str) -> AnyModelConfig: def get_config(self, key: str) -> AnyModelConfig:
""" """Gets a model's config.
Gets a model's info, an dict-like object.
:param key: The key of the model. Args:
key: The key of the model.
Returns:
The model's config.
""" """
return self._services.model_manager.store.get_model(key=key) return self._services.model_manager.store.get_model(key=key)
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]: def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
""" """Gets a model's metadata, if it has any.
Gets a model's metadata, if it has any.
:param key: The key of the model. Args:
key: The key of the model.
Returns:
The model's metadata, if it has any.
""" """
return self._services.model_manager.store.get_metadata(key=key) return self._services.model_manager.store.get_metadata(key=key)
def search_by_path(self, path: Path) -> list[AnyModelConfig]: def search_by_path(self, path: Path) -> list[AnyModelConfig]:
""" """Searches for models by path.
Searches for models by path.
:param path: The path to search for. Args:
path: The path to search for.
Returns:
A list of models that match the path.
""" """
return self._services.model_manager.store.search_by_path(path) return self._services.model_manager.store.search_by_path(path)
@ -338,13 +391,16 @@ class ModelsInterface(InvocationContextInterface):
type: Optional[ModelType] = None, type: Optional[ModelType] = None,
format: Optional[ModelFormat] = None, format: Optional[ModelFormat] = None,
) -> list[AnyModelConfig]: ) -> list[AnyModelConfig]:
""" """Searches for models by attributes.
Searches for models by attributes.
:param model_name: Name of to be fetched. Args:
:param base_model: Base model name: The name to search for (exact match).
:param model_type: Type of the model base: The base to search for, e.g. `BaseModelType.StableDiffusion1`, `BaseModelType.StableDiffusionXL`, etc.
:param submodel: For main (pipeline models), the submodel to fetch type: Type type of model to search for, e.g. `ModelType.Main`, `ModelType.Vae`, etc.
format: The format of model to search for, e.g. `ModelFormat.Checkpoint`, `ModelFormat.Diffusers`, etc.
Returns:
A list of models that match the attributes.
""" """
return self._services.model_manager.store.search_by_attr( return self._services.model_manager.store.search_by_attr(
@ -357,7 +413,11 @@ class ModelsInterface(InvocationContextInterface):
class ConfigInterface(InvocationContextInterface): class ConfigInterface(InvocationContextInterface):
def get(self) -> InvokeAIAppConfig: def get(self) -> InvokeAIAppConfig:
"""Gets the app's config.""" """Gets the app's config.
Returns:
The app's config.
"""
return self._services.configuration.get_config() return self._services.configuration.get_config()
@ -370,7 +430,11 @@ class UtilInterface(InvocationContextInterface):
self._cancel_event = cancel_event self._cancel_event = cancel_event
def is_canceled(self) -> bool: def is_canceled(self) -> bool:
"""Checks if the current invocation has been canceled.""" """Checks if the current session has been canceled.
Returns:
True if the current session has been canceled, False if not.
"""
return self._cancel_event.is_set() return self._cancel_event.is_set()
def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_model: BaseModelType) -> None: def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_model: BaseModelType) -> None:
@ -380,8 +444,9 @@ class UtilInterface(InvocationContextInterface):
This should be called after each denoising step. This should be called after each denoising step.
:param intermediate_state: The intermediate state of the diffusion pipeline. Args:
:param base_model: The base model for the current denoising step. intermediate_state: The intermediate state of the diffusion pipeline.
base_model: The base model for the current denoising step.
""" """
stable_diffusion_step_callback( stable_diffusion_step_callback(
@ -394,8 +459,17 @@ class UtilInterface(InvocationContextInterface):
class InvocationContext: class InvocationContext:
""" """Provides access to various services and data for the current invocation.
The `InvocationContext` provides access to various services and data for the current invocation.
Attributes:
images (ImagesInterface): Methods to save, get and update images and their metadata.
tensors (TensorsInterface): Methods to save and get tensors, including image, noise, masks, and masked images.
conditioning (ConditioningInterface): Methods to save and get conditioning data.
models (ModelsInterface): Methods to check if a model exists, get a model, and get a model's info.
logger (LoggerInterface): The app logger.
config (ConfigInterface): The app config.
util (UtilInterface): Utility methods, including a method to check if an invocation was canceled and step callbacks.
boards (BoardsInterface): Methods to interact with boards.
""" """
def __init__( def __init__(
@ -438,11 +512,14 @@ def build_invocation_context(
data: InvocationContextData, data: InvocationContextData,
cancel_event: threading.Event, cancel_event: threading.Event,
) -> InvocationContext: ) -> InvocationContext:
""" """Builds the invocation context for a specific invocation execution.
Builds the invocation context for a specific invocation execution.
:param services: The invocation services to wrap. Args:
:param data: The invocation context data. services: The invocation services to wrap.
data: The invocation context data.
Returns:
The invocation context.
""" """
logger = LoggerInterface(services=services, data=data) logger = LoggerInterface(services=services, data=data)