Add ability to control regional prompt region weights.

This commit is contained in:
Ryan Dick 2024-03-03 12:55:07 -05:00
parent ad18429fe3
commit 5fad379192
6 changed files with 46 additions and 17 deletions

View File

@ -21,9 +21,11 @@ class AddConditioningMaskInvocation(BaseInvocation):
conditioning: ConditioningField = InputField(description="The conditioning tensor to add a mask to.") conditioning: ConditioningField = InputField(description="The conditioning tensor to add a mask to.")
mask: MaskField = InputField(description="A mask to add to the conditioning tensor.") mask: MaskField = InputField(description="A mask to add to the conditioning tensor.")
positive_cross_attn_mask_score: float = InputField(default=0.0, description="")
def invoke(self, context: InvocationContext) -> ConditioningOutput: def invoke(self, context: InvocationContext) -> ConditioningOutput:
self.conditioning.mask = self.mask self.conditioning.mask = self.mask
self.conditioning.positive_cross_attn_mask_score = self.positive_cross_attn_mask_score
return ConditioningOutput(conditioning=self.conditioning) return ConditioningOutput(conditioning=self.conditioning)

View File

@ -236,6 +236,11 @@ class ConditioningField(BaseModel):
description="The bool mask associated with this conditioning tensor. Excluded regions should be set to False, " description="The bool mask associated with this conditioning tensor. Excluded regions should be set to False, "
"included regions should be set to True.", "included regions should be set to True.",
) )
positive_cross_attn_mask_score: float = Field(
default=0.0,
# TODO(ryand): Add more details to this description
description="The weight of this conditioning tensor's mask relative to overlapping masks.",
)
class MetadataField(RootModel): class MetadataField(RootModel):

View File

@ -381,7 +381,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
context: InvocationContext, context: InvocationContext,
device: torch.device, device: torch.device,
dtype: torch.dtype, dtype: torch.dtype,
) -> tuple[Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], list[Optional[torch.Tensor]]]: ) -> tuple[
Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], list[Optional[torch.Tensor]], list[float]
]:
"""Get the text embeddings and masks from the input conditioning fields.""" """Get the text embeddings and masks from the input conditioning fields."""
# Normalize cond_field to a list. # Normalize cond_field to a list.
cond_list = cond_field cond_list = cond_field
@ -390,6 +392,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = [] text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = []
text_embeddings_masks: list[Optional[torch.Tensor]] = [] text_embeddings_masks: list[Optional[torch.Tensor]] = []
positive_cross_attn_mask_scores: list[float] = []
for cond in cond_list: for cond in cond_list:
cond_data = context.conditioning.load(cond.conditioning_name) cond_data = context.conditioning.load(cond.conditioning_name)
text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype)) text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype))
@ -399,7 +402,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
mask = context.tensors.load(mask.mask_name) mask = context.tensors.load(mask.mask_name)
text_embeddings_masks.append(mask) text_embeddings_masks.append(mask)
return text_embeddings, text_embeddings_masks positive_cross_attn_mask_scores.append(cond.positive_cross_attn_mask_score)
return text_embeddings, text_embeddings_masks, positive_cross_attn_mask_scores
def _preprocess_regional_prompt_mask( def _preprocess_regional_prompt_mask(
self, mask: Optional[torch.Tensor], target_height: int, target_width: int self, mask: Optional[torch.Tensor], target_height: int, target_width: int
@ -427,6 +431,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
self, self,
text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]],
masks: Optional[list[Optional[torch.Tensor]]], masks: Optional[list[Optional[torch.Tensor]]],
positive_cross_attn_mask_scores: list[float],
latent_height: int, latent_height: int,
latent_width: int, latent_width: int,
) -> tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[TextConditioningRegions]]: ) -> tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[TextConditioningRegions]]:
@ -486,7 +491,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
regions = None regions = None
if not all_masks_are_none: if not all_masks_are_none:
regions = TextConditioningRegions(masks=torch.cat(processed_masks, dim=1), ranges=embedding_ranges) regions = TextConditioningRegions(
masks=torch.cat(processed_masks, dim=1),
ranges=embedding_ranges,
positive_cross_attn_mask_scores=positive_cross_attn_mask_scores,
)
if extra_conditioning is not None and len(text_conditionings) > 1: if extra_conditioning is not None and len(text_conditionings) > 1:
raise ValueError( raise ValueError(
@ -513,21 +522,27 @@ class DenoiseLatentsInvocation(BaseInvocation):
latent_height: int, latent_height: int,
latent_width: int, latent_width: int,
) -> TextConditioningData: ) -> TextConditioningData:
cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks( (
self.positive_conditioning, context, unet.device, unet.dtype cond_text_embeddings,
) cond_text_embedding_masks,
uncond_text_embeddings, uncond_text_embedding_masks = self._get_text_embeddings_and_masks( cond_positive_cross_attn_mask_scores,
self.negative_conditioning, context, unet.device, unet.dtype ) = self._get_text_embeddings_and_masks(self.positive_conditioning, context, unet.device, unet.dtype)
) (
uncond_text_embeddings,
uncond_text_embedding_masks,
uncond_positive_cross_attn_mask_scores,
) = self._get_text_embeddings_and_masks(self.negative_conditioning, context, unet.device, unet.dtype)
cond_text_embedding, cond_regions = self.concat_regional_text_embeddings( cond_text_embedding, cond_regions = self.concat_regional_text_embeddings(
text_conditionings=cond_text_embeddings, text_conditionings=cond_text_embeddings,
masks=cond_text_embedding_masks, masks=cond_text_embedding_masks,
positive_cross_attn_mask_scores=cond_positive_cross_attn_mask_scores,
latent_height=latent_height, latent_height=latent_height,
latent_width=latent_width, latent_width=latent_width,
) )
uncond_text_embedding, uncond_regions = self.concat_regional_text_embeddings( uncond_text_embedding, uncond_regions = self.concat_regional_text_embeddings(
text_conditionings=uncond_text_embeddings, text_conditionings=uncond_text_embeddings,
masks=uncond_text_embedding_masks, masks=uncond_text_embedding_masks,
positive_cross_attn_mask_scores=uncond_positive_cross_attn_mask_scores,
latent_height=latent_height, latent_height=latent_height,
latent_width=latent_width, latent_width=latent_width,
) )

