add support for runwayML custom inpainting model

This is still a work in progress but seems functional. It supports
inpainting, txt2img and img2img on the ddim and k* samplers (plms
still needs work, but I know what to do).

To test this, get the file `sd-v1-5-inpainting.ckpt' from
https://huggingface.co/runwayml/stable-diffusion-inpainting and place it
at `models/ldm/stable-diffusion-v1/sd-v1-5-inpainting.ckpt`

Launch invoke.py with --model inpainting-1.5 and proceed as usual.

Caveats:

1. The inpainting model takes about 800 Mb more memory than the standard
   1.5 model. This model will not work on 4 GB cards.

2. The inpainting model is temperamental. It wants you to describe the
   entire scene and not just the masked area to replace. So if you want
   to replace the parrot on a man's shoulder with a crow, the prompt
   "crow" may fail. Try "man with a crow on shoulder" instead. The
   symptom of a failed inpainting is that the area will be erased and
   replaced with background.

3. This has not been tested well. Please report bugs.
This commit is contained in:
Lincoln Stein 2022-10-25 10:45:12 -04:00
parent aaf7a4f1d3
commit b101be041b
3 changed files with 30 additions and 10 deletions

View File

@ -13,6 +13,13 @@ stable-diffusion-1.4:
default: true
width: 512
height: 512
inpainting-1.5:
description: runwayML tuned inpainting model v1.5
weights: models/ldm/stable-diffusion-v1/sd-v1-5-inpainting.ckpt
config: configs/stable-diffusion/v1-inpainting-inference.yaml
# vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
width: 512
height: 512
stable-diffusion-1.5:
config: configs/stable-diffusion/v1-inference.yaml
weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt

View File

@ -4,6 +4,7 @@ import torch
import numpy as np
from einops import repeat
from PIL import Image, ImageOps
from ldm.invoke.devices import choose_autocast
from ldm.invoke.generator.base import downsampling
from ldm.invoke.generator.img2img import Img2Img
from ldm.invoke.generator.txt2img import Txt2Img
@ -60,7 +61,7 @@ class Omnibus(Img2Img,Txt2Img):
) # move to latent space
# create a completely black mask (1s)
mask_image = torch.ones(init_image.shape[0], 3, init_image.width, init_image.height, device=self.model.device)
mask_image = torch.ones(1, 1, init_image.shape[2], init_image.shape[3], device=self.model.device)
# and the masked image is just a copy of the original
masked_image = init_image
t_enc = int(strength * steps)
@ -74,7 +75,8 @@ class Omnibus(Img2Img,Txt2Img):
def make_image(x_T):
with torch.no_grad():
with torch.autocast("cuda"):
scope = choose_autocast(self.precision)
with scope(self.model.device.type):
batch = self.make_batch_sd(
init_image,
@ -139,15 +141,10 @@ class Omnibus(Img2Img,Txt2Img):
"mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples),
"masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples),
}
print(f'DEBUG: image = {batch["image"]} shape={batch["image"].shape}')
print(f'DEBUG: mask = {batch["mask"]} shape={batch["mask"].shape}')
print(f'DEBUG: masked_image = {batch["masked_image"]} shape={batch["masked_image"].shape}')
return batch
def get_noise(self, width:int, height:int):
if self.init_latent:
print('DEBUG: returning Img2Img.getnoise()')
return super(Img2Img,self).get_noise(width,height)
else:
print('DEBUG: returning Txt2Img.getnoise()')
return super(Txt2Img,self).get_noise(width,height)

View File

@ -12,6 +12,22 @@ from ldm.modules.diffusionmodules.util import (
extract_into_tensor,
)
def make_cond_in(uncond, cond):
if isinstance(cond, dict):
assert isinstance(uncond, dict)
cond_in = dict()
for k in cond:
if isinstance(cond[k], list):
cond_in[k] = [
torch.cat([uncond[k][i], cond[k][i]])
for i in range(len(cond[k]))
]
else:
cond_in[k] = torch.cat([uncond[k], cond[k]])
else:
cond_in = torch.cat([uncond, cond])
return cond_in
def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
if threshold <= 0.0:
return result
@ -37,7 +53,7 @@ class CFGDenoiser(nn.Module):
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
cond_in = make_cond_in(uncond,cond)
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
if self.warmup < self.warmup_max:
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
@ -64,13 +80,12 @@ class KSampler(Sampler):
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
cond_in = make_cond_in(uncond, cond)
uncond, cond = self.inner_model(
x_in, sigma_in, cond=cond_in
).chunk(2)
return uncond + (cond - uncond) * cond_scale
def make_schedule(
self,
ddim_num_steps,
@ -283,3 +298,4 @@ class KSampler(Sampler):
def conditioning_key(self)->str:
return self.model.inner_model.model.conditioning_key