mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
for k* samplers, estimate step_index from sigma
This commit is contained in:
parent
7d677a63b8
commit
04d93f0445
@ -80,19 +80,17 @@ class CrossAttentionControl:
|
|||||||
TOKENS = 2
|
TOKENS = 2
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_active_cross_attention_control_types_for_step(cls, context: 'CrossAttentionControl.Context', step_index:int=None)\
|
def get_active_cross_attention_control_types_for_step(cls, context: 'CrossAttentionControl.Context', percent_through:float=None)\
|
||||||
-> list['CrossAttentionControl.CrossAttentionType']:
|
-> list['CrossAttentionControl.CrossAttentionType']:
|
||||||
"""
|
"""
|
||||||
Should cross-attention control be applied on the given step?
|
Should cross-attention control be applied on the given step?
|
||||||
:param step_index: The step index (counts upwards from 0), or None if unknown.
|
:param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0.
|
||||||
:return: A list of attention types that cross-attention control should be performed for on the given step. May be [].
|
:return: A list of attention types that cross-attention control should be performed for on the given step. May be [].
|
||||||
"""
|
"""
|
||||||
if step_index is None:
|
if percent_through is None:
|
||||||
return [cls.CrossAttentionType.SELF, cls.CrossAttentionType.TOKENS]
|
return [cls.CrossAttentionType.SELF, cls.CrossAttentionType.TOKENS]
|
||||||
|
|
||||||
opts = context.arguments.edit_options
|
opts = context.arguments.edit_options
|
||||||
# percent_through will never reach 1.0 (but this is intended)
|
|
||||||
percent_through = float(step_index)/float(context.step_count)
|
|
||||||
to_control = []
|
to_control = []
|
||||||
if opts['s_start'] <= percent_through and percent_through < opts['s_end']:
|
if opts['s_start'] <= percent_through and percent_through < opts['s_end']:
|
||||||
to_control.append(cls.CrossAttentionType.SELF)
|
to_control.append(cls.CrossAttentionType.SELF)
|
||||||
|
@ -42,9 +42,9 @@ class CFGDenoiser(nn.Module):
|
|||||||
self.invokeai_diffuser.remove_cross_attention_control()
|
self.invokeai_diffuser.remove_cross_attention_control()
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale, step_index):
|
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||||
|
|
||||||
next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale, step_index)
|
next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale)
|
||||||
|
|
||||||
# apply threshold
|
# apply threshold
|
||||||
if self.warmup < self.warmup_max:
|
if self.warmup < self.warmup_max:
|
||||||
|
@ -60,9 +60,8 @@ class PLMSSampler(Sampler):
|
|||||||
# damian0815 would like to know when/if this code path is used
|
# damian0815 would like to know when/if this code path is used
|
||||||
e_t = self.model.apply_model(x, t, c)
|
e_t = self.model.apply_model(x, t, c)
|
||||||
else:
|
else:
|
||||||
# step_index is expected to count up while index counts down
|
# step_index counts in the opposite direction to index
|
||||||
step_index = step_count-(index+1)
|
step_index = step_count-(index+1)
|
||||||
# note that step_index == 0 is evaluated twice with different x
|
|
||||||
e_t = self.invokeai_diffuser.do_diffusion_step(x, t,
|
e_t = self.invokeai_diffuser.do_diffusion_step(x, t,
|
||||||
unconditional_conditioning, c,
|
unconditional_conditioning, c,
|
||||||
unconditional_guidance_scale,
|
unconditional_guidance_scale,
|
||||||
|
@ -57,7 +57,8 @@ class InvokeAIDiffuserComponent:
|
|||||||
def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
|
def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
|
||||||
unconditioning: torch.Tensor, conditioning: torch.Tensor,
|
unconditioning: torch.Tensor, conditioning: torch.Tensor,
|
||||||
unconditional_guidance_scale: float,
|
unconditional_guidance_scale: float,
|
||||||
step_index: int=None):
|
step_index: int=None
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
:param x: Current latents
|
:param x: Current latents
|
||||||
:param sigma: aka t, passed to the internal model to control how much denoising will occur
|
:param sigma: aka t, passed to the internal model to control how much denoising will occur
|
||||||
@ -72,7 +73,17 @@ class InvokeAIDiffuserComponent:
|
|||||||
cross_attention_control_types_to_do = []
|
cross_attention_control_types_to_do = []
|
||||||
|
|
||||||
if self.cross_attention_control_context is not None:
|
if self.cross_attention_control_context is not None:
|
||||||
cross_attention_control_types_to_do = CrossAttentionControl.get_active_cross_attention_control_types_for_step(self.cross_attention_control_context, step_index)
|
if step_index is not None:
|
||||||
|
# percent_through will never reach 1.0 (but this is intended)
|
||||||
|
percent_through = float(step_index) / float(self.cross_attention_control_context.step_count)
|
||||||
|
else:
|
||||||
|
# find the current sigma in the sigma sequence
|
||||||
|
sigma_index = torch.nonzero(self.model.sigmas <= sigma)[-1]
|
||||||
|
# flip because sigmas[0] is for the fully denoised image
|
||||||
|
# percent_through must be <1
|
||||||
|
percent_through = 1.0 - float(sigma_index.item() + 1) / float(self.model.sigmas.shape[0])
|
||||||
|
print('estimated percent_through', percent_through, 'from sigma', sigma)
|
||||||
|
cross_attention_control_types_to_do = CrossAttentionControl.get_active_cross_attention_control_types_for_step(self.cross_attention_control_context, percent_through)
|
||||||
|
|
||||||
if len(cross_attention_control_types_to_do)==0:
|
if len(cross_attention_control_types_to_do)==0:
|
||||||
print('step', step_index, ': not doing cross attention control')
|
print('step', step_index, ': not doing cross attention control')
|
||||||
|
Loading…
x
Reference in New Issue
Block a user