cross-attention working with placeholder {} syntax

This commit is contained in:
Damian at mba 2022-10-17 21:15:03 +02:00
parent 8ff507b03b
commit 1fc1f8bf05
8 changed files with 534 additions and 237 deletions

View File

@ -400,7 +400,7 @@ class Generate:
mask_image = None mask_image = None
try: try:
uc, c, ec = get_uc_and_c_and_ec( uc, c, ec, ec_index_map = get_uc_and_c_and_ec(
prompt, model =self.model, prompt, model =self.model,
skip_normalize=skip_normalize, skip_normalize=skip_normalize,
log_tokens =self.log_tokenization log_tokens =self.log_tokenization
@ -438,7 +438,7 @@ class Generate:
sampler=self.sampler, sampler=self.sampler,
steps=steps, steps=steps,
cfg_scale=cfg_scale, cfg_scale=cfg_scale,
conditioning=(uc, c, ec), conditioning=(uc, c, ec, ec_index_map),
ddim_eta=ddim_eta, ddim_eta=ddim_eta,
image_callback=image_callback, # called after the final image is generated image_callback=image_callback, # called after the final image is generated
step_callback=step_callback, # called after each intermediate image is generated step_callback=step_callback, # called after each intermediate image is generated

View File

@ -10,6 +10,8 @@ log_tokenization() print out colour-coded tokens and warn if trunca
''' '''
import re import re
from difflib import SequenceMatcher
import torch import torch
def get_uc_and_c_and_ec(prompt, model, log_tokens=False, skip_normalize=False): def get_uc_and_c_and_ec(prompt, model, log_tokens=False, skip_normalize=False):
@ -35,32 +37,46 @@ def get_uc_and_c_and_ec(prompt, model, log_tokens=False, skip_normalize=False):
clean_prompt = edited_regex_compile.sub(' ', prompt) clean_prompt = edited_regex_compile.sub(' ', prompt)
prompt = re.sub(' +', ' ', clean_prompt) prompt = re.sub(' +', ' ', clean_prompt)
uc = model.get_learned_conditioning([unconditioned_words])
ec = None
if edited_words is not None:
ec = model.get_learned_conditioning([edited_words])
# get weighted sub-prompts # get weighted sub-prompts
weighted_subprompts = split_weighted_subprompts( weighted_subprompts = split_weighted_subprompts(
prompt, skip_normalize prompt, skip_normalize
) )
ec = None
edit_opcodes = None
uc, _ = model.get_learned_conditioning([unconditioned_words])
if len(weighted_subprompts) > 1: if len(weighted_subprompts) > 1:
# i dont know if this is correct.. but it works # i dont know if this is correct.. but it works
c = torch.zeros_like(uc) c = torch.zeros_like(uc)
# normalize each "sub prompt" and add it # normalize each "sub prompt" and add it
for subprompt, weight in weighted_subprompts: for subprompt, weight in weighted_subprompts:
log_tokenization(subprompt, model, log_tokens, weight) log_tokenization(subprompt, model, log_tokens, weight)
subprompt_embeddings, _ = model.get_learned_conditioning([subprompt])
c = torch.add( c = torch.add(
c, c,
model.get_learned_conditioning([subprompt]), subprompt_embeddings,
alpha=weight, alpha=weight,
) )
if edited_words is not None:
print("can't do cross-attention control with blends just yet, ignoring edits")
else: # just standard 1 prompt else: # just standard 1 prompt
log_tokenization(prompt, model, log_tokens, 1) log_tokenization(prompt, model, log_tokens, 1)
c = model.get_learned_conditioning([prompt]) c, c_tokens = model.get_learned_conditioning([prompt])
uc = model.get_learned_conditioning([unconditioned_words]) if edited_words is not None:
return (uc, c, ec) ec, ec_tokens = model.get_learned_conditioning([edited_words])
edit_opcodes = build_token_edit_opcodes(c_tokens, ec_tokens)
return (uc, c, ec, edit_opcodes)
def build_token_edit_opcodes(c_tokens, ec_tokens):
tokens = c_tokens.cpu().numpy()[0]
tokens_edit = ec_tokens.cpu().numpy()[0]
opcodes = SequenceMatcher(None, tokens, tokens_edit).get_opcodes()
return opcodes
def split_weighted_subprompts(text, skip_normalize=False)->list: def split_weighted_subprompts(text, skip_normalize=False)->list:
""" """

View File

@ -19,7 +19,7 @@ class Txt2Img(Generator):
kwargs are 'width' and 'height' kwargs are 'width' and 'height'
""" """
self.perlin = perlin self.perlin = perlin
uc, c, ec = conditioning uc, c, ec, edit_index_map = conditioning
@torch.no_grad() @torch.no_grad()
def make_image(x_T): def make_image(x_T):
@ -44,6 +44,7 @@ class Txt2Img(Generator):
unconditional_guidance_scale = cfg_scale, unconditional_guidance_scale = cfg_scale,
unconditional_conditioning = uc, unconditional_conditioning = uc,
edited_conditioning = ec, edited_conditioning = ec,
edit_token_index_map = edit_index_map,
eta = ddim_eta, eta = ddim_eta,
img_callback = step_callback, img_callback = step_callback,
threshold = threshold, threshold = threshold,

View File

@ -2,89 +2,99 @@ from enum import Enum
import torch import torch
# adapted from bloc97's CrossAttentionControl colab
# https://github.com/bloc97/CrossAttentionControl
class CrossAttentionControl: class CrossAttentionControl:
class AttentionType(Enum): class AttentionType(Enum):
SELF = 1 SELF = 1
TOKENS = 2 TOKENS = 2
@classmethod @classmethod
def get_attention_module(cls, model, which: AttentionType): def setup_attention_editing(cls, model,
which_attn = "attn1" if which is cls.AttentionType.SELF else "attn2"
module = next(module for name, module in model.named_modules() if
type(module).__name__ == "CrossAttention" and which_attn in name)
return module
@classmethod
def setup_attention_editing(cls, model, original_tokens_length: int,
substitute_conditioning: torch.Tensor = None, substitute_conditioning: torch.Tensor = None,
token_indices_to_edit: list = None): edit_opcodes: list = None):
"""
:param model: The unet model to inject into.
:param substitute_conditioning: The "edited" conditioning vector, [Bx77x768]
:param edit_opcodes: Opcodes from difflib.SequenceMatcher describing how the base
conditionings map to the "edited" conditionings.
:return:
"""
# adapted from init_attention_edit # adapted from init_attention_edit
self_attention_module = cls.get_attention_module(model, cls.AttentionType.SELF)
tokens_attention_module = cls.get_attention_module(model, cls.AttentionType.TOKENS)
if substitute_conditioning is not None: if substitute_conditioning is not None:
device = substitute_conditioning.device device = substitute_conditioning.device
# this is not very torch-y max_length = model.inner_model.cond_stage_model.max_length
mask = torch.zeros(original_tokens_length) # mask=1 means use base prompt attention, mask=0 means use edited prompt attention
for i in token_indices_to_edit: mask = torch.zeros(max_length)
mask[i] = 1 indices_target = torch.arange(max_length, dtype=torch.long)
indices = torch.zeros(max_length, dtype=torch.long)
for name, a0, a1, b0, b1 in edit_opcodes:
if b0 < max_length:
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
# these tokens have not been edited
indices[b0:b1] = indices_target[a0:a1]
mask[b0:b1] = 1
self_attention_module.last_attn_slice_mask = None for m in cls.get_attention_modules(model, cls.AttentionType.SELF):
self_attention_module.last_attn_slice_indices = None m.last_attn_slice_mask = None
tokens_attention_module.last_attn_slice_mask = mask.to(device) m.last_attn_slice_indices = None
tokens_attention_module.last_attn_slice_indices = torch.tensor(token_indices_to_edit, device=device)
for m in cls.get_attention_modules(model, cls.AttentionType.TOKENS):
m.last_attn_slice_mask = mask.to(device)
m.last_attn_slice_indices = indices.to(device)
cls.inject_attention_functions(model) cls.inject_attention_functions(model)
@classmethod
def get_attention_modules(cls, model, which: AttentionType):
which_attn = "attn1" if which is cls.AttentionType.SELF else "attn2"
return [module for name, module in model.named_modules() if
type(module).__name__ == "CrossAttention" and which_attn in name]
@classmethod @classmethod
def request_save_attention_maps(cls, model): def request_save_attention_maps(cls, model):
self_attention_module = cls.get_attention_module(model, cls.AttentionType.SELF) self_attention_modules = cls.get_attention_modules(model, cls.AttentionType.SELF)
tokens_attention_module = cls.get_attention_module(model, cls.AttentionType.TOKENS) tokens_attention_modules = cls.get_attention_modules(model, cls.AttentionType.TOKENS)
self_attention_module.save_last_attn_slice = True for m in self_attention_modules+tokens_attention_modules:
tokens_attention_module.save_last_attn_slice = True m.save_last_attn_slice = True
@classmethod @classmethod
def request_apply_saved_attention_maps(cls, model): def request_apply_saved_attention_maps(cls, model):
self_attention_module = cls.get_attention_module(model, cls.AttentionType.SELF) self_attention_modules = cls.get_attention_modules(model, cls.AttentionType.SELF)
tokens_attention_module = cls.get_attention_module(model, cls.AttentionType.TOKENS) tokens_attention_modules = cls.get_attention_modules(model, cls.AttentionType.TOKENS)
self_attention_module.use_last_attn_slice = True for m in self_attention_modules+tokens_attention_modules:
tokens_attention_module.use_last_attn_slice = True m.use_last_attn_slice = True
@classmethod @classmethod
def inject_attention_functions(cls, unet): def inject_attention_functions(cls, unet):
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276 # ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
def new_attention(self, query, key, value): def new_attention(self, query, key, value):
# TODO: use baddbmm for better performance # TODO: use baddbmm for better performance
print(f"entered new_attention")
attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
attn_slice = attention_scores.softmax(dim=-1) attn_slice = attention_scores.softmax(dim=-1)
# compute attention output # compute attention output
if self.use_last_attn_slice: if self.use_last_attn_slice:
if self.last_attn_slice_mask is not None: if self.last_attn_slice_mask is not None:
print('using masked last_attn_slice') base_attn_slice = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices)
base_attn_slice_mask = self.last_attn_slice_mask
new_attn_slice = (torch.index_select(self.last_attn_slice, -1, this_attn_slice_mask = 1 - self.last_attn_slice_mask
self.last_attn_slice_indices)) attn_slice = attn_slice * this_attn_slice_mask + base_attn_slice * base_attn_slice_mask
attn_slice = (attn_slice * (1 - self.last_attn_slice_mask)
+ new_attn_slice * self.last_attn_slice_mask)
else: else:
print('using unmasked last_attn_slice')
attn_slice = self.last_attn_slice attn_slice = self.last_attn_slice
self.use_last_attn_slice = False self.use_last_attn_slice = False
else:
print('not using last_attn_slice')
if self.save_last_attn_slice: if self.save_last_attn_slice:
print('saving last_attn_slice')
self.last_attn_slice = attn_slice self.last_attn_slice = attn_slice
self.save_last_attn_slice = False self.save_last_attn_slice = False
else:
print('not saving last_attn_slice')
if self.use_last_attn_weights and self.last_attn_slice_weights is not None: if self.use_last_attn_weights and self.last_attn_slice_weights is not None:
attn_slice = attn_slice * self.last_attn_slice_weights attn_slice = attn_slice * self.last_attn_slice_weights
@ -92,16 +102,59 @@ class CrossAttentionControl:
hidden_states = torch.matmul(attn_slice, value) hidden_states = torch.matmul(attn_slice, value)
# reshape hidden_states # reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states return hidden_states
for _, module in unet.named_modules(): def new_sliced_attention(self, query, key, value, sequence_length, dim):
batch_size_attention = query.shape[0]
hidden_states = torch.zeros(
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
)
slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
for i in range(hidden_states.shape[0] // slice_size):
start_idx = i * slice_size
end_idx = (i + 1) * slice_size
attn_slice = (
torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
) # TODO: use baddbmm for better performance
attn_slice = attn_slice.softmax(dim=-1)
if self.use_last_attn_slice:
if self.last_attn_slice_mask is not None:
new_attn_slice = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices)
attn_slice = attn_slice * (
1 - self.last_attn_slice_mask) + new_attn_slice * self.last_attn_slice_mask
else:
attn_slice = self.last_attn_slice
self.use_last_attn_slice = False
if self.save_last_attn_slice:
self.last_attn_slice = attn_slice
self.save_last_attn_slice = False
if self.use_last_attn_weights and self.last_attn_slice_weights is not None:
attn_slice = attn_slice * self.last_attn_slice_weights
self.use_last_attn_weights = False
attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states
for name, module in unet.named_modules():
module_name = type(module).__name__ module_name = type(module).__name__
if module_name == 'CrossAttention': if module_name == "CrossAttention":
module.last_attn_slice = None module.last_attn_slice = None
module.use_last_attn_slice = False module.use_last_attn_slice = False
module.use_last_attn_weights = False module.use_last_attn_weights = False
module.save_last_attn_slice = False module.save_last_attn_slice = False
module.cross_attention_callback = new_attention.__get__(module, type(module)) module._sliced_attention = new_sliced_attention.__get__(module, type(module))
module._attention = new_attention.__get__(module, type(module))
# original code below # original code below

View File

@ -30,7 +30,7 @@ def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
class CFGDenoiser(nn.Module): class CFGDenoiser(nn.Module):
def __init__(self, model, threshold = 0, warmup = 0, edited_conditioning = None): def __init__(self, model, threshold = 0, warmup = 0, edited_conditioning = None, edit_opcodes = None):
super().__init__() super().__init__()
self.inner_model = model self.inner_model = model
self.threshold = threshold self.threshold = threshold
@ -39,24 +39,23 @@ class CFGDenoiser(nn.Module):
self.edited_conditioning = edited_conditioning self.edited_conditioning = edited_conditioning
if self.edited_conditioning is not None: if edited_conditioning is not None:
initial_tokens_count = 77 # '<start> a cat sitting on a car <end>' # <start> a cat sitting on a car <end>
token_indices_to_edit = [2] # 'cat' CrossAttentionControl.setup_attention_editing(self.inner_model, edited_conditioning, edit_opcodes)
CrossAttentionControl.setup_attention_editing(self.inner_model, initial_tokens_count, edited_conditioning, token_indices_to_edit) else:
# pass through the attention func but don't act on it
CrossAttentionControl.setup_attention_editing(self.inner_model)
def forward(self, x, sigma, uncond, cond, cond_scale): def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
print('generating new unconditioned latents') print('generating unconditioned latents')
unconditioned_latents = self.inner_model(x, sigma, cond=uncond) unconditioned_latents = self.inner_model(x, sigma, cond=uncond)
# process x using the original prompt, saving the attention maps if required # process x using the original prompt, saving the attention maps if required
if self.edited_conditioning is not None: if self.edited_conditioning is not None:
# this is automatically toggled off after the model forward() # this is automatically toggled off after the model forward()
CrossAttentionControl.request_save_attention_maps(self.inner_model) CrossAttentionControl.request_save_attention_maps(self.inner_model)
print('generating new conditioned latents') print('generating conditioned latents')
conditioned_latents = self.inner_model(x, sigma, cond=cond) conditioned_latents = self.inner_model(x, sigma, cond=cond)
if self.edited_conditioning is not None: if self.edited_conditioning is not None:
@ -192,6 +191,7 @@ class KSampler(Sampler):
unconditional_guidance_scale=1.0, unconditional_guidance_scale=1.0,
unconditional_conditioning=None, unconditional_conditioning=None,
edited_conditioning=None, edited_conditioning=None,
edit_token_index_map=None,
threshold = 0, threshold = 0,
perlin = 0, perlin = 0,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
@ -223,7 +223,8 @@ class KSampler(Sampler):
else: else:
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10), edited_conditioning=edited_conditioning) model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10),
edited_conditioning=edited_conditioning, edit_opcodes=edit_token_index_map)
extra_args = { extra_args = {
'cond': conditioning, 'cond': conditioning,
'uncond': unconditional_conditioning, 'uncond': unconditional_conditioning,

View File

@ -1,155 +1,367 @@
from inspect import isfunction # Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math import math
from typing import Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn, einsum from torch import nn
from einops import rearrange, repeat
from ldm.modules.diffusionmodules.util import checkpoint
import psutil
def exists(val):
return val is not None
def uniq(arr): class AttentionBlock(nn.Module):
return{el: True for el in arr}.keys() """
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
Uses three q, k, v linear layers to compute attention.
Parameters:
channels (:obj:`int`): The number of channels in the input and output.
num_head_channels (:obj:`int`, *optional*):
The number of channels in each head. If None, then `num_heads` = 1.
num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm.
rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
"""
def __init__(
self,
channels: int,
num_head_channels: Optional[int] = None,
num_groups: int = 32,
rescale_output_factor: float = 1.0,
eps: float = 1e-5,
):
super().__init__()
self.channels = channels
self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
self.num_head_size = num_head_channels
self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
# define q,k,v as linear layers
self.query = nn.Linear(channels, channels)
self.key = nn.Linear(channels, channels)
self.value = nn.Linear(channels, channels)
self.rescale_output_factor = rescale_output_factor
self.proj_attn = nn.Linear(channels, channels, 1)
def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
return new_projection
def forward(self, hidden_states):
residual = hidden_states
batch, channel, height, width = hidden_states.shape
# norm
hidden_states = self.group_norm(hidden_states)
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
# proj to q, k, v
query_proj = self.query(hidden_states)
key_proj = self.key(hidden_states)
value_proj = self.value(hidden_states)
# transpose
query_states = self.transpose_for_scores(query_proj)
key_states = self.transpose_for_scores(key_proj)
value_states = self.transpose_for_scores(value_proj)
# get scores
scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) # TODO: use baddmm
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
# compute attention output
hidden_states = torch.matmul(attention_probs, value_states)
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
hidden_states = hidden_states.view(new_hidden_states_shape)
# compute next hidden_states
hidden_states = self.proj_attn(hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
# res connect and rescale
hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states
def default(val, d): class SpatialTransformer(nn.Module):
if exists(val): """
return val Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
return d() if isfunction(d) else d standard transformer action. Finally, reshape to image.
Parameters:
in_channels (:obj:`int`): The number of channels in the input and output.
n_heads (:obj:`int`): The number of heads to use for multi-head attention.
d_head (:obj:`int`): The number of channels in each head.
depth (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use.
context_dim (:obj:`int`, *optional*): The number of context dimensions to use.
"""
def __init__(
self,
in_channels: int,
n_heads: int,
d_head: int,
depth: int = 1,
dropout: float = 0.0,
num_groups: int = 32,
context_dim: Optional[int] = None,
):
super().__init__()
self.n_heads = n_heads
self.d_head = d_head
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
for d in range(depth)
]
)
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
def _set_attention_slice(self, slice_size):
for block in self.transformer_blocks:
block._set_attention_slice(slice_size)
def forward(self, hidden_states, context=None):
# note: if no context is given, cross-attention defaults to self-attention
batch, channel, height, weight = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
hidden_states = self.proj_in(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
for block in self.transformer_blocks:
hidden_states = block(hidden_states, context=context)
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
hidden_states = self.proj_out(hidden_states)
return hidden_states + residual
def max_neg_value(t): class BasicTransformerBlock(nn.Module):
return -torch.finfo(t.dtype).max r"""
A basic Transformer block.
Parameters:
dim (:obj:`int`): The number of channels in the input and output.
n_heads (:obj:`int`): The number of heads to use for multi-head attention.
d_head (:obj:`int`): The number of channels in each head.
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention.
gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network.
checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing.
"""
def __init__(
self,
dim: int,
n_heads: int,
d_head: int,
dropout=0.0,
context_dim: Optional[int] = None,
gated_ff: bool = True,
checkpoint: bool = True,
):
super().__init__()
self.attn1 = CrossAttention(
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(
query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout
) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def _set_attention_slice(self, slice_size):
self.attn1._slice_size = slice_size
self.attn2._slice_size = slice_size
def forward(self, hidden_states, context=None):
hidden_states = hidden_states.contiguous() if hidden_states.device.type == "mps" else hidden_states
hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states
hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
return hidden_states
def init_(tensor): class CrossAttention(nn.Module):
dim = tensor.shape[-1] r"""
std = 1 / math.sqrt(dim) A cross attention layer.
tensor.uniform_(-std, std)
return tensor Parameters:
query_dim (:obj:`int`): The number of channels in the query.
context_dim (:obj:`int`, *optional*):
The number of channels in the context. If not given, defaults to `query_dim`.
heads (:obj:`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
dim_head (:obj:`int`, *optional*, defaults to 64): The number of channels in each head.
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
"""
def __init__(
self, query_dim: int, context_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: int = 0.0
):
super().__init__()
inner_dim = dim_head * heads
context_dim = context_dim if context_dim is not None else query_dim
self.scale = dim_head**-0.5
self.heads = heads
# for slice_size > 0 the attention score computation
# is split across the batch axis to save memory
# You can set slice_size with `set_attention_slice`
self._slice_size = None
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor
def reshape_batch_dim_to_heads(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
def forward(self, hidden_states, context=None, mask=None):
batch_size, sequence_length, _ = hidden_states.shape
query = self.to_q(hidden_states)
context = context if context is not None else hidden_states
key = self.to_k(context)
value = self.to_v(context)
dim = query.shape[-1]
query = self.reshape_heads_to_batch_dim(query)
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)
# TODO(PVP) - mask is currently never used. Remember to re-implement when used
# attention, what we cannot get enough of
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
hidden_states = self._attention(query, key, value)
else:
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)
return self.to_out(hidden_states)
def _attention(self, query, key, value):
# TODO: use baddbmm for better performance
attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
attention_probs = attention_scores.softmax(dim=-1)
# compute attention output
hidden_states = torch.matmul(attention_probs, value)
# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states
def _sliced_attention(self, query, key, value, sequence_length, dim):
batch_size_attention = query.shape[0]
hidden_states = torch.zeros(
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
)
slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
for i in range(hidden_states.shape[0] // slice_size):
start_idx = i * slice_size
end_idx = (i + 1) * slice_size
attn_slice = (
torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
) # TODO: use baddbmm for better performance
attn_slice = attn_slice.softmax(dim=-1)
attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states
class FeedForward(nn.Module):
r"""
A feed-forward layer.
Parameters:
dim (:obj:`int`): The number of channels in the input.
dim_out (:obj:`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
glu (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use GLU activation.
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
"""
def __init__(
self, dim: int, dim_out: Optional[int] = None, mult: int = 4, glu: bool = False, dropout: float = 0.0
):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
project_in = GEGLU(dim, inner_dim)
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
def forward(self, hidden_states):
return self.net(hidden_states)
# feedforward # feedforward
class GEGLU(nn.Module): class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out): r"""
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
Parameters:
dim_in (:obj:`int`): The number of channels in the input.
dim_out (:obj:`int`): The number of channels in the output.
"""
def __init__(self, dim_in: int, dim_out: int):
super().__init__() super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2) self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x): def forward(self, hidden_states):
x, gate = self.proj(x).chunk(2, dim=-1) hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
return x * F.gelu(gate) return hidden_states * F.gelu(gate)
'''
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.net(x)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class LinearAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q)
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
return self.to_out(out)
class SpatialSelfAttention(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b,c,h,w = q.shape
q = rearrange(q, 'b c h w -> b (h w) c')
k = rearrange(k, 'b c h w -> b c (h w)')
w_ = torch.einsum('bij,bjk->bik', q, k)
w_ = w_ * (int(c)**(-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = rearrange(v, 'b c h w -> b c (h w)')
w_ = rearrange(w_, 'b i j -> b j i')
h_ = torch.einsum('bij,bjk->bik', v, w_)
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
h_ = self.proj_out(h_)
return x+h_
class CrossAttention(nn.Module): class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
super().__init__() super().__init__()
@ -172,48 +384,45 @@ class CrossAttention(nn.Module):
self.cross_attention_callback = None self.cross_attention_callback = None
def einsum_op_compvis(self, q, k, v): def einsum_op_slice_dim0(self, q, k, v, slice_size, callback):
s = einsum('b i d, b j d -> b i j', q, k)
s = s.softmax(dim=-1, dtype=s.dtype)
return einsum('b i j, b j d -> b i d', s, v)
def einsum_op_slice_0(self, q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[0], slice_size): for i in range(0, q.shape[0], slice_size):
end = i + slice_size end = i + slice_size
r[i:end] = self.einsum_op_compvis(q[i:end], k[i:end], v[i:end]) r[i:end] = callback(q[i:end], k[i:end], v[i:end], offset=i)
return r return r
def einsum_op_slice_1(self, q, k, v, slice_size): def einsum_op_slice_dim1(self, q, k, v, slice_size, callback):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[1], slice_size): for i in range(0, q.shape[1], slice_size):
end = i + slice_size end = i + slice_size
r[:, i:end] = self.einsum_op_compvis(q[:, i:end], k, v) r[:, i:end] = callback(q[:, i:end], k, v, offset=i)
return r return r
def einsum_op_mps_v1(self, q, k, v): def einsum_op_mps_v1(self, q, k, v, callback):
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096 if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
return self.einsum_op_compvis(q, k, v) return callback(q, k, v)
else: else:
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
return self.einsum_op_slice_1(q, k, v, slice_size) return self.einsum_op_slice_dim1(q, k, v, slice_size, callback)
def einsum_op_mps_v2(self, q, k, v): def einsum_op_mps_v2(self, q, k, v, callback):
if self.mem_total_gb > 8 and q.shape[1] <= 4096: if self.mem_total_gb > 8 and q.shape[1] <= 4096:
return self.einsum_op_compvis(q, k, v) return callback(q, k, v, offset=0)
else: else:
return self.einsum_op_slice_0(q, k, v, 1) return self.einsum_op_slice_dim0(q, k, v, 1, callback)
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb): def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb, callback):
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20) size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
if size_mb <= max_tensor_mb: if size_mb <= max_tensor_mb:
return self.einsum_op_compvis(q, k, v) return callback(q, k, v, offset=0)
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length() div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
if div <= q.shape[0]: if div <= q.shape[0]:
return self.einsum_op_slice_0(q, k, v, q.shape[0] // div) print("warning: untested call to einsum_op_slice_dim0")
return self.einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1)) return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div, callback)
print("warning: untested call to einsum_op_slice_dim1")
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1), callback)
def einsum_op_cuda(self, q, k, v): def einsum_op_cuda(self, q, k, v, callback):
stats = torch.cuda.memory_stats(q.device) stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current'] mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current'] mem_reserved = stats['reserved_bytes.all.current']
@ -221,20 +430,26 @@ class CrossAttention(nn.Module):
mem_free_torch = mem_reserved - mem_active mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch mem_free_total = mem_free_cuda + mem_free_torch
# Divide factor of safety as there's copying and fragmentation # Divide factor of safety as there's copying and fragmentation
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20), callback)
def get_attention_mem_efficient(self, q, k, v, callback):
"""
Calculate attention by slicing q, k, and v for memory efficiency then calling
callback(q, k, v, offset=offset)
multiple times if necessary. The offset argument is something
"""
def einsum_op(self, q, k, v):
if q.device.type == 'cuda': if q.device.type == 'cuda':
return self.einsum_op_cuda(q, k, v) return self.einsum_op_cuda(q, k, v, callback)
if q.device.type == 'mps': if q.device.type == 'mps':
if self.mem_total_gb >= 32: if self.mem_total_gb >= 32:
return self.einsum_op_mps_v1(q, k, v) return self.einsum_op_mps_v1(q, k, v, callback)
return self.einsum_op_mps_v2(q, k, v) return self.einsum_op_mps_v2(q, k, v, callback)
# Smaller slices are faster due to L2/L3/SLC caches. # Smaller slices are faster due to L2/L3/SLC caches.
# Tested on i7 with 8MB L3 cache. # Tested on i7 with 8MB L3 cache.
return self.einsum_op_tensor_mem(q, k, v, 32) return self.einsum_op_tensor_mem(q, k, v, 32, callback)
def forward(self, x, context=None, mask=None): def forward(self, x, context=None, mask=None):
h = self.heads h = self.heads
@ -247,14 +462,24 @@ class CrossAttention(nn.Module):
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
if self.cross_attention_callback is not None: def default_attention_calculator(q, k, v, **kwargs):
r = self.cross_attention_callback(q, k, v) # calculate attention scores
else: attention_scores = einsum('b i d, b j d -> b i j', q, k)
r = self.einsum_op(q, k, v) # calculate attenion slice by taking the best scores for each latent pixel
attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
return einsum('b i j, b j d -> b i d', attention_slice, v)
attention_calculator = \
self.custom_attention_calculator if self.custom_attention_calculator is not None \
else default_attention_calculator
r = self.get_attention_mem_efficient(q, k, v, attention_calculator)
hidden_states = rearrange(r, '(b h) n d -> b n (h d)', h=h) hidden_states = rearrange(r, '(b h) n d -> b n (h d)', h=h)
return self.to_out(hidden_states) return self.to_out(hidden_states)
class BasicTransformerBlock(nn.Module): class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
super().__init__() super().__init__()
@ -322,3 +547,4 @@ class SpatialTransformer(nn.Module):
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
x = self.proj_out(x) x = self.proj_out(x)
return x + x_in return x + x_in
'''

View File

@ -8,7 +8,7 @@ import numpy as np
from einops import rearrange from einops import rearrange
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from ldm.modules.attention import LinearAttention #from ldm.modules.attention import LinearAttention
import psutil import psutil
@ -151,10 +151,10 @@ class ResnetBlock(nn.Module):
return x + h return x + h
class LinAttnBlock(LinearAttention): #class LinAttnBlock(LinearAttention):
"""to match AttnBlock usage""" # """to match AttnBlock usage"""
def __init__(self, in_channels): # def __init__(self, in_channels):
super().__init__(dim=in_channels, heads=1, dim_head=in_channels) # super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
class AttnBlock(nn.Module): class AttnBlock(nn.Module):

View File

@ -449,7 +449,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
tokens = batch_encoding['input_ids'].to(self.device) tokens = batch_encoding['input_ids'].to(self.device)
z = self.transformer(input_ids=tokens, **kwargs) z = self.transformer(input_ids=tokens, **kwargs)
return z return z, tokens
def encode(self, text, **kwargs): def encode(self, text, **kwargs):
return self(text, **kwargs) return self(text, **kwargs)