diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index 174a57fd85..430e051ca6 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -9,10 +9,10 @@ import PIL.Image import einops import torch import torchvision.transforms as T +from diffusers.models import attention from ldm.models.diffusion.cross_attention_control import InvokeAICrossAttention -from diffusers.models import attention # monkeypatch diffusers CrossAttention 🙈 # this is to make prompt2prompt and (future) attention maps work attention.CrossAttention = InvokeAICrossAttention @@ -23,6 +23,7 @@ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import Stabl from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from torchvision.transforms.functional import resize as tv_resize from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent @@ -49,6 +50,21 @@ _default_personalization_config_params = dict( ) +@dataclass +class AddsMaskLatents: + forward: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor] + mask: torch.FloatTensor + mask_latents: torch.FloatTensor + + def __call__(self, latents: torch.FloatTensor, t: torch.Tensor, text_embeddings: torch.FloatTensor) -> torch.Tensor: + batch_size = latents.size(0) + mask = einops.repeat(self.mask, 'b c h w -> (repeat b) c h w', repeat=batch_size) + mask_latents = einops.repeat(self.mask_latents, 'b c h w -> (repeat b) c h w', repeat=batch_size) + model_input, _ = einops.pack([latents, mask, mask_latents], 'b * h w') + # model_input = torch.cat([latents, mask, mask_latents], dim=1) + return self.forward(model_input, t, text_embeddings) + + def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True, multiple_of=8) -> torch.FloatTensor: """ @@ -57,7 +73,7 @@ def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True :param multiple_of: resize the input so both dimensions are a multiple of this """ w, h = image.size - w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8 + w, h = map(lambda x: x - x % multiple_of, (w, h)) # resize to integer multiple of 8 transformation = T.Compose([ T.Resize((h, w), T.InterpolationMode.LANCZOS), T.ToTensor(), @@ -68,6 +84,10 @@ def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True return tensor +def is_inpainting_model(unet: UNet2DConditionModel): + return unet.conv_in.in_channels == 9 + + class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): r""" Pipeline for text-to-image generation using Stable Diffusion. @@ -314,7 +334,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 6. Prepare latent variables - latents = self.prepare_latents_from_image(init_image, latent_timestep, latents_dtype, device, noise_func) + latents, _ = self.prepare_latents_from_image(init_image, latent_timestep, latents_dtype, device, noise_func) result = None for result in self.generate_from_embeddings( @@ -331,7 +351,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): def inpaint_from_embeddings( self, init_image: torch.FloatTensor, - mask_image: torch.FloatTensor, + mask: torch.FloatTensor, strength: float, num_inference_steps: int, text_embeddings: torch.Tensor, unconditioned_embeddings: torch.Tensor, @@ -349,8 +369,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): if isinstance(init_image, PIL.Image.Image): init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB')) + init_image = init_image.to(device=device, dtype=latents_dtype) + if init_image.dim() == 3: - init_image = einops.rearrange(init_image, 'c h w -> 1 c h w') + init_image = init_image.unsqueeze(0) img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components) img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device) @@ -358,22 +380,38 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): # 6. Prepare latent variables latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) - latents = self.prepare_latents_from_image(init_image, latent_timestep, latents_dtype, device, noise_func) + latents, init_image_latents = self.prepare_latents_from_image(init_image, latent_timestep, latents_dtype, device, noise_func) + + if is_inpainting_model(self.unet): + if mask.dim() == 3: + mask = mask.unsqueeze(0) + mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR)\ + .to(device=device, dtype=latents_dtype) + + self.invokeai_diffuser.model_forward_callback = \ + AddsMaskLatents(self._unet_forward, mask, init_image_latents) + else: + # FIXME: need to add guidance that applies mask + pass result = None - for result in self.generate_from_embeddings( - latents, text_embeddings, unconditioned_embeddings, guidance_scale, - extra_conditioning_info=extra_conditioning_info, - timesteps=timesteps, - run_id=run_id, **extra_step_kwargs): - if callback is not None and isinstance(result, PipelineIntermediateState): - callback(result) - if result is None: - raise AssertionError("why was that an empty generator?") - return result + + try: + for result in self.generate_from_embeddings( + latents, text_embeddings, unconditioned_embeddings, guidance_scale, + extra_conditioning_info=extra_conditioning_info, + timesteps=timesteps, + run_id=run_id, **extra_step_kwargs): + if callback is not None and isinstance(result, PipelineIntermediateState): + callback(result) + if result is None: + raise AssertionError("why was that an empty generator?") + return result + finally: + self.invokeai_diffuser.model_forward_callback = self._unet_forward - def prepare_latents_from_image(self, init_image, timestep, dtype, device, noise_func) -> torch.FloatTensor: + def prepare_latents_from_image(self, init_image, timestep, dtype, device, noise_func) -> (torch.FloatTensor, torch.FloatTensor): # can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents # because we have our own noise function init_image = init_image.to(device=device, dtype=dtype) @@ -383,8 +421,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): init_latents = 0.18215 * init_latents noise = noise_func(init_latents) - - return self.scheduler.add_noise(init_latents, noise, timestep) + noised_latents = self.scheduler.add_noise(init_latents, noise, timestep) + return noised_latents, init_latents def check_for_safety(self, output, dtype): with torch.inference_mode(): diff --git a/ldm/invoke/generator/inpaint.py b/ldm/invoke/generator/inpaint.py index 25c776de55..976121d720 100644 --- a/ldm/invoke/generator/inpaint.py +++ b/ldm/invoke/generator/inpaint.py @@ -10,9 +10,7 @@ import cv2 import numpy as np import torch from PIL import Image, ImageFilter, ImageOps, ImageChops -from einops import repeat -from ldm.invoke.generator.base import downsampling from ldm.invoke.generator.diffusers_pipeline import image_resized_to_grid_as_tensor, StableDiffusionGeneratorPipeline from ldm.invoke.generator.img2img import Img2Img from ldm.invoke.globals import Globals @@ -154,7 +152,7 @@ class Inpaint(Img2Img): ddim_eta, conditioning, init_image = im.copy().convert('RGBA'), - mask_image = mask.convert('RGB'), # Code currently requires an RGB mask + mask_image = mask, strength = strength, mask_blur_radius = 0, seam_size = 0, @@ -228,7 +226,11 @@ class Inpaint(Img2Img): self.pil_mask = mask_image.copy() debug_image(mask_image, "mask_image BEFORE multiply with pil_image", debug_status=self.enable_image_debugging) - mask_image = ImageChops.multiply(mask_image, self.pil_image.split()[-1].convert('RGB')) + init_alpha = self.pil_image.getchannel("A") + if mask_image.mode != "L": + # FIXME: why do we get passed an RGB image here? We can only use single-channel. + mask_image = mask_image.convert("L") + mask_image = ImageChops.multiply(mask_image, init_alpha) self.pil_mask = mask_image # Resize if requested for inpainting @@ -236,57 +238,32 @@ class Inpaint(Img2Img): mask_image = mask_image.resize((inpaint_width, inpaint_height)) debug_image(mask_image, "mask_image AFTER multiply with pil_image", debug_status=self.enable_image_debugging) - mask_image = mask_image.resize( - ( - mask_image.width // downsampling, - mask_image.height // downsampling - ), - resample=Image.Resampling.NEAREST - ) - mask_image = image_resized_to_grid_as_tensor(mask_image, normalize=False) + mask: torch.FloatTensor = image_resized_to_grid_as_tensor(mask_image, normalize=False) + else: + mask: torch.FloatTensor = mask_image self.mask_blur_radius = mask_blur_radius - # klms samplers not supported yet, so ignore previous sampler - # if isinstance(sampler,KSampler): - # print( - # ">> Using recommended DDIM sampler for inpainting." - # ) - # sampler = DDIMSampler(self.model, device=self.model.device) - - mask_image = mask_image[0][0].unsqueeze(0).repeat(4,1,1).unsqueeze(0) - mask_image = repeat(mask_image, '1 ... -> b ...', b=1) - - t_enc = int(strength * steps) # todo: support cross-attention control uc, c, _ = conditioning - print(f">> target t_enc is {t_enc} steps") - # noinspection PyTypeChecker pipeline: StableDiffusionGeneratorPipeline = self.model pipeline.scheduler = sampler def make_image(x_T): # FIXME: some of this z_enc and inpaint_replace stuff was probably important - # encode (scaled latent) - # z_enc = sampler.stochastic_encode( - # self.init_latent, - # torch.tensor([t_enc]).to(self.model.device), - # noise=x_T - # ) - # # # to replace masked area with latent noise, weighted by inpaint_replace strength # if inpaint_replace > 0.0: # print(f'>> inpaint will replace what was under the mask with a strength of {inpaint_replace}') # l_noise = self.get_noise(kwargs['width'],kwargs['height']) - # inverted_mask = 1.0-mask_image # there will be 1s where the mask is + # inverted_mask = 1.0-mask # there will be 1s where the mask is # masked_region = (1.0-inpaint_replace) * inverted_mask * z_enc + inpaint_replace * inverted_mask * l_noise - # z_enc = z_enc * mask_image + masked_region + # z_enc = z_enc * mask + masked_region pipeline_output = pipeline.inpaint_from_embeddings( init_image=init_image, - mask_image=mask_image, + mask=1 - mask, # expects white means "paint here." strength=strength, num_inference_steps=steps, text_embeddings=c,