Cleanup, fix variable name, fix controlnet for sequential and cross attention guidance

This commit is contained in:
Sergey Borisov 2023-07-17 23:53:50 +03:00
parent 1c680a7147
commit 0fce35c54c
2 changed files with 48 additions and 12 deletions

View File

@ -422,7 +422,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise: torch.Tensor,
callback: Callable[[PipelineIntermediateState], None] = None,
run_id=None,
**kwargs,
) -> InvokeAIStableDiffusionPipelineOutput:
r"""
Function invoked when calling the pipeline for generation.
@ -443,7 +442,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise=noise,
run_id=run_id,
callback=callback,
**kwargs,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
@ -469,7 +467,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
run_id=None,
callback: Callable[[PipelineIntermediateState], None] = None,
control_data: List[ControlNetData] = None,
**kwargs,
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
if self.scheduler.config.get("cpu_only", False):
scheduler_device = torch.device('cpu')
@ -487,11 +484,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
timesteps,
conditioning_data,
noise=noise,
additional_guidance=additional_guidance,
run_id=run_id,
callback=callback,
additional_guidance=additional_guidance,
control_data=control_data,
**kwargs,
callback=callback,
)
return result.latents, result.attention_map_saver
@ -505,7 +502,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
run_id: str = None,
additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None,
**kwargs,
):
self._adjust_memory_efficient_attention(latents)
if run_id is None:
@ -546,7 +542,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
total_step_count=len(timesteps),
additional_guidance=additional_guidance,
control_data=control_data,
**kwargs,
)
latents = step_output.prev_sample
@ -588,7 +583,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
total_step_count: int,
additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None,
**kwargs,
):
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
timestep = t[0]
@ -634,7 +628,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
encoder_hidden_states = conditioning_data.text_embeddings
encoder_attention_mask = None
else:
encoder_hidden_states, encoder_hidden_states = self.invokeai_diffuser._concat_conditionings_for_batch(
encoder_hidden_states, encoder_attention_mask = self.invokeai_diffuser._concat_conditionings_for_batch(
conditioning_data.unconditioned_embeddings,
conditioning_data.text_embeddings,
)

View File

@ -297,8 +297,32 @@ class InvokeAIDiffuserComponent:
**kwargs,
):
# low-memory sequential path
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning, **kwargs)
uncond_down_block, cond_down_block = None, None
down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None)
if down_block_additional_residuals is not None:
uncond_down_block, cond_down_block = [], []
for down_block in down_block_additional_residuals:
_uncond_down, _cond_down = down_block.chunk(2)
uncond_down_block.append(_uncond_down)
cond_down_block.append(_cond_down)
uncond_mid_block, cond_mid_block = None, None
mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None)
if mid_block_additional_residual is not None:
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
unconditioned_next_x = self.model_forward_callback(
x, sigma, unconditioning,
down_block_additional_residuals=uncond_down_block,
mid_block_additional_residual=uncond_mid_block,
**kwargs,
)
conditioned_next_x = self.model_forward_callback(
x, sigma, conditioning,
down_block_additional_residuals=cond_down_block,
mid_block_additional_residual=cond_mid_block,
**kwargs,
)
return unconditioned_next_x, conditioned_next_x
# TODO: looks unused
@ -332,6 +356,20 @@ class InvokeAIDiffuserComponent:
):
context: Context = self.cross_attention_control_context
uncond_down_block, cond_down_block = None, None
down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None)
if down_block_additional_residuals is not None:
uncond_down_block, cond_down_block = [], []
for down_block in down_block_additional_residuals:
_uncond_down, _cond_down = down_block.chunk(2)
uncond_down_block.append(_uncond_down)
cond_down_block.append(_cond_down)
uncond_mid_block, cond_mid_block = None, None
mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None)
if mid_block_additional_residual is not None:
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
cross_attn_processor_context = SwapCrossAttnContext(
modified_text_embeddings=context.arguments.edited_conditioning,
index_map=context.cross_attention_index_map,
@ -344,6 +382,8 @@ class InvokeAIDiffuserComponent:
sigma,
unconditioning,
{"swap_cross_attn_context": cross_attn_processor_context},
down_block_additional_residuals=uncond_down_block,
mid_block_additional_residual=uncond_mid_block,
**kwargs,
)
@ -356,6 +396,8 @@ class InvokeAIDiffuserComponent:
sigma,
conditioning,
{"swap_cross_attn_context": cross_attn_processor_context},
down_block_additional_residuals=cond_down_block,
mid_block_additional_residual=cond_mid_block,
**kwargs,
)
return unconditioned_next_x, conditioned_next_x