mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Removed unused function: _prepare_text_embeddings(...)
This commit is contained in:
parent
2966c8de2c
commit
cfba51aed5
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import math
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
@ -467,88 +467,6 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
return mask
|
||||
|
||||
def _prepare_text_embeddings(
|
||||
self,
|
||||
text_embeddings: list[Union[BasicConditioningInfo, SDXLConditioningInfo]],
|
||||
masks: list[Optional[torch.Tensor]],
|
||||
target_height: int,
|
||||
target_width: int,
|
||||
) -> Tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[RegionalPromptData]]:
|
||||
"""Prepare text embeddings and associated masks for use in the UNet forward pass.
|
||||
|
||||
- Concatenates the text embeddings into a single tensor (returned as a single BasicConditioningInfo or
|
||||
SDXLConditioningInfo).
|
||||
- Preprocesses the masks to match the target height and width, and stacks them into a single tensor.
|
||||
- If all masks are None, skips all mask processing.
|
||||
|
||||
Returns:
|
||||
Tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[RegionalPromptData]]:
|
||||
(text_embedding, regional_prompt_data)
|
||||
- text_embedding: The concatenated text embeddings.
|
||||
- regional_prompt_data: The processed masks and embedding ranges, or None if all masks are None.
|
||||
"""
|
||||
is_sdxl = type(text_embeddings[0]) is SDXLConditioningInfo
|
||||
|
||||
all_masks_are_none = all(mask is None for mask in masks)
|
||||
|
||||
text_embedding = []
|
||||
pooled_embedding = None
|
||||
add_time_ids = None
|
||||
processed_masks = []
|
||||
cur_text_embedding_len = 0
|
||||
embedding_ranges: list[Range] = []
|
||||
|
||||
for text_embedding_info, mask in zip(text_embeddings, masks, strict=True):
|
||||
# HACK(ryand): Figure out the intended relationship between CAC and other conditioning features.
|
||||
assert (
|
||||
text_embedding_info.extra_conditioning is None
|
||||
or not text_embedding_info.extra_conditioning.wants_cross_attention_control
|
||||
)
|
||||
|
||||
if is_sdxl:
|
||||
# We just use the the first SDXLConditioningInfo's pooled_embeds and add_time_ids.
|
||||
# TODO(ryand): Think about this some more. If we can't use the pooled_embeds and add_time_ids from all
|
||||
# the conditioning info, then we shouldn't allow it to be passed in.
|
||||
# How does Compel handle this? Options that come to mind:
|
||||
# - Blend the pooled_embeds and add_time_ids from all of the text embeddings.
|
||||
# - Use the pooled_embeds and add_time_ids from the text embedding with the largest mask area, since
|
||||
# this is likely the global prompt.
|
||||
if pooled_embedding is None:
|
||||
pooled_embedding = text_embedding_info.pooled_embeds
|
||||
if add_time_ids is None:
|
||||
add_time_ids = text_embedding_info.add_time_ids
|
||||
|
||||
text_embedding.append(text_embedding_info.embeds)
|
||||
embedding_ranges.append(
|
||||
Range(start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1])
|
||||
)
|
||||
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
|
||||
|
||||
if not all_masks_are_none:
|
||||
processed_masks.append(self._preprocess_regional_prompt_mask(mask, target_height, target_width))
|
||||
|
||||
text_embedding = torch.cat(text_embedding, dim=1)
|
||||
assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len
|
||||
|
||||
regional_prompt_data = None
|
||||
if not all_masks_are_none:
|
||||
# TODO(ryand): Think about at what point a batch dimension should be added to the masks.
|
||||
processed_masks = torch.cat(processed_masks, dim=1)
|
||||
|
||||
regional_prompt_data = RegionalPromptData(masks=processed_masks, embedding_ranges=embedding_ranges)
|
||||
|
||||
if is_sdxl:
|
||||
return SDXLConditioningInfo(
|
||||
embeds=text_embedding,
|
||||
extra_conditioning=None,
|
||||
pooled_embeds=pooled_embedding,
|
||||
add_time_ids=add_time_ids,
|
||||
), regional_prompt_data
|
||||
return BasicConditioningInfo(
|
||||
embeds=text_embedding,
|
||||
extra_conditioning=None,
|
||||
), regional_prompt_data
|
||||
|
||||
def _apply_standard_conditioning(
|
||||
self,
|
||||
x,
|
||||
|
Loading…
x
Reference in New Issue
Block a user