Merge branch 'release-candidate-2-1-3' of github.com:/invoke-ai/InvokeAI into release-candidate-2-1-3

This commit is contained in:
Lincoln Stein 2022-11-09 17:26:24 +00:00
commit 2dd6fc2b93
5 changed files with 249 additions and 145 deletions

View File

@ -802,6 +802,10 @@ class Generate:
# the model cache does the loading and offloading # the model cache does the loading and offloading
cache = self.model_cache cache = self.model_cache
if not cache.valid_model(model_name):
print(f'** "{model_name}" is not a known model name. Please check your models.yaml file')
return self.model
cache.print_vram_usage() cache.print_vram_usage()
# have to get rid of all references to model in order # have to get rid of all references to model in order

View File

@ -41,15 +41,22 @@ class ModelCache(object):
self.stack = [] # this is an LRU FIFO self.stack = [] # this is an LRU FIFO
self.current_model = None self.current_model = None
def valid_model(self, model_name:str)->bool:
'''
Given a model name, returns True if it is a valid
identifier.
'''
return model_name in self.config
def get_model(self, model_name:str): def get_model(self, model_name:str):
''' '''
Given a model named identified in models.yaml, return Given a model named identified in models.yaml, return
the model object. If in RAM will load into GPU VRAM. the model object. If in RAM will load into GPU VRAM.
If on disk, will load from there. If on disk, will load from there.
''' '''
if model_name not in self.config: if not self.valid_model(model_name):
print(f'** "{model_name}" is not a known model name. Please check your models.yaml file') print(f'** "{model_name}" is not a known model name. Please check your models.yaml file')
return None return self.current_model
if self.current_model != model_name: if self.current_model != model_name:
if model_name not in self.models: # make room for a new one if model_name not in self.models: # make room for a new one

View File

