Remove use of **kwargs in do_unet_step(...), where full parameter list is known and supported.

This commit is contained in:
Ryan Dick 2024-02-14 15:52:34 -05:00 committed by Kent Keirsey
parent ad96857e0f
commit 9bc4e7a593
2 changed files with 30 additions and 31 deletions

View File

@ -598,7 +598,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
step_index=step_index, step_index=step_index,
total_step_count=total_step_count, total_step_count=total_step_count,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
# extra:
down_block_additional_residuals=down_block_additional_residuals, # for ControlNet down_block_additional_residuals=down_block_additional_residuals, # for ControlNet
mid_block_additional_residual=mid_block_additional_residual, # for ControlNet mid_block_additional_residual=mid_block_additional_residual, # for ControlNet
down_intrablock_additional_residuals=down_intrablock_additional_residuals, # for T2I-Adapter down_intrablock_additional_residuals=down_intrablock_additional_residuals, # for T2I-Adapter

View File

@ -224,10 +224,12 @@ class InvokeAIDiffuserComponent:
self, self,
sample: torch.Tensor, sample: torch.Tensor,
timestep: torch.Tensor, timestep: torch.Tensor,
conditioning_data, # TODO: type conditioning_data: ConditioningData,
step_index: int, step_index: int,
total_step_count: int, total_step_count: int,
**kwargs, down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
): ):
cross_attention_control_types_to_do = [] cross_attention_control_types_to_do = []
context: Context = self.cross_attention_control_context context: Context = self.cross_attention_control_context
@ -236,7 +238,6 @@ class InvokeAIDiffuserComponent:
cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step( cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(
percent_through percent_through
) )
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0 wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
if wants_cross_attention_control: if wants_cross_attention_control:
@ -244,31 +245,37 @@ class InvokeAIDiffuserComponent:
unconditioned_next_x, unconditioned_next_x,
conditioned_next_x, conditioned_next_x,
) = self._apply_cross_attention_controlled_conditioning( ) = self._apply_cross_attention_controlled_conditioning(
sample, x=sample,
timestep, sigma=timestep,
conditioning_data, conditioning_data=conditioning_data,
cross_attention_control_types_to_do, cross_attention_control_types_to_do=cross_attention_control_types_to_do,
**kwargs, down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
) )
elif self.sequential_guidance: elif self.sequential_guidance:
( (
unconditioned_next_x, unconditioned_next_x,
conditioned_next_x, conditioned_next_x,
) = self._apply_standard_conditioning_sequentially( ) = self._apply_standard_conditioning_sequentially(
sample, x=sample,
timestep, sigma=timestep,
conditioning_data, conditioning_data=conditioning_data,
**kwargs, down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
) )
else: else:
( (
unconditioned_next_x, unconditioned_next_x,
conditioned_next_x, conditioned_next_x,
) = self._apply_standard_conditioning( ) = self._apply_standard_conditioning(
sample, x=sample,
timestep, sigma=timestep,
conditioning_data, conditioning_data=conditioning_data,
**kwargs, down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
) )
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x
@ -335,7 +342,7 @@ class InvokeAIDiffuserComponent:
# methods below are called from do_diffusion_step and should be considered private to this class. # methods below are called from do_diffusion_step and should be considered private to this class.
def _apply_standard_conditioning(self, x, sigma, conditioning_data: ConditioningData, **kwargs): def _apply_standard_conditioning(self, x, sigma, conditioning_data: ConditioningData):
"""Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at """Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at
the cost of higher memory usage. the cost of higher memory usage.
""" """
@ -384,7 +391,6 @@ class InvokeAIDiffuserComponent:
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
added_cond_kwargs=added_cond_kwargs, added_cond_kwargs=added_cond_kwargs,
**kwargs,
) )
unconditioned_next_x, conditioned_next_x = both_results.chunk(2) unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x
@ -394,14 +400,15 @@ class InvokeAIDiffuserComponent:
x: torch.Tensor, x: torch.Tensor,
sigma, sigma,
conditioning_data: ConditioningData, conditioning_data: ConditioningData,
**kwargs, down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
): ):
"""Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of """Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of
slower execution speed. slower execution speed.
""" """
# low-memory sequential path # low-memory sequential path
uncond_down_block, cond_down_block = None, None uncond_down_block, cond_down_block = None, None
down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None)
if down_block_additional_residuals is not None: if down_block_additional_residuals is not None:
uncond_down_block, cond_down_block = [], [] uncond_down_block, cond_down_block = [], []
for down_block in down_block_additional_residuals: for down_block in down_block_additional_residuals:
@ -410,7 +417,6 @@ class InvokeAIDiffuserComponent:
cond_down_block.append(_cond_down) cond_down_block.append(_cond_down)
uncond_down_intrablock, cond_down_intrablock = None, None uncond_down_intrablock, cond_down_intrablock = None, None
down_intrablock_additional_residuals = kwargs.pop("down_intrablock_additional_residuals", None)
if down_intrablock_additional_residuals is not None: if down_intrablock_additional_residuals is not None:
uncond_down_intrablock, cond_down_intrablock = [], [] uncond_down_intrablock, cond_down_intrablock = [], []
for down_intrablock in down_intrablock_additional_residuals: for down_intrablock in down_intrablock_additional_residuals:
@ -419,7 +425,6 @@ class InvokeAIDiffuserComponent:
cond_down_intrablock.append(_cond_down) cond_down_intrablock.append(_cond_down)
uncond_mid_block, cond_mid_block = None, None uncond_mid_block, cond_mid_block = None, None
mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None)
if mid_block_additional_residual is not None: if mid_block_additional_residual is not None:
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2) uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
@ -451,7 +456,6 @@ class InvokeAIDiffuserComponent:
mid_block_additional_residual=uncond_mid_block, mid_block_additional_residual=uncond_mid_block,
down_intrablock_additional_residuals=uncond_down_intrablock, down_intrablock_additional_residuals=uncond_down_intrablock,
added_cond_kwargs=added_cond_kwargs, added_cond_kwargs=added_cond_kwargs,
**kwargs,
) )
# Run conditional UNet denoising. # Run conditional UNet denoising.
@ -481,7 +485,6 @@ class InvokeAIDiffuserComponent:
mid_block_additional_residual=cond_mid_block, mid_block_additional_residual=cond_mid_block,
down_intrablock_additional_residuals=cond_down_intrablock, down_intrablock_additional_residuals=cond_down_intrablock,
added_cond_kwargs=added_cond_kwargs, added_cond_kwargs=added_cond_kwargs,
**kwargs,
) )
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x
@ -491,12 +494,13 @@ class InvokeAIDiffuserComponent:
sigma, sigma,
conditioning_data, conditioning_data,
cross_attention_control_types_to_do, cross_attention_control_types_to_do,
**kwargs, down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
): ):
context: Context = self.cross_attention_control_context context: Context = self.cross_attention_control_context
uncond_down_block, cond_down_block = None, None uncond_down_block, cond_down_block = None, None
down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None)
if down_block_additional_residuals is not None: if down_block_additional_residuals is not None:
uncond_down_block, cond_down_block = [], [] uncond_down_block, cond_down_block = [], []
for down_block in down_block_additional_residuals: for down_block in down_block_additional_residuals:
@ -505,7 +509,6 @@ class InvokeAIDiffuserComponent:
cond_down_block.append(_cond_down) cond_down_block.append(_cond_down)
uncond_down_intrablock, cond_down_intrablock = None, None uncond_down_intrablock, cond_down_intrablock = None, None
down_intrablock_additional_residuals = kwargs.pop("down_intrablock_additional_residuals", None)
if down_intrablock_additional_residuals is not None: if down_intrablock_additional_residuals is not None:
uncond_down_intrablock, cond_down_intrablock = [], [] uncond_down_intrablock, cond_down_intrablock = [], []
for down_intrablock in down_intrablock_additional_residuals: for down_intrablock in down_intrablock_additional_residuals:
@ -514,7 +517,6 @@ class InvokeAIDiffuserComponent:
cond_down_intrablock.append(_cond_down) cond_down_intrablock.append(_cond_down)
uncond_mid_block, cond_mid_block = None, None uncond_mid_block, cond_mid_block = None, None
mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None)
if mid_block_additional_residual is not None: if mid_block_additional_residual is not None:
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2) uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
@ -543,7 +545,6 @@ class InvokeAIDiffuserComponent:
mid_block_additional_residual=uncond_mid_block, mid_block_additional_residual=uncond_mid_block,
down_intrablock_additional_residuals=uncond_down_intrablock, down_intrablock_additional_residuals=uncond_down_intrablock,
added_cond_kwargs=added_cond_kwargs, added_cond_kwargs=added_cond_kwargs,
**kwargs,
) )
if is_sdxl: if is_sdxl:
@ -563,7 +564,6 @@ class InvokeAIDiffuserComponent:
mid_block_additional_residual=cond_mid_block, mid_block_additional_residual=cond_mid_block,
down_intrablock_additional_residuals=cond_down_intrablock, down_intrablock_additional_residuals=cond_down_intrablock,
added_cond_kwargs=added_cond_kwargs, added_cond_kwargs=added_cond_kwargs,
**kwargs,
) )
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x