Ruff format

This commit is contained in:
Sergey Borisov 2024-07-17 04:24:45 +03:00
parent 2c2ec8f0bc
commit 3f79467f7b

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, List, Optional, Union from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import torch import torch
@ -231,7 +231,7 @@ class TextConditioningData:
conditionings: List[torch.Tensor], conditionings: List[torch.Tensor],
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Concatenate provided conditioning tensors to one batched tensor. """Concatenate provided conditioning tensors to one batched tensor.
If tensors have different sizes then pad them by zeros and creates If tensors have different sizes then pad them by zeros and creates
encoder_attention_mask to exclude padding from attention. encoder_attention_mask to exclude padding from attention.
Args: Args:
@ -242,9 +242,7 @@ class TextConditioningData:
if any(c.shape[1] != max_len for c in conditionings): if any(c.shape[1] != max_len for c in conditionings):
encoder_attention_masks = [None] * len(conditionings) encoder_attention_masks = [None] * len(conditionings)
for i in range(len(conditionings)): for i in range(len(conditionings)):
conditionings[i], encoder_attention_masks[i] = cls._pad_conditioning( conditionings[i], encoder_attention_masks[i] = cls._pad_conditioning(conditionings[i], max_len)
conditionings[i], max_len
)
encoder_attention_mask = torch.cat(encoder_attention_masks) encoder_attention_mask = torch.cat(encoder_attention_masks)
return torch.cat(conditionings), encoder_attention_mask return torch.cat(conditionings), encoder_attention_mask