diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 21fb8d5780..80b671df65 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -172,40 +172,46 @@ class TextConditioningData: regional_prompt_data=RegionalPromptData(regions=tmp_regions, device=device, dtype=dtype), ) - def _concat_conditionings_for_batch(self, conditionings): - def _pad_zeros(t: torch.Tensor, pad_shape: tuple, dim: int): - return torch.cat([t, torch.zeros(pad_shape, device=t.device, dtype=t.dtype)], dim=dim) + @staticmethod + def _pad_zeros(t: torch.Tensor, pad_shape: tuple, dim: int): + return torch.cat([t, torch.zeros(pad_shape, device=t.device, dtype=t.dtype)], dim=dim) - def _pad_conditioning(cond, target_len, encoder_attention_mask): - conditioning_attention_mask = torch.ones( - (cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype + @classmethod + def _pad_conditioning( + cls, + cond: torch.Tensor, + target_len: int, + encoder_attention_mask: Optional[torch.Tensor], + ): + conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype) + + if cond.shape[1] < target_len: + conditioning_attention_mask = cls._pad_zeros( + conditioning_attention_mask, + pad_shape=(cond.shape[0], target_len - cond.shape[1]), + dim=1, ) - if cond.shape[1] < max_len: - conditioning_attention_mask = _pad_zeros( - conditioning_attention_mask, - pad_shape=(cond.shape[0], max_len - cond.shape[1]), - dim=1, - ) + cond = cls._pad_zeros( + cond, + pad_shape=(cond.shape[0], target_len - cond.shape[1], cond.shape[2]), + dim=1, + ) - cond = _pad_zeros( - cond, - pad_shape=(cond.shape[0], max_len - cond.shape[1], cond.shape[2]), - dim=1, - ) + if encoder_attention_mask is None: + encoder_attention_mask = conditioning_attention_mask + else: + encoder_attention_mask = torch.cat([encoder_attention_mask, conditioning_attention_mask]) - if encoder_attention_mask is None: - encoder_attention_mask = conditioning_attention_mask - else: - encoder_attention_mask = torch.cat([encoder_attention_mask, conditioning_attention_mask]) - - return cond, encoder_attention_mask + return cond, encoder_attention_mask + @classmethod + def _concat_conditionings_for_batch(cls, conditionings: List[torch.Tensor]): encoder_attention_mask = None max_len = max([c.shape[1] for c in conditionings]) if any(c.shape[1] != max_len for c in conditionings): for i in range(len(conditionings)): - conditionings[i], encoder_attention_mask = _pad_conditioning( + conditionings[i], encoder_attention_mask = cls._pad_conditioning( conditionings[i], max_len, encoder_attention_mask )