Removed unused function: _prepare_text_embeddings(...)

This commit is contained in:
Ryan Dick 2024-02-27 19:23:20 -05:00
parent 2966c8de2c
commit cfba51aed5

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import math
from contextlib import contextmanager
from typing import Any, Callable, Optional, Tuple, Union
from typing import Any, Callable, Optional, Union
import torch
import torchvision
@ -467,88 +467,6 @@ class InvokeAIDiffuserComponent:
return mask
def _prepare_text_embeddings(
self,
text_embeddings: list[Union[BasicConditioningInfo, SDXLConditioningInfo]],
masks: list[Optional[torch.Tensor]],
target_height: int,
target_width: int,
) -> Tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[RegionalPromptData]]:
"""Prepare text embeddings and associated masks for use in the UNet forward pass.
- Concatenates the text embeddings into a single tensor (returned as a single BasicConditioningInfo or
SDXLConditioningInfo).
- Preprocesses the masks to match the target height and width, and stacks them into a single tensor.
- If all masks are None, skips all mask processing.
Returns:
Tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[RegionalPromptData]]:
(text_embedding, regional_prompt_data)
- text_embedding: The concatenated text embeddings.
- regional_prompt_data: The processed masks and embedding ranges, or None if all masks are None.
"""
is_sdxl = type(text_embeddings[0]) is SDXLConditioningInfo
all_masks_are_none = all(mask is None for mask in masks)
text_embedding = []
pooled_embedding = None
add_time_ids = None
processed_masks = []
cur_text_embedding_len = 0
embedding_ranges: list[Range] = []
for text_embedding_info, mask in zip(text_embeddings, masks, strict=True):
# HACK(ryand): Figure out the intended relationship between CAC and other conditioning features.
assert (
text_embedding_info.extra_conditioning is None
or not text_embedding_info.extra_conditioning.wants_cross_attention_control
)
if is_sdxl:
# We just use the the first SDXLConditioningInfo's pooled_embeds and add_time_ids.
# TODO(ryand): Think about this some more. If we can't use the pooled_embeds and add_time_ids from all
# the conditioning info, then we shouldn't allow it to be passed in.
# How does Compel handle this? Options that come to mind:
# - Blend the pooled_embeds and add_time_ids from all of the text embeddings.
# - Use the pooled_embeds and add_time_ids from the text embedding with the largest mask area, since
# this is likely the global prompt.
if pooled_embedding is None:
pooled_embedding = text_embedding_info.pooled_embeds
if add_time_ids is None:
add_time_ids = text_embedding_info.add_time_ids
text_embedding.append(text_embedding_info.embeds)
embedding_ranges.append(
Range(start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1])
)
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
if not all_masks_are_none:
processed_masks.append(self._preprocess_regional_prompt_mask(mask, target_height, target_width))
text_embedding = torch.cat(text_embedding, dim=1)
assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len
regional_prompt_data = None
if not all_masks_are_none:
# TODO(ryand): Think about at what point a batch dimension should be added to the masks.
processed_masks = torch.cat(processed_masks, dim=1)
regional_prompt_data = RegionalPromptData(masks=processed_masks, embedding_ranges=embedding_ranges)
if is_sdxl:
return SDXLConditioningInfo(
embeds=text_embedding,
extra_conditioning=None,
pooled_embeds=pooled_embedding,
add_time_ids=add_time_ids,
), regional_prompt_data
return BasicConditioningInfo(
embeds=text_embedding,
extra_conditioning=None,
), regional_prompt_data
def _apply_standard_conditioning(
self,
x,