diffusers support for the inpainting model

This commit is contained in:
Kevin Turner 2022-12-04 20:12:04 -08:00
parent ff42027a00
commit 875312080d
2 changed files with 69 additions and 54 deletions

View File

@ -9,10 +9,10 @@ import PIL.Image
import einops import einops
import torch import torch
import torchvision.transforms as T import torchvision.transforms as T
from diffusers.models import attention
from ldm.models.diffusion.cross_attention_control import InvokeAICrossAttention from ldm.models.diffusion.cross_attention_control import InvokeAICrossAttention
from diffusers.models import attention
# monkeypatch diffusers CrossAttention 🙈 # monkeypatch diffusers CrossAttention 🙈
# this is to make prompt2prompt and (future) attention maps work # this is to make prompt2prompt and (future) attention maps work
attention.CrossAttention = InvokeAICrossAttention attention.CrossAttention = InvokeAICrossAttention
@ -23,6 +23,7 @@ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import Stabl
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
@ -49,6 +50,21 @@ _default_personalization_config_params = dict(
) )
@dataclass
class AddsMaskLatents:
forward: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
mask: torch.FloatTensor
mask_latents: torch.FloatTensor
def __call__(self, latents: torch.FloatTensor, t: torch.Tensor, text_embeddings: torch.FloatTensor) -> torch.Tensor:
batch_size = latents.size(0)
mask = einops.repeat(self.mask, 'b c h w -> (repeat b) c h w', repeat=batch_size)
mask_latents = einops.repeat(self.mask_latents, 'b c h w -> (repeat b) c h w', repeat=batch_size)
model_input, _ = einops.pack([latents, mask, mask_latents], 'b * h w')
# model_input = torch.cat([latents, mask, mask_latents], dim=1)
return self.forward(model_input, t, text_embeddings)
def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True, multiple_of=8) -> torch.FloatTensor: def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True, multiple_of=8) -> torch.FloatTensor:
""" """
@ -57,7 +73,7 @@ def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True
:param multiple_of: resize the input so both dimensions are a multiple of this :param multiple_of: resize the input so both dimensions are a multiple of this
""" """
w, h = image.size w, h = image.size
w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8 w, h = map(lambda x: x - x % multiple_of, (w, h)) # resize to integer multiple of 8
transformation = T.Compose([ transformation = T.Compose([
T.Resize((h, w), T.InterpolationMode.LANCZOS), T.Resize((h, w), T.InterpolationMode.LANCZOS),
T.ToTensor(), T.ToTensor(),
@ -68,6 +84,10 @@ def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True
return tensor return tensor
def is_inpainting_model(unet: UNet2DConditionModel):
return unet.conv_in.in_channels == 9
class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
r""" r"""
Pipeline for text-to-image generation using Stable Diffusion. Pipeline for text-to-image generation using Stable Diffusion.
@ -314,7 +334,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
# 6. Prepare latent variables # 6. Prepare latent variables
latents = self.prepare_latents_from_image(init_image, latent_timestep, latents_dtype, device, noise_func) latents, _ = self.prepare_latents_from_image(init_image, latent_timestep, latents_dtype, device, noise_func)
result = None result = None
for result in self.generate_from_embeddings( for result in self.generate_from_embeddings(
@ -331,7 +351,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
def inpaint_from_embeddings( def inpaint_from_embeddings(
self, self,
init_image: torch.FloatTensor, init_image: torch.FloatTensor,
mask_image: torch.FloatTensor, mask: torch.FloatTensor,
strength: float, strength: float,
num_inference_steps: int, num_inference_steps: int,
text_embeddings: torch.Tensor, unconditioned_embeddings: torch.Tensor, text_embeddings: torch.Tensor, unconditioned_embeddings: torch.Tensor,
@ -349,8 +369,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if isinstance(init_image, PIL.Image.Image): if isinstance(init_image, PIL.Image.Image):
init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB')) init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB'))
init_image = init_image.to(device=device, dtype=latents_dtype)
if init_image.dim() == 3: if init_image.dim() == 3:
init_image = einops.rearrange(init_image, 'c h w -> 1 c h w') init_image = init_image.unsqueeze(0)
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components) img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device) img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
@ -358,22 +380,38 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# 6. Prepare latent variables # 6. Prepare latent variables
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
latents = self.prepare_latents_from_image(init_image, latent_timestep, latents_dtype, device, noise_func) latents, init_image_latents = self.prepare_latents_from_image(init_image, latent_timestep, latents_dtype, device, noise_func)
if is_inpainting_model(self.unet):
if mask.dim() == 3:
mask = mask.unsqueeze(0)
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR)\
.to(device=device, dtype=latents_dtype)
self.invokeai_diffuser.model_forward_callback = \
AddsMaskLatents(self._unet_forward, mask, init_image_latents)
else:
# FIXME: need to add guidance that applies mask
pass
result = None result = None
for result in self.generate_from_embeddings(
latents, text_embeddings, unconditioned_embeddings, guidance_scale, try:
extra_conditioning_info=extra_conditioning_info, for result in self.generate_from_embeddings(
timesteps=timesteps, latents, text_embeddings, unconditioned_embeddings, guidance_scale,
run_id=run_id, **extra_step_kwargs): extra_conditioning_info=extra_conditioning_info,
if callback is not None and isinstance(result, PipelineIntermediateState): timesteps=timesteps,
callback(result) run_id=run_id, **extra_step_kwargs):
if result is None: if callback is not None and isinstance(result, PipelineIntermediateState):
raise AssertionError("why was that an empty generator?") callback(result)
return result if result is None:
raise AssertionError("why was that an empty generator?")
return result
finally:
self.invokeai_diffuser.model_forward_callback = self._unet_forward
def prepare_latents_from_image(self, init_image, timestep, dtype, device, noise_func) -> torch.FloatTensor: def prepare_latents_from_image(self, init_image, timestep, dtype, device, noise_func) -> (torch.FloatTensor, torch.FloatTensor):
# can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents # can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents
# because we have our own noise function # because we have our own noise function
init_image = init_image.to(device=device, dtype=dtype) init_image = init_image.to(device=device, dtype=dtype)
@ -383,8 +421,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
init_latents = 0.18215 * init_latents init_latents = 0.18215 * init_latents
noise = noise_func(init_latents) noise = noise_func(init_latents)
noised_latents = self.scheduler.add_noise(init_latents, noise, timestep)
return self.scheduler.add_noise(init_latents, noise, timestep) return noised_latents, init_latents
def check_for_safety(self, output, dtype): def check_for_safety(self, output, dtype):
with torch.inference_mode(): with torch.inference_mode():

View File

@ -10,9 +10,7 @@ import cv2
import numpy as np import numpy as np
import torch import torch
from PIL import Image, ImageFilter, ImageOps, ImageChops from PIL import Image, ImageFilter, ImageOps, ImageChops
from einops import repeat
from ldm.invoke.generator.base import downsampling
from ldm.invoke.generator.diffusers_pipeline import image_resized_to_grid_as_tensor, StableDiffusionGeneratorPipeline from ldm.invoke.generator.diffusers_pipeline import image_resized_to_grid_as_tensor, StableDiffusionGeneratorPipeline
from ldm.invoke.generator.img2img import Img2Img from ldm.invoke.generator.img2img import Img2Img
from ldm.invoke.globals import Globals from ldm.invoke.globals import Globals
@ -154,7 +152,7 @@ class Inpaint(Img2Img):
ddim_eta, ddim_eta,
conditioning, conditioning,
init_image = im.copy().convert('RGBA'), init_image = im.copy().convert('RGBA'),
mask_image = mask.convert('RGB'), # Code currently requires an RGB mask mask_image = mask,
strength = strength, strength = strength,
mask_blur_radius = 0, mask_blur_radius = 0,
seam_size = 0, seam_size = 0,
@ -228,7 +226,11 @@ class Inpaint(Img2Img):
self.pil_mask = mask_image.copy() self.pil_mask = mask_image.copy()
debug_image(mask_image, "mask_image BEFORE multiply with pil_image", debug_status=self.enable_image_debugging) debug_image(mask_image, "mask_image BEFORE multiply with pil_image", debug_status=self.enable_image_debugging)
mask_image = ImageChops.multiply(mask_image, self.pil_image.split()[-1].convert('RGB')) init_alpha = self.pil_image.getchannel("A")
if mask_image.mode != "L":
# FIXME: why do we get passed an RGB image here? We can only use single-channel.
mask_image = mask_image.convert("L")
mask_image = ImageChops.multiply(mask_image, init_alpha)
self.pil_mask = mask_image self.pil_mask = mask_image
# Resize if requested for inpainting # Resize if requested for inpainting
@ -236,57 +238,32 @@ class Inpaint(Img2Img):
mask_image = mask_image.resize((inpaint_width, inpaint_height)) mask_image = mask_image.resize((inpaint_width, inpaint_height))
debug_image(mask_image, "mask_image AFTER multiply with pil_image", debug_status=self.enable_image_debugging) debug_image(mask_image, "mask_image AFTER multiply with pil_image", debug_status=self.enable_image_debugging)
mask_image = mask_image.resize( mask: torch.FloatTensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
( else:
mask_image.width // downsampling, mask: torch.FloatTensor = mask_image
mask_image.height // downsampling
),
resample=Image.Resampling.NEAREST
)
mask_image = image_resized_to_grid_as_tensor(mask_image, normalize=False)
self.mask_blur_radius = mask_blur_radius self.mask_blur_radius = mask_blur_radius
# klms samplers not supported yet, so ignore previous sampler
# if isinstance(sampler,KSampler):
# print(
# ">> Using recommended DDIM sampler for inpainting."
# )
# sampler = DDIMSampler(self.model, device=self.model.device)
mask_image = mask_image[0][0].unsqueeze(0).repeat(4,1,1).unsqueeze(0)
mask_image = repeat(mask_image, '1 ... -> b ...', b=1)
t_enc = int(strength * steps)
# todo: support cross-attention control # todo: support cross-attention control
uc, c, _ = conditioning uc, c, _ = conditioning
print(f">> target t_enc is {t_enc} steps")
# noinspection PyTypeChecker # noinspection PyTypeChecker
pipeline: StableDiffusionGeneratorPipeline = self.model pipeline: StableDiffusionGeneratorPipeline = self.model
pipeline.scheduler = sampler pipeline.scheduler = sampler
def make_image(x_T): def make_image(x_T):
# FIXME: some of this z_enc and inpaint_replace stuff was probably important # FIXME: some of this z_enc and inpaint_replace stuff was probably important
# encode (scaled latent)
# z_enc = sampler.stochastic_encode(
# self.init_latent,
# torch.tensor([t_enc]).to(self.model.device),
# noise=x_T
# )
#
# # to replace masked area with latent noise, weighted by inpaint_replace strength # # to replace masked area with latent noise, weighted by inpaint_replace strength
# if inpaint_replace > 0.0: # if inpaint_replace > 0.0:
# print(f'>> inpaint will replace what was under the mask with a strength of {inpaint_replace}') # print(f'>> inpaint will replace what was under the mask with a strength of {inpaint_replace}')
# l_noise = self.get_noise(kwargs['width'],kwargs['height']) # l_noise = self.get_noise(kwargs['width'],kwargs['height'])
# inverted_mask = 1.0-mask_image # there will be 1s where the mask is # inverted_mask = 1.0-mask # there will be 1s where the mask is
# masked_region = (1.0-inpaint_replace) * inverted_mask * z_enc + inpaint_replace * inverted_mask * l_noise # masked_region = (1.0-inpaint_replace) * inverted_mask * z_enc + inpaint_replace * inverted_mask * l_noise
# z_enc = z_enc * mask_image + masked_region # z_enc = z_enc * mask + masked_region
pipeline_output = pipeline.inpaint_from_embeddings( pipeline_output = pipeline.inpaint_from_embeddings(
init_image=init_image, init_image=init_image,
mask_image=mask_image, mask=1 - mask, # expects white means "paint here."
strength=strength, strength=strength,
num_inference_steps=steps, num_inference_steps=steps,
text_embeddings=c, text_embeddings=c,