Fixup some details of densediffusion for testing.

This commit is contained in:
Ryan Dick 2024-03-06 19:03:26 -05:00
parent b8cbff828b
commit 969982b789
2 changed files with 19 additions and 16 deletions

View File

@ -106,13 +106,13 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
dtype=hidden_states.dtype, device=hidden_states.device dtype=hidden_states.dtype, device=hidden_states.device
) )
attn_mask_weight = 0.8 attn_mask_weight = 1.0 * ((1 - percent_through) ** 5)
else: # self-attention else: # self-attention
prompt_region_attention_mask = regional_prompt_data.get_self_attn_mask( prompt_region_attention_mask = regional_prompt_data.get_self_attn_mask(
query_seq_len=query_seq_len, query_seq_len=query_seq_len,
percent_through=percent_through, percent_through=percent_through,
) )
attn_mask_weight = 0.5 attn_mask_weight = 0.3 * ((1 - percent_through) ** 5)
if attn.group_norm is not None: if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
@ -142,7 +142,9 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
# (batch, heads, source_length, target_length) # (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if regional_prompt_data is not None and percent_through < 0.5: if regional_prompt_data is not None and percent_through < 0.3:
# Don't apply to uncond????
prompt_region_attention_mask = attn.prepare_attention_mask( prompt_region_attention_mask = attn.prepare_attention_mask(
prompt_region_attention_mask, sequence_length, batch_size prompt_region_attention_mask, sequence_length, batch_size
) )
@ -154,8 +156,8 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
scale_factor = 1 / math.sqrt(query.size(-1)) scale_factor = 1 / math.sqrt(query.size(-1))
attn_weight = query @ key.transpose(-2, -1) * scale_factor attn_weight = query @ key.transpose(-2, -1) * scale_factor
m_pos = attn_weight.max() - attn_weight m_pos = attn_weight.max(dim=-1, keepdim=True)[0] - attn_weight
m_neg = attn_weight - attn_weight.min() m_neg = attn_weight - attn_weight.min(dim=-1, keepdim=True)[0]
prompt_region_attention_mask = attn_mask_weight * ( prompt_region_attention_mask = attn_mask_weight * (
m_pos * prompt_region_attention_mask - m_neg * (1.0 - prompt_region_attention_mask) m_pos * prompt_region_attention_mask - m_neg * (1.0 - prompt_region_attention_mask)

View File

@ -97,11 +97,11 @@ class RegionalPromptData:
for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges): for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges):
batch_sample_query_scores = batch_sample_query_masks[0, prompt_idx, :, :].clone() batch_sample_query_scores = batch_sample_query_masks[0, prompt_idx, :, :].clone()
size = batch_sample_query_scores.sum() / batch_sample_query_scores.numel()
size = size.to(dtype=batch_sample_query_scores.dtype)
batch_sample_query_mask = batch_sample_query_scores > 0.5 batch_sample_query_mask = batch_sample_query_scores > 0.5
batch_sample_query_scores[ batch_sample_query_scores[batch_sample_query_mask] = 1.0 * (1.0 - size)
batch_sample_query_mask batch_sample_query_scores[~batch_sample_query_mask] = 0.0
] = batch_sample_regions.positive_cross_attn_mask_scores[prompt_idx]
batch_sample_query_scores[~batch_sample_query_mask] = self._negative_cross_attn_mask_score
attn_mask[batch_idx, :, embedding_range.start : embedding_range.end] = batch_sample_query_scores attn_mask[batch_idx, :, embedding_range.start : embedding_range.end] = batch_sample_query_scores
return attn_mask return attn_mask
@ -133,20 +133,21 @@ class RegionalPromptData:
batch_sample_query_masks = batch_sample_spatial_masks.view((1, num_prompts, query_seq_len, 1)) batch_sample_query_masks = batch_sample_spatial_masks.view((1, num_prompts, query_seq_len, 1))
for prompt_idx in range(num_prompts): for prompt_idx in range(num_prompts):
if percent_through > batch_sample_regions.self_attn_adjustment_end_step_percents[prompt_idx]:
continue
prompt_query_mask = batch_sample_query_masks[0, prompt_idx, :, 0] # Shape: (query_seq_len,) prompt_query_mask = batch_sample_query_masks[0, prompt_idx, :, 0] # Shape: (query_seq_len,)
size = prompt_query_mask.sum() / prompt_query_mask.numel()
size = size.to(dtype=prompt_query_mask.dtype)
# Multiply a (1, query_seq_len) mask by a (query_seq_len, 1) mask to get a (query_seq_len, # Multiply a (1, query_seq_len) mask by a (query_seq_len, 1) mask to get a (query_seq_len,
# query_seq_len) mask. # query_seq_len) mask.
# TODO(ryand): Is += really the best option here? # TODO(ryand): Is += really the best option here?
attn_mask[batch_idx, :, :] += ( attn_mask[batch_idx, :, :] += (
prompt_query_mask.unsqueeze(0) prompt_query_mask.unsqueeze(0) * prompt_query_mask.unsqueeze(1) * (1 - size)
* prompt_query_mask.unsqueeze(1)
* batch_sample_regions.positive_self_attn_mask_scores[prompt_idx]
) )
attn_mask[attn_mask > 0.5] = 1.0 # if attn_mask[batch_idx].max() < 0.01:
attn_mask[attn_mask <= 0.5] = 0.0 # attn_mask[batch_idx, ...] = 1.0
# attn_mask[attn_mask > 0.5] = 1.0
# attn_mask[attn_mask <= 0.5] = 0.0
# attn_mask_min = attn_mask[batch_idx].min() # attn_mask_min = attn_mask[batch_idx].min()
# # Adjust so that the minimum value is 0.0 regardless of whether all pixels are covered or not. # # Adjust so that the minimum value is 0.0 regardless of whether all pixels are covered or not.