From 41e1a9f2026856ed330d6916c57ec3e29981395d Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 5 Mar 2024 15:19:58 -0500 Subject: [PATCH] Use the correct device / dtype for RegionalPromptData calculations. --- .../diffusion/custom_attention.py | 9 ---- .../diffusion/regional_prompt_data.py | 41 ++++++++++--------- .../diffusion/shared_invokeai_diffusion.py | 8 ++-- 3 files changed, 26 insertions(+), 32 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py index 1b2c570ca0..c0ac6d63e1 100644 --- a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py +++ b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py @@ -109,17 +109,8 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): prompt_region_attention_mask = regional_prompt_data.get_self_attn_mask( query_seq_len=query_seq_len, percent_through=percent_through, - device=hidden_states.device, - dtype=hidden_states.dtype, ) - # TODO(ryand): Avoid redundant type/device conversion here. - # prompt_region_attention_mask = prompt_region_attention_mask.to( - # dtype=hidden_states.dtype, device=hidden_states.device - # ) - # prompt_region_attention_mask[prompt_region_attention_mask < 0.5] = -0.5 - # prompt_region_attention_mask[prompt_region_attention_mask >= 0.5] = 0.0 - if attention_mask is None: attention_mask = prompt_region_attention_mask else: diff --git a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py index d7a2809ce5..b7b7e61768 100644 --- a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py @@ -7,30 +7,37 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( class RegionalPromptData: - def __init__(self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8): + def __init__( + self, + regions: list[TextConditioningRegions], + device: torch.device, + dtype: torch.dtype, + max_downscale_factor: int = 8, + ): """Initialize a `RegionalPromptData` object. Args: regions (list[TextConditioningRegions]): regions[i] contains the prompt regions for the i'th sample in the batch. - max_downscale_factor: The maximum downscale factor to use when preparing the spatial masks. + device (torch.device): The device to use for the attention masks. + dtype (torch.dtype): The data type to use for the attention masks. + max_downscale_factor: Spatial masks will be prepared for downscale factors from 1 to max_downscale_factor + in steps of 2x. """ self._regions = regions + self._device = device + self._dtype = dtype # self._spatial_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query # sequence length of s. self._spatial_masks_by_seq_len: list[dict[int, torch.Tensor]] = self._prepare_spatial_masks( regions, max_downscale_factor ) - - self.negative_cross_attn_mask_score = -10000 + self._negative_cross_attn_mask_score = 0.0 def _prepare_spatial_masks( self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8 ) -> list[dict[int, torch.Tensor]]: """Prepare the spatial masks for all downscaling factors.""" - # TODO(ryand): Pass in a list of downscale factors? IIRC, SDXL does not apply attention at all downscaling - # levels, but I need to double check that. - # batch_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query sequence length # of s. batch_sample_masks_by_seq_len: list[dict[int, torch.Tensor]] = [] @@ -39,12 +46,12 @@ class RegionalPromptData: batch_sample_masks_by_seq_len.append({}) # Convert the bool masks to float masks so that max pooling can be applied. - batch_sample_masks = batch_sample_regions.masks.to(dtype=torch.float32) + batch_sample_masks = batch_sample_regions.masks.to(device=self._device, dtype=self._dtype) # Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached. downscale_factor = 1 while downscale_factor <= max_downscale_factor: - b, num_prompts, h, w = batch_sample_masks.shape + b, _num_prompts, h, w = batch_sample_masks.shape assert b == 1 query_seq_len = h * w @@ -78,7 +85,7 @@ class RegionalPromptData: batch_spatial_masks = [self._spatial_masks_by_seq_len[b][query_seq_len] for b in range(batch_size)] # Create an empty attention mask with the correct shape. - attn_mask = torch.zeros((batch_size, query_seq_len, key_seq_len)) + attn_mask = torch.zeros((batch_size, query_seq_len, key_seq_len), dtype=self._dtype, device=self._device) for batch_idx in range(batch_size): batch_sample_spatial_masks = batch_spatial_masks[batch_idx] @@ -94,14 +101,12 @@ class RegionalPromptData: batch_sample_query_scores[ batch_sample_query_mask ] = batch_sample_regions.positive_cross_attn_mask_scores[prompt_idx] - batch_sample_query_scores[~batch_sample_query_mask] = self.negative_cross_attn_mask_score # TODO(ryand) + batch_sample_query_scores[~batch_sample_query_mask] = self._negative_cross_attn_mask_score attn_mask[batch_idx, :, embedding_range.start : embedding_range.end] = batch_sample_query_scores return attn_mask - def get_self_attn_mask( - self, query_seq_len: int, percent_through: float, device: torch.device, dtype: torch.dtype - ) -> torch.Tensor: + def get_self_attn_mask(self, query_seq_len: int, percent_through: float) -> torch.Tensor: """Get the self-attention mask for the given query sequence length. Args: @@ -113,15 +118,11 @@ class RegionalPromptData: dtype: float The mask is a binary mask with values of 0.0 and 1.0. """ - # TODO(ryand): Manage dtype and device properly. There's a lot of inefficient copying, conversion, and - # unnecessary CPU operations happening in this class. batch_size = len(self._spatial_masks_by_seq_len) - batch_spatial_masks = [ - self._spatial_masks_by_seq_len[b][query_seq_len].to(device=device, dtype=dtype) for b in range(batch_size) - ] + batch_spatial_masks = [self._spatial_masks_by_seq_len[b][query_seq_len] for b in range(batch_size)] # Create an empty attention mask with the correct shape. - attn_mask = torch.zeros((batch_size, query_seq_len, query_seq_len), dtype=dtype, device=device) + attn_mask = torch.zeros((batch_size, query_seq_len, query_seq_len), dtype=self._dtype, device=self._device) for batch_idx in range(batch_size): batch_sample_spatial_masks = batch_spatial_masks[batch_idx] diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index baaf70e65c..80fb48d252 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -351,7 +351,9 @@ class InvokeAIDiffuserComponent: ) regions.append(r) - cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(regions=regions) + cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData( + regions=regions, device=x.device, dtype=x.dtype + ) cross_attention_kwargs["percent_through"] = percent_through both_results = self.model_forward_callback( @@ -447,7 +449,7 @@ class InvokeAIDiffuserComponent: # Prepare prompt regions for the unconditioned pass. if conditioning_data.uncond_regions is not None: cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData( - regions=[conditioning_data.uncond_regions] + regions=[conditioning_data.uncond_regions], device=x.device, dtype=x.dtype ) cross_attention_kwargs["percent_through"] = percent_through @@ -493,7 +495,7 @@ class InvokeAIDiffuserComponent: # Prepare prompt regions for the conditioned pass. if conditioning_data.cond_regions is not None: cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData( - regions=[conditioning_data.cond_regions] + regions=[conditioning_data.cond_regions], device=x.device, dtype=x.dtype ) cross_attention_kwargs["percent_through"] = percent_through