diffusers support for the inpainting model

This commit is contained in:
Kevin Turner 2022-12-04 20:12:04 -08:00
parent ff42027a00
commit 875312080d
2 changed files with 69 additions and 54 deletions

View File

@ -9,10 +9,10 @@ import PIL.Image
import einops
import torch
import torchvision.transforms as T
from diffusers.models import attention
from ldm.models.diffusion.cross_attention_control import InvokeAICrossAttention
from diffusers.models import attention
# monkeypatch diffusers CrossAttention 🙈
# this is to make prompt2prompt and (future) attention maps work
attention.CrossAttention = InvokeAICrossAttention
@ -23,6 +23,7 @@ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import Stabl
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
@ -49,6 +50,21 @@ _default_personalization_config_params = dict(
)
@dataclass
class AddsMaskLatents:
forward: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
mask: torch.FloatTensor
mask_latents: torch.FloatTensor
def __call__(self, latents: torch.FloatTensor, t: torch.Tensor, text_embeddings: torch.FloatTensor) -> torch.Tensor:
batch_size = latents.size(0)
mask = einops.repeat(self.mask, 'b c h w -> (repeat b) c h w', repeat=batch_size)
mask_latents = einops.repeat(self.mask_latents, 'b c h w -> (repeat b) c h w', repeat=batch_size)
model_input, _ = einops.pack([latents, mask, mask_latents], 'b * h w')
# model_input = torch.cat([latents, mask, mask_latents], dim=1)
return self.forward(model_input, t, text_embeddings)
def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True, multiple_of=8) -> torch.FloatTensor:
"""
@ -57,7 +73,7 @@ def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True
:param multiple_of: resize the input so both dimensions are a multiple of this
"""
w, h = image.size
w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8
w, h = map(lambda x: x - x % multiple_of, (w, h)) # resize to integer multiple of 8
transformation = T.Compose([
T.Resize((h, w), T.InterpolationMode.LANCZOS),
T.ToTensor(),
@ -68,6 +84,10 @@ def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True
return tensor
def is_inpainting_model(unet: UNet2DConditionModel):
return unet.conv_in.in_channels == 9
class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
r"""
Pipeline for text-to-image generation using Stable Diffusion.
@ -314,7 +334,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
# 6. Prepare latent variables
latents = self.prepare_latents_from_image(init_image, latent_timestep, latents_dtype, device, noise_func)
latents, _ = self.prepare_latents_from_image(init_image, latent_timestep, latents_dtype, device, noise_func)
result = None
for result in self.generate_from_embeddings(
@ -331,7 +351,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
def inpaint_from_embeddings(
self,
init_image: torch.FloatTensor,
mask_image: torch.FloatTensor,
mask: torch.FloatTensor,
strength: float,
num_inference_steps: int,
text_embeddings: torch.Tensor, unconditioned_embeddings: torch.Tensor,
@ -349,8 +369,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if isinstance(init_image, PIL.Image.Image):
init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB'))
init_image = init_image.to(device=device, dtype=latents_dtype)
if init_image.dim() == 3:
init_image = einops.rearrange(init_image, 'c h w -> 1 c h w')
init_image = init_image.unsqueeze(0)
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
@ -358,22 +380,38 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# 6. Prepare latent variables
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
latents = self.prepare_latents_from_image(init_image, latent_timestep, latents_dtype, device, noise_func)
latents, init_image_latents = self.prepare_latents_from_image(init_image, latent_timestep, latents_dtype, device, noise_func)
if is_inpainting_model(self.unet):
if mask.dim() == 3:
mask = mask.unsqueeze(0)
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR)\
.to(device=device, dtype=latents_dtype)
self.invokeai_diffuser.model_forward_callback = \
AddsMaskLatents(self._unet_forward, mask, init_image_latents)
else:
# FIXME: need to add guidance that applies mask
pass
result = None
for result in self.generate_from_embeddings(
latents, text_embeddings, unconditioned_embeddings, guidance_scale,
extra_conditioning_info=extra_conditioning_info,
timesteps=timesteps,
run_id=run_id, **extra_step_kwargs):
if callback is not None and isinstance(result, PipelineIntermediateState):
callback(result)
if result is None:
raise AssertionError("why was that an empty generator?")
return result
try:
for result in self.generate_from_embeddings(
latents, text_embeddings, unconditioned_embeddings, guidance_scale,
extra_conditioning_info=extra_conditioning_info,
timesteps=timesteps,
run_id=run_id, **extra_step_kwargs):
if callback is not None and isinstance(result, PipelineIntermediateState):
callback(result)
if result is None:
raise AssertionError("why was that an empty generator?")
return result
finally:
self.invokeai_diffuser.model_forward_callback = self._unet_forward
def prepare_latents_from_image(self, init_image, timestep, dtype, device, noise_func) -> torch.FloatTensor:
def prepare_latents_from_image(self, init_image, timestep, dtype, device, noise_func) -> (torch.FloatTensor, torch.FloatTensor):
# can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents
# because we have our own noise function
init_image = init_image.to(device=device, dtype=dtype)
@ -383,8 +421,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
init_latents = 0.18215 * init_latents
noise = noise_func(init_latents)
return self.scheduler.add_noise(init_latents, noise, timestep)
noised_latents = self.scheduler.add_noise(init_latents, noise, timestep)
return noised_latents, init_latents
def check_for_safety(self, output, dtype):
with torch.inference_mode():

View File

@ -10,9 +10,7 @@ import cv2
import numpy as np
import torch
from PIL import Image, ImageFilter, ImageOps, ImageChops
from einops import repeat
from ldm.invoke.generator.base import downsampling
from ldm.invoke.generator.diffusers_pipeline import image_resized_to_grid_as_tensor, StableDiffusionGeneratorPipeline
from ldm.invoke.generator.img2img import Img2Img
from ldm.invoke.globals import Globals
@ -154,7 +152,7 @@ class Inpaint(Img2Img):
ddim_eta,
conditioning,
init_image = im.copy().convert('RGBA'),
mask_image = mask.convert('RGB'), # Code currently requires an RGB mask
mask_image = mask,
strength = strength,
mask_blur_radius = 0,
seam_size = 0,
@ -228,7 +226,11 @@ class Inpaint(Img2Img):
self.pil_mask = mask_image.copy()
debug_image(mask_image, "mask_image BEFORE multiply with pil_image", debug_status=self.enable_image_debugging)
mask_image = ImageChops.multiply(mask_image, self.pil_image.split()[-1].convert('RGB'))
init_alpha = self.pil_image.getchannel("A")
if mask_image.mode != "L":
# FIXME: why do we get passed an RGB image here? We can only use single-channel.
mask_image = mask_image.convert("L")
mask_image = ImageChops.multiply(mask_image, init_alpha)
self.pil_mask = mask_image
# Resize if requested for inpainting
@ -236,57 +238,32 @@ class Inpaint(Img2Img):
mask_image = mask_image.resize((inpaint_width, inpaint_height))
debug_image(mask_image, "mask_image AFTER multiply with pil_image", debug_status=self.enable_image_debugging)
mask_image = mask_image.resize(
(
mask_image.width // downsampling,
mask_image.height // downsampling
),
resample=Image.Resampling.NEAREST
)
mask_image = image_resized_to_grid_as_tensor(mask_image, normalize=False)
mask: torch.FloatTensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
else:
mask: torch.FloatTensor = mask_image
self.mask_blur_radius = mask_blur_radius
# klms samplers not supported yet, so ignore previous sampler
# if isinstance(sampler,KSampler):
# print(
# ">> Using recommended DDIM sampler for inpainting."
# )
# sampler = DDIMSampler(self.model, device=self.model.device)
mask_image = mask_image[0][0].unsqueeze(0).repeat(4,1,1).unsqueeze(0)
mask_image = repeat(mask_image, '1 ... -> b ...', b=1)
t_enc = int(strength * steps)
# todo: support cross-attention control
uc, c, _ = conditioning
print(f">> target t_enc is {t_enc} steps")
# noinspection PyTypeChecker
pipeline: StableDiffusionGeneratorPipeline = self.model
pipeline.scheduler = sampler
def make_image(x_T):
# FIXME: some of this z_enc and inpaint_replace stuff was probably important
# encode (scaled latent)
# z_enc = sampler.stochastic_encode(
# self.init_latent,
# torch.tensor([t_enc]).to(self.model.device),
# noise=x_T
# )
#
# # to replace masked area with latent noise, weighted by inpaint_replace strength
# if inpaint_replace > 0.0:
# print(f'>> inpaint will replace what was under the mask with a strength of {inpaint_replace}')
# l_noise = self.get_noise(kwargs['width'],kwargs['height'])
# inverted_mask = 1.0-mask_image # there will be 1s where the mask is
# inverted_mask = 1.0-mask # there will be 1s where the mask is
# masked_region = (1.0-inpaint_replace) * inverted_mask * z_enc + inpaint_replace * inverted_mask * l_noise
# z_enc = z_enc * mask_image + masked_region
# z_enc = z_enc * mask + masked_region
pipeline_output = pipeline.inpaint_from_embeddings(
init_image=init_image,
mask_image=mask_image,
mask=1 - mask, # expects white means "paint here."
strength=strength,
num_inference_steps=steps,
text_embeddings=c,