sliced cross-attention wrangler works

This commit is contained in:
Damian at mba 2022-10-18 11:48:33 +02:00
parent 37a204324b
commit 056cb0d8a8
3 changed files with 123 additions and 103 deletions

View File

@ -56,6 +56,13 @@ class CrossAttentionControl:
return [module for name, module in model.named_modules() if
type(module).__name__ == "CrossAttention" and which_attn in name]
@classmethod
def clear_requests(cls, model):
self_attention_modules = cls.get_attention_modules(model, cls.AttentionType.SELF)
tokens_attention_modules = cls.get_attention_modules(model, cls.AttentionType.TOKENS)
for m in self_attention_modules+tokens_attention_modules:
m.save_last_attn_slice = False
m.use_last_attn_slice = False
@classmethod
def request_save_attention_maps(cls, model):
@ -76,81 +83,84 @@ class CrossAttentionControl:
def inject_attention_functions(cls, unet):
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
def new_attention(self, query, key, value):
def attention_slice_wrangler(self, attention_scores, suggested_attention_slice, dim, offset, slice_size):
attn_slice = suggested_attention_slice
if dim is not None:
start = offset
end = start+slice_size
#print(f"in wrangler, sliced dim {dim} {start}-{end}, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}")
#else:
# print(f"in wrangler, whole, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}")
attention_scores = torch.functional.einsum('b i d, b j d -> b i j', query, key)
# calculate attention slice by taking the best scores for each latent pixel
attn_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
if self.use_last_attn_slice:
this_attn_slice = attn_slice
if self.last_attn_slice_mask is not None:
base_attn_slice = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices)
# indices and mask operate on dim=2, no need to slice
base_attn_slice_full = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices)
base_attn_slice_mask = self.last_attn_slice_mask
this_attn_slice_mask = 1 - self.last_attn_slice_mask
attn_slice = attn_slice * this_attn_slice_mask + base_attn_slice * base_attn_slice_mask
else:
attn_slice = self.last_attn_slice
if dim is None:
base_attn_slice = base_attn_slice_full
#print("using whole base slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape)
elif dim == 0:
base_attn_slice = base_attn_slice_full[start:end]
#print("using base dim 0 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape)
elif dim == 1:
base_attn_slice = base_attn_slice_full[:, start:end]
#print("using base dim 1 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape)
self.use_last_attn_slice = False
attn_slice = this_attn_slice * (1 - base_attn_slice_mask) + \
base_attn_slice * base_attn_slice_mask
else:
if dim is None:
attn_slice = self.last_attn_slice
#print("took whole slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
elif dim == 0:
attn_slice = self.last_attn_slice[start:end]
#print("took dim 0 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
elif dim == 1:
attn_slice = self.last_attn_slice[:, start:end]
#print("took dim 1 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
if self.save_last_attn_slice:
self.last_attn_slice = attn_slice
self.save_last_attn_slice = False
if dim is None:
self.last_attn_slice = attn_slice
elif dim == 0:
# dynamically grow last_attn_slice if needed
if self.last_attn_slice is None:
self.last_attn_slice = attn_slice
#print("no last_attn_slice: shape now", self.last_attn_slice.shape)
elif self.last_attn_slice.shape[0] == start:
self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=0)
assert(self.last_attn_slice.shape[0] == end)
#print("last_attn_slice too small, appended dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape)
else:
# no need to grow
self.last_attn_slice[start:end] = attn_slice
#print("last_attn_slice shape is fine, setting dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape)
elif dim == 1:
# dynamically grow last_attn_slice if needed
if self.last_attn_slice is None:
self.last_attn_slice = attn_slice
elif self.last_attn_slice.shape[1] == start:
self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=1)
assert(self.last_attn_slice.shape[1] == end)
else:
# no need to grow
self.last_attn_slice[:, start:end] = attn_slice
if self.use_last_attn_weights and self.last_attn_slice_weights is not None:
attn_slice = attn_slice * self.last_attn_slice_weights
self.use_last_attn_weights = False
if dim is None:
weights = self.last_attn_slice_weights
elif dim == 0:
weights = self.last_attn_slice_weights[start:end]
elif dim == 1:
weights = self.last_attn_slice_weights[:, start:end]
attn_slice = attn_slice * weights
return torch.functional.einsum('b i j, b j d -> b i d', attn_slice, value)
def new_sliced_attention(self, query, key, value, sequence_length, dim):
raise NotImplementedError("not tested yet")
batch_size_attention = query.shape[0]
hidden_states = torch.zeros(
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
)
slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
for i in range(hidden_states.shape[0] // slice_size):
start_idx = i * slice_size
end_idx = (i + 1) * slice_size
attn_slice = (
torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
) # TODO: use baddbmm for better performance
attn_slice = attn_slice.softmax(dim=-1)
if self.use_last_attn_slice:
if self.last_attn_slice_mask is not None:
new_attn_slice = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices)
attn_slice = attn_slice * (
1 - self.last_attn_slice_mask) + new_attn_slice * self.last_attn_slice_mask
else:
attn_slice = self.last_attn_slice
self.use_last_attn_slice = False
if self.save_last_attn_slice:
self.last_attn_slice = attn_slice
self.save_last_attn_slice = False
if self.use_last_attn_weights and self.last_attn_slice_weights is not None:
attn_slice = attn_slice * self.last_attn_slice_weights
self.use_last_attn_weights = False
attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states
def select_attention_func(module, q, k, v, dim, offset, slice_size):
if dim == 0 or dim == 1:
return new_sliced_attention(module, q, k, v, sequence_length=slice_size, dim=dim)
else:
return new_attention(module, q, k, v)
return attn_slice
for name, module in unet.named_modules():
module_name = type(module).__name__
@ -159,7 +169,7 @@ class CrossAttentionControl:
module.use_last_attn_slice = False
module.use_last_attn_weights = False
module.save_last_attn_slice = False
module.set_custom_attention_calculator(select_attention_func)
module.set_attention_slice_wrangler(attention_slice_wrangler)
# original code below

