diffusers: make masked img2img behave better with multi-step schedulers

re-randomizing the noise each step was confusing them.
This commit is contained in:
Kevin Turner 2022-12-10 08:27:46 -08:00
parent 1a67836012
commit 50c48cffc7

View File

@ -91,7 +91,7 @@ class AddsMaskGuidance:
mask: torch.FloatTensor
mask_latents: torch.FloatTensor
_scheduler: SchedulerMixin
_noise_func: Callable
noise: torch.Tensor
_debug: Optional[Callable] = None
def __call__(self, step_output: BaseOutput | SchedulerOutput, t: torch.Tensor, conditioning) -> BaseOutput:
@ -117,8 +117,9 @@ class AddsMaskGuidance:
def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor:
batch_size = latents.size(0)
mask = einops.repeat(self.mask, 'b c h w -> (repeat b) c h w', repeat=batch_size)
noise = self._noise_func(self.mask_latents)
mask_latents = self._scheduler.add_noise(self.mask_latents, noise, t)
# Noise shouldn't be re-randomized between steps here. The multistep schedulers
# get very confused about what is happening from step to step when we do that.
mask_latents = self._scheduler.add_noise(self.mask_latents, self.noise, t)
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
mask_latents = einops.repeat(mask_latents, 'b c h w -> (repeat b) c h w', repeat=batch_size)
masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype))
@ -413,7 +414,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, _ = img2img_pipeline.get_timesteps(num_inference_steps, strength, device=device)
latent_timestep = timesteps[:1].repeat(batch_size)
latents = self.noise_latents_for_time(initial_latents, latent_timestep, noise_func=noise_func)
noise = noise_func(initial_latents)
noised_latents = self.scheduler.add_noise(initial_latents, noise, latent_timestep)
latents = noised_latents
result_latents = self.latents_from_embeddings(
latents, num_inference_steps, text_embeddings, unconditioned_embeddings, guidance_scale,
@ -467,7 +470,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents
# because we have our own noise function
init_image_latents = self.non_noised_latents_from_image(init_image, device=device, dtype=latents_dtype)
latents = self.noise_latents_for_time(init_image_latents, latent_timestep, noise_func=noise_func)
noise = noise_func(init_image_latents)
latents = self.scheduler.add_noise(init_image_latents, noise, latent_timestep)
if mask.dim() == 3:
mask = mask.unsqueeze(0)
@ -481,7 +485,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self.invokeai_diffuser.model_forward_callback = \
AddsMaskLatents(self._unet_forward, mask, init_image_latents)
else:
guidance.append(AddsMaskGuidance(mask, init_image_latents, self.scheduler, noise_func))
guidance.append(AddsMaskGuidance(mask, init_image_latents, self.scheduler, noise))
try:
result_latents = self.latents_from_embeddings(
@ -510,11 +514,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
init_latents = 0.18215 * init_latents
return init_latents
def noise_latents_for_time(self, latents, timestep, *, noise_func):
noise = noise_func(latents)
noised_latents = self.scheduler.add_noise(latents, noise, timestep)
return noised_latents
def check_for_safety(self, output, dtype):
with torch.inference_mode():
screened_images, has_nsfw_concept = self.run_safety_checker(