diff --git a/ldm/invoke/generator/img2img.py b/ldm/invoke/generator/img2img.py index 4942bcc0c3..e12f5e5e79 100644 --- a/ldm/invoke/generator/img2img.py +++ b/ldm/invoke/generator/img2img.py @@ -51,7 +51,8 @@ class Img2Img(Generator): unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc, init_latent = self.init_latent, # changes how noising is performed in ksampler - extra_conditioning_info = extra_conditioning_info + extra_conditioning_info = extra_conditioning_info, + all_timesteps_count = steps ) return self.sample_to_image(samples) diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index 6e873f1c6d..1e5b073a3d 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -29,6 +29,10 @@ class CrossAttentionControl: class Context: def __init__(self, arguments: 'CrossAttentionControl.Arguments', step_count: int): + """ + :param arguments: Arguments for the cross-attention control process + :param step_count: The absolute total number of steps of diffusion (for img2img this is likely larger than the number of steps that will actually run) + """ self.arguments = arguments self.step_count = step_count diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index 5b5dfaf4af..b11e8578e7 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -16,9 +16,10 @@ class DDIMSampler(Sampler): super().prepare_to_sample(t_enc, **kwargs) extra_conditioning_info = kwargs.get('extra_conditioning_info', None) + all_timesteps_count = kwargs.get('all_timesteps_count', t_enc) if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: - self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = t_enc) + self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = all_timesteps_count) else: self.invokeai_diffuser.remove_cross_attention_control() diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py index 40c2631bcd..6bd519b63b 100644 --- a/ldm/models/diffusion/plms.py +++ b/ldm/models/diffusion/plms.py @@ -21,9 +21,10 @@ class PLMSSampler(Sampler): super().prepare_to_sample(t_enc, **kwargs) extra_conditioning_info = kwargs.get('extra_conditioning_info', None) + all_timesteps_count = kwargs.get('all_timesteps_count', t_enc) if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: - self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = t_enc) + self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = all_timesteps_count) else: self.invokeai_diffuser.remove_cross_attention_control() diff --git a/ldm/models/diffusion/sampler.py b/ldm/models/diffusion/sampler.py index 79c15717fe..853702ef68 100644 --- a/ldm/models/diffusion/sampler.py +++ b/ldm/models/diffusion/sampler.py @@ -235,7 +235,7 @@ class Sampler(object): dynamic_ncols=True, ) old_eps = [] - self.prepare_to_sample(t_enc=total_steps,**kwargs) + self.prepare_to_sample(t_enc=total_steps,all_timesteps_count=steps,**kwargs) img = self.get_initial_image(x_T,shape,total_steps) # probably don't need this at all @@ -310,6 +310,7 @@ class Sampler(object): use_original_steps=False, init_latent = None, mask = None, + all_timesteps_count = None, **kwargs ): @@ -327,7 +328,7 @@ class Sampler(object): iterator = tqdm(time_range, desc='Decoding image', total=total_steps) x_dec = x_latent x0 = init_latent - self.prepare_to_sample(t_enc=total_steps,**kwargs) + self.prepare_to_sample(t_enc=total_steps, all_timesteps_count=all_timesteps_count, **kwargs) for i, step in enumerate(iterator): index = total_steps - i - 1 @@ -359,7 +360,7 @@ class Sampler(object): unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, t_next = ts_next, - step_count=total_steps + step_count=len(self.ddim_timesteps) ) x_dec, pred_x0, e_t = outs