diff --git a/invokeai/app/invocations/conditioning.py b/invokeai/app/invocations/conditioning.py index b15fdcd198..595a20e186 100644 --- a/invokeai/app/invocations/conditioning.py +++ b/invokeai/app/invocations/conditioning.py @@ -21,9 +21,11 @@ class AddConditioningMaskInvocation(BaseInvocation): conditioning: ConditioningField = InputField(description="The conditioning tensor to add a mask to.") mask: MaskField = InputField(description="A mask to add to the conditioning tensor.") + positive_cross_attn_mask_score: float = InputField(default=0.0, description="") def invoke(self, context: InvocationContext) -> ConditioningOutput: self.conditioning.mask = self.mask + self.conditioning.positive_cross_attn_mask_score = self.positive_cross_attn_mask_score return ConditioningOutput(conditioning=self.conditioning) diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index 63572b5151..300d3c0a0f 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -236,6 +236,11 @@ class ConditioningField(BaseModel): description="The bool mask associated with this conditioning tensor. Excluded regions should be set to False, " "included regions should be set to True.", ) + positive_cross_attn_mask_score: float = Field( + default=0.0, + # TODO(ryand): Add more details to this description + description="The weight of this conditioning tensor's mask relative to overlapping masks.", + ) class MetadataField(RootModel): diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index f414fb7770..3243cc0fd3 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -381,7 +381,9 @@ class DenoiseLatentsInvocation(BaseInvocation): context: InvocationContext, device: torch.device, dtype: torch.dtype, - ) -> tuple[Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], list[Optional[torch.Tensor]]]: + ) -> tuple[ + Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], list[Optional[torch.Tensor]], list[float] + ]: """Get the text embeddings and masks from the input conditioning fields.""" # Normalize cond_field to a list. cond_list = cond_field @@ -390,6 +392,7 @@ class DenoiseLatentsInvocation(BaseInvocation): text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = [] text_embeddings_masks: list[Optional[torch.Tensor]] = [] + positive_cross_attn_mask_scores: list[float] = [] for cond in cond_list: cond_data = context.conditioning.load(cond.conditioning_name) text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype)) @@ -399,7 +402,8 @@ class DenoiseLatentsInvocation(BaseInvocation): mask = context.tensors.load(mask.mask_name) text_embeddings_masks.append(mask) - return text_embeddings, text_embeddings_masks + positive_cross_attn_mask_scores.append(cond.positive_cross_attn_mask_score) + return text_embeddings, text_embeddings_masks, positive_cross_attn_mask_scores def _preprocess_regional_prompt_mask( self, mask: Optional[torch.Tensor], target_height: int, target_width: int @@ -427,6 +431,7 @@ class DenoiseLatentsInvocation(BaseInvocation): self, text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], masks: Optional[list[Optional[torch.Tensor]]], + positive_cross_attn_mask_scores: list[float], latent_height: int, latent_width: int, ) -> tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[TextConditioningRegions]]: @@ -486,7 +491,11 @@ class DenoiseLatentsInvocation(BaseInvocation): regions = None if not all_masks_are_none: - regions = TextConditioningRegions(masks=torch.cat(processed_masks, dim=1), ranges=embedding_ranges) + regions = TextConditioningRegions( + masks=torch.cat(processed_masks, dim=1), + ranges=embedding_ranges, + positive_cross_attn_mask_scores=positive_cross_attn_mask_scores, + ) if extra_conditioning is not None and len(text_conditionings) > 1: raise ValueError( @@ -513,21 +522,27 @@ class DenoiseLatentsInvocation(BaseInvocation): latent_height: int, latent_width: int, ) -> TextConditioningData: - cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks( - self.positive_conditioning, context, unet.device, unet.dtype - ) - uncond_text_embeddings, uncond_text_embedding_masks = self._get_text_embeddings_and_masks( - self.negative_conditioning, context, unet.device, unet.dtype - ) + ( + cond_text_embeddings, + cond_text_embedding_masks, + cond_positive_cross_attn_mask_scores, + ) = self._get_text_embeddings_and_masks(self.positive_conditioning, context, unet.device, unet.dtype) + ( + uncond_text_embeddings, + uncond_text_embedding_masks, + uncond_positive_cross_attn_mask_scores, + ) = self._get_text_embeddings_and_masks(self.negative_conditioning, context, unet.device, unet.dtype) cond_text_embedding, cond_regions = self.concat_regional_text_embeddings( text_conditionings=cond_text_embeddings, masks=cond_text_embedding_masks, + positive_cross_attn_mask_scores=cond_positive_cross_attn_mask_scores, latent_height=latent_height, latent_width=latent_width, ) uncond_text_embedding, uncond_regions = self.concat_regional_text_embeddings( text_conditionings=uncond_text_embeddings, masks=uncond_text_embedding_masks, + positive_cross_attn_mask_scores=uncond_positive_cross_attn_mask_scores, latent_height=latent_height, latent_width=latent_width, ) diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 8d7ff09be1..4b144155b6 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -70,7 +70,7 @@ class Range: class TextConditioningRegions: - def __init__(self, masks: torch.Tensor, ranges: list[Range]): + def __init__(self, masks: torch.Tensor, ranges: list[Range], positive_cross_attn_mask_scores: list[float]): # A binary mask indicating the regions of the image that the prompt should be applied to. # Shape: (1, num_prompts, height, width) # Dtype: torch.bool @@ -80,7 +80,12 @@ class TextConditioningRegions: # ranges[i] contains the embedding range for the i'th prompt / mask. self.ranges = ranges + # A list of positive cross attention mask scores for each prompt. + # positive_cross_attn_mask_scores[i] contains the positive cross attention mask score for the i'th prompt/mask. + self.positive_cross_attn_mask_scores = positive_cross_attn_mask_scores + assert self.masks.shape[1] == len(self.ranges) + assert self.masks.shape[1] == len(self.positive_cross_attn_mask_scores) class TextConditioningData: diff --git a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py index 749e3e0224..57b956924c 100644 --- a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py @@ -46,7 +46,7 @@ class RegionalPromptData: # - Add support for setting these one nodes. Might just need positive cross-attention mask score. Being able to downweight the global prompt mighth help alot. # - Scale by region size. self.negative_cross_attn_mask_score = -10000 - self.positive_cross_attn_mask_score = 0.0 + # self.positive_cross_attn_mask_score = 0.0 self.positive_self_attn_mask_score = 2.0 self.self_attn_mask_end_step_percent = 0.3 # This one is for regional prompting in general, so should be set on the DenoiseLatents node. @@ -186,13 +186,14 @@ class RegionalPromptData: batch_sample_query_masks = batch_sample_spatial_masks.view((1, num_prompts, query_seq_len, 1)) for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges): - attn_mask[batch_idx, :, embedding_range.start : embedding_range.end] = batch_sample_query_masks[ - 0, prompt_idx, :, : - ] + batch_sample_query_scores = batch_sample_query_masks[0, prompt_idx, :, :].clone() + batch_sample_query_mask = batch_sample_query_scores > 0.5 + batch_sample_query_scores[ + batch_sample_query_mask + ] = batch_sample_regions.positive_cross_attn_mask_scores[prompt_idx] + batch_sample_query_scores[~batch_sample_query_mask] = self.negative_cross_attn_mask_score # TODO(ryand) + attn_mask[batch_idx, :, embedding_range.start : embedding_range.end] = batch_sample_query_scores - pos_mask = attn_mask >= 0.5 - attn_mask[~pos_mask] = self.negative_cross_attn_mask_score - attn_mask[pos_mask] = self.positive_cross_attn_mask_score return attn_mask def get_self_attn_mask( diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index 9c716e5440..60eeb9fadb 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -345,6 +345,7 @@ class InvokeAIDiffuserComponent: r = TextConditioningRegions( masks=torch.ones((1, 1, h, w), dtype=torch.bool), ranges=[Range(start=0, end=c.embeds.shape[1])], + positive_cross_attn_mask_scores=[0.0], ) regions.append(r)