add support for runwayML custom inpainting model

This is still a work in progress but seems functional. It supports
inpainting, txt2img and img2img on the ddim and k* samplers (plms
still needs work, but I know what to do).

To test this, get the file `sd-v1-5-inpainting.ckpt' from
https://huggingface.co/runwayml/stable-diffusion-inpainting and place it
at `models/ldm/stable-diffusion-v1/sd-v1-5-inpainting.ckpt`

Launch invoke.py with --model inpainting-1.5 and proceed as usual.

Caveats:

1. The inpainting model takes about 800 Mb more memory than the standard
   1.5 model. This model will not work on 4 GB cards.

2. The inpainting model is temperamental. It wants you to describe the
   entire scene and not just the masked area to replace. So if you want
   to replace the parrot on a man's shoulder with a crow, the prompt
   "crow" may fail. Try "man with a crow on shoulder" instead. The
   symptom of a failed inpainting is that the area will be erased and
   replaced with background.

3. This has not been tested well. Please report bugs.
This commit is contained in:
Lincoln Stein
2022-10-25 10:45:12 -04:00
parent aaf7a4f1d3
commit b101be041b
3 changed files with 30 additions and 10 deletions

View File

@ -12,6 +12,22 @@ from ldm.modules.diffusionmodules.util import (
extract_into_tensor,
)
def make_cond_in(uncond, cond):
if isinstance(cond, dict):
assert isinstance(uncond, dict)
cond_in = dict()
for k in cond:
if isinstance(cond[k], list):
cond_in[k] = [
torch.cat([uncond[k][i], cond[k][i]])
for i in range(len(cond[k]))
]
else:
cond_in[k] = torch.cat([uncond[k], cond[k]])
else:
cond_in = torch.cat([uncond, cond])
return cond_in
def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
if threshold <= 0.0:
return result
@ -37,7 +53,7 @@ class CFGDenoiser(nn.Module):
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
cond_in = make_cond_in(uncond,cond)
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
if self.warmup < self.warmup_max:
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
@ -64,13 +80,12 @@ class KSampler(Sampler):
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
cond_in = make_cond_in(uncond, cond)
uncond, cond = self.inner_model(
x_in, sigma_in, cond=cond_in
).chunk(2)
return uncond + (cond - uncond) * cond_scale
def make_schedule(
self,
ddim_num_steps,
@ -283,3 +298,4 @@ class KSampler(Sampler):
def conditioning_key(self)->str:
return self.model.inner_model.model.conditioning_key