preliminary support for outpainting (no masking yet)

This commit is contained in:
Kevin Turner 2022-11-30 22:05:58 -08:00
parent b02f3688a6
commit ea1cf83c20
2 changed files with 127 additions and 77 deletions

View File

@ -6,12 +6,13 @@ from dataclasses import dataclass
from typing import List, Optional, Union, Callable
import PIL.Image
import einops
import torch
import torchvision.transforms as T
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import preprocess, \
StableDiffusionImg2ImgPipeline
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
@ -40,6 +41,25 @@ _default_personalization_config_params = dict(
)
def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True, multiple_of=8) -> torch.FloatTensor:
"""
:param image: input image
:param normalize: scale the range to [-1, 1] instead of [0, 1]
:param multiple_of: resize the input so both dimensions are a multiple of this
"""
w, h = image.size
w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8
transformation = T.Compose([
T.Resize((h, w), T.InterpolationMode.LANCZOS),
T.ToTensor(),
])
tensor = transformation(image)
if normalize:
tensor = tensor * 2.0 - 1.0
return tensor
class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
r"""
Pipeline for text-to-image generation using Stable Diffusion.
@ -270,12 +290,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise_func=None,
**extra_step_kwargs) -> StableDiffusionPipelineOutput:
device = self.unet.device
latents_dtype = text_embeddings.dtype
latents_dtype = self.unet.dtype
batch_size = 1
num_images_per_prompt = 1
if isinstance(init_image, PIL.Image.Image):
init_image = preprocess(init_image.convert('RGB'))
init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB'))
if init_image.dim() == 3:
init_image = einops.rearrange(init_image, 'c h w -> 1 c h w')
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
@ -297,6 +320,51 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
raise AssertionError("why was that an empty generator?")
return result
def inpaint_from_embeddings(
self,
init_image: torch.FloatTensor,
mask_image: torch.FloatTensor,
strength: float,
num_inference_steps: int,
text_embeddings: torch.Tensor, unconditioned_embeddings: torch.Tensor,
guidance_scale: float,
*, callback: Callable[[PipelineIntermediateState], None] = None,
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None,
run_id=None,
noise_func=None,
**extra_step_kwargs) -> StableDiffusionPipelineOutput:
device = self.unet.device
latents_dtype = self.unet.dtype
batch_size = 1
num_images_per_prompt = 1
if isinstance(init_image, PIL.Image.Image):
init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB'))
if init_image.dim() == 3:
init_image = einops.rearrange(init_image, 'c h w -> 1 c h w')
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = img2img_pipeline.get_timesteps(num_inference_steps, strength, device=device)
# 6. Prepare latent variables
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
latents = self.prepare_latents_from_image(init_image, latent_timestep, latents_dtype, device, noise_func)
result = None
for result in self.generate_from_embeddings(
latents, text_embeddings, unconditioned_embeddings, guidance_scale,
extra_conditioning_info=extra_conditioning_info,
timesteps=timesteps,
run_id=run_id, **extra_step_kwargs):
if callback is not None and isinstance(result, PipelineIntermediateState):
callback(result)
if result is None:
raise AssertionError("why was that an empty generator?")
return result
def prepare_latents_from_image(self, init_image, timestep, dtype, device, noise_func) -> torch.FloatTensor:
# can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents
# because we have our own noise function

View File

@ -1,22 +1,21 @@
'''
ldm.invoke.generator.inpaint descends from ldm.invoke.generator
'''
from __future__ import annotations
import math
import PIL
import cv2 as cv
import cv2
import numpy as np
import torch
from PIL import Image, ImageFilter, ImageOps, ImageChops
from einops import repeat
from ldm.invoke.devices import choose_autocast
from ldm.invoke.generator.base import downsampling
from ldm.invoke.generator.diffusers_pipeline import image_resized_to_grid_as_tensor, StableDiffusionGeneratorPipeline
from ldm.invoke.generator.img2img import Img2Img
from ldm.invoke.globals import Globals
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.ksampler import KSampler
from ldm.util import debug_image
infill_methods: list[str] = list()
@ -36,6 +35,9 @@ infill_methods.append('tile')
class Inpaint(Img2Img):
def __init__(self, model, precision):
self.inpaint_height = 0
self.inpaint_width = 0
self.enable_image_debugging = False
self.init_latent = None
self.pil_image = None
self.pil_mask = None
@ -123,13 +125,13 @@ class Inpaint(Img2Img):
npgradient = np.uint8(255 * (1.0 - np.floor(np.abs(0.5 - np.float32(npimg) / 255.0) * 2.0)))
# Detect hard edges
npedge = cv.Canny(npimg, threshold1=100, threshold2=200)
npedge = cv2.Canny(npimg, threshold1=100, threshold2=200)
# Combine
npmask = npgradient + npedge
# Expand
npmask = cv.dilate(npmask, np.ones((3,3), np.uint8), iterations = int(edge_size / 2))
npmask = cv2.dilate(npmask, np.ones((3,3), np.uint8), iterations = int(edge_size / 2))
new_mask = Image.fromarray(npmask)
@ -139,15 +141,8 @@ class Inpaint(Img2Img):
return ImageOps.invert(new_mask)
def seam_paint(self,
im: Image.Image,
seam_size: int,
seam_blur: int,
prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,strength,
noise,
step_callback
) -> Image.Image:
def seam_paint(self, im: Image.Image, seam_size: int, seam_blur: int, prompt, sampler, steps, cfg_scale, ddim_eta,
conditioning, strength, noise, infill_method, step_callback) -> Image.Image:
hard_mask = self.pil_image.split()[-1].copy()
mask = self.mask_edge(hard_mask, seam_size, seam_blur)
@ -165,7 +160,8 @@ class Inpaint(Img2Img):
seam_size = 0,
step_callback = step_callback,
inpaint_width = im.width,
inpaint_height = im.height
inpaint_height = im.height,
infill_method = infill_method
)
seam_noise = self.get_noise(im.width, im.height)
@ -177,7 +173,10 @@ class Inpaint(Img2Img):
@torch.no_grad()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,init_image,mask_image,strength,
conditioning,
init_image: PIL.Image.Image | torch.FloatTensor,
mask_image: PIL.Image.Image | torch.FloatTensor,
strength: float,
mask_blur_radius: int = 8,
# Seam settings - when 0, doesn't fill seam
seam_size: int = 0,
@ -223,7 +222,7 @@ class Inpaint(Img2Img):
debug_image(init_filled, "init_filled", debug_status=self.enable_image_debugging)
# Create init tensor
init_image = self._image_to_tensor(init_filled.convert('RGB'))
init_image = image_resized_to_grid_as_tensor(init_filled.convert('RGB'))
if isinstance(mask_image, PIL.Image.Image):
self.pil_mask = mask_image.copy()
@ -244,85 +243,68 @@ class Inpaint(Img2Img):
),
resample=Image.Resampling.NEAREST
)
mask_image = self._image_to_tensor(mask_image,normalize=False)
mask_image = image_resized_to_grid_as_tensor(mask_image, normalize=False)
self.mask_blur_radius = mask_blur_radius
# klms samplers not supported yet, so ignore previous sampler
if isinstance(sampler,KSampler):
print(
">> Using recommended DDIM sampler for inpainting."
)
sampler = DDIMSampler(self.model, device=self.model.device)
sampler.make_schedule(
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
)
# if isinstance(sampler,KSampler):
# print(
# ">> Using recommended DDIM sampler for inpainting."
# )
# sampler = DDIMSampler(self.model, device=self.model.device)
mask_image = mask_image[0][0].unsqueeze(0).repeat(4,1,1).unsqueeze(0)
mask_image = repeat(mask_image, '1 ... -> b ...', b=1)
scope = choose_autocast(self.precision)
with scope(self.model.device.type):
self.init_latent = self.model.get_first_stage_encoding(
self.model.encode_first_stage(init_image)
) # move to latent space
t_enc = int(strength * steps)
# todo: support cross-attention control
uc, c, _ = conditioning
print(f">> target t_enc is {t_enc} steps")
@torch.no_grad()
# noinspection PyTypeChecker
pipeline: StableDiffusionGeneratorPipeline = self.model
pipeline.scheduler = sampler
def make_image(x_T):
# FIXME: some of this z_enc and inpaint_replace stuff was probably important
# encode (scaled latent)
z_enc = sampler.stochastic_encode(
self.init_latent,
torch.tensor([t_enc]).to(self.model.device),
noise=x_T
# z_enc = sampler.stochastic_encode(
# self.init_latent,
# torch.tensor([t_enc]).to(self.model.device),
# noise=x_T
# )
#
# # to replace masked area with latent noise, weighted by inpaint_replace strength
# if inpaint_replace > 0.0:
# print(f'>> inpaint will replace what was under the mask with a strength of {inpaint_replace}')
# l_noise = self.get_noise(kwargs['width'],kwargs['height'])
# inverted_mask = 1.0-mask_image # there will be 1s where the mask is
# masked_region = (1.0-inpaint_replace) * inverted_mask * z_enc + inpaint_replace * inverted_mask * l_noise
# z_enc = z_enc * mask_image + masked_region
pipeline_output = pipeline.inpaint_from_embeddings(
init_image=init_image,
mask_image=mask_image,
strength=strength,
num_inference_steps=steps,
text_embeddings=c,
unconditioned_embeddings=uc,
guidance_scale=cfg_scale,
noise_func=self.get_noise_like,
callback=step_callback,
)
# to replace masked area with latent noise, weighted by inpaint_replace strength
if inpaint_replace > 0.0:
print(f'>> inpaint will replace what was under the mask with a strength of {inpaint_replace}')
l_noise = self.get_noise(kwargs['width'],kwargs['height'])
inverted_mask = 1.0-mask_image # there will be 1s where the mask is
masked_region = (1.0-inpaint_replace) * inverted_mask * z_enc + inpaint_replace * inverted_mask * l_noise
z_enc = z_enc * mask_image + masked_region
# decode it
samples = sampler.decode(
z_enc,
c,
t_enc,
img_callback = step_callback,
unconditional_guidance_scale = cfg_scale,
unconditional_conditioning = uc,
mask = mask_image,
init_latent = self.init_latent
)
result = self.sample_to_image(samples)
result = pipeline.numpy_to_pil(pipeline_output.images)[0]
# Seam paint if this is our first pass (seam_size set to 0 during seam painting)
if seam_size > 0:
old_image = self.pil_image or init_image
old_mask = self.pil_mask or mask_image
result = self.seam_paint(
result,
seam_size,
seam_blur,
prompt,
sampler,
seam_steps,
cfg_scale,
ddim_eta,
conditioning,
seam_strength,
x_T,
step_callback)
result = self.seam_paint(result, seam_size, seam_blur, prompt, sampler, seam_steps, cfg_scale, ddim_eta,
conditioning, seam_strength, x_T, infill_method, step_callback)
# Restore original settings
self.get_make_image(prompt,sampler,steps,cfg_scale,ddim_eta,