fix(inpainting model): blank areas to be repainted in the masked image (#2447)

Otherwise the model seems too reluctant to change these areas, even
though the mask channel should allow it to.

This makes the solid infill method proposed by #2441 less necessary,
though I think there's still a place for an infill method that is faster
than patchmatch and more predictable than tiles.

Even with #2441, this PR is still useful because it influences all areas
to be painted, not just the infill area.

Fixes #2417
This commit is contained in:
Lincoln Stein 2023-01-31 18:01:33 -05:00 committed by GitHub
commit 053d11fe30
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -544,6 +544,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB'))
init_image = init_image.to(device=device, dtype=latents_dtype)
mask = mask.to(device=device, dtype=latents_dtype)
if init_image.dim() == 3:
init_image = init_image.unsqueeze(0)
@ -562,17 +563,22 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if mask.dim() == 3:
mask = mask.unsqueeze(0)
mask = tv_resize(mask, init_image_latents.shape[-2:], T.InterpolationMode.BILINEAR) \
latent_mask = tv_resize(mask, init_image_latents.shape[-2:], T.InterpolationMode.BILINEAR) \
.to(device=device, dtype=latents_dtype)
guidance: List[Callable] = []
if is_inpainting_model(self.unet):
# You'd think the inpainting model wouldn't be paying attention to the area it is going to repaint
# (that's why there's a mask!) but it seems to really want that blanked out.
masked_init_image = init_image * torch.where(mask < 0.5, 1, 0)
masked_latents = self.non_noised_latents_from_image(masked_init_image, device=device, dtype=latents_dtype)
# TODO: we should probably pass this in so we don't have to try/finally around setting it.
self.invokeai_diffuser.model_forward_callback = \
AddsMaskLatents(self._unet_forward, mask, init_image_latents)
AddsMaskLatents(self._unet_forward, latent_mask, masked_latents)
else:
guidance.append(AddsMaskGuidance(mask, init_image_latents, self.scheduler, noise))
guidance.append(AddsMaskGuidance(latent_mask, init_image_latents, self.scheduler, noise))
try:
result_latents, result_attention_maps = self.latents_from_embeddings(