diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 6c170323cf..228fbd0585 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -422,7 +422,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): noise: torch.Tensor, callback: Callable[[PipelineIntermediateState], None] = None, run_id=None, - **kwargs, ) -> InvokeAIStableDiffusionPipelineOutput: r""" Function invoked when calling the pipeline for generation. @@ -443,7 +442,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): noise=noise, run_id=run_id, callback=callback, - **kwargs, ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 torch.cuda.empty_cache() @@ -469,7 +467,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): run_id=None, callback: Callable[[PipelineIntermediateState], None] = None, control_data: List[ControlNetData] = None, - **kwargs, ) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]: if self.scheduler.config.get("cpu_only", False): scheduler_device = torch.device('cpu') @@ -487,11 +484,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): timesteps, conditioning_data, noise=noise, - additional_guidance=additional_guidance, run_id=run_id, - callback=callback, + additional_guidance=additional_guidance, control_data=control_data, - **kwargs, + + callback=callback, ) return result.latents, result.attention_map_saver @@ -505,42 +502,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): run_id: str = None, additional_guidance: List[Callable] = None, control_data: List[ControlNetData] = None, - **kwargs, ): - def _pad_conditioning(cond, target_len, encoder_attention_mask): - conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype) - - if cond.shape[1] < max_len: - conditioning_attention_mask = torch.cat([ - conditioning_attention_mask, - torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype), - ], dim=1) - - cond = torch.cat([ - cond, - torch.zeros((cond.shape[0], max_len - cond.shape[1], cond.shape[2]), device=cond.device, dtype=cond.dtype), - ], dim=1) - - if encoder_attention_mask is None: - encoder_attention_mask = conditioning_attention_mask - else: - encoder_attention_mask = torch.cat([ - encoder_attention_mask, - conditioning_attention_mask, - ]) - - return cond, encoder_attention_mask - - encoder_attention_mask = None - if conditioning_data.unconditioned_embeddings.shape[1] != conditioning_data.text_embeddings.shape[1]: - max_len = max(conditioning_data.unconditioned_embeddings.shape[1], conditioning_data.text_embeddings.shape[1]) - conditioning_data.unconditioned_embeddings, encoder_attention_mask = _pad_conditioning( - conditioning_data.unconditioned_embeddings, max_len, encoder_attention_mask - ) - conditioning_data.text_embeddings, encoder_attention_mask = _pad_conditioning( - conditioning_data.text_embeddings, max_len, encoder_attention_mask - ) - self._adjust_memory_efficient_attention(latents) if run_id is None: run_id = secrets.token_urlsafe(self.ID_LENGTH) @@ -580,8 +542,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): total_step_count=len(timesteps), additional_guidance=additional_guidance, control_data=control_data, - encoder_attention_mask=encoder_attention_mask, - **kwargs, ) latents = step_output.prev_sample @@ -623,7 +583,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): total_step_count: int, additional_guidance: List[Callable] = None, control_data: List[ControlNetData] = None, - **kwargs, ): # invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value timestep = t[0] @@ -638,8 +597,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): down_block_res_samples, mid_block_res_sample = None, None if control_data is not None: - # TODO: rewrite to pass with conditionings - encoder_attention_mask = kwargs.get("encoder_attention_mask", None) # control_data should be type List[ControlNetData] # this loop covers both ControlNet (one ControlNetData in list) # and MultiControlNet (multiple ControlNetData in list) @@ -669,9 +626,12 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned encoder_hidden_states = conditioning_data.text_embeddings + encoder_attention_mask = None else: - encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings, - conditioning_data.text_embeddings]) + encoder_hidden_states, encoder_attention_mask = self.invokeai_diffuser._concat_conditionings_for_batch( + conditioning_data.unconditioned_embeddings, + conditioning_data.text_embeddings, + ) if isinstance(control_datum.weight, list): # if controlnet has multiple weights, use the weight for the current step controlnet_weight = control_datum.weight[step_index] diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index 3fb2df8ce1..f44578cd47 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -237,6 +237,39 @@ class InvokeAIDiffuserComponent: ) return latents + def _concat_conditionings_for_batch(self, unconditioning, conditioning): + def _pad_conditioning(cond, target_len, encoder_attention_mask): + conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype) + + if cond.shape[1] < max_len: + conditioning_attention_mask = torch.cat([ + conditioning_attention_mask, + torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype), + ], dim=1) + + cond = torch.cat([ + cond, + torch.zeros((cond.shape[0], max_len - cond.shape[1], cond.shape[2]), device=cond.device, dtype=cond.dtype), + ], dim=1) + + if encoder_attention_mask is None: + encoder_attention_mask = conditioning_attention_mask + else: + encoder_attention_mask = torch.cat([ + encoder_attention_mask, + conditioning_attention_mask, + ]) + + return cond, encoder_attention_mask + + encoder_attention_mask = None + if unconditioning.shape[1] != conditioning.shape[1]: + max_len = max(unconditioning.shape[1], conditioning.shape[1]) + unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask) + conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask) + + return torch.cat([unconditioning, conditioning]), encoder_attention_mask + # methods below are called from do_diffusion_step and should be considered private to this class. def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs): @@ -244,9 +277,13 @@ class InvokeAIDiffuserComponent: x_twice = torch.cat([x] * 2) sigma_twice = torch.cat([sigma] * 2) - both_conditionings = torch.cat([unconditioning, conditioning]) + both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch( + unconditioning, conditioning + ) both_results = self.model_forward_callback( - x_twice, sigma_twice, both_conditionings, **kwargs, + x_twice, sigma_twice, both_conditionings, + encoder_attention_mask=encoder_attention_mask, + **kwargs, ) unconditioned_next_x, conditioned_next_x = both_results.chunk(2) return unconditioned_next_x, conditioned_next_x @@ -260,8 +297,32 @@ class InvokeAIDiffuserComponent: **kwargs, ): # low-memory sequential path - unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs) - conditioned_next_x = self.model_forward_callback(x, sigma, conditioning, **kwargs) + uncond_down_block, cond_down_block = None, None + down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None) + if down_block_additional_residuals is not None: + uncond_down_block, cond_down_block = [], [] + for down_block in down_block_additional_residuals: + _uncond_down, _cond_down = down_block.chunk(2) + uncond_down_block.append(_uncond_down) + cond_down_block.append(_cond_down) + + uncond_mid_block, cond_mid_block = None, None + mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None) + if mid_block_additional_residual is not None: + uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2) + + unconditioned_next_x = self.model_forward_callback( + x, sigma, unconditioning, + down_block_additional_residuals=uncond_down_block, + mid_block_additional_residual=uncond_mid_block, + **kwargs, + ) + conditioned_next_x = self.model_forward_callback( + x, sigma, conditioning, + down_block_additional_residuals=cond_down_block, + mid_block_additional_residual=cond_mid_block, + **kwargs, + ) return unconditioned_next_x, conditioned_next_x # TODO: looks unused @@ -295,6 +356,20 @@ class InvokeAIDiffuserComponent: ): context: Context = self.cross_attention_control_context + uncond_down_block, cond_down_block = None, None + down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None) + if down_block_additional_residuals is not None: + uncond_down_block, cond_down_block = [], [] + for down_block in down_block_additional_residuals: + _uncond_down, _cond_down = down_block.chunk(2) + uncond_down_block.append(_uncond_down) + cond_down_block.append(_cond_down) + + uncond_mid_block, cond_mid_block = None, None + mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None) + if mid_block_additional_residual is not None: + uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2) + cross_attn_processor_context = SwapCrossAttnContext( modified_text_embeddings=context.arguments.edited_conditioning, index_map=context.cross_attention_index_map, @@ -307,6 +382,8 @@ class InvokeAIDiffuserComponent: sigma, unconditioning, {"swap_cross_attn_context": cross_attn_processor_context}, + down_block_additional_residuals=uncond_down_block, + mid_block_additional_residual=uncond_mid_block, **kwargs, ) @@ -319,6 +396,8 @@ class InvokeAIDiffuserComponent: sigma, conditioning, {"swap_cross_attn_context": cross_attn_processor_context}, + down_block_additional_residuals=cond_down_block, + mid_block_additional_residual=cond_mid_block, **kwargs, ) return unconditioned_next_x, conditioned_next_x