mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
docs(nodes): update all docstrings for public nodes API
This commit is contained in:
parent
84dc5c5c7b
commit
68344ecac9
@ -65,75 +65,86 @@ class InvocationContextInterface:
|
||||
|
||||
class BoardsInterface(InvocationContextInterface):
|
||||
def create(self, board_name: str) -> BoardDTO:
|
||||
"""
|
||||
Creates a board.
|
||||
"""Creates a board.
|
||||
|
||||
:param board_name: The name of the board to create.
|
||||
Args:
|
||||
board_name: The name of the board to create.
|
||||
|
||||
Returns:
|
||||
The created board DTO.
|
||||
"""
|
||||
return self._services.boards.create(board_name)
|
||||
|
||||
def get_dto(self, board_id: str) -> BoardDTO:
|
||||
"""
|
||||
Gets a board DTO.
|
||||
"""Gets a board DTO.
|
||||
|
||||
:param board_id: The ID of the board to get.
|
||||
Args:
|
||||
board_id: The ID of the board to get.
|
||||
|
||||
Returns:
|
||||
The board DTO.
|
||||
"""
|
||||
return self._services.boards.get_dto(board_id)
|
||||
|
||||
def get_all(self) -> list[BoardDTO]:
|
||||
"""
|
||||
Gets all boards.
|
||||
"""Gets all boards.
|
||||
|
||||
Returns:
|
||||
A list of all boards.
|
||||
"""
|
||||
return self._services.boards.get_all()
|
||||
|
||||
def add_image_to_board(self, board_id: str, image_name: str) -> None:
|
||||
"""
|
||||
Adds an image to a board.
|
||||
"""Adds an image to a board.
|
||||
|
||||
:param board_id: The ID of the board to add the image to.
|
||||
:param image_name: The name of the image to add to the board.
|
||||
Args:
|
||||
board_id: The ID of the board to add the image to.
|
||||
image_name: The name of the image to add to the board.
|
||||
"""
|
||||
return self._services.board_images.add_image_to_board(board_id, image_name)
|
||||
|
||||
def get_all_image_names_for_board(self, board_id: str) -> list[str]:
|
||||
"""
|
||||
Gets all image names for a board.
|
||||
"""Gets all image names for a board.
|
||||
|
||||
:param board_id: The ID of the board to get the image names for.
|
||||
Args:
|
||||
board_id: The ID of the board to get the image names for.
|
||||
|
||||
Returns:
|
||||
A list of all image names for the board.
|
||||
"""
|
||||
return self._services.board_images.get_all_board_image_names_for_board(board_id)
|
||||
|
||||
|
||||
class LoggerInterface(InvocationContextInterface):
|
||||
def debug(self, message: str) -> None:
|
||||
"""
|
||||
Logs a debug message.
|
||||
"""Logs a debug message.
|
||||
|
||||
:param message: The message to log.
|
||||
Args:
|
||||
message: The message to log.
|
||||
"""
|
||||
self._services.logger.debug(message)
|
||||
|
||||
def info(self, message: str) -> None:
|
||||
"""
|
||||
Logs an info message.
|
||||
"""Logs an info message.
|
||||
|
||||
:param message: The message to log.
|
||||
Args:
|
||||
message: The message to log.
|
||||
"""
|
||||
self._services.logger.info(message)
|
||||
|
||||
def warning(self, message: str) -> None:
|
||||
"""
|
||||
Logs a warning message.
|
||||
"""Logs a warning message.
|
||||
|
||||
:param message: The message to log.
|
||||
Args:
|
||||
message: The message to log.
|
||||
"""
|
||||
self._services.logger.warning(message)
|
||||
|
||||
def error(self, message: str) -> None:
|
||||
"""
|
||||
Logs an error message.
|
||||
"""Logs an error message.
|
||||
|
||||
:param message: The message to log.
|
||||
Args:
|
||||
message: The message to log.
|
||||
"""
|
||||
self._services.logger.error(message)
|
||||
|
||||
@ -146,20 +157,23 @@ class ImagesInterface(InvocationContextInterface):
|
||||
image_category: ImageCategory = ImageCategory.GENERAL,
|
||||
metadata: Optional[MetadataField] = None,
|
||||
) -> ImageDTO:
|
||||
"""
|
||||
Saves an image, returning its DTO.
|
||||
"""Saves an image, returning its DTO.
|
||||
|
||||
If the current queue item has a workflow or metadata, it is automatically saved with the image.
|
||||
|
||||
:param image: The image to save, as a PIL image.
|
||||
:param board_id: The board ID to add the image to, if it should be added. It the invocation \
|
||||
Args:
|
||||
image: The image to save, as a PIL image.
|
||||
board_id: The board ID to add the image to, if it should be added. It the invocation \
|
||||
inherits from `WithBoard`, that board will be used automatically. **Use this only if \
|
||||
you want to override or provide a board manually!**
|
||||
:param image_category: The category of the image. Only the GENERAL category is added \
|
||||
image_category: The category of the image. Only the GENERAL category is added \
|
||||
to the gallery.
|
||||
:param metadata: The metadata to save with the image, if it should have any. If the \
|
||||
metadata: The metadata to save with the image, if it should have any. If the \
|
||||
invocation inherits from `WithMetadata`, that metadata will be used automatically. \
|
||||
**Use this only if you want to override or provide metadata manually!**
|
||||
|
||||
Returns:
|
||||
The saved image DTO.
|
||||
"""
|
||||
|
||||
# If `metadata` is provided directly, use that. Else, use the metadata provided by `WithMetadata`, falling back to None.
|
||||
@ -189,11 +203,14 @@ class ImagesInterface(InvocationContextInterface):
|
||||
)
|
||||
|
||||
def get_pil(self, image_name: str, mode: IMAGE_MODES | None = None) -> Image:
|
||||
"""
|
||||
Gets an image as a PIL Image object.
|
||||
"""Gets an image as a PIL Image object.
|
||||
|
||||
:param image_name: The name of the image to get.
|
||||
:param mode: The color mode to convert the image to. If None, the original mode is used.
|
||||
Args:
|
||||
image_name: The name of the image to get.
|
||||
mode: The color mode to convert the image to. If None, the original mode is used.
|
||||
|
||||
Returns:
|
||||
The image as a PIL Image object.
|
||||
"""
|
||||
image = self._services.images.get_pil_image(image_name)
|
||||
if mode and mode != image.mode:
|
||||
@ -206,58 +223,76 @@ class ImagesInterface(InvocationContextInterface):
|
||||
return image
|
||||
|
||||
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
|
||||
"""
|
||||
Gets an image's metadata, if it has any.
|
||||
"""Gets an image's metadata, if it has any.
|
||||
|
||||
:param image_name: The name of the image to get the metadata for.
|
||||
Args:
|
||||
image_name: The name of the image to get the metadata for.
|
||||
|
||||
Returns:
|
||||
The image's metadata, if it has any.
|
||||
"""
|
||||
return self._services.images.get_metadata(image_name)
|
||||
|
||||
def get_dto(self, image_name: str) -> ImageDTO:
|
||||
"""
|
||||
Gets an image as an ImageDTO object.
|
||||
"""Gets an image as an ImageDTO object.
|
||||
|
||||
:param image_name: The name of the image to get.
|
||||
Args:
|
||||
image_name: The name of the image to get.
|
||||
|
||||
Returns:
|
||||
The image as an ImageDTO object.
|
||||
"""
|
||||
return self._services.images.get_dto(image_name)
|
||||
|
||||
|
||||
class TensorsInterface(InvocationContextInterface):
|
||||
def save(self, tensor: Tensor) -> str:
|
||||
"""
|
||||
Saves a tensor, returning its name.
|
||||
"""Saves a tensor, returning its name.
|
||||
|
||||
:param tensor: The tensor to save.
|
||||
Args:
|
||||
tensor: The tensor to save.
|
||||
|
||||
Returns:
|
||||
The name of the saved tensor.
|
||||
"""
|
||||
|
||||
name = self._services.tensors.save(obj=tensor)
|
||||
return name
|
||||
|
||||
def load(self, name: str) -> Tensor:
|
||||
"""
|
||||
Loads a tensor by name.
|
||||
"""Loads a tensor by name.
|
||||
|
||||
:param name: The name of the tensor to load.
|
||||
Args:
|
||||
name: The name of the tensor to load.
|
||||
|
||||
Returns:
|
||||
The loaded tensor.
|
||||
"""
|
||||
return self._services.tensors.load(name)
|
||||
|
||||
|
||||
class ConditioningInterface(InvocationContextInterface):
|
||||
def save(self, conditioning_data: ConditioningFieldData) -> str:
|
||||
"""
|
||||
Saves a conditioning data object, returning its name.
|
||||
"""Saves a conditioning data object, returning its name.
|
||||
|
||||
:param conditioning_data: The conditioning data to save.
|
||||
Args:
|
||||
conditioning_data: The conditioning data to save.
|
||||
|
||||
Returns:
|
||||
The name of the saved conditioning data.
|
||||
"""
|
||||
|
||||
name = self._services.conditioning.save(obj=conditioning_data)
|
||||
return name
|
||||
|
||||
def load(self, name: str) -> ConditioningFieldData:
|
||||
"""
|
||||
Loads conditioning data by name.
|
||||
"""Loads conditioning data by name.
|
||||
|
||||
:param name: The name of the conditioning data to load.
|
||||
Args:
|
||||
name: The name of the conditioning data to load.
|
||||
|
||||
Returns:
|
||||
The loaded conditioning data.
|
||||
"""
|
||||
|
||||
return self._services.conditioning.load(name)
|
||||
@ -265,20 +300,25 @@ class ConditioningInterface(InvocationContextInterface):
|
||||
|
||||
class ModelsInterface(InvocationContextInterface):
|
||||
def exists(self, key: str) -> bool:
|
||||
"""
|
||||
Checks if a model exists.
|
||||
"""Checks if a model exists.
|
||||
|
||||
:param key: The key of the model.
|
||||
Args:
|
||||
key: The key of the model.
|
||||
|
||||
Returns:
|
||||
True if the model exists, False if not.
|
||||
"""
|
||||
return self._services.model_manager.store.exists(key)
|
||||
|
||||
def load(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
"""
|
||||
Loads a model.
|
||||
"""Loads a model.
|
||||
|
||||
:param key: The key of the model.
|
||||
:param submodel_type: The submodel of the model to get.
|
||||
:returns: An object representing the loaded model.
|
||||
Args:
|
||||
key: The key of the model.
|
||||
submodel_type: The submodel of the model to get.
|
||||
|
||||
Returns:
|
||||
An object representing the loaded model.
|
||||
"""
|
||||
|
||||
# The model manager emits events as it loads the model. It needs the context data to build
|
||||
@ -291,13 +331,17 @@ class ModelsInterface(InvocationContextInterface):
|
||||
def load_by_attrs(
|
||||
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Loads a model by its attributes.
|
||||
"""Loads a model by its attributes.
|
||||
|
||||
:param model_name: Name of to be fetched.
|
||||
:param base_model: Base model
|
||||
:param model_type: Type of the model
|
||||
:param submodel: For main (pipeline models), the submodel to fetch
|
||||
Args:
|
||||
name: Name of the model.
|
||||
base: The models' base type, e.g. `BaseModelType.StableDiffusion1`, `BaseModelType.StableDiffusionXL`, etc.
|
||||
type: Type of the model, e.g. `ModelType.Main`, `ModelType.Vae`, etc.
|
||||
submodel_type: The type of submodel to load, e.g. `SubModelType.UNet`, `SubModelType.TextEncoder`, etc. Only main
|
||||
models have submodels.
|
||||
|
||||
Returns:
|
||||
An object representing the loaded model.
|
||||
"""
|
||||
return self._services.model_manager.load_model_by_attr(
|
||||
model_name=name,
|
||||
@ -308,26 +352,35 @@ class ModelsInterface(InvocationContextInterface):
|
||||
)
|
||||
|
||||
def get_config(self, key: str) -> AnyModelConfig:
|
||||
"""
|
||||
Gets a model's info, an dict-like object.
|
||||
"""Gets a model's config.
|
||||
|
||||
:param key: The key of the model.
|
||||
Args:
|
||||
key: The key of the model.
|
||||
|
||||
Returns:
|
||||
The model's config.
|
||||
"""
|
||||
return self._services.model_manager.store.get_model(key=key)
|
||||
|
||||
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
|
||||
"""
|
||||
Gets a model's metadata, if it has any.
|
||||
"""Gets a model's metadata, if it has any.
|
||||
|
||||
:param key: The key of the model.
|
||||
Args:
|
||||
key: The key of the model.
|
||||
|
||||
Returns:
|
||||
The model's metadata, if it has any.
|
||||
"""
|
||||
return self._services.model_manager.store.get_metadata(key=key)
|
||||
|
||||
def search_by_path(self, path: Path) -> list[AnyModelConfig]:
|
||||
"""
|
||||
Searches for models by path.
|
||||
"""Searches for models by path.
|
||||
|
||||
:param path: The path to search for.
|
||||
Args:
|
||||
path: The path to search for.
|
||||
|
||||
Returns:
|
||||
A list of models that match the path.
|
||||
"""
|
||||
return self._services.model_manager.store.search_by_path(path)
|
||||
|
||||
@ -338,13 +391,16 @@ class ModelsInterface(InvocationContextInterface):
|
||||
type: Optional[ModelType] = None,
|
||||
format: Optional[ModelFormat] = None,
|
||||
) -> list[AnyModelConfig]:
|
||||
"""
|
||||
Searches for models by attributes.
|
||||
"""Searches for models by attributes.
|
||||
|
||||
:param model_name: Name of to be fetched.
|
||||
:param base_model: Base model
|
||||
:param model_type: Type of the model
|
||||
:param submodel: For main (pipeline models), the submodel to fetch
|
||||
Args:
|
||||
name: The name to search for (exact match).
|
||||
base: The base to search for, e.g. `BaseModelType.StableDiffusion1`, `BaseModelType.StableDiffusionXL`, etc.
|
||||
type: Type type of model to search for, e.g. `ModelType.Main`, `ModelType.Vae`, etc.
|
||||
format: The format of model to search for, e.g. `ModelFormat.Checkpoint`, `ModelFormat.Diffusers`, etc.
|
||||
|
||||
Returns:
|
||||
A list of models that match the attributes.
|
||||
"""
|
||||
|
||||
return self._services.model_manager.store.search_by_attr(
|
||||
@ -357,7 +413,11 @@ class ModelsInterface(InvocationContextInterface):
|
||||
|
||||
class ConfigInterface(InvocationContextInterface):
|
||||
def get(self) -> InvokeAIAppConfig:
|
||||
"""Gets the app's config."""
|
||||
"""Gets the app's config.
|
||||
|
||||
Returns:
|
||||
The app's config.
|
||||
"""
|
||||
|
||||
return self._services.configuration.get_config()
|
||||
|
||||
@ -370,7 +430,11 @@ class UtilInterface(InvocationContextInterface):
|
||||
self._cancel_event = cancel_event
|
||||
|
||||
def is_canceled(self) -> bool:
|
||||
"""Checks if the current invocation has been canceled."""
|
||||
"""Checks if the current session has been canceled.
|
||||
|
||||
Returns:
|
||||
True if the current session has been canceled, False if not.
|
||||
"""
|
||||
return self._cancel_event.is_set()
|
||||
|
||||
def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_model: BaseModelType) -> None:
|
||||
@ -380,8 +444,9 @@ class UtilInterface(InvocationContextInterface):
|
||||
|
||||
This should be called after each denoising step.
|
||||
|
||||
:param intermediate_state: The intermediate state of the diffusion pipeline.
|
||||
:param base_model: The base model for the current denoising step.
|
||||
Args:
|
||||
intermediate_state: The intermediate state of the diffusion pipeline.
|
||||
base_model: The base model for the current denoising step.
|
||||
"""
|
||||
|
||||
stable_diffusion_step_callback(
|
||||
@ -394,8 +459,17 @@ class UtilInterface(InvocationContextInterface):
|
||||
|
||||
|
||||
class InvocationContext:
|
||||
"""
|
||||
The `InvocationContext` provides access to various services and data for the current invocation.
|
||||
"""Provides access to various services and data for the current invocation.
|
||||
|
||||
Attributes:
|
||||
images (ImagesInterface): Methods to save, get and update images and their metadata.
|
||||
tensors (TensorsInterface): Methods to save and get tensors, including image, noise, masks, and masked images.
|
||||
conditioning (ConditioningInterface): Methods to save and get conditioning data.
|
||||
models (ModelsInterface): Methods to check if a model exists, get a model, and get a model's info.
|
||||
logger (LoggerInterface): The app logger.
|
||||
config (ConfigInterface): The app config.
|
||||
util (UtilInterface): Utility methods, including a method to check if an invocation was canceled and step callbacks.
|
||||
boards (BoardsInterface): Methods to interact with boards.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -438,11 +512,14 @@ def build_invocation_context(
|
||||
data: InvocationContextData,
|
||||
cancel_event: threading.Event,
|
||||
) -> InvocationContext:
|
||||
"""
|
||||
Builds the invocation context for a specific invocation execution.
|
||||
"""Builds the invocation context for a specific invocation execution.
|
||||
|
||||
:param services: The invocation services to wrap.
|
||||
:param data: The invocation context data.
|
||||
Args:
|
||||
services: The invocation services to wrap.
|
||||
data: The invocation context data.
|
||||
|
||||
Returns:
|
||||
The invocation context.
|
||||
"""
|
||||
|
||||
logger = LoggerInterface(services=services, data=data)
|
||||
|
Loading…
Reference in New Issue
Block a user