Get naive latent space regional prompting working. Next, need to support areas in addition to masks.

This commit is contained in:
Ryan Dick 2024-02-20 19:02:19 -05:00
parent cb6c5c23ce
commit 32f602ab2a
2 changed files with 96 additions and 42 deletions

View File

@ -73,7 +73,7 @@ class IPAdapterConditioningInfo:
@dataclass @dataclass
class ConditioningData: class ConditioningData:
unconditioned_embeddings: BasicConditioningInfo unconditioned_embeddings: Union[BasicConditioningInfo, SDXLConditioningInfo]
text_embeddings: list[TextConditioningInfoWithMask] text_embeddings: list[TextConditioningInfoWithMask]
""" """
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).

View File

@ -5,13 +5,16 @@ from contextlib import contextmanager
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional, Union
import torch import torch
import torchvision
from diffusers import UNet2DConditionModel from diffusers import UNet2DConditionModel
from typing_extensions import TypeAlias from typing_extensions import TypeAlias
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
ConditioningData, ConditioningData,
ExtraConditioningInfo, ExtraConditioningInfo,
IPAdapterConditioningInfo,
PostprocessingSettings, PostprocessingSettings,
SDXLConditioningInfo, SDXLConditioningInfo,
) )
@ -217,35 +220,87 @@ class InvokeAIDiffuserComponent:
) )
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0 wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
if wants_cross_attention_control or self.sequential_guidance: cond_next_xs = []
# If wants_cross_attention_control is True, we force the sequential mode to be used, because cross-attention uncond_next_x = None
# control is currently only supported in sequential mode. for text_conditioning in conditioning_data.text_embeddings:
( if wants_cross_attention_control or self.sequential_guidance:
unconditioned_next_x, raise NotImplementedError(
conditioned_next_x, "Sequential conditioning has not yet been updated to work with multiple text embeddings."
) = self._apply_standard_conditioning_sequentially( )
x=sample, # If wants_cross_attention_control is True, we force the sequential mode to be used, because cross-attention
sigma=timestep, # control is currently only supported in sequential mode.
conditioning_data=conditioning_data, # (
cross_attention_control_types_to_do=cross_attention_control_types_to_do, # unconditioned_next_x,
down_block_additional_residuals=down_block_additional_residuals, # conditioned_next_x,
mid_block_additional_residual=mid_block_additional_residual, # ) = self._apply_standard_conditioning_sequentially(
down_intrablock_additional_residuals=down_intrablock_additional_residuals, # x=sample,
) # sigma=timestep,
else: # conditioning_data=conditioning_data,
( # cross_attention_control_types_to_do=cross_attention_control_types_to_do,
unconditioned_next_x, # down_block_additional_residuals=down_block_additional_residuals,
conditioned_next_x, # mid_block_additional_residual=mid_block_additional_residual,
) = self._apply_standard_conditioning( # down_intrablock_additional_residuals=down_intrablock_additional_residuals,
x=sample, # )
sigma=timestep, else:
conditioning_data=conditioning_data, (
down_block_additional_residuals=down_block_additional_residuals, unconditioned_next_x,
mid_block_additional_residual=mid_block_additional_residual, conditioned_next_x,
down_intrablock_additional_residuals=down_intrablock_additional_residuals, ) = self._apply_standard_conditioning(
) x=sample,
sigma=timestep,
cond_text_embedding=text_conditioning.text_conditioning_info,
uncond_text_embedding=conditioning_data.unconditioned_embeddings,
ip_adapter_conditioning=conditioning_data.ip_adapter_conditioning,
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
)
cond_next_xs.append(conditioned_next_x)
# HACK(ryand): We re-run unconditioned denoising for each text embedding, but we should only need to do it
# once.
uncond_next_x = unconditioned_next_x
return unconditioned_next_x, conditioned_next_x # TODO(ryand): Think about how to handle the batch dimension here. Should this be torch.stack()? It probably
# doesn't matter, as I'm sure there are many other places where we don't properly support batching.
cond_out = torch.concat(cond_next_xs, dim=0)
# Initialize count to 1e-9 to avoid division by zero.
cond_count = torch.ones_like(cond_out[0, ...]) * 1e-9
_, _, height, width = cond_out.shape
for te_idx, te in enumerate(conditioning_data.text_embeddings):
mask = te.mask
if mask is not None:
# Resize if necessary.
tf = torchvision.transforms.Resize(
(height, width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
)
mask = mask.unsqueeze(0).unsqueeze(0) # Shape: (h, w) -> (1, 1, h, w)
mask = tf(mask)
# TODO(ryand): We are converting from uint8 to float here. Should we just be storing a float mask to
# begin with?
mask = mask.to(cond_out.device, cond_out.dtype)
# Make sure that all mask values are either 0.0 or 1.0.
# HACK(ryand): This is not the right place to be doing this. Just be clear about the expected format of
# the mask in the passed data structures.
mask[mask < 0.5] = 0.0
mask[mask >= 0.5] = 1.0
mask *= te.mask_strength
else:
# mask is None, so treat as a mask of all 1.0s (by taking advantage of torch's treatment of scalar
# values).
mask = 1.0
# Apply the mask and update the count.
cond_out[te_idx, ...] *= mask[0]
cond_count += mask[0]
# Combine the masked conditionings.
cond_out = cond_out.sum(dim=0, keepdim=True) / cond_count
return uncond_next_x, cond_out
def do_latent_postprocessing( def do_latent_postprocessing(
self, self,
@ -313,7 +368,9 @@ class InvokeAIDiffuserComponent:
self, self,
x, x,
sigma, sigma,
conditioning_data: ConditioningData, cond_text_embedding: Union[BasicConditioningInfo, SDXLConditioningInfo],
uncond_text_embedding: Union[BasicConditioningInfo, SDXLConditioningInfo],
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
@ -324,43 +381,40 @@ class InvokeAIDiffuserComponent:
x_twice = torch.cat([x] * 2) x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2) sigma_twice = torch.cat([sigma] * 2)
assert len(conditioning_data.text_embeddings) == 1
text_embeddings = conditioning_data.text_embeddings[0].text_conditioning_info
cross_attention_kwargs = None cross_attention_kwargs = None
if conditioning_data.ip_adapter_conditioning is not None: if ip_adapter_conditioning is not None:
# Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len). # Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len).
cross_attention_kwargs = { cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": [ "ip_adapter_image_prompt_embeds": [
torch.stack( torch.stack(
[ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds] [ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds]
) )
for ipa_conditioning in conditioning_data.ip_adapter_conditioning for ipa_conditioning in ip_adapter_conditioning
] ]
} }
added_cond_kwargs = None added_cond_kwargs = None
if type(text_embeddings) is SDXLConditioningInfo: if type(cond_text_embedding) is SDXLConditioningInfo:
added_cond_kwargs = { added_cond_kwargs = {
"text_embeds": torch.cat( "text_embeds": torch.cat(
[ [
# TODO: how to pad? just by zeros? or even truncate? # TODO: how to pad? just by zeros? or even truncate?
conditioning_data.unconditioned_embeddings.pooled_embeds, uncond_text_embedding.pooled_embeds,
text_embeddings.pooled_embeds, cond_text_embedding.pooled_embeds,
], ],
dim=0, dim=0,
), ),
"time_ids": torch.cat( "time_ids": torch.cat(
[ [
conditioning_data.unconditioned_embeddings.add_time_ids, uncond_text_embedding.add_time_ids,
text_embeddings.add_time_ids, cond_text_embedding.add_time_ids,
], ],
dim=0, dim=0,
), ),
} }
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch( both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
conditioning_data.unconditioned_embeddings.embeds, text_embeddings.embeds uncond_text_embedding.embeds, cond_text_embedding.embeds
) )
both_results = self.model_forward_callback( both_results = self.model_forward_callback(
x_twice, x_twice,
@ -385,7 +439,7 @@ class InvokeAIDiffuserComponent:
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
): ) -> tuple[torch.Tensor, torch.Tensor]:
"""Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of """Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of
slower execution speed. slower execution speed.
""" """