Rough hacky implementation of DenseDiffusion.

This commit is contained in:
Ryan Dick 2024-03-05 18:10:01 -05:00
parent 57266d36a2
commit d3a40c5b2b
2 changed files with 39 additions and 17 deletions

View File

@ -1,3 +1,4 @@
import math
from typing import Optional from typing import Optional
import torch import torch
@ -105,24 +106,13 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
dtype=hidden_states.dtype, device=hidden_states.device dtype=hidden_states.dtype, device=hidden_states.device
) )
attn_mask_weight = 1.0
else: # self-attention else: # self-attention
prompt_region_attention_mask = regional_prompt_data.get_self_attn_mask( prompt_region_attention_mask = regional_prompt_data.get_self_attn_mask(
query_seq_len=query_seq_len, query_seq_len=query_seq_len,
percent_through=percent_through, percent_through=percent_through,
) )
attn_mask_weight = 0.3
if attention_mask is None:
attention_mask = prompt_region_attention_mask
else:
attention_mask = prompt_region_attention_mask + attention_mask
# Start unmodified block from AttnProcessor2_0.
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None: if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
@ -146,6 +136,36 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if regional_prompt_data is not None:
prompt_region_attention_mask = attn.prepare_attention_mask(
prompt_region_attention_mask, sequence_length, batch_size
)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
prompt_region_attention_mask = prompt_region_attention_mask.view(
batch_size, attn.heads, -1, prompt_region_attention_mask.shape[-1]
)
scale_factor = 1 / math.sqrt(query.size(-1))
attn_weight = query @ key.transpose(-2, -1) * scale_factor
m_pos = attn_weight.max() - attn_weight
m_neg = attn_weight - attn_weight.min()
prompt_region_attention_mask = attn_mask_weight * (
m_pos * prompt_region_attention_mask - m_neg * (1.0 - prompt_region_attention_mask)
)
if attention_mask is None:
attention_mask = prompt_region_attention_mask
else:
attention_mask = prompt_region_attention_mask + attention_mask
# the output of sdp = (batch, num_heads, seq_len, head_dim) # the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1 # TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention( hidden_states = F.scaled_dot_product_attention(

View File

@ -145,9 +145,11 @@ class RegionalPromptData:
* batch_sample_regions.positive_self_attn_mask_scores[prompt_idx] * batch_sample_regions.positive_self_attn_mask_scores[prompt_idx]
) )
attn_mask_min = attn_mask[batch_idx].min() attn_mask[attn_mask > 0.5] = 1.0
attn_mask[attn_mask <= 0.5] = 0.0
# attn_mask_min = attn_mask[batch_idx].min()
# Adjust so that the minimum value is 0.0 regardless of whether all pixels are covered or not. # # Adjust so that the minimum value is 0.0 regardless of whether all pixels are covered or not.
if abs(attn_mask_min) > 0.0001: # if abs(attn_mask_min) > 0.0001:
attn_mask[batch_idx] = attn_mask[batch_idx] - attn_mask_min # attn_mask[batch_idx] = attn_mask[batch_idx] - attn_mask_min
return attn_mask return attn_mask