diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index 99ef1d49bc..cecbc869e9 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -108,7 +108,7 @@ class Context: return self.tokens_cross_attention_action == Context.Action.APPLY return False - def get_active_cross_attention_control_types_for_step(self, percent_through:float=None)\ + def get_active_cross_attention_control_types_for_step(self, percent_through:Optional[float]=None, step_size:Optional[float]=None)\ -> list[CrossAttentionType]: """ Should cross-attention control be applied on the given step? @@ -117,6 +117,11 @@ class Context: """ if percent_through is None: return [CrossAttentionType.SELF, CrossAttentionType.TOKENS] + if step_size is not None: + # adjust percent_through to ignore the first step + percent_through = (percent_through - step_size) / (1.0 - step_size) + if percent_through < 0: + return [] opts = self.arguments.edit_options to_control = [] diff --git a/ldm/models/diffusion/shared_invokeai_diffusion.py b/ldm/models/diffusion/shared_invokeai_diffusion.py index 0c91df9528..c4a571a21d 100644 --- a/ldm/models/diffusion/shared_invokeai_diffusion.py +++ b/ldm/models/diffusion/shared_invokeai_diffusion.py @@ -141,13 +141,16 @@ class InvokeAIDiffuserComponent: if step_index is not None and total_step_count is not None: # 🧨diffusers codepath percent_through = step_index / total_step_count # will never reach 1.0 - this is deliberate + step_size_percent = 1 / total_step_count else: # legacy compvis codepath # TODO remove when compvis codepath support is dropped if step_index is None and sigma is None: raise ValueError(f"Either step_index or sigma is required when doing cross attention control, but both are None.") percent_through = self.estimate_percent_through(step_index, sigma) - cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through) + # legacy code path supports s_* so we don't need step_size_percent + step_size_percent = None + cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through, step_size=step_size_percent) wants_cross_attention_control = (len(cross_attention_control_types_to_do) > 0) wants_hybrid_conditioning = isinstance(conditioning, dict)