mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Use the correct device / dtype for RegionalPromptData calculations.
This commit is contained in:
parent
bcfb43e5f0
commit
41e1a9f202
@ -109,17 +109,8 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
prompt_region_attention_mask = regional_prompt_data.get_self_attn_mask(
|
||||
query_seq_len=query_seq_len,
|
||||
percent_through=percent_through,
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
# TODO(ryand): Avoid redundant type/device conversion here.
|
||||
# prompt_region_attention_mask = prompt_region_attention_mask.to(
|
||||
# dtype=hidden_states.dtype, device=hidden_states.device
|
||||
# )
|
||||
# prompt_region_attention_mask[prompt_region_attention_mask < 0.5] = -0.5
|
||||
# prompt_region_attention_mask[prompt_region_attention_mask >= 0.5] = 0.0
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = prompt_region_attention_mask
|
||||
else:
|
||||
|
@ -7,30 +7,37 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
|
||||
|
||||
class RegionalPromptData:
|
||||
def __init__(self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8):
|
||||
def __init__(
|
||||
self,
|
||||
regions: list[TextConditioningRegions],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
max_downscale_factor: int = 8,
|
||||
):
|
||||
"""Initialize a `RegionalPromptData` object.
|
||||
|
||||
Args:
|
||||
regions (list[TextConditioningRegions]): regions[i] contains the prompt regions for the i'th sample in the
|
||||
batch.
|
||||
max_downscale_factor: The maximum downscale factor to use when preparing the spatial masks.
|
||||
device (torch.device): The device to use for the attention masks.
|
||||
dtype (torch.dtype): The data type to use for the attention masks.
|
||||
max_downscale_factor: Spatial masks will be prepared for downscale factors from 1 to max_downscale_factor
|
||||
in steps of 2x.
|
||||
"""
|
||||
self._regions = regions
|
||||
self._device = device
|
||||
self._dtype = dtype
|
||||
# self._spatial_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query
|
||||
# sequence length of s.
|
||||
self._spatial_masks_by_seq_len: list[dict[int, torch.Tensor]] = self._prepare_spatial_masks(
|
||||
regions, max_downscale_factor
|
||||
)
|
||||
|
||||
self.negative_cross_attn_mask_score = -10000
|
||||
self._negative_cross_attn_mask_score = 0.0
|
||||
|
||||
def _prepare_spatial_masks(
|
||||
self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8
|
||||
) -> list[dict[int, torch.Tensor]]:
|
||||
"""Prepare the spatial masks for all downscaling factors."""
|
||||
# TODO(ryand): Pass in a list of downscale factors? IIRC, SDXL does not apply attention at all downscaling
|
||||
# levels, but I need to double check that.
|
||||
|
||||
# batch_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query sequence length
|
||||
# of s.
|
||||
batch_sample_masks_by_seq_len: list[dict[int, torch.Tensor]] = []
|
||||
@ -39,12 +46,12 @@ class RegionalPromptData:
|
||||
batch_sample_masks_by_seq_len.append({})
|
||||
|
||||
# Convert the bool masks to float masks so that max pooling can be applied.
|
||||
batch_sample_masks = batch_sample_regions.masks.to(dtype=torch.float32)
|
||||
batch_sample_masks = batch_sample_regions.masks.to(device=self._device, dtype=self._dtype)
|
||||
|
||||
# Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.
|
||||
downscale_factor = 1
|
||||
while downscale_factor <= max_downscale_factor:
|
||||
b, num_prompts, h, w = batch_sample_masks.shape
|
||||
b, _num_prompts, h, w = batch_sample_masks.shape
|
||||
assert b == 1
|
||||
query_seq_len = h * w
|
||||
|
||||
@ -78,7 +85,7 @@ class RegionalPromptData:
|
||||
batch_spatial_masks = [self._spatial_masks_by_seq_len[b][query_seq_len] for b in range(batch_size)]
|
||||
|
||||
# Create an empty attention mask with the correct shape.
|
||||
attn_mask = torch.zeros((batch_size, query_seq_len, key_seq_len))
|
||||
attn_mask = torch.zeros((batch_size, query_seq_len, key_seq_len), dtype=self._dtype, device=self._device)
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
batch_sample_spatial_masks = batch_spatial_masks[batch_idx]
|
||||
@ -94,14 +101,12 @@ class RegionalPromptData:
|
||||
batch_sample_query_scores[
|
||||
batch_sample_query_mask
|
||||
] = batch_sample_regions.positive_cross_attn_mask_scores[prompt_idx]
|
||||
batch_sample_query_scores[~batch_sample_query_mask] = self.negative_cross_attn_mask_score # TODO(ryand)
|
||||
batch_sample_query_scores[~batch_sample_query_mask] = self._negative_cross_attn_mask_score
|
||||
attn_mask[batch_idx, :, embedding_range.start : embedding_range.end] = batch_sample_query_scores
|
||||
|
||||
return attn_mask
|
||||
|
||||
def get_self_attn_mask(
|
||||
self, query_seq_len: int, percent_through: float, device: torch.device, dtype: torch.dtype
|
||||
) -> torch.Tensor:
|
||||
def get_self_attn_mask(self, query_seq_len: int, percent_through: float) -> torch.Tensor:
|
||||
"""Get the self-attention mask for the given query sequence length.
|
||||
|
||||
Args:
|
||||
@ -113,15 +118,11 @@ class RegionalPromptData:
|
||||
dtype: float
|
||||
The mask is a binary mask with values of 0.0 and 1.0.
|
||||
"""
|
||||
# TODO(ryand): Manage dtype and device properly. There's a lot of inefficient copying, conversion, and
|
||||
# unnecessary CPU operations happening in this class.
|
||||
batch_size = len(self._spatial_masks_by_seq_len)
|
||||
batch_spatial_masks = [
|
||||
self._spatial_masks_by_seq_len[b][query_seq_len].to(device=device, dtype=dtype) for b in range(batch_size)
|
||||
]
|
||||
batch_spatial_masks = [self._spatial_masks_by_seq_len[b][query_seq_len] for b in range(batch_size)]
|
||||
|
||||
# Create an empty attention mask with the correct shape.
|
||||
attn_mask = torch.zeros((batch_size, query_seq_len, query_seq_len), dtype=dtype, device=device)
|
||||
attn_mask = torch.zeros((batch_size, query_seq_len, query_seq_len), dtype=self._dtype, device=self._device)
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
batch_sample_spatial_masks = batch_spatial_masks[batch_idx]
|
||||
|
@ -351,7 +351,9 @@ class InvokeAIDiffuserComponent:
|
||||
)
|
||||
regions.append(r)
|
||||
|
||||
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(regions=regions)
|
||||
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
|
||||
regions=regions, device=x.device, dtype=x.dtype
|
||||
)
|
||||
cross_attention_kwargs["percent_through"] = percent_through
|
||||
|
||||
both_results = self.model_forward_callback(
|
||||
@ -447,7 +449,7 @@ class InvokeAIDiffuserComponent:
|
||||
# Prepare prompt regions for the unconditioned pass.
|
||||
if conditioning_data.uncond_regions is not None:
|
||||
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
|
||||
regions=[conditioning_data.uncond_regions]
|
||||
regions=[conditioning_data.uncond_regions], device=x.device, dtype=x.dtype
|
||||
)
|
||||
cross_attention_kwargs["percent_through"] = percent_through
|
||||
|
||||
@ -493,7 +495,7 @@ class InvokeAIDiffuserComponent:
|
||||
# Prepare prompt regions for the conditioned pass.
|
||||
if conditioning_data.cond_regions is not None:
|
||||
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
|
||||
regions=[conditioning_data.cond_regions]
|
||||
regions=[conditioning_data.cond_regions], device=x.device, dtype=x.dtype
|
||||
)
|
||||
cross_attention_kwargs["percent_through"] = percent_through
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user