mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
sliced cross-attention wrangler works
This commit is contained in:
parent
37a204324b
commit
056cb0d8a8
@ -56,6 +56,13 @@ class CrossAttentionControl:
|
||||
return [module for name, module in model.named_modules() if
|
||||
type(module).__name__ == "CrossAttention" and which_attn in name]
|
||||
|
||||
@classmethod
|
||||
def clear_requests(cls, model):
|
||||
self_attention_modules = cls.get_attention_modules(model, cls.AttentionType.SELF)
|
||||
tokens_attention_modules = cls.get_attention_modules(model, cls.AttentionType.TOKENS)
|
||||
for m in self_attention_modules+tokens_attention_modules:
|
||||
m.save_last_attn_slice = False
|
||||
m.use_last_attn_slice = False
|
||||
|
||||
@classmethod
|
||||
def request_save_attention_maps(cls, model):
|
||||
@ -76,81 +83,84 @@ class CrossAttentionControl:
|
||||
def inject_attention_functions(cls, unet):
|
||||
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
|
||||
|
||||
def new_attention(self, query, key, value):
|
||||
def attention_slice_wrangler(self, attention_scores, suggested_attention_slice, dim, offset, slice_size):
|
||||
|
||||
attn_slice = suggested_attention_slice
|
||||
if dim is not None:
|
||||
start = offset
|
||||
end = start+slice_size
|
||||
#print(f"in wrangler, sliced dim {dim} {start}-{end}, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}")
|
||||
#else:
|
||||
# print(f"in wrangler, whole, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}")
|
||||
|
||||
attention_scores = torch.functional.einsum('b i d, b j d -> b i j', query, key)
|
||||
# calculate attention slice by taking the best scores for each latent pixel
|
||||
attn_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
|
||||
|
||||
if self.use_last_attn_slice:
|
||||
this_attn_slice = attn_slice
|
||||
if self.last_attn_slice_mask is not None:
|
||||
base_attn_slice = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices)
|
||||
# indices and mask operate on dim=2, no need to slice
|
||||
base_attn_slice_full = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices)
|
||||
base_attn_slice_mask = self.last_attn_slice_mask
|
||||
this_attn_slice_mask = 1 - self.last_attn_slice_mask
|
||||
attn_slice = attn_slice * this_attn_slice_mask + base_attn_slice * base_attn_slice_mask
|
||||
else:
|
||||
attn_slice = self.last_attn_slice
|
||||
if dim is None:
|
||||
base_attn_slice = base_attn_slice_full
|
||||
#print("using whole base slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape)
|
||||
elif dim == 0:
|
||||
base_attn_slice = base_attn_slice_full[start:end]
|
||||
#print("using base dim 0 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape)
|
||||
elif dim == 1:
|
||||
base_attn_slice = base_attn_slice_full[:, start:end]
|
||||
#print("using base dim 1 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape)
|
||||
|
||||
self.use_last_attn_slice = False
|
||||
attn_slice = this_attn_slice * (1 - base_attn_slice_mask) + \
|
||||
base_attn_slice * base_attn_slice_mask
|
||||
else:
|
||||
if dim is None:
|
||||
attn_slice = self.last_attn_slice
|
||||
#print("took whole slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
|
||||
elif dim == 0:
|
||||
attn_slice = self.last_attn_slice[start:end]
|
||||
#print("took dim 0 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
|
||||
elif dim == 1:
|
||||
attn_slice = self.last_attn_slice[:, start:end]
|
||||
#print("took dim 1 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
|
||||
|
||||
if self.save_last_attn_slice:
|
||||
self.last_attn_slice = attn_slice
|
||||
self.save_last_attn_slice = False
|
||||
if dim is None:
|
||||
self.last_attn_slice = attn_slice
|
||||
elif dim == 0:
|
||||
# dynamically grow last_attn_slice if needed
|
||||
if self.last_attn_slice is None:
|
||||
self.last_attn_slice = attn_slice
|
||||
#print("no last_attn_slice: shape now", self.last_attn_slice.shape)
|
||||
elif self.last_attn_slice.shape[0] == start:
|
||||
self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=0)
|
||||
assert(self.last_attn_slice.shape[0] == end)
|
||||
#print("last_attn_slice too small, appended dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape)
|
||||
else:
|
||||
# no need to grow
|
||||
self.last_attn_slice[start:end] = attn_slice
|
||||
#print("last_attn_slice shape is fine, setting dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape)
|
||||
|
||||
elif dim == 1:
|
||||
# dynamically grow last_attn_slice if needed
|
||||
if self.last_attn_slice is None:
|
||||
self.last_attn_slice = attn_slice
|
||||
elif self.last_attn_slice.shape[1] == start:
|
||||
self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=1)
|
||||
assert(self.last_attn_slice.shape[1] == end)
|
||||
else:
|
||||
# no need to grow
|
||||
self.last_attn_slice[:, start:end] = attn_slice
|
||||
|
||||
if self.use_last_attn_weights and self.last_attn_slice_weights is not None:
|
||||
attn_slice = attn_slice * self.last_attn_slice_weights
|
||||
self.use_last_attn_weights = False
|
||||
if dim is None:
|
||||
weights = self.last_attn_slice_weights
|
||||
elif dim == 0:
|
||||
weights = self.last_attn_slice_weights[start:end]
|
||||
elif dim == 1:
|
||||
weights = self.last_attn_slice_weights[:, start:end]
|
||||
attn_slice = attn_slice * weights
|
||||
|
||||
return torch.functional.einsum('b i j, b j d -> b i d', attn_slice, value)
|
||||
|
||||
def new_sliced_attention(self, query, key, value, sequence_length, dim):
|
||||
|
||||
raise NotImplementedError("not tested yet")
|
||||
|
||||
batch_size_attention = query.shape[0]
|
||||
hidden_states = torch.zeros(
|
||||
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
|
||||
)
|
||||
slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
|
||||
for i in range(hidden_states.shape[0] // slice_size):
|
||||
start_idx = i * slice_size
|
||||
end_idx = (i + 1) * slice_size
|
||||
attn_slice = (
|
||||
torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
|
||||
) # TODO: use baddbmm for better performance
|
||||
attn_slice = attn_slice.softmax(dim=-1)
|
||||
|
||||
if self.use_last_attn_slice:
|
||||
if self.last_attn_slice_mask is not None:
|
||||
new_attn_slice = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices)
|
||||
attn_slice = attn_slice * (
|
||||
1 - self.last_attn_slice_mask) + new_attn_slice * self.last_attn_slice_mask
|
||||
else:
|
||||
attn_slice = self.last_attn_slice
|
||||
|
||||
self.use_last_attn_slice = False
|
||||
|
||||
if self.save_last_attn_slice:
|
||||
self.last_attn_slice = attn_slice
|
||||
self.save_last_attn_slice = False
|
||||
|
||||
if self.use_last_attn_weights and self.last_attn_slice_weights is not None:
|
||||
attn_slice = attn_slice * self.last_attn_slice_weights
|
||||
self.use_last_attn_weights = False
|
||||
|
||||
attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])
|
||||
|
||||
hidden_states[start_idx:end_idx] = attn_slice
|
||||
|
||||
# reshape hidden_states
|
||||
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def select_attention_func(module, q, k, v, dim, offset, slice_size):
|
||||
if dim == 0 or dim == 1:
|
||||
return new_sliced_attention(module, q, k, v, sequence_length=slice_size, dim=dim)
|
||||
else:
|
||||
return new_attention(module, q, k, v)
|
||||
return attn_slice
|
||||
|
||||
for name, module in unet.named_modules():
|
||||
module_name = type(module).__name__
|
||||
@ -159,7 +169,7 @@ class CrossAttentionControl:
|
||||
module.use_last_attn_slice = False
|
||||
module.use_last_attn_weights = False
|
||||
module.save_last_attn_slice = False
|
||||
module.set_custom_attention_calculator(select_attention_func)
|
||||
module.set_attention_slice_wrangler(attention_slice_wrangler)
|
||||
|
||||
|
||||
# original code below
|
||||
|
@ -48,6 +48,8 @@ class CFGDenoiser(nn.Module):
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||
|
||||
CrossAttentionControl.clear_requests(self.inner_model)
|
||||
|
||||
#rint('generating unconditioned latents')
|
||||
unconditioned_latents = self.inner_model(x, sigma, cond=uncond)
|
||||
|
||||
@ -61,6 +63,7 @@ class CFGDenoiser(nn.Module):
|
||||
if self.edited_conditioning is not None:
|
||||
# process x again, using the saved attention maps but the new conditioning
|
||||
# this is automatically toggled off after the model forward()
|
||||
CrossAttentionControl.clear_requests(self.inner_model)
|
||||
CrossAttentionControl.request_apply_saved_attention_maps(self.inner_model)
|
||||
#print('generating edited conditioned latents')
|
||||
conditioned_latents = self.inner_model(x, sigma, cond=self.edited_conditioning)
|
||||
|
@ -173,59 +173,75 @@ class CrossAttention(nn.Module):
|
||||
|
||||
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
||||
|
||||
self.custom_attention_calculator = None
|
||||
self.attention_slice_wrangler = None
|
||||
|
||||
def set_custom_attention_calculator(self, callback:Callable[[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int], torch.Tensor]):
|
||||
def set_attention_slice_wrangler(self, wrangler:Callable[[nn.Module, torch.Tensor, torch.Tensor, int, int, int], torch.Tensor]):
|
||||
'''
|
||||
Set custom attention calculator to be called when attention is calculated
|
||||
:param callback: Callback, with args q, k, v, dim, offset, slice_size, which returns attention info.
|
||||
q, k, v are as regular attention calculator.
|
||||
:param wrangler: Callback, with args (self, attention_scores, suggested_attention_slice, dim, offset, slice_size),
|
||||
which returns either the suggested_attention_slice or an adjusted equivalent.
|
||||
self is the current CrossAttention module for which the callback is being invoked.
|
||||
attention_scores are the scores for attention
|
||||
suggested_attention_slice is a softmax(dim=-1) over attention_scores
|
||||
dim is -1 if the call is non-sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
|
||||
If dim is >= 0, offset and slice_size specify the slice start and length.
|
||||
|
||||
Pass None to use the default attention calculation.
|
||||
:return:
|
||||
'''
|
||||
self.custom_attention_calculator = callback
|
||||
self.attention_slice_wrangler = wrangler
|
||||
|
||||
def einsum_op_slice_dim0(self, q, k, v, slice_size, callback):
|
||||
def einsum_lowest_level(self, q, k, v, dim, offset, slice_size):
|
||||
# calculate attention scores
|
||||
attention_scores = einsum('b i d, b j d -> b i j', q, k)
|
||||
# calculate attenion slice by taking the best scores for each latent pixel
|
||||
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
|
||||
if self.attention_slice_wrangler is not None:
|
||||
attention_slice = self.attention_slice_wrangler(self, attention_scores, default_attention_slice, dim, offset, slice_size)
|
||||
else:
|
||||
attention_slice = default_attention_slice
|
||||
|
||||
return einsum('b i j, b j d -> b i d', attention_slice, v)
|
||||
|
||||
def einsum_op_slice_dim0(self, q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[0], slice_size):
|
||||
end = i + slice_size
|
||||
r[i:end] = callback(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
|
||||
r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
|
||||
return r
|
||||
|
||||
def einsum_op_slice_dim1(self, q, k, v, slice_size, callback):
|
||||
def einsum_op_slice_dim1(self, q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
r[:, i:end] = callback(self, q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
|
||||
r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
|
||||
return r
|
||||
|
||||
def einsum_op_mps_v1(self, q, k, v, callback):
|
||||
def einsum_op_mps_v1(self, q, k, v):
|
||||
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
|
||||
return callback(self, q, k, v, -1, 0, 0)
|
||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||
else:
|
||||
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
||||
return self.einsum_op_slice_dim1(q, k, v, slice_size, callback)
|
||||
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
||||
|
||||
def einsum_op_mps_v2(self, q, k, v, callback):
|
||||
def einsum_op_mps_v2(self, q, k, v):
|
||||
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
|
||||
return callback(self, q, k, v, -1, 0, 0)
|
||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||
else:
|
||||
return self.einsum_op_slice_dim0(q, k, v, 1, callback)
|
||||
return self.einsum_op_slice_dim0(q, k, v, 1)
|
||||
|
||||
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb, callback):
|
||||
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
|
||||
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
|
||||
if size_mb <= max_tensor_mb:
|
||||
return callback(self, q, k, v, offset=0)
|
||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
|
||||
if div <= q.shape[0]:
|
||||
print("warning: untested call to einsum_op_slice_dim0")
|
||||
return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div, callback)
|
||||
return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div)
|
||||
print("warning: untested call to einsum_op_slice_dim1")
|
||||
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1), callback)
|
||||
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1))
|
||||
|
||||
def einsum_op_cuda(self, q, k, v, callback):
|
||||
def einsum_op_cuda(self, q, k, v):
|
||||
stats = torch.cuda.memory_stats(q.device)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
@ -233,20 +249,20 @@ class CrossAttention(nn.Module):
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
# Divide factor of safety as there's copying and fragmentation
|
||||
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20), callback)
|
||||
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
||||
|
||||
def get_attention_mem_efficient(self, q, k, v, callback):
|
||||
def get_attention_mem_efficient(self, q, k, v):
|
||||
if q.device.type == 'cuda':
|
||||
return self.einsum_op_cuda(q, k, v, callback)
|
||||
return self.einsum_op_cuda(q, k, v)
|
||||
|
||||
if q.device.type == 'mps':
|
||||
if self.mem_total_gb >= 32:
|
||||
return self.einsum_op_mps_v1(q, k, v, callback)
|
||||
return self.einsum_op_mps_v2(q, k, v, callback)
|
||||
return self.einsum_op_mps_v1(q, k, v)
|
||||
return self.einsum_op_mps_v2(q, k, v)
|
||||
|
||||
# Smaller slices are faster due to L2/L3/SLC caches.
|
||||
# Tested on i7 with 8MB L3 cache.
|
||||
return self.einsum_op_tensor_mem(q, k, v, 32, callback)
|
||||
return self.einsum_op_tensor_mem(q, k, v, 32)
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
@ -259,23 +275,14 @@ class CrossAttention(nn.Module):
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
|
||||
def default_attention_calculator(q, k, v, **kwargs):
|
||||
# calculate attention scores
|
||||
attention_scores = einsum('b i d, b j d -> b i j', q, k)
|
||||
# calculate attenion slice by taking the best scores for each latent pixel
|
||||
attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
|
||||
return einsum('b i j, b j d -> b i d', attention_slice, v)
|
||||
|
||||
attention_calculator = \
|
||||
self.custom_attention_calculator if self.custom_attention_calculator is not None \
|
||||
else default_attention_calculator
|
||||
|
||||
r = self.get_attention_mem_efficient(q, k, v, attention_calculator)
|
||||
r = self.get_attention_mem_efficient(q, k, v)
|
||||
|
||||
hidden_states = rearrange(r, '(b h) n d -> b n (h d)', h=h)
|
||||
return self.to_out(hidden_states)
|
||||
|
||||
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
|
||||
super().__init__()
|
||||
|
Loading…
Reference in New Issue
Block a user