mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Use the correct device / dtype for RegionalPromptData calculations.
This commit is contained in:
parent
bcfb43e5f0
commit
41e1a9f202
@ -109,17 +109,8 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
prompt_region_attention_mask = regional_prompt_data.get_self_attn_mask(
|
prompt_region_attention_mask = regional_prompt_data.get_self_attn_mask(
|
||||||
query_seq_len=query_seq_len,
|
query_seq_len=query_seq_len,
|
||||||
percent_through=percent_through,
|
percent_through=percent_through,
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO(ryand): Avoid redundant type/device conversion here.
|
|
||||||
# prompt_region_attention_mask = prompt_region_attention_mask.to(
|
|
||||||
# dtype=hidden_states.dtype, device=hidden_states.device
|
|
||||||
# )
|
|
||||||
# prompt_region_attention_mask[prompt_region_attention_mask < 0.5] = -0.5
|
|
||||||
# prompt_region_attention_mask[prompt_region_attention_mask >= 0.5] = 0.0
|
|
||||||
|
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = prompt_region_attention_mask
|
attention_mask = prompt_region_attention_mask
|
||||||
else:
|
else:
|
||||||
|
@ -7,30 +7,37 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
|||||||
|
|
||||||
|
|
||||||
class RegionalPromptData:
|
class RegionalPromptData:
|
||||||
def __init__(self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8):
|
def __init__(
|
||||||
|
self,
|
||||||
|
regions: list[TextConditioningRegions],
|
||||||
|
device: torch.device,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
max_downscale_factor: int = 8,
|
||||||
|
):
|
||||||
"""Initialize a `RegionalPromptData` object.
|
"""Initialize a `RegionalPromptData` object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
regions (list[TextConditioningRegions]): regions[i] contains the prompt regions for the i'th sample in the
|
regions (list[TextConditioningRegions]): regions[i] contains the prompt regions for the i'th sample in the
|
||||||
batch.
|
batch.
|
||||||
max_downscale_factor: The maximum downscale factor to use when preparing the spatial masks.
|
device (torch.device): The device to use for the attention masks.
|
||||||
|
dtype (torch.dtype): The data type to use for the attention masks.
|
||||||
|
max_downscale_factor: Spatial masks will be prepared for downscale factors from 1 to max_downscale_factor
|
||||||
|
in steps of 2x.
|
||||||
"""
|
"""
|
||||||
self._regions = regions
|
self._regions = regions
|
||||||
|
self._device = device
|
||||||
|
self._dtype = dtype
|
||||||
# self._spatial_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query
|
# self._spatial_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query
|
||||||
# sequence length of s.
|
# sequence length of s.
|
||||||
self._spatial_masks_by_seq_len: list[dict[int, torch.Tensor]] = self._prepare_spatial_masks(
|
self._spatial_masks_by_seq_len: list[dict[int, torch.Tensor]] = self._prepare_spatial_masks(
|
||||||
regions, max_downscale_factor
|
regions, max_downscale_factor
|
||||||
)
|
)
|
||||||
|
self._negative_cross_attn_mask_score = 0.0
|
||||||
self.negative_cross_attn_mask_score = -10000
|
|
||||||
|
|
||||||
def _prepare_spatial_masks(
|
def _prepare_spatial_masks(
|
||||||
self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8
|
self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8
|
||||||
) -> list[dict[int, torch.Tensor]]:
|
) -> list[dict[int, torch.Tensor]]:
|
||||||
"""Prepare the spatial masks for all downscaling factors."""
|
"""Prepare the spatial masks for all downscaling factors."""
|
||||||
# TODO(ryand): Pass in a list of downscale factors? IIRC, SDXL does not apply attention at all downscaling
|
|
||||||
# levels, but I need to double check that.
|
|
||||||
|
|
||||||
# batch_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query sequence length
|
# batch_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query sequence length
|
||||||
# of s.
|
# of s.
|
||||||
batch_sample_masks_by_seq_len: list[dict[int, torch.Tensor]] = []
|
batch_sample_masks_by_seq_len: list[dict[int, torch.Tensor]] = []
|
||||||
@ -39,12 +46,12 @@ class RegionalPromptData:
|
|||||||
batch_sample_masks_by_seq_len.append({})
|
batch_sample_masks_by_seq_len.append({})
|
||||||
|
|
||||||
# Convert the bool masks to float masks so that max pooling can be applied.
|
# Convert the bool masks to float masks so that max pooling can be applied.
|
||||||
batch_sample_masks = batch_sample_regions.masks.to(dtype=torch.float32)
|
batch_sample_masks = batch_sample_regions.masks.to(device=self._device, dtype=self._dtype)
|
||||||
|
|
||||||
# Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.
|
# Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.
|
||||||
downscale_factor = 1
|
downscale_factor = 1
|
||||||
while downscale_factor <= max_downscale_factor:
|
while downscale_factor <= max_downscale_factor:
|
||||||
b, num_prompts, h, w = batch_sample_masks.shape
|
b, _num_prompts, h, w = batch_sample_masks.shape
|
||||||
assert b == 1
|
assert b == 1
|
||||||
query_seq_len = h * w
|
query_seq_len = h * w
|
||||||
|
|
||||||
@ -78,7 +85,7 @@ class RegionalPromptData:
|
|||||||
batch_spatial_masks = [self._spatial_masks_by_seq_len[b][query_seq_len] for b in range(batch_size)]
|
batch_spatial_masks = [self._spatial_masks_by_seq_len[b][query_seq_len] for b in range(batch_size)]
|
||||||
|
|
||||||
# Create an empty attention mask with the correct shape.
|
# Create an empty attention mask with the correct shape.
|
||||||
attn_mask = torch.zeros((batch_size, query_seq_len, key_seq_len))
|
attn_mask = torch.zeros((batch_size, query_seq_len, key_seq_len), dtype=self._dtype, device=self._device)
|
||||||
|
|
||||||
for batch_idx in range(batch_size):
|
for batch_idx in range(batch_size):
|
||||||
batch_sample_spatial_masks = batch_spatial_masks[batch_idx]
|
batch_sample_spatial_masks = batch_spatial_masks[batch_idx]
|
||||||
@ -94,14 +101,12 @@ class RegionalPromptData:
|
|||||||
batch_sample_query_scores[
|
batch_sample_query_scores[
|
||||||
batch_sample_query_mask
|
batch_sample_query_mask
|
||||||
] = batch_sample_regions.positive_cross_attn_mask_scores[prompt_idx]
|
] = batch_sample_regions.positive_cross_attn_mask_scores[prompt_idx]
|
||||||
batch_sample_query_scores[~batch_sample_query_mask] = self.negative_cross_attn_mask_score # TODO(ryand)
|
batch_sample_query_scores[~batch_sample_query_mask] = self._negative_cross_attn_mask_score
|
||||||
attn_mask[batch_idx, :, embedding_range.start : embedding_range.end] = batch_sample_query_scores
|
attn_mask[batch_idx, :, embedding_range.start : embedding_range.end] = batch_sample_query_scores
|
||||||
|
|
||||||
return attn_mask
|
return attn_mask
|
||||||
|
|
||||||
def get_self_attn_mask(
|
def get_self_attn_mask(self, query_seq_len: int, percent_through: float) -> torch.Tensor:
|
||||||
self, query_seq_len: int, percent_through: float, device: torch.device, dtype: torch.dtype
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Get the self-attention mask for the given query sequence length.
|
"""Get the self-attention mask for the given query sequence length.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -113,15 +118,11 @@ class RegionalPromptData:
|
|||||||
dtype: float
|
dtype: float
|
||||||
The mask is a binary mask with values of 0.0 and 1.0.
|
The mask is a binary mask with values of 0.0 and 1.0.
|
||||||
"""
|
"""
|
||||||
# TODO(ryand): Manage dtype and device properly. There's a lot of inefficient copying, conversion, and
|
|
||||||
# unnecessary CPU operations happening in this class.
|
|
||||||
batch_size = len(self._spatial_masks_by_seq_len)
|
batch_size = len(self._spatial_masks_by_seq_len)
|
||||||
batch_spatial_masks = [
|
batch_spatial_masks = [self._spatial_masks_by_seq_len[b][query_seq_len] for b in range(batch_size)]
|
||||||
self._spatial_masks_by_seq_len[b][query_seq_len].to(device=device, dtype=dtype) for b in range(batch_size)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Create an empty attention mask with the correct shape.
|
# Create an empty attention mask with the correct shape.
|
||||||
attn_mask = torch.zeros((batch_size, query_seq_len, query_seq_len), dtype=dtype, device=device)
|
attn_mask = torch.zeros((batch_size, query_seq_len, query_seq_len), dtype=self._dtype, device=self._device)
|
||||||
|
|
||||||
for batch_idx in range(batch_size):
|
for batch_idx in range(batch_size):
|
||||||
batch_sample_spatial_masks = batch_spatial_masks[batch_idx]
|
batch_sample_spatial_masks = batch_spatial_masks[batch_idx]
|
||||||
|
@ -351,7 +351,9 @@ class InvokeAIDiffuserComponent:
|
|||||||
)
|
)
|
||||||
regions.append(r)
|
regions.append(r)
|
||||||
|
|
||||||
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(regions=regions)
|
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
|
||||||
|
regions=regions, device=x.device, dtype=x.dtype
|
||||||
|
)
|
||||||
cross_attention_kwargs["percent_through"] = percent_through
|
cross_attention_kwargs["percent_through"] = percent_through
|
||||||
|
|
||||||
both_results = self.model_forward_callback(
|
both_results = self.model_forward_callback(
|
||||||
@ -447,7 +449,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
# Prepare prompt regions for the unconditioned pass.
|
# Prepare prompt regions for the unconditioned pass.
|
||||||
if conditioning_data.uncond_regions is not None:
|
if conditioning_data.uncond_regions is not None:
|
||||||
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
|
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
|
||||||
regions=[conditioning_data.uncond_regions]
|
regions=[conditioning_data.uncond_regions], device=x.device, dtype=x.dtype
|
||||||
)
|
)
|
||||||
cross_attention_kwargs["percent_through"] = percent_through
|
cross_attention_kwargs["percent_through"] = percent_through
|
||||||
|
|
||||||
@ -493,7 +495,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
# Prepare prompt regions for the conditioned pass.
|
# Prepare prompt regions for the conditioned pass.
|
||||||
if conditioning_data.cond_regions is not None:
|
if conditioning_data.cond_regions is not None:
|
||||||
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
|
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
|
||||||
regions=[conditioning_data.cond_regions]
|
regions=[conditioning_data.cond_regions], device=x.device, dtype=x.dtype
|
||||||
)
|
)
|
||||||
cross_attention_kwargs["percent_through"] = percent_through
|
cross_attention_kwargs["percent_through"] = percent_through
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user