Use the correct device / dtype for RegionalPromptData calculations.

This commit is contained in:
Ryan Dick 2024-03-05 15:19:58 -05:00
parent bcfb43e5f0
commit 41e1a9f202
3 changed files with 26 additions and 32 deletions

View File

@ -109,17 +109,8 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
prompt_region_attention_mask = regional_prompt_data.get_self_attn_mask( prompt_region_attention_mask = regional_prompt_data.get_self_attn_mask(
query_seq_len=query_seq_len, query_seq_len=query_seq_len,
percent_through=percent_through, percent_through=percent_through,
device=hidden_states.device,
dtype=hidden_states.dtype,
) )
# TODO(ryand): Avoid redundant type/device conversion here.
# prompt_region_attention_mask = prompt_region_attention_mask.to(
# dtype=hidden_states.dtype, device=hidden_states.device
# )
# prompt_region_attention_mask[prompt_region_attention_mask < 0.5] = -0.5
# prompt_region_attention_mask[prompt_region_attention_mask >= 0.5] = 0.0
if attention_mask is None: if attention_mask is None:
attention_mask = prompt_region_attention_mask attention_mask = prompt_region_attention_mask
else: else:

View File

@ -7,30 +7,37 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
class RegionalPromptData: class RegionalPromptData:
def __init__(self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8): def __init__(
self,
regions: list[TextConditioningRegions],
device: torch.device,
dtype: torch.dtype,
max_downscale_factor: int = 8,
):
"""Initialize a `RegionalPromptData` object. """Initialize a `RegionalPromptData` object.
Args: Args:
regions (list[TextConditioningRegions]): regions[i] contains the prompt regions for the i'th sample in the regions (list[TextConditioningRegions]): regions[i] contains the prompt regions for the i'th sample in the
batch. batch.
max_downscale_factor: The maximum downscale factor to use when preparing the spatial masks. device (torch.device): The device to use for the attention masks.
dtype (torch.dtype): The data type to use for the attention masks.
max_downscale_factor: Spatial masks will be prepared for downscale factors from 1 to max_downscale_factor
in steps of 2x.
""" """
self._regions = regions self._regions = regions
self._device = device
self._dtype = dtype
# self._spatial_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query # self._spatial_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query
# sequence length of s. # sequence length of s.
self._spatial_masks_by_seq_len: list[dict[int, torch.Tensor]] = self._prepare_spatial_masks( self._spatial_masks_by_seq_len: list[dict[int, torch.Tensor]] = self._prepare_spatial_masks(
regions, max_downscale_factor regions, max_downscale_factor
) )
self._negative_cross_attn_mask_score = 0.0
self.negative_cross_attn_mask_score = -10000
def _prepare_spatial_masks( def _prepare_spatial_masks(
self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8 self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8
) -> list[dict[int, torch.Tensor]]: ) -> list[dict[int, torch.Tensor]]:
"""Prepare the spatial masks for all downscaling factors.""" """Prepare the spatial masks for all downscaling factors."""
# TODO(ryand): Pass in a list of downscale factors? IIRC, SDXL does not apply attention at all downscaling
# levels, but I need to double check that.
# batch_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query sequence length # batch_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query sequence length
# of s. # of s.
batch_sample_masks_by_seq_len: list[dict[int, torch.Tensor]] = [] batch_sample_masks_by_seq_len: list[dict[int, torch.Tensor]] = []
@ -39,12 +46,12 @@ class RegionalPromptData:
batch_sample_masks_by_seq_len.append({}) batch_sample_masks_by_seq_len.append({})
# Convert the bool masks to float masks so that max pooling can be applied. # Convert the bool masks to float masks so that max pooling can be applied.
batch_sample_masks = batch_sample_regions.masks.to(dtype=torch.float32) batch_sample_masks = batch_sample_regions.masks.to(device=self._device, dtype=self._dtype)
# Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached. # Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.
downscale_factor = 1 downscale_factor = 1
while downscale_factor <= max_downscale_factor: while downscale_factor <= max_downscale_factor:
b, num_prompts, h, w = batch_sample_masks.shape b, _num_prompts, h, w = batch_sample_masks.shape
assert b == 1 assert b == 1
query_seq_len = h * w query_seq_len = h * w
@ -78,7 +85,7 @@ class RegionalPromptData:
batch_spatial_masks = [self._spatial_masks_by_seq_len[b][query_seq_len] for b in range(batch_size)] batch_spatial_masks = [self._spatial_masks_by_seq_len[b][query_seq_len] for b in range(batch_size)]
# Create an empty attention mask with the correct shape. # Create an empty attention mask with the correct shape.
attn_mask = torch.zeros((batch_size, query_seq_len, key_seq_len)) attn_mask = torch.zeros((batch_size, query_seq_len, key_seq_len), dtype=self._dtype, device=self._device)
for batch_idx in range(batch_size): for batch_idx in range(batch_size):
batch_sample_spatial_masks = batch_spatial_masks[batch_idx] batch_sample_spatial_masks = batch_spatial_masks[batch_idx]
@ -94,14 +101,12 @@ class RegionalPromptData:
batch_sample_query_scores[ batch_sample_query_scores[
batch_sample_query_mask batch_sample_query_mask
] = batch_sample_regions.positive_cross_attn_mask_scores[prompt_idx] ] = batch_sample_regions.positive_cross_attn_mask_scores[prompt_idx]
batch_sample_query_scores[~batch_sample_query_mask] = self.negative_cross_attn_mask_score # TODO(ryand) batch_sample_query_scores[~batch_sample_query_mask] = self._negative_cross_attn_mask_score
attn_mask[batch_idx, :, embedding_range.start : embedding_range.end] = batch_sample_query_scores attn_mask[batch_idx, :, embedding_range.start : embedding_range.end] = batch_sample_query_scores
return attn_mask return attn_mask
def get_self_attn_mask( def get_self_attn_mask(self, query_seq_len: int, percent_through: float) -> torch.Tensor:
self, query_seq_len: int, percent_through: float, device: torch.device, dtype: torch.dtype
) -> torch.Tensor:
"""Get the self-attention mask for the given query sequence length. """Get the self-attention mask for the given query sequence length.
Args: Args:
@ -113,15 +118,11 @@ class RegionalPromptData:
dtype: float dtype: float
The mask is a binary mask with values of 0.0 and 1.0. The mask is a binary mask with values of 0.0 and 1.0.
""" """
# TODO(ryand): Manage dtype and device properly. There's a lot of inefficient copying, conversion, and
# unnecessary CPU operations happening in this class.
batch_size = len(self._spatial_masks_by_seq_len) batch_size = len(self._spatial_masks_by_seq_len)
batch_spatial_masks = [ batch_spatial_masks = [self._spatial_masks_by_seq_len[b][query_seq_len] for b in range(batch_size)]
self._spatial_masks_by_seq_len[b][query_seq_len].to(device=device, dtype=dtype) for b in range(batch_size)
]
# Create an empty attention mask with the correct shape. # Create an empty attention mask with the correct shape.
attn_mask = torch.zeros((batch_size, query_seq_len, query_seq_len), dtype=dtype, device=device) attn_mask = torch.zeros((batch_size, query_seq_len, query_seq_len), dtype=self._dtype, device=self._device)
for batch_idx in range(batch_size): for batch_idx in range(batch_size):
batch_sample_spatial_masks = batch_spatial_masks[batch_idx] batch_sample_spatial_masks = batch_spatial_masks[batch_idx]

