From cfba51aed5011ccc0a7444dafdc656afb30d0506 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 27 Feb 2024 19:23:20 -0500 Subject: [PATCH] Removed unused function: _prepare_text_embeddings(...) --- .../diffusion/shared_invokeai_diffusion.py | 84 +------------------ 1 file changed, 1 insertion(+), 83 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index 9b7c0b7fd8..940eafe69c 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -2,7 +2,7 @@ from __future__ import annotations import math from contextlib import contextmanager -from typing import Any, Callable, Optional, Tuple, Union +from typing import Any, Callable, Optional, Union import torch import torchvision @@ -467,88 +467,6 @@ class InvokeAIDiffuserComponent: return mask - def _prepare_text_embeddings( - self, - text_embeddings: list[Union[BasicConditioningInfo, SDXLConditioningInfo]], - masks: list[Optional[torch.Tensor]], - target_height: int, - target_width: int, - ) -> Tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[RegionalPromptData]]: - """Prepare text embeddings and associated masks for use in the UNet forward pass. - - - Concatenates the text embeddings into a single tensor (returned as a single BasicConditioningInfo or - SDXLConditioningInfo). - - Preprocesses the masks to match the target height and width, and stacks them into a single tensor. - - If all masks are None, skips all mask processing. - - Returns: - Tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[RegionalPromptData]]: - (text_embedding, regional_prompt_data) - - text_embedding: The concatenated text embeddings. - - regional_prompt_data: The processed masks and embedding ranges, or None if all masks are None. - """ - is_sdxl = type(text_embeddings[0]) is SDXLConditioningInfo - - all_masks_are_none = all(mask is None for mask in masks) - - text_embedding = [] - pooled_embedding = None - add_time_ids = None - processed_masks = [] - cur_text_embedding_len = 0 - embedding_ranges: list[Range] = [] - - for text_embedding_info, mask in zip(text_embeddings, masks, strict=True): - # HACK(ryand): Figure out the intended relationship between CAC and other conditioning features. - assert ( - text_embedding_info.extra_conditioning is None - or not text_embedding_info.extra_conditioning.wants_cross_attention_control - ) - - if is_sdxl: - # We just use the the first SDXLConditioningInfo's pooled_embeds and add_time_ids. - # TODO(ryand): Think about this some more. If we can't use the pooled_embeds and add_time_ids from all - # the conditioning info, then we shouldn't allow it to be passed in. - # How does Compel handle this? Options that come to mind: - # - Blend the pooled_embeds and add_time_ids from all of the text embeddings. - # - Use the pooled_embeds and add_time_ids from the text embedding with the largest mask area, since - # this is likely the global prompt. - if pooled_embedding is None: - pooled_embedding = text_embedding_info.pooled_embeds - if add_time_ids is None: - add_time_ids = text_embedding_info.add_time_ids - - text_embedding.append(text_embedding_info.embeds) - embedding_ranges.append( - Range(start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1]) - ) - cur_text_embedding_len += text_embedding_info.embeds.shape[1] - - if not all_masks_are_none: - processed_masks.append(self._preprocess_regional_prompt_mask(mask, target_height, target_width)) - - text_embedding = torch.cat(text_embedding, dim=1) - assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len - - regional_prompt_data = None - if not all_masks_are_none: - # TODO(ryand): Think about at what point a batch dimension should be added to the masks. - processed_masks = torch.cat(processed_masks, dim=1) - - regional_prompt_data = RegionalPromptData(masks=processed_masks, embedding_ranges=embedding_ranges) - - if is_sdxl: - return SDXLConditioningInfo( - embeds=text_embedding, - extra_conditioning=None, - pooled_embeds=pooled_embedding, - add_time_ids=add_time_ids, - ), regional_prompt_data - return BasicConditioningInfo( - embeds=text_embedding, - extra_conditioning=None, - ), regional_prompt_data - def _apply_standard_conditioning( self, x,