From 1fc1f8bf05c01af0c07714df38f10745dd984103 Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Mon, 17 Oct 2022 21:15:03 +0200 Subject: [PATCH] cross-attention working with placeholder {} syntax --- ldm/generate.py | 4 +- ldm/invoke/conditioning.py | 34 +- ldm/invoke/generator/txt2img.py | 3 +- ldm/models/diffusion/cross_attention.py | 141 ++++-- ldm/models/diffusion/ksampler.py | 23 +- ldm/modules/attention.py | 554 +++++++++++++++++------- ldm/modules/diffusionmodules/model.py | 10 +- ldm/modules/encoders/modules.py | 2 +- 8 files changed, 534 insertions(+), 237 deletions(-) diff --git a/ldm/generate.py b/ldm/generate.py index b8945342b0..37df973291 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -400,7 +400,7 @@ class Generate: mask_image = None try: - uc, c, ec = get_uc_and_c_and_ec( + uc, c, ec, ec_index_map = get_uc_and_c_and_ec( prompt, model =self.model, skip_normalize=skip_normalize, log_tokens =self.log_tokenization @@ -438,7 +438,7 @@ class Generate: sampler=self.sampler, steps=steps, cfg_scale=cfg_scale, - conditioning=(uc, c, ec), + conditioning=(uc, c, ec, ec_index_map), ddim_eta=ddim_eta, image_callback=image_callback, # called after the final image is generated step_callback=step_callback, # called after each intermediate image is generated diff --git a/ldm/invoke/conditioning.py b/ldm/invoke/conditioning.py index 1453d9ce8c..8c8f5eeb01 100644 --- a/ldm/invoke/conditioning.py +++ b/ldm/invoke/conditioning.py @@ -10,6 +10,8 @@ log_tokenization() print out colour-coded tokens and warn if trunca ''' import re +from difflib import SequenceMatcher + import torch def get_uc_and_c_and_ec(prompt, model, log_tokens=False, skip_normalize=False): @@ -35,32 +37,46 @@ def get_uc_and_c_and_ec(prompt, model, log_tokens=False, skip_normalize=False): clean_prompt = edited_regex_compile.sub(' ', prompt) prompt = re.sub(' +', ' ', clean_prompt) - uc = model.get_learned_conditioning([unconditioned_words]) - ec = None - if edited_words is not None: - ec = model.get_learned_conditioning([edited_words]) - # get weighted sub-prompts weighted_subprompts = split_weighted_subprompts( prompt, skip_normalize ) + ec = None + edit_opcodes = None + + uc, _ = model.get_learned_conditioning([unconditioned_words]) + if len(weighted_subprompts) > 1: # i dont know if this is correct.. but it works c = torch.zeros_like(uc) # normalize each "sub prompt" and add it for subprompt, weight in weighted_subprompts: log_tokenization(subprompt, model, log_tokens, weight) + subprompt_embeddings, _ = model.get_learned_conditioning([subprompt]) c = torch.add( c, - model.get_learned_conditioning([subprompt]), + subprompt_embeddings, alpha=weight, ) + if edited_words is not None: + print("can't do cross-attention control with blends just yet, ignoring edits") else: # just standard 1 prompt log_tokenization(prompt, model, log_tokens, 1) - c = model.get_learned_conditioning([prompt]) - uc = model.get_learned_conditioning([unconditioned_words]) - return (uc, c, ec) + c, c_tokens = model.get_learned_conditioning([prompt]) + if edited_words is not None: + ec, ec_tokens = model.get_learned_conditioning([edited_words]) + edit_opcodes = build_token_edit_opcodes(c_tokens, ec_tokens) + + return (uc, c, ec, edit_opcodes) + +def build_token_edit_opcodes(c_tokens, ec_tokens): + tokens = c_tokens.cpu().numpy()[0] + tokens_edit = ec_tokens.cpu().numpy()[0] + + opcodes = SequenceMatcher(None, tokens, tokens_edit).get_opcodes() + return opcodes + def split_weighted_subprompts(text, skip_normalize=False)->list: """ diff --git a/ldm/invoke/generator/txt2img.py b/ldm/invoke/generator/txt2img.py index 23f03f22db..9f066745f7 100644 --- a/ldm/invoke/generator/txt2img.py +++ b/ldm/invoke/generator/txt2img.py @@ -19,7 +19,7 @@ class Txt2Img(Generator): kwargs are 'width' and 'height' """ self.perlin = perlin - uc, c, ec = conditioning + uc, c, ec, edit_index_map = conditioning @torch.no_grad() def make_image(x_T): @@ -44,6 +44,7 @@ class Txt2Img(Generator): unconditional_guidance_scale = cfg_scale, unconditional_conditioning = uc, edited_conditioning = ec, + edit_token_index_map = edit_index_map, eta = ddim_eta, img_callback = step_callback, threshold = threshold, diff --git a/ldm/models/diffusion/cross_attention.py b/ldm/models/diffusion/cross_attention.py index a440eb3e6a..d829162f35 100644 --- a/ldm/models/diffusion/cross_attention.py +++ b/ldm/models/diffusion/cross_attention.py @@ -2,89 +2,99 @@ from enum import Enum import torch +# adapted from bloc97's CrossAttentionControl colab +# https://github.com/bloc97/CrossAttentionControl + class CrossAttentionControl: class AttentionType(Enum): SELF = 1 TOKENS = 2 @classmethod - def get_attention_module(cls, model, which: AttentionType): - which_attn = "attn1" if which is cls.AttentionType.SELF else "attn2" - module = next(module for name, module in model.named_modules() if - type(module).__name__ == "CrossAttention" and which_attn in name) - return module - - @classmethod - def setup_attention_editing(cls, model, original_tokens_length: int, + def setup_attention_editing(cls, model, substitute_conditioning: torch.Tensor = None, - token_indices_to_edit: list = None): + edit_opcodes: list = None): + """ + :param model: The unet model to inject into. + :param substitute_conditioning: The "edited" conditioning vector, [Bx77x768] + :param edit_opcodes: Opcodes from difflib.SequenceMatcher describing how the base + conditionings map to the "edited" conditionings. + :return: + """ # adapted from init_attention_edit - self_attention_module = cls.get_attention_module(model, cls.AttentionType.SELF) - tokens_attention_module = cls.get_attention_module(model, cls.AttentionType.TOKENS) - if substitute_conditioning is not None: device = substitute_conditioning.device - # this is not very torch-y - mask = torch.zeros(original_tokens_length) - for i in token_indices_to_edit: - mask[i] = 1 + max_length = model.inner_model.cond_stage_model.max_length + # mask=1 means use base prompt attention, mask=0 means use edited prompt attention + mask = torch.zeros(max_length) + indices_target = torch.arange(max_length, dtype=torch.long) + indices = torch.zeros(max_length, dtype=torch.long) + for name, a0, a1, b0, b1 in edit_opcodes: + if b0 < max_length: + if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0): + # these tokens have not been edited + indices[b0:b1] = indices_target[a0:a1] + mask[b0:b1] = 1 - self_attention_module.last_attn_slice_mask = None - self_attention_module.last_attn_slice_indices = None - tokens_attention_module.last_attn_slice_mask = mask.to(device) - tokens_attention_module.last_attn_slice_indices = torch.tensor(token_indices_to_edit, device=device) + for m in cls.get_attention_modules(model, cls.AttentionType.SELF): + m.last_attn_slice_mask = None + m.last_attn_slice_indices = None + + for m in cls.get_attention_modules(model, cls.AttentionType.TOKENS): + m.last_attn_slice_mask = mask.to(device) + m.last_attn_slice_indices = indices.to(device) cls.inject_attention_functions(model) + + @classmethod + def get_attention_modules(cls, model, which: AttentionType): + which_attn = "attn1" if which is cls.AttentionType.SELF else "attn2" + return [module for name, module in model.named_modules() if + type(module).__name__ == "CrossAttention" and which_attn in name] + + @classmethod def request_save_attention_maps(cls, model): - self_attention_module = cls.get_attention_module(model, cls.AttentionType.SELF) - tokens_attention_module = cls.get_attention_module(model, cls.AttentionType.TOKENS) - self_attention_module.save_last_attn_slice = True - tokens_attention_module.save_last_attn_slice = True + self_attention_modules = cls.get_attention_modules(model, cls.AttentionType.SELF) + tokens_attention_modules = cls.get_attention_modules(model, cls.AttentionType.TOKENS) + for m in self_attention_modules+tokens_attention_modules: + m.save_last_attn_slice = True @classmethod def request_apply_saved_attention_maps(cls, model): - self_attention_module = cls.get_attention_module(model, cls.AttentionType.SELF) - tokens_attention_module = cls.get_attention_module(model, cls.AttentionType.TOKENS) - self_attention_module.use_last_attn_slice = True - tokens_attention_module.use_last_attn_slice = True + self_attention_modules = cls.get_attention_modules(model, cls.AttentionType.SELF) + tokens_attention_modules = cls.get_attention_modules(model, cls.AttentionType.TOKENS) + for m in self_attention_modules+tokens_attention_modules: + m.use_last_attn_slice = True + @classmethod def inject_attention_functions(cls, unet): # ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276 def new_attention(self, query, key, value): # TODO: use baddbmm for better performance - print(f"entered new_attention") attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale attn_slice = attention_scores.softmax(dim=-1) # compute attention output if self.use_last_attn_slice: if self.last_attn_slice_mask is not None: - print('using masked last_attn_slice') - - new_attn_slice = (torch.index_select(self.last_attn_slice, -1, - self.last_attn_slice_indices)) - attn_slice = (attn_slice * (1 - self.last_attn_slice_mask) - + new_attn_slice * self.last_attn_slice_mask) + base_attn_slice = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices) + base_attn_slice_mask = self.last_attn_slice_mask + this_attn_slice_mask = 1 - self.last_attn_slice_mask + attn_slice = attn_slice * this_attn_slice_mask + base_attn_slice * base_attn_slice_mask else: - print('using unmasked last_attn_slice') attn_slice = self.last_attn_slice self.use_last_attn_slice = False - else: - print('not using last_attn_slice') if self.save_last_attn_slice: - print('saving last_attn_slice') self.last_attn_slice = attn_slice self.save_last_attn_slice = False - else: - print('not saving last_attn_slice') if self.use_last_attn_weights and self.last_attn_slice_weights is not None: attn_slice = attn_slice * self.last_attn_slice_weights @@ -92,16 +102,59 @@ class CrossAttentionControl: hidden_states = torch.matmul(attn_slice, value) # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) return hidden_states - for _, module in unet.named_modules(): + def new_sliced_attention(self, query, key, value, sequence_length, dim): + + batch_size_attention = query.shape[0] + hidden_states = torch.zeros( + (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype + ) + slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] + for i in range(hidden_states.shape[0] // slice_size): + start_idx = i * slice_size + end_idx = (i + 1) * slice_size + attn_slice = ( + torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale + ) # TODO: use baddbmm for better performance + attn_slice = attn_slice.softmax(dim=-1) + + if self.use_last_attn_slice: + if self.last_attn_slice_mask is not None: + new_attn_slice = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices) + attn_slice = attn_slice * ( + 1 - self.last_attn_slice_mask) + new_attn_slice * self.last_attn_slice_mask + else: + attn_slice = self.last_attn_slice + + self.use_last_attn_slice = False + + if self.save_last_attn_slice: + self.last_attn_slice = attn_slice + self.save_last_attn_slice = False + + if self.use_last_attn_weights and self.last_attn_slice_weights is not None: + attn_slice = attn_slice * self.last_attn_slice_weights + self.use_last_attn_weights = False + + attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + for name, module in unet.named_modules(): module_name = type(module).__name__ - if module_name == 'CrossAttention': + if module_name == "CrossAttention": module.last_attn_slice = None module.use_last_attn_slice = False module.use_last_attn_weights = False module.save_last_attn_slice = False - module.cross_attention_callback = new_attention.__get__(module, type(module)) + module._sliced_attention = new_sliced_attention.__get__(module, type(module)) + module._attention = new_attention.__get__(module, type(module)) # original code below diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index 29949aff8d..e5d521f33f 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -30,7 +30,7 @@ def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7): class CFGDenoiser(nn.Module): - def __init__(self, model, threshold = 0, warmup = 0, edited_conditioning = None): + def __init__(self, model, threshold = 0, warmup = 0, edited_conditioning = None, edit_opcodes = None): super().__init__() self.inner_model = model self.threshold = threshold @@ -39,24 +39,23 @@ class CFGDenoiser(nn.Module): self.edited_conditioning = edited_conditioning - if self.edited_conditioning is not None: - initial_tokens_count = 77 # ' a cat sitting on a car ' - token_indices_to_edit = [2] # 'cat' - CrossAttentionControl.setup_attention_editing(self.inner_model, initial_tokens_count, edited_conditioning, token_indices_to_edit) + if edited_conditioning is not None: + # a cat sitting on a car + CrossAttentionControl.setup_attention_editing(self.inner_model, edited_conditioning, edit_opcodes) + else: + # pass through the attention func but don't act on it + CrossAttentionControl.setup_attention_editing(self.inner_model) def forward(self, x, sigma, uncond, cond, cond_scale): - x_in = torch.cat([x] * 2) - sigma_in = torch.cat([sigma] * 2) - cond_in = torch.cat([uncond, cond]) - print('generating new unconditioned latents') + print('generating unconditioned latents') unconditioned_latents = self.inner_model(x, sigma, cond=uncond) # process x using the original prompt, saving the attention maps if required if self.edited_conditioning is not None: # this is automatically toggled off after the model forward() CrossAttentionControl.request_save_attention_maps(self.inner_model) - print('generating new conditioned latents') + print('generating conditioned latents') conditioned_latents = self.inner_model(x, sigma, cond=cond) if self.edited_conditioning is not None: @@ -192,6 +191,7 @@ class KSampler(Sampler): unconditional_guidance_scale=1.0, unconditional_conditioning=None, edited_conditioning=None, + edit_token_index_map=None, threshold = 0, perlin = 0, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... @@ -223,7 +223,8 @@ class KSampler(Sampler): else: x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] - model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10), edited_conditioning=edited_conditioning) + model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10), + edited_conditioning=edited_conditioning, edit_opcodes=edit_token_index_map) extra_args = { 'cond': conditioning, 'uncond': unconditional_conditioning, diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index a9805e6c67..d00b95b1af 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -1,155 +1,367 @@ -from inspect import isfunction +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import math +from typing import Optional + import torch import torch.nn.functional as F -from torch import nn, einsum -from einops import rearrange, repeat - -from ldm.modules.diffusionmodules.util import checkpoint - -import psutil - -def exists(val): - return val is not None +from torch import nn -def uniq(arr): - return{el: True for el in arr}.keys() +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted + to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + Uses three q, k, v linear layers to compute attention. + + Parameters: + channels (:obj:`int`): The number of channels in the input and output. + num_head_channels (:obj:`int`, *optional*): + The number of channels in each head. If None, then `num_heads` = 1. + num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm. + rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by. + eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. + """ + + def __init__( + self, + channels: int, + num_head_channels: Optional[int] = None, + num_groups: int = 32, + rescale_output_factor: float = 1.0, + eps: float = 1e-5, + ): + super().__init__() + self.channels = channels + + self.num_heads = channels // num_head_channels if num_head_channels is not None else 1 + self.num_head_size = num_head_channels + self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True) + + # define q,k,v as linear layers + self.query = nn.Linear(channels, channels) + self.key = nn.Linear(channels, channels) + self.value = nn.Linear(channels, channels) + + self.rescale_output_factor = rescale_output_factor + self.proj_attn = nn.Linear(channels, channels, 1) + + def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor: + new_projection_shape = projection.size()[:-1] + (self.num_heads, -1) + # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D) + new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) + return new_projection + + def forward(self, hidden_states): + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.query(hidden_states) + key_proj = self.key(hidden_states) + value_proj = self.value(hidden_states) + + # transpose + query_states = self.transpose_for_scores(query_proj) + key_states = self.transpose_for_scores(key_proj) + value_states = self.transpose_for_scores(value_proj) + + # get scores + scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads)) + attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) # TODO: use baddmm + attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype) + + # compute attention output + hidden_states = torch.matmul(attention_probs, value_states) + + hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous() + new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,) + hidden_states = hidden_states.view(new_hidden_states_shape) + + # compute next hidden_states + hidden_states = self.proj_attn(hidden_states) + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states -def default(val, d): - if exists(val): - return val - return d() if isfunction(d) else d +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply + standard transformer action. Finally, reshape to image. + + Parameters: + in_channels (:obj:`int`): The number of channels in the input and output. + n_heads (:obj:`int`): The number of heads to use for multi-head attention. + d_head (:obj:`int`): The number of channels in each head. + depth (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use. + context_dim (:obj:`int`, *optional*): The number of context dimensions to use. + """ + + def __init__( + self, + in_channels: int, + n_heads: int, + d_head: int, + depth: int = 1, + dropout: float = 0.0, + num_groups: int = 32, + context_dim: Optional[int] = None, + ): + super().__init__() + self.n_heads = n_heads + self.d_head = d_head + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) + for d in range(depth) + ] + ) + + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + + def _set_attention_slice(self, slice_size): + for block in self.transformer_blocks: + block._set_attention_slice(slice_size) + + def forward(self, hidden_states, context=None): + # note: if no context is given, cross-attention defaults to self-attention + batch, channel, height, weight = hidden_states.shape + residual = hidden_states + hidden_states = self.norm(hidden_states) + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + for block in self.transformer_blocks: + hidden_states = block(hidden_states, context=context) + hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2) + hidden_states = self.proj_out(hidden_states) + return hidden_states + residual -def max_neg_value(t): - return -torch.finfo(t.dtype).max +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (:obj:`int`): The number of channels in the input and output. + n_heads (:obj:`int`): The number of heads to use for multi-head attention. + d_head (:obj:`int`): The number of channels in each head. + dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. + context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention. + gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network. + checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing. + """ + + def __init__( + self, + dim: int, + n_heads: int, + d_head: int, + dropout=0.0, + context_dim: Optional[int] = None, + gated_ff: bool = True, + checkpoint: bool = True, + ): + super().__init__() + self.attn1 = CrossAttention( + query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout + ) # is a self-attention + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention( + query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout + ) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def _set_attention_slice(self, slice_size): + self.attn1._slice_size = slice_size + self.attn2._slice_size = slice_size + + def forward(self, hidden_states, context=None): + hidden_states = hidden_states.contiguous() if hidden_states.device.type == "mps" else hidden_states + hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states + hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + return hidden_states -def init_(tensor): - dim = tensor.shape[-1] - std = 1 / math.sqrt(dim) - tensor.uniform_(-std, std) - return tensor +class CrossAttention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (:obj:`int`): The number of channels in the query. + context_dim (:obj:`int`, *optional*): + The number of channels in the context. If not given, defaults to `query_dim`. + heads (:obj:`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. + dim_head (:obj:`int`, *optional*, defaults to 64): The number of channels in each head. + dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. + """ + + def __init__( + self, query_dim: int, context_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: int = 0.0 + ): + super().__init__() + inner_dim = dim_head * heads + context_dim = context_dim if context_dim is not None else query_dim + + self.scale = dim_head**-0.5 + self.heads = heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self._slice_size = None + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def forward(self, hidden_states, context=None, mask=None): + batch_size, sequence_length, _ = hidden_states.shape + + query = self.to_q(hidden_states) + context = context if context is not None else hidden_states + key = self.to_k(context) + value = self.to_v(context) + + dim = query.shape[-1] + + query = self.reshape_heads_to_batch_dim(query) + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + # TODO(PVP) - mask is currently never used. Remember to re-implement when used + + # attention, what we cannot get enough of + + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim) + + return self.to_out(hidden_states) + + def _attention(self, query, key, value): + # TODO: use baddbmm for better performance + attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale + attention_probs = attention_scores.softmax(dim=-1) + # compute attention output + hidden_states = torch.matmul(attention_probs, value) + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def _sliced_attention(self, query, key, value, sequence_length, dim): + batch_size_attention = query.shape[0] + hidden_states = torch.zeros( + (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype + ) + slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] + for i in range(hidden_states.shape[0] // slice_size): + start_idx = i * slice_size + end_idx = (i + 1) * slice_size + attn_slice = ( + torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale + ) # TODO: use baddbmm for better performance + attn_slice = attn_slice.softmax(dim=-1) + attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (:obj:`int`): The number of channels in the input. + dim_out (:obj:`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + glu (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use GLU activation. + dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. + """ + + def __init__( + self, dim: int, dim_out: Optional[int] = None, mult: int = 4, glu: bool = False, dropout: float = 0.0 + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + project_in = GEGLU(dim, inner_dim) + + self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) + + def forward(self, hidden_states): + return self.net(hidden_states) # feedforward class GEGLU(nn.Module): - def __init__(self, dim_in, dim_out): + r""" + A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. + + Parameters: + dim_in (:obj:`int`): The number of channels in the input. + dim_out (:obj:`int`): The number of channels in the output. + """ + + def __init__(self, dim_in: int, dim_out: int): super().__init__() self.proj = nn.Linear(dim_in, dim_out * 2) - def forward(self, x): - x, gate = self.proj(x).chunk(2, dim=-1) - return x * F.gelu(gate) - - -class FeedForward(nn.Module): - def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): - super().__init__() - inner_dim = int(dim * mult) - dim_out = default(dim_out, dim) - project_in = nn.Sequential( - nn.Linear(dim, inner_dim), - nn.GELU() - ) if not glu else GEGLU(dim, inner_dim) - - self.net = nn.Sequential( - project_in, - nn.Dropout(dropout), - nn.Linear(inner_dim, dim_out) - ) - - def forward(self, x): - return self.net(x) - - -def zero_module(module): - """ - Zero out the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().zero_() - return module - - -def Normalize(in_channels): - return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) - - -class LinearAttention(nn.Module): - def __init__(self, dim, heads=4, dim_head=32): - super().__init__() - self.heads = heads - hidden_dim = dim_head * heads - self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) - self.to_out = nn.Conv2d(hidden_dim, dim, 1) - - def forward(self, x): - b, c, h, w = x.shape - qkv = self.to_qkv(x) - q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) - k = k.softmax(dim=-1) - context = torch.einsum('bhdn,bhen->bhde', k, v) - out = torch.einsum('bhde,bhdn->bhen', context, q) - out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) - return self.to_out(out) - - -class SpatialSelfAttention(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.k = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.v = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - - def forward(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - b,c,h,w = q.shape - q = rearrange(q, 'b c h w -> b (h w) c') - k = rearrange(k, 'b c h w -> b c (h w)') - w_ = torch.einsum('bij,bjk->bik', q, k) - - w_ = w_ * (int(c)**(-0.5)) - w_ = torch.nn.functional.softmax(w_, dim=2) - - # attend to values - v = rearrange(v, 'b c h w -> b c (h w)') - w_ = rearrange(w_, 'b i j -> b j i') - h_ = torch.einsum('bij,bjk->bik', v, w_) - h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) - h_ = self.proj_out(h_) - - return x+h_ - - + def forward(self, hidden_states): + hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) + return hidden_states * F.gelu(gate) +''' class CrossAttention(nn.Module): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): super().__init__() @@ -172,48 +384,45 @@ class CrossAttention(nn.Module): self.cross_attention_callback = None - def einsum_op_compvis(self, q, k, v): - s = einsum('b i d, b j d -> b i j', q, k) - s = s.softmax(dim=-1, dtype=s.dtype) - return einsum('b i j, b j d -> b i d', s, v) - - def einsum_op_slice_0(self, q, k, v, slice_size): + def einsum_op_slice_dim0(self, q, k, v, slice_size, callback): r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) for i in range(0, q.shape[0], slice_size): end = i + slice_size - r[i:end] = self.einsum_op_compvis(q[i:end], k[i:end], v[i:end]) + r[i:end] = callback(q[i:end], k[i:end], v[i:end], offset=i) return r - def einsum_op_slice_1(self, q, k, v, slice_size): + def einsum_op_slice_dim1(self, q, k, v, slice_size, callback): r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) for i in range(0, q.shape[1], slice_size): end = i + slice_size - r[:, i:end] = self.einsum_op_compvis(q[:, i:end], k, v) + r[:, i:end] = callback(q[:, i:end], k, v, offset=i) return r - def einsum_op_mps_v1(self, q, k, v): + def einsum_op_mps_v1(self, q, k, v, callback): if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096 - return self.einsum_op_compvis(q, k, v) + return callback(q, k, v) else: slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) - return self.einsum_op_slice_1(q, k, v, slice_size) + return self.einsum_op_slice_dim1(q, k, v, slice_size, callback) - def einsum_op_mps_v2(self, q, k, v): + def einsum_op_mps_v2(self, q, k, v, callback): if self.mem_total_gb > 8 and q.shape[1] <= 4096: - return self.einsum_op_compvis(q, k, v) + return callback(q, k, v, offset=0) else: - return self.einsum_op_slice_0(q, k, v, 1) + return self.einsum_op_slice_dim0(q, k, v, 1, callback) - def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb): + def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb, callback): size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20) if size_mb <= max_tensor_mb: - return self.einsum_op_compvis(q, k, v) + return callback(q, k, v, offset=0) div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length() if div <= q.shape[0]: - return self.einsum_op_slice_0(q, k, v, q.shape[0] // div) - return self.einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1)) + print("warning: untested call to einsum_op_slice_dim0") + return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div, callback) + print("warning: untested call to einsum_op_slice_dim1") + return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1), callback) - def einsum_op_cuda(self, q, k, v): + def einsum_op_cuda(self, q, k, v, callback): stats = torch.cuda.memory_stats(q.device) mem_active = stats['active_bytes.all.current'] mem_reserved = stats['reserved_bytes.all.current'] @@ -221,20 +430,26 @@ class CrossAttention(nn.Module): mem_free_torch = mem_reserved - mem_active mem_free_total = mem_free_cuda + mem_free_torch # Divide factor of safety as there's copying and fragmentation - return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) + return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20), callback) + + def get_attention_mem_efficient(self, q, k, v, callback): + """ + Calculate attention by slicing q, k, and v for memory efficiency then calling + callback(q, k, v, offset=offset) + multiple times if necessary. The offset argument is something + """ - def einsum_op(self, q, k, v): if q.device.type == 'cuda': - return self.einsum_op_cuda(q, k, v) + return self.einsum_op_cuda(q, k, v, callback) if q.device.type == 'mps': if self.mem_total_gb >= 32: - return self.einsum_op_mps_v1(q, k, v) - return self.einsum_op_mps_v2(q, k, v) + return self.einsum_op_mps_v1(q, k, v, callback) + return self.einsum_op_mps_v2(q, k, v, callback) # Smaller slices are faster due to L2/L3/SLC caches. # Tested on i7 with 8MB L3 cache. - return self.einsum_op_tensor_mem(q, k, v, 32) + return self.einsum_op_tensor_mem(q, k, v, 32, callback) def forward(self, x, context=None, mask=None): h = self.heads @@ -247,14 +462,24 @@ class CrossAttention(nn.Module): q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - if self.cross_attention_callback is not None: - r = self.cross_attention_callback(q, k, v) - else: - r = self.einsum_op(q, k, v) + def default_attention_calculator(q, k, v, **kwargs): + # calculate attention scores + attention_scores = einsum('b i d, b j d -> b i j', q, k) + # calculate attenion slice by taking the best scores for each latent pixel + attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype) + return einsum('b i j, b j d -> b i d', attention_slice, v) + + attention_calculator = \ + self.custom_attention_calculator if self.custom_attention_calculator is not None \ + else default_attention_calculator + + r = self.get_attention_mem_efficient(q, k, v, attention_calculator) + hidden_states = rearrange(r, '(b h) n d -> b n (h d)', h=h) return self.to_out(hidden_states) + class BasicTransformerBlock(nn.Module): def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): super().__init__() @@ -322,3 +547,4 @@ class SpatialTransformer(nn.Module): x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) x = self.proj_out(x) return x + x_in +''' \ No newline at end of file diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py index 739710d006..73218d36f8 100644 --- a/ldm/modules/diffusionmodules/model.py +++ b/ldm/modules/diffusionmodules/model.py @@ -8,7 +8,7 @@ import numpy as np from einops import rearrange from ldm.util import instantiate_from_config -from ldm.modules.attention import LinearAttention +#from ldm.modules.attention import LinearAttention import psutil @@ -151,10 +151,10 @@ class ResnetBlock(nn.Module): return x + h -class LinAttnBlock(LinearAttention): - """to match AttnBlock usage""" - def __init__(self, in_channels): - super().__init__(dim=in_channels, heads=1, dim_head=in_channels) +#class LinAttnBlock(LinearAttention): +# """to match AttnBlock usage""" +# def __init__(self, in_channels): +# super().__init__(dim=in_channels, heads=1, dim_head=in_channels) class AttnBlock(nn.Module): diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index 426fccced3..12ef737134 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -449,7 +449,7 @@ class FrozenCLIPEmbedder(AbstractEncoder): tokens = batch_encoding['input_ids'].to(self.device) z = self.transformer(input_ids=tokens, **kwargs) - return z + return z, tokens def encode(self, text, **kwargs): return self(text, **kwargs)