diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py index 017ee700d1..625fa41259 100644 --- a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py +++ b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py @@ -106,13 +106,13 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): dtype=hidden_states.dtype, device=hidden_states.device ) - attn_mask_weight = 0.8 + attn_mask_weight = 1.0 * ((1 - percent_through) ** 5) else: # self-attention prompt_region_attention_mask = regional_prompt_data.get_self_attn_mask( query_seq_len=query_seq_len, percent_through=percent_through, ) - attn_mask_weight = 0.5 + attn_mask_weight = 0.3 * ((1 - percent_through) ** 5) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) @@ -142,7 +142,9 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - if regional_prompt_data is not None and percent_through < 0.5: + if regional_prompt_data is not None and percent_through < 0.3: + # Don't apply to uncond???? + prompt_region_attention_mask = attn.prepare_attention_mask( prompt_region_attention_mask, sequence_length, batch_size ) @@ -154,8 +156,8 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): scale_factor = 1 / math.sqrt(query.size(-1)) attn_weight = query @ key.transpose(-2, -1) * scale_factor - m_pos = attn_weight.max() - attn_weight - m_neg = attn_weight - attn_weight.min() + m_pos = attn_weight.max(dim=-1, keepdim=True)[0] - attn_weight + m_neg = attn_weight - attn_weight.min(dim=-1, keepdim=True)[0] prompt_region_attention_mask = attn_mask_weight * ( m_pos * prompt_region_attention_mask - m_neg * (1.0 - prompt_region_attention_mask) diff --git a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py index 244b8a3276..76b5e83b53 100644 --- a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py @@ -97,11 +97,11 @@ class RegionalPromptData: for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges): batch_sample_query_scores = batch_sample_query_masks[0, prompt_idx, :, :].clone() + size = batch_sample_query_scores.sum() / batch_sample_query_scores.numel() + size = size.to(dtype=batch_sample_query_scores.dtype) batch_sample_query_mask = batch_sample_query_scores > 0.5 - batch_sample_query_scores[ - batch_sample_query_mask - ] = batch_sample_regions.positive_cross_attn_mask_scores[prompt_idx] - batch_sample_query_scores[~batch_sample_query_mask] = self._negative_cross_attn_mask_score + batch_sample_query_scores[batch_sample_query_mask] = 1.0 * (1.0 - size) + batch_sample_query_scores[~batch_sample_query_mask] = 0.0 attn_mask[batch_idx, :, embedding_range.start : embedding_range.end] = batch_sample_query_scores return attn_mask @@ -133,20 +133,21 @@ class RegionalPromptData: batch_sample_query_masks = batch_sample_spatial_masks.view((1, num_prompts, query_seq_len, 1)) for prompt_idx in range(num_prompts): - if percent_through > batch_sample_regions.self_attn_adjustment_end_step_percents[prompt_idx]: - continue prompt_query_mask = batch_sample_query_masks[0, prompt_idx, :, 0] # Shape: (query_seq_len,) + size = prompt_query_mask.sum() / prompt_query_mask.numel() + size = size.to(dtype=prompt_query_mask.dtype) # Multiply a (1, query_seq_len) mask by a (query_seq_len, 1) mask to get a (query_seq_len, # query_seq_len) mask. # TODO(ryand): Is += really the best option here? attn_mask[batch_idx, :, :] += ( - prompt_query_mask.unsqueeze(0) - * prompt_query_mask.unsqueeze(1) - * batch_sample_regions.positive_self_attn_mask_scores[prompt_idx] + prompt_query_mask.unsqueeze(0) * prompt_query_mask.unsqueeze(1) * (1 - size) ) - attn_mask[attn_mask > 0.5] = 1.0 - attn_mask[attn_mask <= 0.5] = 0.0 + # if attn_mask[batch_idx].max() < 0.01: + # attn_mask[batch_idx, ...] = 1.0 + + # attn_mask[attn_mask > 0.5] = 1.0 + # attn_mask[attn_mask <= 0.5] = 0.0 # attn_mask_min = attn_mask[batch_idx].min() # # Adjust so that the minimum value is 0.0 regardless of whether all pixels are covered or not.