mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Make regional prompting work with sequential conditioning.
This commit is contained in:
parent
5f49e7ae26
commit
e132afb705
@ -126,19 +126,21 @@ class RegionalPromptAttnProcessor2_0(AttnProcessor2_0):
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
assert regional_prompt_data is not None
|
||||
assert attention_mask is not None
|
||||
if encoder_hidden_states is not None and regional_prompt_data is not None:
|
||||
# If encoder_hidden_states is not None, that means we are doing cross-attention case.
|
||||
_, query_seq_len, _ = hidden_states.shape
|
||||
prompt_region_attention_mask = regional_prompt_data.get_attn_mask(query_seq_len)
|
||||
# TODO(ryand): Avoid redundant type/device conversion here.
|
||||
prompt_region_attention_mask = prompt_region_attention_mask.to(
|
||||
dtype=attention_mask.dtype, device=attention_mask.device
|
||||
dtype=encoder_hidden_states.dtype, device=encoder_hidden_states.device
|
||||
)
|
||||
prompt_region_attention_mask[prompt_region_attention_mask < 0.5] = -10000.0
|
||||
prompt_region_attention_mask[prompt_region_attention_mask >= 0.5] = 0.0
|
||||
|
||||
attention_mask = prompt_region_attention_mask + attention_mask
|
||||
if attention_mask is None:
|
||||
attention_mask = prompt_region_attention_mask
|
||||
else:
|
||||
attention_mask = prompt_region_attention_mask + attention_mask
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
|
@ -448,6 +448,15 @@ class InvokeAIDiffuserComponent:
|
||||
"time_ids": conditioning_data.uncond_text.add_time_ids,
|
||||
}
|
||||
|
||||
# Prepare prompt regions for the unconditioned pass.
|
||||
if conditioning_data.uncond_regions is not None:
|
||||
_, key_seq_len, _ = conditioning_data.uncond_text.embeds.shape
|
||||
cross_attention_kwargs = {
|
||||
"regional_prompt_data": RegionalPromptData.from_regions(
|
||||
regions=[conditioning_data.uncond_regions], key_seq_len=key_seq_len
|
||||
)
|
||||
}
|
||||
|
||||
# Run unconditioned UNet denoising (i.e. negative prompt).
|
||||
unconditioned_next_x = self.model_forward_callback(
|
||||
x,
|
||||
@ -489,6 +498,15 @@ class InvokeAIDiffuserComponent:
|
||||
"time_ids": conditioning_data.cond_text.add_time_ids,
|
||||
}
|
||||
|
||||
# Prepare prompt regions for the conditioned pass.
|
||||
if conditioning_data.cond_regions is not None:
|
||||
_, key_seq_len, _ = conditioning_data.cond_text.embeds.shape
|
||||
cross_attention_kwargs = {
|
||||
"regional_prompt_data": RegionalPromptData.from_regions(
|
||||
regions=[conditioning_data.cond_regions], key_seq_len=key_seq_len
|
||||
)
|
||||
}
|
||||
|
||||
# Run conditioned UNet denoising (i.e. positive prompt).
|
||||
conditioned_next_x = self.model_forward_callback(
|
||||
x,
|
||||
|
Loading…
Reference in New Issue
Block a user