mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fixup some details of densediffusion for testing.
This commit is contained in:
parent
b8cbff828b
commit
969982b789
@ -106,13 +106,13 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
dtype=hidden_states.dtype, device=hidden_states.device
|
||||
)
|
||||
|
||||
attn_mask_weight = 0.8
|
||||
attn_mask_weight = 1.0 * ((1 - percent_through) ** 5)
|
||||
else: # self-attention
|
||||
prompt_region_attention_mask = regional_prompt_data.get_self_attn_mask(
|
||||
query_seq_len=query_seq_len,
|
||||
percent_through=percent_through,
|
||||
)
|
||||
attn_mask_weight = 0.5
|
||||
attn_mask_weight = 0.3 * ((1 - percent_through) ** 5)
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
@ -142,7 +142,9 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if regional_prompt_data is not None and percent_through < 0.5:
|
||||
if regional_prompt_data is not None and percent_through < 0.3:
|
||||
# Don't apply to uncond????
|
||||
|
||||
prompt_region_attention_mask = attn.prepare_attention_mask(
|
||||
prompt_region_attention_mask, sequence_length, batch_size
|
||||
)
|
||||
@ -154,8 +156,8 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
|
||||
scale_factor = 1 / math.sqrt(query.size(-1))
|
||||
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
||||
m_pos = attn_weight.max() - attn_weight
|
||||
m_neg = attn_weight - attn_weight.min()
|
||||
m_pos = attn_weight.max(dim=-1, keepdim=True)[0] - attn_weight
|
||||
m_neg = attn_weight - attn_weight.min(dim=-1, keepdim=True)[0]
|
||||
|
||||
prompt_region_attention_mask = attn_mask_weight * (
|
||||
m_pos * prompt_region_attention_mask - m_neg * (1.0 - prompt_region_attention_mask)
|
||||
|
@ -97,11 +97,11 @@ class RegionalPromptData:
|
||||
|
||||
for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges):
|
||||
batch_sample_query_scores = batch_sample_query_masks[0, prompt_idx, :, :].clone()
|
||||
size = batch_sample_query_scores.sum() / batch_sample_query_scores.numel()
|
||||
size = size.to(dtype=batch_sample_query_scores.dtype)
|
||||
batch_sample_query_mask = batch_sample_query_scores > 0.5
|
||||
batch_sample_query_scores[
|
||||
batch_sample_query_mask
|
||||
] = batch_sample_regions.positive_cross_attn_mask_scores[prompt_idx]
|
||||
batch_sample_query_scores[~batch_sample_query_mask] = self._negative_cross_attn_mask_score
|
||||
batch_sample_query_scores[batch_sample_query_mask] = 1.0 * (1.0 - size)
|
||||
batch_sample_query_scores[~batch_sample_query_mask] = 0.0
|
||||
attn_mask[batch_idx, :, embedding_range.start : embedding_range.end] = batch_sample_query_scores
|
||||
|
||||
return attn_mask
|
||||
@ -133,20 +133,21 @@ class RegionalPromptData:
|
||||
batch_sample_query_masks = batch_sample_spatial_masks.view((1, num_prompts, query_seq_len, 1))
|
||||
|
||||
for prompt_idx in range(num_prompts):
|
||||
if percent_through > batch_sample_regions.self_attn_adjustment_end_step_percents[prompt_idx]:
|
||||
continue
|
||||
prompt_query_mask = batch_sample_query_masks[0, prompt_idx, :, 0] # Shape: (query_seq_len,)
|
||||
size = prompt_query_mask.sum() / prompt_query_mask.numel()
|
||||
size = size.to(dtype=prompt_query_mask.dtype)
|
||||
# Multiply a (1, query_seq_len) mask by a (query_seq_len, 1) mask to get a (query_seq_len,
|
||||
# query_seq_len) mask.
|
||||
# TODO(ryand): Is += really the best option here?
|
||||
attn_mask[batch_idx, :, :] += (
|
||||
prompt_query_mask.unsqueeze(0)
|
||||
* prompt_query_mask.unsqueeze(1)
|
||||
* batch_sample_regions.positive_self_attn_mask_scores[prompt_idx]
|
||||
prompt_query_mask.unsqueeze(0) * prompt_query_mask.unsqueeze(1) * (1 - size)
|
||||
)
|
||||
|
||||
attn_mask[attn_mask > 0.5] = 1.0
|
||||
attn_mask[attn_mask <= 0.5] = 0.0
|
||||
# if attn_mask[batch_idx].max() < 0.01:
|
||||
# attn_mask[batch_idx, ...] = 1.0
|
||||
|
||||
# attn_mask[attn_mask > 0.5] = 1.0
|
||||
# attn_mask[attn_mask <= 0.5] = 0.0
|
||||
# attn_mask_min = attn_mask[batch_idx].min()
|
||||
|
||||
# # Adjust so that the minimum value is 0.0 regardless of whether all pixels are covered or not.
|
||||
|
Loading…
Reference in New Issue
Block a user