View File

@ -48,6 +48,8 @@ class CFGDenoiser(nn.Module):
def forward(self, x, sigma, uncond, cond, cond_scale):
CrossAttentionControl.clear_requests(self.inner_model)
#rint('generating unconditioned latents')
unconditioned_latents = self.inner_model(x, sigma, cond=uncond)
@ -61,6 +63,7 @@ class CFGDenoiser(nn.Module):
if self.edited_conditioning is not None:
# process x again, using the saved attention maps but the new conditioning
# this is automatically toggled off after the model forward()
CrossAttentionControl.clear_requests(self.inner_model)
CrossAttentionControl.request_apply_saved_attention_maps(self.inner_model)
#print('generating edited conditioned latents')
conditioned_latents = self.inner_model(x, sigma, cond=self.edited_conditioning)

View File

@ -173,59 +173,75 @@ class CrossAttention(nn.Module):
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
self.custom_attention_calculator = None
self.attention_slice_wrangler = None
def set_custom_attention_calculator(self, callback:Callable[[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int], torch.Tensor]):
def set_attention_slice_wrangler(self, wrangler:Callable[[nn.Module, torch.Tensor, torch.Tensor, int, int, int], torch.Tensor]):
'''
Set custom attention calculator to be called when attention is calculated
:param callback: Callback, with args q, k, v, dim, offset, slice_size, which returns attention info.
q, k, v are as regular attention calculator.
:param wrangler: Callback, with args (self, attention_scores, suggested_attention_slice, dim, offset, slice_size),
which returns either the suggested_attention_slice or an adjusted equivalent.
self is the current CrossAttention module for which the callback is being invoked.
attention_scores are the scores for attention
suggested_attention_slice is a softmax(dim=-1) over attention_scores
dim is -1 if the call is non-sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
If dim is >= 0, offset and slice_size specify the slice start and length.
Pass None to use the default attention calculation.
:return:
'''
self.custom_attention_calculator = callback
self.attention_slice_wrangler = wrangler
def einsum_op_slice_dim0(self, q, k, v, slice_size, callback):
def einsum_lowest_level(self, q, k, v, dim, offset, slice_size):
# calculate attention scores
attention_scores = einsum('b i d, b j d -> b i j', q, k)
# calculate attenion slice by taking the best scores for each latent pixel
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
if self.attention_slice_wrangler is not None:
attention_slice = self.attention_slice_wrangler(self, attention_scores, default_attention_slice, dim, offset, slice_size)
else:
attention_slice = default_attention_slice
return einsum('b i j, b j d -> b i d', attention_slice, v)
def einsum_op_slice_dim0(self, q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[0], slice_size):
end = i + slice_size
r[i:end] = callback(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
return r
def einsum_op_slice_dim1(self, q, k, v, slice_size, callback):
def einsum_op_slice_dim1(self, q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
r[:, i:end] = callback(self, q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
return r
def einsum_op_mps_v1(self, q, k, v, callback):
def einsum_op_mps_v1(self, q, k, v):
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
return callback(self, q, k, v, -1, 0, 0)
return self.einsum_lowest_level(q, k, v, None, None, None)
else:
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
return self.einsum_op_slice_dim1(q, k, v, slice_size, callback)
return self.einsum_op_slice_dim1(q, k, v, slice_size)
def einsum_op_mps_v2(self, q, k, v, callback):
def einsum_op_mps_v2(self, q, k, v):
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
return callback(self, q, k, v, -1, 0, 0)
return self.einsum_lowest_level(q, k, v, None, None, None)
else:
return self.einsum_op_slice_dim0(q, k, v, 1, callback)
return self.einsum_op_slice_dim0(q, k, v, 1)
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb, callback):
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
if size_mb <= max_tensor_mb:
return callback(self, q, k, v, offset=0)
return self.einsum_lowest_level(q, k, v, None, None, None)
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
if div <= q.shape[0]:
print("warning: untested call to einsum_op_slice_dim0")
return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div, callback)
return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div)
print("warning: untested call to einsum_op_slice_dim1")
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1), callback)
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1))
def einsum_op_cuda(self, q, k, v, callback):
def einsum_op_cuda(self, q, k, v):
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
@ -233,20 +249,20 @@ class CrossAttention(nn.Module):
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
# Divide factor of safety as there's copying and fragmentation
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20), callback)
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
def get_attention_mem_efficient(self, q, k, v, callback):
def get_attention_mem_efficient(self, q, k, v):
if q.device.type == 'cuda':
return self.einsum_op_cuda(q, k, v, callback)
return self.einsum_op_cuda(q, k, v)
if q.device.type == 'mps':
if self.mem_total_gb >= 32:
return self.einsum_op_mps_v1(q, k, v, callback)
return self.einsum_op_mps_v2(q, k, v, callback)
return self.einsum_op_mps_v1(q, k, v)
return self.einsum_op_mps_v2(q, k, v)
# Smaller slices are faster due to L2/L3/SLC caches.
# Tested on i7 with 8MB L3 cache.
return self.einsum_op_tensor_mem(q, k, v, 32, callback)
return self.einsum_op_tensor_mem(q, k, v, 32)
def forward(self, x, context=None, mask=None):
h = self.heads
@ -259,23 +275,14 @@ class CrossAttention(nn.Module):
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
def default_attention_calculator(q, k, v, **kwargs):
# calculate attention scores
attention_scores = einsum('b i d, b j d -> b i j', q, k)
# calculate attenion slice by taking the best scores for each latent pixel
attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
return einsum('b i j, b j d -> b i d', attention_slice, v)
attention_calculator = \
self.custom_attention_calculator if self.custom_attention_calculator is not None \
else default_attention_calculator
r = self.get_attention_mem_efficient(q, k, v, attention_calculator)
r = self.get_attention_mem_efficient(q, k, v)
hidden_states = rearrange(r, '(b h) n d -> b n (h d)', h=h)
return self.to_out(hidden_states)
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
super().__init__()