diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index 770c71f110..37f0ebfa1d 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -306,7 +306,7 @@ class InvokeAICrossAttentionMixin: -def remove_cross_attention_control(model, is_running_diffusers: bool, restore_attention_processor: Optional[AttnProcessor]=None): +def restore_default_cross_attention(model, is_running_diffusers: bool, restore_attention_processor: Optional[AttnProcessor]=None): if is_running_diffusers: unet = model unet.set_attn_processor(restore_attention_processor or CrossAttnProcessor()) @@ -314,7 +314,7 @@ def remove_cross_attention_control(model, is_running_diffusers: bool, restore_at remove_attention_function(model) -def setup_cross_attention_control(model, context: Context, is_running_diffusers = False): +def override_cross_attention(model, context: Context, is_running_diffusers = False): """ Inject attention parameters and functions into the passed in model to enable cross attention editing. diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index e5a502f977..304009c1d3 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -19,9 +19,9 @@ class DDIMSampler(Sampler): all_timesteps_count = kwargs.get('all_timesteps_count', t_enc) if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: - self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = all_timesteps_count) + self.invokeai_diffuser.override_cross_attention(extra_conditioning_info, step_count = all_timesteps_count) else: - self.invokeai_diffuser.remove_cross_attention_control() + self.invokeai_diffuser.restore_default_cross_attention() # This is the central routine diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index 0038c481e8..f98ca8de21 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -43,9 +43,9 @@ class CFGDenoiser(nn.Module): extra_conditioning_info = kwargs.get('extra_conditioning_info', None) if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: - self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = t_enc) + self.invokeai_diffuser.override_cross_attention(extra_conditioning_info, step_count = t_enc) else: - self.invokeai_diffuser.remove_cross_attention_control() + self.invokeai_diffuser.restore_default_cross_attention() def forward(self, x, sigma, uncond, cond, cond_scale): diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py index 5124badcd1..9edd333780 100644 --- a/ldm/models/diffusion/plms.py +++ b/ldm/models/diffusion/plms.py @@ -21,9 +21,9 @@ class PLMSSampler(Sampler): all_timesteps_count = kwargs.get('all_timesteps_count', t_enc) if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: - self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = all_timesteps_count) + self.invokeai_diffuser.override_cross_attention(extra_conditioning_info, step_count = all_timesteps_count) else: - self.invokeai_diffuser.remove_cross_attention_control() + self.invokeai_diffuser.restore_default_cross_attention() # this is the essential routine diff --git a/ldm/models/diffusion/shared_invokeai_diffusion.py b/ldm/models/diffusion/shared_invokeai_diffusion.py index 1ecbd1c488..f37bec789e 100644 --- a/ldm/models/diffusion/shared_invokeai_diffusion.py +++ b/ldm/models/diffusion/shared_invokeai_diffusion.py @@ -9,7 +9,7 @@ import torch from diffusers.models.cross_attention import AttnProcessor from ldm.models.diffusion.cross_attention_control import Arguments, \ - remove_cross_attention_control, setup_cross_attention_control, Context, get_cross_attention_modules, \ + restore_default_cross_attention, override_cross_attention, Context, get_cross_attention_modules, \ CrossAttentionType, SwapCrossAttnContext from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver @@ -64,17 +64,17 @@ class InvokeAIDiffuserComponent: do_swap = extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control old_attn_processor = None if do_swap: - old_attn_processor = self.setup_cross_attention_control(extra_conditioning_info, - step_count=step_count) + old_attn_processor = self.override_cross_attention(extra_conditioning_info, + step_count=step_count) try: yield None finally: if old_attn_processor is not None: - self.remove_cross_attention_control(old_attn_processor) + self.restore_default_cross_attention(old_attn_processor) # TODO resuscitate attention map saving #self.remove_attention_map_saving() - def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int) -> Dict[str, AttnProcessor]: + def override_cross_attention(self, conditioning: ExtraConditioningInfo, step_count: int) -> Dict[str, AttnProcessor]: """ setup cross attention .swap control. for diffusers this replaces the attention processor, so the previous attention processor is returned so that the caller can restore it later. @@ -84,16 +84,16 @@ class InvokeAIDiffuserComponent: arguments=self.conditioning.cross_attention_control_args, step_count=step_count ) - return setup_cross_attention_control(self.model, - self.cross_attention_control_context, - is_running_diffusers=self.is_running_diffusers) + return override_cross_attention(self.model, + self.cross_attention_control_context, + is_running_diffusers=self.is_running_diffusers) - def remove_cross_attention_control(self, restore_attention_processor: Optional['AttnProcessor']=None): + def restore_default_cross_attention(self, restore_attention_processor: Optional['AttnProcessor']=None): self.conditioning = None self.cross_attention_control_context = None - remove_cross_attention_control(self.model, - is_running_diffusers=self.is_running_diffusers, - restore_attention_processor=restore_attention_processor) + restore_default_cross_attention(self.model, + is_running_diffusers=self.is_running_diffusers, + restore_attention_processor=restore_attention_processor) def setup_attention_map_saving(self, saver: AttentionMapSaver): def callback(slice, dim, offset, slice_size, key):