mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): tidy invocation_context.py
, improve comments
This commit is contained in:
parent
8dc1207790
commit
c58951dfcc
@ -1,5 +1,4 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from PIL.Image import Image
|
||||
@ -37,6 +36,9 @@ Wrapping these services provides a simpler and safer interface for nodes to use.
|
||||
When a node executes, a fresh `InvocationContext` is built for it, ensuring nodes cannot interfere
|
||||
with each other.
|
||||
|
||||
Many of the wrappers have the same signature as the methods they wrap. This allows us to write
|
||||
user-facing docstrings and not need to go and update the internal services to match.
|
||||
|
||||
Note: The docstrings are in weird places, but that's where they must be to get IDEs to see them.
|
||||
"""
|
||||
|
||||
@ -44,12 +46,19 @@ Note: The docstrings are in weird places, but that's where they must be to get I
|
||||
@dataclass(frozen=True)
|
||||
class InvocationContextData:
|
||||
invocation: "BaseInvocation"
|
||||
"""The invocation that is being executed."""
|
||||
session_id: str
|
||||
"""The session that is being executed."""
|
||||
queue_id: str
|
||||
"""The queue in which the session is being executed."""
|
||||
source_node_id: str
|
||||
"""The ID of the node from which the currently executing invocation was prepared."""
|
||||
queue_item_id: int
|
||||
"""The ID of the queue item that is being executed."""
|
||||
batch_id: str
|
||||
"""The ID of the batch that is being executed."""
|
||||
workflow: Optional[WorkflowWithoutID] = None
|
||||
"""The workflow associated with this queue item, if any."""
|
||||
|
||||
|
||||
class LoggerInterface:
|
||||
@ -103,14 +112,15 @@ class ImagesInterface:
|
||||
"""
|
||||
Saves an image, returning its DTO.
|
||||
|
||||
If the current queue item has a workflow, it is automatically saved with the image.
|
||||
If the current queue item has a workflow or metadata, it is automatically saved with the image.
|
||||
|
||||
:param image: The image to save, as a PIL image.
|
||||
:param board_id: The board ID to add the image to, if it should be added.
|
||||
:param image_category: The category of the image. Only the GENERAL category is added to the gallery.
|
||||
:param metadata: The metadata to save with the image, if it should have any. If the invocation inherits \
|
||||
from `WithMetadata`, that metadata will be used automatically. Provide this only if you want to \
|
||||
override or provide metadata manually.
|
||||
:param image_category: The category of the image. Only the GENERAL category is added \
|
||||
to the gallery.
|
||||
:param metadata: The metadata to save with the image, if it should have any. If the \
|
||||
invocation inherits from `WithMetadata`, that metadata will be used automatically. \
|
||||
**Use this only if you want to override or provide metadata manually!**
|
||||
"""
|
||||
|
||||
# If the invocation inherits metadata, use that. Else, use the metadata passed in.
|
||||
@ -186,14 +196,6 @@ class ImagesInterface:
|
||||
self.update = update
|
||||
|
||||
|
||||
class LatentsKind(str, Enum):
|
||||
IMAGE = "image"
|
||||
NOISE = "noise"
|
||||
MASK = "mask"
|
||||
MASKED_IMAGE = "masked_image"
|
||||
OTHER = "other"
|
||||
|
||||
|
||||
class LatentsInterface:
|
||||
def __init__(
|
||||
self,
|
||||
@ -206,6 +208,22 @@ class LatentsInterface:
|
||||
|
||||
:param tensor: The latents tensor to save.
|
||||
"""
|
||||
|
||||
# Previously, we added a suffix indicating the type of Tensor we were saving, e.g.
|
||||
# "mask", "noise", "masked_latents", etc.
|
||||
#
|
||||
# Retaining that capability in this wrapper would require either many different methods
|
||||
# to save latents, or extra args for this method. Instead of complicating the API, we
|
||||
# will use the same naming scheme for all latents.
|
||||
#
|
||||
# This has a very minor impact as we don't use them after a session completes.
|
||||
|
||||
# Previously, invocations chose the name for their latents. This is a bit risky, so we
|
||||
# will generate a name for them instead. We use a uuid to ensure the name is unique.
|
||||
#
|
||||
# Because the name of the latents file will includes the session and invocation IDs,
|
||||
# we don't need to worry about collisions. A truncated UUIDv4 is fine.
|
||||
|
||||
name = f"{context_data.session_id}__{context_data.invocation.id}__{uuid_string()[:7]}"
|
||||
services.latents.save(
|
||||
name=name,
|
||||
@ -231,12 +249,21 @@ class ConditioningInterface:
|
||||
services: InvocationServices,
|
||||
context_data: InvocationContextData,
|
||||
) -> None:
|
||||
# TODO(psyche): We are (ab)using the latents storage service as a general pickle storage
|
||||
# service, but it is typed to work with Tensors only. We have to fudge the types here.
|
||||
|
||||
def save(conditioning_data: ConditioningFieldData) -> str:
|
||||
"""
|
||||
Saves a conditioning data object, returning its name.
|
||||
|
||||
:param conditioning_data: The conditioning data to save.
|
||||
"""
|
||||
|
||||
# Conditioning data is *not* a Tensor, so we will suffix it to indicate this.
|
||||
#
|
||||
# See comment for `LatentsInterface.save` for more info about this method (it's very
|
||||
# similar).
|
||||
|
||||
name = f"{context_data.session_id}__{context_data.invocation.id}__{uuid_string()[:7]}__conditioning"
|
||||
services.latents.save(
|
||||
name=name,
|
||||
@ -250,9 +277,8 @@ class ConditioningInterface:
|
||||
|
||||
:param conditioning_name: The name of the conditioning data to get.
|
||||
"""
|
||||
# TODO(sm): We are (ab)using the latents storage service as a general pickle storage
|
||||
# service, but it is typed as returning tensors, so we need to ignore the type here.
|
||||
return services.latents.get(conditioning_name) # type: ignore [return-value]
|
||||
|
||||
return services.latents.get(conditioning_name) # type: ignore [return-value]
|
||||
|
||||
self.save = save
|
||||
self.get = get
|
||||
@ -281,6 +307,17 @@ class ModelsInterface:
|
||||
:param model_type: The type of the model to get.
|
||||
:param submodel: The submodel of the model to get.
|
||||
"""
|
||||
|
||||
# During this call, the model manager emits events with model loading status. The model
|
||||
# manager itself has access to the events services, but does not have access to the
|
||||
# required metadata for the events.
|
||||
#
|
||||
# For example, it needs access to the node's ID so that the events can be associated
|
||||
# with the execution of a specific node.
|
||||
#
|
||||
# While this is available within the node, it's tedious to need to pass it in on every
|
||||
# call. We can avoid that by wrapping the method here.
|
||||
|
||||
return services.model_manager.get_model(
|
||||
model_name, base_model, model_type, submodel, context_data=context_data
|
||||
)
|
||||
@ -306,8 +343,11 @@ class ConfigInterface:
|
||||
"""
|
||||
Gets the app's config.
|
||||
"""
|
||||
# The config can be changed at runtime. We don't want nodes doing this, so we make a
|
||||
# frozen copy..
|
||||
|
||||
# The config can be changed at runtime.
|
||||
#
|
||||
# We don't want nodes doing this, so we make a frozen copy.
|
||||
|
||||
config = services.configuration.get_config()
|
||||
frozen_config = config.model_copy(update={"model_config": ConfigDict(frozen=True)})
|
||||
return frozen_config
|
||||
@ -330,6 +370,12 @@ class UtilInterface:
|
||||
:param intermediate_state: The intermediate state of the diffusion pipeline.
|
||||
:param base_model: The base model for the current denoising step.
|
||||
"""
|
||||
|
||||
# The step callback needs access to the events and the invocation queue services, but this
|
||||
# represents a dangerous level of access.
|
||||
#
|
||||
# We wrap the step callback so that nodes do not have direct access to these services.
|
||||
|
||||
stable_diffusion_step_callback(
|
||||
context_data=context_data,
|
||||
intermediate_state=intermediate_state,
|
||||
@ -343,36 +389,36 @@ class UtilInterface:
|
||||
|
||||
class InvocationContext:
|
||||
"""
|
||||
The invocation context provides access to various services and data about the current invocation.
|
||||
The `InvocationContext` provides access to various services and data for the current invocation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
images: ImagesInterface,
|
||||
latents: LatentsInterface,
|
||||
models: ModelsInterface,
|
||||
config: ConfigInterface,
|
||||
logger: LoggerInterface,
|
||||
data: InvocationContextData,
|
||||
util: UtilInterface,
|
||||
conditioning: ConditioningInterface,
|
||||
models: ModelsInterface,
|
||||
logger: LoggerInterface,
|
||||
config: ConfigInterface,
|
||||
util: UtilInterface,
|
||||
data: InvocationContextData,
|
||||
) -> None:
|
||||
self.images = images
|
||||
"Provides methods to save, get and update images and their metadata."
|
||||
self.logger = logger
|
||||
"Provides access to the app logger."
|
||||
"""Provides methods to save, get and update images and their metadata."""
|
||||
self.latents = latents
|
||||
"Provides methods to save and get latents tensors, including image, noise, masks, and masked images."
|
||||
"""Provides methods to save and get latents tensors, including image, noise, masks, and masked images."""
|
||||
self.conditioning = conditioning
|
||||
"Provides methods to save and get conditioning data."
|
||||
"""Provides methods to save and get conditioning data."""
|
||||
self.models = models
|
||||
"Provides methods to check if a model exists, get a model, and get a model's info."
|
||||
"""Provides methods to check if a model exists, get a model, and get a model's info."""
|
||||
self.logger = logger
|
||||
"""Provides access to the app logger."""
|
||||
self.config = config
|
||||
"Provides access to the app's config."
|
||||
self.data = data
|
||||
"Provides data about the current queue item and invocation."
|
||||
"""Provides access to the app's config."""
|
||||
self.util = util
|
||||
"Provides utility methods."
|
||||
"""Provides utility methods."""
|
||||
self.data = data
|
||||
"""Provides data about the current queue item and invocation."""
|
||||
|
||||
|
||||
def build_invocation_context(
|
||||
@ -380,8 +426,7 @@ def build_invocation_context(
|
||||
context_data: InvocationContextData,
|
||||
) -> InvocationContext:
|
||||
"""
|
||||
Builds the invocation context. This is a wrapper around the invocation services that provides
|
||||
a more convenient (and less dangerous) interface for nodes to use.
|
||||
Builds the invocation context for a specific invocation execution.
|
||||
|
||||
:param invocation_services: The invocation services to wrap.
|
||||
:param invocation_context_data: The invocation context data.
|
||||
|
Loading…
Reference in New Issue
Block a user