diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 303e0a0c84..a5a9701149 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -100,7 +100,7 @@ class CompelInvocation(BaseInvocation): text_encoder=text_encoder, textual_inversion_manager=ti_manager, dtype_for_device_getter=torch_dtype, - truncate_long_prompts=True, # TODO: + truncate_long_prompts=False, ) conjunction = Compel.parse_prompt_string(self.prompt) @@ -112,9 +112,6 @@ class CompelInvocation(BaseInvocation): c, options = compel.build_conditioning_tensor_for_prompt_object( prompt) - # TODO: long prompt support - # if not self.truncate_long_prompts: - # [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc]) ec = InvokeAIDiffuserComponent.ExtraConditioningInfo( tokens_count_including_eos_bos=get_max_token_count( tokenizer, conjunction), diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index 1175475bba..307e949ef8 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -241,11 +241,45 @@ class InvokeAIDiffuserComponent: def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs): # fast batched path + + def _pad_conditioning(cond, target_len, encoder_attention_mask): + conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype) + + if cond.shape[1] < max_len: + conditioning_attention_mask = torch.cat([ + conditioning_attention_mask, + torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype), + ], dim=1) + + cond = torch.cat([ + cond, + torch.zeros((cond.shape[0], max_len - cond.shape[1], cond.shape[2]), device=cond.device, dtype=cond.dtype), + ], dim=1) + + if encoder_attention_mask is None: + encoder_attention_mask = conditioning_attention_mask + else: + encoder_attention_mask = torch.cat([ + encoder_attention_mask, + conditioning_attention_mask, + ]) + + return cond, encoder_attention_mask + x_twice = torch.cat([x] * 2) sigma_twice = torch.cat([sigma] * 2) + + encoder_attention_mask = None + if unconditioning.shape[1] != conditioning.shape[1]: + max_len = max(unconditioning.shape[1], conditioning.shape[1]) + unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask) + conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask) + both_conditionings = torch.cat([unconditioning, conditioning]) both_results = self.model_forward_callback( - x_twice, sigma_twice, both_conditionings, **kwargs, + x_twice, sigma_twice, both_conditionings, + encoder_attention_mask=encoder_attention_mask, + **kwargs, ) unconditioned_next_x, conditioned_next_x = both_results.chunk(2) return unconditioned_next_x, conditioned_next_x