refactored single diffusion path seems to be working for all samplers

This commit is contained in:
Damian at mba
2022-10-19 19:57:20 +02:00
parent 147d39cb7c
commit 1ffd4a9e06
7 changed files with 57 additions and 52 deletions

View File

@ -34,18 +34,17 @@ class CFGDenoiser(nn.Module):
def prepare_to_sample(self, t_enc, **kwargs):
edited_conditioning = kwargs.get('edited_conditioning', None)
structured_conditioning = kwargs.get('structured_conditioning', None)
if edited_conditioning is not None:
conditioning_edit_opcodes = kwargs.get('conditioning_edit_opcodes', None)
self.invokeai_diffuser.setup_cross_attention_control(edited_conditioning, conditioning_edit_opcodes)
if structured_conditioning is not None and structured_conditioning.wants_cross_attention_control:
self.invokeai_diffuser.setup_cross_attention_control(structured_conditioning)
else:
self.invokeai_diffuser.cleanup_cross_attention_control()
self.invokeai_diffuser.remove_cross_attention_control()
def forward(self, x, sigma, uncond, cond, cond_scale):
final_next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale)
next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale)
# apply threshold
if self.warmup < self.warmup_max:
@ -55,7 +54,7 @@ class CFGDenoiser(nn.Module):
thresh = self.threshold
if thresh > self.threshold:
thresh = self.threshold
return cfg_apply_threshold(final_next_x, thresh)
return cfg_apply_threshold(next_x, thresh)
@ -165,8 +164,7 @@ class KSampler(Sampler):
log_every_t=100,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
edited_conditioning=None,
conditioning_edit_opcodes=None,
structured_conditioning=None,
threshold = 0,
perlin = 0,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
@ -199,7 +197,7 @@ class KSampler(Sampler):
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
model_wrap_cfg.prepare_to_sample(S, edited_conditioning=edited_conditioning, conditioning_edit_opcodes=conditioning_edit_opcodes)
model_wrap_cfg.prepare_to_sample(S, structured_conditioning=structured_conditioning)
extra_args = {
'cond': conditioning,
'uncond': unconditional_conditioning,
@ -226,8 +224,7 @@ class KSampler(Sampler):
index,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
edited_conditioning=None,
conditioning_edit_opcodes=None,
structured_conditioning=None,
**kwargs,
):
if self.model_wrap is None:
@ -253,7 +250,7 @@ class KSampler(Sampler):
# so the actual formula for indexing into sigmas:
# sigma_index = (steps-index)
s_index = t_enc - index - 1
self.model_wrap.prepare_to_sample(s_index, edited_conditioning=edited_conditioning, conditioning_edit_opcodes=conditioning_edit_opcodes)
self.model_wrap.prepare_to_sample(s_index, structured_conditioning=structured_conditioning)
img = K.sampling.__dict__[f'_{self.schedule}'](
self.model_wrap,
img,