Ruff format

This commit is contained in:
Sergey Borisov 2024-07-17 04:24:45 +03:00
parent 2c2ec8f0bc
commit 3f79467f7b

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, List, Optional, Union from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import torch import torch
@ -242,9 +242,7 @@ class TextConditioningData:
if any(c.shape[1] != max_len for c in conditionings): if any(c.shape[1] != max_len for c in conditionings):
encoder_attention_masks = [None] * len(conditionings) encoder_attention_masks = [None] * len(conditionings)
for i in range(len(conditionings)): for i in range(len(conditionings)):
conditionings[i], encoder_attention_masks[i] = cls._pad_conditioning( conditionings[i], encoder_attention_masks[i] = cls._pad_conditioning(conditionings[i], max_len)
conditionings[i], max_len
)
encoder_attention_mask = torch.cat(encoder_attention_masks) encoder_attention_mask = torch.cat(encoder_attention_masks)
return torch.cat(conditionings), encoder_attention_mask return torch.cat(conditionings), encoder_attention_mask