mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
plms works, bugs quashed
- The plms sampler now works with custom inpainting model - Quashed bug that was causing generation on normal models to fail (oops!) - Can now generate non-square images with custom inpainting model Credits for advice and assistance during porting: @any-winter-4079 (http://github.com/any-winter-4079) @db3000 (Danny Beer http://github.com/db3000)
This commit is contained in:
parent
b101be041b
commit
e33971fe2c
@ -60,11 +60,10 @@ class Generator():
|
|||||||
first_seed = seed
|
first_seed = seed
|
||||||
seed, initial_noise = self.generate_initial_noise(seed, width, height)
|
seed, initial_noise = self.generate_initial_noise(seed, width, height)
|
||||||
|
|
||||||
scope = (scope(self.model.device.type), self.model.ema_scope()) if sampler.conditioning_key() not in ('hybrid','concat') else scope(self.model.device.type)
|
# There used to be an additional self.model.ema_scope() here, but it breaks
|
||||||
|
# the inpaint-1.5 model. Not sure what it did.... ?
|
||||||
with scope:
|
with scope(self.model.device.type):
|
||||||
for n in trange(iterations, desc='Generating'):
|
for n in trange(iterations, desc='Generating'):
|
||||||
print('DEBUG: in iterations loop() called')
|
|
||||||
x_T = None
|
x_T = None
|
||||||
if self.variation_amount > 0:
|
if self.variation_amount > 0:
|
||||||
seed_everything(seed)
|
seed_everything(seed)
|
||||||
|
@ -67,8 +67,8 @@ class Omnibus(Img2Img,Txt2Img):
|
|||||||
t_enc = int(strength * steps)
|
t_enc = int(strength * steps)
|
||||||
|
|
||||||
else: # txt2img
|
else: # txt2img
|
||||||
init_image = torch.zeros(1, 3, width, height, device=self.model.device)
|
init_image = torch.zeros(1, 3, height, width, device=self.model.device)
|
||||||
mask_image = torch.ones(1, 1, width, height, device=self.model.device)
|
mask_image = torch.ones(1, 1, height, width, device=self.model.device)
|
||||||
masked_image = init_image
|
masked_image = init_image
|
||||||
|
|
||||||
model = self.model
|
model = self.model
|
||||||
|
@ -12,22 +12,6 @@ from ldm.modules.diffusionmodules.util import (
|
|||||||
extract_into_tensor,
|
extract_into_tensor,
|
||||||
)
|
)
|
||||||
|
|
||||||
def make_cond_in(uncond, cond):
|
|
||||||
if isinstance(cond, dict):
|
|
||||||
assert isinstance(uncond, dict)
|
|
||||||
cond_in = dict()
|
|
||||||
for k in cond:
|
|
||||||
if isinstance(cond[k], list):
|
|
||||||
cond_in[k] = [
|
|
||||||
torch.cat([uncond[k][i], cond[k][i]])
|
|
||||||
for i in range(len(cond[k]))
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
cond_in[k] = torch.cat([uncond[k], cond[k]])
|
|
||||||
else:
|
|
||||||
cond_in = torch.cat([uncond, cond])
|
|
||||||
return cond_in
|
|
||||||
|
|
||||||
def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
|
def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
|
||||||
if threshold <= 0.0:
|
if threshold <= 0.0:
|
||||||
return result
|
return result
|
||||||
@ -43,9 +27,10 @@ def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
|
|||||||
|
|
||||||
|
|
||||||
class CFGDenoiser(nn.Module):
|
class CFGDenoiser(nn.Module):
|
||||||
def __init__(self, model, threshold = 0, warmup = 0):
|
def __init__(self, sampler, threshold = 0, warmup = 0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.inner_model = model
|
self.inner_model = sampler.model
|
||||||
|
self.sampler = sampler
|
||||||
self.threshold = threshold
|
self.threshold = threshold
|
||||||
self.warmup_max = warmup
|
self.warmup_max = warmup
|
||||||
self.warmup = max(warmup / 10, 1)
|
self.warmup = max(warmup / 10, 1)
|
||||||
@ -53,7 +38,7 @@ class CFGDenoiser(nn.Module):
|
|||||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||||
x_in = torch.cat([x] * 2)
|
x_in = torch.cat([x] * 2)
|
||||||
sigma_in = torch.cat([sigma] * 2)
|
sigma_in = torch.cat([sigma] * 2)
|
||||||
cond_in = make_cond_in(uncond,cond)
|
cond_in = self.sampler.make_cond_in(uncond,cond)
|
||||||
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
||||||
if self.warmup < self.warmup_max:
|
if self.warmup < self.warmup_max:
|
||||||
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
|
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
|
||||||
@ -80,7 +65,7 @@ class KSampler(Sampler):
|
|||||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||||
x_in = torch.cat([x] * 2)
|
x_in = torch.cat([x] * 2)
|
||||||
sigma_in = torch.cat([sigma] * 2)
|
sigma_in = torch.cat([sigma] * 2)
|
||||||
cond_in = make_cond_in(uncond, cond)
|
cond_in = self.make_cond_in(uncond, cond)
|
||||||
uncond, cond = self.inner_model(
|
uncond, cond = self.inner_model(
|
||||||
x_in, sigma_in, cond=cond_in
|
x_in, sigma_in, cond=cond_in
|
||||||
).chunk(2)
|
).chunk(2)
|
||||||
@ -209,7 +194,7 @@ class KSampler(Sampler):
|
|||||||
else:
|
else:
|
||||||
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
|
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
|
||||||
|
|
||||||
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
|
model_wrap_cfg = CFGDenoiser(self, threshold=threshold, warmup=max(0.8*S,S-10))
|
||||||
extra_args = {
|
extra_args = {
|
||||||
'cond': conditioning,
|
'cond': conditioning,
|
||||||
'uncond': unconditional_conditioning,
|
'uncond': unconditional_conditioning,
|
||||||
|
@ -45,7 +45,7 @@ class PLMSSampler(Sampler):
|
|||||||
else:
|
else:
|
||||||
x_in = torch.cat([x] * 2)
|
x_in = torch.cat([x] * 2)
|
||||||
t_in = torch.cat([t] * 2)
|
t_in = torch.cat([t] * 2)
|
||||||
c_in = torch.cat([unconditional_conditioning, c])
|
c_in = self.make_cond_in(unconditional_conditioning, c)
|
||||||
e_t_uncond, e_t = self.model.apply_model(
|
e_t_uncond, e_t = self.model.apply_model(
|
||||||
x_in, t_in, c_in
|
x_in, t_in, c_in
|
||||||
).chunk(2)
|
).chunk(2)
|
||||||
|
@ -439,3 +439,24 @@ class Sampler(object):
|
|||||||
|
|
||||||
def conditioning_key(self)->str:
|
def conditioning_key(self)->str:
|
||||||
return self.model.model.conditioning_key
|
return self.model.model.conditioning_key
|
||||||
|
|
||||||
|
def make_cond_in(self, uncond, cond):
|
||||||
|
'''
|
||||||
|
This handles the choice between a conditional conditioning
|
||||||
|
that is a tensor (used by cross attention) vs one that is a dict
|
||||||
|
used by 'hybrid'
|
||||||
|
'''
|
||||||
|
if isinstance(cond, dict):
|
||||||
|
assert isinstance(uncond, dict)
|
||||||
|
cond_in = dict()
|
||||||
|
for k in cond:
|
||||||
|
if isinstance(cond[k], list):
|
||||||
|
cond_in[k] = [
|
||||||
|
torch.cat([uncond[k][i], cond[k][i]])
|
||||||
|
for i in range(len(cond[k]))
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
cond_in[k] = torch.cat([uncond[k], cond[k]])
|
||||||
|
else:
|
||||||
|
cond_in = torch.cat([uncond, cond])
|
||||||
|
return cond_in
|
||||||
|
Loading…
x
Reference in New Issue
Block a user