First draft of area-based latent-space regional prompting.

This commit is contained in:
Ryan Dick 2024-02-22 07:50:11 -08:00
parent 32f602ab2a
commit bc869cad62
5 changed files with 128 additions and 86 deletions

View File

@ -9,46 +9,75 @@ from invokeai.app.invocations.baseinvocation import (
WithMetadata,
invocation,
)
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, DenoisingArea, ImageField, ImageOutput
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
# @invocation(
# "add_conditioning_mask",
# title="Add Conditioning Mask",
# tags=["conditioning"],
# category="conditioning",
# version="1.0.0",
# )
# class AddConditioningMaskInvocation(BaseInvocation):
# """Add a mask to an existing conditioning tensor."""
# conditioning: ConditioningField = InputField(description="The conditioning tensor to add a mask to.")
# mask: ImageField = InputField(
# description="A mask image to add to the conditioning tensor. Only the first channel of the image is used. "
# "Pixels <128 are excluded from the mask, pixels >=128 are included in the mask."
# )
# mask_strength: float = InputField(
# description="The strength of the mask to apply to the conditioning tensor.", default=1.0
# )
# @staticmethod
# def convert_image_to_mask(image: Image.Image) -> torch.Tensor:
# """Convert a PIL image to a uint8 mask tensor."""
# np_image = np.array(image)
# torch_image = torch.from_numpy(np_image[:, :, 0])
# mask = torch_image >= 128
# return mask.to(dtype=torch.uint8)
# def invoke(self, context: InvocationContext) -> ConditioningOutput:
# image = context.services.images.get_pil_image(self.mask.image_name)
# mask = self.convert_image_to_mask(image)
# mask_name = f"{context.graph_execution_state_id}__{self.id}_conditioning_mask"
# context.services.latents.save(mask_name, mask)
# self.conditioning.mask_name = mask_name
# self.conditioning.mask_strength = self.mask_strength
# return ConditioningOutput(conditioning=self.conditioning)
@invocation(
"add_conditioning_mask",
title="Add Conditioning Mask",
"add_conditioning_area",
title="Add Conditioning Area",
tags=["conditioning"],
category="conditioning",
version="1.0.0",
)
class AddConditioningMaskInvocation(BaseInvocation):
"""Add a mask to an existing conditioning tensor."""
class AddConditioningAreaInvocation(BaseInvocation):
"""Add a denoising area to an existing conditioning tensor."""
conditioning: ConditioningField = InputField(description="The conditioning tensor to add a mask to.")
mask: ImageField = InputField(
description="A mask image to add to the conditioning tensor. Only the first channel of the image is used. "
"Pixels <128 are excluded from the mask, pixels >=128 are included in the mask."
)
# mask: ImageField = InputField(
# description="A mask image to add to the conditioning tensor. Only the first channel of the image is used. "
# "Pixels <128 are excluded from the mask, pixels >=128 are included in the mask."
# )
mask_strength: float = InputField(
description="The strength of the mask to apply to the conditioning tensor.", default=1.0
)
@staticmethod
def convert_image_to_mask(image: Image.Image) -> torch.Tensor:
"""Convert a PIL image to a uint8 mask tensor."""
np_image = np.array(image)
torch_image = torch.from_numpy(np_image[:, :, 0])
mask = torch_image >= 128
return mask.to(dtype=torch.uint8)
height: int = InputField(description="The height of the area in latent space.")
width: int = InputField(description="The width of the area in latent space.")
y_top: int = InputField(description="The top of the area in latent space.")
x_left: int = InputField(description="The left of the area in latent space.")
def invoke(self, context: InvocationContext) -> ConditioningOutput:
image = context.services.images.get_pil_image(self.mask.image_name)
mask = self.convert_image_to_mask(image)
mask_name = f"{context.graph_execution_state_id}__{self.id}_conditioning_mask"
context.services.latents.save(mask_name, mask)
self.conditioning.mask_name = mask_name
self.conditioning.mask_strength = self.mask_strength
self.conditioning.denoising_area = DenoisingArea(height=self.height, width=self.width, top_y=self.y_top, left_x=self.x_left)
return ConditioningOutput(conditioning=self.conditioning)

