Invert mask, fix l2l on no mask conntected, remove zeroing latents on zero start

This commit is contained in:
Sergey Borisov 2023-08-08 20:01:49 +03:00
parent 96b7248051
commit da0184a786

View File

@ -462,7 +462,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
mask_tensor = tv_resize( mask_tensor = tv_resize(
mask_tensor, lantents.shape[-2:], T.InterpolationMode.BILINEAR mask_tensor, lantents.shape[-2:], T.InterpolationMode.BILINEAR
) )
return mask_tensor return 1 - mask_tensor
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
@ -502,10 +502,11 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet( with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
unet_info.context.model, _lora_loader() unet_info.context.model, _lora_loader()
), unet_info as unet: ), unet_info as unet:
latent = latent.to(device=unet.device, dtype=unet.dtype)
if noise is not None: if noise is not None:
noise = noise.to(device=unet.device, dtype=unet.dtype) noise = noise.to(device=unet.device, dtype=unet.dtype)
latent = latent.to(device=unet.device, dtype=unet.dtype) if mask is not None:
mask = mask.to(device=unet.device, dtype=unet.dtype) mask = mask.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler( scheduler = get_scheduler(
context=context, context=context,
@ -526,11 +527,6 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
exit_stack=exit_stack, exit_stack=exit_stack,
) )
# TODO: Verify the noise is the right size
initial_latents = latent
if self.denoising_start <= 0.0:
initial_latents = torch.zeros_like(latent, device=unet.device, dtype=latent.dtype)
num_inference_steps, timesteps = self.init_scheduler( num_inference_steps, timesteps = self.init_scheduler(
scheduler, scheduler,
device=unet.device, device=unet.device,
@ -540,7 +536,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
) )
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings( result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
latents=initial_latents, latents=latent,
timesteps=timesteps, timesteps=timesteps,
noise=noise, noise=noise,
seed=seed, seed=seed,