feat(nodes): restricts invocation context power

Creates a low-power `InvocationContext` with simplified methods and data.

See `invocation_context.py` for detailed comments.
This commit is contained in:
psychedelicious 2024-01-13 18:02:58 +11:00
parent 992b02aa65
commit 3d98446d5d
2 changed files with 434 additions and 13 deletions

View File

@ -0,0 +1,408 @@
from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING, Optional
from PIL.Image import Image
from pydantic import ConfigDict
from torch import Tensor
from invokeai.app.invocations.compel import ConditioningFieldData
from invokeai.app.invocations.fields import MetadataField, WithMetadata
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
from invokeai.app.util.misc import uuid_string
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_management.model_manager import ModelInfo
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
if TYPE_CHECKING:
from invokeai.app.invocations.baseinvocation import BaseInvocation
"""
The InvocationContext provides access to various services and data about the current invocation.
We do not provide the invocation services directly, as their methods are both dangerous and
inconvenient to use.
For example:
- The `images` service allows nodes to delete or unsafely modify existing images.
- The `configuration` service allows nodes to change the app's config at runtime.
- The `events` service allows nodes to emit arbitrary events.
Wrapping these services provides a simpler and safer interface for nodes to use.
When a node executes, a fresh `InvocationContext` is built for it, ensuring nodes cannot interfere
with each other.
Note: The docstrings are in weird places, but that's where they must be to get IDEs to see them.
"""
@dataclass(frozen=True)
class InvocationContextData:
invocation: "BaseInvocation"
session_id: str
queue_id: str
source_node_id: str
queue_item_id: int
batch_id: str
workflow: Optional[WorkflowWithoutID] = None
class LoggerInterface:
def __init__(self, services: InvocationServices) -> None:
def debug(message: str) -> None:
"""
Logs a debug message.
:param message: The message to log.
"""
services.logger.debug(message)
def info(message: str) -> None:
"""
Logs an info message.
:param message: The message to log.
"""
services.logger.info(message)
def warning(message: str) -> None:
"""
Logs a warning message.
:param message: The message to log.
"""
services.logger.warning(message)
def error(message: str) -> None:
"""
Logs an error message.
:param message: The message to log.
"""
services.logger.error(message)
self.debug = debug
self.info = info
self.warning = warning
self.error = error
class ImagesInterface:
def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None:
def save(
image: Image,
board_id: Optional[str] = None,
image_category: ImageCategory = ImageCategory.GENERAL,
metadata: Optional[MetadataField] = None,
) -> ImageDTO:
"""
Saves an image, returning its DTO.
If the current queue item has a workflow, it is automatically saved with the image.
:param image: The image to save, as a PIL image.
:param board_id: The board ID to add the image to, if it should be added.
:param image_category: The category of the image. Only the GENERAL category is added to the gallery.
:param metadata: The metadata to save with the image, if it should have any. If the invocation inherits \
from `WithMetadata`, that metadata will be used automatically. Provide this only if you want to \
override or provide metadata manually.
"""
# If the invocation inherits metadata, use that. Else, use the metadata passed in.
metadata_ = (
context_data.invocation.metadata if isinstance(context_data.invocation, WithMetadata) else metadata
)
return services.images.create(
image=image,
is_intermediate=context_data.invocation.is_intermediate,
image_category=image_category,
board_id=board_id,
metadata=metadata_,
image_origin=ResourceOrigin.INTERNAL,
workflow=context_data.workflow,
session_id=context_data.session_id,
node_id=context_data.invocation.id,
)
def get_pil(image_name: str) -> Image:
"""
Gets an image as a PIL Image object.
:param image_name: The name of the image to get.
"""
return services.images.get_pil_image(image_name)
def get_metadata(image_name: str) -> Optional[MetadataField]:
"""
Gets an image's metadata, if it has any.
:param image_name: The name of the image to get the metadata for.
"""
return services.images.get_metadata(image_name)
def get_dto(image_name: str) -> ImageDTO:
"""
Gets an image as an ImageDTO object.
:param image_name: The name of the image to get.
"""
return services.images.get_dto(image_name)
def update(
image_name: str,
board_id: Optional[str] = None,
is_intermediate: Optional[bool] = False,
) -> ImageDTO:
"""
Updates an image, returning its updated DTO.
It is not suggested to update images saved by earlier nodes, as this can cause confusion for users.
If you use this method, you *must* return the image as an :class:`ImageOutput` for the gallery to
get the updated image.
:param image_name: The name of the image to update.
:param board_id: The board ID to add the image to, if it should be added.
:param is_intermediate: Whether the image is an intermediate. Intermediate images aren't added to the gallery.
"""
if is_intermediate is not None:
services.images.update(image_name, ImageRecordChanges(is_intermediate=is_intermediate))
if board_id is None:
services.board_images.remove_image_from_board(image_name)
else:
services.board_images.add_image_to_board(image_name, board_id)
return services.images.get_dto(image_name)
self.save = save
self.get_pil = get_pil
self.get_metadata = get_metadata
self.get_dto = get_dto
self.update = update
class LatentsKind(str, Enum):
IMAGE = "image"
NOISE = "noise"
MASK = "mask"
MASKED_IMAGE = "masked_image"
OTHER = "other"
class LatentsInterface:
def __init__(
self,
services: InvocationServices,
context_data: InvocationContextData,
) -> None:
def save(tensor: Tensor) -> str:
"""
Saves a latents tensor, returning its name.
:param tensor: The latents tensor to save.
"""
name = f"{context_data.session_id}__{context_data.invocation.id}__{uuid_string()[:7]}"
services.latents.save(
name=name,
data=tensor,
)
return name
def get(latents_name: str) -> Tensor:
"""
Gets a latents tensor by name.
:param latents_name: The name of the latents tensor to get.
"""
return services.latents.get(latents_name)
self.save = save
self.get = get
class ConditioningInterface:
def __init__(
self,
services: InvocationServices,
context_data: InvocationContextData,
) -> None:
def save(conditioning_data: ConditioningFieldData) -> str:
"""
Saves a conditioning data object, returning its name.
:param conditioning_data: The conditioning data to save.
"""
name = f"{context_data.session_id}__{context_data.invocation.id}__{uuid_string()[:7]}__conditioning"
services.latents.save(
name=name,
data=conditioning_data, # type: ignore [arg-type]
)
return name
def get(conditioning_name: str) -> Tensor:
"""
Gets conditioning data by name.
:param conditioning_name: The name of the conditioning data to get.
"""
return services.latents.get(conditioning_name)
self.save = save
self.get = get
class ModelsInterface:
def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None:
def exists(model_name: str, base_model: BaseModelType, model_type: ModelType) -> bool:
"""
Checks if a model exists.
:param model_name: The name of the model to check.
:param base_model: The base model of the model to check.
:param model_type: The type of the model to check.
"""
return services.model_manager.model_exists(model_name, base_model, model_type)
def load(
model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None
) -> ModelInfo:
"""
Loads a model, returning its `ModelInfo` object.
:param model_name: The name of the model to get.
:param base_model: The base model of the model to get.
:param model_type: The type of the model to get.
:param submodel: The submodel of the model to get.
"""
return services.model_manager.get_model(
model_name, base_model, model_type, submodel, context_data=context_data
)
def get_info(model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
"""
Gets a model's info, an dict-like object.
:param model_name: The name of the model to get.
:param base_model: The base model of the model to get.
:param model_type: The type of the model to get.
"""
return services.model_manager.model_info(model_name, base_model, model_type)
self.exists = exists
self.load = load
self.get_info = get_info
class ConfigInterface:
def __init__(self, services: InvocationServices) -> None:
def get() -> InvokeAIAppConfig:
"""
Gets the app's config.
"""
# The config can be changed at runtime. We don't want nodes doing this, so we make a
# frozen copy..
config = services.configuration.get_config()
frozen_config = config.model_copy(update={"model_config": ConfigDict(frozen=True)})
return frozen_config
self.get = get
class UtilInterface:
def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None:
def sd_step_callback(
intermediate_state: PipelineIntermediateState,
base_model: BaseModelType,
) -> None:
"""
The step callback emits a progress event with the current step, the total number of
steps, a preview image, and some other internal metadata.
This should be called after each step of the diffusion process.
:param intermediate_state: The intermediate state of the diffusion pipeline.
:param base_model: The base model for the current denoising step.
"""
stable_diffusion_step_callback(
context_data=context_data,
intermediate_state=intermediate_state,
base_model=base_model,
invocation_queue=services.queue,
events=services.events,
)
self.sd_step_callback = sd_step_callback
class InvocationContext:
"""
The invocation context provides access to various services and data about the current invocation.
"""
def __init__(
self,
images: ImagesInterface,
latents: LatentsInterface,
models: ModelsInterface,
config: ConfigInterface,
logger: LoggerInterface,
data: InvocationContextData,
util: UtilInterface,
conditioning: ConditioningInterface,
) -> None:
self.images = images
"Provides methods to save, get and update images and their metadata."
self.logger = logger
"Provides access to the app logger."
self.latents = latents
"Provides methods to save and get latents tensors, including image, noise, masks, and masked images."
self.conditioning = conditioning
"Provides methods to save and get conditioning data."
self.models = models
"Provides methods to check if a model exists, get a model, and get a model's info."
self.config = config
"Provides access to the app's config."
self.data = data
"Provides data about the current queue item and invocation."
self.util = util
"Provides utility methods."
def build_invocation_context(
services: InvocationServices,
context_data: InvocationContextData,
) -> InvocationContext:
"""
Builds the invocation context. This is a wrapper around the invocation services that provides
a more convenient (and less dangerous) interface for nodes to use.
:param invocation_services: The invocation services to wrap.
:param invocation_context_data: The invocation context data.
"""
logger = LoggerInterface(services=services)
images = ImagesInterface(services=services, context_data=context_data)
latents = LatentsInterface(services=services, context_data=context_data)
models = ModelsInterface(services=services, context_data=context_data)
config = ConfigInterface(services=services)
util = UtilInterface(services=services, context_data=context_data)
conditioning = ConditioningInterface(services=services, context_data=context_data)
ctx = InvocationContext(
images=images,
logger=logger,
config=config,
latents=latents,
models=models,
data=context_data,
util=util,
conditioning=conditioning,
)
return ctx

View File

@ -1,12 +1,25 @@
from typing import Protocol
import torch
from PIL import Image
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException, ProgressImage
from invokeai.app.services.invocation_queue.invocation_queue_base import InvocationQueueABC
from invokeai.app.services.shared.invocation_context import InvocationContextData
from ...backend.model_management.models import BaseModelType
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.util.util import image_to_dataURL
from ..invocations.baseinvocation import InvocationContext
class StepCallback(Protocol):
def __call__(
self,
intermediate_state: PipelineIntermediateState,
base_model: BaseModelType,
) -> None:
...
def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=None):
@ -25,13 +38,13 @@ def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=
def stable_diffusion_step_callback(
context: InvocationContext,
context_data: InvocationContextData,
intermediate_state: PipelineIntermediateState,
node: dict,
source_node_id: str,
base_model: BaseModelType,
):
if context.services.queue.is_canceled(context.graph_execution_state_id):
invocation_queue: InvocationQueueABC,
events: EventServiceBase,
) -> None:
if invocation_queue.is_canceled(context_data.session_id):
raise CanceledException
# Some schedulers report not only the noisy latents at the current timestep,
@ -108,13 +121,13 @@ def stable_diffusion_step_callback(
dataURL = image_to_dataURL(image, image_format="JPEG")
context.services.events.emit_generator_progress(
queue_id=context.queue_id,
queue_item_id=context.queue_item_id,
queue_batch_id=context.queue_batch_id,
graph_execution_state_id=context.graph_execution_state_id,
node=node,
source_node_id=source_node_id,
events.emit_generator_progress(
queue_id=context_data.queue_id,
queue_item_id=context_data.queue_item_id,
queue_batch_id=context_data.batch_id,
graph_execution_state_id=context_data.session_id,
node_id=context_data.invocation.id,
source_node_id=context_data.source_node_id,
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
step=intermediate_state.step,
order=intermediate_state.order,