mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
rename override/restore methods to better reflect what they actually do
This commit is contained in:
parent
17d73d09c0
commit
d044d4c577
@ -306,7 +306,7 @@ class InvokeAICrossAttentionMixin:
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def remove_cross_attention_control(model, is_running_diffusers: bool, restore_attention_processor: Optional[AttnProcessor]=None):
|
def restore_default_cross_attention(model, is_running_diffusers: bool, restore_attention_processor: Optional[AttnProcessor]=None):
|
||||||
if is_running_diffusers:
|
if is_running_diffusers:
|
||||||
unet = model
|
unet = model
|
||||||
unet.set_attn_processor(restore_attention_processor or CrossAttnProcessor())
|
unet.set_attn_processor(restore_attention_processor or CrossAttnProcessor())
|
||||||
@ -314,7 +314,7 @@ def remove_cross_attention_control(model, is_running_diffusers: bool, restore_at
|
|||||||
remove_attention_function(model)
|
remove_attention_function(model)
|
||||||
|
|
||||||
|
|
||||||
def setup_cross_attention_control(model, context: Context, is_running_diffusers = False):
|
def override_cross_attention(model, context: Context, is_running_diffusers = False):
|
||||||
"""
|
"""
|
||||||
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
||||||
|
|
||||||
|
@ -19,9 +19,9 @@ class DDIMSampler(Sampler):
|
|||||||
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
|
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
|
||||||
|
|
||||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||||
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = all_timesteps_count)
|
self.invokeai_diffuser.override_cross_attention(extra_conditioning_info, step_count = all_timesteps_count)
|
||||||
else:
|
else:
|
||||||
self.invokeai_diffuser.remove_cross_attention_control()
|
self.invokeai_diffuser.restore_default_cross_attention()
|
||||||
|
|
||||||
|
|
||||||
# This is the central routine
|
# This is the central routine
|
||||||
|
@ -43,9 +43,9 @@ class CFGDenoiser(nn.Module):
|
|||||||
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
||||||
|
|
||||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||||
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = t_enc)
|
self.invokeai_diffuser.override_cross_attention(extra_conditioning_info, step_count = t_enc)
|
||||||
else:
|
else:
|
||||||
self.invokeai_diffuser.remove_cross_attention_control()
|
self.invokeai_diffuser.restore_default_cross_attention()
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||||
|
@ -21,9 +21,9 @@ class PLMSSampler(Sampler):
|
|||||||
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
|
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
|
||||||
|
|
||||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||||
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = all_timesteps_count)
|
self.invokeai_diffuser.override_cross_attention(extra_conditioning_info, step_count = all_timesteps_count)
|
||||||
else:
|
else:
|
||||||
self.invokeai_diffuser.remove_cross_attention_control()
|
self.invokeai_diffuser.restore_default_cross_attention()
|
||||||
|
|
||||||
|
|
||||||
# this is the essential routine
|
# this is the essential routine
|
||||||
|
@ -9,7 +9,7 @@ import torch
|
|||||||
|
|
||||||
from diffusers.models.cross_attention import AttnProcessor
|
from diffusers.models.cross_attention import AttnProcessor
|
||||||
from ldm.models.diffusion.cross_attention_control import Arguments, \
|
from ldm.models.diffusion.cross_attention_control import Arguments, \
|
||||||
remove_cross_attention_control, setup_cross_attention_control, Context, get_cross_attention_modules, \
|
restore_default_cross_attention, override_cross_attention, Context, get_cross_attention_modules, \
|
||||||
CrossAttentionType, SwapCrossAttnContext
|
CrossAttentionType, SwapCrossAttnContext
|
||||||
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||||
|
|
||||||
@ -64,17 +64,17 @@ class InvokeAIDiffuserComponent:
|
|||||||
do_swap = extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
|
do_swap = extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
|
||||||
old_attn_processor = None
|
old_attn_processor = None
|
||||||
if do_swap:
|
if do_swap:
|
||||||
old_attn_processor = self.setup_cross_attention_control(extra_conditioning_info,
|
old_attn_processor = self.override_cross_attention(extra_conditioning_info,
|
||||||
step_count=step_count)
|
step_count=step_count)
|
||||||
try:
|
try:
|
||||||
yield None
|
yield None
|
||||||
finally:
|
finally:
|
||||||
if old_attn_processor is not None:
|
if old_attn_processor is not None:
|
||||||
self.remove_cross_attention_control(old_attn_processor)
|
self.restore_default_cross_attention(old_attn_processor)
|
||||||
# TODO resuscitate attention map saving
|
# TODO resuscitate attention map saving
|
||||||
#self.remove_attention_map_saving()
|
#self.remove_attention_map_saving()
|
||||||
|
|
||||||
def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int) -> Dict[str, AttnProcessor]:
|
def override_cross_attention(self, conditioning: ExtraConditioningInfo, step_count: int) -> Dict[str, AttnProcessor]:
|
||||||
"""
|
"""
|
||||||
setup cross attention .swap control. for diffusers this replaces the attention processor, so
|
setup cross attention .swap control. for diffusers this replaces the attention processor, so
|
||||||
the previous attention processor is returned so that the caller can restore it later.
|
the previous attention processor is returned so that the caller can restore it later.
|
||||||
@ -84,16 +84,16 @@ class InvokeAIDiffuserComponent:
|
|||||||
arguments=self.conditioning.cross_attention_control_args,
|
arguments=self.conditioning.cross_attention_control_args,
|
||||||
step_count=step_count
|
step_count=step_count
|
||||||
)
|
)
|
||||||
return setup_cross_attention_control(self.model,
|
return override_cross_attention(self.model,
|
||||||
self.cross_attention_control_context,
|
self.cross_attention_control_context,
|
||||||
is_running_diffusers=self.is_running_diffusers)
|
is_running_diffusers=self.is_running_diffusers)
|
||||||
|
|
||||||
def remove_cross_attention_control(self, restore_attention_processor: Optional['AttnProcessor']=None):
|
def restore_default_cross_attention(self, restore_attention_processor: Optional['AttnProcessor']=None):
|
||||||
self.conditioning = None
|
self.conditioning = None
|
||||||
self.cross_attention_control_context = None
|
self.cross_attention_control_context = None
|
||||||
remove_cross_attention_control(self.model,
|
restore_default_cross_attention(self.model,
|
||||||
is_running_diffusers=self.is_running_diffusers,
|
is_running_diffusers=self.is_running_diffusers,
|
||||||
restore_attention_processor=restore_attention_processor)
|
restore_attention_processor=restore_attention_processor)
|
||||||
|
|
||||||
def setup_attention_map_saving(self, saver: AttentionMapSaver):
|
def setup_attention_map_saving(self, saver: AttentionMapSaver):
|
||||||
def callback(slice, dim, offset, slice_size, key):
|
def callback(slice, dim, offset, slice_size, key):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user