mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Pad conditionings using zeros and encoder_attention_mask
This commit is contained in:
parent
565299c7a1
commit
7093e5d033
@ -100,7 +100,7 @@ class CompelInvocation(BaseInvocation):
|
|||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
textual_inversion_manager=ti_manager,
|
textual_inversion_manager=ti_manager,
|
||||||
dtype_for_device_getter=torch_dtype,
|
dtype_for_device_getter=torch_dtype,
|
||||||
truncate_long_prompts=True, # TODO:
|
truncate_long_prompts=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
conjunction = Compel.parse_prompt_string(self.prompt)
|
conjunction = Compel.parse_prompt_string(self.prompt)
|
||||||
@ -112,9 +112,6 @@ class CompelInvocation(BaseInvocation):
|
|||||||
c, options = compel.build_conditioning_tensor_for_prompt_object(
|
c, options = compel.build_conditioning_tensor_for_prompt_object(
|
||||||
prompt)
|
prompt)
|
||||||
|
|
||||||
# TODO: long prompt support
|
|
||||||
# if not self.truncate_long_prompts:
|
|
||||||
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
|
||||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||||
tokens_count_including_eos_bos=get_max_token_count(
|
tokens_count_including_eos_bos=get_max_token_count(
|
||||||
tokenizer, conjunction),
|
tokenizer, conjunction),
|
||||||
|
@ -241,11 +241,45 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
||||||
# fast batched path
|
# fast batched path
|
||||||
|
|
||||||
|
def _pad_conditioning(cond, target_len, encoder_attention_mask):
|
||||||
|
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
|
||||||
|
|
||||||
|
if cond.shape[1] < max_len:
|
||||||
|
conditioning_attention_mask = torch.cat([
|
||||||
|
conditioning_attention_mask,
|
||||||
|
torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype),
|
||||||
|
], dim=1)
|
||||||
|
|
||||||
|
cond = torch.cat([
|
||||||
|
cond,
|
||||||
|
torch.zeros((cond.shape[0], max_len - cond.shape[1], cond.shape[2]), device=cond.device, dtype=cond.dtype),
|
||||||
|
], dim=1)
|
||||||
|
|
||||||
|
if encoder_attention_mask is None:
|
||||||
|
encoder_attention_mask = conditioning_attention_mask
|
||||||
|
else:
|
||||||
|
encoder_attention_mask = torch.cat([
|
||||||
|
encoder_attention_mask,
|
||||||
|
conditioning_attention_mask,
|
||||||
|
])
|
||||||
|
|
||||||
|
return cond, encoder_attention_mask
|
||||||
|
|
||||||
x_twice = torch.cat([x] * 2)
|
x_twice = torch.cat([x] * 2)
|
||||||
sigma_twice = torch.cat([sigma] * 2)
|
sigma_twice = torch.cat([sigma] * 2)
|
||||||
|
|
||||||
|
encoder_attention_mask = None
|
||||||
|
if unconditioning.shape[1] != conditioning.shape[1]:
|
||||||
|
max_len = max(unconditioning.shape[1], conditioning.shape[1])
|
||||||
|
unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask)
|
||||||
|
conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask)
|
||||||
|
|
||||||
both_conditionings = torch.cat([unconditioning, conditioning])
|
both_conditionings = torch.cat([unconditioning, conditioning])
|
||||||
both_results = self.model_forward_callback(
|
both_results = self.model_forward_callback(
|
||||||
x_twice, sigma_twice, both_conditionings, **kwargs,
|
x_twice, sigma_twice, both_conditionings,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
Loading…
Reference in New Issue
Block a user