Move out _concat_conditionings_for_batch submethods

This commit is contained in:
Sergey Borisov 2024-07-17 03:31:26 +03:00
parent cd1bc1595a
commit ae6d4fbc78

View File

@ -172,40 +172,46 @@ class TextConditioningData:
regional_prompt_data=RegionalPromptData(regions=tmp_regions, device=device, dtype=dtype), regional_prompt_data=RegionalPromptData(regions=tmp_regions, device=device, dtype=dtype),
) )
def _concat_conditionings_for_batch(self, conditionings): @staticmethod
def _pad_zeros(t: torch.Tensor, pad_shape: tuple, dim: int): def _pad_zeros(t: torch.Tensor, pad_shape: tuple, dim: int):
return torch.cat([t, torch.zeros(pad_shape, device=t.device, dtype=t.dtype)], dim=dim) return torch.cat([t, torch.zeros(pad_shape, device=t.device, dtype=t.dtype)], dim=dim)
def _pad_conditioning(cond, target_len, encoder_attention_mask): @classmethod
conditioning_attention_mask = torch.ones( def _pad_conditioning(
(cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype cls,
cond: torch.Tensor,
target_len: int,
encoder_attention_mask: Optional[torch.Tensor],
):
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
if cond.shape[1] < target_len:
conditioning_attention_mask = cls._pad_zeros(
conditioning_attention_mask,
pad_shape=(cond.shape[0], target_len - cond.shape[1]),
dim=1,
) )
if cond.shape[1] < max_len: cond = cls._pad_zeros(
conditioning_attention_mask = _pad_zeros( cond,
conditioning_attention_mask, pad_shape=(cond.shape[0], target_len - cond.shape[1], cond.shape[2]),
pad_shape=(cond.shape[0], max_len - cond.shape[1]), dim=1,
dim=1, )
)
cond = _pad_zeros( if encoder_attention_mask is None:
cond, encoder_attention_mask = conditioning_attention_mask
pad_shape=(cond.shape[0], max_len - cond.shape[1], cond.shape[2]), else:
dim=1, encoder_attention_mask = torch.cat([encoder_attention_mask, conditioning_attention_mask])
)
if encoder_attention_mask is None: return cond, encoder_attention_mask
encoder_attention_mask = conditioning_attention_mask
else:
encoder_attention_mask = torch.cat([encoder_attention_mask, conditioning_attention_mask])
return cond, encoder_attention_mask
@classmethod
def _concat_conditionings_for_batch(cls, conditionings: List[torch.Tensor]):
encoder_attention_mask = None encoder_attention_mask = None
max_len = max([c.shape[1] for c in conditionings]) max_len = max([c.shape[1] for c in conditionings])
if any(c.shape[1] != max_len for c in conditionings): if any(c.shape[1] != max_len for c in conditionings):
for i in range(len(conditionings)): for i in range(len(conditionings)):
conditionings[i], encoder_attention_mask = _pad_conditioning( conditionings[i], encoder_attention_mask = cls._pad_conditioning(
conditionings[i], max_len, encoder_attention_mask conditionings[i], max_len, encoder_attention_mask
) )