From 1c680a71473b6b7ca939c28e24c26955be94d3c9 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Mon, 17 Jul 2023 23:13:37 +0300 Subject: [PATCH] Fix - encoder_attention_mask not passed before to unet, even if passed it will broke sequential guidance run, so rewrite logic --- .../stable_diffusion/diffusers_pipeline.py | 44 +++---------------- .../diffusion/shared_invokeai_diffusion.py | 41 ++++++++++++++++- 2 files changed, 44 insertions(+), 41 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 6c170323cf..120bfb9663 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -507,40 +507,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): control_data: List[ControlNetData] = None, **kwargs, ): - def _pad_conditioning(cond, target_len, encoder_attention_mask): - conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype) - - if cond.shape[1] < max_len: - conditioning_attention_mask = torch.cat([ - conditioning_attention_mask, - torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype), - ], dim=1) - - cond = torch.cat([ - cond, - torch.zeros((cond.shape[0], max_len - cond.shape[1], cond.shape[2]), device=cond.device, dtype=cond.dtype), - ], dim=1) - - if encoder_attention_mask is None: - encoder_attention_mask = conditioning_attention_mask - else: - encoder_attention_mask = torch.cat([ - encoder_attention_mask, - conditioning_attention_mask, - ]) - - return cond, encoder_attention_mask - - encoder_attention_mask = None - if conditioning_data.unconditioned_embeddings.shape[1] != conditioning_data.text_embeddings.shape[1]: - max_len = max(conditioning_data.unconditioned_embeddings.shape[1], conditioning_data.text_embeddings.shape[1]) - conditioning_data.unconditioned_embeddings, encoder_attention_mask = _pad_conditioning( - conditioning_data.unconditioned_embeddings, max_len, encoder_attention_mask - ) - conditioning_data.text_embeddings, encoder_attention_mask = _pad_conditioning( - conditioning_data.text_embeddings, max_len, encoder_attention_mask - ) - self._adjust_memory_efficient_attention(latents) if run_id is None: run_id = secrets.token_urlsafe(self.ID_LENGTH) @@ -580,7 +546,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): total_step_count=len(timesteps), additional_guidance=additional_guidance, control_data=control_data, - encoder_attention_mask=encoder_attention_mask, **kwargs, ) latents = step_output.prev_sample @@ -638,8 +603,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): down_block_res_samples, mid_block_res_sample = None, None if control_data is not None: - # TODO: rewrite to pass with conditionings - encoder_attention_mask = kwargs.get("encoder_attention_mask", None) # control_data should be type List[ControlNetData] # this loop covers both ControlNet (one ControlNetData in list) # and MultiControlNet (multiple ControlNetData in list) @@ -669,9 +632,12 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned encoder_hidden_states = conditioning_data.text_embeddings + encoder_attention_mask = None else: - encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings, - conditioning_data.text_embeddings]) + encoder_hidden_states, encoder_hidden_states = self.invokeai_diffuser._concat_conditionings_for_batch( + conditioning_data.unconditioned_embeddings, + conditioning_data.text_embeddings, + ) if isinstance(control_datum.weight, list): # if controlnet has multiple weights, use the weight for the current step controlnet_weight = control_datum.weight[step_index] diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index 3fb2df8ce1..b637ceb815 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -237,6 +237,39 @@ class InvokeAIDiffuserComponent: ) return latents + def _concat_conditionings_for_batch(self, unconditioning, conditioning): + def _pad_conditioning(cond, target_len, encoder_attention_mask): + conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype) + + if cond.shape[1] < max_len: + conditioning_attention_mask = torch.cat([ + conditioning_attention_mask, + torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype), + ], dim=1) + + cond = torch.cat([ + cond, + torch.zeros((cond.shape[0], max_len - cond.shape[1], cond.shape[2]), device=cond.device, dtype=cond.dtype), + ], dim=1) + + if encoder_attention_mask is None: + encoder_attention_mask = conditioning_attention_mask + else: + encoder_attention_mask = torch.cat([ + encoder_attention_mask, + conditioning_attention_mask, + ]) + + return cond, encoder_attention_mask + + encoder_attention_mask = None + if unconditioning.shape[1] != conditioning.shape[1]: + max_len = max(unconditioning.shape[1], conditioning.shape[1]) + unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask) + conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask) + + return torch.cat([unconditioning, conditioning]), encoder_attention_mask + # methods below are called from do_diffusion_step and should be considered private to this class. def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs): @@ -244,9 +277,13 @@ class InvokeAIDiffuserComponent: x_twice = torch.cat([x] * 2) sigma_twice = torch.cat([sigma] * 2) - both_conditionings = torch.cat([unconditioning, conditioning]) + both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch( + unconditioning, conditioning + ) both_results = self.model_forward_callback( - x_twice, sigma_twice, both_conditionings, **kwargs, + x_twice, sigma_twice, both_conditionings, + encoder_attention_mask=encoder_attention_mask, + **kwargs, ) unconditioned_next_x, conditioned_next_x = both_results.chunk(2) return unconditioned_next_x, conditioned_next_x