View File

@ -351,7 +351,9 @@ class InvokeAIDiffuserComponent:
) )
regions.append(r) regions.append(r)
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(regions=regions) cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
regions=regions, device=x.device, dtype=x.dtype
)
cross_attention_kwargs["percent_through"] = percent_through cross_attention_kwargs["percent_through"] = percent_through
both_results = self.model_forward_callback( both_results = self.model_forward_callback(
@ -447,7 +449,7 @@ class InvokeAIDiffuserComponent:
# Prepare prompt regions for the unconditioned pass. # Prepare prompt regions for the unconditioned pass.
if conditioning_data.uncond_regions is not None: if conditioning_data.uncond_regions is not None:
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData( cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
regions=[conditioning_data.uncond_regions] regions=[conditioning_data.uncond_regions], device=x.device, dtype=x.dtype
) )
cross_attention_kwargs["percent_through"] = percent_through cross_attention_kwargs["percent_through"] = percent_through
@ -493,7 +495,7 @@ class InvokeAIDiffuserComponent:
# Prepare prompt regions for the conditioned pass. # Prepare prompt regions for the conditioned pass.
if conditioning_data.cond_regions is not None: if conditioning_data.cond_regions is not None:
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData( cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
regions=[conditioning_data.cond_regions] regions=[conditioning_data.cond_regions], device=x.device, dtype=x.dtype
) )
cross_attention_kwargs["percent_through"] = percent_through cross_attention_kwargs["percent_through"] = percent_through