mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
for k* samplers, estimate step_index from sigma
This commit is contained in:
parent
7d677a63b8
commit
04d93f0445
@ -80,19 +80,17 @@ class CrossAttentionControl:
|
||||
TOKENS = 2
|
||||
|
||||
@classmethod
|
||||
def get_active_cross_attention_control_types_for_step(cls, context: 'CrossAttentionControl.Context', step_index:int=None)\
|
||||
def get_active_cross_attention_control_types_for_step(cls, context: 'CrossAttentionControl.Context', percent_through:float=None)\
|
||||
-> list['CrossAttentionControl.CrossAttentionType']:
|
||||
"""
|
||||
Should cross-attention control be applied on the given step?
|
||||
:param step_index: The step index (counts upwards from 0), or None if unknown.
|
||||
:param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0.
|
||||
:return: A list of attention types that cross-attention control should be performed for on the given step. May be [].
|
||||
"""
|
||||
if step_index is None:
|
||||
if percent_through is None:
|
||||
return [cls.CrossAttentionType.SELF, cls.CrossAttentionType.TOKENS]
|
||||
|
||||
opts = context.arguments.edit_options
|
||||
# percent_through will never reach 1.0 (but this is intended)
|
||||
percent_through = float(step_index)/float(context.step_count)
|
||||
to_control = []
|
||||
if opts['s_start'] <= percent_through and percent_through < opts['s_end']:
|
||||
to_control.append(cls.CrossAttentionType.SELF)
|
||||
|
@ -42,9 +42,9 @@ class CFGDenoiser(nn.Module):
|
||||
self.invokeai_diffuser.remove_cross_attention_control()
|
||||
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale, step_index):
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||
|
||||
next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale, step_index)
|
||||
next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale)
|
||||
|
||||
# apply threshold
|
||||
if self.warmup < self.warmup_max:
|
||||
|
@ -60,9 +60,8 @@ class PLMSSampler(Sampler):
|
||||
# damian0815 would like to know when/if this code path is used
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
# step_index is expected to count up while index counts down
|
||||
# step_index counts in the opposite direction to index
|
||||
step_index = step_count-(index+1)
|
||||
# note that step_index == 0 is evaluated twice with different x
|
||||
e_t = self.invokeai_diffuser.do_diffusion_step(x, t,
|
||||
unconditional_conditioning, c,
|
||||
unconditional_guidance_scale,
|
||||
|
@ -57,7 +57,8 @@ class InvokeAIDiffuserComponent:
|
||||
def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
|
||||
unconditioning: torch.Tensor, conditioning: torch.Tensor,
|
||||
unconditional_guidance_scale: float,
|
||||
step_index: int=None):
|
||||
step_index: int=None
|
||||
):
|
||||
"""
|
||||
:param x: Current latents
|
||||
:param sigma: aka t, passed to the internal model to control how much denoising will occur
|
||||
@ -72,7 +73,17 @@ class InvokeAIDiffuserComponent:
|
||||
cross_attention_control_types_to_do = []
|
||||
|
||||
if self.cross_attention_control_context is not None:
|
||||
cross_attention_control_types_to_do = CrossAttentionControl.get_active_cross_attention_control_types_for_step(self.cross_attention_control_context, step_index)
|
||||
if step_index is not None:
|
||||
# percent_through will never reach 1.0 (but this is intended)
|
||||
percent_through = float(step_index) / float(self.cross_attention_control_context.step_count)
|
||||
else:
|
||||
# find the current sigma in the sigma sequence
|
||||
sigma_index = torch.nonzero(self.model.sigmas <= sigma)[-1]
|
||||
# flip because sigmas[0] is for the fully denoised image
|
||||
# percent_through must be <1
|
||||
percent_through = 1.0 - float(sigma_index.item() + 1) / float(self.model.sigmas.shape[0])
|
||||
print('estimated percent_through', percent_through, 'from sigma', sigma)
|
||||
cross_attention_control_types_to_do = CrossAttentionControl.get_active_cross_attention_control_types_for_step(self.cross_attention_control_context, percent_through)
|
||||
|
||||
if len(cross_attention_control_types_to_do)==0:
|
||||
print('step', step_index, ': not doing cross attention control')
|
||||
|
Loading…
Reference in New Issue
Block a user