Fix #1362 by improving VRAM usage patterns when doing .swap()

commit ef3f7a26e242b73c2beb0195c7fd8f654ef47f55
Author: damian0815 <null@damianstewart.com>
Date:   Tue Nov 8 12:18:37 2022 +0100

    remove log spam

commit 7189d649622d4668b120b0dd278388ad672142c4
Author: damian0815 <null@damianstewart.com>
Date:   Tue Nov 8 12:10:28 2022 +0100

    change the way saved slicing strategy is applied

commit 01c40f751ab72955140165c16f95ae411732265b
Author: damian0815 <null@damianstewart.com>
Date:   Tue Nov 8 12:04:43 2022 +0100

    fix slicing_strategy_getter callsite

commit f8cfe25150a346958903316bc710737d99839923
Author: damian0815 <null@damianstewart.com>
Date:   Tue Nov 8 11:56:22 2022 +0100

    cleanup, consistent dim=0 also tested

commit 5bf9b1e890d48e962afd4a668a219b68271e5dc1
Author: damian0815 <null@damianstewart.com>
Date:   Tue Nov 8 11:34:09 2022 +0100

    refactored context, tested with non-sliced cross attention control

commit d58a46e39bf562e7459290d2444256e8c08ad0b6
Author: damian0815 <null@damianstewart.com>
Date:   Sun Nov 6 00:41:52 2022 +0100

    cleanup

commit 7e2c658b4c06fe239311b65b9bb16fa3adec7fd7
Author: damian0815 <null@damianstewart.com>
Date:   Sat Nov 5 22:57:31 2022 +0100

    disable logs

commit 20ee89d93841b070738b3d8a4385c93b097d92eb
Author: damian0815 <null@damianstewart.com>
Date:   Sat Nov 5 22:36:58 2022 +0100

    slice saved attention if necessary

commit 0a7684a22c880ec0f48cc22bfed4526358f71546
Author: damian0815 <null@damianstewart.com>
Date:   Sat Nov 5 22:32:38 2022 +0100

    raise instead of asserting

commit 7083104c7f3a0d8fd96e94a2f391de50a3c942e4
Author: damian0815 <null@damianstewart.com>
Date:   Sat Nov 5 22:31:00 2022 +0100

    store dim when saving slices

commit f7c0808ed383ec1dc70645288a798ed2aa4fa85c
Author: damian0815 <null@damianstewart.com>
Date:   Sat Nov 5 22:27:16 2022 +0100

    don't retry on exception

commit 749a721e939b3fe7c1741e7998dab6bd2c85a0cb
Author: damian0815 <null@damianstewart.com>
Date:   Sat Nov 5 22:24:50 2022 +0100

    stuff

commit 032ab90e9533be8726301ec91b97137e2aadef9a
Author: damian0815 <null@damianstewart.com>
Date:   Sat Nov 5 22:20:17 2022 +0100

    more logging

commit 3dc34b387f033482305360e605809d95a40bf6f8
Author: damian0815 <null@damianstewart.com>
Date:   Sat Nov 5 22:16:47 2022 +0100

    logs

commit 901c4c1aa4b9bcef695a6551867ec8149e6e6a93
Author: damian0815 <null@damianstewart.com>
Date:   Sat Nov 5 22:12:39 2022 +0100

    actually set save_slicing_strategy to True

commit f780e0a0a7c6b6a3db320891064da82589358c8a
Author: damian0815 <null@damianstewart.com>
Date:   Sat Nov 5 22:10:35 2022 +0100

    store slicing strategy

commit 93bb6d566fd18c5c69ef7dacc8f74ba2cf671cb7
Author: damian <git@damianstewart.com>
Date:   Sat Nov 5 20:43:48 2022 +0100

    still not it

commit 5e3a9541f8ae00bde524046963910323e20c40b7
Author: damian <git@damianstewart.com>
Date:   Sat Nov 5 17:20:02 2022 +0100

    wip offloading attention slices on-demand

commit 4c2966aa856b6f3b446216da3619ae931552ef08
Author: damian0815 <null@damianstewart.com>
Date:   Sat Nov 5 15:47:40 2022 +0100

    pre-emptive offloading, idk if it works

commit 572576755e9f0a878d38e8173e485126c0efbefb
Author: root <you@example.com>
Date:   Sat Nov 5 11:25:32 2022 +0000

    push attention slices to cpu. slow but saves memory.

