feat(nodes): do not hide services in invocation context interfaces

This commit is contained in:
psychedelicious 2024-02-07 14:24:05 +11:00
parent cc8d713c57
commit dcafbb9988

View File

@ -64,97 +64,90 @@ class InvocationContextData:
"""The workflow associated with this queue item, if any.""" """The workflow associated with this queue item, if any."""
class BoardsInterface: class InvocationContextInterface:
def __init__(self, services: InvocationServices) -> None: def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None:
def create(board_name: str) -> BoardDTO: self._services = services
self._context_data = context_data
class BoardsInterface(InvocationContextInterface):
def create(self, board_name: str) -> BoardDTO:
""" """
Creates a board. Creates a board.
:param board_name: The name of the board to create. :param board_name: The name of the board to create.
""" """
return services.boards.create(board_name) return self._services.boards.create(board_name)
def get_dto(board_id: str) -> BoardDTO: def get_dto(self, board_id: str) -> BoardDTO:
""" """
Gets a board DTO. Gets a board DTO.
:param board_id: The ID of the board to get. :param board_id: The ID of the board to get.
""" """
return services.boards.get_dto(board_id) return self._services.boards.get_dto(board_id)
def get_all() -> list[BoardDTO]: def get_all(self) -> list[BoardDTO]:
""" """
Gets all boards. Gets all boards.
""" """
return services.boards.get_all() return self._services.boards.get_all()
def add_image_to_board(board_id: str, image_name: str) -> None: def add_image_to_board(self, board_id: str, image_name: str) -> None:
""" """
Adds an image to a board. Adds an image to a board.
:param board_id: The ID of the board to add the image to. :param board_id: The ID of the board to add the image to.
:param image_name: The name of the image to add to the board. :param image_name: The name of the image to add to the board.
""" """
services.board_images.add_image_to_board(board_id, image_name) return self._services.board_images.add_image_to_board(board_id, image_name)
def get_all_image_names_for_board(board_id: str) -> list[str]: def get_all_image_names_for_board(self, board_id: str) -> list[str]:
""" """
Gets all image names for a board. Gets all image names for a board.
:param board_id: The ID of the board to get the image names for. :param board_id: The ID of the board to get the image names for.
""" """
return services.board_images.get_all_board_image_names_for_board(board_id) return self._services.board_images.get_all_board_image_names_for_board(board_id)
self.create = create
self.get_dto = get_dto
self.get_all = get_all
self.add_image_to_board = add_image_to_board
self.get_all_image_names_for_board = get_all_image_names_for_board
class LoggerInterface: class LoggerInterface(InvocationContextInterface):
def __init__(self, services: InvocationServices) -> None: def debug(self, message: str) -> None:
def debug(message: str) -> None:
""" """
Logs a debug message. Logs a debug message.
:param message: The message to log. :param message: The message to log.
""" """
services.logger.debug(message) self._services.logger.debug(message)
def info(message: str) -> None: def info(self, message: str) -> None:
""" """
Logs an info message. Logs an info message.
:param message: The message to log. :param message: The message to log.
""" """
services.logger.info(message) self._services.logger.info(message)
def warning(message: str) -> None: def warning(self, message: str) -> None:
""" """
Logs a warning message. Logs a warning message.
:param message: The message to log. :param message: The message to log.
""" """
services.logger.warning(message) self._services.logger.warning(message)
def error(message: str) -> None: def error(self, message: str) -> None:
""" """
Logs an error message. Logs an error message.
:param message: The message to log. :param message: The message to log.
""" """
services.logger.error(message) self._services.logger.error(message)
self.debug = debug
self.info = info
self.warning = warning
self.error = error
class ImagesInterface: class ImagesInterface(InvocationContextInterface):
def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None:
def save( def save(
self,
image: Image, image: Image,
board_id: Optional[str] = None, board_id: Optional[str] = None,
image_category: ImageCategory = ImageCategory.GENERAL, image_category: ImageCategory = ImageCategory.GENERAL,
@ -176,46 +169,49 @@ class ImagesInterface:
# If the invocation inherits metadata, use that. Else, use the metadata passed in. # If the invocation inherits metadata, use that. Else, use the metadata passed in.
metadata_ = ( metadata_ = (
context_data.invocation.metadata if isinstance(context_data.invocation, WithMetadata) else metadata self._context_data.invocation.metadata
if isinstance(self._context_data.invocation, WithMetadata)
else metadata
) )
return services.images.create( return self._services.images.create(
image=image, image=image,
is_intermediate=context_data.invocation.is_intermediate, is_intermediate=self._context_data.invocation.is_intermediate,
image_category=image_category, image_category=image_category,
board_id=board_id, board_id=board_id,
metadata=metadata_, metadata=metadata_,
image_origin=ResourceOrigin.INTERNAL, image_origin=ResourceOrigin.INTERNAL,
workflow=context_data.workflow, workflow=self._context_data.workflow,
session_id=context_data.session_id, session_id=self._context_data.session_id,
node_id=context_data.invocation.id, node_id=self._context_data.invocation.id,
) )
def get_pil(image_name: str) -> Image: def get_pil(self, image_name: str) -> Image:
""" """
Gets an image as a PIL Image object. Gets an image as a PIL Image object.
:param image_name: The name of the image to get. :param image_name: The name of the image to get.
""" """
return services.images.get_pil_image(image_name) return self._services.images.get_pil_image(image_name)
def get_metadata(image_name: str) -> Optional[MetadataField]: def get_metadata(self, image_name: str) -> Optional[MetadataField]:
""" """
Gets an image's metadata, if it has any. Gets an image's metadata, if it has any.
:param image_name: The name of the image to get the metadata for. :param image_name: The name of the image to get the metadata for.
""" """
return services.images.get_metadata(image_name) return self._services.images.get_metadata(image_name)
def get_dto(image_name: str) -> ImageDTO: def get_dto(self, image_name: str) -> ImageDTO:
""" """
Gets an image as an ImageDTO object. Gets an image as an ImageDTO object.
:param image_name: The name of the image to get. :param image_name: The name of the image to get.
""" """
return services.images.get_dto(image_name) return self._services.images.get_dto(image_name)
def update( def update(
self,
image_name: str, image_name: str,
board_id: Optional[str] = None, board_id: Optional[str] = None,
is_intermediate: Optional[bool] = False, is_intermediate: Optional[bool] = False,
@ -233,27 +229,16 @@ class ImagesInterface:
:param is_intermediate: Whether the image is an intermediate. Intermediate images aren't added to the gallery. :param is_intermediate: Whether the image is an intermediate. Intermediate images aren't added to the gallery.
""" """
if is_intermediate is not None: if is_intermediate is not None:
services.images.update(image_name, ImageRecordChanges(is_intermediate=is_intermediate)) self._services.images.update(image_name, ImageRecordChanges(is_intermediate=is_intermediate))
if board_id is None: if board_id is None:
services.board_images.remove_image_from_board(image_name) self._services.board_images.remove_image_from_board(image_name)
else: else:
services.board_images.add_image_to_board(image_name, board_id) self._services.board_images.add_image_to_board(image_name, board_id)
return services.images.get_dto(image_name) return self._services.images.get_dto(image_name)
self.save = save
self.get_pil = get_pil
self.get_metadata = get_metadata
self.get_dto = get_dto
self.update = update
class LatentsInterface: class LatentsInterface(InvocationContextInterface):
def __init__( def save(self, tensor: Tensor) -> str:
self,
services: InvocationServices,
context_data: InvocationContextData,
) -> None:
def save(tensor: Tensor) -> str:
""" """
Saves a latents tensor, returning its name. Saves a latents tensor, returning its name.
@ -275,39 +260,30 @@ class LatentsInterface:
# Because the name of the latents file will includes the session and invocation IDs, # Because the name of the latents file will includes the session and invocation IDs,
# we don't need to worry about collisions. A truncated UUIDv4 is fine. # we don't need to worry about collisions. A truncated UUIDv4 is fine.
name = f"{context_data.session_id}__{context_data.invocation.id}__{uuid_string()[:7]}" name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}"
services.latents.save( self._services.latents.save(
name=name, name=name,
data=tensor, data=tensor,
) )
return name return name
def get(latents_name: str) -> Tensor: def get(self, latents_name: str) -> Tensor:
""" """
Gets a latents tensor by name. Gets a latents tensor by name.
:param latents_name: The name of the latents tensor to get. :param latents_name: The name of the latents tensor to get.
""" """
return services.latents.get(latents_name) return self._services.latents.get(latents_name)
self.save = save
self.get = get
class ConditioningInterface: class ConditioningInterface(InvocationContextInterface):
def __init__(
self,
services: InvocationServices,
context_data: InvocationContextData,
) -> None:
# TODO(psyche): We are (ab)using the latents storage service as a general pickle storage # TODO(psyche): We are (ab)using the latents storage service as a general pickle storage
# service, but it is typed to work with Tensors only. We have to fudge the types here. # service, but it is typed to work with Tensors only. We have to fudge the types here.
def save(self, conditioning_data: ConditioningFieldData) -> str:
def save(conditioning_data: ConditioningFieldData) -> str:
""" """
Saves a conditioning data object, returning its name. Saves a conditioning data object, returning its name.
:param conditioning_data: The conditioning data to save. :param conditioning_context_data: The conditioning data to save.
""" """
# Conditioning data is *not* a Tensor, so we will suffix it to indicate this. # Conditioning data is *not* a Tensor, so we will suffix it to indicate this.
@ -315,29 +291,25 @@ class ConditioningInterface:
# See comment for `LatentsInterface.save` for more info about this method (it's very # See comment for `LatentsInterface.save` for more info about this method (it's very
# similar). # similar).
name = f"{context_data.session_id}__{context_data.invocation.id}__{uuid_string()[:7]}__conditioning" name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}__conditioning"
services.latents.save( self._services.latents.save(
name=name, name=name,
data=conditioning_data, # type: ignore [arg-type] data=conditioning_data, # type: ignore [arg-type]
) )
return name return name
def get(conditioning_name: str) -> ConditioningFieldData: def get(self, conditioning_name: str) -> ConditioningFieldData:
""" """
Gets conditioning data by name. Gets conditioning data by name.
:param conditioning_name: The name of the conditioning data to get. :param conditioning_name: The name of the conditioning data to get.
""" """
return services.latents.get(conditioning_name) # type: ignore [return-value] return self._services.latents.get(conditioning_name) # type: ignore [return-value]
self.save = save
self.get = get
class ModelsInterface: class ModelsInterface(InvocationContextInterface):
def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None: def exists(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> bool:
def exists(model_name: str, base_model: BaseModelType, model_type: ModelType) -> bool:
""" """
Checks if a model exists. Checks if a model exists.
@ -345,10 +317,10 @@ class ModelsInterface:
:param base_model: The base model of the model to check. :param base_model: The base model of the model to check.
:param model_type: The type of the model to check. :param model_type: The type of the model to check.
""" """
return services.model_manager.model_exists(model_name, base_model, model_type) return self._services.model_manager.model_exists(model_name, base_model, model_type)
def load( def load(
model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None
) -> ModelInfo: ) -> ModelInfo:
""" """
Loads a model, returning its `ModelInfo` object. Loads a model, returning its `ModelInfo` object.
@ -369,11 +341,11 @@ class ModelsInterface:
# While this is available within the node, it's tedious to need to pass it in on every # While this is available within the node, it's tedious to need to pass it in on every
# call. We can avoid that by wrapping the method here. # call. We can avoid that by wrapping the method here.
return services.model_manager.get_model( return self._services.model_manager.get_model(
model_name, base_model, model_type, submodel, context_data=context_data model_name, base_model, model_type, submodel, context_data=self._context_data
) )
def get_info(model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: def get_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
""" """
Gets a model's info, an dict-like object. Gets a model's info, an dict-like object.
@ -381,16 +353,11 @@ class ModelsInterface:
:param base_model: The base model of the model to get. :param base_model: The base model of the model to get.
:param model_type: The type of the model to get. :param model_type: The type of the model to get.
""" """
return services.model_manager.model_info(model_name, base_model, model_type) return self._services.model_manager.model_info(model_name, base_model, model_type)
self.exists = exists
self.load = load
self.get_info = get_info
class ConfigInterface: class ConfigInterface(InvocationContextInterface):
def __init__(self, services: InvocationServices) -> None: def get(self) -> InvokeAIAppConfig:
def get() -> InvokeAIAppConfig:
""" """
Gets the app's config. The config is read-only; attempts to mutate it will raise an error. Gets the app's config. The config is read-only; attempts to mutate it will raise an error.
""" """
@ -399,20 +366,14 @@ class ConfigInterface:
# #
# We don't want nodes doing this, so we make a frozen copy. # We don't want nodes doing this, so we make a frozen copy.
config = services.configuration.get_config() config = self._services.configuration.get_config()
# TODO(psyche): If config cannot be changed at runtime, should we cache this? # TODO(psyche): If config cannot be changed at runtime, should we cache this?
frozen_config = config.model_copy(update={"model_config": ConfigDict(frozen=True)}) frozen_config = config.model_copy(update={"model_config": ConfigDict(frozen=True)})
return frozen_config return frozen_config
self.get = get
class UtilInterface(InvocationContextInterface):
class UtilInterface: def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_model: BaseModelType) -> None:
def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None:
def sd_step_callback(
intermediate_state: PipelineIntermediateState,
base_model: BaseModelType,
) -> None:
""" """
The step callback emits a progress event with the current step, the total number of The step callback emits a progress event with the current step, the total number of
steps, a preview image, and some other internal metadata. steps, a preview image, and some other internal metadata.
@ -429,15 +390,13 @@ class UtilInterface:
# We wrap the step callback so that nodes do not have direct access to these services. # We wrap the step callback so that nodes do not have direct access to these services.
stable_diffusion_step_callback( stable_diffusion_step_callback(
context_data=context_data, context_data=self._context_data,
intermediate_state=intermediate_state, intermediate_state=intermediate_state,
base_model=base_model, base_model=base_model,
invocation_queue=services.queue, invocation_queue=self._services.queue,
events=services.events, events=self._services.events,
) )
self.sd_step_callback = sd_step_callback
deprecation_version = "3.7.0" deprecation_version = "3.7.0"
removed_version = "3.8.0" removed_version = "3.8.0"
@ -600,14 +559,14 @@ def build_invocation_context(
:param invocation_context_data: The invocation context data. :param invocation_context_data: The invocation context data.
""" """
logger = LoggerInterface(services=services) logger = LoggerInterface(services=services, context_data=context_data)
images = ImagesInterface(services=services, context_data=context_data) images = ImagesInterface(services=services, context_data=context_data)
latents = LatentsInterface(services=services, context_data=context_data) latents = LatentsInterface(services=services, context_data=context_data)
models = ModelsInterface(services=services, context_data=context_data) models = ModelsInterface(services=services, context_data=context_data)
config = ConfigInterface(services=services) config = ConfigInterface(services=services, context_data=context_data)
util = UtilInterface(services=services, context_data=context_data) util = UtilInterface(services=services, context_data=context_data)
conditioning = ConditioningInterface(services=services, context_data=context_data) conditioning = ConditioningInterface(services=services, context_data=context_data)
boards = BoardsInterface(services=services) boards = BoardsInterface(services=services, context_data=context_data)
ctx = InvocationContext( ctx = InvocationContext(
images=images, images=images,