mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
wip
This commit is contained in:
parent
d3a40c5b2b
commit
b8cbff828b
@ -59,7 +59,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
scale: float = 1.0,
|
||||
# For regional prompting:
|
||||
regional_prompt_data: Optional[RegionalPromptData] = None,
|
||||
percent_through: Optional[torch.FloatTensor] = None,
|
||||
percent_through: Optional[float] = None,
|
||||
# For IP-Adapter:
|
||||
ip_adapter_image_prompt_embeds: Optional[list[torch.Tensor]] = None,
|
||||
) -> torch.FloatTensor:
|
||||
@ -106,13 +106,13 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
dtype=hidden_states.dtype, device=hidden_states.device
|
||||
)
|
||||
|
||||
attn_mask_weight = 1.0
|
||||
attn_mask_weight = 0.8
|
||||
else: # self-attention
|
||||
prompt_region_attention_mask = regional_prompt_data.get_self_attn_mask(
|
||||
query_seq_len=query_seq_len,
|
||||
percent_through=percent_through,
|
||||
)
|
||||
attn_mask_weight = 0.3
|
||||
attn_mask_weight = 0.5
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
@ -142,7 +142,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if regional_prompt_data is not None:
|
||||
if regional_prompt_data is not None and percent_through < 0.5:
|
||||
prompt_region_attention_mask = attn.prepare_attention_mask(
|
||||
prompt_region_attention_mask, sequence_length, batch_size
|
||||
)
|
||||
@ -161,10 +161,12 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
m_pos * prompt_region_attention_mask - m_neg * (1.0 - prompt_region_attention_mask)
|
||||
)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = prompt_region_attention_mask
|
||||
if attention_mask is None:
|
||||
attention_mask = prompt_region_attention_mask
|
||||
else:
|
||||
attention_mask = prompt_region_attention_mask + attention_mask
|
||||
else:
|
||||
attention_mask = prompt_region_attention_mask + attention_mask
|
||||
pass
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
|
Loading…
Reference in New Issue
Block a user