from __future__ import annotations import torch from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput from tqdm.auto import tqdm from invokeai.app.services.config.config_default import get_config from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, UNetKwargs from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager class StableDiffusionBackend: def __init__( self, unet: UNet2DConditionModel, scheduler: SchedulerMixin, ): self.unet = unet self.scheduler = scheduler config = get_config() self._sequential_guidance = config.sequential_guidance def latents_from_embeddings(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): if ctx.inputs.init_timestep.shape[0] == 0: return ctx.inputs.orig_latents ctx.latents = ctx.inputs.orig_latents.clone() if ctx.inputs.noise is not None: batch_size = ctx.latents.shape[0] # latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers ctx.latents = ctx.scheduler.add_noise( ctx.latents, ctx.inputs.noise, ctx.inputs.init_timestep.expand(batch_size) ) # if no work to do, return latents if ctx.inputs.timesteps.shape[0] == 0: return ctx.latents # ext: inpaint[pre_denoise_loop, priority=normal] (maybe init, but not sure if it needed) # ext: preview[pre_denoise_loop, priority=low] ext_manager.callbacks.pre_denoise_loop(ctx, ext_manager) for ctx.step_index, ctx.timestep in enumerate(tqdm(ctx.inputs.timesteps)): # noqa: B020 # ext: inpaint (apply mask to latents on non-inpaint models) ext_manager.callbacks.pre_step(ctx, ext_manager) # ext: tiles? [override: step] ctx.step_output = self.step(ctx, ext_manager) # ext: inpaint[post_step, priority=high] (apply mask to preview on non-inpaint models) # ext: preview[post_step, priority=low] ext_manager.callbacks.post_step(ctx, ext_manager) ctx.latents = ctx.step_output.prev_sample # ext: inpaint[post_denoise_loop] (restore unmasked part) ext_manager.callbacks.post_denoise_loop(ctx, ext_manager) return ctx.latents @torch.inference_mode() def step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager) -> SchedulerOutput: ctx.latent_model_input = ctx.scheduler.scale_model_input(ctx.latents, ctx.timestep) # TODO: conditionings as list(conditioning_data.to_unet_kwargs - ready) # Note: The current handling of conditioning doesn't feel very future-proof. # This might change in the future as new requirements come up, but for now, # this is the rough plan. if self._sequential_guidance: ctx.negative_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Negative) ctx.positive_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Positive) else: both_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Both) ctx.negative_noise_pred, ctx.positive_noise_pred = both_noise_pred.chunk(2) # ext: override apply_cfg ctx.noise_pred = self.apply_cfg(ctx) # ext: cfg_rescale [modify_noise_prediction] # TODO: rename ext_manager.callbacks.post_apply_cfg(ctx, ext_manager) # compute the previous noisy sample x_t -> x_t-1 step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.inputs.scheduler_step_kwargs) # clean up locals ctx.latent_model_input = None ctx.negative_noise_pred = None ctx.positive_noise_pred = None ctx.noise_pred = None return step_output @staticmethod def apply_cfg(ctx: DenoiseContext) -> torch.Tensor: guidance_scale = ctx.inputs.conditioning_data.guidance_scale if isinstance(guidance_scale, list): guidance_scale = guidance_scale[ctx.step_index] return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale) # return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred) def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditioning_mode: ConditioningMode): sample = ctx.latent_model_input if conditioning_mode == ConditioningMode.Both: sample = torch.cat([sample] * 2) ctx.unet_kwargs = UNetKwargs( sample=sample, timestep=ctx.timestep, encoder_hidden_states=None, # set later by conditoning cross_attention_kwargs=dict( # noqa: C408 percent_through=ctx.step_index / len(ctx.inputs.timesteps), ), ) ctx.conditioning_mode = conditioning_mode ctx.inputs.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, ctx.conditioning_mode) # ext: controlnet/ip/t2i [pre_unet] ext_manager.callbacks.pre_unet(ctx, ext_manager) # ext: inpaint [pre_unet, priority=low] # or # ext: inpaint [override: unet_forward] noise_pred = self._unet_forward(**vars(ctx.unet_kwargs)) ext_manager.callbacks.post_unet(ctx, ext_manager) # clean up locals ctx.unet_kwargs = None ctx.conditioning_mode = None return noise_pred def _unet_forward(self, **kwargs) -> torch.Tensor: return self.unet(**kwargs).sample