Move out _concat_conditionings_for_batch submethods

This commit is contained in:
Sergey Borisov 2024-07-17 03:31:26 +03:00
parent cd1bc1595a
commit ae6d4fbc78

View File

@ -172,40 +172,46 @@ class TextConditioningData:
regional_prompt_data=RegionalPromptData(regions=tmp_regions, device=device, dtype=dtype),
)
def _concat_conditionings_for_batch(self, conditionings):
def _pad_zeros(t: torch.Tensor, pad_shape: tuple, dim: int):
return torch.cat([t, torch.zeros(pad_shape, device=t.device, dtype=t.dtype)], dim=dim)
@staticmethod
def _pad_zeros(t: torch.Tensor, pad_shape: tuple, dim: int):
return torch.cat([t, torch.zeros(pad_shape, device=t.device, dtype=t.dtype)], dim=dim)
def _pad_conditioning(cond, target_len, encoder_attention_mask):
conditioning_attention_mask = torch.ones(
(cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype
@classmethod
def _pad_conditioning(
cls,
cond: torch.Tensor,
target_len: int,
encoder_attention_mask: Optional[torch.Tensor],
):
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
if cond.shape[1] < target_len:
conditioning_attention_mask = cls._pad_zeros(
conditioning_attention_mask,
pad_shape=(cond.shape[0], target_len - cond.shape[1]),
dim=1,
)
if cond.shape[1] < max_len:
conditioning_attention_mask = _pad_zeros(
conditioning_attention_mask,
pad_shape=(cond.shape[0], max_len - cond.shape[1]),
dim=1,
)
cond = cls._pad_zeros(
cond,
pad_shape=(cond.shape[0], target_len - cond.shape[1], cond.shape[2]),
dim=1,
)
cond = _pad_zeros(
cond,
pad_shape=(cond.shape[0], max_len - cond.shape[1], cond.shape[2]),
dim=1,
)
if encoder_attention_mask is None:
encoder_attention_mask = conditioning_attention_mask
else:
encoder_attention_mask = torch.cat([encoder_attention_mask, conditioning_attention_mask])
if encoder_attention_mask is None:
encoder_attention_mask = conditioning_attention_mask
else:
encoder_attention_mask = torch.cat([encoder_attention_mask, conditioning_attention_mask])
return cond, encoder_attention_mask
return cond, encoder_attention_mask
@classmethod
def _concat_conditionings_for_batch(cls, conditionings: List[torch.Tensor]):
encoder_attention_mask = None
max_len = max([c.shape[1] for c in conditionings])
if any(c.shape[1] != max_len for c in conditionings):
for i in range(len(conditionings)):
conditionings[i], encoder_attention_mask = _pad_conditioning(
conditionings[i], encoder_attention_mask = cls._pad_conditioning(
conditionings[i], max_len, encoder_attention_mask
)