mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Pass device to set_timestep to avoid float64 error
This commit is contained in:
parent
e191f6d4b2
commit
6edeb4e072
@ -291,13 +291,14 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
|||||||
scheduler_name=self.scheduler,
|
scheduler_name=self.scheduler,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context)
|
||||||
|
|
||||||
num_inference_steps = self.steps
|
num_inference_steps = self.steps
|
||||||
scheduler.set_timesteps(num_inference_steps)
|
scheduler.set_timesteps(num_inference_steps, device=unet_info.device)
|
||||||
timesteps = scheduler.timesteps
|
timesteps = scheduler.timesteps
|
||||||
|
|
||||||
latents = latents * scheduler.init_noise_sigma
|
latents = latents * scheduler.init_noise_sigma
|
||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context)
|
|
||||||
do_classifier_free_guidance = True
|
do_classifier_free_guidance = True
|
||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = None
|
||||||
with unet_info as unet:
|
with unet_info as unet:
|
||||||
@ -537,9 +538,14 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
|||||||
scheduler_name=self.scheduler,
|
scheduler_name=self.scheduler,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
unet_info = context.services.model_manager.get_model(
|
||||||
|
**self.unet.unet.dict(),
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
|
||||||
# apply denoising_start
|
# apply denoising_start
|
||||||
num_inference_steps = self.steps
|
num_inference_steps = self.steps
|
||||||
scheduler.set_timesteps(num_inference_steps)
|
scheduler.set_timesteps(num_inference_steps, device=unet_info.device)
|
||||||
|
|
||||||
t_start = int(round(self.denoising_start * num_inference_steps))
|
t_start = int(round(self.denoising_start * num_inference_steps))
|
||||||
timesteps = scheduler.timesteps[t_start * scheduler.order :]
|
timesteps = scheduler.timesteps[t_start * scheduler.order :]
|
||||||
@ -551,10 +557,6 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
|||||||
latents = scheduler.add_noise(latents, noise, timesteps[:1])
|
latents = scheduler.add_noise(latents, noise, timesteps[:1])
|
||||||
del noise
|
del noise
|
||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(
|
|
||||||
**self.unet.unet.dict(),
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
do_classifier_free_guidance = True
|
do_classifier_free_guidance = True
|
||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = None
|
||||||
with unet_info as unet:
|
with unet_info as unet:
|
||||||
|
Loading…
Reference in New Issue
Block a user