inpainting for the normal model [WIP]

This seems to be performing well until the LAST STEP, at which point it dissolves to confetti.
This commit is contained in:
Kevin Turner 2022-12-04 23:36:12 -08:00
parent b2664e807e
commit f3570d8344

View File

@ -22,6 +22,7 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
@ -61,10 +62,32 @@ class AddsMaskLatents:
mask = einops.repeat(self.mask, 'b c h w -> (repeat b) c h w', repeat=batch_size)
mask_latents = einops.repeat(self.mask_latents, 'b c h w -> (repeat b) c h w', repeat=batch_size)
model_input, _ = einops.pack([latents, mask, mask_latents], 'b * h w')
# model_input = torch.cat([latents, mask, mask_latents], dim=1)
return self.forward(model_input, t, text_embeddings)
@dataclass
class AddsMaskGuidance:
forward: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
mask: torch.FloatTensor
mask_latents: torch.FloatTensor
_scheduler: SchedulerMixin
_noise_func: Callable
_debug: Optional[Callable] = None
def __call__(self, latents: torch.FloatTensor, t: torch.Tensor, text_embeddings: torch.FloatTensor) -> torch.Tensor:
batch_size = latents.size(0)
mask = einops.repeat(self.mask, 'b c h w -> (repeat b) c h w', repeat=batch_size)
noise = self._noise_func(self.mask_latents)
mask_latents = self._scheduler.add_noise(self.mask_latents, noise, t[0]) # .to(dtype=mask_latents.dtype)
mask_latents = einops.repeat(mask_latents, 'b c h w -> (repeat b) c h w', repeat=batch_size)
# if self._debug:
# self._debug(latents, f"t={t[0]} latents")
masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype))
if self._debug:
self._debug(masked_input, f"t={t[0]} lerped")
return self.forward(masked_input, t, text_embeddings)
def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True, multiple_of=8) -> torch.FloatTensor:
"""
@ -382,17 +405,18 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
latents, init_image_latents = self.prepare_latents_from_image(init_image, latent_timestep, latents_dtype, device, noise_func)
if is_inpainting_model(self.unet):
if mask.dim() == 3:
mask = mask.unsqueeze(0)
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR)\
.to(device=device, dtype=latents_dtype)
if mask.dim() == 3:
mask = mask.unsqueeze(0)
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR) \
.to(device=device, dtype=latents_dtype)
if is_inpainting_model(self.unet):
self.invokeai_diffuser.model_forward_callback = \
AddsMaskLatents(self._unet_forward, mask, init_image_latents)
else:
# FIXME: need to add guidance that applies mask
pass
self.invokeai_diffuser.model_forward_callback = \
AddsMaskGuidance(self._unet_forward, mask, init_image_latents,
self.scheduler, noise_func) # self.debug_latents)
result = None
@ -417,7 +441,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
init_image = init_image.to(device=device, dtype=dtype)
with torch.inference_mode():
init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample() # FIXME: uses torch.randn. make reproducible!
init_latents = init_latent_dist.sample().to(dtype=dtype) # FIXME: uses torch.randn. make reproducible!
init_latents = 0.18215 * init_latents
noise = noise_func(init_latents)
@ -456,3 +480,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
def channels(self) -> int:
"""Compatible with DiffusionWrapper"""
return self.unet.in_channels
def debug_latents(self, latents, msg):
with torch.inference_mode():
from ldm.util import debug_image
decoded = self.numpy_to_pil(self.decode_latents(latents))
for i, img in enumerate(decoded):
debug_image(img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True)