add support for runwayML custom inpainting model

This is still a work in progress but seems functional. It supports
inpainting, txt2img and img2img on the ddim and k* samplers (plms
still needs work, but I know what to do).

To test this, get the file `sd-v1-5-inpainting.ckpt' from
https://huggingface.co/runwayml/stable-diffusion-inpainting and place it
at `models/ldm/stable-diffusion-v1/sd-v1-5-inpainting.ckpt`

Launch invoke.py with --model inpainting-1.5 and proceed as usual.

Caveats:

1. The inpainting model takes about 800 Mb more memory than the standard
   1.5 model. This model will not work on 4 GB cards.

2. The inpainting model is temperamental. It wants you to describe the
   entire scene and not just the masked area to replace. So if you want
   to replace the parrot on a man's shoulder with a crow, the prompt
   "crow" may fail. Try "man with a crow on shoulder" instead. The
   symptom of a failed inpainting is that the area will be erased and
   replaced with background.

3. This has not been tested well. Please report bugs.
This commit is contained in:
Lincoln Stein 2022-10-25 10:45:12 -04:00
parent aaf7a4f1d3
commit b101be041b
3 changed files with 30 additions and 10 deletions

View File

@ -13,6 +13,13 @@ stable-diffusion-1.4:
default: true default: true
width: 512 width: 512
height: 512 height: 512
inpainting-1.5:
description: runwayML tuned inpainting model v1.5
weights: models/ldm/stable-diffusion-v1/sd-v1-5-inpainting.ckpt
config: configs/stable-diffusion/v1-inpainting-inference.yaml
# vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
width: 512
height: 512
stable-diffusion-1.5: stable-diffusion-1.5:
config: configs/stable-diffusion/v1-inference.yaml config: configs/stable-diffusion/v1-inference.yaml
weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt

View File

@ -4,6 +4,7 @@ import torch
import numpy as np import numpy as np
from einops import repeat from einops import repeat
from PIL import Image, ImageOps from PIL import Image, ImageOps
from ldm.invoke.devices import choose_autocast
from ldm.invoke.generator.base import downsampling from ldm.invoke.generator.base import downsampling
from ldm.invoke.generator.img2img import Img2Img from ldm.invoke.generator.img2img import Img2Img
from ldm.invoke.generator.txt2img import Txt2Img from ldm.invoke.generator.txt2img import Txt2Img
@ -60,7 +61,7 @@ class Omnibus(Img2Img,Txt2Img):
) # move to latent space ) # move to latent space
# create a completely black mask (1s) # create a completely black mask (1s)
mask_image = torch.ones(init_image.shape[0], 3, init_image.width, init_image.height, device=self.model.device) mask_image = torch.ones(1, 1, init_image.shape[2], init_image.shape[3], device=self.model.device)
# and the masked image is just a copy of the original # and the masked image is just a copy of the original
masked_image = init_image masked_image = init_image
t_enc = int(strength * steps) t_enc = int(strength * steps)
@ -74,7 +75,8 @@ class Omnibus(Img2Img,Txt2Img):
def make_image(x_T): def make_image(x_T):
with torch.no_grad(): with torch.no_grad():
with torch.autocast("cuda"): scope = choose_autocast(self.precision)
with scope(self.model.device.type):
batch = self.make_batch_sd( batch = self.make_batch_sd(
init_image, init_image,
@ -139,15 +141,10 @@ class Omnibus(Img2Img,Txt2Img):
"mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples), "mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples),
"masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples), "masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples),
} }
print(f'DEBUG: image = {batch["image"]} shape={batch["image"].shape}')
print(f'DEBUG: mask = {batch["mask"]} shape={batch["mask"].shape}')
print(f'DEBUG: masked_image = {batch["masked_image"]} shape={batch["masked_image"].shape}')
return batch return batch
def get_noise(self, width:int, height:int): def get_noise(self, width:int, height:int):
if self.init_latent: if self.init_latent:
print('DEBUG: returning Img2Img.getnoise()')
return super(Img2Img,self).get_noise(width,height) return super(Img2Img,self).get_noise(width,height)
else: else:
print('DEBUG: returning Txt2Img.getnoise()')
return super(Txt2Img,self).get_noise(width,height) return super(Txt2Img,self).get_noise(width,height)

View File

@ -12,6 +12,22 @@ from ldm.modules.diffusionmodules.util import (
extract_into_tensor, extract_into_tensor,
) )
def make_cond_in(uncond, cond):
if isinstance(cond, dict):
assert isinstance(uncond, dict)
cond_in = dict()
for k in cond:
if isinstance(cond[k], list):
cond_in[k] = [
torch.cat([uncond[k][i], cond[k][i]])
for i in range(len(cond[k]))
]
else:
cond_in[k] = torch.cat([uncond[k], cond[k]])
else:
cond_in = torch.cat([uncond, cond])
return cond_in
def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7): def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
if threshold <= 0.0: if threshold <= 0.0:
return result return result
@ -37,7 +53,7 @@ class CFGDenoiser(nn.Module):
def forward(self, x, sigma, uncond, cond, cond_scale): def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2) x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2) sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond]) cond_in = make_cond_in(uncond,cond)
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
if self.warmup < self.warmup_max: if self.warmup < self.warmup_max:
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max)) thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
@ -64,13 +80,12 @@ class KSampler(Sampler):
def forward(self, x, sigma, uncond, cond, cond_scale): def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2) x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2) sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond]) cond_in = make_cond_in(uncond, cond)
uncond, cond = self.inner_model( uncond, cond = self.inner_model(
x_in, sigma_in, cond=cond_in x_in, sigma_in, cond=cond_in
).chunk(2) ).chunk(2)
return uncond + (cond - uncond) * cond_scale return uncond + (cond - uncond) * cond_scale
def make_schedule( def make_schedule(
self, self,
ddim_num_steps, ddim_num_steps,
@ -283,3 +298,4 @@ class KSampler(Sampler):
def conditioning_key(self)->str: def conditioning_key(self)->str:
return self.model.inner_model.model.conditioning_key return self.model.inner_model.model.conditioning_key