"""wrapper around part of Katherine Crowson's k-diffusion library, making it call compatible with other Samplers""" import k_diffusion as K import torch from torch import nn from .cross_attention_map_saving import AttentionMapSaver from .sampler import Sampler from .shared_invokeai_diffusion import InvokeAIDiffuserComponent # at this threshold, the scheduler will stop using the Karras # noise schedule and start using the model's schedule STEP_THRESHOLD = 30 def cfg_apply_threshold(result, threshold=0.0, scale=0.7): if threshold <= 0.0: return result maxval = 0.0 + torch.max(result).cpu().numpy() minval = 0.0 + torch.min(result).cpu().numpy() if maxval < threshold and minval > -threshold: return result if maxval > threshold: maxval = min(max(1, scale * maxval), threshold) if minval < -threshold: minval = max(min(-1, scale * minval), -threshold) return torch.clamp(result, min=minval, max=maxval) class CFGDenoiser(nn.Module): def __init__(self, model, threshold=0, warmup=0): super().__init__() self.inner_model = model self.threshold = threshold self.warmup_max = warmup self.warmup = max(warmup / 10, 1) self.invokeai_diffuser = InvokeAIDiffuserComponent( model, model_forward_callback=lambda x, sigma, cond: self.inner_model( x, sigma, cond=cond ), ) def prepare_to_sample(self, t_enc, **kwargs): extra_conditioning_info = kwargs.get("extra_conditioning_info", None) if ( extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control ): self.invokeai_diffuser.override_cross_attention( extra_conditioning_info, step_count=t_enc ) else: self.invokeai_diffuser.restore_default_cross_attention() def forward(self, x, sigma, uncond, cond, cond_scale): next_x = self.invokeai_diffuser.do_diffusion_step( x, sigma, uncond, cond, cond_scale ) if self.warmup < self.warmup_max: thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max)) self.warmup += 1 else: thresh = self.threshold if thresh > self.threshold: thresh = self.threshold return cfg_apply_threshold(next_x, thresh) class KSampler(Sampler): def __init__(self, model, schedule="lms", device=None, **kwargs): denoiser = K.external.CompVisDenoiser(model) super().__init__( denoiser, schedule, steps=model.num_timesteps, ) self.sigmas = None self.ds = None self.s_in = None self.karras_max = kwargs.get("karras_max", STEP_THRESHOLD) if self.karras_max is None: self.karras_max = STEP_THRESHOLD def make_schedule( self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=False, ): outer_model = self.model self.model = outer_model.inner_model super().make_schedule( ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=False, ) self.model = outer_model self.ddim_num_steps = ddim_num_steps # we don't need both of these sigmas, but storing them here to make # comparison easier later on self.model_sigmas = self.model.get_sigmas(ddim_num_steps) self.karras_sigmas = K.sampling.get_sigmas_karras( n=ddim_num_steps, sigma_min=self.model.sigmas[0].item(), sigma_max=self.model.sigmas[-1].item(), rho=7.0, device=self.device, ) if ddim_num_steps >= self.karras_max: print( f">> Ksampler using model noise schedule (steps >= {self.karras_max})" ) self.sigmas = self.model_sigmas else: print( f">> Ksampler using karras noise schedule (steps < {self.karras_max})" ) self.sigmas = self.karras_sigmas # ALERT: We are completely overriding the sample() method in the base class, which # means that inpainting will not work. To get this to work we need to be able to # modify the inner loop of k_heun, k_lms, etc, as is done in an ugly way # in the lstein/k-diffusion branch. @torch.no_grad() def decode( self, z_enc, cond, t_enc, img_callback=None, unconditional_guidance_scale=1.0, unconditional_conditioning=None, use_original_steps=False, init_latent=None, mask=None, **kwargs, ): samples, _ = self.sample( batch_size=1, S=t_enc, x_T=z_enc, shape=z_enc.shape[1:], conditioning=cond, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, img_callback=img_callback, x0=init_latent, mask=mask, **kwargs, ) return samples # this is a no-op, provided here for compatibility with ddim and plms samplers @torch.no_grad() def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): return x0 # Most of these arguments are ignored and are only present for compatibility with # other samples @torch.no_grad() def sample( self, S, batch_size, shape, conditioning=None, callback=None, normals_sequence=None, img_callback=None, attention_maps_callback=None, quantize_x0=False, eta=0.0, mask=None, x0=None, temperature=1.0, noise_dropout=0.0, score_corrector=None, corrector_kwargs=None, verbose=True, x_T=None, log_every_t=100, unconditional_guidance_scale=1.0, unconditional_conditioning=None, extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None, threshold=0, perlin=0, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... **kwargs, ): def route_callback(k_callback_values): if img_callback is not None: img_callback(k_callback_values["x"], k_callback_values["i"]) # if make_schedule() hasn't been called, we do it now if self.sigmas is None: self.make_schedule( ddim_num_steps=S, ddim_eta=eta, verbose=False, ) # sigmas are set up in make_schedule - we take the last steps items sigmas = self.sigmas[-S - 1 :] # x_T is variation noise. When an init image is provided (in x0) we need to add # more randomness to the starting image. if x_T is not None: if x0 is not None: x = x_T + torch.randn_like(x0, device=self.device) * sigmas[0] else: x = x_T * sigmas[0] else: x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] model_wrap_cfg = CFGDenoiser( self.model, threshold=threshold, warmup=max(0.8 * S, S - 10) ) model_wrap_cfg.prepare_to_sample( S, extra_conditioning_info=extra_conditioning_info ) # setup attention maps saving. checks for None are because there are multiple code paths to get here. attention_map_saver = None if attention_maps_callback is not None and extra_conditioning_info is not None: eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1 attention_map_token_ids = range(1, eos_token_index) attention_map_saver = AttentionMapSaver( token_ids=attention_map_token_ids, latents_shape=x.shape[-2:] ) model_wrap_cfg.invokeai_diffuser.setup_attention_map_saving( attention_map_saver ) extra_args = { "cond": conditioning, "uncond": unconditional_conditioning, "cond_scale": unconditional_guidance_scale, } print( f">> Sampling with k_{self.schedule} starting at step {len(self.sigmas)-S-1} of {len(self.sigmas)-1} ({S} new sampling steps)" ) sampling_result = ( K.sampling.__dict__[f"sample_{self.schedule}"]( model_wrap_cfg, x, sigmas, extra_args=extra_args, callback=route_callback, ), None, ) if attention_map_saver is not None: attention_maps_callback(attention_map_saver) return sampling_result # this code will support inpainting if and when ksampler API modified or # a workaround is found. @torch.no_grad() def p_sample( self, img, cond, ts, index, unconditional_guidance_scale=1.0, unconditional_conditioning=None, extra_conditioning_info=None, **kwargs, ): if self.model_wrap is None: self.model_wrap = CFGDenoiser(self.model) extra_args = { "cond": cond, "uncond": unconditional_conditioning, "cond_scale": unconditional_guidance_scale, } if self.s_in is None: self.s_in = img.new_ones([img.shape[0]]) if self.ds is None: self.ds = [] # terrible, confusing names here steps = self.ddim_num_steps t_enc = self.t_enc # sigmas is a full steps in length, but t_enc might # be less. We start in the middle of the sigma array # and work our way to the end after t_enc steps. # index starts at t_enc and works its way to zero, # so the actual formula for indexing into sigmas: # sigma_index = (steps-index) s_index = t_enc - index - 1 self.model_wrap.prepare_to_sample( s_index, extra_conditioning_info=extra_conditioning_info ) img = K.sampling.__dict__[f"_{self.schedule}"]( self.model_wrap, img, self.sigmas, s_index, s_in=self.s_in, ds=self.ds, extra_args=extra_args, ) return img, None, None # REVIEW THIS METHOD: it has never been tested. In particular, # we should not be multiplying by self.sigmas[0] if we # are at an intermediate step in img2img. See similar in # sample() which does work. def get_initial_image(self, x_T, shape, steps): print(f"WARNING: ksampler.get_initial_image(): get_initial_image needs testing") x = torch.randn(shape, device=self.device) * self.sigmas[0] if x_T is not None: return x_T + x else: return x def prepare_to_sample(self, t_enc, **kwargs): self.t_enc = t_enc self.model_wrap = None self.ds = None self.s_in = None def q_sample(self, x0, ts): """ Overrides parent method to return the q_sample of the inner model. """ return self.model.inner_model.q_sample(x0, ts) def conditioning_key(self) -> str: return self.model.inner_model.model.conditioning_key