From 2966c8de2c89c40e24ce66b2dc9198b82d1350ac Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 27 Feb 2024 18:16:01 -0500 Subject: [PATCH] Handle conditioned and unconditioned text conditioning in the same way for regional prompt attention. --- .../diffusion/conditioning_data.py | 1 + .../diffusion/regional_prompt_attention.py | 165 +++++++++------ .../diffusion/shared_invokeai_diffusion.py | 196 ++++++++++++++++-- 3 files changed, 275 insertions(+), 87 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index ca215a5714..83bb78e42e 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -65,6 +65,7 @@ class IPAdapterConditioningInfo: @dataclass class ConditioningData: + # TODO(ryand): Support masks for unconditioned_embeddings. unconditioned_embeddings: BasicConditioningInfo text_embeddings: list[BasicConditioningInfo] text_embedding_masks: list[Optional[torch.Tensor]] diff --git a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_attention.py b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_attention.py index 2db4a24570..fe9bdbc951 100644 --- a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_attention.py +++ b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_attention.py @@ -15,78 +15,102 @@ class Range: end: int -@dataclass class RegionalPromptData: - # The region masks for each prompt. - # shape: (batch_size, num_prompts, height, width) - # dtype: float* - # The mask is set to 1.0 in regions where the prompt should be applied, and 0.0 elsewhere. - masks: torch.Tensor + def __init__(self, attn_masks_by_seq_len: dict[int, torch.Tensor]): + self._attn_masks_by_seq_len = attn_masks_by_seq_len - # The embedding ranges for each prompt. - # The i'th mask is applied to the embeddings in: - # encoder_hidden_states[:, embedding_ranges[i].start:embedding_ranges[i].end, :] - embedding_ranges: list[Range] + @classmethod + def from_masks_and_ranges( + cls, + masks: list[torch.Tensor], + embedding_ranges: list[list[Range]], + key_seq_len: int, + # TODO(ryand): Pass in a list of downscale factors? + max_downscale_factor: int = 8, + ): + """Construct a `RegionalPromptData` object. + + Args: + masks (list[torch.Tensor]): masks[i] contains the regions masks for the i'th sample in the batch. + The shape of masks[i] is (num_prompts, height, width), and dtype=bool. The mask is set to True in + regions where the prompt should be applied, and 0.0 elsewhere. + + embedding_ranges (list[list[Range]]): embedding_ranges[i][j] contains the embedding range for the j'th + prompt in the i'th batch sample. masks[i][j, ...] is applied to the embeddings in: + encoder_hidden_states[i, embedding_ranges[j].start:embedding_ranges[j].end, :]. + + key_seq_len (int): The sequence length of the expected prompt embeddings (which act as the key in the + cross-attention layers). + """ + attn_masks_by_seq_len = {} + + # batch_attn_mask_by_seq_len[b][s] contains the attention mask for the b'th batch sample with a query sequence + # length of s. + batch_attn_masks_by_seq_len: list[dict[int, torch.Tensor]] = [] + for batch_masks, batch_ranges in zip(masks, embedding_ranges, strict=True): + batch_attn_masks_by_seq_len.append({}) + + # Convert the bool masks to float masks so that max pooling can be applied. + batch_masks = batch_masks.to(dtype=torch.float32) + + # Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached. + downscale_factor = 1 + while downscale_factor <= max_downscale_factor: + _, num_prompts, h, w = batch_masks.shape + query_seq_len = h * w + + # Flatten the spatial dimensions of the mask by reshaping to (1, num_prompts, query_seq_len, 1). + batch_query_masks = batch_masks.reshape((1, num_prompts, -1, 1)) + + # Create a cross-attention mask for each prompt that selects the corresponding embeddings from + # `encoder_hidden_states`. + # attn_mask shape: (batch_size, query_seq_len, key_seq_len) + # TODO(ryand): What device / dtype should this be? + attn_mask = torch.zeros((1, query_seq_len, key_seq_len)) + + for prompt_idx, embedding_range in enumerate(batch_ranges): + attn_mask[0, :, embedding_range.start : embedding_range.end] = batch_query_masks[ + :, prompt_idx, :, : + ] + + batch_attn_masks_by_seq_len[-1][query_seq_len] = attn_mask + + downscale_factor *= 2 + if downscale_factor <= max_downscale_factor: + # We use max pooling because we downscale to a pretty low resolution, so we don't want small prompt + # regions to be lost entirely. + # TODO(ryand): In the future, we may want to experiment with other downsampling methods, and could + # potentially use a weighted mask rather than a binary mask. + batch_masks = F.max_pool2d(batch_masks, kernel_size=2, stride=2) + + # Merge the batch_attn_masks_by_seq_len into a single attn_masks_by_seq_len. + for query_seq_len in batch_attn_masks_by_seq_len[0].keys(): + attn_masks_by_seq_len[query_seq_len] = torch.cat( + [batch_attn_masks_by_seq_len[i][query_seq_len] for i in range(len(batch_attn_masks_by_seq_len))] + ) + + return cls(attn_masks_by_seq_len) + + def get_attn_mask(self, query_seq_len: int) -> torch.Tensor: + """Get the attention mask for the given query sequence length (i.e. downscaling level). + + This is called during cross-attention, where query_seq_len is the length of the flattened spatial features, so + it changes at each downscaling level in the model. + + key_seq_len is the length of the expected prompt embeddings. + + Returns: + torch.Tensor: The masks. + shape: (batch_size, query_seq_len, key_seq_len). + dtype: float + The mask is a binary mask with values of 0.0 and 1.0. + """ + return self._attn_masks_by_seq_len[query_seq_len] class RegionalPromptAttnProcessor2_0(AttnProcessor2_0): """An attention processor that supports regional prompt attention for PyTorch 2.0.""" - def _prepare_regional_prompt_attention_mask( - self, - regional_prompt_data: RegionalPromptData, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - orig_attn_mask: torch.Tensor, - ) -> torch.Tensor: - # Infer the current spatial dimensions from the shape of `hidden_states`. - _, query_seq_len, _ = hidden_states.shape - per_prompt_query_masks = regional_prompt_data.masks - _, _, h, w = per_prompt_query_masks.shape - - # Downsample by factors of 2 until the spatial dimensions match the current query sequence length. - scale_factor = 1 - while h * w > query_seq_len: - scale_factor *= 2 - h //= 2 - w //= 2 - assert h * w == query_seq_len - - # Convert the bool masks to float masks. - per_prompt_query_masks = per_prompt_query_masks.to(dtype=torch.float32) - - # Apply max-pooling to resize the masks to the target spatial dimensions. - # TODO(ryand): We should be able to pre-compute all of the mask sizes. There's a lot of redundant computation - # here. - per_prompt_query_masks = F.max_pool2d(per_prompt_query_masks, kernel_size=scale_factor, stride=scale_factor) - batch_size, num_prompts, resized_h, resized_w = per_prompt_query_masks.shape - assert resized_h == h and resized_w == w - - # Flatten the spatial dimensions of the masks. - # Shape after reshape: (batch_size, num_prompts, query_seq_len) - per_prompt_query_masks = per_prompt_query_masks.reshape((batch_size, num_prompts, -1, 1)) - - # Create a cross-attention mask for each prompt that selects the corresponding embeddings from - # `encoder_hidden_states`. - - # attn_mask shape: (batch_size, query_seq_len, key_seq_len) - _, key_seq_len, _ = encoder_hidden_states.shape - # HACK(ryand): We are assuming the batch size. - attn_mask = torch.zeros((2, query_seq_len, key_seq_len), device=hidden_states.device) - - for i, embedding_range in enumerate(regional_prompt_data.embedding_ranges): - # HACK(ryand): We are assuming that batch 0 is unconditioned and batch 1 is conditioned. This is too fragile - # to merge. - attn_mask[1, :, embedding_range.start : embedding_range.end] = per_prompt_query_masks[:, i, :, :] - - # HACK(ryand): We are assuming that batch 0 is unconditioned and batch 1 is conditioned. We are also assuming - # the intent of attn_mask. And we shouldn't have to do this awkward mask type conversion. - orig_mask = torch.zeros_like(orig_attn_mask[0, ...]) - orig_mask[orig_attn_mask[0, ...] > -0.5] = 1.0 - attn_mask[0, ...] = orig_mask - - return attn_mask > 0.5 - def __call__( self, attn: Attention, @@ -114,9 +138,16 @@ class RegionalPromptAttnProcessor2_0(AttnProcessor2_0): if encoder_hidden_states is not None: assert regional_prompt_data is not None assert attention_mask is not None - attention_mask = self._prepare_regional_prompt_attention_mask( - regional_prompt_data, hidden_states, encoder_hidden_states, attention_mask + _, query_seq_len, _ = hidden_states.shape + prompt_region_attention_mask = regional_prompt_data.get_attn_mask(query_seq_len) + # TODO(ryand): Avoid redundant type/device conversion here. + prompt_region_attention_mask = prompt_region_attention_mask.to( + dtype=attention_mask.dtype, device=attention_mask.device ) + prompt_region_attention_mask[prompt_region_attention_mask < 0.5] = -10000.0 + prompt_region_attention_mask[prompt_region_attention_mask >= 0.5] = 0.0 + + attention_mask = prompt_region_attention_mask + attention_mask if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index a8e663b49d..9b7c0b7fd8 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -36,6 +36,142 @@ ModelForwardCallback: TypeAlias = Union[ ] +class RegionalTextConditioningInfo: + def __init__( + self, + text_conditioning: Union[BasicConditioningInfo, SDXLConditioningInfo], + masks: Optional[torch.Tensor] = None, + embedding_ranges: Optional[list[Range]] = None, + ): + """Initialize a RegionalTextConditioningInfo. + + Args: + text_conditioning (Union[BasicConditioningInfo, SDXLConditioningInfo]): The text conditioning embeddings + after concatenating the embeddings for all regions. + masks (Optional[torch.Tensor], optional): Shape: (1, num_regions, h, w). + embedding_ranges (Optional[list[Range]], optional): The embedding range for each region. + """ + self.text_conditioning = text_conditioning + self.masks = masks + self.embedding_ranges = embedding_ranges + + assert (self.masks is None) == (self.embedding_ranges is None) + if self.masks is not None: + assert self.masks.shape[1] == len(self.embedding_ranges) + + def has_region_masks(self): + if self.masks is None: + return False + return any(mask is not None for mask in self.masks) + + def is_sdxl(self): + return isinstance(self.text_conditioning, SDXLConditioningInfo) + + @classmethod + def _preprocess_regional_prompt_mask( + cls, mask: Optional[torch.Tensor], target_height: int, target_width: int + ) -> torch.Tensor: + """Preprocess a regional prompt mask to match the target height and width. + + If mask is None, returns a mask of all ones with the target height and width. + If mask is not None, resizes the mask to the target height and width using nearest neighbor interpolation. + + Returns: + torch.Tensor: The processed mask. dtype: torch.bool, shape: (1, 1, target_height, target_width). + """ + if mask is None: + return torch.ones((1, 1, target_height, target_width), dtype=torch.bool) + + tf = torchvision.transforms.Resize( + (target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST + ) + mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w) + mask = tf(mask) + + return mask + + @classmethod + def from_text_conditioning_and_masks( + cls, + text_conditionings: list[Union[BasicConditioningInfo, SDXLConditioningInfo]], + masks: Optional[list[Optional[torch.Tensor]]], + latent_height: int, + latent_width: int, + ): + if masks is None: + masks = [None] * len(text_conditionings) + assert len(text_conditionings) == len(masks) + + is_sdxl = type(text_conditionings[0]) is SDXLConditioningInfo + + all_masks_are_none = all(mask is None for mask in masks) + + text_embedding = [] + pooled_embedding = None + add_time_ids = None + processed_masks = [] + cur_text_embedding_len = 0 + embedding_ranges: list[Range] = [] + + for text_embedding_info, mask in zip(text_conditionings, masks, strict=True): + # HACK(ryand): Figure out the intended relationship between CAC and other conditioning features. + assert ( + text_embedding_info.extra_conditioning is None + or not text_embedding_info.extra_conditioning.wants_cross_attention_control + ) + + if is_sdxl: + # We just use the the first SDXLConditioningInfo's pooled_embeds and add_time_ids. + # TODO(ryand): Think about this some more. If we can't use the pooled_embeds and add_time_ids from all + # the conditioning info, then we shouldn't allow it to be passed in. + # How does Compel handle this? Options that come to mind: + # - Blend the pooled_embeds and add_time_ids from all of the text embeddings. + # - Use the pooled_embeds and add_time_ids from the text embedding with the largest mask area, since + # this is likely the global prompt. + if pooled_embedding is None: + pooled_embedding = text_embedding_info.pooled_embeds + if add_time_ids is None: + add_time_ids = text_embedding_info.add_time_ids + + text_embedding.append(text_embedding_info.embeds) + embedding_ranges.append( + Range(start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1]) + ) + cur_text_embedding_len += text_embedding_info.embeds.shape[1] + + if not all_masks_are_none: + processed_masks.append(cls._preprocess_regional_prompt_mask(mask, latent_height, latent_width)) + + text_embedding = torch.cat(text_embedding, dim=1) + assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len + + if not all_masks_are_none: + processed_masks = torch.cat(processed_masks, dim=1) + else: + processed_masks = None + embedding_ranges = None + + if is_sdxl: + return cls( + text_conditioning=SDXLConditioningInfo( + embeds=text_embedding, + extra_conditioning=None, + pooled_embeds=pooled_embedding, + add_time_ids=add_time_ids, + ), + masks=processed_masks, + embedding_ranges=embedding_ranges, + ) + return cls( + text_conditioning=BasicConditioningInfo( + embeds=text_embedding, + extra_conditioning=None, + ), + masks=processed_masks, + embedding_ranges=embedding_ranges, + ) + + class InvokeAIDiffuserComponent: """ The aim of this component is to provide a single place for code that can be applied identically to @@ -59,7 +195,6 @@ class InvokeAIDiffuserComponent: :param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning) """ config = InvokeAIAppConfig.get_config() - self.conditioning = None self.model = model self.model_forward_callback = model_forward_callback self.cross_attention_control_context = None @@ -433,14 +568,44 @@ class InvokeAIDiffuserComponent: # denoising step. cross_attention_kwargs = None _, _, h, w = x.shape - text_embeddings, regional_prompt_data = self._prepare_text_embeddings( - text_embeddings=conditioning_data.text_embeddings, + cond_text = RegionalTextConditioningInfo.from_text_conditioning_and_masks( + text_conditionings=conditioning_data.text_embeddings, masks=conditioning_data.text_embedding_masks, - target_height=h, - target_width=w, + latent_height=h, + latent_width=w, ) - if regional_prompt_data is not None: - cross_attention_kwargs = {"regional_prompt_data": regional_prompt_data} + uncond_text = RegionalTextConditioningInfo.from_text_conditioning_and_masks( + text_conditionings=[conditioning_data.unconditioned_embeddings], + masks=[None], + latent_height=h, + latent_width=w, + ) + + if cond_text.has_region_masks() or uncond_text.has_region_masks(): + masks = [] + embedding_ranges = [] + for c in [uncond_text, cond_text]: + if c.has_region_masks(): + masks.append(c.masks) + embedding_ranges.append(c.embedding_ranges) + else: + # Create a dummy mask and range for text conditioning that doesn't have region masks. + masks.append(torch.ones((1, 1, h, w), dtype=torch.bool)) + embedding_ranges.append([Range(start=0, end=c.text_conditioning.embeds.shape[1])]) + + # The key_seq_len will be the maximum sequence length of all the conditioning embeddings. All other + # embeddings will be padded to match this length. + key_seq_len = 0 + for c in [uncond_text, cond_text]: + _, seq_len, _ = c.text_conditioning.embeds.shape + if seq_len > key_seq_len: + key_seq_len = seq_len + + cross_attention_kwargs = { + "regional_prompt_data": RegionalPromptData.from_masks_and_ranges( + masks=masks, embedding_ranges=embedding_ranges, key_seq_len=key_seq_len + ) + } # TODO(ryand): Figure out interactions between regional prompting and IP-Adapter conditioning. if conditioning_data.ip_adapter_conditioning is not None: @@ -455,27 +620,18 @@ class InvokeAIDiffuserComponent: } added_cond_kwargs = None - if type(text_embeddings) is SDXLConditioningInfo: + if cond_text.is_sdxl(): added_cond_kwargs = { "text_embeds": torch.cat( - [ - # TODO: how to pad? just by zeros? or even truncate? - conditioning_data.unconditioned_embeddings.pooled_embeds, - text_embeddings.pooled_embeds, - ], - dim=0, + [uncond_text.text_conditioning.pooled_embeds, cond_text.text_conditioning.pooled_embeds], dim=0 ), "time_ids": torch.cat( - [ - conditioning_data.unconditioned_embeddings.add_time_ids, - text_embeddings.add_time_ids, - ], - dim=0, + [uncond_text.text_conditioning.add_time_ids, cond_text.text_conditioning.add_time_ids], dim=0 ), } both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch( - conditioning_data.unconditioned_embeddings.embeds, text_embeddings.embeds + uncond_text.text_conditioning.embeds, cond_text.text_conditioning.embeds ) both_results = self.model_forward_callback( x_twice,