Add positive_self_attn_mask_score and self_attn_adjustment_end_step_percent to the prompt nodes.

This commit is contained in:
Ryan Dick 2024-03-04 15:34:26 -05:00
parent d313e5eb70
commit a665f20fb5
6 changed files with 64 additions and 60 deletions

View File

@ -63,6 +63,8 @@ class CompelInvocation(BaseInvocation):
default=None, description="A mask defining the region that this conditioning prompt applies to." default=None, description="A mask defining the region that this conditioning prompt applies to."
) )
positive_cross_attn_mask_score: float = InputField(default=0.0, description="") positive_cross_attn_mask_score: float = InputField(default=0.0, description="")
positive_self_attn_mask_score: float = InputField(default=1.0, description="")
self_attn_adjustment_end_step_percent: float = InputField(default=0.0, description="")
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput: def invoke(self, context: InvocationContext) -> ConditioningOutput:
@ -135,6 +137,8 @@ class CompelInvocation(BaseInvocation):
conditioning_name=conditioning_name, conditioning_name=conditioning_name,
mask=self.mask, mask=self.mask,
positive_cross_attn_mask_score=self.positive_cross_attn_mask_score, positive_cross_attn_mask_score=self.positive_cross_attn_mask_score,
positive_self_attn_mask_score=self.positive_self_attn_mask_score,
self_attn_adjustment_end_step_percent=self.self_attn_adjustment_end_step_percent,
) )
) )
@ -278,6 +282,8 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
default=None, description="A mask defining the region that this conditioning prompt applies to." default=None, description="A mask defining the region that this conditioning prompt applies to."
) )
positive_cross_attn_mask_score: float = InputField(default=0.0, description="") positive_cross_attn_mask_score: float = InputField(default=0.0, description="")
positive_self_attn_mask_score: float = InputField(default=1.0, description="")
self_attn_adjustment_end_step_percent: float = InputField(default=0.0, description="")
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput: def invoke(self, context: InvocationContext) -> ConditioningOutput:
@ -345,6 +351,8 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
conditioning_name=conditioning_name, conditioning_name=conditioning_name,
mask=self.mask, mask=self.mask,
positive_cross_attn_mask_score=self.positive_cross_attn_mask_score, positive_cross_attn_mask_score=self.positive_cross_attn_mask_score,
positive_self_attn_mask_score=self.positive_self_attn_mask_score,
self_attn_adjustment_end_step_percent=self.self_attn_adjustment_end_step_percent,
) )
) )

View File

@ -241,6 +241,8 @@ class ConditioningField(BaseModel):
# TODO(ryand): Add more details to this description # TODO(ryand): Add more details to this description
description="The weight of this conditioning tensor's mask relative to overlapping masks.", description="The weight of this conditioning tensor's mask relative to overlapping masks.",
) )
positive_self_attn_mask_score: float = Field(default=1.0, description="")
self_attn_adjustment_end_step_percent: float = Field(default=0.0, description="")
class MetadataField(RootModel): class MetadataField(RootModel):

View File

