Get naive latent space regional prompting working. Next, need to support areas in addition to masks.

This commit is contained in:
Ryan Dick 2024-02-20 19:02:19 -05:00
parent cb6c5c23ce
commit 32f602ab2a
2 changed files with 96 additions and 42 deletions

View File

@ -73,7 +73,7 @@ class IPAdapterConditioningInfo:
@dataclass
class ConditioningData:
unconditioned_embeddings: BasicConditioningInfo
unconditioned_embeddings: Union[BasicConditioningInfo, SDXLConditioningInfo]
text_embeddings: list[TextConditioningInfoWithMask]
"""
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).

View File

@ -5,13 +5,16 @@ from contextlib import contextmanager
from typing import Any, Callable, Optional, Union
import torch
import torchvision
from diffusers import UNet2DConditionModel
from typing_extensions import TypeAlias
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
ConditioningData,
ExtraConditioningInfo,
IPAdapterConditioningInfo,
PostprocessingSettings,
SDXLConditioningInfo,
)
@ -217,35 +220,87 @@ class InvokeAIDiffuserComponent:
)
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
if wants_cross_attention_control or self.sequential_guidance:
# If wants_cross_attention_control is True, we force the sequential mode to be used, because cross-attention
# control is currently only supported in sequential mode.
(
unconditioned_next_x,
conditioned_next_x,
) = self._apply_standard_conditioning_sequentially(
x=sample,
sigma=timestep,
conditioning_data=conditioning_data,
cross_attention_control_types_to_do=cross_attention_control_types_to_do,
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
)
else:
(
unconditioned_next_x,
conditioned_next_x,
) = self._apply_standard_conditioning(
x=sample,
sigma=timestep,
conditioning_data=conditioning_data,
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
)
cond_next_xs = []
uncond_next_x = None
for text_conditioning in conditioning_data.text_embeddings:
if wants_cross_attention_control or self.sequential_guidance:
raise NotImplementedError(
"Sequential conditioning has not yet been updated to work with multiple text embeddings."
)
# If wants_cross_attention_control is True, we force the sequential mode to be used, because cross-attention
# control is currently only supported in sequential mode.
# (
# unconditioned_next_x,
# conditioned_next_x,
# ) = self._apply_standard_conditioning_sequentially(
# x=sample,
# sigma=timestep,
# conditioning_data=conditioning_data,
# cross_attention_control_types_to_do=cross_attention_control_types_to_do,
# down_block_additional_residuals=down_block_additional_residuals,
# mid_block_additional_residual=mid_block_additional_residual,
# down_intrablock_additional_residuals=down_intrablock_additional_residuals,
# )
else:
(
unconditioned_next_x,
conditioned_next_x,
) = self._apply_standard_conditioning(
x=sample,
sigma=timestep,
cond_text_embedding=text_conditioning.text_conditioning_info,
uncond_text_embedding=conditioning_data.unconditioned_embeddings,
ip_adapter_conditioning=conditioning_data.ip_adapter_conditioning,
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
)
cond_next_xs.append(conditioned_next_x)
# HACK(ryand): We re-run unconditioned denoising for each text embedding, but we should only need to do it
# once.
uncond_next_x = unconditioned_next_x
return unconditioned_next_x, conditioned_next_x
# TODO(ryand): Think about how to handle the batch dimension here. Should this be torch.stack()? It probably
# doesn't matter, as I'm sure there are many other places where we don't properly support batching.
cond_out = torch.concat(cond_next_xs, dim=0)
# Initialize count to 1e-9 to avoid division by zero.
cond_count = torch.ones_like(cond_out[0, ...]) * 1e-9
_, _, height, width = cond_out.shape
for te_idx, te in enumerate(conditioning_data.text_embeddings):
mask = te.mask
if mask is not None:
# Resize if necessary.
tf = torchvision.transforms.Resize(
(height, width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
)
mask = mask.unsqueeze(0).unsqueeze(0) # Shape: (h, w) -> (1, 1, h, w)
mask = tf(mask)
# TODO(ryand): We are converting from uint8 to float here. Should we just be storing a float mask to
# begin with?
mask = mask.to(cond_out.device, cond_out.dtype)
# Make sure that all mask values are either 0.0 or 1.0.
# HACK(ryand): This is not the right place to be doing this. Just be clear about the expected format of
# the mask in the passed data structures.
mask[mask < 0.5] = 0.0
mask[mask >= 0.5] = 1.0
mask *= te.mask_strength
else:
# mask is None, so treat as a mask of all 1.0s (by taking advantage of torch's treatment of scalar
# values).
mask = 1.0
# Apply the mask and update the count.
cond_out[te_idx, ...] *= mask[0]
cond_count += mask[0]
# Combine the masked conditionings.
cond_out = cond_out.sum(dim=0, keepdim=True) / cond_count
return uncond_next_x, cond_out
def do_latent_postprocessing(
self,
@ -313,7 +368,9 @@ class InvokeAIDiffuserComponent:
self,
x,
sigma,
conditioning_data: ConditioningData,
cond_text_embedding: Union[BasicConditioningInfo, SDXLConditioningInfo],
uncond_text_embedding: Union[BasicConditioningInfo, SDXLConditioningInfo],
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
@ -324,43 +381,40 @@ class InvokeAIDiffuserComponent:
x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2)
assert len(conditioning_data.text_embeddings) == 1
text_embeddings = conditioning_data.text_embeddings[0].text_conditioning_info
cross_attention_kwargs = None
if conditioning_data.ip_adapter_conditioning is not None:
if ip_adapter_conditioning is not None:
# Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len).
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": [
torch.stack(
[ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds]
)
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
for ipa_conditioning in ip_adapter_conditioning
]
}
added_cond_kwargs = None
if type(text_embeddings) is SDXLConditioningInfo:
if type(cond_text_embedding) is SDXLConditioningInfo:
added_cond_kwargs = {
"text_embeds": torch.cat(
[
# TODO: how to pad? just by zeros? or even truncate?
conditioning_data.unconditioned_embeddings.pooled_embeds,
text_embeddings.pooled_embeds,
uncond_text_embedding.pooled_embeds,
cond_text_embedding.pooled_embeds,
],
dim=0,
),
"time_ids": torch.cat(
[
conditioning_data.unconditioned_embeddings.add_time_ids,
text_embeddings.add_time_ids,
uncond_text_embedding.add_time_ids,
cond_text_embedding.add_time_ids,
],
dim=0,
),
}
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
conditioning_data.unconditioned_embeddings.embeds, text_embeddings.embeds
uncond_text_embedding.embeds, cond_text_embedding.embeds
)
both_results = self.model_forward_callback(
x_twice,
@ -385,7 +439,7 @@ class InvokeAIDiffuserComponent:
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
):
) -> tuple[torch.Tensor, torch.Tensor]:
"""Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of
slower execution speed.
"""