diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 620c1acb1e..5672ee3bd9 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -63,6 +63,8 @@ class CompelInvocation(BaseInvocation): default=None, description="A mask defining the region that this conditioning prompt applies to." ) positive_cross_attn_mask_score: float = InputField(default=0.0, description="") + positive_self_attn_mask_score: float = InputField(default=1.0, description="") + self_attn_adjustment_end_step_percent: float = InputField(default=0.0, description="") @torch.no_grad() def invoke(self, context: InvocationContext) -> ConditioningOutput: @@ -135,6 +137,8 @@ class CompelInvocation(BaseInvocation): conditioning_name=conditioning_name, mask=self.mask, positive_cross_attn_mask_score=self.positive_cross_attn_mask_score, + positive_self_attn_mask_score=self.positive_self_attn_mask_score, + self_attn_adjustment_end_step_percent=self.self_attn_adjustment_end_step_percent, ) ) @@ -278,6 +282,8 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): default=None, description="A mask defining the region that this conditioning prompt applies to." ) positive_cross_attn_mask_score: float = InputField(default=0.0, description="") + positive_self_attn_mask_score: float = InputField(default=1.0, description="") + self_attn_adjustment_end_step_percent: float = InputField(default=0.0, description="") @torch.no_grad() def invoke(self, context: InvocationContext) -> ConditioningOutput: @@ -345,6 +351,8 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): conditioning_name=conditioning_name, mask=self.mask, positive_cross_attn_mask_score=self.positive_cross_attn_mask_score, + positive_self_attn_mask_score=self.positive_self_attn_mask_score, + self_attn_adjustment_end_step_percent=self.self_attn_adjustment_end_step_percent, ) ) diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index 300d3c0a0f..282369c241 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -241,6 +241,8 @@ class ConditioningField(BaseModel): # TODO(ryand): Add more details to this description description="The weight of this conditioning tensor's mask relative to overlapping masks.", ) + positive_self_attn_mask_score: float = Field(default=1.0, description="") + self_attn_adjustment_end_step_percent: float = Field(default=0.0, description="") class MetadataField(RootModel): diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 4c766e955c..5622876c17 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -391,22 +391,14 @@ class DenoiseLatentsInvocation(BaseInvocation): def _get_text_embeddings_and_masks( self, - cond_field: Union[ConditioningField, list[ConditioningField]], + cond_list: list[ConditioningField], context: InvocationContext, device: torch.device, dtype: torch.dtype, - ) -> tuple[ - Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], list[Optional[torch.Tensor]], list[float] - ]: + ) -> tuple[Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], list[Optional[torch.Tensor]]]: """Get the text embeddings and masks from the input conditioning fields.""" - # Normalize cond_field to a list. - cond_list = cond_field - if not isinstance(cond_list, list): - cond_list = [cond_list] - text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = [] text_embeddings_masks: list[Optional[torch.Tensor]] = [] - positive_cross_attn_mask_scores: list[float] = [] for cond in cond_list: cond_data = context.conditioning.load(cond.conditioning_name) text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype)) @@ -416,8 +408,7 @@ class DenoiseLatentsInvocation(BaseInvocation): mask = context.tensors.load(mask.mask_name) text_embeddings_masks.append(mask) - positive_cross_attn_mask_scores.append(cond.positive_cross_attn_mask_score) - return text_embeddings, text_embeddings_masks, positive_cross_attn_mask_scores + return text_embeddings, text_embeddings_masks def _preprocess_regional_prompt_mask( self, mask: Optional[torch.Tensor], target_height: int, target_width: int @@ -445,7 +436,7 @@ class DenoiseLatentsInvocation(BaseInvocation): self, text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], masks: Optional[list[Optional[torch.Tensor]]], - positive_cross_attn_mask_scores: list[float], + conditioning_fields: list[ConditioningField], latent_height: int, latent_width: int, ) -> tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[TextConditioningRegions]]: @@ -466,7 +457,8 @@ class DenoiseLatentsInvocation(BaseInvocation): embedding_ranges = [] extra_conditioning = None - for text_embedding_info, mask in zip(text_conditionings, masks, strict=True): + for prompt_idx, text_embedding_info in enumerate(text_conditionings): + mask = masks[prompt_idx] if ( text_embedding_info.extra_conditioning is not None and text_embedding_info.extra_conditioning.wants_cross_attention_control @@ -508,7 +500,11 @@ class DenoiseLatentsInvocation(BaseInvocation): regions = TextConditioningRegions( masks=torch.cat(processed_masks, dim=1), ranges=embedding_ranges, - positive_cross_attn_mask_scores=positive_cross_attn_mask_scores, + positive_cross_attn_mask_scores=[x.positive_cross_attn_mask_score for x in conditioning_fields], + positive_self_attn_mask_scores=[x.positive_self_attn_mask_score for x in conditioning_fields], + self_attn_adjustment_end_step_percents=[ + x.self_attn_adjustment_end_step_percent for x in conditioning_fields + ], ) if extra_conditioning is not None and len(text_conditionings) > 1: @@ -536,27 +532,32 @@ class DenoiseLatentsInvocation(BaseInvocation): latent_height: int, latent_width: int, ) -> TextConditioningData: - ( - cond_text_embeddings, - cond_text_embedding_masks, - cond_positive_cross_attn_mask_scores, - ) = self._get_text_embeddings_and_masks(self.positive_conditioning, context, unet.device, unet.dtype) - ( - uncond_text_embeddings, - uncond_text_embedding_masks, - uncond_positive_cross_attn_mask_scores, - ) = self._get_text_embeddings_and_masks(self.negative_conditioning, context, unet.device, unet.dtype) + # Normalize self.positive_conditioning and self.negative_conditioning to lists. + cond_list = self.positive_conditioning + if not isinstance(cond_list, list): + cond_list = [cond_list] + uncond_list = self.negative_conditioning + if not isinstance(uncond_list, list): + uncond_list = [uncond_list] + + cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks( + cond_list, context, unet.device, unet.dtype + ) + + uncond_text_embeddings, uncond_text_embedding_masks = self._get_text_embeddings_and_masks( + uncond_list, context, unet.device, unet.dtype + ) cond_text_embedding, cond_regions = self.concat_regional_text_embeddings( text_conditionings=cond_text_embeddings, masks=cond_text_embedding_masks, - positive_cross_attn_mask_scores=cond_positive_cross_attn_mask_scores, + conditioning_fields=cond_list, latent_height=latent_height, latent_width=latent_width, ) uncond_text_embedding, uncond_regions = self.concat_regional_text_embeddings( text_conditionings=uncond_text_embeddings, masks=uncond_text_embedding_masks, - positive_cross_attn_mask_scores=uncond_positive_cross_attn_mask_scores, + conditioning_fields=uncond_list, latent_height=latent_height, latent_width=latent_width, ) diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 4b144155b6..762e3003fd 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -70,7 +70,14 @@ class Range: class TextConditioningRegions: - def __init__(self, masks: torch.Tensor, ranges: list[Range], positive_cross_attn_mask_scores: list[float]): + def __init__( + self, + masks: torch.Tensor, + ranges: list[Range], + positive_cross_attn_mask_scores: list[float], + positive_self_attn_mask_scores: list[float], + self_attn_adjustment_end_step_percents: list[float], + ): # A binary mask indicating the regions of the image that the prompt should be applied to. # Shape: (1, num_prompts, height, width) # Dtype: torch.bool @@ -80,12 +87,19 @@ class TextConditioningRegions: # ranges[i] contains the embedding range for the i'th prompt / mask. self.ranges = ranges - # A list of positive cross attention mask scores for each prompt. - # positive_cross_attn_mask_scores[i] contains the positive cross attention mask score for the i'th prompt/mask. + # The following fields control the cross-attention and self-attention handling of region masks. They are indexed + # by prompt / mask index. self.positive_cross_attn_mask_scores = positive_cross_attn_mask_scores + self.positive_self_attn_mask_scores = positive_self_attn_mask_scores + self.self_attn_adjustment_end_step_percents = self_attn_adjustment_end_step_percents - assert self.masks.shape[1] == len(self.ranges) - assert self.masks.shape[1] == len(self.positive_cross_attn_mask_scores) + assert ( + self.masks.shape[1] + == len(self.ranges) + == len(self.positive_cross_attn_mask_scores) + == len(self.positive_self_attn_mask_scores) + == len(self.self_attn_adjustment_end_step_percents) + ) class TextConditioningData: @@ -102,23 +116,6 @@ class TextConditioningData: self.cond_text = cond_text self.uncond_regions = uncond_regions self.cond_regions = cond_regions - # All params: - # negative_cross_attn_mask_score: -10000 (recommended to leave this as -10000 to prevent leakage to the rest of the image) - # positive_cross_attn_mask_score: 0.0 (relative weightin of masks) - # positive_self_attn_mask_score: 0.3 - # negative_self_attn_mask_score: This doesn't really make sense. It would effectively have the same effect as further increasing positive_self_attn_mask_score. - # cross_attn_start_step - # self_attn_mask_begin_step_percent: 0.0 - # self_attn_mask_end_step percent: 0.5 - # Should we allow cross_attn_mask_begin_step_percent and cross_attn_mask_end_step_percent? Probably not, this seems like more control than necessary. And easy to add in the future. - self.negative_cross_attn_mask_score = -10000 - self.positive_cross_attn_mask_score = 0.0 - self.positive_self_attn_mask_score = 0.3 - self.self_attn_mask_end_step_percent = 0.5 - # mask_weight: float = Field( - # default=1.0, - # description="The weight to apply to the mask. This weight controls the relative weighting of overlapping masks. This weight gets added to the attention map logits before applying a pixelwise softmax.", - # ) # Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). # `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). # Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate diff --git a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py index 82f07385a9..f7043761a8 100644 --- a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py @@ -41,16 +41,7 @@ class RegionalPromptData: regions, max_downscale_factor ) - # TODO: These should be indexed by batch sample index and prompt index. - # Next: - # - Add support for setting these one nodes. Might just need positive cross-attention mask score. Being able to downweight the global prompt mighth help alot. - # - Scale by region size. self.negative_cross_attn_mask_score = -10000 - # self.positive_cross_attn_mask_score = 0.0 - self.positive_self_attn_mask_score = 1.0 - self.self_attn_mask_end_step_percent = 0.3 - # This one is for regional prompting in general, so should be set on the DenoiseLatents node. - self.self_attn_score_range = 3.0 def _prepare_spatial_masks( self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8 @@ -222,20 +213,23 @@ class RegionalPromptData: for batch_idx in range(batch_size): batch_sample_spatial_masks = batch_spatial_masks[batch_idx] + batch_sample_regions = self._regions[batch_idx] # Flatten the spatial dimensions of the mask by reshaping to (1, num_prompts, query_seq_len, 1). _, num_prompts, _, _ = batch_sample_spatial_masks.shape batch_sample_query_masks = batch_sample_spatial_masks.view((1, num_prompts, query_seq_len, 1)) for prompt_idx in range(num_prompts): - if percent_through > self.self_attn_mask_end_step_percent: + if percent_through > batch_sample_regions.self_attn_adjustment_end_step_percents[prompt_idx]: continue prompt_query_mask = batch_sample_query_masks[0, prompt_idx, :, 0] # Shape: (query_seq_len,) # Multiply a (1, query_seq_len) mask by a (query_seq_len, 1) mask to get a (query_seq_len, # query_seq_len) mask. # TODO(ryand): Is += really the best option here? attn_mask[batch_idx, :, :] += ( - prompt_query_mask.unsqueeze(0) * prompt_query_mask.unsqueeze(1) * self.positive_self_attn_mask_score + prompt_query_mask.unsqueeze(0) + * prompt_query_mask.unsqueeze(1) + * batch_sample_regions.positive_self_attn_mask_scores[prompt_idx] ) # attn_mask_min = attn_mask[batch_idx].min() diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index 60eeb9fadb..baaf70e65c 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -346,6 +346,8 @@ class InvokeAIDiffuserComponent: masks=torch.ones((1, 1, h, w), dtype=torch.bool), ranges=[Range(start=0, end=c.embeds.shape[1])], positive_cross_attn_mask_scores=[0.0], + positive_self_attn_mask_scores=[0.0], + self_attn_adjustment_end_step_percents=[0.0], ) regions.append(r)