@ -391,22 +391,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
def _get_text_embeddings_and_masks( def _get_text_embeddings_and_masks(
self, self,
cond_field: Union[ConditioningField, list[ConditioningField]], cond_list: list[ConditioningField],
context: InvocationContext, context: InvocationContext,
device: torch.device, device: torch.device,
dtype: torch.dtype, dtype: torch.dtype,
) -> tuple[ ) -> tuple[Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], list[Optional[torch.Tensor]]]:
Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], list[Optional[torch.Tensor]], list[float]
]:
"""Get the text embeddings and masks from the input conditioning fields.""" """Get the text embeddings and masks from the input conditioning fields."""
# Normalize cond_field to a list.
cond_list = cond_field
if not isinstance(cond_list, list):
cond_list = [cond_list]
text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = [] text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = []
text_embeddings_masks: list[Optional[torch.Tensor]] = [] text_embeddings_masks: list[Optional[torch.Tensor]] = []
positive_cross_attn_mask_scores: list[float] = []
for cond in cond_list: for cond in cond_list:
cond_data = context.conditioning.load(cond.conditioning_name) cond_data = context.conditioning.load(cond.conditioning_name)
text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype)) text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype))
@ -416,8 +408,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
mask = context.tensors.load(mask.mask_name) mask = context.tensors.load(mask.mask_name)
text_embeddings_masks.append(mask) text_embeddings_masks.append(mask)
positive_cross_attn_mask_scores.append(cond.positive_cross_attn_mask_score) return text_embeddings, text_embeddings_masks
return text_embeddings, text_embeddings_masks, positive_cross_attn_mask_scores
def _preprocess_regional_prompt_mask( def _preprocess_regional_prompt_mask(
self, mask: Optional[torch.Tensor], target_height: int, target_width: int self, mask: Optional[torch.Tensor], target_height: int, target_width: int
@ -445,7 +436,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
self, self,
text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]],
masks: Optional[list[Optional[torch.Tensor]]], masks: Optional[list[Optional[torch.Tensor]]],
positive_cross_attn_mask_scores: list[float], conditioning_fields: list[ConditioningField],
latent_height: int, latent_height: int,
latent_width: int, latent_width: int,
) -> tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[TextConditioningRegions]]: ) -> tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[TextConditioningRegions]]:
@ -466,7 +457,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
embedding_ranges = [] embedding_ranges = []
extra_conditioning = None extra_conditioning = None
for text_embedding_info, mask in zip(text_conditionings, masks, strict=True): for prompt_idx, text_embedding_info in enumerate(text_conditionings):
mask = masks[prompt_idx]
if ( if (
text_embedding_info.extra_conditioning is not None text_embedding_info.extra_conditioning is not None
and text_embedding_info.extra_conditioning.wants_cross_attention_control and text_embedding_info.extra_conditioning.wants_cross_attention_control
@ -508,7 +500,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
regions = TextConditioningRegions( regions = TextConditioningRegions(
masks=torch.cat(processed_masks, dim=1), masks=torch.cat(processed_masks, dim=1),
ranges=embedding_ranges, ranges=embedding_ranges,
positive_cross_attn_mask_scores=positive_cross_attn_mask_scores, positive_cross_attn_mask_scores=[x.positive_cross_attn_mask_score for x in conditioning_fields],
positive_self_attn_mask_scores=[x.positive_self_attn_mask_score for x in conditioning_fields],
self_attn_adjustment_end_step_percents=[
x.self_attn_adjustment_end_step_percent for x in conditioning_fields
],
) )
if extra_conditioning is not None and len(text_conditionings) > 1: if extra_conditioning is not None and len(text_conditionings) > 1:
@ -536,27 +532,32 @@ class DenoiseLatentsInvocation(BaseInvocation):
latent_height: int, latent_height: int,
latent_width: int, latent_width: int,
) -> TextConditioningData: ) -> TextConditioningData:
( # Normalize self.positive_conditioning and self.negative_conditioning to lists.
cond_text_embeddings, cond_list = self.positive_conditioning
cond_text_embedding_masks, if not isinstance(cond_list, list):
cond_positive_cross_attn_mask_scores, cond_list = [cond_list]
) = self._get_text_embeddings_and_masks(self.positive_conditioning, context, unet.device, unet.dtype) uncond_list = self.negative_conditioning
( if not isinstance(uncond_list, list):
uncond_text_embeddings, uncond_list = [uncond_list]
uncond_text_embedding_masks,
uncond_positive_cross_attn_mask_scores, cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks(
) = self._get_text_embeddings_and_masks(self.negative_conditioning, context, unet.device, unet.dtype) cond_list, context, unet.device, unet.dtype
)
uncond_text_embeddings, uncond_text_embedding_masks = self._get_text_embeddings_and_masks(
uncond_list, context, unet.device, unet.dtype
)
cond_text_embedding, cond_regions = self.concat_regional_text_embeddings( cond_text_embedding, cond_regions = self.concat_regional_text_embeddings(
text_conditionings=cond_text_embeddings, text_conditionings=cond_text_embeddings,
masks=cond_text_embedding_masks, masks=cond_text_embedding_masks,
positive_cross_attn_mask_scores=cond_positive_cross_attn_mask_scores, conditioning_fields=cond_list,
latent_height=latent_height, latent_height=latent_height,
latent_width=latent_width, latent_width=latent_width,
) )
uncond_text_embedding, uncond_regions = self.concat_regional_text_embeddings( uncond_text_embedding, uncond_regions = self.concat_regional_text_embeddings(
text_conditionings=uncond_text_embeddings, text_conditionings=uncond_text_embeddings,
masks=uncond_text_embedding_masks, masks=uncond_text_embedding_masks,
positive_cross_attn_mask_scores=uncond_positive_cross_attn_mask_scores, conditioning_fields=uncond_list,
latent_height=latent_height, latent_height=latent_height,
latent_width=latent_width, latent_width=latent_width,
) )

