mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Removed unused function: _prepare_text_embeddings(...)
This commit is contained in:
parent
2966c8de2c
commit
cfba51aed5
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Callable, Optional, Tuple, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
@ -467,88 +467,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
def _prepare_text_embeddings(
|
|
||||||
self,
|
|
||||||
text_embeddings: list[Union[BasicConditioningInfo, SDXLConditioningInfo]],
|
|
||||||
masks: list[Optional[torch.Tensor]],
|
|
||||||
target_height: int,
|
|
||||||
target_width: int,
|
|
||||||
) -> Tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[RegionalPromptData]]:
|
|
||||||
"""Prepare text embeddings and associated masks for use in the UNet forward pass.
|
|
||||||
|
|
||||||
- Concatenates the text embeddings into a single tensor (returned as a single BasicConditioningInfo or
|
|
||||||
SDXLConditioningInfo).
|
|
||||||
- Preprocesses the masks to match the target height and width, and stacks them into a single tensor.
|
|
||||||
- If all masks are None, skips all mask processing.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[RegionalPromptData]]:
|
|
||||||
(text_embedding, regional_prompt_data)
|
|
||||||
- text_embedding: The concatenated text embeddings.
|
|
||||||
- regional_prompt_data: The processed masks and embedding ranges, or None if all masks are None.
|
|
||||||
"""
|
|
||||||
is_sdxl = type(text_embeddings[0]) is SDXLConditioningInfo
|
|
||||||
|
|
||||||
all_masks_are_none = all(mask is None for mask in masks)
|
|
||||||
|
|
||||||
text_embedding = []
|
|
||||||
pooled_embedding = None
|
|
||||||
add_time_ids = None
|
|
||||||
processed_masks = []
|
|
||||||
cur_text_embedding_len = 0
|
|
||||||
embedding_ranges: list[Range] = []
|
|
||||||
|
|
||||||
for text_embedding_info, mask in zip(text_embeddings, masks, strict=True):
|
|
||||||
# HACK(ryand): Figure out the intended relationship between CAC and other conditioning features.
|
|
||||||
assert (
|
|
||||||
text_embedding_info.extra_conditioning is None
|
|
||||||
or not text_embedding_info.extra_conditioning.wants_cross_attention_control
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_sdxl:
|
|
||||||
# We just use the the first SDXLConditioningInfo's pooled_embeds and add_time_ids.
|
|
||||||
# TODO(ryand): Think about this some more. If we can't use the pooled_embeds and add_time_ids from all
|
|
||||||
# the conditioning info, then we shouldn't allow it to be passed in.
|
|
||||||
# How does Compel handle this? Options that come to mind:
|
|
||||||
# - Blend the pooled_embeds and add_time_ids from all of the text embeddings.
|
|
||||||
# - Use the pooled_embeds and add_time_ids from the text embedding with the largest mask area, since
|
|
||||||
# this is likely the global prompt.
|
|
||||||
if pooled_embedding is None:
|
|
||||||
pooled_embedding = text_embedding_info.pooled_embeds
|
|
||||||
if add_time_ids is None:
|
|
||||||
add_time_ids = text_embedding_info.add_time_ids
|
|
||||||
|
|
||||||
text_embedding.append(text_embedding_info.embeds)
|
|
||||||
embedding_ranges.append(
|
|
||||||
Range(start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1])
|
|
||||||
)
|
|
||||||
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
|
|
||||||
|
|
||||||
if not all_masks_are_none:
|
|
||||||
processed_masks.append(self._preprocess_regional_prompt_mask(mask, target_height, target_width))
|
|
||||||
|
|
||||||
text_embedding = torch.cat(text_embedding, dim=1)
|
|
||||||
assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len
|
|
||||||
|
|
||||||
regional_prompt_data = None
|
|
||||||
if not all_masks_are_none:
|
|
||||||
# TODO(ryand): Think about at what point a batch dimension should be added to the masks.
|
|
||||||
processed_masks = torch.cat(processed_masks, dim=1)
|
|
||||||
|
|
||||||
regional_prompt_data = RegionalPromptData(masks=processed_masks, embedding_ranges=embedding_ranges)
|
|
||||||
|
|
||||||
if is_sdxl:
|
|
||||||
return SDXLConditioningInfo(
|
|
||||||
embeds=text_embedding,
|
|
||||||
extra_conditioning=None,
|
|
||||||
pooled_embeds=pooled_embedding,
|
|
||||||
add_time_ids=add_time_ids,
|
|
||||||
), regional_prompt_data
|
|
||||||
return BasicConditioningInfo(
|
|
||||||
embeds=text_embedding,
|
|
||||||
extra_conditioning=None,
|
|
||||||
), regional_prompt_data
|
|
||||||
|
|
||||||
def _apply_standard_conditioning(
|
def _apply_standard_conditioning(
|
||||||
self,
|
self,
|
||||||
x,
|
x,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user