mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
DRAFT: Cross-Attention Control
Signed-off-by: Ben Alkov <ben.alkov@gmail.com>
This commit is contained in:
parent
92d4dfaabf
commit
07a3df6001
177
c_a_c.py
Normal file
177
c_a_c.py
Normal file
@ -0,0 +1,177 @@
|
|||||||
|
# Functions supporting Cross-Attention Control
|
||||||
|
# Copied from https://github.com/bloc97/CrossAttentionControl
|
||||||
|
|
||||||
|
from difflib import SequenceMatcher
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_token(prompt, index, clip_tokenizer):
|
||||||
|
tokens = clip_tokenizer(prompt,
|
||||||
|
padding='max_length',
|
||||||
|
max_length=clip_tokenizer.model_max_length,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors='pt',
|
||||||
|
return_overflowing_tokens=True
|
||||||
|
).input_ids[0]
|
||||||
|
return clip_tokenizer.decode(tokens[index:index+1])
|
||||||
|
|
||||||
|
|
||||||
|
def init_attention_weights(weight_tuples, clip_tokenizer, unet, device):
|
||||||
|
tokens_length = clip_tokenizer.model_max_length
|
||||||
|
weights = torch.ones(tokens_length)
|
||||||
|
|
||||||
|
for i, w in weight_tuples:
|
||||||
|
if i < tokens_length and i >= 0:
|
||||||
|
weights[i] = w
|
||||||
|
|
||||||
|
for name, module in unet.named_modules():
|
||||||
|
module_name = type(module).__name__
|
||||||
|
if module_name == 'CrossAttention' and 'attn2' in name:
|
||||||
|
module.last_attn_slice_weights = weights.to(device)
|
||||||
|
if module_name == 'CrossAttention' and 'attn1' in name:
|
||||||
|
module.last_attn_slice_weights = None
|
||||||
|
|
||||||
|
|
||||||
|
def init_attention_edit(tokens, tokens_edit, clip_tokenizer, unet, device):
|
||||||
|
tokens_length = clip_tokenizer.model_max_length
|
||||||
|
mask = torch.zeros(tokens_length)
|
||||||
|
indices_target = torch.arange(tokens_length, dtype=torch.long)
|
||||||
|
indices = torch.zeros(tokens_length, dtype=torch.long)
|
||||||
|
|
||||||
|
tokens = tokens.input_ids.numpy()[0]
|
||||||
|
tokens_edit = tokens_edit.input_ids.numpy()[0]
|
||||||
|
|
||||||
|
for name, a0, a1, b0, b1 in SequenceMatcher(None, tokens, tokens_edit).get_opcodes():
|
||||||
|
if b0 < tokens_length:
|
||||||
|
if name == 'equal' or (name == 'replace' and a1-a0 == b1-b0):
|
||||||
|
mask[b0:b1] = 1
|
||||||
|
indices[b0:b1] = indices_target[a0:a1]
|
||||||
|
|
||||||
|
for name, module in unet.named_modules():
|
||||||
|
module_name = type(module).__name__
|
||||||
|
if module_name == 'CrossAttention' and 'attn2' in name:
|
||||||
|
module.last_attn_slice_mask = mask.to(device)
|
||||||
|
module.last_attn_slice_indices = indices.to(device)
|
||||||
|
if module_name == 'CrossAttention' and 'attn1' in name:
|
||||||
|
module.last_attn_slice_mask = None
|
||||||
|
module.last_attn_slice_indices = None
|
||||||
|
|
||||||
|
|
||||||
|
def init_attention_func(unet):
|
||||||
|
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
|
||||||
|
def new_attention(self, query, key, value):
|
||||||
|
# TODO: use baddbmm for better performance
|
||||||
|
attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
|
||||||
|
attn_slice = attention_scores.softmax(dim=-1)
|
||||||
|
# compute attention output
|
||||||
|
|
||||||
|
if self.use_last_attn_slice:
|
||||||
|
if self.last_attn_slice_mask is not None:
|
||||||
|
new_attn_slice = (torch.index_select(self.last_attn_slice, -1,
|
||||||
|
self.last_attn_slice_indices))
|
||||||
|
attn_slice = (attn_slice * (1 - self.last_attn_slice_mask)
|
||||||
|
+ new_attn_slice * self.last_attn_slice_mask)
|
||||||
|
else:
|
||||||
|
attn_slice = self.last_attn_slice
|
||||||
|
|
||||||
|
self.use_last_attn_slice = False
|
||||||
|
|
||||||
|
if self.save_last_attn_slice:
|
||||||
|
self.last_attn_slice = attn_slice
|
||||||
|
self.save_last_attn_slice = False
|
||||||
|
|
||||||
|
if self.use_last_attn_weights and self.last_attn_slice_weights is not None:
|
||||||
|
attn_slice = attn_slice * self.last_attn_slice_weights
|
||||||
|
self.use_last_attn_weights = False
|
||||||
|
|
||||||
|
hidden_states = torch.matmul(attn_slice, value)
|
||||||
|
# reshape hidden_states
|
||||||
|
return self.reshape_batch_dim_to_heads(hidden_states)
|
||||||
|
|
||||||
|
def new_sliced_attention(self, query, key, value, sequence_length, dim):
|
||||||
|
|
||||||
|
batch_size_attention = query.shape[0]
|
||||||
|
hidden_states = torch.zeros(
|
||||||
|
(batch_size_attention, sequence_length, dim // self.heads),
|
||||||
|
device=query.device, dtype=query.dtype
|
||||||
|
)
|
||||||
|
slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
|
||||||
|
for i in range(hidden_states.shape[0] // slice_size):
|
||||||
|
start_idx = i * slice_size
|
||||||
|
end_idx = (i + 1) * slice_size
|
||||||
|
attn_slice = (
|
||||||
|
torch.matmul(query[start_idx:end_idx],
|
||||||
|
key[start_idx:end_idx].transpose(1, 2)) * self.scale
|
||||||
|
) # TODO: use baddbmm for better performance
|
||||||
|
attn_slice = attn_slice.softmax(dim=-1)
|
||||||
|
|
||||||
|
if self.use_last_attn_slice:
|
||||||
|
if self.last_attn_slice_mask is not None:
|
||||||
|
new_attn_slice = (torch.index_select(self.last_attn_slice,
|
||||||
|
-1, self.last_attn_slice_indices))
|
||||||
|
attn_slice = (attn_slice * (1 - self.last_attn_slice_mask)
|
||||||
|
+ new_attn_slice * self.last_attn_slice_mask)
|
||||||
|
else:
|
||||||
|
attn_slice = self.last_attn_slice
|
||||||
|
|
||||||
|
self.use_last_attn_slice = False
|
||||||
|
|
||||||
|
if self.save_last_attn_slice:
|
||||||
|
self.last_attn_slice = attn_slice
|
||||||
|
self.save_last_attn_slice = False
|
||||||
|
|
||||||
|
if self.use_last_attn_weights and self.last_attn_slice_weights is not None:
|
||||||
|
attn_slice = attn_slice * self.last_attn_slice_weights
|
||||||
|
self.use_last_attn_weights = False
|
||||||
|
|
||||||
|
attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])
|
||||||
|
|
||||||
|
hidden_states[start_idx:end_idx] = attn_slice
|
||||||
|
|
||||||
|
return self.reshape_batch_dim_to_heads(hidden_states) # reshape hidden_states
|
||||||
|
|
||||||
|
for _, module in unet.named_modules():
|
||||||
|
module_name = type(module).__name__
|
||||||
|
if module_name == 'CrossAttention':
|
||||||
|
module.last_attn_slice = None
|
||||||
|
module.use_last_attn_slice = False
|
||||||
|
module.use_last_attn_weights = False
|
||||||
|
module.save_last_attn_slice = False
|
||||||
|
module._sliced_attention = new_sliced_attention.__get__(module, type(module))
|
||||||
|
module._attention = new_attention.__get__(module, type(module))
|
||||||
|
|
||||||
|
|
||||||
|
def use_last_tokens_attention(unet, use=True):
|
||||||
|
for name, module in unet.named_modules():
|
||||||
|
module_name = type(module).__name__
|
||||||
|
if module_name == 'CrossAttention' and 'attn2' in name:
|
||||||
|
module.use_last_attn_slice = use
|
||||||
|
|
||||||
|
|
||||||
|
def use_last_tokens_attention_weights(unet, use=True):
|
||||||
|
for name, module in unet.named_modules():
|
||||||
|
module_name = type(module).__name__
|
||||||
|
if module_name == 'CrossAttention' and 'attn2' in name:
|
||||||
|
module.use_last_attn_weights = use
|
||||||
|
|
||||||
|
|
||||||
|
def use_last_self_attention(unet, use=True):
|
||||||
|
for name, module in unet.named_modules():
|
||||||
|
module_name = type(module).__name__
|
||||||
|
if module_name == 'CrossAttention' and 'attn1' in name:
|
||||||
|
module.use_last_attn_slice = use
|
||||||
|
|
||||||
|
|
||||||
|
def save_last_tokens_attention(unet, save=True):
|
||||||
|
for name, module in unet.named_modules():
|
||||||
|
module_name = type(module).__name__
|
||||||
|
if module_name == 'CrossAttention' and 'attn2' in name:
|
||||||
|
module.save_last_attn_slice = save
|
||||||
|
|
||||||
|
|
||||||
|
def save_last_self_attention(unet, save=True):
|
||||||
|
for name, module in unet.named_modules():
|
||||||
|
module_name = type(module).__name__
|
||||||
|
if module_name == 'CrossAttention' and 'attn1' in name:
|
||||||
|
module.save_last_attn_slice = save
|
185
cross_attention_loop.py
Normal file
185
cross_attention_loop.py
Normal file
@ -0,0 +1,185 @@
|
|||||||
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from diffusers import (LMSDiscreteScheduler)
|
||||||
|
from PIL import Image
|
||||||
|
from torch import autocast
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
import c_a_c
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def stablediffusion(
|
||||||
|
clip,
|
||||||
|
clip_tokenizer,
|
||||||
|
device,
|
||||||
|
vae,
|
||||||
|
unet,
|
||||||
|
prompt='',
|
||||||
|
prompt_edit=None,
|
||||||
|
prompt_edit_token_weights=None,
|
||||||
|
prompt_edit_tokens_start=0.0,
|
||||||
|
prompt_edit_tokens_end=1.0,
|
||||||
|
prompt_edit_spatial_start=0.0,
|
||||||
|
prompt_edit_spatial_end=1.0,
|
||||||
|
guidance_scale=7.5,
|
||||||
|
steps=50,
|
||||||
|
seed=None,
|
||||||
|
width=512,
|
||||||
|
height=512,
|
||||||
|
init_image=None,
|
||||||
|
init_image_strength=0.5,
|
||||||
|
):
|
||||||
|
if prompt_edit_token_weights is None:
|
||||||
|
prompt_edit_token_weights = []
|
||||||
|
# Change size to multiple of 64 to prevent size mismatches inside model
|
||||||
|
width = width - width % 64
|
||||||
|
height = height - height % 64
|
||||||
|
|
||||||
|
# If seed is None, randomly select seed from 0 to 2^32-1
|
||||||
|
if seed is None: seed = random.randrange(2**32 - 1)
|
||||||
|
generator = torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
|
# Set inference timesteps to scheduler
|
||||||
|
scheduler = LMSDiscreteScheduler(beta_start=0.00085,
|
||||||
|
beta_end=0.012,
|
||||||
|
beta_schedule='scaled_linear',
|
||||||
|
num_train_timesteps=1000,
|
||||||
|
)
|
||||||
|
scheduler.set_timesteps(steps)
|
||||||
|
|
||||||
|
# Preprocess image if it exists (img2img)
|
||||||
|
if init_image is not None:
|
||||||
|
# Resize and transpose for numpy b h w c -> torch b c h w
|
||||||
|
init_image = init_image.resize((width, height), resample=Image.Resampling.LANCZOS)
|
||||||
|
init_image = np.array(init_image).astype(np.float32) / 255.0 * 2.0 - 1.0
|
||||||
|
init_image = torch.from_numpy(init_image[np.newaxis, ...].transpose(0, 3, 1, 2))
|
||||||
|
|
||||||
|
# If there is alpha channel, composite alpha for white, as the diffusion
|
||||||
|
# model does not support alpha channel
|
||||||
|
if init_image.shape[1] > 3:
|
||||||
|
init_image = init_image[:, :3] * init_image[:, 3:] + (1 - init_image[:, 3:])
|
||||||
|
|
||||||
|
# Move image to GPU
|
||||||
|
init_image = init_image.to(device)
|
||||||
|
|
||||||
|
# Encode image
|
||||||
|
with autocast(device):
|
||||||
|
init_latent = (vae.encode(init_image)
|
||||||
|
.latent_dist
|
||||||
|
.sample(generator=generator)
|
||||||
|
* 0.18215)
|
||||||
|
|
||||||
|
t_start = steps - int(steps * init_image_strength)
|
||||||
|
|
||||||
|
else:
|
||||||
|
init_latent = torch.zeros((1, unet.in_channels, height // 8, width // 8),
|
||||||
|
device=device)
|
||||||
|
t_start = 0
|
||||||
|
|
||||||
|
# Generate random normal noise
|
||||||
|
noise = torch.randn(init_latent.shape, generator=generator, device=device)
|
||||||
|
latent = scheduler.add_noise(init_latent,
|
||||||
|
noise,
|
||||||
|
torch.tensor([scheduler.timesteps[t_start]], device=device)
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
# Process clip
|
||||||
|
with autocast(device):
|
||||||
|
tokens_uncond = clip_tokenizer('', padding='max_length',
|
||||||
|
max_length=clip_tokenizer.model_max_length,
|
||||||
|
truncation=True, return_tensors='pt',
|
||||||
|
return_overflowing_tokens=True
|
||||||
|
)
|
||||||
|
embedding_uncond = clip(tokens_uncond.input_ids.to(device)).last_hidden_state
|
||||||
|
|
||||||
|
tokens_cond = clip_tokenizer(prompt, padding='max_length',
|
||||||
|
max_length=clip_tokenizer.model_max_length,
|
||||||
|
truncation=True, return_tensors='pt',
|
||||||
|
return_overflowing_tokens=True
|
||||||
|
)
|
||||||
|
embedding_cond = clip(tokens_cond.input_ids.to(device)).last_hidden_state
|
||||||
|
|
||||||
|
# Process prompt editing
|
||||||
|
if prompt_edit is not None:
|
||||||
|
tokens_cond_edit = clip_tokenizer(prompt_edit, padding='max_length',
|
||||||
|
max_length=clip_tokenizer.model_max_length,
|
||||||
|
truncation=True, return_tensors='pt',
|
||||||
|
return_overflowing_tokens=True
|
||||||
|
)
|
||||||
|
embedding_cond_edit = clip(tokens_cond_edit.input_ids.to(device)).last_hidden_state
|
||||||
|
|
||||||
|
c_a_c.init_attention_edit(tokens_cond, tokens_cond_edit)
|
||||||
|
|
||||||
|
c_a_c.init_attention_func()
|
||||||
|
c_a_c.init_attention_weights(prompt_edit_token_weights)
|
||||||
|
|
||||||
|
timesteps = scheduler.timesteps[t_start:]
|
||||||
|
|
||||||
|
for idx, timestep in tqdm(enumerate(timesteps), total=len(timesteps)):
|
||||||
|
t_index = t_start + idx
|
||||||
|
|
||||||
|
latent_model_input = latent
|
||||||
|
latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)
|
||||||
|
|
||||||
|
# Predict the unconditional noise residual
|
||||||
|
noise_pred_uncond = unet(latent_model_input,
|
||||||
|
timestep,
|
||||||
|
encoder_hidden_states=embedding_uncond
|
||||||
|
).sample
|
||||||
|
|
||||||
|
# Prepare the Cross-Attention layers
|
||||||
|
if prompt_edit is not None:
|
||||||
|
c_a_c.save_last_tokens_attention()
|
||||||
|
c_a_c.save_last_self_attention()
|
||||||
|
else:
|
||||||
|
# Use weights on non-edited prompt when edit is None
|
||||||
|
c_a_c.use_last_tokens_attention_weights()
|
||||||
|
|
||||||
|
# Predict the conditional noise residual and save the
|
||||||
|
# cross-attention layer activations
|
||||||
|
noise_pred_cond = unet(latent_model_input,
|
||||||
|
timestep,
|
||||||
|
encoder_hidden_states=embedding_cond
|
||||||
|
).sample
|
||||||
|
|
||||||
|
# Edit the Cross-Attention layer activations
|
||||||
|
if prompt_edit is not None:
|
||||||
|
t_scale = timestep / scheduler.num_train_timesteps
|
||||||
|
if (t_scale >= prompt_edit_tokens_start
|
||||||
|
and t_scale <= prompt_edit_tokens_end):
|
||||||
|
c_a_c.use_last_tokens_attention()
|
||||||
|
if (t_scale >= prompt_edit_spatial_start
|
||||||
|
and t_scale <= prompt_edit_spatial_end):
|
||||||
|
c_a_c.use_last_self_attention()
|
||||||
|
|
||||||
|
# Use weights on edited prompt
|
||||||
|
c_a_c.use_last_tokens_attention_weights()
|
||||||
|
|
||||||
|
# Predict the edited conditional noise residual using the
|
||||||
|
# cross-attention masks
|
||||||
|
noise_pred_cond = unet(latent_model_input,
|
||||||
|
timestep,
|
||||||
|
encoder_hidden_states=embedding_cond_edit
|
||||||
|
).sample
|
||||||
|
|
||||||
|
# Perform guidance
|
||||||
|
noise_pred = (noise_pred_uncond + guidance_scale
|
||||||
|
* (noise_pred_cond - noise_pred_uncond))
|
||||||
|
|
||||||
|
latent = scheduler.step(noise_pred,
|
||||||
|
t_index,
|
||||||
|
latent
|
||||||
|
).prev_sample
|
||||||
|
|
||||||
|
# scale and decode the image latents with vae
|
||||||
|
latent = latent / 0.18215
|
||||||
|
image = vae.decode(latent.to(vae.dtype)).sample
|
||||||
|
|
||||||
|
image = (image / 2 + 0.5).clamp(0, 1)
|
||||||
|
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||||
|
image = (image[0] * 255).round().astype('uint8')
|
||||||
|
return Image.fromarray(image)
|
Loading…
Reference in New Issue
Block a user