with diffusers cac, always run the original prompt on the first step

This commit is contained in:
Damian Stewart 2023-01-30 14:50:57 +01:00
parent 5e7ed964d2
commit 27ee939e4b
2 changed files with 10 additions and 2 deletions

View File

@ -108,7 +108,7 @@ class Context:
return self.tokens_cross_attention_action == Context.Action.APPLY
return False
def get_active_cross_attention_control_types_for_step(self, percent_through:float=None)\
def get_active_cross_attention_control_types_for_step(self, percent_through:Optional[float]=None, step_size:Optional[float]=None)\
-> list[CrossAttentionType]:
"""
Should cross-attention control be applied on the given step?
@ -117,6 +117,11 @@ class Context:
"""
if percent_through is None:
return [CrossAttentionType.SELF, CrossAttentionType.TOKENS]
if step_size is not None:
# adjust percent_through to ignore the first step
percent_through = (percent_through - step_size) / (1.0 - step_size)
if percent_through < 0:
return []
opts = self.arguments.edit_options
to_control = []

View File

@ -141,13 +141,16 @@ class InvokeAIDiffuserComponent:
if step_index is not None and total_step_count is not None:
# 🧨diffusers codepath
percent_through = step_index / total_step_count # will never reach 1.0 - this is deliberate
step_size_percent = 1 / total_step_count
else:
# legacy compvis codepath
# TODO remove when compvis codepath support is dropped
if step_index is None and sigma is None:
raise ValueError(f"Either step_index or sigma is required when doing cross attention control, but both are None.")
percent_through = self.estimate_percent_through(step_index, sigma)
cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through)
# legacy code path supports s_* so we don't need step_size_percent
step_size_percent = None
cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through, step_size=step_size_percent)
wants_cross_attention_control = (len(cross_attention_control_types_to_do) > 0)
wants_hybrid_conditioning = isinstance(conditioning, dict)