diff --git a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_attention.py b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_attention.py index fe9bdbc951..15da2c3151 100644 --- a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_attention.py +++ b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_attention.py @@ -32,15 +32,16 @@ class RegionalPromptData: Args: masks (list[torch.Tensor]): masks[i] contains the regions masks for the i'th sample in the batch. - The shape of masks[i] is (num_prompts, height, width), and dtype=bool. The mask is set to True in - regions where the prompt should be applied, and 0.0 elsewhere. + The shape of masks[i] is (num_prompts, height, width). The mask is set to 1.0 in regions where the + prompt should be applied, and 0.0 elsewhere. embedding_ranges (list[list[Range]]): embedding_ranges[i][j] contains the embedding range for the j'th prompt in the i'th batch sample. masks[i][j, ...] is applied to the embeddings in: encoder_hidden_states[i, embedding_ranges[j].start:embedding_ranges[j].end, :]. key_seq_len (int): The sequence length of the expected prompt embeddings (which act as the key in the - cross-attention layers). + cross-attention layers). This is most likely equal to the max embedding range end, but we pass it + explicitly to be sure. """ attn_masks_by_seq_len = {} diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index f22cf1375e..327161b736 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -121,13 +121,14 @@ class RegionalTextConditioningInfo: ) if is_sdxl: - # We just use the the first SDXLConditioningInfo's pooled_embeds and add_time_ids. - # TODO(ryand): Think about this some more. If we can't use the pooled_embeds and add_time_ids from all - # the conditioning info, then we shouldn't allow it to be passed in. - # How does Compel handle this? Options that come to mind: - # - Blend the pooled_embeds and add_time_ids from all of the text embeddings. - # - Use the pooled_embeds and add_time_ids from the text embedding with the largest mask area, since - # this is likely the global prompt. + # HACK(ryand): We just use the the first SDXLConditioningInfo's pooled_embeds and add_time_ids. This is + # fundamentally an interface issue, as the SDXL Compel nodes are not designed to be used in the way that + # we use them for regional prompting. Ideally, the DenoiseLatents invocation should accept a single + # pooled_embeds tensor and a list of standard text embeds with region masks. This change would be a + # pretty major breaking change to a popular node, so for now we use this hack. + # + # An improvement could be to use the pooled embeds from the prompt with the largest region, as this is + # most likely to be a global prompt. if pooled_embedding is None: pooled_embedding = text_embedding_info.pooled_embeds if add_time_ids is None: @@ -433,28 +434,6 @@ class InvokeAIDiffuserComponent: return torch.cat([unconditioning, conditioning]), encoder_attention_mask - def _preprocess_regional_prompt_mask( - self, mask: Optional[torch.Tensor], target_height: int, target_width: int - ) -> torch.Tensor: - """Preprocess a regional prompt mask to match the target height and width. - - If mask is None, returns a mask of all ones with the target height and width. - If mask is not None, resizes the mask to the target height and width using nearest neighbor interpolation. - - Returns: - torch.Tensor: The processed mask. dtype: torch.bool, shape: (1, 1, target_height, target_width). - """ - if mask is None: - return torch.ones((1, 1, target_height, target_width), dtype=torch.bool) - - tf = torchvision.transforms.Resize( - (target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST - ) - mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w) - mask = tf(mask) - - return mask - def _apply_standard_conditioning( self, x, @@ -470,8 +449,12 @@ class InvokeAIDiffuserComponent: x_twice = torch.cat([x] * 2) sigma_twice = torch.cat([sigma] * 2) - # HACK(ryand): We should only have to call _prepare_text_embeddings once, but we currently re-run it on every - # denoising step. + # TODO(ryand): We currently call from_text_conditioning_and_masks(...) and from_masks_and_ranges(...) for every + # denoising step. The text conditionings and masks are not changing from step-to-step, so this really only needs + # to be done once. While this seems painfully inefficient, the time spent is typically negligible compared to + # the forward inference pass of the UNet. The main reason that this hasn't been moved up to eliminate redundancy + # is that it is slightly awkward to handle both standard conditioning and sequential conditioning further up the + # stack. cross_attention_kwargs = None _, _, h, w = x.shape cond_text = RegionalTextConditioningInfo.from_text_conditioning_and_masks(