From ad18429fe371e40bba73311b1d2f81b6041944e5 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Sat, 2 Mar 2024 17:43:21 -0500 Subject: [PATCH] Very experimentation with various regional prompting tuning params. --- .../diffusion/conditioning_data.py | 17 ++++++ .../diffusion/custom_attention.py | 17 +++--- .../diffusion/regional_prompt_data.py | 53 ++++++++++++++++--- .../diffusion/shared_invokeai_diffusion.py | 24 +++++---- 4 files changed, 89 insertions(+), 22 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index e4b27ad04d..8d7ff09be1 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -97,6 +97,23 @@ class TextConditioningData: self.cond_text = cond_text self.uncond_regions = uncond_regions self.cond_regions = cond_regions + # All params: + # negative_cross_attn_mask_score: -10000 (recommended to leave this as -10000 to prevent leakage to the rest of the image) + # positive_cross_attn_mask_score: 0.0 (relative weightin of masks) + # positive_self_attn_mask_score: 0.3 + # negative_self_attn_mask_score: This doesn't really make sense. It would effectively have the same effect as further increasing positive_self_attn_mask_score. + # cross_attn_start_step + # self_attn_mask_begin_step_percent: 0.0 + # self_attn_mask_end_step percent: 0.5 + # Should we allow cross_attn_mask_begin_step_percent and cross_attn_mask_end_step_percent? Probably not, this seems like more control than necessary. And easy to add in the future. + self.negative_cross_attn_mask_score = -10000 + self.positive_cross_attn_mask_score = 0.0 + self.positive_self_attn_mask_score = 0.3 + self.self_attn_mask_end_step_percent = 0.5 + # mask_weight: float = Field( + # default=1.0, + # description="The weight to apply to the mask. This weight controls the relative weighting of overlapping masks. This weight gets added to the attention map logits before applying a pixelwise softmax.", + # ) # Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). # `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). # Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py index 632d6beeb0..1b2c570ca0 100644 --- a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py +++ b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py @@ -58,6 +58,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): scale: float = 1.0, # For regional prompting: regional_prompt_data: Optional[RegionalPromptData] = None, + percent_through: Optional[torch.FloatTensor] = None, # For IP-Adapter: ip_adapter_image_prompt_embeds: Optional[list[torch.Tensor]] = None, ) -> torch.FloatTensor: @@ -93,6 +94,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): # Handle regional prompt attention masks. if regional_prompt_data is not None: + assert percent_through is not None _, query_seq_len, _ = hidden_states.shape if is_cross_attention: prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask( @@ -102,16 +104,19 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): prompt_region_attention_mask = prompt_region_attention_mask.to( dtype=hidden_states.dtype, device=hidden_states.device ) - prompt_region_attention_mask[prompt_region_attention_mask < 0.5] = -10000.0 - prompt_region_attention_mask[prompt_region_attention_mask >= 0.5] = 0.0 else: # self-attention - prompt_region_attention_mask = regional_prompt_data.get_self_attn_mask(query_seq_len=query_seq_len) + prompt_region_attention_mask = regional_prompt_data.get_self_attn_mask( + query_seq_len=query_seq_len, + percent_through=percent_through, + device=hidden_states.device, + dtype=hidden_states.dtype, + ) # TODO(ryand): Avoid redundant type/device conversion here. - prompt_region_attention_mask = prompt_region_attention_mask.to( - dtype=hidden_states.dtype, device=hidden_states.device - ) + # prompt_region_attention_mask = prompt_region_attention_mask.to( + # dtype=hidden_states.dtype, device=hidden_states.device + # ) # prompt_region_attention_mask[prompt_region_attention_mask < 0.5] = -0.5 # prompt_region_attention_mask[prompt_region_attention_mask >= 0.5] = 0.0 diff --git a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py index b5966c8b56..749e3e0224 100644 --- a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py @@ -41,6 +41,17 @@ class RegionalPromptData: regions, max_downscale_factor ) + # TODO: These should be indexed by batch sample index and prompt index. + # Next: + # - Add support for setting these one nodes. Might just need positive cross-attention mask score. Being able to downweight the global prompt mighth help alot. + # - Scale by region size. + self.negative_cross_attn_mask_score = -10000 + self.positive_cross_attn_mask_score = 0.0 + self.positive_self_attn_mask_score = 2.0 + self.self_attn_mask_end_step_percent = 0.3 + # This one is for regional prompting in general, so should be set on the DenoiseLatents node. + self.self_attn_score_range = 3.0 + def _prepare_spatial_masks( self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8 ) -> list[dict[int, torch.Tensor]]: @@ -179,9 +190,14 @@ class RegionalPromptData: 0, prompt_idx, :, : ] + pos_mask = attn_mask >= 0.5 + attn_mask[~pos_mask] = self.negative_cross_attn_mask_score + attn_mask[pos_mask] = self.positive_cross_attn_mask_score return attn_mask - def get_self_attn_mask(self, query_seq_len: int) -> torch.Tensor: + def get_self_attn_mask( + self, query_seq_len: int, percent_through: float, device: torch.device, dtype: torch.dtype + ) -> torch.Tensor: """Get the self-attention mask for the given query sequence length. Args: @@ -193,11 +209,15 @@ class RegionalPromptData: dtype: float The mask is a binary mask with values of 0.0 and 1.0. """ + # TODO(ryand): Manage dtype and device properly. There's a lot of inefficient copying, conversion, and + # unnecessary CPU operations happening in this class. batch_size = len(self._spatial_masks_by_seq_len) - batch_spatial_masks = [self._spatial_masks_by_seq_len[b][query_seq_len] for b in range(batch_size)] + batch_spatial_masks = [ + self._spatial_masks_by_seq_len[b][query_seq_len].to(device=device, dtype=dtype) for b in range(batch_size) + ] # Create an empty attention mask with the correct shape. - attn_mask = torch.zeros((batch_size, query_seq_len, query_seq_len)) + attn_mask = torch.zeros((batch_size, query_seq_len, query_seq_len), dtype=dtype, device=device) for batch_idx in range(batch_size): batch_sample_spatial_masks = batch_spatial_masks[batch_idx] @@ -207,11 +227,32 @@ class RegionalPromptData: batch_sample_query_masks = batch_sample_spatial_masks.view((1, num_prompts, query_seq_len, 1)) for prompt_idx in range(num_prompts): + if percent_through > self.self_attn_mask_end_step_percent: + continue prompt_query_mask = batch_sample_query_masks[0, prompt_idx, :, 0] # Shape: (query_seq_len,) # Multiply a (1, query_seq_len) mask by a (query_seq_len, 1) mask to get a (query_seq_len, # query_seq_len) mask. - attn_mask[batch_idx, :, :] += prompt_query_mask.unsqueeze(0) * prompt_query_mask.unsqueeze(1) * 0.5 + attn_mask[batch_idx, :, :] += ( + prompt_query_mask.unsqueeze(0) * prompt_query_mask.unsqueeze(1) * self.positive_self_attn_mask_score + ) - # Since we were adding masks in the previous loop, we need to clamp the values to 1.0. - # attn_mask[attn_mask > 0.5] = 1.0 + # attn_mask_min = attn_mask[batch_idx].min() + # attn_mask_max = attn_mask[batch_idx].max() + # attn_mask_range = attn_mask_max - attn_mask_min + + # if abs(attn_mask_range) < 0.0001: + # # All attn_mask value in this batch sample are the same, set the attn_mask to 0.0s (to avoid divide by + # # zero in the normalization). + # attn_mask[batch_idx] = attn_mask[batch_idx] * 0.0 + # else: + # # Normalize from range [attn_mask_min, attn_mask_max] to [0, self.self_attn_score_range]. + # attn_mask[batch_idx] = ( + # (attn_mask[batch_idx] - attn_mask_min) / attn_mask_range * self.self_attn_score_range + # ) + + attn_mask_min = attn_mask[batch_idx].min() + + # Adjust so that the minimum value is 0.0 regardless of whether all pixels are covered or not. + if abs(attn_mask_min) > 0.0001: + attn_mask[batch_idx] = attn_mask[batch_idx] - attn_mask_min return attn_mask diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index fd4806d024..9c716e5440 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -200,9 +200,9 @@ class InvokeAIDiffuserComponent: mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter ): + percent_through = step_index / total_step_count cross_attention_control_types_to_do = [] if self.cross_attention_control_context is not None: - percent_through = step_index / total_step_count cross_attention_control_types_to_do = ( self.cross_attention_control_context.get_active_cross_attention_control_types_for_step(percent_through) ) @@ -219,6 +219,7 @@ class InvokeAIDiffuserComponent: sigma=timestep, conditioning_data=conditioning_data, ip_adapter_conditioning=ip_adapter_conditioning, + percent_through=percent_through, cross_attention_control_types_to_do=cross_attention_control_types_to_do, down_block_additional_residuals=down_block_additional_residuals, mid_block_additional_residual=mid_block_additional_residual, @@ -232,6 +233,7 @@ class InvokeAIDiffuserComponent: x=sample, sigma=timestep, conditioning_data=conditioning_data, + percent_through=percent_through, ip_adapter_conditioning=ip_adapter_conditioning, down_block_additional_residuals=down_block_additional_residuals, mid_block_additional_residual=mid_block_additional_residual, @@ -293,6 +295,7 @@ class InvokeAIDiffuserComponent: sigma, conditioning_data: TextConditioningData, ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]], + percent_through: float, down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter @@ -326,8 +329,8 @@ class InvokeAIDiffuserComponent: ) if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None: - # TODO(ryand): We currently call from_regions(...) for every denoising step. The text conditionings and - # masks are not changing from step-to-step, so this really only needs to be done once. While this seems + # TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings + # and masks are not changing from step-to-step, so this really only needs to be done once. While this seems # painfully inefficient, the time spent is typically negligible compared to the forward inference pass of # the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly # awkward to handle both standard conditioning and sequential conditioning further up the stack. @@ -345,8 +348,8 @@ class InvokeAIDiffuserComponent: ) regions.append(r) - _, key_seq_len, _ = both_conditionings.shape cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(regions=regions) + cross_attention_kwargs["percent_through"] = percent_through both_results = self.model_forward_callback( x_twice, @@ -369,6 +372,7 @@ class InvokeAIDiffuserComponent: conditioning_data: TextConditioningData, ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]], cross_attention_control_types_to_do: list[CrossAttentionType], + percent_through: float, down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter @@ -439,10 +443,10 @@ class InvokeAIDiffuserComponent: # Prepare prompt regions for the unconditioned pass. if conditioning_data.uncond_regions is not None: - _, key_seq_len, _ = conditioning_data.uncond_text.embeds.shape - cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData.from_regions( - regions=[conditioning_data.uncond_regions], key_seq_len=key_seq_len + cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData( + regions=[conditioning_data.uncond_regions] ) + cross_attention_kwargs["percent_through"] = percent_through # Run unconditioned UNet denoising (i.e. negative prompt). unconditioned_next_x = self.model_forward_callback( @@ -485,10 +489,10 @@ class InvokeAIDiffuserComponent: # Prepare prompt regions for the conditioned pass. if conditioning_data.cond_regions is not None: - _, key_seq_len, _ = conditioning_data.cond_text.embeds.shape - cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData.from_regions( - regions=[conditioning_data.cond_regions], key_seq_len=key_seq_len + cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData( + regions=[conditioning_data.cond_regions] ) + cross_attention_kwargs["percent_through"] = percent_through # Run conditioned UNet denoising (i.e. positive prompt). conditioned_next_x = self.model_forward_callback(