mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
diffusers support for the inpainting model
This commit is contained in:
parent
ff42027a00
commit
875312080d
@ -9,10 +9,10 @@ import PIL.Image
|
||||
import einops
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from diffusers.models import attention
|
||||
|
||||
from ldm.models.diffusion.cross_attention_control import InvokeAICrossAttention
|
||||
|
||||
from diffusers.models import attention
|
||||
# monkeypatch diffusers CrossAttention 🙈
|
||||
# this is to make prompt2prompt and (future) attention maps work
|
||||
attention.CrossAttention = InvokeAICrossAttention
|
||||
@ -23,6 +23,7 @@ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import Stabl
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
@ -49,6 +50,21 @@ _default_personalization_config_params = dict(
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AddsMaskLatents:
|
||||
forward: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
|
||||
mask: torch.FloatTensor
|
||||
mask_latents: torch.FloatTensor
|
||||
|
||||
def __call__(self, latents: torch.FloatTensor, t: torch.Tensor, text_embeddings: torch.FloatTensor) -> torch.Tensor:
|
||||
batch_size = latents.size(0)
|
||||
mask = einops.repeat(self.mask, 'b c h w -> (repeat b) c h w', repeat=batch_size)
|
||||
mask_latents = einops.repeat(self.mask_latents, 'b c h w -> (repeat b) c h w', repeat=batch_size)
|
||||
model_input, _ = einops.pack([latents, mask, mask_latents], 'b * h w')
|
||||
# model_input = torch.cat([latents, mask, mask_latents], dim=1)
|
||||
return self.forward(model_input, t, text_embeddings)
|
||||
|
||||
|
||||
def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True, multiple_of=8) -> torch.FloatTensor:
|
||||
"""
|
||||
|
||||
@ -57,7 +73,7 @@ def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True
|
||||
:param multiple_of: resize the input so both dimensions are a multiple of this
|
||||
"""
|
||||
w, h = image.size
|
||||
w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8
|
||||
w, h = map(lambda x: x - x % multiple_of, (w, h)) # resize to integer multiple of 8
|
||||
transformation = T.Compose([
|
||||
T.Resize((h, w), T.InterpolationMode.LANCZOS),
|
||||
T.ToTensor(),
|
||||
@ -68,6 +84,10 @@ def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True
|
||||
return tensor
|
||||
|
||||
|
||||
def is_inpainting_model(unet: UNet2DConditionModel):
|
||||
return unet.conv_in.in_channels == 9
|
||||
|
||||
|
||||
class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion.
|
||||
@ -314,7 +334,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
||||
|
||||
# 6. Prepare latent variables
|
||||
latents = self.prepare_latents_from_image(init_image, latent_timestep, latents_dtype, device, noise_func)
|
||||
latents, _ = self.prepare_latents_from_image(init_image, latent_timestep, latents_dtype, device, noise_func)
|
||||
|
||||
result = None
|
||||
for result in self.generate_from_embeddings(
|
||||
@ -331,7 +351,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
def inpaint_from_embeddings(
|
||||
self,
|
||||
init_image: torch.FloatTensor,
|
||||
mask_image: torch.FloatTensor,
|
||||
mask: torch.FloatTensor,
|
||||
strength: float,
|
||||
num_inference_steps: int,
|
||||
text_embeddings: torch.Tensor, unconditioned_embeddings: torch.Tensor,
|
||||
@ -349,8 +369,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB'))
|
||||
|
||||
init_image = init_image.to(device=device, dtype=latents_dtype)
|
||||
|
||||
if init_image.dim() == 3:
|
||||
init_image = einops.rearrange(init_image, 'c h w -> 1 c h w')
|
||||
init_image = init_image.unsqueeze(0)
|
||||
|
||||
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
|
||||
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
@ -358,22 +380,38 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
|
||||
# 6. Prepare latent variables
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
||||
latents = self.prepare_latents_from_image(init_image, latent_timestep, latents_dtype, device, noise_func)
|
||||
latents, init_image_latents = self.prepare_latents_from_image(init_image, latent_timestep, latents_dtype, device, noise_func)
|
||||
|
||||
if is_inpainting_model(self.unet):
|
||||
if mask.dim() == 3:
|
||||
mask = mask.unsqueeze(0)
|
||||
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR)\
|
||||
.to(device=device, dtype=latents_dtype)
|
||||
|
||||
self.invokeai_diffuser.model_forward_callback = \
|
||||
AddsMaskLatents(self._unet_forward, mask, init_image_latents)
|
||||
else:
|
||||
# FIXME: need to add guidance that applies mask
|
||||
pass
|
||||
|
||||
result = None
|
||||
for result in self.generate_from_embeddings(
|
||||
latents, text_embeddings, unconditioned_embeddings, guidance_scale,
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
timesteps=timesteps,
|
||||
run_id=run_id, **extra_step_kwargs):
|
||||
if callback is not None and isinstance(result, PipelineIntermediateState):
|
||||
callback(result)
|
||||
if result is None:
|
||||
raise AssertionError("why was that an empty generator?")
|
||||
return result
|
||||
|
||||
try:
|
||||
for result in self.generate_from_embeddings(
|
||||
latents, text_embeddings, unconditioned_embeddings, guidance_scale,
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
timesteps=timesteps,
|
||||
run_id=run_id, **extra_step_kwargs):
|
||||
if callback is not None and isinstance(result, PipelineIntermediateState):
|
||||
callback(result)
|
||||
if result is None:
|
||||
raise AssertionError("why was that an empty generator?")
|
||||
return result
|
||||
finally:
|
||||
self.invokeai_diffuser.model_forward_callback = self._unet_forward
|
||||
|
||||
|
||||
def prepare_latents_from_image(self, init_image, timestep, dtype, device, noise_func) -> torch.FloatTensor:
|
||||
def prepare_latents_from_image(self, init_image, timestep, dtype, device, noise_func) -> (torch.FloatTensor, torch.FloatTensor):
|
||||
# can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents
|
||||
# because we have our own noise function
|
||||
init_image = init_image.to(device=device, dtype=dtype)
|
||||
@ -383,8 +421,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
init_latents = 0.18215 * init_latents
|
||||
|
||||
noise = noise_func(init_latents)
|
||||
|
||||
return self.scheduler.add_noise(init_latents, noise, timestep)
|
||||
noised_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
||||
return noised_latents, init_latents
|
||||
|
||||
def check_for_safety(self, output, dtype):
|
||||
with torch.inference_mode():
|
||||
|
@ -10,9 +10,7 @@ import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image, ImageFilter, ImageOps, ImageChops
|
||||
from einops import repeat
|
||||
|
||||
from ldm.invoke.generator.base import downsampling
|
||||
from ldm.invoke.generator.diffusers_pipeline import image_resized_to_grid_as_tensor, StableDiffusionGeneratorPipeline
|
||||
from ldm.invoke.generator.img2img import Img2Img
|
||||
from ldm.invoke.globals import Globals
|
||||
@ -154,7 +152,7 @@ class Inpaint(Img2Img):
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
init_image = im.copy().convert('RGBA'),
|
||||
mask_image = mask.convert('RGB'), # Code currently requires an RGB mask
|
||||
mask_image = mask,
|
||||
strength = strength,
|
||||
mask_blur_radius = 0,
|
||||
seam_size = 0,
|
||||
@ -228,7 +226,11 @@ class Inpaint(Img2Img):
|
||||
self.pil_mask = mask_image.copy()
|
||||
debug_image(mask_image, "mask_image BEFORE multiply with pil_image", debug_status=self.enable_image_debugging)
|
||||
|
||||
mask_image = ImageChops.multiply(mask_image, self.pil_image.split()[-1].convert('RGB'))
|
||||
init_alpha = self.pil_image.getchannel("A")
|
||||
if mask_image.mode != "L":
|
||||
# FIXME: why do we get passed an RGB image here? We can only use single-channel.
|
||||
mask_image = mask_image.convert("L")
|
||||
mask_image = ImageChops.multiply(mask_image, init_alpha)
|
||||
self.pil_mask = mask_image
|
||||
|
||||
# Resize if requested for inpainting
|
||||
@ -236,57 +238,32 @@ class Inpaint(Img2Img):
|
||||
mask_image = mask_image.resize((inpaint_width, inpaint_height))
|
||||
|
||||
debug_image(mask_image, "mask_image AFTER multiply with pil_image", debug_status=self.enable_image_debugging)
|
||||
mask_image = mask_image.resize(
|
||||
(
|
||||
mask_image.width // downsampling,
|
||||
mask_image.height // downsampling
|
||||
),
|
||||
resample=Image.Resampling.NEAREST
|
||||
)
|
||||
mask_image = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||
mask: torch.FloatTensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||
else:
|
||||
mask: torch.FloatTensor = mask_image
|
||||
|
||||
self.mask_blur_radius = mask_blur_radius
|
||||
|
||||
# klms samplers not supported yet, so ignore previous sampler
|
||||
# if isinstance(sampler,KSampler):
|
||||
# print(
|
||||
# ">> Using recommended DDIM sampler for inpainting."
|
||||
# )
|
||||
# sampler = DDIMSampler(self.model, device=self.model.device)
|
||||
|
||||
mask_image = mask_image[0][0].unsqueeze(0).repeat(4,1,1).unsqueeze(0)
|
||||
mask_image = repeat(mask_image, '1 ... -> b ...', b=1)
|
||||
|
||||
t_enc = int(strength * steps)
|
||||
# todo: support cross-attention control
|
||||
uc, c, _ = conditioning
|
||||
|
||||
print(f">> target t_enc is {t_enc} steps")
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
pipeline: StableDiffusionGeneratorPipeline = self.model
|
||||
pipeline.scheduler = sampler
|
||||
|
||||
def make_image(x_T):
|
||||
# FIXME: some of this z_enc and inpaint_replace stuff was probably important
|
||||
# encode (scaled latent)
|
||||
# z_enc = sampler.stochastic_encode(
|
||||
# self.init_latent,
|
||||
# torch.tensor([t_enc]).to(self.model.device),
|
||||
# noise=x_T
|
||||
# )
|
||||
#
|
||||
# # to replace masked area with latent noise, weighted by inpaint_replace strength
|
||||
# if inpaint_replace > 0.0:
|
||||
# print(f'>> inpaint will replace what was under the mask with a strength of {inpaint_replace}')
|
||||
# l_noise = self.get_noise(kwargs['width'],kwargs['height'])
|
||||
# inverted_mask = 1.0-mask_image # there will be 1s where the mask is
|
||||
# inverted_mask = 1.0-mask # there will be 1s where the mask is
|
||||
# masked_region = (1.0-inpaint_replace) * inverted_mask * z_enc + inpaint_replace * inverted_mask * l_noise
|
||||
# z_enc = z_enc * mask_image + masked_region
|
||||
# z_enc = z_enc * mask + masked_region
|
||||
|
||||
pipeline_output = pipeline.inpaint_from_embeddings(
|
||||
init_image=init_image,
|
||||
mask_image=mask_image,
|
||||
mask=1 - mask, # expects white means "paint here."
|
||||
strength=strength,
|
||||
num_inference_steps=steps,
|
||||
text_embeddings=c,
|
||||
|
Loading…
x
Reference in New Issue
Block a user