From 6edeb4e07247b866a27703ebe729ef05032c2d51 Mon Sep 17 00:00:00 2001 From: ZachNagengast Date: Thu, 27 Jul 2023 12:48:49 -0700 Subject: [PATCH] Pass device to set_timestep to avoid float64 error --- invokeai/app/invocations/sdxl.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index 303a2fedd9..b9cd80e2f8 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -291,13 +291,14 @@ class SDXLTextToLatentsInvocation(BaseInvocation): scheduler_name=self.scheduler, ) + unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context) + num_inference_steps = self.steps - scheduler.set_timesteps(num_inference_steps) + scheduler.set_timesteps(num_inference_steps, device=unet_info.device) timesteps = scheduler.timesteps latents = latents * scheduler.init_noise_sigma - unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context) do_classifier_free_guidance = True cross_attention_kwargs = None with unet_info as unet: @@ -537,9 +538,14 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation): scheduler_name=self.scheduler, ) + unet_info = context.services.model_manager.get_model( + **self.unet.unet.dict(), + context=context, + ) + # apply denoising_start num_inference_steps = self.steps - scheduler.set_timesteps(num_inference_steps) + scheduler.set_timesteps(num_inference_steps, device=unet_info.device) t_start = int(round(self.denoising_start * num_inference_steps)) timesteps = scheduler.timesteps[t_start * scheduler.order :] @@ -551,10 +557,6 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation): latents = scheduler.add_noise(latents, noise, timesteps[:1]) del noise - unet_info = context.services.model_manager.get_model( - **self.unet.unet.dict(), - context=context, - ) do_classifier_free_guidance = True cross_attention_kwargs = None with unet_info as unet: