rename override/restore methods to better reflect what they actually do

This commit is contained in:
Damian Stewart 2023-01-30 16:23:44 +01:00
parent 17d73d09c0
commit d044d4c577
5 changed files with 20 additions and 20 deletions

View File

@ -306,7 +306,7 @@ class InvokeAICrossAttentionMixin:
def remove_cross_attention_control(model, is_running_diffusers: bool, restore_attention_processor: Optional[AttnProcessor]=None): def restore_default_cross_attention(model, is_running_diffusers: bool, restore_attention_processor: Optional[AttnProcessor]=None):
if is_running_diffusers: if is_running_diffusers:
unet = model unet = model
unet.set_attn_processor(restore_attention_processor or CrossAttnProcessor()) unet.set_attn_processor(restore_attention_processor or CrossAttnProcessor())
@ -314,7 +314,7 @@ def remove_cross_attention_control(model, is_running_diffusers: bool, restore_at
remove_attention_function(model) remove_attention_function(model)
def setup_cross_attention_control(model, context: Context, is_running_diffusers = False): def override_cross_attention(model, context: Context, is_running_diffusers = False):
""" """
Inject attention parameters and functions into the passed in model to enable cross attention editing. Inject attention parameters and functions into the passed in model to enable cross attention editing.

View File

@ -19,9 +19,9 @@ class DDIMSampler(Sampler):
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc) all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = all_timesteps_count) self.invokeai_diffuser.override_cross_attention(extra_conditioning_info, step_count = all_timesteps_count)
else: else:
self.invokeai_diffuser.remove_cross_attention_control() self.invokeai_diffuser.restore_default_cross_attention()
# This is the central routine # This is the central routine

View File

@ -43,9 +43,9 @@ class CFGDenoiser(nn.Module):
extra_conditioning_info = kwargs.get('extra_conditioning_info', None) extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = t_enc) self.invokeai_diffuser.override_cross_attention(extra_conditioning_info, step_count = t_enc)
else: else:
self.invokeai_diffuser.remove_cross_attention_control() self.invokeai_diffuser.restore_default_cross_attention()
def forward(self, x, sigma, uncond, cond, cond_scale): def forward(self, x, sigma, uncond, cond, cond_scale):

View File

@ -21,9 +21,9 @@ class PLMSSampler(Sampler):
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc) all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = all_timesteps_count) self.invokeai_diffuser.override_cross_attention(extra_conditioning_info, step_count = all_timesteps_count)
else: else:
self.invokeai_diffuser.remove_cross_attention_control() self.invokeai_diffuser.restore_default_cross_attention()
# this is the essential routine # this is the essential routine

View File

@ -9,7 +9,7 @@ import torch
from diffusers.models.cross_attention import AttnProcessor from diffusers.models.cross_attention import AttnProcessor
from ldm.models.diffusion.cross_attention_control import Arguments, \ from ldm.models.diffusion.cross_attention_control import Arguments, \
remove_cross_attention_control, setup_cross_attention_control, Context, get_cross_attention_modules, \ restore_default_cross_attention, override_cross_attention, Context, get_cross_attention_modules, \
CrossAttentionType, SwapCrossAttnContext CrossAttentionType, SwapCrossAttnContext
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
@ -64,17 +64,17 @@ class InvokeAIDiffuserComponent:
do_swap = extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control do_swap = extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
old_attn_processor = None old_attn_processor = None
if do_swap: if do_swap:
old_attn_processor = self.setup_cross_attention_control(extra_conditioning_info, old_attn_processor = self.override_cross_attention(extra_conditioning_info,
step_count=step_count) step_count=step_count)
try: try:
yield None yield None
finally: finally:
if old_attn_processor is not None: if old_attn_processor is not None:
self.remove_cross_attention_control(old_attn_processor) self.restore_default_cross_attention(old_attn_processor)
# TODO resuscitate attention map saving # TODO resuscitate attention map saving
#self.remove_attention_map_saving() #self.remove_attention_map_saving()
def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int) -> Dict[str, AttnProcessor]: def override_cross_attention(self, conditioning: ExtraConditioningInfo, step_count: int) -> Dict[str, AttnProcessor]:
""" """
setup cross attention .swap control. for diffusers this replaces the attention processor, so setup cross attention .swap control. for diffusers this replaces the attention processor, so
the previous attention processor is returned so that the caller can restore it later. the previous attention processor is returned so that the caller can restore it later.
@ -84,16 +84,16 @@ class InvokeAIDiffuserComponent:
arguments=self.conditioning.cross_attention_control_args, arguments=self.conditioning.cross_attention_control_args,
step_count=step_count step_count=step_count
) )
return setup_cross_attention_control(self.model, return override_cross_attention(self.model,
self.cross_attention_control_context, self.cross_attention_control_context,
is_running_diffusers=self.is_running_diffusers) is_running_diffusers=self.is_running_diffusers)
def remove_cross_attention_control(self, restore_attention_processor: Optional['AttnProcessor']=None): def restore_default_cross_attention(self, restore_attention_processor: Optional['AttnProcessor']=None):
self.conditioning = None self.conditioning = None
self.cross_attention_control_context = None self.cross_attention_control_context = None
remove_cross_attention_control(self.model, restore_default_cross_attention(self.model,
is_running_diffusers=self.is_running_diffusers, is_running_diffusers=self.is_running_diffusers,
restore_attention_processor=restore_attention_processor) restore_attention_processor=restore_attention_processor)
def setup_attention_map_saving(self, saver: AttentionMapSaver): def setup_attention_map_saving(self, saver: AttentionMapSaver):
def callback(slice, dim, offset, slice_size, key): def callback(slice, dim, offset, slice_size, key):