diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index c79e1ed697..2787074265 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -58,6 +58,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0 from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType +from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt @@ -465,6 +466,38 @@ class DenoiseLatentsInvocation(BaseInvocation): return controlnet_data + @staticmethod + def parse_controlnet_field( + exit_stack: ExitStack, + context: InvocationContext, + control_input: ControlField | list[ControlField] | None, + ext_manager: ExtensionsManager, + ) -> None: + # Normalize control_input to a list. + control_list: list[ControlField] + if isinstance(control_input, ControlField): + control_list = [control_input] + elif isinstance(control_input, list): + control_list = control_input + elif control_input is None: + control_list = [] + else: + raise ValueError(f"Unexpected control_input type: {type(control_input)}") + + for control_info in control_list: + model = exit_stack.enter_context(context.models.load(control_info.control_model)) + ext_manager.add_extension( + ControlNetExt( + model=model, + image=context.images.get_pil(control_info.image.image_name), + weight=control_info.control_weight, + begin_step_percent=control_info.begin_step_percent, + end_step_percent=control_info.end_step_percent, + control_mode=control_info.control_mode, + resize_mode=control_info.resize_mode, + ) + ) + def prep_ip_adapter_image_prompts( self, context: InvocationContext, @@ -800,22 +833,30 @@ class DenoiseLatentsInvocation(BaseInvocation): if self.unet.freeu_config: ext_manager.add_extension(FreeUExt(self.unet.freeu_config)) - # ext: t2i/ip adapter - ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx) + # context for loading additional models + with ExitStack() as exit_stack: + # later should be smth like: + # for extension_field in self.extensions: + # ext = extension_field.to_extension(exit_stack, context, ext_manager) + # ext_manager.add_extension(ext) + self.parse_controlnet_field(exit_stack, context, self.control, ext_manager) - unet_info = context.models.load(self.unet.unet) - assert isinstance(unet_info.model, UNet2DConditionModel) - with ( - unet_info.model_on_device() as (cached_weights, unet), - ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls), - # ext: controlnet - ext_manager.patch_extensions(unet), - # ext: freeu, seamless, ip adapter, lora - ext_manager.patch_unet(unet, cached_weights), - ): - sd_backend = StableDiffusionBackend(unet, scheduler) - denoise_ctx.unet = unet - result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager) + # ext: t2i/ip adapter + ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx) + + unet_info = context.models.load(self.unet.unet) + assert isinstance(unet_info.model, UNet2DConditionModel) + with ( + unet_info.model_on_device() as (cached_weights, unet), + ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls), + # ext: controlnet + ext_manager.patch_extensions(denoise_ctx), + # ext: freeu, seamless, ip adapter, lora + ext_manager.patch_unet(unet, cached_weights), + ): + sd_backend = StableDiffusionBackend(unet, scheduler) + denoise_ctx.unet = unet + result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 result_latents = result_latents.detach().to("cpu") diff --git a/invokeai/backend/stable_diffusion/extensions/base.py b/invokeai/backend/stable_diffusion/extensions/base.py index 6a85a2e441..820d5d32a3 100644 --- a/invokeai/backend/stable_diffusion/extensions/base.py +++ b/invokeai/backend/stable_diffusion/extensions/base.py @@ -52,7 +52,7 @@ class ExtensionBase: return self._callbacks @contextmanager - def patch_extension(self, context: DenoiseContext): + def patch_extension(self, ctx: DenoiseContext): yield None @contextmanager diff --git a/invokeai/backend/stable_diffusion/extensions/controlnet.py b/invokeai/backend/stable_diffusion/extensions/controlnet.py new file mode 100644 index 0000000000..a48a681af3 --- /dev/null +++ b/invokeai/backend/stable_diffusion/extensions/controlnet.py @@ -0,0 +1,158 @@ +from __future__ import annotations + +import math +from contextlib import contextmanager +from typing import TYPE_CHECKING, List, Optional, Union + +import torch +from PIL.Image import Image + +from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR +from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, prepare_control_image +from invokeai.backend.stable_diffusion.denoise_context import UNetKwargs +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode +from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType +from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback + +if TYPE_CHECKING: + from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext + from invokeai.backend.util.hotfixes import ControlNetModel + + +class ControlNetExt(ExtensionBase): + def __init__( + self, + model: ControlNetModel, + image: Image, + weight: Union[float, List[float]], + begin_step_percent: float, + end_step_percent: float, + control_mode: CONTROLNET_MODE_VALUES, + resize_mode: CONTROLNET_RESIZE_VALUES, + ): + super().__init__() + self._model = model + self._image = image + self._weight = weight + self._begin_step_percent = begin_step_percent + self._end_step_percent = end_step_percent + self._control_mode = control_mode + self._resize_mode = resize_mode + + self._image_tensor: Optional[torch.Tensor] = None + + @contextmanager + def patch_extension(self, ctx: DenoiseContext): + original_processors = self._model.attn_processors + try: + self._model.set_attn_processor(ctx.inputs.attention_processor_cls()) + + yield None + finally: + self._model.set_attn_processor(original_processors) + + @callback(ExtensionCallbackType.PRE_DENOISE_LOOP) + def resize_image(self, ctx: DenoiseContext): + _, _, latent_height, latent_width = ctx.latents.shape + image_height = latent_height * LATENT_SCALE_FACTOR + image_width = latent_width * LATENT_SCALE_FACTOR + + self._image_tensor = prepare_control_image( + image=self._image, + do_classifier_free_guidance=False, + width=image_width, + height=image_height, + device=ctx.latents.device, + dtype=ctx.latents.dtype, + control_mode=self._control_mode, + resize_mode=self._resize_mode, + ) + + @callback(ExtensionCallbackType.PRE_UNET) + def pre_unet_step(self, ctx: DenoiseContext): + # skip if model not active in current step + total_steps = len(ctx.inputs.timesteps) + first_step = math.floor(self._begin_step_percent * total_steps) + last_step = math.ceil(self._end_step_percent * total_steps) + if ctx.step_index < first_step or ctx.step_index > last_step: + return + + # convert mode to internal flags + soft_injection = self._control_mode in ["more_prompt", "more_control"] + cfg_injection = self._control_mode in ["more_control", "unbalanced"] + + # no negative conditioning in cfg_injection mode + if cfg_injection: + if ctx.conditioning_mode == ConditioningMode.Negative: + return + down_samples, mid_sample = self._run(ctx, soft_injection, ConditioningMode.Positive) + + if ctx.conditioning_mode == ConditioningMode.Both: + # add zeros as samples for negative conditioning + down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples] + mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample]) + + else: + down_samples, mid_sample = self._run(ctx, soft_injection, ctx.conditioning_mode) + + if ( + ctx.unet_kwargs.down_block_additional_residuals is None + and ctx.unet_kwargs.mid_block_additional_residual is None + ): + ctx.unet_kwargs.down_block_additional_residuals = down_samples + ctx.unet_kwargs.mid_block_additional_residual = mid_sample + else: + # add controlnet outputs together if have multiple controlnets + ctx.unet_kwargs.down_block_additional_residuals = [ + samples_prev + samples_curr + for samples_prev, samples_curr in zip( + ctx.unet_kwargs.down_block_additional_residuals, down_samples, strict=True + ) + ] + ctx.unet_kwargs.mid_block_additional_residual += mid_sample + + def _run(self, ctx: DenoiseContext, soft_injection: bool, conditioning_mode: ConditioningMode): + total_steps = len(ctx.inputs.timesteps) + + model_input = ctx.latent_model_input + image_tensor = self._image_tensor + if conditioning_mode == ConditioningMode.Both: + model_input = torch.cat([model_input] * 2) + image_tensor = torch.cat([image_tensor] * 2) + + cn_unet_kwargs = UNetKwargs( + sample=model_input, + timestep=ctx.timestep, + encoder_hidden_states=None, # set later by conditioning + cross_attention_kwargs=dict( # noqa: C408 + percent_through=ctx.step_index / total_steps, + ), + ) + + ctx.inputs.conditioning_data.to_unet_kwargs(cn_unet_kwargs, conditioning_mode=conditioning_mode) + + # get static weight, or weight corresponding to current step + weight = self._weight + if isinstance(weight, list): + weight = weight[ctx.step_index] + + tmp_kwargs = vars(cn_unet_kwargs) + + # Remove kwargs not related to ControlNet unet + # ControlNet guidance fields + del tmp_kwargs["down_block_additional_residuals"] + del tmp_kwargs["mid_block_additional_residual"] + + # T2i Adapter guidance fields + del tmp_kwargs["down_intrablock_additional_residuals"] + + # controlnet(s) inference + down_samples, mid_sample = self._model( + controlnet_cond=image_tensor, + conditioning_scale=weight, # controlnet specific, NOT the guidance scale + guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel + return_dict=False, + **vars(cn_unet_kwargs), + ) + + return down_samples, mid_sample diff --git a/invokeai/backend/stable_diffusion/extensions_manager.py b/invokeai/backend/stable_diffusion/extensions_manager.py index 9c4347a56c..c8d585406a 100644 --- a/invokeai/backend/stable_diffusion/extensions_manager.py +++ b/invokeai/backend/stable_diffusion/extensions_manager.py @@ -52,13 +52,13 @@ class ExtensionsManager: cb.function(ctx) @contextmanager - def patch_extensions(self, context: DenoiseContext): + def patch_extensions(self, ctx: DenoiseContext): if self._is_canceled and self._is_canceled(): raise CanceledException with ExitStack() as exit_stack: for ext in self._extensions: - exit_stack.enter_context(ext.patch_extension(context)) + exit_stack.enter_context(ext.patch_extension(ctx)) yield None