mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Rough hacky implementation of DenseDiffusion.
This commit is contained in:
parent
57266d36a2
commit
d3a40c5b2b
@ -1,3 +1,4 @@
|
|||||||
|
import math
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -105,24 +106,13 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
dtype=hidden_states.dtype, device=hidden_states.device
|
dtype=hidden_states.dtype, device=hidden_states.device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
attn_mask_weight = 1.0
|
||||||
else: # self-attention
|
else: # self-attention
|
||||||
prompt_region_attention_mask = regional_prompt_data.get_self_attn_mask(
|
prompt_region_attention_mask = regional_prompt_data.get_self_attn_mask(
|
||||||
query_seq_len=query_seq_len,
|
query_seq_len=query_seq_len,
|
||||||
percent_through=percent_through,
|
percent_through=percent_through,
|
||||||
)
|
)
|
||||||
|
attn_mask_weight = 0.3
|
||||||
if attention_mask is None:
|
|
||||||
attention_mask = prompt_region_attention_mask
|
|
||||||
else:
|
|
||||||
attention_mask = prompt_region_attention_mask + attention_mask
|
|
||||||
|
|
||||||
# Start unmodified block from AttnProcessor2_0.
|
|
||||||
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
|
|
||||||
if attention_mask is not None:
|
|
||||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
|
||||||
# scaled_dot_product_attention expects attention_mask shape to be
|
|
||||||
# (batch, heads, source_length, target_length)
|
|
||||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
|
||||||
|
|
||||||
if attn.group_norm is not None:
|
if attn.group_norm is not None:
|
||||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||||
@ -146,6 +136,36 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||||
|
# scaled_dot_product_attention expects attention_mask shape to be
|
||||||
|
# (batch, heads, source_length, target_length)
|
||||||
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||||
|
|
||||||
|
if regional_prompt_data is not None:
|
||||||
|
prompt_region_attention_mask = attn.prepare_attention_mask(
|
||||||
|
prompt_region_attention_mask, sequence_length, batch_size
|
||||||
|
)
|
||||||
|
# scaled_dot_product_attention expects attention_mask shape to be
|
||||||
|
# (batch, heads, source_length, target_length)
|
||||||
|
prompt_region_attention_mask = prompt_region_attention_mask.view(
|
||||||
|
batch_size, attn.heads, -1, prompt_region_attention_mask.shape[-1]
|
||||||
|
)
|
||||||
|
|
||||||
|
scale_factor = 1 / math.sqrt(query.size(-1))
|
||||||
|
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
||||||
|
m_pos = attn_weight.max() - attn_weight
|
||||||
|
m_neg = attn_weight - attn_weight.min()
|
||||||
|
|
||||||
|
prompt_region_attention_mask = attn_mask_weight * (
|
||||||
|
m_pos * prompt_region_attention_mask - m_neg * (1.0 - prompt_region_attention_mask)
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = prompt_region_attention_mask
|
||||||
|
else:
|
||||||
|
attention_mask = prompt_region_attention_mask + attention_mask
|
||||||
|
|
||||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||||
hidden_states = F.scaled_dot_product_attention(
|
hidden_states = F.scaled_dot_product_attention(
|
||||||
|
@ -145,9 +145,11 @@ class RegionalPromptData:
|
|||||||
* batch_sample_regions.positive_self_attn_mask_scores[prompt_idx]
|
* batch_sample_regions.positive_self_attn_mask_scores[prompt_idx]
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_mask_min = attn_mask[batch_idx].min()
|
attn_mask[attn_mask > 0.5] = 1.0
|
||||||
|
attn_mask[attn_mask <= 0.5] = 0.0
|
||||||
|
# attn_mask_min = attn_mask[batch_idx].min()
|
||||||
|
|
||||||
# Adjust so that the minimum value is 0.0 regardless of whether all pixels are covered or not.
|
# # Adjust so that the minimum value is 0.0 regardless of whether all pixels are covered or not.
|
||||||
if abs(attn_mask_min) > 0.0001:
|
# if abs(attn_mask_min) > 0.0001:
|
||||||
attn_mask[batch_idx] = attn_mask[batch_idx] - attn_mask_min
|
# attn_mask[batch_idx] = attn_mask[batch_idx] - attn_mask_min
|
||||||
return attn_mask
|
return attn_mask
|
||||||
|
Loading…
Reference in New Issue
Block a user