Very experimentation with various regional prompting tuning params.

This commit is contained in:
Ryan Dick 2024-03-02 17:43:21 -05:00
parent 942efa011e
commit ad18429fe3
4 changed files with 89 additions and 22 deletions

View File

@ -97,6 +97,23 @@ class TextConditioningData:
self.cond_text = cond_text
self.uncond_regions = uncond_regions
self.cond_regions = cond_regions
# All params:
# negative_cross_attn_mask_score: -10000 (recommended to leave this as -10000 to prevent leakage to the rest of the image)
# positive_cross_attn_mask_score: 0.0 (relative weightin of masks)
# positive_self_attn_mask_score: 0.3
# negative_self_attn_mask_score: This doesn't really make sense. It would effectively have the same effect as further increasing positive_self_attn_mask_score.
# cross_attn_start_step
# self_attn_mask_begin_step_percent: 0.0
# self_attn_mask_end_step percent: 0.5
# Should we allow cross_attn_mask_begin_step_percent and cross_attn_mask_end_step_percent? Probably not, this seems like more control than necessary. And easy to add in the future.
self.negative_cross_attn_mask_score = -10000
self.positive_cross_attn_mask_score = 0.0
self.positive_self_attn_mask_score = 0.3
self.self_attn_mask_end_step_percent = 0.5
# mask_weight: float = Field(
# default=1.0,
# description="The weight to apply to the mask. This weight controls the relative weighting of overlapping masks. This weight gets added to the attention map logits before applying a pixelwise softmax.",
# )
# Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
# `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
# Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate

View File

@ -58,6 +58,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
scale: float = 1.0,
# For regional prompting:
regional_prompt_data: Optional[RegionalPromptData] = None,
percent_through: Optional[torch.FloatTensor] = None,
# For IP-Adapter:
ip_adapter_image_prompt_embeds: Optional[list[torch.Tensor]] = None,
) -> torch.FloatTensor:
@ -93,6 +94,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
# Handle regional prompt attention masks.
if regional_prompt_data is not None:
assert percent_through is not None
_, query_seq_len, _ = hidden_states.shape
if is_cross_attention:
prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask(
@ -102,16 +104,19 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
prompt_region_attention_mask = prompt_region_attention_mask.to(
dtype=hidden_states.dtype, device=hidden_states.device
)
prompt_region_attention_mask[prompt_region_attention_mask < 0.5] = -10000.0
prompt_region_attention_mask[prompt_region_attention_mask >= 0.5] = 0.0
else: # self-attention
prompt_region_attention_mask = regional_prompt_data.get_self_attn_mask(query_seq_len=query_seq_len)
prompt_region_attention_mask = regional_prompt_data.get_self_attn_mask(
query_seq_len=query_seq_len,
percent_through=percent_through,
device=hidden_states.device,
dtype=hidden_states.dtype,
)
# TODO(ryand): Avoid redundant type/device conversion here.
prompt_region_attention_mask = prompt_region_attention_mask.to(
dtype=hidden_states.dtype, device=hidden_states.device
)
# prompt_region_attention_mask = prompt_region_attention_mask.to(
# dtype=hidden_states.dtype, device=hidden_states.device
# )
# prompt_region_attention_mask[prompt_region_attention_mask < 0.5] = -0.5
# prompt_region_attention_mask[prompt_region_attention_mask >= 0.5] = 0.0

View File

@ -41,6 +41,17 @@ class RegionalPromptData:
regions, max_downscale_factor
)
# TODO: These should be indexed by batch sample index and prompt index.
# Next:
# - Add support for setting these one nodes. Might just need positive cross-attention mask score. Being able to downweight the global prompt mighth help alot.
# - Scale by region size.
self.negative_cross_attn_mask_score = -10000
self.positive_cross_attn_mask_score = 0.0
self.positive_self_attn_mask_score = 2.0
self.self_attn_mask_end_step_percent = 0.3
# This one is for regional prompting in general, so should be set on the DenoiseLatents node.
self.self_attn_score_range = 3.0
def _prepare_spatial_masks(
self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8
) -> list[dict[int, torch.Tensor]]:
@ -179,9 +190,14 @@ class RegionalPromptData:
0, prompt_idx, :, :
]
pos_mask = attn_mask >= 0.5
attn_mask[~pos_mask] = self.negative_cross_attn_mask_score
attn_mask[pos_mask] = self.positive_cross_attn_mask_score
return attn_mask
def get_self_attn_mask(self, query_seq_len: int) -> torch.Tensor:
def get_self_attn_mask(
self, query_seq_len: int, percent_through: float, device: torch.device, dtype: torch.dtype
) -> torch.Tensor:
"""Get the self-attention mask for the given query sequence length.
Args:
@ -193,11 +209,15 @@ class RegionalPromptData:
dtype: float
The mask is a binary mask with values of 0.0 and 1.0.
"""
# TODO(ryand): Manage dtype and device properly. There's a lot of inefficient copying, conversion, and
# unnecessary CPU operations happening in this class.
batch_size = len(self._spatial_masks_by_seq_len)
batch_spatial_masks = [self._spatial_masks_by_seq_len[b][query_seq_len] for b in range(batch_size)]
batch_spatial_masks = [
self._spatial_masks_by_seq_len[b][query_seq_len].to(device=device, dtype=dtype) for b in range(batch_size)
]
# Create an empty attention mask with the correct shape.
attn_mask = torch.zeros((batch_size, query_seq_len, query_seq_len))
attn_mask = torch.zeros((batch_size, query_seq_len, query_seq_len), dtype=dtype, device=device)
for batch_idx in range(batch_size):
batch_sample_spatial_masks = batch_spatial_masks[batch_idx]
@ -207,11 +227,32 @@ class RegionalPromptData:
batch_sample_query_masks = batch_sample_spatial_masks.view((1, num_prompts, query_seq_len, 1))
for prompt_idx in range(num_prompts):
if percent_through > self.self_attn_mask_end_step_percent:
continue
prompt_query_mask = batch_sample_query_masks[0, prompt_idx, :, 0] # Shape: (query_seq_len,)
# Multiply a (1, query_seq_len) mask by a (query_seq_len, 1) mask to get a (query_seq_len,
# query_seq_len) mask.
attn_mask[batch_idx, :, :] += prompt_query_mask.unsqueeze(0) * prompt_query_mask.unsqueeze(1) * 0.5
attn_mask[batch_idx, :, :] += (
prompt_query_mask.unsqueeze(0) * prompt_query_mask.unsqueeze(1) * self.positive_self_attn_mask_score
)
# Since we were adding masks in the previous loop, we need to clamp the values to 1.0.
# attn_mask[attn_mask > 0.5] = 1.0
# attn_mask_min = attn_mask[batch_idx].min()
# attn_mask_max = attn_mask[batch_idx].max()
# attn_mask_range = attn_mask_max - attn_mask_min
# if abs(attn_mask_range) < 0.0001:
# # All attn_mask value in this batch sample are the same, set the attn_mask to 0.0s (to avoid divide by
# # zero in the normalization).
# attn_mask[batch_idx] = attn_mask[batch_idx] * 0.0
# else:
# # Normalize from range [attn_mask_min, attn_mask_max] to [0, self.self_attn_score_range].
# attn_mask[batch_idx] = (
# (attn_mask[batch_idx] - attn_mask_min) / attn_mask_range * self.self_attn_score_range
# )
attn_mask_min = attn_mask[batch_idx].min()
# Adjust so that the minimum value is 0.0 regardless of whether all pixels are covered or not.
if abs(attn_mask_min) > 0.0001:
attn_mask[batch_idx] = attn_mask[batch_idx] - attn_mask_min
return attn_mask