View File

@ -70,7 +70,14 @@ class Range:
class TextConditioningRegions: class TextConditioningRegions:
def __init__(self, masks: torch.Tensor, ranges: list[Range], positive_cross_attn_mask_scores: list[float]): def __init__(
self,
masks: torch.Tensor,
ranges: list[Range],
positive_cross_attn_mask_scores: list[float],
positive_self_attn_mask_scores: list[float],
self_attn_adjustment_end_step_percents: list[float],
):
# A binary mask indicating the regions of the image that the prompt should be applied to. # A binary mask indicating the regions of the image that the prompt should be applied to.
# Shape: (1, num_prompts, height, width) # Shape: (1, num_prompts, height, width)
# Dtype: torch.bool # Dtype: torch.bool
@ -80,12 +87,19 @@ class TextConditioningRegions:
# ranges[i] contains the embedding range for the i'th prompt / mask. # ranges[i] contains the embedding range for the i'th prompt / mask.
self.ranges = ranges self.ranges = ranges
# A list of positive cross attention mask scores for each prompt. # The following fields control the cross-attention and self-attention handling of region masks. They are indexed
# positive_cross_attn_mask_scores[i] contains the positive cross attention mask score for the i'th prompt/mask. # by prompt / mask index.
self.positive_cross_attn_mask_scores = positive_cross_attn_mask_scores self.positive_cross_attn_mask_scores = positive_cross_attn_mask_scores
self.positive_self_attn_mask_scores = positive_self_attn_mask_scores
self.self_attn_adjustment_end_step_percents = self_attn_adjustment_end_step_percents
assert self.masks.shape[1] == len(self.ranges) assert (
assert self.masks.shape[1] == len(self.positive_cross_attn_mask_scores) self.masks.shape[1]
== len(self.ranges)
== len(self.positive_cross_attn_mask_scores)
== len(self.positive_self_attn_mask_scores)
== len(self.self_attn_adjustment_end_step_percents)
)
class TextConditioningData: class TextConditioningData:
@ -102,23 +116,6 @@ class TextConditioningData:
self.cond_text = cond_text self.cond_text = cond_text
self.uncond_regions = uncond_regions self.uncond_regions = uncond_regions
self.cond_regions = cond_regions self.cond_regions = cond_regions
# All params:
# negative_cross_attn_mask_score: -10000 (recommended to leave this as -10000 to prevent leakage to the rest of the image)
# positive_cross_attn_mask_score: 0.0 (relative weightin of masks)
# positive_self_attn_mask_score: 0.3
# negative_self_attn_mask_score: This doesn't really make sense. It would effectively have the same effect as further increasing positive_self_attn_mask_score.
# cross_attn_start_step
# self_attn_mask_begin_step_percent: 0.0
# self_attn_mask_end_step percent: 0.5
# Should we allow cross_attn_mask_begin_step_percent and cross_attn_mask_end_step_percent? Probably not, this seems like more control than necessary. And easy to add in the future.
self.negative_cross_attn_mask_score = -10000
self.positive_cross_attn_mask_score = 0.0
self.positive_self_attn_mask_score = 0.3
self.self_attn_mask_end_step_percent = 0.5
# mask_weight: float = Field(
# default=1.0,
# description="The weight to apply to the mask. This weight controls the relative weighting of overlapping masks. This weight gets added to the attention map logits before applying a pixelwise softmax.",
# )
# Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). # Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
# `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). # `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
# Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate # Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate

