(minor) Remove commented code.

This commit is contained in:
Ryan Dick 2024-03-05 09:12:17 -05:00
parent a665f20fb5
commit bcfb43e5f0

View File

@ -5,25 +5,6 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
TextConditioningRegions,
)
# Stages:
# - Convert image masks to spatial masks at all downsampling factors.
# - Decision: Max pooling? Nearest? Other?
# - Should definitely be shared across all denoising steps - that should be easy.
# - Convert spatial masks to cross-attention masks.
# - This should ideally be shared across all denoising steps, but preparing the masks requires knowing the max_key_seq_len.
# - Could create it just-in-time and them cache the result
# - Convert spatial masks to self-attention masks.
# - This should be shared across all denoising steps.
# - Shape depends only on spatial resolution and downsampling factors.
# - Convert cross-attention binary mask to score mask.
# - Convert self-attention binary mask to score mask.
#
# If we wanted a time schedule for level of attenuation, we would apply that in the attention layer.
# Pre-compute the spatial masks, because that's easy.
# Compute the other stuff as it's requested. Add caching if we find that it's slow.
class RegionalPromptData:
def __init__(self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8):
@ -78,75 +59,6 @@ class RegionalPromptData:
batch_sample_masks = F.max_pool2d(batch_sample_masks, kernel_size=2, stride=2)
return batch_sample_masks_by_seq_len
# Merge the batch_attn_masks_by_seq_len into a single attn_masks_by_seq_len.
# for query_seq_len in batch_sample_masks_by_seq_len[0].keys():
# masks_by_seq_len[query_seq_len] = torch.cat(
# [batch_sample_masks_by_seq_len[i][query_seq_len] for i in range(len(batch_sample_masks_by_seq_len))]
# )
# return masks_by_seq_len
# @classmethod
# def from_regions(
# cls,
# regions: list[TextConditioningRegions],
# key_seq_len: int,
# max_downscale_factor: int = 8,
# ):
# """Construct a `RegionalPromptData` object.
# Args:
# regions (list[TextConditioningRegions]): regions[i] contains the prompt regions for the i'th sample in the
# batch.
# """
# attn_masks_by_seq_len = {}
# # batch_attn_mask_by_seq_len[b][s] contains the attention mask for the b'th batch sample with a query sequence
# # length of s.
# batch_attn_masks_by_seq_len: list[dict[int, torch.Tensor]] = []
# for batch_sample_regions in regions:
# batch_attn_masks_by_seq_len.append({})
# # Convert the bool masks to float masks so that max pooling can be applied.
# batch_masks = batch_sample_regions.masks.to(dtype=torch.float32)
# # Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.
# downscale_factor = 1
# while downscale_factor <= max_downscale_factor:
# _, num_prompts, h, w = batch_masks.shape
# query_seq_len = h * w
# # Flatten the spatial dimensions of the mask by reshaping to (1, num_prompts, query_seq_len, 1).
# batch_query_masks = batch_masks.reshape((1, num_prompts, -1, 1))
# # Create a cross-attention mask for each prompt that selects the corresponding embeddings from
# # `encoder_hidden_states`.
# # attn_mask shape: (batch_size, query_seq_len, key_seq_len)
# # TODO(ryand): What device / dtype should this be?
# attn_mask = torch.zeros((1, query_seq_len, key_seq_len))
# for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges):
# attn_mask[0, :, embedding_range.start : embedding_range.end] = batch_query_masks[
# :, prompt_idx, :, :
# ]
# batch_attn_masks_by_seq_len[-1][query_seq_len] = attn_mask
# downscale_factor *= 2
# if downscale_factor <= max_downscale_factor:
# # We use max pooling because we downscale to a pretty low resolution, so we don't want small prompt
# # regions to be lost entirely.
# # TODO(ryand): In the future, we may want to experiment with other downsampling methods, and could
# # potentially use a weighted mask rather than a binary mask.
# batch_masks = F.max_pool2d(batch_masks, kernel_size=2, stride=2)
# # Merge the batch_attn_masks_by_seq_len into a single attn_masks_by_seq_len.
# for query_seq_len in batch_attn_masks_by_seq_len[0].keys():
# attn_masks_by_seq_len[query_seq_len] = torch.cat(
# [batch_attn_masks_by_seq_len[i][query_seq_len] for i in range(len(batch_attn_masks_by_seq_len))]
# )
# return cls(attn_masks_by_seq_len)
def get_cross_attn_mask(self, query_seq_len: int, key_seq_len: int) -> torch.Tensor:
"""Get the cross-attention mask for the given query sequence length.
@ -232,20 +144,6 @@ class RegionalPromptData:
* batch_sample_regions.positive_self_attn_mask_scores[prompt_idx]
)
# attn_mask_min = attn_mask[batch_idx].min()
# attn_mask_max = attn_mask[batch_idx].max()
# attn_mask_range = attn_mask_max - attn_mask_min
# if abs(attn_mask_range) < 0.0001:
# # All attn_mask value in this batch sample are the same, set the attn_mask to 0.0s (to avoid divide by
# # zero in the normalization).
# attn_mask[batch_idx] = attn_mask[batch_idx] * 0.0
# else:
# # Normalize from range [attn_mask_min, attn_mask_max] to [0, self.self_attn_score_range].
# attn_mask[batch_idx] = (
# (attn_mask[batch_idx] - attn_mask_min) / attn_mask_range * self.self_attn_score_range
# )
attn_mask_min = attn_mask[batch_idx].min()
# Adjust so that the minimum value is 0.0 regardless of whether all pixels are covered or not.