Very experimentation with various regional prompting tuning params.

This commit is contained in:
Ryan Dick 2024-03-02 17:43:21 -05:00
parent 942efa011e
commit ad18429fe3
4 changed files with 89 additions and 22 deletions

View File

@ -97,6 +97,23 @@ class TextConditioningData:
self.cond_text = cond_text self.cond_text = cond_text
self.uncond_regions = uncond_regions self.uncond_regions = uncond_regions
self.cond_regions = cond_regions self.cond_regions = cond_regions
# All params:
# negative_cross_attn_mask_score: -10000 (recommended to leave this as -10000 to prevent leakage to the rest of the image)
# positive_cross_attn_mask_score: 0.0 (relative weightin of masks)
# positive_self_attn_mask_score: 0.3
# negative_self_attn_mask_score: This doesn't really make sense. It would effectively have the same effect as further increasing positive_self_attn_mask_score.
# cross_attn_start_step
# self_attn_mask_begin_step_percent: 0.0
# self_attn_mask_end_step percent: 0.5
# Should we allow cross_attn_mask_begin_step_percent and cross_attn_mask_end_step_percent? Probably not, this seems like more control than necessary. And easy to add in the future.
self.negative_cross_attn_mask_score = -10000
self.positive_cross_attn_mask_score = 0.0
self.positive_self_attn_mask_score = 0.3
self.self_attn_mask_end_step_percent = 0.5
# mask_weight: float = Field(
# default=1.0,
# description="The weight to apply to the mask. This weight controls the relative weighting of overlapping masks. This weight gets added to the attention map logits before applying a pixelwise softmax.",
# )
# Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). # Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
# `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). # `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
# Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate # Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate

View File

