mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Move out _concat_conditionings_for_batch submethods
This commit is contained in:
parent
cd1bc1595a
commit
ae6d4fbc78
@ -172,40 +172,46 @@ class TextConditioningData:
|
|||||||
regional_prompt_data=RegionalPromptData(regions=tmp_regions, device=device, dtype=dtype),
|
regional_prompt_data=RegionalPromptData(regions=tmp_regions, device=device, dtype=dtype),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _concat_conditionings_for_batch(self, conditionings):
|
@staticmethod
|
||||||
def _pad_zeros(t: torch.Tensor, pad_shape: tuple, dim: int):
|
def _pad_zeros(t: torch.Tensor, pad_shape: tuple, dim: int):
|
||||||
return torch.cat([t, torch.zeros(pad_shape, device=t.device, dtype=t.dtype)], dim=dim)
|
return torch.cat([t, torch.zeros(pad_shape, device=t.device, dtype=t.dtype)], dim=dim)
|
||||||
|
|
||||||
def _pad_conditioning(cond, target_len, encoder_attention_mask):
|
@classmethod
|
||||||
conditioning_attention_mask = torch.ones(
|
def _pad_conditioning(
|
||||||
(cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype
|
cls,
|
||||||
|
cond: torch.Tensor,
|
||||||
|
target_len: int,
|
||||||
|
encoder_attention_mask: Optional[torch.Tensor],
|
||||||
|
):
|
||||||
|
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
|
||||||
|
|
||||||
|
if cond.shape[1] < target_len:
|
||||||
|
conditioning_attention_mask = cls._pad_zeros(
|
||||||
|
conditioning_attention_mask,
|
||||||
|
pad_shape=(cond.shape[0], target_len - cond.shape[1]),
|
||||||
|
dim=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cond.shape[1] < max_len:
|
cond = cls._pad_zeros(
|
||||||
conditioning_attention_mask = _pad_zeros(
|
cond,
|
||||||
conditioning_attention_mask,
|
pad_shape=(cond.shape[0], target_len - cond.shape[1], cond.shape[2]),
|
||||||
pad_shape=(cond.shape[0], max_len - cond.shape[1]),
|
dim=1,
|
||||||
dim=1,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
cond = _pad_zeros(
|
if encoder_attention_mask is None:
|
||||||
cond,
|
encoder_attention_mask = conditioning_attention_mask
|
||||||
pad_shape=(cond.shape[0], max_len - cond.shape[1], cond.shape[2]),
|
else:
|
||||||
dim=1,
|
encoder_attention_mask = torch.cat([encoder_attention_mask, conditioning_attention_mask])
|
||||||
)
|
|
||||||
|
|
||||||
if encoder_attention_mask is None:
|
return cond, encoder_attention_mask
|
||||||
encoder_attention_mask = conditioning_attention_mask
|
|
||||||
else:
|
|
||||||
encoder_attention_mask = torch.cat([encoder_attention_mask, conditioning_attention_mask])
|
|
||||||
|
|
||||||
return cond, encoder_attention_mask
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _concat_conditionings_for_batch(cls, conditionings: List[torch.Tensor]):
|
||||||
encoder_attention_mask = None
|
encoder_attention_mask = None
|
||||||
max_len = max([c.shape[1] for c in conditionings])
|
max_len = max([c.shape[1] for c in conditionings])
|
||||||
if any(c.shape[1] != max_len for c in conditionings):
|
if any(c.shape[1] != max_len for c in conditionings):
|
||||||
for i in range(len(conditionings)):
|
for i in range(len(conditionings)):
|
||||||
conditionings[i], encoder_attention_mask = _pad_conditioning(
|
conditionings[i], encoder_attention_mask = cls._pad_conditioning(
|
||||||
conditionings[i], max_len, encoder_attention_mask
|
conditionings[i], max_len, encoder_attention_mask
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user