Fix - encoder_attention_mask not passed before to unet, even if passed it will broke sequential guidance run, so rewrite logic

This commit is contained in:
Sergey Borisov
2023-07-17 23:13:37 +03:00
parent 1d3fda80aa
commit 1c680a7147
2 changed files with 44 additions and 41 deletions

View File

@ -237,6 +237,39 @@ class InvokeAIDiffuserComponent:
)
return latents
def _concat_conditionings_for_batch(self, unconditioning, conditioning):
def _pad_conditioning(cond, target_len, encoder_attention_mask):
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
if cond.shape[1] < max_len:
conditioning_attention_mask = torch.cat([
conditioning_attention_mask,
torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype),
], dim=1)
cond = torch.cat([
cond,
torch.zeros((cond.shape[0], max_len - cond.shape[1], cond.shape[2]), device=cond.device, dtype=cond.dtype),
], dim=1)
if encoder_attention_mask is None:
encoder_attention_mask = conditioning_attention_mask
else:
encoder_attention_mask = torch.cat([
encoder_attention_mask,
conditioning_attention_mask,
])
return cond, encoder_attention_mask
encoder_attention_mask = None
if unconditioning.shape[1] != conditioning.shape[1]:
max_len = max(unconditioning.shape[1], conditioning.shape[1])
unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask)
conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask)
return torch.cat([unconditioning, conditioning]), encoder_attention_mask
# methods below are called from do_diffusion_step and should be considered private to this class.
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
@ -244,9 +277,13 @@ class InvokeAIDiffuserComponent:
x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2)
both_conditionings = torch.cat([unconditioning, conditioning])
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
unconditioning, conditioning
)
both_results = self.model_forward_callback(
x_twice, sigma_twice, both_conditionings, **kwargs,
x_twice, sigma_twice, both_conditionings,
encoder_attention_mask=encoder_attention_mask,
**kwargs,
)
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
return unconditioned_next_x, conditioned_next_x