commit b57c83a68f2ac03976ebc89ce2ff03812d6d185f
Author: damian0815 <null@damianstewart.com>
Date:   Sat Nov 5 12:04:22 2022 +0100

    verbose logging

commit 3a5dae116f110a96585d9eb71d713b5ed2bc3d2b
Author: damian0815 <null@damianstewart.com>
Date:   Sat Nov 5 11:50:48 2022 +0100

    wip fixing mem strategy crash (4 test on runpod)

commit 3cf237db5fae0c7b0b4cc3c47c81830bdb2ae7de
Author: damian0815 <null@damianstewart.com>
Date:   Fri Nov 4 09:02:40 2022 +0100

    wip, only works on cuda
This commit is contained in:
damian0815
2022-11-08 12:59:34 +01:00
committed by Lincoln Stein
parent 4b5a96501d
commit 178f0c78d8
3 changed files with 234 additions and 143 deletions

View File

@ -1,6 +1,6 @@
from inspect import isfunction
import math
from typing import Callable
from typing import Callable, Optional
import torch
import torch.nn.functional as F
@ -151,6 +151,17 @@ class SpatialSelfAttention(nn.Module):
return x+h_
def get_mem_free_total(device):
#only on cuda
if not torch.cuda.is_available():
return None
stats = torch.cuda.memory_stats(device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(device)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
return mem_free_total
class CrossAttention(nn.Module):
@ -173,31 +184,43 @@ class CrossAttention(nn.Module):
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
self.cached_mem_free_total = None
self.attention_slice_wrangler = None
self.slicing_strategy_getter = None
def set_attention_slice_wrangler(self, wrangler:Callable[[nn.Module, torch.Tensor, torch.Tensor, int, int, int], torch.Tensor]):
def set_attention_slice_wrangler(self, wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]]):
'''
Set custom attention calculator to be called when attention is calculated
:param wrangler: Callback, with args (self, attention_scores, suggested_attention_slice, dim, offset, slice_size),
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
which returns either the suggested_attention_slice or an adjusted equivalent.
self is the current CrossAttention module for which the callback is being invoked.
attention_scores are the scores for attention
suggested_attention_slice is a softmax(dim=-1) over attention_scores
dim is -1 if the call is non-sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
If dim is >= 0, offset and slice_size specify the slice start and length.
`module` is the current CrossAttention module for which the callback is being invoked.
`suggested_attention_slice` is the default-calculated attention slice
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
Pass None to use the default attention calculation.
:return:
'''
self.attention_slice_wrangler = wrangler
def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int,int]]]):
self.slicing_strategy_getter = getter
def cache_free_memory_count(self, device):
self.cached_mem_free_total = get_mem_free_total(device)
print("free cuda memory: ", self.cached_mem_free_total)
def clear_cached_free_memory_count(self):
self.cached_mem_free_total = None
def einsum_lowest_level(self, q, k, v, dim, offset, slice_size):
# calculate attention scores
attention_scores = einsum('b i d, b j d -> b i j', q, k)
# calculate attenion slice by taking the best scores for each latent pixel
# calculate attention slice by taking the best scores for each latent pixel
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
if self.attention_slice_wrangler is not None:
attention_slice = self.attention_slice_wrangler(self, attention_scores, default_attention_slice, dim, offset, slice_size)
attention_slice_wrangler = self.attention_slice_wrangler
if attention_slice_wrangler is not None:
attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size)
else:
attention_slice = default_attention_slice
@ -240,17 +263,27 @@ class CrossAttention(nn.Module):
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1))
def einsum_op_cuda(self, q, k, v):
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(q.device)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
# check if we already have a slicing strategy (this should only happen during cross-attention controlled generation)
slicing_strategy_getter = self.slicing_strategy_getter
if slicing_strategy_getter is not None:
(dim, slice_size) = slicing_strategy_getter(self)
if dim is not None:
# print("using saved slicing strategy with dim", dim, "slice size", slice_size)
if dim == 0:
return self.einsum_op_slice_dim0(q, k, v, slice_size)
elif dim == 1:
return self.einsum_op_slice_dim1(q, k, v, slice_size)
# fallback for when there is no saved strategy, or saved strategy does not slice
mem_free_total = self.cached_mem_free_total or get_mem_free_total(q.device)
# Divide factor of safety as there's copying and fragmentation
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
def get_attention_mem_efficient(self, q, k, v):
if q.device.type == 'cuda':
torch.cuda.empty_cache()
#print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device))
return self.einsum_op_cuda(q, k, v)
if q.device.type == 'mps':