View File

@ -41,16 +41,7 @@ class RegionalPromptData:
regions, max_downscale_factor regions, max_downscale_factor
) )
# TODO: These should be indexed by batch sample index and prompt index.
# Next:
# - Add support for setting these one nodes. Might just need positive cross-attention mask score. Being able to downweight the global prompt mighth help alot.
# - Scale by region size.
self.negative_cross_attn_mask_score = -10000 self.negative_cross_attn_mask_score = -10000
# self.positive_cross_attn_mask_score = 0.0
self.positive_self_attn_mask_score = 1.0
self.self_attn_mask_end_step_percent = 0.3
# This one is for regional prompting in general, so should be set on the DenoiseLatents node.
self.self_attn_score_range = 3.0
def _prepare_spatial_masks( def _prepare_spatial_masks(
self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8 self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8
@ -222,20 +213,23 @@ class RegionalPromptData:
for batch_idx in range(batch_size): for batch_idx in range(batch_size):
batch_sample_spatial_masks = batch_spatial_masks[batch_idx] batch_sample_spatial_masks = batch_spatial_masks[batch_idx]
batch_sample_regions = self._regions[batch_idx]
# Flatten the spatial dimensions of the mask by reshaping to (1, num_prompts, query_seq_len, 1). # Flatten the spatial dimensions of the mask by reshaping to (1, num_prompts, query_seq_len, 1).
_, num_prompts, _, _ = batch_sample_spatial_masks.shape _, num_prompts, _, _ = batch_sample_spatial_masks.shape
batch_sample_query_masks = batch_sample_spatial_masks.view((1, num_prompts, query_seq_len, 1)) batch_sample_query_masks = batch_sample_spatial_masks.view((1, num_prompts, query_seq_len, 1))
for prompt_idx in range(num_prompts): for prompt_idx in range(num_prompts):
if percent_through > self.self_attn_mask_end_step_percent: if percent_through > batch_sample_regions.self_attn_adjustment_end_step_percents[prompt_idx]:
continue continue
prompt_query_mask = batch_sample_query_masks[0, prompt_idx, :, 0] # Shape: (query_seq_len,) prompt_query_mask = batch_sample_query_masks[0, prompt_idx, :, 0] # Shape: (query_seq_len,)
# Multiply a (1, query_seq_len) mask by a (query_seq_len, 1) mask to get a (query_seq_len, # Multiply a (1, query_seq_len) mask by a (query_seq_len, 1) mask to get a (query_seq_len,
# query_seq_len) mask. # query_seq_len) mask.
# TODO(ryand): Is += really the best option here? # TODO(ryand): Is += really the best option here?
attn_mask[batch_idx, :, :] += ( attn_mask[batch_idx, :, :] += (
prompt_query_mask.unsqueeze(0) * prompt_query_mask.unsqueeze(1) * self.positive_self_attn_mask_score prompt_query_mask.unsqueeze(0)
* prompt_query_mask.unsqueeze(1)
* batch_sample_regions.positive_self_attn_mask_scores[prompt_idx]
) )
# attn_mask_min = attn_mask[batch_idx].min() # attn_mask_min = attn_mask[batch_idx].min()

View File

@ -346,6 +346,8 @@ class InvokeAIDiffuserComponent:
masks=torch.ones((1, 1, h, w), dtype=torch.bool), masks=torch.ones((1, 1, h, w), dtype=torch.bool),
ranges=[Range(start=0, end=c.embeds.shape[1])], ranges=[Range(start=0, end=c.embeds.shape[1])],
positive_cross_attn_mask_scores=[0.0], positive_cross_attn_mask_scores=[0.0],
positive_self_attn_mask_scores=[0.0],
self_attn_adjustment_end_step_percents=[0.0],
) )
regions.append(r) regions.append(r)