Revert "Revert "Changes to _apply_standard_conditioning_sequentially() and _apply_cross_attention_controlled_conditioning() to reflect changes to T2I-Adapter implementation to allow usage of T2I-Adapter and ControlNet at the same time.""

This reverts commit c04fb451ee.
This commit is contained in:
Kent Keirsey 2023-10-17 11:59:19 -04:00 committed by GitHub
parent 282d36b640
commit a97ec88e06

View File

@ -260,7 +260,6 @@ class InvokeAIDiffuserComponent:
conditioning_data, conditioning_data,
**kwargs, **kwargs,
) )
else: else:
( (
unconditioned_next_x, unconditioned_next_x,
@ -407,6 +406,16 @@ class InvokeAIDiffuserComponent:
uncond_down_block.append(_uncond_down) uncond_down_block.append(_uncond_down)
cond_down_block.append(_cond_down) cond_down_block.append(_cond_down)
uncond_down_intrablock, cond_down_intrablock = None, None
down_intrablock_additional_residuals = kwargs.pop("down_intrablock_additional_residuals", None)
if down_intrablock_additional_residuals is not None:
uncond_down_intrablock, cond_down_intrablock = [], []
for down_intrablock in down_intrablock_additional_residuals:
print("down_intrablock shape: ", down_intrablock.shape)
_uncond_down, _cond_down = down_intrablock.chunk(2)
uncond_down_intrablock.append(_uncond_down)
cond_down_intrablock.append(_cond_down)
uncond_mid_block, cond_mid_block = None, None uncond_mid_block, cond_mid_block = None, None
mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None) mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None)
if mid_block_additional_residual is not None: if mid_block_additional_residual is not None:
@ -437,6 +446,7 @@ class InvokeAIDiffuserComponent:
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=uncond_down_block, down_block_additional_residuals=uncond_down_block,
mid_block_additional_residual=uncond_mid_block, mid_block_additional_residual=uncond_mid_block,
down_intrablock_additional_residuals=uncond_down_intrablock,
added_cond_kwargs=added_cond_kwargs, added_cond_kwargs=added_cond_kwargs,
**kwargs, **kwargs,
) )
@ -465,6 +475,7 @@ class InvokeAIDiffuserComponent:
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=cond_down_block, down_block_additional_residuals=cond_down_block,
mid_block_additional_residual=cond_mid_block, mid_block_additional_residual=cond_mid_block,
down_intrablock_additional_residuals=cond_down_intrablock,
added_cond_kwargs=added_cond_kwargs, added_cond_kwargs=added_cond_kwargs,
**kwargs, **kwargs,
) )
@ -489,6 +500,15 @@ class InvokeAIDiffuserComponent:
uncond_down_block.append(_uncond_down) uncond_down_block.append(_uncond_down)
cond_down_block.append(_cond_down) cond_down_block.append(_cond_down)
uncond_down_intrablock, cond_down_intrablock = None, None
down_intrablock_additional_residuals = kwargs.pop("down_intrablock_additional_residuals", None)
if down_intrablock_additional_residuals is not None:
uncond_down_intrablock, cond_down_intrablock = [], []
for down_intrablock in down_intrablock_additional_residuals:
_uncond_down, _cond_down = down_intrablock.chunk(2)
uncond_down_intrablock.append(_uncond_down)
cond_down_intrablock.append(_cond_down)
uncond_mid_block, cond_mid_block = None, None uncond_mid_block, cond_mid_block = None, None
mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None) mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None)
if mid_block_additional_residual is not None: if mid_block_additional_residual is not None:
@ -517,6 +537,7 @@ class InvokeAIDiffuserComponent:
{"swap_cross_attn_context": cross_attn_processor_context}, {"swap_cross_attn_context": cross_attn_processor_context},
down_block_additional_residuals=uncond_down_block, down_block_additional_residuals=uncond_down_block,
mid_block_additional_residual=uncond_mid_block, mid_block_additional_residual=uncond_mid_block,
down_intrablock_additional_residuals=uncond_down_intrablock,
added_cond_kwargs=added_cond_kwargs, added_cond_kwargs=added_cond_kwargs,
**kwargs, **kwargs,
) )
@ -536,6 +557,7 @@ class InvokeAIDiffuserComponent:
{"swap_cross_attn_context": cross_attn_processor_context}, {"swap_cross_attn_context": cross_attn_processor_context},
down_block_additional_residuals=cond_down_block, down_block_additional_residuals=cond_down_block,
mid_block_additional_residual=cond_mid_block, mid_block_additional_residual=cond_mid_block,
down_intrablock_additional_residuals=cond_down_intrablock,
added_cond_kwargs=added_cond_kwargs, added_cond_kwargs=added_cond_kwargs,
**kwargs, **kwargs,
) )