From 42356ec866b1186f2812c0857d17c2ab38ac104c Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Sun, 21 Jul 2024 20:01:30 +0300 Subject: [PATCH] Add ControlNet support to denoise --- invokeai/app/invocations/denoise_latents.py | 72 ++++++-- .../stable_diffusion/extensions/controlnet.py | 155 ++++++++++++++++++ 2 files changed, 212 insertions(+), 15 deletions(-) create mode 100644 invokeai/backend/stable_diffusion/extensions/controlnet.py diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index ccacc3303c..f966d7f2cc 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -58,6 +58,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0 from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType +from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP @@ -463,6 +464,39 @@ class DenoiseLatentsInvocation(BaseInvocation): return controlnet_data + @staticmethod + def parse_controlnet_field( + exit_stack: ExitStack, + context: InvocationContext, + control_input: ControlField | list[ControlField] | None, + ext_manager: ExtensionsManager, + ) -> None: + # Normalize control_input to a list. + control_list: list[ControlField] + if isinstance(control_input, ControlField): + control_list = [control_input] + elif isinstance(control_input, list): + control_list = control_input + elif control_input is None: + control_list = [] + else: + raise ValueError(f"Unexpected control_input type: {type(control_input)}") + + for control_info in control_list: + model = exit_stack.enter_context(context.models.load(control_info.control_model)) + ext_manager.add_extension( + ControlNetExt( + model=model, + image=context.images.get_pil(control_info.image.image_name), + weight=control_info.control_weight, + begin_step_percent=control_info.begin_step_percent, + end_step_percent=control_info.end_step_percent, + control_mode=control_info.control_mode, + resize_mode=control_info.resize_mode, + ) + ) + # MultiControlNetModel has been refactored out, just need list[ControlNetData] + def prep_ip_adapter_image_prompts( self, context: InvocationContext, @@ -790,22 +824,30 @@ class DenoiseLatentsInvocation(BaseInvocation): ext_manager.add_extension(PreviewExt(step_callback)) - # ext: t2i/ip adapter - ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx) + # context for loading additional models + with ExitStack() as exit_stack: + # later should be smth like: + # for extension_field in self.extensions: + # ext = extension_field.to_extension(exit_stack, context, ext_manager) + # ext_manager.add_extension(ext) + self.parse_controlnet_field(exit_stack, context, self.control, ext_manager) - unet_info = context.models.load(self.unet.unet) - assert isinstance(unet_info.model, UNet2DConditionModel) - with ( - unet_info.model_on_device() as (model_state_dict, unet), - ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls), - # ext: controlnet - ext_manager.patch_extensions(unet), - # ext: freeu, seamless, ip adapter, lora - ext_manager.patch_unet(model_state_dict, unet), - ): - sd_backend = StableDiffusionBackend(unet, scheduler) - denoise_ctx.unet = unet - result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager) + # ext: t2i/ip adapter + ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx) + + unet_info = context.models.load(self.unet.unet) + assert isinstance(unet_info.model, UNet2DConditionModel) + with ( + unet_info.model_on_device() as (model_state_dict, unet), + ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls), + # ext: controlnet + ext_manager.patch_extensions(denoise_ctx), + # ext: freeu, seamless, ip adapter, lora + ext_manager.patch_unet(model_state_dict, unet), + ): + sd_backend = StableDiffusionBackend(unet, scheduler) + denoise_ctx.unet = unet + result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 result_latents = result_latents.detach().to("cpu") diff --git a/invokeai/backend/stable_diffusion/extensions/controlnet.py b/invokeai/backend/stable_diffusion/extensions/controlnet.py new file mode 100644 index 0000000000..e74d183c2c --- /dev/null +++ b/invokeai/backend/stable_diffusion/extensions/controlnet.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +import math +from contextlib import contextmanager +from typing import TYPE_CHECKING, List, Optional, Union + +import torch +from PIL.Image import Image + +from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR +from invokeai.app.util.controlnet_utils import prepare_control_image +from invokeai.backend.stable_diffusion.denoise_context import UNetKwargs +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode +from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType +from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback + +if TYPE_CHECKING: + from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext + from invokeai.backend.util.hotfixes import ControlNetModel + + +class ControlNetExt(ExtensionBase): + def __init__( + self, + model: ControlNetModel, + image: Image, + weight: Union[float, List[float]], + begin_step_percent: float, + end_step_percent: float, + control_mode: str, + resize_mode: str, + ): + super().__init__() + self.model = model + self.image = image + self.weight = weight + self.begin_step_percent = begin_step_percent + self.end_step_percent = end_step_percent + self.control_mode = control_mode + self.resize_mode = resize_mode + + self.image_tensor: Optional[torch.Tensor] = None + + @contextmanager + def patch_extension(self, ctx: DenoiseContext): + try: + original_processors = self.model.attn_processors + self.model.set_attn_processor(ctx.inputs.attention_processor_cls()) + + yield None + finally: + self.model.set_attn_processor(original_processors) + + @callback(ExtensionCallbackType.PRE_DENOISE_LOOP) + def resize_image(self, ctx: DenoiseContext): + _, _, latent_height, latent_width = ctx.latents.shape + image_height = latent_height * LATENT_SCALE_FACTOR + image_width = latent_width * LATENT_SCALE_FACTOR + + self.image_tensor = prepare_control_image( + image=self.image, + do_classifier_free_guidance=False, + width=image_width, + height=image_height, + # batch_size=batch_size * num_images_per_prompt, + # num_images_per_prompt=num_images_per_prompt, + device=ctx.latents.device, + dtype=ctx.latents.dtype, + control_mode=self.control_mode, + resize_mode=self.resize_mode, + ) + + @callback(ExtensionCallbackType.PRE_UNET) + def pre_unet_step(self, ctx: DenoiseContext): + # skip if model not active in current step + total_steps = len(ctx.inputs.timesteps) + first_step = math.floor(self.begin_step_percent * total_steps) + last_step = math.ceil(self.end_step_percent * total_steps) + if ctx.step_index < first_step or ctx.step_index > last_step: + return + + # convert mode to internal flags + soft_injection = self.control_mode in ["more_prompt", "more_control"] + cfg_injection = self.control_mode in ["more_control", "unbalanced"] + + # no negative conditioning in cfg_injection mode + if cfg_injection: + if ctx.conditioning_mode == ConditioningMode.Negative: + return + down_samples, mid_sample = self._run(ctx, soft_injection, ConditioningMode.Positive) + + if ctx.conditioning_mode == ConditioningMode.Both: + # add zeros as samples for negative conditioning + down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples] + mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample]) + + else: + down_samples, mid_sample = self._run(ctx, soft_injection, ctx.conditioning_mode) + + if ( + ctx.unet_kwargs.down_block_additional_residuals is None + and ctx.unet_kwargs.mid_block_additional_residual is None + ): + ctx.unet_kwargs.down_block_additional_residuals = down_samples + ctx.unet_kwargs.mid_block_additional_residual = mid_sample + else: + # add controlnet outputs together if have multiple controlnets + ctx.unet_kwargs.down_block_additional_residuals = [ + samples_prev + samples_curr + for samples_prev, samples_curr in zip( + ctx.unet_kwargs.down_block_additional_residuals, down_samples, strict=True + ) + ] + ctx.unet_kwargs.mid_block_additional_residual += mid_sample + + def _run(self, ctx: DenoiseContext, soft_injection: bool, conditioning_mode: ConditioningMode): + total_steps = len(ctx.inputs.timesteps) + + model_input = ctx.latent_model_input + image_tensor = self.image_tensor + if conditioning_mode == ConditioningMode.Both: + model_input = torch.cat([model_input] * 2) + image_tensor = torch.cat([image_tensor] * 2) + + cn_unet_kwargs = UNetKwargs( + sample=model_input, + timestep=ctx.timestep, + encoder_hidden_states=None, # set later by conditoning + cross_attention_kwargs=dict( # noqa: C408 + percent_through=ctx.step_index / total_steps, + ), + ) + + ctx.inputs.conditioning_data.to_unet_kwargs(cn_unet_kwargs, conditioning_mode=conditioning_mode) + + # get static weight, or weight corresponding to current step + weight = self.weight + if isinstance(weight, list): + weight = weight[ctx.step_index] + + tmp_kwargs = vars(cn_unet_kwargs) + tmp_kwargs.pop("down_block_additional_residuals", None) + tmp_kwargs.pop("mid_block_additional_residual", None) + tmp_kwargs.pop("down_intrablock_additional_residuals", None) + + # controlnet(s) inference + down_samples, mid_sample = self.model( + controlnet_cond=image_tensor, + conditioning_scale=weight, # controlnet specific, NOT the guidance scale + guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel + return_dict=False, + **vars(cn_unet_kwargs), + ) + + return down_samples, mid_sample