From e33971fe2cb046d08548366bc775a4091b0739ea Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Tue, 25 Oct 2022 11:42:30 -0400 Subject: [PATCH] plms works, bugs quashed - The plms sampler now works with custom inpainting model - Quashed bug that was causing generation on normal models to fail (oops!) - Can now generate non-square images with custom inpainting model Credits for advice and assistance during porting: @any-winter-4079 (http://github.com/any-winter-4079) @db3000 (Danny Beer http://github.com/db3000) --- ldm/invoke/generator/base.py | 7 +++---- ldm/invoke/generator/omnibus.py | 4 ++-- ldm/models/diffusion/ksampler.py | 27 ++++++--------------------- ldm/models/diffusion/plms.py | 2 +- ldm/models/diffusion/sampler.py | 21 +++++++++++++++++++++ 5 files changed, 33 insertions(+), 28 deletions(-) diff --git a/ldm/invoke/generator/base.py b/ldm/invoke/generator/base.py index c70924449b..03f066323c 100644 --- a/ldm/invoke/generator/base.py +++ b/ldm/invoke/generator/base.py @@ -60,11 +60,10 @@ class Generator(): first_seed = seed seed, initial_noise = self.generate_initial_noise(seed, width, height) - scope = (scope(self.model.device.type), self.model.ema_scope()) if sampler.conditioning_key() not in ('hybrid','concat') else scope(self.model.device.type) - - with scope: + # There used to be an additional self.model.ema_scope() here, but it breaks + # the inpaint-1.5 model. Not sure what it did.... ? + with scope(self.model.device.type): for n in trange(iterations, desc='Generating'): - print('DEBUG: in iterations loop() called') x_T = None if self.variation_amount > 0: seed_everything(seed) diff --git a/ldm/invoke/generator/omnibus.py b/ldm/invoke/generator/omnibus.py index 99fe046654..c8de01addb 100644 --- a/ldm/invoke/generator/omnibus.py +++ b/ldm/invoke/generator/omnibus.py @@ -67,8 +67,8 @@ class Omnibus(Img2Img,Txt2Img): t_enc = int(strength * steps) else: # txt2img - init_image = torch.zeros(1, 3, width, height, device=self.model.device) - mask_image = torch.ones(1, 1, width, height, device=self.model.device) + init_image = torch.zeros(1, 3, height, width, device=self.model.device) + mask_image = torch.ones(1, 1, height, width, device=self.model.device) masked_image = init_image model = self.model diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index 59a3bebe4d..5f223cdf46 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -12,22 +12,6 @@ from ldm.modules.diffusionmodules.util import ( extract_into_tensor, ) -def make_cond_in(uncond, cond): - if isinstance(cond, dict): - assert isinstance(uncond, dict) - cond_in = dict() - for k in cond: - if isinstance(cond[k], list): - cond_in[k] = [ - torch.cat([uncond[k][i], cond[k][i]]) - for i in range(len(cond[k])) - ] - else: - cond_in[k] = torch.cat([uncond[k], cond[k]]) - else: - cond_in = torch.cat([uncond, cond]) - return cond_in - def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7): if threshold <= 0.0: return result @@ -43,9 +27,10 @@ def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7): class CFGDenoiser(nn.Module): - def __init__(self, model, threshold = 0, warmup = 0): + def __init__(self, sampler, threshold = 0, warmup = 0): super().__init__() - self.inner_model = model + self.inner_model = sampler.model + self.sampler = sampler self.threshold = threshold self.warmup_max = warmup self.warmup = max(warmup / 10, 1) @@ -53,7 +38,7 @@ class CFGDenoiser(nn.Module): def forward(self, x, sigma, uncond, cond, cond_scale): x_in = torch.cat([x] * 2) sigma_in = torch.cat([sigma] * 2) - cond_in = make_cond_in(uncond,cond) + cond_in = self.sampler.make_cond_in(uncond,cond) uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) if self.warmup < self.warmup_max: thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max)) @@ -80,7 +65,7 @@ class KSampler(Sampler): def forward(self, x, sigma, uncond, cond, cond_scale): x_in = torch.cat([x] * 2) sigma_in = torch.cat([sigma] * 2) - cond_in = make_cond_in(uncond, cond) + cond_in = self.make_cond_in(uncond, cond) uncond, cond = self.inner_model( x_in, sigma_in, cond=cond_in ).chunk(2) @@ -209,7 +194,7 @@ class KSampler(Sampler): else: x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] - model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10)) + model_wrap_cfg = CFGDenoiser(self, threshold=threshold, warmup=max(0.8*S,S-10)) extra_args = { 'cond': conditioning, 'uncond': unconditional_conditioning, diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py index 9e722eb932..4261f549d2 100644 --- a/ldm/models/diffusion/plms.py +++ b/ldm/models/diffusion/plms.py @@ -45,7 +45,7 @@ class PLMSSampler(Sampler): else: x_in = torch.cat([x] * 2) t_in = torch.cat([t] * 2) - c_in = torch.cat([unconditional_conditioning, c]) + c_in = self.make_cond_in(unconditional_conditioning, c) e_t_uncond, e_t = self.model.apply_model( x_in, t_in, c_in ).chunk(2) diff --git a/ldm/models/diffusion/sampler.py b/ldm/models/diffusion/sampler.py index 9e57bc25d4..fd7ba106c1 100644 --- a/ldm/models/diffusion/sampler.py +++ b/ldm/models/diffusion/sampler.py @@ -439,3 +439,24 @@ class Sampler(object): def conditioning_key(self)->str: return self.model.model.conditioning_key + + def make_cond_in(self, uncond, cond): + ''' + This handles the choice between a conditional conditioning + that is a tensor (used by cross attention) vs one that is a dict + used by 'hybrid' + ''' + if isinstance(cond, dict): + assert isinstance(uncond, dict) + cond_in = dict() + for k in cond: + if isinstance(cond[k], list): + cond_in[k] = [ + torch.cat([uncond[k][i], cond[k][i]]) + for i in range(len(cond[k])) + ] + else: + cond_in[k] = torch.cat([uncond[k], cond[k]]) + else: + cond_in = torch.cat([uncond, cond]) + return cond_in