Add support for mask weights, and only mask the tokens associated with the prompts (not eh entire 77-token embedding).

This commit is contained in:
Ryan Dick 2024-03-07 14:30:51 -05:00
parent 969982b789
commit ff950bc5cd
6 changed files with 42 additions and 55 deletions

View File

@ -44,7 +44,7 @@ from .model import ClipField
title="Prompt",
tags=["prompt", "compel"],
category="conditioning",
version="1.1.0",
version="1.2.0",
)
class CompelInvocation(BaseInvocation):
"""Parse prompt using compel package to conditioning."""
@ -62,9 +62,7 @@ class CompelInvocation(BaseInvocation):
mask: Optional[MaskField] = InputField(
default=None, description="A mask defining the region that this conditioning prompt applies to."
)
positive_cross_attn_mask_score: float = InputField(default=0.0, description="")
positive_self_attn_mask_score: float = InputField(default=1.0, description="")
self_attn_adjustment_end_step_percent: float = InputField(default=0.0, description="")
mask_weight: float = InputField(default=1.0, description="")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
@ -136,9 +134,7 @@ class CompelInvocation(BaseInvocation):
conditioning=ConditioningField(
conditioning_name=conditioning_name,
mask=self.mask,
positive_cross_attn_mask_score=self.positive_cross_attn_mask_score,
positive_self_attn_mask_score=self.positive_self_attn_mask_score,
self_attn_adjustment_end_step_percent=self.self_attn_adjustment_end_step_percent,
mask_weight=self.mask_weight,
)
)
@ -254,7 +250,7 @@ class SDXLPromptInvocationBase:
title="SDXL Prompt",
tags=["sdxl", "compel", "prompt"],
category="conditioning",
version="1.1.0",
version="1.2.0",
)
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning."""
@ -281,9 +277,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
mask: Optional[MaskField] = InputField(
default=None, description="A mask defining the region that this conditioning prompt applies to."
)
positive_cross_attn_mask_score: float = InputField(default=0.0, description="")
positive_self_attn_mask_score: float = InputField(default=1.0, description="")
self_attn_adjustment_end_step_percent: float = InputField(default=0.0, description="")
mask_weight: float = InputField(default=1.0, description="")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
@ -350,9 +344,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
conditioning=ConditioningField(
conditioning_name=conditioning_name,
mask=self.mask,
positive_cross_attn_mask_score=self.positive_cross_attn_mask_score,
positive_self_attn_mask_score=self.positive_self_attn_mask_score,
self_attn_adjustment_end_step_percent=self.self_attn_adjustment_end_step_percent,
mask_weight=self.mask_weight,
)
)
@ -403,7 +395,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
conditioning_name = context.conditioning.save(conditioning_data)
return ConditioningOutput(conditioning=ConditioningField(conditioning_name=conditioning_name))
return ConditioningOutput(conditioning=ConditioningField(conditioning_name=conditioning_name, mask_weight=1.0))
@invocation_output("clip_skip_output")

View File

@ -236,13 +236,7 @@ class ConditioningField(BaseModel):
description="The bool mask associated with this conditioning tensor. Excluded regions should be set to False, "
"included regions should be set to True.",
)
positive_cross_attn_mask_score: float = Field(
default=0.0,
# TODO(ryand): Add more details to this description
description="The weight of this conditioning tensor's mask relative to overlapping masks.",
)
positive_self_attn_mask_score: float = Field(default=1.0, description="")
self_attn_adjustment_end_step_percent: float = Field(default=0.0, description="")
mask_weight: float = Field(description="")
class MetadataField(RootModel):

View File

@ -466,9 +466,18 @@ class DenoiseLatentsInvocation(BaseInvocation):
text_embedding.append(text_embedding_info.embeds)
if not all_masks_are_none:
# embedding_ranges.append(
# Range(
# start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1]
# )
# )
# HACK(ryand): Contrary to its name, tokens_count_including_eos_bos does not seem to include eos and bos
# in the count.
embedding_ranges.append(
Range(
start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1]
start=cur_text_embedding_len + 1,
end=cur_text_embedding_len
+ text_embedding_info.extra_conditioning.tokens_count_including_eos_bos,
)
)
processed_masks.append(self._preprocess_regional_prompt_mask(mask, latent_height, latent_width))
@ -483,11 +492,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
regions = TextConditioningRegions(
masks=torch.cat(processed_masks, dim=1),
ranges=embedding_ranges,
positive_cross_attn_mask_scores=[x.positive_cross_attn_mask_score for x in conditioning_fields],
positive_self_attn_mask_scores=[x.positive_self_attn_mask_score for x in conditioning_fields],
self_attn_adjustment_end_step_percents=[
x.self_attn_adjustment_end_step_percent for x in conditioning_fields
],
mask_weights=[x.mask_weight for x in conditioning_fields],
)
if extra_conditioning is not None and len(text_conditionings) > 1:

View File

@ -74,9 +74,7 @@ class TextConditioningRegions:
self,
masks: torch.Tensor,
ranges: list[Range],
positive_cross_attn_mask_scores: list[float],
positive_self_attn_mask_scores: list[float],
self_attn_adjustment_end_step_percents: list[float],
mask_weights: list[float],
):
# A binary mask indicating the regions of the image that the prompt should be applied to.
# Shape: (1, num_prompts, height, width)
@ -87,19 +85,9 @@ class TextConditioningRegions:
# ranges[i] contains the embedding range for the i'th prompt / mask.
self.ranges = ranges
# The following fields control the cross-attention and self-attention handling of region masks. They are indexed
# by prompt / mask index.
self.positive_cross_attn_mask_scores = positive_cross_attn_mask_scores
self.positive_self_attn_mask_scores = positive_self_attn_mask_scores
self.self_attn_adjustment_end_step_percents = self_attn_adjustment_end_step_percents
self.mask_weights = mask_weights
assert (
self.masks.shape[1]
== len(self.ranges)
== len(self.positive_cross_attn_mask_scores)
== len(self.positive_self_attn_mask_scores)
== len(self.self_attn_adjustment_end_step_percents)
)
assert self.masks.shape[1] == len(self.ranges) == len(self.mask_weights)
class TextConditioningData:

View File

@ -33,6 +33,7 @@ class RegionalPromptData:
regions, max_downscale_factor
)
self._negative_cross_attn_mask_score = 0.0
self._size_weight = 1.0
def _prepare_spatial_masks(
self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8
@ -96,13 +97,16 @@ class RegionalPromptData:
batch_sample_query_masks = batch_sample_spatial_masks.view((1, num_prompts, query_seq_len, 1))
for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges):
batch_sample_query_scores = batch_sample_query_masks[0, prompt_idx, :, :].clone()
batch_sample_query_scores = batch_sample_query_masks[0, prompt_idx, :, :]
size = batch_sample_query_scores.sum() / batch_sample_query_scores.numel()
size = size.to(dtype=batch_sample_query_scores.dtype)
batch_sample_query_mask = batch_sample_query_scores > 0.5
batch_sample_query_scores[batch_sample_query_mask] = 1.0 * (1.0 - size)
batch_sample_query_scores[~batch_sample_query_mask] = 0.0
attn_mask[batch_idx, :, embedding_range.start : embedding_range.end] = batch_sample_query_scores
mask_weight = batch_sample_regions.mask_weights[prompt_idx]
# size = size.to(dtype=batch_sample_query_scores.dtype)
# batch_sample_query_mask = batch_sample_query_scores > 0.5
# batch_sample_query_scores[batch_sample_query_mask] = 1.0 * (1.0 - size)
# batch_sample_query_scores[~batch_sample_query_mask] = 0.0
attn_mask[batch_idx, :, embedding_range.start : embedding_range.end] = batch_sample_query_scores * (
mask_weight + self._size_weight * (1 - size)
)
return attn_mask
@ -136,11 +140,15 @@ class RegionalPromptData:
prompt_query_mask = batch_sample_query_masks[0, prompt_idx, :, 0] # Shape: (query_seq_len,)
size = prompt_query_mask.sum() / prompt_query_mask.numel()
size = size.to(dtype=prompt_query_mask.dtype)
mask_weight = batch_sample_regions.mask_weights[prompt_idx]
# Multiply a (1, query_seq_len) mask by a (query_seq_len, 1) mask to get a (query_seq_len,
# query_seq_len) mask.
# TODO(ryand): Is += really the best option here?
attn_mask[batch_idx, :, :] += (
prompt_query_mask.unsqueeze(0) * prompt_query_mask.unsqueeze(1) * (1 - size)
# TODO(ryand): Is += really the best option here? Maybe elementwise max is better?
attn_mask[batch_idx, :, :] = torch.maximum(
attn_mask[batch_idx, :, :],
prompt_query_mask.unsqueeze(0)
* prompt_query_mask.unsqueeze(1)
* (mask_weight + self._size_weight * (1 - size)),
)
# if attn_mask[batch_idx].max() < 0.01:

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import math
import time
from contextlib import contextmanager
from typing import Any, Callable, Optional, Union
@ -345,9 +346,7 @@ class InvokeAIDiffuserComponent:
r = TextConditioningRegions(
masks=torch.ones((1, 1, h, w), dtype=torch.bool),
ranges=[Range(start=0, end=c.embeds.shape[1])],
positive_cross_attn_mask_scores=[0.0],
positive_self_attn_mask_scores=[0.0],
self_attn_adjustment_end_step_percents=[0.0],
mask_weights=[0.0],
)
regions.append(r)
@ -355,6 +354,7 @@ class InvokeAIDiffuserComponent:
regions=regions, device=x.device, dtype=x.dtype
)
cross_attention_kwargs["percent_through"] = percent_through
time.sleep(1.0)
both_results = self.model_forward_callback(
x_twice,