resolve conflicts between PR #1108 and #1243

This commit is contained in:
Lincoln Stein
2022-10-26 15:37:24 -04:00
17 changed files with 444 additions and 38 deletions

View File

@ -23,9 +23,10 @@ def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
class CFGDenoiser(nn.Module):
def __init__(self, model, threshold = 0, warmup = 0):
def __init__(self, sampler, threshold = 0, warmup = 0):
super().__init__()
self.inner_model = model
self.inner_model = sampler.model
self.sampler = sampler
self.threshold = threshold
self.warmup_max = warmup
self.warmup = max(warmup / 10, 1)
@ -43,10 +44,14 @@ class CFGDenoiser(nn.Module):
def forward(self, x, sigma, uncond, cond, cond_scale):
next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale)
# apply threshold
if isinstance(cond,dict): # hybrid model
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = self.sampler.make_cond_in(uncond,cond)
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
next_x = uncond + (cond - uncond) * cond_scale
else: # cross attention model
next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale)
if self.warmup < self.warmup_max:
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
self.warmup += 1
@ -56,8 +61,6 @@ class CFGDenoiser(nn.Module):
thresh = self.threshold
return cfg_apply_threshold(next_x, thresh)
class KSampler(Sampler):
def __init__(self, model, schedule='lms', device=None, **kwargs):
denoiser = K.external.CompVisDenoiser(model)
@ -286,3 +289,6 @@ class KSampler(Sampler):
'''
return self.model.inner_model.q_sample(x0,ts)
def conditioning_key(self)->str:
return self.model.inner_model.model.conditioning_key