View File

@ -43,7 +43,7 @@ from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
ConditioningData,
IPAdapterConditioningInfo,
TextConditioningInfoWithMask,
TextConditioningInfoWithArea,
)
from ...backend.model_management.lora import ModelPatcher
@ -339,17 +339,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
if not isinstance(positive_conditioning_list, list):
positive_conditioning_list = [positive_conditioning_list]
text_embeddings: list[TextConditioningInfoWithMask] = []
text_embeddings: list[TextConditioningInfoWithArea] = []
for positive_conditioning in positive_conditioning_list:
positive_cond_data = context.services.latents.get(positive_conditioning.conditioning_name)
mask_name = positive_conditioning.mask_name
mask = None
if mask_name is not None:
mask = context.services.latents.get(mask_name)
text_embeddings.append(
TextConditioningInfoWithMask(
TextConditioningInfoWithArea(
text_conditioning_info=positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype),
mask=mask,
area=positive_conditioning.denoising_area,
mask_strength=positive_conditioning.mask_strength,
)
)

View File

@ -423,16 +423,23 @@ class ColorInvocation(BaseInvocation):
# region Conditioning
class DenoisingArea(BaseModel):
top_y: int = Field(description="TODO")
left_x: int = Field(description="TODO")
height: int = Field(description="TODO")
width: int = Field(description="TODO")
class ConditioningField(BaseModel):
"""A conditioning tensor primitive value"""
conditioning_name: str = Field(description="The name of conditioning tensor")
mask_name: Optional[str] = Field(
default=None,
description="The mask associated with this conditioning tensor. Excluded regions should be set to 0, included "
"regions should be set to 1.",
)
# mask_name: Optional[str] = Field(
# default=None,
# description="The mask associated with this conditioning tensor. Excluded regions should be set to 0, included "
# "regions should be set to 1.",
# )
denoising_area: DenoisingArea = Field(description="The area to apply this conditioning to. Coordinates are in latent space.")
mask_strength: float = Field(
default=1.0,
description="The strength of the mask. Only has an effect if mask_name is set. The strength is relative to "

View File

@ -5,6 +5,8 @@ from typing import Any, List, Optional, Union
import torch
from invokeai.app.invocations.primitives import DenoisingArea
from .cross_attention_control import Arguments
@ -39,15 +41,15 @@ class SDXLConditioningInfo(BasicConditioningInfo):
return super().to(device=device, dtype=dtype)
class TextConditioningInfoWithMask:
class TextConditioningInfoWithArea:
def __init__(
self,
text_conditioning_info: Union[BasicConditioningInfo, SDXLConditioningInfo],
mask: Optional[torch.Tensor],
area: DenoisingArea,
mask_strength: float,
):
self.text_conditioning_info = text_conditioning_info
self.mask = mask
self.area = area
self.mask_strength = mask_strength
@ -74,7 +76,7 @@ class IPAdapterConditioningInfo:
@dataclass
class ConditioningData:
unconditioned_embeddings: Union[BasicConditioningInfo, SDXLConditioningInfo]
text_embeddings: list[TextConditioningInfoWithMask]
text_embeddings: list[TextConditioningInfoWithArea]
"""
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).

View File

@ -8,6 +8,7 @@ import torch
import torchvision
from diffusers import UNet2DConditionModel
from typing_extensions import TypeAlias
from invokeai.app.invocations.primitives import DenoisingArea
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
@ -220,9 +221,21 @@ class InvokeAIDiffuserComponent:
)
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
cond_next_xs = []
uncond_next_x = None
for text_conditioning in conditioning_data.text_embeddings:
latent_shape = sample.shape
cond_next_xs = torch.zeros(
(latent_shape[0], len(conditioning_data.text_embeddings)) + latent_shape[1:],
dtype=sample.dtype,
device=sample.device,
)
uncond_next_xs = torch.zeros(
(latent_shape[0], len(conditioning_data.text_embeddings)) + latent_shape[1:],
dtype=sample.dtype,
device=sample.device,
)
# Initialize counts to 1e-10 to avoid divide-by-zero.
cond_count = torch.ones_like(sample) * 1e-10
for i, text_conditioning in enumerate(conditioning_data.text_embeddings):
if wants_cross_attention_control or self.sequential_guidance:
raise NotImplementedError(
"Sequential conditioning has not yet been updated to work with multiple text embeddings."
@ -242,11 +255,26 @@ class InvokeAIDiffuserComponent:
# down_intrablock_additional_residuals=down_intrablock_additional_residuals,
# )
else:
area = text_conditioning.area
# TODO(ryand): Use LATENT_SCALE_FACTOR instead of hard-coding to 8.
latent_area = DenoisingArea(
height=max(1, area.height // 8),
width=max(1, area.width // 8),
top_y=area.top_y // 8,
left_x=area.left_x // 8,
)
area_sample = sample[
:,
:,
latent_area.top_y : latent_area.top_y + latent_area.height,
latent_area.left_x : latent_area.left_x + latent_area.width,
]
(
unconditioned_next_x,
conditioned_next_x,
) = self._apply_standard_conditioning(
x=sample,
x=area_sample,
sigma=timestep,
cond_text_embedding=text_conditioning.text_conditioning_info,
uncond_text_embedding=conditioning_data.unconditioned_embeddings,
@ -255,52 +283,32 @@ class InvokeAIDiffuserComponent:
mid_block_additional_residual=mid_block_additional_residual,
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
)
cond_next_xs.append(conditioned_next_x)
# HACK(ryand): We re-run unconditioned denoising for each text embedding, but we should only need to do it
# once.
uncond_next_x = unconditioned_next_x
# TODO(ryand): Think about how to handle the batch dimension here. Should this be torch.stack()? It probably
# doesn't matter, as I'm sure there are many other places where we don't properly support batching.
cond_out = torch.concat(cond_next_xs, dim=0)
# Initialize count to 1e-9 to avoid division by zero.
cond_count = torch.ones_like(cond_out[0, ...]) * 1e-9
# TODO(ryand): Apply mask here.
cond_next_xs[
:,
i,
latent_area.top_y : latent_area.top_y + latent_area.height,
latent_area.left_x : latent_area.left_x + latent_area.width,
] = conditioned_next_x * text_conditioning.mask_strength
uncond_next_xs[
:,
i,
latent_area.top_y : latent_area.top_y + latent_area.height,
latent_area.left_x : latent_area.left_x + latent_area.width,
] = unconditioned_next_x * text_conditioning.mask_strength
cond_count[
:,
latent_area.top_y : latent_area.top_y + latent_area.height,
latent_area.left_x : latent_area.left_x + latent_area.width,
] += text_conditioning.mask_strength
_, _, height, width = cond_out.shape
for te_idx, te in enumerate(conditioning_data.text_embeddings):
mask = te.mask
if mask is not None:
# Resize if necessary.
tf = torchvision.transforms.Resize(
(height, width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
)
mask = mask.unsqueeze(0).unsqueeze(0) # Shape: (h, w) -> (1, 1, h, w)
mask = tf(mask)
# TODO(ryand): Do other apps apply the same mask weight and count to the unconditioned output?
# TODO(ryand): We are converting from uint8 to float here. Should we just be storing a float mask to
# begin with?
mask = mask.to(cond_out.device, cond_out.dtype)
cond_next_x = cond_next_xs.sum(dim=1) / cond_count
uncond_next_x = uncond_next_xs.sum(dim=1) / cond_count
# Make sure that all mask values are either 0.0 or 1.0.
# HACK(ryand): This is not the right place to be doing this. Just be clear about the expected format of
# the mask in the passed data structures.
mask[mask < 0.5] = 0.0
mask[mask >= 0.5] = 1.0
mask *= te.mask_strength
else:
# mask is None, so treat as a mask of all 1.0s (by taking advantage of torch's treatment of scalar
# values).
mask = 1.0
# Apply the mask and update the count.
cond_out[te_idx, ...] *= mask[0]
cond_count += mask[0]
# Combine the masked conditionings.
cond_out = cond_out.sum(dim=0, keepdim=True) / cond_count
return uncond_next_x, cond_out
return uncond_next_x, cond_next_x
def do_latent_postprocessing(
self,