diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index c12c86ed92..ef0f3ee261 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -260,7 +260,6 @@ class InvokeAIDiffuserComponent: conditioning_data, **kwargs, ) - else: ( unconditioned_next_x, @@ -407,6 +406,16 @@ class InvokeAIDiffuserComponent: uncond_down_block.append(_uncond_down) cond_down_block.append(_cond_down) + uncond_down_intrablock, cond_down_intrablock = None, None + down_intrablock_additional_residuals = kwargs.pop("down_intrablock_additional_residuals", None) + if down_intrablock_additional_residuals is not None: + uncond_down_intrablock, cond_down_intrablock = [], [] + for down_intrablock in down_intrablock_additional_residuals: + print("down_intrablock shape: ", down_intrablock.shape) + _uncond_down, _cond_down = down_intrablock.chunk(2) + uncond_down_intrablock.append(_uncond_down) + cond_down_intrablock.append(_cond_down) + uncond_mid_block, cond_mid_block = None, None mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None) if mid_block_additional_residual is not None: @@ -437,6 +446,7 @@ class InvokeAIDiffuserComponent: cross_attention_kwargs=cross_attention_kwargs, down_block_additional_residuals=uncond_down_block, mid_block_additional_residual=uncond_mid_block, + down_intrablock_additional_residuals=uncond_down_intrablock, added_cond_kwargs=added_cond_kwargs, **kwargs, ) @@ -465,6 +475,7 @@ class InvokeAIDiffuserComponent: cross_attention_kwargs=cross_attention_kwargs, down_block_additional_residuals=cond_down_block, mid_block_additional_residual=cond_mid_block, + down_intrablock_additional_residuals=cond_down_intrablock, added_cond_kwargs=added_cond_kwargs, **kwargs, ) @@ -489,6 +500,15 @@ class InvokeAIDiffuserComponent: uncond_down_block.append(_uncond_down) cond_down_block.append(_cond_down) + uncond_down_intrablock, cond_down_intrablock = None, None + down_intrablock_additional_residuals = kwargs.pop("down_intrablock_additional_residuals", None) + if down_intrablock_additional_residuals is not None: + uncond_down_intrablock, cond_down_intrablock = [], [] + for down_intrablock in down_intrablock_additional_residuals: + _uncond_down, _cond_down = down_intrablock.chunk(2) + uncond_down_intrablock.append(_uncond_down) + cond_down_intrablock.append(_cond_down) + uncond_mid_block, cond_mid_block = None, None mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None) if mid_block_additional_residual is not None: @@ -517,6 +537,7 @@ class InvokeAIDiffuserComponent: {"swap_cross_attn_context": cross_attn_processor_context}, down_block_additional_residuals=uncond_down_block, mid_block_additional_residual=uncond_mid_block, + down_intrablock_additional_residuals=uncond_down_intrablock, added_cond_kwargs=added_cond_kwargs, **kwargs, ) @@ -536,6 +557,7 @@ class InvokeAIDiffuserComponent: {"swap_cross_attn_context": cross_attn_processor_context}, down_block_additional_residuals=cond_down_block, mid_block_additional_residual=cond_mid_block, + down_intrablock_additional_residuals=cond_down_intrablock, added_cond_kwargs=added_cond_kwargs, **kwargs, )