mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Comments, a bit refactor
This commit is contained in:
@ -137,6 +137,12 @@ class TextConditioningData:
|
||||
return isinstance(self.cond_text, SDXLConditioningInfo)
|
||||
|
||||
def to_unet_kwargs(self, unet_kwargs: UNetKwargs, conditioning_mode: ConditioningMode):
|
||||
"""Fills unet arguments with data from provided conditionings.
|
||||
|
||||
Args:
|
||||
unet_kwargs (UNetKwargs): Object which stores UNet model arguments.
|
||||
conditioning_mode (ConditioningMode): Describes which conditionings should be used.
|
||||
"""
|
||||
_, _, h, w = unet_kwargs.sample.shape
|
||||
device = unet_kwargs.sample.device
|
||||
dtype = unet_kwargs.sample.dtype
|
||||
@ -187,7 +193,7 @@ class TextConditioningData:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _pad_zeros(t: torch.Tensor, pad_shape: tuple, dim: int):
|
||||
def _pad_zeros(t: torch.Tensor, pad_shape: tuple, dim: int) -> torch.Tensor:
|
||||
return torch.cat([t, torch.zeros(pad_shape, device=t.device, dtype=t.dtype)], dim=dim)
|
||||
|
||||
@classmethod
|
||||
@ -195,8 +201,13 @@ class TextConditioningData:
|
||||
cls,
|
||||
cond: torch.Tensor,
|
||||
target_len: int,
|
||||
encoder_attention_mask: Optional[torch.Tensor],
|
||||
):
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Pad provided conditioning tensor to target_len by zeros and returns mask of unpadded bytes.
|
||||
|
||||
Args:
|
||||
cond (torch.Tensor): Conditioning tensor which to pads by zeros.
|
||||
target_len (int): To which length(tokens count) pad tensor.
|
||||
"""
|
||||
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
|
||||
|
||||
if cond.shape[1] < target_len:
|
||||
@ -212,21 +223,28 @@ class TextConditioningData:
|
||||
dim=1,
|
||||
)
|
||||
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = conditioning_attention_mask
|
||||
else:
|
||||
encoder_attention_mask = torch.cat([encoder_attention_mask, conditioning_attention_mask])
|
||||
|
||||
return cond, encoder_attention_mask
|
||||
return cond, conditioning_attention_mask
|
||||
|
||||
@classmethod
|
||||
def _concat_conditionings_for_batch(cls, conditionings: List[torch.Tensor]):
|
||||
def _concat_conditionings_for_batch(
|
||||
cls,
|
||||
conditionings: List[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Concatenate provided conditioning tensors to one batched tensor.
|
||||
If tensors have different sizes then pad them by zeros and creates
|
||||
encoder_attention_mask to exclude padding from attention.
|
||||
|
||||
Args:
|
||||
conditionings (List[torch.Tensor]): List of conditioning tensors to concatenate.
|
||||
"""
|
||||
encoder_attention_mask = None
|
||||
max_len = max([c.shape[1] for c in conditionings])
|
||||
if any(c.shape[1] != max_len for c in conditionings):
|
||||
encoder_attention_masks = [None] * len(conditionings)
|
||||
for i in range(len(conditionings)):
|
||||
conditionings[i], encoder_attention_mask = cls._pad_conditioning(
|
||||
conditionings[i], max_len, encoder_attention_mask
|
||||
conditionings[i], encoder_attention_masks[i] = cls._pad_conditioning(
|
||||
conditionings[i], max_len
|
||||
)
|
||||
encoder_attention_mask = torch.cat(encoder_attention_masks)
|
||||
|
||||
return torch.cat(conditionings), encoder_attention_mask
|
||||
|
Reference in New Issue
Block a user