@ -1,10 +1,13 @@
from enum import Enum import enum
from typing import Optional
import torch import torch
# adapted from bloc97's CrossAttentionControl colab # adapted from bloc97's CrossAttentionControl colab
# https://github.com/bloc97/CrossAttentionControl # https://github.com/bloc97/CrossAttentionControl
class CrossAttentionControl: class CrossAttentionControl:
class Arguments: class Arguments:
@ -27,7 +30,14 @@ class CrossAttentionControl:
print('warning: cross-attention control options are not working properly for >1 edit') print('warning: cross-attention control options are not working properly for >1 edit')
self.edit_options = non_none_edit_options[0] self.edit_options = non_none_edit_options[0]
class Context: class Context:
class Action(enum.Enum):
NONE = 0
SAVE = 1,
APPLY = 2
def __init__(self, arguments: 'CrossAttentionControl.Arguments', step_count: int): def __init__(self, arguments: 'CrossAttentionControl.Arguments', step_count: int):
""" """
:param arguments: Arguments for the cross-attention control process :param arguments: Arguments for the cross-attention control process
@ -36,14 +46,124 @@ class CrossAttentionControl:
self.arguments = arguments self.arguments = arguments
self.step_count = step_count self.step_count = step_count
self.self_cross_attention_module_identifiers = []
self.tokens_cross_attention_module_identifiers = []
self.saved_cross_attention_maps = {}
self.clear_requests(cleanup=True)
def register_cross_attention_modules(self, model):
for name,module in CrossAttentionControl.get_attention_modules(model,
CrossAttentionControl.CrossAttentionType.SELF):
self.self_cross_attention_module_identifiers.append(name)
for name,module in CrossAttentionControl.get_attention_modules(model,
CrossAttentionControl.CrossAttentionType.TOKENS):
self.tokens_cross_attention_module_identifiers.append(name)
def request_save_attention_maps(self, cross_attention_type: 'CrossAttentionControl.CrossAttentionType'):
if cross_attention_type == CrossAttentionControl.CrossAttentionType.SELF:
self.self_cross_attention_action = CrossAttentionControl.Context.Action.SAVE
else:
self.tokens_cross_attention_action = CrossAttentionControl.Context.Action.SAVE
def request_apply_saved_attention_maps(self, cross_attention_type: 'CrossAttentionControl.CrossAttentionType'):
if cross_attention_type == CrossAttentionControl.CrossAttentionType.SELF:
self.self_cross_attention_action = CrossAttentionControl.Context.Action.APPLY
else:
self.tokens_cross_attention_action = CrossAttentionControl.Context.Action.APPLY
def is_tokens_cross_attention(self, module_identifier) -> bool:
return module_identifier in self.tokens_cross_attention_module_identifiers
def get_should_save_maps(self, module_identifier: str) -> bool:
if module_identifier in self.self_cross_attention_module_identifiers:
return self.self_cross_attention_action == CrossAttentionControl.Context.Action.SAVE
elif module_identifier in self.tokens_cross_attention_module_identifiers:
return self.tokens_cross_attention_action == CrossAttentionControl.Context.Action.SAVE
return False
def get_should_apply_saved_maps(self, module_identifier: str) -> bool:
if module_identifier in self.self_cross_attention_module_identifiers:
return self.self_cross_attention_action == CrossAttentionControl.Context.Action.APPLY
elif module_identifier in self.tokens_cross_attention_module_identifiers:
return self.tokens_cross_attention_action == CrossAttentionControl.Context.Action.APPLY
return False
def get_active_cross_attention_control_types_for_step(self, percent_through:float=None)\
-> list['CrossAttentionControl.CrossAttentionType']:
"""
Should cross-attention control be applied on the given step?
:param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0.
:return: A list of attention types that cross-attention control should be performed for on the given step. May be [].
"""
if percent_through is None:
return [CrossAttentionControl.CrossAttentionType.SELF, CrossAttentionControl.CrossAttentionType.TOKENS]
opts = self.arguments.edit_options
to_control = []
if opts['s_start'] <= percent_through and percent_through < opts['s_end']:
to_control.append(CrossAttentionControl.CrossAttentionType.SELF)
if opts['t_start'] <= percent_through and percent_through < opts['t_end']:
to_control.append(CrossAttentionControl.CrossAttentionType.TOKENS)
return to_control
def save_slice(self, identifier: str, slice: torch.Tensor, dim: Optional[int], offset: int,
slice_size: Optional[int]):
if identifier not in self.saved_cross_attention_maps:
self.saved_cross_attention_maps[identifier] = {
'dim': dim,
'slice_size': slice_size,
'slices': {offset or 0: slice}
}
else:
self.saved_cross_attention_maps[identifier]['slices'][offset or 0] = slice
def get_slice(self, identifier: str, requested_dim: Optional[int], requested_offset: int, slice_size: int):
saved_attention_dict = self.saved_cross_attention_maps[identifier]
if requested_dim is None:
if saved_attention_dict['dim'] is not None:
raise RuntimeError(f"dim mismatch: expected dim=None, have {saved_attention_dict['dim']}")
return saved_attention_dict['slices'][0]
if saved_attention_dict['dim'] == requested_dim:
if slice_size != saved_attention_dict['slice_size']:
raise RuntimeError(
f"slice_size mismatch: expected slice_size={slice_size}, have {saved_attention_dict['slice_size']}")
return saved_attention_dict['slices'][requested_offset]
if saved_attention_dict['dim'] == None:
whole_saved_attention = saved_attention_dict['slices'][0]
if requested_dim == 0:
return whole_saved_attention[requested_offset:requested_offset + slice_size]
elif requested_dim == 1:
return whole_saved_attention[:, requested_offset:requested_offset + slice_size]
raise RuntimeError(f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}")
def get_slicing_strategy(self, identifier: str) -> Optional[tuple[int, int]]:
saved_attention = self.saved_cross_attention_maps.get(identifier, None)
if saved_attention is None:
return None, None
return saved_attention['dim'], saved_attention['slice_size']
def clear_requests(self, cleanup=True):
self.tokens_cross_attention_action = CrossAttentionControl.Context.Action.NONE
self.self_cross_attention_action = CrossAttentionControl.Context.Action.NONE
if cleanup:
self.saved_cross_attention_maps = {}
def offload_saved_attention_slices_to_cpu(self):
for key, map_dict in self.saved_cross_attention_maps.items():
for offset, slice in map_dict['slices'].items():
map_dict[offset] = slice.to('cpu')
@classmethod @classmethod
def remove_cross_attention_control(cls, model): def remove_cross_attention_control(cls, model):
cls.remove_attention_function(model) cls.remove_attention_function(model)
@classmethod @classmethod
def setup_cross_attention_control(cls, model, def setup_cross_attention_control(cls, model, context: Context):
cross_attention_control_args: Arguments
):
""" """
Inject attention parameters and functions into the passed in model to enable cross attention editing. Inject attention parameters and functions into the passed in model to enable cross attention editing.
@ -53,7 +173,7 @@ class CrossAttentionControl:
""" """
# adapted from init_attention_edit # adapted from init_attention_edit
device = cross_attention_control_args.edited_conditioning.device device = context.arguments.edited_conditioning.device
# urgh. should this be hardcoded? # urgh. should this be hardcoded?
max_length = 77 max_length = 77
@ -61,141 +181,82 @@ class CrossAttentionControl:
mask = torch.zeros(max_length) mask = torch.zeros(max_length)
indices_target = torch.arange(max_length, dtype=torch.long) indices_target = torch.arange(max_length, dtype=torch.long)
indices = torch.zeros(max_length, dtype=torch.long) indices = torch.zeros(max_length, dtype=torch.long)
for name, a0, a1, b0, b1 in cross_attention_control_args.edit_opcodes: for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
if b0 < max_length: if b0 < max_length:
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0): if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
# these tokens have not been edited # these tokens have not been edited
indices[b0:b1] = indices_target[a0:a1] indices[b0:b1] = indices_target[a0:a1]
mask[b0:b1] = 1 mask[b0:b1] = 1
cls.inject_attention_function(model) context.register_cross_attention_modules(model)
context.cross_attention_mask = mask.to(device)
for m in cls.get_attention_modules(model, cls.CrossAttentionType.SELF): context.cross_attention_index_map = indices.to(device)
m.last_attn_slice_mask = None cls.inject_attention_function(model, context)
m.last_attn_slice_indices = None
for m in cls.get_attention_modules(model, cls.CrossAttentionType.TOKENS):
m.last_attn_slice_mask = mask.to(device)
m.last_attn_slice_indices = indices.to(device)
class CrossAttentionType(Enum): class CrossAttentionType(enum.Enum):
SELF = 1 SELF = 1
TOKENS = 2 TOKENS = 2
@classmethod
def get_active_cross_attention_control_types_for_step(cls, context: 'CrossAttentionControl.Context', percent_through:float=None)\
-> list['CrossAttentionControl.CrossAttentionType']:
"""
Should cross-attention control be applied on the given step?
:param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0.
:return: A list of attention types that cross-attention control should be performed for on the given step. May be [].
"""
if percent_through is None:
return [cls.CrossAttentionType.SELF, cls.CrossAttentionType.TOKENS]
opts = context.arguments.edit_options
to_control = []
if opts['s_start'] <= percent_through and percent_through < opts['s_end']:
to_control.append(cls.CrossAttentionType.SELF)
if opts['t_start'] <= percent_through and percent_through < opts['t_end']:
to_control.append(cls.CrossAttentionType.TOKENS)
return to_control
@classmethod @classmethod
def get_attention_modules(cls, model, which: CrossAttentionType): def get_attention_modules(cls, model, which: CrossAttentionType):
which_attn = "attn1" if which is cls.CrossAttentionType.SELF else "attn2" which_attn = "attn1" if which is cls.CrossAttentionType.SELF else "attn2"
return [module for name, module in model.named_modules() if return [(name,module) for name, module in model.named_modules() if
type(module).__name__ == "CrossAttention" and which_attn in name] type(module).__name__ == "CrossAttention" and which_attn in name]
@classmethod
def clear_requests(cls, model, clear_attn_slice=True):
self_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.SELF)
tokens_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.TOKENS)
for m in self_attention_modules+tokens_attention_modules:
m.save_last_attn_slice = False
m.use_last_attn_slice = False
if clear_attn_slice:
m.last_attn_slice = None
@classmethod @classmethod
def request_save_attention_maps(cls, model, cross_attention_type: CrossAttentionType): def inject_attention_function(cls, unet, context: 'CrossAttentionControl.Context'):
modules = cls.get_attention_modules(model, cross_attention_type)
for m in modules:
# clear out the saved slice in case the outermost dim changes
m.last_attn_slice = None
m.save_last_attn_slice = True
@classmethod
def request_apply_saved_attention_maps(cls, model, cross_attention_type: CrossAttentionType):
modules = cls.get_attention_modules(model, cross_attention_type)
for m in modules:
m.use_last_attn_slice = True
@classmethod
def inject_attention_function(cls, unet):
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276 # ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
def attention_slice_wrangler(self, attention_scores, suggested_attention_slice, dim, offset, slice_size): def attention_slice_wrangler(module, suggested_attention_slice:torch.Tensor, dim, offset, slice_size):
#print("in wrangler with suggested_attention_slice shape", suggested_attention_slice.shape, "dim", dim) #memory_usage = suggested_attention_slice.element_size() * suggested_attention_slice.nelement()
attn_slice = suggested_attention_slice attention_slice = suggested_attention_slice
if dim is not None:
start = offset
end = start+slice_size
#print(f"in wrangler, sliced dim {dim} {start}-{end}, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}")
#else:
# print(f"in wrangler, whole, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}")
if self.use_last_attn_slice: if context.get_should_save_maps(module.identifier):
if dim is None: #print(module.identifier, "saving suggested_attention_slice of shape",
last_attn_slice = self.last_attn_slice # suggested_attention_slice.shape, "dim", dim, "offset", offset)
# print("took whole slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape) slice_to_save = attention_slice.to('cpu') if dim is not None else attention_slice
context.save_slice(module.identifier, slice_to_save, dim=dim, offset=offset, slice_size=slice_size)
elif context.get_should_apply_saved_maps(module.identifier):
#print(module.identifier, "applying saved attention slice for dim", dim, "offset", offset)
saved_attention_slice = context.get_slice(module.identifier, dim, offset, slice_size)
# slice may have been offloaded to CPU
saved_attention_slice = saved_attention_slice.to(suggested_attention_slice.device)
if context.is_tokens_cross_attention(module.identifier):
index_map = context.cross_attention_index_map
remapped_saved_attention_slice = torch.index_select(saved_attention_slice, -1, index_map)
this_attention_slice = suggested_attention_slice
mask = context.cross_attention_mask
saved_mask = mask
this_mask = 1 - mask
attention_slice = remapped_saved_attention_slice * saved_mask + \
this_attention_slice * this_mask
else: else:
last_attn_slice = self.last_attn_slice[offset]
if self.last_attn_slice_mask is None:
# just use everything # just use everything
attn_slice = last_attn_slice attention_slice = saved_attention_slice
else:
last_attn_slice_mask = self.last_attn_slice_mask
remapped_last_attn_slice = torch.index_select(last_attn_slice, -1, self.last_attn_slice_indices)
this_attn_slice = attn_slice return attention_slice
this_attn_slice_mask = 1 - last_attn_slice_mask
attn_slice = this_attn_slice * this_attn_slice_mask + \
remapped_last_attn_slice * last_attn_slice_mask
if self.save_last_attn_slice:
if dim is None:
self.last_attn_slice = attn_slice
else:
if self.last_attn_slice is None:
self.last_attn_slice = { offset: attn_slice }
else:
self.last_attn_slice[offset] = attn_slice
return attn_slice
for name, module in unet.named_modules(): for name, module in unet.named_modules():
module_name = type(module).__name__ module_name = type(module).__name__
if module_name == "CrossAttention": if module_name == "CrossAttention":
module.last_attn_slice = None module.identifier = name
module.last_attn_slice_indices = None
module.last_attn_slice_mask = None
module.use_last_attn_weights = False
module.use_last_attn_slice = False
module.save_last_attn_slice = False
module.set_attention_slice_wrangler(attention_slice_wrangler) module.set_attention_slice_wrangler(attention_slice_wrangler)
module.set_slicing_strategy_getter(lambda module, module_identifier=name: \
context.get_slicing_strategy(module_identifier))
@classmethod @classmethod
def remove_attention_function(cls, unet): def remove_attention_function(cls, unet):
# clear wrangler callback
for name, module in unet.named_modules(): for name, module in unet.named_modules():
module_name = type(module).__name__ module_name = type(module).__name__
if module_name == "CrossAttention": if module_name == "CrossAttention":
module.set_attention_slice_wrangler(None) module.set_attention_slice_wrangler(None)
module.set_slicing_strategy_getter(None)

