mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add cross-attention control support to diffusers (fails on MPS)
For unknown reasons MPS produces garbage output with .swap(). Use --always_use_cpu arg to invoke.py for now to test this code on MPS.
This commit is contained in:
parent
f48706efee
commit
ff42027a00
@ -9,6 +9,14 @@ import PIL.Image
|
||||
import einops
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
|
||||
from ldm.models.diffusion.cross_attention_control import InvokeAICrossAttention
|
||||
|
||||
from diffusers.models import attention
|
||||
# monkeypatch diffusers CrossAttention 🙈
|
||||
# this is to make prompt2prompt and (future) attention maps work
|
||||
attention.CrossAttention = InvokeAICrossAttention
|
||||
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
|
||||
|
@ -1,8 +1,11 @@
|
||||
import enum
|
||||
import warnings
|
||||
from typing import Optional
|
||||
import math
|
||||
from typing import Optional, Callable
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
import diffusers
|
||||
from torch import nn
|
||||
|
||||
|
||||
# adapted from bloc97's CrossAttentionControl colab
|
||||
@ -66,8 +69,12 @@ class Context:
|
||||
|
||||
def register_cross_attention_modules(self, model):
|
||||
for name,module in get_attention_modules(model, CrossAttentionType.SELF):
|
||||
if name in self.self_cross_attention_module_identifiers:
|
||||
assert False, f"name {name} cannot appear more than once"
|
||||
self.self_cross_attention_module_identifiers.append(name)
|
||||
for name,module in get_attention_modules(model, CrossAttentionType.TOKENS):
|
||||
if name in self.tokens_cross_attention_module_identifiers:
|
||||
assert False, f"name {name} cannot appear more than once"
|
||||
self.tokens_cross_attention_module_identifiers.append(name)
|
||||
|
||||
def request_save_attention_maps(self, cross_attention_type: CrossAttentionType):
|
||||
@ -189,7 +196,7 @@ def setup_cross_attention_control(model, context: Context):
|
||||
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention
|
||||
mask = torch.zeros(max_length)
|
||||
indices_target = torch.arange(max_length, dtype=torch.long)
|
||||
indices = torch.zeros(max_length, dtype=torch.long)
|
||||
indices = torch.arange(max_length, dtype=torch.long)
|
||||
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
|
||||
if b0 < max_length:
|
||||
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
|
||||
@ -204,9 +211,22 @@ def setup_cross_attention_control(model, context: Context):
|
||||
|
||||
|
||||
def get_attention_modules(model, which: CrossAttentionType):
|
||||
# cross_attention_class: type = ldm.modules.attention.CrossAttention
|
||||
cross_attention_class: type = InvokeAICrossAttention
|
||||
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
|
||||
return [(name,module) for name, module in model.named_modules() if
|
||||
type(module).__name__ == "CrossAttention" and which_attn in name]
|
||||
attention_module_tuples = [(name,module) for name, module in model.named_modules() if
|
||||
isinstance(module, cross_attention_class) and which_attn in name]
|
||||
cross_attention_modules_in_model_count = len(attention_module_tuples)
|
||||
expected_count = 16
|
||||
if cross_attention_modules_in_model_count != expected_count:
|
||||
# non-fatal error but .swap() won't work.
|
||||
print(f"Error! CrossAttentionControl found an unexpected number of InvokeAICrossAttention modules in the model " +
|
||||
f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed " +
|
||||
f"or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, " +
|
||||
f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows " +
|
||||
f"what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not " +
|
||||
f"work properly until it is fixed.")
|
||||
return attention_module_tuples
|
||||
|
||||
|
||||
def inject_attention_function(unet, context: Context):
|
||||
@ -246,8 +266,7 @@ def inject_attention_function(unet, context: Context):
|
||||
|
||||
return attention_slice
|
||||
|
||||
cross_attention_modules = [(name, module) for (name, module) in unet.named_modules()
|
||||
if type(module).__name__ == "CrossAttention"]
|
||||
cross_attention_modules = get_attention_modules(unet, CrossAttentionType.TOKENS) + get_attention_modules(unet, CrossAttentionType.SELF)
|
||||
for identifier, module in cross_attention_modules:
|
||||
module.identifier = identifier
|
||||
try:
|
||||
@ -257,22 +276,21 @@ def inject_attention_function(unet, context: Context):
|
||||
)
|
||||
except AttributeError as e:
|
||||
if is_attribute_error_about(e, 'set_attention_slice_wrangler'):
|
||||
warnings.warn(f"TODO: implement for {type(module)}") # TODO
|
||||
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") # TODO
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def remove_attention_function(unet):
|
||||
cross_attention_modules = [module for (_, module) in unet.named_modules()
|
||||
if type(module).__name__ == "CrossAttention"]
|
||||
for module in cross_attention_modules:
|
||||
cross_attention_modules = get_attention_modules(unet, CrossAttentionType.TOKENS) + get_attention_modules(unet, CrossAttentionType.SELF)
|
||||
for identifier, module in cross_attention_modules:
|
||||
try:
|
||||
# clear wrangler callback
|
||||
module.set_attention_slice_wrangler(None)
|
||||
module.set_slicing_strategy_getter(None)
|
||||
except AttributeError as e:
|
||||
if is_attribute_error_about(e, 'set_attention_slice_wrangler'):
|
||||
warnings.warn(f"TODO: implement for {type(module)}") # TODO
|
||||
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}")
|
||||
else:
|
||||
raise
|
||||
|
||||
@ -282,3 +300,145 @@ def is_attribute_error_about(error: AttributeError, attribute: str):
|
||||
return error.name == attribute
|
||||
else: # Python 3.9
|
||||
return attribute in str(error)
|
||||
|
||||
|
||||
|
||||
def get_mem_free_total(device):
|
||||
#only on cuda
|
||||
if not torch.cuda.is_available():
|
||||
return None
|
||||
stats = torch.cuda.memory_stats(device)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(device)
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
return mem_free_total
|
||||
|
||||
class InvokeAICrossAttention(diffusers.models.attention.CrossAttention):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
||||
|
||||
self.attention_slice_wrangler = None
|
||||
self.slicing_strategy_getter = None
|
||||
|
||||
def set_attention_slice_wrangler(self, wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]]):
|
||||
'''
|
||||
Set custom attention calculator to be called when attention is calculated
|
||||
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
|
||||
which returns either the suggested_attention_slice or an adjusted equivalent.
|
||||
`module` is the current CrossAttention module for which the callback is being invoked.
|
||||
`suggested_attention_slice` is the default-calculated attention slice
|
||||
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
|
||||
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
|
||||
|
||||
Pass None to use the default attention calculation.
|
||||
:return:
|
||||
'''
|
||||
self.attention_slice_wrangler = wrangler
|
||||
|
||||
def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int,int]]]):
|
||||
self.slicing_strategy_getter = getter
|
||||
|
||||
def _attention(self, query, key, value):
|
||||
#default_result = super()._attention(query, key, value)
|
||||
damian_result = self.get_attention_mem_efficient(query, key, value)
|
||||
|
||||
hidden_states = self.reshape_batch_dim_to_heads(damian_result)
|
||||
return hidden_states
|
||||
|
||||
def einsum_lowest_level(self, query, key, value, dim, offset, slice_size):
|
||||
# calculate attention scores
|
||||
#attention_scores = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
if dim is not None:
|
||||
print(f"sliced dim {dim}, offset {offset}, slice_size {slice_size}")
|
||||
attention_scores = torch.baddbmm(
|
||||
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
||||
query,
|
||||
key.transpose(-1, -2),
|
||||
beta=0,
|
||||
alpha=self.scale,
|
||||
)
|
||||
|
||||
# calculate attention slice by taking the best scores for each latent pixel
|
||||
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
|
||||
attention_slice_wrangler = self.attention_slice_wrangler
|
||||
if attention_slice_wrangler is not None:
|
||||
attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size)
|
||||
else:
|
||||
attention_slice = default_attention_slice
|
||||
|
||||
#return torch.einsum('b i j, b j d -> b i d', attention_slice, v)
|
||||
hidden_states = torch.bmm(attention_slice, value)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def einsum_op_slice_dim0(self, q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[0], slice_size):
|
||||
end = i + slice_size
|
||||
r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
|
||||
return r
|
||||
|
||||
def einsum_op_slice_dim1(self, q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
|
||||
return r
|
||||
|
||||
def einsum_op_mps_v1(self, q, k, v):
|
||||
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
|
||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||
else:
|
||||
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
||||
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
||||
|
||||
def einsum_op_mps_v2(self, q, k, v):
|
||||
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
|
||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||
else:
|
||||
return self.einsum_op_slice_dim0(q, k, v, 1)
|
||||
|
||||
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
|
||||
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
|
||||
if size_mb <= max_tensor_mb:
|
||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
|
||||
if div <= q.shape[0]:
|
||||
return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div)
|
||||
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1))
|
||||
|
||||
def einsum_op_cuda(self, q, k, v):
|
||||
# check if we already have a slicing strategy (this should only happen during cross-attention controlled generation)
|
||||
slicing_strategy_getter = self.slicing_strategy_getter
|
||||
if slicing_strategy_getter is not None:
|
||||
(dim, slice_size) = slicing_strategy_getter(self)
|
||||
if dim is not None:
|
||||
# print("using saved slicing strategy with dim", dim, "slice size", slice_size)
|
||||
if dim == 0:
|
||||
return self.einsum_op_slice_dim0(q, k, v, slice_size)
|
||||
elif dim == 1:
|
||||
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
||||
|
||||
# fallback for when there is no saved strategy, or saved strategy does not slice
|
||||
mem_free_total = get_mem_free_total(q.device)
|
||||
# Divide factor of safety as there's copying and fragmentation
|
||||
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
||||
|
||||
|
||||
def get_attention_mem_efficient(self, q, k, v):
|
||||
if q.device.type == 'cuda':
|
||||
#print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device))
|
||||
return self.einsum_op_cuda(q, k, v)
|
||||
|
||||
if q.device.type == 'mps' or q.device.type == 'cpu':
|
||||
if self.mem_total_gb >= 32:
|
||||
return self.einsum_op_mps_v1(q, k, v)
|
||||
return self.einsum_op_mps_v2(q, k, v)
|
||||
|
||||
# Smaller slices are faster due to L2/L3/SLC caches.
|
||||
# Tested on i7 with 8MB L3 cache.
|
||||
return self.einsum_op_tensor_mem(q, k, v, 32)
|
||||
|
@ -165,7 +165,9 @@ def get_mem_free_total(device):
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
||||
print(f"Warning! ldm.modules.attention.CrossAttention is no longer being maintained. Please use InvokeAICrossAttention instead.")
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
Loading…
Reference in New Issue
Block a user