diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index cbbe3216cc..7d378e22e3 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -65,75 +65,86 @@ class InvocationContextInterface: class BoardsInterface(InvocationContextInterface): def create(self, board_name: str) -> BoardDTO: - """ - Creates a board. + """Creates a board. - :param board_name: The name of the board to create. + Args: + board_name: The name of the board to create. + + Returns: + The created board DTO. """ return self._services.boards.create(board_name) def get_dto(self, board_id: str) -> BoardDTO: - """ - Gets a board DTO. + """Gets a board DTO. - :param board_id: The ID of the board to get. + Args: + board_id: The ID of the board to get. + + Returns: + The board DTO. """ return self._services.boards.get_dto(board_id) def get_all(self) -> list[BoardDTO]: - """ - Gets all boards. + """Gets all boards. + + Returns: + A list of all boards. """ return self._services.boards.get_all() def add_image_to_board(self, board_id: str, image_name: str) -> None: - """ - Adds an image to a board. + """Adds an image to a board. - :param board_id: The ID of the board to add the image to. - :param image_name: The name of the image to add to the board. + Args: + board_id: The ID of the board to add the image to. + image_name: The name of the image to add to the board. """ return self._services.board_images.add_image_to_board(board_id, image_name) def get_all_image_names_for_board(self, board_id: str) -> list[str]: - """ - Gets all image names for a board. + """Gets all image names for a board. - :param board_id: The ID of the board to get the image names for. + Args: + board_id: The ID of the board to get the image names for. + + Returns: + A list of all image names for the board. """ return self._services.board_images.get_all_board_image_names_for_board(board_id) class LoggerInterface(InvocationContextInterface): def debug(self, message: str) -> None: - """ - Logs a debug message. + """Logs a debug message. - :param message: The message to log. + Args: + message: The message to log. """ self._services.logger.debug(message) def info(self, message: str) -> None: - """ - Logs an info message. + """Logs an info message. - :param message: The message to log. + Args: + message: The message to log. """ self._services.logger.info(message) def warning(self, message: str) -> None: - """ - Logs a warning message. + """Logs a warning message. - :param message: The message to log. + Args: + message: The message to log. """ self._services.logger.warning(message) def error(self, message: str) -> None: - """ - Logs an error message. + """Logs an error message. - :param message: The message to log. + Args: + message: The message to log. """ self._services.logger.error(message) @@ -146,20 +157,23 @@ class ImagesInterface(InvocationContextInterface): image_category: ImageCategory = ImageCategory.GENERAL, metadata: Optional[MetadataField] = None, ) -> ImageDTO: - """ - Saves an image, returning its DTO. + """Saves an image, returning its DTO. If the current queue item has a workflow or metadata, it is automatically saved with the image. - :param image: The image to save, as a PIL image. - :param board_id: The board ID to add the image to, if it should be added. It the invocation \ + Args: + image: The image to save, as a PIL image. + board_id: The board ID to add the image to, if it should be added. It the invocation \ inherits from `WithBoard`, that board will be used automatically. **Use this only if \ you want to override or provide a board manually!** - :param image_category: The category of the image. Only the GENERAL category is added \ + image_category: The category of the image. Only the GENERAL category is added \ to the gallery. - :param metadata: The metadata to save with the image, if it should have any. If the \ + metadata: The metadata to save with the image, if it should have any. If the \ invocation inherits from `WithMetadata`, that metadata will be used automatically. \ **Use this only if you want to override or provide metadata manually!** + + Returns: + The saved image DTO. """ # If `metadata` is provided directly, use that. Else, use the metadata provided by `WithMetadata`, falling back to None. @@ -189,11 +203,14 @@ class ImagesInterface(InvocationContextInterface): ) def get_pil(self, image_name: str, mode: IMAGE_MODES | None = None) -> Image: - """ - Gets an image as a PIL Image object. + """Gets an image as a PIL Image object. - :param image_name: The name of the image to get. - :param mode: The color mode to convert the image to. If None, the original mode is used. + Args: + image_name: The name of the image to get. + mode: The color mode to convert the image to. If None, the original mode is used. + + Returns: + The image as a PIL Image object. """ image = self._services.images.get_pil_image(image_name) if mode and mode != image.mode: @@ -206,58 +223,76 @@ class ImagesInterface(InvocationContextInterface): return image def get_metadata(self, image_name: str) -> Optional[MetadataField]: - """ - Gets an image's metadata, if it has any. + """Gets an image's metadata, if it has any. - :param image_name: The name of the image to get the metadata for. + Args: + image_name: The name of the image to get the metadata for. + + Returns: + The image's metadata, if it has any. """ return self._services.images.get_metadata(image_name) def get_dto(self, image_name: str) -> ImageDTO: - """ - Gets an image as an ImageDTO object. + """Gets an image as an ImageDTO object. - :param image_name: The name of the image to get. + Args: + image_name: The name of the image to get. + + Returns: + The image as an ImageDTO object. """ return self._services.images.get_dto(image_name) class TensorsInterface(InvocationContextInterface): def save(self, tensor: Tensor) -> str: - """ - Saves a tensor, returning its name. + """Saves a tensor, returning its name. - :param tensor: The tensor to save. + Args: + tensor: The tensor to save. + + Returns: + The name of the saved tensor. """ name = self._services.tensors.save(obj=tensor) return name def load(self, name: str) -> Tensor: - """ - Loads a tensor by name. + """Loads a tensor by name. - :param name: The name of the tensor to load. + Args: + name: The name of the tensor to load. + + Returns: + The loaded tensor. """ return self._services.tensors.load(name) class ConditioningInterface(InvocationContextInterface): def save(self, conditioning_data: ConditioningFieldData) -> str: - """ - Saves a conditioning data object, returning its name. + """Saves a conditioning data object, returning its name. - :param conditioning_data: The conditioning data to save. + Args: + conditioning_data: The conditioning data to save. + + Returns: + The name of the saved conditioning data. """ name = self._services.conditioning.save(obj=conditioning_data) return name def load(self, name: str) -> ConditioningFieldData: - """ - Loads conditioning data by name. + """Loads conditioning data by name. - :param name: The name of the conditioning data to load. + Args: + name: The name of the conditioning data to load. + + Returns: + The loaded conditioning data. """ return self._services.conditioning.load(name) @@ -265,20 +300,25 @@ class ConditioningInterface(InvocationContextInterface): class ModelsInterface(InvocationContextInterface): def exists(self, key: str) -> bool: - """ - Checks if a model exists. + """Checks if a model exists. - :param key: The key of the model. + Args: + key: The key of the model. + + Returns: + True if the model exists, False if not. """ return self._services.model_manager.store.exists(key) def load(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel: - """ - Loads a model. + """Loads a model. - :param key: The key of the model. - :param submodel_type: The submodel of the model to get. - :returns: An object representing the loaded model. + Args: + key: The key of the model. + submodel_type: The submodel of the model to get. + + Returns: + An object representing the loaded model. """ # The model manager emits events as it loads the model. It needs the context data to build @@ -291,13 +331,17 @@ class ModelsInterface(InvocationContextInterface): def load_by_attrs( self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None ) -> LoadedModel: - """ - Loads a model by its attributes. + """Loads a model by its attributes. - :param model_name: Name of to be fetched. - :param base_model: Base model - :param model_type: Type of the model - :param submodel: For main (pipeline models), the submodel to fetch + Args: + name: Name of the model. + base: The models' base type, e.g. `BaseModelType.StableDiffusion1`, `BaseModelType.StableDiffusionXL`, etc. + type: Type of the model, e.g. `ModelType.Main`, `ModelType.Vae`, etc. + submodel_type: The type of submodel to load, e.g. `SubModelType.UNet`, `SubModelType.TextEncoder`, etc. Only main + models have submodels. + + Returns: + An object representing the loaded model. """ return self._services.model_manager.load_model_by_attr( model_name=name, @@ -308,26 +352,35 @@ class ModelsInterface(InvocationContextInterface): ) def get_config(self, key: str) -> AnyModelConfig: - """ - Gets a model's info, an dict-like object. + """Gets a model's config. - :param key: The key of the model. + Args: + key: The key of the model. + + Returns: + The model's config. """ return self._services.model_manager.store.get_model(key=key) def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]: - """ - Gets a model's metadata, if it has any. + """Gets a model's metadata, if it has any. - :param key: The key of the model. + Args: + key: The key of the model. + + Returns: + The model's metadata, if it has any. """ return self._services.model_manager.store.get_metadata(key=key) def search_by_path(self, path: Path) -> list[AnyModelConfig]: - """ - Searches for models by path. + """Searches for models by path. - :param path: The path to search for. + Args: + path: The path to search for. + + Returns: + A list of models that match the path. """ return self._services.model_manager.store.search_by_path(path) @@ -338,13 +391,16 @@ class ModelsInterface(InvocationContextInterface): type: Optional[ModelType] = None, format: Optional[ModelFormat] = None, ) -> list[AnyModelConfig]: - """ - Searches for models by attributes. + """Searches for models by attributes. - :param model_name: Name of to be fetched. - :param base_model: Base model - :param model_type: Type of the model - :param submodel: For main (pipeline models), the submodel to fetch + Args: + name: The name to search for (exact match). + base: The base to search for, e.g. `BaseModelType.StableDiffusion1`, `BaseModelType.StableDiffusionXL`, etc. + type: Type type of model to search for, e.g. `ModelType.Main`, `ModelType.Vae`, etc. + format: The format of model to search for, e.g. `ModelFormat.Checkpoint`, `ModelFormat.Diffusers`, etc. + + Returns: + A list of models that match the attributes. """ return self._services.model_manager.store.search_by_attr( @@ -357,7 +413,11 @@ class ModelsInterface(InvocationContextInterface): class ConfigInterface(InvocationContextInterface): def get(self) -> InvokeAIAppConfig: - """Gets the app's config.""" + """Gets the app's config. + + Returns: + The app's config. + """ return self._services.configuration.get_config() @@ -370,7 +430,11 @@ class UtilInterface(InvocationContextInterface): self._cancel_event = cancel_event def is_canceled(self) -> bool: - """Checks if the current invocation has been canceled.""" + """Checks if the current session has been canceled. + + Returns: + True if the current session has been canceled, False if not. + """ return self._cancel_event.is_set() def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_model: BaseModelType) -> None: @@ -380,8 +444,9 @@ class UtilInterface(InvocationContextInterface): This should be called after each denoising step. - :param intermediate_state: The intermediate state of the diffusion pipeline. - :param base_model: The base model for the current denoising step. + Args: + intermediate_state: The intermediate state of the diffusion pipeline. + base_model: The base model for the current denoising step. """ stable_diffusion_step_callback( @@ -394,8 +459,17 @@ class UtilInterface(InvocationContextInterface): class InvocationContext: - """ - The `InvocationContext` provides access to various services and data for the current invocation. + """Provides access to various services and data for the current invocation. + + Attributes: + images (ImagesInterface): Methods to save, get and update images and their metadata. + tensors (TensorsInterface): Methods to save and get tensors, including image, noise, masks, and masked images. + conditioning (ConditioningInterface): Methods to save and get conditioning data. + models (ModelsInterface): Methods to check if a model exists, get a model, and get a model's info. + logger (LoggerInterface): The app logger. + config (ConfigInterface): The app config. + util (UtilInterface): Utility methods, including a method to check if an invocation was canceled and step callbacks. + boards (BoardsInterface): Methods to interact with boards. """ def __init__( @@ -438,11 +512,14 @@ def build_invocation_context( data: InvocationContextData, cancel_event: threading.Event, ) -> InvocationContext: - """ - Builds the invocation context for a specific invocation execution. + """Builds the invocation context for a specific invocation execution. - :param services: The invocation services to wrap. - :param data: The invocation context data. + Args: + services: The invocation services to wrap. + data: The invocation context data. + + Returns: + The invocation context. """ logger = LoggerInterface(services=services, data=data)