feat(nodes): add boards interface to invocation context

This commit is contained in:
psychedelicious 2024-02-05 17:40:49 +11:00
parent 4ce21087d3
commit 95dd5aad16

View File

@ -7,6 +7,7 @@ from pydantic import ConfigDict
from torch import Tensor
from invokeai.app.invocations.fields import MetadataField, WithMetadata
from invokeai.app.services.boards.boards_common import BoardDTO
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
from invokeai.app.services.images.images_common import ImageDTO
@ -63,6 +64,54 @@ class InvocationContextData:
"""The workflow associated with this queue item, if any."""
class BoardsInterface:
def __init__(self, services: InvocationServices) -> None:
def create(board_name: str) -> BoardDTO:
"""
Creates a board.
:param board_name: The name of the board to create.
"""
return services.boards.create(board_name)
def get_dto(board_id: str) -> BoardDTO:
"""
Gets a board DTO.
:param board_id: The ID of the board to get.
"""
return services.boards.get_dto(board_id)
def get_all() -> list[BoardDTO]:
"""
Gets all boards.
"""
return services.boards.get_all()
def add_image_to_board(board_id: str, image_name: str) -> None:
"""
Adds an image to a board.
:param board_id: The ID of the board to add the image to.
:param image_name: The name of the image to add to the board.
"""
services.board_images.add_image_to_board(board_id, image_name)
def get_all_image_names_for_board(board_id: str) -> list[str]:
"""
Gets all image names for a board.
:param board_id: The ID of the board to get the image names for.
"""
return services.board_images.get_all_board_image_names_for_board(board_id)
self.create = create
self.get_dto = get_dto
self.get_all = get_all
self.add_image_to_board = add_image_to_board
self.get_all_image_names_for_board = get_all_image_names_for_board
class LoggerInterface:
def __init__(self, services: InvocationServices) -> None:
def debug(message: str) -> None:
@ -427,6 +476,7 @@ class InvocationContext:
logger: LoggerInterface,
config: ConfigInterface,
util: UtilInterface,
boards: BoardsInterface,
data: InvocationContextData,
services: InvocationServices,
) -> None:
@ -444,6 +494,8 @@ class InvocationContext:
"""Provides access to the app's config."""
self.util = util
"""Provides utility methods."""
self.boards = boards
"""Provides methods to interact with boards."""
self.data = data
"""Provides data about the current queue item and invocation."""
self.__services = services
@ -554,6 +606,7 @@ def build_invocation_context(
config = ConfigInterface(services=services)
util = UtilInterface(services=services, context_data=context_data)
conditioning = ConditioningInterface(services=services, context_data=context_data)
boards = BoardsInterface(services=services)
ctx = InvocationContext(
images=images,
@ -565,6 +618,7 @@ def build_invocation_context(
util=util,
conditioning=conditioning,
services=services,
boards=boards,
)
return ctx