diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index 303a2fedd9..b9cd80e2f8 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -291,13 +291,14 @@ class SDXLTextToLatentsInvocation(BaseInvocation): scheduler_name=self.scheduler, ) + unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context) + num_inference_steps = self.steps - scheduler.set_timesteps(num_inference_steps) + scheduler.set_timesteps(num_inference_steps, device=unet_info.device) timesteps = scheduler.timesteps latents = latents * scheduler.init_noise_sigma - unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context) do_classifier_free_guidance = True cross_attention_kwargs = None with unet_info as unet: @@ -537,9 +538,14 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation): scheduler_name=self.scheduler, ) + unet_info = context.services.model_manager.get_model( + **self.unet.unet.dict(), + context=context, + ) + # apply denoising_start num_inference_steps = self.steps - scheduler.set_timesteps(num_inference_steps) + scheduler.set_timesteps(num_inference_steps, device=unet_info.device) t_start = int(round(self.denoising_start * num_inference_steps)) timesteps = scheduler.timesteps[t_start * scheduler.order :] @@ -551,10 +557,6 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation): latents = scheduler.add_noise(latents, noise, timesteps[:1]) del noise - unet_info = context.services.model_manager.get_model( - **self.unet.unet.dict(), - context=context, - ) do_classifier_free_guidance = True cross_attention_kwargs = None with unet_info as unet: