mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
docs(nodes): update all docstrings for public nodes API
This commit is contained in:
parent
84dc5c5c7b
commit
68344ecac9
@ -65,75 +65,86 @@ class InvocationContextInterface:
|
|||||||
|
|
||||||
class BoardsInterface(InvocationContextInterface):
|
class BoardsInterface(InvocationContextInterface):
|
||||||
def create(self, board_name: str) -> BoardDTO:
|
def create(self, board_name: str) -> BoardDTO:
|
||||||
"""
|
"""Creates a board.
|
||||||
Creates a board.
|
|
||||||
|
|
||||||
:param board_name: The name of the board to create.
|
Args:
|
||||||
|
board_name: The name of the board to create.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created board DTO.
|
||||||
"""
|
"""
|
||||||
return self._services.boards.create(board_name)
|
return self._services.boards.create(board_name)
|
||||||
|
|
||||||
def get_dto(self, board_id: str) -> BoardDTO:
|
def get_dto(self, board_id: str) -> BoardDTO:
|
||||||
"""
|
"""Gets a board DTO.
|
||||||
Gets a board DTO.
|
|
||||||
|
|
||||||
:param board_id: The ID of the board to get.
|
Args:
|
||||||
|
board_id: The ID of the board to get.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The board DTO.
|
||||||
"""
|
"""
|
||||||
return self._services.boards.get_dto(board_id)
|
return self._services.boards.get_dto(board_id)
|
||||||
|
|
||||||
def get_all(self) -> list[BoardDTO]:
|
def get_all(self) -> list[BoardDTO]:
|
||||||
"""
|
"""Gets all boards.
|
||||||
Gets all boards.
|
|
||||||
|
Returns:
|
||||||
|
A list of all boards.
|
||||||
"""
|
"""
|
||||||
return self._services.boards.get_all()
|
return self._services.boards.get_all()
|
||||||
|
|
||||||
def add_image_to_board(self, board_id: str, image_name: str) -> None:
|
def add_image_to_board(self, board_id: str, image_name: str) -> None:
|
||||||
"""
|
"""Adds an image to a board.
|
||||||
Adds an image to a board.
|
|
||||||
|
|
||||||
:param board_id: The ID of the board to add the image to.
|
Args:
|
||||||
:param image_name: The name of the image to add to the board.
|
board_id: The ID of the board to add the image to.
|
||||||
|
image_name: The name of the image to add to the board.
|
||||||
"""
|
"""
|
||||||
return self._services.board_images.add_image_to_board(board_id, image_name)
|
return self._services.board_images.add_image_to_board(board_id, image_name)
|
||||||
|
|
||||||
def get_all_image_names_for_board(self, board_id: str) -> list[str]:
|
def get_all_image_names_for_board(self, board_id: str) -> list[str]:
|
||||||
"""
|
"""Gets all image names for a board.
|
||||||
Gets all image names for a board.
|
|
||||||
|
|
||||||
:param board_id: The ID of the board to get the image names for.
|
Args:
|
||||||
|
board_id: The ID of the board to get the image names for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of all image names for the board.
|
||||||
"""
|
"""
|
||||||
return self._services.board_images.get_all_board_image_names_for_board(board_id)
|
return self._services.board_images.get_all_board_image_names_for_board(board_id)
|
||||||
|
|
||||||
|
|
||||||
class LoggerInterface(InvocationContextInterface):
|
class LoggerInterface(InvocationContextInterface):
|
||||||
def debug(self, message: str) -> None:
|
def debug(self, message: str) -> None:
|
||||||
"""
|
"""Logs a debug message.
|
||||||
Logs a debug message.
|
|
||||||
|
|
||||||
:param message: The message to log.
|
Args:
|
||||||
|
message: The message to log.
|
||||||
"""
|
"""
|
||||||
self._services.logger.debug(message)
|
self._services.logger.debug(message)
|
||||||
|
|
||||||
def info(self, message: str) -> None:
|
def info(self, message: str) -> None:
|
||||||
"""
|
"""Logs an info message.
|
||||||
Logs an info message.
|
|
||||||
|
|
||||||
:param message: The message to log.
|
Args:
|
||||||
|
message: The message to log.
|
||||||
"""
|
"""
|
||||||
self._services.logger.info(message)
|
self._services.logger.info(message)
|
||||||
|
|
||||||
def warning(self, message: str) -> None:
|
def warning(self, message: str) -> None:
|
||||||
"""
|
"""Logs a warning message.
|
||||||
Logs a warning message.
|
|
||||||
|
|
||||||
:param message: The message to log.
|
Args:
|
||||||
|
message: The message to log.
|
||||||
"""
|
"""
|
||||||
self._services.logger.warning(message)
|
self._services.logger.warning(message)
|
||||||
|
|
||||||
def error(self, message: str) -> None:
|
def error(self, message: str) -> None:
|
||||||
"""
|
"""Logs an error message.
|
||||||
Logs an error message.
|
|
||||||
|
|
||||||
:param message: The message to log.
|
Args:
|
||||||
|
message: The message to log.
|
||||||
"""
|
"""
|
||||||
self._services.logger.error(message)
|
self._services.logger.error(message)
|
||||||
|
|
||||||
@ -146,20 +157,23 @@ class ImagesInterface(InvocationContextInterface):
|
|||||||
image_category: ImageCategory = ImageCategory.GENERAL,
|
image_category: ImageCategory = ImageCategory.GENERAL,
|
||||||
metadata: Optional[MetadataField] = None,
|
metadata: Optional[MetadataField] = None,
|
||||||
) -> ImageDTO:
|
) -> ImageDTO:
|
||||||
"""
|
"""Saves an image, returning its DTO.
|
||||||
Saves an image, returning its DTO.
|
|
||||||
|
|
||||||
If the current queue item has a workflow or metadata, it is automatically saved with the image.
|
If the current queue item has a workflow or metadata, it is automatically saved with the image.
|
||||||
|
|
||||||
:param image: The image to save, as a PIL image.
|
Args:
|
||||||
:param board_id: The board ID to add the image to, if it should be added. It the invocation \
|
image: The image to save, as a PIL image.
|
||||||
|
board_id: The board ID to add the image to, if it should be added. It the invocation \
|
||||||
inherits from `WithBoard`, that board will be used automatically. **Use this only if \
|
inherits from `WithBoard`, that board will be used automatically. **Use this only if \
|
||||||
you want to override or provide a board manually!**
|
you want to override or provide a board manually!**
|
||||||
:param image_category: The category of the image. Only the GENERAL category is added \
|
image_category: The category of the image. Only the GENERAL category is added \
|
||||||
to the gallery.
|
to the gallery.
|
||||||
:param metadata: The metadata to save with the image, if it should have any. If the \
|
metadata: The metadata to save with the image, if it should have any. If the \
|
||||||
invocation inherits from `WithMetadata`, that metadata will be used automatically. \
|
invocation inherits from `WithMetadata`, that metadata will be used automatically. \
|
||||||
**Use this only if you want to override or provide metadata manually!**
|
**Use this only if you want to override or provide metadata manually!**
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The saved image DTO.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# If `metadata` is provided directly, use that. Else, use the metadata provided by `WithMetadata`, falling back to None.
|
# If `metadata` is provided directly, use that. Else, use the metadata provided by `WithMetadata`, falling back to None.
|
||||||
@ -189,11 +203,14 @@ class ImagesInterface(InvocationContextInterface):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_pil(self, image_name: str, mode: IMAGE_MODES | None = None) -> Image:
|
def get_pil(self, image_name: str, mode: IMAGE_MODES | None = None) -> Image:
|
||||||
"""
|
"""Gets an image as a PIL Image object.
|
||||||
Gets an image as a PIL Image object.
|
|
||||||
|
|
||||||
:param image_name: The name of the image to get.
|
Args:
|
||||||
:param mode: The color mode to convert the image to. If None, the original mode is used.
|
image_name: The name of the image to get.
|
||||||
|
mode: The color mode to convert the image to. If None, the original mode is used.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The image as a PIL Image object.
|
||||||
"""
|
"""
|
||||||
image = self._services.images.get_pil_image(image_name)
|
image = self._services.images.get_pil_image(image_name)
|
||||||
if mode and mode != image.mode:
|
if mode and mode != image.mode:
|
||||||
@ -206,58 +223,76 @@ class ImagesInterface(InvocationContextInterface):
|
|||||||
return image
|
return image
|
||||||
|
|
||||||
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
|
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
|
||||||
"""
|
"""Gets an image's metadata, if it has any.
|
||||||
Gets an image's metadata, if it has any.
|
|
||||||
|
|
||||||
:param image_name: The name of the image to get the metadata for.
|
Args:
|
||||||
|
image_name: The name of the image to get the metadata for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The image's metadata, if it has any.
|
||||||
"""
|
"""
|
||||||
return self._services.images.get_metadata(image_name)
|
return self._services.images.get_metadata(image_name)
|
||||||
|
|
||||||
def get_dto(self, image_name: str) -> ImageDTO:
|
def get_dto(self, image_name: str) -> ImageDTO:
|
||||||
"""
|
"""Gets an image as an ImageDTO object.
|
||||||
Gets an image as an ImageDTO object.
|
|
||||||
|
|
||||||
:param image_name: The name of the image to get.
|
Args:
|
||||||
|
image_name: The name of the image to get.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The image as an ImageDTO object.
|
||||||
"""
|
"""
|
||||||
return self._services.images.get_dto(image_name)
|
return self._services.images.get_dto(image_name)
|
||||||
|
|
||||||
|
|
||||||
class TensorsInterface(InvocationContextInterface):
|
class TensorsInterface(InvocationContextInterface):
|
||||||
def save(self, tensor: Tensor) -> str:
|
def save(self, tensor: Tensor) -> str:
|
||||||
"""
|
"""Saves a tensor, returning its name.
|
||||||
Saves a tensor, returning its name.
|
|
||||||
|
|
||||||
:param tensor: The tensor to save.
|
Args:
|
||||||
|
tensor: The tensor to save.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The name of the saved tensor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name = self._services.tensors.save(obj=tensor)
|
name = self._services.tensors.save(obj=tensor)
|
||||||
return name
|
return name
|
||||||
|
|
||||||
def load(self, name: str) -> Tensor:
|
def load(self, name: str) -> Tensor:
|
||||||
"""
|
"""Loads a tensor by name.
|
||||||
Loads a tensor by name.
|
|
||||||
|
|
||||||
:param name: The name of the tensor to load.
|
Args:
|
||||||
|
name: The name of the tensor to load.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The loaded tensor.
|
||||||
"""
|
"""
|
||||||
return self._services.tensors.load(name)
|
return self._services.tensors.load(name)
|
||||||
|
|
||||||
|
|
||||||
class ConditioningInterface(InvocationContextInterface):
|
class ConditioningInterface(InvocationContextInterface):
|
||||||
def save(self, conditioning_data: ConditioningFieldData) -> str:
|
def save(self, conditioning_data: ConditioningFieldData) -> str:
|
||||||
"""
|
"""Saves a conditioning data object, returning its name.
|
||||||
Saves a conditioning data object, returning its name.
|
|
||||||
|
|
||||||
:param conditioning_data: The conditioning data to save.
|
Args:
|
||||||
|
conditioning_data: The conditioning data to save.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The name of the saved conditioning data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name = self._services.conditioning.save(obj=conditioning_data)
|
name = self._services.conditioning.save(obj=conditioning_data)
|
||||||
return name
|
return name
|
||||||
|
|
||||||
def load(self, name: str) -> ConditioningFieldData:
|
def load(self, name: str) -> ConditioningFieldData:
|
||||||
"""
|
"""Loads conditioning data by name.
|
||||||
Loads conditioning data by name.
|
|
||||||
|
|
||||||
:param name: The name of the conditioning data to load.
|
Args:
|
||||||
|
name: The name of the conditioning data to load.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The loaded conditioning data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self._services.conditioning.load(name)
|
return self._services.conditioning.load(name)
|
||||||
@ -265,20 +300,25 @@ class ConditioningInterface(InvocationContextInterface):
|
|||||||
|
|
||||||
class ModelsInterface(InvocationContextInterface):
|
class ModelsInterface(InvocationContextInterface):
|
||||||
def exists(self, key: str) -> bool:
|
def exists(self, key: str) -> bool:
|
||||||
"""
|
"""Checks if a model exists.
|
||||||
Checks if a model exists.
|
|
||||||
|
|
||||||
:param key: The key of the model.
|
Args:
|
||||||
|
key: The key of the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the model exists, False if not.
|
||||||
"""
|
"""
|
||||||
return self._services.model_manager.store.exists(key)
|
return self._services.model_manager.store.exists(key)
|
||||||
|
|
||||||
def load(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
def load(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||||
"""
|
"""Loads a model.
|
||||||
Loads a model.
|
|
||||||
|
|
||||||
:param key: The key of the model.
|
Args:
|
||||||
:param submodel_type: The submodel of the model to get.
|
key: The key of the model.
|
||||||
:returns: An object representing the loaded model.
|
submodel_type: The submodel of the model to get.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An object representing the loaded model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# The model manager emits events as it loads the model. It needs the context data to build
|
# The model manager emits events as it loads the model. It needs the context data to build
|
||||||
@ -291,13 +331,17 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
def load_by_attrs(
|
def load_by_attrs(
|
||||||
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
|
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
|
||||||
) -> LoadedModel:
|
) -> LoadedModel:
|
||||||
"""
|
"""Loads a model by its attributes.
|
||||||
Loads a model by its attributes.
|
|
||||||
|
|
||||||
:param model_name: Name of to be fetched.
|
Args:
|
||||||
:param base_model: Base model
|
name: Name of the model.
|
||||||
:param model_type: Type of the model
|
base: The models' base type, e.g. `BaseModelType.StableDiffusion1`, `BaseModelType.StableDiffusionXL`, etc.
|
||||||
:param submodel: For main (pipeline models), the submodel to fetch
|
type: Type of the model, e.g. `ModelType.Main`, `ModelType.Vae`, etc.
|
||||||
|
submodel_type: The type of submodel to load, e.g. `SubModelType.UNet`, `SubModelType.TextEncoder`, etc. Only main
|
||||||
|
models have submodels.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An object representing the loaded model.
|
||||||
"""
|
"""
|
||||||
return self._services.model_manager.load_model_by_attr(
|
return self._services.model_manager.load_model_by_attr(
|
||||||
model_name=name,
|
model_name=name,
|
||||||
@ -308,26 +352,35 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_config(self, key: str) -> AnyModelConfig:
|
def get_config(self, key: str) -> AnyModelConfig:
|
||||||
"""
|
"""Gets a model's config.
|
||||||
Gets a model's info, an dict-like object.
|
|
||||||
|
|
||||||
:param key: The key of the model.
|
Args:
|
||||||
|
key: The key of the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The model's config.
|
||||||
"""
|
"""
|
||||||
return self._services.model_manager.store.get_model(key=key)
|
return self._services.model_manager.store.get_model(key=key)
|
||||||
|
|
||||||
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
|
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
|
||||||
"""
|
"""Gets a model's metadata, if it has any.
|
||||||
Gets a model's metadata, if it has any.
|
|
||||||
|
|
||||||
:param key: The key of the model.
|
Args:
|
||||||
|
key: The key of the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The model's metadata, if it has any.
|
||||||
"""
|
"""
|
||||||
return self._services.model_manager.store.get_metadata(key=key)
|
return self._services.model_manager.store.get_metadata(key=key)
|
||||||
|
|
||||||
def search_by_path(self, path: Path) -> list[AnyModelConfig]:
|
def search_by_path(self, path: Path) -> list[AnyModelConfig]:
|
||||||
"""
|
"""Searches for models by path.
|
||||||
Searches for models by path.
|
|
||||||
|
|
||||||
:param path: The path to search for.
|
Args:
|
||||||
|
path: The path to search for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of models that match the path.
|
||||||
"""
|
"""
|
||||||
return self._services.model_manager.store.search_by_path(path)
|
return self._services.model_manager.store.search_by_path(path)
|
||||||
|
|
||||||
@ -338,13 +391,16 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
type: Optional[ModelType] = None,
|
type: Optional[ModelType] = None,
|
||||||
format: Optional[ModelFormat] = None,
|
format: Optional[ModelFormat] = None,
|
||||||
) -> list[AnyModelConfig]:
|
) -> list[AnyModelConfig]:
|
||||||
"""
|
"""Searches for models by attributes.
|
||||||
Searches for models by attributes.
|
|
||||||
|
|
||||||
:param model_name: Name of to be fetched.
|
Args:
|
||||||
:param base_model: Base model
|
name: The name to search for (exact match).
|
||||||
:param model_type: Type of the model
|
base: The base to search for, e.g. `BaseModelType.StableDiffusion1`, `BaseModelType.StableDiffusionXL`, etc.
|
||||||
:param submodel: For main (pipeline models), the submodel to fetch
|
type: Type type of model to search for, e.g. `ModelType.Main`, `ModelType.Vae`, etc.
|
||||||
|
format: The format of model to search for, e.g. `ModelFormat.Checkpoint`, `ModelFormat.Diffusers`, etc.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of models that match the attributes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self._services.model_manager.store.search_by_attr(
|
return self._services.model_manager.store.search_by_attr(
|
||||||
@ -357,7 +413,11 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
|
|
||||||
class ConfigInterface(InvocationContextInterface):
|
class ConfigInterface(InvocationContextInterface):
|
||||||
def get(self) -> InvokeAIAppConfig:
|
def get(self) -> InvokeAIAppConfig:
|
||||||
"""Gets the app's config."""
|
"""Gets the app's config.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The app's config.
|
||||||
|
"""
|
||||||
|
|
||||||
return self._services.configuration.get_config()
|
return self._services.configuration.get_config()
|
||||||
|
|
||||||
@ -370,7 +430,11 @@ class UtilInterface(InvocationContextInterface):
|
|||||||
self._cancel_event = cancel_event
|
self._cancel_event = cancel_event
|
||||||
|
|
||||||
def is_canceled(self) -> bool:
|
def is_canceled(self) -> bool:
|
||||||
"""Checks if the current invocation has been canceled."""
|
"""Checks if the current session has been canceled.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the current session has been canceled, False if not.
|
||||||
|
"""
|
||||||
return self._cancel_event.is_set()
|
return self._cancel_event.is_set()
|
||||||
|
|
||||||
def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_model: BaseModelType) -> None:
|
def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_model: BaseModelType) -> None:
|
||||||
@ -380,8 +444,9 @@ class UtilInterface(InvocationContextInterface):
|
|||||||
|
|
||||||
This should be called after each denoising step.
|
This should be called after each denoising step.
|
||||||
|
|
||||||
:param intermediate_state: The intermediate state of the diffusion pipeline.
|
Args:
|
||||||
:param base_model: The base model for the current denoising step.
|
intermediate_state: The intermediate state of the diffusion pipeline.
|
||||||
|
base_model: The base model for the current denoising step.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
stable_diffusion_step_callback(
|
stable_diffusion_step_callback(
|
||||||
@ -394,8 +459,17 @@ class UtilInterface(InvocationContextInterface):
|
|||||||
|
|
||||||
|
|
||||||
class InvocationContext:
|
class InvocationContext:
|
||||||
"""
|
"""Provides access to various services and data for the current invocation.
|
||||||
The `InvocationContext` provides access to various services and data for the current invocation.
|
|
||||||
|
Attributes:
|
||||||
|
images (ImagesInterface): Methods to save, get and update images and their metadata.
|
||||||
|
tensors (TensorsInterface): Methods to save and get tensors, including image, noise, masks, and masked images.
|
||||||
|
conditioning (ConditioningInterface): Methods to save and get conditioning data.
|
||||||
|
models (ModelsInterface): Methods to check if a model exists, get a model, and get a model's info.
|
||||||
|
logger (LoggerInterface): The app logger.
|
||||||
|
config (ConfigInterface): The app config.
|
||||||
|
util (UtilInterface): Utility methods, including a method to check if an invocation was canceled and step callbacks.
|
||||||
|
boards (BoardsInterface): Methods to interact with boards.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -438,11 +512,14 @@ def build_invocation_context(
|
|||||||
data: InvocationContextData,
|
data: InvocationContextData,
|
||||||
cancel_event: threading.Event,
|
cancel_event: threading.Event,
|
||||||
) -> InvocationContext:
|
) -> InvocationContext:
|
||||||
"""
|
"""Builds the invocation context for a specific invocation execution.
|
||||||
Builds the invocation context for a specific invocation execution.
|
|
||||||
|
|
||||||
:param services: The invocation services to wrap.
|
Args:
|
||||||
:param data: The invocation context data.
|
services: The invocation services to wrap.
|
||||||
|
data: The invocation context data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The invocation context.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
logger = LoggerInterface(services=services, data=data)
|
logger = LoggerInterface(services=services, data=data)
|
||||||
|
Loading…
Reference in New Issue
Block a user