feat(nodes): do not hide services in invocation context interfaces

This commit is contained in:
psychedelicious 2024-02-07 14:24:05 +11:00 committed by Brandon Rising
parent 0ff466ebc4
commit 0fde0d1ff7

View File

@ -64,379 +64,338 @@ class InvocationContextData:
"""The workflow associated with this queue item, if any.""" """The workflow associated with this queue item, if any."""
class BoardsInterface: class InvocationContextInterface:
def __init__(self, services: InvocationServices) -> None:
def create(board_name: str) -> BoardDTO:
"""
Creates a board.
:param board_name: The name of the board to create.
"""
return services.boards.create(board_name)
def get_dto(board_id: str) -> BoardDTO:
"""
Gets a board DTO.
:param board_id: The ID of the board to get.
"""
return services.boards.get_dto(board_id)
def get_all() -> list[BoardDTO]:
"""
Gets all boards.
"""
return services.boards.get_all()
def add_image_to_board(board_id: str, image_name: str) -> None:
"""
Adds an image to a board.
:param board_id: The ID of the board to add the image to.
:param image_name: The name of the image to add to the board.
"""
services.board_images.add_image_to_board(board_id, image_name)
def get_all_image_names_for_board(board_id: str) -> list[str]:
"""
Gets all image names for a board.
:param board_id: The ID of the board to get the image names for.
"""
return services.board_images.get_all_board_image_names_for_board(board_id)
self.create = create
self.get_dto = get_dto
self.get_all = get_all
self.add_image_to_board = add_image_to_board
self.get_all_image_names_for_board = get_all_image_names_for_board
class LoggerInterface:
def __init__(self, services: InvocationServices) -> None:
def debug(message: str) -> None:
"""
Logs a debug message.
:param message: The message to log.
"""
services.logger.debug(message)
def info(message: str) -> None:
"""
Logs an info message.
:param message: The message to log.
"""
services.logger.info(message)
def warning(message: str) -> None:
"""
Logs a warning message.
:param message: The message to log.
"""
services.logger.warning(message)
def error(message: str) -> None:
"""
Logs an error message.
:param message: The message to log.
"""
services.logger.error(message)
self.debug = debug
self.info = info
self.warning = warning
self.error = error
class ImagesInterface:
def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None: def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None:
def save( self._services = services
image: Image, self._context_data = context_data
board_id: Optional[str] = None,
image_category: ImageCategory = ImageCategory.GENERAL,
metadata: Optional[MetadataField] = None,
) -> ImageDTO:
"""
Saves an image, returning its DTO.
If the current queue item has a workflow or metadata, it is automatically saved with the image.
:param image: The image to save, as a PIL image.
:param board_id: The board ID to add the image to, if it should be added.
:param image_category: The category of the image. Only the GENERAL category is added \
to the gallery.
:param metadata: The metadata to save with the image, if it should have any. If the \
invocation inherits from `WithMetadata`, that metadata will be used automatically. \
**Use this only if you want to override or provide metadata manually!**
"""
# If the invocation inherits metadata, use that. Else, use the metadata passed in.
metadata_ = (
context_data.invocation.metadata if isinstance(context_data.invocation, WithMetadata) else metadata
)
return services.images.create(
image=image,
is_intermediate=context_data.invocation.is_intermediate,
image_category=image_category,
board_id=board_id,
metadata=metadata_,
image_origin=ResourceOrigin.INTERNAL,
workflow=context_data.workflow,
session_id=context_data.session_id,
node_id=context_data.invocation.id,
)
def get_pil(image_name: str) -> Image:
"""
Gets an image as a PIL Image object.
:param image_name: The name of the image to get.
"""
return services.images.get_pil_image(image_name)
def get_metadata(image_name: str) -> Optional[MetadataField]:
"""
Gets an image's metadata, if it has any.
:param image_name: The name of the image to get the metadata for.
"""
return services.images.get_metadata(image_name)
def get_dto(image_name: str) -> ImageDTO:
"""
Gets an image as an ImageDTO object.
:param image_name: The name of the image to get.
"""
return services.images.get_dto(image_name)
def update(
image_name: str,
board_id: Optional[str] = None,
is_intermediate: Optional[bool] = False,
) -> ImageDTO:
"""
Updates an image, returning its updated DTO.
It is not suggested to update images saved by earlier nodes, as this can cause confusion for users.
If you use this method, you *must* return the image as an :class:`ImageOutput` for the gallery to
get the updated image.
:param image_name: The name of the image to update.
:param board_id: The board ID to add the image to, if it should be added.
:param is_intermediate: Whether the image is an intermediate. Intermediate images aren't added to the gallery.
"""
if is_intermediate is not None:
services.images.update(image_name, ImageRecordChanges(is_intermediate=is_intermediate))
if board_id is None:
services.board_images.remove_image_from_board(image_name)
else:
services.board_images.add_image_to_board(image_name, board_id)
return services.images.get_dto(image_name)
self.save = save
self.get_pil = get_pil
self.get_metadata = get_metadata
self.get_dto = get_dto
self.update = update
class LatentsInterface: class BoardsInterface(InvocationContextInterface):
def __init__( def create(self, board_name: str) -> BoardDTO:
"""
Creates a board.
:param board_name: The name of the board to create.
"""
return self._services.boards.create(board_name)
def get_dto(self, board_id: str) -> BoardDTO:
"""
Gets a board DTO.
:param board_id: The ID of the board to get.
"""
return self._services.boards.get_dto(board_id)
def get_all(self) -> list[BoardDTO]:
"""
Gets all boards.
"""
return self._services.boards.get_all()
def add_image_to_board(self, board_id: str, image_name: str) -> None:
"""
Adds an image to a board.
:param board_id: The ID of the board to add the image to.
:param image_name: The name of the image to add to the board.
"""
return self._services.board_images.add_image_to_board(board_id, image_name)
def get_all_image_names_for_board(self, board_id: str) -> list[str]:
"""
Gets all image names for a board.
:param board_id: The ID of the board to get the image names for.
"""
return self._services.board_images.get_all_board_image_names_for_board(board_id)
class LoggerInterface(InvocationContextInterface):
def debug(self, message: str) -> None:
"""
Logs a debug message.
:param message: The message to log.
"""
self._services.logger.debug(message)
def info(self, message: str) -> None:
"""
Logs an info message.
:param message: The message to log.
"""
self._services.logger.info(message)
def warning(self, message: str) -> None:
"""
Logs a warning message.
:param message: The message to log.
"""
self._services.logger.warning(message)
def error(self, message: str) -> None:
"""
Logs an error message.
:param message: The message to log.
"""
self._services.logger.error(message)
class ImagesInterface(InvocationContextInterface):
def save(
self, self,
services: InvocationServices, image: Image,
context_data: InvocationContextData, board_id: Optional[str] = None,
) -> None: image_category: ImageCategory = ImageCategory.GENERAL,
def save(tensor: Tensor) -> str: metadata: Optional[MetadataField] = None,
""" ) -> ImageDTO:
Saves a latents tensor, returning its name. """
Saves an image, returning its DTO.
:param tensor: The latents tensor to save. If the current queue item has a workflow or metadata, it is automatically saved with the image.
"""
# Previously, we added a suffix indicating the type of Tensor we were saving, e.g. :param image: The image to save, as a PIL image.
# "mask", "noise", "masked_latents", etc. :param board_id: The board ID to add the image to, if it should be added.
# :param image_category: The category of the image. Only the GENERAL category is added \
# Retaining that capability in this wrapper would require either many different methods to the gallery.
# to save latents, or extra args for this method. Instead of complicating the API, we :param metadata: The metadata to save with the image, if it should have any. If the \
# will use the same naming scheme for all latents. invocation inherits from `WithMetadata`, that metadata will be used automatically. \
# **Use this only if you want to override or provide metadata manually!**
# This has a very minor impact as we don't use them after a session completes. """
# Previously, invocations chose the name for their latents. This is a bit risky, so we # If the invocation inherits metadata, use that. Else, use the metadata passed in.
# will generate a name for them instead. We use a uuid to ensure the name is unique. metadata_ = (
# self._context_data.invocation.metadata
# Because the name of the latents file will includes the session and invocation IDs, if isinstance(self._context_data.invocation, WithMetadata)
# we don't need to worry about collisions. A truncated UUIDv4 is fine. else metadata
)
name = f"{context_data.session_id}__{context_data.invocation.id}__{uuid_string()[:7]}" return self._services.images.create(
services.latents.save( image=image,
name=name, is_intermediate=self._context_data.invocation.is_intermediate,
data=tensor, image_category=image_category,
) board_id=board_id,
return name metadata=metadata_,
image_origin=ResourceOrigin.INTERNAL,
workflow=self._context_data.workflow,
session_id=self._context_data.session_id,
node_id=self._context_data.invocation.id,
)
def get(latents_name: str) -> Tensor: def get_pil(self, image_name: str) -> Image:
""" """
Gets a latents tensor by name. Gets an image as a PIL Image object.
:param latents_name: The name of the latents tensor to get. :param image_name: The name of the image to get.
""" """
return services.latents.get(latents_name) return self._services.images.get_pil_image(image_name)
self.save = save def get_metadata(self, image_name: str) -> Optional[MetadataField]:
self.get = get """
Gets an image's metadata, if it has any.
:param image_name: The name of the image to get the metadata for.
"""
return self._services.images.get_metadata(image_name)
class ConditioningInterface: def get_dto(self, image_name: str) -> ImageDTO:
def __init__( """
Gets an image as an ImageDTO object.
:param image_name: The name of the image to get.
"""
return self._services.images.get_dto(image_name)
def update(
self, self,
services: InvocationServices, image_name: str,
context_data: InvocationContextData, board_id: Optional[str] = None,
) -> None: is_intermediate: Optional[bool] = False,
# TODO(psyche): We are (ab)using the latents storage service as a general pickle storage ) -> ImageDTO:
# service, but it is typed to work with Tensors only. We have to fudge the types here. """
Updates an image, returning its updated DTO.
def save(conditioning_data: ConditioningFieldData) -> str: It is not suggested to update images saved by earlier nodes, as this can cause confusion for users.
"""
Saves a conditioning data object, returning its name.
:param conditioning_data: The conditioning data to save. If you use this method, you *must* return the image as an :class:`ImageOutput` for the gallery to
""" get the updated image.
# Conditioning data is *not* a Tensor, so we will suffix it to indicate this. :param image_name: The name of the image to update.
# :param board_id: The board ID to add the image to, if it should be added.
# See comment for `LatentsInterface.save` for more info about this method (it's very :param is_intermediate: Whether the image is an intermediate. Intermediate images aren't added to the gallery.
# similar). """
if is_intermediate is not None:
name = f"{context_data.session_id}__{context_data.invocation.id}__{uuid_string()[:7]}__conditioning" self._services.images.update(image_name, ImageRecordChanges(is_intermediate=is_intermediate))
services.latents.save( if board_id is None:
name=name, self._services.board_images.remove_image_from_board(image_name)
data=conditioning_data, # type: ignore [arg-type] else:
) self._services.board_images.add_image_to_board(image_name, board_id)
return name return self._services.images.get_dto(image_name)
def get(conditioning_name: str) -> ConditioningFieldData:
"""
Gets conditioning data by name.
:param conditioning_name: The name of the conditioning data to get.
"""
return services.latents.get(conditioning_name) # type: ignore [return-value]
self.save = save
self.get = get
class ModelsInterface: class LatentsInterface(InvocationContextInterface):
def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None: def save(self, tensor: Tensor) -> str:
def exists(model_name: str, base_model: BaseModelType, model_type: ModelType) -> bool: """
""" Saves a latents tensor, returning its name.
Checks if a model exists.
:param model_name: The name of the model to check. :param tensor: The latents tensor to save.
:param base_model: The base model of the model to check. """
:param model_type: The type of the model to check.
"""
return services.model_manager.model_exists(model_name, base_model, model_type)
def load( # Previously, we added a suffix indicating the type of Tensor we were saving, e.g.
model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None # "mask", "noise", "masked_latents", etc.
) -> ModelInfo: #
""" # Retaining that capability in this wrapper would require either many different methods
Loads a model, returning its `ModelInfo` object. # to save latents, or extra args for this method. Instead of complicating the API, we
# will use the same naming scheme for all latents.
#
# This has a very minor impact as we don't use them after a session completes.
:param model_name: The name of the model to get. # Previously, invocations chose the name for their latents. This is a bit risky, so we
:param base_model: The base model of the model to get. # will generate a name for them instead. We use a uuid to ensure the name is unique.
:param model_type: The type of the model to get. #
:param submodel: The submodel of the model to get. # Because the name of the latents file will includes the session and invocation IDs,
""" # we don't need to worry about collisions. A truncated UUIDv4 is fine.
# During this call, the model manager emits events with model loading status. The model name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}"
# manager itself has access to the events services, but does not have access to the self._services.latents.save(
# required metadata for the events. name=name,
# data=tensor,
# For example, it needs access to the node's ID so that the events can be associated )
# with the execution of a specific node. return name
#
# While this is available within the node, it's tedious to need to pass it in on every
# call. We can avoid that by wrapping the method here.
return services.model_manager.get_model( def get(self, latents_name: str) -> Tensor:
model_name, base_model, model_type, submodel, context_data=context_data """
) Gets a latents tensor by name.
def get_info(model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: :param latents_name: The name of the latents tensor to get.
""" """
Gets a model's info, an dict-like object. return self._services.latents.get(latents_name)
:param model_name: The name of the model to get.
:param base_model: The base model of the model to get.
:param model_type: The type of the model to get.
"""
return services.model_manager.model_info(model_name, base_model, model_type)
self.exists = exists
self.load = load
self.get_info = get_info
class ConfigInterface: class ConditioningInterface(InvocationContextInterface):
def __init__(self, services: InvocationServices) -> None: # TODO(psyche): We are (ab)using the latents storage service as a general pickle storage
def get() -> InvokeAIAppConfig: # service, but it is typed to work with Tensors only. We have to fudge the types here.
""" def save(self, conditioning_data: ConditioningFieldData) -> str:
Gets the app's config. The config is read-only; attempts to mutate it will raise an error. """
""" Saves a conditioning data object, returning its name.
# The config can be changed at runtime. :param conditioning_context_data: The conditioning data to save.
# """
# We don't want nodes doing this, so we make a frozen copy.
config = services.configuration.get_config() # Conditioning data is *not* a Tensor, so we will suffix it to indicate this.
# TODO(psyche): If config cannot be changed at runtime, should we cache this? #
frozen_config = config.model_copy(update={"model_config": ConfigDict(frozen=True)}) # See comment for `LatentsInterface.save` for more info about this method (it's very
return frozen_config # similar).
self.get = get name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}__conditioning"
self._services.latents.save(
name=name,
data=conditioning_data, # type: ignore [arg-type]
)
return name
def get(self, conditioning_name: str) -> ConditioningFieldData:
"""
Gets conditioning data by name.
:param conditioning_name: The name of the conditioning data to get.
"""
return self._services.latents.get(conditioning_name) # type: ignore [return-value]
class UtilInterface: class ModelsInterface(InvocationContextInterface):
def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None: def exists(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> bool:
def sd_step_callback( """
intermediate_state: PipelineIntermediateState, Checks if a model exists.
base_model: BaseModelType,
) -> None:
"""
The step callback emits a progress event with the current step, the total number of
steps, a preview image, and some other internal metadata.
This should be called after each denoising step. :param model_name: The name of the model to check.
:param base_model: The base model of the model to check.
:param model_type: The type of the model to check.
"""
return self._services.model_manager.model_exists(model_name, base_model, model_type)
:param intermediate_state: The intermediate state of the diffusion pipeline. def load(
:param base_model: The base model for the current denoising step. self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None
""" ) -> ModelInfo:
"""
Loads a model, returning its `ModelInfo` object.
# The step callback needs access to the events and the invocation queue services, but this :param model_name: The name of the model to get.
# represents a dangerous level of access. :param base_model: The base model of the model to get.
# :param model_type: The type of the model to get.
# We wrap the step callback so that nodes do not have direct access to these services. :param submodel: The submodel of the model to get.
"""
stable_diffusion_step_callback( # During this call, the model manager emits events with model loading status. The model
context_data=context_data, # manager itself has access to the events services, but does not have access to the
intermediate_state=intermediate_state, # required metadata for the events.
base_model=base_model, #
invocation_queue=services.queue, # For example, it needs access to the node's ID so that the events can be associated
events=services.events, # with the execution of a specific node.
) #
# While this is available within the node, it's tedious to need to pass it in on every
# call. We can avoid that by wrapping the method here.
self.sd_step_callback = sd_step_callback return self._services.model_manager.get_model(
model_name, base_model, model_type, submodel, context_data=self._context_data
)
def get_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
"""
Gets a model's info, an dict-like object.
:param model_name: The name of the model to get.
:param base_model: The base model of the model to get.
:param model_type: The type of the model to get.
"""
return self._services.model_manager.model_info(model_name, base_model, model_type)
class ConfigInterface(InvocationContextInterface):
def get(self) -> InvokeAIAppConfig:
"""
Gets the app's config. The config is read-only; attempts to mutate it will raise an error.
"""
# The config can be changed at runtime.
#
# We don't want nodes doing this, so we make a frozen copy.
config = self._services.configuration.get_config()
# TODO(psyche): If config cannot be changed at runtime, should we cache this?
frozen_config = config.model_copy(update={"model_config": ConfigDict(frozen=True)})
return frozen_config
class UtilInterface(InvocationContextInterface):
def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_model: BaseModelType) -> None:
"""
The step callback emits a progress event with the current step, the total number of
steps, a preview image, and some other internal metadata.
This should be called after each denoising step.
:param intermediate_state: The intermediate state of the diffusion pipeline.
:param base_model: The base model for the current denoising step.
"""
# The step callback needs access to the events and the invocation queue services, but this
# represents a dangerous level of access.
#
# We wrap the step callback so that nodes do not have direct access to these services.
stable_diffusion_step_callback(
context_data=self._context_data,
intermediate_state=intermediate_state,
base_model=base_model,
invocation_queue=self._services.queue,
events=self._services.events,
)
deprecation_version = "3.7.0" deprecation_version = "3.7.0"
@ -600,14 +559,14 @@ def build_invocation_context(
:param invocation_context_data: The invocation context data. :param invocation_context_data: The invocation context data.
""" """
logger = LoggerInterface(services=services) logger = LoggerInterface(services=services, context_data=context_data)
images = ImagesInterface(services=services, context_data=context_data) images = ImagesInterface(services=services, context_data=context_data)
latents = LatentsInterface(services=services, context_data=context_data) latents = LatentsInterface(services=services, context_data=context_data)
models = ModelsInterface(services=services, context_data=context_data) models = ModelsInterface(services=services, context_data=context_data)
config = ConfigInterface(services=services) config = ConfigInterface(services=services, context_data=context_data)
util = UtilInterface(services=services, context_data=context_data) util = UtilInterface(services=services, context_data=context_data)
conditioning = ConditioningInterface(services=services, context_data=context_data) conditioning = ConditioningInterface(services=services, context_data=context_data)
boards = BoardsInterface(services=services) boards = BoardsInterface(services=services, context_data=context_data)
ctx = InvocationContext( ctx = InvocationContext(
images=images, images=images,