diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index b68e521c73..7961c011af 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from enum import Enum from typing import TYPE_CHECKING, Optional from PIL.Image import Image @@ -37,6 +36,9 @@ Wrapping these services provides a simpler and safer interface for nodes to use. When a node executes, a fresh `InvocationContext` is built for it, ensuring nodes cannot interfere with each other. +Many of the wrappers have the same signature as the methods they wrap. This allows us to write +user-facing docstrings and not need to go and update the internal services to match. + Note: The docstrings are in weird places, but that's where they must be to get IDEs to see them. """ @@ -44,12 +46,19 @@ Note: The docstrings are in weird places, but that's where they must be to get I @dataclass(frozen=True) class InvocationContextData: invocation: "BaseInvocation" + """The invocation that is being executed.""" session_id: str + """The session that is being executed.""" queue_id: str + """The queue in which the session is being executed.""" source_node_id: str + """The ID of the node from which the currently executing invocation was prepared.""" queue_item_id: int + """The ID of the queue item that is being executed.""" batch_id: str + """The ID of the batch that is being executed.""" workflow: Optional[WorkflowWithoutID] = None + """The workflow associated with this queue item, if any.""" class LoggerInterface: @@ -103,14 +112,15 @@ class ImagesInterface: """ Saves an image, returning its DTO. - If the current queue item has a workflow, it is automatically saved with the image. + If the current queue item has a workflow or metadata, it is automatically saved with the image. :param image: The image to save, as a PIL image. :param board_id: The board ID to add the image to, if it should be added. - :param image_category: The category of the image. Only the GENERAL category is added to the gallery. - :param metadata: The metadata to save with the image, if it should have any. If the invocation inherits \ - from `WithMetadata`, that metadata will be used automatically. Provide this only if you want to \ - override or provide metadata manually. + :param image_category: The category of the image. Only the GENERAL category is added \ + to the gallery. + :param metadata: The metadata to save with the image, if it should have any. If the \ + invocation inherits from `WithMetadata`, that metadata will be used automatically. \ + **Use this only if you want to override or provide metadata manually!** """ # If the invocation inherits metadata, use that. Else, use the metadata passed in. @@ -186,14 +196,6 @@ class ImagesInterface: self.update = update -class LatentsKind(str, Enum): - IMAGE = "image" - NOISE = "noise" - MASK = "mask" - MASKED_IMAGE = "masked_image" - OTHER = "other" - - class LatentsInterface: def __init__( self, @@ -206,6 +208,22 @@ class LatentsInterface: :param tensor: The latents tensor to save. """ + + # Previously, we added a suffix indicating the type of Tensor we were saving, e.g. + # "mask", "noise", "masked_latents", etc. + # + # Retaining that capability in this wrapper would require either many different methods + # to save latents, or extra args for this method. Instead of complicating the API, we + # will use the same naming scheme for all latents. + # + # This has a very minor impact as we don't use them after a session completes. + + # Previously, invocations chose the name for their latents. This is a bit risky, so we + # will generate a name for them instead. We use a uuid to ensure the name is unique. + # + # Because the name of the latents file will includes the session and invocation IDs, + # we don't need to worry about collisions. A truncated UUIDv4 is fine. + name = f"{context_data.session_id}__{context_data.invocation.id}__{uuid_string()[:7]}" services.latents.save( name=name, @@ -231,12 +249,21 @@ class ConditioningInterface: services: InvocationServices, context_data: InvocationContextData, ) -> None: + # TODO(psyche): We are (ab)using the latents storage service as a general pickle storage + # service, but it is typed to work with Tensors only. We have to fudge the types here. + def save(conditioning_data: ConditioningFieldData) -> str: """ Saves a conditioning data object, returning its name. :param conditioning_data: The conditioning data to save. """ + + # Conditioning data is *not* a Tensor, so we will suffix it to indicate this. + # + # See comment for `LatentsInterface.save` for more info about this method (it's very + # similar). + name = f"{context_data.session_id}__{context_data.invocation.id}__{uuid_string()[:7]}__conditioning" services.latents.save( name=name, @@ -250,9 +277,8 @@ class ConditioningInterface: :param conditioning_name: The name of the conditioning data to get. """ - # TODO(sm): We are (ab)using the latents storage service as a general pickle storage - # service, but it is typed as returning tensors, so we need to ignore the type here. - return services.latents.get(conditioning_name) # type: ignore [return-value] + + return services.latents.get(conditioning_name) # type: ignore [return-value] self.save = save self.get = get @@ -281,6 +307,17 @@ class ModelsInterface: :param model_type: The type of the model to get. :param submodel: The submodel of the model to get. """ + + # During this call, the model manager emits events with model loading status. The model + # manager itself has access to the events services, but does not have access to the + # required metadata for the events. + # + # For example, it needs access to the node's ID so that the events can be associated + # with the execution of a specific node. + # + # While this is available within the node, it's tedious to need to pass it in on every + # call. We can avoid that by wrapping the method here. + return services.model_manager.get_model( model_name, base_model, model_type, submodel, context_data=context_data ) @@ -306,8 +343,11 @@ class ConfigInterface: """ Gets the app's config. """ - # The config can be changed at runtime. We don't want nodes doing this, so we make a - # frozen copy.. + + # The config can be changed at runtime. + # + # We don't want nodes doing this, so we make a frozen copy. + config = services.configuration.get_config() frozen_config = config.model_copy(update={"model_config": ConfigDict(frozen=True)}) return frozen_config @@ -330,6 +370,12 @@ class UtilInterface: :param intermediate_state: The intermediate state of the diffusion pipeline. :param base_model: The base model for the current denoising step. """ + + # The step callback needs access to the events and the invocation queue services, but this + # represents a dangerous level of access. + # + # We wrap the step callback so that nodes do not have direct access to these services. + stable_diffusion_step_callback( context_data=context_data, intermediate_state=intermediate_state, @@ -343,36 +389,36 @@ class UtilInterface: class InvocationContext: """ - The invocation context provides access to various services and data about the current invocation. + The `InvocationContext` provides access to various services and data for the current invocation. """ def __init__( self, images: ImagesInterface, latents: LatentsInterface, - models: ModelsInterface, - config: ConfigInterface, - logger: LoggerInterface, - data: InvocationContextData, - util: UtilInterface, conditioning: ConditioningInterface, + models: ModelsInterface, + logger: LoggerInterface, + config: ConfigInterface, + util: UtilInterface, + data: InvocationContextData, ) -> None: self.images = images - "Provides methods to save, get and update images and their metadata." - self.logger = logger - "Provides access to the app logger." + """Provides methods to save, get and update images and their metadata.""" self.latents = latents - "Provides methods to save and get latents tensors, including image, noise, masks, and masked images." + """Provides methods to save and get latents tensors, including image, noise, masks, and masked images.""" self.conditioning = conditioning - "Provides methods to save and get conditioning data." + """Provides methods to save and get conditioning data.""" self.models = models - "Provides methods to check if a model exists, get a model, and get a model's info." + """Provides methods to check if a model exists, get a model, and get a model's info.""" + self.logger = logger + """Provides access to the app logger.""" self.config = config - "Provides access to the app's config." - self.data = data - "Provides data about the current queue item and invocation." + """Provides access to the app's config.""" self.util = util - "Provides utility methods." + """Provides utility methods.""" + self.data = data + """Provides data about the current queue item and invocation.""" def build_invocation_context( @@ -380,8 +426,7 @@ def build_invocation_context( context_data: InvocationContextData, ) -> InvocationContext: """ - Builds the invocation context. This is a wrapper around the invocation services that provides - a more convenient (and less dangerous) interface for nodes to use. + Builds the invocation context for a specific invocation execution. :param invocation_services: The invocation services to wrap. :param invocation_context_data: The invocation context data.