View File

@ -1,9 +1,11 @@
import traceback
from math import ceil from math import ceil
from typing import Callable, Optional, Union from typing import Callable, Optional, Union
import torch import torch
from ldm.models.diffusion.cross_attention_control import CrossAttentionControl from ldm.models.diffusion.cross_attention_control import CrossAttentionControl
from ldm.modules.attention import get_mem_free_total
class InvokeAIDiffuserComponent: class InvokeAIDiffuserComponent:
@ -34,7 +36,7 @@ class InvokeAIDiffuserComponent:
""" """
self.model = model self.model = model
self.model_forward_callback = model_forward_callback self.model_forward_callback = model_forward_callback
self.cross_attention_control_context = None
def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int): def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int):
self.conditioning = conditioning self.conditioning = conditioning
@ -42,11 +44,7 @@ class InvokeAIDiffuserComponent:
arguments=self.conditioning.cross_attention_control_args, arguments=self.conditioning.cross_attention_control_args,
step_count=step_count step_count=step_count
) )
CrossAttentionControl.setup_cross_attention_control(self.model, CrossAttentionControl.setup_cross_attention_control(self.model, self.cross_attention_control_context)
cross_attention_control_args=self.conditioning.cross_attention_control_args
)
#todo: refactor edited_conditioning, edit_opcodes, edit_options into a struct
#todo: apply edit_options using step_count
def remove_cross_attention_control(self): def remove_cross_attention_control(self):
self.conditioning = None self.conditioning = None
@ -54,6 +52,7 @@ class InvokeAIDiffuserComponent:
CrossAttentionControl.remove_cross_attention_control(self.model) CrossAttentionControl.remove_cross_attention_control(self.model)
def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor, def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
unconditioning: Union[torch.Tensor,dict], unconditioning: Union[torch.Tensor,dict],
conditioning: Union[torch.Tensor,dict], conditioning: Union[torch.Tensor,dict],
@ -70,12 +69,12 @@ class InvokeAIDiffuserComponent:
:return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning. :return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
""" """
CrossAttentionControl.clear_requests(self.model)
cross_attention_control_types_to_do = [] cross_attention_control_types_to_do = []
context: CrossAttentionControl.Context = self.cross_attention_control_context
if self.cross_attention_control_context is not None: if self.cross_attention_control_context is not None:
percent_through = self.estimate_percent_through(step_index, sigma) percent_through = self.estimate_percent_through(step_index, sigma)
cross_attention_control_types_to_do = CrossAttentionControl.get_active_cross_attention_control_types_for_step(self.cross_attention_control_context, percent_through) cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through)
wants_cross_attention_control = (len(cross_attention_control_types_to_do) > 0) wants_cross_attention_control = (len(cross_attention_control_types_to_do) > 0)
wants_hybrid_conditioning = isinstance(conditioning, dict) wants_hybrid_conditioning = isinstance(conditioning, dict)
@ -124,7 +123,7 @@ class InvokeAIDiffuserComponent:
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x
def apply_cross_attention_controlled_conditioning(self, x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do): def apply_cross_attention_controlled_conditioning(self, x:torch.Tensor, sigma, unconditioning, conditioning, cross_attention_control_types_to_do):
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do) # print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
# slower non-batched path (20% slower on mac MPS) # slower non-batched path (20% slower on mac MPS)
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of # We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
@ -134,32 +133,32 @@ class InvokeAIDiffuserComponent:
# representing batched uncond + cond, but then when it comes to applying the saved attention, the # representing batched uncond + cond, but then when it comes to applying the saved attention, the
# wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.) # wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.)
# todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well. # todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well.
context:CrossAttentionControl.Context = self.cross_attention_control_context
try: try:
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning) unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
# process x using the original prompt, saving the attention maps # process x using the original prompt, saving the attention maps
for type in cross_attention_control_types_to_do: #print("saving attention maps for", cross_attention_control_types_to_do)
CrossAttentionControl.request_save_attention_maps(self.model, type) for ca_type in cross_attention_control_types_to_do:
context.request_save_attention_maps(ca_type)
_ = self.model_forward_callback(x, sigma, conditioning) _ = self.model_forward_callback(x, sigma, conditioning)
CrossAttentionControl.clear_requests(self.model, clear_attn_slice=False) context.clear_requests(cleanup=False)
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied # process x again, using the saved attention maps to control where self.edited_conditioning will be applied
for type in cross_attention_control_types_to_do: #print("applying saved attention maps for", cross_attention_control_types_to_do)
CrossAttentionControl.request_apply_saved_attention_maps(self.model, type) for ca_type in cross_attention_control_types_to_do:
context.request_apply_saved_attention_maps(ca_type)
edited_conditioning = self.conditioning.cross_attention_control_args.edited_conditioning edited_conditioning = self.conditioning.cross_attention_control_args.edited_conditioning
conditioned_next_x = self.model_forward_callback(x, sigma, edited_conditioning) conditioned_next_x = self.model_forward_callback(x, sigma, edited_conditioning)
context.clear_requests(cleanup=True)
CrossAttentionControl.clear_requests(self.model) except:
context.clear_requests(cleanup=True)
raise
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x
except RuntimeError:
# make sure we clean out the attention slices we're storing on the model
# TODO don't store things on the model
CrossAttentionControl.clear_requests(self.model)
raise
def estimate_percent_through(self, step_index, sigma): def estimate_percent_through(self, step_index, sigma):
if step_index is not None and self.cross_attention_control_context is not None: if step_index is not None and self.cross_attention_control_context is not None:
# percent_through will never reach 1.0 (but this is intended) # percent_through will never reach 1.0 (but this is intended)