@ -58,6 +58,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
scale: float = 1.0, scale: float = 1.0,
# For regional prompting: # For regional prompting:
regional_prompt_data: Optional[RegionalPromptData] = None, regional_prompt_data: Optional[RegionalPromptData] = None,
percent_through: Optional[torch.FloatTensor] = None,
# For IP-Adapter: # For IP-Adapter:
ip_adapter_image_prompt_embeds: Optional[list[torch.Tensor]] = None, ip_adapter_image_prompt_embeds: Optional[list[torch.Tensor]] = None,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
@ -93,6 +94,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
# Handle regional prompt attention masks. # Handle regional prompt attention masks.
if regional_prompt_data is not None: if regional_prompt_data is not None:
assert percent_through is not None
_, query_seq_len, _ = hidden_states.shape _, query_seq_len, _ = hidden_states.shape
if is_cross_attention: if is_cross_attention:
prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask( prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask(
@ -102,16 +104,19 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
prompt_region_attention_mask = prompt_region_attention_mask.to( prompt_region_attention_mask = prompt_region_attention_mask.to(
dtype=hidden_states.dtype, device=hidden_states.device dtype=hidden_states.dtype, device=hidden_states.device
) )
prompt_region_attention_mask[prompt_region_attention_mask < 0.5] = -10000.0
prompt_region_attention_mask[prompt_region_attention_mask >= 0.5] = 0.0
else: # self-attention else: # self-attention
prompt_region_attention_mask = regional_prompt_data.get_self_attn_mask(query_seq_len=query_seq_len) prompt_region_attention_mask = regional_prompt_data.get_self_attn_mask(
query_seq_len=query_seq_len,
percent_through=percent_through,
device=hidden_states.device,
dtype=hidden_states.dtype,
)
# TODO(ryand): Avoid redundant type/device conversion here. # TODO(ryand): Avoid redundant type/device conversion here.
prompt_region_attention_mask = prompt_region_attention_mask.to( # prompt_region_attention_mask = prompt_region_attention_mask.to(
dtype=hidden_states.dtype, device=hidden_states.device # dtype=hidden_states.dtype, device=hidden_states.device
) # )
# prompt_region_attention_mask[prompt_region_attention_mask < 0.5] = -0.5 # prompt_region_attention_mask[prompt_region_attention_mask < 0.5] = -0.5
# prompt_region_attention_mask[prompt_region_attention_mask >= 0.5] = 0.0 # prompt_region_attention_mask[prompt_region_attention_mask >= 0.5] = 0.0

View File

@ -41,6 +41,17 @@ class RegionalPromptData:
regions, max_downscale_factor regions, max_downscale_factor
) )
# TODO: These should be indexed by batch sample index and prompt index.
# Next:
# - Add support for setting these one nodes. Might just need positive cross-attention mask score. Being able to downweight the global prompt mighth help alot.
# - Scale by region size.
self.negative_cross_attn_mask_score = -10000
self.positive_cross_attn_mask_score = 0.0
self.positive_self_attn_mask_score = 2.0
self.self_attn_mask_end_step_percent = 0.3
# This one is for regional prompting in general, so should be set on the DenoiseLatents node.
self.self_attn_score_range = 3.0
def _prepare_spatial_masks( def _prepare_spatial_masks(
self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8 self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8
) -> list[dict[int, torch.Tensor]]: ) -> list[dict[int, torch.Tensor]]:
@ -179,9 +190,14 @@ class RegionalPromptData:
0, prompt_idx, :, : 0, prompt_idx, :, :
] ]
pos_mask = attn_mask >= 0.5
attn_mask[~pos_mask] = self.negative_cross_attn_mask_score
attn_mask[pos_mask] = self.positive_cross_attn_mask_score
return attn_mask return attn_mask
def get_self_attn_mask(self, query_seq_len: int) -> torch.Tensor: def get_self_attn_mask(
self, query_seq_len: int, percent_through: float, device: torch.device, dtype: torch.dtype
) -> torch.Tensor:
"""Get the self-attention mask for the given query sequence length. """Get the self-attention mask for the given query sequence length.
Args: Args:
@ -193,11 +209,15 @@ class RegionalPromptData:
dtype: float dtype: float
The mask is a binary mask with values of 0.0 and 1.0. The mask is a binary mask with values of 0.0 and 1.0.
""" """
# TODO(ryand): Manage dtype and device properly. There's a lot of inefficient copying, conversion, and
# unnecessary CPU operations happening in this class.
batch_size = len(self._spatial_masks_by_seq_len) batch_size = len(self._spatial_masks_by_seq_len)
batch_spatial_masks = [self._spatial_masks_by_seq_len[b][query_seq_len] for b in range(batch_size)] batch_spatial_masks = [
self._spatial_masks_by_seq_len[b][query_seq_len].to(device=device, dtype=dtype) for b in range(batch_size)
]
# Create an empty attention mask with the correct shape. # Create an empty attention mask with the correct shape.
attn_mask = torch.zeros((batch_size, query_seq_len, query_seq_len)) attn_mask = torch.zeros((batch_size, query_seq_len, query_seq_len), dtype=dtype, device=device)
for batch_idx in range(batch_size): for batch_idx in range(batch_size):
batch_sample_spatial_masks = batch_spatial_masks[batch_idx] batch_sample_spatial_masks = batch_spatial_masks[batch_idx]
@ -207,11 +227,32 @@ class RegionalPromptData:
batch_sample_query_masks = batch_sample_spatial_masks.view((1, num_prompts, query_seq_len, 1)) batch_sample_query_masks = batch_sample_spatial_masks.view((1, num_prompts, query_seq_len, 1))
for prompt_idx in range(num_prompts): for prompt_idx in range(num_prompts):
if percent_through > self.self_attn_mask_end_step_percent:
continue
prompt_query_mask = batch_sample_query_masks[0, prompt_idx, :, 0] # Shape: (query_seq_len,) prompt_query_mask = batch_sample_query_masks[0, prompt_idx, :, 0] # Shape: (query_seq_len,)
# Multiply a (1, query_seq_len) mask by a (query_seq_len, 1) mask to get a (query_seq_len, # Multiply a (1, query_seq_len) mask by a (query_seq_len, 1) mask to get a (query_seq_len,
# query_seq_len) mask. # query_seq_len) mask.
attn_mask[batch_idx, :, :] += prompt_query_mask.unsqueeze(0) * prompt_query_mask.unsqueeze(1) * 0.5 attn_mask[batch_idx, :, :] += (
prompt_query_mask.unsqueeze(0) * prompt_query_mask.unsqueeze(1) * self.positive_self_attn_mask_score
)
# Since we were adding masks in the previous loop, we need to clamp the values to 1.0. # attn_mask_min = attn_mask[batch_idx].min()
# attn_mask[attn_mask > 0.5] = 1.0 # attn_mask_max = attn_mask[batch_idx].max()
# attn_mask_range = attn_mask_max - attn_mask_min
# if abs(attn_mask_range) < 0.0001:
# # All attn_mask value in this batch sample are the same, set the attn_mask to 0.0s (to avoid divide by
# # zero in the normalization).
# attn_mask[batch_idx] = attn_mask[batch_idx] * 0.0
# else:
# # Normalize from range [attn_mask_min, attn_mask_max] to [0, self.self_attn_score_range].
# attn_mask[batch_idx] = (
# (attn_mask[batch_idx] - attn_mask_min) / attn_mask_range * self.self_attn_score_range
# )
attn_mask_min = attn_mask[batch_idx].min()
# Adjust so that the minimum value is 0.0 regardless of whether all pixels are covered or not.
if abs(attn_mask_min) > 0.0001:
attn_mask[batch_idx] = attn_mask[batch_idx] - attn_mask_min
return attn_mask return attn_mask

View File

