diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index 2787074265..e9899a8289 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -62,6 +62,7 @@ from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetEx from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt +from invokeai.backend.stable_diffusion.extensions.t2i_adapter import T2IAdapterExt from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES @@ -498,6 +499,33 @@ class DenoiseLatentsInvocation(BaseInvocation): ) ) + @staticmethod + def parse_t2i_adapter_field( + exit_stack: ExitStack, + context: InvocationContext, + t2i_adapters: Optional[Union[T2IAdapterField, list[T2IAdapterField]]], + ext_manager: ExtensionsManager, + ) -> None: + if t2i_adapters is None: + return + + # Handle the possibility that t2i_adapters could be a list or a single T2IAdapterField. + if isinstance(t2i_adapters, T2IAdapterField): + t2i_adapters = [t2i_adapters] + + for t2i_adapter_field in t2i_adapters: + ext_manager.add_extension( + T2IAdapterExt( + node_context=context, + model_id=t2i_adapter_field.t2i_adapter_model, + image=context.images.get_pil(t2i_adapter_field.image.image_name), + weight=t2i_adapter_field.weight, + begin_step_percent=t2i_adapter_field.begin_step_percent, + end_step_percent=t2i_adapter_field.end_step_percent, + resize_mode=t2i_adapter_field.resize_mode, + ) + ) + def prep_ip_adapter_image_prompts( self, context: InvocationContext, @@ -840,6 +868,7 @@ class DenoiseLatentsInvocation(BaseInvocation): # ext = extension_field.to_extension(exit_stack, context, ext_manager) # ext_manager.add_extension(ext) self.parse_controlnet_field(exit_stack, context, self.control, ext_manager) + self.parse_t2i_adapter_field(exit_stack, context, self.t2i_adapter, ext_manager) # ext: t2i/ip adapter ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx) diff --git a/invokeai/backend/stable_diffusion/extensions/t2i_adapter.py b/invokeai/backend/stable_diffusion/extensions/t2i_adapter.py new file mode 100644 index 0000000000..6c8b4b7504 --- /dev/null +++ b/invokeai/backend/stable_diffusion/extensions/t2i_adapter.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +import math +from typing import TYPE_CHECKING, List, Optional, Union + +import torch +from diffusers import T2IAdapter +from PIL.Image import Image + +from invokeai.app.util.controlnet_utils import prepare_control_image +from invokeai.backend.model_manager import BaseModelType +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode +from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType +from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback + +if TYPE_CHECKING: + from invokeai.app.invocations.model import ModelIdentifierField + from invokeai.app.services.shared.invocation_context import InvocationContext + from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES + from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext + + +class T2IAdapterExt(ExtensionBase): + def __init__( + self, + node_context: InvocationContext, + model_id: ModelIdentifierField, + image: Image, + weight: Union[float, List[float]], + begin_step_percent: float, + end_step_percent: float, + resize_mode: CONTROLNET_RESIZE_VALUES, + ): + super().__init__() + self._node_context = node_context + self._model_id = model_id + self._image = image + self._weight = weight + self._resize_mode = resize_mode + self._begin_step_percent = begin_step_percent + self._end_step_percent = end_step_percent + + self._adapter_state: Optional[List[torch.Tensor]] = None + + # The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally. + model_config = self._node_context.models.get_config(self._model_id.key) + if model_config.base == BaseModelType.StableDiffusion1: + self._max_unet_downscale = 8 + elif model_config.base == BaseModelType.StableDiffusionXL: + self._max_unet_downscale = 4 + else: + raise ValueError(f"Unexpected T2I-Adapter base model type: '{model_config.base}'.") + + @callback(ExtensionCallbackType.SETUP) + def setup(self, ctx: DenoiseContext): + t2i_model: T2IAdapter + with self._node_context.models.load(self._model_id) as t2i_model: + _, _, latents_height, latents_width = ctx.inputs.orig_latents.shape + + self._adapter_state = self._run_model( + model=t2i_model, + image=self._image, + latents_height=latents_height, + latents_width=latents_width, + max_unet_downscale=self._max_unet_downscale, + resize_mode=self._resize_mode, + ) + + def _run_model( + self, + model: T2IAdapter, + image: Image, + latents_height: int, + latents_width: int, + max_unet_downscale: int, + resize_mode: CONTROLNET_RESIZE_VALUES, + ): + input_height = latents_height // max_unet_downscale * model.total_downscale_factor + input_width = latents_width // max_unet_downscale * model.total_downscale_factor + + t2i_image = prepare_control_image( + image=image, + do_classifier_free_guidance=False, + width=input_width, + height=input_height, + num_channels=model.config["in_channels"], # mypy treats this as a FrozenDict + device=model.device, + dtype=model.dtype, + resize_mode=resize_mode, + ) + + return model(t2i_image) + + @callback(ExtensionCallbackType.PRE_UNET) + def pre_unet_step(self, ctx: DenoiseContext): + # skip if model not active in current step + total_steps = len(ctx.inputs.timesteps) + first_step = math.floor(self._begin_step_percent * total_steps) + last_step = math.ceil(self._end_step_percent * total_steps) + if ctx.step_index < first_step or ctx.step_index > last_step: + return + + weight = self._weight + if isinstance(weight, list): + weight = weight[ctx.step_index] + + adapter_state = self._adapter_state + if ctx.conditioning_mode == ConditioningMode.Both: + adapter_state = [torch.cat([v] * 2) for v in adapter_state] + + if ctx.unet_kwargs.down_intrablock_additional_residuals is None: + ctx.unet_kwargs.down_intrablock_additional_residuals = [v * weight for v in adapter_state] + else: + for i, value in enumerate(adapter_state): + ctx.unet_kwargs.down_intrablock_additional_residuals[i] += value * weight