From b8cbff828b6fa08944255e0d3b3f208862f90a50 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 6 Mar 2024 10:52:35 -0500 Subject: [PATCH] wip --- .../diffusion/custom_attention.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py index 8dc402928f..017ee700d1 100644 --- a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py +++ b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py @@ -59,7 +59,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): scale: float = 1.0, # For regional prompting: regional_prompt_data: Optional[RegionalPromptData] = None, - percent_through: Optional[torch.FloatTensor] = None, + percent_through: Optional[float] = None, # For IP-Adapter: ip_adapter_image_prompt_embeds: Optional[list[torch.Tensor]] = None, ) -> torch.FloatTensor: @@ -106,13 +106,13 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): dtype=hidden_states.dtype, device=hidden_states.device ) - attn_mask_weight = 1.0 + attn_mask_weight = 0.8 else: # self-attention prompt_region_attention_mask = regional_prompt_data.get_self_attn_mask( query_seq_len=query_seq_len, percent_through=percent_through, ) - attn_mask_weight = 0.3 + attn_mask_weight = 0.5 if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) @@ -142,7 +142,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - if regional_prompt_data is not None: + if regional_prompt_data is not None and percent_through < 0.5: prompt_region_attention_mask = attn.prepare_attention_mask( prompt_region_attention_mask, sequence_length, batch_size ) @@ -161,10 +161,12 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): m_pos * prompt_region_attention_mask - m_neg * (1.0 - prompt_region_attention_mask) ) - if attention_mask is None: - attention_mask = prompt_region_attention_mask + if attention_mask is None: + attention_mask = prompt_region_attention_mask + else: + attention_mask = prompt_region_attention_mask + attention_mask else: - attention_mask = prompt_region_attention_mask + attention_mask + pass # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1