@ -200,9 +200,9 @@ class InvokeAIDiffuserComponent:
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
): ):
percent_through = step_index / total_step_count
cross_attention_control_types_to_do = [] cross_attention_control_types_to_do = []
if self.cross_attention_control_context is not None: if self.cross_attention_control_context is not None:
percent_through = step_index / total_step_count
cross_attention_control_types_to_do = ( cross_attention_control_types_to_do = (
self.cross_attention_control_context.get_active_cross_attention_control_types_for_step(percent_through) self.cross_attention_control_context.get_active_cross_attention_control_types_for_step(percent_through)
) )
@ -219,6 +219,7 @@ class InvokeAIDiffuserComponent:
sigma=timestep, sigma=timestep,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
ip_adapter_conditioning=ip_adapter_conditioning, ip_adapter_conditioning=ip_adapter_conditioning,
percent_through=percent_through,
cross_attention_control_types_to_do=cross_attention_control_types_to_do, cross_attention_control_types_to_do=cross_attention_control_types_to_do,
down_block_additional_residuals=down_block_additional_residuals, down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual, mid_block_additional_residual=mid_block_additional_residual,
@ -232,6 +233,7 @@ class InvokeAIDiffuserComponent:
x=sample, x=sample,
sigma=timestep, sigma=timestep,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
percent_through=percent_through,
ip_adapter_conditioning=ip_adapter_conditioning, ip_adapter_conditioning=ip_adapter_conditioning,
down_block_additional_residuals=down_block_additional_residuals, down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual, mid_block_additional_residual=mid_block_additional_residual,
@ -293,6 +295,7 @@ class InvokeAIDiffuserComponent:
sigma, sigma,
conditioning_data: TextConditioningData, conditioning_data: TextConditioningData,
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]], ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
percent_through: float,
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
@ -326,8 +329,8 @@ class InvokeAIDiffuserComponent:
) )
if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None: if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None:
# TODO(ryand): We currently call from_regions(...) for every denoising step. The text conditionings and # TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings
# masks are not changing from step-to-step, so this really only needs to be done once. While this seems # and masks are not changing from step-to-step, so this really only needs to be done once. While this seems
# painfully inefficient, the time spent is typically negligible compared to the forward inference pass of # painfully inefficient, the time spent is typically negligible compared to the forward inference pass of
# the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly # the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly
# awkward to handle both standard conditioning and sequential conditioning further up the stack. # awkward to handle both standard conditioning and sequential conditioning further up the stack.
@ -345,8 +348,8 @@ class InvokeAIDiffuserComponent:
) )
regions.append(r) regions.append(r)
_, key_seq_len, _ = both_conditionings.shape
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(regions=regions) cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(regions=regions)
cross_attention_kwargs["percent_through"] = percent_through
both_results = self.model_forward_callback( both_results = self.model_forward_callback(
x_twice, x_twice,
@ -369,6 +372,7 @@ class InvokeAIDiffuserComponent:
conditioning_data: TextConditioningData, conditioning_data: TextConditioningData,
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]], ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
cross_attention_control_types_to_do: list[CrossAttentionType], cross_attention_control_types_to_do: list[CrossAttentionType],
percent_through: float,
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
@ -439,10 +443,10 @@ class InvokeAIDiffuserComponent:
# Prepare prompt regions for the unconditioned pass. # Prepare prompt regions for the unconditioned pass.
if conditioning_data.uncond_regions is not None: if conditioning_data.uncond_regions is not None:
_, key_seq_len, _ = conditioning_data.uncond_text.embeds.shape cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData.from_regions( regions=[conditioning_data.uncond_regions]
regions=[conditioning_data.uncond_regions], key_seq_len=key_seq_len
) )
cross_attention_kwargs["percent_through"] = percent_through
# Run unconditioned UNet denoising (i.e. negative prompt). # Run unconditioned UNet denoising (i.e. negative prompt).
unconditioned_next_x = self.model_forward_callback( unconditioned_next_x = self.model_forward_callback(
@ -485,10 +489,10 @@ class InvokeAIDiffuserComponent:
# Prepare prompt regions for the conditioned pass. # Prepare prompt regions for the conditioned pass.
if conditioning_data.cond_regions is not None: if conditioning_data.cond_regions is not None:
_, key_seq_len, _ = conditioning_data.cond_text.embeds.shape cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData.from_regions( regions=[conditioning_data.cond_regions]
regions=[conditioning_data.cond_regions], key_seq_len=key_seq_len
) )
cross_attention_kwargs["percent_through"] = percent_through
# Run conditioned UNet denoising (i.e. positive prompt). # Run conditioned UNet denoising (i.e. positive prompt).
conditioned_next_x = self.model_forward_callback( conditioned_next_x = self.model_forward_callback(