mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Move out _concat_conditionings_for_batch submethods
This commit is contained in:
parent
cd1bc1595a
commit
ae6d4fbc78
@ -172,25 +172,29 @@ class TextConditioningData:
|
||||
regional_prompt_data=RegionalPromptData(regions=tmp_regions, device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
def _concat_conditionings_for_batch(self, conditionings):
|
||||
@staticmethod
|
||||
def _pad_zeros(t: torch.Tensor, pad_shape: tuple, dim: int):
|
||||
return torch.cat([t, torch.zeros(pad_shape, device=t.device, dtype=t.dtype)], dim=dim)
|
||||
|
||||
def _pad_conditioning(cond, target_len, encoder_attention_mask):
|
||||
conditioning_attention_mask = torch.ones(
|
||||
(cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype
|
||||
)
|
||||
@classmethod
|
||||
def _pad_conditioning(
|
||||
cls,
|
||||
cond: torch.Tensor,
|
||||
target_len: int,
|
||||
encoder_attention_mask: Optional[torch.Tensor],
|
||||
):
|
||||
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
|
||||
|
||||
if cond.shape[1] < max_len:
|
||||
conditioning_attention_mask = _pad_zeros(
|
||||
if cond.shape[1] < target_len:
|
||||
conditioning_attention_mask = cls._pad_zeros(
|
||||
conditioning_attention_mask,
|
||||
pad_shape=(cond.shape[0], max_len - cond.shape[1]),
|
||||
pad_shape=(cond.shape[0], target_len - cond.shape[1]),
|
||||
dim=1,
|
||||
)
|
||||
|
||||
cond = _pad_zeros(
|
||||
cond = cls._pad_zeros(
|
||||
cond,
|
||||
pad_shape=(cond.shape[0], max_len - cond.shape[1], cond.shape[2]),
|
||||
pad_shape=(cond.shape[0], target_len - cond.shape[1], cond.shape[2]),
|
||||
dim=1,
|
||||
)
|
||||
|
||||
@ -201,11 +205,13 @@ class TextConditioningData:
|
||||
|
||||
return cond, encoder_attention_mask
|
||||
|
||||
@classmethod
|
||||
def _concat_conditionings_for_batch(cls, conditionings: List[torch.Tensor]):
|
||||
encoder_attention_mask = None
|
||||
max_len = max([c.shape[1] for c in conditionings])
|
||||
if any(c.shape[1] != max_len for c in conditionings):
|
||||
for i in range(len(conditionings)):
|
||||
conditionings[i], encoder_attention_mask = _pad_conditioning(
|
||||
conditionings[i], encoder_attention_mask = cls._pad_conditioning(
|
||||
conditionings[i], max_len, encoder_attention_mask
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user