Make regional prompting work with sequential conditioning.

This commit is contained in:
Ryan Dick 2024-02-28 21:21:50 -05:00
parent 5f49e7ae26
commit e132afb705
2 changed files with 25 additions and 5 deletions

View File

@ -126,19 +126,21 @@ class RegionalPromptAttnProcessor2_0(AttnProcessor2_0):
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if encoder_hidden_states is not None:
assert regional_prompt_data is not None
assert attention_mask is not None
if encoder_hidden_states is not None and regional_prompt_data is not None:
# If encoder_hidden_states is not None, that means we are doing cross-attention case.
_, query_seq_len, _ = hidden_states.shape
prompt_region_attention_mask = regional_prompt_data.get_attn_mask(query_seq_len)
# TODO(ryand): Avoid redundant type/device conversion here.
prompt_region_attention_mask = prompt_region_attention_mask.to(
dtype=attention_mask.dtype, device=attention_mask.device
dtype=encoder_hidden_states.dtype, device=encoder_hidden_states.device
)
prompt_region_attention_mask[prompt_region_attention_mask < 0.5] = -10000.0
prompt_region_attention_mask[prompt_region_attention_mask >= 0.5] = 0.0
attention_mask = prompt_region_attention_mask + attention_mask
if attention_mask is None:
attention_mask = prompt_region_attention_mask
else:
attention_mask = prompt_region_attention_mask + attention_mask
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

View File

@ -448,6 +448,15 @@ class InvokeAIDiffuserComponent:
"time_ids": conditioning_data.uncond_text.add_time_ids,
}
# Prepare prompt regions for the unconditioned pass.
if conditioning_data.uncond_regions is not None:
_, key_seq_len, _ = conditioning_data.uncond_text.embeds.shape
cross_attention_kwargs = {
"regional_prompt_data": RegionalPromptData.from_regions(
regions=[conditioning_data.uncond_regions], key_seq_len=key_seq_len
)
}
# Run unconditioned UNet denoising (i.e. negative prompt).
unconditioned_next_x = self.model_forward_callback(
x,
@ -489,6 +498,15 @@ class InvokeAIDiffuserComponent:
"time_ids": conditioning_data.cond_text.add_time_ids,
}
# Prepare prompt regions for the conditioned pass.
if conditioning_data.cond_regions is not None:
_, key_seq_len, _ = conditioning_data.cond_text.embeds.shape
cross_attention_kwargs = {
"regional_prompt_data": RegionalPromptData.from_regions(
regions=[conditioning_data.cond_regions], key_seq_len=key_seq_len
)
}
# Run conditioned UNet denoising (i.e. positive prompt).
conditioned_next_x = self.model_forward_callback(
x,