mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor common CrossAttention stuff into a mixin so that the old ldm code can still work if necessary
This commit is contained in:
parent
c6f31e5f36
commit
69d42762de
@ -11,11 +11,11 @@ import torch
|
|||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
from diffusers.models import attention
|
from diffusers.models import attention
|
||||||
|
|
||||||
from ldm.models.diffusion.cross_attention_control import InvokeAICrossAttention
|
from ldm.models.diffusion.cross_attention_control import InvokeAIDiffusersCrossAttention
|
||||||
|
|
||||||
# monkeypatch diffusers CrossAttention 🙈
|
# monkeypatch diffusers CrossAttention 🙈
|
||||||
# this is to make prompt2prompt and (future) attention maps work
|
# this is to make prompt2prompt and (future) attention maps work
|
||||||
attention.CrossAttention = InvokeAICrossAttention
|
attention.CrossAttention = InvokeAIDiffusersCrossAttention
|
||||||
|
|
||||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||||
|
@ -212,7 +212,7 @@ def setup_cross_attention_control(model, context: Context):
|
|||||||
|
|
||||||
def get_attention_modules(model, which: CrossAttentionType):
|
def get_attention_modules(model, which: CrossAttentionType):
|
||||||
# cross_attention_class: type = ldm.modules.attention.CrossAttention
|
# cross_attention_class: type = ldm.modules.attention.CrossAttention
|
||||||
cross_attention_class: type = InvokeAICrossAttention
|
cross_attention_class: type = InvokeAIDiffusersCrossAttention
|
||||||
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
|
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
|
||||||
attention_module_tuples = [(name,module) for name, module in model.named_modules() if
|
attention_module_tuples = [(name,module) for name, module in model.named_modules() if
|
||||||
isinstance(module, cross_attention_class) and which_attn in name]
|
isinstance(module, cross_attention_class) and which_attn in name]
|
||||||
@ -315,12 +315,16 @@ def get_mem_free_total(device):
|
|||||||
mem_free_total = mem_free_cuda + mem_free_torch
|
mem_free_total = mem_free_cuda + mem_free_torch
|
||||||
return mem_free_total
|
return mem_free_total
|
||||||
|
|
||||||
class InvokeAICrossAttention(diffusers.models.attention.CrossAttention):
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
class InvokeAICrossAttentionMixin:
|
||||||
|
"""
|
||||||
|
Enable InvokeAI-flavoured CrossAttention calculation, which does aggressive low-memory slicing and calls
|
||||||
|
through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling
|
||||||
|
and dymamic slicing strategy selection.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
||||||
|
|
||||||
self.attention_slice_wrangler = None
|
self.attention_slice_wrangler = None
|
||||||
self.slicing_strategy_getter = None
|
self.slicing_strategy_getter = None
|
||||||
|
|
||||||
@ -342,16 +346,9 @@ class InvokeAICrossAttention(diffusers.models.attention.CrossAttention):
|
|||||||
def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int,int]]]):
|
def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int,int]]]):
|
||||||
self.slicing_strategy_getter = getter
|
self.slicing_strategy_getter = getter
|
||||||
|
|
||||||
def _attention(self, query, key, value):
|
|
||||||
#default_result = super()._attention(query, key, value)
|
|
||||||
damian_result = self.get_attention_mem_efficient(query, key, value)
|
|
||||||
|
|
||||||
hidden_states = self.reshape_batch_dim_to_heads(damian_result)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
def einsum_lowest_level(self, query, key, value, dim, offset, slice_size):
|
def einsum_lowest_level(self, query, key, value, dim, offset, slice_size):
|
||||||
# calculate attention scores
|
# calculate attention scores
|
||||||
#attention_scores = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
|
#attention_scores = torch.einsum('b i d, b j d -> b i j', q, k)
|
||||||
if dim is not None:
|
if dim is not None:
|
||||||
print(f"sliced dim {dim}, offset {offset}, slice_size {slice_size}")
|
print(f"sliced dim {dim}, offset {offset}, slice_size {slice_size}")
|
||||||
attention_scores = torch.baddbmm(
|
attention_scores = torch.baddbmm(
|
||||||
@ -370,11 +367,9 @@ class InvokeAICrossAttention(diffusers.models.attention.CrossAttention):
|
|||||||
else:
|
else:
|
||||||
attention_slice = default_attention_slice
|
attention_slice = default_attention_slice
|
||||||
|
|
||||||
#return torch.einsum('b i j, b j d -> b i d', attention_slice, v)
|
|
||||||
hidden_states = torch.bmm(attention_slice, value)
|
hidden_states = torch.bmm(attention_slice, value)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
def einsum_op_slice_dim0(self, q, k, v, slice_size):
|
def einsum_op_slice_dim0(self, q, k, v, slice_size):
|
||||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||||
for i in range(0, q.shape[0], slice_size):
|
for i in range(0, q.shape[0], slice_size):
|
||||||
@ -424,12 +419,12 @@ class InvokeAICrossAttention(diffusers.models.attention.CrossAttention):
|
|||||||
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
||||||
|
|
||||||
# fallback for when there is no saved strategy, or saved strategy does not slice
|
# fallback for when there is no saved strategy, or saved strategy does not slice
|
||||||
mem_free_total = get_mem_free_total(q.device)
|
mem_free_total = self.cached_mem_free_total or get_mem_free_total(q.device)
|
||||||
# Divide factor of safety as there's copying and fragmentation
|
# Divide factor of safety as there's copying and fragmentation
|
||||||
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
||||||
|
|
||||||
|
|
||||||
def get_attention_mem_efficient(self, q, k, v):
|
def get_invokeai_attention_mem_efficient(self, q, k, v):
|
||||||
if q.device.type == 'cuda':
|
if q.device.type == 'cuda':
|
||||||
#print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device))
|
#print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device))
|
||||||
return self.einsum_op_cuda(q, k, v)
|
return self.einsum_op_cuda(q, k, v)
|
||||||
@ -442,3 +437,19 @@ class InvokeAICrossAttention(diffusers.models.attention.CrossAttention):
|
|||||||
# Smaller slices are faster due to L2/L3/SLC caches.
|
# Smaller slices are faster due to L2/L3/SLC caches.
|
||||||
# Tested on i7 with 8MB L3 cache.
|
# Tested on i7 with 8MB L3 cache.
|
||||||
return self.einsum_op_tensor_mem(q, k, v, 32)
|
return self.einsum_op_tensor_mem(q, k, v, 32)
|
||||||
|
|
||||||
|
|
||||||
|
class InvokeAIDiffusersCrossAttention(diffusers.models.attention.CrossAttention, InvokeAICrossAttentionMixin):
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
InvokeAICrossAttentionMixin.__init__(self)
|
||||||
|
|
||||||
|
def _attention(self, query, key, value):
|
||||||
|
#default_result = super()._attention(query, key, value)
|
||||||
|
damian_result = self.get_invokeai_attention_mem_efficient(query, key, value)
|
||||||
|
|
||||||
|
hidden_states = self.reshape_batch_dim_to_heads(damian_result)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,6 +7,7 @@ import torch.nn.functional as F
|
|||||||
from torch import nn, einsum
|
from torch import nn, einsum
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
|
from ldm.models.diffusion.cross_attention_control import InvokeAICrossAttentionMixin
|
||||||
from ldm.modules.diffusionmodules.util import checkpoint
|
from ldm.modules.diffusionmodules.util import checkpoint
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
@ -163,8 +164,7 @@ def get_mem_free_total(device):
|
|||||||
mem_free_total = mem_free_cuda + mem_free_torch
|
mem_free_total = mem_free_cuda + mem_free_torch
|
||||||
return mem_free_total
|
return mem_free_total
|
||||||
|
|
||||||
|
class CrossAttention(nn.Module, InvokeAICrossAttentionMixin):
|
||||||
class CrossAttention(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
||||||
print(f"Warning! ldm.modules.attention.CrossAttention is no longer being maintained. Please use InvokeAICrossAttention instead.")
|
print(f"Warning! ldm.modules.attention.CrossAttention is no longer being maintained. Please use InvokeAICrossAttention instead.")
|
||||||
@ -184,117 +184,6 @@ class CrossAttention(nn.Module):
|
|||||||
nn.Dropout(dropout)
|
nn.Dropout(dropout)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
|
||||||
|
|
||||||
self.cached_mem_free_total = None
|
|
||||||
self.attention_slice_wrangler = None
|
|
||||||
self.slicing_strategy_getter = None
|
|
||||||
|
|
||||||
def set_attention_slice_wrangler(self, wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]]):
|
|
||||||
'''
|
|
||||||
Set custom attention calculator to be called when attention is calculated
|
|
||||||
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
|
|
||||||
which returns either the suggested_attention_slice or an adjusted equivalent.
|
|
||||||
`module` is the current CrossAttention module for which the callback is being invoked.
|
|
||||||
`suggested_attention_slice` is the default-calculated attention slice
|
|
||||||
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
|
|
||||||
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
|
|
||||||
|
|
||||||
Pass None to use the default attention calculation.
|
|
||||||
:return:
|
|
||||||
'''
|
|
||||||
self.attention_slice_wrangler = wrangler
|
|
||||||
|
|
||||||
def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int,int]]]):
|
|
||||||
self.slicing_strategy_getter = getter
|
|
||||||
|
|
||||||
def cache_free_memory_count(self, device):
|
|
||||||
self.cached_mem_free_total = get_mem_free_total(device)
|
|
||||||
print("free cuda memory: ", self.cached_mem_free_total)
|
|
||||||
|
|
||||||
def clear_cached_free_memory_count(self):
|
|
||||||
self.cached_mem_free_total = None
|
|
||||||
|
|
||||||
def einsum_lowest_level(self, q, k, v, dim, offset, slice_size):
|
|
||||||
# calculate attention scores
|
|
||||||
attention_scores = einsum('b i d, b j d -> b i j', q, k)
|
|
||||||
# calculate attention slice by taking the best scores for each latent pixel
|
|
||||||
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
|
|
||||||
attention_slice_wrangler = self.attention_slice_wrangler
|
|
||||||
if attention_slice_wrangler is not None:
|
|
||||||
attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size)
|
|
||||||
else:
|
|
||||||
attention_slice = default_attention_slice
|
|
||||||
|
|
||||||
return einsum('b i j, b j d -> b i d', attention_slice, v)
|
|
||||||
|
|
||||||
def einsum_op_slice_dim0(self, q, k, v, slice_size):
|
|
||||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
|
||||||
for i in range(0, q.shape[0], slice_size):
|
|
||||||
end = i + slice_size
|
|
||||||
r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
|
|
||||||
return r
|
|
||||||
|
|
||||||
def einsum_op_slice_dim1(self, q, k, v, slice_size):
|
|
||||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
|
||||||
for i in range(0, q.shape[1], slice_size):
|
|
||||||
end = i + slice_size
|
|
||||||
r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
|
|
||||||
return r
|
|
||||||
|
|
||||||
def einsum_op_mps_v1(self, q, k, v):
|
|
||||||
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
|
|
||||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
|
||||||
else:
|
|
||||||
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
|
||||||
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
|
||||||
|
|
||||||
def einsum_op_mps_v2(self, q, k, v):
|
|
||||||
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
|
|
||||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
|
||||||
else:
|
|
||||||
return self.einsum_op_slice_dim0(q, k, v, 1)
|
|
||||||
|
|
||||||
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
|
|
||||||
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
|
|
||||||
if size_mb <= max_tensor_mb:
|
|
||||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
|
||||||
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
|
|
||||||
if div <= q.shape[0]:
|
|
||||||
return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div)
|
|
||||||
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1))
|
|
||||||
|
|
||||||
def einsum_op_cuda(self, q, k, v):
|
|
||||||
# check if we already have a slicing strategy (this should only happen during cross-attention controlled generation)
|
|
||||||
slicing_strategy_getter = self.slicing_strategy_getter
|
|
||||||
if slicing_strategy_getter is not None:
|
|
||||||
(dim, slice_size) = slicing_strategy_getter(self)
|
|
||||||
if dim is not None:
|
|
||||||
# print("using saved slicing strategy with dim", dim, "slice size", slice_size)
|
|
||||||
if dim == 0:
|
|
||||||
return self.einsum_op_slice_dim0(q, k, v, slice_size)
|
|
||||||
elif dim == 1:
|
|
||||||
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
|
||||||
|
|
||||||
# fallback for when there is no saved strategy, or saved strategy does not slice
|
|
||||||
mem_free_total = self.cached_mem_free_total or get_mem_free_total(q.device)
|
|
||||||
# Divide factor of safety as there's copying and fragmentation
|
|
||||||
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
|
||||||
|
|
||||||
|
|
||||||
def get_attention_mem_efficient(self, q, k, v):
|
|
||||||
if q.device.type == 'cuda':
|
|
||||||
#print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device))
|
|
||||||
return self.einsum_op_cuda(q, k, v)
|
|
||||||
|
|
||||||
if q.device.type == 'mps':
|
|
||||||
if self.mem_total_gb >= 32:
|
|
||||||
return self.einsum_op_mps_v1(q, k, v)
|
|
||||||
return self.einsum_op_mps_v2(q, k, v)
|
|
||||||
|
|
||||||
# Smaller slices are faster due to L2/L3/SLC caches.
|
|
||||||
# Tested on i7 with 8MB L3 cache.
|
|
||||||
return self.einsum_op_tensor_mem(q, k, v, 32)
|
|
||||||
|
|
||||||
def forward(self, x, context=None, mask=None):
|
def forward(self, x, context=None, mask=None):
|
||||||
h = self.heads
|
h = self.heads
|
||||||
@ -307,7 +196,11 @@ class CrossAttention(nn.Module):
|
|||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||||
|
|
||||||
r = self.get_attention_mem_efficient(q, k, v)
|
# prevent scale being applied twice
|
||||||
|
cached_scale = self.scale
|
||||||
|
self.scale = 1
|
||||||
|
r = self.get_invokeai_attention_mem_efficient(q, k, v)
|
||||||
|
self.scale = cached_scale
|
||||||
|
|
||||||
hidden_states = rearrange(r, '(b h) n d -> b n (h d)', h=h)
|
hidden_states = rearrange(r, '(b h) n d -> b n (h d)', h=h)
|
||||||
return self.to_out(hidden_states)
|
return self.to_out(hidden_states)
|
||||||
|
Loading…
Reference in New Issue
Block a user