diff --git a/c_a_c.py b/c_a_c.py deleted file mode 100644 index 33243edaf7..0000000000 --- a/c_a_c.py +++ /dev/null @@ -1,177 +0,0 @@ -# Functions supporting Cross-Attention Control -# Copied from https://github.com/bloc97/CrossAttentionControl - -from difflib import SequenceMatcher - -import torch - - -def prompt_token(prompt, index, clip_tokenizer): - tokens = clip_tokenizer(prompt, - padding='max_length', - max_length=clip_tokenizer.model_max_length, - truncation=True, - return_tensors='pt', - return_overflowing_tokens=True - ).input_ids[0] - return clip_tokenizer.decode(tokens[index:index+1]) - - -def init_attention_weights(weight_tuples, clip_tokenizer, unet, device): - tokens_length = clip_tokenizer.model_max_length - weights = torch.ones(tokens_length) - - for i, w in weight_tuples: - if i < tokens_length and i >= 0: - weights[i] = w - - for name, module in unet.named_modules(): - module_name = type(module).__name__ - if module_name == 'CrossAttention' and 'attn2' in name: - module.last_attn_slice_weights = weights.to(device) - if module_name == 'CrossAttention' and 'attn1' in name: - module.last_attn_slice_weights = None - - -def init_attention_edit(tokens, tokens_edit, clip_tokenizer, unet, device): - tokens_length = clip_tokenizer.model_max_length - mask = torch.zeros(tokens_length) - indices_target = torch.arange(tokens_length, dtype=torch.long) - indices = torch.zeros(tokens_length, dtype=torch.long) - - tokens = tokens.input_ids.numpy()[0] - tokens_edit = tokens_edit.input_ids.numpy()[0] - - for name, a0, a1, b0, b1 in SequenceMatcher(None, tokens, tokens_edit).get_opcodes(): - if b0 < tokens_length: - if name == 'equal' or (name == 'replace' and a1-a0 == b1-b0): - mask[b0:b1] = 1 - indices[b0:b1] = indices_target[a0:a1] - - for name, module in unet.named_modules(): - module_name = type(module).__name__ - if module_name == 'CrossAttention' and 'attn2' in name: - module.last_attn_slice_mask = mask.to(device) - module.last_attn_slice_indices = indices.to(device) - if module_name == 'CrossAttention' and 'attn1' in name: - module.last_attn_slice_mask = None - module.last_attn_slice_indices = None - - -def init_attention_func(unet): - # ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276 - def new_attention(self, query, key, value): - # TODO: use baddbmm for better performance - attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale - attn_slice = attention_scores.softmax(dim=-1) - # compute attention output - - if self.use_last_attn_slice: - if self.last_attn_slice_mask is not None: - new_attn_slice = (torch.index_select(self.last_attn_slice, -1, - self.last_attn_slice_indices)) - attn_slice = (attn_slice * (1 - self.last_attn_slice_mask) - + new_attn_slice * self.last_attn_slice_mask) - else: - attn_slice = self.last_attn_slice - - self.use_last_attn_slice = False - - if self.save_last_attn_slice: - self.last_attn_slice = attn_slice - self.save_last_attn_slice = False - - if self.use_last_attn_weights and self.last_attn_slice_weights is not None: - attn_slice = attn_slice * self.last_attn_slice_weights - self.use_last_attn_weights = False - - hidden_states = torch.matmul(attn_slice, value) - # reshape hidden_states - return self.reshape_batch_dim_to_heads(hidden_states) - - def new_sliced_attention(self, query, key, value, sequence_length, dim): - - batch_size_attention = query.shape[0] - hidden_states = torch.zeros( - (batch_size_attention, sequence_length, dim // self.heads), - device=query.device, dtype=query.dtype - ) - slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] - for i in range(hidden_states.shape[0] // slice_size): - start_idx = i * slice_size - end_idx = (i + 1) * slice_size - attn_slice = ( - torch.matmul(query[start_idx:end_idx], - key[start_idx:end_idx].transpose(1, 2)) * self.scale - ) # TODO: use baddbmm for better performance - attn_slice = attn_slice.softmax(dim=-1) - - if self.use_last_attn_slice: - if self.last_attn_slice_mask is not None: - new_attn_slice = (torch.index_select(self.last_attn_slice, - -1, self.last_attn_slice_indices)) - attn_slice = (attn_slice * (1 - self.last_attn_slice_mask) - + new_attn_slice * self.last_attn_slice_mask) - else: - attn_slice = self.last_attn_slice - - self.use_last_attn_slice = False - - if self.save_last_attn_slice: - self.last_attn_slice = attn_slice - self.save_last_attn_slice = False - - if self.use_last_attn_weights and self.last_attn_slice_weights is not None: - attn_slice = attn_slice * self.last_attn_slice_weights - self.use_last_attn_weights = False - - attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx]) - - hidden_states[start_idx:end_idx] = attn_slice - - return self.reshape_batch_dim_to_heads(hidden_states) # reshape hidden_states - - for _, module in unet.named_modules(): - module_name = type(module).__name__ - if module_name == 'CrossAttention': - module.last_attn_slice = None - module.use_last_attn_slice = False - module.use_last_attn_weights = False - module.save_last_attn_slice = False - module._sliced_attention = new_sliced_attention.__get__(module, type(module)) - module._attention = new_attention.__get__(module, type(module)) - - -def use_last_tokens_attention(unet, use=True): - for name, module in unet.named_modules(): - module_name = type(module).__name__ - if module_name == 'CrossAttention' and 'attn2' in name: - module.use_last_attn_slice = use - - -def use_last_tokens_attention_weights(unet, use=True): - for name, module in unet.named_modules(): - module_name = type(module).__name__ - if module_name == 'CrossAttention' and 'attn2' in name: - module.use_last_attn_weights = use - - -def use_last_self_attention(unet, use=True): - for name, module in unet.named_modules(): - module_name = type(module).__name__ - if module_name == 'CrossAttention' and 'attn1' in name: - module.use_last_attn_slice = use - - -def save_last_tokens_attention(unet, save=True): - for name, module in unet.named_modules(): - module_name = type(module).__name__ - if module_name == 'CrossAttention' and 'attn2' in name: - module.save_last_attn_slice = save - - -def save_last_self_attention(unet, save=True): - for name, module in unet.named_modules(): - module_name = type(module).__name__ - if module_name == 'CrossAttention' and 'attn1' in name: - module.save_last_attn_slice = save diff --git a/cross_attention_loop.py b/cross_attention_loop.py index e2d6eb6201..ed3e3b0462 100644 --- a/cross_attention_loop.py +++ b/cross_attention_loop.py @@ -1,4 +1,5 @@ import random +import traceback import numpy as np import torch @@ -8,7 +9,7 @@ from PIL import Image from torch import autocast from tqdm.auto import tqdm -import c_a_c +import .ldm.models.diffusion.cross_attention @torch.no_grad() @@ -41,7 +42,7 @@ def stablediffusion( # If seed is None, randomly select seed from 0 to 2^32-1 if seed is None: seed = random.randrange(2**32 - 1) - generator = torch.cuda.manual_seed(seed) + generator = torch.manual_seed(seed) # Set inference timesteps to scheduler scheduler = LMSDiscreteScheduler(beta_start=0.00085, diff --git a/ldm/generate.py b/ldm/generate.py index 7fb68dec0a..b8945342b0 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -32,7 +32,7 @@ from ldm.invoke.pngwriter import PngWriter from ldm.invoke.args import metadata_from_png from ldm.invoke.image_util import InitImageResizer from ldm.invoke.devices import choose_torch_device, choose_precision -from ldm.invoke.conditioning import get_uc_and_c +from ldm.invoke.conditioning import get_uc_and_c_and_ec from ldm.invoke.model_cache import ModelCache from ldm.invoke.seamless import configure_model_padding from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale @@ -400,7 +400,7 @@ class Generate: mask_image = None try: - uc, c = get_uc_and_c( + uc, c, ec = get_uc_and_c_and_ec( prompt, model =self.model, skip_normalize=skip_normalize, log_tokens =self.log_tokenization @@ -438,7 +438,7 @@ class Generate: sampler=self.sampler, steps=steps, cfg_scale=cfg_scale, - conditioning=(uc, c), + conditioning=(uc, c, ec), ddim_eta=ddim_eta, image_callback=image_callback, # called after the final image is generated step_callback=step_callback, # called after each intermediate image is generated @@ -469,14 +469,14 @@ class Generate: save_original = save_original, image_callback = image_callback) - except RuntimeError as e: - print(traceback.format_exc(), file=sys.stderr) - print('>> Could not generate image.') except KeyboardInterrupt: if catch_interrupts: print('**Interrupted** Partial results will be returned.') else: raise KeyboardInterrupt + except (RuntimeError, Exception) as e: + print(traceback.format_exc(), file=sys.stderr) + print('>> Could not generate image.') toc = time.time() print('>> Usage stats:') diff --git a/ldm/invoke/conditioning.py b/ldm/invoke/conditioning.py index fedd965a2c..1453d9ce8c 100644 --- a/ldm/invoke/conditioning.py +++ b/ldm/invoke/conditioning.py @@ -12,7 +12,7 @@ log_tokenization() print out colour-coded tokens and warn if trunca import re import torch -def get_uc_and_c(prompt, model, log_tokens=False, skip_normalize=False): +def get_uc_and_c_and_ec(prompt, model, log_tokens=False, skip_normalize=False): # Extract Unconditioned Words From Prompt unconditioned_words = '' unconditional_regex = r'\[(.*?)\]' @@ -26,7 +26,19 @@ def get_uc_and_c(prompt, model, log_tokens=False, skip_normalize=False): clean_prompt = unconditional_regex_compile.sub(' ', prompt) prompt = re.sub(' +', ' ', clean_prompt) + edited_words = None + edited_regex = r'\{(.*?)\}' + edited = re.findall(edited_regex, prompt) + if len(edited) > 0: + edited_words = ' '.join(edited) + edited_regex_compile = re.compile(edited_regex) + clean_prompt = edited_regex_compile.sub(' ', prompt) + prompt = re.sub(' +', ' ', clean_prompt) + uc = model.get_learned_conditioning([unconditioned_words]) + ec = None + if edited_words is not None: + ec = model.get_learned_conditioning([edited_words]) # get weighted sub-prompts weighted_subprompts = split_weighted_subprompts( @@ -48,7 +60,7 @@ def get_uc_and_c(prompt, model, log_tokens=False, skip_normalize=False): log_tokenization(prompt, model, log_tokens, 1) c = model.get_learned_conditioning([prompt]) uc = model.get_learned_conditioning([unconditioned_words]) - return (uc, c) + return (uc, c, ec) def split_weighted_subprompts(text, skip_normalize=False)->list: """ diff --git a/ldm/invoke/generator/txt2img.py b/ldm/invoke/generator/txt2img.py index edd12c948c..23f03f22db 100644 --- a/ldm/invoke/generator/txt2img.py +++ b/ldm/invoke/generator/txt2img.py @@ -19,7 +19,7 @@ class Txt2Img(Generator): kwargs are 'width' and 'height' """ self.perlin = perlin - uc, c = conditioning + uc, c, ec = conditioning @torch.no_grad() def make_image(x_T): @@ -43,6 +43,7 @@ class Txt2Img(Generator): verbose = False, unconditional_guidance_scale = cfg_scale, unconditional_conditioning = uc, + edited_conditioning = ec, eta = ddim_eta, img_callback = step_callback, threshold = threshold, diff --git a/ldm/models/diffusion/cross_attention.py b/ldm/models/diffusion/cross_attention.py index c39d8d5959..a440eb3e6a 100644 --- a/ldm/models/diffusion/cross_attention.py +++ b/ldm/models/diffusion/cross_attention.py @@ -1,8 +1,8 @@ from enum import Enum +import torch -class CrossAttention: - +class CrossAttentionControl: class AttentionType(Enum): SELF = 1 TOKENS = 2 @@ -15,5 +15,146 @@ class CrossAttention: return module @classmethod - def inject_attention_mask_capture(cls, model, callback): - pass \ No newline at end of file + def setup_attention_editing(cls, model, original_tokens_length: int, + substitute_conditioning: torch.Tensor = None, + token_indices_to_edit: list = None): + + # adapted from init_attention_edit + self_attention_module = cls.get_attention_module(model, cls.AttentionType.SELF) + tokens_attention_module = cls.get_attention_module(model, cls.AttentionType.TOKENS) + + if substitute_conditioning is not None: + + device = substitute_conditioning.device + + # this is not very torch-y + mask = torch.zeros(original_tokens_length) + for i in token_indices_to_edit: + mask[i] = 1 + + self_attention_module.last_attn_slice_mask = None + self_attention_module.last_attn_slice_indices = None + tokens_attention_module.last_attn_slice_mask = mask.to(device) + tokens_attention_module.last_attn_slice_indices = torch.tensor(token_indices_to_edit, device=device) + + cls.inject_attention_functions(model) + + @classmethod + def request_save_attention_maps(cls, model): + self_attention_module = cls.get_attention_module(model, cls.AttentionType.SELF) + tokens_attention_module = cls.get_attention_module(model, cls.AttentionType.TOKENS) + self_attention_module.save_last_attn_slice = True + tokens_attention_module.save_last_attn_slice = True + + @classmethod + def request_apply_saved_attention_maps(cls, model): + self_attention_module = cls.get_attention_module(model, cls.AttentionType.SELF) + tokens_attention_module = cls.get_attention_module(model, cls.AttentionType.TOKENS) + self_attention_module.use_last_attn_slice = True + tokens_attention_module.use_last_attn_slice = True + + @classmethod + def inject_attention_functions(cls, unet): + # ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276 + def new_attention(self, query, key, value): + # TODO: use baddbmm for better performance + print(f"entered new_attention") + attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale + attn_slice = attention_scores.softmax(dim=-1) + # compute attention output + + if self.use_last_attn_slice: + if self.last_attn_slice_mask is not None: + print('using masked last_attn_slice') + + new_attn_slice = (torch.index_select(self.last_attn_slice, -1, + self.last_attn_slice_indices)) + attn_slice = (attn_slice * (1 - self.last_attn_slice_mask) + + new_attn_slice * self.last_attn_slice_mask) + else: + print('using unmasked last_attn_slice') + attn_slice = self.last_attn_slice + + self.use_last_attn_slice = False + else: + print('not using last_attn_slice') + + if self.save_last_attn_slice: + print('saving last_attn_slice') + self.last_attn_slice = attn_slice + self.save_last_attn_slice = False + else: + print('not saving last_attn_slice') + + if self.use_last_attn_weights and self.last_attn_slice_weights is not None: + attn_slice = attn_slice * self.last_attn_slice_weights + self.use_last_attn_weights = False + + hidden_states = torch.matmul(attn_slice, value) + # reshape hidden_states + return hidden_states + + for _, module in unet.named_modules(): + module_name = type(module).__name__ + if module_name == 'CrossAttention': + module.last_attn_slice = None + module.use_last_attn_slice = False + module.use_last_attn_weights = False + module.save_last_attn_slice = False + module.cross_attention_callback = new_attention.__get__(module, type(module)) + + +# original code below + +# Functions supporting Cross-Attention Control +# Copied from https://github.com/bloc97/CrossAttentionControl + +from difflib import SequenceMatcher + +import torch + + +def prompt_token(prompt, index, clip_tokenizer): + tokens = clip_tokenizer(prompt, + padding='max_length', + max_length=clip_tokenizer.model_max_length, + truncation=True, + return_tensors='pt', + return_overflowing_tokens=True + ).input_ids[0] + return clip_tokenizer.decode(tokens[index:index + 1]) + + +def use_last_tokens_attention(unet, use=True): + for name, module in unet.named_modules(): + module_name = type(module).__name__ + if module_name == 'CrossAttention' and 'attn2' in name: + module.use_last_attn_slice = use + + +def use_last_tokens_attention_weights(unet, use=True): + for name, module in unet.named_modules(): + module_name = type(module).__name__ + if module_name == 'CrossAttention' and 'attn2' in name: + module.use_last_attn_weights = use + + +def use_last_self_attention(unet, use=True): + for name, module in unet.named_modules(): + module_name = type(module).__name__ + if module_name == 'CrossAttention' and 'attn1' in name: + module.use_last_attn_slice = use + + +def save_last_tokens_attention(unet, save=True): + for name, module in unet.named_modules(): + module_name = type(module).__name__ + if module_name == 'CrossAttention' and 'attn2' in name: + module.save_last_attn_slice = save + + +def save_last_self_attention(unet, save=True): + for name, module in unet.named_modules(): + module_name = type(module).__name__ + if module_name == 'CrossAttention' and 'attn1' in name: + module.save_last_attn_slice = save diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index 8010b44d1d..29949aff8d 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -13,6 +13,7 @@ from ldm.modules.diffusionmodules.util import ( noise_like, extract_into_tensor, ) +from ldm.models.diffusion.cross_attention import CrossAttentionControl def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7): if threshold <= 0.0: @@ -29,21 +30,41 @@ def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7): class CFGDenoiser(nn.Module): - def __init__(self, model, threshold = 0, warmup = 0): + def __init__(self, model, threshold = 0, warmup = 0, edited_conditioning = None): super().__init__() self.inner_model = model self.threshold = threshold self.warmup_max = warmup self.warmup = max(warmup / 10, 1) + self.edited_conditioning = edited_conditioning + + if self.edited_conditioning is not None: + initial_tokens_count = 77 # ' a cat sitting on a car ' + token_indices_to_edit = [2] # 'cat' + CrossAttentionControl.setup_attention_editing(self.inner_model, initial_tokens_count, edited_conditioning, token_indices_to_edit) def forward(self, x, sigma, uncond, cond, cond_scale): x_in = torch.cat([x] * 2) sigma_in = torch.cat([sigma] * 2) cond_in = torch.cat([uncond, cond]) - uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) - module = self.get_attention_module(AttentionLayer.TOKENS) + print('generating new unconditioned latents') + unconditioned_latents = self.inner_model(x, sigma, cond=uncond) + + # process x using the original prompt, saving the attention maps if required + if self.edited_conditioning is not None: + # this is automatically toggled off after the model forward() + CrossAttentionControl.request_save_attention_maps(self.inner_model) + print('generating new conditioned latents') + conditioned_latents = self.inner_model(x, sigma, cond=cond) + + if self.edited_conditioning is not None: + # process x again, using the saved attention maps but the new conditioning + # this is automatically toggled off after the model forward() + CrossAttentionControl.request_apply_saved_attention_maps(self.inner_model) + print('generating edited conditioned latents') + conditioned_latents = self.inner_model(x, sigma, cond=self.edited_conditioning) if self.warmup < self.warmup_max: thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max)) @@ -52,7 +73,8 @@ class CFGDenoiser(nn.Module): thresh = self.threshold if thresh > self.threshold: thresh = self.threshold - return cfg_apply_threshold(uncond + (cond - uncond) * cond_scale, thresh) + delta = (conditioned_latents - unconditioned_latents) + return cfg_apply_threshold(unconditioned_latents + delta * cond_scale, thresh) class KSampler(Sampler): @@ -169,6 +191,7 @@ class KSampler(Sampler): log_every_t=100, unconditional_guidance_scale=1.0, unconditional_conditioning=None, + edited_conditioning=None, threshold = 0, perlin = 0, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... @@ -200,7 +223,7 @@ class KSampler(Sampler): else: x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] - model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10)) + model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10), edited_conditioning=edited_conditioning) extra_args = { 'cond': conditioning, 'uncond': unconditional_conditioning, diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index ef9c2d3e65..a9805e6c67 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -170,6 +170,8 @@ class CrossAttention(nn.Module): self.mem_total_gb = psutil.virtual_memory().total // (1 << 30) + self.cross_attention_callback = None + def einsum_op_compvis(self, q, k, v): s = einsum('b i d, b j d -> b i j', q, k) s = s.softmax(dim=-1, dtype=s.dtype) @@ -244,8 +246,13 @@ class CrossAttention(nn.Module): del context, x q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - r = self.einsum_op(q, k, v) - return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h)) + + if self.cross_attention_callback is not None: + r = self.cross_attention_callback(q, k, v) + else: + r = self.einsum_op(q, k, v) + hidden_states = rearrange(r, '(b h) n d -> b n (h d)', h=h) + return self.to_out(hidden_states) class BasicTransformerBlock(nn.Module):