diff --git a/ldm/models/diffusion/cross_attention.py b/ldm/models/diffusion/cross_attention.py index d829162f35..bcc6c8cc94 100644 --- a/ldm/models/diffusion/cross_attention.py +++ b/ldm/models/diffusion/cross_attention.py @@ -75,11 +75,12 @@ class CrossAttentionControl: @classmethod def inject_attention_functions(cls, unet): # ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276 + def new_attention(self, query, key, value): - # TODO: use baddbmm for better performance - attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale - attn_slice = attention_scores.softmax(dim=-1) - # compute attention output + + attention_scores = torch.functional.einsum('b i d, b j d -> b i j', query, key) + # calculate attention slice by taking the best scores for each latent pixel + attn_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype) if self.use_last_attn_slice: if self.last_attn_slice_mask is not None: @@ -100,13 +101,12 @@ class CrossAttentionControl: attn_slice = attn_slice * self.last_attn_slice_weights self.use_last_attn_weights = False - hidden_states = torch.matmul(attn_slice, value) - # reshape hidden_states - hidden_states = self.reshape_batch_dim_to_heads(hidden_states) - return hidden_states + return torch.functional.einsum('b i j, b j d -> b i d', attn_slice, value) def new_sliced_attention(self, query, key, value, sequence_length, dim): + raise NotImplementedError("not tested yet") + batch_size_attention = query.shape[0] hidden_states = torch.zeros( (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype @@ -146,6 +146,12 @@ class CrossAttentionControl: hidden_states = self.reshape_batch_dim_to_heads(hidden_states) return hidden_states + def select_attention_func(module, q, k, v, dim, offset, slice_size): + if dim == 0 or dim == 1: + return new_sliced_attention(module, q, k, v, sequence_length=slice_size, dim=dim) + else: + return new_attention(module, q, k, v) + for name, module in unet.named_modules(): module_name = type(module).__name__ if module_name == "CrossAttention": @@ -153,8 +159,7 @@ class CrossAttentionControl: module.use_last_attn_slice = False module.use_last_attn_weights = False module.save_last_attn_slice = False - module._sliced_attention = new_sliced_attention.__get__(module, type(module)) - module._attention = new_attention.__get__(module, type(module)) + module.set_custom_attention_calculator(select_attention_func) # original code below diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index e5d521f33f..f5af25fc04 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -48,21 +48,21 @@ class CFGDenoiser(nn.Module): def forward(self, x, sigma, uncond, cond, cond_scale): - print('generating unconditioned latents') + #rint('generating unconditioned latents') unconditioned_latents = self.inner_model(x, sigma, cond=uncond) # process x using the original prompt, saving the attention maps if required if self.edited_conditioning is not None: # this is automatically toggled off after the model forward() CrossAttentionControl.request_save_attention_maps(self.inner_model) - print('generating conditioned latents') + #print('generating conditioned latents') conditioned_latents = self.inner_model(x, sigma, cond=cond) if self.edited_conditioning is not None: # process x again, using the saved attention maps but the new conditioning # this is automatically toggled off after the model forward() CrossAttentionControl.request_apply_saved_attention_maps(self.inner_model) - print('generating edited conditioned latents') + #print('generating edited conditioned latents') conditioned_latents = self.inner_model(x, sigma, cond=self.edited_conditioning) if self.warmup < self.warmup_max: diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index d00b95b1af..1ee0795fdd 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -1,367 +1,158 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +from inspect import isfunction import math -from typing import Optional +from typing import Callable import torch import torch.nn.functional as F -from torch import nn +from torch import nn, einsum +from einops import rearrange, repeat + +from ldm.modules.diffusionmodules.util import checkpoint + +import psutil + +def exists(val): + return val is not None -class AttentionBlock(nn.Module): - """ - An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted - to the N-d case. - https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. - Uses three q, k, v linear layers to compute attention. - - Parameters: - channels (:obj:`int`): The number of channels in the input and output. - num_head_channels (:obj:`int`, *optional*): - The number of channels in each head. If None, then `num_heads` = 1. - num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm. - rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by. - eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. - """ - - def __init__( - self, - channels: int, - num_head_channels: Optional[int] = None, - num_groups: int = 32, - rescale_output_factor: float = 1.0, - eps: float = 1e-5, - ): - super().__init__() - self.channels = channels - - self.num_heads = channels // num_head_channels if num_head_channels is not None else 1 - self.num_head_size = num_head_channels - self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True) - - # define q,k,v as linear layers - self.query = nn.Linear(channels, channels) - self.key = nn.Linear(channels, channels) - self.value = nn.Linear(channels, channels) - - self.rescale_output_factor = rescale_output_factor - self.proj_attn = nn.Linear(channels, channels, 1) - - def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor: - new_projection_shape = projection.size()[:-1] + (self.num_heads, -1) - # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D) - new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) - return new_projection - - def forward(self, hidden_states): - residual = hidden_states - batch, channel, height, width = hidden_states.shape - - # norm - hidden_states = self.group_norm(hidden_states) - - hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) - - # proj to q, k, v - query_proj = self.query(hidden_states) - key_proj = self.key(hidden_states) - value_proj = self.value(hidden_states) - - # transpose - query_states = self.transpose_for_scores(query_proj) - key_states = self.transpose_for_scores(key_proj) - value_states = self.transpose_for_scores(value_proj) - - # get scores - scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads)) - attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) # TODO: use baddmm - attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype) - - # compute attention output - hidden_states = torch.matmul(attention_probs, value_states) - - hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous() - new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,) - hidden_states = hidden_states.view(new_hidden_states_shape) - - # compute next hidden_states - hidden_states = self.proj_attn(hidden_states) - hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) - - # res connect and rescale - hidden_states = (hidden_states + residual) / self.rescale_output_factor - return hidden_states +def uniq(arr): + return{el: True for el in arr}.keys() -class SpatialTransformer(nn.Module): - """ - Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply - standard transformer action. Finally, reshape to image. - - Parameters: - in_channels (:obj:`int`): The number of channels in the input and output. - n_heads (:obj:`int`): The number of heads to use for multi-head attention. - d_head (:obj:`int`): The number of channels in each head. - depth (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. - dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use. - context_dim (:obj:`int`, *optional*): The number of context dimensions to use. - """ - - def __init__( - self, - in_channels: int, - n_heads: int, - d_head: int, - depth: int = 1, - dropout: float = 0.0, - num_groups: int = 32, - context_dim: Optional[int] = None, - ): - super().__init__() - self.n_heads = n_heads - self.d_head = d_head - self.in_channels = in_channels - inner_dim = n_heads * d_head - self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) - - self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) - - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) - for d in range(depth) - ] - ) - - self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) - - def _set_attention_slice(self, slice_size): - for block in self.transformer_blocks: - block._set_attention_slice(slice_size) - - def forward(self, hidden_states, context=None): - # note: if no context is given, cross-attention defaults to self-attention - batch, channel, height, weight = hidden_states.shape - residual = hidden_states - hidden_states = self.norm(hidden_states) - hidden_states = self.proj_in(hidden_states) - inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) - for block in self.transformer_blocks: - hidden_states = block(hidden_states, context=context) - hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2) - hidden_states = self.proj_out(hidden_states) - return hidden_states + residual +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d -class BasicTransformerBlock(nn.Module): - r""" - A basic Transformer block. - - Parameters: - dim (:obj:`int`): The number of channels in the input and output. - n_heads (:obj:`int`): The number of heads to use for multi-head attention. - d_head (:obj:`int`): The number of channels in each head. - dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. - context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention. - gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network. - checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing. - """ - - def __init__( - self, - dim: int, - n_heads: int, - d_head: int, - dropout=0.0, - context_dim: Optional[int] = None, - gated_ff: bool = True, - checkpoint: bool = True, - ): - super().__init__() - self.attn1 = CrossAttention( - query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout - ) # is a self-attention - self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) - self.attn2 = CrossAttention( - query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout - ) # is self-attn if context is none - self.norm1 = nn.LayerNorm(dim) - self.norm2 = nn.LayerNorm(dim) - self.norm3 = nn.LayerNorm(dim) - self.checkpoint = checkpoint - - def _set_attention_slice(self, slice_size): - self.attn1._slice_size = slice_size - self.attn2._slice_size = slice_size - - def forward(self, hidden_states, context=None): - hidden_states = hidden_states.contiguous() if hidden_states.device.type == "mps" else hidden_states - hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states - hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states - hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states - return hidden_states +def max_neg_value(t): + return -torch.finfo(t.dtype).max -class CrossAttention(nn.Module): - r""" - A cross attention layer. - - Parameters: - query_dim (:obj:`int`): The number of channels in the query. - context_dim (:obj:`int`, *optional*): - The number of channels in the context. If not given, defaults to `query_dim`. - heads (:obj:`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. - dim_head (:obj:`int`, *optional*, defaults to 64): The number of channels in each head. - dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. - """ - - def __init__( - self, query_dim: int, context_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: int = 0.0 - ): - super().__init__() - inner_dim = dim_head * heads - context_dim = context_dim if context_dim is not None else query_dim - - self.scale = dim_head**-0.5 - self.heads = heads - # for slice_size > 0 the attention score computation - # is split across the batch axis to save memory - # You can set slice_size with `set_attention_slice` - self._slice_size = None - - self.to_q = nn.Linear(query_dim, inner_dim, bias=False) - self.to_k = nn.Linear(context_dim, inner_dim, bias=False) - self.to_v = nn.Linear(context_dim, inner_dim, bias=False) - - self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) - - def reshape_heads_to_batch_dim(self, tensor): - batch_size, seq_len, dim = tensor.shape - head_size = self.heads - tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) - return tensor - - def reshape_batch_dim_to_heads(self, tensor): - batch_size, seq_len, dim = tensor.shape - head_size = self.heads - tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) - return tensor - - def forward(self, hidden_states, context=None, mask=None): - batch_size, sequence_length, _ = hidden_states.shape - - query = self.to_q(hidden_states) - context = context if context is not None else hidden_states - key = self.to_k(context) - value = self.to_v(context) - - dim = query.shape[-1] - - query = self.reshape_heads_to_batch_dim(query) - key = self.reshape_heads_to_batch_dim(key) - value = self.reshape_heads_to_batch_dim(value) - - # TODO(PVP) - mask is currently never used. Remember to re-implement when used - - # attention, what we cannot get enough of - - if self._slice_size is None or query.shape[0] // self._slice_size == 1: - hidden_states = self._attention(query, key, value) - else: - hidden_states = self._sliced_attention(query, key, value, sequence_length, dim) - - return self.to_out(hidden_states) - - def _attention(self, query, key, value): - # TODO: use baddbmm for better performance - attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale - attention_probs = attention_scores.softmax(dim=-1) - # compute attention output - hidden_states = torch.matmul(attention_probs, value) - # reshape hidden_states - hidden_states = self.reshape_batch_dim_to_heads(hidden_states) - return hidden_states - - def _sliced_attention(self, query, key, value, sequence_length, dim): - batch_size_attention = query.shape[0] - hidden_states = torch.zeros( - (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype - ) - slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] - for i in range(hidden_states.shape[0] // slice_size): - start_idx = i * slice_size - end_idx = (i + 1) * slice_size - attn_slice = ( - torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale - ) # TODO: use baddbmm for better performance - attn_slice = attn_slice.softmax(dim=-1) - attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx]) - - hidden_states[start_idx:end_idx] = attn_slice - - # reshape hidden_states - hidden_states = self.reshape_batch_dim_to_heads(hidden_states) - return hidden_states - - -class FeedForward(nn.Module): - r""" - A feed-forward layer. - - Parameters: - dim (:obj:`int`): The number of channels in the input. - dim_out (:obj:`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. - mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. - glu (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use GLU activation. - dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. - """ - - def __init__( - self, dim: int, dim_out: Optional[int] = None, mult: int = 4, glu: bool = False, dropout: float = 0.0 - ): - super().__init__() - inner_dim = int(dim * mult) - dim_out = dim_out if dim_out is not None else dim - project_in = GEGLU(dim, inner_dim) - - self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) - - def forward(self, hidden_states): - return self.net(hidden_states) +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor # feedforward class GEGLU(nn.Module): - r""" - A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. - - Parameters: - dim_in (:obj:`int`): The number of channels in the input. - dim_out (:obj:`int`): The number of channels in the output. - """ - - def __init__(self, dim_in: int, dim_out: int): + def __init__(self, dim_in, dim_out): super().__init__() self.proj = nn.Linear(dim_in, dim_out * 2) - def forward(self, hidden_states): - hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) - return hidden_states * F.gelu(gate) -''' + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum('bhdn,bhen->bhde', k, v) + out = torch.einsum('bhde,bhdn->bhen', context, q) + out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = rearrange(q, 'b c h w -> b (h w) c') + k = rearrange(k, 'b c h w -> b c (h w)') + w_ = torch.einsum('bij,bjk->bik', q, k) + + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, 'b c h w -> b c (h w)') + w_ = rearrange(w_, 'b i j -> b j i') + h_ = torch.einsum('bij,bjk->bik', v, w_) + h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) + h_ = self.proj_out(h_) + + return x+h_ + + + class CrossAttention(nn.Module): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): super().__init__() @@ -382,39 +173,51 @@ class CrossAttention(nn.Module): self.mem_total_gb = psutil.virtual_memory().total // (1 << 30) - self.cross_attention_callback = None + self.custom_attention_calculator = None + + def set_custom_attention_calculator(self, callback:Callable[[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int], torch.Tensor]): + ''' + Set custom attention calculator to be called when attention is calculated + :param callback: Callback, with args q, k, v, dim, offset, slice_size, which returns attention info. + q, k, v are as regular attention calculator. + dim is -1 if the call is non-sliced, or 0 or 1 for dimension-0 or dimension-1 slicing. + If dim is >= 0, offset and slice_size specify the slice start and length. + Pass None to use the default attention calculation. + :return: + ''' + self.custom_attention_calculator = callback def einsum_op_slice_dim0(self, q, k, v, slice_size, callback): r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) for i in range(0, q.shape[0], slice_size): end = i + slice_size - r[i:end] = callback(q[i:end], k[i:end], v[i:end], offset=i) + r[i:end] = callback(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size) return r def einsum_op_slice_dim1(self, q, k, v, slice_size, callback): r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) for i in range(0, q.shape[1], slice_size): end = i + slice_size - r[:, i:end] = callback(q[:, i:end], k, v, offset=i) + r[:, i:end] = callback(self, q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size) return r def einsum_op_mps_v1(self, q, k, v, callback): if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096 - return callback(q, k, v) + return callback(self, q, k, v, -1, 0, 0) else: slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) return self.einsum_op_slice_dim1(q, k, v, slice_size, callback) def einsum_op_mps_v2(self, q, k, v, callback): if self.mem_total_gb > 8 and q.shape[1] <= 4096: - return callback(q, k, v, offset=0) + return callback(self, q, k, v, -1, 0, 0) else: return self.einsum_op_slice_dim0(q, k, v, 1, callback) def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb, callback): size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20) if size_mb <= max_tensor_mb: - return callback(q, k, v, offset=0) + return callback(self, q, k, v, offset=0) div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length() if div <= q.shape[0]: print("warning: untested call to einsum_op_slice_dim0") @@ -433,12 +236,6 @@ class CrossAttention(nn.Module): return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20), callback) def get_attention_mem_efficient(self, q, k, v, callback): - """ - Calculate attention by slicing q, k, and v for memory efficiency then calling - callback(q, k, v, offset=offset) - multiple times if necessary. The offset argument is something - """ - if q.device.type == 'cuda': return self.einsum_op_cuda(q, k, v, callback) @@ -479,7 +276,6 @@ class CrossAttention(nn.Module): return self.to_out(hidden_states) - class BasicTransformerBlock(nn.Module): def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): super().__init__() @@ -547,4 +343,3 @@ class SpatialTransformer(nn.Module): x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) x = self.proj_out(x) return x + x_in -''' \ No newline at end of file