mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Very experimentation with various regional prompting tuning params.
This commit is contained in:
parent
942efa011e
commit
ad18429fe3
@ -97,6 +97,23 @@ class TextConditioningData:
|
||||
self.cond_text = cond_text
|
||||
self.uncond_regions = uncond_regions
|
||||
self.cond_regions = cond_regions
|
||||
# All params:
|
||||
# negative_cross_attn_mask_score: -10000 (recommended to leave this as -10000 to prevent leakage to the rest of the image)
|
||||
# positive_cross_attn_mask_score: 0.0 (relative weightin of masks)
|
||||
# positive_self_attn_mask_score: 0.3
|
||||
# negative_self_attn_mask_score: This doesn't really make sense. It would effectively have the same effect as further increasing positive_self_attn_mask_score.
|
||||
# cross_attn_start_step
|
||||
# self_attn_mask_begin_step_percent: 0.0
|
||||
# self_attn_mask_end_step percent: 0.5
|
||||
# Should we allow cross_attn_mask_begin_step_percent and cross_attn_mask_end_step_percent? Probably not, this seems like more control than necessary. And easy to add in the future.
|
||||
self.negative_cross_attn_mask_score = -10000
|
||||
self.positive_cross_attn_mask_score = 0.0
|
||||
self.positive_self_attn_mask_score = 0.3
|
||||
self.self_attn_mask_end_step_percent = 0.5
|
||||
# mask_weight: float = Field(
|
||||
# default=1.0,
|
||||
# description="The weight to apply to the mask. This weight controls the relative weighting of overlapping masks. This weight gets added to the attention map logits before applying a pixelwise softmax.",
|
||||
# )
|
||||
# Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
# `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
||||
# Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
||||
|
@ -58,6 +58,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
scale: float = 1.0,
|
||||
# For regional prompting:
|
||||
regional_prompt_data: Optional[RegionalPromptData] = None,
|
||||
percent_through: Optional[torch.FloatTensor] = None,
|
||||
# For IP-Adapter:
|
||||
ip_adapter_image_prompt_embeds: Optional[list[torch.Tensor]] = None,
|
||||
) -> torch.FloatTensor:
|
||||
@ -93,6 +94,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
|
||||
# Handle regional prompt attention masks.
|
||||
if regional_prompt_data is not None:
|
||||
assert percent_through is not None
|
||||
_, query_seq_len, _ = hidden_states.shape
|
||||
if is_cross_attention:
|
||||
prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask(
|
||||
@ -102,16 +104,19 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
prompt_region_attention_mask = prompt_region_attention_mask.to(
|
||||
dtype=hidden_states.dtype, device=hidden_states.device
|
||||
)
|
||||
prompt_region_attention_mask[prompt_region_attention_mask < 0.5] = -10000.0
|
||||
prompt_region_attention_mask[prompt_region_attention_mask >= 0.5] = 0.0
|
||||
|
||||
else: # self-attention
|
||||
prompt_region_attention_mask = regional_prompt_data.get_self_attn_mask(query_seq_len=query_seq_len)
|
||||
prompt_region_attention_mask = regional_prompt_data.get_self_attn_mask(
|
||||
query_seq_len=query_seq_len,
|
||||
percent_through=percent_through,
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
# TODO(ryand): Avoid redundant type/device conversion here.
|
||||
prompt_region_attention_mask = prompt_region_attention_mask.to(
|
||||
dtype=hidden_states.dtype, device=hidden_states.device
|
||||
)
|
||||
# prompt_region_attention_mask = prompt_region_attention_mask.to(
|
||||
# dtype=hidden_states.dtype, device=hidden_states.device
|
||||
# )
|
||||
# prompt_region_attention_mask[prompt_region_attention_mask < 0.5] = -0.5
|
||||
# prompt_region_attention_mask[prompt_region_attention_mask >= 0.5] = 0.0
|
||||
|
||||
|
@ -41,6 +41,17 @@ class RegionalPromptData:
|
||||
regions, max_downscale_factor
|
||||
)
|
||||
|
||||
# TODO: These should be indexed by batch sample index and prompt index.
|
||||
# Next:
|
||||
# - Add support for setting these one nodes. Might just need positive cross-attention mask score. Being able to downweight the global prompt mighth help alot.
|
||||
# - Scale by region size.
|
||||
self.negative_cross_attn_mask_score = -10000
|
||||
self.positive_cross_attn_mask_score = 0.0
|
||||
self.positive_self_attn_mask_score = 2.0
|
||||
self.self_attn_mask_end_step_percent = 0.3
|
||||
# This one is for regional prompting in general, so should be set on the DenoiseLatents node.
|
||||
self.self_attn_score_range = 3.0
|
||||
|
||||
def _prepare_spatial_masks(
|
||||
self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8
|
||||
) -> list[dict[int, torch.Tensor]]:
|
||||
@ -179,9 +190,14 @@ class RegionalPromptData:
|
||||
0, prompt_idx, :, :
|
||||
]
|
||||
|
||||
pos_mask = attn_mask >= 0.5
|
||||
attn_mask[~pos_mask] = self.negative_cross_attn_mask_score
|
||||
attn_mask[pos_mask] = self.positive_cross_attn_mask_score
|
||||
return attn_mask
|
||||
|
||||
def get_self_attn_mask(self, query_seq_len: int) -> torch.Tensor:
|
||||
def get_self_attn_mask(
|
||||
self, query_seq_len: int, percent_through: float, device: torch.device, dtype: torch.dtype
|
||||
) -> torch.Tensor:
|
||||
"""Get the self-attention mask for the given query sequence length.
|
||||
|
||||
Args:
|
||||
@ -193,11 +209,15 @@ class RegionalPromptData:
|
||||
dtype: float
|
||||
The mask is a binary mask with values of 0.0 and 1.0.
|
||||
"""
|
||||
# TODO(ryand): Manage dtype and device properly. There's a lot of inefficient copying, conversion, and
|
||||
# unnecessary CPU operations happening in this class.
|
||||
batch_size = len(self._spatial_masks_by_seq_len)
|
||||
batch_spatial_masks = [self._spatial_masks_by_seq_len[b][query_seq_len] for b in range(batch_size)]
|
||||
batch_spatial_masks = [
|
||||
self._spatial_masks_by_seq_len[b][query_seq_len].to(device=device, dtype=dtype) for b in range(batch_size)
|
||||
]
|
||||
|
||||
# Create an empty attention mask with the correct shape.
|
||||
attn_mask = torch.zeros((batch_size, query_seq_len, query_seq_len))
|
||||
attn_mask = torch.zeros((batch_size, query_seq_len, query_seq_len), dtype=dtype, device=device)
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
batch_sample_spatial_masks = batch_spatial_masks[batch_idx]
|
||||
@ -207,11 +227,32 @@ class RegionalPromptData:
|
||||
batch_sample_query_masks = batch_sample_spatial_masks.view((1, num_prompts, query_seq_len, 1))
|
||||
|
||||
for prompt_idx in range(num_prompts):
|
||||
if percent_through > self.self_attn_mask_end_step_percent:
|
||||
continue
|
||||
prompt_query_mask = batch_sample_query_masks[0, prompt_idx, :, 0] # Shape: (query_seq_len,)
|
||||
# Multiply a (1, query_seq_len) mask by a (query_seq_len, 1) mask to get a (query_seq_len,
|
||||
# query_seq_len) mask.
|
||||
attn_mask[batch_idx, :, :] += prompt_query_mask.unsqueeze(0) * prompt_query_mask.unsqueeze(1) * 0.5
|
||||
attn_mask[batch_idx, :, :] += (
|
||||
prompt_query_mask.unsqueeze(0) * prompt_query_mask.unsqueeze(1) * self.positive_self_attn_mask_score
|
||||
)
|
||||
|
||||
# Since we were adding masks in the previous loop, we need to clamp the values to 1.0.
|
||||
# attn_mask[attn_mask > 0.5] = 1.0
|
||||
# attn_mask_min = attn_mask[batch_idx].min()
|
||||
# attn_mask_max = attn_mask[batch_idx].max()
|
||||
# attn_mask_range = attn_mask_max - attn_mask_min
|
||||
|
||||
# if abs(attn_mask_range) < 0.0001:
|
||||
# # All attn_mask value in this batch sample are the same, set the attn_mask to 0.0s (to avoid divide by
|
||||
# # zero in the normalization).
|
||||
# attn_mask[batch_idx] = attn_mask[batch_idx] * 0.0
|
||||
# else:
|
||||
# # Normalize from range [attn_mask_min, attn_mask_max] to [0, self.self_attn_score_range].
|
||||
# attn_mask[batch_idx] = (
|
||||
# (attn_mask[batch_idx] - attn_mask_min) / attn_mask_range * self.self_attn_score_range
|
||||
# )
|
||||
|
||||
attn_mask_min = attn_mask[batch_idx].min()
|
||||
|
||||
# Adjust so that the minimum value is 0.0 regardless of whether all pixels are covered or not.
|
||||
if abs(attn_mask_min) > 0.0001:
|
||||
attn_mask[batch_idx] = attn_mask[batch_idx] - attn_mask_min
|
||||
return attn_mask
|
||||
|
@ -200,9 +200,9 @@ class InvokeAIDiffuserComponent:
|
||||
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
||||
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
||||
):
|
||||
percent_through = step_index / total_step_count
|
||||
cross_attention_control_types_to_do = []
|
||||
if self.cross_attention_control_context is not None:
|
||||
percent_through = step_index / total_step_count
|
||||
cross_attention_control_types_to_do = (
|
||||
self.cross_attention_control_context.get_active_cross_attention_control_types_for_step(percent_through)
|
||||
)
|
||||
@ -219,6 +219,7 @@ class InvokeAIDiffuserComponent:
|
||||
sigma=timestep,
|
||||
conditioning_data=conditioning_data,
|
||||
ip_adapter_conditioning=ip_adapter_conditioning,
|
||||
percent_through=percent_through,
|
||||
cross_attention_control_types_to_do=cross_attention_control_types_to_do,
|
||||
down_block_additional_residuals=down_block_additional_residuals,
|
||||
mid_block_additional_residual=mid_block_additional_residual,
|
||||
@ -232,6 +233,7 @@ class InvokeAIDiffuserComponent:
|
||||
x=sample,
|
||||
sigma=timestep,
|
||||
conditioning_data=conditioning_data,
|
||||
percent_through=percent_through,
|
||||
ip_adapter_conditioning=ip_adapter_conditioning,
|
||||
down_block_additional_residuals=down_block_additional_residuals,
|
||||
mid_block_additional_residual=mid_block_additional_residual,
|
||||
@ -293,6 +295,7 @@ class InvokeAIDiffuserComponent:
|
||||
sigma,
|
||||
conditioning_data: TextConditioningData,
|
||||
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
|
||||
percent_through: float,
|
||||
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
||||
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
||||
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
||||
@ -326,8 +329,8 @@ class InvokeAIDiffuserComponent:
|
||||
)
|
||||
|
||||
if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None:
|
||||
# TODO(ryand): We currently call from_regions(...) for every denoising step. The text conditionings and
|
||||
# masks are not changing from step-to-step, so this really only needs to be done once. While this seems
|
||||
# TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings
|
||||
# and masks are not changing from step-to-step, so this really only needs to be done once. While this seems
|
||||
# painfully inefficient, the time spent is typically negligible compared to the forward inference pass of
|
||||
# the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly
|
||||
# awkward to handle both standard conditioning and sequential conditioning further up the stack.
|
||||
@ -345,8 +348,8 @@ class InvokeAIDiffuserComponent:
|
||||
)
|
||||
regions.append(r)
|
||||
|
||||
_, key_seq_len, _ = both_conditionings.shape
|
||||
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(regions=regions)
|
||||
cross_attention_kwargs["percent_through"] = percent_through
|
||||
|
||||
both_results = self.model_forward_callback(
|
||||
x_twice,
|
||||
@ -369,6 +372,7 @@ class InvokeAIDiffuserComponent:
|
||||
conditioning_data: TextConditioningData,
|
||||
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
|
||||
cross_attention_control_types_to_do: list[CrossAttentionType],
|
||||
percent_through: float,
|
||||
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
||||
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
||||
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
||||
@ -439,10 +443,10 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
# Prepare prompt regions for the unconditioned pass.
|
||||
if conditioning_data.uncond_regions is not None:
|
||||
_, key_seq_len, _ = conditioning_data.uncond_text.embeds.shape
|
||||
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData.from_regions(
|
||||
regions=[conditioning_data.uncond_regions], key_seq_len=key_seq_len
|
||||
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
|
||||
regions=[conditioning_data.uncond_regions]
|
||||
)
|
||||
cross_attention_kwargs["percent_through"] = percent_through
|
||||
|
||||
# Run unconditioned UNet denoising (i.e. negative prompt).
|
||||
unconditioned_next_x = self.model_forward_callback(
|
||||
@ -485,10 +489,10 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
# Prepare prompt regions for the conditioned pass.
|
||||
if conditioning_data.cond_regions is not None:
|
||||
_, key_seq_len, _ = conditioning_data.cond_text.embeds.shape
|
||||
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData.from_regions(
|
||||
regions=[conditioning_data.cond_regions], key_seq_len=key_seq_len
|
||||
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
|
||||
regions=[conditioning_data.cond_regions]
|
||||
)
|
||||
cross_attention_kwargs["percent_through"] = percent_through
|
||||
|
||||
# Run conditioned UNet denoising (i.e. positive prompt).
|
||||
conditioned_next_x = self.model_forward_callback(
|
||||
|
Loading…
Reference in New Issue
Block a user