Move out _concat_conditionings_for_batch submethods

This commit is contained in:
Sergey Borisov 2024-07-17 03:31:26 +03:00
parent cd1bc1595a
commit ae6d4fbc78

View File

@ -172,25 +172,29 @@ class TextConditioningData:
regional_prompt_data=RegionalPromptData(regions=tmp_regions, device=device, dtype=dtype), regional_prompt_data=RegionalPromptData(regions=tmp_regions, device=device, dtype=dtype),
) )
def _concat_conditionings_for_batch(self, conditionings): @staticmethod
def _pad_zeros(t: torch.Tensor, pad_shape: tuple, dim: int): def _pad_zeros(t: torch.Tensor, pad_shape: tuple, dim: int):
return torch.cat([t, torch.zeros(pad_shape, device=t.device, dtype=t.dtype)], dim=dim) return torch.cat([t, torch.zeros(pad_shape, device=t.device, dtype=t.dtype)], dim=dim)
def _pad_conditioning(cond, target_len, encoder_attention_mask): @classmethod
conditioning_attention_mask = torch.ones( def _pad_conditioning(
(cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype cls,
) cond: torch.Tensor,
target_len: int,
encoder_attention_mask: Optional[torch.Tensor],
):
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
if cond.shape[1] < max_len: if cond.shape[1] < target_len:
conditioning_attention_mask = _pad_zeros( conditioning_attention_mask = cls._pad_zeros(
conditioning_attention_mask, conditioning_attention_mask,
pad_shape=(cond.shape[0], max_len - cond.shape[1]), pad_shape=(cond.shape[0], target_len - cond.shape[1]),
dim=1, dim=1,
) )
cond = _pad_zeros( cond = cls._pad_zeros(
cond, cond,
pad_shape=(cond.shape[0], max_len - cond.shape[1], cond.shape[2]), pad_shape=(cond.shape[0], target_len - cond.shape[1], cond.shape[2]),
dim=1, dim=1,
) )
@ -201,11 +205,13 @@ class TextConditioningData:
return cond, encoder_attention_mask return cond, encoder_attention_mask
@classmethod
def _concat_conditionings_for_batch(cls, conditionings: List[torch.Tensor]):
encoder_attention_mask = None encoder_attention_mask = None
max_len = max([c.shape[1] for c in conditionings]) max_len = max([c.shape[1] for c in conditionings])
if any(c.shape[1] != max_len for c in conditionings): if any(c.shape[1] != max_len for c in conditionings):
for i in range(len(conditionings)): for i in range(len(conditionings)):
conditionings[i], encoder_attention_mask = _pad_conditioning( conditionings[i], encoder_attention_mask = cls._pad_conditioning(
conditionings[i], max_len, encoder_attention_mask conditionings[i], max_len, encoder_attention_mask
) )