feat(nodes): tidy invocation_context.py, improve comments

This commit is contained in:
psychedelicious 2024-01-14 00:05:15 +11:00
parent ef27283569
commit 1616974b48

View File

@ -1,5 +1,4 @@
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from PIL.Image import Image from PIL.Image import Image
@ -37,6 +36,9 @@ Wrapping these services provides a simpler and safer interface for nodes to use.
When a node executes, a fresh `InvocationContext` is built for it, ensuring nodes cannot interfere When a node executes, a fresh `InvocationContext` is built for it, ensuring nodes cannot interfere
with each other. with each other.
Many of the wrappers have the same signature as the methods they wrap. This allows us to write
user-facing docstrings and not need to go and update the internal services to match.
Note: The docstrings are in weird places, but that's where they must be to get IDEs to see them. Note: The docstrings are in weird places, but that's where they must be to get IDEs to see them.
""" """
@ -44,12 +46,19 @@ Note: The docstrings are in weird places, but that's where they must be to get I
@dataclass(frozen=True) @dataclass(frozen=True)
class InvocationContextData: class InvocationContextData:
invocation: "BaseInvocation" invocation: "BaseInvocation"
"""The invocation that is being executed."""
session_id: str session_id: str
"""The session that is being executed."""
queue_id: str queue_id: str
"""The queue in which the session is being executed."""
source_node_id: str source_node_id: str
"""The ID of the node from which the currently executing invocation was prepared."""
queue_item_id: int queue_item_id: int
"""The ID of the queue item that is being executed."""
batch_id: str batch_id: str
"""The ID of the batch that is being executed."""
workflow: Optional[WorkflowWithoutID] = None workflow: Optional[WorkflowWithoutID] = None
"""The workflow associated with this queue item, if any."""
class LoggerInterface: class LoggerInterface:
@ -103,14 +112,15 @@ class ImagesInterface:
""" """
Saves an image, returning its DTO. Saves an image, returning its DTO.
If the current queue item has a workflow, it is automatically saved with the image. If the current queue item has a workflow or metadata, it is automatically saved with the image.
:param image: The image to save, as a PIL image. :param image: The image to save, as a PIL image.
:param board_id: The board ID to add the image to, if it should be added. :param board_id: The board ID to add the image to, if it should be added.
:param image_category: The category of the image. Only the GENERAL category is added to the gallery. :param image_category: The category of the image. Only the GENERAL category is added \
:param metadata: The metadata to save with the image, if it should have any. If the invocation inherits \ to the gallery.
from `WithMetadata`, that metadata will be used automatically. Provide this only if you want to \ :param metadata: The metadata to save with the image, if it should have any. If the \
override or provide metadata manually. invocation inherits from `WithMetadata`, that metadata will be used automatically. \
**Use this only if you want to override or provide metadata manually!**
""" """
# If the invocation inherits metadata, use that. Else, use the metadata passed in. # If the invocation inherits metadata, use that. Else, use the metadata passed in.
@ -186,14 +196,6 @@ class ImagesInterface:
self.update = update self.update = update
class LatentsKind(str, Enum):
IMAGE = "image"
NOISE = "noise"
MASK = "mask"
MASKED_IMAGE = "masked_image"
OTHER = "other"
class LatentsInterface: class LatentsInterface:
def __init__( def __init__(
self, self,
@ -206,6 +208,22 @@ class LatentsInterface:
:param tensor: The latents tensor to save. :param tensor: The latents tensor to save.
""" """
# Previously, we added a suffix indicating the type of Tensor we were saving, e.g.
# "mask", "noise", "masked_latents", etc.
#
# Retaining that capability in this wrapper would require either many different methods
# to save latents, or extra args for this method. Instead of complicating the API, we
# will use the same naming scheme for all latents.
#
# This has a very minor impact as we don't use them after a session completes.
# Previously, invocations chose the name for their latents. This is a bit risky, so we
# will generate a name for them instead. We use a uuid to ensure the name is unique.
#
# Because the name of the latents file will includes the session and invocation IDs,
# we don't need to worry about collisions. A truncated UUIDv4 is fine.
name = f"{context_data.session_id}__{context_data.invocation.id}__{uuid_string()[:7]}" name = f"{context_data.session_id}__{context_data.invocation.id}__{uuid_string()[:7]}"
services.latents.save( services.latents.save(
name=name, name=name,
@ -231,12 +249,21 @@ class ConditioningInterface:
services: InvocationServices, services: InvocationServices,
context_data: InvocationContextData, context_data: InvocationContextData,
) -> None: ) -> None:
# TODO(psyche): We are (ab)using the latents storage service as a general pickle storage
# service, but it is typed to work with Tensors only. We have to fudge the types here.
def save(conditioning_data: ConditioningFieldData) -> str: def save(conditioning_data: ConditioningFieldData) -> str:
""" """
Saves a conditioning data object, returning its name. Saves a conditioning data object, returning its name.
:param conditioning_data: The conditioning data to save. :param conditioning_data: The conditioning data to save.
""" """
# Conditioning data is *not* a Tensor, so we will suffix it to indicate this.
#
# See comment for `LatentsInterface.save` for more info about this method (it's very
# similar).
name = f"{context_data.session_id}__{context_data.invocation.id}__{uuid_string()[:7]}__conditioning" name = f"{context_data.session_id}__{context_data.invocation.id}__{uuid_string()[:7]}__conditioning"
services.latents.save( services.latents.save(
name=name, name=name,
@ -250,8 +277,7 @@ class ConditioningInterface:
:param conditioning_name: The name of the conditioning data to get. :param conditioning_name: The name of the conditioning data to get.
""" """
# TODO(sm): We are (ab)using the latents storage service as a general pickle storage
# service, but it is typed as returning tensors, so we need to ignore the type here.
return services.latents.get(conditioning_name) # type: ignore [return-value] return services.latents.get(conditioning_name) # type: ignore [return-value]
self.save = save self.save = save
@ -281,6 +307,17 @@ class ModelsInterface:
:param model_type: The type of the model to get. :param model_type: The type of the model to get.
:param submodel: The submodel of the model to get. :param submodel: The submodel of the model to get.
""" """
# During this call, the model manager emits events with model loading status. The model
# manager itself has access to the events services, but does not have access to the
# required metadata for the events.
#
# For example, it needs access to the node's ID so that the events can be associated
# with the execution of a specific node.
#
# While this is available within the node, it's tedious to need to pass it in on every
# call. We can avoid that by wrapping the method here.
return services.model_manager.get_model( return services.model_manager.get_model(
model_name, base_model, model_type, submodel, context_data=context_data model_name, base_model, model_type, submodel, context_data=context_data
) )
@ -306,8 +343,11 @@ class ConfigInterface:
""" """
Gets the app's config. Gets the app's config.
""" """
# The config can be changed at runtime. We don't want nodes doing this, so we make a
# frozen copy.. # The config can be changed at runtime.
#
# We don't want nodes doing this, so we make a frozen copy.
config = services.configuration.get_config() config = services.configuration.get_config()
frozen_config = config.model_copy(update={"model_config": ConfigDict(frozen=True)}) frozen_config = config.model_copy(update={"model_config": ConfigDict(frozen=True)})
return frozen_config return frozen_config
@ -330,6 +370,12 @@ class UtilInterface:
:param intermediate_state: The intermediate state of the diffusion pipeline. :param intermediate_state: The intermediate state of the diffusion pipeline.
:param base_model: The base model for the current denoising step. :param base_model: The base model for the current denoising step.
""" """
# The step callback needs access to the events and the invocation queue services, but this
# represents a dangerous level of access.
#
# We wrap the step callback so that nodes do not have direct access to these services.
stable_diffusion_step_callback( stable_diffusion_step_callback(
context_data=context_data, context_data=context_data,
intermediate_state=intermediate_state, intermediate_state=intermediate_state,
@ -343,36 +389,36 @@ class UtilInterface:
class InvocationContext: class InvocationContext:
""" """
The invocation context provides access to various services and data about the current invocation. The `InvocationContext` provides access to various services and data for the current invocation.
""" """
def __init__( def __init__(
self, self,
images: ImagesInterface, images: ImagesInterface,
latents: LatentsInterface, latents: LatentsInterface,
models: ModelsInterface,
config: ConfigInterface,
logger: LoggerInterface,
data: InvocationContextData,
util: UtilInterface,
conditioning: ConditioningInterface, conditioning: ConditioningInterface,
models: ModelsInterface,
logger: LoggerInterface,
config: ConfigInterface,
util: UtilInterface,
data: InvocationContextData,
) -> None: ) -> None:
self.images = images self.images = images
"Provides methods to save, get and update images and their metadata." """Provides methods to save, get and update images and their metadata."""
self.logger = logger
"Provides access to the app logger."
self.latents = latents self.latents = latents
"Provides methods to save and get latents tensors, including image, noise, masks, and masked images." """Provides methods to save and get latents tensors, including image, noise, masks, and masked images."""
self.conditioning = conditioning self.conditioning = conditioning
"Provides methods to save and get conditioning data." """Provides methods to save and get conditioning data."""
self.models = models self.models = models
"Provides methods to check if a model exists, get a model, and get a model's info." """Provides methods to check if a model exists, get a model, and get a model's info."""
self.logger = logger
"""Provides access to the app logger."""
self.config = config self.config = config
"Provides access to the app's config." """Provides access to the app's config."""
self.data = data
"Provides data about the current queue item and invocation."
self.util = util self.util = util
"Provides utility methods." """Provides utility methods."""
self.data = data
"""Provides data about the current queue item and invocation."""
def build_invocation_context( def build_invocation_context(
@ -380,8 +426,7 @@ def build_invocation_context(
context_data: InvocationContextData, context_data: InvocationContextData,
) -> InvocationContext: ) -> InvocationContext:
""" """
Builds the invocation context. This is a wrapper around the invocation services that provides Builds the invocation context for a specific invocation execution.
a more convenient (and less dangerous) interface for nodes to use.
:param invocation_services: The invocation services to wrap. :param invocation_services: The invocation services to wrap.
:param invocation_context_data: The invocation context data. :param invocation_context_data: The invocation context data.