mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): restricts invocation context power
Creates a low-power `InvocationContext` with simplified methods and data. See `invocation_context.py` for detailed comments.
This commit is contained in:
parent
992b02aa65
commit
3d98446d5d
408
invokeai/app/services/shared/invocation_context.py
Normal file
408
invokeai/app/services/shared/invocation_context.py
Normal file
@ -0,0 +1,408 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
from PIL.Image import Image
|
||||||
|
from pydantic import ConfigDict
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from invokeai.app.invocations.compel import ConditioningFieldData
|
||||||
|
from invokeai.app.invocations.fields import MetadataField, WithMetadata
|
||||||
|
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||||
|
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
|
||||||
|
from invokeai.app.services.images.images_common import ImageDTO
|
||||||
|
from invokeai.app.services.invocation_services import InvocationServices
|
||||||
|
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||||
|
from invokeai.app.util.misc import uuid_string
|
||||||
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
|
from invokeai.backend.model_management.model_manager import ModelInfo
|
||||||
|
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
|
||||||
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||||
|
|
||||||
|
"""
|
||||||
|
The InvocationContext provides access to various services and data about the current invocation.
|
||||||
|
|
||||||
|
We do not provide the invocation services directly, as their methods are both dangerous and
|
||||||
|
inconvenient to use.
|
||||||
|
|
||||||
|
For example:
|
||||||
|
- The `images` service allows nodes to delete or unsafely modify existing images.
|
||||||
|
- The `configuration` service allows nodes to change the app's config at runtime.
|
||||||
|
- The `events` service allows nodes to emit arbitrary events.
|
||||||
|
|
||||||
|
Wrapping these services provides a simpler and safer interface for nodes to use.
|
||||||
|
|
||||||
|
When a node executes, a fresh `InvocationContext` is built for it, ensuring nodes cannot interfere
|
||||||
|
with each other.
|
||||||
|
|
||||||
|
Note: The docstrings are in weird places, but that's where they must be to get IDEs to see them.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class InvocationContextData:
|
||||||
|
invocation: "BaseInvocation"
|
||||||
|
session_id: str
|
||||||
|
queue_id: str
|
||||||
|
source_node_id: str
|
||||||
|
queue_item_id: int
|
||||||
|
batch_id: str
|
||||||
|
workflow: Optional[WorkflowWithoutID] = None
|
||||||
|
|
||||||
|
|
||||||
|
class LoggerInterface:
|
||||||
|
def __init__(self, services: InvocationServices) -> None:
|
||||||
|
def debug(message: str) -> None:
|
||||||
|
"""
|
||||||
|
Logs a debug message.
|
||||||
|
|
||||||
|
:param message: The message to log.
|
||||||
|
"""
|
||||||
|
services.logger.debug(message)
|
||||||
|
|
||||||
|
def info(message: str) -> None:
|
||||||
|
"""
|
||||||
|
Logs an info message.
|
||||||
|
|
||||||
|
:param message: The message to log.
|
||||||
|
"""
|
||||||
|
services.logger.info(message)
|
||||||
|
|
||||||
|
def warning(message: str) -> None:
|
||||||
|
"""
|
||||||
|
Logs a warning message.
|
||||||
|
|
||||||
|
:param message: The message to log.
|
||||||
|
"""
|
||||||
|
services.logger.warning(message)
|
||||||
|
|
||||||
|
def error(message: str) -> None:
|
||||||
|
"""
|
||||||
|
Logs an error message.
|
||||||
|
|
||||||
|
:param message: The message to log.
|
||||||
|
"""
|
||||||
|
services.logger.error(message)
|
||||||
|
|
||||||
|
self.debug = debug
|
||||||
|
self.info = info
|
||||||
|
self.warning = warning
|
||||||
|
self.error = error
|
||||||
|
|
||||||
|
|
||||||
|
class ImagesInterface:
|
||||||
|
def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None:
|
||||||
|
def save(
|
||||||
|
image: Image,
|
||||||
|
board_id: Optional[str] = None,
|
||||||
|
image_category: ImageCategory = ImageCategory.GENERAL,
|
||||||
|
metadata: Optional[MetadataField] = None,
|
||||||
|
) -> ImageDTO:
|
||||||
|
"""
|
||||||
|
Saves an image, returning its DTO.
|
||||||
|
|
||||||
|
If the current queue item has a workflow, it is automatically saved with the image.
|
||||||
|
|
||||||
|
:param image: The image to save, as a PIL image.
|
||||||
|
:param board_id: The board ID to add the image to, if it should be added.
|
||||||
|
:param image_category: The category of the image. Only the GENERAL category is added to the gallery.
|
||||||
|
:param metadata: The metadata to save with the image, if it should have any. If the invocation inherits \
|
||||||
|
from `WithMetadata`, that metadata will be used automatically. Provide this only if you want to \
|
||||||
|
override or provide metadata manually.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# If the invocation inherits metadata, use that. Else, use the metadata passed in.
|
||||||
|
metadata_ = (
|
||||||
|
context_data.invocation.metadata if isinstance(context_data.invocation, WithMetadata) else metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
return services.images.create(
|
||||||
|
image=image,
|
||||||
|
is_intermediate=context_data.invocation.is_intermediate,
|
||||||
|
image_category=image_category,
|
||||||
|
board_id=board_id,
|
||||||
|
metadata=metadata_,
|
||||||
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
|
workflow=context_data.workflow,
|
||||||
|
session_id=context_data.session_id,
|
||||||
|
node_id=context_data.invocation.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_pil(image_name: str) -> Image:
|
||||||
|
"""
|
||||||
|
Gets an image as a PIL Image object.
|
||||||
|
|
||||||
|
:param image_name: The name of the image to get.
|
||||||
|
"""
|
||||||
|
return services.images.get_pil_image(image_name)
|
||||||
|
|
||||||
|
def get_metadata(image_name: str) -> Optional[MetadataField]:
|
||||||
|
"""
|
||||||
|
Gets an image's metadata, if it has any.
|
||||||
|
|
||||||
|
:param image_name: The name of the image to get the metadata for.
|
||||||
|
"""
|
||||||
|
return services.images.get_metadata(image_name)
|
||||||
|
|
||||||
|
def get_dto(image_name: str) -> ImageDTO:
|
||||||
|
"""
|
||||||
|
Gets an image as an ImageDTO object.
|
||||||
|
|
||||||
|
:param image_name: The name of the image to get.
|
||||||
|
"""
|
||||||
|
return services.images.get_dto(image_name)
|
||||||
|
|
||||||
|
def update(
|
||||||
|
image_name: str,
|
||||||
|
board_id: Optional[str] = None,
|
||||||
|
is_intermediate: Optional[bool] = False,
|
||||||
|
) -> ImageDTO:
|
||||||
|
"""
|
||||||
|
Updates an image, returning its updated DTO.
|
||||||
|
|
||||||
|
It is not suggested to update images saved by earlier nodes, as this can cause confusion for users.
|
||||||
|
|
||||||
|
If you use this method, you *must* return the image as an :class:`ImageOutput` for the gallery to
|
||||||
|
get the updated image.
|
||||||
|
|
||||||
|
:param image_name: The name of the image to update.
|
||||||
|
:param board_id: The board ID to add the image to, if it should be added.
|
||||||
|
:param is_intermediate: Whether the image is an intermediate. Intermediate images aren't added to the gallery.
|
||||||
|
"""
|
||||||
|
if is_intermediate is not None:
|
||||||
|
services.images.update(image_name, ImageRecordChanges(is_intermediate=is_intermediate))
|
||||||
|
if board_id is None:
|
||||||
|
services.board_images.remove_image_from_board(image_name)
|
||||||
|
else:
|
||||||
|
services.board_images.add_image_to_board(image_name, board_id)
|
||||||
|
return services.images.get_dto(image_name)
|
||||||
|
|
||||||
|
self.save = save
|
||||||
|
self.get_pil = get_pil
|
||||||
|
self.get_metadata = get_metadata
|
||||||
|
self.get_dto = get_dto
|
||||||
|
self.update = update
|
||||||
|
|
||||||
|
|
||||||
|
class LatentsKind(str, Enum):
|
||||||
|
IMAGE = "image"
|
||||||
|
NOISE = "noise"
|
||||||
|
MASK = "mask"
|
||||||
|
MASKED_IMAGE = "masked_image"
|
||||||
|
OTHER = "other"
|
||||||
|
|
||||||
|
|
||||||
|
class LatentsInterface:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
services: InvocationServices,
|
||||||
|
context_data: InvocationContextData,
|
||||||
|
) -> None:
|
||||||
|
def save(tensor: Tensor) -> str:
|
||||||
|
"""
|
||||||
|
Saves a latents tensor, returning its name.
|
||||||
|
|
||||||
|
:param tensor: The latents tensor to save.
|
||||||
|
"""
|
||||||
|
name = f"{context_data.session_id}__{context_data.invocation.id}__{uuid_string()[:7]}"
|
||||||
|
services.latents.save(
|
||||||
|
name=name,
|
||||||
|
data=tensor,
|
||||||
|
)
|
||||||
|
return name
|
||||||
|
|
||||||
|
def get(latents_name: str) -> Tensor:
|
||||||
|
"""
|
||||||
|
Gets a latents tensor by name.
|
||||||
|
|
||||||
|
:param latents_name: The name of the latents tensor to get.
|
||||||
|
"""
|
||||||
|
return services.latents.get(latents_name)
|
||||||
|
|
||||||
|
self.save = save
|
||||||
|
self.get = get
|
||||||
|
|
||||||
|
|
||||||
|
class ConditioningInterface:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
services: InvocationServices,
|
||||||
|
context_data: InvocationContextData,
|
||||||
|
) -> None:
|
||||||
|
def save(conditioning_data: ConditioningFieldData) -> str:
|
||||||
|
"""
|
||||||
|
Saves a conditioning data object, returning its name.
|
||||||
|
|
||||||
|
:param conditioning_data: The conditioning data to save.
|
||||||
|
"""
|
||||||
|
name = f"{context_data.session_id}__{context_data.invocation.id}__{uuid_string()[:7]}__conditioning"
|
||||||
|
services.latents.save(
|
||||||
|
name=name,
|
||||||
|
data=conditioning_data, # type: ignore [arg-type]
|
||||||
|
)
|
||||||
|
return name
|
||||||
|
|
||||||
|
def get(conditioning_name: str) -> Tensor:
|
||||||
|
"""
|
||||||
|
Gets conditioning data by name.
|
||||||
|
|
||||||
|
:param conditioning_name: The name of the conditioning data to get.
|
||||||
|
"""
|
||||||
|
return services.latents.get(conditioning_name)
|
||||||
|
|
||||||
|
self.save = save
|
||||||
|
self.get = get
|
||||||
|
|
||||||
|
|
||||||
|
class ModelsInterface:
|
||||||
|
def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None:
|
||||||
|
def exists(model_name: str, base_model: BaseModelType, model_type: ModelType) -> bool:
|
||||||
|
"""
|
||||||
|
Checks if a model exists.
|
||||||
|
|
||||||
|
:param model_name: The name of the model to check.
|
||||||
|
:param base_model: The base model of the model to check.
|
||||||
|
:param model_type: The type of the model to check.
|
||||||
|
"""
|
||||||
|
return services.model_manager.model_exists(model_name, base_model, model_type)
|
||||||
|
|
||||||
|
def load(
|
||||||
|
model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None
|
||||||
|
) -> ModelInfo:
|
||||||
|
"""
|
||||||
|
Loads a model, returning its `ModelInfo` object.
|
||||||
|
|
||||||
|
:param model_name: The name of the model to get.
|
||||||
|
:param base_model: The base model of the model to get.
|
||||||
|
:param model_type: The type of the model to get.
|
||||||
|
:param submodel: The submodel of the model to get.
|
||||||
|
"""
|
||||||
|
return services.model_manager.get_model(
|
||||||
|
model_name, base_model, model_type, submodel, context_data=context_data
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_info(model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||||
|
"""
|
||||||
|
Gets a model's info, an dict-like object.
|
||||||
|
|
||||||
|
:param model_name: The name of the model to get.
|
||||||
|
:param base_model: The base model of the model to get.
|
||||||
|
:param model_type: The type of the model to get.
|
||||||
|
"""
|
||||||
|
return services.model_manager.model_info(model_name, base_model, model_type)
|
||||||
|
|
||||||
|
self.exists = exists
|
||||||
|
self.load = load
|
||||||
|
self.get_info = get_info
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigInterface:
|
||||||
|
def __init__(self, services: InvocationServices) -> None:
|
||||||
|
def get() -> InvokeAIAppConfig:
|
||||||
|
"""
|
||||||
|
Gets the app's config.
|
||||||
|
"""
|
||||||
|
# The config can be changed at runtime. We don't want nodes doing this, so we make a
|
||||||
|
# frozen copy..
|
||||||
|
config = services.configuration.get_config()
|
||||||
|
frozen_config = config.model_copy(update={"model_config": ConfigDict(frozen=True)})
|
||||||
|
return frozen_config
|
||||||
|
|
||||||
|
self.get = get
|
||||||
|
|
||||||
|
|
||||||
|
class UtilInterface:
|
||||||
|
def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None:
|
||||||
|
def sd_step_callback(
|
||||||
|
intermediate_state: PipelineIntermediateState,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
The step callback emits a progress event with the current step, the total number of
|
||||||
|
steps, a preview image, and some other internal metadata.
|
||||||
|
|
||||||
|
This should be called after each step of the diffusion process.
|
||||||
|
|
||||||
|
:param intermediate_state: The intermediate state of the diffusion pipeline.
|
||||||
|
:param base_model: The base model for the current denoising step.
|
||||||
|
"""
|
||||||
|
stable_diffusion_step_callback(
|
||||||
|
context_data=context_data,
|
||||||
|
intermediate_state=intermediate_state,
|
||||||
|
base_model=base_model,
|
||||||
|
invocation_queue=services.queue,
|
||||||
|
events=services.events,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.sd_step_callback = sd_step_callback
|
||||||
|
|
||||||
|
|
||||||
|
class InvocationContext:
|
||||||
|
"""
|
||||||
|
The invocation context provides access to various services and data about the current invocation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
images: ImagesInterface,
|
||||||
|
latents: LatentsInterface,
|
||||||
|
models: ModelsInterface,
|
||||||
|
config: ConfigInterface,
|
||||||
|
logger: LoggerInterface,
|
||||||
|
data: InvocationContextData,
|
||||||
|
util: UtilInterface,
|
||||||
|
conditioning: ConditioningInterface,
|
||||||
|
) -> None:
|
||||||
|
self.images = images
|
||||||
|
"Provides methods to save, get and update images and their metadata."
|
||||||
|
self.logger = logger
|
||||||
|
"Provides access to the app logger."
|
||||||
|
self.latents = latents
|
||||||
|
"Provides methods to save and get latents tensors, including image, noise, masks, and masked images."
|
||||||
|
self.conditioning = conditioning
|
||||||
|
"Provides methods to save and get conditioning data."
|
||||||
|
self.models = models
|
||||||
|
"Provides methods to check if a model exists, get a model, and get a model's info."
|
||||||
|
self.config = config
|
||||||
|
"Provides access to the app's config."
|
||||||
|
self.data = data
|
||||||
|
"Provides data about the current queue item and invocation."
|
||||||
|
self.util = util
|
||||||
|
"Provides utility methods."
|
||||||
|
|
||||||
|
|
||||||
|
def build_invocation_context(
|
||||||
|
services: InvocationServices,
|
||||||
|
context_data: InvocationContextData,
|
||||||
|
) -> InvocationContext:
|
||||||
|
"""
|
||||||
|
Builds the invocation context. This is a wrapper around the invocation services that provides
|
||||||
|
a more convenient (and less dangerous) interface for nodes to use.
|
||||||
|
|
||||||
|
:param invocation_services: The invocation services to wrap.
|
||||||
|
:param invocation_context_data: The invocation context data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
logger = LoggerInterface(services=services)
|
||||||
|
images = ImagesInterface(services=services, context_data=context_data)
|
||||||
|
latents = LatentsInterface(services=services, context_data=context_data)
|
||||||
|
models = ModelsInterface(services=services, context_data=context_data)
|
||||||
|
config = ConfigInterface(services=services)
|
||||||
|
util = UtilInterface(services=services, context_data=context_data)
|
||||||
|
conditioning = ConditioningInterface(services=services, context_data=context_data)
|
||||||
|
|
||||||
|
ctx = InvocationContext(
|
||||||
|
images=images,
|
||||||
|
logger=logger,
|
||||||
|
config=config,
|
||||||
|
latents=latents,
|
||||||
|
models=models,
|
||||||
|
data=context_data,
|
||||||
|
util=util,
|
||||||
|
conditioning=conditioning,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ctx
|
@ -1,12 +1,25 @@
|
|||||||
|
from typing import Protocol
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||||||
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException, ProgressImage
|
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException, ProgressImage
|
||||||
|
from invokeai.app.services.invocation_queue.invocation_queue_base import InvocationQueueABC
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||||
|
|
||||||
from ...backend.model_management.models import BaseModelType
|
from ...backend.model_management.models import BaseModelType
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
from ...backend.util.util import image_to_dataURL
|
from ...backend.util.util import image_to_dataURL
|
||||||
from ..invocations.baseinvocation import InvocationContext
|
|
||||||
|
|
||||||
|
class StepCallback(Protocol):
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
intermediate_state: PipelineIntermediateState,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=None):
|
def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=None):
|
||||||
@ -25,13 +38,13 @@ def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=
|
|||||||
|
|
||||||
|
|
||||||
def stable_diffusion_step_callback(
|
def stable_diffusion_step_callback(
|
||||||
context: InvocationContext,
|
context_data: InvocationContextData,
|
||||||
intermediate_state: PipelineIntermediateState,
|
intermediate_state: PipelineIntermediateState,
|
||||||
node: dict,
|
|
||||||
source_node_id: str,
|
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
):
|
invocation_queue: InvocationQueueABC,
|
||||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
events: EventServiceBase,
|
||||||
|
) -> None:
|
||||||
|
if invocation_queue.is_canceled(context_data.session_id):
|
||||||
raise CanceledException
|
raise CanceledException
|
||||||
|
|
||||||
# Some schedulers report not only the noisy latents at the current timestep,
|
# Some schedulers report not only the noisy latents at the current timestep,
|
||||||
@ -108,13 +121,13 @@ def stable_diffusion_step_callback(
|
|||||||
|
|
||||||
dataURL = image_to_dataURL(image, image_format="JPEG")
|
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||||
|
|
||||||
context.services.events.emit_generator_progress(
|
events.emit_generator_progress(
|
||||||
queue_id=context.queue_id,
|
queue_id=context_data.queue_id,
|
||||||
queue_item_id=context.queue_item_id,
|
queue_item_id=context_data.queue_item_id,
|
||||||
queue_batch_id=context.queue_batch_id,
|
queue_batch_id=context_data.batch_id,
|
||||||
graph_execution_state_id=context.graph_execution_state_id,
|
graph_execution_state_id=context_data.session_id,
|
||||||
node=node,
|
node_id=context_data.invocation.id,
|
||||||
source_node_id=source_node_id,
|
source_node_id=context_data.source_node_id,
|
||||||
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
|
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
|
||||||
step=intermediate_state.step,
|
step=intermediate_state.step,
|
||||||
order=intermediate_state.order,
|
order=intermediate_state.order,
|
||||||
|
Loading…
Reference in New Issue
Block a user