From 942efa011eac947c511f64703b8948e8379f0cef Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 1 Mar 2024 18:43:32 -0500 Subject: [PATCH] Implement (very slow) self-attention regional masking. --- .../diffusion/custom_attention.py | 29 ++- .../diffusion/regional_prompt_data.py | 222 ++++++++++++++---- .../diffusion/shared_invokeai_diffusion.py | 4 +- 3 files changed, 195 insertions(+), 60 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py index 58aba2f709..632d6beeb0 100644 --- a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py +++ b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py @@ -92,15 +92,28 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): # End unmodified block from AttnProcessor2_0. # Handle regional prompt attention masks. - if is_cross_attention and regional_prompt_data is not None: + if regional_prompt_data is not None: _, query_seq_len, _ = hidden_states.shape - prompt_region_attention_mask = regional_prompt_data.get_attn_mask(query_seq_len) - # TODO(ryand): Avoid redundant type/device conversion here. - prompt_region_attention_mask = prompt_region_attention_mask.to( - dtype=encoder_hidden_states.dtype, device=encoder_hidden_states.device - ) - prompt_region_attention_mask[prompt_region_attention_mask < 0.5] = -10000.0 - prompt_region_attention_mask[prompt_region_attention_mask >= 0.5] = 0.0 + if is_cross_attention: + prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask( + query_seq_len=query_seq_len, key_seq_len=sequence_length + ) + # TODO(ryand): Avoid redundant type/device conversion here. + prompt_region_attention_mask = prompt_region_attention_mask.to( + dtype=hidden_states.dtype, device=hidden_states.device + ) + prompt_region_attention_mask[prompt_region_attention_mask < 0.5] = -10000.0 + prompt_region_attention_mask[prompt_region_attention_mask >= 0.5] = 0.0 + + else: # self-attention + prompt_region_attention_mask = regional_prompt_data.get_self_attn_mask(query_seq_len=query_seq_len) + + # TODO(ryand): Avoid redundant type/device conversion here. + prompt_region_attention_mask = prompt_region_attention_mask.to( + dtype=hidden_states.dtype, device=hidden_states.device + ) + # prompt_region_attention_mask[prompt_region_attention_mask < 0.5] = -0.5 + # prompt_region_attention_mask[prompt_region_attention_mask >= 0.5] = 0.0 if attention_mask is None: attention_mask = prompt_region_attention_mask diff --git a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py index 48b558d054..b5966c8b56 100644 --- a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py @@ -5,60 +5,67 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( TextConditioningRegions, ) +# Stages: +# - Convert image masks to spatial masks at all downsampling factors. +# - Decision: Max pooling? Nearest? Other? +# - Should definitely be shared across all denoising steps - that should be easy. +# - Convert spatial masks to cross-attention masks. +# - This should ideally be shared across all denoising steps, but preparing the masks requires knowing the max_key_seq_len. +# - Could create it just-in-time and them cache the result +# - Convert spatial masks to self-attention masks. +# - This should be shared across all denoising steps. +# - Shape depends only on spatial resolution and downsampling factors. +# - Convert cross-attention binary mask to score mask. +# - Convert self-attention binary mask to score mask. +# +# If we wanted a time schedule for level of attenuation, we would apply that in the attention layer. + + +# Pre-compute the spatial masks, because that's easy. +# Compute the other stuff as it's requested. Add caching if we find that it's slow. + class RegionalPromptData: - def __init__(self, attn_masks_by_seq_len: dict[int, torch.Tensor]): - self._attn_masks_by_seq_len = attn_masks_by_seq_len - - @classmethod - def from_regions( - cls, - regions: list[TextConditioningRegions], - key_seq_len: int, - # TODO(ryand): Pass in a list of downscale factors? - max_downscale_factor: int = 8, - ): - """Construct a `RegionalPromptData` object. + def __init__(self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8): + """Initialize a `RegionalPromptData` object. Args: regions (list[TextConditioningRegions]): regions[i] contains the prompt regions for the i'th sample in the batch. - key_seq_len (int): The sequence length of the expected prompt embeddings (which act as the key in the - cross-attention layers). This is most likely equal to the max embedding range end, but we pass it - explicitly to be sure. + max_downscale_factor: The maximum downscale factor to use when preparing the spatial masks. """ - attn_masks_by_seq_len = {} + self._regions = regions + # self._spatial_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query + # sequence length of s. + self._spatial_masks_by_seq_len: list[dict[int, torch.Tensor]] = self._prepare_spatial_masks( + regions, max_downscale_factor + ) + + def _prepare_spatial_masks( + self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8 + ) -> list[dict[int, torch.Tensor]]: + """Prepare the spatial masks for all downscaling factors.""" + # TODO(ryand): Pass in a list of downscale factors? IIRC, SDXL does not apply attention at all downscaling + # levels, but I need to double check that. + + # batch_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query sequence length + # of s. + batch_sample_masks_by_seq_len: list[dict[int, torch.Tensor]] = [] - # batch_attn_mask_by_seq_len[b][s] contains the attention mask for the b'th batch sample with a query sequence - # length of s. - batch_attn_masks_by_seq_len: list[dict[int, torch.Tensor]] = [] for batch_sample_regions in regions: - batch_attn_masks_by_seq_len.append({}) + batch_sample_masks_by_seq_len.append({}) # Convert the bool masks to float masks so that max pooling can be applied. - batch_masks = batch_sample_regions.masks.to(dtype=torch.float32) + batch_sample_masks = batch_sample_regions.masks.to(dtype=torch.float32) # Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached. downscale_factor = 1 while downscale_factor <= max_downscale_factor: - _, num_prompts, h, w = batch_masks.shape + b, num_prompts, h, w = batch_sample_masks.shape + assert b == 1 query_seq_len = h * w - # Flatten the spatial dimensions of the mask by reshaping to (1, num_prompts, query_seq_len, 1). - batch_query_masks = batch_masks.reshape((1, num_prompts, -1, 1)) - - # Create a cross-attention mask for each prompt that selects the corresponding embeddings from - # `encoder_hidden_states`. - # attn_mask shape: (batch_size, query_seq_len, key_seq_len) - # TODO(ryand): What device / dtype should this be? - attn_mask = torch.zeros((1, query_seq_len, key_seq_len)) - - for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges): - attn_mask[0, :, embedding_range.start : embedding_range.end] = batch_query_masks[ - :, prompt_idx, :, : - ] - - batch_attn_masks_by_seq_len[-1][query_seq_len] = attn_mask + batch_sample_masks_by_seq_len[-1][query_seq_len] = batch_sample_masks downscale_factor *= 2 if downscale_factor <= max_downscale_factor: @@ -66,23 +73,86 @@ class RegionalPromptData: # regions to be lost entirely. # TODO(ryand): In the future, we may want to experiment with other downsampling methods, and could # potentially use a weighted mask rather than a binary mask. - batch_masks = F.max_pool2d(batch_masks, kernel_size=2, stride=2) + batch_sample_masks = F.max_pool2d(batch_sample_masks, kernel_size=2, stride=2) + return batch_sample_masks_by_seq_len # Merge the batch_attn_masks_by_seq_len into a single attn_masks_by_seq_len. - for query_seq_len in batch_attn_masks_by_seq_len[0].keys(): - attn_masks_by_seq_len[query_seq_len] = torch.cat( - [batch_attn_masks_by_seq_len[i][query_seq_len] for i in range(len(batch_attn_masks_by_seq_len))] - ) + # for query_seq_len in batch_sample_masks_by_seq_len[0].keys(): + # masks_by_seq_len[query_seq_len] = torch.cat( + # [batch_sample_masks_by_seq_len[i][query_seq_len] for i in range(len(batch_sample_masks_by_seq_len))] + # ) - return cls(attn_masks_by_seq_len) + # return masks_by_seq_len - def get_attn_mask(self, query_seq_len: int) -> torch.Tensor: - """Get the attention mask for the given query sequence length (i.e. downscaling level). + # @classmethod + # def from_regions( + # cls, + # regions: list[TextConditioningRegions], + # key_seq_len: int, + # max_downscale_factor: int = 8, + # ): + # """Construct a `RegionalPromptData` object. - This is called during cross-attention, where query_seq_len is the length of the flattened spatial features, so - it changes at each downscaling level in the model. + # Args: + # regions (list[TextConditioningRegions]): regions[i] contains the prompt regions for the i'th sample in the + # batch. + # """ + # attn_masks_by_seq_len = {} - key_seq_len is the length of the expected prompt embeddings. + # # batch_attn_mask_by_seq_len[b][s] contains the attention mask for the b'th batch sample with a query sequence + # # length of s. + # batch_attn_masks_by_seq_len: list[dict[int, torch.Tensor]] = [] + # for batch_sample_regions in regions: + # batch_attn_masks_by_seq_len.append({}) + + # # Convert the bool masks to float masks so that max pooling can be applied. + # batch_masks = batch_sample_regions.masks.to(dtype=torch.float32) + + # # Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached. + # downscale_factor = 1 + # while downscale_factor <= max_downscale_factor: + # _, num_prompts, h, w = batch_masks.shape + # query_seq_len = h * w + + # # Flatten the spatial dimensions of the mask by reshaping to (1, num_prompts, query_seq_len, 1). + # batch_query_masks = batch_masks.reshape((1, num_prompts, -1, 1)) + + # # Create a cross-attention mask for each prompt that selects the corresponding embeddings from + # # `encoder_hidden_states`. + # # attn_mask shape: (batch_size, query_seq_len, key_seq_len) + # # TODO(ryand): What device / dtype should this be? + # attn_mask = torch.zeros((1, query_seq_len, key_seq_len)) + + # for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges): + # attn_mask[0, :, embedding_range.start : embedding_range.end] = batch_query_masks[ + # :, prompt_idx, :, : + # ] + + # batch_attn_masks_by_seq_len[-1][query_seq_len] = attn_mask + + # downscale_factor *= 2 + # if downscale_factor <= max_downscale_factor: + # # We use max pooling because we downscale to a pretty low resolution, so we don't want small prompt + # # regions to be lost entirely. + # # TODO(ryand): In the future, we may want to experiment with other downsampling methods, and could + # # potentially use a weighted mask rather than a binary mask. + # batch_masks = F.max_pool2d(batch_masks, kernel_size=2, stride=2) + + # # Merge the batch_attn_masks_by_seq_len into a single attn_masks_by_seq_len. + # for query_seq_len in batch_attn_masks_by_seq_len[0].keys(): + # attn_masks_by_seq_len[query_seq_len] = torch.cat( + # [batch_attn_masks_by_seq_len[i][query_seq_len] for i in range(len(batch_attn_masks_by_seq_len))] + # ) + + # return cls(attn_masks_by_seq_len) + + def get_cross_attn_mask(self, query_seq_len: int, key_seq_len: int) -> torch.Tensor: + """Get the cross-attention mask for the given query sequence length. + + Args: + query_seq_len: The length of the flattened spatial features at the current downscaling level. + key_seq_len (int): The sequence length of the prompt embeddings (which act as the key in the cross-attention + layers). This is most likely equal to the max embedding range end, but we pass it explicitly to be sure. Returns: torch.Tensor: The masks. @@ -90,4 +160,58 @@ class RegionalPromptData: dtype: float The mask is a binary mask with values of 0.0 and 1.0. """ - return self._attn_masks_by_seq_len[query_seq_len] + batch_size = len(self._spatial_masks_by_seq_len) + batch_spatial_masks = [self._spatial_masks_by_seq_len[b][query_seq_len] for b in range(batch_size)] + + # Create an empty attention mask with the correct shape. + attn_mask = torch.zeros((batch_size, query_seq_len, key_seq_len)) + + for batch_idx in range(batch_size): + batch_sample_spatial_masks = batch_spatial_masks[batch_idx] + batch_sample_regions = self._regions[batch_idx] + + # Flatten the spatial dimensions of the mask by reshaping to (1, num_prompts, query_seq_len, 1). + _, num_prompts, _, _ = batch_sample_spatial_masks.shape + batch_sample_query_masks = batch_sample_spatial_masks.view((1, num_prompts, query_seq_len, 1)) + + for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges): + attn_mask[batch_idx, :, embedding_range.start : embedding_range.end] = batch_sample_query_masks[ + 0, prompt_idx, :, : + ] + + return attn_mask + + def get_self_attn_mask(self, query_seq_len: int) -> torch.Tensor: + """Get the self-attention mask for the given query sequence length. + + Args: + query_seq_len: The length of the flattened spatial features at the current downscaling level. + + Returns: + torch.Tensor: The masks. + shape: (batch_size, query_seq_len, query_seq_len). + dtype: float + The mask is a binary mask with values of 0.0 and 1.0. + """ + batch_size = len(self._spatial_masks_by_seq_len) + batch_spatial_masks = [self._spatial_masks_by_seq_len[b][query_seq_len] for b in range(batch_size)] + + # Create an empty attention mask with the correct shape. + attn_mask = torch.zeros((batch_size, query_seq_len, query_seq_len)) + + for batch_idx in range(batch_size): + batch_sample_spatial_masks = batch_spatial_masks[batch_idx] + + # Flatten the spatial dimensions of the mask by reshaping to (1, num_prompts, query_seq_len, 1). + _, num_prompts, _, _ = batch_sample_spatial_masks.shape + batch_sample_query_masks = batch_sample_spatial_masks.view((1, num_prompts, query_seq_len, 1)) + + for prompt_idx in range(num_prompts): + prompt_query_mask = batch_sample_query_masks[0, prompt_idx, :, 0] # Shape: (query_seq_len,) + # Multiply a (1, query_seq_len) mask by a (query_seq_len, 1) mask to get a (query_seq_len, + # query_seq_len) mask. + attn_mask[batch_idx, :, :] += prompt_query_mask.unsqueeze(0) * prompt_query_mask.unsqueeze(1) * 0.5 + + # Since we were adding masks in the previous loop, we need to clamp the values to 1.0. + # attn_mask[attn_mask > 0.5] = 1.0 + return attn_mask diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index 674a091a2f..fd4806d024 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -346,9 +346,7 @@ class InvokeAIDiffuserComponent: regions.append(r) _, key_seq_len, _ = both_conditionings.shape - cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData.from_regions( - regions=regions, key_seq_len=key_seq_len - ) + cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(regions=regions) both_results = self.model_forward_callback( x_twice,