mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix #1362 by improving VRAM usage patterns when doing .swap()
commit ef3f7a26e242b73c2beb0195c7fd8f654ef47f55 Author: damian0815 <null@damianstewart.com> Date: Tue Nov 8 12:18:37 2022 +0100 remove log spam commit 7189d649622d4668b120b0dd278388ad672142c4 Author: damian0815 <null@damianstewart.com> Date: Tue Nov 8 12:10:28 2022 +0100 change the way saved slicing strategy is applied commit 01c40f751ab72955140165c16f95ae411732265b Author: damian0815 <null@damianstewart.com> Date: Tue Nov 8 12:04:43 2022 +0100 fix slicing_strategy_getter callsite commit f8cfe25150a346958903316bc710737d99839923 Author: damian0815 <null@damianstewart.com> Date: Tue Nov 8 11:56:22 2022 +0100 cleanup, consistent dim=0 also tested commit 5bf9b1e890d48e962afd4a668a219b68271e5dc1 Author: damian0815 <null@damianstewart.com> Date: Tue Nov 8 11:34:09 2022 +0100 refactored context, tested with non-sliced cross attention control commit d58a46e39bf562e7459290d2444256e8c08ad0b6 Author: damian0815 <null@damianstewart.com> Date: Sun Nov 6 00:41:52 2022 +0100 cleanup commit 7e2c658b4c06fe239311b65b9bb16fa3adec7fd7 Author: damian0815 <null@damianstewart.com> Date: Sat Nov 5 22:57:31 2022 +0100 disable logs commit 20ee89d93841b070738b3d8a4385c93b097d92eb Author: damian0815 <null@damianstewart.com> Date: Sat Nov 5 22:36:58 2022 +0100 slice saved attention if necessary commit 0a7684a22c880ec0f48cc22bfed4526358f71546 Author: damian0815 <null@damianstewart.com> Date: Sat Nov 5 22:32:38 2022 +0100 raise instead of asserting commit 7083104c7f3a0d8fd96e94a2f391de50a3c942e4 Author: damian0815 <null@damianstewart.com> Date: Sat Nov 5 22:31:00 2022 +0100 store dim when saving slices commit f7c0808ed383ec1dc70645288a798ed2aa4fa85c Author: damian0815 <null@damianstewart.com> Date: Sat Nov 5 22:27:16 2022 +0100 don't retry on exception commit 749a721e939b3fe7c1741e7998dab6bd2c85a0cb Author: damian0815 <null@damianstewart.com> Date: Sat Nov 5 22:24:50 2022 +0100 stuff commit 032ab90e9533be8726301ec91b97137e2aadef9a Author: damian0815 <null@damianstewart.com> Date: Sat Nov 5 22:20:17 2022 +0100 more logging commit 3dc34b387f033482305360e605809d95a40bf6f8 Author: damian0815 <null@damianstewart.com> Date: Sat Nov 5 22:16:47 2022 +0100 logs commit 901c4c1aa4b9bcef695a6551867ec8149e6e6a93 Author: damian0815 <null@damianstewart.com> Date: Sat Nov 5 22:12:39 2022 +0100 actually set save_slicing_strategy to True commit f780e0a0a7c6b6a3db320891064da82589358c8a Author: damian0815 <null@damianstewart.com> Date: Sat Nov 5 22:10:35 2022 +0100 store slicing strategy commit 93bb6d566fd18c5c69ef7dacc8f74ba2cf671cb7 Author: damian <git@damianstewart.com> Date: Sat Nov 5 20:43:48 2022 +0100 still not it commit 5e3a9541f8ae00bde524046963910323e20c40b7 Author: damian <git@damianstewart.com> Date: Sat Nov 5 17:20:02 2022 +0100 wip offloading attention slices on-demand commit 4c2966aa856b6f3b446216da3619ae931552ef08 Author: damian0815 <null@damianstewart.com> Date: Sat Nov 5 15:47:40 2022 +0100 pre-emptive offloading, idk if it works commit 572576755e9f0a878d38e8173e485126c0efbefb Author: root <you@example.com> Date: Sat Nov 5 11:25:32 2022 +0000 push attention slices to cpu. slow but saves memory. commit b57c83a68f2ac03976ebc89ce2ff03812d6d185f Author: damian0815 <null@damianstewart.com> Date: Sat Nov 5 12:04:22 2022 +0100 verbose logging commit 3a5dae116f110a96585d9eb71d713b5ed2bc3d2b Author: damian0815 <null@damianstewart.com> Date: Sat Nov 5 11:50:48 2022 +0100 wip fixing mem strategy crash (4 test on runpod) commit 3cf237db5fae0c7b0b4cc3c47c81830bdb2ae7de Author: damian0815 <null@damianstewart.com> Date: Fri Nov 4 09:02:40 2022 +0100 wip, only works on cuda
This commit is contained in:
parent
5702271991
commit
71bbfe4a1a
@ -1,10 +1,13 @@
|
|||||||
from enum import Enum
|
import enum
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
# adapted from bloc97's CrossAttentionControl colab
|
# adapted from bloc97's CrossAttentionControl colab
|
||||||
# https://github.com/bloc97/CrossAttentionControl
|
# https://github.com/bloc97/CrossAttentionControl
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CrossAttentionControl:
|
class CrossAttentionControl:
|
||||||
|
|
||||||
class Arguments:
|
class Arguments:
|
||||||
@ -27,7 +30,14 @@ class CrossAttentionControl:
|
|||||||
print('warning: cross-attention control options are not working properly for >1 edit')
|
print('warning: cross-attention control options are not working properly for >1 edit')
|
||||||
self.edit_options = non_none_edit_options[0]
|
self.edit_options = non_none_edit_options[0]
|
||||||
|
|
||||||
|
|
||||||
class Context:
|
class Context:
|
||||||
|
|
||||||
|
class Action(enum.Enum):
|
||||||
|
NONE = 0
|
||||||
|
SAVE = 1,
|
||||||
|
APPLY = 2
|
||||||
|
|
||||||
def __init__(self, arguments: 'CrossAttentionControl.Arguments', step_count: int):
|
def __init__(self, arguments: 'CrossAttentionControl.Arguments', step_count: int):
|
||||||
"""
|
"""
|
||||||
:param arguments: Arguments for the cross-attention control process
|
:param arguments: Arguments for the cross-attention control process
|
||||||
@ -36,14 +46,124 @@ class CrossAttentionControl:
|
|||||||
self.arguments = arguments
|
self.arguments = arguments
|
||||||
self.step_count = step_count
|
self.step_count = step_count
|
||||||
|
|
||||||
|
self.self_cross_attention_module_identifiers = []
|
||||||
|
self.tokens_cross_attention_module_identifiers = []
|
||||||
|
|
||||||
|
self.saved_cross_attention_maps = {}
|
||||||
|
|
||||||
|
self.clear_requests(cleanup=True)
|
||||||
|
|
||||||
|
def register_cross_attention_modules(self, model):
|
||||||
|
for name,module in CrossAttentionControl.get_attention_modules(model,
|
||||||
|
CrossAttentionControl.CrossAttentionType.SELF):
|
||||||
|
self.self_cross_attention_module_identifiers.append(name)
|
||||||
|
for name,module in CrossAttentionControl.get_attention_modules(model,
|
||||||
|
CrossAttentionControl.CrossAttentionType.TOKENS):
|
||||||
|
self.tokens_cross_attention_module_identifiers.append(name)
|
||||||
|
|
||||||
|
def request_save_attention_maps(self, cross_attention_type: 'CrossAttentionControl.CrossAttentionType'):
|
||||||
|
if cross_attention_type == CrossAttentionControl.CrossAttentionType.SELF:
|
||||||
|
self.self_cross_attention_action = CrossAttentionControl.Context.Action.SAVE
|
||||||
|
else:
|
||||||
|
self.tokens_cross_attention_action = CrossAttentionControl.Context.Action.SAVE
|
||||||
|
|
||||||
|
def request_apply_saved_attention_maps(self, cross_attention_type: 'CrossAttentionControl.CrossAttentionType'):
|
||||||
|
if cross_attention_type == CrossAttentionControl.CrossAttentionType.SELF:
|
||||||
|
self.self_cross_attention_action = CrossAttentionControl.Context.Action.APPLY
|
||||||
|
else:
|
||||||
|
self.tokens_cross_attention_action = CrossAttentionControl.Context.Action.APPLY
|
||||||
|
|
||||||
|
def is_tokens_cross_attention(self, module_identifier) -> bool:
|
||||||
|
return module_identifier in self.tokens_cross_attention_module_identifiers
|
||||||
|
|
||||||
|
def get_should_save_maps(self, module_identifier: str) -> bool:
|
||||||
|
if module_identifier in self.self_cross_attention_module_identifiers:
|
||||||
|
return self.self_cross_attention_action == CrossAttentionControl.Context.Action.SAVE
|
||||||
|
elif module_identifier in self.tokens_cross_attention_module_identifiers:
|
||||||
|
return self.tokens_cross_attention_action == CrossAttentionControl.Context.Action.SAVE
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_should_apply_saved_maps(self, module_identifier: str) -> bool:
|
||||||
|
if module_identifier in self.self_cross_attention_module_identifiers:
|
||||||
|
return self.self_cross_attention_action == CrossAttentionControl.Context.Action.APPLY
|
||||||
|
elif module_identifier in self.tokens_cross_attention_module_identifiers:
|
||||||
|
return self.tokens_cross_attention_action == CrossAttentionControl.Context.Action.APPLY
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_active_cross_attention_control_types_for_step(self, percent_through:float=None)\
|
||||||
|
-> list['CrossAttentionControl.CrossAttentionType']:
|
||||||
|
"""
|
||||||
|
Should cross-attention control be applied on the given step?
|
||||||
|
:param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0.
|
||||||
|
:return: A list of attention types that cross-attention control should be performed for on the given step. May be [].
|
||||||
|
"""
|
||||||
|
if percent_through is None:
|
||||||
|
return [CrossAttentionControl.CrossAttentionType.SELF, CrossAttentionControl.CrossAttentionType.TOKENS]
|
||||||
|
|
||||||
|
opts = self.arguments.edit_options
|
||||||
|
to_control = []
|
||||||
|
if opts['s_start'] <= percent_through and percent_through < opts['s_end']:
|
||||||
|
to_control.append(CrossAttentionControl.CrossAttentionType.SELF)
|
||||||
|
if opts['t_start'] <= percent_through and percent_through < opts['t_end']:
|
||||||
|
to_control.append(CrossAttentionControl.CrossAttentionType.TOKENS)
|
||||||
|
return to_control
|
||||||
|
|
||||||
|
def save_slice(self, identifier: str, slice: torch.Tensor, dim: Optional[int], offset: int,
|
||||||
|
slice_size: Optional[int]):
|
||||||
|
if identifier not in self.saved_cross_attention_maps:
|
||||||
|
self.saved_cross_attention_maps[identifier] = {
|
||||||
|
'dim': dim,
|
||||||
|
'slice_size': slice_size,
|
||||||
|
'slices': {offset or 0: slice}
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
self.saved_cross_attention_maps[identifier]['slices'][offset or 0] = slice
|
||||||
|
|
||||||
|
def get_slice(self, identifier: str, requested_dim: Optional[int], requested_offset: int, slice_size: int):
|
||||||
|
saved_attention_dict = self.saved_cross_attention_maps[identifier]
|
||||||
|
if requested_dim is None:
|
||||||
|
if saved_attention_dict['dim'] is not None:
|
||||||
|
raise RuntimeError(f"dim mismatch: expected dim=None, have {saved_attention_dict['dim']}")
|
||||||
|
return saved_attention_dict['slices'][0]
|
||||||
|
|
||||||
|
if saved_attention_dict['dim'] == requested_dim:
|
||||||
|
if slice_size != saved_attention_dict['slice_size']:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"slice_size mismatch: expected slice_size={slice_size}, have {saved_attention_dict['slice_size']}")
|
||||||
|
return saved_attention_dict['slices'][requested_offset]
|
||||||
|
|
||||||
|
if saved_attention_dict['dim'] == None:
|
||||||
|
whole_saved_attention = saved_attention_dict['slices'][0]
|
||||||
|
if requested_dim == 0:
|
||||||
|
return whole_saved_attention[requested_offset:requested_offset + slice_size]
|
||||||
|
elif requested_dim == 1:
|
||||||
|
return whole_saved_attention[:, requested_offset:requested_offset + slice_size]
|
||||||
|
|
||||||
|
raise RuntimeError(f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}")
|
||||||
|
|
||||||
|
def get_slicing_strategy(self, identifier: str) -> Optional[tuple[int, int]]:
|
||||||
|
saved_attention = self.saved_cross_attention_maps.get(identifier, None)
|
||||||
|
if saved_attention is None:
|
||||||
|
return None, None
|
||||||
|
return saved_attention['dim'], saved_attention['slice_size']
|
||||||
|
|
||||||
|
def clear_requests(self, cleanup=True):
|
||||||
|
self.tokens_cross_attention_action = CrossAttentionControl.Context.Action.NONE
|
||||||
|
self.self_cross_attention_action = CrossAttentionControl.Context.Action.NONE
|
||||||
|
if cleanup:
|
||||||
|
self.saved_cross_attention_maps = {}
|
||||||
|
|
||||||
|
def offload_saved_attention_slices_to_cpu(self):
|
||||||
|
for key, map_dict in self.saved_cross_attention_maps.items():
|
||||||
|
for offset, slice in map_dict['slices'].items():
|
||||||
|
map_dict[offset] = slice.to('cpu')
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def remove_cross_attention_control(cls, model):
|
def remove_cross_attention_control(cls, model):
|
||||||
cls.remove_attention_function(model)
|
cls.remove_attention_function(model)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setup_cross_attention_control(cls, model,
|
def setup_cross_attention_control(cls, model, context: Context):
|
||||||
cross_attention_control_args: Arguments
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
||||||
|
|
||||||
@ -53,7 +173,7 @@ class CrossAttentionControl:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# adapted from init_attention_edit
|
# adapted from init_attention_edit
|
||||||
device = cross_attention_control_args.edited_conditioning.device
|
device = context.arguments.edited_conditioning.device
|
||||||
|
|
||||||
# urgh. should this be hardcoded?
|
# urgh. should this be hardcoded?
|
||||||
max_length = 77
|
max_length = 77
|
||||||
@ -61,141 +181,82 @@ class CrossAttentionControl:
|
|||||||
mask = torch.zeros(max_length)
|
mask = torch.zeros(max_length)
|
||||||
indices_target = torch.arange(max_length, dtype=torch.long)
|
indices_target = torch.arange(max_length, dtype=torch.long)
|
||||||
indices = torch.zeros(max_length, dtype=torch.long)
|
indices = torch.zeros(max_length, dtype=torch.long)
|
||||||
for name, a0, a1, b0, b1 in cross_attention_control_args.edit_opcodes:
|
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
|
||||||
if b0 < max_length:
|
if b0 < max_length:
|
||||||
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
|
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
|
||||||
# these tokens have not been edited
|
# these tokens have not been edited
|
||||||
indices[b0:b1] = indices_target[a0:a1]
|
indices[b0:b1] = indices_target[a0:a1]
|
||||||
mask[b0:b1] = 1
|
mask[b0:b1] = 1
|
||||||
|
|
||||||
cls.inject_attention_function(model)
|
context.register_cross_attention_modules(model)
|
||||||
|
context.cross_attention_mask = mask.to(device)
|
||||||
for m in cls.get_attention_modules(model, cls.CrossAttentionType.SELF):
|
context.cross_attention_index_map = indices.to(device)
|
||||||
m.last_attn_slice_mask = None
|
cls.inject_attention_function(model, context)
|
||||||
m.last_attn_slice_indices = None
|
|
||||||
|
|
||||||
for m in cls.get_attention_modules(model, cls.CrossAttentionType.TOKENS):
|
|
||||||
m.last_attn_slice_mask = mask.to(device)
|
|
||||||
m.last_attn_slice_indices = indices.to(device)
|
|
||||||
|
|
||||||
|
|
||||||
class CrossAttentionType(Enum):
|
class CrossAttentionType(enum.Enum):
|
||||||
SELF = 1
|
SELF = 1
|
||||||
TOKENS = 2
|
TOKENS = 2
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_active_cross_attention_control_types_for_step(cls, context: 'CrossAttentionControl.Context', percent_through:float=None)\
|
|
||||||
-> list['CrossAttentionControl.CrossAttentionType']:
|
|
||||||
"""
|
|
||||||
Should cross-attention control be applied on the given step?
|
|
||||||
:param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0.
|
|
||||||
:return: A list of attention types that cross-attention control should be performed for on the given step. May be [].
|
|
||||||
"""
|
|
||||||
if percent_through is None:
|
|
||||||
return [cls.CrossAttentionType.SELF, cls.CrossAttentionType.TOKENS]
|
|
||||||
|
|
||||||
opts = context.arguments.edit_options
|
|
||||||
to_control = []
|
|
||||||
if opts['s_start'] <= percent_through and percent_through < opts['s_end']:
|
|
||||||
to_control.append(cls.CrossAttentionType.SELF)
|
|
||||||
if opts['t_start'] <= percent_through and percent_through < opts['t_end']:
|
|
||||||
to_control.append(cls.CrossAttentionType.TOKENS)
|
|
||||||
return to_control
|
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_attention_modules(cls, model, which: CrossAttentionType):
|
def get_attention_modules(cls, model, which: CrossAttentionType):
|
||||||
which_attn = "attn1" if which is cls.CrossAttentionType.SELF else "attn2"
|
which_attn = "attn1" if which is cls.CrossAttentionType.SELF else "attn2"
|
||||||
return [module for name, module in model.named_modules() if
|
return [(name,module) for name, module in model.named_modules() if
|
||||||
type(module).__name__ == "CrossAttention" and which_attn in name]
|
type(module).__name__ == "CrossAttention" and which_attn in name]
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def clear_requests(cls, model, clear_attn_slice=True):
|
|
||||||
self_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.SELF)
|
|
||||||
tokens_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.TOKENS)
|
|
||||||
for m in self_attention_modules+tokens_attention_modules:
|
|
||||||
m.save_last_attn_slice = False
|
|
||||||
m.use_last_attn_slice = False
|
|
||||||
if clear_attn_slice:
|
|
||||||
m.last_attn_slice = None
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def request_save_attention_maps(cls, model, cross_attention_type: CrossAttentionType):
|
def inject_attention_function(cls, unet, context: 'CrossAttentionControl.Context'):
|
||||||
modules = cls.get_attention_modules(model, cross_attention_type)
|
|
||||||
for m in modules:
|
|
||||||
# clear out the saved slice in case the outermost dim changes
|
|
||||||
m.last_attn_slice = None
|
|
||||||
m.save_last_attn_slice = True
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def request_apply_saved_attention_maps(cls, model, cross_attention_type: CrossAttentionType):
|
|
||||||
modules = cls.get_attention_modules(model, cross_attention_type)
|
|
||||||
for m in modules:
|
|
||||||
m.use_last_attn_slice = True
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def inject_attention_function(cls, unet):
|
|
||||||
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
|
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
|
||||||
|
|
||||||
def attention_slice_wrangler(self, attention_scores, suggested_attention_slice, dim, offset, slice_size):
|
def attention_slice_wrangler(module, suggested_attention_slice:torch.Tensor, dim, offset, slice_size):
|
||||||
|
|
||||||
#print("in wrangler with suggested_attention_slice shape", suggested_attention_slice.shape, "dim", dim)
|
#memory_usage = suggested_attention_slice.element_size() * suggested_attention_slice.nelement()
|
||||||
|
|
||||||
attn_slice = suggested_attention_slice
|
attention_slice = suggested_attention_slice
|
||||||
if dim is not None:
|
|
||||||
start = offset
|
|
||||||
end = start+slice_size
|
|
||||||
#print(f"in wrangler, sliced dim {dim} {start}-{end}, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}")
|
|
||||||
#else:
|
|
||||||
# print(f"in wrangler, whole, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}")
|
|
||||||
|
|
||||||
if self.use_last_attn_slice:
|
if context.get_should_save_maps(module.identifier):
|
||||||
if dim is None:
|
#print(module.identifier, "saving suggested_attention_slice of shape",
|
||||||
last_attn_slice = self.last_attn_slice
|
# suggested_attention_slice.shape, "dim", dim, "offset", offset)
|
||||||
# print("took whole slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
|
slice_to_save = attention_slice.to('cpu') if dim is not None else attention_slice
|
||||||
|
context.save_slice(module.identifier, slice_to_save, dim=dim, offset=offset, slice_size=slice_size)
|
||||||
|
elif context.get_should_apply_saved_maps(module.identifier):
|
||||||
|
#print(module.identifier, "applying saved attention slice for dim", dim, "offset", offset)
|
||||||
|
saved_attention_slice = context.get_slice(module.identifier, dim, offset, slice_size)
|
||||||
|
|
||||||
|
# slice may have been offloaded to CPU
|
||||||
|
saved_attention_slice = saved_attention_slice.to(suggested_attention_slice.device)
|
||||||
|
|
||||||
|
if context.is_tokens_cross_attention(module.identifier):
|
||||||
|
index_map = context.cross_attention_index_map
|
||||||
|
remapped_saved_attention_slice = torch.index_select(saved_attention_slice, -1, index_map)
|
||||||
|
this_attention_slice = suggested_attention_slice
|
||||||
|
|
||||||
|
mask = context.cross_attention_mask
|
||||||
|
saved_mask = mask
|
||||||
|
this_mask = 1 - mask
|
||||||
|
attention_slice = remapped_saved_attention_slice * saved_mask + \
|
||||||
|
this_attention_slice * this_mask
|
||||||
else:
|
else:
|
||||||
last_attn_slice = self.last_attn_slice[offset]
|
|
||||||
|
|
||||||
if self.last_attn_slice_mask is None:
|
|
||||||
# just use everything
|
# just use everything
|
||||||
attn_slice = last_attn_slice
|
attention_slice = saved_attention_slice
|
||||||
else:
|
|
||||||
last_attn_slice_mask = self.last_attn_slice_mask
|
|
||||||
remapped_last_attn_slice = torch.index_select(last_attn_slice, -1, self.last_attn_slice_indices)
|
|
||||||
|
|
||||||
this_attn_slice = attn_slice
|
return attention_slice
|
||||||
this_attn_slice_mask = 1 - last_attn_slice_mask
|
|
||||||
attn_slice = this_attn_slice * this_attn_slice_mask + \
|
|
||||||
remapped_last_attn_slice * last_attn_slice_mask
|
|
||||||
|
|
||||||
if self.save_last_attn_slice:
|
|
||||||
if dim is None:
|
|
||||||
self.last_attn_slice = attn_slice
|
|
||||||
else:
|
|
||||||
if self.last_attn_slice is None:
|
|
||||||
self.last_attn_slice = { offset: attn_slice }
|
|
||||||
else:
|
|
||||||
self.last_attn_slice[offset] = attn_slice
|
|
||||||
|
|
||||||
return attn_slice
|
|
||||||
|
|
||||||
for name, module in unet.named_modules():
|
for name, module in unet.named_modules():
|
||||||
module_name = type(module).__name__
|
module_name = type(module).__name__
|
||||||
if module_name == "CrossAttention":
|
if module_name == "CrossAttention":
|
||||||
module.last_attn_slice = None
|
module.identifier = name
|
||||||
module.last_attn_slice_indices = None
|
|
||||||
module.last_attn_slice_mask = None
|
|
||||||
module.use_last_attn_weights = False
|
|
||||||
module.use_last_attn_slice = False
|
|
||||||
module.save_last_attn_slice = False
|
|
||||||
module.set_attention_slice_wrangler(attention_slice_wrangler)
|
module.set_attention_slice_wrangler(attention_slice_wrangler)
|
||||||
|
module.set_slicing_strategy_getter(lambda module, module_identifier=name: \
|
||||||
|
context.get_slicing_strategy(module_identifier))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def remove_attention_function(cls, unet):
|
def remove_attention_function(cls, unet):
|
||||||
|
# clear wrangler callback
|
||||||
for name, module in unet.named_modules():
|
for name, module in unet.named_modules():
|
||||||
module_name = type(module).__name__
|
module_name = type(module).__name__
|
||||||
if module_name == "CrossAttention":
|
if module_name == "CrossAttention":
|
||||||
module.set_attention_slice_wrangler(None)
|
module.set_attention_slice_wrangler(None)
|
||||||
|
module.set_slicing_strategy_getter(None)
|
||||||
|
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
|
import traceback
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from typing import Callable, Optional, Union
|
from typing import Callable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ldm.models.diffusion.cross_attention_control import CrossAttentionControl
|
from ldm.models.diffusion.cross_attention_control import CrossAttentionControl
|
||||||
|
from ldm.modules.attention import get_mem_free_total
|
||||||
|
|
||||||
|
|
||||||
class InvokeAIDiffuserComponent:
|
class InvokeAIDiffuserComponent:
|
||||||
@ -34,7 +36,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
"""
|
"""
|
||||||
self.model = model
|
self.model = model
|
||||||
self.model_forward_callback = model_forward_callback
|
self.model_forward_callback = model_forward_callback
|
||||||
|
self.cross_attention_control_context = None
|
||||||
|
|
||||||
def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int):
|
def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int):
|
||||||
self.conditioning = conditioning
|
self.conditioning = conditioning
|
||||||
@ -42,11 +44,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
arguments=self.conditioning.cross_attention_control_args,
|
arguments=self.conditioning.cross_attention_control_args,
|
||||||
step_count=step_count
|
step_count=step_count
|
||||||
)
|
)
|
||||||
CrossAttentionControl.setup_cross_attention_control(self.model,
|
CrossAttentionControl.setup_cross_attention_control(self.model, self.cross_attention_control_context)
|
||||||
cross_attention_control_args=self.conditioning.cross_attention_control_args
|
|
||||||
)
|
|
||||||
#todo: refactor edited_conditioning, edit_opcodes, edit_options into a struct
|
|
||||||
#todo: apply edit_options using step_count
|
|
||||||
|
|
||||||
def remove_cross_attention_control(self):
|
def remove_cross_attention_control(self):
|
||||||
self.conditioning = None
|
self.conditioning = None
|
||||||
@ -54,6 +52,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
CrossAttentionControl.remove_cross_attention_control(self.model)
|
CrossAttentionControl.remove_cross_attention_control(self.model)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
|
def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
|
||||||
unconditioning: Union[torch.Tensor,dict],
|
unconditioning: Union[torch.Tensor,dict],
|
||||||
conditioning: Union[torch.Tensor,dict],
|
conditioning: Union[torch.Tensor,dict],
|
||||||
@ -70,12 +69,12 @@ class InvokeAIDiffuserComponent:
|
|||||||
:return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
|
:return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
CrossAttentionControl.clear_requests(self.model)
|
|
||||||
|
|
||||||
cross_attention_control_types_to_do = []
|
cross_attention_control_types_to_do = []
|
||||||
|
context: CrossAttentionControl.Context = self.cross_attention_control_context
|
||||||
if self.cross_attention_control_context is not None:
|
if self.cross_attention_control_context is not None:
|
||||||
percent_through = self.estimate_percent_through(step_index, sigma)
|
percent_through = self.estimate_percent_through(step_index, sigma)
|
||||||
cross_attention_control_types_to_do = CrossAttentionControl.get_active_cross_attention_control_types_for_step(self.cross_attention_control_context, percent_through)
|
cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through)
|
||||||
|
|
||||||
wants_cross_attention_control = (len(cross_attention_control_types_to_do) > 0)
|
wants_cross_attention_control = (len(cross_attention_control_types_to_do) > 0)
|
||||||
wants_hybrid_conditioning = isinstance(conditioning, dict)
|
wants_hybrid_conditioning = isinstance(conditioning, dict)
|
||||||
@ -124,7 +123,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
|
|
||||||
def apply_cross_attention_controlled_conditioning(self, x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do):
|
def apply_cross_attention_controlled_conditioning(self, x:torch.Tensor, sigma, unconditioning, conditioning, cross_attention_control_types_to_do):
|
||||||
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
|
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
|
||||||
# slower non-batched path (20% slower on mac MPS)
|
# slower non-batched path (20% slower on mac MPS)
|
||||||
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
|
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
|
||||||
@ -134,32 +133,30 @@ class InvokeAIDiffuserComponent:
|
|||||||
# representing batched uncond + cond, but then when it comes to applying the saved attention, the
|
# representing batched uncond + cond, but then when it comes to applying the saved attention, the
|
||||||
# wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.)
|
# wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.)
|
||||||
# todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well.
|
# todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well.
|
||||||
|
context:CrossAttentionControl.Context = self.cross_attention_control_context
|
||||||
|
|
||||||
try:
|
try:
|
||||||
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
|
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
|
||||||
|
|
||||||
# process x using the original prompt, saving the attention maps
|
# process x using the original prompt, saving the attention maps
|
||||||
for type in cross_attention_control_types_to_do:
|
#print("saving attention maps for", cross_attention_control_types_to_do)
|
||||||
CrossAttentionControl.request_save_attention_maps(self.model, type)
|
for ca_type in cross_attention_control_types_to_do:
|
||||||
|
context.request_save_attention_maps(ca_type)
|
||||||
_ = self.model_forward_callback(x, sigma, conditioning)
|
_ = self.model_forward_callback(x, sigma, conditioning)
|
||||||
CrossAttentionControl.clear_requests(self.model, clear_attn_slice=False)
|
context.clear_requests(cleanup=False)
|
||||||
|
|
||||||
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
|
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
|
||||||
for type in cross_attention_control_types_to_do:
|
#print("applying saved attention maps for", cross_attention_control_types_to_do)
|
||||||
CrossAttentionControl.request_apply_saved_attention_maps(self.model, type)
|
for ca_type in cross_attention_control_types_to_do:
|
||||||
|
context.request_apply_saved_attention_maps(ca_type)
|
||||||
edited_conditioning = self.conditioning.cross_attention_control_args.edited_conditioning
|
edited_conditioning = self.conditioning.cross_attention_control_args.edited_conditioning
|
||||||
conditioned_next_x = self.model_forward_callback(x, sigma, edited_conditioning)
|
conditioned_next_x = self.model_forward_callback(x, sigma, edited_conditioning)
|
||||||
|
|
||||||
CrossAttentionControl.clear_requests(self.model)
|
finally:
|
||||||
|
context.clear_requests(cleanup=True)
|
||||||
|
|
||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
except RuntimeError:
|
|
||||||
# make sure we clean out the attention slices we're storing on the model
|
|
||||||
# TODO don't store things on the model
|
|
||||||
CrossAttentionControl.clear_requests(self.model)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def estimate_percent_through(self, step_index, sigma):
|
def estimate_percent_through(self, step_index, sigma):
|
||||||
if step_index is not None and self.cross_attention_control_context is not None:
|
if step_index is not None and self.cross_attention_control_context is not None:
|
||||||
# percent_through will never reach 1.0 (but this is intended)
|
# percent_through will never reach 1.0 (but this is intended)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from inspect import isfunction
|
from inspect import isfunction
|
||||||
import math
|
import math
|
||||||
from typing import Callable
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -151,6 +151,17 @@ class SpatialSelfAttention(nn.Module):
|
|||||||
|
|
||||||
return x+h_
|
return x+h_
|
||||||
|
|
||||||
|
def get_mem_free_total(device):
|
||||||
|
#only on cuda
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
return None
|
||||||
|
stats = torch.cuda.memory_stats(device)
|
||||||
|
mem_active = stats['active_bytes.all.current']
|
||||||
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
|
mem_free_cuda, _ = torch.cuda.mem_get_info(device)
|
||||||
|
mem_free_torch = mem_reserved - mem_active
|
||||||
|
mem_free_total = mem_free_cuda + mem_free_torch
|
||||||
|
return mem_free_total
|
||||||
|
|
||||||
|
|
||||||
class CrossAttention(nn.Module):
|
class CrossAttention(nn.Module):
|
||||||
@ -173,31 +184,43 @@ class CrossAttention(nn.Module):
|
|||||||
|
|
||||||
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
||||||
|
|
||||||
|
self.cached_mem_free_total = None
|
||||||
self.attention_slice_wrangler = None
|
self.attention_slice_wrangler = None
|
||||||
|
self.slicing_strategy_getter = None
|
||||||
|
|
||||||
def set_attention_slice_wrangler(self, wrangler:Callable[[nn.Module, torch.Tensor, torch.Tensor, int, int, int], torch.Tensor]):
|
def set_attention_slice_wrangler(self, wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]]):
|
||||||
'''
|
'''
|
||||||
Set custom attention calculator to be called when attention is calculated
|
Set custom attention calculator to be called when attention is calculated
|
||||||
:param wrangler: Callback, with args (self, attention_scores, suggested_attention_slice, dim, offset, slice_size),
|
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
|
||||||
which returns either the suggested_attention_slice or an adjusted equivalent.
|
which returns either the suggested_attention_slice or an adjusted equivalent.
|
||||||
self is the current CrossAttention module for which the callback is being invoked.
|
`module` is the current CrossAttention module for which the callback is being invoked.
|
||||||
attention_scores are the scores for attention
|
`suggested_attention_slice` is the default-calculated attention slice
|
||||||
suggested_attention_slice is a softmax(dim=-1) over attention_scores
|
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
|
||||||
dim is -1 if the call is non-sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
|
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
|
||||||
If dim is >= 0, offset and slice_size specify the slice start and length.
|
|
||||||
|
|
||||||
Pass None to use the default attention calculation.
|
Pass None to use the default attention calculation.
|
||||||
:return:
|
:return:
|
||||||
'''
|
'''
|
||||||
self.attention_slice_wrangler = wrangler
|
self.attention_slice_wrangler = wrangler
|
||||||
|
|
||||||
|
def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int,int]]]):
|
||||||
|
self.slicing_strategy_getter = getter
|
||||||
|
|
||||||
|
def cache_free_memory_count(self, device):
|
||||||
|
self.cached_mem_free_total = get_mem_free_total(device)
|
||||||
|
print("free cuda memory: ", self.cached_mem_free_total)
|
||||||
|
|
||||||
|
def clear_cached_free_memory_count(self):
|
||||||
|
self.cached_mem_free_total = None
|
||||||
|
|
||||||
def einsum_lowest_level(self, q, k, v, dim, offset, slice_size):
|
def einsum_lowest_level(self, q, k, v, dim, offset, slice_size):
|
||||||
# calculate attention scores
|
# calculate attention scores
|
||||||
attention_scores = einsum('b i d, b j d -> b i j', q, k)
|
attention_scores = einsum('b i d, b j d -> b i j', q, k)
|
||||||
# calculate attenion slice by taking the best scores for each latent pixel
|
# calculate attention slice by taking the best scores for each latent pixel
|
||||||
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
|
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
|
||||||
if self.attention_slice_wrangler is not None:
|
attention_slice_wrangler = self.attention_slice_wrangler
|
||||||
attention_slice = self.attention_slice_wrangler(self, attention_scores, default_attention_slice, dim, offset, slice_size)
|
if attention_slice_wrangler is not None:
|
||||||
|
attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size)
|
||||||
else:
|
else:
|
||||||
attention_slice = default_attention_slice
|
attention_slice = default_attention_slice
|
||||||
|
|
||||||
@ -240,17 +263,27 @@ class CrossAttention(nn.Module):
|
|||||||
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1))
|
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1))
|
||||||
|
|
||||||
def einsum_op_cuda(self, q, k, v):
|
def einsum_op_cuda(self, q, k, v):
|
||||||
stats = torch.cuda.memory_stats(q.device)
|
# check if we already have a slicing strategy (this should only happen during cross-attention controlled generation)
|
||||||
mem_active = stats['active_bytes.all.current']
|
slicing_strategy_getter = self.slicing_strategy_getter
|
||||||
mem_reserved = stats['reserved_bytes.all.current']
|
if slicing_strategy_getter is not None:
|
||||||
mem_free_cuda, _ = torch.cuda.mem_get_info(q.device)
|
(dim, slice_size) = slicing_strategy_getter(self)
|
||||||
mem_free_torch = mem_reserved - mem_active
|
if dim is not None:
|
||||||
mem_free_total = mem_free_cuda + mem_free_torch
|
# print("using saved slicing strategy with dim", dim, "slice size", slice_size)
|
||||||
|
if dim == 0:
|
||||||
|
return self.einsum_op_slice_dim0(q, k, v, slice_size)
|
||||||
|
elif dim == 1:
|
||||||
|
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
||||||
|
|
||||||
|
# fallback for when there is no saved strategy, or saved strategy does not slice
|
||||||
|
mem_free_total = self.cached_mem_free_total or get_mem_free_total(q.device)
|
||||||
# Divide factor of safety as there's copying and fragmentation
|
# Divide factor of safety as there's copying and fragmentation
|
||||||
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
||||||
|
|
||||||
|
|
||||||
def get_attention_mem_efficient(self, q, k, v):
|
def get_attention_mem_efficient(self, q, k, v):
|
||||||
if q.device.type == 'cuda':
|
if q.device.type == 'cuda':
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
#print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device))
|
||||||
return self.einsum_op_cuda(q, k, v)
|
return self.einsum_op_cuda(q, k, v)
|
||||||
|
|
||||||
if q.device.type == 'mps':
|
if q.device.type == 'mps':
|
||||||
|
Loading…
Reference in New Issue
Block a user