mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
(minor) Remove commented code.
This commit is contained in:
parent
a665f20fb5
commit
bcfb43e5f0
@ -5,25 +5,6 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
|||||||
TextConditioningRegions,
|
TextConditioningRegions,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Stages:
|
|
||||||
# - Convert image masks to spatial masks at all downsampling factors.
|
|
||||||
# - Decision: Max pooling? Nearest? Other?
|
|
||||||
# - Should definitely be shared across all denoising steps - that should be easy.
|
|
||||||
# - Convert spatial masks to cross-attention masks.
|
|
||||||
# - This should ideally be shared across all denoising steps, but preparing the masks requires knowing the max_key_seq_len.
|
|
||||||
# - Could create it just-in-time and them cache the result
|
|
||||||
# - Convert spatial masks to self-attention masks.
|
|
||||||
# - This should be shared across all denoising steps.
|
|
||||||
# - Shape depends only on spatial resolution and downsampling factors.
|
|
||||||
# - Convert cross-attention binary mask to score mask.
|
|
||||||
# - Convert self-attention binary mask to score mask.
|
|
||||||
#
|
|
||||||
# If we wanted a time schedule for level of attenuation, we would apply that in the attention layer.
|
|
||||||
|
|
||||||
|
|
||||||
# Pre-compute the spatial masks, because that's easy.
|
|
||||||
# Compute the other stuff as it's requested. Add caching if we find that it's slow.
|
|
||||||
|
|
||||||
|
|
||||||
class RegionalPromptData:
|
class RegionalPromptData:
|
||||||
def __init__(self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8):
|
def __init__(self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8):
|
||||||
@ -78,75 +59,6 @@ class RegionalPromptData:
|
|||||||
batch_sample_masks = F.max_pool2d(batch_sample_masks, kernel_size=2, stride=2)
|
batch_sample_masks = F.max_pool2d(batch_sample_masks, kernel_size=2, stride=2)
|
||||||
|
|
||||||
return batch_sample_masks_by_seq_len
|
return batch_sample_masks_by_seq_len
|
||||||
# Merge the batch_attn_masks_by_seq_len into a single attn_masks_by_seq_len.
|
|
||||||
# for query_seq_len in batch_sample_masks_by_seq_len[0].keys():
|
|
||||||
# masks_by_seq_len[query_seq_len] = torch.cat(
|
|
||||||
# [batch_sample_masks_by_seq_len[i][query_seq_len] for i in range(len(batch_sample_masks_by_seq_len))]
|
|
||||||
# )
|
|
||||||
|
|
||||||
# return masks_by_seq_len
|
|
||||||
|
|
||||||
# @classmethod
|
|
||||||
# def from_regions(
|
|
||||||
# cls,
|
|
||||||
# regions: list[TextConditioningRegions],
|
|
||||||
# key_seq_len: int,
|
|
||||||
# max_downscale_factor: int = 8,
|
|
||||||
# ):
|
|
||||||
# """Construct a `RegionalPromptData` object.
|
|
||||||
|
|
||||||
# Args:
|
|
||||||
# regions (list[TextConditioningRegions]): regions[i] contains the prompt regions for the i'th sample in the
|
|
||||||
# batch.
|
|
||||||
# """
|
|
||||||
# attn_masks_by_seq_len = {}
|
|
||||||
|
|
||||||
# # batch_attn_mask_by_seq_len[b][s] contains the attention mask for the b'th batch sample with a query sequence
|
|
||||||
# # length of s.
|
|
||||||
# batch_attn_masks_by_seq_len: list[dict[int, torch.Tensor]] = []
|
|
||||||
# for batch_sample_regions in regions:
|
|
||||||
# batch_attn_masks_by_seq_len.append({})
|
|
||||||
|
|
||||||
# # Convert the bool masks to float masks so that max pooling can be applied.
|
|
||||||
# batch_masks = batch_sample_regions.masks.to(dtype=torch.float32)
|
|
||||||
|
|
||||||
# # Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.
|
|
||||||
# downscale_factor = 1
|
|
||||||
# while downscale_factor <= max_downscale_factor:
|
|
||||||
# _, num_prompts, h, w = batch_masks.shape
|
|
||||||
# query_seq_len = h * w
|
|
||||||
|
|
||||||
# # Flatten the spatial dimensions of the mask by reshaping to (1, num_prompts, query_seq_len, 1).
|
|
||||||
# batch_query_masks = batch_masks.reshape((1, num_prompts, -1, 1))
|
|
||||||
|
|
||||||
# # Create a cross-attention mask for each prompt that selects the corresponding embeddings from
|
|
||||||
# # `encoder_hidden_states`.
|
|
||||||
# # attn_mask shape: (batch_size, query_seq_len, key_seq_len)
|
|
||||||
# # TODO(ryand): What device / dtype should this be?
|
|
||||||
# attn_mask = torch.zeros((1, query_seq_len, key_seq_len))
|
|
||||||
|
|
||||||
# for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges):
|
|
||||||
# attn_mask[0, :, embedding_range.start : embedding_range.end] = batch_query_masks[
|
|
||||||
# :, prompt_idx, :, :
|
|
||||||
# ]
|
|
||||||
|
|
||||||
# batch_attn_masks_by_seq_len[-1][query_seq_len] = attn_mask
|
|
||||||
|
|
||||||
# downscale_factor *= 2
|
|
||||||
# if downscale_factor <= max_downscale_factor:
|
|
||||||
# # We use max pooling because we downscale to a pretty low resolution, so we don't want small prompt
|
|
||||||
# # regions to be lost entirely.
|
|
||||||
# # TODO(ryand): In the future, we may want to experiment with other downsampling methods, and could
|
|
||||||
# # potentially use a weighted mask rather than a binary mask.
|
|
||||||
# batch_masks = F.max_pool2d(batch_masks, kernel_size=2, stride=2)
|
|
||||||
|
|
||||||
# # Merge the batch_attn_masks_by_seq_len into a single attn_masks_by_seq_len.
|
|
||||||
# for query_seq_len in batch_attn_masks_by_seq_len[0].keys():
|
|
||||||
# attn_masks_by_seq_len[query_seq_len] = torch.cat(
|
|
||||||
# [batch_attn_masks_by_seq_len[i][query_seq_len] for i in range(len(batch_attn_masks_by_seq_len))]
|
|
||||||
# )
|
|
||||||
|
|
||||||
# return cls(attn_masks_by_seq_len)
|
|
||||||
|
|
||||||
def get_cross_attn_mask(self, query_seq_len: int, key_seq_len: int) -> torch.Tensor:
|
def get_cross_attn_mask(self, query_seq_len: int, key_seq_len: int) -> torch.Tensor:
|
||||||
"""Get the cross-attention mask for the given query sequence length.
|
"""Get the cross-attention mask for the given query sequence length.
|
||||||
@ -232,20 +144,6 @@ class RegionalPromptData:
|
|||||||
* batch_sample_regions.positive_self_attn_mask_scores[prompt_idx]
|
* batch_sample_regions.positive_self_attn_mask_scores[prompt_idx]
|
||||||
)
|
)
|
||||||
|
|
||||||
# attn_mask_min = attn_mask[batch_idx].min()
|
|
||||||
# attn_mask_max = attn_mask[batch_idx].max()
|
|
||||||
# attn_mask_range = attn_mask_max - attn_mask_min
|
|
||||||
|
|
||||||
# if abs(attn_mask_range) < 0.0001:
|
|
||||||
# # All attn_mask value in this batch sample are the same, set the attn_mask to 0.0s (to avoid divide by
|
|
||||||
# # zero in the normalization).
|
|
||||||
# attn_mask[batch_idx] = attn_mask[batch_idx] * 0.0
|
|
||||||
# else:
|
|
||||||
# # Normalize from range [attn_mask_min, attn_mask_max] to [0, self.self_attn_score_range].
|
|
||||||
# attn_mask[batch_idx] = (
|
|
||||||
# (attn_mask[batch_idx] - attn_mask_min) / attn_mask_range * self.self_attn_score_range
|
|
||||||
# )
|
|
||||||
|
|
||||||
attn_mask_min = attn_mask[batch_idx].min()
|
attn_mask_min = attn_mask[batch_idx].min()
|
||||||
|
|
||||||
# Adjust so that the minimum value is 0.0 regardless of whether all pixels are covered or not.
|
# Adjust so that the minimum value is 0.0 regardless of whether all pixels are covered or not.
|
||||||
|
Loading…
Reference in New Issue
Block a user