From ff950bc5cd11eb255ac11311f706af578727ce7b Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 7 Mar 2024 14:30:51 -0500 Subject: [PATCH] Add support for mask weights, and only mask the tokens associated with the prompts (not eh entire 77-token embedding). --- invokeai/app/invocations/compel.py | 22 +++++----------- invokeai/app/invocations/fields.py | 8 +----- invokeai/app/invocations/latent.py | 17 +++++++----- .../diffusion/conditioning_data.py | 18 +++---------- .../diffusion/regional_prompt_data.py | 26 ++++++++++++------- .../diffusion/shared_invokeai_diffusion.py | 6 ++--- 6 files changed, 42 insertions(+), 55 deletions(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 5672ee3bd9..69987a458b 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -44,7 +44,7 @@ from .model import ClipField title="Prompt", tags=["prompt", "compel"], category="conditioning", - version="1.1.0", + version="1.2.0", ) class CompelInvocation(BaseInvocation): """Parse prompt using compel package to conditioning.""" @@ -62,9 +62,7 @@ class CompelInvocation(BaseInvocation): mask: Optional[MaskField] = InputField( default=None, description="A mask defining the region that this conditioning prompt applies to." ) - positive_cross_attn_mask_score: float = InputField(default=0.0, description="") - positive_self_attn_mask_score: float = InputField(default=1.0, description="") - self_attn_adjustment_end_step_percent: float = InputField(default=0.0, description="") + mask_weight: float = InputField(default=1.0, description="") @torch.no_grad() def invoke(self, context: InvocationContext) -> ConditioningOutput: @@ -136,9 +134,7 @@ class CompelInvocation(BaseInvocation): conditioning=ConditioningField( conditioning_name=conditioning_name, mask=self.mask, - positive_cross_attn_mask_score=self.positive_cross_attn_mask_score, - positive_self_attn_mask_score=self.positive_self_attn_mask_score, - self_attn_adjustment_end_step_percent=self.self_attn_adjustment_end_step_percent, + mask_weight=self.mask_weight, ) ) @@ -254,7 +250,7 @@ class SDXLPromptInvocationBase: title="SDXL Prompt", tags=["sdxl", "compel", "prompt"], category="conditioning", - version="1.1.0", + version="1.2.0", ) class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): """Parse prompt using compel package to conditioning.""" @@ -281,9 +277,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): mask: Optional[MaskField] = InputField( default=None, description="A mask defining the region that this conditioning prompt applies to." ) - positive_cross_attn_mask_score: float = InputField(default=0.0, description="") - positive_self_attn_mask_score: float = InputField(default=1.0, description="") - self_attn_adjustment_end_step_percent: float = InputField(default=0.0, description="") + mask_weight: float = InputField(default=1.0, description="") @torch.no_grad() def invoke(self, context: InvocationContext) -> ConditioningOutput: @@ -350,9 +344,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): conditioning=ConditioningField( conditioning_name=conditioning_name, mask=self.mask, - positive_cross_attn_mask_score=self.positive_cross_attn_mask_score, - positive_self_attn_mask_score=self.positive_self_attn_mask_score, - self_attn_adjustment_end_step_percent=self.self_attn_adjustment_end_step_percent, + mask_weight=self.mask_weight, ) ) @@ -403,7 +395,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase conditioning_name = context.conditioning.save(conditioning_data) - return ConditioningOutput(conditioning=ConditioningField(conditioning_name=conditioning_name)) + return ConditioningOutput(conditioning=ConditioningField(conditioning_name=conditioning_name, mask_weight=1.0)) @invocation_output("clip_skip_output") diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index 282369c241..08f009a02f 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -236,13 +236,7 @@ class ConditioningField(BaseModel): description="The bool mask associated with this conditioning tensor. Excluded regions should be set to False, " "included regions should be set to True.", ) - positive_cross_attn_mask_score: float = Field( - default=0.0, - # TODO(ryand): Add more details to this description - description="The weight of this conditioning tensor's mask relative to overlapping masks.", - ) - positive_self_attn_mask_score: float = Field(default=1.0, description="") - self_attn_adjustment_end_step_percent: float = Field(default=0.0, description="") + mask_weight: float = Field(description="") class MetadataField(RootModel): diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index d967a68bba..27b482d30f 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -466,9 +466,18 @@ class DenoiseLatentsInvocation(BaseInvocation): text_embedding.append(text_embedding_info.embeds) if not all_masks_are_none: + # embedding_ranges.append( + # Range( + # start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1] + # ) + # ) + # HACK(ryand): Contrary to its name, tokens_count_including_eos_bos does not seem to include eos and bos + # in the count. embedding_ranges.append( Range( - start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1] + start=cur_text_embedding_len + 1, + end=cur_text_embedding_len + + text_embedding_info.extra_conditioning.tokens_count_including_eos_bos, ) ) processed_masks.append(self._preprocess_regional_prompt_mask(mask, latent_height, latent_width)) @@ -483,11 +492,7 @@ class DenoiseLatentsInvocation(BaseInvocation): regions = TextConditioningRegions( masks=torch.cat(processed_masks, dim=1), ranges=embedding_ranges, - positive_cross_attn_mask_scores=[x.positive_cross_attn_mask_score for x in conditioning_fields], - positive_self_attn_mask_scores=[x.positive_self_attn_mask_score for x in conditioning_fields], - self_attn_adjustment_end_step_percents=[ - x.self_attn_adjustment_end_step_percent for x in conditioning_fields - ], + mask_weights=[x.mask_weight for x in conditioning_fields], ) if extra_conditioning is not None and len(text_conditionings) > 1: diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 762e3003fd..f68a5ae12a 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -74,9 +74,7 @@ class TextConditioningRegions: self, masks: torch.Tensor, ranges: list[Range], - positive_cross_attn_mask_scores: list[float], - positive_self_attn_mask_scores: list[float], - self_attn_adjustment_end_step_percents: list[float], + mask_weights: list[float], ): # A binary mask indicating the regions of the image that the prompt should be applied to. # Shape: (1, num_prompts, height, width) @@ -87,19 +85,9 @@ class TextConditioningRegions: # ranges[i] contains the embedding range for the i'th prompt / mask. self.ranges = ranges - # The following fields control the cross-attention and self-attention handling of region masks. They are indexed - # by prompt / mask index. - self.positive_cross_attn_mask_scores = positive_cross_attn_mask_scores - self.positive_self_attn_mask_scores = positive_self_attn_mask_scores - self.self_attn_adjustment_end_step_percents = self_attn_adjustment_end_step_percents + self.mask_weights = mask_weights - assert ( - self.masks.shape[1] - == len(self.ranges) - == len(self.positive_cross_attn_mask_scores) - == len(self.positive_self_attn_mask_scores) - == len(self.self_attn_adjustment_end_step_percents) - ) + assert self.masks.shape[1] == len(self.ranges) == len(self.mask_weights) class TextConditioningData: diff --git a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py index 76b5e83b53..3fb9c490f5 100644 --- a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py @@ -33,6 +33,7 @@ class RegionalPromptData: regions, max_downscale_factor ) self._negative_cross_attn_mask_score = 0.0 + self._size_weight = 1.0 def _prepare_spatial_masks( self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8 @@ -96,13 +97,16 @@ class RegionalPromptData: batch_sample_query_masks = batch_sample_spatial_masks.view((1, num_prompts, query_seq_len, 1)) for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges): - batch_sample_query_scores = batch_sample_query_masks[0, prompt_idx, :, :].clone() + batch_sample_query_scores = batch_sample_query_masks[0, prompt_idx, :, :] size = batch_sample_query_scores.sum() / batch_sample_query_scores.numel() - size = size.to(dtype=batch_sample_query_scores.dtype) - batch_sample_query_mask = batch_sample_query_scores > 0.5 - batch_sample_query_scores[batch_sample_query_mask] = 1.0 * (1.0 - size) - batch_sample_query_scores[~batch_sample_query_mask] = 0.0 - attn_mask[batch_idx, :, embedding_range.start : embedding_range.end] = batch_sample_query_scores + mask_weight = batch_sample_regions.mask_weights[prompt_idx] + # size = size.to(dtype=batch_sample_query_scores.dtype) + # batch_sample_query_mask = batch_sample_query_scores > 0.5 + # batch_sample_query_scores[batch_sample_query_mask] = 1.0 * (1.0 - size) + # batch_sample_query_scores[~batch_sample_query_mask] = 0.0 + attn_mask[batch_idx, :, embedding_range.start : embedding_range.end] = batch_sample_query_scores * ( + mask_weight + self._size_weight * (1 - size) + ) return attn_mask @@ -136,11 +140,15 @@ class RegionalPromptData: prompt_query_mask = batch_sample_query_masks[0, prompt_idx, :, 0] # Shape: (query_seq_len,) size = prompt_query_mask.sum() / prompt_query_mask.numel() size = size.to(dtype=prompt_query_mask.dtype) + mask_weight = batch_sample_regions.mask_weights[prompt_idx] # Multiply a (1, query_seq_len) mask by a (query_seq_len, 1) mask to get a (query_seq_len, # query_seq_len) mask. - # TODO(ryand): Is += really the best option here? - attn_mask[batch_idx, :, :] += ( - prompt_query_mask.unsqueeze(0) * prompt_query_mask.unsqueeze(1) * (1 - size) + # TODO(ryand): Is += really the best option here? Maybe elementwise max is better? + attn_mask[batch_idx, :, :] = torch.maximum( + attn_mask[batch_idx, :, :], + prompt_query_mask.unsqueeze(0) + * prompt_query_mask.unsqueeze(1) + * (mask_weight + self._size_weight * (1 - size)), ) # if attn_mask[batch_idx].max() < 0.01: diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index 80fb48d252..e52b5ae36a 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -1,6 +1,7 @@ from __future__ import annotations import math +import time from contextlib import contextmanager from typing import Any, Callable, Optional, Union @@ -345,9 +346,7 @@ class InvokeAIDiffuserComponent: r = TextConditioningRegions( masks=torch.ones((1, 1, h, w), dtype=torch.bool), ranges=[Range(start=0, end=c.embeds.shape[1])], - positive_cross_attn_mask_scores=[0.0], - positive_self_attn_mask_scores=[0.0], - self_attn_adjustment_end_step_percents=[0.0], + mask_weights=[0.0], ) regions.append(r) @@ -355,6 +354,7 @@ class InvokeAIDiffuserComponent: regions=regions, device=x.device, dtype=x.dtype ) cross_attention_kwargs["percent_through"] = percent_through + time.sleep(1.0) both_results = self.model_forward_callback( x_twice,