Add support for mask weights, and only mask the tokens associated with the prompts (not eh entire 77-token embedding).

This commit is contained in:
Ryan Dick 2024-03-07 14:30:51 -05:00
parent 969982b789
commit ff950bc5cd
6 changed files with 42 additions and 55 deletions

View File

@ -44,7 +44,7 @@ from .model import ClipField
title="Prompt", title="Prompt",
tags=["prompt", "compel"], tags=["prompt", "compel"],
category="conditioning", category="conditioning",
version="1.1.0", version="1.2.0",
) )
class CompelInvocation(BaseInvocation): class CompelInvocation(BaseInvocation):
"""Parse prompt using compel package to conditioning.""" """Parse prompt using compel package to conditioning."""
@ -62,9 +62,7 @@ class CompelInvocation(BaseInvocation):
mask: Optional[MaskField] = InputField( mask: Optional[MaskField] = InputField(
default=None, description="A mask defining the region that this conditioning prompt applies to." default=None, description="A mask defining the region that this conditioning prompt applies to."
) )
positive_cross_attn_mask_score: float = InputField(default=0.0, description="") mask_weight: float = InputField(default=1.0, description="")
positive_self_attn_mask_score: float = InputField(default=1.0, description="")
self_attn_adjustment_end_step_percent: float = InputField(default=0.0, description="")
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput: def invoke(self, context: InvocationContext) -> ConditioningOutput:
@ -136,9 +134,7 @@ class CompelInvocation(BaseInvocation):
conditioning=ConditioningField( conditioning=ConditioningField(
conditioning_name=conditioning_name, conditioning_name=conditioning_name,
mask=self.mask, mask=self.mask,
positive_cross_attn_mask_score=self.positive_cross_attn_mask_score, mask_weight=self.mask_weight,
positive_self_attn_mask_score=self.positive_self_attn_mask_score,
self_attn_adjustment_end_step_percent=self.self_attn_adjustment_end_step_percent,
) )
) )
@ -254,7 +250,7 @@ class SDXLPromptInvocationBase:
title="SDXL Prompt", title="SDXL Prompt",
tags=["sdxl", "compel", "prompt"], tags=["sdxl", "compel", "prompt"],
category="conditioning", category="conditioning",
version="1.1.0", version="1.2.0",
) )
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning.""" """Parse prompt using compel package to conditioning."""
@ -281,9 +277,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
mask: Optional[MaskField] = InputField( mask: Optional[MaskField] = InputField(
default=None, description="A mask defining the region that this conditioning prompt applies to." default=None, description="A mask defining the region that this conditioning prompt applies to."
) )
positive_cross_attn_mask_score: float = InputField(default=0.0, description="") mask_weight: float = InputField(default=1.0, description="")
positive_self_attn_mask_score: float = InputField(default=1.0, description="")
self_attn_adjustment_end_step_percent: float = InputField(default=0.0, description="")
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput: def invoke(self, context: InvocationContext) -> ConditioningOutput:
@ -350,9 +344,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
conditioning=ConditioningField( conditioning=ConditioningField(
conditioning_name=conditioning_name, conditioning_name=conditioning_name,
mask=self.mask, mask=self.mask,
positive_cross_attn_mask_score=self.positive_cross_attn_mask_score, mask_weight=self.mask_weight,
positive_self_attn_mask_score=self.positive_self_attn_mask_score,
self_attn_adjustment_end_step_percent=self.self_attn_adjustment_end_step_percent,
) )
) )
@ -403,7 +395,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
conditioning_name = context.conditioning.save(conditioning_data) conditioning_name = context.conditioning.save(conditioning_data)
return ConditioningOutput(conditioning=ConditioningField(conditioning_name=conditioning_name)) return ConditioningOutput(conditioning=ConditioningField(conditioning_name=conditioning_name, mask_weight=1.0))
@invocation_output("clip_skip_output") @invocation_output("clip_skip_output")

View File

@ -236,13 +236,7 @@ class ConditioningField(BaseModel):
description="The bool mask associated with this conditioning tensor. Excluded regions should be set to False, " description="The bool mask associated with this conditioning tensor. Excluded regions should be set to False, "
"included regions should be set to True.", "included regions should be set to True.",
) )
positive_cross_attn_mask_score: float = Field( mask_weight: float = Field(description="")
default=0.0,
# TODO(ryand): Add more details to this description
description="The weight of this conditioning tensor's mask relative to overlapping masks.",
)
positive_self_attn_mask_score: float = Field(default=1.0, description="")
self_attn_adjustment_end_step_percent: float = Field(default=0.0, description="")
class MetadataField(RootModel): class MetadataField(RootModel):

View File

@ -466,9 +466,18 @@ class DenoiseLatentsInvocation(BaseInvocation):
text_embedding.append(text_embedding_info.embeds) text_embedding.append(text_embedding_info.embeds)
if not all_masks_are_none: if not all_masks_are_none:
# embedding_ranges.append(
# Range(
# start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1]
# )
# )
# HACK(ryand): Contrary to its name, tokens_count_including_eos_bos does not seem to include eos and bos
# in the count.
embedding_ranges.append( embedding_ranges.append(
Range( Range(
start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1] start=cur_text_embedding_len + 1,
end=cur_text_embedding_len
+ text_embedding_info.extra_conditioning.tokens_count_including_eos_bos,
) )
) )
processed_masks.append(self._preprocess_regional_prompt_mask(mask, latent_height, latent_width)) processed_masks.append(self._preprocess_regional_prompt_mask(mask, latent_height, latent_width))
@ -483,11 +492,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
regions = TextConditioningRegions( regions = TextConditioningRegions(
masks=torch.cat(processed_masks, dim=1), masks=torch.cat(processed_masks, dim=1),
ranges=embedding_ranges, ranges=embedding_ranges,
positive_cross_attn_mask_scores=[x.positive_cross_attn_mask_score for x in conditioning_fields], mask_weights=[x.mask_weight for x in conditioning_fields],
positive_self_attn_mask_scores=[x.positive_self_attn_mask_score for x in conditioning_fields],
self_attn_adjustment_end_step_percents=[
x.self_attn_adjustment_end_step_percent for x in conditioning_fields
],
) )
if extra_conditioning is not None and len(text_conditionings) > 1: if extra_conditioning is not None and len(text_conditionings) > 1:

View File

@ -74,9 +74,7 @@ class TextConditioningRegions:
self, self,
masks: torch.Tensor, masks: torch.Tensor,
ranges: list[Range], ranges: list[Range],
positive_cross_attn_mask_scores: list[float], mask_weights: list[float],
positive_self_attn_mask_scores: list[float],
self_attn_adjustment_end_step_percents: list[float],
): ):
# A binary mask indicating the regions of the image that the prompt should be applied to. # A binary mask indicating the regions of the image that the prompt should be applied to.
# Shape: (1, num_prompts, height, width) # Shape: (1, num_prompts, height, width)
@ -87,19 +85,9 @@ class TextConditioningRegions:
# ranges[i] contains the embedding range for the i'th prompt / mask. # ranges[i] contains the embedding range for the i'th prompt / mask.
self.ranges = ranges self.ranges = ranges
# The following fields control the cross-attention and self-attention handling of region masks. They are indexed self.mask_weights = mask_weights
# by prompt / mask index.
self.positive_cross_attn_mask_scores = positive_cross_attn_mask_scores
self.positive_self_attn_mask_scores = positive_self_attn_mask_scores
self.self_attn_adjustment_end_step_percents = self_attn_adjustment_end_step_percents
assert ( assert self.masks.shape[1] == len(self.ranges) == len(self.mask_weights)
self.masks.shape[1]
== len(self.ranges)
== len(self.positive_cross_attn_mask_scores)
== len(self.positive_self_attn_mask_scores)
== len(self.self_attn_adjustment_end_step_percents)
)
class TextConditioningData: class TextConditioningData:

View File

@ -33,6 +33,7 @@ class RegionalPromptData:
regions, max_downscale_factor regions, max_downscale_factor
) )
self._negative_cross_attn_mask_score = 0.0 self._negative_cross_attn_mask_score = 0.0
self._size_weight = 1.0
def _prepare_spatial_masks( def _prepare_spatial_masks(
self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8 self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8
@ -96,13 +97,16 @@ class RegionalPromptData:
batch_sample_query_masks = batch_sample_spatial_masks.view((1, num_prompts, query_seq_len, 1)) batch_sample_query_masks = batch_sample_spatial_masks.view((1, num_prompts, query_seq_len, 1))
for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges): for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges):
batch_sample_query_scores = batch_sample_query_masks[0, prompt_idx, :, :].clone() batch_sample_query_scores = batch_sample_query_masks[0, prompt_idx, :, :]
size = batch_sample_query_scores.sum() / batch_sample_query_scores.numel() size = batch_sample_query_scores.sum() / batch_sample_query_scores.numel()
size = size.to(dtype=batch_sample_query_scores.dtype) mask_weight = batch_sample_regions.mask_weights[prompt_idx]
batch_sample_query_mask = batch_sample_query_scores > 0.5 # size = size.to(dtype=batch_sample_query_scores.dtype)
batch_sample_query_scores[batch_sample_query_mask] = 1.0 * (1.0 - size) # batch_sample_query_mask = batch_sample_query_scores > 0.5
batch_sample_query_scores[~batch_sample_query_mask] = 0.0 # batch_sample_query_scores[batch_sample_query_mask] = 1.0 * (1.0 - size)
attn_mask[batch_idx, :, embedding_range.start : embedding_range.end] = batch_sample_query_scores # batch_sample_query_scores[~batch_sample_query_mask] = 0.0
attn_mask[batch_idx, :, embedding_range.start : embedding_range.end] = batch_sample_query_scores * (
mask_weight + self._size_weight * (1 - size)
)
return attn_mask return attn_mask
@ -136,11 +140,15 @@ class RegionalPromptData:
prompt_query_mask = batch_sample_query_masks[0, prompt_idx, :, 0] # Shape: (query_seq_len,) prompt_query_mask = batch_sample_query_masks[0, prompt_idx, :, 0] # Shape: (query_seq_len,)
size = prompt_query_mask.sum() / prompt_query_mask.numel() size = prompt_query_mask.sum() / prompt_query_mask.numel()
size = size.to(dtype=prompt_query_mask.dtype) size = size.to(dtype=prompt_query_mask.dtype)
mask_weight = batch_sample_regions.mask_weights[prompt_idx]
# Multiply a (1, query_seq_len) mask by a (query_seq_len, 1) mask to get a (query_seq_len, # Multiply a (1, query_seq_len) mask by a (query_seq_len, 1) mask to get a (query_seq_len,
# query_seq_len) mask. # query_seq_len) mask.
# TODO(ryand): Is += really the best option here? # TODO(ryand): Is += really the best option here? Maybe elementwise max is better?
attn_mask[batch_idx, :, :] += ( attn_mask[batch_idx, :, :] = torch.maximum(
prompt_query_mask.unsqueeze(0) * prompt_query_mask.unsqueeze(1) * (1 - size) attn_mask[batch_idx, :, :],
prompt_query_mask.unsqueeze(0)
* prompt_query_mask.unsqueeze(1)
* (mask_weight + self._size_weight * (1 - size)),
) )
# if attn_mask[batch_idx].max() < 0.01: # if attn_mask[batch_idx].max() < 0.01:

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import math import math
import time
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional, Union
@ -345,9 +346,7 @@ class InvokeAIDiffuserComponent:
r = TextConditioningRegions( r = TextConditioningRegions(
masks=torch.ones((1, 1, h, w), dtype=torch.bool), masks=torch.ones((1, 1, h, w), dtype=torch.bool),
ranges=[Range(start=0, end=c.embeds.shape[1])], ranges=[Range(start=0, end=c.embeds.shape[1])],
positive_cross_attn_mask_scores=[0.0], mask_weights=[0.0],
positive_self_attn_mask_scores=[0.0],
self_attn_adjustment_end_step_percents=[0.0],
) )
regions.append(r) regions.append(r)
@ -355,6 +354,7 @@ class InvokeAIDiffuserComponent:
regions=regions, device=x.device, dtype=x.dtype regions=regions, device=x.device, dtype=x.dtype
) )
cross_attention_kwargs["percent_through"] = percent_through cross_attention_kwargs["percent_through"] = percent_through
time.sleep(1.0)
both_results = self.model_forward_callback( both_results = self.model_forward_callback(
x_twice, x_twice,