diff --git a/c_a_c.py b/c_a_c.py new file mode 100644 index 0000000000..33243edaf7 --- /dev/null +++ b/c_a_c.py @@ -0,0 +1,177 @@ +# Functions supporting Cross-Attention Control +# Copied from https://github.com/bloc97/CrossAttentionControl + +from difflib import SequenceMatcher + +import torch + + +def prompt_token(prompt, index, clip_tokenizer): + tokens = clip_tokenizer(prompt, + padding='max_length', + max_length=clip_tokenizer.model_max_length, + truncation=True, + return_tensors='pt', + return_overflowing_tokens=True + ).input_ids[0] + return clip_tokenizer.decode(tokens[index:index+1]) + + +def init_attention_weights(weight_tuples, clip_tokenizer, unet, device): + tokens_length = clip_tokenizer.model_max_length + weights = torch.ones(tokens_length) + + for i, w in weight_tuples: + if i < tokens_length and i >= 0: + weights[i] = w + + for name, module in unet.named_modules(): + module_name = type(module).__name__ + if module_name == 'CrossAttention' and 'attn2' in name: + module.last_attn_slice_weights = weights.to(device) + if module_name == 'CrossAttention' and 'attn1' in name: + module.last_attn_slice_weights = None + + +def init_attention_edit(tokens, tokens_edit, clip_tokenizer, unet, device): + tokens_length = clip_tokenizer.model_max_length + mask = torch.zeros(tokens_length) + indices_target = torch.arange(tokens_length, dtype=torch.long) + indices = torch.zeros(tokens_length, dtype=torch.long) + + tokens = tokens.input_ids.numpy()[0] + tokens_edit = tokens_edit.input_ids.numpy()[0] + + for name, a0, a1, b0, b1 in SequenceMatcher(None, tokens, tokens_edit).get_opcodes(): + if b0 < tokens_length: + if name == 'equal' or (name == 'replace' and a1-a0 == b1-b0): + mask[b0:b1] = 1 + indices[b0:b1] = indices_target[a0:a1] + + for name, module in unet.named_modules(): + module_name = type(module).__name__ + if module_name == 'CrossAttention' and 'attn2' in name: + module.last_attn_slice_mask = mask.to(device) + module.last_attn_slice_indices = indices.to(device) + if module_name == 'CrossAttention' and 'attn1' in name: + module.last_attn_slice_mask = None + module.last_attn_slice_indices = None + + +def init_attention_func(unet): + # ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276 + def new_attention(self, query, key, value): + # TODO: use baddbmm for better performance + attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale + attn_slice = attention_scores.softmax(dim=-1) + # compute attention output + + if self.use_last_attn_slice: + if self.last_attn_slice_mask is not None: + new_attn_slice = (torch.index_select(self.last_attn_slice, -1, + self.last_attn_slice_indices)) + attn_slice = (attn_slice * (1 - self.last_attn_slice_mask) + + new_attn_slice * self.last_attn_slice_mask) + else: + attn_slice = self.last_attn_slice + + self.use_last_attn_slice = False + + if self.save_last_attn_slice: + self.last_attn_slice = attn_slice + self.save_last_attn_slice = False + + if self.use_last_attn_weights and self.last_attn_slice_weights is not None: + attn_slice = attn_slice * self.last_attn_slice_weights + self.use_last_attn_weights = False + + hidden_states = torch.matmul(attn_slice, value) + # reshape hidden_states + return self.reshape_batch_dim_to_heads(hidden_states) + + def new_sliced_attention(self, query, key, value, sequence_length, dim): + + batch_size_attention = query.shape[0] + hidden_states = torch.zeros( + (batch_size_attention, sequence_length, dim // self.heads), + device=query.device, dtype=query.dtype + ) + slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] + for i in range(hidden_states.shape[0] // slice_size): + start_idx = i * slice_size + end_idx = (i + 1) * slice_size + attn_slice = ( + torch.matmul(query[start_idx:end_idx], + key[start_idx:end_idx].transpose(1, 2)) * self.scale + ) # TODO: use baddbmm for better performance + attn_slice = attn_slice.softmax(dim=-1) + + if self.use_last_attn_slice: + if self.last_attn_slice_mask is not None: + new_attn_slice = (torch.index_select(self.last_attn_slice, + -1, self.last_attn_slice_indices)) + attn_slice = (attn_slice * (1 - self.last_attn_slice_mask) + + new_attn_slice * self.last_attn_slice_mask) + else: + attn_slice = self.last_attn_slice + + self.use_last_attn_slice = False + + if self.save_last_attn_slice: + self.last_attn_slice = attn_slice + self.save_last_attn_slice = False + + if self.use_last_attn_weights and self.last_attn_slice_weights is not None: + attn_slice = attn_slice * self.last_attn_slice_weights + self.use_last_attn_weights = False + + attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + return self.reshape_batch_dim_to_heads(hidden_states) # reshape hidden_states + + for _, module in unet.named_modules(): + module_name = type(module).__name__ + if module_name == 'CrossAttention': + module.last_attn_slice = None + module.use_last_attn_slice = False + module.use_last_attn_weights = False + module.save_last_attn_slice = False + module._sliced_attention = new_sliced_attention.__get__(module, type(module)) + module._attention = new_attention.__get__(module, type(module)) + + +def use_last_tokens_attention(unet, use=True): + for name, module in unet.named_modules(): + module_name = type(module).__name__ + if module_name == 'CrossAttention' and 'attn2' in name: + module.use_last_attn_slice = use + + +def use_last_tokens_attention_weights(unet, use=True): + for name, module in unet.named_modules(): + module_name = type(module).__name__ + if module_name == 'CrossAttention' and 'attn2' in name: + module.use_last_attn_weights = use + + +def use_last_self_attention(unet, use=True): + for name, module in unet.named_modules(): + module_name = type(module).__name__ + if module_name == 'CrossAttention' and 'attn1' in name: + module.use_last_attn_slice = use + + +def save_last_tokens_attention(unet, save=True): + for name, module in unet.named_modules(): + module_name = type(module).__name__ + if module_name == 'CrossAttention' and 'attn2' in name: + module.save_last_attn_slice = save + + +def save_last_self_attention(unet, save=True): + for name, module in unet.named_modules(): + module_name = type(module).__name__ + if module_name == 'CrossAttention' and 'attn1' in name: + module.save_last_attn_slice = save diff --git a/cross_attention_loop.py b/cross_attention_loop.py new file mode 100644 index 0000000000..e2d6eb6201 --- /dev/null +++ b/cross_attention_loop.py @@ -0,0 +1,185 @@ +import random + +import numpy as np +import torch + +from diffusers import (LMSDiscreteScheduler) +from PIL import Image +from torch import autocast +from tqdm.auto import tqdm + +import c_a_c + + +@torch.no_grad() +def stablediffusion( + clip, + clip_tokenizer, + device, + vae, + unet, + prompt='', + prompt_edit=None, + prompt_edit_token_weights=None, + prompt_edit_tokens_start=0.0, + prompt_edit_tokens_end=1.0, + prompt_edit_spatial_start=0.0, + prompt_edit_spatial_end=1.0, + guidance_scale=7.5, + steps=50, + seed=None, + width=512, + height=512, + init_image=None, + init_image_strength=0.5, + ): + if prompt_edit_token_weights is None: + prompt_edit_token_weights = [] + # Change size to multiple of 64 to prevent size mismatches inside model + width = width - width % 64 + height = height - height % 64 + + # If seed is None, randomly select seed from 0 to 2^32-1 + if seed is None: seed = random.randrange(2**32 - 1) + generator = torch.cuda.manual_seed(seed) + + # Set inference timesteps to scheduler + scheduler = LMSDiscreteScheduler(beta_start=0.00085, + beta_end=0.012, + beta_schedule='scaled_linear', + num_train_timesteps=1000, + ) + scheduler.set_timesteps(steps) + + # Preprocess image if it exists (img2img) + if init_image is not None: + # Resize and transpose for numpy b h w c -> torch b c h w + init_image = init_image.resize((width, height), resample=Image.Resampling.LANCZOS) + init_image = np.array(init_image).astype(np.float32) / 255.0 * 2.0 - 1.0 + init_image = torch.from_numpy(init_image[np.newaxis, ...].transpose(0, 3, 1, 2)) + + # If there is alpha channel, composite alpha for white, as the diffusion + # model does not support alpha channel + if init_image.shape[1] > 3: + init_image = init_image[:, :3] * init_image[:, 3:] + (1 - init_image[:, 3:]) + + # Move image to GPU + init_image = init_image.to(device) + + # Encode image + with autocast(device): + init_latent = (vae.encode(init_image) + .latent_dist + .sample(generator=generator) + * 0.18215) + + t_start = steps - int(steps * init_image_strength) + + else: + init_latent = torch.zeros((1, unet.in_channels, height // 8, width // 8), + device=device) + t_start = 0 + + # Generate random normal noise + noise = torch.randn(init_latent.shape, generator=generator, device=device) + latent = scheduler.add_noise(init_latent, + noise, + torch.tensor([scheduler.timesteps[t_start]], device=device) + ).to(device) + + # Process clip + with autocast(device): + tokens_uncond = clip_tokenizer('', padding='max_length', + max_length=clip_tokenizer.model_max_length, + truncation=True, return_tensors='pt', + return_overflowing_tokens=True + ) + embedding_uncond = clip(tokens_uncond.input_ids.to(device)).last_hidden_state + + tokens_cond = clip_tokenizer(prompt, padding='max_length', + max_length=clip_tokenizer.model_max_length, + truncation=True, return_tensors='pt', + return_overflowing_tokens=True + ) + embedding_cond = clip(tokens_cond.input_ids.to(device)).last_hidden_state + + # Process prompt editing + if prompt_edit is not None: + tokens_cond_edit = clip_tokenizer(prompt_edit, padding='max_length', + max_length=clip_tokenizer.model_max_length, + truncation=True, return_tensors='pt', + return_overflowing_tokens=True + ) + embedding_cond_edit = clip(tokens_cond_edit.input_ids.to(device)).last_hidden_state + + c_a_c.init_attention_edit(tokens_cond, tokens_cond_edit) + + c_a_c.init_attention_func() + c_a_c.init_attention_weights(prompt_edit_token_weights) + + timesteps = scheduler.timesteps[t_start:] + + for idx, timestep in tqdm(enumerate(timesteps), total=len(timesteps)): + t_index = t_start + idx + + latent_model_input = latent + latent_model_input = scheduler.scale_model_input(latent_model_input, timestep) + + # Predict the unconditional noise residual + noise_pred_uncond = unet(latent_model_input, + timestep, + encoder_hidden_states=embedding_uncond + ).sample + + # Prepare the Cross-Attention layers + if prompt_edit is not None: + c_a_c.save_last_tokens_attention() + c_a_c.save_last_self_attention() + else: + # Use weights on non-edited prompt when edit is None + c_a_c.use_last_tokens_attention_weights() + + # Predict the conditional noise residual and save the + # cross-attention layer activations + noise_pred_cond = unet(latent_model_input, + timestep, + encoder_hidden_states=embedding_cond + ).sample + + # Edit the Cross-Attention layer activations + if prompt_edit is not None: + t_scale = timestep / scheduler.num_train_timesteps + if (t_scale >= prompt_edit_tokens_start + and t_scale <= prompt_edit_tokens_end): + c_a_c.use_last_tokens_attention() + if (t_scale >= prompt_edit_spatial_start + and t_scale <= prompt_edit_spatial_end): + c_a_c.use_last_self_attention() + + # Use weights on edited prompt + c_a_c.use_last_tokens_attention_weights() + + # Predict the edited conditional noise residual using the + # cross-attention masks + noise_pred_cond = unet(latent_model_input, + timestep, + encoder_hidden_states=embedding_cond_edit + ).sample + + # Perform guidance + noise_pred = (noise_pred_uncond + guidance_scale + * (noise_pred_cond - noise_pred_uncond)) + + latent = scheduler.step(noise_pred, + t_index, + latent + ).prev_sample + + # scale and decode the image latents with vae + latent = latent / 0.18215 + image = vae.decode(latent.to(vae.dtype)).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + image = (image[0] * 255).round().astype('uint8') + return Image.fromarray(image)