Ruff format

This commit is contained in:
Sergey Borisov 2024-07-17 04:24:45 +03:00
parent 2c2ec8f0bc
commit 3f79467f7b

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import math
from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING, List, Optional, Union
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import torch
@ -231,7 +231,7 @@ class TextConditioningData:
conditionings: List[torch.Tensor],
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Concatenate provided conditioning tensors to one batched tensor.
If tensors have different sizes then pad them by zeros and creates
If tensors have different sizes then pad them by zeros and creates
encoder_attention_mask to exclude padding from attention.
Args:
@ -242,9 +242,7 @@ class TextConditioningData:
if any(c.shape[1] != max_len for c in conditionings):
encoder_attention_masks = [None] * len(conditionings)
for i in range(len(conditionings)):
conditionings[i], encoder_attention_masks[i] = cls._pad_conditioning(
conditionings[i], max_len
)
conditionings[i], encoder_attention_masks[i] = cls._pad_conditioning(conditionings[i], max_len)
encoder_attention_mask = torch.cat(encoder_attention_masks)
return torch.cat(conditionings), encoder_attention_mask