View File

@ -200,9 +200,9 @@ class InvokeAIDiffuserComponent:
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
):
percent_through = step_index / total_step_count
cross_attention_control_types_to_do = []
if self.cross_attention_control_context is not None:
percent_through = step_index / total_step_count
cross_attention_control_types_to_do = (
self.cross_attention_control_context.get_active_cross_attention_control_types_for_step(percent_through)
)
@ -219,6 +219,7 @@ class InvokeAIDiffuserComponent:
sigma=timestep,
conditioning_data=conditioning_data,
ip_adapter_conditioning=ip_adapter_conditioning,
percent_through=percent_through,
cross_attention_control_types_to_do=cross_attention_control_types_to_do,
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
@ -232,6 +233,7 @@ class InvokeAIDiffuserComponent:
x=sample,
sigma=timestep,
conditioning_data=conditioning_data,
percent_through=percent_through,
ip_adapter_conditioning=ip_adapter_conditioning,
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
@ -293,6 +295,7 @@ class InvokeAIDiffuserComponent:
sigma,
conditioning_data: TextConditioningData,
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
percent_through: float,
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
@ -326,8 +329,8 @@ class InvokeAIDiffuserComponent:
)
if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None:
# TODO(ryand): We currently call from_regions(...) for every denoising step. The text conditionings and
# masks are not changing from step-to-step, so this really only needs to be done once. While this seems
# TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings
# and masks are not changing from step-to-step, so this really only needs to be done once. While this seems
# painfully inefficient, the time spent is typically negligible compared to the forward inference pass of
# the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly
# awkward to handle both standard conditioning and sequential conditioning further up the stack.
@ -345,8 +348,8 @@ class InvokeAIDiffuserComponent:
)
regions.append(r)
_, key_seq_len, _ = both_conditionings.shape
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(regions=regions)
cross_attention_kwargs["percent_through"] = percent_through
both_results = self.model_forward_callback(
x_twice,
@ -369,6 +372,7 @@ class InvokeAIDiffuserComponent:
conditioning_data: TextConditioningData,
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
cross_attention_control_types_to_do: list[CrossAttentionType],
percent_through: float,
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
@ -439,10 +443,10 @@ class InvokeAIDiffuserComponent:
# Prepare prompt regions for the unconditioned pass.
if conditioning_data.uncond_regions is not None:
_, key_seq_len, _ = conditioning_data.uncond_text.embeds.shape
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData.from_regions(
regions=[conditioning_data.uncond_regions], key_seq_len=key_seq_len
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
regions=[conditioning_data.uncond_regions]
)
cross_attention_kwargs["percent_through"] = percent_through
# Run unconditioned UNet denoising (i.e. negative prompt).
unconditioned_next_x = self.model_forward_callback(
@ -485,10 +489,10 @@ class InvokeAIDiffuserComponent:
# Prepare prompt regions for the conditioned pass.
if conditioning_data.cond_regions is not None:
_, key_seq_len, _ = conditioning_data.cond_text.embeds.shape
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData.from_regions(
regions=[conditioning_data.cond_regions], key_seq_len=key_seq_len
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
regions=[conditioning_data.cond_regions]
)
cross_attention_kwargs["percent_through"] = percent_through
# Run conditioned UNet denoising (i.e. positive prompt).
conditioned_next_x = self.model_forward_callback(