From 42356ec866b1186f2812c0857d17c2ab38ac104c Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Sun, 21 Jul 2024 20:01:30 +0300 Subject: [PATCH 1/4] Add ControlNet support to denoise --- invokeai/app/invocations/denoise_latents.py | 72 ++++++-- .../stable_diffusion/extensions/controlnet.py | 155 ++++++++++++++++++ 2 files changed, 212 insertions(+), 15 deletions(-) create mode 100644 invokeai/backend/stable_diffusion/extensions/controlnet.py diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index ccacc3303c..f966d7f2cc 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -58,6 +58,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0 from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType +from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP @@ -463,6 +464,39 @@ class DenoiseLatentsInvocation(BaseInvocation): return controlnet_data + @staticmethod + def parse_controlnet_field( + exit_stack: ExitStack, + context: InvocationContext, + control_input: ControlField | list[ControlField] | None, + ext_manager: ExtensionsManager, + ) -> None: + # Normalize control_input to a list. + control_list: list[ControlField] + if isinstance(control_input, ControlField): + control_list = [control_input] + elif isinstance(control_input, list): + control_list = control_input + elif control_input is None: + control_list = [] + else: + raise ValueError(f"Unexpected control_input type: {type(control_input)}") + + for control_info in control_list: + model = exit_stack.enter_context(context.models.load(control_info.control_model)) + ext_manager.add_extension( + ControlNetExt( + model=model, + image=context.images.get_pil(control_info.image.image_name), + weight=control_info.control_weight, + begin_step_percent=control_info.begin_step_percent, + end_step_percent=control_info.end_step_percent, + control_mode=control_info.control_mode, + resize_mode=control_info.resize_mode, + ) + ) + # MultiControlNetModel has been refactored out, just need list[ControlNetData] + def prep_ip_adapter_image_prompts( self, context: InvocationContext, @@ -790,22 +824,30 @@ class DenoiseLatentsInvocation(BaseInvocation): ext_manager.add_extension(PreviewExt(step_callback)) - # ext: t2i/ip adapter - ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx) + # context for loading additional models + with ExitStack() as exit_stack: + # later should be smth like: + # for extension_field in self.extensions: + # ext = extension_field.to_extension(exit_stack, context, ext_manager) + # ext_manager.add_extension(ext) + self.parse_controlnet_field(exit_stack, context, self.control, ext_manager) - unet_info = context.models.load(self.unet.unet) - assert isinstance(unet_info.model, UNet2DConditionModel) - with ( - unet_info.model_on_device() as (model_state_dict, unet), - ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls), - # ext: controlnet - ext_manager.patch_extensions(unet), - # ext: freeu, seamless, ip adapter, lora - ext_manager.patch_unet(model_state_dict, unet), - ): - sd_backend = StableDiffusionBackend(unet, scheduler) - denoise_ctx.unet = unet - result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager) + # ext: t2i/ip adapter + ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx) + + unet_info = context.models.load(self.unet.unet) + assert isinstance(unet_info.model, UNet2DConditionModel) + with ( + unet_info.model_on_device() as (model_state_dict, unet), + ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls), + # ext: controlnet + ext_manager.patch_extensions(denoise_ctx), + # ext: freeu, seamless, ip adapter, lora + ext_manager.patch_unet(model_state_dict, unet), + ): + sd_backend = StableDiffusionBackend(unet, scheduler) + denoise_ctx.unet = unet + result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 result_latents = result_latents.detach().to("cpu") diff --git a/invokeai/backend/stable_diffusion/extensions/controlnet.py b/invokeai/backend/stable_diffusion/extensions/controlnet.py new file mode 100644 index 0000000000..e74d183c2c --- /dev/null +++ b/invokeai/backend/stable_diffusion/extensions/controlnet.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +import math +from contextlib import contextmanager +from typing import TYPE_CHECKING, List, Optional, Union + +import torch +from PIL.Image import Image + +from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR +from invokeai.app.util.controlnet_utils import prepare_control_image +from invokeai.backend.stable_diffusion.denoise_context import UNetKwargs +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode +from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType +from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback + +if TYPE_CHECKING: + from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext + from invokeai.backend.util.hotfixes import ControlNetModel + + +class ControlNetExt(ExtensionBase): + def __init__( + self, + model: ControlNetModel, + image: Image, + weight: Union[float, List[float]], + begin_step_percent: float, + end_step_percent: float, + control_mode: str, + resize_mode: str, + ): + super().__init__() + self.model = model + self.image = image + self.weight = weight + self.begin_step_percent = begin_step_percent + self.end_step_percent = end_step_percent + self.control_mode = control_mode + self.resize_mode = resize_mode + + self.image_tensor: Optional[torch.Tensor] = None + + @contextmanager + def patch_extension(self, ctx: DenoiseContext): + try: + original_processors = self.model.attn_processors + self.model.set_attn_processor(ctx.inputs.attention_processor_cls()) + + yield None + finally: + self.model.set_attn_processor(original_processors) + + @callback(ExtensionCallbackType.PRE_DENOISE_LOOP) + def resize_image(self, ctx: DenoiseContext): + _, _, latent_height, latent_width = ctx.latents.shape + image_height = latent_height * LATENT_SCALE_FACTOR + image_width = latent_width * LATENT_SCALE_FACTOR + + self.image_tensor = prepare_control_image( + image=self.image, + do_classifier_free_guidance=False, + width=image_width, + height=image_height, + # batch_size=batch_size * num_images_per_prompt, + # num_images_per_prompt=num_images_per_prompt, + device=ctx.latents.device, + dtype=ctx.latents.dtype, + control_mode=self.control_mode, + resize_mode=self.resize_mode, + ) + + @callback(ExtensionCallbackType.PRE_UNET) + def pre_unet_step(self, ctx: DenoiseContext): + # skip if model not active in current step + total_steps = len(ctx.inputs.timesteps) + first_step = math.floor(self.begin_step_percent * total_steps) + last_step = math.ceil(self.end_step_percent * total_steps) + if ctx.step_index < first_step or ctx.step_index > last_step: + return + + # convert mode to internal flags + soft_injection = self.control_mode in ["more_prompt", "more_control"] + cfg_injection = self.control_mode in ["more_control", "unbalanced"] + + # no negative conditioning in cfg_injection mode + if cfg_injection: + if ctx.conditioning_mode == ConditioningMode.Negative: + return + down_samples, mid_sample = self._run(ctx, soft_injection, ConditioningMode.Positive) + + if ctx.conditioning_mode == ConditioningMode.Both: + # add zeros as samples for negative conditioning + down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples] + mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample]) + + else: + down_samples, mid_sample = self._run(ctx, soft_injection, ctx.conditioning_mode) + + if ( + ctx.unet_kwargs.down_block_additional_residuals is None + and ctx.unet_kwargs.mid_block_additional_residual is None + ): + ctx.unet_kwargs.down_block_additional_residuals = down_samples + ctx.unet_kwargs.mid_block_additional_residual = mid_sample + else: + # add controlnet outputs together if have multiple controlnets + ctx.unet_kwargs.down_block_additional_residuals = [ + samples_prev + samples_curr + for samples_prev, samples_curr in zip( + ctx.unet_kwargs.down_block_additional_residuals, down_samples, strict=True + ) + ] + ctx.unet_kwargs.mid_block_additional_residual += mid_sample + + def _run(self, ctx: DenoiseContext, soft_injection: bool, conditioning_mode: ConditioningMode): + total_steps = len(ctx.inputs.timesteps) + + model_input = ctx.latent_model_input + image_tensor = self.image_tensor + if conditioning_mode == ConditioningMode.Both: + model_input = torch.cat([model_input] * 2) + image_tensor = torch.cat([image_tensor] * 2) + + cn_unet_kwargs = UNetKwargs( + sample=model_input, + timestep=ctx.timestep, + encoder_hidden_states=None, # set later by conditoning + cross_attention_kwargs=dict( # noqa: C408 + percent_through=ctx.step_index / total_steps, + ), + ) + + ctx.inputs.conditioning_data.to_unet_kwargs(cn_unet_kwargs, conditioning_mode=conditioning_mode) + + # get static weight, or weight corresponding to current step + weight = self.weight + if isinstance(weight, list): + weight = weight[ctx.step_index] + + tmp_kwargs = vars(cn_unet_kwargs) + tmp_kwargs.pop("down_block_additional_residuals", None) + tmp_kwargs.pop("mid_block_additional_residual", None) + tmp_kwargs.pop("down_intrablock_additional_residuals", None) + + # controlnet(s) inference + down_samples, mid_sample = self.model( + controlnet_cond=image_tensor, + conditioning_scale=weight, # controlnet specific, NOT the guidance scale + guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel + return_dict=False, + **vars(cn_unet_kwargs), + ) + + return down_samples, mid_sample From 3cb13d6288af9c7c662ed1f03231b71dba992443 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Tue, 23 Jul 2024 01:01:18 +0300 Subject: [PATCH 2/4] Rename as suggested in other PRs Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com> --- .../stable_diffusion/extensions/controlnet.py | 44 +++++++++---------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/invokeai/backend/stable_diffusion/extensions/controlnet.py b/invokeai/backend/stable_diffusion/extensions/controlnet.py index e74d183c2c..0506a7f1a3 100644 --- a/invokeai/backend/stable_diffusion/extensions/controlnet.py +++ b/invokeai/backend/stable_diffusion/extensions/controlnet.py @@ -31,25 +31,25 @@ class ControlNetExt(ExtensionBase): resize_mode: str, ): super().__init__() - self.model = model - self.image = image - self.weight = weight - self.begin_step_percent = begin_step_percent - self.end_step_percent = end_step_percent - self.control_mode = control_mode - self.resize_mode = resize_mode + self._model = model + self._image = image + self._weight = weight + self._begin_step_percent = begin_step_percent + self._end_step_percent = end_step_percent + self._control_mode = control_mode + self._resize_mode = resize_mode - self.image_tensor: Optional[torch.Tensor] = None + self._image_tensor: Optional[torch.Tensor] = None @contextmanager def patch_extension(self, ctx: DenoiseContext): try: - original_processors = self.model.attn_processors - self.model.set_attn_processor(ctx.inputs.attention_processor_cls()) + original_processors = self._model.attn_processors + self._model.set_attn_processor(ctx.inputs.attention_processor_cls()) yield None finally: - self.model.set_attn_processor(original_processors) + self._model.set_attn_processor(original_processors) @callback(ExtensionCallbackType.PRE_DENOISE_LOOP) def resize_image(self, ctx: DenoiseContext): @@ -57,8 +57,8 @@ class ControlNetExt(ExtensionBase): image_height = latent_height * LATENT_SCALE_FACTOR image_width = latent_width * LATENT_SCALE_FACTOR - self.image_tensor = prepare_control_image( - image=self.image, + self._image_tensor = prepare_control_image( + image=self._image, do_classifier_free_guidance=False, width=image_width, height=image_height, @@ -66,22 +66,22 @@ class ControlNetExt(ExtensionBase): # num_images_per_prompt=num_images_per_prompt, device=ctx.latents.device, dtype=ctx.latents.dtype, - control_mode=self.control_mode, - resize_mode=self.resize_mode, + control_mode=self._control_mode, + resize_mode=self._resize_mode, ) @callback(ExtensionCallbackType.PRE_UNET) def pre_unet_step(self, ctx: DenoiseContext): # skip if model not active in current step total_steps = len(ctx.inputs.timesteps) - first_step = math.floor(self.begin_step_percent * total_steps) - last_step = math.ceil(self.end_step_percent * total_steps) + first_step = math.floor(self._begin_step_percent * total_steps) + last_step = math.ceil(self._end_step_percent * total_steps) if ctx.step_index < first_step or ctx.step_index > last_step: return # convert mode to internal flags - soft_injection = self.control_mode in ["more_prompt", "more_control"] - cfg_injection = self.control_mode in ["more_control", "unbalanced"] + soft_injection = self._control_mode in ["more_prompt", "more_control"] + cfg_injection = self._control_mode in ["more_control", "unbalanced"] # no negative conditioning in cfg_injection mode if cfg_injection: @@ -117,7 +117,7 @@ class ControlNetExt(ExtensionBase): total_steps = len(ctx.inputs.timesteps) model_input = ctx.latent_model_input - image_tensor = self.image_tensor + image_tensor = self._image_tensor if conditioning_mode == ConditioningMode.Both: model_input = torch.cat([model_input] * 2) image_tensor = torch.cat([image_tensor] * 2) @@ -134,7 +134,7 @@ class ControlNetExt(ExtensionBase): ctx.inputs.conditioning_data.to_unet_kwargs(cn_unet_kwargs, conditioning_mode=conditioning_mode) # get static weight, or weight corresponding to current step - weight = self.weight + weight = self._weight if isinstance(weight, list): weight = weight[ctx.step_index] @@ -144,7 +144,7 @@ class ControlNetExt(ExtensionBase): tmp_kwargs.pop("down_intrablock_additional_residuals", None) # controlnet(s) inference - down_samples, mid_sample = self.model( + down_samples, mid_sample = self._model( controlnet_cond=image_tensor, conditioning_scale=weight, # controlnet specific, NOT the guidance scale guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel From 4e8dcb7a1ab1bccd7544397dac726872f3383cbf Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Tue, 23 Jul 2024 01:46:29 +0300 Subject: [PATCH 3/4] Suggested changes Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com> --- invokeai/app/invocations/denoise_latents.py | 1 - .../stable_diffusion/extensions/controlnet.py | 23 +++++++++++-------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index f966d7f2cc..77cd9e4630 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -495,7 +495,6 @@ class DenoiseLatentsInvocation(BaseInvocation): resize_mode=control_info.resize_mode, ) ) - # MultiControlNetModel has been refactored out, just need list[ControlNetData] def prep_ip_adapter_image_prompts( self, diff --git a/invokeai/backend/stable_diffusion/extensions/controlnet.py b/invokeai/backend/stable_diffusion/extensions/controlnet.py index 0506a7f1a3..a48a681af3 100644 --- a/invokeai/backend/stable_diffusion/extensions/controlnet.py +++ b/invokeai/backend/stable_diffusion/extensions/controlnet.py @@ -8,7 +8,7 @@ import torch from PIL.Image import Image from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR -from invokeai.app.util.controlnet_utils import prepare_control_image +from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, prepare_control_image from invokeai.backend.stable_diffusion.denoise_context import UNetKwargs from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType @@ -27,8 +27,8 @@ class ControlNetExt(ExtensionBase): weight: Union[float, List[float]], begin_step_percent: float, end_step_percent: float, - control_mode: str, - resize_mode: str, + control_mode: CONTROLNET_MODE_VALUES, + resize_mode: CONTROLNET_RESIZE_VALUES, ): super().__init__() self._model = model @@ -43,8 +43,8 @@ class ControlNetExt(ExtensionBase): @contextmanager def patch_extension(self, ctx: DenoiseContext): + original_processors = self._model.attn_processors try: - original_processors = self._model.attn_processors self._model.set_attn_processor(ctx.inputs.attention_processor_cls()) yield None @@ -62,8 +62,6 @@ class ControlNetExt(ExtensionBase): do_classifier_free_guidance=False, width=image_width, height=image_height, - # batch_size=batch_size * num_images_per_prompt, - # num_images_per_prompt=num_images_per_prompt, device=ctx.latents.device, dtype=ctx.latents.dtype, control_mode=self._control_mode, @@ -125,7 +123,7 @@ class ControlNetExt(ExtensionBase): cn_unet_kwargs = UNetKwargs( sample=model_input, timestep=ctx.timestep, - encoder_hidden_states=None, # set later by conditoning + encoder_hidden_states=None, # set later by conditioning cross_attention_kwargs=dict( # noqa: C408 percent_through=ctx.step_index / total_steps, ), @@ -139,9 +137,14 @@ class ControlNetExt(ExtensionBase): weight = weight[ctx.step_index] tmp_kwargs = vars(cn_unet_kwargs) - tmp_kwargs.pop("down_block_additional_residuals", None) - tmp_kwargs.pop("mid_block_additional_residual", None) - tmp_kwargs.pop("down_intrablock_additional_residuals", None) + + # Remove kwargs not related to ControlNet unet + # ControlNet guidance fields + del tmp_kwargs["down_block_additional_residuals"] + del tmp_kwargs["mid_block_additional_residual"] + + # T2i Adapter guidance fields + del tmp_kwargs["down_intrablock_additional_residuals"] # controlnet(s) inference down_samples, mid_sample = self._model( From 39e804d0f8be8d14ef940f3a5ac71b09e3aa15d3 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 23 Jul 2024 09:18:04 -0400 Subject: [PATCH 4/4] Use consistent param names in patch_extension(...) functions: context -> ctx. --- invokeai/backend/stable_diffusion/extensions/base.py | 2 +- invokeai/backend/stable_diffusion/extensions_manager.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/invokeai/backend/stable_diffusion/extensions/base.py b/invokeai/backend/stable_diffusion/extensions/base.py index 802af86e6d..835fe0aaf9 100644 --- a/invokeai/backend/stable_diffusion/extensions/base.py +++ b/invokeai/backend/stable_diffusion/extensions/base.py @@ -52,7 +52,7 @@ class ExtensionBase: return self._callbacks @contextmanager - def patch_extension(self, context: DenoiseContext): + def patch_extension(self, ctx: DenoiseContext): yield None @contextmanager diff --git a/invokeai/backend/stable_diffusion/extensions_manager.py b/invokeai/backend/stable_diffusion/extensions_manager.py index 1cae2e4219..a9e554ad98 100644 --- a/invokeai/backend/stable_diffusion/extensions_manager.py +++ b/invokeai/backend/stable_diffusion/extensions_manager.py @@ -52,13 +52,13 @@ class ExtensionsManager: cb.function(ctx) @contextmanager - def patch_extensions(self, context: DenoiseContext): + def patch_extensions(self, ctx: DenoiseContext): if self._is_canceled and self._is_canceled(): raise CanceledException with ExitStack() as exit_stack: for ext in self._extensions: - exit_stack.enter_context(ext.patch_extension(context)) + exit_stack.enter_context(ext.patch_extension(ctx)) yield None