From f3570d8344f84f8f1567c3ff63fc33a46dcfee59 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Sun, 4 Dec 2022 23:36:12 -0800 Subject: [PATCH] inpainting for the normal model [WIP] This seems to be performing well until the LAST STEP, at which point it dissolves to confetti. --- ldm/invoke/generator/diffusers_pipeline.py | 49 ++++++++++++++++++---- 1 file changed, 40 insertions(+), 9 deletions(-) diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index 430e051ca6..67a2f2fba6 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -22,6 +22,7 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers.schedulers.scheduling_utils import SchedulerMixin from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from torchvision.transforms.functional import resize as tv_resize from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer @@ -61,10 +62,32 @@ class AddsMaskLatents: mask = einops.repeat(self.mask, 'b c h w -> (repeat b) c h w', repeat=batch_size) mask_latents = einops.repeat(self.mask_latents, 'b c h w -> (repeat b) c h w', repeat=batch_size) model_input, _ = einops.pack([latents, mask, mask_latents], 'b * h w') - # model_input = torch.cat([latents, mask, mask_latents], dim=1) return self.forward(model_input, t, text_embeddings) +@dataclass +class AddsMaskGuidance: + forward: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor] + mask: torch.FloatTensor + mask_latents: torch.FloatTensor + _scheduler: SchedulerMixin + _noise_func: Callable + _debug: Optional[Callable] = None + + def __call__(self, latents: torch.FloatTensor, t: torch.Tensor, text_embeddings: torch.FloatTensor) -> torch.Tensor: + batch_size = latents.size(0) + mask = einops.repeat(self.mask, 'b c h w -> (repeat b) c h w', repeat=batch_size) + noise = self._noise_func(self.mask_latents) + mask_latents = self._scheduler.add_noise(self.mask_latents, noise, t[0]) # .to(dtype=mask_latents.dtype) + mask_latents = einops.repeat(mask_latents, 'b c h w -> (repeat b) c h w', repeat=batch_size) + # if self._debug: + # self._debug(latents, f"t={t[0]} latents") + masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype)) + if self._debug: + self._debug(masked_input, f"t={t[0]} lerped") + return self.forward(masked_input, t, text_embeddings) + + def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True, multiple_of=8) -> torch.FloatTensor: """ @@ -382,17 +405,18 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) latents, init_image_latents = self.prepare_latents_from_image(init_image, latent_timestep, latents_dtype, device, noise_func) - if is_inpainting_model(self.unet): - if mask.dim() == 3: - mask = mask.unsqueeze(0) - mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR)\ - .to(device=device, dtype=latents_dtype) + if mask.dim() == 3: + mask = mask.unsqueeze(0) + mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR) \ + .to(device=device, dtype=latents_dtype) + if is_inpainting_model(self.unet): self.invokeai_diffuser.model_forward_callback = \ AddsMaskLatents(self._unet_forward, mask, init_image_latents) else: - # FIXME: need to add guidance that applies mask - pass + self.invokeai_diffuser.model_forward_callback = \ + AddsMaskGuidance(self._unet_forward, mask, init_image_latents, + self.scheduler, noise_func) # self.debug_latents) result = None @@ -417,7 +441,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): init_image = init_image.to(device=device, dtype=dtype) with torch.inference_mode(): init_latent_dist = self.vae.encode(init_image).latent_dist - init_latents = init_latent_dist.sample() # FIXME: uses torch.randn. make reproducible! + init_latents = init_latent_dist.sample().to(dtype=dtype) # FIXME: uses torch.randn. make reproducible! init_latents = 0.18215 * init_latents noise = noise_func(init_latents) @@ -456,3 +480,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): def channels(self) -> int: """Compatible with DiffusionWrapper""" return self.unet.in_channels + + def debug_latents(self, latents, msg): + with torch.inference_mode(): + from ldm.util import debug_image + decoded = self.numpy_to_pil(self.decode_latents(latents)) + for i, img in enumerate(decoded): + debug_image(img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True)