View File

@ -70,7 +70,7 @@ class Range:
class TextConditioningRegions: class TextConditioningRegions:
def __init__(self, masks: torch.Tensor, ranges: list[Range]): def __init__(self, masks: torch.Tensor, ranges: list[Range], positive_cross_attn_mask_scores: list[float]):
# A binary mask indicating the regions of the image that the prompt should be applied to. # A binary mask indicating the regions of the image that the prompt should be applied to.
# Shape: (1, num_prompts, height, width) # Shape: (1, num_prompts, height, width)
# Dtype: torch.bool # Dtype: torch.bool
@ -80,7 +80,12 @@ class TextConditioningRegions:
# ranges[i] contains the embedding range for the i'th prompt / mask. # ranges[i] contains the embedding range for the i'th prompt / mask.
self.ranges = ranges self.ranges = ranges
# A list of positive cross attention mask scores for each prompt.
# positive_cross_attn_mask_scores[i] contains the positive cross attention mask score for the i'th prompt/mask.
self.positive_cross_attn_mask_scores = positive_cross_attn_mask_scores
assert self.masks.shape[1] == len(self.ranges) assert self.masks.shape[1] == len(self.ranges)
assert self.masks.shape[1] == len(self.positive_cross_attn_mask_scores)
class TextConditioningData: class TextConditioningData:

View File

@ -46,7 +46,7 @@ class RegionalPromptData:
# - Add support for setting these one nodes. Might just need positive cross-attention mask score. Being able to downweight the global prompt mighth help alot. # - Add support for setting these one nodes. Might just need positive cross-attention mask score. Being able to downweight the global prompt mighth help alot.
# - Scale by region size. # - Scale by region size.
self.negative_cross_attn_mask_score = -10000 self.negative_cross_attn_mask_score = -10000
self.positive_cross_attn_mask_score = 0.0 # self.positive_cross_attn_mask_score = 0.0
self.positive_self_attn_mask_score = 2.0 self.positive_self_attn_mask_score = 2.0
self.self_attn_mask_end_step_percent = 0.3 self.self_attn_mask_end_step_percent = 0.3
# This one is for regional prompting in general, so should be set on the DenoiseLatents node. # This one is for regional prompting in general, so should be set on the DenoiseLatents node.
@ -186,13 +186,14 @@ class RegionalPromptData:
batch_sample_query_masks = batch_sample_spatial_masks.view((1, num_prompts, query_seq_len, 1)) batch_sample_query_masks = batch_sample_spatial_masks.view((1, num_prompts, query_seq_len, 1))
for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges): for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges):
attn_mask[batch_idx, :, embedding_range.start : embedding_range.end] = batch_sample_query_masks[ batch_sample_query_scores = batch_sample_query_masks[0, prompt_idx, :, :].clone()
0, prompt_idx, :, : batch_sample_query_mask = batch_sample_query_scores > 0.5
] batch_sample_query_scores[
batch_sample_query_mask
] = batch_sample_regions.positive_cross_attn_mask_scores[prompt_idx]
batch_sample_query_scores[~batch_sample_query_mask] = self.negative_cross_attn_mask_score # TODO(ryand)
attn_mask[batch_idx, :, embedding_range.start : embedding_range.end] = batch_sample_query_scores
pos_mask = attn_mask >= 0.5
attn_mask[~pos_mask] = self.negative_cross_attn_mask_score
attn_mask[pos_mask] = self.positive_cross_attn_mask_score
return attn_mask return attn_mask
def get_self_attn_mask( def get_self_attn_mask(

View File

@ -345,6 +345,7 @@ class InvokeAIDiffuserComponent:
r = TextConditioningRegions( r = TextConditioningRegions(
masks=torch.ones((1, 1, h, w), dtype=torch.bool), masks=torch.ones((1, 1, h, w), dtype=torch.bool),
ranges=[Range(start=0, end=c.embeds.shape[1])], ranges=[Range(start=0, end=c.embeds.shape[1])],
positive_cross_attn_mask_scores=[0.0],
) )
regions.append(r) regions.append(r)