Pass device to set_timestep to avoid float64 error

This commit is contained in:
ZachNagengast 2023-07-27 12:48:49 -07:00
parent e191f6d4b2
commit 6edeb4e072

View File

@ -291,13 +291,14 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
scheduler_name=self.scheduler,
)
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context)
num_inference_steps = self.steps
scheduler.set_timesteps(num_inference_steps)
scheduler.set_timesteps(num_inference_steps, device=unet_info.device)
timesteps = scheduler.timesteps
latents = latents * scheduler.init_noise_sigma
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context)
do_classifier_free_guidance = True
cross_attention_kwargs = None
with unet_info as unet:
@ -537,9 +538,14 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
scheduler_name=self.scheduler,
)
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict(),
context=context,
)
# apply denoising_start
num_inference_steps = self.steps
scheduler.set_timesteps(num_inference_steps)
scheduler.set_timesteps(num_inference_steps, device=unet_info.device)
t_start = int(round(self.denoising_start * num_inference_steps))
timesteps = scheduler.timesteps[t_start * scheduler.order :]
@ -551,10 +557,6 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
latents = scheduler.add_noise(latents, noise, timesteps[:1])
del noise
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict(),
context=context,
)
do_classifier_free_guidance = True
cross_attention_kwargs = None
with unet_info as unet: