From 95dd5aad16287cc92b8503e6c321fc178361833e Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 5 Feb 2024 17:40:49 +1100 Subject: [PATCH] feat(nodes): add boards interface to invocation context --- .../app/services/shared/invocation_context.py | 54 +++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index a849d6b17a..cbcaa6a548 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -7,6 +7,7 @@ from pydantic import ConfigDict from torch import Tensor from invokeai.app.invocations.fields import MetadataField, WithMetadata +from invokeai.app.services.boards.boards_common import BoardDTO from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin from invokeai.app.services.images.images_common import ImageDTO @@ -63,6 +64,54 @@ class InvocationContextData: """The workflow associated with this queue item, if any.""" +class BoardsInterface: + def __init__(self, services: InvocationServices) -> None: + def create(board_name: str) -> BoardDTO: + """ + Creates a board. + + :param board_name: The name of the board to create. + """ + return services.boards.create(board_name) + + def get_dto(board_id: str) -> BoardDTO: + """ + Gets a board DTO. + + :param board_id: The ID of the board to get. + """ + return services.boards.get_dto(board_id) + + def get_all() -> list[BoardDTO]: + """ + Gets all boards. + """ + return services.boards.get_all() + + def add_image_to_board(board_id: str, image_name: str) -> None: + """ + Adds an image to a board. + + :param board_id: The ID of the board to add the image to. + :param image_name: The name of the image to add to the board. + """ + services.board_images.add_image_to_board(board_id, image_name) + + def get_all_image_names_for_board(board_id: str) -> list[str]: + """ + Gets all image names for a board. + + :param board_id: The ID of the board to get the image names for. + """ + return services.board_images.get_all_board_image_names_for_board(board_id) + + self.create = create + self.get_dto = get_dto + self.get_all = get_all + self.add_image_to_board = add_image_to_board + self.get_all_image_names_for_board = get_all_image_names_for_board + + class LoggerInterface: def __init__(self, services: InvocationServices) -> None: def debug(message: str) -> None: @@ -427,6 +476,7 @@ class InvocationContext: logger: LoggerInterface, config: ConfigInterface, util: UtilInterface, + boards: BoardsInterface, data: InvocationContextData, services: InvocationServices, ) -> None: @@ -444,6 +494,8 @@ class InvocationContext: """Provides access to the app's config.""" self.util = util """Provides utility methods.""" + self.boards = boards + """Provides methods to interact with boards.""" self.data = data """Provides data about the current queue item and invocation.""" self.__services = services @@ -554,6 +606,7 @@ def build_invocation_context( config = ConfigInterface(services=services) util = UtilInterface(services=services, context_data=context_data) conditioning = ConditioningInterface(services=services, context_data=context_data) + boards = BoardsInterface(services=services) ctx = InvocationContext( images=images, @@ -565,6 +618,7 @@ def build_invocation_context( util=util, conditioning=conditioning, services=services, + boards=boards, ) return ctx