From d3a40c5b2b3c2231ca3156bfdc277d101210c561 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 5 Mar 2024 18:10:01 -0500 Subject: [PATCH] Rough hacky implementation of DenseDiffusion. --- .../diffusion/custom_attention.py | 46 +++++++++++++------ .../diffusion/regional_prompt_data.py | 10 ++-- 2 files changed, 39 insertions(+), 17 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py index c0ac6d63e1..8dc402928f 100644 --- a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py +++ b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py @@ -1,3 +1,4 @@ +import math from typing import Optional import torch @@ -105,24 +106,13 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): dtype=hidden_states.dtype, device=hidden_states.device ) + attn_mask_weight = 1.0 else: # self-attention prompt_region_attention_mask = regional_prompt_data.get_self_attn_mask( query_seq_len=query_seq_len, percent_through=percent_through, ) - - if attention_mask is None: - attention_mask = prompt_region_attention_mask - else: - attention_mask = prompt_region_attention_mask + attention_mask - - # Start unmodified block from AttnProcessor2_0. - # vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + attn_mask_weight = 0.3 if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) @@ -146,6 +136,36 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if regional_prompt_data is not None: + prompt_region_attention_mask = attn.prepare_attention_mask( + prompt_region_attention_mask, sequence_length, batch_size + ) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + prompt_region_attention_mask = prompt_region_attention_mask.view( + batch_size, attn.heads, -1, prompt_region_attention_mask.shape[-1] + ) + + scale_factor = 1 / math.sqrt(query.size(-1)) + attn_weight = query @ key.transpose(-2, -1) * scale_factor + m_pos = attn_weight.max() - attn_weight + m_neg = attn_weight - attn_weight.min() + + prompt_region_attention_mask = attn_mask_weight * ( + m_pos * prompt_region_attention_mask - m_neg * (1.0 - prompt_region_attention_mask) + ) + + if attention_mask is None: + attention_mask = prompt_region_attention_mask + else: + attention_mask = prompt_region_attention_mask + attention_mask + # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( diff --git a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py index b7b7e61768..244b8a3276 100644 --- a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py @@ -145,9 +145,11 @@ class RegionalPromptData: * batch_sample_regions.positive_self_attn_mask_scores[prompt_idx] ) - attn_mask_min = attn_mask[batch_idx].min() + attn_mask[attn_mask > 0.5] = 1.0 + attn_mask[attn_mask <= 0.5] = 0.0 + # attn_mask_min = attn_mask[batch_idx].min() - # Adjust so that the minimum value is 0.0 regardless of whether all pixels are covered or not. - if abs(attn_mask_min) > 0.0001: - attn_mask[batch_idx] = attn_mask[batch_idx] - attn_mask_min + # # Adjust so that the minimum value is 0.0 regardless of whether all pixels are covered or not. + # if abs(attn_mask_min) > 0.0001: + # attn_mask[batch_idx] = attn_mask[batch_idx] - attn_mask_min return attn_mask