Comments, a bit refactor

This commit is contained in:
Sergey Borisov
2024-07-17 04:20:31 +03:00
parent 79e35bd0d3
commit 2c2ec8f0bc
2 changed files with 98 additions and 71 deletions

View File

@ -137,6 +137,12 @@ class TextConditioningData:
return isinstance(self.cond_text, SDXLConditioningInfo)
def to_unet_kwargs(self, unet_kwargs: UNetKwargs, conditioning_mode: ConditioningMode):
"""Fills unet arguments with data from provided conditionings.
Args:
unet_kwargs (UNetKwargs): Object which stores UNet model arguments.
conditioning_mode (ConditioningMode): Describes which conditionings should be used.
"""
_, _, h, w = unet_kwargs.sample.shape
device = unet_kwargs.sample.device
dtype = unet_kwargs.sample.dtype
@ -187,7 +193,7 @@ class TextConditioningData:
)
@staticmethod
def _pad_zeros(t: torch.Tensor, pad_shape: tuple, dim: int):
def _pad_zeros(t: torch.Tensor, pad_shape: tuple, dim: int) -> torch.Tensor:
return torch.cat([t, torch.zeros(pad_shape, device=t.device, dtype=t.dtype)], dim=dim)
@classmethod
@ -195,8 +201,13 @@ class TextConditioningData:
cls,
cond: torch.Tensor,
target_len: int,
encoder_attention_mask: Optional[torch.Tensor],
):
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Pad provided conditioning tensor to target_len by zeros and returns mask of unpadded bytes.
Args:
cond (torch.Tensor): Conditioning tensor which to pads by zeros.
target_len (int): To which length(tokens count) pad tensor.
"""
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
if cond.shape[1] < target_len:
@ -212,21 +223,28 @@ class TextConditioningData:
dim=1,
)
if encoder_attention_mask is None:
encoder_attention_mask = conditioning_attention_mask
else:
encoder_attention_mask = torch.cat([encoder_attention_mask, conditioning_attention_mask])
return cond, encoder_attention_mask
return cond, conditioning_attention_mask
@classmethod
def _concat_conditionings_for_batch(cls, conditionings: List[torch.Tensor]):
def _concat_conditionings_for_batch(
cls,
conditionings: List[torch.Tensor],
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Concatenate provided conditioning tensors to one batched tensor.
If tensors have different sizes then pad them by zeros and creates
encoder_attention_mask to exclude padding from attention.
Args:
conditionings (List[torch.Tensor]): List of conditioning tensors to concatenate.
"""
encoder_attention_mask = None
max_len = max([c.shape[1] for c in conditionings])
if any(c.shape[1] != max_len for c in conditionings):
encoder_attention_masks = [None] * len(conditionings)
for i in range(len(conditionings)):
conditionings[i], encoder_attention_mask = cls._pad_conditioning(
conditionings[i], max_len, encoder_attention_mask
conditionings[i], encoder_attention_masks[i] = cls._pad_conditioning(
conditionings[i], max_len
)
encoder_attention_mask = torch.cat(encoder_attention_masks)
return torch.cat(conditionings), encoder_attention_mask