mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactored single diffusion path seems to be working for all samplers
This commit is contained in:
@ -34,18 +34,17 @@ class CFGDenoiser(nn.Module):
|
||||
|
||||
def prepare_to_sample(self, t_enc, **kwargs):
|
||||
|
||||
edited_conditioning = kwargs.get('edited_conditioning', None)
|
||||
structured_conditioning = kwargs.get('structured_conditioning', None)
|
||||
|
||||
if edited_conditioning is not None:
|
||||
conditioning_edit_opcodes = kwargs.get('conditioning_edit_opcodes', None)
|
||||
self.invokeai_diffuser.setup_cross_attention_control(edited_conditioning, conditioning_edit_opcodes)
|
||||
if structured_conditioning is not None and structured_conditioning.wants_cross_attention_control:
|
||||
self.invokeai_diffuser.setup_cross_attention_control(structured_conditioning)
|
||||
else:
|
||||
self.invokeai_diffuser.cleanup_cross_attention_control()
|
||||
self.invokeai_diffuser.remove_cross_attention_control()
|
||||
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||
|
||||
final_next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale)
|
||||
next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale)
|
||||
|
||||
# apply threshold
|
||||
if self.warmup < self.warmup_max:
|
||||
@ -55,7 +54,7 @@ class CFGDenoiser(nn.Module):
|
||||
thresh = self.threshold
|
||||
if thresh > self.threshold:
|
||||
thresh = self.threshold
|
||||
return cfg_apply_threshold(final_next_x, thresh)
|
||||
return cfg_apply_threshold(next_x, thresh)
|
||||
|
||||
|
||||
|
||||
@ -165,8 +164,7 @@ class KSampler(Sampler):
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
edited_conditioning=None,
|
||||
conditioning_edit_opcodes=None,
|
||||
structured_conditioning=None,
|
||||
threshold = 0,
|
||||
perlin = 0,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
@ -199,7 +197,7 @@ class KSampler(Sampler):
|
||||
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
|
||||
|
||||
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
|
||||
model_wrap_cfg.prepare_to_sample(S, edited_conditioning=edited_conditioning, conditioning_edit_opcodes=conditioning_edit_opcodes)
|
||||
model_wrap_cfg.prepare_to_sample(S, structured_conditioning=structured_conditioning)
|
||||
extra_args = {
|
||||
'cond': conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
@ -226,8 +224,7 @@ class KSampler(Sampler):
|
||||
index,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
edited_conditioning=None,
|
||||
conditioning_edit_opcodes=None,
|
||||
structured_conditioning=None,
|
||||
**kwargs,
|
||||
):
|
||||
if self.model_wrap is None:
|
||||
@ -253,7 +250,7 @@ class KSampler(Sampler):
|
||||
# so the actual formula for indexing into sigmas:
|
||||
# sigma_index = (steps-index)
|
||||
s_index = t_enc - index - 1
|
||||
self.model_wrap.prepare_to_sample(s_index, edited_conditioning=edited_conditioning, conditioning_edit_opcodes=conditioning_edit_opcodes)
|
||||
self.model_wrap.prepare_to_sample(s_index, structured_conditioning=structured_conditioning)
|
||||
img = K.sampling.__dict__[f'_{self.schedule}'](
|
||||
self.model_wrap,
|
||||
img,
|
||||
|
Reference in New Issue
Block a user