Removed unused function: _prepare_text_embeddings(...)

This commit is contained in:
Ryan Dick 2024-02-27 19:23:20 -05:00
parent 2966c8de2c
commit cfba51aed5

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import math import math
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Callable, Optional, Tuple, Union from typing import Any, Callable, Optional, Union
import torch import torch
import torchvision import torchvision
@ -467,88 +467,6 @@ class InvokeAIDiffuserComponent:
return mask return mask
def _prepare_text_embeddings(
self,
text_embeddings: list[Union[BasicConditioningInfo, SDXLConditioningInfo]],
masks: list[Optional[torch.Tensor]],
target_height: int,
target_width: int,
) -> Tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[RegionalPromptData]]:
"""Prepare text embeddings and associated masks for use in the UNet forward pass.
- Concatenates the text embeddings into a single tensor (returned as a single BasicConditioningInfo or
SDXLConditioningInfo).
- Preprocesses the masks to match the target height and width, and stacks them into a single tensor.
- If all masks are None, skips all mask processing.
Returns:
Tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[RegionalPromptData]]:
(text_embedding, regional_prompt_data)
- text_embedding: The concatenated text embeddings.
- regional_prompt_data: The processed masks and embedding ranges, or None if all masks are None.
"""
is_sdxl = type(text_embeddings[0]) is SDXLConditioningInfo
all_masks_are_none = all(mask is None for mask in masks)
text_embedding = []
pooled_embedding = None
add_time_ids = None
processed_masks = []
cur_text_embedding_len = 0
embedding_ranges: list[Range] = []
for text_embedding_info, mask in zip(text_embeddings, masks, strict=True):
# HACK(ryand): Figure out the intended relationship between CAC and other conditioning features.
assert (
text_embedding_info.extra_conditioning is None
or not text_embedding_info.extra_conditioning.wants_cross_attention_control
)
if is_sdxl:
# We just use the the first SDXLConditioningInfo's pooled_embeds and add_time_ids.
# TODO(ryand): Think about this some more. If we can't use the pooled_embeds and add_time_ids from all
# the conditioning info, then we shouldn't allow it to be passed in.
# How does Compel handle this? Options that come to mind:
# - Blend the pooled_embeds and add_time_ids from all of the text embeddings.
# - Use the pooled_embeds and add_time_ids from the text embedding with the largest mask area, since
# this is likely the global prompt.
if pooled_embedding is None:
pooled_embedding = text_embedding_info.pooled_embeds
if add_time_ids is None:
add_time_ids = text_embedding_info.add_time_ids
text_embedding.append(text_embedding_info.embeds)
embedding_ranges.append(
Range(start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1])
)
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
if not all_masks_are_none:
processed_masks.append(self._preprocess_regional_prompt_mask(mask, target_height, target_width))
text_embedding = torch.cat(text_embedding, dim=1)
assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len
regional_prompt_data = None
if not all_masks_are_none:
# TODO(ryand): Think about at what point a batch dimension should be added to the masks.
processed_masks = torch.cat(processed_masks, dim=1)
regional_prompt_data = RegionalPromptData(masks=processed_masks, embedding_ranges=embedding_ranges)
if is_sdxl:
return SDXLConditioningInfo(
embeds=text_embedding,
extra_conditioning=None,
pooled_embeds=pooled_embedding,
add_time_ids=add_time_ids,
), regional_prompt_data
return BasicConditioningInfo(
embeds=text_embedding,
extra_conditioning=None,
), regional_prompt_data
def _apply_standard_conditioning( def _apply_standard_conditioning(
self, self,
x, x,