diff --git a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py index f7043761a8..d7a2809ce5 100644 --- a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py @@ -5,25 +5,6 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( TextConditioningRegions, ) -# Stages: -# - Convert image masks to spatial masks at all downsampling factors. -# - Decision: Max pooling? Nearest? Other? -# - Should definitely be shared across all denoising steps - that should be easy. -# - Convert spatial masks to cross-attention masks. -# - This should ideally be shared across all denoising steps, but preparing the masks requires knowing the max_key_seq_len. -# - Could create it just-in-time and them cache the result -# - Convert spatial masks to self-attention masks. -# - This should be shared across all denoising steps. -# - Shape depends only on spatial resolution and downsampling factors. -# - Convert cross-attention binary mask to score mask. -# - Convert self-attention binary mask to score mask. -# -# If we wanted a time schedule for level of attenuation, we would apply that in the attention layer. - - -# Pre-compute the spatial masks, because that's easy. -# Compute the other stuff as it's requested. Add caching if we find that it's slow. - class RegionalPromptData: def __init__(self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8): @@ -78,75 +59,6 @@ class RegionalPromptData: batch_sample_masks = F.max_pool2d(batch_sample_masks, kernel_size=2, stride=2) return batch_sample_masks_by_seq_len - # Merge the batch_attn_masks_by_seq_len into a single attn_masks_by_seq_len. - # for query_seq_len in batch_sample_masks_by_seq_len[0].keys(): - # masks_by_seq_len[query_seq_len] = torch.cat( - # [batch_sample_masks_by_seq_len[i][query_seq_len] for i in range(len(batch_sample_masks_by_seq_len))] - # ) - - # return masks_by_seq_len - - # @classmethod - # def from_regions( - # cls, - # regions: list[TextConditioningRegions], - # key_seq_len: int, - # max_downscale_factor: int = 8, - # ): - # """Construct a `RegionalPromptData` object. - - # Args: - # regions (list[TextConditioningRegions]): regions[i] contains the prompt regions for the i'th sample in the - # batch. - # """ - # attn_masks_by_seq_len = {} - - # # batch_attn_mask_by_seq_len[b][s] contains the attention mask for the b'th batch sample with a query sequence - # # length of s. - # batch_attn_masks_by_seq_len: list[dict[int, torch.Tensor]] = [] - # for batch_sample_regions in regions: - # batch_attn_masks_by_seq_len.append({}) - - # # Convert the bool masks to float masks so that max pooling can be applied. - # batch_masks = batch_sample_regions.masks.to(dtype=torch.float32) - - # # Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached. - # downscale_factor = 1 - # while downscale_factor <= max_downscale_factor: - # _, num_prompts, h, w = batch_masks.shape - # query_seq_len = h * w - - # # Flatten the spatial dimensions of the mask by reshaping to (1, num_prompts, query_seq_len, 1). - # batch_query_masks = batch_masks.reshape((1, num_prompts, -1, 1)) - - # # Create a cross-attention mask for each prompt that selects the corresponding embeddings from - # # `encoder_hidden_states`. - # # attn_mask shape: (batch_size, query_seq_len, key_seq_len) - # # TODO(ryand): What device / dtype should this be? - # attn_mask = torch.zeros((1, query_seq_len, key_seq_len)) - - # for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges): - # attn_mask[0, :, embedding_range.start : embedding_range.end] = batch_query_masks[ - # :, prompt_idx, :, : - # ] - - # batch_attn_masks_by_seq_len[-1][query_seq_len] = attn_mask - - # downscale_factor *= 2 - # if downscale_factor <= max_downscale_factor: - # # We use max pooling because we downscale to a pretty low resolution, so we don't want small prompt - # # regions to be lost entirely. - # # TODO(ryand): In the future, we may want to experiment with other downsampling methods, and could - # # potentially use a weighted mask rather than a binary mask. - # batch_masks = F.max_pool2d(batch_masks, kernel_size=2, stride=2) - - # # Merge the batch_attn_masks_by_seq_len into a single attn_masks_by_seq_len. - # for query_seq_len in batch_attn_masks_by_seq_len[0].keys(): - # attn_masks_by_seq_len[query_seq_len] = torch.cat( - # [batch_attn_masks_by_seq_len[i][query_seq_len] for i in range(len(batch_attn_masks_by_seq_len))] - # ) - - # return cls(attn_masks_by_seq_len) def get_cross_attn_mask(self, query_seq_len: int, key_seq_len: int) -> torch.Tensor: """Get the cross-attention mask for the given query sequence length. @@ -232,20 +144,6 @@ class RegionalPromptData: * batch_sample_regions.positive_self_attn_mask_scores[prompt_idx] ) - # attn_mask_min = attn_mask[batch_idx].min() - # attn_mask_max = attn_mask[batch_idx].max() - # attn_mask_range = attn_mask_max - attn_mask_min - - # if abs(attn_mask_range) < 0.0001: - # # All attn_mask value in this batch sample are the same, set the attn_mask to 0.0s (to avoid divide by - # # zero in the normalization). - # attn_mask[batch_idx] = attn_mask[batch_idx] * 0.0 - # else: - # # Normalize from range [attn_mask_min, attn_mask_max] to [0, self.self_attn_score_range]. - # attn_mask[batch_idx] = ( - # (attn_mask[batch_idx] - attn_mask_min) / attn_mask_range * self.self_attn_score_range - # ) - attn_mask_min = attn_mask[batch_idx].min() # Adjust so that the minimum value is 0.0 regardless of whether all pixels are covered or not.