diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index b9cd80e2f8..1ea330fe56 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -291,17 +291,17 @@ class SDXLTextToLatentsInvocation(BaseInvocation): scheduler_name=self.scheduler, ) - unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context) - num_inference_steps = self.steps - scheduler.set_timesteps(num_inference_steps, device=unet_info.device) - timesteps = scheduler.timesteps latents = latents * scheduler.init_noise_sigma + unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context) do_classifier_free_guidance = True cross_attention_kwargs = None with unet_info as unet: + scheduler.set_timesteps(num_inference_steps, device=unet.device) + timesteps = scheduler.timesteps + extra_step_kwargs = dict() if "eta" in set(inspect.signature(scheduler.step).parameters.keys()): extra_step_kwargs.update( @@ -543,23 +543,23 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation): context=context, ) - # apply denoising_start - num_inference_steps = self.steps - scheduler.set_timesteps(num_inference_steps, device=unet_info.device) - - t_start = int(round(self.denoising_start * num_inference_steps)) - timesteps = scheduler.timesteps[t_start * scheduler.order :] - num_inference_steps = num_inference_steps - t_start - - # apply noise(if provided) - if self.noise is not None and timesteps.shape[0] > 0: - noise = context.services.latents.get(self.noise.latents_name) - latents = scheduler.add_noise(latents, noise, timesteps[:1]) - del noise - do_classifier_free_guidance = True cross_attention_kwargs = None with unet_info as unet: + # apply denoising_start + num_inference_steps = self.steps + scheduler.set_timesteps(num_inference_steps, device=self.scheduler.device) + + t_start = int(round(self.denoising_start * num_inference_steps)) + timesteps = scheduler.timesteps[t_start * scheduler.order :] + num_inference_steps = num_inference_steps - t_start + + # apply noise(if provided) + if self.noise is not None and timesteps.shape[0] > 0: + noise = context.services.latents.get(self.noise.latents_name) + latents = scheduler.add_noise(latents, noise, timesteps[:1]) + del noise + # apply scheduler extra args extra_step_kwargs = dict() if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):