From 9bc4e7a5934cd1bd27be2a4036e4297487faee3b Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 14 Feb 2024 15:52:34 -0500 Subject: [PATCH] Remove use of **kwargs in do_unet_step(...), where full parameter list is known and supported. --- .../stable_diffusion/diffusers_pipeline.py | 1 - .../diffusion/shared_invokeai_diffusion.py | 60 +++++++++---------- 2 files changed, 30 insertions(+), 31 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 918ca538a3..d7f17fb744 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -598,7 +598,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): step_index=step_index, total_step_count=total_step_count, conditioning_data=conditioning_data, - # extra: down_block_additional_residuals=down_block_additional_residuals, # for ControlNet mid_block_additional_residual=mid_block_additional_residual, # for ControlNet down_intrablock_additional_residuals=down_intrablock_additional_residuals, # for T2I-Adapter diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index 455e5e1096..353256006a 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -224,10 +224,12 @@ class InvokeAIDiffuserComponent: self, sample: torch.Tensor, timestep: torch.Tensor, - conditioning_data, # TODO: type + conditioning_data: ConditioningData, step_index: int, total_step_count: int, - **kwargs, + down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet + mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet + down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter ): cross_attention_control_types_to_do = [] context: Context = self.cross_attention_control_context @@ -236,7 +238,6 @@ class InvokeAIDiffuserComponent: cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step( percent_through ) - wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0 if wants_cross_attention_control: @@ -244,31 +245,37 @@ class InvokeAIDiffuserComponent: unconditioned_next_x, conditioned_next_x, ) = self._apply_cross_attention_controlled_conditioning( - sample, - timestep, - conditioning_data, - cross_attention_control_types_to_do, - **kwargs, + x=sample, + sigma=timestep, + conditioning_data=conditioning_data, + cross_attention_control_types_to_do=cross_attention_control_types_to_do, + down_block_additional_residuals=down_block_additional_residuals, + mid_block_additional_residual=mid_block_additional_residual, + down_intrablock_additional_residuals=down_intrablock_additional_residuals, ) elif self.sequential_guidance: ( unconditioned_next_x, conditioned_next_x, ) = self._apply_standard_conditioning_sequentially( - sample, - timestep, - conditioning_data, - **kwargs, + x=sample, + sigma=timestep, + conditioning_data=conditioning_data, + down_block_additional_residuals=down_block_additional_residuals, + mid_block_additional_residual=mid_block_additional_residual, + down_intrablock_additional_residuals=down_intrablock_additional_residuals, ) else: ( unconditioned_next_x, conditioned_next_x, ) = self._apply_standard_conditioning( - sample, - timestep, - conditioning_data, - **kwargs, + x=sample, + sigma=timestep, + conditioning_data=conditioning_data, + down_block_additional_residuals=down_block_additional_residuals, + mid_block_additional_residual=mid_block_additional_residual, + down_intrablock_additional_residuals=down_intrablock_additional_residuals, ) return unconditioned_next_x, conditioned_next_x @@ -335,7 +342,7 @@ class InvokeAIDiffuserComponent: # methods below are called from do_diffusion_step and should be considered private to this class. - def _apply_standard_conditioning(self, x, sigma, conditioning_data: ConditioningData, **kwargs): + def _apply_standard_conditioning(self, x, sigma, conditioning_data: ConditioningData): """Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at the cost of higher memory usage. """ @@ -384,7 +391,6 @@ class InvokeAIDiffuserComponent: cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, added_cond_kwargs=added_cond_kwargs, - **kwargs, ) unconditioned_next_x, conditioned_next_x = both_results.chunk(2) return unconditioned_next_x, conditioned_next_x @@ -394,14 +400,15 @@ class InvokeAIDiffuserComponent: x: torch.Tensor, sigma, conditioning_data: ConditioningData, - **kwargs, + down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet + mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet + down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter ): """Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of slower execution speed. """ # low-memory sequential path uncond_down_block, cond_down_block = None, None - down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None) if down_block_additional_residuals is not None: uncond_down_block, cond_down_block = [], [] for down_block in down_block_additional_residuals: @@ -410,7 +417,6 @@ class InvokeAIDiffuserComponent: cond_down_block.append(_cond_down) uncond_down_intrablock, cond_down_intrablock = None, None - down_intrablock_additional_residuals = kwargs.pop("down_intrablock_additional_residuals", None) if down_intrablock_additional_residuals is not None: uncond_down_intrablock, cond_down_intrablock = [], [] for down_intrablock in down_intrablock_additional_residuals: @@ -419,7 +425,6 @@ class InvokeAIDiffuserComponent: cond_down_intrablock.append(_cond_down) uncond_mid_block, cond_mid_block = None, None - mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None) if mid_block_additional_residual is not None: uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2) @@ -451,7 +456,6 @@ class InvokeAIDiffuserComponent: mid_block_additional_residual=uncond_mid_block, down_intrablock_additional_residuals=uncond_down_intrablock, added_cond_kwargs=added_cond_kwargs, - **kwargs, ) # Run conditional UNet denoising. @@ -481,7 +485,6 @@ class InvokeAIDiffuserComponent: mid_block_additional_residual=cond_mid_block, down_intrablock_additional_residuals=cond_down_intrablock, added_cond_kwargs=added_cond_kwargs, - **kwargs, ) return unconditioned_next_x, conditioned_next_x @@ -491,12 +494,13 @@ class InvokeAIDiffuserComponent: sigma, conditioning_data, cross_attention_control_types_to_do, - **kwargs, + down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet + mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet + down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter ): context: Context = self.cross_attention_control_context uncond_down_block, cond_down_block = None, None - down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None) if down_block_additional_residuals is not None: uncond_down_block, cond_down_block = [], [] for down_block in down_block_additional_residuals: @@ -505,7 +509,6 @@ class InvokeAIDiffuserComponent: cond_down_block.append(_cond_down) uncond_down_intrablock, cond_down_intrablock = None, None - down_intrablock_additional_residuals = kwargs.pop("down_intrablock_additional_residuals", None) if down_intrablock_additional_residuals is not None: uncond_down_intrablock, cond_down_intrablock = [], [] for down_intrablock in down_intrablock_additional_residuals: @@ -514,7 +517,6 @@ class InvokeAIDiffuserComponent: cond_down_intrablock.append(_cond_down) uncond_mid_block, cond_mid_block = None, None - mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None) if mid_block_additional_residual is not None: uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2) @@ -543,7 +545,6 @@ class InvokeAIDiffuserComponent: mid_block_additional_residual=uncond_mid_block, down_intrablock_additional_residuals=uncond_down_intrablock, added_cond_kwargs=added_cond_kwargs, - **kwargs, ) if is_sdxl: @@ -563,7 +564,6 @@ class InvokeAIDiffuserComponent: mid_block_additional_residual=cond_mid_block, down_intrablock_additional_residuals=cond_down_intrablock, added_cond_kwargs=added_cond_kwargs, - **kwargs, ) return unconditioned_next_x, conditioned_next_x