Ruff format

This commit is contained in:
Sergey Borisov 2024-07-17 04:24:45 +03:00
parent 2c2ec8f0bc
commit 3f79467f7b

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import math
from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING, List, Optional, Union
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import torch
@ -242,9 +242,7 @@ class TextConditioningData:
if any(c.shape[1] != max_len for c in conditionings):
encoder_attention_masks = [None] * len(conditionings)
for i in range(len(conditionings)):
conditionings[i], encoder_attention_masks[i] = cls._pad_conditioning(
conditionings[i], max_len
)
conditionings[i], encoder_attention_masks[i] = cls._pad_conditioning(conditionings[i], max_len)
encoder_attention_mask = torch.cat(encoder_attention_masks)
return torch.cat(conditionings), encoder_attention_mask