From 0fce35c54cbc8ad5b2cae811c36a783037edeff7 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Mon, 17 Jul 2023 23:53:50 +0300 Subject: [PATCH] Cleanup, fix variable name, fix controlnet for sequential and cross attention guidance --- .../stable_diffusion/diffusers_pipeline.py | 14 ++---- .../diffusion/shared_invokeai_diffusion.py | 46 ++++++++++++++++++- 2 files changed, 48 insertions(+), 12 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 120bfb9663..228fbd0585 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -422,7 +422,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): noise: torch.Tensor, callback: Callable[[PipelineIntermediateState], None] = None, run_id=None, - **kwargs, ) -> InvokeAIStableDiffusionPipelineOutput: r""" Function invoked when calling the pipeline for generation. @@ -443,7 +442,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): noise=noise, run_id=run_id, callback=callback, - **kwargs, ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 torch.cuda.empty_cache() @@ -469,7 +467,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): run_id=None, callback: Callable[[PipelineIntermediateState], None] = None, control_data: List[ControlNetData] = None, - **kwargs, ) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]: if self.scheduler.config.get("cpu_only", False): scheduler_device = torch.device('cpu') @@ -487,11 +484,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): timesteps, conditioning_data, noise=noise, - additional_guidance=additional_guidance, run_id=run_id, - callback=callback, + additional_guidance=additional_guidance, control_data=control_data, - **kwargs, + + callback=callback, ) return result.latents, result.attention_map_saver @@ -505,7 +502,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): run_id: str = None, additional_guidance: List[Callable] = None, control_data: List[ControlNetData] = None, - **kwargs, ): self._adjust_memory_efficient_attention(latents) if run_id is None: @@ -546,7 +542,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): total_step_count=len(timesteps), additional_guidance=additional_guidance, control_data=control_data, - **kwargs, ) latents = step_output.prev_sample @@ -588,7 +583,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): total_step_count: int, additional_guidance: List[Callable] = None, control_data: List[ControlNetData] = None, - **kwargs, ): # invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value timestep = t[0] @@ -634,7 +628,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): encoder_hidden_states = conditioning_data.text_embeddings encoder_attention_mask = None else: - encoder_hidden_states, encoder_hidden_states = self.invokeai_diffuser._concat_conditionings_for_batch( + encoder_hidden_states, encoder_attention_mask = self.invokeai_diffuser._concat_conditionings_for_batch( conditioning_data.unconditioned_embeddings, conditioning_data.text_embeddings, ) diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index b637ceb815..f44578cd47 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -297,8 +297,32 @@ class InvokeAIDiffuserComponent: **kwargs, ): # low-memory sequential path - unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs) - conditioned_next_x = self.model_forward_callback(x, sigma, conditioning, **kwargs) + uncond_down_block, cond_down_block = None, None + down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None) + if down_block_additional_residuals is not None: + uncond_down_block, cond_down_block = [], [] + for down_block in down_block_additional_residuals: + _uncond_down, _cond_down = down_block.chunk(2) + uncond_down_block.append(_uncond_down) + cond_down_block.append(_cond_down) + + uncond_mid_block, cond_mid_block = None, None + mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None) + if mid_block_additional_residual is not None: + uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2) + + unconditioned_next_x = self.model_forward_callback( + x, sigma, unconditioning, + down_block_additional_residuals=uncond_down_block, + mid_block_additional_residual=uncond_mid_block, + **kwargs, + ) + conditioned_next_x = self.model_forward_callback( + x, sigma, conditioning, + down_block_additional_residuals=cond_down_block, + mid_block_additional_residual=cond_mid_block, + **kwargs, + ) return unconditioned_next_x, conditioned_next_x # TODO: looks unused @@ -332,6 +356,20 @@ class InvokeAIDiffuserComponent: ): context: Context = self.cross_attention_control_context + uncond_down_block, cond_down_block = None, None + down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None) + if down_block_additional_residuals is not None: + uncond_down_block, cond_down_block = [], [] + for down_block in down_block_additional_residuals: + _uncond_down, _cond_down = down_block.chunk(2) + uncond_down_block.append(_uncond_down) + cond_down_block.append(_cond_down) + + uncond_mid_block, cond_mid_block = None, None + mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None) + if mid_block_additional_residual is not None: + uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2) + cross_attn_processor_context = SwapCrossAttnContext( modified_text_embeddings=context.arguments.edited_conditioning, index_map=context.cross_attention_index_map, @@ -344,6 +382,8 @@ class InvokeAIDiffuserComponent: sigma, unconditioning, {"swap_cross_attn_context": cross_attn_processor_context}, + down_block_additional_residuals=uncond_down_block, + mid_block_additional_residual=uncond_mid_block, **kwargs, ) @@ -356,6 +396,8 @@ class InvokeAIDiffuserComponent: sigma, conditioning, {"swap_cross_attn_context": cross_attn_processor_context}, + down_block_additional_residuals=cond_down_block, + mid_block_additional_residual=cond_mid_block, **kwargs, ) return unconditioned_next_x, conditioned_next_x