View File

@ -1,6 +1,6 @@
from inspect import isfunction from inspect import isfunction
import math import math
from typing import Callable from typing import Callable, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -151,6 +151,17 @@ class SpatialSelfAttention(nn.Module):
return x+h_ return x+h_
def get_mem_free_total(device):
#only on cuda
if not torch.cuda.is_available():
return None
stats = torch.cuda.memory_stats(device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(device)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
return mem_free_total
class CrossAttention(nn.Module): class CrossAttention(nn.Module):
@ -173,31 +184,43 @@ class CrossAttention(nn.Module):
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30) self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
self.cached_mem_free_total = None
self.attention_slice_wrangler = None self.attention_slice_wrangler = None
self.slicing_strategy_getter = None
def set_attention_slice_wrangler(self, wrangler:Callable[[nn.Module, torch.Tensor, torch.Tensor, int, int, int], torch.Tensor]): def set_attention_slice_wrangler(self, wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]]):
''' '''
Set custom attention calculator to be called when attention is calculated Set custom attention calculator to be called when attention is calculated
:param wrangler: Callback, with args (self, attention_scores, suggested_attention_slice, dim, offset, slice_size), :param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
which returns either the suggested_attention_slice or an adjusted equivalent. which returns either the suggested_attention_slice or an adjusted equivalent.
self is the current CrossAttention module for which the callback is being invoked. `module` is the current CrossAttention module for which the callback is being invoked.
attention_scores are the scores for attention `suggested_attention_slice` is the default-calculated attention slice
suggested_attention_slice is a softmax(dim=-1) over attention_scores `dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
dim is -1 if the call is non-sliced, or 0 or 1 for dimension-0 or dimension-1 slicing. If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
If dim is >= 0, offset and slice_size specify the slice start and length.
Pass None to use the default attention calculation. Pass None to use the default attention calculation.
:return: :return:
''' '''
self.attention_slice_wrangler = wrangler self.attention_slice_wrangler = wrangler
def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int,int]]]):
self.slicing_strategy_getter = getter
def cache_free_memory_count(self, device):
self.cached_mem_free_total = get_mem_free_total(device)
print("free cuda memory: ", self.cached_mem_free_total)
def clear_cached_free_memory_count(self):
self.cached_mem_free_total = None
def einsum_lowest_level(self, q, k, v, dim, offset, slice_size): def einsum_lowest_level(self, q, k, v, dim, offset, slice_size):
# calculate attention scores # calculate attention scores
attention_scores = einsum('b i d, b j d -> b i j', q, k) attention_scores = einsum('b i d, b j d -> b i j', q, k)
# calculate attenion slice by taking the best scores for each latent pixel # calculate attention slice by taking the best scores for each latent pixel
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype) default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
if self.attention_slice_wrangler is not None: attention_slice_wrangler = self.attention_slice_wrangler
attention_slice = self.attention_slice_wrangler(self, attention_scores, default_attention_slice, dim, offset, slice_size) if attention_slice_wrangler is not None:
attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size)
else: else:
attention_slice = default_attention_slice attention_slice = default_attention_slice
@ -240,17 +263,27 @@ class CrossAttention(nn.Module):
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1)) return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1))
def einsum_op_cuda(self, q, k, v): def einsum_op_cuda(self, q, k, v):
stats = torch.cuda.memory_stats(q.device) # check if we already have a slicing strategy (this should only happen during cross-attention controlled generation)
mem_active = stats['active_bytes.all.current'] slicing_strategy_getter = self.slicing_strategy_getter
mem_reserved = stats['reserved_bytes.all.current'] if slicing_strategy_getter is not None:
mem_free_cuda, _ = torch.cuda.mem_get_info(q.device) (dim, slice_size) = slicing_strategy_getter(self)
mem_free_torch = mem_reserved - mem_active if dim is not None:
mem_free_total = mem_free_cuda + mem_free_torch # print("using saved slicing strategy with dim", dim, "slice size", slice_size)
if dim == 0:
return self.einsum_op_slice_dim0(q, k, v, slice_size)
elif dim == 1:
return self.einsum_op_slice_dim1(q, k, v, slice_size)
# fallback for when there is no saved strategy, or saved strategy does not slice
mem_free_total = self.cached_mem_free_total or get_mem_free_total(q.device)
# Divide factor of safety as there's copying and fragmentation # Divide factor of safety as there's copying and fragmentation
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
def get_attention_mem_efficient(self, q, k, v): def get_attention_mem_efficient(self, q, k, v):
if q.device.type == 'cuda': if q.device.type == 'cuda':
torch.cuda.empty_cache()
#print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device))
return self.einsum_op_cuda(q, k, v) return self.einsum_op_cuda(q, k, v)
if q.device.type == 'mps': if q.device.type == 'mps':