diff --git a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_attention.py b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_attention.py index 963aa3aa99..861ac128ea 100644 --- a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_attention.py +++ b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_attention.py @@ -32,6 +32,58 @@ class RegionalPromptData: class RegionalPromptAttnProcessor2_0(AttnProcessor2_0): """An attention processor that supports regional prompt attention for PyTorch 2.0.""" + def _prepare_regional_prompt_attention_mask( + self, + regional_prompt_data: RegionalPromptData, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + orig_attn_mask: torch.Tensor, + ) -> torch.Tensor: + # Infer the current spatial dimensions from the shape of `hidden_states`. + _, query_seq_len, _ = hidden_states.shape + per_prompt_query_masks = regional_prompt_data.masks + _, _, h, w = per_prompt_query_masks.shape + + # Downsample by factors of 2 until the spatial dimensions match the current query sequence length. + scale_factor = 1 + while h * w > query_seq_len: + scale_factor *= 2 + h //= 2 + w //= 2 + assert h * w == query_seq_len + + # Apply max-pooling to resize the masks to the target spatial dimensions. + # TODO(ryand): We should be able to pre-compute all of the mask sizes. There's a lot of redundant computation + # here. + per_prompt_query_masks = F.max_pool2d(per_prompt_query_masks, kernel_size=scale_factor, stride=scale_factor) + batch_size, num_prompts, resized_h, resized_w = per_prompt_query_masks.shape + assert resized_h == h and resized_w == w + + # Flatten the spatial dimensions of the masks. + # Shape after reshape: (batch_size, num_prompts, query_seq_len) + per_prompt_query_masks = per_prompt_query_masks.reshape((batch_size, num_prompts, -1, 1)) + + # Create a cross-attention mask for each prompt that selects the corresponding embeddings from + # `encoder_hidden_states`. + + # attn_mask shape: (batch_size, query_seq_len, key_seq_len) + _, key_seq_len, _ = encoder_hidden_states.shape + # HACK(ryand): We are assuming the batch size. + attn_mask = torch.zeros((2, query_seq_len, key_seq_len), device=hidden_states.device) + + for i, embedding_range in enumerate(regional_prompt_data.embedding_ranges): + # HACK(ryand): We are assuming that batch 0 is unconditioned and batch 1 is conditioned. This is too fragile + # to merge. + attn_mask[1, :, embedding_range.start : embedding_range.end] = per_prompt_query_masks[:, i, :, :] + + # HACK(ryand): We are assuming that batch 0 is unconditioned and batch 1 is conditioned. We are also assuming + # the intent of attn_mask. And we shouldn't have to do this awkward mask type conversion. + orig_mask = torch.zeros_like(orig_attn_mask[0, ...]) + orig_mask[orig_attn_mask[0, ...] > -0.5] = 1.0 + attn_mask[0, ...] = orig_mask + + return attn_mask > 0.5 + def __call__( self, attn: Attention, @@ -56,6 +108,13 @@ class RegionalPromptAttnProcessor2_0(AttnProcessor2_0): hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) + if encoder_hidden_states is not None: + assert regional_prompt_data is not None + assert attention_mask is not None + attention_mask = self._prepare_regional_prompt_attention_mask( + regional_prompt_data, hidden_states, encoder_hidden_states, attention_mask + ) + if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be