Fix unet_info location, can have no device prop

This commit is contained in:
ZachNagengast 2023-07-27 14:47:09 -07:00
parent 6edeb4e072
commit aa1f827271

View File

@ -291,17 +291,17 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
scheduler_name=self.scheduler, scheduler_name=self.scheduler,
) )
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context)
num_inference_steps = self.steps num_inference_steps = self.steps
scheduler.set_timesteps(num_inference_steps, device=unet_info.device)
timesteps = scheduler.timesteps
latents = latents * scheduler.init_noise_sigma latents = latents * scheduler.init_noise_sigma
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context)
do_classifier_free_guidance = True do_classifier_free_guidance = True
cross_attention_kwargs = None cross_attention_kwargs = None
with unet_info as unet: with unet_info as unet:
scheduler.set_timesteps(num_inference_steps, device=unet.device)
timesteps = scheduler.timesteps
extra_step_kwargs = dict() extra_step_kwargs = dict()
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()): if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
extra_step_kwargs.update( extra_step_kwargs.update(
@ -543,23 +543,23 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
context=context, context=context,
) )
# apply denoising_start
num_inference_steps = self.steps
scheduler.set_timesteps(num_inference_steps, device=unet_info.device)
t_start = int(round(self.denoising_start * num_inference_steps))
timesteps = scheduler.timesteps[t_start * scheduler.order :]
num_inference_steps = num_inference_steps - t_start
# apply noise(if provided)
if self.noise is not None and timesteps.shape[0] > 0:
noise = context.services.latents.get(self.noise.latents_name)
latents = scheduler.add_noise(latents, noise, timesteps[:1])
del noise
do_classifier_free_guidance = True do_classifier_free_guidance = True
cross_attention_kwargs = None cross_attention_kwargs = None
with unet_info as unet: with unet_info as unet:
# apply denoising_start
num_inference_steps = self.steps
scheduler.set_timesteps(num_inference_steps, device=self.scheduler.device)
t_start = int(round(self.denoising_start * num_inference_steps))
timesteps = scheduler.timesteps[t_start * scheduler.order :]
num_inference_steps = num_inference_steps - t_start
# apply noise(if provided)
if self.noise is not None and timesteps.shape[0] > 0:
noise = context.services.latents.get(self.noise.latents_name)
latents = scheduler.add_noise(latents, noise, timesteps[:1])
del noise
# apply scheduler extra args # apply scheduler extra args
extra_step_kwargs = dict() extra_step_kwargs = dict()
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()): if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):