mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(inpainting model): blank areas to be repainted in the masked image
Otherwise the model seems too reluctant to change these areas, even though the mask channel should allow it to.
This commit is contained in:
parent
c18db4e47b
commit
71b6ddf5fb
@ -531,6 +531,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB'))
|
||||
|
||||
init_image = init_image.to(device=device, dtype=latents_dtype)
|
||||
mask = mask.to(device=device, dtype=latents_dtype)
|
||||
|
||||
if init_image.dim() == 3:
|
||||
init_image = init_image.unsqueeze(0)
|
||||
@ -549,17 +550,22 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
|
||||
if mask.dim() == 3:
|
||||
mask = mask.unsqueeze(0)
|
||||
mask = tv_resize(mask, init_image_latents.shape[-2:], T.InterpolationMode.BILINEAR) \
|
||||
latent_mask = tv_resize(mask, init_image_latents.shape[-2:], T.InterpolationMode.BILINEAR) \
|
||||
.to(device=device, dtype=latents_dtype)
|
||||
|
||||
guidance: List[Callable] = []
|
||||
|
||||
if is_inpainting_model(self.unet):
|
||||
# TODO: we should probably pass this in so we don't have to try/finally around setting it.
|
||||
|
||||
# You'd think the inpainting model wouldn't be paying attention to the area it is going to repaint
|
||||
# (that's why there's a mask!) but it seems to really want that blanked out.
|
||||
masked_latents = self.non_noised_latents_from_image(
|
||||
init_image * (1 - mask), device=device, dtype=latents_dtype)
|
||||
self.invokeai_diffuser.model_forward_callback = \
|
||||
AddsMaskLatents(self._unet_forward, mask, init_image_latents)
|
||||
AddsMaskLatents(self._unet_forward, latent_mask, masked_latents)
|
||||
else:
|
||||
guidance.append(AddsMaskGuidance(mask, init_image_latents, self.scheduler, noise))
|
||||
guidance.append(AddsMaskGuidance(latent_mask, init_image_latents, self.scheduler, noise))
|
||||
|
||||
try:
|
||||
result_latents, result_attention_maps = self.latents_from_embeddings(
|
||||
|
Loading…
Reference in New Issue
Block a user