mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Remove use of **kwargs in do_unet_step(...), where full parameter list is known and supported.
This commit is contained in:
parent
ad96857e0f
commit
9bc4e7a593
@ -598,7 +598,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
step_index=step_index,
|
step_index=step_index,
|
||||||
total_step_count=total_step_count,
|
total_step_count=total_step_count,
|
||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
# extra:
|
|
||||||
down_block_additional_residuals=down_block_additional_residuals, # for ControlNet
|
down_block_additional_residuals=down_block_additional_residuals, # for ControlNet
|
||||||
mid_block_additional_residual=mid_block_additional_residual, # for ControlNet
|
mid_block_additional_residual=mid_block_additional_residual, # for ControlNet
|
||||||
down_intrablock_additional_residuals=down_intrablock_additional_residuals, # for T2I-Adapter
|
down_intrablock_additional_residuals=down_intrablock_additional_residuals, # for T2I-Adapter
|
||||||
|
@ -224,10 +224,12 @@ class InvokeAIDiffuserComponent:
|
|||||||
self,
|
self,
|
||||||
sample: torch.Tensor,
|
sample: torch.Tensor,
|
||||||
timestep: torch.Tensor,
|
timestep: torch.Tensor,
|
||||||
conditioning_data, # TODO: type
|
conditioning_data: ConditioningData,
|
||||||
step_index: int,
|
step_index: int,
|
||||||
total_step_count: int,
|
total_step_count: int,
|
||||||
**kwargs,
|
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
||||||
|
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
||||||
|
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
||||||
):
|
):
|
||||||
cross_attention_control_types_to_do = []
|
cross_attention_control_types_to_do = []
|
||||||
context: Context = self.cross_attention_control_context
|
context: Context = self.cross_attention_control_context
|
||||||
@ -236,7 +238,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(
|
cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(
|
||||||
percent_through
|
percent_through
|
||||||
)
|
)
|
||||||
|
|
||||||
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
|
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
|
||||||
|
|
||||||
if wants_cross_attention_control:
|
if wants_cross_attention_control:
|
||||||
@ -244,31 +245,37 @@ class InvokeAIDiffuserComponent:
|
|||||||
unconditioned_next_x,
|
unconditioned_next_x,
|
||||||
conditioned_next_x,
|
conditioned_next_x,
|
||||||
) = self._apply_cross_attention_controlled_conditioning(
|
) = self._apply_cross_attention_controlled_conditioning(
|
||||||
sample,
|
x=sample,
|
||||||
timestep,
|
sigma=timestep,
|
||||||
conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
cross_attention_control_types_to_do,
|
cross_attention_control_types_to_do=cross_attention_control_types_to_do,
|
||||||
**kwargs,
|
down_block_additional_residuals=down_block_additional_residuals,
|
||||||
|
mid_block_additional_residual=mid_block_additional_residual,
|
||||||
|
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
||||||
)
|
)
|
||||||
elif self.sequential_guidance:
|
elif self.sequential_guidance:
|
||||||
(
|
(
|
||||||
unconditioned_next_x,
|
unconditioned_next_x,
|
||||||
conditioned_next_x,
|
conditioned_next_x,
|
||||||
) = self._apply_standard_conditioning_sequentially(
|
) = self._apply_standard_conditioning_sequentially(
|
||||||
sample,
|
x=sample,
|
||||||
timestep,
|
sigma=timestep,
|
||||||
conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
**kwargs,
|
down_block_additional_residuals=down_block_additional_residuals,
|
||||||
|
mid_block_additional_residual=mid_block_additional_residual,
|
||||||
|
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
(
|
(
|
||||||
unconditioned_next_x,
|
unconditioned_next_x,
|
||||||
conditioned_next_x,
|
conditioned_next_x,
|
||||||
) = self._apply_standard_conditioning(
|
) = self._apply_standard_conditioning(
|
||||||
sample,
|
x=sample,
|
||||||
timestep,
|
sigma=timestep,
|
||||||
conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
**kwargs,
|
down_block_additional_residuals=down_block_additional_residuals,
|
||||||
|
mid_block_additional_residual=mid_block_additional_residual,
|
||||||
|
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
||||||
)
|
)
|
||||||
|
|
||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
@ -335,7 +342,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
# methods below are called from do_diffusion_step and should be considered private to this class.
|
||||||
|
|
||||||
def _apply_standard_conditioning(self, x, sigma, conditioning_data: ConditioningData, **kwargs):
|
def _apply_standard_conditioning(self, x, sigma, conditioning_data: ConditioningData):
|
||||||
"""Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at
|
"""Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at
|
||||||
the cost of higher memory usage.
|
the cost of higher memory usage.
|
||||||
"""
|
"""
|
||||||
@ -384,7 +391,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
cross_attention_kwargs=cross_attention_kwargs,
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
added_cond_kwargs=added_cond_kwargs,
|
added_cond_kwargs=added_cond_kwargs,
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
@ -394,14 +400,15 @@ class InvokeAIDiffuserComponent:
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
sigma,
|
sigma,
|
||||||
conditioning_data: ConditioningData,
|
conditioning_data: ConditioningData,
|
||||||
**kwargs,
|
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
||||||
|
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
||||||
|
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
||||||
):
|
):
|
||||||
"""Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of
|
"""Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of
|
||||||
slower execution speed.
|
slower execution speed.
|
||||||
"""
|
"""
|
||||||
# low-memory sequential path
|
# low-memory sequential path
|
||||||
uncond_down_block, cond_down_block = None, None
|
uncond_down_block, cond_down_block = None, None
|
||||||
down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None)
|
|
||||||
if down_block_additional_residuals is not None:
|
if down_block_additional_residuals is not None:
|
||||||
uncond_down_block, cond_down_block = [], []
|
uncond_down_block, cond_down_block = [], []
|
||||||
for down_block in down_block_additional_residuals:
|
for down_block in down_block_additional_residuals:
|
||||||
@ -410,7 +417,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
cond_down_block.append(_cond_down)
|
cond_down_block.append(_cond_down)
|
||||||
|
|
||||||
uncond_down_intrablock, cond_down_intrablock = None, None
|
uncond_down_intrablock, cond_down_intrablock = None, None
|
||||||
down_intrablock_additional_residuals = kwargs.pop("down_intrablock_additional_residuals", None)
|
|
||||||
if down_intrablock_additional_residuals is not None:
|
if down_intrablock_additional_residuals is not None:
|
||||||
uncond_down_intrablock, cond_down_intrablock = [], []
|
uncond_down_intrablock, cond_down_intrablock = [], []
|
||||||
for down_intrablock in down_intrablock_additional_residuals:
|
for down_intrablock in down_intrablock_additional_residuals:
|
||||||
@ -419,7 +425,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
cond_down_intrablock.append(_cond_down)
|
cond_down_intrablock.append(_cond_down)
|
||||||
|
|
||||||
uncond_mid_block, cond_mid_block = None, None
|
uncond_mid_block, cond_mid_block = None, None
|
||||||
mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None)
|
|
||||||
if mid_block_additional_residual is not None:
|
if mid_block_additional_residual is not None:
|
||||||
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
|
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
|
||||||
|
|
||||||
@ -451,7 +456,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
mid_block_additional_residual=uncond_mid_block,
|
mid_block_additional_residual=uncond_mid_block,
|
||||||
down_intrablock_additional_residuals=uncond_down_intrablock,
|
down_intrablock_additional_residuals=uncond_down_intrablock,
|
||||||
added_cond_kwargs=added_cond_kwargs,
|
added_cond_kwargs=added_cond_kwargs,
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run conditional UNet denoising.
|
# Run conditional UNet denoising.
|
||||||
@ -481,7 +485,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
mid_block_additional_residual=cond_mid_block,
|
mid_block_additional_residual=cond_mid_block,
|
||||||
down_intrablock_additional_residuals=cond_down_intrablock,
|
down_intrablock_additional_residuals=cond_down_intrablock,
|
||||||
added_cond_kwargs=added_cond_kwargs,
|
added_cond_kwargs=added_cond_kwargs,
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
@ -491,12 +494,13 @@ class InvokeAIDiffuserComponent:
|
|||||||
sigma,
|
sigma,
|
||||||
conditioning_data,
|
conditioning_data,
|
||||||
cross_attention_control_types_to_do,
|
cross_attention_control_types_to_do,
|
||||||
**kwargs,
|
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
||||||
|
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
||||||
|
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
||||||
):
|
):
|
||||||
context: Context = self.cross_attention_control_context
|
context: Context = self.cross_attention_control_context
|
||||||
|
|
||||||
uncond_down_block, cond_down_block = None, None
|
uncond_down_block, cond_down_block = None, None
|
||||||
down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None)
|
|
||||||
if down_block_additional_residuals is not None:
|
if down_block_additional_residuals is not None:
|
||||||
uncond_down_block, cond_down_block = [], []
|
uncond_down_block, cond_down_block = [], []
|
||||||
for down_block in down_block_additional_residuals:
|
for down_block in down_block_additional_residuals:
|
||||||
@ -505,7 +509,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
cond_down_block.append(_cond_down)
|
cond_down_block.append(_cond_down)
|
||||||
|
|
||||||
uncond_down_intrablock, cond_down_intrablock = None, None
|
uncond_down_intrablock, cond_down_intrablock = None, None
|
||||||
down_intrablock_additional_residuals = kwargs.pop("down_intrablock_additional_residuals", None)
|
|
||||||
if down_intrablock_additional_residuals is not None:
|
if down_intrablock_additional_residuals is not None:
|
||||||
uncond_down_intrablock, cond_down_intrablock = [], []
|
uncond_down_intrablock, cond_down_intrablock = [], []
|
||||||
for down_intrablock in down_intrablock_additional_residuals:
|
for down_intrablock in down_intrablock_additional_residuals:
|
||||||
@ -514,7 +517,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
cond_down_intrablock.append(_cond_down)
|
cond_down_intrablock.append(_cond_down)
|
||||||
|
|
||||||
uncond_mid_block, cond_mid_block = None, None
|
uncond_mid_block, cond_mid_block = None, None
|
||||||
mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None)
|
|
||||||
if mid_block_additional_residual is not None:
|
if mid_block_additional_residual is not None:
|
||||||
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
|
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
|
||||||
|
|
||||||
@ -543,7 +545,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
mid_block_additional_residual=uncond_mid_block,
|
mid_block_additional_residual=uncond_mid_block,
|
||||||
down_intrablock_additional_residuals=uncond_down_intrablock,
|
down_intrablock_additional_residuals=uncond_down_intrablock,
|
||||||
added_cond_kwargs=added_cond_kwargs,
|
added_cond_kwargs=added_cond_kwargs,
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_sdxl:
|
if is_sdxl:
|
||||||
@ -563,7 +564,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
mid_block_additional_residual=cond_mid_block,
|
mid_block_additional_residual=cond_mid_block,
|
||||||
down_intrablock_additional_residuals=cond_down_intrablock,
|
down_intrablock_additional_residuals=cond_down_intrablock,
|
||||||
added_cond_kwargs=added_cond_kwargs,
|
added_cond_kwargs=added_cond_kwargs,
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user