mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
go back to using InvokeAI attention
This commit is contained in:
parent
1fc1f8bf05
commit
37a204324b
@ -75,11 +75,12 @@ class CrossAttentionControl:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def inject_attention_functions(cls, unet):
|
def inject_attention_functions(cls, unet):
|
||||||
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
|
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
|
||||||
|
|
||||||
def new_attention(self, query, key, value):
|
def new_attention(self, query, key, value):
|
||||||
# TODO: use baddbmm for better performance
|
|
||||||
attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
|
attention_scores = torch.functional.einsum('b i d, b j d -> b i j', query, key)
|
||||||
attn_slice = attention_scores.softmax(dim=-1)
|
# calculate attention slice by taking the best scores for each latent pixel
|
||||||
# compute attention output
|
attn_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
|
||||||
|
|
||||||
if self.use_last_attn_slice:
|
if self.use_last_attn_slice:
|
||||||
if self.last_attn_slice_mask is not None:
|
if self.last_attn_slice_mask is not None:
|
||||||
@ -100,13 +101,12 @@ class CrossAttentionControl:
|
|||||||
attn_slice = attn_slice * self.last_attn_slice_weights
|
attn_slice = attn_slice * self.last_attn_slice_weights
|
||||||
self.use_last_attn_weights = False
|
self.use_last_attn_weights = False
|
||||||
|
|
||||||
hidden_states = torch.matmul(attn_slice, value)
|
return torch.functional.einsum('b i j, b j d -> b i d', attn_slice, value)
|
||||||
# reshape hidden_states
|
|
||||||
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
def new_sliced_attention(self, query, key, value, sequence_length, dim):
|
def new_sliced_attention(self, query, key, value, sequence_length, dim):
|
||||||
|
|
||||||
|
raise NotImplementedError("not tested yet")
|
||||||
|
|
||||||
batch_size_attention = query.shape[0]
|
batch_size_attention = query.shape[0]
|
||||||
hidden_states = torch.zeros(
|
hidden_states = torch.zeros(
|
||||||
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
|
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
|
||||||
@ -146,6 +146,12 @@ class CrossAttentionControl:
|
|||||||
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def select_attention_func(module, q, k, v, dim, offset, slice_size):
|
||||||
|
if dim == 0 or dim == 1:
|
||||||
|
return new_sliced_attention(module, q, k, v, sequence_length=slice_size, dim=dim)
|
||||||
|
else:
|
||||||
|
return new_attention(module, q, k, v)
|
||||||
|
|
||||||
for name, module in unet.named_modules():
|
for name, module in unet.named_modules():
|
||||||
module_name = type(module).__name__
|
module_name = type(module).__name__
|
||||||
if module_name == "CrossAttention":
|
if module_name == "CrossAttention":
|
||||||
@ -153,8 +159,7 @@ class CrossAttentionControl:
|
|||||||
module.use_last_attn_slice = False
|
module.use_last_attn_slice = False
|
||||||
module.use_last_attn_weights = False
|
module.use_last_attn_weights = False
|
||||||
module.save_last_attn_slice = False
|
module.save_last_attn_slice = False
|
||||||
module._sliced_attention = new_sliced_attention.__get__(module, type(module))
|
module.set_custom_attention_calculator(select_attention_func)
|
||||||
module._attention = new_attention.__get__(module, type(module))
|
|
||||||
|
|
||||||
|
|
||||||
# original code below
|
# original code below
|
||||||
|
@ -48,21 +48,21 @@ class CFGDenoiser(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||||
|
|
||||||
print('generating unconditioned latents')
|
#rint('generating unconditioned latents')
|
||||||
unconditioned_latents = self.inner_model(x, sigma, cond=uncond)
|
unconditioned_latents = self.inner_model(x, sigma, cond=uncond)
|
||||||
|
|
||||||
# process x using the original prompt, saving the attention maps if required
|
# process x using the original prompt, saving the attention maps if required
|
||||||
if self.edited_conditioning is not None:
|
if self.edited_conditioning is not None:
|
||||||
# this is automatically toggled off after the model forward()
|
# this is automatically toggled off after the model forward()
|
||||||
CrossAttentionControl.request_save_attention_maps(self.inner_model)
|
CrossAttentionControl.request_save_attention_maps(self.inner_model)
|
||||||
print('generating conditioned latents')
|
#print('generating conditioned latents')
|
||||||
conditioned_latents = self.inner_model(x, sigma, cond=cond)
|
conditioned_latents = self.inner_model(x, sigma, cond=cond)
|
||||||
|
|
||||||
if self.edited_conditioning is not None:
|
if self.edited_conditioning is not None:
|
||||||
# process x again, using the saved attention maps but the new conditioning
|
# process x again, using the saved attention maps but the new conditioning
|
||||||
# this is automatically toggled off after the model forward()
|
# this is automatically toggled off after the model forward()
|
||||||
CrossAttentionControl.request_apply_saved_attention_maps(self.inner_model)
|
CrossAttentionControl.request_apply_saved_attention_maps(self.inner_model)
|
||||||
print('generating edited conditioned latents')
|
#print('generating edited conditioned latents')
|
||||||
conditioned_latents = self.inner_model(x, sigma, cond=self.edited_conditioning)
|
conditioned_latents = self.inner_model(x, sigma, cond=self.edited_conditioning)
|
||||||
|
|
||||||
if self.warmup < self.warmup_max:
|
if self.warmup < self.warmup_max:
|
||||||
|
@ -1,367 +1,158 @@
|
|||||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
from inspect import isfunction
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import math
|
import math
|
||||||
from typing import Optional
|
from typing import Callable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn, einsum
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
|
from ldm.modules.diffusionmodules.util import checkpoint
|
||||||
|
|
||||||
|
import psutil
|
||||||
|
|
||||||
|
def exists(val):
|
||||||
|
return val is not None
|
||||||
|
|
||||||
|
|
||||||
class AttentionBlock(nn.Module):
|
def uniq(arr):
|
||||||
"""
|
return{el: True for el in arr}.keys()
|
||||||
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
|
|
||||||
to the N-d case.
|
|
||||||
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
|
||||||
Uses three q, k, v linear layers to compute attention.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
channels (:obj:`int`): The number of channels in the input and output.
|
|
||||||
num_head_channels (:obj:`int`, *optional*):
|
|
||||||
The number of channels in each head. If None, then `num_heads` = 1.
|
|
||||||
num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm.
|
|
||||||
rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
|
|
||||||
eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
channels: int,
|
|
||||||
num_head_channels: Optional[int] = None,
|
|
||||||
num_groups: int = 32,
|
|
||||||
rescale_output_factor: float = 1.0,
|
|
||||||
eps: float = 1e-5,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.channels = channels
|
|
||||||
|
|
||||||
self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
|
|
||||||
self.num_head_size = num_head_channels
|
|
||||||
self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
|
|
||||||
|
|
||||||
# define q,k,v as linear layers
|
|
||||||
self.query = nn.Linear(channels, channels)
|
|
||||||
self.key = nn.Linear(channels, channels)
|
|
||||||
self.value = nn.Linear(channels, channels)
|
|
||||||
|
|
||||||
self.rescale_output_factor = rescale_output_factor
|
|
||||||
self.proj_attn = nn.Linear(channels, channels, 1)
|
|
||||||
|
|
||||||
def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
|
|
||||||
new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
|
|
||||||
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
|
|
||||||
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
|
|
||||||
return new_projection
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
residual = hidden_states
|
|
||||||
batch, channel, height, width = hidden_states.shape
|
|
||||||
|
|
||||||
# norm
|
|
||||||
hidden_states = self.group_norm(hidden_states)
|
|
||||||
|
|
||||||
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
|
|
||||||
|
|
||||||
# proj to q, k, v
|
|
||||||
query_proj = self.query(hidden_states)
|
|
||||||
key_proj = self.key(hidden_states)
|
|
||||||
value_proj = self.value(hidden_states)
|
|
||||||
|
|
||||||
# transpose
|
|
||||||
query_states = self.transpose_for_scores(query_proj)
|
|
||||||
key_states = self.transpose_for_scores(key_proj)
|
|
||||||
value_states = self.transpose_for_scores(value_proj)
|
|
||||||
|
|
||||||
# get scores
|
|
||||||
scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
|
|
||||||
attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) # TODO: use baddmm
|
|
||||||
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
|
|
||||||
|
|
||||||
# compute attention output
|
|
||||||
hidden_states = torch.matmul(attention_probs, value_states)
|
|
||||||
|
|
||||||
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
|
|
||||||
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
|
|
||||||
hidden_states = hidden_states.view(new_hidden_states_shape)
|
|
||||||
|
|
||||||
# compute next hidden_states
|
|
||||||
hidden_states = self.proj_attn(hidden_states)
|
|
||||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
|
|
||||||
|
|
||||||
# res connect and rescale
|
|
||||||
hidden_states = (hidden_states + residual) / self.rescale_output_factor
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class SpatialTransformer(nn.Module):
|
def default(val, d):
|
||||||
"""
|
if exists(val):
|
||||||
Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
|
return val
|
||||||
standard transformer action. Finally, reshape to image.
|
return d() if isfunction(d) else d
|
||||||
|
|
||||||
Parameters:
|
|
||||||
in_channels (:obj:`int`): The number of channels in the input and output.
|
|
||||||
n_heads (:obj:`int`): The number of heads to use for multi-head attention.
|
|
||||||
d_head (:obj:`int`): The number of channels in each head.
|
|
||||||
depth (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
|
||||||
dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use.
|
|
||||||
context_dim (:obj:`int`, *optional*): The number of context dimensions to use.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels: int,
|
|
||||||
n_heads: int,
|
|
||||||
d_head: int,
|
|
||||||
depth: int = 1,
|
|
||||||
dropout: float = 0.0,
|
|
||||||
num_groups: int = 32,
|
|
||||||
context_dim: Optional[int] = None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.n_heads = n_heads
|
|
||||||
self.d_head = d_head
|
|
||||||
self.in_channels = in_channels
|
|
||||||
inner_dim = n_heads * d_head
|
|
||||||
self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
|
||||||
|
|
||||||
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
|
||||||
|
|
||||||
self.transformer_blocks = nn.ModuleList(
|
|
||||||
[
|
|
||||||
BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
|
|
||||||
for d in range(depth)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
|
||||||
|
|
||||||
def _set_attention_slice(self, slice_size):
|
|
||||||
for block in self.transformer_blocks:
|
|
||||||
block._set_attention_slice(slice_size)
|
|
||||||
|
|
||||||
def forward(self, hidden_states, context=None):
|
|
||||||
# note: if no context is given, cross-attention defaults to self-attention
|
|
||||||
batch, channel, height, weight = hidden_states.shape
|
|
||||||
residual = hidden_states
|
|
||||||
hidden_states = self.norm(hidden_states)
|
|
||||||
hidden_states = self.proj_in(hidden_states)
|
|
||||||
inner_dim = hidden_states.shape[1]
|
|
||||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
|
||||||
for block in self.transformer_blocks:
|
|
||||||
hidden_states = block(hidden_states, context=context)
|
|
||||||
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
|
|
||||||
hidden_states = self.proj_out(hidden_states)
|
|
||||||
return hidden_states + residual
|
|
||||||
|
|
||||||
|
|
||||||
class BasicTransformerBlock(nn.Module):
|
def max_neg_value(t):
|
||||||
r"""
|
return -torch.finfo(t.dtype).max
|
||||||
A basic Transformer block.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
dim (:obj:`int`): The number of channels in the input and output.
|
|
||||||
n_heads (:obj:`int`): The number of heads to use for multi-head attention.
|
|
||||||
d_head (:obj:`int`): The number of channels in each head.
|
|
||||||
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
|
||||||
context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention.
|
|
||||||
gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network.
|
|
||||||
checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim: int,
|
|
||||||
n_heads: int,
|
|
||||||
d_head: int,
|
|
||||||
dropout=0.0,
|
|
||||||
context_dim: Optional[int] = None,
|
|
||||||
gated_ff: bool = True,
|
|
||||||
checkpoint: bool = True,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.attn1 = CrossAttention(
|
|
||||||
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
|
|
||||||
) # is a self-attention
|
|
||||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
|
||||||
self.attn2 = CrossAttention(
|
|
||||||
query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout
|
|
||||||
) # is self-attn if context is none
|
|
||||||
self.norm1 = nn.LayerNorm(dim)
|
|
||||||
self.norm2 = nn.LayerNorm(dim)
|
|
||||||
self.norm3 = nn.LayerNorm(dim)
|
|
||||||
self.checkpoint = checkpoint
|
|
||||||
|
|
||||||
def _set_attention_slice(self, slice_size):
|
|
||||||
self.attn1._slice_size = slice_size
|
|
||||||
self.attn2._slice_size = slice_size
|
|
||||||
|
|
||||||
def forward(self, hidden_states, context=None):
|
|
||||||
hidden_states = hidden_states.contiguous() if hidden_states.device.type == "mps" else hidden_states
|
|
||||||
hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states
|
|
||||||
hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states
|
|
||||||
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class CrossAttention(nn.Module):
|
def init_(tensor):
|
||||||
r"""
|
dim = tensor.shape[-1]
|
||||||
A cross attention layer.
|
std = 1 / math.sqrt(dim)
|
||||||
|
tensor.uniform_(-std, std)
|
||||||
Parameters:
|
|
||||||
query_dim (:obj:`int`): The number of channels in the query.
|
|
||||||
context_dim (:obj:`int`, *optional*):
|
|
||||||
The number of channels in the context. If not given, defaults to `query_dim`.
|
|
||||||
heads (:obj:`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
|
|
||||||
dim_head (:obj:`int`, *optional*, defaults to 64): The number of channels in each head.
|
|
||||||
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, query_dim: int, context_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: int = 0.0
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
inner_dim = dim_head * heads
|
|
||||||
context_dim = context_dim if context_dim is not None else query_dim
|
|
||||||
|
|
||||||
self.scale = dim_head**-0.5
|
|
||||||
self.heads = heads
|
|
||||||
# for slice_size > 0 the attention score computation
|
|
||||||
# is split across the batch axis to save memory
|
|
||||||
# You can set slice_size with `set_attention_slice`
|
|
||||||
self._slice_size = None
|
|
||||||
|
|
||||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
|
||||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
|
||||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
|
||||||
|
|
||||||
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
|
||||||
|
|
||||||
def reshape_heads_to_batch_dim(self, tensor):
|
|
||||||
batch_size, seq_len, dim = tensor.shape
|
|
||||||
head_size = self.heads
|
|
||||||
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
|
||||||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
|
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
def reshape_batch_dim_to_heads(self, tensor):
|
|
||||||
batch_size, seq_len, dim = tensor.shape
|
|
||||||
head_size = self.heads
|
|
||||||
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
|
||||||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
def forward(self, hidden_states, context=None, mask=None):
|
|
||||||
batch_size, sequence_length, _ = hidden_states.shape
|
|
||||||
|
|
||||||
query = self.to_q(hidden_states)
|
|
||||||
context = context if context is not None else hidden_states
|
|
||||||
key = self.to_k(context)
|
|
||||||
value = self.to_v(context)
|
|
||||||
|
|
||||||
dim = query.shape[-1]
|
|
||||||
|
|
||||||
query = self.reshape_heads_to_batch_dim(query)
|
|
||||||
key = self.reshape_heads_to_batch_dim(key)
|
|
||||||
value = self.reshape_heads_to_batch_dim(value)
|
|
||||||
|
|
||||||
# TODO(PVP) - mask is currently never used. Remember to re-implement when used
|
|
||||||
|
|
||||||
# attention, what we cannot get enough of
|
|
||||||
|
|
||||||
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
|
||||||
hidden_states = self._attention(query, key, value)
|
|
||||||
else:
|
|
||||||
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)
|
|
||||||
|
|
||||||
return self.to_out(hidden_states)
|
|
||||||
|
|
||||||
def _attention(self, query, key, value):
|
|
||||||
# TODO: use baddbmm for better performance
|
|
||||||
attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
|
|
||||||
attention_probs = attention_scores.softmax(dim=-1)
|
|
||||||
# compute attention output
|
|
||||||
hidden_states = torch.matmul(attention_probs, value)
|
|
||||||
# reshape hidden_states
|
|
||||||
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
def _sliced_attention(self, query, key, value, sequence_length, dim):
|
|
||||||
batch_size_attention = query.shape[0]
|
|
||||||
hidden_states = torch.zeros(
|
|
||||||
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
|
|
||||||
)
|
|
||||||
slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
|
|
||||||
for i in range(hidden_states.shape[0] // slice_size):
|
|
||||||
start_idx = i * slice_size
|
|
||||||
end_idx = (i + 1) * slice_size
|
|
||||||
attn_slice = (
|
|
||||||
torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
|
|
||||||
) # TODO: use baddbmm for better performance
|
|
||||||
attn_slice = attn_slice.softmax(dim=-1)
|
|
||||||
attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])
|
|
||||||
|
|
||||||
hidden_states[start_idx:end_idx] = attn_slice
|
|
||||||
|
|
||||||
# reshape hidden_states
|
|
||||||
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class FeedForward(nn.Module):
|
|
||||||
r"""
|
|
||||||
A feed-forward layer.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
dim (:obj:`int`): The number of channels in the input.
|
|
||||||
dim_out (:obj:`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
|
||||||
mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
|
||||||
glu (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use GLU activation.
|
|
||||||
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, dim: int, dim_out: Optional[int] = None, mult: int = 4, glu: bool = False, dropout: float = 0.0
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
inner_dim = int(dim * mult)
|
|
||||||
dim_out = dim_out if dim_out is not None else dim
|
|
||||||
project_in = GEGLU(dim, inner_dim)
|
|
||||||
|
|
||||||
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
return self.net(hidden_states)
|
|
||||||
|
|
||||||
|
|
||||||
# feedforward
|
# feedforward
|
||||||
class GEGLU(nn.Module):
|
class GEGLU(nn.Module):
|
||||||
r"""
|
def __init__(self, dim_in, dim_out):
|
||||||
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
dim_in (:obj:`int`): The number of channels in the input.
|
|
||||||
dim_out (:obj:`int`): The number of channels in the output.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, dim_in: int, dim_out: int):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, x):
|
||||||
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
|
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||||
return hidden_states * F.gelu(gate)
|
return x * F.gelu(gate)
|
||||||
'''
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = int(dim * mult)
|
||||||
|
dim_out = default(dim_out, dim)
|
||||||
|
project_in = nn.Sequential(
|
||||||
|
nn.Linear(dim, inner_dim),
|
||||||
|
nn.GELU()
|
||||||
|
) if not glu else GEGLU(dim, inner_dim)
|
||||||
|
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
project_in,
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(inner_dim, dim_out)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
|
def zero_module(module):
|
||||||
|
"""
|
||||||
|
Zero out the parameters of a module and return it.
|
||||||
|
"""
|
||||||
|
for p in module.parameters():
|
||||||
|
p.detach().zero_()
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
def Normalize(in_channels):
|
||||||
|
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||||
|
|
||||||
|
|
||||||
|
class LinearAttention(nn.Module):
|
||||||
|
def __init__(self, dim, heads=4, dim_head=32):
|
||||||
|
super().__init__()
|
||||||
|
self.heads = heads
|
||||||
|
hidden_dim = dim_head * heads
|
||||||
|
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
|
||||||
|
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
b, c, h, w = x.shape
|
||||||
|
qkv = self.to_qkv(x)
|
||||||
|
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
|
||||||
|
k = k.softmax(dim=-1)
|
||||||
|
context = torch.einsum('bhdn,bhen->bhde', k, v)
|
||||||
|
out = torch.einsum('bhde,bhdn->bhen', context, q)
|
||||||
|
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
|
||||||
|
return self.to_out(out)
|
||||||
|
|
||||||
|
|
||||||
|
class SpatialSelfAttention(nn.Module):
|
||||||
|
def __init__(self, in_channels):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
|
||||||
|
self.norm = Normalize(in_channels)
|
||||||
|
self.q = torch.nn.Conv2d(in_channels,
|
||||||
|
in_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0)
|
||||||
|
self.k = torch.nn.Conv2d(in_channels,
|
||||||
|
in_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0)
|
||||||
|
self.v = torch.nn.Conv2d(in_channels,
|
||||||
|
in_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0)
|
||||||
|
self.proj_out = torch.nn.Conv2d(in_channels,
|
||||||
|
in_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
h_ = x
|
||||||
|
h_ = self.norm(h_)
|
||||||
|
q = self.q(h_)
|
||||||
|
k = self.k(h_)
|
||||||
|
v = self.v(h_)
|
||||||
|
|
||||||
|
# compute attention
|
||||||
|
b,c,h,w = q.shape
|
||||||
|
q = rearrange(q, 'b c h w -> b (h w) c')
|
||||||
|
k = rearrange(k, 'b c h w -> b c (h w)')
|
||||||
|
w_ = torch.einsum('bij,bjk->bik', q, k)
|
||||||
|
|
||||||
|
w_ = w_ * (int(c)**(-0.5))
|
||||||
|
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||||
|
|
||||||
|
# attend to values
|
||||||
|
v = rearrange(v, 'b c h w -> b c (h w)')
|
||||||
|
w_ = rearrange(w_, 'b i j -> b j i')
|
||||||
|
h_ = torch.einsum('bij,bjk->bik', v, w_)
|
||||||
|
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
|
||||||
|
h_ = self.proj_out(h_)
|
||||||
|
|
||||||
|
return x+h_
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CrossAttention(nn.Module):
|
class CrossAttention(nn.Module):
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -382,39 +173,51 @@ class CrossAttention(nn.Module):
|
|||||||
|
|
||||||
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
||||||
|
|
||||||
self.cross_attention_callback = None
|
self.custom_attention_calculator = None
|
||||||
|
|
||||||
|
def set_custom_attention_calculator(self, callback:Callable[[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int], torch.Tensor]):
|
||||||
|
'''
|
||||||
|
Set custom attention calculator to be called when attention is calculated
|
||||||
|
:param callback: Callback, with args q, k, v, dim, offset, slice_size, which returns attention info.
|
||||||
|
q, k, v are as regular attention calculator.
|
||||||
|
dim is -1 if the call is non-sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
|
||||||
|
If dim is >= 0, offset and slice_size specify the slice start and length.
|
||||||
|
Pass None to use the default attention calculation.
|
||||||
|
:return:
|
||||||
|
'''
|
||||||
|
self.custom_attention_calculator = callback
|
||||||
|
|
||||||
def einsum_op_slice_dim0(self, q, k, v, slice_size, callback):
|
def einsum_op_slice_dim0(self, q, k, v, slice_size, callback):
|
||||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||||
for i in range(0, q.shape[0], slice_size):
|
for i in range(0, q.shape[0], slice_size):
|
||||||
end = i + slice_size
|
end = i + slice_size
|
||||||
r[i:end] = callback(q[i:end], k[i:end], v[i:end], offset=i)
|
r[i:end] = callback(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
|
||||||
return r
|
return r
|
||||||
|
|
||||||
def einsum_op_slice_dim1(self, q, k, v, slice_size, callback):
|
def einsum_op_slice_dim1(self, q, k, v, slice_size, callback):
|
||||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||||
for i in range(0, q.shape[1], slice_size):
|
for i in range(0, q.shape[1], slice_size):
|
||||||
end = i + slice_size
|
end = i + slice_size
|
||||||
r[:, i:end] = callback(q[:, i:end], k, v, offset=i)
|
r[:, i:end] = callback(self, q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
|
||||||
return r
|
return r
|
||||||
|
|
||||||
def einsum_op_mps_v1(self, q, k, v, callback):
|
def einsum_op_mps_v1(self, q, k, v, callback):
|
||||||
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
|
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
|
||||||
return callback(q, k, v)
|
return callback(self, q, k, v, -1, 0, 0)
|
||||||
else:
|
else:
|
||||||
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
||||||
return self.einsum_op_slice_dim1(q, k, v, slice_size, callback)
|
return self.einsum_op_slice_dim1(q, k, v, slice_size, callback)
|
||||||
|
|
||||||
def einsum_op_mps_v2(self, q, k, v, callback):
|
def einsum_op_mps_v2(self, q, k, v, callback):
|
||||||
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
|
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
|
||||||
return callback(q, k, v, offset=0)
|
return callback(self, q, k, v, -1, 0, 0)
|
||||||
else:
|
else:
|
||||||
return self.einsum_op_slice_dim0(q, k, v, 1, callback)
|
return self.einsum_op_slice_dim0(q, k, v, 1, callback)
|
||||||
|
|
||||||
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb, callback):
|
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb, callback):
|
||||||
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
|
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
|
||||||
if size_mb <= max_tensor_mb:
|
if size_mb <= max_tensor_mb:
|
||||||
return callback(q, k, v, offset=0)
|
return callback(self, q, k, v, offset=0)
|
||||||
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
|
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
|
||||||
if div <= q.shape[0]:
|
if div <= q.shape[0]:
|
||||||
print("warning: untested call to einsum_op_slice_dim0")
|
print("warning: untested call to einsum_op_slice_dim0")
|
||||||
@ -433,12 +236,6 @@ class CrossAttention(nn.Module):
|
|||||||
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20), callback)
|
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20), callback)
|
||||||
|
|
||||||
def get_attention_mem_efficient(self, q, k, v, callback):
|
def get_attention_mem_efficient(self, q, k, v, callback):
|
||||||
"""
|
|
||||||
Calculate attention by slicing q, k, and v for memory efficiency then calling
|
|
||||||
callback(q, k, v, offset=offset)
|
|
||||||
multiple times if necessary. The offset argument is something
|
|
||||||
"""
|
|
||||||
|
|
||||||
if q.device.type == 'cuda':
|
if q.device.type == 'cuda':
|
||||||
return self.einsum_op_cuda(q, k, v, callback)
|
return self.einsum_op_cuda(q, k, v, callback)
|
||||||
|
|
||||||
@ -479,7 +276,6 @@ class CrossAttention(nn.Module):
|
|||||||
return self.to_out(hidden_states)
|
return self.to_out(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class BasicTransformerBlock(nn.Module):
|
class BasicTransformerBlock(nn.Module):
|
||||||
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
|
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -547,4 +343,3 @@ class SpatialTransformer(nn.Module):
|
|||||||
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
|
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
|
||||||
x = self.proj_out(x)
|
x = self.proj_out(x)
|
||||||
return x + x_in
|
return x + x_in
|
||||||
'''
|
|
Loading…
Reference in New Issue
Block a user