# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) import inspect from contextlib import ExitStack from typing import Any, Dict, Iterator, List, Optional, Tuple, Union import torch import torchvision import torchvision.transforms as T from diffusers.configuration_utils import ConfigMixin from diffusers.models.adapter import T2IAdapter from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from diffusers.schedulers.scheduling_dpmsolver_sde import DPMSolverSDEScheduler from diffusers.schedulers.scheduling_tcd import TCDScheduler from diffusers.schedulers.scheduling_utils import SchedulerMixin as Scheduler from pydantic import field_validator from torchvision.transforms.functional import resize as tv_resize from transformers import CLIPVisionModelWithProjection from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR from invokeai.app.invocations.controlnet_image_processors import ControlField from invokeai.app.invocations.fields import ( ConditioningField, DenoiseMaskField, FieldDescriptions, Input, InputField, LatentsField, UIType, ) from invokeai.app.invocations.ip_adapter import IPAdapterField from invokeai.app.invocations.model import ModelIdentifierField, UNetField from invokeai.app.invocations.primitives import LatentsOutput from invokeai.app.invocations.t2i_adapter import T2IAdapterField from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.controlnet_utils import prepare_control_image from invokeai.backend.ip_adapter.ip_adapter import IPAdapter from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.model_manager import BaseModelType from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless from invokeai.backend.stable_diffusion.diffusers_pipeline import ( ControlNetData, StableDiffusionGeneratorPipeline, T2IAdapterData, ) from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, IPAdapterConditioningInfo, IPAdapterData, Range, SDXLConditioningInfo, TextConditioningData, TextConditioningRegions, ) from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.hotfixes import ControlNetModel from invokeai.backend.util.mask import to_standard_float_mask from invokeai.backend.util.silence_warnings import SilenceWarnings def get_scheduler( context: InvocationContext, scheduler_info: ModelIdentifierField, scheduler_name: str, seed: int, ) -> Scheduler: """Load a scheduler and apply some scheduler-specific overrides.""" # TODO(ryand): Silently falling back to ddim seems like a bad idea. Look into why this was added and remove if # possible. scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"]) orig_scheduler_info = context.models.load(scheduler_info) with orig_scheduler_info as orig_scheduler: scheduler_config = orig_scheduler.config if "_backup" in scheduler_config: scheduler_config = scheduler_config["_backup"] scheduler_config = { **scheduler_config, **scheduler_extra_config, # FIXME "_backup": scheduler_config, } # make dpmpp_sde reproducable(seed can be passed only in initializer) if scheduler_class is DPMSolverSDEScheduler: scheduler_config["noise_sampler_seed"] = seed scheduler = scheduler_class.from_config(scheduler_config) # hack copied over from generate.py if not hasattr(scheduler, "uses_inpainting_model"): scheduler.uses_inpainting_model = lambda: False assert isinstance(scheduler, Scheduler) return scheduler @invocation( "denoise_latents", title="Denoise Latents", tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"], category="latents", version="1.5.3", ) class DenoiseLatentsInvocation(BaseInvocation): """Denoises noisy latents to decodable images""" positive_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField( description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0 ) negative_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField( description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=1 ) noise: Optional[LatentsField] = InputField( default=None, description=FieldDescriptions.noise, input=Input.Connection, ui_order=3, ) steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps) cfg_scale: Union[float, List[float]] = InputField( default=7.5, description=FieldDescriptions.cfg_scale, title="CFG Scale" ) denoising_start: float = InputField( default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start, ) denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end) scheduler: SCHEDULER_NAME_VALUES = InputField( default="euler", description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler, ) unet: UNetField = InputField( description=FieldDescriptions.unet, input=Input.Connection, title="UNet", ui_order=2, ) control: Optional[Union[ControlField, list[ControlField]]] = InputField( default=None, input=Input.Connection, ui_order=5, ) ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]] = InputField( description=FieldDescriptions.ip_adapter, title="IP-Adapter", default=None, input=Input.Connection, ui_order=6, ) t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]] = InputField( description=FieldDescriptions.t2i_adapter, title="T2I-Adapter", default=None, input=Input.Connection, ui_order=7, ) cfg_rescale_multiplier: float = InputField( title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier ) latents: Optional[LatentsField] = InputField( default=None, description=FieldDescriptions.latents, input=Input.Connection, ui_order=4, ) denoise_mask: Optional[DenoiseMaskField] = InputField( default=None, description=FieldDescriptions.mask, input=Input.Connection, ui_order=8, ) @field_validator("cfg_scale") def ge_one(cls, v: Union[List[float], float]) -> Union[List[float], float]: """validate that all cfg_scale values are >= 1""" if isinstance(v, list): for i in v: if i < 1: raise ValueError("cfg_scale must be greater than 1") else: if v < 1: raise ValueError("cfg_scale must be greater than 1") return v @staticmethod def _get_text_embeddings_and_masks( cond_list: list[ConditioningField], context: InvocationContext, device: torch.device, dtype: torch.dtype, ) -> tuple[Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], list[Optional[torch.Tensor]]]: """Get the text embeddings and masks from the input conditioning fields.""" text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = [] text_embeddings_masks: list[Optional[torch.Tensor]] = [] for cond in cond_list: cond_data = context.conditioning.load(cond.conditioning_name) text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype)) mask = cond.mask if mask is not None: mask = context.tensors.load(mask.tensor_name) text_embeddings_masks.append(mask) return text_embeddings, text_embeddings_masks @staticmethod def _preprocess_regional_prompt_mask( mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype ) -> torch.Tensor: """Preprocess a regional prompt mask to match the target height and width. If mask is None, returns a mask of all ones with the target height and width. If mask is not None, resizes the mask to the target height and width using 'nearest' interpolation. Returns: torch.Tensor: The processed mask. shape: (1, 1, target_height, target_width). """ if mask is None: return torch.ones((1, 1, target_height, target_width), dtype=dtype) mask = to_standard_float_mask(mask, out_dtype=dtype) tf = torchvision.transforms.Resize( (target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST ) # Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w). mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w) resized_mask = tf(mask) return resized_mask @staticmethod def _concat_regional_text_embeddings( text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], masks: Optional[list[Optional[torch.Tensor]]], latent_height: int, latent_width: int, dtype: torch.dtype, ) -> tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[TextConditioningRegions]]: """Concatenate regional text embeddings into a single embedding and track the region masks accordingly.""" if masks is None: masks = [None] * len(text_conditionings) assert len(text_conditionings) == len(masks) is_sdxl = type(text_conditionings[0]) is SDXLConditioningInfo all_masks_are_none = all(mask is None for mask in masks) text_embedding = [] pooled_embedding = None add_time_ids = None cur_text_embedding_len = 0 processed_masks = [] embedding_ranges = [] for prompt_idx, text_embedding_info in enumerate(text_conditionings): mask = masks[prompt_idx] if is_sdxl: # We choose a random SDXLConditioningInfo's pooled_embeds and add_time_ids here, with a preference for # prompts without a mask. We prefer prompts without a mask, because they are more likely to contain # global prompt information. In an ideal case, there should be exactly one global prompt without a # mask, but we don't enforce this. # HACK(ryand): The fact that we have to choose a single pooled_embedding and add_time_ids here is a # fundamental interface issue. The SDXL Compel nodes are not designed to be used in the way that we use # them for regional prompting. Ideally, the DenoiseLatents invocation should accept a single # pooled_embeds tensor and a list of standard text embeds with region masks. This change would be a # pretty major breaking change to a popular node, so for now we use this hack. if pooled_embedding is None or mask is None: pooled_embedding = text_embedding_info.pooled_embeds if add_time_ids is None or mask is None: add_time_ids = text_embedding_info.add_time_ids text_embedding.append(text_embedding_info.embeds) if not all_masks_are_none: embedding_ranges.append( Range( start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1] ) ) processed_masks.append( DenoiseLatentsInvocation._preprocess_regional_prompt_mask( mask, latent_height, latent_width, dtype=dtype ) ) cur_text_embedding_len += text_embedding_info.embeds.shape[1] text_embedding = torch.cat(text_embedding, dim=1) assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len regions = None if not all_masks_are_none: regions = TextConditioningRegions( masks=torch.cat(processed_masks, dim=1), ranges=embedding_ranges, ) if is_sdxl: return ( SDXLConditioningInfo(embeds=text_embedding, pooled_embeds=pooled_embedding, add_time_ids=add_time_ids), regions, ) return BasicConditioningInfo(embeds=text_embedding), regions @staticmethod def get_conditioning_data( context: InvocationContext, positive_conditioning_field: Union[ConditioningField, list[ConditioningField]], negative_conditioning_field: Union[ConditioningField, list[ConditioningField]], unet: UNet2DConditionModel, latent_height: int, latent_width: int, cfg_scale: float | list[float], steps: int, cfg_rescale_multiplier: float, ) -> TextConditioningData: # Normalize positive_conditioning_field and negative_conditioning_field to lists. cond_list = positive_conditioning_field if not isinstance(cond_list, list): cond_list = [cond_list] uncond_list = negative_conditioning_field if not isinstance(uncond_list, list): uncond_list = [uncond_list] cond_text_embeddings, cond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks( cond_list, context, unet.device, unet.dtype ) uncond_text_embeddings, uncond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks( uncond_list, context, unet.device, unet.dtype ) cond_text_embedding, cond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings( text_conditionings=cond_text_embeddings, masks=cond_text_embedding_masks, latent_height=latent_height, latent_width=latent_width, dtype=unet.dtype, ) uncond_text_embedding, uncond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings( text_conditionings=uncond_text_embeddings, masks=uncond_text_embedding_masks, latent_height=latent_height, latent_width=latent_width, dtype=unet.dtype, ) if isinstance(cfg_scale, list): assert len(cfg_scale) == steps, "cfg_scale (list) must have the same length as the number of steps" conditioning_data = TextConditioningData( uncond_text=uncond_text_embedding, cond_text=cond_text_embedding, uncond_regions=uncond_regions, cond_regions=cond_regions, guidance_scale=cfg_scale, guidance_rescale_multiplier=cfg_rescale_multiplier, ) return conditioning_data @staticmethod def create_pipeline( unet: UNet2DConditionModel, scheduler: Scheduler, ) -> StableDiffusionGeneratorPipeline: class FakeVae: class FakeVaeConfig: def __init__(self) -> None: self.block_out_channels = [0] def __init__(self) -> None: self.config = FakeVae.FakeVaeConfig() return StableDiffusionGeneratorPipeline( vae=FakeVae(), # TODO: oh... text_encoder=None, tokenizer=None, unet=unet, scheduler=scheduler, safety_checker=None, feature_extractor=None, requires_safety_checker=False, ) @staticmethod def prep_control_data( context: InvocationContext, control_input: ControlField | list[ControlField] | None, latents_shape: List[int], exit_stack: ExitStack, do_classifier_free_guidance: bool = True, ) -> list[ControlNetData] | None: # Normalize control_input to a list. control_list: list[ControlField] if isinstance(control_input, ControlField): control_list = [control_input] elif isinstance(control_input, list): control_list = control_input elif control_input is None: control_list = [] else: raise ValueError(f"Unexpected control_input type: {type(control_input)}") if len(control_list) == 0: return None # Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR. _, _, latent_height, latent_width = latents_shape control_height_resize = latent_height * LATENT_SCALE_FACTOR control_width_resize = latent_width * LATENT_SCALE_FACTOR controlnet_data: list[ControlNetData] = [] for control_info in control_list: control_model = exit_stack.enter_context(context.models.load(control_info.control_model)) assert isinstance(control_model, ControlNetModel) control_image_field = control_info.image input_image = context.images.get_pil(control_image_field.image_name) # self.image.image_type, self.image.image_name # FIXME: still need to test with different widths, heights, devices, dtypes # and add in batch_size, num_images_per_prompt? # and do real check for classifier_free_guidance? # prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width) control_image = prepare_control_image( image=input_image, do_classifier_free_guidance=do_classifier_free_guidance, width=control_width_resize, height=control_height_resize, # batch_size=batch_size * num_images_per_prompt, # num_images_per_prompt=num_images_per_prompt, device=control_model.device, dtype=control_model.dtype, control_mode=control_info.control_mode, resize_mode=control_info.resize_mode, ) control_item = ControlNetData( model=control_model, image_tensor=control_image, weight=control_info.control_weight, begin_step_percent=control_info.begin_step_percent, end_step_percent=control_info.end_step_percent, control_mode=control_info.control_mode, # any resizing needed should currently be happening in prepare_control_image(), # but adding resize_mode to ControlNetData in case needed in the future resize_mode=control_info.resize_mode, ) controlnet_data.append(control_item) # MultiControlNetModel has been refactored out, just need list[ControlNetData] return controlnet_data def prep_ip_adapter_image_prompts( self, context: InvocationContext, ip_adapters: List[IPAdapterField], ) -> List[Tuple[torch.Tensor, torch.Tensor]]: """Run the IPAdapter CLIPVisionModel, returning image prompt embeddings.""" image_prompts = [] for single_ip_adapter in ip_adapters: with context.models.load(single_ip_adapter.ip_adapter_model) as ip_adapter_model: assert isinstance(ip_adapter_model, IPAdapter) image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model) # `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here. single_ipa_image_fields = single_ip_adapter.image if not isinstance(single_ipa_image_fields, list): single_ipa_image_fields = [single_ipa_image_fields] single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_image_fields] with image_encoder_model_info as image_encoder_model: assert isinstance(image_encoder_model, CLIPVisionModelWithProjection) # Get image embeddings from CLIP and ImageProjModel. image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds( single_ipa_images, image_encoder_model ) image_prompts.append((image_prompt_embeds, uncond_image_prompt_embeds)) return image_prompts def prep_ip_adapter_data( self, context: InvocationContext, ip_adapters: List[IPAdapterField], image_prompts: List[Tuple[torch.Tensor, torch.Tensor]], exit_stack: ExitStack, latent_height: int, latent_width: int, dtype: torch.dtype, ) -> Optional[List[IPAdapterData]]: """If IP-Adapter is enabled, then this function loads the requisite models and adds the image prompt conditioning data.""" ip_adapter_data_list = [] for single_ip_adapter, (image_prompt_embeds, uncond_image_prompt_embeds) in zip( ip_adapters, image_prompts, strict=True ): ip_adapter_model = exit_stack.enter_context(context.models.load(single_ip_adapter.ip_adapter_model)) mask_field = single_ip_adapter.mask mask = context.tensors.load(mask_field.tensor_name) if mask_field is not None else None mask = self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype) ip_adapter_data_list.append( IPAdapterData( ip_adapter_model=ip_adapter_model, weight=single_ip_adapter.weight, target_blocks=single_ip_adapter.target_blocks, begin_step_percent=single_ip_adapter.begin_step_percent, end_step_percent=single_ip_adapter.end_step_percent, ip_adapter_conditioning=IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds), mask=mask, ) ) return ip_adapter_data_list if len(ip_adapter_data_list) > 0 else None def run_t2i_adapters( self, context: InvocationContext, t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]], latents_shape: list[int], do_classifier_free_guidance: bool, ) -> Optional[list[T2IAdapterData]]: if t2i_adapter is None: return None # Handle the possibility that t2i_adapter could be a list or a single T2IAdapterField. if isinstance(t2i_adapter, T2IAdapterField): t2i_adapter = [t2i_adapter] if len(t2i_adapter) == 0: return None t2i_adapter_data = [] for t2i_adapter_field in t2i_adapter: t2i_adapter_model_config = context.models.get_config(t2i_adapter_field.t2i_adapter_model.key) t2i_adapter_loaded_model = context.models.load(t2i_adapter_field.t2i_adapter_model) image = context.images.get_pil(t2i_adapter_field.image.image_name) # The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally. if t2i_adapter_model_config.base == BaseModelType.StableDiffusion1: max_unet_downscale = 8 elif t2i_adapter_model_config.base == BaseModelType.StableDiffusionXL: max_unet_downscale = 4 else: raise ValueError(f"Unexpected T2I-Adapter base model type: '{t2i_adapter_model_config.base}'.") t2i_adapter_model: T2IAdapter with t2i_adapter_loaded_model as t2i_adapter_model: total_downscale_factor = t2i_adapter_model.total_downscale_factor # Resize the T2I-Adapter input image. # We select the resize dimensions so that after the T2I-Adapter's total_downscale_factor is applied, the # result will match the latent image's dimensions after max_unet_downscale is applied. t2i_input_height = latents_shape[2] // max_unet_downscale * total_downscale_factor t2i_input_width = latents_shape[3] // max_unet_downscale * total_downscale_factor # Note: We have hard-coded `do_classifier_free_guidance=False`. This is because we only want to prepare # a single image. If CFG is enabled, we will duplicate the resultant tensor after applying the # T2I-Adapter model. # # Note: We re-use the `prepare_control_image(...)` from ControlNet for T2I-Adapter, because it has many # of the same requirements (e.g. preserving binary masks during resize). t2i_image = prepare_control_image( image=image, do_classifier_free_guidance=False, width=t2i_input_width, height=t2i_input_height, num_channels=t2i_adapter_model.config["in_channels"], # mypy treats this as a FrozenDict device=t2i_adapter_model.device, dtype=t2i_adapter_model.dtype, resize_mode=t2i_adapter_field.resize_mode, ) adapter_state = t2i_adapter_model(t2i_image) if do_classifier_free_guidance: for idx, value in enumerate(adapter_state): adapter_state[idx] = torch.cat([value] * 2, dim=0) t2i_adapter_data.append( T2IAdapterData( adapter_state=adapter_state, weight=t2i_adapter_field.weight, begin_step_percent=t2i_adapter_field.begin_step_percent, end_step_percent=t2i_adapter_field.end_step_percent, ) ) return t2i_adapter_data # original idea by https://github.com/AmericanPresidentJimmyCarter # TODO: research more for second order schedulers timesteps @staticmethod def init_scheduler( scheduler: Union[Scheduler, ConfigMixin], device: torch.device, steps: int, denoising_start: float, denoising_end: float, seed: int, ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]: assert isinstance(scheduler, ConfigMixin) if scheduler.config.get("cpu_only", False): scheduler.set_timesteps(steps, device="cpu") timesteps = scheduler.timesteps.to(device=device) else: scheduler.set_timesteps(steps, device=device) timesteps = scheduler.timesteps # skip greater order timesteps _timesteps = timesteps[:: scheduler.order] # get start timestep index t_start_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_start))) t_start_idx = len(list(filter(lambda ts: ts >= t_start_val, _timesteps))) # get end timestep index t_end_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_end))) t_end_idx = len(list(filter(lambda ts: ts >= t_end_val, _timesteps[t_start_idx:]))) # apply order to indexes t_start_idx *= scheduler.order t_end_idx *= scheduler.order init_timestep = timesteps[t_start_idx : t_start_idx + 1] timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx] scheduler_step_kwargs: Dict[str, Any] = {} scheduler_step_signature = inspect.signature(scheduler.step) if "generator" in scheduler_step_signature.parameters: # At some point, someone decided that schedulers that accept a generator should use the original seed with # all bits flipped. I don't know the original rationale for this, but now we must keep it like this for # reproducibility. # # These Invoke-supported schedulers accept a generator as of 2024-06-04: # - DDIMScheduler # - DDPMScheduler # - DPMSolverMultistepScheduler # - EulerAncestralDiscreteScheduler # - EulerDiscreteScheduler # - KDPM2AncestralDiscreteScheduler # - LCMScheduler # - TCDScheduler scheduler_step_kwargs.update({"generator": torch.Generator(device=device).manual_seed(seed ^ 0xFFFFFFFF)}) if isinstance(scheduler, TCDScheduler): scheduler_step_kwargs.update({"eta": 1.0}) return timesteps, init_timestep, scheduler_step_kwargs def prep_inpaint_mask( self, context: InvocationContext, latents: torch.Tensor ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], bool]: if self.denoise_mask is None: return None, None, False mask = context.tensors.load(self.denoise_mask.mask_name) mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) if self.denoise_mask.masked_latents_name is not None: masked_latents = context.tensors.load(self.denoise_mask.masked_latents_name) else: masked_latents = torch.where(mask < 0.5, 0.0, latents) return 1 - mask, masked_latents, self.denoise_mask.gradient @staticmethod def prepare_noise_and_latents( context: InvocationContext, noise_field: LatentsField | None, latents_field: LatentsField | None ) -> Tuple[int, torch.Tensor | None, torch.Tensor]: """Depending on the workflow, we expect different combinations of noise and latents to be provided. This function handles preparing these values accordingly. Expected workflows: - Text-to-Image Denoising: `noise` is provided, `latents` is not. `latents` is initialized to zeros. - Image-to-Image Denoising: `noise` and `latents` are both provided. - Text-to-Image SDXL Refiner Denoising: `latents` is provided, `noise` is not. - Image-to-Image SDXL Refiner Denoising: `latents` is provided, `noise` is not. NOTE(ryand): I wrote this docstring, but I am not the original author of this code. There may be other workflows I haven't considered. """ noise = None if noise_field is not None: noise = context.tensors.load(noise_field.latents_name) if latents_field is not None: latents = context.tensors.load(latents_field.latents_name) elif noise is not None: latents = torch.zeros_like(noise) else: raise ValueError("'latents' or 'noise' must be provided!") if noise is not None and noise.shape[1:] != latents.shape[1:]: raise ValueError(f"Incompatible 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}") # The seed comes from (in order of priority): the noise field, the latents field, or 0. seed = 0 if noise_field is not None and noise_field.seed is not None: seed = noise_field.seed elif latents_field is not None and latents_field.seed is not None: seed = latents_field.seed else: seed = 0 return seed, noise, latents @torch.no_grad() @SilenceWarnings() # This quenches the NSFW nag from diffusers. def invoke(self, context: InvocationContext) -> LatentsOutput: seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents) mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents) # TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets, # below. Investigate whether this is appropriate. t2i_adapter_data = self.run_t2i_adapters( context, self.t2i_adapter, latents.shape, do_classifier_free_guidance=True, ) ip_adapters: List[IPAdapterField] = [] if self.ip_adapter is not None: # ip_adapter could be a list or a single IPAdapterField. Normalize to a list here. if isinstance(self.ip_adapter, list): ip_adapters = self.ip_adapter else: ip_adapters = [self.ip_adapter] # If there are IP adapters, the following line runs the adapters' CLIPVision image encoders to return # a series of image conditioning embeddings. This is being done here rather than in the # big model context below in order to use less VRAM on low-VRAM systems. # The image prompts are then passed to prep_ip_adapter_data(). image_prompts = self.prep_ip_adapter_image_prompts(context=context, ip_adapters=ip_adapters) # get the unet's config so that we can pass the base to sd_step_callback() unet_config = context.models.get_config(self.unet.unet.key) def step_callback(state: PipelineIntermediateState) -> None: context.util.sd_step_callback(state, unet_config.base) def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in self.unet.loras: lora_info = context.models.load(lora.lora) assert isinstance(lora_info.model, LoRAModelRaw) yield (lora_info.model, lora.weight) del lora_info return unet_info = context.models.load(self.unet.unet) assert isinstance(unet_info.model, UNet2DConditionModel) with ( ExitStack() as exit_stack, unet_info.model_on_device() as (model_state_dict, unet), ModelPatcher.apply_freeu(unet, self.unet.freeu_config), set_seamless(unet, self.unet.seamless_axes), # FIXME # Apply the LoRA after unet has been moved to its target device for faster patching. ModelPatcher.apply_lora_unet( unet, loras=_lora_loader(), model_state_dict=model_state_dict, ), ): assert isinstance(unet, UNet2DConditionModel) latents = latents.to(device=unet.device, dtype=unet.dtype) if noise is not None: noise = noise.to(device=unet.device, dtype=unet.dtype) if mask is not None: mask = mask.to(device=unet.device, dtype=unet.dtype) if masked_latents is not None: masked_latents = masked_latents.to(device=unet.device, dtype=unet.dtype) scheduler = get_scheduler( context=context, scheduler_info=self.unet.scheduler, scheduler_name=self.scheduler, seed=seed, ) pipeline = self.create_pipeline(unet, scheduler) _, _, latent_height, latent_width = latents.shape conditioning_data = self.get_conditioning_data( context=context, positive_conditioning_field=self.positive_conditioning, negative_conditioning_field=self.negative_conditioning, unet=unet, latent_height=latent_height, latent_width=latent_width, cfg_scale=self.cfg_scale, steps=self.steps, cfg_rescale_multiplier=self.cfg_rescale_multiplier, ) controlnet_data = self.prep_control_data( context=context, control_input=self.control, latents_shape=latents.shape, # do_classifier_free_guidance=(self.cfg_scale >= 1.0)) do_classifier_free_guidance=True, exit_stack=exit_stack, ) ip_adapter_data = self.prep_ip_adapter_data( context=context, ip_adapters=ip_adapters, image_prompts=image_prompts, exit_stack=exit_stack, latent_height=latent_height, latent_width=latent_width, dtype=unet.dtype, ) timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler( scheduler, device=unet.device, steps=self.steps, denoising_start=self.denoising_start, denoising_end=self.denoising_end, seed=seed, ) result_latents = pipeline.latents_from_embeddings( latents=latents, timesteps=timesteps, init_timestep=init_timestep, noise=noise, seed=seed, mask=mask, masked_latents=masked_latents, is_gradient_mask=gradient_mask, scheduler_step_kwargs=scheduler_step_kwargs, conditioning_data=conditioning_data, control_data=controlnet_data, ip_adapter_data=ip_adapter_data, t2i_adapter_data=t2i_adapter_data, callback=step_callback, ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 result_latents = result_latents.to("cpu") TorchDevice.empty_cache() name = context.tensors.save(tensor=result_latents) return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)