for k* samplers, estimate step_index from sigma

This commit is contained in:
Damian at mba 2022-10-23 16:26:50 +02:00
parent 7d677a63b8
commit 04d93f0445
4 changed files with 19 additions and 11 deletions

View File

@ -80,19 +80,17 @@ class CrossAttentionControl:
TOKENS = 2
@classmethod
def get_active_cross_attention_control_types_for_step(cls, context: 'CrossAttentionControl.Context', step_index:int=None)\
def get_active_cross_attention_control_types_for_step(cls, context: 'CrossAttentionControl.Context', percent_through:float=None)\
-> list['CrossAttentionControl.CrossAttentionType']:
"""
Should cross-attention control be applied on the given step?
:param step_index: The step index (counts upwards from 0), or None if unknown.
:param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0.
:return: A list of attention types that cross-attention control should be performed for on the given step. May be [].
"""
if step_index is None:
if percent_through is None:
return [cls.CrossAttentionType.SELF, cls.CrossAttentionType.TOKENS]
opts = context.arguments.edit_options
# percent_through will never reach 1.0 (but this is intended)
percent_through = float(step_index)/float(context.step_count)
to_control = []
if opts['s_start'] <= percent_through and percent_through < opts['s_end']:
to_control.append(cls.CrossAttentionType.SELF)

View File

@ -42,9 +42,9 @@ class CFGDenoiser(nn.Module):
self.invokeai_diffuser.remove_cross_attention_control()
def forward(self, x, sigma, uncond, cond, cond_scale, step_index):
def forward(self, x, sigma, uncond, cond, cond_scale):
next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale, step_index)
next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale)
# apply threshold
if self.warmup < self.warmup_max:

View File

@ -60,9 +60,8 @@ class PLMSSampler(Sampler):
# damian0815 would like to know when/if this code path is used
e_t = self.model.apply_model(x, t, c)
else:
# step_index is expected to count up while index counts down
# step_index counts in the opposite direction to index
step_index = step_count-(index+1)
# note that step_index == 0 is evaluated twice with different x
e_t = self.invokeai_diffuser.do_diffusion_step(x, t,
unconditional_conditioning, c,
unconditional_guidance_scale,

View File

@ -57,7 +57,8 @@ class InvokeAIDiffuserComponent:
def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
unconditioning: torch.Tensor, conditioning: torch.Tensor,
unconditional_guidance_scale: float,
step_index: int=None):
step_index: int=None
):
"""
:param x: Current latents
:param sigma: aka t, passed to the internal model to control how much denoising will occur
@ -72,7 +73,17 @@ class InvokeAIDiffuserComponent:
cross_attention_control_types_to_do = []
if self.cross_attention_control_context is not None:
cross_attention_control_types_to_do = CrossAttentionControl.get_active_cross_attention_control_types_for_step(self.cross_attention_control_context, step_index)
if step_index is not None:
# percent_through will never reach 1.0 (but this is intended)
percent_through = float(step_index) / float(self.cross_attention_control_context.step_count)
else:
# find the current sigma in the sigma sequence
sigma_index = torch.nonzero(self.model.sigmas <= sigma)[-1]
# flip because sigmas[0] is for the fully denoised image
# percent_through must be <1
percent_through = 1.0 - float(sigma_index.item() + 1) / float(self.model.sigmas.shape[0])
print('estimated percent_through', percent_through, 'from sigma', sigma)
cross_attention_control_types_to_do = CrossAttentionControl.get_active_cross_attention_control_types_for_step(self.cross_attention_control_context, percent_through)
if len(cross_attention_control_types_to_do)==0:
print('step', step_index, ': not doing cross attention control')