plms works, bugs quashed

- The plms sampler now works with custom inpainting model
- Quashed bug that was causing generation on normal models to fail (oops!)
- Can now generate non-square images with custom inpainting model

Credits for advice and assistance during porting:

@any-winter-4079 (http://github.com/any-winter-4079)
@db3000 (Danny Beer http://github.com/db3000)
This commit is contained in:
Lincoln Stein
2022-10-25 11:42:30 -04:00
parent b101be041b
commit e33971fe2c
5 changed files with 33 additions and 28 deletions

View File

@ -12,22 +12,6 @@ from ldm.modules.diffusionmodules.util import (
extract_into_tensor,
)
def make_cond_in(uncond, cond):
if isinstance(cond, dict):
assert isinstance(uncond, dict)
cond_in = dict()
for k in cond:
if isinstance(cond[k], list):
cond_in[k] = [
torch.cat([uncond[k][i], cond[k][i]])
for i in range(len(cond[k]))
]
else:
cond_in[k] = torch.cat([uncond[k], cond[k]])
else:
cond_in = torch.cat([uncond, cond])
return cond_in
def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
if threshold <= 0.0:
return result
@ -43,9 +27,10 @@ def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
class CFGDenoiser(nn.Module):
def __init__(self, model, threshold = 0, warmup = 0):
def __init__(self, sampler, threshold = 0, warmup = 0):
super().__init__()
self.inner_model = model
self.inner_model = sampler.model
self.sampler = sampler
self.threshold = threshold
self.warmup_max = warmup
self.warmup = max(warmup / 10, 1)
@ -53,7 +38,7 @@ class CFGDenoiser(nn.Module):
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = make_cond_in(uncond,cond)
cond_in = self.sampler.make_cond_in(uncond,cond)
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
if self.warmup < self.warmup_max:
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
@ -80,7 +65,7 @@ class KSampler(Sampler):
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = make_cond_in(uncond, cond)
cond_in = self.make_cond_in(uncond, cond)
uncond, cond = self.inner_model(
x_in, sigma_in, cond=cond_in
).chunk(2)
@ -209,7 +194,7 @@ class KSampler(Sampler):
else:
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
model_wrap_cfg = CFGDenoiser(self, threshold=threshold, warmup=max(0.8*S,S-10))
extra_args = {
'cond': conditioning,
'uncond': unconditional_conditioning,