from __future__ import annotations import math from contextlib import contextmanager from typing import TYPE_CHECKING, List, Optional, Union import torch from PIL.Image import Image from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR from invokeai.app.util.controlnet_utils import prepare_control_image from invokeai.backend.stable_diffusion.denoise_context import UNetKwargs from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback if TYPE_CHECKING: from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext from invokeai.backend.util.hotfixes import ControlNetModel class ControlNetExt(ExtensionBase): def __init__( self, model: ControlNetModel, image: Image, weight: Union[float, List[float]], begin_step_percent: float, end_step_percent: float, control_mode: str, resize_mode: str, ): super().__init__() self.model = model self.image = image self.weight = weight self.begin_step_percent = begin_step_percent self.end_step_percent = end_step_percent self.control_mode = control_mode self.resize_mode = resize_mode self.image_tensor: Optional[torch.Tensor] = None @contextmanager def patch_extension(self, ctx: DenoiseContext): try: original_processors = self.model.attn_processors self.model.set_attn_processor(ctx.inputs.attention_processor_cls()) yield None finally: self.model.set_attn_processor(original_processors) @callback(ExtensionCallbackType.PRE_DENOISE_LOOP) def resize_image(self, ctx: DenoiseContext): _, _, latent_height, latent_width = ctx.latents.shape image_height = latent_height * LATENT_SCALE_FACTOR image_width = latent_width * LATENT_SCALE_FACTOR self.image_tensor = prepare_control_image( image=self.image, do_classifier_free_guidance=False, width=image_width, height=image_height, # batch_size=batch_size * num_images_per_prompt, # num_images_per_prompt=num_images_per_prompt, device=ctx.latents.device, dtype=ctx.latents.dtype, control_mode=self.control_mode, resize_mode=self.resize_mode, ) @callback(ExtensionCallbackType.PRE_UNET) def pre_unet_step(self, ctx: DenoiseContext): # skip if model not active in current step total_steps = len(ctx.inputs.timesteps) first_step = math.floor(self.begin_step_percent * total_steps) last_step = math.ceil(self.end_step_percent * total_steps) if ctx.step_index < first_step or ctx.step_index > last_step: return # convert mode to internal flags soft_injection = self.control_mode in ["more_prompt", "more_control"] cfg_injection = self.control_mode in ["more_control", "unbalanced"] # no negative conditioning in cfg_injection mode if cfg_injection: if ctx.conditioning_mode == ConditioningMode.Negative: return down_samples, mid_sample = self._run(ctx, soft_injection, ConditioningMode.Positive) if ctx.conditioning_mode == ConditioningMode.Both: # add zeros as samples for negative conditioning down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples] mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample]) else: down_samples, mid_sample = self._run(ctx, soft_injection, ctx.conditioning_mode) if ( ctx.unet_kwargs.down_block_additional_residuals is None and ctx.unet_kwargs.mid_block_additional_residual is None ): ctx.unet_kwargs.down_block_additional_residuals = down_samples ctx.unet_kwargs.mid_block_additional_residual = mid_sample else: # add controlnet outputs together if have multiple controlnets ctx.unet_kwargs.down_block_additional_residuals = [ samples_prev + samples_curr for samples_prev, samples_curr in zip( ctx.unet_kwargs.down_block_additional_residuals, down_samples, strict=True ) ] ctx.unet_kwargs.mid_block_additional_residual += mid_sample def _run(self, ctx: DenoiseContext, soft_injection: bool, conditioning_mode: ConditioningMode): total_steps = len(ctx.inputs.timesteps) model_input = ctx.latent_model_input image_tensor = self.image_tensor if conditioning_mode == ConditioningMode.Both: model_input = torch.cat([model_input] * 2) image_tensor = torch.cat([image_tensor] * 2) cn_unet_kwargs = UNetKwargs( sample=model_input, timestep=ctx.timestep, encoder_hidden_states=None, # set later by conditoning cross_attention_kwargs=dict( # noqa: C408 percent_through=ctx.step_index / total_steps, ), ) ctx.inputs.conditioning_data.to_unet_kwargs(cn_unet_kwargs, conditioning_mode=conditioning_mode) # get static weight, or weight corresponding to current step weight = self.weight if isinstance(weight, list): weight = weight[ctx.step_index] tmp_kwargs = vars(cn_unet_kwargs) tmp_kwargs.pop("down_block_additional_residuals", None) tmp_kwargs.pop("mid_block_additional_residual", None) tmp_kwargs.pop("down_intrablock_additional_residuals", None) # controlnet(s) inference down_samples, mid_sample = self.model( controlnet_cond=image_tensor, conditioning_scale=weight, # controlnet specific, NOT the guidance scale guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel return_dict=False, **vars(cn_unet_kwargs), ) return down_samples, mid_sample